├── .github └── workflows │ ├── codeql-analysis.yml │ └── test.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── TESTS.md ├── array.go ├── array_test.go ├── auth └── kerberos │ ├── go.mod │ ├── go.sum │ ├── krb.go │ ├── krb_unix.go │ └── krb_windows.go ├── bench_test.go ├── buf.go ├── buf_test.go ├── certs ├── Makefile ├── README ├── bogus_root.crt ├── postgresql.cnf ├── postgresql.crt ├── postgresql.key ├── root.cnf ├── root.crt ├── server.cnf ├── server.crt └── server.key ├── conn.go ├── conn_go115.go ├── conn_go18.go ├── conn_go19.go ├── conn_go19_test.go ├── conn_test.go ├── connector.go ├── connector_example_test.go ├── connector_test.go ├── copy.go ├── copy_test.go ├── doc.go ├── encode.go ├── encode_test.go ├── error.go ├── example └── listen │ └── doc.go ├── go.mod ├── go18_test.go ├── go19_test.go ├── hstore ├── hstore.go └── hstore_test.go ├── issues_test.go ├── krb.go ├── notice.go ├── notice_example_test.go ├── notice_test.go ├── notify.go ├── notify_test.go ├── oid ├── doc.go ├── gen.go └── types.go ├── rows.go ├── rows_test.go ├── scram └── scram.go ├── ssl.go ├── ssl_permissions.go ├── ssl_permissions_test.go ├── ssl_test.go ├── ssl_windows.go ├── url.go ├── url_test.go ├── user_other.go ├── user_posix.go ├── user_windows.go ├── uuid.go └── uuid_test.go /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | analyze: 7 | name: Analyze 8 | runs-on: ubuntu-latest 9 | permissions: 10 | actions: read 11 | contents: read 12 | security-events: write 13 | 14 | steps: 15 | - name: Checkout repo 16 | uses: actions/checkout@v3 17 | 18 | - name: Initialize CodeQL 19 | uses: github/codeql-action/init@v1 20 | with: 21 | languages: 'go' 22 | 23 | - name: CodeQL Analysis 24 | uses: github/codeql-action/analyze@v1 25 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | fail-fast: false 10 | matrix: 11 | postgres: 12 | - '15' 13 | - '14' 14 | - '13' 15 | - '12' 16 | - '11' 17 | - '10' 18 | - '9.6' 19 | go: 20 | - '1.20' 21 | - '1.19' 22 | - '1.18' 23 | - '1.17' 24 | - '1.16' 25 | - '1.15' 26 | - '1.14' 27 | steps: 28 | - name: setup postgres pre-reqs 29 | run: | 30 | mkdir init 31 | cat < init/root.crt 32 | -----BEGIN CERTIFICATE----- 33 | MIIEBjCCAu6gAwIBAgIJAPizR+OD14YnMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV 34 | BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG 35 | A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw 36 | MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowXjELMAkGA1UEBhMCVVMxDzANBgNVBAgM 37 | Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t 38 | L2xpYi9wcTEOMAwGA1UEAwwFcHEgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 39 | ggEKAoIBAQDb9d6sjdU6GdibGrXRMOHREH3MRUS8T4TFqGgPEGVDP/V5bAZlBSGP 40 | AN0o9DTyVLcbQpBt8zMTw9KeIzIIe5NIVkSmA16lw/YckGhOM+kZIkiDuE6qt5Ia 41 | OQCRMdXkZ8ejG/JUu+rHU8FJZL8DE+jyYherzdjkeVAQ7JfzxAwW2Dl7T/47g337 42 | Pwmf17AEb8ibSqmXyUN7R5NhJQs+hvaYdNagzdx91E1H+qlyBvmiNeasUQljLvZ+ 43 | Y8wAuU79neA+d09O4PBiYwV17rSP6SZCeGE3oLZviL/0KM9Xig88oB+2FmvQ6Zxa 44 | L7SoBlqS+5pBZwpH7eee/wCIKAnJtMAJAgMBAAGjgcYwgcMwDwYDVR0TAQH/BAUw 45 | AwEB/zAdBgNVHQ4EFgQUfIXEczahbcM2cFrwclJF7GbdajkwgZAGA1UdIwSBiDCB 46 | hYAUfIXEczahbcM2cFrwclJF7GbdajmhYqRgMF4xCzAJBgNVBAYTAlVTMQ8wDQYD 47 | VQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHVi 48 | LmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBggkA+LNH44PXhicwDQYJKoZIhvcN 49 | AQELBQADggEBABFyGgSz2mHVJqYgX1Y+7P+MfKt83cV2uYDGYvXrLG2OGiCilVul 50 | oTBG+8omIMSHOsQZvWMpA5H0tnnlQHrKpKpUyKkSL+Wv5GL0UtBmHX7mVRiaK2l4 51 | q2BjRaQUitp/FH4NSdXtVrMME5T1JBBZHsQkNL3cNRzRKwY/Vj5UGEDxDS7lILUC 52 | e01L4oaK0iKQn4beALU+TvKoAHdPvoxpPpnhkF5ss9HmdcvRktJrKZemDJZswZ7/ 53 | +omx8ZPIYYUH5VJJYYE88S7guAt+ZaKIUlel/t6xPbo2ZySFSg9u1uB99n+jTo3L 54 | 1rAxFnN3FCX2jBqgP29xMVmisaN5k04UmyI= 55 | -----END CERTIFICATE----- 56 | CONF 57 | cat < init/server.crt 58 | -----BEGIN CERTIFICATE----- 59 | MIIDqzCCApOgAwIBAgIJAPiewLrOyYipMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV 60 | BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG 61 | A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw 62 | MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowTjELMAkGA1UEBhMCVVMxDzANBgNVBAgM 63 | Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t 64 | L2xpYi9wcTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKf6H4UzmANN 65 | QiQJe92Mf3ETMYmpZKNNO9DPEHyNLIkag+XwMrBTdcCK0mLvsNCYpXuBN6703KCd 66 | WAFOeMmj7gOsWtvjt5Xm6bRHLgegekXzcG/jDwq/wyzeDzr/YkITuIlG44Lf9lhY 67 | FLwiHlHOWHnwrZaEh6aU//02aQkzyX5INeXl/3TZm2G2eIH6AOxOKOU27MUsyVSQ 68 | 5DE+SDKGcRP4bElueeQWvxAXNMZYb7sVSDdfHI3zr32K4k/tC8x0fZJ5XN/dvl4t 69 | 4N4MrYlmDO5XOrb/gQH1H4iu6+5EMDfZYab4fkThnNFdfFqu4/8Scv7KZ8mWqpKM 70 | fGAjEPctQi0CAwEAAaN8MHowHQYDVR0OBBYEFENExPbmDyFB2AJUdbMvVyhlNPD5 71 | MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgWgMBMGA1UdEQQMMAqCCHBvc3RncmVzMCwG 72 | CWCGSAGG+EIBDQQfFh1PcGVuU1NMIEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTANBgkq 73 | hkiG9w0BAQsFAAOCAQEAMRVbV8RiEsmp9HAtnVCZmRXMIbgPGrqjeSwk586s4K8v 74 | BSqNCqxv6s5GfCRmDYiqSqeuCVDtUJS1HsTmbxVV7Ke71WMo+xHR1ICGKOa8WGCb 75 | TGsuicG5QZXWaxeMOg4s0qpKmKko0d1aErdVsanU5dkrVS7D6729Ffnzu4lwApk6 76 | invAB67p8u7sojwqRq5ce0vRaG+YFylTrWomF9kauEb8gKbQ9Xc7QfX+h+UH/mq9 77 | Nvdj8LOHp6/82bZdnsYUOtV4lS1IA/qzeXpqBphxqfWabD1yLtkyJyImZKq8uIPp 78 | 0CG4jhObPdWcCkXD6bg3QK3mhwlC79OtFgxWmldCRQ== 79 | -----END CERTIFICATE----- 80 | CONF 81 | cat < init/server.key 82 | -----BEGIN PRIVATE KEY----- 83 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCn+h+FM5gDTUIk 84 | CXvdjH9xEzGJqWSjTTvQzxB8jSyJGoPl8DKwU3XAitJi77DQmKV7gTeu9NygnVgB 85 | TnjJo+4DrFrb47eV5um0Ry4HoHpF83Bv4w8Kv8Ms3g86/2JCE7iJRuOC3/ZYWBS8 86 | Ih5Rzlh58K2WhIemlP/9NmkJM8l+SDXl5f902ZthtniB+gDsTijlNuzFLMlUkOQx 87 | PkgyhnET+GxJbnnkFr8QFzTGWG+7FUg3XxyN8699iuJP7QvMdH2SeVzf3b5eLeDe 88 | DK2JZgzuVzq2/4EB9R+IruvuRDA32WGm+H5E4ZzRXXxaruP/EnL+ymfJlqqSjHxg 89 | IxD3LUItAgMBAAECggEAOE2naQ9tIZYw2EFxikZApVcooJrtx6ropMnzHbx4NBB2 90 | K4mChAXFj184u77ZxmGT/jzGvFcI6LE0wWNbK0NOUV7hKZk/fPhkV3AQZrAMrAu4 91 | IVi7PwAd3JkmA8F8XuebUDA5rDGDsgL8GD9baFJA58abeLs9eMGyuF4XgOUh4bip 92 | hgHa76O2rcDWNY5HZqqRslw75FzlYkB0PCts/UJxSswj70kTTihyOhDlrm2TnyxI 93 | ne54UbGRrpfs9wiheSGLjDG81qZToBHQDwoAnjjZhu1VCaBISuGbgZrxyyRyqdnn 94 | xPW+KczMv04XyvF7v6Pz+bUEppalLXGiXnH5UtWvZQKBgQDTPCdMpNE/hwlq4nAw 95 | Kf42zIBWfbnMLVWYoeDiAOhtl9XAUAXn76xe6Rvo0qeAo67yejdbJfRq3HvGyw+q 96 | 4PS8r9gXYmLYIPQxSoLL5+rFoBCN3qFippfjLB1j32mp7+15KjRj8FF2r6xIN8fu 97 | XatSRsaqmvCWYLDRv/rbHnxwkwKBgQDLkyfFLF7BtwtPWKdqrwOM7ip1UKh+oDBS 98 | vkCQ08aEFRBU7T3jChsx5GbaW6zmsSBwBwcrHclpSkz7n3aq19DDWObJR2p80Fma 99 | rsXeIcvtEpkvT3pVX268P5d+XGs1kxgFunqTysG9yChW+xzcs5MdKBzuMPPn7rL8 100 | MKAzdar6PwKBgEypkzW8x3h/4Moa3k6MnwdyVs2NGaZheaRIc95yJ+jGZzxBjrMr 101 | h+p2PbvU4BfO0AqOkpKRBtDVrlJqlggVVp04UHvEKE16QEW3Xhr0037f5cInX3j3 102 | Lz6yXwRFLAsR2aTUzWjL6jTh8uvO2s/GzQuyRh3a16Ar/WBShY+K0+zjAoGATnLT 103 | xZjWnyHRmu8X/PWakamJ9RFzDPDgDlLAgM8LVgTj+UY/LgnL9wsEU6s2UuP5ExKy 104 | QXxGDGwUhHar/SQTj+Pnc7Mwpw6HKSOmnnY5po8fNusSwml3O9XppEkrC0c236Y/ 105 | 7EobJO5IFVTJh4cv7vFxTJzSsRL8KFD4uzvh+nMCgYEAqY8NBYtIgNJA2B6C6hHF 106 | +bG7v46434ZHFfGTmMQwzE4taVg7YRnzYESAlvK4bAP5ZXR90n7GRGFhrXzoMZ38 107 | r0bw/q9rV+ReGda7/Bjf7ciCKiq0RODcHtf4IaskjPXCoQRGJtgCPLhWPfld6g9v 108 | /HTvO96xv9e3eG/PKSPog94= 109 | -----END PRIVATE KEY----- 110 | CONF 111 | cat < init/hba.sh 112 | cat < /var/lib/postgresql/data/pg_hba.conf 113 | local all all trust 114 | host all postgres all trust 115 | hostnossl all pqgossltest all reject 116 | hostnossl all pqgosslcert all reject 117 | hostssl all pqgossltest all trust 118 | hostssl all pqgosslcert all cert 119 | host all all all trust 120 | EOF 121 | CONF 122 | sudo chown 999:999 ./init/* 123 | sudo chmod 600 ./init/* 124 | 125 | - name: start postgres 126 | run: | 127 | docker run -d \ 128 | --name pg \ 129 | -p 5432:5432 \ 130 | -v $(pwd)/init:/init \ 131 | -e POSTGRES_PASSWORD=unused \ 132 | -e POSTGRES_USER=postgres \ 133 | postgres:${{ matrix.postgres }} \ 134 | -c ssl=on \ 135 | -c ssl_ca_file=/init/root.crt \ 136 | -c ssl_cert_file=/init/server.crt \ 137 | -c ssl_key_file=/init/server.key 138 | 139 | - name: configure postgres 140 | run: | 141 | n=0 142 | until [ "$n" -ge 10 ] 143 | do 144 | docker exec pg pg_isready -h localhost && break 145 | n=$((n+1)) 146 | echo waiting for postgres to be ready... 147 | sleep 1 148 | done 149 | docker exec pg bash /init/hba.sh 150 | n=0 151 | until [ "$n" -ge 10 ] 152 | do 153 | docker exec pg su postgres -c '/usr/lib/postgresql/${{ matrix.postgres }}/bin/pg_ctl reload' && break 154 | n=$((n+1)) 155 | echo waiting for postgres to reload... 156 | sleep 1 157 | done 158 | 159 | - name: setup hosts 160 | run: echo '127.0.0.1 postgres' | sudo tee -a /etc/hosts 161 | 162 | - name: create db/roles 163 | run: | 164 | n=0 165 | until [ "$n" -ge 10 ] 166 | do 167 | docker exec pg pg_isready -h localhost && break 168 | n=$((n+1)) 169 | echo waiting for postgres to be ready... 170 | sleep 1 171 | done 172 | docker exec pg createdb -h localhost -U postgres pqgotest 173 | docker exec pg createuser -h localhost -U postgres -DRS pqgossltest 174 | docker exec pg createuser -h localhost -U postgres -DRS pqgosslcert 175 | 176 | - name: check out code into the Go module directory 177 | uses: actions/checkout@v3 178 | 179 | - name: set up go 180 | uses: actions/setup-go@v4 181 | with: 182 | go-version: ${{ matrix.go }} 183 | id: go 184 | 185 | - name: set key perms 186 | run: sudo chmod 600 certs/postgresql.key 187 | 188 | - name: run tests 189 | env: 190 | PGUSER: postgres 191 | PGHOST: localhost 192 | PGPORT: 5432 193 | PQGOSSLTESTS: 1 194 | PQSSLCERTTEST_PATH: certs 195 | run: | 196 | PQTEST_BINARY_PARAMETERS=no go test -race -v ./... 197 | PQTEST_BINARY_PARAMETERS=yes go test -race -v ./... 198 | 199 | - name: install goimports 200 | run: go get golang.org/x/tools/cmd/goimports 201 | 202 | - name: install staticcheck 203 | run: | 204 | wget https://github.com/dominikh/go-tools/releases/latest/download/staticcheck_linux_amd64.tar.gz -O - | tar -xz staticcheck 205 | 206 | - name: run goimports 207 | run: | 208 | goimports -d -e . | awk '{ print } END { exit NR == 0 ? 0 : 1 }' 209 | 210 | - name: run staticcheck 211 | run: ./staticcheck/staticcheck -go 1.13 ./... 212 | 213 | - name: build 214 | run: go build -v . 215 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .db 2 | *.test 3 | *~ 4 | *.swp 5 | .idea 6 | .vscode -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011-2013, 'pq' Contributors 2 | Portions Copyright (C) 2011 Blake Mizerany 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pq - A pure Go postgres driver for Go's database/sql package 2 | 3 | [![GoDoc](https://godoc.org/github.com/lib/pq?status.svg)](https://pkg.go.dev/github.com/lib/pq?tab=doc) 4 | 5 | ## Install 6 | 7 | go get github.com/lib/pq 8 | 9 | ## Features 10 | 11 | * SSL 12 | * Handles bad connections for `database/sql` 13 | * Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`) 14 | * Scan binary blobs correctly (i.e. `bytea`) 15 | * Package for `hstore` support 16 | * COPY FROM support 17 | * pq.ParseURL for converting urls to connection strings for sql.Open. 18 | * Many libpq compatible environment variables 19 | * Unix socket support 20 | * Notifications: `LISTEN`/`NOTIFY` 21 | * pgpass support 22 | * GSS (Kerberos) auth 23 | 24 | ## Tests 25 | 26 | `go test` is used for testing. See [TESTS.md](TESTS.md) for more details. 27 | 28 | ## Status 29 | 30 | This package is currently in maintenance mode, which means: 31 | 1. It generally does not accept new features. 32 | 2. It does accept bug fixes and version compatability changes provided by the community. 33 | 3. Maintainers usually do not resolve reported issues. 34 | 4. Community members are encouraged to help each other with reported issues. 35 | 36 | For users that require new features or reliable resolution of reported bugs, we recommend using [pgx](https://github.com/jackc/pgx) which is under active development. 37 | -------------------------------------------------------------------------------- /TESTS.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | ## Running Tests 4 | 5 | `go test` is used for testing. A running PostgreSQL 6 | server is required, with the ability to log in. The 7 | database to connect to test with is "pqgotest," on 8 | "localhost" but these can be overridden using [environment 9 | variables](https://www.postgresql.org/docs/9.3/static/libpq-envars.html). 10 | 11 | Example: 12 | 13 | PGHOST=/run/postgresql go test 14 | 15 | ## Benchmarks 16 | 17 | A benchmark suite can be run as part of the tests: 18 | 19 | go test -bench . 20 | 21 | ## Example setup (Docker) 22 | 23 | Run a postgres container: 24 | 25 | ``` 26 | docker run --expose 5432:5432 postgres 27 | ``` 28 | 29 | Run tests: 30 | 31 | ``` 32 | PGHOST=localhost PGPORT=5432 PGUSER=postgres PGSSLMODE=disable PGDATABASE=postgres go test 33 | ``` 34 | -------------------------------------------------------------------------------- /auth/kerberos/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lib/pq/auth/kerberos 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 7 | github.com/jcmturner/gokrb5/v8 v8.2.0 8 | ) 9 | -------------------------------------------------------------------------------- /auth/kerberos/go.sum: -------------------------------------------------------------------------------- 1 | github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5 h1:P5U+E4x5OkVEKQDklVPmzs71WM56RTTRqV4OrDC//Y4= 2 | github.com/alexbrainman/sspi v0.0.0-20180613141037-e580b900e9f5/go.mod h1:976q2ETgjT2snVCf2ZaBnyBbVoPERGjUz+0sofzEfro= 3 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= 6 | github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= 7 | github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ= 8 | github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= 9 | github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= 10 | github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= 11 | github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= 12 | github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= 13 | github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= 14 | github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= 15 | github.com/jcmturner/gofork v1.0.0 h1:J7uCkflzTEhUZ64xqKnkDxq3kzc96ajM1Gli5ktUem8= 16 | github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= 17 | github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= 18 | github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= 19 | github.com/jcmturner/gokrb5/v8 v8.2.0 h1:lzPl/30ZLkTveYsYZPKMcgXc8MbnE6RsTd4F9KgiLtk= 20 | github.com/jcmturner/gokrb5/v8 v8.2.0/go.mod h1:T1hnNppQsBtxW0tCHMHTkAt8n/sABdzZgZdoFrZaZNM= 21 | github.com/jcmturner/rpc/v2 v2.0.2 h1:gMB4IwRXYsWw4Bc6o/az2HJgFUA1ffSh90i26ZJ6Xl0= 22 | github.com/jcmturner/rpc/v2 v2.0.2/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= 23 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 24 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 25 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 26 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 27 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 28 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 29 | golang.org/x/crypto v0.0.0-20200117160349-530e935923ad h1:Jh8cai0fqIK+f6nG0UgPW5wFk8wmiMhM3AyciDBdtQg= 30 | golang.org/x/crypto v0.0.0-20200117160349-530e935923ad/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 31 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 32 | golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa h1:F+8P+gmewFQYRk6JoLQLwjBCTu3mcIURZfNkVweuRKA= 33 | golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 34 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 35 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 36 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 37 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 38 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 39 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 40 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 41 | -------------------------------------------------------------------------------- /auth/kerberos/krb.go: -------------------------------------------------------------------------------- 1 | package kerberos 2 | 3 | import ( 4 | "net" 5 | "strings" 6 | ) 7 | 8 | /* 9 | * Find the A record associated with a hostname 10 | * In general, hostnames supplied to the driver should be 11 | * canonicalized because the KDC usually only has one 12 | * principal and not one per potential alias of a host. 13 | */ 14 | func canonicalizeHostname(host string) (string, error) { 15 | canon := host 16 | 17 | name, err := net.LookupCNAME(host) 18 | if err != nil { 19 | return "", err 20 | } 21 | 22 | name = strings.TrimSuffix(name, ".") 23 | 24 | if name != "" { 25 | canon = name 26 | } 27 | 28 | return canon, nil 29 | } 30 | -------------------------------------------------------------------------------- /auth/kerberos/krb_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package kerberos 5 | 6 | import ( 7 | "fmt" 8 | "os" 9 | "os/user" 10 | "strings" 11 | 12 | "github.com/jcmturner/gokrb5/v8/client" 13 | "github.com/jcmturner/gokrb5/v8/config" 14 | "github.com/jcmturner/gokrb5/v8/credentials" 15 | "github.com/jcmturner/gokrb5/v8/spnego" 16 | ) 17 | 18 | /* 19 | * UNIX Kerberos support, using jcmturner's pure-go 20 | * implementation 21 | */ 22 | 23 | // GSS implements the pq.GSS interface. 24 | type GSS struct { 25 | cli *client.Client 26 | } 27 | 28 | // NewGSS creates a new GSS provider. 29 | func NewGSS() (*GSS, error) { 30 | g := &GSS{} 31 | err := g.init() 32 | 33 | if err != nil { 34 | return nil, err 35 | } 36 | 37 | return g, nil 38 | } 39 | 40 | func (g *GSS) init() error { 41 | cfgPath, ok := os.LookupEnv("KRB5_CONFIG") 42 | if !ok { 43 | cfgPath = "/etc/krb5.conf" 44 | } 45 | 46 | cfg, err := config.Load(cfgPath) 47 | if err != nil { 48 | return err 49 | } 50 | 51 | u, err := user.Current() 52 | if err != nil { 53 | return err 54 | } 55 | 56 | ccpath := "/tmp/krb5cc_" + u.Uid 57 | 58 | ccname := os.Getenv("KRB5CCNAME") 59 | if strings.HasPrefix(ccname, "FILE:") { 60 | ccpath = strings.SplitN(ccname, ":", 2)[1] 61 | } 62 | 63 | ccache, err := credentials.LoadCCache(ccpath) 64 | if err != nil { 65 | return err 66 | } 67 | 68 | cl, err := client.NewFromCCache(ccache, cfg, client.DisablePAFXFAST(true)) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | cl.Login() 74 | 75 | g.cli = cl 76 | 77 | return nil 78 | } 79 | 80 | // GetInitToken implements the GSS interface. 81 | func (g *GSS) GetInitToken(host string, service string) ([]byte, error) { 82 | 83 | // Resolve the hostname down to an 'A' record, if required (usually, it is) 84 | if g.cli.Config.LibDefaults.DNSCanonicalizeHostname { 85 | var err error 86 | host, err = canonicalizeHostname(host) 87 | if err != nil { 88 | return nil, err 89 | } 90 | } 91 | 92 | spn := service + "/" + host 93 | 94 | return g.GetInitTokenFromSpn(spn) 95 | } 96 | 97 | // GetInitTokenFromSpn implements the GSS interface. 98 | func (g *GSS) GetInitTokenFromSpn(spn string) ([]byte, error) { 99 | s := spnego.SPNEGOClient(g.cli, spn) 100 | 101 | st, err := s.InitSecContext() 102 | if err != nil { 103 | return nil, fmt.Errorf("kerberos error (InitSecContext): %s", err.Error()) 104 | } 105 | 106 | b, err := st.Marshal() 107 | if err != nil { 108 | return nil, fmt.Errorf("kerberos error (Marshaling token): %s", err.Error()) 109 | } 110 | 111 | return b, nil 112 | } 113 | 114 | // Continue implements the GSS interface. 115 | func (g *GSS) Continue(inToken []byte) (done bool, outToken []byte, err error) { 116 | t := &spnego.SPNEGOToken{} 117 | err = t.Unmarshal(inToken) 118 | if err != nil { 119 | return true, nil, fmt.Errorf("kerberos error (Unmarshaling token): %s", err.Error()) 120 | } 121 | 122 | state := t.NegTokenResp.State() 123 | if state != spnego.NegStateAcceptCompleted { 124 | return true, nil, fmt.Errorf("kerberos: expected state 'Completed' - got %d", state) 125 | } 126 | 127 | return true, nil, nil 128 | } 129 | -------------------------------------------------------------------------------- /auth/kerberos/krb_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package kerberos 5 | 6 | import ( 7 | "github.com/alexbrainman/sspi" 8 | "github.com/alexbrainman/sspi/negotiate" 9 | ) 10 | 11 | // GSS implements the pq.GSS interface. 12 | type GSS struct { 13 | creds *sspi.Credentials 14 | ctx *negotiate.ClientContext 15 | } 16 | 17 | // NewGSS creates a new GSS provider. 18 | func NewGSS() (*GSS, error) { 19 | g := &GSS{} 20 | err := g.init() 21 | 22 | if err != nil { 23 | return nil, err 24 | } 25 | 26 | return g, nil 27 | } 28 | 29 | func (g *GSS) init() error { 30 | creds, err := negotiate.AcquireCurrentUserCredentials() 31 | if err != nil { 32 | return err 33 | } 34 | 35 | g.creds = creds 36 | return nil 37 | } 38 | 39 | // GetInitToken implements the GSS interface. 40 | func (g *GSS) GetInitToken(host string, service string) ([]byte, error) { 41 | 42 | host, err := canonicalizeHostname(host) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | spn := service + "/" + host 48 | 49 | return g.GetInitTokenFromSpn(spn) 50 | } 51 | 52 | // GetInitTokenFromSpn implements the GSS interface. 53 | func (g *GSS) GetInitTokenFromSpn(spn string) ([]byte, error) { 54 | ctx, token, err := negotiate.NewClientContext(g.creds, spn) 55 | if err != nil { 56 | return nil, err 57 | } 58 | 59 | g.ctx = ctx 60 | 61 | return token, nil 62 | } 63 | 64 | // Continue implements the GSS interface. 65 | func (g *GSS) Continue(inToken []byte) (done bool, outToken []byte, err error) { 66 | return g.ctx.Update(inToken) 67 | } 68 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "context" 7 | "database/sql" 8 | "database/sql/driver" 9 | "io" 10 | "math/rand" 11 | "net" 12 | "runtime" 13 | "strconv" 14 | "strings" 15 | "sync" 16 | "testing" 17 | "time" 18 | 19 | "github.com/lib/pq/oid" 20 | ) 21 | 22 | var ( 23 | selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'" 24 | selectSeriesQuery = "SELECT generate_series(1, 100)" 25 | ) 26 | 27 | func BenchmarkSelectString(b *testing.B) { 28 | var result string 29 | benchQuery(b, selectStringQuery, &result) 30 | } 31 | 32 | func BenchmarkSelectSeries(b *testing.B) { 33 | var result int 34 | benchQuery(b, selectSeriesQuery, &result) 35 | } 36 | 37 | func benchQuery(b *testing.B, query string, result interface{}) { 38 | b.StopTimer() 39 | db := openTestConn(b) 40 | defer db.Close() 41 | b.StartTimer() 42 | 43 | for i := 0; i < b.N; i++ { 44 | benchQueryLoop(b, db, query, result) 45 | } 46 | } 47 | 48 | func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) { 49 | rows, err := db.Query(query) 50 | if err != nil { 51 | b.Fatal(err) 52 | } 53 | defer rows.Close() 54 | for rows.Next() { 55 | err = rows.Scan(result) 56 | if err != nil { 57 | b.Fatal("failed to scan", err) 58 | } 59 | } 60 | } 61 | 62 | // reading from circularConn yields content[:prefixLen] once, followed by 63 | // content[prefixLen:] over and over again. It never returns EOF. 64 | type circularConn struct { 65 | content string 66 | prefixLen int 67 | pos int 68 | net.Conn // for all other net.Conn methods that will never be called 69 | } 70 | 71 | func (r *circularConn) Read(b []byte) (n int, err error) { 72 | n = copy(b, r.content[r.pos:]) 73 | r.pos += n 74 | if r.pos >= len(r.content) { 75 | r.pos = r.prefixLen 76 | } 77 | return 78 | } 79 | 80 | func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil } 81 | 82 | func (r *circularConn) Close() error { return nil } 83 | 84 | func fakeConn(content string, prefixLen int) *conn { 85 | c := &circularConn{content: content, prefixLen: prefixLen} 86 | return &conn{buf: bufio.NewReader(c), c: c} 87 | } 88 | 89 | // This benchmark is meant to be the same as BenchmarkSelectString, but takes 90 | // out some of the factors this package can't control. The numbers are less noisy, 91 | // but also the costs of network communication aren't accurately represented. 92 | func BenchmarkMockSelectString(b *testing.B) { 93 | b.StopTimer() 94 | // taken from a recorded run of BenchmarkSelectString 95 | // See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html 96 | const response = "1\x00\x00\x00\x04" + 97 | "t\x00\x00\x00\x06\x00\x00" + 98 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 99 | "Z\x00\x00\x00\x05I" + 100 | "2\x00\x00\x00\x04" + 101 | "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + 102 | "C\x00\x00\x00\rSELECT 1\x00" + 103 | "Z\x00\x00\x00\x05I" + 104 | "3\x00\x00\x00\x04" + 105 | "Z\x00\x00\x00\x05I" 106 | c := fakeConn(response, 0) 107 | b.StartTimer() 108 | 109 | for i := 0; i < b.N; i++ { 110 | benchMockQuery(b, c, selectStringQuery) 111 | } 112 | } 113 | 114 | var seriesRowData = func() string { 115 | var buf bytes.Buffer 116 | for i := 1; i <= 100; i++ { 117 | digits := byte(2) 118 | if i >= 100 { 119 | digits = 3 120 | } else if i < 10 { 121 | digits = 1 122 | } 123 | buf.WriteString("D\x00\x00\x00") 124 | buf.WriteByte(10 + digits) 125 | buf.WriteString("\x00\x01\x00\x00\x00") 126 | buf.WriteByte(digits) 127 | buf.WriteString(strconv.Itoa(i)) 128 | } 129 | return buf.String() 130 | }() 131 | 132 | func BenchmarkMockSelectSeries(b *testing.B) { 133 | b.StopTimer() 134 | var response = "1\x00\x00\x00\x04" + 135 | "t\x00\x00\x00\x06\x00\x00" + 136 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 137 | "Z\x00\x00\x00\x05I" + 138 | "2\x00\x00\x00\x04" + 139 | seriesRowData + 140 | "C\x00\x00\x00\x0fSELECT 100\x00" + 141 | "Z\x00\x00\x00\x05I" + 142 | "3\x00\x00\x00\x04" + 143 | "Z\x00\x00\x00\x05I" 144 | c := fakeConn(response, 0) 145 | b.StartTimer() 146 | 147 | for i := 0; i < b.N; i++ { 148 | benchMockQuery(b, c, selectSeriesQuery) 149 | } 150 | } 151 | 152 | func benchMockQuery(b *testing.B, c *conn, query string) { 153 | stmt, err := c.Prepare(query) 154 | if err != nil { 155 | b.Fatal(err) 156 | } 157 | defer stmt.Close() 158 | rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) 159 | if err != nil { 160 | b.Fatal(err) 161 | } 162 | defer rows.Close() 163 | var dest [1]driver.Value 164 | for { 165 | if err := rows.Next(dest[:]); err != nil { 166 | if err == io.EOF { 167 | break 168 | } 169 | b.Fatal(err) 170 | } 171 | } 172 | } 173 | 174 | func BenchmarkPreparedSelectString(b *testing.B) { 175 | var result string 176 | benchPreparedQuery(b, selectStringQuery, &result) 177 | } 178 | 179 | func BenchmarkPreparedSelectSeries(b *testing.B) { 180 | var result int 181 | benchPreparedQuery(b, selectSeriesQuery, &result) 182 | } 183 | 184 | func benchPreparedQuery(b *testing.B, query string, result interface{}) { 185 | b.StopTimer() 186 | db := openTestConn(b) 187 | defer db.Close() 188 | stmt, err := db.Prepare(query) 189 | if err != nil { 190 | b.Fatal(err) 191 | } 192 | defer stmt.Close() 193 | b.StartTimer() 194 | 195 | for i := 0; i < b.N; i++ { 196 | benchPreparedQueryLoop(b, db, stmt, result) 197 | } 198 | } 199 | 200 | func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) { 201 | rows, err := stmt.Query() 202 | if err != nil { 203 | b.Fatal(err) 204 | } 205 | if !rows.Next() { 206 | rows.Close() 207 | b.Fatal("no rows") 208 | } 209 | defer rows.Close() 210 | for rows.Next() { 211 | err = rows.Scan(&result) 212 | if err != nil { 213 | b.Fatal("failed to scan") 214 | } 215 | } 216 | } 217 | 218 | // See the comment for BenchmarkMockSelectString. 219 | func BenchmarkMockPreparedSelectString(b *testing.B) { 220 | b.StopTimer() 221 | const parseResponse = "1\x00\x00\x00\x04" + 222 | "t\x00\x00\x00\x06\x00\x00" + 223 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 224 | "Z\x00\x00\x00\x05I" 225 | const responses = parseResponse + 226 | "2\x00\x00\x00\x04" + 227 | "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + 228 | "C\x00\x00\x00\rSELECT 1\x00" + 229 | "Z\x00\x00\x00\x05I" 230 | c := fakeConn(responses, len(parseResponse)) 231 | 232 | stmt, err := c.Prepare(selectStringQuery) 233 | if err != nil { 234 | b.Fatal(err) 235 | } 236 | b.StartTimer() 237 | 238 | for i := 0; i < b.N; i++ { 239 | benchPreparedMockQuery(b, c, stmt) 240 | } 241 | } 242 | 243 | func BenchmarkMockPreparedSelectSeries(b *testing.B) { 244 | b.StopTimer() 245 | const parseResponse = "1\x00\x00\x00\x04" + 246 | "t\x00\x00\x00\x06\x00\x00" + 247 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 248 | "Z\x00\x00\x00\x05I" 249 | var responses = parseResponse + 250 | "2\x00\x00\x00\x04" + 251 | seriesRowData + 252 | "C\x00\x00\x00\x0fSELECT 100\x00" + 253 | "Z\x00\x00\x00\x05I" 254 | c := fakeConn(responses, len(parseResponse)) 255 | 256 | stmt, err := c.Prepare(selectSeriesQuery) 257 | if err != nil { 258 | b.Fatal(err) 259 | } 260 | b.StartTimer() 261 | 262 | for i := 0; i < b.N; i++ { 263 | benchPreparedMockQuery(b, c, stmt) 264 | } 265 | } 266 | 267 | func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) { 268 | rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) 269 | if err != nil { 270 | b.Fatal(err) 271 | } 272 | defer rows.Close() 273 | var dest [1]driver.Value 274 | for { 275 | if err := rows.Next(dest[:]); err != nil { 276 | if err == io.EOF { 277 | break 278 | } 279 | b.Fatal(err) 280 | } 281 | } 282 | } 283 | 284 | func BenchmarkEncodeInt64(b *testing.B) { 285 | for i := 0; i < b.N; i++ { 286 | encode(¶meterStatus{}, int64(1234), oid.T_int8) 287 | } 288 | } 289 | 290 | func BenchmarkEncodeFloat64(b *testing.B) { 291 | for i := 0; i < b.N; i++ { 292 | encode(¶meterStatus{}, 3.14159, oid.T_float8) 293 | } 294 | } 295 | 296 | var testByteString = []byte("abcdefghijklmnopqrstuvwxyz") 297 | 298 | func BenchmarkEncodeByteaHex(b *testing.B) { 299 | for i := 0; i < b.N; i++ { 300 | encode(¶meterStatus{serverVersion: 90000}, testByteString, oid.T_bytea) 301 | } 302 | } 303 | func BenchmarkEncodeByteaEscape(b *testing.B) { 304 | for i := 0; i < b.N; i++ { 305 | encode(¶meterStatus{serverVersion: 84000}, testByteString, oid.T_bytea) 306 | } 307 | } 308 | 309 | func BenchmarkEncodeBool(b *testing.B) { 310 | for i := 0; i < b.N; i++ { 311 | encode(¶meterStatus{}, true, oid.T_bool) 312 | } 313 | } 314 | 315 | var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local) 316 | 317 | func BenchmarkEncodeTimestamptz(b *testing.B) { 318 | for i := 0; i < b.N; i++ { 319 | encode(¶meterStatus{}, testTimestamptz, oid.T_timestamptz) 320 | } 321 | } 322 | 323 | var testIntBytes = []byte("1234") 324 | 325 | func BenchmarkDecodeInt64(b *testing.B) { 326 | for i := 0; i < b.N; i++ { 327 | decode(¶meterStatus{}, testIntBytes, oid.T_int8, formatText) 328 | } 329 | } 330 | 331 | var testFloatBytes = []byte("3.14159") 332 | 333 | func BenchmarkDecodeFloat64(b *testing.B) { 334 | for i := 0; i < b.N; i++ { 335 | decode(¶meterStatus{}, testFloatBytes, oid.T_float8, formatText) 336 | } 337 | } 338 | 339 | var testBoolBytes = []byte{'t'} 340 | 341 | func BenchmarkDecodeBool(b *testing.B) { 342 | for i := 0; i < b.N; i++ { 343 | decode(¶meterStatus{}, testBoolBytes, oid.T_bool, formatText) 344 | } 345 | } 346 | 347 | func TestDecodeBool(t *testing.T) { 348 | db := openTestConn(t) 349 | rows, err := db.Query("select true") 350 | if err != nil { 351 | t.Fatal(err) 352 | } 353 | rows.Close() 354 | } 355 | 356 | var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07") 357 | 358 | func BenchmarkDecodeTimestamptz(b *testing.B) { 359 | for i := 0; i < b.N; i++ { 360 | decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) 361 | } 362 | } 363 | 364 | func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) { 365 | oldProcs := runtime.GOMAXPROCS(0) 366 | defer runtime.GOMAXPROCS(oldProcs) 367 | runtime.GOMAXPROCS(runtime.NumCPU()) 368 | globalLocationCache = newLocationCache() 369 | 370 | f := func(wg *sync.WaitGroup, loops int) { 371 | defer wg.Done() 372 | for i := 0; i < loops; i++ { 373 | decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) 374 | } 375 | } 376 | 377 | wg := &sync.WaitGroup{} 378 | b.ResetTimer() 379 | for j := 0; j < 10; j++ { 380 | wg.Add(1) 381 | go f(wg, b.N/10) 382 | } 383 | wg.Wait() 384 | } 385 | 386 | func BenchmarkLocationCache(b *testing.B) { 387 | globalLocationCache = newLocationCache() 388 | for i := 0; i < b.N; i++ { 389 | globalLocationCache.getLocation(rand.Intn(10000)) 390 | } 391 | } 392 | 393 | func BenchmarkLocationCacheMultiThread(b *testing.B) { 394 | oldProcs := runtime.GOMAXPROCS(0) 395 | defer runtime.GOMAXPROCS(oldProcs) 396 | runtime.GOMAXPROCS(runtime.NumCPU()) 397 | globalLocationCache = newLocationCache() 398 | 399 | f := func(wg *sync.WaitGroup, loops int) { 400 | defer wg.Done() 401 | for i := 0; i < loops; i++ { 402 | globalLocationCache.getLocation(rand.Intn(10000)) 403 | } 404 | } 405 | 406 | wg := &sync.WaitGroup{} 407 | b.ResetTimer() 408 | for j := 0; j < 10; j++ { 409 | wg.Add(1) 410 | go f(wg, b.N/10) 411 | } 412 | wg.Wait() 413 | } 414 | 415 | // Stress test the performance of parsing results from the wire. 416 | func BenchmarkResultParsing(b *testing.B) { 417 | b.StopTimer() 418 | 419 | db := openTestConn(b) 420 | defer db.Close() 421 | _, err := db.Exec("BEGIN") 422 | if err != nil { 423 | b.Fatal(err) 424 | } 425 | 426 | b.StartTimer() 427 | for i := 0; i < b.N; i++ { 428 | res, err := db.Query("SELECT generate_series(1, 50000)") 429 | if err != nil { 430 | b.Fatal(err) 431 | } 432 | res.Close() 433 | } 434 | } 435 | -------------------------------------------------------------------------------- /buf.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | 7 | "github.com/lib/pq/oid" 8 | ) 9 | 10 | type readBuf []byte 11 | 12 | func (b *readBuf) int32() (n int) { 13 | n = int(int32(binary.BigEndian.Uint32(*b))) 14 | *b = (*b)[4:] 15 | return 16 | } 17 | 18 | func (b *readBuf) oid() (n oid.Oid) { 19 | n = oid.Oid(binary.BigEndian.Uint32(*b)) 20 | *b = (*b)[4:] 21 | return 22 | } 23 | 24 | // N.B: this is actually an unsigned 16-bit integer, unlike int32 25 | func (b *readBuf) int16() (n int) { 26 | n = int(binary.BigEndian.Uint16(*b)) 27 | *b = (*b)[2:] 28 | return 29 | } 30 | 31 | func (b *readBuf) string() string { 32 | i := bytes.IndexByte(*b, 0) 33 | if i < 0 { 34 | errorf("invalid message format; expected string terminator") 35 | } 36 | s := (*b)[:i] 37 | *b = (*b)[i+1:] 38 | return string(s) 39 | } 40 | 41 | func (b *readBuf) next(n int) (v []byte) { 42 | v = (*b)[:n] 43 | *b = (*b)[n:] 44 | return 45 | } 46 | 47 | func (b *readBuf) byte() byte { 48 | return b.next(1)[0] 49 | } 50 | 51 | type writeBuf struct { 52 | buf []byte 53 | pos int 54 | } 55 | 56 | func (b *writeBuf) int32(n int) { 57 | x := make([]byte, 4) 58 | binary.BigEndian.PutUint32(x, uint32(n)) 59 | b.buf = append(b.buf, x...) 60 | } 61 | 62 | func (b *writeBuf) int16(n int) { 63 | x := make([]byte, 2) 64 | binary.BigEndian.PutUint16(x, uint16(n)) 65 | b.buf = append(b.buf, x...) 66 | } 67 | 68 | func (b *writeBuf) string(s string) { 69 | b.buf = append(append(b.buf, s...), '\000') 70 | } 71 | 72 | func (b *writeBuf) byte(c byte) { 73 | b.buf = append(b.buf, c) 74 | } 75 | 76 | func (b *writeBuf) bytes(v []byte) { 77 | b.buf = append(b.buf, v...) 78 | } 79 | 80 | func (b *writeBuf) wrap() []byte { 81 | p := b.buf[b.pos:] 82 | binary.BigEndian.PutUint32(p, uint32(len(p))) 83 | return b.buf 84 | } 85 | 86 | func (b *writeBuf) next(c byte) { 87 | p := b.buf[b.pos:] 88 | binary.BigEndian.PutUint32(p, uint32(len(p))) 89 | b.pos = len(b.buf) + 1 90 | b.buf = append(b.buf, c, 0, 0, 0, 0) 91 | } 92 | -------------------------------------------------------------------------------- /buf_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import "testing" 4 | 5 | func Benchmark_writeBuf_string(b *testing.B) { 6 | var buf writeBuf 7 | const s = "foo" 8 | 9 | b.ReportAllocs() 10 | b.ResetTimer() 11 | 12 | for i := 0; i < b.N; i++ { 13 | buf.string(s) 14 | buf.buf = buf.buf[:0] 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /certs/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all root-ssl server-ssl client-ssl 2 | 3 | # Rebuilds self-signed root/server/client certs/keys in a consistent way 4 | all: root-ssl server-ssl client-ssl 5 | rm -f .srl 6 | 7 | root-ssl: 8 | openssl req -new -sha256 -nodes -newkey rsa:2048 \ 9 | -config ./certs/root.cnf \ 10 | -keyout /tmp/root.key \ 11 | -out /tmp/root.csr 12 | openssl x509 -req -days 3653 -sha256 \ 13 | -in /tmp/root.csr \ 14 | -extfile /etc/ssl/openssl.cnf -extensions v3_ca \ 15 | -signkey /tmp/root.key \ 16 | -out ./certs/root.crt 17 | 18 | server-ssl: 19 | openssl req -new -sha256 -nodes -newkey rsa:2048 \ 20 | -config ./certs/server.cnf \ 21 | -keyout ./certs/server.key \ 22 | -out /tmp/server.csr 23 | openssl x509 -req -days 3653 -sha256 \ 24 | -extfile ./certs/server.cnf -extensions req_ext \ 25 | -CA ./certs/root.crt -CAkey /tmp/root.key -CAcreateserial \ 26 | -in /tmp/server.csr \ 27 | -out ./certs/server.crt 28 | 29 | client-ssl: 30 | openssl req -new -sha256 -nodes -newkey rsa:2048 \ 31 | -config ./certs/postgresql.cnf \ 32 | -keyout ./certs/postgresql.key \ 33 | -out /tmp/postgresql.csr 34 | openssl x509 -req -days 3653 -sha256 \ 35 | -CA ./certs/root.crt -CAkey /tmp/root.key -CAcreateserial \ 36 | -in /tmp/postgresql.csr \ 37 | -out ./certs/postgresql.crt 38 | -------------------------------------------------------------------------------- /certs/README: -------------------------------------------------------------------------------- 1 | This directory contains certificates and private keys for testing some 2 | SSL-related functionality in Travis. Do NOT use these certificates for 3 | anything other than testing. 4 | -------------------------------------------------------------------------------- /certs/bogus_root.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDBjCCAe6gAwIBAgIQSnDYp/Naet9HOZljF5PuwDANBgkqhkiG9w0BAQsFADAr 3 | MRIwEAYDVQQKEwlDb2Nrcm9hY2gxFTATBgNVBAMTDENvY2tyb2FjaCBDQTAeFw0x 4 | NjAyMDcxNjQ0MzdaFw0xNzAyMDYxNjQ0MzdaMCsxEjAQBgNVBAoTCUNvY2tyb2Fj 5 | aDEVMBMGA1UEAxMMQ29ja3JvYWNoIENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A 6 | MIIBCgKCAQEAxdln3/UdgP7ayA/G1kT7upjLe4ERwQjYQ25q0e1+vgsB5jhiirxJ 7 | e0+WkhhYu/mwoSAXzvlsbZ2PWFyfdanZeD/Lh6SvIeWXVVaPcWVWL1TEcoN2jr5+ 8 | E85MMHmbbmaT2he8s6br2tM/UZxyTQ2XRprIzApbDssyw1c0Yufcpu3C6267FLEl 9 | IfcWrzDhnluFhthhtGXv3ToD8IuMScMC5qlKBXtKmD1B5x14ngO/ecNJ+OlEi0HU 10 | mavK4KWgI2rDXRZ2EnCpyTZdkc3kkRnzKcg653oOjMDRZdrhfIrha+Jq38ACsUmZ 11 | Su7Sp5jkIHOCO8Zg+l6GKVSq37dKMapD8wIDAQABoyYwJDAOBgNVHQ8BAf8EBAMC 12 | AuQwEgYDVR0TAQH/BAgwBgEB/wIBATANBgkqhkiG9w0BAQsFAAOCAQEAwZ2Tu0Yu 13 | rrSVdMdoPEjT1IZd+5OhM/SLzL0ddtvTithRweLHsw2lDQYlXFqr24i3UGZJQ1sp 14 | cqSrNwswgLUQT3vWyTjmM51HEb2vMYWKmjZ+sBQYAUP1CadrN/+OTfNGnlF1+B4w 15 | IXOzh7EvQmJJnNybLe4a/aRvj1NE2n8Z898B76SVU9WbfKKz8VwLzuIPDqkKcZda 16 | lMy5yzthyztV9YjcWs2zVOUGZvGdAhDrvZuUq6mSmxrBEvR2LBOggmVf3tGRT+Ls 17 | lW7c9Lrva5zLHuqmoPP07A+vuI9a0D1X44jwGDuPWJ5RnTOQ63Uez12mKNjqleHw 18 | DnkwNanuO8dhAA== 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /certs/postgresql.cnf: -------------------------------------------------------------------------------- 1 | [req] 2 | distinguished_name = req_distinguished_name 3 | prompt = no 4 | 5 | [req_distinguished_name] 6 | C = US 7 | ST = Nevada 8 | L = Las Vegas 9 | O = github.com/lib/pq 10 | CN = pqgosslcert 11 | -------------------------------------------------------------------------------- /certs/postgresql.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDPjCCAiYCCQD4nsC6zsmIqjANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJV 3 | UzEPMA0GA1UECAwGTmV2YWRhMRIwEAYDVQQHDAlMYXMgVmVnYXMxGjAYBgNVBAoM 4 | EWdpdGh1Yi5jb20vbGliL3BxMQ4wDAYDVQQDDAVwcSBDQTAeFw0yMTA5MDIwMTU1 5 | MDJaFw0zMTA5MDMwMTU1MDJaMGQxCzAJBgNVBAYTAlVTMQ8wDQYDVQQIDAZOZXZh 6 | ZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHViLmNvbS9saWIv 7 | cHExFDASBgNVBAMMC3BxZ29zc2xjZXJ0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A 8 | MIIBCgKCAQEAx0ucPVUNCrVmbyithwWrmmZ1dGudBwhSyDB6af4z5Cr+S6dx2SRU 9 | UGUw3Lv+z+tUqQ7hJj0oNddIQeYKl/Tt6JPpZsQfERP/cUGedtyt7HnCKobBL+0B 10 | NvHnDIUiIL4LgfiZK4DWJkGmm7nTHo/7qKAw60vCMLUW98DC0Xhlk9MHYG+e9Zai 11 | 3G0vY2X6DUYcSmzBI3JakFEgMZTQg3ofUQMz8TYeK3/DYadLXkl08d18LL3Dnefx 12 | 0xRuBPNTa2tLfVnFkfFi6Z9xVB/WhG6+X4OLnO85v5xUOGTV+g154iR7FOkrrl5F 13 | lEUBj+yaIoTRi+MyZ/oYqWwQUDYS3+Te9wIDAQABMA0GCSqGSIb3DQEBCwUAA4IB 14 | AQCCJpwUWCx7xfXv3vH3LQcffZycyRHYPgTCbiQw3x9aBb77jUAh5O6lEj/W0nx2 15 | SCTEsCsRSAiFwfUb+g/AFCW84dELRWmf38eoqACebLymqnvxyZA+O87yu07XyFZR 16 | TnmbDMzZgsyWWGwS3JoGFk+ibWY4AImYQnSJO8Pi0kZ37ngbAyJ3RtDhhEQJWw/Q 17 | D04p3uky/ea7Gyz0QTx5o40n4gq7nEzF1OS6IHozM840J5aZrxRiXEa56fsmJHmI 18 | IGyI07SGlWJ15r1wc8lB+8ilnAqH1QQlYzTIW0Q4NZE7n3uQg1EVuueGiGO2ex2/ 19 | he9lDiJfOQuPuLbOxzctP9v9 20 | -----END CERTIFICATE----- 21 | -------------------------------------------------------------------------------- /certs/postgresql.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDHS5w9VQ0KtWZv 3 | KK2HBauaZnV0a50HCFLIMHpp/jPkKv5Lp3HZJFRQZTDcu/7P61SpDuEmPSg110hB 4 | 5gqX9O3ok+lmxB8RE/9xQZ523K3secIqhsEv7QE28ecMhSIgvguB+JkrgNYmQaab 5 | udMej/uooDDrS8IwtRb3wMLReGWT0wdgb571lqLcbS9jZfoNRhxKbMEjclqQUSAx 6 | lNCDeh9RAzPxNh4rf8Nhp0teSXTx3XwsvcOd5/HTFG4E81Nra0t9WcWR8WLpn3FU 7 | H9aEbr5fg4uc7zm/nFQ4ZNX6DXniJHsU6SuuXkWURQGP7JoihNGL4zJn+hipbBBQ 8 | NhLf5N73AgMBAAECggEAHLNY1sRO0oH5NHzpMI6yfdPPimqM/JxIP6grmOQQ2QUQ 9 | BhkhHiJLOiC4frFcKtk7IfWQmw8noUlVkJfuYp/VOy9B55jK2IzGtqq6hWeWbH3E 10 | Zpdtbtd021LO8VCi75Au3BLPDCLLtEq0Ea0bKEWX+lrHcLtCRf1uR1OtOrlZ94Wl 11 | DUhm7YJC4cS1bi6Kdf03R+fw2oFi7/QdywcT4ow032jGWOly/Jl7bSHZK7xLtM/i 12 | 9HfMwmusD/iuz7mtLU7VCpnlKZm6MfS5D427ybW8MruuiZEtQJ6QtRIrHBHk93aK 13 | Op0tjJ6tMav1UsJzgVz9+uWILE9l0AjAa4AvbfNzEQKBgQD8mma9SLQPtBb6cXuT 14 | CQgjE4vyph8mRnm/pTz3QLIpMiLy2+aKJD/u4cduzLw1vjuH1tlb7NQ9c891jAJh 15 | JhwDwqKAXfFicfRs/PYWngx/XtGhbbpgm1yA6XuYL1D06gzmjzXgHvZMOFcts+GF 16 | y0JEuV7v6eYrpQJRQYCwY6xTgwKBgQDJ+bHAlgOaC94DZEXZMiUznCCjBjAstiXG 17 | BEN7Cnfn6vgvPm/b6BkKn4VrsCmbZQKT7QJDSOhYwXCC2ZlrKiF8GEUHX4mi8347 18 | 8B+DsuokTLNmN61QAZbb1c3XQVnr15xH8ijm7yYs4tCBmVLKBmpw1T4IZXXlVE5k 19 | gmee+AwIfQKBgGr+P0wnclVAc4cq8CusZKzux5VEtebxbPo21CbqWUxHtzPk3rZe 20 | elIFggK1Z3bgF7kG0NQ18QQCfLoOTqe1i6IwG8KBiA+pst1DHD0iPqroj6RvpMTs 21 | qXbU7ovcZs8GH+a8fBZtJufL6WkrSvfvyybu2X6HNP4Bi4S9WPPdlA1fAoGAE5m/ 22 | vkjQoKp2KS4Z+TH8mj2UjT2Uf0JN+CGByvcBG+iZnTwZ7uVfSMCiWgkGgKYU0fY2 23 | OgFhSvu6x3gGg3fbOAfC6yxCVyX6IibzZ/x87HjlEA5nK1R8J2lgSHt3FoQeDn1Z 24 | qs+ajNCWG32doy1sNvb6xiXSgybjVK2zEKJRyKECgYBJTk2IABebjvInNb6tagcI 25 | nD4d2LgBmZJZsTruHXrpO0s3XCQcFKks4JKH1CVjd34f7LkxzEOGbE7wKBBd652s 26 | ob6gFKnbqTniTo3NRUycB6ymo4LSaBvKgeY5hYbVxrYheRLPGY+gPVYb3VMKu9N9 27 | 76rcaFqJOz7OeywRG5bHUg== 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /certs/root.cnf: -------------------------------------------------------------------------------- 1 | [req] 2 | distinguished_name = req_distinguished_name 3 | prompt = no 4 | 5 | [req_distinguished_name] 6 | C = US 7 | ST = Nevada 8 | L = Las Vegas 9 | O = github.com/lib/pq 10 | CN = pq CA 11 | -------------------------------------------------------------------------------- /certs/root.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEBjCCAu6gAwIBAgIJAPizR+OD14YnMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV 3 | BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG 4 | A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw 5 | MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowXjELMAkGA1UEBhMCVVMxDzANBgNVBAgM 6 | Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t 7 | L2xpYi9wcTEOMAwGA1UEAwwFcHEgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 8 | ggEKAoIBAQDb9d6sjdU6GdibGrXRMOHREH3MRUS8T4TFqGgPEGVDP/V5bAZlBSGP 9 | AN0o9DTyVLcbQpBt8zMTw9KeIzIIe5NIVkSmA16lw/YckGhOM+kZIkiDuE6qt5Ia 10 | OQCRMdXkZ8ejG/JUu+rHU8FJZL8DE+jyYherzdjkeVAQ7JfzxAwW2Dl7T/47g337 11 | Pwmf17AEb8ibSqmXyUN7R5NhJQs+hvaYdNagzdx91E1H+qlyBvmiNeasUQljLvZ+ 12 | Y8wAuU79neA+d09O4PBiYwV17rSP6SZCeGE3oLZviL/0KM9Xig88oB+2FmvQ6Zxa 13 | L7SoBlqS+5pBZwpH7eee/wCIKAnJtMAJAgMBAAGjgcYwgcMwDwYDVR0TAQH/BAUw 14 | AwEB/zAdBgNVHQ4EFgQUfIXEczahbcM2cFrwclJF7GbdajkwgZAGA1UdIwSBiDCB 15 | hYAUfIXEczahbcM2cFrwclJF7GbdajmhYqRgMF4xCzAJBgNVBAYTAlVTMQ8wDQYD 16 | VQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgGA1UECgwRZ2l0aHVi 17 | LmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBggkA+LNH44PXhicwDQYJKoZIhvcN 18 | AQELBQADggEBABFyGgSz2mHVJqYgX1Y+7P+MfKt83cV2uYDGYvXrLG2OGiCilVul 19 | oTBG+8omIMSHOsQZvWMpA5H0tnnlQHrKpKpUyKkSL+Wv5GL0UtBmHX7mVRiaK2l4 20 | q2BjRaQUitp/FH4NSdXtVrMME5T1JBBZHsQkNL3cNRzRKwY/Vj5UGEDxDS7lILUC 21 | e01L4oaK0iKQn4beALU+TvKoAHdPvoxpPpnhkF5ss9HmdcvRktJrKZemDJZswZ7/ 22 | +omx8ZPIYYUH5VJJYYE88S7guAt+ZaKIUlel/t6xPbo2ZySFSg9u1uB99n+jTo3L 23 | 1rAxFnN3FCX2jBqgP29xMVmisaN5k04UmyI= 24 | -----END CERTIFICATE----- 25 | -------------------------------------------------------------------------------- /certs/server.cnf: -------------------------------------------------------------------------------- 1 | [ req ] 2 | default_bits = 2048 3 | distinguished_name = subject 4 | req_extensions = req_ext 5 | x509_extensions = x509_ext 6 | string_mask = utf8only 7 | prompt = no 8 | 9 | [ subject ] 10 | C = US 11 | ST = Nevada 12 | L = Las Vegas 13 | O = github.com/lib/pq 14 | 15 | [ x509_ext ] 16 | subjectKeyIdentifier = hash 17 | authorityKeyIdentifier = keyid,issuer 18 | 19 | basicConstraints = CA:FALSE 20 | keyUsage = digitalSignature, keyEncipherment 21 | subjectAltName = DNS:postgres 22 | nsComment = "OpenSSL Generated Certificate" 23 | 24 | [ req_ext ] 25 | subjectKeyIdentifier = hash 26 | basicConstraints = CA:FALSE 27 | keyUsage = digitalSignature, keyEncipherment 28 | subjectAltName = DNS:postgres 29 | nsComment = "OpenSSL Generated Certificate" 30 | -------------------------------------------------------------------------------- /certs/server.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDqzCCApOgAwIBAgIJAPiewLrOyYipMA0GCSqGSIb3DQEBCwUAMF4xCzAJBgNV 3 | BAYTAlVTMQ8wDQYDVQQIDAZOZXZhZGExEjAQBgNVBAcMCUxhcyBWZWdhczEaMBgG 4 | A1UECgwRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMMBXBxIENBMB4XDTIxMDkw 5 | MjAxNTUwMloXDTMxMDkwMzAxNTUwMlowTjELMAkGA1UEBhMCVVMxDzANBgNVBAgM 6 | Bk5ldmFkYTESMBAGA1UEBwwJTGFzIFZlZ2FzMRowGAYDVQQKDBFnaXRodWIuY29t 7 | L2xpYi9wcTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKf6H4UzmANN 8 | QiQJe92Mf3ETMYmpZKNNO9DPEHyNLIkag+XwMrBTdcCK0mLvsNCYpXuBN6703KCd 9 | WAFOeMmj7gOsWtvjt5Xm6bRHLgegekXzcG/jDwq/wyzeDzr/YkITuIlG44Lf9lhY 10 | FLwiHlHOWHnwrZaEh6aU//02aQkzyX5INeXl/3TZm2G2eIH6AOxOKOU27MUsyVSQ 11 | 5DE+SDKGcRP4bElueeQWvxAXNMZYb7sVSDdfHI3zr32K4k/tC8x0fZJ5XN/dvl4t 12 | 4N4MrYlmDO5XOrb/gQH1H4iu6+5EMDfZYab4fkThnNFdfFqu4/8Scv7KZ8mWqpKM 13 | fGAjEPctQi0CAwEAAaN8MHowHQYDVR0OBBYEFENExPbmDyFB2AJUdbMvVyhlNPD5 14 | MAkGA1UdEwQCMAAwCwYDVR0PBAQDAgWgMBMGA1UdEQQMMAqCCHBvc3RncmVzMCwG 15 | CWCGSAGG+EIBDQQfFh1PcGVuU1NMIEdlbmVyYXRlZCBDZXJ0aWZpY2F0ZTANBgkq 16 | hkiG9w0BAQsFAAOCAQEAMRVbV8RiEsmp9HAtnVCZmRXMIbgPGrqjeSwk586s4K8v 17 | BSqNCqxv6s5GfCRmDYiqSqeuCVDtUJS1HsTmbxVV7Ke71WMo+xHR1ICGKOa8WGCb 18 | TGsuicG5QZXWaxeMOg4s0qpKmKko0d1aErdVsanU5dkrVS7D6729Ffnzu4lwApk6 19 | invAB67p8u7sojwqRq5ce0vRaG+YFylTrWomF9kauEb8gKbQ9Xc7QfX+h+UH/mq9 20 | Nvdj8LOHp6/82bZdnsYUOtV4lS1IA/qzeXpqBphxqfWabD1yLtkyJyImZKq8uIPp 21 | 0CG4jhObPdWcCkXD6bg3QK3mhwlC79OtFgxWmldCRQ== 22 | -----END CERTIFICATE----- 23 | -------------------------------------------------------------------------------- /certs/server.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCn+h+FM5gDTUIk 3 | CXvdjH9xEzGJqWSjTTvQzxB8jSyJGoPl8DKwU3XAitJi77DQmKV7gTeu9NygnVgB 4 | TnjJo+4DrFrb47eV5um0Ry4HoHpF83Bv4w8Kv8Ms3g86/2JCE7iJRuOC3/ZYWBS8 5 | Ih5Rzlh58K2WhIemlP/9NmkJM8l+SDXl5f902ZthtniB+gDsTijlNuzFLMlUkOQx 6 | PkgyhnET+GxJbnnkFr8QFzTGWG+7FUg3XxyN8699iuJP7QvMdH2SeVzf3b5eLeDe 7 | DK2JZgzuVzq2/4EB9R+IruvuRDA32WGm+H5E4ZzRXXxaruP/EnL+ymfJlqqSjHxg 8 | IxD3LUItAgMBAAECggEAOE2naQ9tIZYw2EFxikZApVcooJrtx6ropMnzHbx4NBB2 9 | K4mChAXFj184u77ZxmGT/jzGvFcI6LE0wWNbK0NOUV7hKZk/fPhkV3AQZrAMrAu4 10 | IVi7PwAd3JkmA8F8XuebUDA5rDGDsgL8GD9baFJA58abeLs9eMGyuF4XgOUh4bip 11 | hgHa76O2rcDWNY5HZqqRslw75FzlYkB0PCts/UJxSswj70kTTihyOhDlrm2TnyxI 12 | ne54UbGRrpfs9wiheSGLjDG81qZToBHQDwoAnjjZhu1VCaBISuGbgZrxyyRyqdnn 13 | xPW+KczMv04XyvF7v6Pz+bUEppalLXGiXnH5UtWvZQKBgQDTPCdMpNE/hwlq4nAw 14 | Kf42zIBWfbnMLVWYoeDiAOhtl9XAUAXn76xe6Rvo0qeAo67yejdbJfRq3HvGyw+q 15 | 4PS8r9gXYmLYIPQxSoLL5+rFoBCN3qFippfjLB1j32mp7+15KjRj8FF2r6xIN8fu 16 | XatSRsaqmvCWYLDRv/rbHnxwkwKBgQDLkyfFLF7BtwtPWKdqrwOM7ip1UKh+oDBS 17 | vkCQ08aEFRBU7T3jChsx5GbaW6zmsSBwBwcrHclpSkz7n3aq19DDWObJR2p80Fma 18 | rsXeIcvtEpkvT3pVX268P5d+XGs1kxgFunqTysG9yChW+xzcs5MdKBzuMPPn7rL8 19 | MKAzdar6PwKBgEypkzW8x3h/4Moa3k6MnwdyVs2NGaZheaRIc95yJ+jGZzxBjrMr 20 | h+p2PbvU4BfO0AqOkpKRBtDVrlJqlggVVp04UHvEKE16QEW3Xhr0037f5cInX3j3 21 | Lz6yXwRFLAsR2aTUzWjL6jTh8uvO2s/GzQuyRh3a16Ar/WBShY+K0+zjAoGATnLT 22 | xZjWnyHRmu8X/PWakamJ9RFzDPDgDlLAgM8LVgTj+UY/LgnL9wsEU6s2UuP5ExKy 23 | QXxGDGwUhHar/SQTj+Pnc7Mwpw6HKSOmnnY5po8fNusSwml3O9XppEkrC0c236Y/ 24 | 7EobJO5IFVTJh4cv7vFxTJzSsRL8KFD4uzvh+nMCgYEAqY8NBYtIgNJA2B6C6hHF 25 | +bG7v46434ZHFfGTmMQwzE4taVg7YRnzYESAlvK4bAP5ZXR90n7GRGFhrXzoMZ38 26 | r0bw/q9rV+ReGda7/Bjf7ciCKiq0RODcHtf4IaskjPXCoQRGJtgCPLhWPfld6g9v 27 | /HTvO96xv9e3eG/PKSPog94= 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /conn_go115.go: -------------------------------------------------------------------------------- 1 | //go:build go1.15 2 | // +build go1.15 3 | 4 | package pq 5 | 6 | import "database/sql/driver" 7 | 8 | var _ driver.Validator = &conn{} 9 | -------------------------------------------------------------------------------- /conn_go18.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "fmt" 8 | "io" 9 | "io/ioutil" 10 | "time" 11 | ) 12 | 13 | const ( 14 | watchCancelDialContextTimeout = time.Second * 10 15 | ) 16 | 17 | // Implement the "QueryerContext" interface 18 | func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 19 | list := make([]driver.Value, len(args)) 20 | for i, nv := range args { 21 | list[i] = nv.Value 22 | } 23 | finish := cn.watchCancel(ctx) 24 | r, err := cn.query(query, list) 25 | if err != nil { 26 | if finish != nil { 27 | finish() 28 | } 29 | return nil, err 30 | } 31 | r.finish = finish 32 | return r, nil 33 | } 34 | 35 | // Implement the "ExecerContext" interface 36 | func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 37 | list := make([]driver.Value, len(args)) 38 | for i, nv := range args { 39 | list[i] = nv.Value 40 | } 41 | 42 | if finish := cn.watchCancel(ctx); finish != nil { 43 | defer finish() 44 | } 45 | 46 | return cn.Exec(query, list) 47 | } 48 | 49 | // Implement the "ConnPrepareContext" interface 50 | func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 51 | if finish := cn.watchCancel(ctx); finish != nil { 52 | defer finish() 53 | } 54 | return cn.Prepare(query) 55 | } 56 | 57 | // Implement the "ConnBeginTx" interface 58 | func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 59 | var mode string 60 | 61 | switch sql.IsolationLevel(opts.Isolation) { 62 | case sql.LevelDefault: 63 | // Don't touch mode: use the server's default 64 | case sql.LevelReadUncommitted: 65 | mode = " ISOLATION LEVEL READ UNCOMMITTED" 66 | case sql.LevelReadCommitted: 67 | mode = " ISOLATION LEVEL READ COMMITTED" 68 | case sql.LevelRepeatableRead: 69 | mode = " ISOLATION LEVEL REPEATABLE READ" 70 | case sql.LevelSerializable: 71 | mode = " ISOLATION LEVEL SERIALIZABLE" 72 | default: 73 | return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) 74 | } 75 | 76 | if opts.ReadOnly { 77 | mode += " READ ONLY" 78 | } else { 79 | mode += " READ WRITE" 80 | } 81 | 82 | tx, err := cn.begin(mode) 83 | if err != nil { 84 | return nil, err 85 | } 86 | cn.txnFinish = cn.watchCancel(ctx) 87 | return tx, nil 88 | } 89 | 90 | func (cn *conn) Ping(ctx context.Context) error { 91 | if finish := cn.watchCancel(ctx); finish != nil { 92 | defer finish() 93 | } 94 | rows, err := cn.simpleQuery(";") 95 | if err != nil { 96 | return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger 97 | } 98 | rows.Close() 99 | return nil 100 | } 101 | 102 | func (cn *conn) watchCancel(ctx context.Context) func() { 103 | if done := ctx.Done(); done != nil { 104 | finished := make(chan struct{}, 1) 105 | go func() { 106 | select { 107 | case <-done: 108 | select { 109 | case finished <- struct{}{}: 110 | default: 111 | // We raced with the finish func, let the next query handle this with the 112 | // context. 113 | return 114 | } 115 | 116 | // Set the connection state to bad so it does not get reused. 117 | cn.err.set(ctx.Err()) 118 | 119 | // At this point the function level context is canceled, 120 | // so it must not be used for the additional network 121 | // request to cancel the query. 122 | // Create a new context to pass into the dial. 123 | ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) 124 | defer cancel() 125 | 126 | _ = cn.cancel(ctxCancel) 127 | case <-finished: 128 | } 129 | }() 130 | return func() { 131 | select { 132 | case <-finished: 133 | cn.err.set(ctx.Err()) 134 | cn.Close() 135 | case finished <- struct{}{}: 136 | } 137 | } 138 | } 139 | return nil 140 | } 141 | 142 | func (cn *conn) cancel(ctx context.Context) error { 143 | // Create a new values map (copy). This makes sure the connection created 144 | // in this method cannot write to the same underlying data, which could 145 | // cause a concurrent map write panic. This is necessary because cancel 146 | // is called from a goroutine in watchCancel. 147 | o := make(values) 148 | for k, v := range cn.opts { 149 | o[k] = v 150 | } 151 | 152 | c, err := dial(ctx, cn.dialer, o) 153 | if err != nil { 154 | return err 155 | } 156 | defer c.Close() 157 | 158 | { 159 | can := conn{ 160 | c: c, 161 | } 162 | err = can.ssl(o) 163 | if err != nil { 164 | return err 165 | } 166 | 167 | w := can.writeBuf(0) 168 | w.int32(80877102) // cancel request code 169 | w.int32(cn.processID) 170 | w.int32(cn.secretKey) 171 | 172 | if err := can.sendStartupPacket(w); err != nil { 173 | return err 174 | } 175 | } 176 | 177 | // Read until EOF to ensure that the server received the cancel. 178 | { 179 | _, err := io.Copy(ioutil.Discard, c) 180 | return err 181 | } 182 | } 183 | 184 | // Implement the "StmtQueryContext" interface 185 | func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 186 | list := make([]driver.Value, len(args)) 187 | for i, nv := range args { 188 | list[i] = nv.Value 189 | } 190 | finish := st.watchCancel(ctx) 191 | r, err := st.query(list) 192 | if err != nil { 193 | if finish != nil { 194 | finish() 195 | } 196 | return nil, err 197 | } 198 | r.finish = finish 199 | return r, nil 200 | } 201 | 202 | // Implement the "StmtExecContext" interface 203 | func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 204 | list := make([]driver.Value, len(args)) 205 | for i, nv := range args { 206 | list[i] = nv.Value 207 | } 208 | 209 | if finish := st.watchCancel(ctx); finish != nil { 210 | defer finish() 211 | } 212 | 213 | return st.Exec(list) 214 | } 215 | 216 | // watchCancel is implemented on stmt in order to not mark the parent conn as bad 217 | func (st *stmt) watchCancel(ctx context.Context) func() { 218 | if done := ctx.Done(); done != nil { 219 | finished := make(chan struct{}) 220 | go func() { 221 | select { 222 | case <-done: 223 | // At this point the function level context is canceled, 224 | // so it must not be used for the additional network 225 | // request to cancel the query. 226 | // Create a new context to pass into the dial. 227 | ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) 228 | defer cancel() 229 | 230 | _ = st.cancel(ctxCancel) 231 | finished <- struct{}{} 232 | case <-finished: 233 | } 234 | }() 235 | return func() { 236 | select { 237 | case <-finished: 238 | case finished <- struct{}{}: 239 | } 240 | } 241 | } 242 | return nil 243 | } 244 | 245 | func (st *stmt) cancel(ctx context.Context) error { 246 | return st.cn.cancel(ctx) 247 | } 248 | -------------------------------------------------------------------------------- /conn_go19.go: -------------------------------------------------------------------------------- 1 | //go:build go1.9 2 | // +build go1.9 3 | 4 | package pq 5 | 6 | import ( 7 | "database/sql/driver" 8 | "reflect" 9 | ) 10 | 11 | var _ driver.NamedValueChecker = (*conn)(nil) 12 | 13 | func (c *conn) CheckNamedValue(nv *driver.NamedValue) error { 14 | if _, ok := nv.Value.(driver.Valuer); ok { 15 | // Ignore Valuer, for backward compatibility with pq.Array(). 16 | return driver.ErrSkip 17 | } 18 | 19 | // Ignoring []byte / []uint8. 20 | if _, ok := nv.Value.([]uint8); ok { 21 | return driver.ErrSkip 22 | } 23 | 24 | v := reflect.ValueOf(nv.Value) 25 | if v.Kind() == reflect.Ptr { 26 | v = v.Elem() 27 | } 28 | if v.Kind() == reflect.Slice { 29 | var err error 30 | nv.Value, err = Array(v.Interface()).Value() 31 | return err 32 | } 33 | 34 | return driver.ErrSkip 35 | } 36 | -------------------------------------------------------------------------------- /conn_go19_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.9 2 | // +build go1.9 3 | 4 | package pq 5 | 6 | import ( 7 | "fmt" 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | func TestArrayArg(t *testing.T) { 13 | db := openTestConn(t) 14 | defer db.Close() 15 | 16 | for _, tc := range []struct { 17 | pgType string 18 | in, out interface{} 19 | }{ 20 | { 21 | pgType: "int[]", 22 | in: []int{245, 231}, 23 | out: []int64{245, 231}, 24 | }, 25 | { 26 | pgType: "int[]", 27 | in: &[]int{245, 231}, 28 | out: []int64{245, 231}, 29 | }, 30 | { 31 | pgType: "int[]", 32 | in: []int64{245, 231}, 33 | }, 34 | { 35 | pgType: "int[]", 36 | in: &[]int64{245, 231}, 37 | out: []int64{245, 231}, 38 | }, 39 | { 40 | pgType: "varchar[]", 41 | in: []string{"hello", "world"}, 42 | }, 43 | { 44 | pgType: "varchar[]", 45 | in: &[]string{"hello", "world"}, 46 | out: []string{"hello", "world"}, 47 | }, 48 | } { 49 | if tc.out == nil { 50 | tc.out = tc.in 51 | } 52 | t.Run(fmt.Sprintf("%#v", tc.in), func(t *testing.T) { 53 | r, err := db.Query(fmt.Sprintf("SELECT $1::%s", tc.pgType), tc.in) 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | defer r.Close() 58 | 59 | if !r.Next() { 60 | if r.Err() != nil { 61 | t.Fatal(r.Err()) 62 | } 63 | t.Fatal("expected row") 64 | } 65 | 66 | defer func() { 67 | if r.Next() { 68 | t.Fatal("unexpected row") 69 | } 70 | }() 71 | 72 | got := reflect.New(reflect.TypeOf(tc.out)) 73 | if err := r.Scan(Array(got.Interface())); err != nil { 74 | t.Fatal(err) 75 | } 76 | 77 | if !reflect.DeepEqual(tc.out, got.Elem().Interface()) { 78 | t.Errorf("got %v, want %v", got, tc.out) 79 | } 80 | }) 81 | } 82 | 83 | } 84 | -------------------------------------------------------------------------------- /connector.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "context" 5 | "database/sql/driver" 6 | "errors" 7 | "fmt" 8 | "os" 9 | "strings" 10 | ) 11 | 12 | // Connector represents a fixed configuration for the pq driver with a given 13 | // name. Connector satisfies the database/sql/driver Connector interface and 14 | // can be used to create any number of DB Conn's via the database/sql OpenDB 15 | // function. 16 | // 17 | // See https://golang.org/pkg/database/sql/driver/#Connector. 18 | // See https://golang.org/pkg/database/sql/#OpenDB. 19 | type Connector struct { 20 | opts values 21 | dialer Dialer 22 | } 23 | 24 | // Connect returns a connection to the database using the fixed configuration 25 | // of this Connector. Context is not used. 26 | func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { 27 | return c.open(ctx) 28 | } 29 | 30 | // Dialer allows change the dialer used to open connections. 31 | func (c *Connector) Dialer(dialer Dialer) { 32 | c.dialer = dialer 33 | } 34 | 35 | // Driver returns the underlying driver of this Connector. 36 | func (c *Connector) Driver() driver.Driver { 37 | return &Driver{} 38 | } 39 | 40 | // NewConnector returns a connector for the pq driver in a fixed configuration 41 | // with the given dsn. The returned connector can be used to create any number 42 | // of equivalent Conn's. The returned connector is intended to be used with 43 | // database/sql.OpenDB. 44 | // 45 | // See https://golang.org/pkg/database/sql/driver/#Connector. 46 | // See https://golang.org/pkg/database/sql/#OpenDB. 47 | func NewConnector(dsn string) (*Connector, error) { 48 | var err error 49 | o := make(values) 50 | 51 | // A number of defaults are applied here, in this order: 52 | // 53 | // * Very low precedence defaults applied in every situation 54 | // * Environment variables 55 | // * Explicitly passed connection information 56 | o["host"] = "localhost" 57 | o["port"] = "5432" 58 | // N.B.: Extra float digits should be set to 3, but that breaks 59 | // Postgres 8.4 and older, where the max is 2. 60 | o["extra_float_digits"] = "2" 61 | for k, v := range parseEnviron(os.Environ()) { 62 | o[k] = v 63 | } 64 | 65 | if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { 66 | dsn, err = ParseURL(dsn) 67 | if err != nil { 68 | return nil, err 69 | } 70 | } 71 | 72 | if err := parseOpts(dsn, o); err != nil { 73 | return nil, err 74 | } 75 | 76 | // Use the "fallback" application name if necessary 77 | if fallback, ok := o["fallback_application_name"]; ok { 78 | if _, ok := o["application_name"]; !ok { 79 | o["application_name"] = fallback 80 | } 81 | } 82 | 83 | // We can't work with any client_encoding other than UTF-8 currently. 84 | // However, we have historically allowed the user to set it to UTF-8 85 | // explicitly, and there's no reason to break such programs, so allow that. 86 | // Note that the "options" setting could also set client_encoding, but 87 | // parsing its value is not worth it. Instead, we always explicitly send 88 | // client_encoding as a separate run-time parameter, which should override 89 | // anything set in options. 90 | if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { 91 | return nil, errors.New("client_encoding must be absent or 'UTF8'") 92 | } 93 | o["client_encoding"] = "UTF8" 94 | // DateStyle needs a similar treatment. 95 | if datestyle, ok := o["datestyle"]; ok { 96 | if datestyle != "ISO, MDY" { 97 | return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) 98 | } 99 | } else { 100 | o["datestyle"] = "ISO, MDY" 101 | } 102 | 103 | // If a user is not provided by any other means, the last 104 | // resort is to use the current operating system provided user 105 | // name. 106 | if _, ok := o["user"]; !ok { 107 | u, err := userCurrent() 108 | if err != nil { 109 | return nil, err 110 | } 111 | o["user"] = u 112 | } 113 | 114 | // SSL is not necessary or supported over UNIX domain sockets 115 | if network, _ := network(o); network == "unix" { 116 | o["sslmode"] = "disable" 117 | } 118 | 119 | return &Connector{opts: o, dialer: defaultDialer{}}, nil 120 | } 121 | -------------------------------------------------------------------------------- /connector_example_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package pq_test 5 | 6 | import ( 7 | "database/sql" 8 | "fmt" 9 | 10 | "github.com/lib/pq" 11 | ) 12 | 13 | func ExampleNewConnector() { 14 | name := "" 15 | connector, err := pq.NewConnector(name) 16 | if err != nil { 17 | fmt.Println(err) 18 | return 19 | } 20 | db := sql.OpenDB(connector) 21 | defer db.Close() 22 | 23 | // Use the DB 24 | txn, err := db.Begin() 25 | if err != nil { 26 | fmt.Println(err) 27 | return 28 | } 29 | txn.Rollback() 30 | } 31 | -------------------------------------------------------------------------------- /connector_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package pq 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "database/sql/driver" 10 | "os" 11 | "testing" 12 | ) 13 | 14 | func TestNewConnector_WorksWithOpenDB(t *testing.T) { 15 | name := "" 16 | c, err := NewConnector(name) 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | db := sql.OpenDB(c) 21 | defer db.Close() 22 | // database/sql might not call our Open at all unless we do something with 23 | // the connection 24 | txn, err := db.Begin() 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | txn.Rollback() 29 | } 30 | 31 | func TestNewConnector_Connect(t *testing.T) { 32 | name := "" 33 | c, err := NewConnector(name) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | db, err := c.Connect(context.Background()) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | defer db.Close() 42 | // database/sql might not call our Open at all unless we do something with 43 | // the connection 44 | txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | txn.Rollback() 49 | } 50 | 51 | func TestNewConnector_Driver(t *testing.T) { 52 | name := "" 53 | c, err := NewConnector(name) 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | db, err := c.Driver().Open(name) 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | defer db.Close() 62 | // database/sql might not call our Open at all unless we do something with 63 | // the connection 64 | txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | txn.Rollback() 69 | } 70 | 71 | func TestNewConnector_Environ(t *testing.T) { 72 | name := "" 73 | os.Setenv("PGPASSFILE", "/tmp/.pgpass") 74 | defer os.Unsetenv("PGPASSFILE") 75 | c, err := NewConnector(name) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | for key, expected := range map[string]string{ 80 | "passfile": "/tmp/.pgpass", 81 | } { 82 | if got := c.opts[key]; got != expected { 83 | t.Fatalf("Getting values from environment variables, for %v expected %s got %s", key, expected, got) 84 | } 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /copy.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "database/sql/driver" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "sync" 11 | ) 12 | 13 | var ( 14 | errCopyInClosed = errors.New("pq: copyin statement has already been closed") 15 | errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") 16 | errCopyToNotSupported = errors.New("pq: COPY TO is not supported") 17 | errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") 18 | errCopyInProgress = errors.New("pq: COPY in progress") 19 | ) 20 | 21 | // CopyIn creates a COPY FROM statement which can be prepared with 22 | // Tx.Prepare(). The target table should be visible in search_path. 23 | func CopyIn(table string, columns ...string) string { 24 | buffer := bytes.NewBufferString("COPY ") 25 | BufferQuoteIdentifier(table, buffer) 26 | buffer.WriteString(" (") 27 | makeStmt(buffer, columns...) 28 | return buffer.String() 29 | } 30 | 31 | // MakeStmt makes the stmt string for CopyIn and CopyInSchema. 32 | func makeStmt(buffer *bytes.Buffer, columns ...string) { 33 | //s := bytes.NewBufferString() 34 | for i, col := range columns { 35 | if i != 0 { 36 | buffer.WriteString(", ") 37 | } 38 | BufferQuoteIdentifier(col, buffer) 39 | } 40 | buffer.WriteString(") FROM STDIN") 41 | } 42 | 43 | // CopyInSchema creates a COPY FROM statement which can be prepared with 44 | // Tx.Prepare(). 45 | func CopyInSchema(schema, table string, columns ...string) string { 46 | buffer := bytes.NewBufferString("COPY ") 47 | BufferQuoteIdentifier(schema, buffer) 48 | buffer.WriteRune('.') 49 | BufferQuoteIdentifier(table, buffer) 50 | buffer.WriteString(" (") 51 | makeStmt(buffer, columns...) 52 | return buffer.String() 53 | } 54 | 55 | type copyin struct { 56 | cn *conn 57 | buffer []byte 58 | rowData chan []byte 59 | done chan bool 60 | 61 | closed bool 62 | 63 | mu struct { 64 | sync.Mutex 65 | err error 66 | driver.Result 67 | } 68 | } 69 | 70 | const ciBufferSize = 64 * 1024 71 | 72 | // flush buffer before the buffer is filled up and needs reallocation 73 | const ciBufferFlushSize = 63 * 1024 74 | 75 | func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { 76 | if !cn.isInTransaction() { 77 | return nil, errCopyNotSupportedOutsideTxn 78 | } 79 | 80 | ci := ©in{ 81 | cn: cn, 82 | buffer: make([]byte, 0, ciBufferSize), 83 | rowData: make(chan []byte), 84 | done: make(chan bool, 1), 85 | } 86 | // add CopyData identifier + 4 bytes for message length 87 | ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0) 88 | 89 | b := cn.writeBuf('Q') 90 | b.string(q) 91 | cn.send(b) 92 | 93 | awaitCopyInResponse: 94 | for { 95 | t, r := cn.recv1() 96 | switch t { 97 | case 'G': 98 | if r.byte() != 0 { 99 | err = errBinaryCopyNotSupported 100 | break awaitCopyInResponse 101 | } 102 | go ci.resploop() 103 | return ci, nil 104 | case 'H': 105 | err = errCopyToNotSupported 106 | break awaitCopyInResponse 107 | case 'E': 108 | err = parseError(r) 109 | case 'Z': 110 | if err == nil { 111 | ci.setBad(driver.ErrBadConn) 112 | errorf("unexpected ReadyForQuery in response to COPY") 113 | } 114 | cn.processReadyForQuery(r) 115 | return nil, err 116 | default: 117 | ci.setBad(driver.ErrBadConn) 118 | errorf("unknown response for copy query: %q", t) 119 | } 120 | } 121 | 122 | // something went wrong, abort COPY before we return 123 | b = cn.writeBuf('f') 124 | b.string(err.Error()) 125 | cn.send(b) 126 | 127 | for { 128 | t, r := cn.recv1() 129 | switch t { 130 | case 'c', 'C', 'E': 131 | case 'Z': 132 | // correctly aborted, we're done 133 | cn.processReadyForQuery(r) 134 | return nil, err 135 | default: 136 | ci.setBad(driver.ErrBadConn) 137 | errorf("unknown response for CopyFail: %q", t) 138 | } 139 | } 140 | } 141 | 142 | func (ci *copyin) flush(buf []byte) { 143 | // set message length (without message identifier) 144 | binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) 145 | 146 | _, err := ci.cn.c.Write(buf) 147 | if err != nil { 148 | panic(err) 149 | } 150 | } 151 | 152 | func (ci *copyin) resploop() { 153 | for { 154 | var r readBuf 155 | t, err := ci.cn.recvMessage(&r) 156 | if err != nil { 157 | ci.setBad(driver.ErrBadConn) 158 | ci.setError(err) 159 | ci.done <- true 160 | return 161 | } 162 | switch t { 163 | case 'C': 164 | // complete 165 | res, _ := ci.cn.parseComplete(r.string()) 166 | ci.setResult(res) 167 | case 'N': 168 | if n := ci.cn.noticeHandler; n != nil { 169 | n(parseError(&r)) 170 | } 171 | case 'Z': 172 | ci.cn.processReadyForQuery(&r) 173 | ci.done <- true 174 | return 175 | case 'E': 176 | err := parseError(&r) 177 | ci.setError(err) 178 | default: 179 | ci.setBad(driver.ErrBadConn) 180 | ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) 181 | ci.done <- true 182 | return 183 | } 184 | } 185 | } 186 | 187 | func (ci *copyin) setBad(err error) { 188 | ci.cn.err.set(err) 189 | } 190 | 191 | func (ci *copyin) getBad() error { 192 | return ci.cn.err.get() 193 | } 194 | 195 | func (ci *copyin) err() error { 196 | ci.mu.Lock() 197 | err := ci.mu.err 198 | ci.mu.Unlock() 199 | return err 200 | } 201 | 202 | // setError() sets ci.err if one has not been set already. Caller must not be 203 | // holding ci.Mutex. 204 | func (ci *copyin) setError(err error) { 205 | ci.mu.Lock() 206 | if ci.mu.err == nil { 207 | ci.mu.err = err 208 | } 209 | ci.mu.Unlock() 210 | } 211 | 212 | func (ci *copyin) setResult(result driver.Result) { 213 | ci.mu.Lock() 214 | ci.mu.Result = result 215 | ci.mu.Unlock() 216 | } 217 | 218 | func (ci *copyin) getResult() driver.Result { 219 | ci.mu.Lock() 220 | result := ci.mu.Result 221 | ci.mu.Unlock() 222 | if result == nil { 223 | return driver.RowsAffected(0) 224 | } 225 | return result 226 | } 227 | 228 | func (ci *copyin) NumInput() int { 229 | return -1 230 | } 231 | 232 | func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { 233 | return nil, ErrNotSupported 234 | } 235 | 236 | // Exec inserts values into the COPY stream. The insert is asynchronous 237 | // and Exec can return errors from previous Exec calls to the same 238 | // COPY stmt. 239 | // 240 | // You need to call Exec(nil) to sync the COPY stream and to get any 241 | // errors from pending data, since Stmt.Close() doesn't return errors 242 | // to the user. 243 | func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { 244 | if ci.closed { 245 | return nil, errCopyInClosed 246 | } 247 | 248 | if err := ci.getBad(); err != nil { 249 | return nil, err 250 | } 251 | defer ci.cn.errRecover(&err) 252 | 253 | if err := ci.err(); err != nil { 254 | return nil, err 255 | } 256 | 257 | if len(v) == 0 { 258 | if err := ci.Close(); err != nil { 259 | return driver.RowsAffected(0), err 260 | } 261 | 262 | return ci.getResult(), nil 263 | } 264 | 265 | numValues := len(v) 266 | for i, value := range v { 267 | ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value) 268 | if i < numValues-1 { 269 | ci.buffer = append(ci.buffer, '\t') 270 | } 271 | } 272 | 273 | ci.buffer = append(ci.buffer, '\n') 274 | 275 | if len(ci.buffer) > ciBufferFlushSize { 276 | ci.flush(ci.buffer) 277 | // reset buffer, keep bytes for message identifier and length 278 | ci.buffer = ci.buffer[:5] 279 | } 280 | 281 | return driver.RowsAffected(0), nil 282 | } 283 | 284 | // CopyData inserts a raw string into the COPY stream. The insert is 285 | // asynchronous and CopyData can return errors from previous CopyData calls to 286 | // the same COPY stmt. 287 | // 288 | // You need to call Exec(nil) to sync the COPY stream and to get any 289 | // errors from pending data, since Stmt.Close() doesn't return errors 290 | // to the user. 291 | func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) { 292 | if ci.closed { 293 | return nil, errCopyInClosed 294 | } 295 | 296 | if finish := ci.cn.watchCancel(ctx); finish != nil { 297 | defer finish() 298 | } 299 | 300 | if err := ci.getBad(); err != nil { 301 | return nil, err 302 | } 303 | defer ci.cn.errRecover(&err) 304 | 305 | if err := ci.err(); err != nil { 306 | return nil, err 307 | } 308 | 309 | ci.buffer = append(ci.buffer, []byte(line)...) 310 | ci.buffer = append(ci.buffer, '\n') 311 | 312 | if len(ci.buffer) > ciBufferFlushSize { 313 | ci.flush(ci.buffer) 314 | // reset buffer, keep bytes for message identifier and length 315 | ci.buffer = ci.buffer[:5] 316 | } 317 | 318 | return driver.RowsAffected(0), nil 319 | } 320 | 321 | func (ci *copyin) Close() (err error) { 322 | if ci.closed { // Don't do anything, we're already closed 323 | return nil 324 | } 325 | ci.closed = true 326 | 327 | if err := ci.getBad(); err != nil { 328 | return err 329 | } 330 | defer ci.cn.errRecover(&err) 331 | 332 | if len(ci.buffer) > 0 { 333 | ci.flush(ci.buffer) 334 | } 335 | // Avoid touching the scratch buffer as resploop could be using it. 336 | err = ci.cn.sendSimpleMessage('c') 337 | if err != nil { 338 | return err 339 | } 340 | 341 | <-ci.done 342 | ci.cn.inCopy = false 343 | 344 | if err := ci.err(); err != nil { 345 | return err 346 | } 347 | return nil 348 | } 349 | -------------------------------------------------------------------------------- /copy_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "database/sql/driver" 7 | "fmt" 8 | "net" 9 | "strings" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestCopyInStmt(t *testing.T) { 15 | stmt := CopyIn("table name") 16 | if stmt != `COPY "table name" () FROM STDIN` { 17 | t.Fatal(stmt) 18 | } 19 | 20 | stmt = CopyIn("table name", "column 1", "column 2") 21 | if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` { 22 | t.Fatal(stmt) 23 | } 24 | 25 | stmt = CopyIn(`table " name """`, `co"lumn""`) 26 | if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` { 27 | t.Fatal(stmt) 28 | } 29 | } 30 | 31 | func TestCopyInSchemaStmt(t *testing.T) { 32 | stmt := CopyInSchema("schema name", "table name") 33 | if stmt != `COPY "schema name"."table name" () FROM STDIN` { 34 | t.Fatal(stmt) 35 | } 36 | 37 | stmt = CopyInSchema("schema name", "table name", "column 1", "column 2") 38 | if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` { 39 | t.Fatal(stmt) 40 | } 41 | 42 | stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`) 43 | if stmt != `COPY "schema "" name """"""".`+ 44 | `"table "" name """"""" ("co""lumn""""") FROM STDIN` { 45 | t.Fatal(stmt) 46 | } 47 | } 48 | 49 | func TestCopyInMultipleValues(t *testing.T) { 50 | db := openTestConn(t) 51 | defer db.Close() 52 | 53 | txn, err := db.Begin() 54 | if err != nil { 55 | t.Fatal(err) 56 | } 57 | defer txn.Rollback() 58 | 59 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 60 | if err != nil { 61 | t.Fatal(err) 62 | } 63 | 64 | stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 65 | if err != nil { 66 | t.Fatal(err) 67 | } 68 | 69 | longString := strings.Repeat("#", 500) 70 | 71 | for i := 0; i < 500; i++ { 72 | _, err = stmt.Exec(int64(i), longString) 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | } 77 | 78 | result, err := stmt.Exec() 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | 83 | rowsAffected, err := result.RowsAffected() 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | 88 | if rowsAffected != 500 { 89 | t.Fatalf("expected 500 rows affected, not %d", rowsAffected) 90 | } 91 | 92 | err = stmt.Close() 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | 97 | var num int 98 | err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 99 | if err != nil { 100 | t.Fatal(err) 101 | } 102 | 103 | if num != 500 { 104 | t.Fatalf("expected 500 items, not %d", num) 105 | } 106 | } 107 | 108 | func TestCopyInRaiseStmtTrigger(t *testing.T) { 109 | db := openTestConn(t) 110 | defer db.Close() 111 | 112 | if getServerVersion(t, db) < 90000 { 113 | var exists int 114 | err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists) 115 | if err == sql.ErrNoRows { 116 | t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger") 117 | } else if err != nil { 118 | t.Fatal(err) 119 | } 120 | } 121 | 122 | txn, err := db.Begin() 123 | if err != nil { 124 | t.Fatal(err) 125 | } 126 | defer txn.Rollback() 127 | 128 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | 133 | _, err = txn.Exec(` 134 | CREATE OR REPLACE FUNCTION pg_temp.temptest() 135 | RETURNS trigger AS 136 | $BODY$ begin 137 | raise notice 'Hello world'; 138 | return new; 139 | end $BODY$ 140 | LANGUAGE plpgsql`) 141 | if err != nil { 142 | t.Fatal(err) 143 | } 144 | 145 | _, err = txn.Exec(` 146 | CREATE TRIGGER temptest_trigger 147 | BEFORE INSERT 148 | ON temp 149 | FOR EACH ROW 150 | EXECUTE PROCEDURE pg_temp.temptest()`) 151 | if err != nil { 152 | t.Fatal(err) 153 | } 154 | 155 | stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 156 | if err != nil { 157 | t.Fatal(err) 158 | } 159 | 160 | longString := strings.Repeat("#", 500) 161 | 162 | _, err = stmt.Exec(int64(1), longString) 163 | if err != nil { 164 | t.Fatal(err) 165 | } 166 | 167 | _, err = stmt.Exec() 168 | if err != nil { 169 | t.Fatal(err) 170 | } 171 | 172 | err = stmt.Close() 173 | if err != nil { 174 | t.Fatal(err) 175 | } 176 | 177 | var num int 178 | err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 179 | if err != nil { 180 | t.Fatal(err) 181 | } 182 | 183 | if num != 1 { 184 | t.Fatalf("expected 1 items, not %d", num) 185 | } 186 | } 187 | 188 | func TestCopyInTypes(t *testing.T) { 189 | db := openTestConn(t) 190 | defer db.Close() 191 | 192 | txn, err := db.Begin() 193 | if err != nil { 194 | t.Fatal(err) 195 | } 196 | defer txn.Rollback() 197 | 198 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)") 199 | if err != nil { 200 | t.Fatal(err) 201 | } 202 | 203 | stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) 204 | if err != nil { 205 | t.Fatal(err) 206 | } 207 | 208 | _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil) 209 | if err != nil { 210 | t.Fatal(err) 211 | } 212 | 213 | _, err = stmt.Exec() 214 | if err != nil { 215 | t.Fatal(err) 216 | } 217 | 218 | err = stmt.Close() 219 | if err != nil { 220 | t.Fatal(err) 221 | } 222 | 223 | var num int 224 | var text string 225 | var blob []byte 226 | var nothing sql.NullString 227 | 228 | err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing) 229 | if err != nil { 230 | t.Fatal(err) 231 | } 232 | 233 | if num != 1234567890 { 234 | t.Fatal("unexpected result", num) 235 | } 236 | if text != "Héllö\n ☃!\r\t\\" { 237 | t.Fatal("unexpected result", text) 238 | } 239 | if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) { 240 | t.Fatal("unexpected result", blob) 241 | } 242 | if nothing.Valid { 243 | t.Fatal("unexpected result", nothing.String) 244 | } 245 | } 246 | 247 | func TestCopyInWrongType(t *testing.T) { 248 | db := openTestConn(t) 249 | defer db.Close() 250 | 251 | txn, err := db.Begin() 252 | if err != nil { 253 | t.Fatal(err) 254 | } 255 | defer txn.Rollback() 256 | 257 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 258 | if err != nil { 259 | t.Fatal(err) 260 | } 261 | 262 | stmt, err := txn.Prepare(CopyIn("temp", "num")) 263 | if err != nil { 264 | t.Fatal(err) 265 | } 266 | defer stmt.Close() 267 | 268 | _, err = stmt.Exec("Héllö\n ☃!\r\t\\") 269 | if err != nil { 270 | t.Fatal(err) 271 | } 272 | 273 | _, err = stmt.Exec() 274 | if err == nil { 275 | t.Fatal("expected error") 276 | } 277 | if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" { 278 | t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge) 279 | } 280 | } 281 | 282 | func TestCopyOutsideOfTxnError(t *testing.T) { 283 | db := openTestConn(t) 284 | defer db.Close() 285 | 286 | _, err := db.Prepare(CopyIn("temp", "num")) 287 | if err == nil { 288 | t.Fatal("COPY outside of transaction did not return an error") 289 | } 290 | if err != errCopyNotSupportedOutsideTxn { 291 | t.Fatalf("expected %s, got %s", err, err.Error()) 292 | } 293 | } 294 | 295 | func TestCopyInBinaryError(t *testing.T) { 296 | db := openTestConn(t) 297 | defer db.Close() 298 | 299 | txn, err := db.Begin() 300 | if err != nil { 301 | t.Fatal(err) 302 | } 303 | defer txn.Rollback() 304 | 305 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 306 | if err != nil { 307 | t.Fatal(err) 308 | } 309 | _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary") 310 | if err != errBinaryCopyNotSupported { 311 | t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err) 312 | } 313 | // check that the protocol is in a valid state 314 | err = txn.Rollback() 315 | if err != nil { 316 | t.Fatal(err) 317 | } 318 | } 319 | 320 | func TestCopyFromError(t *testing.T) { 321 | db := openTestConn(t) 322 | defer db.Close() 323 | 324 | txn, err := db.Begin() 325 | if err != nil { 326 | t.Fatal(err) 327 | } 328 | defer txn.Rollback() 329 | 330 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 331 | if err != nil { 332 | t.Fatal(err) 333 | } 334 | _, err = txn.Prepare("COPY temp (num) TO STDOUT") 335 | if err != errCopyToNotSupported { 336 | t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err) 337 | } 338 | // check that the protocol is in a valid state 339 | err = txn.Rollback() 340 | if err != nil { 341 | t.Fatal(err) 342 | } 343 | } 344 | 345 | func TestCopySyntaxError(t *testing.T) { 346 | db := openTestConn(t) 347 | defer db.Close() 348 | 349 | txn, err := db.Begin() 350 | if err != nil { 351 | t.Fatal(err) 352 | } 353 | defer txn.Rollback() 354 | 355 | _, err = txn.Prepare("COPY ") 356 | if err == nil { 357 | t.Fatal("expected error") 358 | } 359 | if pge := err.(*Error); pge.Code.Name() != "syntax_error" { 360 | t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge) 361 | } 362 | // check that the protocol is in a valid state 363 | err = txn.Rollback() 364 | if err != nil { 365 | t.Fatal(err) 366 | } 367 | } 368 | 369 | // Tests for connection errors in copyin.resploop() 370 | func TestCopyRespLoopConnectionError(t *testing.T) { 371 | db := openTestConn(t) 372 | defer db.Close() 373 | 374 | txn, err := db.Begin() 375 | if err != nil { 376 | t.Fatal(err) 377 | } 378 | defer txn.Rollback() 379 | 380 | var pid int 381 | err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid) 382 | if err != nil { 383 | t.Fatal(err) 384 | } 385 | 386 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int)") 387 | if err != nil { 388 | t.Fatal(err) 389 | } 390 | 391 | stmt, err := txn.Prepare(CopyIn("temp", "a")) 392 | if err != nil { 393 | t.Fatal(err) 394 | } 395 | defer stmt.Close() 396 | 397 | _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) 398 | if err != nil { 399 | t.Fatal(err) 400 | } 401 | 402 | if getServerVersion(t, db) < 90500 { 403 | // We have to try and send something over, since postgres before 404 | // version 9.5 won't process SIGTERMs while it's waiting for 405 | // CopyData/CopyEnd messages; see tcop/postgres.c. 406 | _, err = stmt.Exec(1) 407 | if err != nil { 408 | t.Fatal(err) 409 | } 410 | } 411 | retry(t, time.Second*5, func() error { 412 | _, err = stmt.Exec() 413 | if err == nil { 414 | return fmt.Errorf("expected error") 415 | } 416 | return nil 417 | }) 418 | switch pge := err.(type) { 419 | case *Error: 420 | if pge.Code.Name() != "admin_shutdown" { 421 | t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) 422 | } 423 | case *net.OpError: 424 | // ignore 425 | default: 426 | if err == driver.ErrBadConn { 427 | // likely an EPIPE 428 | } else if err == errCopyInClosed { 429 | // ignore 430 | } else { 431 | t.Fatalf("unexpected error, got %+#v", err) 432 | } 433 | } 434 | 435 | _ = stmt.Close() 436 | } 437 | 438 | // retry executes f in a backoff loop until it doesn't return an error. If this 439 | // doesn't happen within duration, t.Fatal is called with the latest error. 440 | func retry(t *testing.T, duration time.Duration, f func() error) { 441 | start := time.Now() 442 | next := time.Millisecond * 100 443 | for { 444 | err := f() 445 | if err == nil { 446 | return 447 | } 448 | if time.Since(start) > duration { 449 | t.Fatal(err) 450 | } 451 | time.Sleep(next) 452 | next *= 2 453 | } 454 | } 455 | 456 | func BenchmarkCopyIn(b *testing.B) { 457 | db := openTestConn(b) 458 | defer db.Close() 459 | 460 | txn, err := db.Begin() 461 | if err != nil { 462 | b.Fatal(err) 463 | } 464 | defer txn.Rollback() 465 | 466 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 467 | if err != nil { 468 | b.Fatal(err) 469 | } 470 | 471 | stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 472 | if err != nil { 473 | b.Fatal(err) 474 | } 475 | 476 | for i := 0; i < b.N; i++ { 477 | _, err = stmt.Exec(int64(i), "hello world!") 478 | if err != nil { 479 | b.Fatal(err) 480 | } 481 | } 482 | 483 | _, err = stmt.Exec() 484 | if err != nil { 485 | b.Fatal(err) 486 | } 487 | 488 | err = stmt.Close() 489 | if err != nil { 490 | b.Fatal(err) 491 | } 492 | 493 | var num int 494 | err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 495 | if err != nil { 496 | b.Fatal(err) 497 | } 498 | 499 | if num != b.N { 500 | b.Fatalf("expected %d items, not %d", b.N, num) 501 | } 502 | } 503 | 504 | var bigTableColumns = []string{"ABIOGENETICALLY", "ABORIGINALITIES", "ABSORBABILITIES", "ABSORBEFACIENTS", "ABSORPTIOMETERS", "ABSTRACTIONISMS", "ABSTRACTIONISTS", "ACANTHOCEPHALAN", "ACCEPTABILITIES", "ACCEPTINGNESSES", "ACCESSARINESSES", "ACCESSIBILITIES", "ACCESSORINESSES", "ACCIDENTALITIES", "ACCIDENTOLOGIES", "ACCLIMATISATION", "ACCLIMATIZATION", "ACCOMMODATINGLY", "ACCOMMODATIONAL", "ACCOMPLISHMENTS", "ACCOUNTABLENESS", "ACCOUNTANTSHIPS", "ACCULTURATIONAL", "ACETOPHENETIDIN", "ACETYLSALICYLIC", "ACHONDROPLASIAS", "ACHONDROPLASTIC", "ACHROMATICITIES", "ACHROMATISATION", "ACHROMATIZATION", "ACIDIMETRICALLY", "ACKNOWLEDGEABLE", "ACKNOWLEDGEABLY", "ACKNOWLEDGEMENT", "ACKNOWLEDGMENTS", "ACQUIRABILITIES", "ACQUISITIVENESS", "ACRIMONIOUSNESS", "ACROPARESTHESIA", "ACTINOBIOLOGIES", "ACTINOCHEMISTRY", "ACTINOTHERAPIES", "ADAPTABLENESSES", "ADDITIONALITIES", "ADENOCARCINOMAS", "ADENOHYPOPHYSES", "ADENOHYPOPHYSIS", "ADENOIDECTOMIES", "ADIATHERMANCIES", "ADJUSTABILITIES", "ADMINISTRATIONS", "ADMIRABLENESSES", "ADMISSIBILITIES", "ADRENALECTOMIES", "ADSORBABILITIES", "ADVENTUROUSNESS", "ADVERSARINESSES", "ADVISABLENESSES", "AERODYNAMICALLY", "AERODYNAMICISTS", "AEROELASTICIANS", "AEROHYDROPLANES", "AEROLITHOLOGIES", "AEROSOLISATIONS", "AEROSOLIZATIONS", "AFFECTABILITIES", "AFFECTIVENESSES", "AFFORDABILITIES", "AFFRANCHISEMENT", "AFTERSENSATIONS", "AGGLUTINABILITY", "AGGRANDISEMENTS", "AGGRANDIZEMENTS", "AGGREGATENESSES", "AGRANULOCYTOSES", "AGRANULOCYTOSIS", "AGREEABLENESSES", "AGRIBUSINESSMAN", "AGRIBUSINESSMEN", "AGRICULTURALIST", "AIRWORTHINESSES", "ALCOHOLISATIONS", "ALCOHOLIZATIONS", "ALCOHOLOMETRIES", "ALEXIPHARMAKONS", "ALGORITHMICALLY", "ALKALINISATIONS", "ALKALINIZATIONS", "ALLEGORICALNESS", "ALLEGORISATIONS", "ALLEGORIZATIONS", "ALLELOMORPHISMS", "ALLERGENICITIES", "ALLOTETRAPLOIDS", "ALLOTETRAPLOIDY", "ALLOTRIOMORPHIC", "ALLOWABLENESSES", "ALPHABETISATION", "ALPHABETIZATION", "ALTERNATIVENESS", "ALTITUDINARIANS", "ALUMINOSILICATE", "ALUMINOTHERMIES", "AMARYLLIDACEOUS", "AMBASSADORSHIPS", "AMBIDEXTERITIES", "AMBIGUOUSNESSES", "AMBISEXUALITIES", "AMBITIOUSNESSES", "AMINOPEPTIDASES", "AMINOPHENAZONES", "AMMONIFICATIONS", "AMORPHOUSNESSES", "AMPHIDIPLOIDIES", "AMPHITHEATRICAL", "ANACOLUTHICALLY", "ANACREONTICALLY", "ANAESTHESIOLOGY", "ANAESTHETICALLY", "ANAGRAMMATISING", "ANAGRAMMATIZING", "ANALOGOUSNESSES", "ANALYZABILITIES", "ANAMORPHOSCOPES", "ANCYLOSTOMIASES", "ANCYLOSTOMIASIS", "ANDROGYNOPHORES", "ANDROMEDOTOXINS", "ANDROMONOECIOUS", "ANDROMONOECISMS", "ANESTHETIZATION", "ANFRACTUOSITIES", "ANGUSTIROSTRATE", "ANIMATRONICALLY", "ANISOTROPICALLY", "ANKYLOSTOMIASES", "ANKYLOSTOMIASIS", "ANNIHILATIONISM", "ANOMALISTICALLY", "ANOMALOUSNESSES", "ANONYMOUSNESSES", "ANSWERABILITIES", "ANTAGONISATIONS", "ANTAGONIZATIONS", "ANTAPHRODISIACS", "ANTEPENULTIMATE", "ANTHROPOBIOLOGY", "ANTHROPOCENTRIC", "ANTHROPOGENESES", "ANTHROPOGENESIS", "ANTHROPOGENETIC", "ANTHROPOLATRIES", "ANTHROPOLOGICAL", "ANTHROPOLOGISTS", "ANTHROPOMETRIES", "ANTHROPOMETRIST", "ANTHROPOMORPHIC", "ANTHROPOPATHIES", "ANTHROPOPATHISM", "ANTHROPOPHAGIES", "ANTHROPOPHAGITE", "ANTHROPOPHAGOUS", "ANTHROPOPHOBIAS", "ANTHROPOPHOBICS", "ANTHROPOPHUISMS", "ANTHROPOPSYCHIC", "ANTHROPOSOPHIES", "ANTHROPOSOPHIST", "ANTIABORTIONIST", "ANTIALCOHOLISMS", "ANTIAPHRODISIAC", "ANTIARRHYTHMICS", "ANTICAPITALISMS", "ANTICAPITALISTS", "ANTICARCINOGENS", "ANTICHOLESTEROL", "ANTICHOLINERGIC", "ANTICHRISTIANLY", "ANTICLERICALISM", "ANTICLIMACTICAL", "ANTICOINCIDENCE", "ANTICOLONIALISM", "ANTICOLONIALIST", "ANTICOMPETITIVE", "ANTICONVULSANTS", "ANTICONVULSIVES", "ANTIDEPRESSANTS", "ANTIDERIVATIVES", "ANTIDEVELOPMENT", "ANTIEDUCATIONAL", "ANTIEGALITARIAN", "ANTIFASHIONABLE", "ANTIFEDERALISTS", "ANTIFERROMAGNET", "ANTIFORECLOSURE", "ANTIHELMINTHICS", "ANTIHISTAMINICS", "ANTILIBERALISMS", "ANTILIBERTARIAN", "ANTILOGARITHMIC", "ANTIMATERIALISM", "ANTIMATERIALIST", "ANTIMETABOLITES", "ANTIMILITARISMS", "ANTIMILITARISTS", "ANTIMONARCHICAL", "ANTIMONARCHISTS", "ANTIMONOPOLISTS", "ANTINATIONALIST", "ANTINUCLEARISTS", "ANTIODONTALGICS", "ANTIPERISTALSES", "ANTIPERISTALSIS", "ANTIPERISTALTIC", "ANTIPERSPIRANTS", "ANTIPHLOGISTICS", "ANTIPORNOGRAPHY", "ANTIPROGRESSIVE", "ANTIQUARIANISMS", "ANTIRADICALISMS", "ANTIRATIONALISM", "ANTIRATIONALIST", "ANTIRATIONALITY", "ANTIREPUBLICANS", "ANTIROMANTICISM", "ANTISEGREGATION", "ANTISENTIMENTAL", "ANTISEPARATISTS", "ANTISEPTICISING", "ANTISEPTICIZING", "ANTISEXUALITIES", "ANTISHOPLIFTING", "ANTISOCIALITIES", "ANTISPECULATION", "ANTISPECULATIVE", "ANTISYPHILITICS", "ANTITHEORETICAL", "ANTITHROMBOTICS", "ANTITRADITIONAL", "ANTITRANSPIRANT", "ANTITRINITARIAN", "ANTITUBERCULOUS", "ANTIVIVISECTION", "APHELIOTROPISMS", "APOCALYPTICALLY", "APOCALYPTICISMS", "APOLIPOPROTEINS", "APOLITICALITIES", "APOPHTHEGMATISE", "APOPHTHEGMATIST", "APOPHTHEGMATIZE", "APOTHEGMATISING", "APOTHEGMATIZING", "APPEALABILITIES", "APPEALINGNESSES", "APPENDICULARIAN", "APPLICABILITIES", "APPRENTICEHOODS", "APPRENTICEMENTS", "APPRENTICESHIPS", "APPROACHABILITY", "APPROPINQUATING", "APPROPINQUATION", "APPROPINQUITIES", "APPROPRIATENESS", "ARACHNOIDITISES", "ARBITRARINESSES", "ARBORICULTURIST", "ARCHAEBACTERIUM", "ARCHAEOBOTANIES", "ARCHAEOBOTANIST", "ARCHAEOMETRISTS", "ARCHAEOPTERYXES", "ARCHAEZOOLOGIES", "ARCHEOASTRONOMY", "ARCHEOBOTANISTS", "ARCHEOLOGICALLY", "ARCHEOMAGNETISM", "ARCHEOZOOLOGIES", "ARCHEOZOOLOGIST", "ARCHGENETHLIACS", "ARCHIDIACONATES", "ARCHIEPISCOPACY", "ARCHIEPISCOPATE", "ARCHITECTURALLY", "ARCHPRIESTHOODS", "ARCHPRIESTSHIPS", "ARGUMENTATIVELY", "ARIBOFLAVINOSES", "ARIBOFLAVINOSIS", "AROMATHERAPISTS", "ARRONDISSEMENTS", "ARTERIALISATION", "ARTERIALIZATION", "ARTERIOGRAPHIES", "ARTIFICIALISING", "ARTIFICIALITIES", "ARTIFICIALIZING", "ASCLEPIADACEOUS", "ASSENTIVENESSES"} 505 | 506 | func BenchmarkCopy(b *testing.B) { 507 | for i := 0; i < b.N; i++ { 508 | CopyIn("temp", bigTableColumns...) 509 | } 510 | } 511 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package pq is a pure Go Postgres driver for the database/sql package. 3 | 4 | In most cases clients will use the database/sql package instead of 5 | using this package directly. For example: 6 | 7 | import ( 8 | "database/sql" 9 | 10 | _ "github.com/lib/pq" 11 | ) 12 | 13 | func main() { 14 | connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full" 15 | db, err := sql.Open("postgres", connStr) 16 | if err != nil { 17 | log.Fatal(err) 18 | } 19 | 20 | age := 21 21 | rows, err := db.Query("SELECT name FROM users WHERE age = $1", age) 22 | … 23 | } 24 | 25 | You can also connect to a database using a URL. For example: 26 | 27 | connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full" 28 | db, err := sql.Open("postgres", connStr) 29 | 30 | 31 | Connection String Parameters 32 | 33 | 34 | Similarly to libpq, when establishing a connection using pq you are expected to 35 | supply a connection string containing zero or more parameters. 36 | A subset of the connection parameters supported by libpq are also supported by pq. 37 | Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem) 38 | directly in the connection string. This is different from libpq, which does not allow 39 | run-time parameters in the connection string, instead requiring you to supply 40 | them in the options parameter. 41 | 42 | For compatibility with libpq, the following special connection parameters are 43 | supported: 44 | 45 | * dbname - The name of the database to connect to 46 | * user - The user to sign in as 47 | * password - The user's password 48 | * host - The host to connect to. Values that start with / are for unix 49 | domain sockets. (default is localhost) 50 | * port - The port to bind to. (default is 5432) 51 | * sslmode - Whether or not to use SSL (default is require, this is not 52 | the default for libpq) 53 | * fallback_application_name - An application_name to fall back to if one isn't provided. 54 | * connect_timeout - Maximum wait for connection, in seconds. Zero or 55 | not specified means wait indefinitely. 56 | * sslcert - Cert file location. The file must contain PEM encoded data. 57 | * sslkey - Key file location. The file must contain PEM encoded data. 58 | * sslrootcert - The location of the root certificate file. The file 59 | must contain PEM encoded data. 60 | 61 | Valid values for sslmode are: 62 | 63 | * disable - No SSL 64 | * require - Always SSL (skip verification) 65 | * verify-ca - Always SSL (verify that the certificate presented by the 66 | server was signed by a trusted CA) 67 | * verify-full - Always SSL (verify that the certification presented by 68 | the server was signed by a trusted CA and the server host name 69 | matches the one in the certificate) 70 | 71 | See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING 72 | for more information about connection string parameters. 73 | 74 | Use single quotes for values that contain whitespace: 75 | 76 | "user=pqgotest password='with spaces'" 77 | 78 | A backslash will escape the next character in values: 79 | 80 | "user=space\ man password='it\'s valid'" 81 | 82 | Note that the connection parameter client_encoding (which sets the 83 | text encoding for the connection) may be set but must be "UTF8", 84 | matching with the same rules as Postgres. It is an error to provide 85 | any other value. 86 | 87 | In addition to the parameters listed above, any run-time parameter that can be 88 | set at backend start time can be set in the connection string. For more 89 | information, see 90 | http://www.postgresql.org/docs/current/static/runtime-config.html. 91 | 92 | Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html 93 | supported by libpq are also supported by pq. If any of the environment 94 | variables not supported by pq are set, pq will panic during connection 95 | establishment. Environment variables have a lower precedence than explicitly 96 | provided connection parameters. 97 | 98 | The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html 99 | is supported, but on Windows PGPASSFILE must be specified explicitly. 100 | 101 | 102 | Queries 103 | 104 | 105 | database/sql does not dictate any specific format for parameter 106 | markers in query strings, and pq uses the Postgres-native ordinal markers, 107 | as shown above. The same marker can be reused for the same parameter: 108 | 109 | rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1 110 | OR age BETWEEN $2 AND $2 + 3`, "orange", 64) 111 | 112 | pq does not support the LastInsertId() method of the Result type in database/sql. 113 | To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres 114 | RETURNING clause with a standard Query or QueryRow call: 115 | 116 | var userid int 117 | err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age) 118 | VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid) 119 | 120 | For more details on RETURNING, see the Postgres documentation: 121 | 122 | http://www.postgresql.org/docs/current/static/sql-insert.html 123 | http://www.postgresql.org/docs/current/static/sql-update.html 124 | http://www.postgresql.org/docs/current/static/sql-delete.html 125 | 126 | For additional instructions on querying see the documentation for the database/sql package. 127 | 128 | 129 | Data Types 130 | 131 | 132 | Parameters pass through driver.DefaultParameterConverter before they are handled 133 | by this package. When the binary_parameters connection option is enabled, 134 | []byte values are sent directly to the backend as data in binary format. 135 | 136 | This package returns the following types for values from the PostgreSQL backend: 137 | 138 | - integer types smallint, integer, and bigint are returned as int64 139 | - floating-point types real and double precision are returned as float64 140 | - character types char, varchar, and text are returned as string 141 | - temporal types date, time, timetz, timestamp, and timestamptz are 142 | returned as time.Time 143 | - the boolean type is returned as bool 144 | - the bytea type is returned as []byte 145 | 146 | All other types are returned directly from the backend as []byte values in text format. 147 | 148 | 149 | Errors 150 | 151 | 152 | pq may return errors of type *pq.Error which can be interrogated for error details: 153 | 154 | if err, ok := err.(*pq.Error); ok { 155 | fmt.Println("pq error:", err.Code.Name()) 156 | } 157 | 158 | See the pq.Error type for details. 159 | 160 | 161 | Bulk imports 162 | 163 | You can perform bulk imports by preparing a statement returned by pq.CopyIn (or 164 | pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement 165 | handle can then be repeatedly "executed" to copy data into the target table. 166 | After all data has been processed you should call Exec() once with no arguments 167 | to flush all buffered data. Any call to Exec() might return an error which 168 | should be handled appropriately, but because of the internal buffering an error 169 | returned by Exec() might not be related to the data passed in the call that 170 | failed. 171 | 172 | CopyIn uses COPY FROM internally. It is not possible to COPY outside of an 173 | explicit transaction in pq. 174 | 175 | Usage example: 176 | 177 | txn, err := db.Begin() 178 | if err != nil { 179 | log.Fatal(err) 180 | } 181 | 182 | stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age")) 183 | if err != nil { 184 | log.Fatal(err) 185 | } 186 | 187 | for _, user := range users { 188 | _, err = stmt.Exec(user.Name, int64(user.Age)) 189 | if err != nil { 190 | log.Fatal(err) 191 | } 192 | } 193 | 194 | _, err = stmt.Exec() 195 | if err != nil { 196 | log.Fatal(err) 197 | } 198 | 199 | err = stmt.Close() 200 | if err != nil { 201 | log.Fatal(err) 202 | } 203 | 204 | err = txn.Commit() 205 | if err != nil { 206 | log.Fatal(err) 207 | } 208 | 209 | 210 | Notifications 211 | 212 | 213 | PostgreSQL supports a simple publish/subscribe model over database 214 | connections. See http://www.postgresql.org/docs/current/static/sql-notify.html 215 | for more information about the general mechanism. 216 | 217 | To start listening for notifications, you first have to open a new connection 218 | to the database by calling NewListener. This connection can not be used for 219 | anything other than LISTEN / NOTIFY. Calling Listen will open a "notification 220 | channel"; once a notification channel is open, a notification generated on that 221 | channel will effect a send on the Listener.Notify channel. A notification 222 | channel will remain open until Unlisten is called, though connection loss might 223 | result in some notifications being lost. To solve this problem, Listener sends 224 | a nil pointer over the Notify channel any time the connection is re-established 225 | following a connection loss. The application can get information about the 226 | state of the underlying connection by setting an event callback in the call to 227 | NewListener. 228 | 229 | A single Listener can safely be used from concurrent goroutines, which means 230 | that there is often no need to create more than one Listener in your 231 | application. However, a Listener is always connected to a single database, so 232 | you will need to create a new Listener instance for every database you want to 233 | receive notifications in. 234 | 235 | The channel name in both Listen and Unlisten is case sensitive, and can contain 236 | any characters legal in an identifier (see 237 | http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS 238 | for more information). Note that the channel name will be truncated to 63 239 | bytes by the PostgreSQL server. 240 | 241 | You can find a complete, working example of Listener usage at 242 | https://godoc.org/github.com/lib/pq/example/listen. 243 | 244 | 245 | Kerberos Support 246 | 247 | 248 | If you need support for Kerberos authentication, add the following to your main 249 | package: 250 | 251 | import "github.com/lib/pq/auth/kerberos" 252 | 253 | func init() { 254 | pq.RegisterGSSProvider(func() (pq.Gss, error) { return kerberos.NewGSS() }) 255 | } 256 | 257 | This package is in a separate module so that users who don't need Kerberos 258 | don't have to download unnecessary dependencies. 259 | 260 | When imported, additional connection string parameters are supported: 261 | 262 | * krbsrvname - GSS (Kerberos) service name when constructing the 263 | SPN (default is `postgres`). This will be combined with the host 264 | to form the full SPN: `krbsrvname/host`. 265 | * krbspn - GSS (Kerberos) SPN. This takes priority over 266 | `krbsrvname` if present. 267 | */ 268 | package pq 269 | -------------------------------------------------------------------------------- /example/listen/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Package listen is a self-contained Go program which uses the LISTEN / NOTIFY 4 | mechanism to avoid polling the database while waiting for more work to arrive. 5 | 6 | // 7 | // You can see the program in action by defining a function similar to 8 | // the following: 9 | // 10 | // CREATE OR REPLACE FUNCTION public.get_work() 11 | // RETURNS bigint 12 | // LANGUAGE sql 13 | // AS $$ 14 | // SELECT CASE WHEN random() >= 0.2 THEN int8 '1' END 15 | // $$ 16 | // ; 17 | 18 | package main 19 | 20 | import ( 21 | "database/sql" 22 | "fmt" 23 | "time" 24 | 25 | "github.com/lib/pq" 26 | ) 27 | 28 | func doWork(db *sql.DB, work int64) { 29 | // work here 30 | } 31 | 32 | func getWork(db *sql.DB) { 33 | for { 34 | // get work from the database here 35 | var work sql.NullInt64 36 | err := db.QueryRow("SELECT get_work()").Scan(&work) 37 | if err != nil { 38 | fmt.Println("call to get_work() failed: ", err) 39 | time.Sleep(10 * time.Second) 40 | continue 41 | } 42 | if !work.Valid { 43 | // no more work to do 44 | fmt.Println("ran out of work") 45 | return 46 | } 47 | 48 | fmt.Println("starting work on ", work.Int64) 49 | go doWork(db, work.Int64) 50 | } 51 | } 52 | 53 | func waitForNotification(l *pq.Listener) { 54 | select { 55 | case <-l.Notify: 56 | fmt.Println("received notification, new work available") 57 | case <-time.After(90 * time.Second): 58 | go l.Ping() 59 | // Check if there's more work available, just in case it takes 60 | // a while for the Listener to notice connection loss and 61 | // reconnect. 62 | fmt.Println("received no work for 90 seconds, checking for new work") 63 | } 64 | } 65 | 66 | func main() { 67 | var conninfo string = "" 68 | 69 | db, err := sql.Open("postgres", conninfo) 70 | if err != nil { 71 | panic(err) 72 | } 73 | 74 | reportProblem := func(ev pq.ListenerEventType, err error) { 75 | if err != nil { 76 | fmt.Println(err.Error()) 77 | } 78 | } 79 | 80 | minReconn := 10 * time.Second 81 | maxReconn := time.Minute 82 | listener := pq.NewListener(conninfo, minReconn, maxReconn, reportProblem) 83 | err = listener.Listen("getwork") 84 | if err != nil { 85 | panic(err) 86 | } 87 | 88 | fmt.Println("entering main loop") 89 | for { 90 | // process all available work before waiting for notifications 91 | getWork(db) 92 | waitForNotification(listener) 93 | } 94 | } 95 | 96 | 97 | */ 98 | package listen 99 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lib/pq 2 | 3 | go 1.13 4 | -------------------------------------------------------------------------------- /go18_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "runtime" 9 | "strings" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestMultipleSimpleQuery(t *testing.T) { 15 | db := openTestConn(t) 16 | defer db.Close() 17 | 18 | rows, err := db.Query("select 1; set time zone default; select 2; select 3") 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | defer rows.Close() 23 | 24 | var i int 25 | for rows.Next() { 26 | if err := rows.Scan(&i); err != nil { 27 | t.Fatal(err) 28 | } 29 | if i != 1 { 30 | t.Fatalf("expected 1, got %d", i) 31 | } 32 | } 33 | if !rows.NextResultSet() { 34 | t.Fatal("expected more result sets", rows.Err()) 35 | } 36 | for rows.Next() { 37 | if err := rows.Scan(&i); err != nil { 38 | t.Fatal(err) 39 | } 40 | if i != 2 { 41 | t.Fatalf("expected 2, got %d", i) 42 | } 43 | } 44 | 45 | // Make sure that if we ignore a result we can still query. 46 | 47 | rows, err = db.Query("select 4; select 5") 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | defer rows.Close() 52 | 53 | for rows.Next() { 54 | if err := rows.Scan(&i); err != nil { 55 | t.Fatal(err) 56 | } 57 | if i != 4 { 58 | t.Fatalf("expected 4, got %d", i) 59 | } 60 | } 61 | if !rows.NextResultSet() { 62 | t.Fatal("expected more result sets", rows.Err()) 63 | } 64 | for rows.Next() { 65 | if err := rows.Scan(&i); err != nil { 66 | t.Fatal(err) 67 | } 68 | if i != 5 { 69 | t.Fatalf("expected 5, got %d", i) 70 | } 71 | } 72 | if rows.NextResultSet() { 73 | t.Fatal("unexpected result set") 74 | } 75 | } 76 | 77 | const contextRaceIterations = 100 78 | 79 | const cancelErrorCode ErrorCode = "57014" 80 | 81 | func TestContextCancelExec(t *testing.T) { 82 | db := openTestConn(t) 83 | defer db.Close() 84 | 85 | ctx, cancel := context.WithCancel(context.Background()) 86 | 87 | // Delay execution for just a bit until db.ExecContext has begun. 88 | defer time.AfterFunc(time.Millisecond*10, cancel).Stop() 89 | 90 | // Not canceled until after the exec has started. 91 | if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { 92 | t.Fatal("expected error") 93 | } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { 94 | t.Fatalf("unexpected error: %s", err) 95 | } 96 | 97 | // Context is already canceled, so error should come before execution. 98 | if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { 99 | t.Fatal("expected error") 100 | } else if err.Error() != "context canceled" { 101 | t.Fatalf("unexpected error: %s", err) 102 | } 103 | 104 | for i := 0; i < contextRaceIterations; i++ { 105 | func() { 106 | ctx, cancel := context.WithCancel(context.Background()) 107 | defer cancel() 108 | if _, err := db.ExecContext(ctx, "select 1"); err != nil { 109 | t.Fatal(err) 110 | } 111 | }() 112 | 113 | if _, err := db.Exec("select 1"); err != nil { 114 | t.Fatal(err) 115 | } 116 | } 117 | } 118 | 119 | func TestContextCancelQuery(t *testing.T) { 120 | db := openTestConn(t) 121 | defer db.Close() 122 | 123 | ctx, cancel := context.WithCancel(context.Background()) 124 | 125 | // Delay execution for just a bit until db.QueryContext has begun. 126 | defer time.AfterFunc(time.Millisecond*10, cancel).Stop() 127 | 128 | // Not canceled until after the exec has started. 129 | if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { 130 | t.Fatal("expected error") 131 | } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { 132 | t.Fatalf("unexpected error: %s", err) 133 | } 134 | 135 | // Context is already canceled, so error should come before execution. 136 | if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { 137 | t.Fatal("expected error") 138 | } else if err.Error() != "context canceled" { 139 | t.Fatalf("unexpected error: %s", err) 140 | } 141 | 142 | for i := 0; i < contextRaceIterations; i++ { 143 | func() { 144 | ctx, cancel := context.WithCancel(context.Background()) 145 | rows, err := db.QueryContext(ctx, "select 1") 146 | cancel() 147 | if err != nil { 148 | t.Fatal(err) 149 | } else if err := rows.Close(); err != nil && err != driver.ErrBadConn && err != context.Canceled { 150 | t.Fatal(err) 151 | } 152 | }() 153 | 154 | if rows, err := db.Query("select 1"); err != nil { 155 | t.Fatal(err) 156 | } else if err := rows.Close(); err != nil { 157 | t.Fatal(err) 158 | } 159 | } 160 | } 161 | 162 | // TestIssue617 tests that a failed query in QueryContext doesn't lead to a 163 | // goroutine leak. 164 | func TestIssue617(t *testing.T) { 165 | db := openTestConn(t) 166 | defer db.Close() 167 | 168 | const N = 10 169 | 170 | numGoroutineStart := runtime.NumGoroutine() 171 | for i := 0; i < N; i++ { 172 | func() { 173 | ctx, cancel := context.WithCancel(context.Background()) 174 | defer cancel() 175 | _, err := db.QueryContext(ctx, `SELECT * FROM DOESNOTEXIST`) 176 | pqErr, _ := err.(*Error) 177 | // Expecting "pq: relation \"doesnotexist\" does not exist" error. 178 | if err == nil || pqErr == nil || pqErr.Code != "42P01" { 179 | t.Fatalf("expected undefined table error, got %v", err) 180 | } 181 | }() 182 | } 183 | 184 | // Give time for goroutines to terminate 185 | delayTime := time.Millisecond * 50 186 | waitTime := time.Second 187 | iterations := int(waitTime / delayTime) 188 | 189 | var numGoroutineFinish int 190 | for i := 0; i < iterations; i++ { 191 | time.Sleep(delayTime) 192 | 193 | numGoroutineFinish = runtime.NumGoroutine() 194 | 195 | // We use N/2 and not N because the GC and other actors may increase or 196 | // decrease the number of goroutines. 197 | if numGoroutineFinish-numGoroutineStart < N/2 { 198 | return 199 | } 200 | } 201 | 202 | t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) 203 | } 204 | 205 | func TestContextCancelBegin(t *testing.T) { 206 | db := openTestConn(t) 207 | defer db.Close() 208 | 209 | ctx, cancel := context.WithCancel(context.Background()) 210 | tx, err := db.BeginTx(ctx, nil) 211 | if err != nil { 212 | t.Fatal(err) 213 | } 214 | 215 | // Delay execution for just a bit until tx.Exec has begun. 216 | defer time.AfterFunc(time.Millisecond*10, cancel).Stop() 217 | 218 | // Not canceled until after the exec has started. 219 | if _, err := tx.Exec("select pg_sleep(1)"); err == nil { 220 | t.Fatal("expected error") 221 | } else if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { 222 | t.Fatalf("unexpected error: %s", err) 223 | } 224 | 225 | // Transaction is canceled, so expect an error. 226 | if _, err := tx.Query("select pg_sleep(1)"); err == nil { 227 | t.Fatal("expected error") 228 | } else if err != sql.ErrTxDone { 229 | t.Fatalf("unexpected error: %s", err) 230 | } 231 | 232 | // Context is canceled, so cannot begin a transaction. 233 | if _, err := db.BeginTx(ctx, nil); err == nil { 234 | t.Fatal("expected error") 235 | } else if err.Error() != "context canceled" { 236 | t.Fatalf("unexpected error: %s", err) 237 | } 238 | 239 | for i := 0; i < contextRaceIterations; i++ { 240 | func() { 241 | ctx, cancel := context.WithCancel(context.Background()) 242 | tx, err := db.BeginTx(ctx, nil) 243 | cancel() 244 | if err != nil { 245 | t.Fatal(err) 246 | } else if err, pgErr := tx.Rollback(), (*Error)(nil); err != nil && 247 | !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) && 248 | err != sql.ErrTxDone && err != driver.ErrBadConn && err != context.Canceled { 249 | t.Fatal(err) 250 | } 251 | }() 252 | 253 | if tx, err := db.Begin(); err != nil { 254 | t.Fatal(err) 255 | } else if err := tx.Rollback(); err != nil { 256 | t.Fatal(err) 257 | } 258 | } 259 | } 260 | 261 | func TestTxOptions(t *testing.T) { 262 | db := openTestConn(t) 263 | defer db.Close() 264 | ctx := context.Background() 265 | 266 | tests := []struct { 267 | level sql.IsolationLevel 268 | isolation string 269 | }{ 270 | { 271 | level: sql.LevelDefault, 272 | isolation: "", 273 | }, 274 | { 275 | level: sql.LevelReadUncommitted, 276 | isolation: "read uncommitted", 277 | }, 278 | { 279 | level: sql.LevelReadCommitted, 280 | isolation: "read committed", 281 | }, 282 | { 283 | level: sql.LevelRepeatableRead, 284 | isolation: "repeatable read", 285 | }, 286 | { 287 | level: sql.LevelSerializable, 288 | isolation: "serializable", 289 | }, 290 | } 291 | 292 | for _, test := range tests { 293 | for _, ro := range []bool{true, false} { 294 | tx, err := db.BeginTx(ctx, &sql.TxOptions{ 295 | Isolation: test.level, 296 | ReadOnly: ro, 297 | }) 298 | if err != nil { 299 | t.Fatal(err) 300 | } 301 | 302 | var isolation string 303 | err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation) 304 | if err != nil { 305 | t.Fatal(err) 306 | } 307 | 308 | if test.isolation != "" && isolation != test.isolation { 309 | t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation) 310 | } 311 | 312 | var isRO string 313 | err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO) 314 | if err != nil { 315 | t.Fatal(err) 316 | } 317 | 318 | if ro != (isRO == "on") { 319 | t.Errorf("read/[write,only] not set: %t != %s for level %s", 320 | ro, isRO, test.isolation) 321 | } 322 | 323 | tx.Rollback() 324 | } 325 | } 326 | 327 | _, err := db.BeginTx(ctx, &sql.TxOptions{ 328 | Isolation: sql.LevelLinearizable, 329 | }) 330 | if err == nil { 331 | t.Fatal("expected LevelLinearizable to fail") 332 | } 333 | if !strings.Contains(err.Error(), "isolation level not supported") { 334 | t.Errorf("Expected error to mention isolation level, got %q", err) 335 | } 336 | } 337 | 338 | func TestErrorSQLState(t *testing.T) { 339 | r := readBuf([]byte{67, 52, 48, 48, 48, 49, 0, 0}) // 40001 340 | err := parseError(&r) 341 | var sqlErr errWithSQLState 342 | if !errors.As(err, &sqlErr) { 343 | t.Fatal("SQLState interface not satisfied") 344 | } 345 | if state := err.SQLState(); state != "40001" { 346 | t.Fatalf("unexpected SQL state %v", state) 347 | } 348 | } 349 | 350 | type errWithSQLState interface { 351 | SQLState() string 352 | } 353 | -------------------------------------------------------------------------------- /go19_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.9 2 | // +build go1.9 3 | 4 | package pq 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "database/sql/driver" 10 | "reflect" 11 | "testing" 12 | ) 13 | 14 | func TestPing(t *testing.T) { 15 | ctx, cancel := context.WithCancel(context.Background()) 16 | db := openTestConn(t) 17 | defer db.Close() 18 | 19 | if _, ok := reflect.TypeOf(db).MethodByName("Conn"); !ok { 20 | t.Skipf("Conn method undefined on type %T, skipping test (requires at least go1.9)", db) 21 | } 22 | 23 | if err := db.PingContext(ctx); err != nil { 24 | t.Fatal("expected Ping to succeed") 25 | } 26 | defer cancel() 27 | 28 | // grab a connection 29 | conn, err := db.Conn(ctx) 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | 34 | // start a transaction and read backend pid of our connection 35 | tx, err := conn.BeginTx(ctx, &sql.TxOptions{ 36 | Isolation: sql.LevelDefault, 37 | ReadOnly: true, 38 | }) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | 43 | rows, err := tx.Query("SELECT pg_backend_pid()") 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | defer rows.Close() 48 | 49 | // read the pid from result 50 | var pid int 51 | for rows.Next() { 52 | if err := rows.Scan(&pid); err != nil { 53 | t.Fatal(err) 54 | } 55 | } 56 | if rows.Err() != nil { 57 | t.Fatal(err) 58 | } 59 | // Fail the transaction and make sure we can still ping. 60 | if _, err := tx.Query("INVALID SQL"); err == nil { 61 | t.Fatal("expected error") 62 | } 63 | if err := conn.PingContext(ctx); err != nil { 64 | t.Fatal(err) 65 | } 66 | if err := tx.Rollback(); err != nil { 67 | t.Fatal(err) 68 | } 69 | 70 | // kill the process which handles our connection and test if the ping fails 71 | if _, err := db.Exec("SELECT pg_terminate_backend($1)", pid); err != nil { 72 | t.Fatal(err) 73 | } 74 | if err := conn.PingContext(ctx); err != driver.ErrBadConn { 75 | t.Fatalf("expected error %s, instead got %s", driver.ErrBadConn, err) 76 | } 77 | } 78 | 79 | func TestCommitInFailedTransactionWithCancelContext(t *testing.T) { 80 | db := openTestConn(t) 81 | defer db.Close() 82 | 83 | ctx, cancel := context.WithCancel(context.Background()) 84 | defer cancel() 85 | 86 | txn, err := db.BeginTx(ctx, nil) 87 | if err != nil { 88 | t.Fatal(err) 89 | } 90 | rows, err := txn.Query("SELECT error") 91 | if err == nil { 92 | rows.Close() 93 | t.Fatal("expected failure") 94 | } 95 | err = txn.Commit() 96 | if err != ErrInFailedTransaction { 97 | t.Fatalf("expected ErrInFailedTransaction; got %#v", err) 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /hstore/hstore.go: -------------------------------------------------------------------------------- 1 | package hstore 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "strings" 7 | ) 8 | 9 | // Hstore is a wrapper for transferring Hstore values back and forth easily. 10 | type Hstore struct { 11 | Map map[string]sql.NullString 12 | } 13 | 14 | // escapes and quotes hstore keys/values 15 | // s should be a sql.NullString or string 16 | func hQuote(s interface{}) string { 17 | var str string 18 | switch v := s.(type) { 19 | case sql.NullString: 20 | if !v.Valid { 21 | return "NULL" 22 | } 23 | str = v.String 24 | case string: 25 | str = v 26 | default: 27 | panic("not a string or sql.NullString") 28 | } 29 | 30 | str = strings.Replace(str, "\\", "\\\\", -1) 31 | return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` 32 | } 33 | 34 | // Scan implements the Scanner interface. 35 | // 36 | // Note h.Map is reallocated before the scan to clear existing values. If the 37 | // hstore column's database value is NULL, then h.Map is set to nil instead. 38 | func (h *Hstore) Scan(value interface{}) error { 39 | if value == nil { 40 | h.Map = nil 41 | return nil 42 | } 43 | h.Map = make(map[string]sql.NullString) 44 | var b byte 45 | pair := [][]byte{{}, {}} 46 | pi := 0 47 | inQuote := false 48 | didQuote := false 49 | sawSlash := false 50 | bindex := 0 51 | for bindex, b = range value.([]byte) { 52 | if sawSlash { 53 | pair[pi] = append(pair[pi], b) 54 | sawSlash = false 55 | continue 56 | } 57 | 58 | switch b { 59 | case '\\': 60 | sawSlash = true 61 | continue 62 | case '"': 63 | inQuote = !inQuote 64 | if !didQuote { 65 | didQuote = true 66 | } 67 | continue 68 | default: 69 | if !inQuote { 70 | switch b { 71 | case ' ', '\t', '\n', '\r': 72 | continue 73 | case '=': 74 | continue 75 | case '>': 76 | pi = 1 77 | didQuote = false 78 | continue 79 | case ',': 80 | s := string(pair[1]) 81 | if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { 82 | h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} 83 | } else { 84 | h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} 85 | } 86 | pair[0] = []byte{} 87 | pair[1] = []byte{} 88 | pi = 0 89 | continue 90 | } 91 | } 92 | } 93 | pair[pi] = append(pair[pi], b) 94 | } 95 | if bindex > 0 { 96 | s := string(pair[1]) 97 | if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { 98 | h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} 99 | } else { 100 | h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} 101 | } 102 | } 103 | return nil 104 | } 105 | 106 | // Value implements the driver Valuer interface. Note if h.Map is nil, the 107 | // database column value will be set to NULL. 108 | func (h Hstore) Value() (driver.Value, error) { 109 | if h.Map == nil { 110 | return nil, nil 111 | } 112 | parts := []string{} 113 | for key, val := range h.Map { 114 | thispart := hQuote(key) + "=>" + hQuote(val) 115 | parts = append(parts, thispart) 116 | } 117 | return []byte(strings.Join(parts, ",")), nil 118 | } 119 | -------------------------------------------------------------------------------- /hstore/hstore_test.go: -------------------------------------------------------------------------------- 1 | package hstore 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | "testing" 7 | 8 | _ "github.com/lib/pq" 9 | ) 10 | 11 | type Fatalistic interface { 12 | Fatal(args ...interface{}) 13 | } 14 | 15 | func openTestConn(t Fatalistic) *sql.DB { 16 | datname := os.Getenv("PGDATABASE") 17 | sslmode := os.Getenv("PGSSLMODE") 18 | 19 | if datname == "" { 20 | os.Setenv("PGDATABASE", "pqgotest") 21 | } 22 | 23 | if sslmode == "" { 24 | os.Setenv("PGSSLMODE", "disable") 25 | } 26 | 27 | conn, err := sql.Open("postgres", "") 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | 32 | return conn 33 | } 34 | 35 | func TestHstore(t *testing.T) { 36 | db := openTestConn(t) 37 | defer db.Close() 38 | 39 | // quietly create hstore if it doesn't exist 40 | _, err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore") 41 | if err != nil { 42 | t.Skipf("Skipping hstore tests - hstore extension create failed: %s", err.Error()) 43 | } 44 | 45 | hs := Hstore{} 46 | 47 | // test for null-valued hstores 48 | err = db.QueryRow("SELECT NULL::hstore").Scan(&hs) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | if hs.Map != nil { 53 | t.Fatalf("expected null map") 54 | } 55 | 56 | err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) 57 | if err != nil { 58 | t.Fatalf("re-query null map failed: %s", err.Error()) 59 | } 60 | if hs.Map != nil { 61 | t.Fatalf("expected null map") 62 | } 63 | 64 | // test for empty hstores 65 | err = db.QueryRow("SELECT ''::hstore").Scan(&hs) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | if hs.Map == nil { 70 | t.Fatalf("expected empty map, got null map") 71 | } 72 | if len(hs.Map) != 0 { 73 | t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) 74 | } 75 | 76 | err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) 77 | if err != nil { 78 | t.Fatalf("re-query empty map failed: %s", err.Error()) 79 | } 80 | if hs.Map == nil { 81 | t.Fatalf("expected empty map, got null map") 82 | } 83 | if len(hs.Map) != 0 { 84 | t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) 85 | } 86 | 87 | // a few example maps to test out 88 | hsOnePair := Hstore{ 89 | Map: map[string]sql.NullString{ 90 | "key1": {String: "value1", Valid: true}, 91 | }, 92 | } 93 | 94 | hsThreePairs := Hstore{ 95 | Map: map[string]sql.NullString{ 96 | "key1": {String: "value1", Valid: true}, 97 | "key2": {String: "value2", Valid: true}, 98 | "key3": {String: "value3", Valid: true}, 99 | }, 100 | } 101 | 102 | hsSmorgasbord := Hstore{ 103 | Map: map[string]sql.NullString{ 104 | "nullstring": {String: "NULL", Valid: true}, 105 | "actuallynull": {String: "", Valid: false}, 106 | "NULL": {String: "NULL string key", Valid: true}, 107 | "withbracket": {String: "value>42", Valid: true}, 108 | "withequal": {String: "value=42", Valid: true}, 109 | `"withquotes1"`: {String: `this "should" be fine`, Valid: true}, 110 | `"withquotes"2"`: {String: `this "should\" also be fine`, Valid: true}, 111 | "embedded1": {String: "value1=>x1", Valid: true}, 112 | "embedded2": {String: `"value2"=>x2`, Valid: true}, 113 | "withnewlines": {String: "\n\nvalue\t=>2", Valid: true}, 114 | "<>": {String: `this, "should,\" also, => be fine`, Valid: true}, 115 | }, 116 | } 117 | 118 | // test encoding in query params, then decoding during Scan 119 | testBidirectional := func(h Hstore) { 120 | err = db.QueryRow("SELECT $1::hstore", h).Scan(&hs) 121 | if err != nil { 122 | t.Fatalf("re-query %d-pair map failed: %s", len(h.Map), err.Error()) 123 | } 124 | if hs.Map == nil { 125 | t.Fatalf("expected %d-pair map, got null map", len(h.Map)) 126 | } 127 | if len(hs.Map) != len(h.Map) { 128 | t.Fatalf("expected %d-pair map, got len(map)=%d", len(h.Map), len(hs.Map)) 129 | } 130 | 131 | for key, val := range hs.Map { 132 | otherval, found := h.Map[key] 133 | if !found { 134 | t.Fatalf(" key '%v' not found in %d-pair map", key, len(h.Map)) 135 | } 136 | if otherval.Valid != val.Valid { 137 | t.Fatalf(" value %v <> %v in %d-pair map", otherval, val, len(h.Map)) 138 | } 139 | if otherval.String != val.String { 140 | t.Fatalf(" value '%v' <> '%v' in %d-pair map", otherval.String, val.String, len(h.Map)) 141 | } 142 | } 143 | } 144 | 145 | testBidirectional(hsOnePair) 146 | testBidirectional(hsThreePairs) 147 | testBidirectional(hsSmorgasbord) 148 | } 149 | -------------------------------------------------------------------------------- /issues_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestIssue494(t *testing.T) { 12 | db := openTestConn(t) 13 | defer db.Close() 14 | 15 | query := `CREATE TEMP TABLE t (i INT PRIMARY KEY)` 16 | if _, err := db.Exec(query); err != nil { 17 | t.Fatal(err) 18 | } 19 | 20 | txn, err := db.Begin() 21 | if err != nil { 22 | t.Fatal(err) 23 | } 24 | 25 | if _, err := txn.Prepare(CopyIn("t", "i")); err != nil { 26 | t.Fatal(err) 27 | } 28 | 29 | if _, err := txn.Query("SELECT 1"); err == nil { 30 | t.Fatal("expected error") 31 | } 32 | } 33 | 34 | func TestIssue1046(t *testing.T) { 35 | ctxTimeout := time.Second * 2 36 | 37 | db := openTestConn(t) 38 | defer db.Close() 39 | 40 | ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) 41 | defer cancel() 42 | 43 | stmt, err := db.PrepareContext(ctx, `SELECT pg_sleep(10) AS id`) 44 | if err != nil { 45 | t.Fatal(err) 46 | } 47 | 48 | var d []uint8 49 | err = stmt.QueryRowContext(ctx).Scan(&d) 50 | dl, _ := ctx.Deadline() 51 | since := time.Since(dl) 52 | if since > ctxTimeout { 53 | t.Logf("FAIL %s: query returned after context deadline: %v\n", t.Name(), since) 54 | t.Fail() 55 | } 56 | if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { 57 | t.Logf("ctx.Err(): [%T]%+v\n", ctx.Err(), ctx.Err()) 58 | t.Logf("got err: [%T] %+v expected errCode: %v", err, err, cancelErrorCode) 59 | t.Fail() 60 | } 61 | } 62 | 63 | func TestIssue1062(t *testing.T) { 64 | db := openTestConn(t) 65 | defer db.Close() 66 | 67 | // Ensure that cancelling a QueryRowContext does not result in an ErrBadConn. 68 | 69 | for i := 0; i < 100; i++ { 70 | ctx, cancel := context.WithCancel(context.Background()) 71 | go cancel() 72 | row := db.QueryRowContext(ctx, "select 1") 73 | 74 | var v int 75 | err := row.Scan(&v) 76 | if pgErr := (*Error)(nil); err != nil && 77 | err != context.Canceled && 78 | !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { 79 | t.Fatalf("Scan resulted in unexpected error %v for canceled QueryRowContext at attempt %d", err, i+1) 80 | } 81 | } 82 | } 83 | 84 | func connIsValid(t *testing.T, db *sql.DB) { 85 | t.Helper() 86 | 87 | ctx := context.Background() 88 | conn, err := db.Conn(ctx) 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | defer conn.Close() 93 | 94 | // the connection must be valid 95 | err = conn.PingContext(ctx) 96 | if err != nil { 97 | t.Errorf("PingContext err=%#v", err) 98 | } 99 | // close must not return an error 100 | err = conn.Close() 101 | if err != nil { 102 | t.Errorf("Close err=%#v", err) 103 | } 104 | } 105 | 106 | func TestQueryCancelRace(t *testing.T) { 107 | db := openTestConn(t) 108 | defer db.Close() 109 | 110 | // cancel a query while executing on Postgres: must return the cancelled error code 111 | ctx, cancel := context.WithCancel(context.Background()) 112 | go func() { 113 | time.Sleep(10 * time.Millisecond) 114 | cancel() 115 | }() 116 | row := db.QueryRowContext(ctx, "select pg_sleep(0.5)") 117 | var pgSleepVoid string 118 | err := row.Scan(&pgSleepVoid) 119 | if pgErr := (*Error)(nil); !(errors.As(err, &pgErr) && pgErr.Code == cancelErrorCode) { 120 | t.Fatalf("expected cancelled error; err=%#v", err) 121 | } 122 | 123 | // get a connection: it must be a valid 124 | connIsValid(t, db) 125 | } 126 | 127 | // Test cancelling a scan after it is started. This broke with 1.10.4. 128 | func TestQueryCancelledReused(t *testing.T) { 129 | db := openTestConn(t) 130 | defer db.Close() 131 | 132 | ctx, cancel := context.WithCancel(context.Background()) 133 | // run a query that returns a lot of data 134 | rows, err := db.QueryContext(ctx, "select generate_series(1, 10000)") 135 | if err != nil { 136 | t.Fatal(err) 137 | } 138 | 139 | // scan the first value 140 | if !rows.Next() { 141 | t.Error("expected rows.Next() to return true") 142 | } 143 | var i int 144 | err = rows.Scan(&i) 145 | if err != nil { 146 | t.Fatal(err) 147 | } 148 | if i != 1 { 149 | t.Error(i) 150 | } 151 | 152 | // cancel the context and close rows, ignoring errors 153 | cancel() 154 | rows.Close() 155 | 156 | // get a connection: it must be valid 157 | connIsValid(t, db) 158 | } 159 | -------------------------------------------------------------------------------- /krb.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | // NewGSSFunc creates a GSS authentication provider, for use with 4 | // RegisterGSSProvider. 5 | type NewGSSFunc func() (GSS, error) 6 | 7 | var newGss NewGSSFunc 8 | 9 | // RegisterGSSProvider registers a GSS authentication provider. For example, if 10 | // you need to use Kerberos to authenticate with your server, add this to your 11 | // main package: 12 | // 13 | // import "github.com/lib/pq/auth/kerberos" 14 | // 15 | // func init() { 16 | // pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() }) 17 | // } 18 | func RegisterGSSProvider(newGssArg NewGSSFunc) { 19 | newGss = newGssArg 20 | } 21 | 22 | // GSS provides GSSAPI authentication (e.g., Kerberos). 23 | type GSS interface { 24 | GetInitToken(host string, service string) ([]byte, error) 25 | GetInitTokenFromSpn(spn string) ([]byte, error) 26 | Continue(inToken []byte) (done bool, outToken []byte, err error) 27 | } 28 | -------------------------------------------------------------------------------- /notice.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package pq 5 | 6 | import ( 7 | "context" 8 | "database/sql/driver" 9 | ) 10 | 11 | // NoticeHandler returns the notice handler on the given connection, if any. A 12 | // runtime panic occurs if c is not a pq connection. This is rarely used 13 | // directly, use ConnectorNoticeHandler and ConnectorWithNoticeHandler instead. 14 | func NoticeHandler(c driver.Conn) func(*Error) { 15 | return c.(*conn).noticeHandler 16 | } 17 | 18 | // SetNoticeHandler sets the given notice handler on the given connection. A 19 | // runtime panic occurs if c is not a pq connection. A nil handler may be used 20 | // to unset it. This is rarely used directly, use ConnectorNoticeHandler and 21 | // ConnectorWithNoticeHandler instead. 22 | // 23 | // Note: Notice handlers are executed synchronously by pq meaning commands 24 | // won't continue to be processed until the handler returns. 25 | func SetNoticeHandler(c driver.Conn, handler func(*Error)) { 26 | c.(*conn).noticeHandler = handler 27 | } 28 | 29 | // NoticeHandlerConnector wraps a regular connector and sets a notice handler 30 | // on it. 31 | type NoticeHandlerConnector struct { 32 | driver.Connector 33 | noticeHandler func(*Error) 34 | } 35 | 36 | // Connect calls the underlying connector's connect method and then sets the 37 | // notice handler. 38 | func (n *NoticeHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) { 39 | c, err := n.Connector.Connect(ctx) 40 | if err == nil { 41 | SetNoticeHandler(c, n.noticeHandler) 42 | } 43 | return c, err 44 | } 45 | 46 | // ConnectorNoticeHandler returns the currently set notice handler, if any. If 47 | // the given connector is not a result of ConnectorWithNoticeHandler, nil is 48 | // returned. 49 | func ConnectorNoticeHandler(c driver.Connector) func(*Error) { 50 | if c, ok := c.(*NoticeHandlerConnector); ok { 51 | return c.noticeHandler 52 | } 53 | return nil 54 | } 55 | 56 | // ConnectorWithNoticeHandler creates or sets the given handler for the given 57 | // connector. If the given connector is a result of calling this function 58 | // previously, it is simply set on the given connector and returned. Otherwise, 59 | // this returns a new connector wrapping the given one and setting the notice 60 | // handler. A nil notice handler may be used to unset it. 61 | // 62 | // The returned connector is intended to be used with database/sql.OpenDB. 63 | // 64 | // Note: Notice handlers are executed synchronously by pq meaning commands 65 | // won't continue to be processed until the handler returns. 66 | func ConnectorWithNoticeHandler(c driver.Connector, handler func(*Error)) *NoticeHandlerConnector { 67 | if c, ok := c.(*NoticeHandlerConnector); ok { 68 | c.noticeHandler = handler 69 | return c 70 | } 71 | return &NoticeHandlerConnector{Connector: c, noticeHandler: handler} 72 | } 73 | -------------------------------------------------------------------------------- /notice_example_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package pq_test 5 | 6 | import ( 7 | "database/sql" 8 | "fmt" 9 | "log" 10 | 11 | "github.com/lib/pq" 12 | ) 13 | 14 | func ExampleConnectorWithNoticeHandler() { 15 | name := "" 16 | // Base connector to wrap 17 | base, err := pq.NewConnector(name) 18 | if err != nil { 19 | log.Fatal(err) 20 | } 21 | // Wrap the connector to simply print out the message 22 | connector := pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) { 23 | fmt.Println("Notice sent: " + notice.Message) 24 | }) 25 | db := sql.OpenDB(connector) 26 | defer db.Close() 27 | // Raise a notice 28 | sql := "DO language plpgsql $$ BEGIN RAISE NOTICE 'test notice'; END $$" 29 | if _, err := db.Exec(sql); err != nil { 30 | log.Fatal(err) 31 | } 32 | // Output: 33 | // Notice sent: test notice 34 | } 35 | -------------------------------------------------------------------------------- /notice_test.go: -------------------------------------------------------------------------------- 1 | //go:build go1.10 2 | // +build go1.10 3 | 4 | package pq 5 | 6 | import ( 7 | "database/sql" 8 | "database/sql/driver" 9 | "testing" 10 | ) 11 | 12 | func TestConnectorWithNoticeHandler_Simple(t *testing.T) { 13 | b, err := NewConnector("") 14 | if err != nil { 15 | t.Fatal(err) 16 | } 17 | var notice *Error 18 | // Make connector w/ handler to set the local var 19 | c := ConnectorWithNoticeHandler(b, func(n *Error) { notice = n }) 20 | raiseNotice(c, t, "Test notice #1") 21 | if notice == nil || notice.Message != "Test notice #1" { 22 | t.Fatalf("Expected notice w/ message, got %v", notice) 23 | } 24 | // Unset the handler on the same connector 25 | prevC := c 26 | if c = ConnectorWithNoticeHandler(c, nil); c != prevC { 27 | t.Fatalf("Expected to not create new connector but did") 28 | } 29 | raiseNotice(c, t, "Test notice #2") 30 | if notice == nil || notice.Message != "Test notice #1" { 31 | t.Fatalf("Expected notice to not change, got %v", notice) 32 | } 33 | // Set it back on the same connector 34 | if c = ConnectorWithNoticeHandler(c, func(n *Error) { notice = n }); c != prevC { 35 | t.Fatal("Expected to not create new connector but did") 36 | } 37 | raiseNotice(c, t, "Test notice #3") 38 | if notice == nil || notice.Message != "Test notice #3" { 39 | t.Fatalf("Expected notice w/ message, got %v", notice) 40 | } 41 | } 42 | 43 | func raiseNotice(c driver.Connector, t *testing.T, escapedNotice string) { 44 | db := sql.OpenDB(c) 45 | defer db.Close() 46 | sql := "DO language plpgsql $$ BEGIN RAISE NOTICE '" + escapedNotice + "'; END $$" 47 | if _, err := db.Exec(sql); err != nil { 48 | t.Fatal(err) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /notify_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "os" 10 | "runtime" 11 | "sync" 12 | "testing" 13 | "time" 14 | ) 15 | 16 | var errNilNotification = errors.New("nil notification") 17 | 18 | func expectNotification(t *testing.T, ch <-chan *Notification, relname string, extra string) error { 19 | select { 20 | case n := <-ch: 21 | if n == nil { 22 | return errNilNotification 23 | } 24 | if n.Channel != relname || n.Extra != extra { 25 | return fmt.Errorf("unexpected notification %v", n) 26 | } 27 | return nil 28 | case <-time.After(1500 * time.Millisecond): 29 | return fmt.Errorf("timeout") 30 | } 31 | } 32 | 33 | func expectNoNotification(t *testing.T, ch <-chan *Notification) error { 34 | select { 35 | case n := <-ch: 36 | return fmt.Errorf("unexpected notification %v", n) 37 | case <-time.After(100 * time.Millisecond): 38 | return nil 39 | } 40 | } 41 | 42 | func expectEvent(t *testing.T, eventch <-chan ListenerEventType, et ListenerEventType) error { 43 | select { 44 | case e := <-eventch: 45 | if e != et { 46 | return fmt.Errorf("unexpected event %v", e) 47 | } 48 | return nil 49 | case <-time.After(1500 * time.Millisecond): 50 | panic("expectEvent timeout") 51 | } 52 | } 53 | 54 | func expectNoEvent(t *testing.T, eventch <-chan ListenerEventType) error { 55 | select { 56 | case e := <-eventch: 57 | return fmt.Errorf("unexpected event %v", e) 58 | case <-time.After(100 * time.Millisecond): 59 | return nil 60 | } 61 | } 62 | 63 | func newTestListenerConn(t *testing.T) (*ListenerConn, <-chan *Notification) { 64 | datname := os.Getenv("PGDATABASE") 65 | sslmode := os.Getenv("PGSSLMODE") 66 | 67 | if datname == "" { 68 | os.Setenv("PGDATABASE", "pqgotest") 69 | } 70 | 71 | if sslmode == "" { 72 | os.Setenv("PGSSLMODE", "disable") 73 | } 74 | 75 | notificationChan := make(chan *Notification) 76 | l, err := NewListenerConn("", notificationChan) 77 | if err != nil { 78 | t.Fatal(err) 79 | } 80 | 81 | return l, notificationChan 82 | } 83 | 84 | func TestNewListenerConn(t *testing.T) { 85 | l, _ := newTestListenerConn(t) 86 | 87 | defer l.Close() 88 | } 89 | 90 | func TestConnListen(t *testing.T) { 91 | l, channel := newTestListenerConn(t) 92 | 93 | defer l.Close() 94 | 95 | db := openTestConn(t) 96 | defer db.Close() 97 | 98 | ok, err := l.Listen("notify_test") 99 | if !ok || err != nil { 100 | t.Fatal(err) 101 | } 102 | 103 | _, err = db.Exec("NOTIFY notify_test") 104 | if err != nil { 105 | t.Fatal(err) 106 | } 107 | 108 | err = expectNotification(t, channel, "notify_test", "") 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | } 113 | 114 | func TestConnUnlisten(t *testing.T) { 115 | l, channel := newTestListenerConn(t) 116 | 117 | defer l.Close() 118 | 119 | db := openTestConn(t) 120 | defer db.Close() 121 | 122 | ok, err := l.Listen("notify_test") 123 | if !ok || err != nil { 124 | t.Fatal(err) 125 | } 126 | 127 | _, err = db.Exec("NOTIFY notify_test") 128 | if err != nil { 129 | t.Fatal(err) 130 | } 131 | 132 | err = expectNotification(t, channel, "notify_test", "") 133 | if err != nil { 134 | t.Fatal(err) 135 | } 136 | 137 | ok, err = l.Unlisten("notify_test") 138 | if !ok || err != nil { 139 | t.Fatal(err) 140 | } 141 | 142 | _, err = db.Exec("NOTIFY notify_test") 143 | if err != nil { 144 | t.Fatal(err) 145 | } 146 | 147 | err = expectNoNotification(t, channel) 148 | if err != nil { 149 | t.Fatal(err) 150 | } 151 | } 152 | 153 | func TestConnUnlistenAll(t *testing.T) { 154 | l, channel := newTestListenerConn(t) 155 | 156 | defer l.Close() 157 | 158 | db := openTestConn(t) 159 | defer db.Close() 160 | 161 | ok, err := l.Listen("notify_test") 162 | if !ok || err != nil { 163 | t.Fatal(err) 164 | } 165 | 166 | _, err = db.Exec("NOTIFY notify_test") 167 | if err != nil { 168 | t.Fatal(err) 169 | } 170 | 171 | err = expectNotification(t, channel, "notify_test", "") 172 | if err != nil { 173 | t.Fatal(err) 174 | } 175 | 176 | ok, err = l.UnlistenAll() 177 | if !ok || err != nil { 178 | t.Fatal(err) 179 | } 180 | 181 | _, err = db.Exec("NOTIFY notify_test") 182 | if err != nil { 183 | t.Fatal(err) 184 | } 185 | 186 | err = expectNoNotification(t, channel) 187 | if err != nil { 188 | t.Fatal(err) 189 | } 190 | } 191 | 192 | func TestConnClose(t *testing.T) { 193 | l, _ := newTestListenerConn(t) 194 | defer l.Close() 195 | 196 | err := l.Close() 197 | if err != nil { 198 | t.Fatal(err) 199 | } 200 | err = l.Close() 201 | if err != errListenerConnClosed { 202 | t.Fatalf("expected errListenerConnClosed; got %v", err) 203 | } 204 | } 205 | 206 | func TestConnPing(t *testing.T) { 207 | l, _ := newTestListenerConn(t) 208 | defer l.Close() 209 | err := l.Ping() 210 | if err != nil { 211 | t.Fatal(err) 212 | } 213 | err = l.Close() 214 | if err != nil { 215 | t.Fatal(err) 216 | } 217 | err = l.Ping() 218 | if err != errListenerConnClosed { 219 | t.Fatalf("expected errListenerConnClosed; got %v", err) 220 | } 221 | } 222 | 223 | // Test for deadlock where a query fails while another one is queued 224 | func TestConnExecDeadlock(t *testing.T) { 225 | l, _ := newTestListenerConn(t) 226 | defer l.Close() 227 | 228 | var wg sync.WaitGroup 229 | wg.Add(2) 230 | 231 | go func() { 232 | l.ExecSimpleQuery("SELECT pg_sleep(60)") 233 | wg.Done() 234 | }() 235 | runtime.Gosched() 236 | go func() { 237 | l.ExecSimpleQuery("SELECT 1") 238 | wg.Done() 239 | }() 240 | // give the two goroutines some time to get into position 241 | runtime.Gosched() 242 | // calls Close on the net.Conn; equivalent to a network failure 243 | l.Close() 244 | 245 | defer time.AfterFunc(10*time.Second, func() { 246 | panic("timed out") 247 | }).Stop() 248 | wg.Wait() 249 | } 250 | 251 | // Test for ListenerConn being closed while a slow query is executing 252 | func TestListenerConnCloseWhileQueryIsExecuting(t *testing.T) { 253 | l, _ := newTestListenerConn(t) 254 | defer l.Close() 255 | 256 | var wg sync.WaitGroup 257 | wg.Add(1) 258 | 259 | go func() { 260 | sent, err := l.ExecSimpleQuery("SELECT pg_sleep(60)") 261 | if sent { 262 | panic("expected sent=false") 263 | } 264 | // could be any of a number of errors 265 | if err == nil { 266 | panic("expected error") 267 | } 268 | wg.Done() 269 | }() 270 | // give the above goroutine some time to get into position 271 | runtime.Gosched() 272 | err := l.Close() 273 | if err != nil { 274 | t.Fatal(err) 275 | } 276 | 277 | defer time.AfterFunc(10*time.Second, func() { 278 | panic("timed out") 279 | }).Stop() 280 | wg.Wait() 281 | } 282 | 283 | func TestNotifyExtra(t *testing.T) { 284 | db := openTestConn(t) 285 | defer db.Close() 286 | 287 | if getServerVersion(t, db) < 90000 { 288 | t.Skip("skipping NOTIFY payload test since the server does not appear to support it") 289 | } 290 | 291 | l, channel := newTestListenerConn(t) 292 | defer l.Close() 293 | 294 | ok, err := l.Listen("notify_test") 295 | if !ok || err != nil { 296 | t.Fatal(err) 297 | } 298 | 299 | _, err = db.Exec("NOTIFY notify_test, 'something'") 300 | if err != nil { 301 | t.Fatal(err) 302 | } 303 | 304 | err = expectNotification(t, channel, "notify_test", "something") 305 | if err != nil { 306 | t.Fatal(err) 307 | } 308 | } 309 | 310 | // create a new test listener and also set the timeouts 311 | func newTestListenerTimeout(t *testing.T, min time.Duration, max time.Duration) (*Listener, <-chan ListenerEventType) { 312 | datname := os.Getenv("PGDATABASE") 313 | sslmode := os.Getenv("PGSSLMODE") 314 | 315 | if datname == "" { 316 | os.Setenv("PGDATABASE", "pqgotest") 317 | } 318 | 319 | if sslmode == "" { 320 | os.Setenv("PGSSLMODE", "disable") 321 | } 322 | 323 | eventch := make(chan ListenerEventType, 16) 324 | l := NewListener("", min, max, func(t ListenerEventType, err error) { eventch <- t }) 325 | err := expectEvent(t, eventch, ListenerEventConnected) 326 | if err != nil { 327 | t.Fatal(err) 328 | } 329 | return l, eventch 330 | } 331 | 332 | func newTestListener(t *testing.T) (*Listener, <-chan ListenerEventType) { 333 | return newTestListenerTimeout(t, time.Hour, time.Hour) 334 | } 335 | 336 | func TestListenerListen(t *testing.T) { 337 | l, _ := newTestListener(t) 338 | defer l.Close() 339 | 340 | db := openTestConn(t) 341 | defer db.Close() 342 | 343 | err := l.Listen("notify_listen_test") 344 | if err != nil { 345 | t.Fatal(err) 346 | } 347 | 348 | _, err = db.Exec("NOTIFY notify_listen_test") 349 | if err != nil { 350 | t.Fatal(err) 351 | } 352 | 353 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 354 | if err != nil { 355 | t.Fatal(err) 356 | } 357 | } 358 | 359 | func TestListenerUnlisten(t *testing.T) { 360 | l, _ := newTestListener(t) 361 | defer l.Close() 362 | 363 | db := openTestConn(t) 364 | defer db.Close() 365 | 366 | err := l.Listen("notify_listen_test") 367 | if err != nil { 368 | t.Fatal(err) 369 | } 370 | 371 | _, err = db.Exec("NOTIFY notify_listen_test") 372 | if err != nil { 373 | t.Fatal(err) 374 | } 375 | 376 | err = l.Unlisten("notify_listen_test") 377 | if err != nil { 378 | t.Fatal(err) 379 | } 380 | 381 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 382 | if err != nil { 383 | t.Fatal(err) 384 | } 385 | 386 | _, err = db.Exec("NOTIFY notify_listen_test") 387 | if err != nil { 388 | t.Fatal(err) 389 | } 390 | 391 | err = expectNoNotification(t, l.Notify) 392 | if err != nil { 393 | t.Fatal(err) 394 | } 395 | } 396 | 397 | func TestListenerUnlistenAll(t *testing.T) { 398 | l, _ := newTestListener(t) 399 | defer l.Close() 400 | 401 | db := openTestConn(t) 402 | defer db.Close() 403 | 404 | err := l.Listen("notify_listen_test") 405 | if err != nil { 406 | t.Fatal(err) 407 | } 408 | 409 | _, err = db.Exec("NOTIFY notify_listen_test") 410 | if err != nil { 411 | t.Fatal(err) 412 | } 413 | 414 | err = l.UnlistenAll() 415 | if err != nil { 416 | t.Fatal(err) 417 | } 418 | 419 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 420 | if err != nil { 421 | t.Fatal(err) 422 | } 423 | 424 | _, err = db.Exec("NOTIFY notify_listen_test") 425 | if err != nil { 426 | t.Fatal(err) 427 | } 428 | 429 | err = expectNoNotification(t, l.Notify) 430 | if err != nil { 431 | t.Fatal(err) 432 | } 433 | } 434 | 435 | func TestListenerFailedQuery(t *testing.T) { 436 | l, eventch := newTestListener(t) 437 | defer l.Close() 438 | 439 | db := openTestConn(t) 440 | defer db.Close() 441 | 442 | err := l.Listen("notify_listen_test") 443 | if err != nil { 444 | t.Fatal(err) 445 | } 446 | 447 | _, err = db.Exec("NOTIFY notify_listen_test") 448 | if err != nil { 449 | t.Fatal(err) 450 | } 451 | 452 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 453 | if err != nil { 454 | t.Fatal(err) 455 | } 456 | 457 | // shouldn't cause a disconnect 458 | ok, err := l.cn.ExecSimpleQuery("SELECT error") 459 | if !ok { 460 | t.Fatalf("could not send query to server: %v", err) 461 | } 462 | _, ok = err.(PGError) 463 | if !ok { 464 | t.Fatalf("unexpected error %v", err) 465 | } 466 | err = expectNoEvent(t, eventch) 467 | if err != nil { 468 | t.Fatal(err) 469 | } 470 | 471 | // should still work 472 | _, err = db.Exec("NOTIFY notify_listen_test") 473 | if err != nil { 474 | t.Fatal(err) 475 | } 476 | 477 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 478 | if err != nil { 479 | t.Fatal(err) 480 | } 481 | } 482 | 483 | func TestListenerReconnect(t *testing.T) { 484 | l, eventch := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) 485 | defer l.Close() 486 | 487 | db := openTestConn(t) 488 | defer db.Close() 489 | 490 | err := l.Listen("notify_listen_test") 491 | if err != nil { 492 | t.Fatal(err) 493 | } 494 | 495 | _, err = db.Exec("NOTIFY notify_listen_test") 496 | if err != nil { 497 | t.Fatal(err) 498 | } 499 | 500 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 501 | if err != nil { 502 | t.Fatal(err) 503 | } 504 | 505 | // kill the connection and make sure it comes back up 506 | ok, err := l.cn.ExecSimpleQuery("SELECT pg_terminate_backend(pg_backend_pid())") 507 | if ok { 508 | t.Fatalf("could not kill the connection: %v", err) 509 | } 510 | if err != io.EOF { 511 | t.Fatalf("unexpected error %v", err) 512 | } 513 | err = expectEvent(t, eventch, ListenerEventDisconnected) 514 | if err != nil { 515 | t.Fatal(err) 516 | } 517 | err = expectEvent(t, eventch, ListenerEventReconnected) 518 | if err != nil { 519 | t.Fatal(err) 520 | } 521 | 522 | // should still work 523 | _, err = db.Exec("NOTIFY notify_listen_test") 524 | if err != nil { 525 | t.Fatal(err) 526 | } 527 | 528 | // should get nil after Reconnected 529 | err = expectNotification(t, l.Notify, "", "") 530 | if err != errNilNotification { 531 | t.Fatal(err) 532 | } 533 | 534 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 535 | if err != nil { 536 | t.Fatal(err) 537 | } 538 | } 539 | 540 | func TestListenerClose(t *testing.T) { 541 | l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) 542 | defer l.Close() 543 | 544 | err := l.Close() 545 | if err != nil { 546 | t.Fatal(err) 547 | } 548 | err = l.Close() 549 | if err != errListenerClosed { 550 | t.Fatalf("expected errListenerClosed; got %v", err) 551 | } 552 | } 553 | 554 | func TestListenerPing(t *testing.T) { 555 | l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) 556 | defer l.Close() 557 | 558 | err := l.Ping() 559 | if err != nil { 560 | t.Fatal(err) 561 | } 562 | 563 | err = l.Close() 564 | if err != nil { 565 | t.Fatal(err) 566 | } 567 | 568 | err = l.Ping() 569 | if err != errListenerClosed { 570 | t.Fatalf("expected errListenerClosed; got %v", err) 571 | } 572 | } 573 | 574 | func TestConnectorWithNotificationHandler_Simple(t *testing.T) { 575 | b, err := NewConnector("") 576 | if err != nil { 577 | t.Fatal(err) 578 | } 579 | var notification *Notification 580 | // Make connector w/ handler to set the local var 581 | c := ConnectorWithNotificationHandler(b, func(n *Notification) { notification = n }) 582 | sendNotification(c, t, "Test notification #1") 583 | if notification == nil || notification.Extra != "Test notification #1" { 584 | t.Fatalf("Expected notification w/ message, got %v", notification) 585 | } 586 | // Unset the handler on the same connector 587 | prevC := c 588 | if c = ConnectorWithNotificationHandler(c, nil); c != prevC { 589 | t.Fatalf("Expected to not create new connector but did") 590 | } 591 | sendNotification(c, t, "Test notification #2") 592 | if notification == nil || notification.Extra != "Test notification #1" { 593 | t.Fatalf("Expected notification to not change, got %v", notification) 594 | } 595 | // Set it back on the same connector 596 | if c = ConnectorWithNotificationHandler(c, func(n *Notification) { notification = n }); c != prevC { 597 | t.Fatal("Expected to not create new connector but did") 598 | } 599 | sendNotification(c, t, "Test notification #3") 600 | if notification == nil || notification.Extra != "Test notification #3" { 601 | t.Fatalf("Expected notification w/ message, got %v", notification) 602 | } 603 | } 604 | 605 | func sendNotification(c driver.Connector, t *testing.T, escapedNotification string) { 606 | db := sql.OpenDB(c) 607 | defer db.Close() 608 | sql := fmt.Sprintf("LISTEN foo; NOTIFY foo, '%s';", escapedNotification) 609 | if _, err := db.Exec(sql); err != nil { 610 | t.Fatal(err) 611 | } 612 | } 613 | -------------------------------------------------------------------------------- /oid/doc.go: -------------------------------------------------------------------------------- 1 | // Package oid contains OID constants 2 | // as defined by the Postgres server. 3 | package oid 4 | 5 | // Oid is a Postgres Object ID. 6 | type Oid uint32 7 | -------------------------------------------------------------------------------- /oid/gen.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | // Generate the table of OID values 5 | // Run with 'go run gen.go'. 6 | package main 7 | 8 | import ( 9 | "database/sql" 10 | "fmt" 11 | "log" 12 | "os" 13 | "os/exec" 14 | "strings" 15 | 16 | _ "github.com/lib/pq" 17 | ) 18 | 19 | // OID represent a postgres Object Identifier Type. 20 | type OID struct { 21 | ID int 22 | Type string 23 | } 24 | 25 | // Name returns an upper case version of the oid type. 26 | func (o OID) Name() string { 27 | return strings.ToUpper(o.Type) 28 | } 29 | 30 | func main() { 31 | datname := os.Getenv("PGDATABASE") 32 | sslmode := os.Getenv("PGSSLMODE") 33 | 34 | if datname == "" { 35 | os.Setenv("PGDATABASE", "pqgotest") 36 | } 37 | 38 | if sslmode == "" { 39 | os.Setenv("PGSSLMODE", "disable") 40 | } 41 | 42 | db, err := sql.Open("postgres", "") 43 | if err != nil { 44 | log.Fatal(err) 45 | } 46 | rows, err := db.Query(` 47 | SELECT typname, oid 48 | FROM pg_type WHERE oid < 10000 49 | ORDER BY oid; 50 | `) 51 | if err != nil { 52 | log.Fatal(err) 53 | } 54 | oids := make([]*OID, 0) 55 | for rows.Next() { 56 | var oid OID 57 | if err = rows.Scan(&oid.Type, &oid.ID); err != nil { 58 | log.Fatal(err) 59 | } 60 | oids = append(oids, &oid) 61 | } 62 | if err = rows.Err(); err != nil { 63 | log.Fatal(err) 64 | } 65 | cmd := exec.Command("gofmt") 66 | cmd.Stderr = os.Stderr 67 | w, err := cmd.StdinPipe() 68 | if err != nil { 69 | log.Fatal(err) 70 | } 71 | f, err := os.Create("types.go") 72 | if err != nil { 73 | log.Fatal(err) 74 | } 75 | cmd.Stdout = f 76 | err = cmd.Start() 77 | if err != nil { 78 | log.Fatal(err) 79 | } 80 | fmt.Fprintln(w, "// Code generated by gen.go. DO NOT EDIT.") 81 | fmt.Fprintln(w, "\npackage oid") 82 | fmt.Fprintln(w, "const (") 83 | for _, oid := range oids { 84 | fmt.Fprintf(w, "T_%s Oid = %d\n", oid.Type, oid.ID) 85 | } 86 | fmt.Fprintln(w, ")") 87 | fmt.Fprintln(w, "var TypeName = map[Oid]string{") 88 | for _, oid := range oids { 89 | fmt.Fprintf(w, "T_%s: \"%s\",\n", oid.Type, oid.Name()) 90 | } 91 | fmt.Fprintln(w, "}") 92 | w.Close() 93 | cmd.Wait() 94 | } 95 | -------------------------------------------------------------------------------- /oid/types.go: -------------------------------------------------------------------------------- 1 | // Code generated by gen.go. DO NOT EDIT. 2 | 3 | package oid 4 | 5 | const ( 6 | T_bool Oid = 16 7 | T_bytea Oid = 17 8 | T_char Oid = 18 9 | T_name Oid = 19 10 | T_int8 Oid = 20 11 | T_int2 Oid = 21 12 | T_int2vector Oid = 22 13 | T_int4 Oid = 23 14 | T_regproc Oid = 24 15 | T_text Oid = 25 16 | T_oid Oid = 26 17 | T_tid Oid = 27 18 | T_xid Oid = 28 19 | T_cid Oid = 29 20 | T_oidvector Oid = 30 21 | T_pg_ddl_command Oid = 32 22 | T_pg_type Oid = 71 23 | T_pg_attribute Oid = 75 24 | T_pg_proc Oid = 81 25 | T_pg_class Oid = 83 26 | T_json Oid = 114 27 | T_xml Oid = 142 28 | T__xml Oid = 143 29 | T_pg_node_tree Oid = 194 30 | T__json Oid = 199 31 | T_smgr Oid = 210 32 | T_index_am_handler Oid = 325 33 | T_point Oid = 600 34 | T_lseg Oid = 601 35 | T_path Oid = 602 36 | T_box Oid = 603 37 | T_polygon Oid = 604 38 | T_line Oid = 628 39 | T__line Oid = 629 40 | T_cidr Oid = 650 41 | T__cidr Oid = 651 42 | T_float4 Oid = 700 43 | T_float8 Oid = 701 44 | T_abstime Oid = 702 45 | T_reltime Oid = 703 46 | T_tinterval Oid = 704 47 | T_unknown Oid = 705 48 | T_circle Oid = 718 49 | T__circle Oid = 719 50 | T_money Oid = 790 51 | T__money Oid = 791 52 | T_macaddr Oid = 829 53 | T_inet Oid = 869 54 | T__bool Oid = 1000 55 | T__bytea Oid = 1001 56 | T__char Oid = 1002 57 | T__name Oid = 1003 58 | T__int2 Oid = 1005 59 | T__int2vector Oid = 1006 60 | T__int4 Oid = 1007 61 | T__regproc Oid = 1008 62 | T__text Oid = 1009 63 | T__tid Oid = 1010 64 | T__xid Oid = 1011 65 | T__cid Oid = 1012 66 | T__oidvector Oid = 1013 67 | T__bpchar Oid = 1014 68 | T__varchar Oid = 1015 69 | T__int8 Oid = 1016 70 | T__point Oid = 1017 71 | T__lseg Oid = 1018 72 | T__path Oid = 1019 73 | T__box Oid = 1020 74 | T__float4 Oid = 1021 75 | T__float8 Oid = 1022 76 | T__abstime Oid = 1023 77 | T__reltime Oid = 1024 78 | T__tinterval Oid = 1025 79 | T__polygon Oid = 1027 80 | T__oid Oid = 1028 81 | T_aclitem Oid = 1033 82 | T__aclitem Oid = 1034 83 | T__macaddr Oid = 1040 84 | T__inet Oid = 1041 85 | T_bpchar Oid = 1042 86 | T_varchar Oid = 1043 87 | T_date Oid = 1082 88 | T_time Oid = 1083 89 | T_timestamp Oid = 1114 90 | T__timestamp Oid = 1115 91 | T__date Oid = 1182 92 | T__time Oid = 1183 93 | T_timestamptz Oid = 1184 94 | T__timestamptz Oid = 1185 95 | T_interval Oid = 1186 96 | T__interval Oid = 1187 97 | T__numeric Oid = 1231 98 | T_pg_database Oid = 1248 99 | T__cstring Oid = 1263 100 | T_timetz Oid = 1266 101 | T__timetz Oid = 1270 102 | T_bit Oid = 1560 103 | T__bit Oid = 1561 104 | T_varbit Oid = 1562 105 | T__varbit Oid = 1563 106 | T_numeric Oid = 1700 107 | T_refcursor Oid = 1790 108 | T__refcursor Oid = 2201 109 | T_regprocedure Oid = 2202 110 | T_regoper Oid = 2203 111 | T_regoperator Oid = 2204 112 | T_regclass Oid = 2205 113 | T_regtype Oid = 2206 114 | T__regprocedure Oid = 2207 115 | T__regoper Oid = 2208 116 | T__regoperator Oid = 2209 117 | T__regclass Oid = 2210 118 | T__regtype Oid = 2211 119 | T_record Oid = 2249 120 | T_cstring Oid = 2275 121 | T_any Oid = 2276 122 | T_anyarray Oid = 2277 123 | T_void Oid = 2278 124 | T_trigger Oid = 2279 125 | T_language_handler Oid = 2280 126 | T_internal Oid = 2281 127 | T_opaque Oid = 2282 128 | T_anyelement Oid = 2283 129 | T__record Oid = 2287 130 | T_anynonarray Oid = 2776 131 | T_pg_authid Oid = 2842 132 | T_pg_auth_members Oid = 2843 133 | T__txid_snapshot Oid = 2949 134 | T_uuid Oid = 2950 135 | T__uuid Oid = 2951 136 | T_txid_snapshot Oid = 2970 137 | T_fdw_handler Oid = 3115 138 | T_pg_lsn Oid = 3220 139 | T__pg_lsn Oid = 3221 140 | T_tsm_handler Oid = 3310 141 | T_anyenum Oid = 3500 142 | T_tsvector Oid = 3614 143 | T_tsquery Oid = 3615 144 | T_gtsvector Oid = 3642 145 | T__tsvector Oid = 3643 146 | T__gtsvector Oid = 3644 147 | T__tsquery Oid = 3645 148 | T_regconfig Oid = 3734 149 | T__regconfig Oid = 3735 150 | T_regdictionary Oid = 3769 151 | T__regdictionary Oid = 3770 152 | T_jsonb Oid = 3802 153 | T__jsonb Oid = 3807 154 | T_anyrange Oid = 3831 155 | T_event_trigger Oid = 3838 156 | T_int4range Oid = 3904 157 | T__int4range Oid = 3905 158 | T_numrange Oid = 3906 159 | T__numrange Oid = 3907 160 | T_tsrange Oid = 3908 161 | T__tsrange Oid = 3909 162 | T_tstzrange Oid = 3910 163 | T__tstzrange Oid = 3911 164 | T_daterange Oid = 3912 165 | T__daterange Oid = 3913 166 | T_int8range Oid = 3926 167 | T__int8range Oid = 3927 168 | T_pg_shseclabel Oid = 4066 169 | T_regnamespace Oid = 4089 170 | T__regnamespace Oid = 4090 171 | T_regrole Oid = 4096 172 | T__regrole Oid = 4097 173 | ) 174 | 175 | var TypeName = map[Oid]string{ 176 | T_bool: "BOOL", 177 | T_bytea: "BYTEA", 178 | T_char: "CHAR", 179 | T_name: "NAME", 180 | T_int8: "INT8", 181 | T_int2: "INT2", 182 | T_int2vector: "INT2VECTOR", 183 | T_int4: "INT4", 184 | T_regproc: "REGPROC", 185 | T_text: "TEXT", 186 | T_oid: "OID", 187 | T_tid: "TID", 188 | T_xid: "XID", 189 | T_cid: "CID", 190 | T_oidvector: "OIDVECTOR", 191 | T_pg_ddl_command: "PG_DDL_COMMAND", 192 | T_pg_type: "PG_TYPE", 193 | T_pg_attribute: "PG_ATTRIBUTE", 194 | T_pg_proc: "PG_PROC", 195 | T_pg_class: "PG_CLASS", 196 | T_json: "JSON", 197 | T_xml: "XML", 198 | T__xml: "_XML", 199 | T_pg_node_tree: "PG_NODE_TREE", 200 | T__json: "_JSON", 201 | T_smgr: "SMGR", 202 | T_index_am_handler: "INDEX_AM_HANDLER", 203 | T_point: "POINT", 204 | T_lseg: "LSEG", 205 | T_path: "PATH", 206 | T_box: "BOX", 207 | T_polygon: "POLYGON", 208 | T_line: "LINE", 209 | T__line: "_LINE", 210 | T_cidr: "CIDR", 211 | T__cidr: "_CIDR", 212 | T_float4: "FLOAT4", 213 | T_float8: "FLOAT8", 214 | T_abstime: "ABSTIME", 215 | T_reltime: "RELTIME", 216 | T_tinterval: "TINTERVAL", 217 | T_unknown: "UNKNOWN", 218 | T_circle: "CIRCLE", 219 | T__circle: "_CIRCLE", 220 | T_money: "MONEY", 221 | T__money: "_MONEY", 222 | T_macaddr: "MACADDR", 223 | T_inet: "INET", 224 | T__bool: "_BOOL", 225 | T__bytea: "_BYTEA", 226 | T__char: "_CHAR", 227 | T__name: "_NAME", 228 | T__int2: "_INT2", 229 | T__int2vector: "_INT2VECTOR", 230 | T__int4: "_INT4", 231 | T__regproc: "_REGPROC", 232 | T__text: "_TEXT", 233 | T__tid: "_TID", 234 | T__xid: "_XID", 235 | T__cid: "_CID", 236 | T__oidvector: "_OIDVECTOR", 237 | T__bpchar: "_BPCHAR", 238 | T__varchar: "_VARCHAR", 239 | T__int8: "_INT8", 240 | T__point: "_POINT", 241 | T__lseg: "_LSEG", 242 | T__path: "_PATH", 243 | T__box: "_BOX", 244 | T__float4: "_FLOAT4", 245 | T__float8: "_FLOAT8", 246 | T__abstime: "_ABSTIME", 247 | T__reltime: "_RELTIME", 248 | T__tinterval: "_TINTERVAL", 249 | T__polygon: "_POLYGON", 250 | T__oid: "_OID", 251 | T_aclitem: "ACLITEM", 252 | T__aclitem: "_ACLITEM", 253 | T__macaddr: "_MACADDR", 254 | T__inet: "_INET", 255 | T_bpchar: "BPCHAR", 256 | T_varchar: "VARCHAR", 257 | T_date: "DATE", 258 | T_time: "TIME", 259 | T_timestamp: "TIMESTAMP", 260 | T__timestamp: "_TIMESTAMP", 261 | T__date: "_DATE", 262 | T__time: "_TIME", 263 | T_timestamptz: "TIMESTAMPTZ", 264 | T__timestamptz: "_TIMESTAMPTZ", 265 | T_interval: "INTERVAL", 266 | T__interval: "_INTERVAL", 267 | T__numeric: "_NUMERIC", 268 | T_pg_database: "PG_DATABASE", 269 | T__cstring: "_CSTRING", 270 | T_timetz: "TIMETZ", 271 | T__timetz: "_TIMETZ", 272 | T_bit: "BIT", 273 | T__bit: "_BIT", 274 | T_varbit: "VARBIT", 275 | T__varbit: "_VARBIT", 276 | T_numeric: "NUMERIC", 277 | T_refcursor: "REFCURSOR", 278 | T__refcursor: "_REFCURSOR", 279 | T_regprocedure: "REGPROCEDURE", 280 | T_regoper: "REGOPER", 281 | T_regoperator: "REGOPERATOR", 282 | T_regclass: "REGCLASS", 283 | T_regtype: "REGTYPE", 284 | T__regprocedure: "_REGPROCEDURE", 285 | T__regoper: "_REGOPER", 286 | T__regoperator: "_REGOPERATOR", 287 | T__regclass: "_REGCLASS", 288 | T__regtype: "_REGTYPE", 289 | T_record: "RECORD", 290 | T_cstring: "CSTRING", 291 | T_any: "ANY", 292 | T_anyarray: "ANYARRAY", 293 | T_void: "VOID", 294 | T_trigger: "TRIGGER", 295 | T_language_handler: "LANGUAGE_HANDLER", 296 | T_internal: "INTERNAL", 297 | T_opaque: "OPAQUE", 298 | T_anyelement: "ANYELEMENT", 299 | T__record: "_RECORD", 300 | T_anynonarray: "ANYNONARRAY", 301 | T_pg_authid: "PG_AUTHID", 302 | T_pg_auth_members: "PG_AUTH_MEMBERS", 303 | T__txid_snapshot: "_TXID_SNAPSHOT", 304 | T_uuid: "UUID", 305 | T__uuid: "_UUID", 306 | T_txid_snapshot: "TXID_SNAPSHOT", 307 | T_fdw_handler: "FDW_HANDLER", 308 | T_pg_lsn: "PG_LSN", 309 | T__pg_lsn: "_PG_LSN", 310 | T_tsm_handler: "TSM_HANDLER", 311 | T_anyenum: "ANYENUM", 312 | T_tsvector: "TSVECTOR", 313 | T_tsquery: "TSQUERY", 314 | T_gtsvector: "GTSVECTOR", 315 | T__tsvector: "_TSVECTOR", 316 | T__gtsvector: "_GTSVECTOR", 317 | T__tsquery: "_TSQUERY", 318 | T_regconfig: "REGCONFIG", 319 | T__regconfig: "_REGCONFIG", 320 | T_regdictionary: "REGDICTIONARY", 321 | T__regdictionary: "_REGDICTIONARY", 322 | T_jsonb: "JSONB", 323 | T__jsonb: "_JSONB", 324 | T_anyrange: "ANYRANGE", 325 | T_event_trigger: "EVENT_TRIGGER", 326 | T_int4range: "INT4RANGE", 327 | T__int4range: "_INT4RANGE", 328 | T_numrange: "NUMRANGE", 329 | T__numrange: "_NUMRANGE", 330 | T_tsrange: "TSRANGE", 331 | T__tsrange: "_TSRANGE", 332 | T_tstzrange: "TSTZRANGE", 333 | T__tstzrange: "_TSTZRANGE", 334 | T_daterange: "DATERANGE", 335 | T__daterange: "_DATERANGE", 336 | T_int8range: "INT8RANGE", 337 | T__int8range: "_INT8RANGE", 338 | T_pg_shseclabel: "PG_SHSECLABEL", 339 | T_regnamespace: "REGNAMESPACE", 340 | T__regnamespace: "_REGNAMESPACE", 341 | T_regrole: "REGROLE", 342 | T__regrole: "_REGROLE", 343 | } 344 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "time" 7 | 8 | "github.com/lib/pq/oid" 9 | ) 10 | 11 | const headerSize = 4 12 | 13 | type fieldDesc struct { 14 | // The object ID of the data type. 15 | OID oid.Oid 16 | // The data type size (see pg_type.typlen). 17 | // Note that negative values denote variable-width types. 18 | Len int 19 | // The type modifier (see pg_attribute.atttypmod). 20 | // The meaning of the modifier is type-specific. 21 | Mod int 22 | } 23 | 24 | func (fd fieldDesc) Type() reflect.Type { 25 | switch fd.OID { 26 | case oid.T_int8: 27 | return reflect.TypeOf(int64(0)) 28 | case oid.T_int4: 29 | return reflect.TypeOf(int32(0)) 30 | case oid.T_int2: 31 | return reflect.TypeOf(int16(0)) 32 | case oid.T_varchar, oid.T_text: 33 | return reflect.TypeOf("") 34 | case oid.T_bool: 35 | return reflect.TypeOf(false) 36 | case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: 37 | return reflect.TypeOf(time.Time{}) 38 | case oid.T_bytea: 39 | return reflect.TypeOf([]byte(nil)) 40 | default: 41 | return reflect.TypeOf(new(interface{})).Elem() 42 | } 43 | } 44 | 45 | func (fd fieldDesc) Name() string { 46 | return oid.TypeName[fd.OID] 47 | } 48 | 49 | func (fd fieldDesc) Length() (length int64, ok bool) { 50 | switch fd.OID { 51 | case oid.T_text, oid.T_bytea: 52 | return math.MaxInt64, true 53 | case oid.T_varchar, oid.T_bpchar: 54 | return int64(fd.Mod - headerSize), true 55 | default: 56 | return 0, false 57 | } 58 | } 59 | 60 | func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { 61 | switch fd.OID { 62 | case oid.T_numeric, oid.T__numeric: 63 | mod := fd.Mod - headerSize 64 | precision = int64((mod >> 16) & 0xffff) 65 | scale = int64(mod & 0xffff) 66 | return precision, scale, true 67 | default: 68 | return 0, 0, false 69 | } 70 | } 71 | 72 | // ColumnTypeScanType returns the value type that can be used to scan types into. 73 | func (rs *rows) ColumnTypeScanType(index int) reflect.Type { 74 | return rs.colTyps[index].Type() 75 | } 76 | 77 | // ColumnTypeDatabaseTypeName return the database system type name. 78 | func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { 79 | return rs.colTyps[index].Name() 80 | } 81 | 82 | // ColumnTypeLength returns the length of the column type if the column is a 83 | // variable length type. If the column is not a variable length type ok 84 | // should return false. 85 | func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { 86 | return rs.colTyps[index].Length() 87 | } 88 | 89 | // ColumnTypePrecisionScale should return the precision and scale for decimal 90 | // types. If not applicable, ok should be false. 91 | func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { 92 | return rs.colTyps[index].PrecisionScale() 93 | } 94 | -------------------------------------------------------------------------------- /rows_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/lib/pq/oid" 9 | ) 10 | 11 | func TestDataTypeName(t *testing.T) { 12 | tts := []struct { 13 | typ oid.Oid 14 | name string 15 | }{ 16 | {oid.T_int8, "INT8"}, 17 | {oid.T_int4, "INT4"}, 18 | {oid.T_int2, "INT2"}, 19 | {oid.T_varchar, "VARCHAR"}, 20 | {oid.T_text, "TEXT"}, 21 | {oid.T_bool, "BOOL"}, 22 | {oid.T_numeric, "NUMERIC"}, 23 | {oid.T_date, "DATE"}, 24 | {oid.T_time, "TIME"}, 25 | {oid.T_timetz, "TIMETZ"}, 26 | {oid.T_timestamp, "TIMESTAMP"}, 27 | {oid.T_timestamptz, "TIMESTAMPTZ"}, 28 | {oid.T_bytea, "BYTEA"}, 29 | } 30 | 31 | for i, tt := range tts { 32 | dt := fieldDesc{OID: tt.typ} 33 | if name := dt.Name(); name != tt.name { 34 | t.Errorf("(%d) got: %s want: %s", i, name, tt.name) 35 | } 36 | } 37 | } 38 | 39 | func TestDataType(t *testing.T) { 40 | tts := []struct { 41 | typ oid.Oid 42 | kind reflect.Kind 43 | }{ 44 | {oid.T_int8, reflect.Int64}, 45 | {oid.T_int4, reflect.Int32}, 46 | {oid.T_int2, reflect.Int16}, 47 | {oid.T_varchar, reflect.String}, 48 | {oid.T_text, reflect.String}, 49 | {oid.T_bool, reflect.Bool}, 50 | {oid.T_date, reflect.Struct}, 51 | {oid.T_time, reflect.Struct}, 52 | {oid.T_timetz, reflect.Struct}, 53 | {oid.T_timestamp, reflect.Struct}, 54 | {oid.T_timestamptz, reflect.Struct}, 55 | {oid.T_bytea, reflect.Slice}, 56 | } 57 | 58 | for i, tt := range tts { 59 | dt := fieldDesc{OID: tt.typ} 60 | if kind := dt.Type().Kind(); kind != tt.kind { 61 | t.Errorf("(%d) got: %s want: %s", i, kind, tt.kind) 62 | } 63 | } 64 | } 65 | 66 | func TestDataTypeLength(t *testing.T) { 67 | tts := []struct { 68 | typ oid.Oid 69 | len int 70 | mod int 71 | length int64 72 | ok bool 73 | }{ 74 | {oid.T_int4, 0, -1, 0, false}, 75 | {oid.T_varchar, 65535, 9, 5, true}, 76 | {oid.T_text, 65535, -1, math.MaxInt64, true}, 77 | {oid.T_bytea, 65535, -1, math.MaxInt64, true}, 78 | } 79 | 80 | for i, tt := range tts { 81 | dt := fieldDesc{OID: tt.typ, Len: tt.len, Mod: tt.mod} 82 | if l, k := dt.Length(); k != tt.ok || l != tt.length { 83 | t.Errorf("(%d) got: %d, %t want: %d, %t", i, l, k, tt.length, tt.ok) 84 | } 85 | } 86 | } 87 | 88 | func TestDataTypePrecisionScale(t *testing.T) { 89 | tts := []struct { 90 | typ oid.Oid 91 | mod int 92 | precision, scale int64 93 | ok bool 94 | }{ 95 | {oid.T_int4, -1, 0, 0, false}, 96 | {oid.T_numeric, 589830, 9, 2, true}, 97 | {oid.T_text, -1, 0, 0, false}, 98 | } 99 | 100 | for i, tt := range tts { 101 | dt := fieldDesc{OID: tt.typ, Mod: tt.mod} 102 | p, s, k := dt.PrecisionScale() 103 | if k != tt.ok { 104 | t.Errorf("(%d) got: %t want: %t", i, k, tt.ok) 105 | } 106 | if p != tt.precision { 107 | t.Errorf("(%d) wrong precision got: %d want: %d", i, p, tt.precision) 108 | } 109 | if s != tt.scale { 110 | t.Errorf("(%d) wrong scale got: %d want: %d", i, s, tt.scale) 111 | } 112 | } 113 | } 114 | 115 | func TestRowsColumnTypes(t *testing.T) { 116 | columnTypesTests := []struct { 117 | Name string 118 | TypeName string 119 | Length struct { 120 | Len int64 121 | OK bool 122 | } 123 | DecimalSize struct { 124 | Precision int64 125 | Scale int64 126 | OK bool 127 | } 128 | ScanType reflect.Type 129 | }{ 130 | { 131 | Name: "a", 132 | TypeName: "INT4", 133 | Length: struct { 134 | Len int64 135 | OK bool 136 | }{ 137 | Len: 0, 138 | OK: false, 139 | }, 140 | DecimalSize: struct { 141 | Precision int64 142 | Scale int64 143 | OK bool 144 | }{ 145 | Precision: 0, 146 | Scale: 0, 147 | OK: false, 148 | }, 149 | ScanType: reflect.TypeOf(int32(0)), 150 | }, { 151 | Name: "bar", 152 | TypeName: "TEXT", 153 | Length: struct { 154 | Len int64 155 | OK bool 156 | }{ 157 | Len: math.MaxInt64, 158 | OK: true, 159 | }, 160 | DecimalSize: struct { 161 | Precision int64 162 | Scale int64 163 | OK bool 164 | }{ 165 | Precision: 0, 166 | Scale: 0, 167 | OK: false, 168 | }, 169 | ScanType: reflect.TypeOf(""), 170 | }, 171 | } 172 | 173 | db := openTestConn(t) 174 | defer db.Close() 175 | 176 | rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") 177 | if err != nil { 178 | t.Fatal(err) 179 | } 180 | 181 | columns, err := rows.ColumnTypes() 182 | if err != nil { 183 | t.Fatal(err) 184 | } 185 | if len(columns) != 3 { 186 | t.Errorf("expected 3 columns found %d", len(columns)) 187 | } 188 | 189 | for i, tt := range columnTypesTests { 190 | c := columns[i] 191 | if c.Name() != tt.Name { 192 | t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) 193 | } 194 | if c.DatabaseTypeName() != tt.TypeName { 195 | t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) 196 | } 197 | l, ok := c.Length() 198 | if l != tt.Length.Len { 199 | t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) 200 | } 201 | if ok != tt.Length.OK { 202 | t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) 203 | } 204 | p, s, ok := c.DecimalSize() 205 | if p != tt.DecimalSize.Precision { 206 | t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) 207 | } 208 | if s != tt.DecimalSize.Scale { 209 | t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) 210 | } 211 | if ok != tt.DecimalSize.OK { 212 | t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) 213 | } 214 | if c.ScanType() != tt.ScanType { 215 | t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) 216 | } 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /scram/scram.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2014 - Gustavo Niemeyer 2 | // 3 | // All rights reserved. 4 | // 5 | // Redistribution and use in source and binary forms, with or without 6 | // modification, are permitted provided that the following conditions are met: 7 | // 8 | // 1. Redistributions of source code must retain the above copyright notice, this 9 | // list of conditions and the following disclaimer. 10 | // 2. Redistributions in binary form must reproduce the above copyright notice, 11 | // this list of conditions and the following disclaimer in the documentation 12 | // and/or other materials provided with the distribution. 13 | // 14 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 15 | // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 16 | // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 18 | // ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 19 | // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 20 | // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 21 | // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 | // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 23 | // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | // Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. 26 | // 27 | // http://tools.ietf.org/html/rfc5802 28 | // 29 | package scram 30 | 31 | import ( 32 | "bytes" 33 | "crypto/hmac" 34 | "crypto/rand" 35 | "encoding/base64" 36 | "fmt" 37 | "hash" 38 | "strconv" 39 | "strings" 40 | ) 41 | 42 | // Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). 43 | // 44 | // A Client may be used within a SASL conversation with logic resembling: 45 | // 46 | // var in []byte 47 | // var client = scram.NewClient(sha1.New, user, pass) 48 | // for client.Step(in) { 49 | // out := client.Out() 50 | // // send out to server 51 | // in := serverOut 52 | // } 53 | // if client.Err() != nil { 54 | // // auth failed 55 | // } 56 | // 57 | type Client struct { 58 | newHash func() hash.Hash 59 | 60 | user string 61 | pass string 62 | step int 63 | out bytes.Buffer 64 | err error 65 | 66 | clientNonce []byte 67 | serverNonce []byte 68 | saltedPass []byte 69 | authMsg bytes.Buffer 70 | } 71 | 72 | // NewClient returns a new SCRAM-* client with the provided hash algorithm. 73 | // 74 | // For SCRAM-SHA-256, for example, use: 75 | // 76 | // client := scram.NewClient(sha256.New, user, pass) 77 | // 78 | func NewClient(newHash func() hash.Hash, user, pass string) *Client { 79 | c := &Client{ 80 | newHash: newHash, 81 | user: user, 82 | pass: pass, 83 | } 84 | c.out.Grow(256) 85 | c.authMsg.Grow(256) 86 | return c 87 | } 88 | 89 | // Out returns the data to be sent to the server in the current step. 90 | func (c *Client) Out() []byte { 91 | if c.out.Len() == 0 { 92 | return nil 93 | } 94 | return c.out.Bytes() 95 | } 96 | 97 | // Err returns the error that occurred, or nil if there were no errors. 98 | func (c *Client) Err() error { 99 | return c.err 100 | } 101 | 102 | // SetNonce sets the client nonce to the provided value. 103 | // If not set, the nonce is generated automatically out of crypto/rand on the first step. 104 | func (c *Client) SetNonce(nonce []byte) { 105 | c.clientNonce = nonce 106 | } 107 | 108 | var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") 109 | 110 | // Step processes the incoming data from the server and makes the 111 | // next round of data for the server available via Client.Out. 112 | // Step returns false if there are no errors and more data is 113 | // still expected. 114 | func (c *Client) Step(in []byte) bool { 115 | c.out.Reset() 116 | if c.step > 2 || c.err != nil { 117 | return false 118 | } 119 | c.step++ 120 | switch c.step { 121 | case 1: 122 | c.err = c.step1(in) 123 | case 2: 124 | c.err = c.step2(in) 125 | case 3: 126 | c.err = c.step3(in) 127 | } 128 | return c.step > 2 || c.err != nil 129 | } 130 | 131 | func (c *Client) step1(in []byte) error { 132 | if len(c.clientNonce) == 0 { 133 | const nonceLen = 16 134 | buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) 135 | if _, err := rand.Read(buf[:nonceLen]); err != nil { 136 | return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err) 137 | } 138 | c.clientNonce = buf[nonceLen:] 139 | b64.Encode(c.clientNonce, buf[:nonceLen]) 140 | } 141 | c.authMsg.WriteString("n=") 142 | escaper.WriteString(&c.authMsg, c.user) 143 | c.authMsg.WriteString(",r=") 144 | c.authMsg.Write(c.clientNonce) 145 | 146 | c.out.WriteString("n,,") 147 | c.out.Write(c.authMsg.Bytes()) 148 | return nil 149 | } 150 | 151 | var b64 = base64.StdEncoding 152 | 153 | func (c *Client) step2(in []byte) error { 154 | c.authMsg.WriteByte(',') 155 | c.authMsg.Write(in) 156 | 157 | fields := bytes.Split(in, []byte(",")) 158 | if len(fields) != 3 { 159 | return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in) 160 | } 161 | if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { 162 | return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0]) 163 | } 164 | if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { 165 | return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1]) 166 | } 167 | if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { 168 | return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) 169 | } 170 | 171 | c.serverNonce = fields[0][2:] 172 | if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { 173 | return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) 174 | } 175 | 176 | salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) 177 | n, err := b64.Decode(salt, fields[1][2:]) 178 | if err != nil { 179 | return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1]) 180 | } 181 | salt = salt[:n] 182 | iterCount, err := strconv.Atoi(string(fields[2][2:])) 183 | if err != nil { 184 | return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) 185 | } 186 | c.saltPassword(salt, iterCount) 187 | 188 | c.authMsg.WriteString(",c=biws,r=") 189 | c.authMsg.Write(c.serverNonce) 190 | 191 | c.out.WriteString("c=biws,r=") 192 | c.out.Write(c.serverNonce) 193 | c.out.WriteString(",p=") 194 | c.out.Write(c.clientProof()) 195 | return nil 196 | } 197 | 198 | func (c *Client) step3(in []byte) error { 199 | var isv, ise bool 200 | var fields = bytes.Split(in, []byte(",")) 201 | if len(fields) == 1 { 202 | isv = bytes.HasPrefix(fields[0], []byte("v=")) 203 | ise = bytes.HasPrefix(fields[0], []byte("e=")) 204 | } 205 | if ise { 206 | return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:]) 207 | } else if !isv { 208 | return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in) 209 | } 210 | if !bytes.Equal(c.serverSignature(), fields[0][2:]) { 211 | return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:]) 212 | } 213 | return nil 214 | } 215 | 216 | func (c *Client) saltPassword(salt []byte, iterCount int) { 217 | mac := hmac.New(c.newHash, []byte(c.pass)) 218 | mac.Write(salt) 219 | mac.Write([]byte{0, 0, 0, 1}) 220 | ui := mac.Sum(nil) 221 | hi := make([]byte, len(ui)) 222 | copy(hi, ui) 223 | for i := 1; i < iterCount; i++ { 224 | mac.Reset() 225 | mac.Write(ui) 226 | mac.Sum(ui[:0]) 227 | for j, b := range ui { 228 | hi[j] ^= b 229 | } 230 | } 231 | c.saltedPass = hi 232 | } 233 | 234 | func (c *Client) clientProof() []byte { 235 | mac := hmac.New(c.newHash, c.saltedPass) 236 | mac.Write([]byte("Client Key")) 237 | clientKey := mac.Sum(nil) 238 | hash := c.newHash() 239 | hash.Write(clientKey) 240 | storedKey := hash.Sum(nil) 241 | mac = hmac.New(c.newHash, storedKey) 242 | mac.Write(c.authMsg.Bytes()) 243 | clientProof := mac.Sum(nil) 244 | for i, b := range clientKey { 245 | clientProof[i] ^= b 246 | } 247 | clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) 248 | b64.Encode(clientProof64, clientProof) 249 | return clientProof64 250 | } 251 | 252 | func (c *Client) serverSignature() []byte { 253 | mac := hmac.New(c.newHash, c.saltedPass) 254 | mac.Write([]byte("Server Key")) 255 | serverKey := mac.Sum(nil) 256 | 257 | mac = hmac.New(c.newHash, serverKey) 258 | mac.Write(c.authMsg.Bytes()) 259 | serverSignature := mac.Sum(nil) 260 | 261 | encoded := make([]byte, b64.EncodedLen(len(serverSignature))) 262 | b64.Encode(encoded, serverSignature) 263 | return encoded 264 | } 265 | -------------------------------------------------------------------------------- /ssl.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "io/ioutil" 7 | "net" 8 | "os" 9 | "os/user" 10 | "path/filepath" 11 | "strings" 12 | ) 13 | 14 | // ssl generates a function to upgrade a net.Conn based on the "sslmode" and 15 | // related settings. The function is nil when no upgrade should take place. 16 | func ssl(o values) (func(net.Conn) (net.Conn, error), error) { 17 | verifyCaOnly := false 18 | tlsConf := tls.Config{} 19 | switch mode := o["sslmode"]; mode { 20 | // "require" is the default. 21 | case "", "require": 22 | // We must skip TLS's own verification since it requires full 23 | // verification since Go 1.3. 24 | tlsConf.InsecureSkipVerify = true 25 | 26 | // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: 27 | // 28 | // Note: For backwards compatibility with earlier versions of 29 | // PostgreSQL, if a root CA file exists, the behavior of 30 | // sslmode=require will be the same as that of verify-ca, meaning the 31 | // server certificate is validated against the CA. Relying on this 32 | // behavior is discouraged, and applications that need certificate 33 | // validation should always use verify-ca or verify-full. 34 | if sslrootcert, ok := o["sslrootcert"]; ok { 35 | if _, err := os.Stat(sslrootcert); err == nil { 36 | verifyCaOnly = true 37 | } else { 38 | delete(o, "sslrootcert") 39 | } 40 | } 41 | case "verify-ca": 42 | // We must skip TLS's own verification since it requires full 43 | // verification since Go 1.3. 44 | tlsConf.InsecureSkipVerify = true 45 | verifyCaOnly = true 46 | case "verify-full": 47 | tlsConf.ServerName = o["host"] 48 | case "disable": 49 | return nil, nil 50 | default: 51 | return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) 52 | } 53 | 54 | // Set Server Name Indication (SNI), if enabled by connection parameters. 55 | // By default SNI is on, any value which is not starting with "1" disables 56 | // SNI -- that is the same check vanilla libpq uses. 57 | if sslsni := o["sslsni"]; sslsni == "" || strings.HasPrefix(sslsni, "1") { 58 | // RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 59 | // or IPv6). This check is coded already crypto.tls.hostnameInSNI, so 60 | // just always set ServerName here and let crypto/tls do the filtering. 61 | tlsConf.ServerName = o["host"] 62 | } 63 | 64 | err := sslClientCertificates(&tlsConf, o) 65 | if err != nil { 66 | return nil, err 67 | } 68 | err = sslCertificateAuthority(&tlsConf, o) 69 | if err != nil { 70 | return nil, err 71 | } 72 | 73 | // Accept renegotiation requests initiated by the backend. 74 | // 75 | // Renegotiation was deprecated then removed from PostgreSQL 9.5, but 76 | // the default configuration of older versions has it enabled. Redshift 77 | // also initiates renegotiations and cannot be reconfigured. 78 | tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient 79 | 80 | return func(conn net.Conn) (net.Conn, error) { 81 | client := tls.Client(conn, &tlsConf) 82 | if verifyCaOnly { 83 | err := sslVerifyCertificateAuthority(client, &tlsConf) 84 | if err != nil { 85 | return nil, err 86 | } 87 | } 88 | return client, nil 89 | }, nil 90 | } 91 | 92 | // sslClientCertificates adds the certificate specified in the "sslcert" and 93 | // "sslkey" settings, or if they aren't set, from the .postgresql directory 94 | // in the user's home directory. The configured files must exist and have 95 | // the correct permissions. 96 | func sslClientCertificates(tlsConf *tls.Config, o values) error { 97 | sslinline := o["sslinline"] 98 | if sslinline == "true" { 99 | cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) 100 | if err != nil { 101 | return err 102 | } 103 | tlsConf.Certificates = []tls.Certificate{cert} 104 | return nil 105 | } 106 | 107 | // user.Current() might fail when cross-compiling. We have to ignore the 108 | // error and continue without home directory defaults, since we wouldn't 109 | // know from where to load them. 110 | user, _ := user.Current() 111 | 112 | // In libpq, the client certificate is only loaded if the setting is not blank. 113 | // 114 | // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 115 | sslcert := o["sslcert"] 116 | if len(sslcert) == 0 && user != nil { 117 | sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") 118 | } 119 | // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 120 | if len(sslcert) == 0 { 121 | return nil 122 | } 123 | // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 124 | if _, err := os.Stat(sslcert); os.IsNotExist(err) { 125 | return nil 126 | } else if err != nil { 127 | return err 128 | } 129 | 130 | // In libpq, the ssl key is only loaded if the setting is not blank. 131 | // 132 | // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 133 | sslkey := o["sslkey"] 134 | if len(sslkey) == 0 && user != nil { 135 | sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") 136 | } 137 | 138 | if len(sslkey) > 0 { 139 | if err := sslKeyPermissions(sslkey); err != nil { 140 | return err 141 | } 142 | } 143 | 144 | cert, err := tls.LoadX509KeyPair(sslcert, sslkey) 145 | if err != nil { 146 | return err 147 | } 148 | 149 | tlsConf.Certificates = []tls.Certificate{cert} 150 | return nil 151 | } 152 | 153 | // sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. 154 | func sslCertificateAuthority(tlsConf *tls.Config, o values) error { 155 | // In libpq, the root certificate is only loaded if the setting is not blank. 156 | // 157 | // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 158 | if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { 159 | tlsConf.RootCAs = x509.NewCertPool() 160 | 161 | sslinline := o["sslinline"] 162 | 163 | var cert []byte 164 | if sslinline == "true" { 165 | cert = []byte(sslrootcert) 166 | } else { 167 | var err error 168 | cert, err = ioutil.ReadFile(sslrootcert) 169 | if err != nil { 170 | return err 171 | } 172 | } 173 | 174 | if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { 175 | return fmterrorf("couldn't parse pem in sslrootcert") 176 | } 177 | } 178 | 179 | return nil 180 | } 181 | 182 | // sslVerifyCertificateAuthority carries out a TLS handshake to the server and 183 | // verifies the presented certificate against the CA, i.e. the one specified in 184 | // sslrootcert or the system CA if sslrootcert was not specified. 185 | func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { 186 | err := client.Handshake() 187 | if err != nil { 188 | return err 189 | } 190 | certs := client.ConnectionState().PeerCertificates 191 | opts := x509.VerifyOptions{ 192 | DNSName: client.ConnectionState().ServerName, 193 | Intermediates: x509.NewCertPool(), 194 | Roots: tlsConf.RootCAs, 195 | } 196 | for i, cert := range certs { 197 | if i == 0 { 198 | continue 199 | } 200 | opts.Intermediates.AddCert(cert) 201 | } 202 | _, err = certs[0].Verify(opts) 203 | return err 204 | } 205 | -------------------------------------------------------------------------------- /ssl_permissions.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package pq 5 | 6 | import ( 7 | "errors" 8 | "os" 9 | "syscall" 10 | ) 11 | 12 | const ( 13 | rootUserID = uint32(0) 14 | 15 | // The maximum permissions that a private key file owned by a regular user 16 | // is allowed to have. This translates to u=rw. 17 | maxUserOwnedKeyPermissions os.FileMode = 0600 18 | 19 | // The maximum permissions that a private key file owned by root is allowed 20 | // to have. This translates to u=rw,g=r. 21 | maxRootOwnedKeyPermissions os.FileMode = 0640 22 | ) 23 | 24 | var ( 25 | errSSLKeyHasUnacceptableUserPermissions = errors.New("permissions for files not owned by root should be u=rw (0600) or less") 26 | errSSLKeyHasUnacceptableRootPermissions = errors.New("permissions for root owned files should be u=rw,g=r (0640) or less") 27 | ) 28 | 29 | // sslKeyPermissions checks the permissions on user-supplied ssl key files. 30 | // The key file should have very little access. 31 | // 32 | // libpq does not check key file permissions on Windows. 33 | func sslKeyPermissions(sslkey string) error { 34 | info, err := os.Stat(sslkey) 35 | if err != nil { 36 | return err 37 | } 38 | 39 | err = hasCorrectPermissions(info) 40 | 41 | // return ErrSSLKeyHasWorldPermissions for backwards compatability with 42 | // existing code. 43 | if err == errSSLKeyHasUnacceptableUserPermissions || err == errSSLKeyHasUnacceptableRootPermissions { 44 | err = ErrSSLKeyHasWorldPermissions 45 | } 46 | return err 47 | } 48 | 49 | // hasCorrectPermissions checks the file info (and the unix-specific stat_t 50 | // output) to verify that the permissions on the file are correct. 51 | // 52 | // If the file is owned by the same user the process is running as, 53 | // the file should only have 0600 (u=rw). If the file is owned by root, 54 | // and the group matches the group that the process is running in, the 55 | // permissions cannot be more than 0640 (u=rw,g=r). The file should 56 | // never have world permissions. 57 | // 58 | // Returns an error when the permission check fails. 59 | func hasCorrectPermissions(info os.FileInfo) error { 60 | // if file's permission matches 0600, allow access. 61 | userPermissionMask := (os.FileMode(0777) ^ maxUserOwnedKeyPermissions) 62 | 63 | // regardless of if we're running as root or not, 0600 is acceptable, 64 | // so we return if we match the regular user permission mask. 65 | if info.Mode().Perm()&userPermissionMask == 0 { 66 | return nil 67 | } 68 | 69 | // We need to pull the Unix file information to get the file's owner. 70 | // If we can't access it, there's some sort of operating system level error 71 | // and we should fail rather than attempting to use faulty information. 72 | sysInfo := info.Sys() 73 | if sysInfo == nil { 74 | return ErrSSLKeyUnknownOwnership 75 | } 76 | 77 | unixStat, ok := sysInfo.(*syscall.Stat_t) 78 | if !ok { 79 | return ErrSSLKeyUnknownOwnership 80 | } 81 | 82 | // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what 83 | // Postgres does. 84 | if unixStat.Uid == rootUserID { 85 | rootPermissionMask := (os.FileMode(0777) ^ maxRootOwnedKeyPermissions) 86 | if info.Mode().Perm()&rootPermissionMask != 0 { 87 | return errSSLKeyHasUnacceptableRootPermissions 88 | } 89 | return nil 90 | } 91 | 92 | return errSSLKeyHasUnacceptableUserPermissions 93 | } 94 | -------------------------------------------------------------------------------- /ssl_permissions_test.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package pq 5 | 6 | import ( 7 | "os" 8 | "syscall" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | type stat_t_wrapper struct { 14 | stat syscall.Stat_t 15 | } 16 | 17 | func (stat_t *stat_t_wrapper) Name() string { 18 | return "pem.key" 19 | } 20 | 21 | func (stat_t *stat_t_wrapper) Size() int64 { 22 | return int64(100) 23 | } 24 | 25 | func (stat_t *stat_t_wrapper) Mode() os.FileMode { 26 | return os.FileMode(stat_t.stat.Mode) 27 | } 28 | 29 | func (stat_t *stat_t_wrapper) ModTime() time.Time { 30 | return time.Now() 31 | } 32 | 33 | func (stat_t *stat_t_wrapper) IsDir() bool { 34 | return true 35 | } 36 | 37 | func (stat_t *stat_t_wrapper) Sys() interface{} { 38 | return &stat_t.stat 39 | } 40 | 41 | func TestHasCorrectRootGroupPermissions(t *testing.T) { 42 | currentUID := uint32(os.Getuid()) 43 | currentGID := uint32(os.Getgid()) 44 | 45 | testData := []struct { 46 | expectedError error 47 | stat syscall.Stat_t 48 | }{ 49 | { 50 | expectedError: nil, 51 | stat: syscall.Stat_t{ 52 | Mode: 0600, 53 | Uid: currentUID, 54 | Gid: currentGID, 55 | }, 56 | }, 57 | { 58 | expectedError: nil, 59 | stat: syscall.Stat_t{ 60 | Mode: 0640, 61 | Uid: 0, 62 | Gid: currentGID, 63 | }, 64 | }, 65 | { 66 | expectedError: errSSLKeyHasUnacceptableUserPermissions, 67 | stat: syscall.Stat_t{ 68 | Mode: 0666, 69 | Uid: currentUID, 70 | Gid: currentGID, 71 | }, 72 | }, 73 | { 74 | expectedError: errSSLKeyHasUnacceptableRootPermissions, 75 | stat: syscall.Stat_t{ 76 | Mode: 0666, 77 | Uid: 0, 78 | Gid: currentGID, 79 | }, 80 | }, 81 | } 82 | 83 | for _, test := range testData { 84 | wrapper := &stat_t_wrapper{ 85 | stat: test.stat, 86 | } 87 | 88 | if test.expectedError != hasCorrectPermissions(wrapper) { 89 | if test.expectedError == nil { 90 | t.Errorf( 91 | "file owned by %d:%d with %s should not have failed check with error \"%s\"", 92 | test.stat.Uid, 93 | test.stat.Gid, 94 | wrapper.Mode(), 95 | hasCorrectPermissions(wrapper), 96 | ) 97 | continue 98 | } 99 | t.Errorf( 100 | "file owned by %d:%d with %s, expected \"%s\", got \"%s\"", 101 | test.stat.Uid, 102 | test.stat.Gid, 103 | wrapper.Mode(), 104 | test.expectedError, 105 | hasCorrectPermissions(wrapper), 106 | ) 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /ssl_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | // This file contains SSL tests 4 | 5 | import ( 6 | "bytes" 7 | _ "crypto/sha256" 8 | "crypto/tls" 9 | "crypto/x509" 10 | "database/sql" 11 | "errors" 12 | "fmt" 13 | "io" 14 | "net" 15 | "os" 16 | "path/filepath" 17 | "strings" 18 | "testing" 19 | "time" 20 | ) 21 | 22 | func maybeSkipSSLTests(t *testing.T) { 23 | // Require some special variables for testing certificates 24 | if os.Getenv("PQSSLCERTTEST_PATH") == "" { 25 | t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests") 26 | } 27 | 28 | value := os.Getenv("PQGOSSLTESTS") 29 | if value == "" || value == "0" { 30 | t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests") 31 | } else if value != "1" { 32 | t.Fatalf("unexpected value %q for PQGOSSLTESTS", value) 33 | } 34 | } 35 | 36 | func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) { 37 | db, err := openTestConnConninfo(conninfo) 38 | if err != nil { 39 | // should never fail 40 | t.Fatal(err) 41 | } 42 | // Do something with the connection to see whether it's working or not. 43 | tx, err := db.Begin() 44 | if err == nil { 45 | return db, tx.Rollback() 46 | } 47 | _ = db.Close() 48 | return nil, err 49 | } 50 | 51 | func checkSSLSetup(t *testing.T, conninfo string) { 52 | _, err := openSSLConn(t, conninfo) 53 | if pge, ok := err.(*Error); ok { 54 | if pge.Code.Name() != "invalid_authorization_specification" { 55 | t.Fatalf("unexpected error code '%s'", pge.Code.Name()) 56 | } 57 | } else { 58 | t.Fatalf("expected %T, got %v", (*Error)(nil), err) 59 | } 60 | } 61 | 62 | // Connect over SSL and run a simple query to test the basics 63 | func TestSSLConnection(t *testing.T) { 64 | maybeSkipSSLTests(t) 65 | // Environment sanity check: should fail without SSL 66 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 67 | 68 | db, err := openSSLConn(t, "sslmode=require user=pqgossltest") 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | rows, err := db.Query("SELECT 1") 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | rows.Close() 77 | } 78 | 79 | // Test sslmode=verify-full 80 | func TestSSLVerifyFull(t *testing.T) { 81 | maybeSkipSSLTests(t) 82 | // Environment sanity check: should fail without SSL 83 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 84 | 85 | // Not OK according to the system CA 86 | _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest") 87 | if err == nil { 88 | t.Fatal("expected error") 89 | } 90 | { 91 | var x509err x509.UnknownAuthorityError 92 | if !errors.As(err, &x509err) { 93 | var x509err x509.HostnameError 94 | if !errors.As(err, &x509err) { 95 | t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err) 96 | } 97 | } 98 | } 99 | 100 | rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") 101 | rootCert := "sslrootcert=" + rootCertPath + " " 102 | // No match on Common Name 103 | _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest") 104 | if err == nil { 105 | t.Fatal("expected error") 106 | } 107 | { 108 | var x509err x509.HostnameError 109 | if !errors.As(err, &x509err) { 110 | t.Fatalf("expected x509.HostnameError, got %#+v", err) 111 | } 112 | } 113 | // OK 114 | _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest") 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | } 119 | 120 | // Test sslmode=require sslrootcert=rootCertPath 121 | func TestSSLRequireWithRootCert(t *testing.T) { 122 | maybeSkipSSLTests(t) 123 | // Environment sanity check: should fail without SSL 124 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 125 | 126 | bogusRootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "bogus_root.crt") 127 | bogusRootCert := "sslrootcert=" + bogusRootCertPath + " " 128 | 129 | // Not OK according to the bogus CA 130 | _, err := openSSLConn(t, bogusRootCert+"host=postgres sslmode=require user=pqgossltest") 131 | if err == nil { 132 | t.Fatal("expected error") 133 | } 134 | { 135 | var x509err x509.UnknownAuthorityError 136 | if !errors.As(err, &x509err) { 137 | t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err) 138 | } 139 | } 140 | 141 | nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt") 142 | nonExistentCert := "sslrootcert=" + nonExistentCertPath + " " 143 | 144 | // No match on Common Name, but that's OK because we're not validating anything. 145 | _, err = openSSLConn(t, nonExistentCert+"host=127.0.0.1 sslmode=require user=pqgossltest") 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") 151 | rootCert := "sslrootcert=" + rootCertPath + " " 152 | 153 | // No match on Common Name, but that's OK because we're not validating the CN. 154 | _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=require user=pqgossltest") 155 | if err != nil { 156 | t.Fatal(err) 157 | } 158 | // Everything OK 159 | _, err = openSSLConn(t, rootCert+"host=postgres sslmode=require user=pqgossltest") 160 | if err != nil { 161 | t.Fatal(err) 162 | } 163 | } 164 | 165 | // Test sslmode=verify-ca 166 | func TestSSLVerifyCA(t *testing.T) { 167 | maybeSkipSSLTests(t) 168 | // Environment sanity check: should fail without SSL 169 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 170 | 171 | // Not OK according to the system CA 172 | { 173 | _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") 174 | var x509err x509.UnknownAuthorityError 175 | if !errors.As(err, &x509err) { 176 | t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) 177 | } 178 | } 179 | 180 | // Still not OK according to the system CA; empty sslrootcert is treated as unspecified. 181 | { 182 | _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''") 183 | var x509err x509.UnknownAuthorityError 184 | if !errors.As(err, &x509err) { 185 | t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) 186 | } 187 | } 188 | 189 | rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") 190 | rootCert := "sslrootcert=" + rootCertPath + " " 191 | // No match on Common Name, but that's OK 192 | if _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil { 193 | t.Fatal(err) 194 | } 195 | // Everything OK 196 | if _, err := openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest"); err != nil { 197 | t.Fatal(err) 198 | } 199 | } 200 | 201 | // Authenticate over SSL using client certificates 202 | func TestSSLClientCertificates(t *testing.T) { 203 | maybeSkipSSLTests(t) 204 | // Environment sanity check: should fail without SSL 205 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 206 | 207 | const baseinfo = "sslmode=require user=pqgosslcert" 208 | 209 | // Certificate not specified, should fail 210 | { 211 | _, err := openSSLConn(t, baseinfo) 212 | if pge, ok := err.(*Error); ok { 213 | if pge.Code.Name() != "invalid_authorization_specification" { 214 | t.Fatalf("unexpected error code '%s'", pge.Code.Name()) 215 | } 216 | } else { 217 | t.Fatalf("expected %T, got %v", (*Error)(nil), err) 218 | } 219 | } 220 | 221 | // Empty certificate specified, should fail 222 | { 223 | _, err := openSSLConn(t, baseinfo+" sslcert=''") 224 | if pge, ok := err.(*Error); ok { 225 | if pge.Code.Name() != "invalid_authorization_specification" { 226 | t.Fatalf("unexpected error code '%s'", pge.Code.Name()) 227 | } 228 | } else { 229 | t.Fatalf("expected %T, got %v", (*Error)(nil), err) 230 | } 231 | } 232 | 233 | // Non-existent certificate specified, should fail 234 | { 235 | _, err := openSSLConn(t, baseinfo+" sslcert=/tmp/filedoesnotexist") 236 | if pge, ok := err.(*Error); ok { 237 | if pge.Code.Name() != "invalid_authorization_specification" { 238 | t.Fatalf("unexpected error code '%s'", pge.Code.Name()) 239 | } 240 | } else { 241 | t.Fatalf("expected %T, got %v", (*Error)(nil), err) 242 | } 243 | } 244 | 245 | certpath, ok := os.LookupEnv("PQSSLCERTTEST_PATH") 246 | if !ok { 247 | t.Fatalf("PQSSLCERTTEST_PATH not present in environment") 248 | } 249 | 250 | sslcert := filepath.Join(certpath, "postgresql.crt") 251 | 252 | // Cert present, key not specified, should fail 253 | { 254 | _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert) 255 | var pathErr *os.PathError 256 | if !errors.As(err, &pathErr) { 257 | t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) 258 | } 259 | } 260 | 261 | // Cert present, empty key specified, should fail 262 | { 263 | _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''") 264 | var pathErr *os.PathError 265 | if !errors.As(err, &pathErr) { 266 | t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) 267 | } 268 | } 269 | 270 | // Cert present, non-existent key, should fail 271 | { 272 | _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist") 273 | var pathErr *os.PathError 274 | if !errors.As(err, &pathErr) { 275 | t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) 276 | } 277 | } 278 | 279 | // Key has wrong permissions (passing the cert as the key), should fail 280 | if _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslcert); err != ErrSSLKeyHasWorldPermissions { 281 | t.Fatalf("expected %s, got %#+v", ErrSSLKeyHasWorldPermissions, err) 282 | } 283 | 284 | sslkey := filepath.Join(certpath, "postgresql.key") 285 | 286 | // Should work 287 | if db, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslkey); err != nil { 288 | t.Fatal(err) 289 | } else { 290 | rows, err := db.Query("SELECT 1") 291 | if err != nil { 292 | t.Fatal(err) 293 | } 294 | if err := rows.Close(); err != nil { 295 | t.Fatal(err) 296 | } 297 | if err := db.Close(); err != nil { 298 | t.Fatal(err) 299 | } 300 | } 301 | } 302 | 303 | // Check that clint sends SNI data when `sslsni` is not disabled 304 | func TestSNISupport(t *testing.T) { 305 | t.Parallel() 306 | tests := []struct { 307 | name string 308 | conn_param string 309 | hostname string 310 | expected_sni string 311 | }{ 312 | { 313 | name: "SNI is set by default", 314 | conn_param: "", 315 | hostname: "localhost", 316 | expected_sni: "localhost", 317 | }, 318 | { 319 | name: "SNI is passed when asked for", 320 | conn_param: "sslsni=1", 321 | hostname: "localhost", 322 | expected_sni: "localhost", 323 | }, 324 | { 325 | name: "SNI is not passed when disabled", 326 | conn_param: "sslsni=0", 327 | hostname: "localhost", 328 | expected_sni: "", 329 | }, 330 | { 331 | name: "SNI is not set for IPv4", 332 | conn_param: "", 333 | hostname: "127.0.0.1", 334 | expected_sni: "", 335 | }, 336 | } 337 | for _, tt := range tests { 338 | tt := tt 339 | t.Run(tt.name, func(t *testing.T) { 340 | t.Parallel() 341 | 342 | // Start mock postgres server on OS-provided port 343 | listener, err := net.Listen("tcp", "127.0.0.1:") 344 | if err != nil { 345 | t.Fatal(err) 346 | } 347 | serverErrChan := make(chan error, 1) 348 | serverSNINameChan := make(chan string, 1) 349 | go mockPostgresSSL(listener, serverErrChan, serverSNINameChan) 350 | 351 | defer listener.Close() 352 | defer close(serverErrChan) 353 | defer close(serverSNINameChan) 354 | 355 | // Try to establish a connection with the mock server. Connection will error out after TLS 356 | // clientHello, but it is enough to catch SNI data on the server side 357 | port := strings.Split(listener.Addr().String(), ":")[1] 358 | connStr := fmt.Sprintf("sslmode=require host=%s port=%s %s", tt.hostname, port, tt.conn_param) 359 | 360 | // We are okay to skip this error as we are polling serverErrChan and we'll get an error 361 | // or timeout from the server side in case of problems here. 362 | db, _ := sql.Open("postgres", connStr) 363 | _, _ = db.Exec("SELECT 1") 364 | 365 | // Check SNI data 366 | select { 367 | case sniHost := <-serverSNINameChan: 368 | if sniHost != tt.expected_sni { 369 | t.Fatalf("Expected SNI to be 'localhost', got '%+v' instead", sniHost) 370 | } 371 | case err = <-serverErrChan: 372 | t.Fatalf("mock server failed with error: %+v", err) 373 | case <-time.After(time.Second): 374 | t.Fatal("exceeded connection timeout without erroring out") 375 | } 376 | }) 377 | } 378 | } 379 | 380 | // Make a postgres mock server to test TLS SNI 381 | // 382 | // Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection. 383 | // While reading clientHello catch passed SNI data and report it to nameChan. 384 | func mockPostgresSSL(listener net.Listener, errChan chan error, nameChan chan string) { 385 | var sniHost string 386 | 387 | conn, err := listener.Accept() 388 | if err != nil { 389 | errChan <- err 390 | return 391 | } 392 | defer conn.Close() 393 | 394 | err = conn.SetDeadline(time.Now().Add(time.Second)) 395 | if err != nil { 396 | errChan <- err 397 | return 398 | } 399 | 400 | // Receive StartupMessage with SSL Request 401 | startupMessage := make([]byte, 8) 402 | if _, err := io.ReadFull(conn, startupMessage); err != nil { 403 | errChan <- err 404 | return 405 | } 406 | // StartupMessage: first four bytes -- total len = 8, last four bytes SslRequestNumber 407 | if !bytes.Equal(startupMessage, []byte{0, 0, 0, 0x8, 0x4, 0xd2, 0x16, 0x2f}) { 408 | errChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage) 409 | return 410 | } 411 | 412 | // Respond with SSLOk 413 | _, err = conn.Write([]byte("S")) 414 | if err != nil { 415 | errChan <- err 416 | return 417 | } 418 | 419 | // Set up TLS context to catch clientHello. It will always error out during handshake 420 | // as no certificate is set. 421 | srv := tls.Server(conn, &tls.Config{ 422 | GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { 423 | sniHost = argHello.ServerName 424 | return nil, nil 425 | }, 426 | }) 427 | defer srv.Close() 428 | 429 | // Do the TLS handshake ignoring errors 430 | _ = srv.Handshake() 431 | 432 | nameChan <- sniHost 433 | } 434 | -------------------------------------------------------------------------------- /ssl_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | // +build windows 3 | 4 | package pq 5 | 6 | // sslKeyPermissions checks the permissions on user-supplied ssl key files. 7 | // The key file should have very little access. 8 | // 9 | // libpq does not check key file permissions on Windows. 10 | func sslKeyPermissions(string) error { return nil } 11 | -------------------------------------------------------------------------------- /url.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | nurl "net/url" 7 | "sort" 8 | "strings" 9 | ) 10 | 11 | // ParseURL no longer needs to be used by clients of this library since supplying a URL as a 12 | // connection string to sql.Open() is now supported: 13 | // 14 | // sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") 15 | // 16 | // It remains exported here for backwards-compatibility. 17 | // 18 | // ParseURL converts a url to a connection string for driver.Open. 19 | // Example: 20 | // 21 | // "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" 22 | // 23 | // converts to: 24 | // 25 | // "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" 26 | // 27 | // A minimal example: 28 | // 29 | // "postgres://" 30 | // 31 | // This will be blank, causing driver.Open to use all of the defaults 32 | func ParseURL(url string) (string, error) { 33 | u, err := nurl.Parse(url) 34 | if err != nil { 35 | return "", err 36 | } 37 | 38 | if u.Scheme != "postgres" && u.Scheme != "postgresql" { 39 | return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) 40 | } 41 | 42 | var kvs []string 43 | escaper := strings.NewReplacer(`'`, `\'`, `\`, `\\`) 44 | accrue := func(k, v string) { 45 | if v != "" { 46 | kvs = append(kvs, k+"='"+escaper.Replace(v)+"'") 47 | } 48 | } 49 | 50 | if u.User != nil { 51 | v := u.User.Username() 52 | accrue("user", v) 53 | 54 | v, _ = u.User.Password() 55 | accrue("password", v) 56 | } 57 | 58 | if host, port, err := net.SplitHostPort(u.Host); err != nil { 59 | accrue("host", u.Host) 60 | } else { 61 | accrue("host", host) 62 | accrue("port", port) 63 | } 64 | 65 | if u.Path != "" { 66 | accrue("dbname", u.Path[1:]) 67 | } 68 | 69 | q := u.Query() 70 | for k := range q { 71 | accrue(k, q.Get(k)) 72 | } 73 | 74 | sort.Strings(kvs) // Makes testing easier (not a performance concern) 75 | return strings.Join(kvs, " "), nil 76 | } 77 | -------------------------------------------------------------------------------- /url_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSimpleParseURL(t *testing.T) { 8 | expected := "host='hostname.remote'" 9 | str, err := ParseURL("postgres://hostname.remote") 10 | if err != nil { 11 | t.Fatal(err) 12 | } 13 | 14 | if str != expected { 15 | t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) 16 | } 17 | } 18 | 19 | func TestIPv6LoopbackParseURL(t *testing.T) { 20 | expected := "host='::1' port='1234'" 21 | str, err := ParseURL("postgres://[::1]:1234") 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | 26 | if str != expected { 27 | t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) 28 | } 29 | } 30 | 31 | func TestFullParseURL(t *testing.T) { 32 | expected := `dbname='database' host='hostname.remote' password='top secret' port='1234' user='username'` 33 | str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | if str != expected { 39 | t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected) 40 | } 41 | } 42 | 43 | func TestInvalidProtocolParseURL(t *testing.T) { 44 | _, err := ParseURL("http://hostname.remote") 45 | switch err { 46 | case nil: 47 | t.Fatal("Expected an error from parsing invalid protocol") 48 | default: 49 | msg := "invalid connection protocol: http" 50 | if err.Error() != msg { 51 | t.Fatalf("Unexpected error message:\n+ %s\n- %s", 52 | err.Error(), msg) 53 | } 54 | } 55 | } 56 | 57 | func TestMinimalURL(t *testing.T) { 58 | cs, err := ParseURL("postgres://") 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | if cs != "" { 64 | t.Fatalf("expected blank connection string, got: %q", cs) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /user_other.go: -------------------------------------------------------------------------------- 1 | // Package pq is a pure Go Postgres driver for the database/sql package. 2 | 3 | //go:build js || android || hurd || zos 4 | // +build js android hurd zos 5 | 6 | package pq 7 | 8 | func userCurrent() (string, error) { 9 | return "", ErrCouldNotDetectUsername 10 | } 11 | -------------------------------------------------------------------------------- /user_posix.go: -------------------------------------------------------------------------------- 1 | // Package pq is a pure Go Postgres driver for the database/sql package. 2 | 3 | //go:build aix || darwin || dragonfly || freebsd || (linux && !android) || nacl || netbsd || openbsd || plan9 || solaris || rumprun || illumos 4 | // +build aix darwin dragonfly freebsd linux,!android nacl netbsd openbsd plan9 solaris rumprun illumos 5 | 6 | package pq 7 | 8 | import ( 9 | "os" 10 | "os/user" 11 | ) 12 | 13 | func userCurrent() (string, error) { 14 | u, err := user.Current() 15 | if err == nil { 16 | return u.Username, nil 17 | } 18 | 19 | name := os.Getenv("USER") 20 | if name != "" { 21 | return name, nil 22 | } 23 | 24 | return "", ErrCouldNotDetectUsername 25 | } 26 | -------------------------------------------------------------------------------- /user_windows.go: -------------------------------------------------------------------------------- 1 | // Package pq is a pure Go Postgres driver for the database/sql package. 2 | package pq 3 | 4 | import ( 5 | "path/filepath" 6 | "syscall" 7 | ) 8 | 9 | // Perform Windows user name lookup identically to libpq. 10 | // 11 | // The PostgreSQL code makes use of the legacy Win32 function 12 | // GetUserName, and that function has not been imported into stock Go. 13 | // GetUserNameEx is available though, the difference being that a 14 | // wider range of names are available. To get the output to be the 15 | // same as GetUserName, only the base (or last) component of the 16 | // result is returned. 17 | func userCurrent() (string, error) { 18 | pw_name := make([]uint16, 128) 19 | pwname_size := uint32(len(pw_name)) - 1 20 | err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size) 21 | if err != nil { 22 | return "", ErrCouldNotDetectUsername 23 | } 24 | s := syscall.UTF16ToString(pw_name) 25 | u := filepath.Base(s) 26 | return u, nil 27 | } 28 | -------------------------------------------------------------------------------- /uuid.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "encoding/hex" 5 | "fmt" 6 | ) 7 | 8 | // decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. 9 | func decodeUUIDBinary(src []byte) ([]byte, error) { 10 | if len(src) != 16 { 11 | return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) 12 | } 13 | 14 | dst := make([]byte, 36) 15 | dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' 16 | hex.Encode(dst[0:], src[0:4]) 17 | hex.Encode(dst[9:], src[4:6]) 18 | hex.Encode(dst[14:], src[6:8]) 19 | hex.Encode(dst[19:], src[8:10]) 20 | hex.Encode(dst[24:], src[10:16]) 21 | 22 | return dst, nil 23 | } 24 | -------------------------------------------------------------------------------- /uuid_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "reflect" 5 | "strings" 6 | "testing" 7 | ) 8 | 9 | func TestDecodeUUIDBinaryError(t *testing.T) { 10 | t.Parallel() 11 | _, err := decodeUUIDBinary([]byte{0x12, 0x34}) 12 | 13 | if err == nil { 14 | t.Fatal("Expected error, got none") 15 | } 16 | if !strings.HasPrefix(err.Error(), "pq:") { 17 | t.Errorf("Expected error to start with %q, got %q", "pq:", err.Error()) 18 | } 19 | if !strings.Contains(err.Error(), "bad length: 2") { 20 | t.Errorf("Expected error to contain length, got %q", err.Error()) 21 | } 22 | } 23 | 24 | func BenchmarkDecodeUUIDBinary(b *testing.B) { 25 | x := []byte{0x03, 0xa3, 0x52, 0x2f, 0x89, 0x28, 0x49, 0x87, 0x84, 0xd6, 0x93, 0x7b, 0x36, 0xec, 0x27, 0x6f} 26 | 27 | for i := 0; i < b.N; i++ { 28 | decodeUUIDBinary(x) 29 | } 30 | } 31 | 32 | func TestDecodeUUIDBackend(t *testing.T) { 33 | db := openTestConn(t) 34 | defer db.Close() 35 | 36 | var s = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a" 37 | var scanned interface{} 38 | 39 | err := db.QueryRow(`SELECT $1::uuid`, s).Scan(&scanned) 40 | if err != nil { 41 | t.Fatalf("Expected no error, got %v", err) 42 | } 43 | if !reflect.DeepEqual(scanned, []byte(s)) { 44 | t.Errorf("Expected []byte(%q), got %T(%q)", s, scanned, scanned) 45 | } 46 | } 47 | --------------------------------------------------------------------------------