├── examples ├── todo │ ├── structure.sql │ ├── README.md │ └── main.go ├── url_shortener │ ├── structure.sql │ ├── README.md │ └── main.go ├── README.md └── chat │ ├── README.md │ └── main.go ├── .gitignore ├── messages.go ├── go.mod ├── example_json_test.go ├── log ├── testingadapter │ └── adapter.go ├── kitlogadapter │ └── adapter.go ├── logrusadapter │ └── adapter.go ├── zapadapter │ └── adapter.go ├── log15adapter │ └── adapter.go └── zerologadapter │ ├── adapter.go │ └── adapter_test.go ├── pgxpool ├── doc.go ├── batch_results.go ├── conn_test.go ├── tx_test.go ├── bench_test.go ├── rows.go ├── stat.go ├── tx.go ├── conn.go └── common_test.go ├── LICENSE ├── .github └── workflows │ └── ci.yml ├── ci └── setup_test.bash ├── pgbouncer_test.go ├── go_stdlib.go ├── logger.go ├── stdlib └── bench_test.go ├── example_custom_type_test.go ├── large_objects.go ├── extended_query_builder.go ├── copy_from.go ├── helper_test.go ├── batch.go ├── large_objects_test.go ├── internal └── sanitize │ ├── sanitize_test.go │ └── sanitize.go ├── values.go ├── README.md ├── CHANGELOG.md ├── rows.go ├── doc.go ├── tx.go └── copy_from_test.go /examples/todo/structure.sql: -------------------------------------------------------------------------------- 1 | create table tasks ( 2 | id serial primary key, 3 | description text not null 4 | ); 5 | -------------------------------------------------------------------------------- /examples/url_shortener/structure.sql: -------------------------------------------------------------------------------- 1 | create table shortened_urls ( 2 | id text primary key, 3 | url text not null 4 | ); -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | 24 | .envrc 25 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | * chat is a command line chat program using listen/notify. 4 | * todo is a command line todo list that demonstrates basic CRUD actions. 5 | * url_shortener contains a simple example of using pgx in a web context. 6 | * [Tern](https://github.com/jackc/tern) is a migration tool that uses pgx. 7 | * [The Pithy Reader](https://github.com/jackc/tpr) is a RSS aggregator that uses pgx. 8 | -------------------------------------------------------------------------------- /messages.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "database/sql/driver" 5 | 6 | "github.com/jackc/pgtype" 7 | ) 8 | 9 | func convertDriverValuers(args []interface{}) ([]interface{}, error) { 10 | for i, arg := range args { 11 | switch arg := arg.(type) { 12 | case pgtype.BinaryEncoder: 13 | case pgtype.TextEncoder: 14 | case driver.Valuer: 15 | v, err := callValuerValue(arg) 16 | if err != nil { 17 | return nil, err 18 | } 19 | args[i] = v 20 | } 21 | } 22 | return args, nil 23 | } 24 | -------------------------------------------------------------------------------- /examples/chat/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | This is a sample chat program implemented using PostgreSQL's listen/notify 4 | functionality with pgx. 5 | 6 | Start multiple instances of this program connected to the same database to chat 7 | between them. 8 | 9 | ## Connection configuration 10 | 11 | The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) 12 | 13 | You can either export them then run chat: 14 | 15 | export PGHOST=/private/tmp 16 | ./chat 17 | 18 | Or you can prefix the chat execution with the environment variables: 19 | 20 | PGHOST=/private/tmp ./chat 21 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/jackc/pgx/v4 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/Masterminds/semver/v3 v3.1.1 7 | github.com/cockroachdb/apd v1.1.0 8 | github.com/go-kit/log v0.1.0 9 | github.com/gofrs/uuid v4.0.0+incompatible 10 | github.com/jackc/pgconn v1.12.1 11 | github.com/jackc/pgio v1.0.0 12 | github.com/jackc/pgproto3/v2 v2.3.0 13 | github.com/jackc/pgtype v1.11.0 14 | github.com/jackc/puddle v1.2.1 15 | github.com/rs/zerolog v1.15.0 16 | github.com/shopspring/decimal v1.2.0 17 | github.com/sirupsen/logrus v1.4.2 18 | github.com/stretchr/testify v1.7.0 19 | go.uber.org/zap v1.13.0 20 | gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec 21 | ) 22 | -------------------------------------------------------------------------------- /example_json_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | 8 | "github.com/jackc/pgx/v4" 9 | ) 10 | 11 | func Example_JSON() { 12 | conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 13 | if err != nil { 14 | fmt.Printf("Unable to establish connection: %v", err) 15 | return 16 | } 17 | 18 | type person struct { 19 | Name string `json:"name"` 20 | Age int `json:"age"` 21 | } 22 | 23 | input := person{ 24 | Name: "John", 25 | Age: 42, 26 | } 27 | 28 | var output person 29 | 30 | err = conn.QueryRow(context.Background(), "select $1::json", input).Scan(&output) 31 | if err != nil { 32 | fmt.Println(err) 33 | return 34 | } 35 | 36 | fmt.Println(output.Name, output.Age) 37 | // Output: 38 | // John 42 39 | } 40 | -------------------------------------------------------------------------------- /examples/url_shortener/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | This is a sample REST URL shortener service implemented using pgx as the connector to a PostgreSQL data store. 4 | 5 | # Usage 6 | 7 | Create a PostgreSQL database and run structure.sql into it to create the necessary data schema. 8 | 9 | Configure the database connection with `DATABASE_URL` or standard PostgreSQL (`PG*`) environment variables or 10 | 11 | Run main.go: 12 | 13 | ``` 14 | go run main.go 15 | ``` 16 | 17 | ## Create or Update a Shortened URL 18 | 19 | ``` 20 | curl -X PUT -d 'http://www.google.com' http://localhost:8080/google 21 | ``` 22 | 23 | ## Get a Shortened URL 24 | 25 | ``` 26 | curl http://localhost:8080/google 27 | ``` 28 | 29 | ## Delete a Shortened URL 30 | 31 | ``` 32 | curl -X DELETE http://localhost:8080/google 33 | ``` 34 | -------------------------------------------------------------------------------- /log/testingadapter/adapter.go: -------------------------------------------------------------------------------- 1 | // Package testingadapter provides a logger that writes to a test or benchmark 2 | // log. 3 | package testingadapter 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | 9 | "github.com/jackc/pgx/v4" 10 | ) 11 | 12 | // TestingLogger interface defines the subset of testing.TB methods used by this 13 | // adapter. 14 | type TestingLogger interface { 15 | Log(args ...interface{}) 16 | } 17 | 18 | type Logger struct { 19 | l TestingLogger 20 | } 21 | 22 | func NewLogger(l TestingLogger) *Logger { 23 | return &Logger{l: l} 24 | } 25 | 26 | func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { 27 | logArgs := make([]interface{}, 0, 2+len(data)) 28 | logArgs = append(logArgs, level, msg) 29 | for k, v := range data { 30 | logArgs = append(logArgs, fmt.Sprintf("%s=%v", k, v)) 31 | } 32 | l.l.Log(logArgs...) 33 | } 34 | -------------------------------------------------------------------------------- /pgxpool/doc.go: -------------------------------------------------------------------------------- 1 | // Package pgxpool is a concurrency-safe connection pool for pgx. 2 | /* 3 | pgxpool implements a nearly identical interface to pgx connections. 4 | 5 | Establishing a Connection 6 | 7 | The primary way of establishing a connection is with `pgxpool.Connect`. 8 | 9 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) 10 | 11 | The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be 12 | specified here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the 13 | connection with `ConnectConfig`. 14 | 15 | config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) 16 | if err != nil { 17 | // ... 18 | } 19 | config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { 20 | // do something with every new connection 21 | } 22 | 23 | pool, err := pgxpool.ConnectConfig(context.Background(), config) 24 | */ 25 | package pgxpool 26 | -------------------------------------------------------------------------------- /log/kitlogadapter/adapter.go: -------------------------------------------------------------------------------- 1 | package kitlogadapter 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/go-kit/log" 7 | kitlevel "github.com/go-kit/log/level" 8 | "github.com/jackc/pgx/v4" 9 | ) 10 | 11 | type Logger struct { 12 | l log.Logger 13 | } 14 | 15 | func NewLogger(l log.Logger) *Logger { 16 | return &Logger{l: l} 17 | } 18 | 19 | func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { 20 | logger := l.l 21 | for k, v := range data { 22 | logger = log.With(logger, k, v) 23 | } 24 | 25 | switch level { 26 | case pgx.LogLevelTrace: 27 | logger.Log("PGX_LOG_LEVEL", level, "msg", msg) 28 | case pgx.LogLevelDebug: 29 | kitlevel.Debug(logger).Log("msg", msg) 30 | case pgx.LogLevelInfo: 31 | kitlevel.Info(logger).Log("msg", msg) 32 | case pgx.LogLevelWarn: 33 | kitlevel.Warn(logger).Log("msg", msg) 34 | case pgx.LogLevelError: 35 | kitlevel.Error(logger).Log("msg", msg) 36 | default: 37 | logger.Log("INVALID_PGX_LOG_LEVEL", level, "error", msg) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /log/logrusadapter/adapter.go: -------------------------------------------------------------------------------- 1 | // Package logrusadapter provides a logger that writes to a github.com/sirupsen/logrus.Logger 2 | // log. 3 | package logrusadapter 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/jackc/pgx/v4" 9 | "github.com/sirupsen/logrus" 10 | ) 11 | 12 | type Logger struct { 13 | l logrus.FieldLogger 14 | } 15 | 16 | func NewLogger(l logrus.FieldLogger) *Logger { 17 | return &Logger{l: l} 18 | } 19 | 20 | func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { 21 | var logger logrus.FieldLogger 22 | if data != nil { 23 | logger = l.l.WithFields(data) 24 | } else { 25 | logger = l.l 26 | } 27 | 28 | switch level { 29 | case pgx.LogLevelTrace: 30 | logger.WithField("PGX_LOG_LEVEL", level).Debug(msg) 31 | case pgx.LogLevelDebug: 32 | logger.Debug(msg) 33 | case pgx.LogLevelInfo: 34 | logger.Info(msg) 35 | case pgx.LogLevelWarn: 36 | logger.Warn(msg) 37 | case pgx.LogLevelError: 38 | logger.Error(msg) 39 | default: 40 | logger.WithField("INVALID_PGX_LOG_LEVEL", level).Error(msg) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013-2021 Jack Christensen 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining 6 | a copy of this software and associated documentation files (the 7 | "Software"), to deal in the Software without restriction, including 8 | without limitation the rights to use, copy, modify, merge, publish, 9 | distribute, sublicense, and/or sell copies of the Software, and to 10 | permit persons to whom the Software is furnished to do so, subject to 11 | the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be 14 | included in all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 19 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 20 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 22 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /log/zapadapter/adapter.go: -------------------------------------------------------------------------------- 1 | // Package zapadapter provides a logger that writes to a go.uber.org/zap.Logger. 2 | package zapadapter 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/jackc/pgx/v4" 8 | "go.uber.org/zap" 9 | "go.uber.org/zap/zapcore" 10 | ) 11 | 12 | type Logger struct { 13 | logger *zap.Logger 14 | } 15 | 16 | func NewLogger(logger *zap.Logger) *Logger { 17 | return &Logger{logger: logger.WithOptions(zap.AddCallerSkip(1))} 18 | } 19 | 20 | func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { 21 | fields := make([]zapcore.Field, len(data)) 22 | i := 0 23 | for k, v := range data { 24 | fields[i] = zap.Any(k, v) 25 | i++ 26 | } 27 | 28 | switch level { 29 | case pgx.LogLevelTrace: 30 | pl.logger.Debug(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) 31 | case pgx.LogLevelDebug: 32 | pl.logger.Debug(msg, fields...) 33 | case pgx.LogLevelInfo: 34 | pl.logger.Info(msg, fields...) 35 | case pgx.LogLevelWarn: 36 | pl.logger.Warn(msg, fields...) 37 | case pgx.LogLevelError: 38 | pl.logger.Error(msg, fields...) 39 | default: 40 | pl.logger.Error(msg, append(fields, zap.Stringer("PGX_LOG_LEVEL", level))...) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /log/log15adapter/adapter.go: -------------------------------------------------------------------------------- 1 | // Package log15adapter provides a logger that writes to a github.com/inconshreveable/log15.Logger 2 | // log. 3 | package log15adapter 4 | 5 | import ( 6 | "context" 7 | 8 | "github.com/jackc/pgx/v4" 9 | ) 10 | 11 | // Log15Logger interface defines the subset of 12 | // github.com/inconshreveable/log15.Logger that this adapter uses. 13 | type Log15Logger interface { 14 | Debug(msg string, ctx ...interface{}) 15 | Info(msg string, ctx ...interface{}) 16 | Warn(msg string, ctx ...interface{}) 17 | Error(msg string, ctx ...interface{}) 18 | Crit(msg string, ctx ...interface{}) 19 | } 20 | 21 | type Logger struct { 22 | l Log15Logger 23 | } 24 | 25 | func NewLogger(l Log15Logger) *Logger { 26 | return &Logger{l: l} 27 | } 28 | 29 | func (l *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { 30 | logArgs := make([]interface{}, 0, len(data)) 31 | for k, v := range data { 32 | logArgs = append(logArgs, k, v) 33 | } 34 | 35 | switch level { 36 | case pgx.LogLevelTrace: 37 | l.l.Debug(msg, append(logArgs, "PGX_LOG_LEVEL", level)...) 38 | case pgx.LogLevelDebug: 39 | l.l.Debug(msg, logArgs...) 40 | case pgx.LogLevelInfo: 41 | l.l.Info(msg, logArgs...) 42 | case pgx.LogLevelWarn: 43 | l.l.Warn(msg, logArgs...) 44 | case pgx.LogLevelError: 45 | l.l.Error(msg, logArgs...) 46 | default: 47 | l.l.Error(msg, append(logArgs, "INVALID_PGX_LOG_LEVEL", level)...) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /pgxpool/batch_results.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "github.com/jackc/pgconn" 5 | "github.com/jackc/pgx/v4" 6 | ) 7 | 8 | type errBatchResults struct { 9 | err error 10 | } 11 | 12 | func (br errBatchResults) Exec() (pgconn.CommandTag, error) { 13 | return nil, br.err 14 | } 15 | 16 | func (br errBatchResults) Query() (pgx.Rows, error) { 17 | return errRows{err: br.err}, br.err 18 | } 19 | 20 | func (br errBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { 21 | return nil, br.err 22 | } 23 | 24 | func (br errBatchResults) QueryRow() pgx.Row { 25 | return errRow{err: br.err} 26 | } 27 | 28 | func (br errBatchResults) Close() error { 29 | return br.err 30 | } 31 | 32 | type poolBatchResults struct { 33 | br pgx.BatchResults 34 | c *Conn 35 | } 36 | 37 | func (br *poolBatchResults) Exec() (pgconn.CommandTag, error) { 38 | return br.br.Exec() 39 | } 40 | 41 | func (br *poolBatchResults) Query() (pgx.Rows, error) { 42 | return br.br.Query() 43 | } 44 | 45 | func (br *poolBatchResults) QueryFunc(scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { 46 | return br.br.QueryFunc(scans, f) 47 | } 48 | 49 | func (br *poolBatchResults) QueryRow() pgx.Row { 50 | return br.br.QueryRow() 51 | } 52 | 53 | func (br *poolBatchResults) Close() error { 54 | err := br.br.Close() 55 | if br.c != nil { 56 | br.c.Release() 57 | br.c = nil 58 | } 59 | return err 60 | } 61 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | test: 12 | name: Test 13 | runs-on: ubuntu-20.04 14 | 15 | strategy: 16 | matrix: 17 | go-version: [1.16, 1.17] 18 | pg-version: [10, 11, 12, 13, 14, cockroachdb] 19 | include: 20 | - pg-version: 10 21 | pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test 22 | - pg-version: 11 23 | pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test 24 | - pg-version: 12 25 | pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test 26 | - pg-version: 13 27 | pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test 28 | - pg-version: 14 29 | pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test 30 | - pg-version: cockroachdb 31 | pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" 32 | 33 | steps: 34 | 35 | - name: Set up Go 1.x 36 | uses: actions/setup-go@v2 37 | with: 38 | go-version: ${{ matrix.go-version }} 39 | 40 | - name: Check out code into the Go module directory 41 | uses: actions/checkout@v2 42 | 43 | - name: Setup database server for testing 44 | run: ci/setup_test.bash 45 | env: 46 | PGVERSION: ${{ matrix.pg-version }} 47 | 48 | - name: Test 49 | run: go test -race ./... 50 | env: 51 | PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} 52 | -------------------------------------------------------------------------------- /examples/chat/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "fmt" 7 | "os" 8 | 9 | "github.com/jackc/pgx/v4/pgxpool" 10 | ) 11 | 12 | var pool *pgxpool.Pool 13 | 14 | func main() { 15 | var err error 16 | pool, err = pgxpool.Connect(context.Background(), os.Getenv("DATABASE_URL")) 17 | if err != nil { 18 | fmt.Fprintln(os.Stderr, "Unable to connect to database:", err) 19 | os.Exit(1) 20 | } 21 | 22 | go listen() 23 | 24 | fmt.Println(`Type a message and press enter. 25 | 26 | This message should appear in any other chat instances connected to the same 27 | database. 28 | 29 | Type "exit" to quit.`) 30 | 31 | scanner := bufio.NewScanner(os.Stdin) 32 | for scanner.Scan() { 33 | msg := scanner.Text() 34 | if msg == "exit" { 35 | os.Exit(0) 36 | } 37 | 38 | _, err = pool.Exec(context.Background(), "select pg_notify('chat', $1)", msg) 39 | if err != nil { 40 | fmt.Fprintln(os.Stderr, "Error sending notification:", err) 41 | os.Exit(1) 42 | } 43 | } 44 | if err := scanner.Err(); err != nil { 45 | fmt.Fprintln(os.Stderr, "Error scanning from stdin:", err) 46 | os.Exit(1) 47 | } 48 | } 49 | 50 | func listen() { 51 | conn, err := pool.Acquire(context.Background()) 52 | if err != nil { 53 | fmt.Fprintln(os.Stderr, "Error acquiring connection:", err) 54 | os.Exit(1) 55 | } 56 | defer conn.Release() 57 | 58 | _, err = conn.Exec(context.Background(), "listen chat") 59 | if err != nil { 60 | fmt.Fprintln(os.Stderr, "Error listening to chat channel:", err) 61 | os.Exit(1) 62 | } 63 | 64 | for { 65 | notification, err := conn.Conn().WaitForNotification(context.Background()) 66 | if err != nil { 67 | fmt.Fprintln(os.Stderr, "Error waiting for notification:", err) 68 | os.Exit(1) 69 | } 70 | 71 | fmt.Println("PID:", notification.PID, "Channel:", notification.Channel, "Payload:", notification.Payload) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /pgxpool/conn_test.go: -------------------------------------------------------------------------------- 1 | package pgxpool_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v4/pgxpool" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestConnExec(t *testing.T) { 13 | t.Parallel() 14 | 15 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 16 | require.NoError(t, err) 17 | defer pool.Close() 18 | 19 | c, err := pool.Acquire(context.Background()) 20 | require.NoError(t, err) 21 | defer c.Release() 22 | 23 | testExec(t, c) 24 | } 25 | 26 | func TestConnQuery(t *testing.T) { 27 | t.Parallel() 28 | 29 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 30 | require.NoError(t, err) 31 | defer pool.Close() 32 | 33 | c, err := pool.Acquire(context.Background()) 34 | require.NoError(t, err) 35 | defer c.Release() 36 | 37 | testQuery(t, c) 38 | } 39 | 40 | func TestConnQueryRow(t *testing.T) { 41 | t.Parallel() 42 | 43 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 44 | require.NoError(t, err) 45 | defer pool.Close() 46 | 47 | c, err := pool.Acquire(context.Background()) 48 | require.NoError(t, err) 49 | defer c.Release() 50 | 51 | testQueryRow(t, c) 52 | } 53 | 54 | func TestConnSendBatch(t *testing.T) { 55 | t.Parallel() 56 | 57 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 58 | require.NoError(t, err) 59 | defer pool.Close() 60 | 61 | c, err := pool.Acquire(context.Background()) 62 | require.NoError(t, err) 63 | defer c.Release() 64 | 65 | testSendBatch(t, c) 66 | } 67 | 68 | func TestConnCopyFrom(t *testing.T) { 69 | t.Parallel() 70 | 71 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 72 | require.NoError(t, err) 73 | defer pool.Close() 74 | 75 | c, err := pool.Acquire(context.Background()) 76 | require.NoError(t, err) 77 | defer c.Release() 78 | 79 | testCopyFrom(t, c) 80 | } 81 | -------------------------------------------------------------------------------- /pgxpool/tx_test.go: -------------------------------------------------------------------------------- 1 | package pgxpool_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v4/pgxpool" 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestTxExec(t *testing.T) { 13 | t.Parallel() 14 | 15 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 16 | require.NoError(t, err) 17 | defer pool.Close() 18 | 19 | tx, err := pool.Begin(context.Background()) 20 | require.NoError(t, err) 21 | defer tx.Rollback(context.Background()) 22 | 23 | testExec(t, tx) 24 | } 25 | 26 | func TestTxQuery(t *testing.T) { 27 | t.Parallel() 28 | 29 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 30 | require.NoError(t, err) 31 | defer pool.Close() 32 | 33 | tx, err := pool.Begin(context.Background()) 34 | require.NoError(t, err) 35 | defer tx.Rollback(context.Background()) 36 | 37 | testQuery(t, tx) 38 | } 39 | 40 | func TestTxQueryRow(t *testing.T) { 41 | t.Parallel() 42 | 43 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 44 | require.NoError(t, err) 45 | defer pool.Close() 46 | 47 | tx, err := pool.Begin(context.Background()) 48 | require.NoError(t, err) 49 | defer tx.Rollback(context.Background()) 50 | 51 | testQueryRow(t, tx) 52 | } 53 | 54 | func TestTxSendBatch(t *testing.T) { 55 | t.Parallel() 56 | 57 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 58 | require.NoError(t, err) 59 | defer pool.Close() 60 | 61 | tx, err := pool.Begin(context.Background()) 62 | require.NoError(t, err) 63 | defer tx.Rollback(context.Background()) 64 | 65 | testSendBatch(t, tx) 66 | } 67 | 68 | func TestTxCopyFrom(t *testing.T) { 69 | t.Parallel() 70 | 71 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 72 | require.NoError(t, err) 73 | defer pool.Close() 74 | 75 | tx, err := pool.Begin(context.Background()) 76 | require.NoError(t, err) 77 | defer tx.Rollback(context.Background()) 78 | 79 | testCopyFrom(t, tx) 80 | } 81 | -------------------------------------------------------------------------------- /examples/todo/README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | This is a sample todo list implemented using pgx as the connector to a 4 | PostgreSQL data store. 5 | 6 | # Usage 7 | 8 | Create a PostgreSQL database and run structure.sql into it to create the 9 | necessary data schema. 10 | 11 | Example: 12 | 13 | createdb todo 14 | psql todo < structure.sql 15 | 16 | Build todo: 17 | 18 | go build 19 | 20 | ## Connection configuration 21 | 22 | The database connection is configured via DATABASE_URL and standard PostgreSQL environment variables (PGHOST, PGUSER, etc.) 23 | 24 | You can either export them then run todo: 25 | 26 | export PGDATABASE=todo 27 | ./todo list 28 | 29 | Or you can prefix the todo execution with the environment variables: 30 | 31 | PGDATABASE=todo ./todo list 32 | 33 | ## Add a todo item 34 | 35 | ./todo add 'Learn go' 36 | 37 | ## List tasks 38 | 39 | ./todo list 40 | 41 | ## Update a task 42 | 43 | ./todo update 1 'Learn more go' 44 | 45 | ## Delete a task 46 | 47 | ./todo remove 1 48 | 49 | # Example Setup and Execution 50 | 51 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ createdb todo 52 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ psql todo < structure.sql 53 | Expanded display is used automatically. 54 | Timing is on. 55 | CREATE TABLE 56 | Time: 6.363 ms 57 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ go build 58 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ export PGDATABASE=todo 59 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 60 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo add 'Learn Go' 61 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 62 | 1. Learn Go 63 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo update 1 'Learn more Go' 64 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 65 | 1. Learn more Go 66 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo remove 1 67 | jack@hk-47~/dev/go/src/github.com/jackc/pgx/examples/todo$ ./todo list 68 | -------------------------------------------------------------------------------- /pgxpool/bench_test.go: -------------------------------------------------------------------------------- 1 | package pgxpool_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v4" 9 | "github.com/jackc/pgx/v4/pgxpool" 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func BenchmarkAcquireAndRelease(b *testing.B) { 14 | pool, err := pgxpool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 15 | require.NoError(b, err) 16 | defer pool.Close() 17 | 18 | b.ResetTimer() 19 | for i := 0; i < b.N; i++ { 20 | c, err := pool.Acquire(context.Background()) 21 | if err != nil { 22 | b.Fatal(err) 23 | } 24 | c.Release() 25 | } 26 | } 27 | 28 | func BenchmarkMinimalPreparedSelectBaseline(b *testing.B) { 29 | config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 30 | require.NoError(b, err) 31 | 32 | config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { 33 | _, err := c.Prepare(ctx, "ps1", "select $1::int8") 34 | return err 35 | } 36 | 37 | db, err := pgxpool.ConnectConfig(context.Background(), config) 38 | require.NoError(b, err) 39 | 40 | conn, err := db.Acquire(context.Background()) 41 | require.NoError(b, err) 42 | defer conn.Release() 43 | 44 | var n int64 45 | 46 | b.ResetTimer() 47 | for i := 0; i < b.N; i++ { 48 | err = conn.QueryRow(context.Background(), "ps1", i).Scan(&n) 49 | if err != nil { 50 | b.Fatal(err) 51 | } 52 | 53 | if n != int64(i) { 54 | b.Fatalf("expected %d, got %d", i, n) 55 | } 56 | } 57 | } 58 | 59 | func BenchmarkMinimalPreparedSelect(b *testing.B) { 60 | config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 61 | require.NoError(b, err) 62 | 63 | config.AfterConnect = func(ctx context.Context, c *pgx.Conn) error { 64 | _, err := c.Prepare(ctx, "ps1", "select $1::int8") 65 | return err 66 | } 67 | 68 | db, err := pgxpool.ConnectConfig(context.Background(), config) 69 | require.NoError(b, err) 70 | 71 | var n int64 72 | 73 | b.ResetTimer() 74 | for i := 0; i < b.N; i++ { 75 | err = db.QueryRow(context.Background(), "ps1", i).Scan(&n) 76 | if err != nil { 77 | b.Fatal(err) 78 | } 79 | 80 | if n != int64(i) { 81 | b.Fatalf("expected %d, got %d", i, n) 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /ci/setup_test.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eux 3 | 4 | if [[ "${PGVERSION-}" =~ ^[0-9.]+$ ]] 5 | then 6 | sudo apt-get remove -y --purge postgresql libpq-dev libpq5 postgresql-client-common postgresql-common 7 | sudo rm -rf /var/lib/postgresql 8 | wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - 9 | sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" 10 | sudo apt-get update -qq 11 | sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION 12 | sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf 13 | echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf 14 | echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf 15 | echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf 16 | sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf 17 | if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then 18 | echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf 19 | echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf 20 | echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf 21 | fi 22 | sudo /etc/init.d/postgresql restart 23 | 24 | psql -U postgres -c 'create database pgx_test' 25 | psql -U postgres pgx_test -c 'create extension hstore' 26 | psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)' 27 | psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'" 28 | psql -U postgres -c "create user `whoami`" 29 | fi 30 | 31 | if [[ "${PGVERSION-}" =~ ^cockroach ]] 32 | then 33 | wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz 34 | sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/ 35 | cockroach start-single-node --insecure --background --listen-addr=localhost 36 | cockroach sql --insecure -e 'create database pgx_test' 37 | fi 38 | 39 | if [ "${CRATEVERSION-}" != "" ] 40 | then 41 | docker run \ 42 | -p "6543:5432" \ 43 | -d \ 44 | crate:"$CRATEVERSION" \ 45 | crate \ 46 | -Cnetwork.host=0.0.0.0 \ 47 | -Ctransport.host=localhost \ 48 | -Clicense.enterprise=false 49 | fi 50 | -------------------------------------------------------------------------------- /pgbouncer_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/jackc/pgconn" 9 | "github.com/jackc/pgconn/stmtcache" 10 | "github.com/jackc/pgx/v4" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestPgbouncerStatementCacheDescribe(t *testing.T) { 16 | connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") 17 | if connString == "" { 18 | t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") 19 | } 20 | 21 | config := mustParseConfig(t, connString) 22 | config.BuildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache { 23 | return stmtcache.New(conn, stmtcache.ModeDescribe, 1024) 24 | } 25 | 26 | testPgbouncer(t, config, 10, 100) 27 | } 28 | 29 | func TestPgbouncerSimpleProtocol(t *testing.T) { 30 | connString := os.Getenv("PGX_TEST_PGBOUNCER_CONN_STRING") 31 | if connString == "" { 32 | t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_PGBOUNCER_CONN_STRING") 33 | } 34 | 35 | config := mustParseConfig(t, connString) 36 | config.BuildStatementCache = nil 37 | config.PreferSimpleProtocol = true 38 | 39 | testPgbouncer(t, config, 10, 100) 40 | } 41 | 42 | func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int) { 43 | doneChan := make(chan struct{}) 44 | 45 | for i := 0; i < workers; i++ { 46 | go func() { 47 | defer func() { doneChan <- struct{}{} }() 48 | conn, err := pgx.ConnectConfig(context.Background(), config) 49 | require.Nil(t, err) 50 | defer closeConn(t, conn) 51 | 52 | for i := 0; i < iterations; i++ { 53 | var i32 int32 54 | var i64 int64 55 | var f32 float32 56 | var s string 57 | var s2 string 58 | err = conn.QueryRow(context.Background(), "select 1::int4, 2::int8, 3::float4, 'hi'::text").Scan(&i32, &i64, &f32, &s) 59 | require.NoError(t, err) 60 | assert.Equal(t, int32(1), i32) 61 | assert.Equal(t, int64(2), i64) 62 | assert.Equal(t, float32(3), f32) 63 | assert.Equal(t, "hi", s) 64 | 65 | err = conn.QueryRow(context.Background(), "select 1::int8, 2::float4, 'bye'::text, 4::int4, 'whatever'::text").Scan(&i64, &f32, &s, &i32, &s2) 66 | require.NoError(t, err) 67 | assert.Equal(t, int64(1), i64) 68 | assert.Equal(t, float32(2), f32) 69 | assert.Equal(t, "bye", s) 70 | assert.Equal(t, int32(4), i32) 71 | assert.Equal(t, "whatever", s2) 72 | } 73 | }() 74 | } 75 | 76 | for i := 0; i < workers; i++ { 77 | <-doneChan 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /pgxpool/rows.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "github.com/jackc/pgconn" 5 | "github.com/jackc/pgproto3/v2" 6 | "github.com/jackc/pgx/v4" 7 | ) 8 | 9 | type errRows struct { 10 | err error 11 | } 12 | 13 | func (errRows) Close() {} 14 | func (e errRows) Err() error { return e.err } 15 | func (errRows) CommandTag() pgconn.CommandTag { return nil } 16 | func (errRows) FieldDescriptions() []pgproto3.FieldDescription { return nil } 17 | func (errRows) Next() bool { return false } 18 | func (e errRows) Scan(dest ...interface{}) error { return e.err } 19 | func (e errRows) Values() ([]interface{}, error) { return nil, e.err } 20 | func (e errRows) RawValues() [][]byte { return nil } 21 | 22 | type errRow struct { 23 | err error 24 | } 25 | 26 | func (e errRow) Scan(dest ...interface{}) error { return e.err } 27 | 28 | type poolRows struct { 29 | r pgx.Rows 30 | c *Conn 31 | err error 32 | } 33 | 34 | func (rows *poolRows) Close() { 35 | rows.r.Close() 36 | if rows.c != nil { 37 | rows.c.Release() 38 | rows.c = nil 39 | } 40 | } 41 | 42 | func (rows *poolRows) Err() error { 43 | if rows.err != nil { 44 | return rows.err 45 | } 46 | return rows.r.Err() 47 | } 48 | 49 | func (rows *poolRows) CommandTag() pgconn.CommandTag { 50 | return rows.r.CommandTag() 51 | } 52 | 53 | func (rows *poolRows) FieldDescriptions() []pgproto3.FieldDescription { 54 | return rows.r.FieldDescriptions() 55 | } 56 | 57 | func (rows *poolRows) Next() bool { 58 | if rows.err != nil { 59 | return false 60 | } 61 | 62 | n := rows.r.Next() 63 | if !n { 64 | rows.Close() 65 | } 66 | return n 67 | } 68 | 69 | func (rows *poolRows) Scan(dest ...interface{}) error { 70 | err := rows.r.Scan(dest...) 71 | if err != nil { 72 | rows.Close() 73 | } 74 | return err 75 | } 76 | 77 | func (rows *poolRows) Values() ([]interface{}, error) { 78 | values, err := rows.r.Values() 79 | if err != nil { 80 | rows.Close() 81 | } 82 | return values, err 83 | } 84 | 85 | func (rows *poolRows) RawValues() [][]byte { 86 | return rows.r.RawValues() 87 | } 88 | 89 | type poolRow struct { 90 | r pgx.Row 91 | c *Conn 92 | err error 93 | } 94 | 95 | func (row *poolRow) Scan(dest ...interface{}) error { 96 | if row.err != nil { 97 | return row.err 98 | } 99 | 100 | err := row.r.Scan(dest...) 101 | if row.c != nil { 102 | row.c.Release() 103 | } 104 | return err 105 | } 106 | -------------------------------------------------------------------------------- /pgxpool/stat.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/jackc/puddle" 7 | ) 8 | 9 | // Stat is a snapshot of Pool statistics. 10 | type Stat struct { 11 | s *puddle.Stat 12 | newConnsCount int64 13 | lifetimeDestroyCount int64 14 | idleDestroyCount int64 15 | } 16 | 17 | // AcquireCount returns the cumulative count of successful acquires from the pool. 18 | func (s *Stat) AcquireCount() int64 { 19 | return s.s.AcquireCount() 20 | } 21 | 22 | // AcquireDuration returns the total duration of all successful acquires from 23 | // the pool. 24 | func (s *Stat) AcquireDuration() time.Duration { 25 | return s.s.AcquireDuration() 26 | } 27 | 28 | // AcquiredConns returns the number of currently acquired connections in the pool. 29 | func (s *Stat) AcquiredConns() int32 { 30 | return s.s.AcquiredResources() 31 | } 32 | 33 | // CanceledAcquireCount returns the cumulative count of acquires from the pool 34 | // that were canceled by a context. 35 | func (s *Stat) CanceledAcquireCount() int64 { 36 | return s.s.CanceledAcquireCount() 37 | } 38 | 39 | // ConstructingConns returns the number of conns with construction in progress in 40 | // the pool. 41 | func (s *Stat) ConstructingConns() int32 { 42 | return s.s.ConstructingResources() 43 | } 44 | 45 | // EmptyAcquireCount returns the cumulative count of successful acquires from the pool 46 | // that waited for a resource to be released or constructed because the pool was 47 | // empty. 48 | func (s *Stat) EmptyAcquireCount() int64 { 49 | return s.s.EmptyAcquireCount() 50 | } 51 | 52 | // IdleConns returns the number of currently idle conns in the pool. 53 | func (s *Stat) IdleConns() int32 { 54 | return s.s.IdleResources() 55 | } 56 | 57 | // MaxConns returns the maximum size of the pool. 58 | func (s *Stat) MaxConns() int32 { 59 | return s.s.MaxResources() 60 | } 61 | 62 | // TotalConns returns the total number of resources currently in the pool. 63 | // The value is the sum of ConstructingConns, AcquiredConns, and 64 | // IdleConns. 65 | func (s *Stat) TotalConns() int32 { 66 | return s.s.TotalResources() 67 | } 68 | 69 | // NewConnsCount returns the cumulative count of new connections opened. 70 | func (s *Stat) NewConnsCount() int64 { 71 | return s.newConnsCount 72 | } 73 | 74 | // MaxLifetimeDestroyCount returns the cumulative count of connections destroyed 75 | // because they exceeded MaxConnLifetime. 76 | func (s *Stat) MaxLifetimeDestroyCount() int64 { 77 | return s.lifetimeDestroyCount 78 | } 79 | 80 | // MaxIdleDestroyCount returns the cumulative count of connections destroyed because 81 | // they exceeded MaxConnIdleTime. 82 | func (s *Stat) MaxIdleDestroyCount() int64 { 83 | return s.idleDestroyCount 84 | } 85 | -------------------------------------------------------------------------------- /go_stdlib.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "database/sql/driver" 5 | "reflect" 6 | ) 7 | 8 | // This file contains code copied from the Go standard library due to the 9 | // required function not being public. 10 | 11 | // Copyright (c) 2009 The Go Authors. All rights reserved. 12 | 13 | // Redistribution and use in source and binary forms, with or without 14 | // modification, are permitted provided that the following conditions are 15 | // met: 16 | 17 | // * Redistributions of source code must retain the above copyright 18 | // notice, this list of conditions and the following disclaimer. 19 | // * Redistributions in binary form must reproduce the above 20 | // copyright notice, this list of conditions and the following disclaimer 21 | // in the documentation and/or other materials provided with the 22 | // distribution. 23 | // * Neither the name of Google Inc. nor the names of its 24 | // contributors may be used to endorse or promote products derived from 25 | // this software without specific prior written permission. 26 | 27 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 28 | // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 29 | // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 30 | // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 31 | // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 32 | // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 33 | // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 34 | // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 35 | // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 36 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 37 | // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 38 | 39 | // From database/sql/convert.go 40 | 41 | var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() 42 | 43 | // callValuerValue returns vr.Value(), with one exception: 44 | // If vr.Value is an auto-generated method on a pointer type and the 45 | // pointer is nil, it would panic at runtime in the panicwrap 46 | // method. Treat it like nil instead. 47 | // Issue 8415. 48 | // 49 | // This is so people can implement driver.Value on value types and 50 | // still use nil pointers to those types to mean nil/NULL, just like 51 | // string/*string. 52 | // 53 | // This function is mirrored in the database/sql/driver package. 54 | func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { 55 | if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && 56 | rv.IsNil() && 57 | rv.Type().Elem().Implements(valuerReflectType) { 58 | return nil, nil 59 | } 60 | return vr.Value() 61 | } 62 | -------------------------------------------------------------------------------- /examples/url_shortener/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "io/ioutil" 6 | "net/http" 7 | "os" 8 | 9 | "github.com/jackc/pgx/v4" 10 | "github.com/jackc/pgx/v4/log/log15adapter" 11 | "github.com/jackc/pgx/v4/pgxpool" 12 | log "gopkg.in/inconshreveable/log15.v2" 13 | ) 14 | 15 | var db *pgxpool.Pool 16 | 17 | func getUrlHandler(w http.ResponseWriter, req *http.Request) { 18 | var url string 19 | err := db.QueryRow(context.Background(), "select url from shortened_urls where id=$1", req.URL.Path).Scan(&url) 20 | switch err { 21 | case nil: 22 | http.Redirect(w, req, url, http.StatusSeeOther) 23 | case pgx.ErrNoRows: 24 | http.NotFound(w, req) 25 | default: 26 | http.Error(w, "Internal server error", http.StatusInternalServerError) 27 | } 28 | } 29 | 30 | func putUrlHandler(w http.ResponseWriter, req *http.Request) { 31 | id := req.URL.Path 32 | var url string 33 | if body, err := ioutil.ReadAll(req.Body); err == nil { 34 | url = string(body) 35 | } else { 36 | http.Error(w, "Internal server error", http.StatusInternalServerError) 37 | return 38 | } 39 | 40 | if _, err := db.Exec(context.Background(), `insert into shortened_urls(id, url) values ($1, $2) 41 | on conflict (id) do update set url=excluded.url`, id, url); err == nil { 42 | w.WriteHeader(http.StatusOK) 43 | } else { 44 | http.Error(w, "Internal server error", http.StatusInternalServerError) 45 | } 46 | } 47 | 48 | func deleteUrlHandler(w http.ResponseWriter, req *http.Request) { 49 | if _, err := db.Exec(context.Background(), "delete from shortened_urls where id=$1", req.URL.Path); err == nil { 50 | w.WriteHeader(http.StatusOK) 51 | } else { 52 | http.Error(w, "Internal server error", http.StatusInternalServerError) 53 | } 54 | } 55 | 56 | func urlHandler(w http.ResponseWriter, req *http.Request) { 57 | switch req.Method { 58 | case "GET": 59 | getUrlHandler(w, req) 60 | 61 | case "PUT": 62 | putUrlHandler(w, req) 63 | 64 | case "DELETE": 65 | deleteUrlHandler(w, req) 66 | 67 | default: 68 | w.Header().Add("Allow", "GET, PUT, DELETE") 69 | w.WriteHeader(http.StatusMethodNotAllowed) 70 | } 71 | } 72 | 73 | func main() { 74 | logger := log15adapter.NewLogger(log.New("module", "pgx")) 75 | 76 | poolConfig, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) 77 | if err != nil { 78 | log.Crit("Unable to parse DATABASE_URL", "error", err) 79 | os.Exit(1) 80 | } 81 | 82 | poolConfig.ConnConfig.Logger = logger 83 | 84 | db, err = pgxpool.ConnectConfig(context.Background(), poolConfig) 85 | if err != nil { 86 | log.Crit("Unable to create connection pool", "error", err) 87 | os.Exit(1) 88 | } 89 | 90 | http.HandleFunc("/", urlHandler) 91 | 92 | log.Info("Starting URL shortener on localhost:8080") 93 | err = http.ListenAndServe("localhost:8080", nil) 94 | if err != nil { 95 | log.Crit("Unable to start web server", "error", err) 96 | os.Exit(1) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "errors" 7 | "fmt" 8 | ) 9 | 10 | // The values for log levels are chosen such that the zero value means that no 11 | // log level was specified. 12 | const ( 13 | LogLevelTrace = 6 14 | LogLevelDebug = 5 15 | LogLevelInfo = 4 16 | LogLevelWarn = 3 17 | LogLevelError = 2 18 | LogLevelNone = 1 19 | ) 20 | 21 | // LogLevel represents the pgx logging level. See LogLevel* constants for 22 | // possible values. 23 | type LogLevel int 24 | 25 | func (ll LogLevel) String() string { 26 | switch ll { 27 | case LogLevelTrace: 28 | return "trace" 29 | case LogLevelDebug: 30 | return "debug" 31 | case LogLevelInfo: 32 | return "info" 33 | case LogLevelWarn: 34 | return "warn" 35 | case LogLevelError: 36 | return "error" 37 | case LogLevelNone: 38 | return "none" 39 | default: 40 | return fmt.Sprintf("invalid level %d", ll) 41 | } 42 | } 43 | 44 | // Logger is the interface used to get logging from pgx internals. 45 | type Logger interface { 46 | // Log a message at the given level with data key/value pairs. data may be nil. 47 | Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) 48 | } 49 | 50 | // LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface 51 | type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) 52 | 53 | // Log delegates the logging request to the wrapped function 54 | func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { 55 | f(ctx, level, msg, data) 56 | } 57 | 58 | // LogLevelFromString converts log level string to constant 59 | // 60 | // Valid levels: 61 | // trace 62 | // debug 63 | // info 64 | // warn 65 | // error 66 | // none 67 | func LogLevelFromString(s string) (LogLevel, error) { 68 | switch s { 69 | case "trace": 70 | return LogLevelTrace, nil 71 | case "debug": 72 | return LogLevelDebug, nil 73 | case "info": 74 | return LogLevelInfo, nil 75 | case "warn": 76 | return LogLevelWarn, nil 77 | case "error": 78 | return LogLevelError, nil 79 | case "none": 80 | return LogLevelNone, nil 81 | default: 82 | return 0, errors.New("invalid log level") 83 | } 84 | } 85 | 86 | func logQueryArgs(args []interface{}) []interface{} { 87 | logArgs := make([]interface{}, 0, len(args)) 88 | 89 | for _, a := range args { 90 | switch v := a.(type) { 91 | case []byte: 92 | if len(v) < 64 { 93 | a = hex.EncodeToString(v) 94 | } else { 95 | a = fmt.Sprintf("%x (truncated %d bytes)", v[:64], len(v)-64) 96 | } 97 | case string: 98 | if len(v) > 64 { 99 | a = fmt.Sprintf("%s (truncated %d bytes)", v[:64], len(v)-64) 100 | } 101 | } 102 | logArgs = append(logArgs, a) 103 | } 104 | 105 | return logArgs 106 | } 107 | -------------------------------------------------------------------------------- /examples/todo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "strconv" 8 | 9 | "github.com/jackc/pgx/v4" 10 | ) 11 | 12 | var conn *pgx.Conn 13 | 14 | func main() { 15 | var err error 16 | conn, err = pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) 17 | if err != nil { 18 | fmt.Fprintf(os.Stderr, "Unable to connection to database: %v\n", err) 19 | os.Exit(1) 20 | } 21 | 22 | if len(os.Args) == 1 { 23 | printHelp() 24 | os.Exit(0) 25 | } 26 | 27 | switch os.Args[1] { 28 | case "list": 29 | err = listTasks() 30 | if err != nil { 31 | fmt.Fprintf(os.Stderr, "Unable to list tasks: %v\n", err) 32 | os.Exit(1) 33 | } 34 | 35 | case "add": 36 | err = addTask(os.Args[2]) 37 | if err != nil { 38 | fmt.Fprintf(os.Stderr, "Unable to add task: %v\n", err) 39 | os.Exit(1) 40 | } 41 | 42 | case "update": 43 | n, err := strconv.ParseInt(os.Args[2], 10, 32) 44 | if err != nil { 45 | fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err) 46 | os.Exit(1) 47 | } 48 | err = updateTask(int32(n), os.Args[3]) 49 | if err != nil { 50 | fmt.Fprintf(os.Stderr, "Unable to update task: %v\n", err) 51 | os.Exit(1) 52 | } 53 | 54 | case "remove": 55 | n, err := strconv.ParseInt(os.Args[2], 10, 32) 56 | if err != nil { 57 | fmt.Fprintf(os.Stderr, "Unable convert task_num into int32: %v\n", err) 58 | os.Exit(1) 59 | } 60 | err = removeTask(int32(n)) 61 | if err != nil { 62 | fmt.Fprintf(os.Stderr, "Unable to remove task: %v\n", err) 63 | os.Exit(1) 64 | } 65 | 66 | default: 67 | fmt.Fprintln(os.Stderr, "Invalid command") 68 | printHelp() 69 | os.Exit(1) 70 | } 71 | } 72 | 73 | func listTasks() error { 74 | rows, _ := conn.Query(context.Background(), "select * from tasks") 75 | 76 | for rows.Next() { 77 | var id int32 78 | var description string 79 | err := rows.Scan(&id, &description) 80 | if err != nil { 81 | return err 82 | } 83 | fmt.Printf("%d. %s\n", id, description) 84 | } 85 | 86 | return rows.Err() 87 | } 88 | 89 | func addTask(description string) error { 90 | _, err := conn.Exec(context.Background(), "insert into tasks(description) values($1)", description) 91 | return err 92 | } 93 | 94 | func updateTask(itemNum int32, description string) error { 95 | _, err := conn.Exec(context.Background(), "update tasks set description=$1 where id=$2", description, itemNum) 96 | return err 97 | } 98 | 99 | func removeTask(itemNum int32) error { 100 | _, err := conn.Exec(context.Background(), "delete from tasks where id=$1", itemNum) 101 | return err 102 | } 103 | 104 | func printHelp() { 105 | fmt.Print(`Todo pgx demo 106 | 107 | Usage: 108 | 109 | todo list 110 | todo add task 111 | todo update task_num item 112 | todo remove task_num 113 | 114 | Example: 115 | 116 | todo add 'Learn Go' 117 | todo list 118 | `) 119 | } 120 | -------------------------------------------------------------------------------- /log/zerologadapter/adapter.go: -------------------------------------------------------------------------------- 1 | // Package zerologadapter provides a logger that writes to a github.com/rs/zerolog. 2 | package zerologadapter 3 | 4 | import ( 5 | "context" 6 | 7 | "github.com/jackc/pgx/v4" 8 | "github.com/rs/zerolog" 9 | ) 10 | 11 | type Logger struct { 12 | logger zerolog.Logger 13 | withFunc func(context.Context, zerolog.Context) zerolog.Context 14 | fromContext bool 15 | skipModule bool 16 | } 17 | 18 | // option options for configuring the logger when creating a new logger. 19 | type option func(logger *Logger) 20 | 21 | // WithContextFunc adds possibility to get request scoped values from the 22 | // ctx.Context before logging lines. 23 | func WithContextFunc(withFunc func(context.Context, zerolog.Context) zerolog.Context) option { 24 | return func(logger *Logger) { 25 | logger.withFunc = withFunc 26 | } 27 | } 28 | 29 | // WithoutPGXModule disables adding module:pgx to the default logger context. 30 | func WithoutPGXModule() option { 31 | return func(logger *Logger) { 32 | logger.skipModule = true 33 | } 34 | } 35 | 36 | // NewLogger accepts a zerolog.Logger as input and returns a new custom pgx 37 | // logging facade as output. 38 | func NewLogger(logger zerolog.Logger, options ...option) *Logger { 39 | l := Logger{ 40 | logger: logger, 41 | } 42 | l.init(options) 43 | return &l 44 | } 45 | 46 | // NewContextLogger creates logger that extracts the zerolog.Logger from the 47 | // context.Context by using `zerolog.Ctx`. The zerolog.DefaultContextLogger will 48 | // be used if no logger is associated with the context. 49 | func NewContextLogger(options ...option) *Logger { 50 | l := Logger{ 51 | fromContext: true, 52 | } 53 | l.init(options) 54 | return &l 55 | } 56 | 57 | func (pl *Logger) init(options []option) { 58 | for _, opt := range options { 59 | opt(pl) 60 | } 61 | if !pl.skipModule { 62 | pl.logger = pl.logger.With().Str("module", "pgx").Logger() 63 | } 64 | } 65 | 66 | func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { 67 | var zlevel zerolog.Level 68 | switch level { 69 | case pgx.LogLevelNone: 70 | zlevel = zerolog.NoLevel 71 | case pgx.LogLevelError: 72 | zlevel = zerolog.ErrorLevel 73 | case pgx.LogLevelWarn: 74 | zlevel = zerolog.WarnLevel 75 | case pgx.LogLevelInfo: 76 | zlevel = zerolog.InfoLevel 77 | case pgx.LogLevelDebug: 78 | zlevel = zerolog.DebugLevel 79 | default: 80 | zlevel = zerolog.DebugLevel 81 | } 82 | 83 | var zctx zerolog.Context 84 | if pl.fromContext { 85 | logger := zerolog.Ctx(ctx) 86 | zctx = logger.With() 87 | } else { 88 | zctx = pl.logger.With() 89 | } 90 | if pl.withFunc != nil { 91 | zctx = pl.withFunc(ctx, zctx) 92 | } 93 | 94 | pgxlog := zctx.Logger() 95 | event := pgxlog.WithLevel(zlevel) 96 | if event.Enabled() { 97 | if pl.fromContext && !pl.skipModule { 98 | event.Str("module", "pgx") 99 | } 100 | event.Fields(data).Msg(msg) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /stdlib/bench_test.go: -------------------------------------------------------------------------------- 1 | package stdlib_test 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "os" 7 | "strconv" 8 | "strings" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func getSelectRowsCounts(b *testing.B) []int64 { 14 | var rowCounts []int64 15 | { 16 | s := os.Getenv("PGX_BENCH_SELECT_ROWS_COUNTS") 17 | if s != "" { 18 | for _, p := range strings.Split(s, " ") { 19 | n, err := strconv.ParseInt(p, 10, 64) 20 | if err != nil { 21 | b.Fatalf("Bad PGX_BENCH_SELECT_ROWS_COUNTS value: %v", err) 22 | } 23 | rowCounts = append(rowCounts, n) 24 | } 25 | } 26 | } 27 | 28 | if len(rowCounts) == 0 { 29 | rowCounts = []int64{1, 10, 100, 1000} 30 | } 31 | 32 | return rowCounts 33 | } 34 | 35 | type BenchRowSimple struct { 36 | ID int32 37 | FirstName string 38 | LastName string 39 | Sex string 40 | BirthDate time.Time 41 | Weight int32 42 | Height int32 43 | UpdateTime time.Time 44 | } 45 | 46 | func BenchmarkSelectRowsScanSimple(b *testing.B) { 47 | db := openDB(b) 48 | defer closeDB(b, db) 49 | 50 | rowCounts := getSelectRowsCounts(b) 51 | 52 | for _, rowCount := range rowCounts { 53 | b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { 54 | br := &BenchRowSimple{} 55 | for i := 0; i < b.N; i++ { 56 | rows, err := db.Query("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(1, $1) n", rowCount) 57 | if err != nil { 58 | b.Fatal(err) 59 | } 60 | 61 | for rows.Next() { 62 | rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) 63 | } 64 | 65 | if rows.Err() != nil { 66 | b.Fatal(rows.Err()) 67 | } 68 | } 69 | }) 70 | } 71 | } 72 | 73 | type BenchRowNull struct { 74 | ID sql.NullInt32 75 | FirstName sql.NullString 76 | LastName sql.NullString 77 | Sex sql.NullString 78 | BirthDate sql.NullTime 79 | Weight sql.NullInt32 80 | Height sql.NullInt32 81 | UpdateTime sql.NullTime 82 | } 83 | 84 | func BenchmarkSelectRowsScanNull(b *testing.B) { 85 | db := openDB(b) 86 | defer closeDB(b, db) 87 | 88 | rowCounts := getSelectRowsCounts(b) 89 | 90 | for _, rowCount := range rowCounts { 91 | b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { 92 | br := &BenchRowSimple{} 93 | for i := 0; i < b.N; i++ { 94 | rows, err := db.Query("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100000, 100000 + $1) n", rowCount) 95 | if err != nil { 96 | b.Fatal(err) 97 | } 98 | 99 | for rows.Next() { 100 | rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) 101 | } 102 | 103 | if rows.Err() != nil { 104 | b.Fatal(rows.Err()) 105 | } 106 | } 107 | }) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /example_custom_type_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "regexp" 8 | "strconv" 9 | 10 | "github.com/jackc/pgtype" 11 | "github.com/jackc/pgx/v4" 12 | ) 13 | 14 | var pointRegexp *regexp.Regexp = regexp.MustCompile(`^\((.*),(.*)\)$`) 15 | 16 | // Point represents a point that may be null. 17 | type Point struct { 18 | X, Y float64 // Coordinates of point 19 | Status pgtype.Status 20 | } 21 | 22 | func (dst *Point) Set(src interface{}) error { 23 | return fmt.Errorf("cannot convert %v to Point", src) 24 | } 25 | 26 | func (dst *Point) Get() interface{} { 27 | switch dst.Status { 28 | case pgtype.Present: 29 | return dst 30 | case pgtype.Null: 31 | return nil 32 | default: 33 | return dst.Status 34 | } 35 | } 36 | 37 | func (src *Point) AssignTo(dst interface{}) error { 38 | return fmt.Errorf("cannot assign %v to %T", src, dst) 39 | } 40 | 41 | func (dst *Point) DecodeText(ci *pgtype.ConnInfo, src []byte) error { 42 | if src == nil { 43 | *dst = Point{Status: pgtype.Null} 44 | return nil 45 | } 46 | 47 | s := string(src) 48 | match := pointRegexp.FindStringSubmatch(s) 49 | if match == nil { 50 | return fmt.Errorf("Received invalid point: %v", s) 51 | } 52 | 53 | x, err := strconv.ParseFloat(match[1], 64) 54 | if err != nil { 55 | return fmt.Errorf("Received invalid point: %v", s) 56 | } 57 | y, err := strconv.ParseFloat(match[2], 64) 58 | if err != nil { 59 | return fmt.Errorf("Received invalid point: %v", s) 60 | } 61 | 62 | *dst = Point{X: x, Y: y, Status: pgtype.Present} 63 | 64 | return nil 65 | } 66 | 67 | func (src *Point) String() string { 68 | if src.Status == pgtype.Null { 69 | return "null point" 70 | } 71 | 72 | return fmt.Sprintf("%.1f, %.1f", src.X, src.Y) 73 | } 74 | 75 | func Example_CustomType() { 76 | conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) 77 | if err != nil { 78 | fmt.Printf("Unable to establish connection: %v", err) 79 | return 80 | } 81 | defer conn.Close(context.Background()) 82 | 83 | if conn.PgConn().ParameterStatus("crdb_version") != "" { 84 | // Skip test / example when running on CockroachDB which doesn't support the point type. Since an example can't be 85 | // skipped fake success instead. 86 | fmt.Println("null point") 87 | fmt.Println("1.5, 2.5") 88 | return 89 | } 90 | 91 | // Override registered handler for point 92 | conn.ConnInfo().RegisterDataType(pgtype.DataType{ 93 | Value: &Point{}, 94 | Name: "point", 95 | OID: 600, 96 | }) 97 | 98 | p := &Point{} 99 | err = conn.QueryRow(context.Background(), "select null::point").Scan(p) 100 | if err != nil { 101 | fmt.Println(err) 102 | return 103 | } 104 | fmt.Println(p) 105 | 106 | err = conn.QueryRow(context.Background(), "select point(1.5,2.5)").Scan(p) 107 | if err != nil { 108 | fmt.Println(err) 109 | return 110 | } 111 | fmt.Println(p) 112 | // Output: 113 | // null point 114 | // 1.5, 2.5 115 | } 116 | -------------------------------------------------------------------------------- /log/zerologadapter/adapter_test.go: -------------------------------------------------------------------------------- 1 | package zerologadapter_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "testing" 7 | 8 | "github.com/jackc/pgx/v4" 9 | "github.com/jackc/pgx/v4/log/zerologadapter" 10 | "github.com/rs/zerolog" 11 | ) 12 | 13 | func TestLogger(t *testing.T) { 14 | 15 | t.Run("default", func(t *testing.T) { 16 | var buf bytes.Buffer 17 | zlogger := zerolog.New(&buf) 18 | logger := zerologadapter.NewLogger(zlogger) 19 | logger.Log(context.Background(), pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) 20 | const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} 21 | ` 22 | got := buf.String() 23 | if got != want { 24 | t.Errorf("%s != %s", got, want) 25 | } 26 | }) 27 | 28 | t.Run("disable pgx module", func(t *testing.T) { 29 | var buf bytes.Buffer 30 | zlogger := zerolog.New(&buf) 31 | logger := zerologadapter.NewLogger(zlogger, zerologadapter.WithoutPGXModule()) 32 | logger.Log(context.Background(), pgx.LogLevelInfo, "hello", nil) 33 | const want = `{"level":"info","message":"hello"} 34 | ` 35 | got := buf.String() 36 | if got != want { 37 | t.Errorf("%s != %s", got, want) 38 | } 39 | }) 40 | 41 | t.Run("from context", func(t *testing.T) { 42 | var buf bytes.Buffer 43 | zlogger := zerolog.New(&buf) 44 | ctx := zlogger.WithContext(context.Background()) 45 | logger := zerologadapter.NewContextLogger() 46 | logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"one": "two"}) 47 | const want = `{"level":"info","module":"pgx","one":"two","message":"hello"} 48 | ` 49 | 50 | got := buf.String() 51 | if got != want { 52 | t.Log(got) 53 | t.Log(want) 54 | t.Errorf("%s != %s", got, want) 55 | } 56 | }) 57 | 58 | var buf bytes.Buffer 59 | type key string 60 | var ck key 61 | zlogger := zerolog.New(&buf) 62 | logger := zerologadapter.NewLogger(zlogger, 63 | zerologadapter.WithContextFunc(func(ctx context.Context, logWith zerolog.Context) zerolog.Context { 64 | // You can use zerolog.hlog.IDFromCtx(ctx) or even 65 | // zerolog.log.Ctx(ctx) to fetch the whole logger instance from the 66 | // context if you want. 67 | id, ok := ctx.Value(ck).(string) 68 | if ok { 69 | logWith = logWith.Str("req_id", id) 70 | } 71 | return logWith 72 | }), 73 | ) 74 | 75 | t.Run("no request id", func(t *testing.T) { 76 | buf.Reset() 77 | ctx := context.Background() 78 | logger.Log(ctx, pgx.LogLevelInfo, "hello", nil) 79 | const want = `{"level":"info","module":"pgx","message":"hello"} 80 | ` 81 | got := buf.String() 82 | if got != want { 83 | t.Errorf("%s != %s", got, want) 84 | } 85 | }) 86 | 87 | t.Run("with request id", func(t *testing.T) { 88 | buf.Reset() 89 | ctx := context.WithValue(context.Background(), ck, "1") 90 | logger.Log(ctx, pgx.LogLevelInfo, "hello", map[string]interface{}{"two": "2"}) 91 | const want = `{"level":"info","module":"pgx","req_id":"1","two":"2","message":"hello"} 92 | ` 93 | got := buf.String() 94 | if got != want { 95 | t.Errorf("%s != %s", got, want) 96 | } 97 | }) 98 | } 99 | -------------------------------------------------------------------------------- /pgxpool/tx.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/jackc/pgconn" 7 | "github.com/jackc/pgx/v4" 8 | ) 9 | 10 | // Tx represents a database transaction acquired from a Pool. 11 | type Tx struct { 12 | t pgx.Tx 13 | c *Conn 14 | } 15 | 16 | // Begin starts a pseudo nested transaction implemented with a savepoint. 17 | func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { 18 | return tx.t.Begin(ctx) 19 | } 20 | 21 | func (tx *Tx) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { 22 | return tx.t.BeginFunc(ctx, f) 23 | } 24 | 25 | // Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed 26 | // if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status 27 | // (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. 28 | func (tx *Tx) Commit(ctx context.Context) error { 29 | err := tx.t.Commit(ctx) 30 | if tx.c != nil { 31 | tx.c.Release() 32 | tx.c = nil 33 | } 34 | return err 35 | } 36 | 37 | // Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return ErrTxClosed 38 | // if the Tx is already closed, but is otherwise safe to call multiple times. Hence, defer tx.Rollback() is safe even if 39 | // tx.Commit() will be called first in a non-error condition. 40 | func (tx *Tx) Rollback(ctx context.Context) error { 41 | err := tx.t.Rollback(ctx) 42 | if tx.c != nil { 43 | tx.c.Release() 44 | tx.c = nil 45 | } 46 | return err 47 | } 48 | 49 | func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { 50 | return tx.t.CopyFrom(ctx, tableName, columnNames, rowSrc) 51 | } 52 | 53 | func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { 54 | return tx.t.SendBatch(ctx, b) 55 | } 56 | 57 | func (tx *Tx) LargeObjects() pgx.LargeObjects { 58 | return tx.t.LargeObjects() 59 | } 60 | 61 | // Prepare creates a prepared statement with name and sql. If the name is empty, 62 | // an anonymous prepared statement will be used. sql can contain placeholders 63 | // for bound parameters. These placeholders are referenced positionally as $1, $2, etc. 64 | // 65 | // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same 66 | // name and sql arguments. This allows a code path to Prepare and Query/Exec without 67 | // needing to first check whether the statement has already been prepared. 68 | func (tx *Tx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { 69 | return tx.t.Prepare(ctx, name, sql) 70 | } 71 | 72 | func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { 73 | return tx.t.Exec(ctx, sql, arguments...) 74 | } 75 | 76 | func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { 77 | return tx.t.Query(ctx, sql, args...) 78 | } 79 | 80 | func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { 81 | return tx.t.QueryRow(ctx, sql, args...) 82 | } 83 | 84 | func (tx *Tx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { 85 | return tx.t.QueryFunc(ctx, sql, args, scans, f) 86 | } 87 | 88 | func (tx *Tx) Conn() *pgx.Conn { 89 | return tx.t.Conn() 90 | } 91 | -------------------------------------------------------------------------------- /large_objects.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | ) 8 | 9 | // LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it 10 | // was created. 11 | // 12 | // For more details see: http://www.postgresql.org/docs/current/static/largeobjects.html 13 | type LargeObjects struct { 14 | tx Tx 15 | } 16 | 17 | type LargeObjectMode int32 18 | 19 | const ( 20 | LargeObjectModeWrite LargeObjectMode = 0x20000 21 | LargeObjectModeRead LargeObjectMode = 0x40000 22 | ) 23 | 24 | // Create creates a new large object. If oid is zero, the server assigns an unused OID. 25 | func (o *LargeObjects) Create(ctx context.Context, oid uint32) (uint32, error) { 26 | err := o.tx.QueryRow(ctx, "select lo_create($1)", oid).Scan(&oid) 27 | return oid, err 28 | } 29 | 30 | // Open opens an existing large object with the given mode. ctx will also be used for all operations on the opened large 31 | // object. 32 | func (o *LargeObjects) Open(ctx context.Context, oid uint32, mode LargeObjectMode) (*LargeObject, error) { 33 | var fd int32 34 | err := o.tx.QueryRow(ctx, "select lo_open($1, $2)", oid, mode).Scan(&fd) 35 | if err != nil { 36 | return nil, err 37 | } 38 | return &LargeObject{fd: fd, tx: o.tx, ctx: ctx}, nil 39 | } 40 | 41 | // Unlink removes a large object from the database. 42 | func (o *LargeObjects) Unlink(ctx context.Context, oid uint32) error { 43 | var result int32 44 | err := o.tx.QueryRow(ctx, "select lo_unlink($1)", oid).Scan(&result) 45 | if err != nil { 46 | return err 47 | } 48 | 49 | if result != 1 { 50 | return errors.New("failed to remove large object") 51 | } 52 | 53 | return nil 54 | } 55 | 56 | // A LargeObject is a large object stored on the server. It is only valid within the transaction that it was initialized 57 | // in. It uses the context it was initialized with for all operations. It implements these interfaces: 58 | // 59 | // io.Writer 60 | // io.Reader 61 | // io.Seeker 62 | // io.Closer 63 | type LargeObject struct { 64 | ctx context.Context 65 | tx Tx 66 | fd int32 67 | } 68 | 69 | // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. 70 | func (o *LargeObject) Write(p []byte) (int, error) { 71 | var n int 72 | err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) 73 | if err != nil { 74 | return n, err 75 | } 76 | 77 | if n < 0 { 78 | return 0, errors.New("failed to write to large object") 79 | } 80 | 81 | return n, nil 82 | } 83 | 84 | // Read reads up to len(p) bytes into p returning the number of bytes read. 85 | func (o *LargeObject) Read(p []byte) (int, error) { 86 | var res []byte 87 | err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) 88 | copy(p, res) 89 | if err != nil { 90 | return len(res), err 91 | } 92 | 93 | if len(res) < len(p) { 94 | err = io.EOF 95 | } 96 | return len(res), err 97 | } 98 | 99 | // Seek moves the current location pointer to the new location specified by offset. 100 | func (o *LargeObject) Seek(offset int64, whence int) (n int64, err error) { 101 | err = o.tx.QueryRow(o.ctx, "select lo_lseek64($1, $2, $3)", o.fd, offset, whence).Scan(&n) 102 | return n, err 103 | } 104 | 105 | // Tell returns the current read or write location of the large object descriptor. 106 | func (o *LargeObject) Tell() (n int64, err error) { 107 | err = o.tx.QueryRow(o.ctx, "select lo_tell64($1)", o.fd).Scan(&n) 108 | return n, err 109 | } 110 | 111 | // Truncate the large object to size. 112 | func (o *LargeObject) Truncate(size int64) (err error) { 113 | _, err = o.tx.Exec(o.ctx, "select lo_truncate64($1, $2)", o.fd, size) 114 | return err 115 | } 116 | 117 | // Close the large object descriptor. 118 | func (o *LargeObject) Close() error { 119 | _, err := o.tx.Exec(o.ctx, "select lo_close($1)", o.fd) 120 | return err 121 | } 122 | -------------------------------------------------------------------------------- /extended_query_builder.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "reflect" 7 | 8 | "github.com/jackc/pgtype" 9 | ) 10 | 11 | type extendedQueryBuilder struct { 12 | paramValues [][]byte 13 | paramValueBytes []byte 14 | paramFormats []int16 15 | resultFormats []int16 16 | } 17 | 18 | func (eqb *extendedQueryBuilder) AppendParam(ci *pgtype.ConnInfo, oid uint32, arg interface{}) error { 19 | f := chooseParameterFormatCode(ci, oid, arg) 20 | eqb.paramFormats = append(eqb.paramFormats, f) 21 | 22 | v, err := eqb.encodeExtendedParamValue(ci, oid, f, arg) 23 | if err != nil { 24 | return err 25 | } 26 | eqb.paramValues = append(eqb.paramValues, v) 27 | 28 | return nil 29 | } 30 | 31 | func (eqb *extendedQueryBuilder) AppendResultFormat(f int16) { 32 | eqb.resultFormats = append(eqb.resultFormats, f) 33 | } 34 | 35 | // Reset readies eqb to build another query. 36 | func (eqb *extendedQueryBuilder) Reset() { 37 | eqb.paramValues = eqb.paramValues[0:0] 38 | eqb.paramValueBytes = eqb.paramValueBytes[0:0] 39 | eqb.paramFormats = eqb.paramFormats[0:0] 40 | eqb.resultFormats = eqb.resultFormats[0:0] 41 | 42 | if cap(eqb.paramValues) > 64 { 43 | eqb.paramValues = make([][]byte, 0, 64) 44 | } 45 | 46 | if cap(eqb.paramValueBytes) > 256 { 47 | eqb.paramValueBytes = make([]byte, 0, 256) 48 | } 49 | 50 | if cap(eqb.paramFormats) > 64 { 51 | eqb.paramFormats = make([]int16, 0, 64) 52 | } 53 | if cap(eqb.resultFormats) > 64 { 54 | eqb.resultFormats = make([]int16, 0, 64) 55 | } 56 | } 57 | 58 | func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, oid uint32, formatCode int16, arg interface{}) ([]byte, error) { 59 | if arg == nil { 60 | return nil, nil 61 | } 62 | 63 | refVal := reflect.ValueOf(arg) 64 | argIsPtr := refVal.Kind() == reflect.Ptr 65 | 66 | if argIsPtr && refVal.IsNil() { 67 | return nil, nil 68 | } 69 | 70 | if eqb.paramValueBytes == nil { 71 | eqb.paramValueBytes = make([]byte, 0, 128) 72 | } 73 | 74 | var err error 75 | var buf []byte 76 | pos := len(eqb.paramValueBytes) 77 | 78 | if arg, ok := arg.(string); ok { 79 | return []byte(arg), nil 80 | } 81 | 82 | if formatCode == TextFormatCode { 83 | if arg, ok := arg.(pgtype.TextEncoder); ok { 84 | buf, err = arg.EncodeText(ci, eqb.paramValueBytes) 85 | if err != nil { 86 | return nil, err 87 | } 88 | if buf == nil { 89 | return nil, nil 90 | } 91 | eqb.paramValueBytes = buf 92 | return eqb.paramValueBytes[pos:], nil 93 | } 94 | } else if formatCode == BinaryFormatCode { 95 | if arg, ok := arg.(pgtype.BinaryEncoder); ok { 96 | buf, err = arg.EncodeBinary(ci, eqb.paramValueBytes) 97 | if err != nil { 98 | return nil, err 99 | } 100 | if buf == nil { 101 | return nil, nil 102 | } 103 | eqb.paramValueBytes = buf 104 | return eqb.paramValueBytes[pos:], nil 105 | } 106 | } 107 | 108 | if argIsPtr { 109 | // We have already checked that arg is not pointing to nil, 110 | // so it is safe to dereference here. 111 | arg = refVal.Elem().Interface() 112 | return eqb.encodeExtendedParamValue(ci, oid, formatCode, arg) 113 | } 114 | 115 | if dt, ok := ci.DataTypeForOID(oid); ok { 116 | value := dt.Value 117 | err := value.Set(arg) 118 | if err != nil { 119 | { 120 | if arg, ok := arg.(driver.Valuer); ok { 121 | v, err := callValuerValue(arg) 122 | if err != nil { 123 | return nil, err 124 | } 125 | return eqb.encodeExtendedParamValue(ci, oid, formatCode, v) 126 | } 127 | } 128 | 129 | return nil, err 130 | } 131 | 132 | return eqb.encodeExtendedParamValue(ci, oid, formatCode, value) 133 | } 134 | 135 | // There is no data type registered for the destination OID, but maybe there is data type registered for the arg 136 | // type. If so use it's text encoder (if available). 137 | if dt, ok := ci.DataTypeForValue(arg); ok { 138 | value := dt.Value 139 | if textEncoder, ok := value.(pgtype.TextEncoder); ok { 140 | err := value.Set(arg) 141 | if err != nil { 142 | return nil, err 143 | } 144 | 145 | buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) 146 | if err != nil { 147 | return nil, err 148 | } 149 | if buf == nil { 150 | return nil, nil 151 | } 152 | eqb.paramValueBytes = buf 153 | return eqb.paramValueBytes[pos:], nil 154 | } 155 | } 156 | 157 | if strippedArg, ok := stripNamedType(&refVal); ok { 158 | return eqb.encodeExtendedParamValue(ci, oid, formatCode, strippedArg) 159 | } 160 | return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) 161 | } 162 | -------------------------------------------------------------------------------- /pgxpool/conn.go: -------------------------------------------------------------------------------- 1 | package pgxpool 2 | 3 | import ( 4 | "context" 5 | "sync/atomic" 6 | 7 | "github.com/jackc/pgconn" 8 | "github.com/jackc/pgx/v4" 9 | "github.com/jackc/puddle" 10 | ) 11 | 12 | // Conn is an acquired *pgx.Conn from a Pool. 13 | type Conn struct { 14 | res *puddle.Resource 15 | p *Pool 16 | } 17 | 18 | // Release returns c to the pool it was acquired from. Once Release has been called, other methods must not be called. 19 | // However, it is safe to call Release multiple times. Subsequent calls after the first will be ignored. 20 | func (c *Conn) Release() { 21 | if c.res == nil { 22 | return 23 | } 24 | 25 | conn := c.Conn() 26 | res := c.res 27 | c.res = nil 28 | 29 | if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { 30 | res.Destroy() 31 | // Signal to the health check to run since we just destroyed a connections 32 | // and we might be below minConns now 33 | c.p.triggerHealthCheck() 34 | return 35 | } 36 | 37 | // If the pool is consistently being used, we might never get to check the 38 | // lifetime of a connection since we only check idle connections in checkConnsHealth 39 | // so we also check the lifetime here and force a health check 40 | if c.p.isExpired(res) { 41 | atomic.AddInt64(&c.p.lifetimeDestroyCount, 1) 42 | res.Destroy() 43 | // Signal to the health check to run since we just destroyed a connections 44 | // and we might be below minConns now 45 | c.p.triggerHealthCheck() 46 | return 47 | } 48 | 49 | if c.p.afterRelease == nil { 50 | res.Release() 51 | return 52 | } 53 | 54 | go func() { 55 | if c.p.afterRelease(conn) { 56 | res.Release() 57 | } else { 58 | res.Destroy() 59 | // Signal to the health check to run since we just destroyed a connections 60 | // and we might be below minConns now 61 | c.p.triggerHealthCheck() 62 | } 63 | }() 64 | } 65 | 66 | // Hijack assumes ownership of the connection from the pool. Caller is responsible for closing the connection. Hijack 67 | // will panic if called on an already released or hijacked connection. 68 | func (c *Conn) Hijack() *pgx.Conn { 69 | if c.res == nil { 70 | panic("cannot hijack already released or hijacked connection") 71 | } 72 | 73 | conn := c.Conn() 74 | res := c.res 75 | c.res = nil 76 | 77 | res.Hijack() 78 | 79 | return conn 80 | } 81 | 82 | func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { 83 | return c.Conn().Exec(ctx, sql, arguments...) 84 | } 85 | 86 | func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { 87 | return c.Conn().Query(ctx, sql, args...) 88 | } 89 | 90 | func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { 91 | return c.Conn().QueryRow(ctx, sql, args...) 92 | } 93 | 94 | func (c *Conn) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(pgx.QueryFuncRow) error) (pgconn.CommandTag, error) { 95 | return c.Conn().QueryFunc(ctx, sql, args, scans, f) 96 | } 97 | 98 | func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { 99 | return c.Conn().SendBatch(ctx, b) 100 | } 101 | 102 | func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { 103 | return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) 104 | } 105 | 106 | // Begin starts a transaction block from the *Conn without explicitly setting a transaction mode (see BeginTx with TxOptions if transaction mode is required). 107 | func (c *Conn) Begin(ctx context.Context) (pgx.Tx, error) { 108 | return c.Conn().Begin(ctx) 109 | } 110 | 111 | // BeginTx starts a transaction block from the *Conn with txOptions determining the transaction mode. 112 | func (c *Conn) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { 113 | return c.Conn().BeginTx(ctx, txOptions) 114 | } 115 | 116 | func (c *Conn) BeginFunc(ctx context.Context, f func(pgx.Tx) error) error { 117 | return c.Conn().BeginFunc(ctx, f) 118 | } 119 | 120 | func (c *Conn) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error { 121 | return c.Conn().BeginTxFunc(ctx, txOptions, f) 122 | } 123 | 124 | func (c *Conn) Ping(ctx context.Context) error { 125 | return c.Conn().Ping(ctx) 126 | } 127 | 128 | func (c *Conn) Conn() *pgx.Conn { 129 | return c.connResource().conn 130 | } 131 | 132 | func (c *Conn) connResource() *connResource { 133 | return c.res.Value().(*connResource) 134 | } 135 | 136 | func (c *Conn) getPoolRow(r pgx.Row) *poolRow { 137 | return c.connResource().getPoolRow(c, r) 138 | } 139 | 140 | func (c *Conn) getPoolRows(r pgx.Rows) *poolRows { 141 | return c.connResource().getPoolRows(c, r) 142 | } 143 | -------------------------------------------------------------------------------- /copy_from.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "io" 8 | "time" 9 | 10 | "github.com/jackc/pgconn" 11 | "github.com/jackc/pgio" 12 | ) 13 | 14 | // CopyFromRows returns a CopyFromSource interface over the provided rows slice 15 | // making it usable by *Conn.CopyFrom. 16 | func CopyFromRows(rows [][]interface{}) CopyFromSource { 17 | return ©FromRows{rows: rows, idx: -1} 18 | } 19 | 20 | type copyFromRows struct { 21 | rows [][]interface{} 22 | idx int 23 | } 24 | 25 | func (ctr *copyFromRows) Next() bool { 26 | ctr.idx++ 27 | return ctr.idx < len(ctr.rows) 28 | } 29 | 30 | func (ctr *copyFromRows) Values() ([]interface{}, error) { 31 | return ctr.rows[ctr.idx], nil 32 | } 33 | 34 | func (ctr *copyFromRows) Err() error { 35 | return nil 36 | } 37 | 38 | // CopyFromSlice returns a CopyFromSource interface over a dynamic func 39 | // making it usable by *Conn.CopyFrom. 40 | func CopyFromSlice(length int, next func(int) ([]interface{}, error)) CopyFromSource { 41 | return ©FromSlice{next: next, idx: -1, len: length} 42 | } 43 | 44 | type copyFromSlice struct { 45 | next func(int) ([]interface{}, error) 46 | idx int 47 | len int 48 | err error 49 | } 50 | 51 | func (cts *copyFromSlice) Next() bool { 52 | cts.idx++ 53 | return cts.idx < cts.len 54 | } 55 | 56 | func (cts *copyFromSlice) Values() ([]interface{}, error) { 57 | values, err := cts.next(cts.idx) 58 | if err != nil { 59 | cts.err = err 60 | } 61 | return values, err 62 | } 63 | 64 | func (cts *copyFromSlice) Err() error { 65 | return cts.err 66 | } 67 | 68 | // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. 69 | type CopyFromSource interface { 70 | // Next returns true if there is another row and makes the next row data 71 | // available to Values(). When there are no more rows available or an error 72 | // has occurred it returns false. 73 | Next() bool 74 | 75 | // Values returns the values for the current row. 76 | Values() ([]interface{}, error) 77 | 78 | // Err returns any error that has been encountered by the CopyFromSource. If 79 | // this is not nil *Conn.CopyFrom will abort the copy. 80 | Err() error 81 | } 82 | 83 | type copyFrom struct { 84 | conn *Conn 85 | tableName Identifier 86 | columnNames []string 87 | rowSrc CopyFromSource 88 | readerErrChan chan error 89 | } 90 | 91 | func (ct *copyFrom) run(ctx context.Context) (int64, error) { 92 | quotedTableName := ct.tableName.Sanitize() 93 | cbuf := &bytes.Buffer{} 94 | for i, cn := range ct.columnNames { 95 | if i != 0 { 96 | cbuf.WriteString(", ") 97 | } 98 | cbuf.WriteString(quoteIdentifier(cn)) 99 | } 100 | quotedColumnNames := cbuf.String() 101 | 102 | sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) 103 | if err != nil { 104 | return 0, err 105 | } 106 | 107 | r, w := io.Pipe() 108 | doneChan := make(chan struct{}) 109 | 110 | go func() { 111 | defer close(doneChan) 112 | 113 | // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. 114 | buf := ct.conn.wbuf 115 | 116 | buf = append(buf, "PGCOPY\n\377\r\n\000"...) 117 | buf = pgio.AppendInt32(buf, 0) 118 | buf = pgio.AppendInt32(buf, 0) 119 | 120 | moreRows := true 121 | for moreRows { 122 | var err error 123 | moreRows, buf, err = ct.buildCopyBuf(buf, sd) 124 | if err != nil { 125 | w.CloseWithError(err) 126 | return 127 | } 128 | 129 | if ct.rowSrc.Err() != nil { 130 | w.CloseWithError(ct.rowSrc.Err()) 131 | return 132 | } 133 | 134 | if len(buf) > 0 { 135 | _, err = w.Write(buf) 136 | if err != nil { 137 | w.Close() 138 | return 139 | } 140 | } 141 | 142 | buf = buf[:0] 143 | } 144 | 145 | w.Close() 146 | }() 147 | 148 | startTime := time.Now() 149 | 150 | commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) 151 | 152 | r.Close() 153 | <-doneChan 154 | 155 | rowsAffected := commandTag.RowsAffected() 156 | endTime := time.Now() 157 | if err == nil { 158 | if ct.conn.shouldLog(LogLevelInfo) { 159 | ct.conn.log(ctx, LogLevelInfo, "CopyFrom", map[string]interface{}{"tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime), "rowCount": rowsAffected}) 160 | } 161 | } else if ct.conn.shouldLog(LogLevelError) { 162 | ct.conn.log(ctx, LogLevelError, "CopyFrom", map[string]interface{}{"err": err, "tableName": ct.tableName, "columnNames": ct.columnNames, "time": endTime.Sub(startTime)}) 163 | } 164 | 165 | return rowsAffected, err 166 | } 167 | 168 | func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { 169 | 170 | for ct.rowSrc.Next() { 171 | values, err := ct.rowSrc.Values() 172 | if err != nil { 173 | return false, nil, err 174 | } 175 | if len(values) != len(ct.columnNames) { 176 | return false, nil, fmt.Errorf("expected %d values, got %d values", len(ct.columnNames), len(values)) 177 | } 178 | 179 | buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) 180 | for i, val := range values { 181 | buf, err = encodePreparedStatementArgument(ct.conn.connInfo, buf, sd.Fields[i].DataTypeOID, val) 182 | if err != nil { 183 | return false, nil, err 184 | } 185 | } 186 | 187 | if len(buf) > 65536 { 188 | return true, buf, nil 189 | } 190 | } 191 | 192 | return false, buf, nil 193 | } 194 | 195 | // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. 196 | // It returns the number of rows copied and an error. 197 | // 198 | // CopyFrom requires all values use the binary format. Almost all types 199 | // implemented by pgx use the binary format by default. Types implementing 200 | // Encoder can only be used if they encode to the binary format. 201 | func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { 202 | ct := ©From{ 203 | conn: c, 204 | tableName: tableName, 205 | columnNames: columnNames, 206 | rowSrc: rowSrc, 207 | readerErrChan: make(chan error), 208 | } 209 | 210 | return ct.run(ctx) 211 | } 212 | -------------------------------------------------------------------------------- /helper_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/jackc/pgconn" 11 | "github.com/jackc/pgx/v4" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { 16 | t.Run("SimpleProto", 17 | func(t *testing.T) { 18 | config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 19 | require.NoError(t, err) 20 | 21 | config.PreferSimpleProtocol = true 22 | conn, err := pgx.ConnectConfig(context.Background(), config) 23 | require.NoError(t, err) 24 | defer func() { 25 | err := conn.Close(context.Background()) 26 | require.NoError(t, err) 27 | }() 28 | 29 | f(t, conn) 30 | 31 | ensureConnValid(t, conn) 32 | }, 33 | ) 34 | 35 | t.Run("DefaultProto", 36 | func(t *testing.T) { 37 | config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 38 | require.NoError(t, err) 39 | 40 | conn, err := pgx.ConnectConfig(context.Background(), config) 41 | require.NoError(t, err) 42 | defer func() { 43 | err := conn.Close(context.Background()) 44 | require.NoError(t, err) 45 | }() 46 | 47 | f(t, conn) 48 | 49 | ensureConnValid(t, conn) 50 | }, 51 | ) 52 | } 53 | 54 | func mustConnectString(t testing.TB, connString string) *pgx.Conn { 55 | conn, err := pgx.Connect(context.Background(), connString) 56 | if err != nil { 57 | t.Fatalf("Unable to establish connection: %v", err) 58 | } 59 | return conn 60 | } 61 | 62 | func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig { 63 | config, err := pgx.ParseConfig(connString) 64 | require.Nil(t, err) 65 | return config 66 | } 67 | 68 | func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn { 69 | conn, err := pgx.ConnectConfig(context.Background(), config) 70 | if err != nil { 71 | t.Fatalf("Unable to establish connection: %v", err) 72 | } 73 | return conn 74 | } 75 | 76 | func closeConn(t testing.TB, conn *pgx.Conn) { 77 | err := conn.Close(context.Background()) 78 | if err != nil { 79 | t.Fatalf("conn.Close unexpectedly failed: %v", err) 80 | } 81 | } 82 | 83 | func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag) { 84 | var err error 85 | if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil { 86 | t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err) 87 | } 88 | return 89 | } 90 | 91 | // Do a simple query to ensure the connection is still usable 92 | func ensureConnValid(t *testing.T, conn *pgx.Conn) { 93 | var sum, rowCount int32 94 | 95 | rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) 96 | if err != nil { 97 | t.Fatalf("conn.Query failed: %v", err) 98 | } 99 | defer rows.Close() 100 | 101 | for rows.Next() { 102 | var n int32 103 | rows.Scan(&n) 104 | sum += n 105 | rowCount++ 106 | } 107 | 108 | if rows.Err() != nil { 109 | t.Fatalf("conn.Query failed: %v", err) 110 | } 111 | 112 | if rowCount != 10 { 113 | t.Error("Select called onDataRow wrong number of times") 114 | } 115 | if sum != 55 { 116 | t.Error("Wrong values returned") 117 | } 118 | } 119 | 120 | func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { 121 | if !assert.NotNil(t, expected) { 122 | return 123 | } 124 | if !assert.NotNil(t, actual) { 125 | return 126 | } 127 | 128 | assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) 129 | assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) 130 | assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) 131 | // Can't test function equality, so just test that they are set or not. 132 | assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) 133 | assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) 134 | 135 | assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) 136 | assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) 137 | assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) 138 | assert.Equalf(t, expected.User, actual.User, "%s - User", testName) 139 | assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) 140 | assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) 141 | assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) 142 | 143 | // Can't test function equality, so just test that they are set or not. 144 | assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) 145 | assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) 146 | 147 | if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { 148 | if expected.TLSConfig != nil { 149 | assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) 150 | assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) 151 | } 152 | } 153 | 154 | if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { 155 | for i := range expected.Fallbacks { 156 | assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) 157 | assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) 158 | 159 | if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { 160 | if expected.Fallbacks[i].TLSConfig != nil { 161 | assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) 162 | assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) 163 | } 164 | } 165 | } 166 | } 167 | } 168 | 169 | func skipCockroachDB(t testing.TB, conn *pgx.Conn, msg string) { 170 | if conn.PgConn().ParameterStatus("crdb_version") != "" { 171 | t.Skip(msg) 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /batch.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | 8 | "github.com/jackc/pgconn" 9 | ) 10 | 11 | type batchItem struct { 12 | query string 13 | arguments []interface{} 14 | } 15 | 16 | // Batch queries are a way of bundling multiple queries together to avoid 17 | // unnecessary network round trips. 18 | type Batch struct { 19 | items []*batchItem 20 | } 21 | 22 | // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. 23 | func (b *Batch) Queue(query string, arguments ...interface{}) { 24 | b.items = append(b.items, &batchItem{ 25 | query: query, 26 | arguments: arguments, 27 | }) 28 | } 29 | 30 | // Len returns number of queries that have been queued so far. 31 | func (b *Batch) Len() int { 32 | return len(b.items) 33 | } 34 | 35 | type BatchResults interface { 36 | // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. 37 | Exec() (pgconn.CommandTag, error) 38 | 39 | // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. 40 | Query() (Rows, error) 41 | 42 | // QueryRow reads the results from the next query in the batch as if the query has been sent with Conn.QueryRow. 43 | QueryRow() Row 44 | 45 | // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. 46 | QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) 47 | 48 | // Close closes the batch operation. This must be called before the underlying connection can be used again. Any error 49 | // that occurred during a batch operation may have made it impossible to resyncronize the connection with the server. 50 | // In this case the underlying connection will have been closed. Close is safe to call multiple times. 51 | Close() error 52 | } 53 | 54 | type batchResults struct { 55 | ctx context.Context 56 | conn *Conn 57 | mrr *pgconn.MultiResultReader 58 | err error 59 | b *Batch 60 | ix int 61 | closed bool 62 | } 63 | 64 | // Exec reads the results from the next query in the batch as if the query has been sent with Exec. 65 | func (br *batchResults) Exec() (pgconn.CommandTag, error) { 66 | if br.err != nil { 67 | return nil, br.err 68 | } 69 | if br.closed { 70 | return nil, fmt.Errorf("batch already closed") 71 | } 72 | 73 | query, arguments, _ := br.nextQueryAndArgs() 74 | 75 | if !br.mrr.NextResult() { 76 | err := br.mrr.Close() 77 | if err == nil { 78 | err = errors.New("no result") 79 | } 80 | if br.conn.shouldLog(LogLevelError) { 81 | br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ 82 | "sql": query, 83 | "args": logQueryArgs(arguments), 84 | "err": err, 85 | }) 86 | } 87 | return nil, err 88 | } 89 | 90 | commandTag, err := br.mrr.ResultReader().Close() 91 | 92 | if err != nil { 93 | if br.conn.shouldLog(LogLevelError) { 94 | br.conn.log(br.ctx, LogLevelError, "BatchResult.Exec", map[string]interface{}{ 95 | "sql": query, 96 | "args": logQueryArgs(arguments), 97 | "err": err, 98 | }) 99 | } 100 | } else if br.conn.shouldLog(LogLevelInfo) { 101 | br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Exec", map[string]interface{}{ 102 | "sql": query, 103 | "args": logQueryArgs(arguments), 104 | "commandTag": commandTag, 105 | }) 106 | } 107 | 108 | return commandTag, err 109 | } 110 | 111 | // Query reads the results from the next query in the batch as if the query has been sent with Query. 112 | func (br *batchResults) Query() (Rows, error) { 113 | query, arguments, ok := br.nextQueryAndArgs() 114 | if !ok { 115 | query = "batch query" 116 | } 117 | 118 | if br.err != nil { 119 | return &connRows{err: br.err, closed: true}, br.err 120 | } 121 | 122 | if br.closed { 123 | alreadyClosedErr := fmt.Errorf("batch already closed") 124 | return &connRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr 125 | } 126 | 127 | rows := br.conn.getRows(br.ctx, query, arguments) 128 | 129 | if !br.mrr.NextResult() { 130 | rows.err = br.mrr.Close() 131 | if rows.err == nil { 132 | rows.err = errors.New("no result") 133 | } 134 | rows.closed = true 135 | 136 | if br.conn.shouldLog(LogLevelError) { 137 | br.conn.log(br.ctx, LogLevelError, "BatchResult.Query", map[string]interface{}{ 138 | "sql": query, 139 | "args": logQueryArgs(arguments), 140 | "err": rows.err, 141 | }) 142 | } 143 | 144 | return rows, rows.err 145 | } 146 | 147 | rows.resultReader = br.mrr.ResultReader() 148 | return rows, nil 149 | } 150 | 151 | // QueryFunc reads the results from the next query in the batch as if the query has been sent with Conn.QueryFunc. 152 | func (br *batchResults) QueryFunc(scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { 153 | if br.closed { 154 | return nil, fmt.Errorf("batch already closed") 155 | } 156 | 157 | rows, err := br.Query() 158 | if err != nil { 159 | return nil, err 160 | } 161 | defer rows.Close() 162 | 163 | for rows.Next() { 164 | err = rows.Scan(scans...) 165 | if err != nil { 166 | return nil, err 167 | } 168 | 169 | err = f(rows) 170 | if err != nil { 171 | return nil, err 172 | } 173 | } 174 | 175 | if err := rows.Err(); err != nil { 176 | return nil, err 177 | } 178 | 179 | return rows.CommandTag(), nil 180 | } 181 | 182 | // QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. 183 | func (br *batchResults) QueryRow() Row { 184 | rows, _ := br.Query() 185 | return (*connRow)(rows.(*connRows)) 186 | 187 | } 188 | 189 | // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to 190 | // resyncronize the connection with the server. In this case the underlying connection will have been closed. 191 | func (br *batchResults) Close() error { 192 | if br.err != nil { 193 | return br.err 194 | } 195 | 196 | if br.closed { 197 | return nil 198 | } 199 | br.closed = true 200 | 201 | // log any queries that haven't yet been logged by Exec or Query 202 | for { 203 | query, args, ok := br.nextQueryAndArgs() 204 | if !ok { 205 | break 206 | } 207 | 208 | if br.conn.shouldLog(LogLevelInfo) { 209 | br.conn.log(br.ctx, LogLevelInfo, "BatchResult.Close", map[string]interface{}{ 210 | "sql": query, 211 | "args": logQueryArgs(args), 212 | }) 213 | } 214 | } 215 | 216 | return br.mrr.Close() 217 | } 218 | 219 | func (br *batchResults) nextQueryAndArgs() (query string, args []interface{}, ok bool) { 220 | if br.b != nil && br.ix < len(br.b.items) { 221 | bi := br.b.items[br.ix] 222 | query = bi.query 223 | args = bi.arguments 224 | ok = true 225 | br.ix++ 226 | } 227 | return 228 | } 229 | -------------------------------------------------------------------------------- /large_objects_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "github.com/jackc/pgconn" 11 | "github.com/jackc/pgx/v4" 12 | ) 13 | 14 | func TestLargeObjects(t *testing.T) { 15 | t.Parallel() 16 | 17 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 18 | defer cancel() 19 | 20 | conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | 25 | skipCockroachDB(t, conn, "Server does support large objects") 26 | 27 | tx, err := conn.Begin(ctx) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | 32 | testLargeObjects(t, ctx, tx) 33 | } 34 | 35 | func TestLargeObjectsPreferSimpleProtocol(t *testing.T) { 36 | t.Parallel() 37 | 38 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 39 | defer cancel() 40 | 41 | config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) 42 | if err != nil { 43 | t.Fatal(err) 44 | } 45 | 46 | config.PreferSimpleProtocol = true 47 | 48 | conn, err := pgx.ConnectConfig(ctx, config) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | skipCockroachDB(t, conn, "Server does support large objects") 54 | 55 | tx, err := conn.Begin(ctx) 56 | if err != nil { 57 | t.Fatal(err) 58 | } 59 | 60 | testLargeObjects(t, ctx, tx) 61 | } 62 | 63 | func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) { 64 | lo := tx.LargeObjects() 65 | 66 | id, err := lo.Create(ctx, 0) 67 | if err != nil { 68 | t.Fatal(err) 69 | } 70 | 71 | obj, err := lo.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | 76 | n, err := obj.Write([]byte("testing")) 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | if n != 7 { 81 | t.Errorf("Expected n to be 7, got %d", n) 82 | } 83 | 84 | pos, err := obj.Seek(1, 0) 85 | if err != nil { 86 | t.Fatal(err) 87 | } 88 | if pos != 1 { 89 | t.Errorf("Expected pos to be 1, got %d", pos) 90 | } 91 | 92 | res := make([]byte, 6) 93 | n, err = obj.Read(res) 94 | if err != nil { 95 | t.Fatal(err) 96 | } 97 | if string(res) != "esting" { 98 | t.Errorf(`Expected res to be "esting", got %q`, res) 99 | } 100 | if n != 6 { 101 | t.Errorf("Expected n to be 6, got %d", n) 102 | } 103 | 104 | n, err = obj.Read(res) 105 | if err != io.EOF { 106 | t.Error("Expected io.EOF, go nil") 107 | } 108 | if n != 0 { 109 | t.Errorf("Expected n to be 0, got %d", n) 110 | } 111 | 112 | pos, err = obj.Tell() 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | if pos != 7 { 117 | t.Errorf("Expected pos to be 7, got %d", pos) 118 | } 119 | 120 | err = obj.Truncate(1) 121 | if err != nil { 122 | t.Fatal(err) 123 | } 124 | 125 | pos, err = obj.Seek(-1, 2) 126 | if err != nil { 127 | t.Fatal(err) 128 | } 129 | if pos != 0 { 130 | t.Errorf("Expected pos to be 0, got %d", pos) 131 | } 132 | 133 | res = make([]byte, 2) 134 | n, err = obj.Read(res) 135 | if err != io.EOF { 136 | t.Errorf("Expected err to be io.EOF, got %v", err) 137 | } 138 | if n != 1 { 139 | t.Errorf("Expected n to be 1, got %d", n) 140 | } 141 | if res[0] != 't' { 142 | t.Errorf("Expected res[0] to be 't', got %v", res[0]) 143 | } 144 | 145 | err = obj.Close() 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | err = lo.Unlink(ctx, id) 151 | if err != nil { 152 | t.Fatal(err) 153 | } 154 | 155 | _, err = lo.Open(ctx, id, pgx.LargeObjectModeRead) 156 | if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { 157 | t.Errorf("Expected undefined_object error (42704), got %#v", err) 158 | } 159 | } 160 | 161 | func TestLargeObjectsMultipleTransactions(t *testing.T) { 162 | t.Parallel() 163 | 164 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 165 | defer cancel() 166 | 167 | conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | 172 | skipCockroachDB(t, conn, "Server does support large objects") 173 | 174 | tx, err := conn.Begin(ctx) 175 | if err != nil { 176 | t.Fatal(err) 177 | } 178 | 179 | lo := tx.LargeObjects() 180 | 181 | id, err := lo.Create(ctx, 0) 182 | if err != nil { 183 | t.Fatal(err) 184 | } 185 | 186 | obj, err := lo.Open(ctx, id, pgx.LargeObjectModeWrite) 187 | if err != nil { 188 | t.Fatal(err) 189 | } 190 | 191 | n, err := obj.Write([]byte("testing")) 192 | if err != nil { 193 | t.Fatal(err) 194 | } 195 | if n != 7 { 196 | t.Errorf("Expected n to be 7, got %d", n) 197 | } 198 | 199 | // Commit the first transaction 200 | err = tx.Commit(ctx) 201 | if err != nil { 202 | t.Fatal(err) 203 | } 204 | 205 | // IMPORTANT: Use the same connection for another query 206 | query := `select n from generate_series(1,10) n` 207 | rows, err := conn.Query(ctx, query) 208 | if err != nil { 209 | t.Fatal(err) 210 | } 211 | rows.Close() 212 | 213 | // Start a new transaction 214 | tx2, err := conn.Begin(ctx) 215 | if err != nil { 216 | t.Fatal(err) 217 | } 218 | 219 | lo2 := tx2.LargeObjects() 220 | 221 | // Reopen the large object in the new transaction 222 | obj2, err := lo2.Open(ctx, id, pgx.LargeObjectModeRead|pgx.LargeObjectModeWrite) 223 | if err != nil { 224 | t.Fatal(err) 225 | } 226 | 227 | pos, err := obj2.Seek(1, 0) 228 | if err != nil { 229 | t.Fatal(err) 230 | } 231 | if pos != 1 { 232 | t.Errorf("Expected pos to be 1, got %d", pos) 233 | } 234 | 235 | res := make([]byte, 6) 236 | n, err = obj2.Read(res) 237 | if err != nil { 238 | t.Fatal(err) 239 | } 240 | if string(res) != "esting" { 241 | t.Errorf(`Expected res to be "esting", got %q`, res) 242 | } 243 | if n != 6 { 244 | t.Errorf("Expected n to be 6, got %d", n) 245 | } 246 | 247 | n, err = obj2.Read(res) 248 | if err != io.EOF { 249 | t.Error("Expected io.EOF, go nil") 250 | } 251 | if n != 0 { 252 | t.Errorf("Expected n to be 0, got %d", n) 253 | } 254 | 255 | pos, err = obj2.Tell() 256 | if err != nil { 257 | t.Fatal(err) 258 | } 259 | if pos != 7 { 260 | t.Errorf("Expected pos to be 7, got %d", pos) 261 | } 262 | 263 | err = obj2.Truncate(1) 264 | if err != nil { 265 | t.Fatal(err) 266 | } 267 | 268 | pos, err = obj2.Seek(-1, 2) 269 | if err != nil { 270 | t.Fatal(err) 271 | } 272 | if pos != 0 { 273 | t.Errorf("Expected pos to be 0, got %d", pos) 274 | } 275 | 276 | res = make([]byte, 2) 277 | n, err = obj2.Read(res) 278 | if err != io.EOF { 279 | t.Errorf("Expected err to be io.EOF, got %v", err) 280 | } 281 | if n != 1 { 282 | t.Errorf("Expected n to be 1, got %d", n) 283 | } 284 | if res[0] != 't' { 285 | t.Errorf("Expected res[0] to be 't', got %v", res[0]) 286 | } 287 | 288 | err = obj2.Close() 289 | if err != nil { 290 | t.Fatal(err) 291 | } 292 | 293 | err = lo2.Unlink(ctx, id) 294 | if err != nil { 295 | t.Fatal(err) 296 | } 297 | 298 | _, err = lo2.Open(ctx, id, pgx.LargeObjectModeRead) 299 | if e, ok := err.(*pgconn.PgError); !ok || e.Code != "42704" { 300 | t.Errorf("Expected undefined_object error (42704), got %#v", err) 301 | } 302 | } 303 | -------------------------------------------------------------------------------- /internal/sanitize/sanitize_test.go: -------------------------------------------------------------------------------- 1 | package sanitize_test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/jackc/pgx/v4/internal/sanitize" 8 | ) 9 | 10 | func TestNewQuery(t *testing.T) { 11 | successTests := []struct { 12 | sql string 13 | expected sanitize.Query 14 | }{ 15 | { 16 | sql: "select 42", 17 | expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, 18 | }, 19 | { 20 | sql: "select $1", 21 | expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 22 | }, 23 | { 24 | sql: "select 'quoted $42', $1", 25 | expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}}, 26 | }, 27 | { 28 | sql: `select "doubled quoted $42", $1`, 29 | expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}}, 30 | }, 31 | { 32 | sql: "select 'foo''bar', $1", 33 | expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}}, 34 | }, 35 | { 36 | sql: `select "foo""bar", $1`, 37 | expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}}, 38 | }, 39 | { 40 | sql: "select '''', $1", 41 | expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}}, 42 | }, 43 | { 44 | sql: `select """", $1`, 45 | expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}}, 46 | }, 47 | { 48 | sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11", 49 | expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}}, 50 | }, 51 | { 52 | sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`, 53 | expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}}, 54 | }, 55 | { 56 | sql: `select E'escape string\' $42', $1`, 57 | expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}}, 58 | }, 59 | { 60 | sql: `select e'escape string\' $42', $1`, 61 | expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, 62 | }, 63 | { 64 | sql: `select /* a baby's toy */ 'barbie', $1`, 65 | expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}}, 66 | }, 67 | { 68 | sql: `select /* *_* */ $1`, 69 | expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}}, 70 | }, 71 | { 72 | sql: `select 42 /* /* /* 42 */ */ */, $1`, 73 | expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}}, 74 | }, 75 | { 76 | sql: "select -- a baby's toy\n'barbie', $1", 77 | expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}}, 78 | }, 79 | { 80 | sql: "select 42 -- is a Deep Thought's favorite number", 81 | expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Deep Thought's favorite number"}}, 82 | }, 83 | { 84 | sql: "select 42, -- \\nis a Deep Thought's favorite number\n$1", 85 | expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\n", 1}}, 86 | }, 87 | { 88 | sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1", 89 | expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}}, 90 | }, 91 | } 92 | 93 | for i, tt := range successTests { 94 | query, err := sanitize.NewQuery(tt.sql) 95 | if err != nil { 96 | t.Errorf("%d. %v", i, err) 97 | } 98 | 99 | if len(query.Parts) == len(tt.expected.Parts) { 100 | for j := range query.Parts { 101 | if query.Parts[j] != tt.expected.Parts[j] { 102 | t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j]) 103 | } 104 | } 105 | } else { 106 | t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts) 107 | } 108 | } 109 | } 110 | 111 | func TestQuerySanitize(t *testing.T) { 112 | successfulTests := []struct { 113 | query sanitize.Query 114 | args []interface{} 115 | expected string 116 | }{ 117 | { 118 | query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, 119 | args: []interface{}{}, 120 | expected: `select 42`, 121 | }, 122 | { 123 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 124 | args: []interface{}{int64(42)}, 125 | expected: `select 42`, 126 | }, 127 | { 128 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 129 | args: []interface{}{float64(1.23)}, 130 | expected: `select 1.23`, 131 | }, 132 | { 133 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 134 | args: []interface{}{true}, 135 | expected: `select true`, 136 | }, 137 | { 138 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 139 | args: []interface{}{[]byte{0, 1, 2, 3, 255}}, 140 | expected: `select '\x00010203ff'`, 141 | }, 142 | { 143 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 144 | args: []interface{}{nil}, 145 | expected: `select null`, 146 | }, 147 | { 148 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 149 | args: []interface{}{"foobar"}, 150 | expected: `select 'foobar'`, 151 | }, 152 | { 153 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 154 | args: []interface{}{"foo'bar"}, 155 | expected: `select 'foo''bar'`, 156 | }, 157 | { 158 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 159 | args: []interface{}{`foo\'bar`}, 160 | expected: `select 'foo\''bar'`, 161 | }, 162 | { 163 | query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, 164 | args: []interface{}{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, 165 | expected: `insert '2020-03-01 23:59:59.999999Z'`, 166 | }, 167 | } 168 | 169 | for i, tt := range successfulTests { 170 | actual, err := tt.query.Sanitize(tt.args...) 171 | if err != nil { 172 | t.Errorf("%d. %v", i, err) 173 | continue 174 | } 175 | 176 | if tt.expected != actual { 177 | t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual) 178 | } 179 | } 180 | 181 | errorTests := []struct { 182 | query sanitize.Query 183 | args []interface{} 184 | expected string 185 | }{ 186 | { 187 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, 188 | args: []interface{}{int64(42)}, 189 | expected: `insufficient arguments`, 190 | }, 191 | { 192 | query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, 193 | args: []interface{}{int64(42)}, 194 | expected: `unused argument: 0`, 195 | }, 196 | { 197 | query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, 198 | args: []interface{}{42}, 199 | expected: `invalid arg type: int`, 200 | }, 201 | } 202 | 203 | for i, tt := range errorTests { 204 | _, err := tt.query.Sanitize(tt.args...) 205 | if err == nil || err.Error() != tt.expected { 206 | t.Errorf("%d. expected error %v, got %v", i, tt.expected, err) 207 | } 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /internal/sanitize/sanitize.go: -------------------------------------------------------------------------------- 1 | package sanitize 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "fmt" 7 | "strconv" 8 | "strings" 9 | "time" 10 | "unicode/utf8" 11 | ) 12 | 13 | // Part is either a string or an int. A string is raw SQL. An int is a 14 | // argument placeholder. 15 | type Part interface{} 16 | 17 | type Query struct { 18 | Parts []Part 19 | } 20 | 21 | func (q *Query) Sanitize(args ...interface{}) (string, error) { 22 | argUse := make([]bool, len(args)) 23 | buf := &bytes.Buffer{} 24 | 25 | for _, part := range q.Parts { 26 | var str string 27 | switch part := part.(type) { 28 | case string: 29 | str = part 30 | case int: 31 | argIdx := part - 1 32 | if argIdx >= len(args) { 33 | return "", fmt.Errorf("insufficient arguments") 34 | } 35 | arg := args[argIdx] 36 | switch arg := arg.(type) { 37 | case nil: 38 | str = "null" 39 | case int64: 40 | str = strconv.FormatInt(arg, 10) 41 | case float64: 42 | str = strconv.FormatFloat(arg, 'f', -1, 64) 43 | case bool: 44 | str = strconv.FormatBool(arg) 45 | case []byte: 46 | str = QuoteBytes(arg) 47 | case string: 48 | str = QuoteString(arg) 49 | case time.Time: 50 | str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") 51 | default: 52 | return "", fmt.Errorf("invalid arg type: %T", arg) 53 | } 54 | argUse[argIdx] = true 55 | default: 56 | return "", fmt.Errorf("invalid Part type: %T", part) 57 | } 58 | buf.WriteString(str) 59 | } 60 | 61 | for i, used := range argUse { 62 | if !used { 63 | return "", fmt.Errorf("unused argument: %d", i) 64 | } 65 | } 66 | return buf.String(), nil 67 | } 68 | 69 | func NewQuery(sql string) (*Query, error) { 70 | l := &sqlLexer{ 71 | src: sql, 72 | stateFn: rawState, 73 | } 74 | 75 | for l.stateFn != nil { 76 | l.stateFn = l.stateFn(l) 77 | } 78 | 79 | query := &Query{Parts: l.parts} 80 | 81 | return query, nil 82 | } 83 | 84 | func QuoteString(str string) string { 85 | return "'" + strings.ReplaceAll(str, "'", "''") + "'" 86 | } 87 | 88 | func QuoteBytes(buf []byte) string { 89 | return `'\x` + hex.EncodeToString(buf) + "'" 90 | } 91 | 92 | type sqlLexer struct { 93 | src string 94 | start int 95 | pos int 96 | nested int // multiline comment nesting level. 97 | stateFn stateFn 98 | parts []Part 99 | } 100 | 101 | type stateFn func(*sqlLexer) stateFn 102 | 103 | func rawState(l *sqlLexer) stateFn { 104 | for { 105 | r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 106 | l.pos += width 107 | 108 | switch r { 109 | case 'e', 'E': 110 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 111 | if nextRune == '\'' { 112 | l.pos += width 113 | return escapeStringState 114 | } 115 | case '\'': 116 | return singleQuoteState 117 | case '"': 118 | return doubleQuoteState 119 | case '$': 120 | nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) 121 | if '0' <= nextRune && nextRune <= '9' { 122 | if l.pos-l.start > 0 { 123 | l.parts = append(l.parts, l.src[l.start:l.pos-width]) 124 | } 125 | l.start = l.pos 126 | return placeholderState 127 | } 128 | case '-': 129 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 130 | if nextRune == '-' { 131 | l.pos += width 132 | return oneLineCommentState 133 | } 134 | case '/': 135 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 136 | if nextRune == '*' { 137 | l.pos += width 138 | return multilineCommentState 139 | } 140 | case utf8.RuneError: 141 | if l.pos-l.start > 0 { 142 | l.parts = append(l.parts, l.src[l.start:l.pos]) 143 | l.start = l.pos 144 | } 145 | return nil 146 | } 147 | } 148 | } 149 | 150 | func singleQuoteState(l *sqlLexer) stateFn { 151 | for { 152 | r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 153 | l.pos += width 154 | 155 | switch r { 156 | case '\'': 157 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 158 | if nextRune != '\'' { 159 | return rawState 160 | } 161 | l.pos += width 162 | case utf8.RuneError: 163 | if l.pos-l.start > 0 { 164 | l.parts = append(l.parts, l.src[l.start:l.pos]) 165 | l.start = l.pos 166 | } 167 | return nil 168 | } 169 | } 170 | } 171 | 172 | func doubleQuoteState(l *sqlLexer) stateFn { 173 | for { 174 | r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 175 | l.pos += width 176 | 177 | switch r { 178 | case '"': 179 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 180 | if nextRune != '"' { 181 | return rawState 182 | } 183 | l.pos += width 184 | case utf8.RuneError: 185 | if l.pos-l.start > 0 { 186 | l.parts = append(l.parts, l.src[l.start:l.pos]) 187 | l.start = l.pos 188 | } 189 | return nil 190 | } 191 | } 192 | } 193 | 194 | // placeholderState consumes a placeholder value. The $ must have already has 195 | // already been consumed. The first rune must be a digit. 196 | func placeholderState(l *sqlLexer) stateFn { 197 | num := 0 198 | 199 | for { 200 | r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 201 | l.pos += width 202 | 203 | if '0' <= r && r <= '9' { 204 | num *= 10 205 | num += int(r - '0') 206 | } else { 207 | l.parts = append(l.parts, num) 208 | l.pos -= width 209 | l.start = l.pos 210 | return rawState 211 | } 212 | } 213 | } 214 | 215 | func escapeStringState(l *sqlLexer) stateFn { 216 | for { 217 | r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 218 | l.pos += width 219 | 220 | switch r { 221 | case '\\': 222 | _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 223 | l.pos += width 224 | case '\'': 225 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 226 | if nextRune != '\'' { 227 | return rawState 228 | } 229 | l.pos += width 230 | case utf8.RuneError: 231 | if l.pos-l.start > 0 { 232 | l.parts = append(l.parts, l.src[l.start:l.pos]) 233 | l.start = l.pos 234 | } 235 | return nil 236 | } 237 | } 238 | } 239 | 240 | func oneLineCommentState(l *sqlLexer) stateFn { 241 | for { 242 | r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 243 | l.pos += width 244 | 245 | switch r { 246 | case '\\': 247 | _, width = utf8.DecodeRuneInString(l.src[l.pos:]) 248 | l.pos += width 249 | case '\n', '\r': 250 | return rawState 251 | case utf8.RuneError: 252 | if l.pos-l.start > 0 { 253 | l.parts = append(l.parts, l.src[l.start:l.pos]) 254 | l.start = l.pos 255 | } 256 | return nil 257 | } 258 | } 259 | } 260 | 261 | func multilineCommentState(l *sqlLexer) stateFn { 262 | for { 263 | r, width := utf8.DecodeRuneInString(l.src[l.pos:]) 264 | l.pos += width 265 | 266 | switch r { 267 | case '/': 268 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 269 | if nextRune == '*' { 270 | l.pos += width 271 | l.nested++ 272 | } 273 | case '*': 274 | nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) 275 | if nextRune != '/' { 276 | continue 277 | } 278 | 279 | l.pos += width 280 | if l.nested == 0 { 281 | return rawState 282 | } 283 | l.nested-- 284 | 285 | case utf8.RuneError: 286 | if l.pos-l.start > 0 { 287 | l.parts = append(l.parts, l.src[l.start:l.pos]) 288 | l.start = l.pos 289 | } 290 | return nil 291 | } 292 | } 293 | } 294 | 295 | // SanitizeSQL replaces placeholder values with args. It quotes and escapes args 296 | // as necessary. This function is only safe when standard_conforming_strings is 297 | // on. 298 | func SanitizeSQL(sql string, args ...interface{}) (string, error) { 299 | query, err := NewQuery(sql) 300 | if err != nil { 301 | return "", err 302 | } 303 | return query.Sanitize(args...) 304 | } 305 | -------------------------------------------------------------------------------- /values.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "math" 7 | "reflect" 8 | "time" 9 | 10 | "github.com/jackc/pgio" 11 | "github.com/jackc/pgtype" 12 | ) 13 | 14 | // PostgreSQL format codes 15 | const ( 16 | TextFormatCode = 0 17 | BinaryFormatCode = 1 18 | ) 19 | 20 | // SerializationError occurs on failure to encode or decode a value 21 | type SerializationError string 22 | 23 | func (e SerializationError) Error() string { 24 | return string(e) 25 | } 26 | 27 | func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) { 28 | if arg == nil { 29 | return nil, nil 30 | } 31 | 32 | refVal := reflect.ValueOf(arg) 33 | if refVal.Kind() == reflect.Ptr && refVal.IsNil() { 34 | return nil, nil 35 | } 36 | 37 | switch arg := arg.(type) { 38 | 39 | // https://github.com/jackc/pgx/issues/409 Changed JSON and JSONB to surface 40 | // []byte to database/sql instead of string. But that caused problems with the 41 | // simple protocol because the driver.Valuer case got taken before the 42 | // pgtype.TextEncoder case. And driver.Valuer needed to be first in the usual 43 | // case because of https://github.com/jackc/pgx/issues/339. So instead we 44 | // special case JSON and JSONB. 45 | case *pgtype.JSON: 46 | buf, err := arg.EncodeText(ci, nil) 47 | if err != nil { 48 | return nil, err 49 | } 50 | if buf == nil { 51 | return nil, nil 52 | } 53 | return string(buf), nil 54 | case *pgtype.JSONB: 55 | buf, err := arg.EncodeText(ci, nil) 56 | if err != nil { 57 | return nil, err 58 | } 59 | if buf == nil { 60 | return nil, nil 61 | } 62 | return string(buf), nil 63 | 64 | case driver.Valuer: 65 | return callValuerValue(arg) 66 | case pgtype.TextEncoder: 67 | buf, err := arg.EncodeText(ci, nil) 68 | if err != nil { 69 | return nil, err 70 | } 71 | if buf == nil { 72 | return nil, nil 73 | } 74 | return string(buf), nil 75 | case float32: 76 | return float64(arg), nil 77 | case float64: 78 | return arg, nil 79 | case bool: 80 | return arg, nil 81 | case time.Duration: 82 | return fmt.Sprintf("%d microsecond", int64(arg)/1000), nil 83 | case time.Time: 84 | return arg, nil 85 | case string: 86 | return arg, nil 87 | case []byte: 88 | return arg, nil 89 | case int8: 90 | return int64(arg), nil 91 | case int16: 92 | return int64(arg), nil 93 | case int32: 94 | return int64(arg), nil 95 | case int64: 96 | return arg, nil 97 | case int: 98 | return int64(arg), nil 99 | case uint8: 100 | return int64(arg), nil 101 | case uint16: 102 | return int64(arg), nil 103 | case uint32: 104 | return int64(arg), nil 105 | case uint64: 106 | if arg > math.MaxInt64 { 107 | return nil, fmt.Errorf("arg too big for int64: %v", arg) 108 | } 109 | return int64(arg), nil 110 | case uint: 111 | if uint64(arg) > math.MaxInt64 { 112 | return nil, fmt.Errorf("arg too big for int64: %v", arg) 113 | } 114 | return int64(arg), nil 115 | } 116 | 117 | if dt, found := ci.DataTypeForValue(arg); found { 118 | v := dt.Value 119 | err := v.Set(arg) 120 | if err != nil { 121 | return nil, err 122 | } 123 | buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) 124 | if err != nil { 125 | return nil, err 126 | } 127 | if buf == nil { 128 | return nil, nil 129 | } 130 | return string(buf), nil 131 | } 132 | 133 | if refVal.Kind() == reflect.Ptr { 134 | arg = refVal.Elem().Interface() 135 | return convertSimpleArgument(ci, arg) 136 | } 137 | 138 | if strippedArg, ok := stripNamedType(&refVal); ok { 139 | return convertSimpleArgument(ci, strippedArg) 140 | } 141 | return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) 142 | } 143 | 144 | func encodePreparedStatementArgument(ci *pgtype.ConnInfo, buf []byte, oid uint32, arg interface{}) ([]byte, error) { 145 | if arg == nil { 146 | return pgio.AppendInt32(buf, -1), nil 147 | } 148 | 149 | switch arg := arg.(type) { 150 | case pgtype.BinaryEncoder: 151 | sp := len(buf) 152 | buf = pgio.AppendInt32(buf, -1) 153 | argBuf, err := arg.EncodeBinary(ci, buf) 154 | if err != nil { 155 | return nil, err 156 | } 157 | if argBuf != nil { 158 | buf = argBuf 159 | pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) 160 | } 161 | return buf, nil 162 | case pgtype.TextEncoder: 163 | sp := len(buf) 164 | buf = pgio.AppendInt32(buf, -1) 165 | argBuf, err := arg.EncodeText(ci, buf) 166 | if err != nil { 167 | return nil, err 168 | } 169 | if argBuf != nil { 170 | buf = argBuf 171 | pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) 172 | } 173 | return buf, nil 174 | case string: 175 | buf = pgio.AppendInt32(buf, int32(len(arg))) 176 | buf = append(buf, arg...) 177 | return buf, nil 178 | } 179 | 180 | refVal := reflect.ValueOf(arg) 181 | 182 | if refVal.Kind() == reflect.Ptr { 183 | if refVal.IsNil() { 184 | return pgio.AppendInt32(buf, -1), nil 185 | } 186 | arg = refVal.Elem().Interface() 187 | return encodePreparedStatementArgument(ci, buf, oid, arg) 188 | } 189 | 190 | if dt, ok := ci.DataTypeForOID(oid); ok { 191 | value := dt.Value 192 | err := value.Set(arg) 193 | if err != nil { 194 | { 195 | if arg, ok := arg.(driver.Valuer); ok { 196 | v, err := callValuerValue(arg) 197 | if err != nil { 198 | return nil, err 199 | } 200 | return encodePreparedStatementArgument(ci, buf, oid, v) 201 | } 202 | } 203 | 204 | return nil, err 205 | } 206 | 207 | sp := len(buf) 208 | buf = pgio.AppendInt32(buf, -1) 209 | argBuf, err := value.(pgtype.BinaryEncoder).EncodeBinary(ci, buf) 210 | if err != nil { 211 | return nil, err 212 | } 213 | if argBuf != nil { 214 | buf = argBuf 215 | pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) 216 | } 217 | return buf, nil 218 | } 219 | 220 | if strippedArg, ok := stripNamedType(&refVal); ok { 221 | return encodePreparedStatementArgument(ci, buf, oid, strippedArg) 222 | } 223 | return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) 224 | } 225 | 226 | // chooseParameterFormatCode determines the correct format code for an 227 | // argument to a prepared statement. It defaults to TextFormatCode if no 228 | // determination can be made. 229 | func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid uint32, arg interface{}) int16 { 230 | switch arg := arg.(type) { 231 | case pgtype.ParamFormatPreferrer: 232 | return arg.PreferredParamFormat() 233 | case pgtype.BinaryEncoder: 234 | return BinaryFormatCode 235 | case string, *string, pgtype.TextEncoder: 236 | return TextFormatCode 237 | } 238 | 239 | return ci.ParamFormatCodeForOID(oid) 240 | } 241 | 242 | func stripNamedType(val *reflect.Value) (interface{}, bool) { 243 | switch val.Kind() { 244 | case reflect.Int: 245 | convVal := int(val.Int()) 246 | return convVal, reflect.TypeOf(convVal) != val.Type() 247 | case reflect.Int8: 248 | convVal := int8(val.Int()) 249 | return convVal, reflect.TypeOf(convVal) != val.Type() 250 | case reflect.Int16: 251 | convVal := int16(val.Int()) 252 | return convVal, reflect.TypeOf(convVal) != val.Type() 253 | case reflect.Int32: 254 | convVal := int32(val.Int()) 255 | return convVal, reflect.TypeOf(convVal) != val.Type() 256 | case reflect.Int64: 257 | convVal := int64(val.Int()) 258 | return convVal, reflect.TypeOf(convVal) != val.Type() 259 | case reflect.Uint: 260 | convVal := uint(val.Uint()) 261 | return convVal, reflect.TypeOf(convVal) != val.Type() 262 | case reflect.Uint8: 263 | convVal := uint8(val.Uint()) 264 | return convVal, reflect.TypeOf(convVal) != val.Type() 265 | case reflect.Uint16: 266 | convVal := uint16(val.Uint()) 267 | return convVal, reflect.TypeOf(convVal) != val.Type() 268 | case reflect.Uint32: 269 | convVal := uint32(val.Uint()) 270 | return convVal, reflect.TypeOf(convVal) != val.Type() 271 | case reflect.Uint64: 272 | convVal := uint64(val.Uint()) 273 | return convVal, reflect.TypeOf(convVal) != val.Type() 274 | case reflect.String: 275 | convVal := val.String() 276 | return convVal, reflect.TypeOf(convVal) != val.Type() 277 | } 278 | 279 | return nil, false 280 | } 281 | -------------------------------------------------------------------------------- /pgxpool/common_test.go: -------------------------------------------------------------------------------- 1 | package pgxpool_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/jackc/pgx/v4/pgxpool" 9 | 10 | "github.com/jackc/pgconn" 11 | "github.com/jackc/pgx/v4" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | // Conn.Release is an asynchronous process that returns immediately. There is no signal when the actual work is 17 | // completed. To test something that relies on the actual work for Conn.Release being completed we must simply wait. 18 | // This function wraps the sleep so there is more meaning for the callers. 19 | func waitForReleaseToComplete() { 20 | time.Sleep(500 * time.Millisecond) 21 | } 22 | 23 | type execer interface { 24 | Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) 25 | } 26 | 27 | func testExec(t *testing.T, db execer) { 28 | results, err := db.Exec(context.Background(), "set time zone 'America/Chicago'") 29 | require.NoError(t, err) 30 | assert.EqualValues(t, "SET", results) 31 | } 32 | 33 | type queryer interface { 34 | Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) 35 | } 36 | 37 | func testQuery(t *testing.T, db queryer) { 38 | var sum, rowCount int32 39 | 40 | rows, err := db.Query(context.Background(), "select generate_series(1,$1)", 10) 41 | require.NoError(t, err) 42 | 43 | for rows.Next() { 44 | var n int32 45 | rows.Scan(&n) 46 | sum += n 47 | rowCount++ 48 | } 49 | 50 | assert.NoError(t, rows.Err()) 51 | assert.Equal(t, int32(10), rowCount) 52 | assert.Equal(t, int32(55), sum) 53 | } 54 | 55 | type queryRower interface { 56 | QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row 57 | } 58 | 59 | func testQueryRow(t *testing.T, db queryRower) { 60 | var what, who string 61 | err := db.QueryRow(context.Background(), "select 'hello', $1::text", "world").Scan(&what, &who) 62 | assert.NoError(t, err) 63 | assert.Equal(t, "hello", what) 64 | assert.Equal(t, "world", who) 65 | } 66 | 67 | type sendBatcher interface { 68 | SendBatch(context.Context, *pgx.Batch) pgx.BatchResults 69 | } 70 | 71 | func testSendBatch(t *testing.T, db sendBatcher) { 72 | batch := &pgx.Batch{} 73 | batch.Queue("select 1") 74 | batch.Queue("select 2") 75 | 76 | br := db.SendBatch(context.Background(), batch) 77 | 78 | var err error 79 | var n int32 80 | err = br.QueryRow().Scan(&n) 81 | assert.NoError(t, err) 82 | assert.EqualValues(t, 1, n) 83 | 84 | err = br.QueryRow().Scan(&n) 85 | assert.NoError(t, err) 86 | assert.EqualValues(t, 2, n) 87 | 88 | err = br.Close() 89 | assert.NoError(t, err) 90 | } 91 | 92 | type copyFromer interface { 93 | CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) 94 | } 95 | 96 | func testCopyFrom(t *testing.T, db interface { 97 | execer 98 | queryer 99 | copyFromer 100 | }) { 101 | _, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) 102 | require.NoError(t, err) 103 | 104 | tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) 105 | 106 | inputRows := [][]interface{}{ 107 | {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, 108 | {nil, nil, nil, nil, nil, nil, nil}, 109 | } 110 | 111 | copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) 112 | assert.NoError(t, err) 113 | assert.EqualValues(t, len(inputRows), copyCount) 114 | 115 | rows, err := db.Query(context.Background(), "select * from foo") 116 | assert.NoError(t, err) 117 | 118 | var outputRows [][]interface{} 119 | for rows.Next() { 120 | row, err := rows.Values() 121 | if err != nil { 122 | t.Errorf("Unexpected error for rows.Values(): %v", err) 123 | } 124 | outputRows = append(outputRows, row) 125 | } 126 | 127 | assert.NoError(t, rows.Err()) 128 | assert.Equal(t, inputRows, outputRows) 129 | } 130 | 131 | func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName string) { 132 | if !assert.NotNil(t, expected) { 133 | return 134 | } 135 | if !assert.NotNil(t, actual) { 136 | return 137 | } 138 | 139 | assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) 140 | 141 | // Can't test function equality, so just test that they are set or not. 142 | assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) 143 | assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName) 144 | assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName) 145 | 146 | assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName) 147 | assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName) 148 | assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName) 149 | assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName) 150 | assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName) 151 | assert.Equalf(t, expected.LazyConnect, actual.LazyConnect, "%s - LazyConnect", testName) 152 | 153 | assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName) 154 | } 155 | 156 | func assertConnConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) { 157 | if !assert.NotNil(t, expected) { 158 | return 159 | } 160 | if !assert.NotNil(t, actual) { 161 | return 162 | } 163 | 164 | assert.Equalf(t, expected.Logger, actual.Logger, "%s - Logger", testName) 165 | assert.Equalf(t, expected.LogLevel, actual.LogLevel, "%s - LogLevel", testName) 166 | assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName) 167 | 168 | // Can't test function equality, so just test that they are set or not. 169 | assert.Equalf(t, expected.BuildStatementCache == nil, actual.BuildStatementCache == nil, "%s - BuildStatementCache", testName) 170 | 171 | assert.Equalf(t, expected.PreferSimpleProtocol, actual.PreferSimpleProtocol, "%s - PreferSimpleProtocol", testName) 172 | 173 | assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName) 174 | assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName) 175 | assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName) 176 | assert.Equalf(t, expected.User, actual.User, "%s - User", testName) 177 | assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName) 178 | assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName) 179 | assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName) 180 | 181 | // Can't test function equality, so just test that they are set or not. 182 | assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName) 183 | assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) 184 | 185 | if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) { 186 | if expected.TLSConfig != nil { 187 | assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName) 188 | assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName) 189 | } 190 | } 191 | 192 | if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) { 193 | for i := range expected.Fallbacks { 194 | assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i) 195 | assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i) 196 | 197 | if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) { 198 | if expected.Fallbacks[i].TLSConfig != nil { 199 | assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName) 200 | assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName) 201 | } 202 | } 203 | } 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v4) 2 | [![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx) 3 | 4 | # pgx - PostgreSQL Driver and Toolkit 5 | 6 | pgx is a pure Go driver and toolkit for PostgreSQL. 7 | 8 | pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for. 9 | 10 | The driver component of pgx can be used alongside the standard `database/sql` package. 11 | 12 | The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol 13 | and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, 14 | proxies, load balancers, logical replication clients, etc. 15 | 16 | The current release of `pgx v4` requires Go modules. To use the previous version, checkout and vendor the `v3` branch. 17 | 18 | ## Example Usage 19 | 20 | ```go 21 | package main 22 | 23 | import ( 24 | "context" 25 | "fmt" 26 | "os" 27 | 28 | "github.com/jackc/pgx/v4" 29 | ) 30 | 31 | func main() { 32 | // urlExample := "postgres://username:password@localhost:5432/database_name" 33 | conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) 34 | if err != nil { 35 | fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err) 36 | os.Exit(1) 37 | } 38 | defer conn.Close(context.Background()) 39 | 40 | var name string 41 | var weight int64 42 | err = conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) 43 | if err != nil { 44 | fmt.Fprintf(os.Stderr, "QueryRow failed: %v\n", err) 45 | os.Exit(1) 46 | } 47 | 48 | fmt.Println(name, weight) 49 | } 50 | ``` 51 | 52 | See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information. 53 | 54 | ## Choosing Between the pgx and database/sql Interfaces 55 | 56 | It is recommended to use the pgx interface if: 57 | 1. The application only targets PostgreSQL. 58 | 2. No other libraries that require `database/sql` are in use. 59 | 60 | The pgx interface is faster and exposes more features. 61 | 62 | The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`, 63 | `float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the 64 | `database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses. 65 | 66 | ## Features 67 | 68 | pgx supports many features beyond what is available through `database/sql`: 69 | 70 | * Support for approximately 70 different PostgreSQL types 71 | * Automatic statement preparation and caching 72 | * Batch queries 73 | * Single-round trip query mode 74 | * Full TLS connection control 75 | * Binary format support for custom types (allows for much quicker encoding/decoding) 76 | * COPY protocol support for faster bulk data loads 77 | * Extendable logging support including built-in support for `log15adapter`, [`logrus`](https://github.com/sirupsen/logrus), [`zap`](https://github.com/uber-go/zap), and [`zerolog`](https://github.com/rs/zerolog) 78 | * Connection pool with after-connect hook for arbitrary connection setup 79 | * Listen / notify 80 | * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings 81 | * Hstore support 82 | * JSON and JSONB support 83 | * Maps `inet` and `cidr` PostgreSQL types to `net.IPNet` and `net.IP` 84 | * Large object support 85 | * NULL mapping to Null* struct or pointer to pointer 86 | * Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types 87 | * Notice response handling 88 | * Simulated nested transactions with savepoints 89 | 90 | ## Performance 91 | 92 | There are three areas in particular where pgx can provide a significant performance advantage over the standard 93 | `database/sql` interface and other drivers: 94 | 95 | 1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format. 96 | 2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an 97 | significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can 98 | perform nearly 3x the number of queries per second. 99 | 3. Batched queries - Multiple queries can be batched together to minimize network round trips. 100 | 101 | ## Testing 102 | 103 | pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` environment 104 | variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or DSN. In addition, the standard `PG*` environment 105 | variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable 106 | handling. 107 | 108 | ### Example Test Environment 109 | 110 | Connect to your PostgreSQL server and run: 111 | 112 | ``` 113 | create database pgx_test; 114 | ``` 115 | 116 | Connect to the newly-created database and run: 117 | 118 | ``` 119 | create domain uint64 as numeric(20,0); 120 | ``` 121 | 122 | Now, you can run the tests: 123 | 124 | ``` 125 | PGX_TEST_DATABASE="host=/var/run/postgresql database=pgx_test" go test ./... 126 | ``` 127 | 128 | In addition, there are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set. 129 | 130 | ## Supported Go and PostgreSQL Versions 131 | 132 | pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). 133 | 134 | ## Version Policy 135 | 136 | pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version. 137 | 138 | ## PGX Family Libraries 139 | 140 | pgx is the head of a family of PostgreSQL libraries. Many of these can be used independently. Many can also be accessed 141 | from pgx for lower-level control. 142 | 143 | ### [github.com/jackc/pgconn](https://github.com/jackc/pgconn) 144 | 145 | `pgconn` is a lower-level PostgreSQL database driver that operates at nearly the same level as the C library `libpq`. 146 | 147 | ### [github.com/jackc/pgx/v4/pgxpool](https://github.com/jackc/pgx/tree/master/pgxpool) 148 | 149 | `pgxpool` is a connection pool for pgx. pgx is entirely decoupled from its default pool implementation. This means that pgx can be used with a different pool or without any pool at all. 150 | 151 | ### [github.com/jackc/pgx/v4/stdlib](https://github.com/jackc/pgx/tree/master/stdlib) 152 | 153 | This is a `database/sql` compatibility layer for pgx. pgx can be used as a normal `database/sql` driver, but at any time, the native interface can be acquired for more performance or PostgreSQL specific functionality. 154 | 155 | ### [github.com/jackc/pgtype](https://github.com/jackc/pgtype) 156 | 157 | Over 70 PostgreSQL types are supported including `uuid`, `hstore`, `json`, `bytea`, `numeric`, `interval`, `inet`, and arrays. These types support `database/sql` interfaces and are usable outside of pgx. They are fully tested in pgx and pq. They also support a higher performance interface when used with the pgx driver. 158 | 159 | ### [github.com/jackc/pgproto3](https://github.com/jackc/pgproto3) 160 | 161 | pgproto3 provides standalone encoding and decoding of the PostgreSQL v3 wire protocol. This is useful for implementing very low level PostgreSQL tooling. 162 | 163 | ### [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) 164 | 165 | pglogrepl provides functionality to act as a client for PostgreSQL logical replication. 166 | 167 | ### [github.com/jackc/pgmock](https://github.com/jackc/pgmock) 168 | 169 | pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). 170 | 171 | ### [github.com/jackc/tern](https://github.com/jackc/tern) 172 | 173 | tern is a stand-alone SQL migration system. 174 | 175 | ### [github.com/jackc/pgerrcode](https://github.com/jackc/pgerrcode) 176 | 177 | pgerrcode contains constants for the PostgreSQL error codes. 178 | 179 | ## 3rd Party Libraries with PGX Support 180 | 181 | ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) 182 | 183 | Library for scanning data from a database into Go structs and more. 184 | 185 | ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) 186 | 187 | Adds GSSAPI / Kerberos authentication support. 188 | 189 | ### [https://github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) 190 | 191 | Adds support for [`github.com/google/uuid`](https://github.com/google/uuid). 192 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 4.16.1 (May 7, 2022) 2 | 3 | * Upgrade pgconn to v1.12.1 4 | * Fix explicitly prepared statements with describe statement cache mode 5 | 6 | # 4.16.0 (April 21, 2022) 7 | 8 | * Upgrade pgconn to v1.12.0 9 | * Upgrade pgproto3 to v2.3.0 10 | * Upgrade pgtype to v1.11.0 11 | * Fix: Do not panic when context cancelled while getting statement from cache. 12 | * Fix: Less memory pinning from old Rows. 13 | * Fix: Support '\r' line ending when sanitizing SQL comment. 14 | * Add pluggable GSSAPI support (Oliver Tan) 15 | 16 | # 4.15.0 (February 7, 2022) 17 | 18 | * Upgrade to pgconn v1.11.0 19 | * Upgrade to pgtype v1.10.0 20 | * Upgrade puddle to v1.2.1 21 | * Make BatchResults.Close safe to be called multiple times 22 | 23 | # 4.14.1 (November 28, 2021) 24 | 25 | * Upgrade pgtype to v1.9.1 (fixes unintentional change to timestamp binary decoding) 26 | * Start pgxpool background health check after initial connections 27 | 28 | # 4.14.0 (November 20, 2021) 29 | 30 | * Upgrade pgconn to v1.10.1 31 | * Upgrade pgproto3 to v2.2.0 32 | * Upgrade pgtype to v1.9.0 33 | * Upgrade puddle to v1.2.0 34 | * Add QueryFunc to BatchResults 35 | * Add context options to zerologadapter (Thomas Frössman) 36 | * Add zerologadapter.NewContextLogger (urso) 37 | * Eager initialize minpoolsize on connect (Daniel) 38 | * Unpin memory used by large queries immediately after use 39 | 40 | # 4.13.0 (July 24, 2021) 41 | 42 | * Trimmed pseudo-dependencies in Go modules from other packages tests 43 | * Upgrade pgconn -- context cancellation no longer will return a net.Error 44 | * Support time durations for simple protocol (Michael Darr) 45 | 46 | # 4.12.0 (July 10, 2021) 47 | 48 | * ResetSession hook is called before a connection is reused from pool for another query (Dmytro Haranzha) 49 | * stdlib: Add RandomizeHostOrderFunc (dkinder) 50 | * stdlib: add OptionBeforeConnect (dkinder) 51 | * stdlib: Do not reuse ConnConfig strings (Andrew Kimball) 52 | * stdlib: implement Conn.ResetSession (Jonathan Amsterdam) 53 | * Upgrade pgconn to v1.9.0 54 | * Upgrade pgtype to v1.8.0 55 | 56 | # 4.11.0 (March 25, 2021) 57 | 58 | * Add BeforeConnect callback to pgxpool.Config (Robert Froehlich) 59 | * Add Ping method to pgxpool.Conn (davidsbond) 60 | * Added a kitlog level log adapter (Fabrice Aneche) 61 | * Make ScanArgError public to allow identification of offending column (Pau Sanchez) 62 | * Add *pgxpool.AcquireFunc 63 | * Add BeginFunc and BeginTxFunc 64 | * Add prefer_simple_protocol to connection string 65 | * Add logging on CopyFrom (Patrick Hemmer) 66 | * Add comment support when sanitizing SQL queries (Rusakow Andrew) 67 | * Do not panic on double close of pgxpool.Pool (Matt Schultz) 68 | * Avoid panic on SendBatch on closed Tx (Matt Schultz) 69 | * Update pgconn to v1.8.1 70 | * Update pgtype to v1.7.0 71 | 72 | # 4.10.1 (December 19, 2020) 73 | 74 | * Fix panic on Query error with nil stmtcache. 75 | 76 | # 4.10.0 (December 3, 2020) 77 | 78 | * Add CopyFromSlice to simplify CopyFrom usage (Egon Elbre) 79 | * Remove broken prepared statements from stmtcache (Ethan Pailes) 80 | * stdlib: consider any Ping error as fatal 81 | * Update puddle to v1.1.3 - this fixes an issue where concurrent Acquires can hang when a connection cannot be established 82 | * Update pgtype to v1.6.2 83 | 84 | # 4.9.2 (November 3, 2020) 85 | 86 | The underlying library updates fix an issue where appending to a scanned slice could corrupt other data. 87 | 88 | * Update pgconn to v1.7.2 89 | * Update pgproto3 to v2.0.6 90 | 91 | # 4.9.1 (October 31, 2020) 92 | 93 | * Update pgconn to v1.7.1 94 | * Update pgtype to v1.6.1 95 | * Fix SendBatch of all prepared statements with statement cache disabled 96 | 97 | # 4.9.0 (September 26, 2020) 98 | 99 | * pgxpool now waits for connection cleanup to finish before making room in pool for another connection. This prevents temporarily exceeding max pool size. 100 | * Fix when scanning a column to nil to skip it on the first row but scanning it to a real value on a subsequent row. 101 | * Fix prefer simple protocol with prepared statements. (Jinzhu) 102 | * Fix FieldDescriptions not being available on Rows before calling Next the first time. 103 | * Various minor fixes in updated versions of pgconn, pgtype, and puddle. 104 | 105 | # 4.8.1 (July 29, 2020) 106 | 107 | * Update pgconn to v1.6.4 108 | * Fix deadlock on error after CommandComplete but before ReadyForQuery 109 | * Fix panic on parsing DSN with trailing '=' 110 | 111 | # 4.8.0 (July 22, 2020) 112 | 113 | * All argument types supported by native pgx should now also work through database/sql 114 | * Update pgconn to v1.6.3 115 | * Update pgtype to v1.4.2 116 | 117 | # 4.7.2 (July 14, 2020) 118 | 119 | * Improve performance of Columns() (zikaeroh) 120 | * Fix fatal Commit() failure not being considered fatal 121 | * Update pgconn to v1.6.2 122 | * Update pgtype to v1.4.1 123 | 124 | # 4.7.1 (June 29, 2020) 125 | 126 | * Fix stdlib decoding error with certain order and combination of fields 127 | 128 | # 4.7.0 (June 27, 2020) 129 | 130 | * Update pgtype to v1.4.0 131 | * Update pgconn to v1.6.1 132 | * Update puddle to v1.1.1 133 | * Fix context propagation with Tx commit and Rollback (georgysavva) 134 | * Add lazy connect option to pgxpool (georgysavva) 135 | * Fix connection leak if pgxpool.BeginTx() fail (Jean-Baptiste Bronisz) 136 | * Add native Go slice support for strings and numbers to simple protocol 137 | * stdlib add default timeouts for Conn.Close() and Stmt.Close() (georgysavva) 138 | * Assorted performance improvements especially with large result sets 139 | * Fix close pool on not lazy connect failure (Yegor Myskin) 140 | * Add Config copy (georgysavva) 141 | * Support SendBatch with Simple Protocol (Jordan Lewis) 142 | * Better error logging on rows close (Igor V. Kozinov) 143 | * Expose stdlib.Conn.Conn() to enable database/sql.Conn.Raw() 144 | * Improve unknown type support for database/sql 145 | * Fix transaction commit failure closing connection 146 | 147 | # 4.6.0 (March 30, 2020) 148 | 149 | * stdlib: Bail early if preloading rows.Next() results in rows.Err() (Bas van Beek) 150 | * Sanitize time to microsecond accuracy (Andrew Nicoll) 151 | * Update pgtype to v1.3.0 152 | * Update pgconn to v1.5.0 153 | * Update golang.org/x/crypto for security fix 154 | * Implement "verify-ca" SSL mode 155 | 156 | # 4.5.0 (March 7, 2020) 157 | 158 | * Update to pgconn v1.4.0 159 | * Fixes QueryRow with empty SQL 160 | * Adds PostgreSQL service file support 161 | * Add Len() to *pgx.Batch (WGH) 162 | * Better logging for individual batch items (Ben Bader) 163 | 164 | # 4.4.1 (February 14, 2020) 165 | 166 | * Update pgconn to v1.3.2 - better default read buffer size 167 | * Fix race in CopyFrom 168 | 169 | # 4.4.0 (February 5, 2020) 170 | 171 | * Update puddle to v1.1.0 - fixes possible deadlock when acquire is cancelled 172 | * Update pgconn to v1.3.1 - fixes CopyFrom deadlock when multiple NoticeResponse received during copy 173 | * Update pgtype to v1.2.0 174 | * Add MaxConnIdleTime to pgxpool (Patrick Ellul) 175 | * Add MinConns to pgxpool (Patrick Ellul) 176 | * Fix: stdlib.ReleaseConn closes connections left in invalid state 177 | 178 | # 4.3.0 (January 23, 2020) 179 | 180 | * Fix Rows.Values panic when unable to decode 181 | * Add Rows.Values support for unknown types 182 | * Add DriverContext support for stdlib (Alex Gaynor) 183 | * Update pgproto3 to v2.0.1 to never return an io.EOF as it would be misinterpreted by database/sql. Instead return io.UnexpectedEOF. 184 | 185 | # 4.2.1 (January 13, 2020) 186 | 187 | * Update pgconn to v1.2.1 (fixes context cancellation data race introduced in v1.2.0)) 188 | 189 | # 4.2.0 (January 11, 2020) 190 | 191 | * Update pgconn to v1.2.0. 192 | * Update pgtype to v1.1.0. 193 | * Return error instead of panic when wrong number of arguments passed to Exec. (malstoun) 194 | * Fix large objects functionality when PreferSimpleProtocol = true. 195 | * Restore GetDefaultDriver which existed in v3. (Johan Brandhorst) 196 | * Add RegisterConnConfig to stdlib which replaces the removed RegisterDriverConfig from v3. 197 | 198 | # 4.1.2 (October 22, 2019) 199 | 200 | * Fix dbSavepoint.Begin recursive self call 201 | * Upgrade pgtype to v1.0.2 - fix scan pointer to pointer 202 | 203 | # 4.1.1 (October 21, 2019) 204 | 205 | * Fix pgxpool Rows.CommandTag() infinite loop / typo 206 | 207 | # 4.1.0 (October 12, 2019) 208 | 209 | ## Potentially Breaking Changes 210 | 211 | Technically, two changes are breaking changes, but in practice these are extremely unlikely to break existing code. 212 | 213 | * Conn.Begin and Conn.BeginTx return a Tx interface instead of the internal dbTx struct. This is necessary for the Conn.Begin method to signature as other methods that begin a transaction. 214 | * Add Conn() to Tx interface. This is necessary to allow code using a Tx to access the *Conn (and pgconn.PgConn) on which the Tx is executing. 215 | 216 | ## Fixes 217 | 218 | * Releasing a busy connection closes the connection instead of returning an unusable connection to the pool 219 | * Do not mutate config.Config.OnNotification in connect 220 | 221 | # 4.0.1 (September 19, 2019) 222 | 223 | * Fix statement cache cleanup. 224 | * Corrected daterange OID. 225 | * Fix Tx when committing or rolling back multiple times in certain cases. 226 | * Improve documentation. 227 | 228 | # 4.0.0 (September 14, 2019) 229 | 230 | v4 is a major release with many significant changes some of which are breaking changes. The most significant are 231 | included below. 232 | 233 | * Simplified establishing a connection with a connection string. 234 | * All potentially blocking operations now require a context.Context. The non-context aware functions have been removed. 235 | * OIDs are hard-coded for known types. This saves the query on connection. 236 | * Context cancellations while network activity is in progress is now always fatal. Previously, it was sometimes recoverable. This led to increased complexity in pgx itself and in application code. 237 | * Go modules are required. 238 | * Errors are now implemented in the Go 1.13 style. 239 | * `Rows` and `Tx` are now interfaces. 240 | * The connection pool as been decoupled from pgx and is now a separate, included package (github.com/jackc/pgx/v4/pgxpool). 241 | * pgtype has been spun off to a separate package (github.com/jackc/pgtype). 242 | * pgproto3 has been spun off to a separate package (github.com/jackc/pgproto3/v2). 243 | * Logical replication support has been spun off to a separate package (github.com/jackc/pglogrepl). 244 | * Lower level PostgreSQL functionality is now implemented in a separate package (github.com/jackc/pgconn). 245 | * Tests are now configured with environment variables. 246 | * Conn has an automatic statement cache by default. 247 | * Batch interface has been simplified. 248 | * QueryArgs has been removed. 249 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/jackc/pgconn" 10 | "github.com/jackc/pgproto3/v2" 11 | "github.com/jackc/pgtype" 12 | ) 13 | 14 | // Rows is the result set returned from *Conn.Query. Rows must be closed before 15 | // the *Conn can be used again. Rows are closed by explicitly calling Close(), 16 | // calling Next() until it returns false, or when a fatal error occurs. 17 | // 18 | // Once a Rows is closed the only methods that may be called are Close(), Err(), and CommandTag(). 19 | // 20 | // Rows is an interface instead of a struct to allow tests to mock Query. However, 21 | // adding a method to an interface is technically a breaking change. Because of this 22 | // the Rows interface is partially excluded from semantic version requirements. 23 | // Methods will not be removed or changed, but new methods may be added. 24 | type Rows interface { 25 | // Close closes the rows, making the connection ready for use again. It is safe 26 | // to call Close after rows is already closed. 27 | Close() 28 | 29 | // Err returns any error that occurred while reading. 30 | Err() error 31 | 32 | // CommandTag returns the command tag from this query. It is only available after Rows is closed. 33 | CommandTag() pgconn.CommandTag 34 | 35 | FieldDescriptions() []pgproto3.FieldDescription 36 | 37 | // Next prepares the next row for reading. It returns true if there is another 38 | // row and false if no more rows are available. It automatically closes rows 39 | // when all rows are read. 40 | Next() bool 41 | 42 | // Scan reads the values from the current row into dest values positionally. 43 | // dest can include pointers to core types, values implementing the Scanner 44 | // interface, and nil. nil will skip the value entirely. It is an error to 45 | // call Scan without first calling Next() and checking that it returned true. 46 | Scan(dest ...interface{}) error 47 | 48 | // Values returns the decoded row values. As with Scan(), it is an error to 49 | // call Values without first calling Next() and checking that it returned 50 | // true. 51 | Values() ([]interface{}, error) 52 | 53 | // RawValues returns the unparsed bytes of the row values. The returned [][]byte is only valid until the next Next 54 | // call or the Rows is closed. However, the underlying byte data is safe to retain a reference to and mutate. 55 | RawValues() [][]byte 56 | } 57 | 58 | // Row is a convenience wrapper over Rows that is returned by QueryRow. 59 | // 60 | // Row is an interface instead of a struct to allow tests to mock QueryRow. However, 61 | // adding a method to an interface is technically a breaking change. Because of this 62 | // the Row interface is partially excluded from semantic version requirements. 63 | // Methods will not be removed or changed, but new methods may be added. 64 | type Row interface { 65 | // Scan works the same as Rows. with the following exceptions. If no 66 | // rows were found it returns ErrNoRows. If multiple rows are returned it 67 | // ignores all but the first. 68 | Scan(dest ...interface{}) error 69 | } 70 | 71 | // connRow implements the Row interface for Conn.QueryRow. 72 | type connRow connRows 73 | 74 | func (r *connRow) Scan(dest ...interface{}) (err error) { 75 | rows := (*connRows)(r) 76 | 77 | if rows.Err() != nil { 78 | return rows.Err() 79 | } 80 | 81 | if !rows.Next() { 82 | if rows.Err() == nil { 83 | return ErrNoRows 84 | } 85 | return rows.Err() 86 | } 87 | 88 | rows.Scan(dest...) 89 | rows.Close() 90 | return rows.Err() 91 | } 92 | 93 | type rowLog interface { 94 | shouldLog(lvl LogLevel) bool 95 | log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) 96 | } 97 | 98 | // connRows implements the Rows interface for Conn.Query. 99 | type connRows struct { 100 | ctx context.Context 101 | logger rowLog 102 | connInfo *pgtype.ConnInfo 103 | values [][]byte 104 | rowCount int 105 | err error 106 | commandTag pgconn.CommandTag 107 | startTime time.Time 108 | sql string 109 | args []interface{} 110 | closed bool 111 | conn *Conn 112 | 113 | resultReader *pgconn.ResultReader 114 | multiResultReader *pgconn.MultiResultReader 115 | 116 | scanPlans []pgtype.ScanPlan 117 | } 118 | 119 | func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription { 120 | return rows.resultReader.FieldDescriptions() 121 | } 122 | 123 | func (rows *connRows) Close() { 124 | if rows.closed { 125 | return 126 | } 127 | 128 | rows.closed = true 129 | 130 | if rows.resultReader != nil { 131 | var closeErr error 132 | rows.commandTag, closeErr = rows.resultReader.Close() 133 | if rows.err == nil { 134 | rows.err = closeErr 135 | } 136 | } 137 | 138 | if rows.multiResultReader != nil { 139 | closeErr := rows.multiResultReader.Close() 140 | if rows.err == nil { 141 | rows.err = closeErr 142 | } 143 | } 144 | 145 | if rows.logger != nil { 146 | endTime := time.Now() 147 | 148 | if rows.err == nil { 149 | if rows.logger.shouldLog(LogLevelInfo) { 150 | rows.logger.log(rows.ctx, LogLevelInfo, "Query", map[string]interface{}{"sql": rows.sql, "args": logQueryArgs(rows.args), "time": endTime.Sub(rows.startTime), "rowCount": rows.rowCount}) 151 | } 152 | } else { 153 | if rows.logger.shouldLog(LogLevelError) { 154 | rows.logger.log(rows.ctx, LogLevelError, "Query", map[string]interface{}{"err": rows.err, "sql": rows.sql, "time": endTime.Sub(rows.startTime), "args": logQueryArgs(rows.args)}) 155 | } 156 | if rows.err != nil && rows.conn.stmtcache != nil { 157 | rows.conn.stmtcache.StatementErrored(rows.sql, rows.err) 158 | } 159 | } 160 | } 161 | } 162 | 163 | func (rows *connRows) CommandTag() pgconn.CommandTag { 164 | return rows.commandTag 165 | } 166 | 167 | func (rows *connRows) Err() error { 168 | return rows.err 169 | } 170 | 171 | // fatal signals an error occurred after the query was sent to the server. It 172 | // closes the rows automatically. 173 | func (rows *connRows) fatal(err error) { 174 | if rows.err != nil { 175 | return 176 | } 177 | 178 | rows.err = err 179 | rows.Close() 180 | } 181 | 182 | func (rows *connRows) Next() bool { 183 | if rows.closed { 184 | return false 185 | } 186 | 187 | if rows.resultReader.NextRow() { 188 | rows.rowCount++ 189 | rows.values = rows.resultReader.Values() 190 | return true 191 | } else { 192 | rows.Close() 193 | return false 194 | } 195 | } 196 | 197 | func (rows *connRows) Scan(dest ...interface{}) error { 198 | ci := rows.connInfo 199 | fieldDescriptions := rows.FieldDescriptions() 200 | values := rows.values 201 | 202 | if len(fieldDescriptions) != len(values) { 203 | err := fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) 204 | rows.fatal(err) 205 | return err 206 | } 207 | if len(fieldDescriptions) != len(dest) { 208 | err := fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) 209 | rows.fatal(err) 210 | return err 211 | } 212 | 213 | if rows.scanPlans == nil { 214 | rows.scanPlans = make([]pgtype.ScanPlan, len(values)) 215 | for i := range dest { 216 | rows.scanPlans[i] = ci.PlanScan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, dest[i]) 217 | } 218 | } 219 | 220 | for i, dst := range dest { 221 | if dst == nil { 222 | continue 223 | } 224 | 225 | err := rows.scanPlans[i].Scan(ci, fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], dst) 226 | if err != nil { 227 | err = ScanArgError{ColumnIndex: i, Err: err} 228 | rows.fatal(err) 229 | return err 230 | } 231 | } 232 | 233 | return nil 234 | } 235 | 236 | func (rows *connRows) Values() ([]interface{}, error) { 237 | if rows.closed { 238 | return nil, errors.New("rows is closed") 239 | } 240 | 241 | values := make([]interface{}, 0, len(rows.FieldDescriptions())) 242 | 243 | for i := range rows.FieldDescriptions() { 244 | buf := rows.values[i] 245 | fd := &rows.FieldDescriptions()[i] 246 | 247 | if buf == nil { 248 | values = append(values, nil) 249 | continue 250 | } 251 | 252 | if dt, ok := rows.connInfo.DataTypeForOID(fd.DataTypeOID); ok { 253 | value := dt.Value 254 | 255 | switch fd.Format { 256 | case TextFormatCode: 257 | decoder, ok := value.(pgtype.TextDecoder) 258 | if !ok { 259 | decoder = &pgtype.GenericText{} 260 | } 261 | err := decoder.DecodeText(rows.connInfo, buf) 262 | if err != nil { 263 | rows.fatal(err) 264 | } 265 | values = append(values, decoder.(pgtype.Value).Get()) 266 | case BinaryFormatCode: 267 | decoder, ok := value.(pgtype.BinaryDecoder) 268 | if !ok { 269 | decoder = &pgtype.GenericBinary{} 270 | } 271 | err := decoder.DecodeBinary(rows.connInfo, buf) 272 | if err != nil { 273 | rows.fatal(err) 274 | } 275 | values = append(values, value.Get()) 276 | default: 277 | rows.fatal(errors.New("Unknown format code")) 278 | } 279 | } else { 280 | switch fd.Format { 281 | case TextFormatCode: 282 | decoder := &pgtype.GenericText{} 283 | err := decoder.DecodeText(rows.connInfo, buf) 284 | if err != nil { 285 | rows.fatal(err) 286 | } 287 | values = append(values, decoder.Get()) 288 | case BinaryFormatCode: 289 | decoder := &pgtype.GenericBinary{} 290 | err := decoder.DecodeBinary(rows.connInfo, buf) 291 | if err != nil { 292 | rows.fatal(err) 293 | } 294 | values = append(values, decoder.Get()) 295 | default: 296 | rows.fatal(errors.New("Unknown format code")) 297 | } 298 | } 299 | 300 | if rows.Err() != nil { 301 | return nil, rows.Err() 302 | } 303 | } 304 | 305 | return values, rows.Err() 306 | } 307 | 308 | func (rows *connRows) RawValues() [][]byte { 309 | return rows.values 310 | } 311 | 312 | type ScanArgError struct { 313 | ColumnIndex int 314 | Err error 315 | } 316 | 317 | func (e ScanArgError) Error() string { 318 | return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) 319 | } 320 | 321 | func (e ScanArgError) Unwrap() error { 322 | return e.Err 323 | } 324 | 325 | // ScanRow decodes raw row data into dest. It can be used to scan rows read from the lower level pgconn interface. 326 | // 327 | // connInfo - OID to Go type mapping. 328 | // fieldDescriptions - OID and format of values 329 | // values - the raw data as returned from the PostgreSQL server 330 | // dest - the destination that values will be decoded into 331 | func ScanRow(connInfo *pgtype.ConnInfo, fieldDescriptions []pgproto3.FieldDescription, values [][]byte, dest ...interface{}) error { 332 | if len(fieldDescriptions) != len(values) { 333 | return fmt.Errorf("number of field descriptions must equal number of values, got %d and %d", len(fieldDescriptions), len(values)) 334 | } 335 | if len(fieldDescriptions) != len(dest) { 336 | return fmt.Errorf("number of field descriptions must equal number of destinations, got %d and %d", len(fieldDescriptions), len(dest)) 337 | } 338 | 339 | for i, d := range dest { 340 | if d == nil { 341 | continue 342 | } 343 | 344 | err := connInfo.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) 345 | if err != nil { 346 | return ScanArgError{ColumnIndex: i, Err: err} 347 | } 348 | } 349 | 350 | return nil 351 | } 352 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package pgx is a PostgreSQL database driver. 2 | /* 3 | pgx provides lower level access to PostgreSQL than the standard database/sql. It remains as similar to the database/sql 4 | interface as possible while providing better speed and access to PostgreSQL specific features. Import 5 | github.com/jackc/pgx/v4/stdlib to use pgx as a database/sql compatible driver. 6 | 7 | Establishing a Connection 8 | 9 | The primary way of establishing a connection is with `pgx.Connect`. 10 | 11 | conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) 12 | 13 | The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified 14 | here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with 15 | `ConnectConfig`. 16 | 17 | config, err := pgx.ParseConfig(os.Getenv("DATABASE_URL")) 18 | if err != nil { 19 | // ... 20 | } 21 | config.Logger = log15adapter.NewLogger(log.New("module", "pgx")) 22 | 23 | conn, err := pgx.ConnectConfig(context.Background(), config) 24 | 25 | Connection Pool 26 | 27 | `*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use sub-package pgxpool for a 28 | concurrency safe connection pool. 29 | 30 | Query Interface 31 | 32 | pgx implements Query and Scan in the familiar database/sql style. 33 | 34 | var sum int32 35 | 36 | // Send the query to the server. The returned rows MUST be closed 37 | // before conn can be used again. 38 | rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10) 39 | if err != nil { 40 | return err 41 | } 42 | 43 | // rows.Close is called by rows.Next when all rows are read 44 | // or an error occurs in Next or Scan. So it may optionally be 45 | // omitted if nothing in the rows.Next loop can panic. It is 46 | // safe to close rows multiple times. 47 | defer rows.Close() 48 | 49 | // Iterate through the result set 50 | for rows.Next() { 51 | var n int32 52 | err = rows.Scan(&n) 53 | if err != nil { 54 | return err 55 | } 56 | sum += n 57 | } 58 | 59 | // Any errors encountered by rows.Next or rows.Scan will be returned here 60 | if rows.Err() != nil { 61 | return rows.Err() 62 | } 63 | 64 | // No errors found - do something with sum 65 | 66 | pgx also implements QueryRow in the same style as database/sql. 67 | 68 | var name string 69 | var weight int64 70 | err := conn.QueryRow(context.Background(), "select name, weight from widgets where id=$1", 42).Scan(&name, &weight) 71 | if err != nil { 72 | return err 73 | } 74 | 75 | Use Exec to execute a query that does not return a result set. 76 | 77 | commandTag, err := conn.Exec(context.Background(), "delete from widgets where id=$1", 42) 78 | if err != nil { 79 | return err 80 | } 81 | if commandTag.RowsAffected() != 1 { 82 | return errors.New("No row found to delete") 83 | } 84 | 85 | QueryFunc can be used to execute a callback function for every row. This is often easier to use than Query. 86 | 87 | var sum, n int32 88 | _, err = conn.QueryFunc( 89 | context.Background(), 90 | "select generate_series(1,$1)", 91 | []interface{}{10}, 92 | []interface{}{&n}, 93 | func(pgx.QueryFuncRow) error { 94 | sum += n 95 | return nil 96 | }, 97 | ) 98 | if err != nil { 99 | return err 100 | } 101 | 102 | Base Type Mapping 103 | 104 | pgx maps between all common base types directly between Go and PostgreSQL. In particular: 105 | 106 | Go PostgreSQL 107 | ----------------------- 108 | string varchar 109 | text 110 | 111 | // Integers are automatically be converted to any other integer type if 112 | // it can be done without overflow or underflow. 113 | int8 114 | int16 smallint 115 | int32 int 116 | int64 bigint 117 | int 118 | uint8 119 | uint16 120 | uint32 121 | uint64 122 | uint 123 | 124 | // Floats are strict and do not automatically convert like integers. 125 | float32 float4 126 | float64 float8 127 | 128 | time.Time date 129 | timestamp 130 | timestamptz 131 | 132 | []byte bytea 133 | 134 | 135 | Null Mapping 136 | 137 | pgx can map nulls in two ways. The first is package pgtype provides types that have a data field and a status field. 138 | They work in a similar fashion to database/sql. The second is to use a pointer to a pointer. 139 | 140 | var foo pgtype.Varchar 141 | var bar *string 142 | err := conn.QueryRow("select foo, bar from widgets where id=$1", 42).Scan(&foo, &bar) 143 | if err != nil { 144 | return err 145 | } 146 | 147 | Array Mapping 148 | 149 | pgx maps between int16, int32, int64, float32, float64, and string Go slices and the equivalent PostgreSQL array type. 150 | Go slices of native types do not support nulls, so if a PostgreSQL array that contains a null is read into a native Go 151 | slice an error will occur. The pgtype package includes many more array types for PostgreSQL types that do not directly 152 | map to native Go types. 153 | 154 | JSON and JSONB Mapping 155 | 156 | pgx includes built-in support to marshal and unmarshal between Go types and the PostgreSQL JSON and JSONB. 157 | 158 | Inet and CIDR Mapping 159 | 160 | pgx encodes from net.IPNet to and from inet and cidr PostgreSQL types. In addition, as a convenience pgx will encode 161 | from a net.IP; it will assume a /32 netmask for IPv4 and a /128 for IPv6. 162 | 163 | Custom Type Support 164 | 165 | pgx includes support for the common data types like integers, floats, strings, dates, and times that have direct 166 | mappings between Go and SQL. In addition, pgx uses the github.com/jackc/pgtype library to support more types. See 167 | documention for that library for instructions on how to implement custom types. 168 | 169 | See example_custom_type_test.go for an example of a custom type for the PostgreSQL point type. 170 | 171 | pgx also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer 172 | interfaces. 173 | 174 | If pgx does cannot natively encode a type and that type is a renamed type (e.g. type MyTime time.Time) pgx will attempt 175 | to encode the underlying type. While this is usually desired behavior it can produce surprising behavior if one the 176 | underlying type and the renamed type each implement database/sql interfaces and the other implements pgx interfaces. It 177 | is recommended that this situation be avoided by implementing pgx interfaces on the renamed type. 178 | 179 | Composite types and row values 180 | 181 | Row values and composite types are represented as pgtype.Record (https://pkg.go.dev/github.com/jackc/pgtype?tab=doc#Record). 182 | It is possible to get values of your custom type by implementing DecodeBinary interface. Decoding into 183 | pgtype.Record first can simplify process by avoiding dealing with raw protocol directly. 184 | 185 | For example: 186 | 187 | type MyType struct { 188 | a int // NULL will cause decoding error 189 | b *string // there can be NULL in this position in SQL 190 | } 191 | 192 | func (t *MyType) DecodeBinary(ci *pgtype.ConnInfo, src []byte) error { 193 | r := pgtype.Record{ 194 | Fields: []pgtype.Value{&pgtype.Int4{}, &pgtype.Text{}}, 195 | } 196 | 197 | if err := r.DecodeBinary(ci, src); err != nil { 198 | return err 199 | } 200 | 201 | if r.Status != pgtype.Present { 202 | return errors.New("BUG: decoding should not be called on NULL value") 203 | } 204 | 205 | a := r.Fields[0].(*pgtype.Int4) 206 | b := r.Fields[1].(*pgtype.Text) 207 | 208 | // type compatibility is checked by AssignTo 209 | // only lossless assignments will succeed 210 | if err := a.AssignTo(&t.a); err != nil { 211 | return err 212 | } 213 | 214 | // AssignTo also deals with null value handling 215 | if err := b.AssignTo(&t.b); err != nil { 216 | return err 217 | } 218 | return nil 219 | } 220 | 221 | result := MyType{} 222 | err := conn.QueryRow(context.Background(), "select row(1, 'foo'::text)", pgx.QueryResultFormats{pgx.BinaryFormatCode}).Scan(&r) 223 | 224 | Raw Bytes Mapping 225 | 226 | []byte passed as arguments to Query, QueryRow, and Exec are passed unmodified to PostgreSQL. 227 | 228 | Transactions 229 | 230 | Transactions are started by calling Begin. 231 | 232 | tx, err := conn.Begin(context.Background()) 233 | if err != nil { 234 | return err 235 | } 236 | // Rollback is safe to call even if the tx is already closed, so if 237 | // the tx commits successfully, this is a no-op 238 | defer tx.Rollback(context.Background()) 239 | 240 | _, err = tx.Exec(context.Background(), "insert into foo(id) values (1)") 241 | if err != nil { 242 | return err 243 | } 244 | 245 | err = tx.Commit(context.Background()) 246 | if err != nil { 247 | return err 248 | } 249 | 250 | The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. 251 | These are internally implemented with savepoints. 252 | 253 | Use BeginTx to control the transaction mode. 254 | 255 | BeginFunc and BeginTxFunc are variants that begin a transaction, execute a function, and commit or rollback the 256 | transaction depending on the return value of the function. These can be simpler and less error prone to use. 257 | 258 | err = conn.BeginFunc(context.Background(), func(tx pgx.Tx) error { 259 | _, err := tx.Exec(context.Background(), "insert into foo(id) values (1)") 260 | return err 261 | }) 262 | if err != nil { 263 | return err 264 | } 265 | 266 | Prepared Statements 267 | 268 | Prepared statements can be manually created with the Prepare method. However, this is rarely necessary because pgx 269 | includes an automatic statement cache by default. Queries run through the normal Query, QueryRow, and Exec functions are 270 | automatically prepared on first execution and the prepared statement is reused on subsequent executions. See ParseConfig 271 | for information on how to customize or disable the statement cache. 272 | 273 | Copy Protocol 274 | 275 | Use CopyFrom to efficiently insert multiple rows at a time using the PostgreSQL copy protocol. CopyFrom accepts a 276 | CopyFromSource interface. If the data is already in a [][]interface{} use CopyFromRows to wrap it in a CopyFromSource 277 | interface. Or implement CopyFromSource to avoid buffering the entire data set in memory. 278 | 279 | rows := [][]interface{}{ 280 | {"John", "Smith", int32(36)}, 281 | {"Jane", "Doe", int32(29)}, 282 | } 283 | 284 | copyCount, err := conn.CopyFrom( 285 | context.Background(), 286 | pgx.Identifier{"people"}, 287 | []string{"first_name", "last_name", "age"}, 288 | pgx.CopyFromRows(rows), 289 | ) 290 | 291 | When you already have a typed array using CopyFromSlice can be more convenient. 292 | 293 | rows := []User{ 294 | {"John", "Smith", 36}, 295 | {"Jane", "Doe", 29}, 296 | } 297 | 298 | copyCount, err := conn.CopyFrom( 299 | context.Background(), 300 | pgx.Identifier{"people"}, 301 | []string{"first_name", "last_name", "age"}, 302 | pgx.CopyFromSlice(len(rows), func(i int) ([]interface{}, error) { 303 | return []interface{}{rows[i].FirstName, rows[i].LastName, rows[i].Age}, nil 304 | }), 305 | ) 306 | 307 | CopyFrom can be faster than an insert with as few as 5 rows. 308 | 309 | Listen and Notify 310 | 311 | pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a 312 | notification is received or the context is canceled. 313 | 314 | _, err := conn.Exec(context.Background(), "listen channelname") 315 | if err != nil { 316 | return nil 317 | } 318 | 319 | if notification, err := conn.WaitForNotification(context.Background()); err != nil { 320 | // do something with notification 321 | } 322 | 323 | 324 | Logging 325 | 326 | pgx defines a simple logger interface. Connections optionally accept a logger that satisfies this interface. Set 327 | LogLevel to control logging verbosity. Adapters for github.com/inconshreveable/log15, github.com/sirupsen/logrus, 328 | go.uber.org/zap, github.com/rs/zerolog, and the testing log are provided in the log directory. 329 | 330 | Lower Level PostgreSQL Functionality 331 | 332 | pgx is implemented on top of github.com/jackc/pgconn a lower level PostgreSQL driver. The Conn.PgConn() method can be 333 | used to access this lower layer. 334 | 335 | PgBouncer 336 | 337 | pgx is compatible with PgBouncer in two modes. One is when the connection has a statement cache in "describe" mode. The 338 | other is when the connection is using the simple protocol. This can be set with the PreferSimpleProtocol config option. 339 | */ 340 | package pgx 341 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package pgx 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "strconv" 9 | 10 | "github.com/jackc/pgconn" 11 | ) 12 | 13 | // TxIsoLevel is the transaction isolation level (serializable, repeatable read, read committed or read uncommitted) 14 | type TxIsoLevel string 15 | 16 | // Transaction isolation levels 17 | const ( 18 | Serializable TxIsoLevel = "serializable" 19 | RepeatableRead TxIsoLevel = "repeatable read" 20 | ReadCommitted TxIsoLevel = "read committed" 21 | ReadUncommitted TxIsoLevel = "read uncommitted" 22 | ) 23 | 24 | // TxAccessMode is the transaction access mode (read write or read only) 25 | type TxAccessMode string 26 | 27 | // Transaction access modes 28 | const ( 29 | ReadWrite TxAccessMode = "read write" 30 | ReadOnly TxAccessMode = "read only" 31 | ) 32 | 33 | // TxDeferrableMode is the transaction deferrable mode (deferrable or not deferrable) 34 | type TxDeferrableMode string 35 | 36 | // Transaction deferrable modes 37 | const ( 38 | Deferrable TxDeferrableMode = "deferrable" 39 | NotDeferrable TxDeferrableMode = "not deferrable" 40 | ) 41 | 42 | // TxOptions are transaction modes within a transaction block 43 | type TxOptions struct { 44 | IsoLevel TxIsoLevel 45 | AccessMode TxAccessMode 46 | DeferrableMode TxDeferrableMode 47 | } 48 | 49 | var emptyTxOptions TxOptions 50 | 51 | func (txOptions TxOptions) beginSQL() string { 52 | if txOptions == emptyTxOptions { 53 | return "begin" 54 | } 55 | buf := &bytes.Buffer{} 56 | buf.WriteString("begin") 57 | if txOptions.IsoLevel != "" { 58 | fmt.Fprintf(buf, " isolation level %s", txOptions.IsoLevel) 59 | } 60 | if txOptions.AccessMode != "" { 61 | fmt.Fprintf(buf, " %s", txOptions.AccessMode) 62 | } 63 | if txOptions.DeferrableMode != "" { 64 | fmt.Fprintf(buf, " %s", txOptions.DeferrableMode) 65 | } 66 | 67 | return buf.String() 68 | } 69 | 70 | var ErrTxClosed = errors.New("tx is closed") 71 | 72 | // ErrTxCommitRollback occurs when an error has occurred in a transaction and 73 | // Commit() is called. PostgreSQL accepts COMMIT on aborted transactions, but 74 | // it is treated as ROLLBACK. 75 | var ErrTxCommitRollback = errors.New("commit unexpectedly resulted in rollback") 76 | 77 | // Begin starts a transaction. Unlike database/sql, the context only affects the begin command. i.e. there is no 78 | // auto-rollback on context cancellation. 79 | func (c *Conn) Begin(ctx context.Context) (Tx, error) { 80 | return c.BeginTx(ctx, TxOptions{}) 81 | } 82 | 83 | // BeginTx starts a transaction with txOptions determining the transaction mode. Unlike database/sql, the context only 84 | // affects the begin command. i.e. there is no auto-rollback on context cancellation. 85 | func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { 86 | _, err := c.Exec(ctx, txOptions.beginSQL()) 87 | if err != nil { 88 | // begin should never fail unless there is an underlying connection issue or 89 | // a context timeout. In either case, the connection is possibly broken. 90 | c.die(errors.New("failed to begin transaction")) 91 | return nil, err 92 | } 93 | 94 | return &dbTx{conn: c}, nil 95 | } 96 | 97 | // BeginFunc starts a transaction and calls f. If f does not return an error the transaction is committed. If f returns 98 | // an error the transaction is rolled back. The context will be used when executing the transaction control statements 99 | // (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect the execution of f. 100 | func (c *Conn) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { 101 | return c.BeginTxFunc(ctx, TxOptions{}, f) 102 | } 103 | 104 | // BeginTxFunc starts a transaction with txOptions determining the transaction mode and calls f. If f does not return 105 | // an error the transaction is committed. If f returns an error the transaction is rolled back. The context will be 106 | // used when executing the transaction control statements (BEGIN, ROLLBACK, and COMMIT) but does not otherwise affect 107 | // the execution of f. 108 | func (c *Conn) BeginTxFunc(ctx context.Context, txOptions TxOptions, f func(Tx) error) (err error) { 109 | var tx Tx 110 | tx, err = c.BeginTx(ctx, txOptions) 111 | if err != nil { 112 | return err 113 | } 114 | defer func() { 115 | rollbackErr := tx.Rollback(ctx) 116 | if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { 117 | err = rollbackErr 118 | } 119 | }() 120 | 121 | fErr := f(tx) 122 | if fErr != nil { 123 | _ = tx.Rollback(ctx) // ignore rollback error as there is already an error to return 124 | return fErr 125 | } 126 | 127 | return tx.Commit(ctx) 128 | } 129 | 130 | // Tx represents a database transaction. 131 | // 132 | // Tx is an interface instead of a struct to enable connection pools to be implemented without relying on internal pgx 133 | // state, to support pseudo-nested transactions with savepoints, and to allow tests to mock transactions. However, 134 | // adding a method to an interface is technically a breaking change. If new methods are added to Conn it may be 135 | // desirable to add them to Tx as well. Because of this the Tx interface is partially excluded from semantic version 136 | // requirements. Methods will not be removed or changed, but new methods may be added. 137 | type Tx interface { 138 | // Begin starts a pseudo nested transaction. 139 | Begin(ctx context.Context) (Tx, error) 140 | 141 | // BeginFunc starts a pseudo nested transaction and executes f. If f does not return an err the pseudo nested 142 | // transaction will be committed. If it does then it will be rolled back. 143 | BeginFunc(ctx context.Context, f func(Tx) error) (err error) 144 | 145 | // Commit commits the transaction if this is a real transaction or releases the savepoint if this is a pseudo nested 146 | // transaction. Commit will return ErrTxClosed if the Tx is already closed, but is otherwise safe to call multiple 147 | // times. If the commit fails with a rollback status (e.g. the transaction was already in a broken state) then 148 | // ErrTxCommitRollback will be returned. 149 | Commit(ctx context.Context) error 150 | 151 | // Rollback rolls back the transaction if this is a real transaction or rolls back to the savepoint if this is a 152 | // pseudo nested transaction. Rollback will return ErrTxClosed if the Tx is already closed, but is otherwise safe to 153 | // call multiple times. Hence, a defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error 154 | // condition. Any other failure of a real transaction will result in the connection being closed. 155 | Rollback(ctx context.Context) error 156 | 157 | CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) 158 | SendBatch(ctx context.Context, b *Batch) BatchResults 159 | LargeObjects() LargeObjects 160 | 161 | Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) 162 | 163 | Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) 164 | Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) 165 | QueryRow(ctx context.Context, sql string, args ...interface{}) Row 166 | QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) 167 | 168 | // Conn returns the underlying *Conn that on which this transaction is executing. 169 | Conn() *Conn 170 | } 171 | 172 | // dbTx represents a database transaction. 173 | // 174 | // All dbTx methods return ErrTxClosed if Commit or Rollback has already been 175 | // called on the dbTx. 176 | type dbTx struct { 177 | conn *Conn 178 | err error 179 | savepointNum int64 180 | closed bool 181 | } 182 | 183 | // Begin starts a pseudo nested transaction implemented with a savepoint. 184 | func (tx *dbTx) Begin(ctx context.Context) (Tx, error) { 185 | if tx.closed { 186 | return nil, ErrTxClosed 187 | } 188 | 189 | tx.savepointNum++ 190 | _, err := tx.conn.Exec(ctx, "savepoint sp_"+strconv.FormatInt(tx.savepointNum, 10)) 191 | if err != nil { 192 | return nil, err 193 | } 194 | 195 | return &dbSimulatedNestedTx{tx: tx, savepointNum: tx.savepointNum}, nil 196 | } 197 | 198 | func (tx *dbTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { 199 | if tx.closed { 200 | return ErrTxClosed 201 | } 202 | 203 | var savepoint Tx 204 | savepoint, err = tx.Begin(ctx) 205 | if err != nil { 206 | return err 207 | } 208 | defer func() { 209 | rollbackErr := savepoint.Rollback(ctx) 210 | if rollbackErr != nil && !errors.Is(rollbackErr, ErrTxClosed) { 211 | err = rollbackErr 212 | } 213 | }() 214 | 215 | fErr := f(savepoint) 216 | if fErr != nil { 217 | _ = savepoint.Rollback(ctx) // ignore rollback error as there is already an error to return 218 | return fErr 219 | } 220 | 221 | return savepoint.Commit(ctx) 222 | } 223 | 224 | // Commit commits the transaction. 225 | func (tx *dbTx) Commit(ctx context.Context) error { 226 | if tx.closed { 227 | return ErrTxClosed 228 | } 229 | 230 | commandTag, err := tx.conn.Exec(ctx, "commit") 231 | tx.closed = true 232 | if err != nil { 233 | if tx.conn.PgConn().TxStatus() != 'I' { 234 | _ = tx.conn.Close(ctx) // already have error to return 235 | } 236 | return err 237 | } 238 | if string(commandTag) == "ROLLBACK" { 239 | return ErrTxCommitRollback 240 | } 241 | 242 | return nil 243 | } 244 | 245 | // Rollback rolls back the transaction. Rollback will return ErrTxClosed if the 246 | // Tx is already closed, but is otherwise safe to call multiple times. Hence, a 247 | // defer tx.Rollback() is safe even if tx.Commit() will be called first in a 248 | // non-error condition. 249 | func (tx *dbTx) Rollback(ctx context.Context) error { 250 | if tx.closed { 251 | return ErrTxClosed 252 | } 253 | 254 | _, err := tx.conn.Exec(ctx, "rollback") 255 | tx.closed = true 256 | if err != nil { 257 | // A rollback failure leaves the connection in an undefined state 258 | tx.conn.die(fmt.Errorf("rollback failed: %w", err)) 259 | return err 260 | } 261 | 262 | return nil 263 | } 264 | 265 | // Exec delegates to the underlying *Conn 266 | func (tx *dbTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { 267 | return tx.conn.Exec(ctx, sql, arguments...) 268 | } 269 | 270 | // Prepare delegates to the underlying *Conn 271 | func (tx *dbTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { 272 | if tx.closed { 273 | return nil, ErrTxClosed 274 | } 275 | 276 | return tx.conn.Prepare(ctx, name, sql) 277 | } 278 | 279 | // Query delegates to the underlying *Conn 280 | func (tx *dbTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { 281 | if tx.closed { 282 | // Because checking for errors can be deferred to the *Rows, build one with the error 283 | err := ErrTxClosed 284 | return &connRows{closed: true, err: err}, err 285 | } 286 | 287 | return tx.conn.Query(ctx, sql, args...) 288 | } 289 | 290 | // QueryRow delegates to the underlying *Conn 291 | func (tx *dbTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { 292 | rows, _ := tx.Query(ctx, sql, args...) 293 | return (*connRow)(rows.(*connRows)) 294 | } 295 | 296 | // QueryFunc delegates to the underlying *Conn. 297 | func (tx *dbTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { 298 | if tx.closed { 299 | return nil, ErrTxClosed 300 | } 301 | 302 | return tx.conn.QueryFunc(ctx, sql, args, scans, f) 303 | } 304 | 305 | // CopyFrom delegates to the underlying *Conn 306 | func (tx *dbTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { 307 | if tx.closed { 308 | return 0, ErrTxClosed 309 | } 310 | 311 | return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc) 312 | } 313 | 314 | // SendBatch delegates to the underlying *Conn 315 | func (tx *dbTx) SendBatch(ctx context.Context, b *Batch) BatchResults { 316 | if tx.closed { 317 | return &batchResults{err: ErrTxClosed} 318 | } 319 | 320 | return tx.conn.SendBatch(ctx, b) 321 | } 322 | 323 | // LargeObjects returns a LargeObjects instance for the transaction. 324 | func (tx *dbTx) LargeObjects() LargeObjects { 325 | return LargeObjects{tx: tx} 326 | } 327 | 328 | func (tx *dbTx) Conn() *Conn { 329 | return tx.conn 330 | } 331 | 332 | // dbSimulatedNestedTx represents a simulated nested transaction implemented by a savepoint. 333 | type dbSimulatedNestedTx struct { 334 | tx Tx 335 | savepointNum int64 336 | closed bool 337 | } 338 | 339 | // Begin starts a pseudo nested transaction implemented with a savepoint. 340 | func (sp *dbSimulatedNestedTx) Begin(ctx context.Context) (Tx, error) { 341 | if sp.closed { 342 | return nil, ErrTxClosed 343 | } 344 | 345 | return sp.tx.Begin(ctx) 346 | } 347 | 348 | func (sp *dbSimulatedNestedTx) BeginFunc(ctx context.Context, f func(Tx) error) (err error) { 349 | if sp.closed { 350 | return ErrTxClosed 351 | } 352 | 353 | return sp.tx.BeginFunc(ctx, f) 354 | } 355 | 356 | // Commit releases the savepoint essentially committing the pseudo nested transaction. 357 | func (sp *dbSimulatedNestedTx) Commit(ctx context.Context) error { 358 | if sp.closed { 359 | return ErrTxClosed 360 | } 361 | 362 | _, err := sp.Exec(ctx, "release savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) 363 | sp.closed = true 364 | return err 365 | } 366 | 367 | // Rollback rolls back to the savepoint essentially rolling back the pseudo nested transaction. Rollback will return 368 | // ErrTxClosed if the dbSavepoint is already closed, but is otherwise safe to call multiple times. Hence, a defer sp.Rollback() 369 | // is safe even if sp.Commit() will be called first in a non-error condition. 370 | func (sp *dbSimulatedNestedTx) Rollback(ctx context.Context) error { 371 | if sp.closed { 372 | return ErrTxClosed 373 | } 374 | 375 | _, err := sp.Exec(ctx, "rollback to savepoint sp_"+strconv.FormatInt(sp.savepointNum, 10)) 376 | sp.closed = true 377 | return err 378 | } 379 | 380 | // Exec delegates to the underlying Tx 381 | func (sp *dbSimulatedNestedTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { 382 | if sp.closed { 383 | return nil, ErrTxClosed 384 | } 385 | 386 | return sp.tx.Exec(ctx, sql, arguments...) 387 | } 388 | 389 | // Prepare delegates to the underlying Tx 390 | func (sp *dbSimulatedNestedTx) Prepare(ctx context.Context, name, sql string) (*pgconn.StatementDescription, error) { 391 | if sp.closed { 392 | return nil, ErrTxClosed 393 | } 394 | 395 | return sp.tx.Prepare(ctx, name, sql) 396 | } 397 | 398 | // Query delegates to the underlying Tx 399 | func (sp *dbSimulatedNestedTx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { 400 | if sp.closed { 401 | // Because checking for errors can be deferred to the *Rows, build one with the error 402 | err := ErrTxClosed 403 | return &connRows{closed: true, err: err}, err 404 | } 405 | 406 | return sp.tx.Query(ctx, sql, args...) 407 | } 408 | 409 | // QueryRow delegates to the underlying Tx 410 | func (sp *dbSimulatedNestedTx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { 411 | rows, _ := sp.Query(ctx, sql, args...) 412 | return (*connRow)(rows.(*connRows)) 413 | } 414 | 415 | // QueryFunc delegates to the underlying Tx. 416 | func (sp *dbSimulatedNestedTx) QueryFunc(ctx context.Context, sql string, args []interface{}, scans []interface{}, f func(QueryFuncRow) error) (pgconn.CommandTag, error) { 417 | if sp.closed { 418 | return nil, ErrTxClosed 419 | } 420 | 421 | return sp.tx.QueryFunc(ctx, sql, args, scans, f) 422 | } 423 | 424 | // CopyFrom delegates to the underlying *Conn 425 | func (sp *dbSimulatedNestedTx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { 426 | if sp.closed { 427 | return 0, ErrTxClosed 428 | } 429 | 430 | return sp.tx.CopyFrom(ctx, tableName, columnNames, rowSrc) 431 | } 432 | 433 | // SendBatch delegates to the underlying *Conn 434 | func (sp *dbSimulatedNestedTx) SendBatch(ctx context.Context, b *Batch) BatchResults { 435 | if sp.closed { 436 | return &batchResults{err: ErrTxClosed} 437 | } 438 | 439 | return sp.tx.SendBatch(ctx, b) 440 | } 441 | 442 | func (sp *dbSimulatedNestedTx) LargeObjects() LargeObjects { 443 | return LargeObjects{tx: sp} 444 | } 445 | 446 | func (sp *dbSimulatedNestedTx) Conn() *Conn { 447 | return sp.tx.Conn() 448 | } 449 | -------------------------------------------------------------------------------- /copy_from_test.go: -------------------------------------------------------------------------------- 1 | package pgx_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "reflect" 8 | "testing" 9 | "time" 10 | 11 | "github.com/jackc/pgconn" 12 | "github.com/jackc/pgx/v4" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestConnCopyFromSmall(t *testing.T) { 17 | t.Parallel() 18 | 19 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 20 | defer closeConn(t, conn) 21 | 22 | mustExec(t, conn, `create temporary table foo( 23 | a int2, 24 | b int4, 25 | c int8, 26 | d varchar, 27 | e text, 28 | f date, 29 | g timestamptz 30 | )`) 31 | 32 | tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) 33 | 34 | inputRows := [][]interface{}{ 35 | {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, 36 | {nil, nil, nil, nil, nil, nil, nil}, 37 | } 38 | 39 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) 40 | if err != nil { 41 | t.Errorf("Unexpected error for CopyFrom: %v", err) 42 | } 43 | if int(copyCount) != len(inputRows) { 44 | t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) 45 | } 46 | 47 | rows, err := conn.Query(context.Background(), "select * from foo") 48 | if err != nil { 49 | t.Errorf("Unexpected error for Query: %v", err) 50 | } 51 | 52 | var outputRows [][]interface{} 53 | for rows.Next() { 54 | row, err := rows.Values() 55 | if err != nil { 56 | t.Errorf("Unexpected error for rows.Values(): %v", err) 57 | } 58 | outputRows = append(outputRows, row) 59 | } 60 | 61 | if rows.Err() != nil { 62 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 63 | } 64 | 65 | if !reflect.DeepEqual(inputRows, outputRows) { 66 | t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) 67 | } 68 | 69 | ensureConnValid(t, conn) 70 | } 71 | 72 | func TestConnCopyFromSliceSmall(t *testing.T) { 73 | t.Parallel() 74 | 75 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 76 | defer closeConn(t, conn) 77 | 78 | mustExec(t, conn, `create temporary table foo( 79 | a int2, 80 | b int4, 81 | c int8, 82 | d varchar, 83 | e text, 84 | f date, 85 | g timestamptz 86 | )`) 87 | 88 | tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) 89 | 90 | inputRows := [][]interface{}{ 91 | {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, 92 | {nil, nil, nil, nil, nil, nil, nil}, 93 | } 94 | 95 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, 96 | pgx.CopyFromSlice(len(inputRows), func(i int) ([]interface{}, error) { 97 | return inputRows[i], nil 98 | })) 99 | if err != nil { 100 | t.Errorf("Unexpected error for CopyFrom: %v", err) 101 | } 102 | if int(copyCount) != len(inputRows) { 103 | t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) 104 | } 105 | 106 | rows, err := conn.Query(context.Background(), "select * from foo") 107 | if err != nil { 108 | t.Errorf("Unexpected error for Query: %v", err) 109 | } 110 | 111 | var outputRows [][]interface{} 112 | for rows.Next() { 113 | row, err := rows.Values() 114 | if err != nil { 115 | t.Errorf("Unexpected error for rows.Values(): %v", err) 116 | } 117 | outputRows = append(outputRows, row) 118 | } 119 | 120 | if rows.Err() != nil { 121 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 122 | } 123 | 124 | if !reflect.DeepEqual(inputRows, outputRows) { 125 | t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) 126 | } 127 | 128 | ensureConnValid(t, conn) 129 | } 130 | 131 | func TestConnCopyFromLarge(t *testing.T) { 132 | t.Parallel() 133 | 134 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 135 | defer closeConn(t, conn) 136 | 137 | skipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)") 138 | 139 | mustExec(t, conn, `create temporary table foo( 140 | a int2, 141 | b int4, 142 | c int8, 143 | d varchar, 144 | e text, 145 | f date, 146 | g timestamptz, 147 | h bytea 148 | )`) 149 | 150 | tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) 151 | 152 | inputRows := [][]interface{}{} 153 | 154 | for i := 0; i < 10000; i++ { 155 | inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) 156 | } 157 | 158 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) 159 | if err != nil { 160 | t.Errorf("Unexpected error for CopyFrom: %v", err) 161 | } 162 | if int(copyCount) != len(inputRows) { 163 | t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) 164 | } 165 | 166 | rows, err := conn.Query(context.Background(), "select * from foo") 167 | if err != nil { 168 | t.Errorf("Unexpected error for Query: %v", err) 169 | } 170 | 171 | var outputRows [][]interface{} 172 | for rows.Next() { 173 | row, err := rows.Values() 174 | if err != nil { 175 | t.Errorf("Unexpected error for rows.Values(): %v", err) 176 | } 177 | outputRows = append(outputRows, row) 178 | } 179 | 180 | if rows.Err() != nil { 181 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 182 | } 183 | 184 | if !reflect.DeepEqual(inputRows, outputRows) { 185 | t.Errorf("Input rows and output rows do not equal") 186 | } 187 | 188 | ensureConnValid(t, conn) 189 | } 190 | 191 | func TestConnCopyFromEnum(t *testing.T) { 192 | t.Parallel() 193 | 194 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 195 | defer closeConn(t, conn) 196 | 197 | ctx := context.Background() 198 | tx, err := conn.Begin(ctx) 199 | require.NoError(t, err) 200 | defer tx.Rollback(ctx) 201 | 202 | _, err = tx.Exec(ctx, `drop type if exists color`) 203 | require.NoError(t, err) 204 | 205 | _, err = tx.Exec(ctx, `drop type if exists fruit`) 206 | require.NoError(t, err) 207 | 208 | _, err = tx.Exec(ctx, `create type color as enum ('blue', 'green', 'orange')`) 209 | require.NoError(t, err) 210 | 211 | _, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`) 212 | require.NoError(t, err) 213 | 214 | _, err = tx.Exec(ctx, `create table foo( 215 | a text, 216 | b color, 217 | c fruit, 218 | d color, 219 | e fruit, 220 | f text 221 | )`) 222 | require.NoError(t, err) 223 | 224 | inputRows := [][]interface{}{ 225 | {"abc", "blue", "grape", "orange", "orange", "def"}, 226 | {nil, nil, nil, nil, nil, nil}, 227 | } 228 | 229 | copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows)) 230 | require.NoError(t, err) 231 | require.EqualValues(t, len(inputRows), copyCount) 232 | 233 | rows, err := conn.Query(ctx, "select * from foo") 234 | require.NoError(t, err) 235 | 236 | var outputRows [][]interface{} 237 | for rows.Next() { 238 | row, err := rows.Values() 239 | require.NoError(t, err) 240 | outputRows = append(outputRows, row) 241 | } 242 | 243 | require.NoError(t, rows.Err()) 244 | 245 | if !reflect.DeepEqual(inputRows, outputRows) { 246 | t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) 247 | } 248 | 249 | ensureConnValid(t, conn) 250 | } 251 | 252 | func TestConnCopyFromJSON(t *testing.T) { 253 | t.Parallel() 254 | 255 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 256 | defer closeConn(t, conn) 257 | 258 | for _, typeName := range []string{"json", "jsonb"} { 259 | if _, ok := conn.ConnInfo().DataTypeForName(typeName); !ok { 260 | return // No JSON/JSONB type -- must be running against old PostgreSQL 261 | } 262 | } 263 | 264 | mustExec(t, conn, `create temporary table foo( 265 | a json, 266 | b jsonb 267 | )`) 268 | 269 | inputRows := [][]interface{}{ 270 | {map[string]interface{}{"foo": "bar"}, map[string]interface{}{"bar": "quz"}}, 271 | {nil, nil}, 272 | } 273 | 274 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) 275 | if err != nil { 276 | t.Errorf("Unexpected error for CopyFrom: %v", err) 277 | } 278 | if int(copyCount) != len(inputRows) { 279 | t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) 280 | } 281 | 282 | rows, err := conn.Query(context.Background(), "select * from foo") 283 | if err != nil { 284 | t.Errorf("Unexpected error for Query: %v", err) 285 | } 286 | 287 | var outputRows [][]interface{} 288 | for rows.Next() { 289 | row, err := rows.Values() 290 | if err != nil { 291 | t.Errorf("Unexpected error for rows.Values(): %v", err) 292 | } 293 | outputRows = append(outputRows, row) 294 | } 295 | 296 | if rows.Err() != nil { 297 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 298 | } 299 | 300 | if !reflect.DeepEqual(inputRows, outputRows) { 301 | t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) 302 | } 303 | 304 | ensureConnValid(t, conn) 305 | } 306 | 307 | type clientFailSource struct { 308 | count int 309 | err error 310 | } 311 | 312 | func (cfs *clientFailSource) Next() bool { 313 | cfs.count++ 314 | return cfs.count < 100 315 | } 316 | 317 | func (cfs *clientFailSource) Values() ([]interface{}, error) { 318 | if cfs.count == 3 { 319 | cfs.err = fmt.Errorf("client error") 320 | return nil, cfs.err 321 | } 322 | return []interface{}{make([]byte, 100000)}, nil 323 | } 324 | 325 | func (cfs *clientFailSource) Err() error { 326 | return cfs.err 327 | } 328 | 329 | func TestConnCopyFromFailServerSideMidway(t *testing.T) { 330 | t.Parallel() 331 | 332 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 333 | defer closeConn(t, conn) 334 | 335 | mustExec(t, conn, `create temporary table foo( 336 | a int4, 337 | b varchar not null 338 | )`) 339 | 340 | inputRows := [][]interface{}{ 341 | {int32(1), "abc"}, 342 | {int32(2), nil}, // this row should trigger a failure 343 | {int32(3), "def"}, 344 | } 345 | 346 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) 347 | if err == nil { 348 | t.Errorf("Expected CopyFrom return error, but it did not") 349 | } 350 | if _, ok := err.(*pgconn.PgError); !ok { 351 | t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) 352 | } 353 | if copyCount != 0 { 354 | t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) 355 | } 356 | 357 | rows, err := conn.Query(context.Background(), "select * from foo") 358 | if err != nil { 359 | t.Errorf("Unexpected error for Query: %v", err) 360 | } 361 | 362 | var outputRows [][]interface{} 363 | for rows.Next() { 364 | row, err := rows.Values() 365 | if err != nil { 366 | t.Errorf("Unexpected error for rows.Values(): %v", err) 367 | } 368 | outputRows = append(outputRows, row) 369 | } 370 | 371 | if rows.Err() != nil { 372 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 373 | } 374 | 375 | if len(outputRows) != 0 { 376 | t.Errorf("Expected 0 rows, but got %v", outputRows) 377 | } 378 | 379 | mustExec(t, conn, "truncate foo") 380 | 381 | ensureConnValid(t, conn) 382 | } 383 | 384 | type failSource struct { 385 | count int 386 | } 387 | 388 | func (fs *failSource) Next() bool { 389 | time.Sleep(time.Millisecond * 100) 390 | fs.count++ 391 | return fs.count < 100 392 | } 393 | 394 | func (fs *failSource) Values() ([]interface{}, error) { 395 | if fs.count == 3 { 396 | return []interface{}{nil}, nil 397 | } 398 | return []interface{}{make([]byte, 100000)}, nil 399 | } 400 | 401 | func (fs *failSource) Err() error { 402 | return nil 403 | } 404 | 405 | func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { 406 | t.Parallel() 407 | 408 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 409 | defer closeConn(t, conn) 410 | 411 | mustExec(t, conn, `create temporary table foo( 412 | a bytea not null 413 | )`) 414 | 415 | startTime := time.Now() 416 | 417 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) 418 | if err == nil { 419 | t.Errorf("Expected CopyFrom return error, but it did not") 420 | } 421 | if _, ok := err.(*pgconn.PgError); !ok { 422 | t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) 423 | } 424 | if copyCount != 0 { 425 | t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) 426 | } 427 | 428 | endTime := time.Now() 429 | copyTime := endTime.Sub(startTime) 430 | if copyTime > time.Second { 431 | t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime) 432 | } 433 | 434 | rows, err := conn.Query(context.Background(), "select * from foo") 435 | if err != nil { 436 | t.Errorf("Unexpected error for Query: %v", err) 437 | } 438 | 439 | var outputRows [][]interface{} 440 | for rows.Next() { 441 | row, err := rows.Values() 442 | if err != nil { 443 | t.Errorf("Unexpected error for rows.Values(): %v", err) 444 | } 445 | outputRows = append(outputRows, row) 446 | } 447 | 448 | if rows.Err() != nil { 449 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 450 | } 451 | 452 | if len(outputRows) != 0 { 453 | t.Errorf("Expected 0 rows, but got %v", outputRows) 454 | } 455 | 456 | ensureConnValid(t, conn) 457 | } 458 | 459 | type slowFailRaceSource struct { 460 | count int 461 | } 462 | 463 | func (fs *slowFailRaceSource) Next() bool { 464 | time.Sleep(time.Millisecond) 465 | fs.count++ 466 | return fs.count < 1000 467 | } 468 | 469 | func (fs *slowFailRaceSource) Values() ([]interface{}, error) { 470 | if fs.count == 500 { 471 | return []interface{}{nil, nil}, nil 472 | } 473 | return []interface{}{1, make([]byte, 1000)}, nil 474 | } 475 | 476 | func (fs *slowFailRaceSource) Err() error { 477 | return nil 478 | } 479 | 480 | func TestConnCopyFromSlowFailRace(t *testing.T) { 481 | t.Parallel() 482 | 483 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 484 | defer closeConn(t, conn) 485 | 486 | mustExec(t, conn, `create temporary table foo( 487 | a int not null, 488 | b bytea not null 489 | )`) 490 | 491 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{}) 492 | if err == nil { 493 | t.Errorf("Expected CopyFrom return error, but it did not") 494 | } 495 | if _, ok := err.(*pgconn.PgError); !ok { 496 | t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) 497 | } 498 | if copyCount != 0 { 499 | t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) 500 | } 501 | 502 | ensureConnValid(t, conn) 503 | } 504 | 505 | func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { 506 | t.Parallel() 507 | 508 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 509 | defer closeConn(t, conn) 510 | 511 | mustExec(t, conn, `create temporary table foo( 512 | a bytea not null 513 | )`) 514 | 515 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) 516 | if err == nil { 517 | t.Errorf("Expected CopyFrom return error, but it did not") 518 | } 519 | if copyCount != 0 { 520 | t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) 521 | } 522 | 523 | rows, err := conn.Query(context.Background(), "select * from foo") 524 | if err != nil { 525 | t.Errorf("Unexpected error for Query: %v", err) 526 | } 527 | 528 | var outputRows [][]interface{} 529 | for rows.Next() { 530 | row, err := rows.Values() 531 | if err != nil { 532 | t.Errorf("Unexpected error for rows.Values(): %v", err) 533 | } 534 | outputRows = append(outputRows, row) 535 | } 536 | 537 | if rows.Err() != nil { 538 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 539 | } 540 | 541 | if len(outputRows) != 0 { 542 | t.Errorf("Expected 0 rows, but got %v", len(outputRows)) 543 | } 544 | 545 | ensureConnValid(t, conn) 546 | } 547 | 548 | type clientFinalErrSource struct { 549 | count int 550 | } 551 | 552 | func (cfs *clientFinalErrSource) Next() bool { 553 | cfs.count++ 554 | return cfs.count < 5 555 | } 556 | 557 | func (cfs *clientFinalErrSource) Values() ([]interface{}, error) { 558 | return []interface{}{make([]byte, 100000)}, nil 559 | } 560 | 561 | func (cfs *clientFinalErrSource) Err() error { 562 | return fmt.Errorf("final error") 563 | } 564 | 565 | func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { 566 | t.Parallel() 567 | 568 | conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) 569 | defer closeConn(t, conn) 570 | 571 | mustExec(t, conn, `create temporary table foo( 572 | a bytea not null 573 | )`) 574 | 575 | copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) 576 | if err == nil { 577 | t.Errorf("Expected CopyFrom return error, but it did not") 578 | } 579 | if copyCount != 0 { 580 | t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) 581 | } 582 | 583 | rows, err := conn.Query(context.Background(), "select * from foo") 584 | if err != nil { 585 | t.Errorf("Unexpected error for Query: %v", err) 586 | } 587 | 588 | var outputRows [][]interface{} 589 | for rows.Next() { 590 | row, err := rows.Values() 591 | if err != nil { 592 | t.Errorf("Unexpected error for rows.Values(): %v", err) 593 | } 594 | outputRows = append(outputRows, row) 595 | } 596 | 597 | if rows.Err() != nil { 598 | t.Errorf("Unexpected error for rows.Err(): %v", rows.Err()) 599 | } 600 | 601 | if len(outputRows) != 0 { 602 | t.Errorf("Expected 0 rows, but got %v", outputRows) 603 | } 604 | 605 | ensureConnValid(t, conn) 606 | } 607 | --------------------------------------------------------------------------------