├── .gitignore ├── oid ├── doc.go ├── gen.go └── types.go ├── certs ├── README ├── postgresql.key ├── bogus_root.crt ├── root.crt ├── server.key ├── postgresql.crt └── server.crt ├── user_posix.go ├── issues_test.go ├── user_windows.go ├── .travis.yml ├── LICENSE.md ├── CONTRIBUTING.md ├── go18_test.go ├── url_test.go ├── url.go ├── buf.go ├── .travis.sh ├── hstore ├── hstore.go └── hstore_test.go ├── listen_example └── doc.go ├── README.md ├── copy.go ├── ssl_test.go ├── doc.go ├── copy_test.go ├── bench_test.go ├── notify_test.go ├── error.go ├── encode.go ├── array.go ├── encode_test.go └── notify.go /.gitignore: -------------------------------------------------------------------------------- 1 | .db 2 | *.test 3 | *~ 4 | *.swp 5 | -------------------------------------------------------------------------------- /oid/doc.go: -------------------------------------------------------------------------------- 1 | // Package oid contains OID constants 2 | // as defined by the Postgres server. 3 | package oid 4 | 5 | // Oid is a Postgres Object ID. 6 | type Oid uint32 7 | -------------------------------------------------------------------------------- /certs/README: -------------------------------------------------------------------------------- 1 | This directory contains certificates and private keys for testing some 2 | SSL-related functionality in Travis. Do NOT use these certificates for 3 | anything other than testing. 4 | -------------------------------------------------------------------------------- /user_posix.go: -------------------------------------------------------------------------------- 1 | // Package pq is a pure Go Postgres driver for the database/sql package. 2 | 3 | // +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris rumprun 4 | 5 | package pq 6 | 7 | import ( 8 | "os" 9 | "os/user" 10 | ) 11 | 12 | func userCurrent() (string, error) { 13 | u, err := user.Current() 14 | if err == nil { 15 | return u.Username, nil 16 | } 17 | 18 | name := os.Getenv("USER") 19 | if name != "" { 20 | return name, nil 21 | } 22 | 23 | return "", ErrCouldNotDetectUsername 24 | } 25 | -------------------------------------------------------------------------------- /issues_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import "testing" 4 | 5 | func TestIssue494(t *testing.T) { 6 | db := openTestConn(t) 7 | defer db.Close() 8 | 9 | query := `CREATE TEMP TABLE t (i INT PRIMARY KEY)` 10 | if _, err := db.Exec(query); err != nil { 11 | t.Fatal(err) 12 | } 13 | 14 | txn, err := db.Begin() 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | 19 | if _, err := txn.Prepare(CopyIn("t", "i")); err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | if _, err := txn.Query("SELECT 1"); err == nil { 24 | t.Fatal("expected error") 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /certs/postgresql.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIICWwIBAAKBgQDjjAaacFRR0TQ0gznNolkPBe2N2A400JL0CU3ujHhVSST4POA0 3 | WAKy55RYwejlu9Gv9lTBQLGQcHkNNVScjxbpwvCS5mRJOMF2+EdmxFtKtqlDzsi+ 4 | bE0rlJc8VbzR0G63U66JXEtrhkC+wa4eZM6crocKaeXIIRK+rh32Rd8WpwIDAQAB 5 | AoGAM5dM6/kp9P700i8qjOgRPym96Zoh5nGfz/rIE5z/r36NBkdvIg8OVZfR96nH 6 | b0b9TOMR5lsPp0sI9yivTWvX6qyvLJRWy2vvx17hXK9NxXUNTAm0PYZUTvCtcPeX 7 | RnJpzQKNZQPkFzF0uXBc4CtPK2Vz0+FGvAelrhYAxnw1dIkCQQD+9qaW5QhXjsjb 8 | Nl85CmXgxPmGROcgLQCO+omfrjf9UXrituU9Dz6auym5lDGEdMFnkzfr+wpasEy9 9 | mf5ZZOhDAkEA5HjXfVGaCtpydOt6hDon/uZsyssCK2lQ7NSuE3vP+sUsYMzIpEoy 10 | t3VWXqKbo+g9KNDTP4WEliqp1aiSIylzzQJANPeqzihQnlgEdD4MdD4rwhFJwVIp 11 | Le8Lcais1KaN7StzOwxB/XhgSibd2TbnPpw+3bSg5n5lvUdo+e62/31OHwJAU1jS 12 | I+F09KikQIr28u3UUWT2IzTT4cpVv1AHAQyV3sG3YsjSGT0IK20eyP9BEBZU2WL0 13 | 7aNjrvR5aHxKc5FXsQJABsFtyGpgI5X4xufkJZVZ+Mklz2n7iXa+XPatMAHFxAtb 14 | EEMt60rngwMjXAzBSC6OYuYogRRAY3UCacNC5VhLYQ== 15 | -----END RSA PRIVATE KEY----- 16 | -------------------------------------------------------------------------------- /user_windows.go: -------------------------------------------------------------------------------- 1 | // Package pq is a pure Go Postgres driver for the database/sql package. 2 | package pq 3 | 4 | import ( 5 | "path/filepath" 6 | "syscall" 7 | ) 8 | 9 | // Perform Windows user name lookup identically to libpq. 10 | // 11 | // The PostgreSQL code makes use of the legacy Win32 function 12 | // GetUserName, and that function has not been imported into stock Go. 13 | // GetUserNameEx is available though, the difference being that a 14 | // wider range of names are available. To get the output to be the 15 | // same as GetUserName, only the base (or last) component of the 16 | // result is returned. 17 | func userCurrent() (string, error) { 18 | pw_name := make([]uint16, 128) 19 | pwname_size := uint32(len(pw_name)) - 1 20 | err := syscall.GetUserNameEx(syscall.NameSamCompatible, &pw_name[0], &pwname_size) 21 | if err != nil { 22 | return "", ErrCouldNotDetectUsername 23 | } 24 | s := syscall.UTF16ToString(pw_name) 25 | u := filepath.Base(s) 26 | return u, nil 27 | } 28 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.5 5 | - 1.6 6 | - 1.7 7 | - tip 8 | 9 | sudo: true 10 | 11 | env: 12 | global: 13 | - PGUSER=postgres 14 | - PQGOSSLTESTS=1 15 | - PQSSLCERTTEST_PATH=$PWD/certs 16 | - PGHOST=127.0.0.1 17 | matrix: 18 | - PGVERSION=9.6 19 | - PGVERSION=9.5 20 | - PGVERSION=9.4 21 | - PGVERSION=9.3 22 | - PGVERSION=9.2 23 | - PGVERSION=9.1 24 | - PGVERSION=9.0 25 | 26 | before_install: 27 | - ./.travis.sh postgresql_uninstall 28 | - ./.travis.sh pgdg_repository 29 | - ./.travis.sh postgresql_install 30 | - ./.travis.sh postgresql_configure 31 | - ./.travis.sh client_configure 32 | - go get golang.org/x/tools/cmd/goimports 33 | 34 | before_script: 35 | - createdb pqgotest 36 | - createuser -DRS pqgossltest 37 | - createuser -DRS pqgosslcert 38 | 39 | script: 40 | - > 41 | goimports -d -e $(find -name '*.go') | awk '{ print } END { exit NR == 0 ? 0 : 1 }' 42 | - go vet ./... 43 | - PQTEST_BINARY_PARAMETERS=no go test -v ./... 44 | - PQTEST_BINARY_PARAMETERS=yes go test -v ./... 45 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011-2013, 'pq' Contributors 2 | Portions Copyright (C) 2011 Blake Mizerany 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | -------------------------------------------------------------------------------- /certs/bogus_root.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDBjCCAe6gAwIBAgIQSnDYp/Naet9HOZljF5PuwDANBgkqhkiG9w0BAQsFADAr 3 | MRIwEAYDVQQKEwlDb2Nrcm9hY2gxFTATBgNVBAMTDENvY2tyb2FjaCBDQTAeFw0x 4 | NjAyMDcxNjQ0MzdaFw0xNzAyMDYxNjQ0MzdaMCsxEjAQBgNVBAoTCUNvY2tyb2Fj 5 | aDEVMBMGA1UEAxMMQ29ja3JvYWNoIENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A 6 | MIIBCgKCAQEAxdln3/UdgP7ayA/G1kT7upjLe4ERwQjYQ25q0e1+vgsB5jhiirxJ 7 | e0+WkhhYu/mwoSAXzvlsbZ2PWFyfdanZeD/Lh6SvIeWXVVaPcWVWL1TEcoN2jr5+ 8 | E85MMHmbbmaT2he8s6br2tM/UZxyTQ2XRprIzApbDssyw1c0Yufcpu3C6267FLEl 9 | IfcWrzDhnluFhthhtGXv3ToD8IuMScMC5qlKBXtKmD1B5x14ngO/ecNJ+OlEi0HU 10 | mavK4KWgI2rDXRZ2EnCpyTZdkc3kkRnzKcg653oOjMDRZdrhfIrha+Jq38ACsUmZ 11 | Su7Sp5jkIHOCO8Zg+l6GKVSq37dKMapD8wIDAQABoyYwJDAOBgNVHQ8BAf8EBAMC 12 | AuQwEgYDVR0TAQH/BAgwBgEB/wIBATANBgkqhkiG9w0BAQsFAAOCAQEAwZ2Tu0Yu 13 | rrSVdMdoPEjT1IZd+5OhM/SLzL0ddtvTithRweLHsw2lDQYlXFqr24i3UGZJQ1sp 14 | cqSrNwswgLUQT3vWyTjmM51HEb2vMYWKmjZ+sBQYAUP1CadrN/+OTfNGnlF1+B4w 15 | IXOzh7EvQmJJnNybLe4a/aRvj1NE2n8Z898B76SVU9WbfKKz8VwLzuIPDqkKcZda 16 | lMy5yzthyztV9YjcWs2zVOUGZvGdAhDrvZuUq6mSmxrBEvR2LBOggmVf3tGRT+Ls 17 | lW7c9Lrva5zLHuqmoPP07A+vuI9a0D1X44jwGDuPWJ5RnTOQ63Uez12mKNjqleHw 18 | DnkwNanuO8dhAA== 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing to pq 2 | 3 | `pq` has a backlog of pull requests, but contributions are still very 4 | much welcome. You can help with patch review, submitting bug reports, 5 | or adding new functionality. There is no formal style guide, but 6 | please conform to the style of existing code and general Go formatting 7 | conventions when submitting patches. 8 | 9 | ### Patch review 10 | 11 | Help review existing open pull requests by commenting on the code or 12 | proposed functionality. 13 | 14 | ### Bug reports 15 | 16 | We appreciate any bug reports, but especially ones with self-contained 17 | (doesn't depend on code outside of pq), minimal (can't be simplified 18 | further) test cases. It's especially helpful if you can submit a pull 19 | request with just the failing test case (you'll probably want to 20 | pattern it after the tests in 21 | [conn_test.go](https://github.com/lib/pq/blob/master/conn_test.go). 22 | 23 | ### New functionality 24 | 25 | There are a number of pending patches for new functionality, so 26 | additional feature patches will take a while to merge. Still, patches 27 | are generally reviewed based on usefulness and complexity in addition 28 | to time-in-queue, so if you have a knockout idea, take a shot. Feel 29 | free to open an issue discussion your proposed patch beforehand. 30 | -------------------------------------------------------------------------------- /go18_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.8 2 | 3 | package pq 4 | 5 | import "testing" 6 | 7 | func TestMultipleSimpleQuery(t *testing.T) { 8 | db := openTestConn(t) 9 | defer db.Close() 10 | 11 | rows, err := db.Query("select 1; set time zone default; select 2; select 3") 12 | if err != nil { 13 | t.Fatal(err) 14 | } 15 | defer rows.Close() 16 | 17 | var i int 18 | for rows.Next() { 19 | if err := rows.Scan(&i); err != nil { 20 | t.Fatal(err) 21 | } 22 | if i != 1 { 23 | t.Fatalf("expected 1, got %d", i) 24 | } 25 | } 26 | if !rows.NextResultSet() { 27 | t.Fatal("expected more result sets", rows.Err()) 28 | } 29 | for rows.Next() { 30 | if err := rows.Scan(&i); err != nil { 31 | t.Fatal(err) 32 | } 33 | if i != 2 { 34 | t.Fatalf("expected 2, got %d", i) 35 | } 36 | } 37 | 38 | // Make sure that if we ignore a result we can still query. 39 | 40 | rows, err = db.Query("select 4; select 5") 41 | if err != nil { 42 | t.Fatal(err) 43 | } 44 | defer rows.Close() 45 | 46 | for rows.Next() { 47 | if err := rows.Scan(&i); err != nil { 48 | t.Fatal(err) 49 | } 50 | if i != 4 { 51 | t.Fatalf("expected 4, got %d", i) 52 | } 53 | } 54 | if !rows.NextResultSet() { 55 | t.Fatal("expected more result sets", rows.Err()) 56 | } 57 | for rows.Next() { 58 | if err := rows.Scan(&i); err != nil { 59 | t.Fatal(err) 60 | } 61 | if i != 5 { 62 | t.Fatalf("expected 5, got %d", i) 63 | } 64 | } 65 | if rows.NextResultSet() { 66 | t.Fatal("unexpected result set") 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /certs/root.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEAzCCAuugAwIBAgIJANmheROCdW1NMA0GCSqGSIb3DQEBBQUAMF4xCzAJBgNV 3 | BAYTAlVTMQ8wDQYDVQQIEwZOZXZhZGExEjAQBgNVBAcTCUxhcyBWZWdhczEaMBgG 4 | A1UEChMRZ2l0aHViLmNvbS9saWIvcHExDjAMBgNVBAMTBXBxIENBMB4XDTE0MTAx 5 | MTE1MDQyOVoXDTI0MTAwODE1MDQyOVowXjELMAkGA1UEBhMCVVMxDzANBgNVBAgT 6 | Bk5ldmFkYTESMBAGA1UEBxMJTGFzIFZlZ2FzMRowGAYDVQQKExFnaXRodWIuY29t 7 | L2xpYi9wcTEOMAwGA1UEAxMFcHEgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw 8 | ggEKAoIBAQCV4PxP7ShzWBzUCThcKk3qZtOLtHmszQVtbqhvgTpm1kTRtKBdVMu0 9 | pLAHQ3JgJCnAYgH0iZxVGoMP16T3irdgsdC48+nNTFM2T0cCdkfDURGIhSFN47cb 10 | Pgy306BcDUD2q7ucW33+dlFSRuGVewocoh4BWM/vMtMvvWzdi4Ag/L/jhb+5wZxZ 11 | sWymsadOVSDePEMKOvlCa3EdVwVFV40TVyDb+iWBUivDAYsS2a3KajuJrO6MbZiE 12 | Sp2RCIkZS2zFmzWxVRi9ZhzIZhh7EVF9JAaNC3T52jhGUdlRq3YpBTMnd89iOh74 13 | 6jWXG7wSuPj3haFzyNhmJ0ZUh+2Ynoh1AgMBAAGjgcMwgcAwHQYDVR0OBBYEFFKT 14 | 7R52Cp9lT94ZZsHVIkA1y6ByMIGQBgNVHSMEgYgwgYWAFFKT7R52Cp9lT94ZZsHV 15 | IkA1y6ByoWKkYDBeMQswCQYDVQQGEwJVUzEPMA0GA1UECBMGTmV2YWRhMRIwEAYD 16 | VQQHEwlMYXMgVmVnYXMxGjAYBgNVBAoTEWdpdGh1Yi5jb20vbGliL3BxMQ4wDAYD 17 | VQQDEwVwcSBDQYIJANmheROCdW1NMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEF 18 | BQADggEBAAEhCLWkqJNMI8b4gkbmj5fqQ/4+oO83bZ3w2Oqf6eZ8I8BC4f2NOyE6 19 | tRUlq5+aU7eqC1cOAvGjO+YHN/bF/DFpwLlzvUSXt+JP/pYcUjL7v+pIvwqec9hD 20 | ndvM4iIbkD/H/OYQ3L+N3W+G1x7AcFIX+bGCb3PzYVQAjxreV6//wgKBosMGFbZo 21 | HPxT9RPMun61SViF04H5TNs0derVn1+5eiiYENeAhJzQNyZoOOUuX1X/Inx9bEPh 22 | C5vFBtSMgIytPgieRJVWAiMLYsfpIAStrHztRAbBs2DU01LmMgRvHdxgFEKinC/d 23 | UHZZQDP+6pT+zADrGhQGXe4eThaO6f0= 24 | -----END CERTIFICATE----- 25 | -------------------------------------------------------------------------------- /oid/gen.go: -------------------------------------------------------------------------------- 1 | // +build ignore 2 | 3 | // Generate the table of OID values 4 | // Run with 'go run gen.go'. 5 | package main 6 | 7 | import ( 8 | "database/sql" 9 | "fmt" 10 | "log" 11 | "os" 12 | "os/exec" 13 | 14 | _ "github.com/lib/pq" 15 | ) 16 | 17 | func main() { 18 | datname := os.Getenv("PGDATABASE") 19 | sslmode := os.Getenv("PGSSLMODE") 20 | 21 | if datname == "" { 22 | os.Setenv("PGDATABASE", "pqgotest") 23 | } 24 | 25 | if sslmode == "" { 26 | os.Setenv("PGSSLMODE", "disable") 27 | } 28 | 29 | db, err := sql.Open("postgres", "") 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | cmd := exec.Command("gofmt") 34 | cmd.Stderr = os.Stderr 35 | w, err := cmd.StdinPipe() 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | f, err := os.Create("types.go") 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | cmd.Stdout = f 44 | err = cmd.Start() 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | fmt.Fprintln(w, "// generated by 'go run gen.go'; do not edit") 49 | fmt.Fprintln(w, "\npackage oid") 50 | fmt.Fprintln(w, "const (") 51 | rows, err := db.Query(` 52 | SELECT typname, oid 53 | FROM pg_type WHERE oid < 10000 54 | ORDER BY oid; 55 | `) 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | var name string 60 | var oid int 61 | for rows.Next() { 62 | err = rows.Scan(&name, &oid) 63 | if err != nil { 64 | log.Fatal(err) 65 | } 66 | fmt.Fprintf(w, "T_%s Oid = %d\n", name, oid) 67 | } 68 | if err = rows.Err(); err != nil { 69 | log.Fatal(err) 70 | } 71 | fmt.Fprintln(w, ")") 72 | w.Close() 73 | cmd.Wait() 74 | } 75 | -------------------------------------------------------------------------------- /url_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestSimpleParseURL(t *testing.T) { 8 | expected := "host=hostname.remote" 9 | str, err := ParseURL("postgres://hostname.remote") 10 | if err != nil { 11 | t.Fatal(err) 12 | } 13 | 14 | if str != expected { 15 | t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) 16 | } 17 | } 18 | 19 | func TestIPv6LoopbackParseURL(t *testing.T) { 20 | expected := "host=::1 port=1234" 21 | str, err := ParseURL("postgres://[::1]:1234") 22 | if err != nil { 23 | t.Fatal(err) 24 | } 25 | 26 | if str != expected { 27 | t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) 28 | } 29 | } 30 | 31 | func TestFullParseURL(t *testing.T) { 32 | expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` 33 | str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | if str != expected { 39 | t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected) 40 | } 41 | } 42 | 43 | func TestInvalidProtocolParseURL(t *testing.T) { 44 | _, err := ParseURL("http://hostname.remote") 45 | switch err { 46 | case nil: 47 | t.Fatal("Expected an error from parsing invalid protocol") 48 | default: 49 | msg := "invalid connection protocol: http" 50 | if err.Error() != msg { 51 | t.Fatalf("Unexpected error message:\n+ %s\n- %s", 52 | err.Error(), msg) 53 | } 54 | } 55 | } 56 | 57 | func TestMinimalURL(t *testing.T) { 58 | cs, err := ParseURL("postgres://") 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | if cs != "" { 64 | t.Fatalf("expected blank connection string, got: %q", cs) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /certs/server.key: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEogIBAAKCAQEA14pMhfsXpTyP4HIRKc4/sB8/fcbuf6f8Ais1RwimPZDfXFYU 3 | lADHbdHS4mGVd7jjpmYx+R8hfWLhJ9qUN2FK6mNToGG4nLul4ue3ptgPBQTHKeLq 4 | SSt/3hUAphhwUMcM3pr5Wpaw4ZQGxm1KITu0D6VtkoY0sk7XDqcZwHcLe4fIkt5C 5 | /4bSt5qk1BUjyq2laSG4zn5my4Vdue2LLQmNlOQEHnLs79B2kBVapPeRS+nOTp1d 6 | mnAXnNjpc4PqPWGZps2skUBaiHflTiqOPRPz+ThvgWuKlcoOB6tv2rSM2f+qeAOq 7 | x8LPb2SS09iD1a/xIxinLnsXC+d98fqoQaMEVwIDAQABAoIBAF3ZoihUhJ82F4+r 8 | Gz4QyDpv4L1reT2sb1aiabhcU8ZK5nbWJG+tRyjSS/i2dNaEcttpdCj9HR/zhgZM 9 | bm0OuAgG58rVwgS80CZUruq++Qs+YVojq8/gWPTiQD4SNhV2Fmx3HkwLgUk3oxuT 10 | SsvdqzGE3okGVrutCIcgy126eA147VPMoej1Bb3fO6npqK0pFPhZfAc0YoqJuM+k 11 | obRm5pAnGUipyLCFXjA9HYPKwYZw2RtfdA3CiImHeanSdqS+ctrC9y8BV40Th7gZ 12 | haXdKUNdjmIxV695QQ1mkGqpKLZFqhzKioGQ2/Ly2d1iaKN9fZltTusu8unepWJ2 13 | tlT9qMECgYEA9uHaF1t2CqE+AJvWTihHhPIIuLxoOQXYea1qvxfcH/UMtaLKzCNm 14 | lQ5pqCGsPvp+10f36yttO1ZehIvlVNXuJsjt0zJmPtIolNuJY76yeussfQ9jHheB 15 | 5uPEzCFlHzxYbBUyqgWaF6W74okRGzEGJXjYSP0yHPPdU4ep2q3bGiUCgYEA34Af 16 | wBSuQSK7uLxArWHvQhyuvi43ZGXls6oRGl+Ysj54s8BP6XGkq9hEJ6G4yxgyV+BR 17 | DUOs5X8/TLT8POuIMYvKTQthQyCk0eLv2FLdESDuuKx0kBVY3s8lK3/z5HhrdOiN 18 | VMNZU+xDKgKc3hN9ypkk8vcZe6EtH7Y14e0rVcsCgYBTgxi8F/M5K0wG9rAqphNz 19 | VFBA9XKn/2M33cKjO5X5tXIEKzpAjaUQvNxexG04rJGljzG8+mar0M6ONahw5yD1 20 | O7i/XWgazgpuOEkkVYiYbd8RutfDgR4vFVMn3hAP3eDnRtBplRWH9Ec3HTiNIys6 21 | F8PKBOQjyRZQQC7jyzW3hQKBgACe5HeuFwXLSOYsb6mLmhR+6+VPT4wR1F95W27N 22 | USk9jyxAnngxfpmTkiziABdgS9N+pfr5cyN4BP77ia/Jn6kzkC5Cl9SN5KdIkA3z 23 | vPVtN/x/ThuQU5zaymmig1ThGLtMYggYOslG4LDfLPxY5YKIhle+Y+259twdr2yf 24 | Mf2dAoGAaGv3tWMgnIdGRk6EQL/yb9PKHo7ShN+tKNlGaK7WwzBdKs+Fe8jkgcr7 25 | pz4Ne887CmxejdISzOCcdT+Zm9Bx6I/uZwWOtDvWpIgIxVX9a9URj/+D1MxTE/y4 26 | d6H+c89yDY62I2+drMpdjCd3EtCaTlxpTbRS+s1eAHMH7aEkcCE= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /url.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | nurl "net/url" 7 | "sort" 8 | "strings" 9 | ) 10 | 11 | // ParseURL no longer needs to be used by clients of this library since supplying a URL as a 12 | // connection string to sql.Open() is now supported: 13 | // 14 | // sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") 15 | // 16 | // It remains exported here for backwards-compatibility. 17 | // 18 | // ParseURL converts a url to a connection string for driver.Open. 19 | // Example: 20 | // 21 | // "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" 22 | // 23 | // converts to: 24 | // 25 | // "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" 26 | // 27 | // A minimal example: 28 | // 29 | // "postgres://" 30 | // 31 | // This will be blank, causing driver.Open to use all of the defaults 32 | func ParseURL(url string) (string, error) { 33 | u, err := nurl.Parse(url) 34 | if err != nil { 35 | return "", err 36 | } 37 | 38 | if u.Scheme != "postgres" && u.Scheme != "postgresql" { 39 | return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) 40 | } 41 | 42 | var kvs []string 43 | escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) 44 | accrue := func(k, v string) { 45 | if v != "" { 46 | kvs = append(kvs, k+"="+escaper.Replace(v)) 47 | } 48 | } 49 | 50 | if u.User != nil { 51 | v := u.User.Username() 52 | accrue("user", v) 53 | 54 | v, _ = u.User.Password() 55 | accrue("password", v) 56 | } 57 | 58 | if host, port, err := net.SplitHostPort(u.Host); err != nil { 59 | accrue("host", u.Host) 60 | } else { 61 | accrue("host", host) 62 | accrue("port", port) 63 | } 64 | 65 | if u.Path != "" { 66 | accrue("dbname", u.Path[1:]) 67 | } 68 | 69 | q := u.Query() 70 | for k := range q { 71 | accrue(k, q.Get(k)) 72 | } 73 | 74 | sort.Strings(kvs) // Makes testing easier (not a performance concern) 75 | return strings.Join(kvs, " "), nil 76 | } 77 | -------------------------------------------------------------------------------- /buf.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | 7 | "github.com/lib/pq/oid" 8 | ) 9 | 10 | type readBuf []byte 11 | 12 | func (b *readBuf) int32() (n int) { 13 | n = int(int32(binary.BigEndian.Uint32(*b))) 14 | *b = (*b)[4:] 15 | return 16 | } 17 | 18 | func (b *readBuf) oid() (n oid.Oid) { 19 | n = oid.Oid(binary.BigEndian.Uint32(*b)) 20 | *b = (*b)[4:] 21 | return 22 | } 23 | 24 | // N.B: this is actually an unsigned 16-bit integer, unlike int32 25 | func (b *readBuf) int16() (n int) { 26 | n = int(binary.BigEndian.Uint16(*b)) 27 | *b = (*b)[2:] 28 | return 29 | } 30 | 31 | func (b *readBuf) string() string { 32 | i := bytes.IndexByte(*b, 0) 33 | if i < 0 { 34 | errorf("invalid message format; expected string terminator") 35 | } 36 | s := (*b)[:i] 37 | *b = (*b)[i+1:] 38 | return string(s) 39 | } 40 | 41 | func (b *readBuf) next(n int) (v []byte) { 42 | v = (*b)[:n] 43 | *b = (*b)[n:] 44 | return 45 | } 46 | 47 | func (b *readBuf) byte() byte { 48 | return b.next(1)[0] 49 | } 50 | 51 | type writeBuf struct { 52 | buf []byte 53 | pos int 54 | } 55 | 56 | func (b *writeBuf) int32(n int) { 57 | x := make([]byte, 4) 58 | binary.BigEndian.PutUint32(x, uint32(n)) 59 | b.buf = append(b.buf, x...) 60 | } 61 | 62 | func (b *writeBuf) int16(n int) { 63 | x := make([]byte, 2) 64 | binary.BigEndian.PutUint16(x, uint16(n)) 65 | b.buf = append(b.buf, x...) 66 | } 67 | 68 | func (b *writeBuf) string(s string) { 69 | b.buf = append(b.buf, (s + "\000")...) 70 | } 71 | 72 | func (b *writeBuf) byte(c byte) { 73 | b.buf = append(b.buf, c) 74 | } 75 | 76 | func (b *writeBuf) bytes(v []byte) { 77 | b.buf = append(b.buf, v...) 78 | } 79 | 80 | func (b *writeBuf) wrap() []byte { 81 | p := b.buf[b.pos:] 82 | binary.BigEndian.PutUint32(p, uint32(len(p))) 83 | return b.buf 84 | } 85 | 86 | func (b *writeBuf) next(c byte) { 87 | p := b.buf[b.pos:] 88 | binary.BigEndian.PutUint32(p, uint32(len(p))) 89 | b.pos = len(b.buf) + 1 90 | b.buf = append(b.buf, c, 0, 0, 0, 0) 91 | } 92 | -------------------------------------------------------------------------------- /.travis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eu 4 | 5 | client_configure() { 6 | sudo chmod 600 $PQSSLCERTTEST_PATH/postgresql.key 7 | } 8 | 9 | pgdg_repository() { 10 | local sourcelist='sources.list.d/postgresql.list' 11 | 12 | curl -sS 'https://www.postgresql.org/media/keys/ACCC4CF8.asc' | sudo apt-key add - 13 | echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION | sudo tee "/etc/apt/$sourcelist" 14 | sudo apt-get -o Dir::Etc::sourcelist="$sourcelist" -o Dir::Etc::sourceparts='-' -o APT::Get::List-Cleanup='0' update 15 | } 16 | 17 | postgresql_configure() { 18 | sudo tee /etc/postgresql/$PGVERSION/main/pg_hba.conf > /dev/null <<-config 19 | local all all trust 20 | hostnossl all pqgossltest 127.0.0.1/32 reject 21 | hostnossl all pqgosslcert 127.0.0.1/32 reject 22 | hostssl all pqgossltest 127.0.0.1/32 trust 23 | hostssl all pqgosslcert 127.0.0.1/32 cert 24 | host all all 127.0.0.1/32 trust 25 | hostnossl all pqgossltest ::1/128 reject 26 | hostnossl all pqgosslcert ::1/128 reject 27 | hostssl all pqgossltest ::1/128 trust 28 | hostssl all pqgosslcert ::1/128 cert 29 | host all all ::1/128 trust 30 | config 31 | 32 | xargs sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/main/ <<-certificates 33 | certs/root.crt 34 | certs/server.crt 35 | certs/server.key 36 | certificates 37 | 38 | sort -VCu <<-versions || 39 | $PGVERSION 40 | 9.2 41 | versions 42 | sudo tee -a /etc/postgresql/$PGVERSION/main/postgresql.conf > /dev/null <<-config 43 | ssl_ca_file = 'root.crt' 44 | ssl_cert_file = 'server.crt' 45 | ssl_key_file = 'server.key' 46 | config 47 | 48 | echo 127.0.0.1 postgres | sudo tee -a /etc/hosts > /dev/null 49 | 50 | sudo service postgresql restart 51 | } 52 | 53 | postgresql_install() { 54 | xargs sudo apt-get -y -o Dpkg::Options::='--force-confdef' -o Dpkg::Options::='--force-confnew' install <<-packages 55 | postgresql-$PGVERSION 56 | postgresql-server-dev-$PGVERSION 57 | postgresql-contrib-$PGVERSION 58 | packages 59 | } 60 | 61 | postgresql_uninstall() { 62 | sudo service postgresql stop 63 | xargs sudo apt-get -y --purge remove <<-packages 64 | libpq-dev 65 | libpq5 66 | postgresql 67 | postgresql-client-common 68 | postgresql-common 69 | packages 70 | sudo rm -rf /var/lib/postgresql 71 | } 72 | 73 | $1 74 | -------------------------------------------------------------------------------- /hstore/hstore.go: -------------------------------------------------------------------------------- 1 | package hstore 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "strings" 7 | ) 8 | 9 | // A wrapper for transferring Hstore values back and forth easily. 10 | type Hstore struct { 11 | Map map[string]sql.NullString 12 | } 13 | 14 | // escapes and quotes hstore keys/values 15 | // s should be a sql.NullString or string 16 | func hQuote(s interface{}) string { 17 | var str string 18 | switch v := s.(type) { 19 | case sql.NullString: 20 | if !v.Valid { 21 | return "NULL" 22 | } 23 | str = v.String 24 | case string: 25 | str = v 26 | default: 27 | panic("not a string or sql.NullString") 28 | } 29 | 30 | str = strings.Replace(str, "\\", "\\\\", -1) 31 | return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` 32 | } 33 | 34 | // Scan implements the Scanner interface. 35 | // 36 | // Note h.Map is reallocated before the scan to clear existing values. If the 37 | // hstore column's database value is NULL, then h.Map is set to nil instead. 38 | func (h *Hstore) Scan(value interface{}) error { 39 | if value == nil { 40 | h.Map = nil 41 | return nil 42 | } 43 | h.Map = make(map[string]sql.NullString) 44 | var b byte 45 | pair := [][]byte{{}, {}} 46 | pi := 0 47 | inQuote := false 48 | didQuote := false 49 | sawSlash := false 50 | bindex := 0 51 | for bindex, b = range value.([]byte) { 52 | if sawSlash { 53 | pair[pi] = append(pair[pi], b) 54 | sawSlash = false 55 | continue 56 | } 57 | 58 | switch b { 59 | case '\\': 60 | sawSlash = true 61 | continue 62 | case '"': 63 | inQuote = !inQuote 64 | if !didQuote { 65 | didQuote = true 66 | } 67 | continue 68 | default: 69 | if !inQuote { 70 | switch b { 71 | case ' ', '\t', '\n', '\r': 72 | continue 73 | case '=': 74 | continue 75 | case '>': 76 | pi = 1 77 | didQuote = false 78 | continue 79 | case ',': 80 | s := string(pair[1]) 81 | if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { 82 | h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} 83 | } else { 84 | h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} 85 | } 86 | pair[0] = []byte{} 87 | pair[1] = []byte{} 88 | pi = 0 89 | continue 90 | } 91 | } 92 | } 93 | pair[pi] = append(pair[pi], b) 94 | } 95 | if bindex > 0 { 96 | s := string(pair[1]) 97 | if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { 98 | h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} 99 | } else { 100 | h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} 101 | } 102 | } 103 | return nil 104 | } 105 | 106 | // Value implements the driver Valuer interface. Note if h.Map is nil, the 107 | // database column value will be set to NULL. 108 | func (h Hstore) Value() (driver.Value, error) { 109 | if h.Map == nil { 110 | return nil, nil 111 | } 112 | parts := []string{} 113 | for key, val := range h.Map { 114 | thispart := hQuote(key) + "=>" + hQuote(val) 115 | parts = append(parts, thispart) 116 | } 117 | return []byte(strings.Join(parts, ",")), nil 118 | } 119 | -------------------------------------------------------------------------------- /listen_example/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Below you will find a self-contained Go program which uses the LISTEN / NOTIFY 4 | mechanism to avoid polling the database while waiting for more work to arrive. 5 | 6 | // 7 | // You can see the program in action by defining a function similar to 8 | // the following: 9 | // 10 | // CREATE OR REPLACE FUNCTION public.get_work() 11 | // RETURNS bigint 12 | // LANGUAGE sql 13 | // AS $$ 14 | // SELECT CASE WHEN random() >= 0.2 THEN int8 '1' END 15 | // $$ 16 | // ; 17 | 18 | package main 19 | 20 | import ( 21 | "database/sql" 22 | "fmt" 23 | "time" 24 | 25 | "github.com/lib/pq" 26 | ) 27 | 28 | func doWork(db *sql.DB, work int64) { 29 | // work here 30 | } 31 | 32 | func getWork(db *sql.DB) { 33 | for { 34 | // get work from the database here 35 | var work sql.NullInt64 36 | err := db.QueryRow("SELECT get_work()").Scan(&work) 37 | if err != nil { 38 | fmt.Println("call to get_work() failed: ", err) 39 | time.Sleep(10 * time.Second) 40 | continue 41 | } 42 | if !work.Valid { 43 | // no more work to do 44 | fmt.Println("ran out of work") 45 | return 46 | } 47 | 48 | fmt.Println("starting work on ", work.Int64) 49 | go doWork(db, work.Int64) 50 | } 51 | } 52 | 53 | func waitForNotification(l *pq.Listener) { 54 | for { 55 | select { 56 | case <-l.Notify: 57 | fmt.Println("received notification, new work available") 58 | return 59 | case <-time.After(90 * time.Second): 60 | go func() { 61 | l.Ping() 62 | }() 63 | // Check if there's more work available, just in case it takes 64 | // a while for the Listener to notice connection loss and 65 | // reconnect. 66 | fmt.Println("received no work for 90 seconds, checking for new work") 67 | return 68 | } 69 | } 70 | } 71 | 72 | func main() { 73 | var conninfo string = "" 74 | 75 | db, err := sql.Open("postgres", conninfo) 76 | if err != nil { 77 | panic(err) 78 | } 79 | 80 | reportProblem := func(ev pq.ListenerEventType, err error) { 81 | if err != nil { 82 | fmt.Println(err.Error()) 83 | } 84 | } 85 | 86 | listener := pq.NewListener(conninfo, 10 * time.Second, time.Minute, reportProblem) 87 | err = listener.Listen("getwork") 88 | if err != nil { 89 | panic(err) 90 | } 91 | 92 | fmt.Println("entering main loop") 93 | for { 94 | // process all available work before waiting for notifications 95 | getWork(db) 96 | waitForNotification(listener) 97 | } 98 | } 99 | 100 | 101 | */ 102 | package listen_example 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pq - A pure Go postgres driver for Go's database/sql package 2 | 3 | [![Build Status](https://travis-ci.org/lib/pq.png?branch=master)](https://travis-ci.org/lib/pq) 4 | 5 | ## Install 6 | 7 | go get github.com/lib/pq 8 | 9 | ## Docs 10 | 11 | For detailed documentation and basic usage examples, please see the package 12 | documentation at . 13 | 14 | ## Tests 15 | 16 | `go test` is used for testing. A running PostgreSQL server is 17 | required, with the ability to log in. The default database to connect 18 | to test with is "pqgotest," but it can be overridden using environment 19 | variables. 20 | 21 | Example: 22 | 23 | PGHOST=/run/postgresql go test github.com/lib/pq 24 | 25 | Optionally, a benchmark suite can be run as part of the tests: 26 | 27 | PGHOST=/run/postgresql go test -bench . 28 | 29 | ## Features 30 | 31 | * SSL 32 | * Handles bad connections for `database/sql` 33 | * Scan `time.Time` correctly (i.e. `timestamp[tz]`, `time[tz]`, `date`) 34 | * Scan binary blobs correctly (i.e. `bytea`) 35 | * Package for `hstore` support 36 | * COPY FROM support 37 | * pq.ParseURL for converting urls to connection strings for sql.Open. 38 | * Many libpq compatible environment variables 39 | * Unix socket support 40 | * Notifications: `LISTEN`/`NOTIFY` 41 | * pgpass support 42 | 43 | ## Future / Things you can help with 44 | 45 | * Better COPY FROM / COPY TO (see discussion in #181) 46 | 47 | ## Thank you (alphabetical) 48 | 49 | Some of these contributors are from the original library `bmizerany/pq.go` whose 50 | code still exists in here. 51 | 52 | * Andy Balholm (andybalholm) 53 | * Ben Berkert (benburkert) 54 | * Benjamin Heatwole (bheatwole) 55 | * Bill Mill (llimllib) 56 | * Bjørn Madsen (aeons) 57 | * Blake Gentry (bgentry) 58 | * Brad Fitzpatrick (bradfitz) 59 | * Charlie Melbye (cmelbye) 60 | * Chris Bandy (cbandy) 61 | * Chris Gilling (cgilling) 62 | * Chris Walsh (cwds) 63 | * Dan Sosedoff (sosedoff) 64 | * Daniel Farina (fdr) 65 | * Eric Chlebek (echlebek) 66 | * Eric Garrido (minusnine) 67 | * Eric Urban (hydrogen18) 68 | * Everyone at The Go Team 69 | * Evan Shaw (edsrzf) 70 | * Ewan Chou (coocood) 71 | * Fazal Majid (fazalmajid) 72 | * Federico Romero (federomero) 73 | * Fumin (fumin) 74 | * Gary Burd (garyburd) 75 | * Heroku (heroku) 76 | * James Pozdena (jpoz) 77 | * Jason McVetta (jmcvetta) 78 | * Jeremy Jay (pbnjay) 79 | * Joakim Sernbrant (serbaut) 80 | * John Gallagher (jgallagher) 81 | * Jonathan Rudenberg (titanous) 82 | * Joël Stemmer (jstemmer) 83 | * Kamil Kisiel (kisielk) 84 | * Kelly Dunn (kellydunn) 85 | * Keith Rarick (kr) 86 | * Kir Shatrov (kirs) 87 | * Lann Martin (lann) 88 | * Maciek Sakrejda (uhoh-itsmaciek) 89 | * Marc Brinkmann (mbr) 90 | * Marko Tiikkaja (johto) 91 | * Matt Newberry (MattNewberry) 92 | * Matt Robenolt (mattrobenolt) 93 | * Martin Olsen (martinolsen) 94 | * Mike Lewis (mikelikespie) 95 | * Nicolas Patry (Narsil) 96 | * Oliver Tonnhofer (olt) 97 | * Patrick Hayes (phayes) 98 | * Paul Hammond (paulhammond) 99 | * Ryan Smith (ryandotsmith) 100 | * Samuel Stauffer (samuel) 101 | * Timothée Peignier (cyberdelia) 102 | * Travis Cline (tmc) 103 | * TruongSinh Tran-Nguyen (truongsinh) 104 | * Yaismel Miranda (ympons) 105 | * notedit (notedit) 106 | -------------------------------------------------------------------------------- /certs/postgresql.crt: -------------------------------------------------------------------------------- 1 | Certificate: 2 | Data: 3 | Version: 3 (0x2) 4 | Serial Number: 2 (0x2) 5 | Signature Algorithm: sha256WithRSAEncryption 6 | Issuer: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=pq CA 7 | Validity 8 | Not Before: Oct 11 15:10:11 2014 GMT 9 | Not After : Oct 8 15:10:11 2024 GMT 10 | Subject: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=pqgosslcert 11 | Subject Public Key Info: 12 | Public Key Algorithm: rsaEncryption 13 | RSA Public Key: (1024 bit) 14 | Modulus (1024 bit): 15 | 00:e3:8c:06:9a:70:54:51:d1:34:34:83:39:cd:a2: 16 | 59:0f:05:ed:8d:d8:0e:34:d0:92:f4:09:4d:ee:8c: 17 | 78:55:49:24:f8:3c:e0:34:58:02:b2:e7:94:58:c1: 18 | e8:e5:bb:d1:af:f6:54:c1:40:b1:90:70:79:0d:35: 19 | 54:9c:8f:16:e9:c2:f0:92:e6:64:49:38:c1:76:f8: 20 | 47:66:c4:5b:4a:b6:a9:43:ce:c8:be:6c:4d:2b:94: 21 | 97:3c:55:bc:d1:d0:6e:b7:53:ae:89:5c:4b:6b:86: 22 | 40:be:c1:ae:1e:64:ce:9c:ae:87:0a:69:e5:c8:21: 23 | 12:be:ae:1d:f6:45:df:16:a7 24 | Exponent: 65537 (0x10001) 25 | X509v3 extensions: 26 | X509v3 Subject Key Identifier: 27 | 9B:25:31:63:A2:D8:06:FF:CB:E3:E9:96:FF:0D:BA:DC:12:7D:04:CF 28 | X509v3 Authority Key Identifier: 29 | keyid:52:93:ED:1E:76:0A:9F:65:4F:DE:19:66:C1:D5:22:40:35:CB:A0:72 30 | 31 | X509v3 Basic Constraints: 32 | CA:FALSE 33 | X509v3 Key Usage: 34 | Digital Signature, Non Repudiation, Key Encipherment 35 | Signature Algorithm: sha256WithRSAEncryption 36 | 3e:f5:f8:0b:4e:11:bd:00:86:1f:ce:dc:97:02:98:91:11:f5: 37 | 65:f6:f2:8a:b2:3e:47:92:05:69:28:c9:e9:b4:f7:cf:93:d1: 38 | 2d:81:5d:00:3c:23:be:da:70:ea:59:e1:2c:d3:25:49:ae:a6: 39 | 95:54:c1:10:df:23:e3:fe:d6:e4:76:c7:6b:73:ad:1b:34:7c: 40 | e2:56:cc:c0:37:ae:c5:7a:11:20:6c:3d:05:0e:99:cd:22:6c: 41 | cf:59:a1:da:28:d4:65:ba:7d:2f:2b:3d:69:6d:a6:c1:ae:57: 42 | bf:56:64:13:79:f8:48:46:65:eb:81:67:28:0b:7b:de:47:10: 43 | b3:80:3c:31:d1:58:94:01:51:4a:c7:c8:1a:01:a8:af:c4:cd: 44 | bb:84:a5:d9:8b:b4:b9:a1:64:3e:95:d9:90:1d:d5:3f:67:cc: 45 | 3b:ba:f5:b4:d1:33:77:ee:c2:d2:3e:7e:c5:66:6e:b7:35:4c: 46 | 60:57:b0:b8:be:36:c8:f3:d3:95:8c:28:4a:c9:f7:27:a4:0d: 47 | e5:96:99:eb:f5:c8:bd:f3:84:6d:ef:02:f9:8a:36:7d:6b:5f: 48 | 36:68:37:41:d9:74:ae:c6:78:2e:44:86:a1:ad:43:ca:fb:b5: 49 | 3e:ba:10:23:09:02:ac:62:d1:d0:83:c8:95:b9:e3:5e:30:ff: 50 | 5b:2b:38:fa 51 | -----BEGIN CERTIFICATE----- 52 | MIIDEzCCAfugAwIBAgIBAjANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJVUzEP 53 | MA0GA1UECBMGTmV2YWRhMRIwEAYDVQQHEwlMYXMgVmVnYXMxGjAYBgNVBAoTEWdp 54 | dGh1Yi5jb20vbGliL3BxMQ4wDAYDVQQDEwVwcSBDQTAeFw0xNDEwMTExNTEwMTFa 55 | Fw0yNDEwMDgxNTEwMTFaMGQxCzAJBgNVBAYTAlVTMQ8wDQYDVQQIEwZOZXZhZGEx 56 | EjAQBgNVBAcTCUxhcyBWZWdhczEaMBgGA1UEChMRZ2l0aHViLmNvbS9saWIvcHEx 57 | FDASBgNVBAMTC3BxZ29zc2xjZXJ0MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB 58 | gQDjjAaacFRR0TQ0gznNolkPBe2N2A400JL0CU3ujHhVSST4POA0WAKy55RYwejl 59 | u9Gv9lTBQLGQcHkNNVScjxbpwvCS5mRJOMF2+EdmxFtKtqlDzsi+bE0rlJc8VbzR 60 | 0G63U66JXEtrhkC+wa4eZM6crocKaeXIIRK+rh32Rd8WpwIDAQABo1owWDAdBgNV 61 | HQ4EFgQUmyUxY6LYBv/L4+mW/w263BJ9BM8wHwYDVR0jBBgwFoAUUpPtHnYKn2VP 62 | 3hlmwdUiQDXLoHIwCQYDVR0TBAIwADALBgNVHQ8EBAMCBeAwDQYJKoZIhvcNAQEL 63 | BQADggEBAD71+AtOEb0Ahh/O3JcCmJER9WX28oqyPkeSBWkoyem098+T0S2BXQA8 64 | I77acOpZ4SzTJUmuppVUwRDfI+P+1uR2x2tzrRs0fOJWzMA3rsV6ESBsPQUOmc0i 65 | bM9Zodoo1GW6fS8rPWltpsGuV79WZBN5+EhGZeuBZygLe95HELOAPDHRWJQBUUrH 66 | yBoBqK/EzbuEpdmLtLmhZD6V2ZAd1T9nzDu69bTRM3fuwtI+fsVmbrc1TGBXsLi+ 67 | Nsjz05WMKErJ9yekDeWWmev1yL3zhG3vAvmKNn1rXzZoN0HZdK7GeC5EhqGtQ8r7 68 | tT66ECMJAqxi0dCDyJW5414w/1srOPo= 69 | -----END CERTIFICATE----- 70 | -------------------------------------------------------------------------------- /hstore/hstore_test.go: -------------------------------------------------------------------------------- 1 | package hstore 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | "testing" 7 | 8 | _ "github.com/lib/pq" 9 | ) 10 | 11 | type Fatalistic interface { 12 | Fatal(args ...interface{}) 13 | } 14 | 15 | func openTestConn(t Fatalistic) *sql.DB { 16 | datname := os.Getenv("PGDATABASE") 17 | sslmode := os.Getenv("PGSSLMODE") 18 | 19 | if datname == "" { 20 | os.Setenv("PGDATABASE", "pqgotest") 21 | } 22 | 23 | if sslmode == "" { 24 | os.Setenv("PGSSLMODE", "disable") 25 | } 26 | 27 | conn, err := sql.Open("postgres", "") 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | 32 | return conn 33 | } 34 | 35 | func TestHstore(t *testing.T) { 36 | db := openTestConn(t) 37 | defer db.Close() 38 | 39 | // quitely create hstore if it doesn't exist 40 | _, err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore") 41 | if err != nil { 42 | t.Skipf("Skipping hstore tests - hstore extension create failed: %s", err.Error()) 43 | } 44 | 45 | hs := Hstore{} 46 | 47 | // test for null-valued hstores 48 | err = db.QueryRow("SELECT NULL::hstore").Scan(&hs) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | if hs.Map != nil { 53 | t.Fatalf("expected null map") 54 | } 55 | 56 | err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) 57 | if err != nil { 58 | t.Fatalf("re-query null map failed: %s", err.Error()) 59 | } 60 | if hs.Map != nil { 61 | t.Fatalf("expected null map") 62 | } 63 | 64 | // test for empty hstores 65 | err = db.QueryRow("SELECT ''::hstore").Scan(&hs) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | if hs.Map == nil { 70 | t.Fatalf("expected empty map, got null map") 71 | } 72 | if len(hs.Map) != 0 { 73 | t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) 74 | } 75 | 76 | err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) 77 | if err != nil { 78 | t.Fatalf("re-query empty map failed: %s", err.Error()) 79 | } 80 | if hs.Map == nil { 81 | t.Fatalf("expected empty map, got null map") 82 | } 83 | if len(hs.Map) != 0 { 84 | t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) 85 | } 86 | 87 | // a few example maps to test out 88 | hsOnePair := Hstore{ 89 | Map: map[string]sql.NullString{ 90 | "key1": {String: "value1", Valid: true}, 91 | }, 92 | } 93 | 94 | hsThreePairs := Hstore{ 95 | Map: map[string]sql.NullString{ 96 | "key1": {String: "value1", Valid: true}, 97 | "key2": {String: "value2", Valid: true}, 98 | "key3": {String: "value3", Valid: true}, 99 | }, 100 | } 101 | 102 | hsSmorgasbord := Hstore{ 103 | Map: map[string]sql.NullString{ 104 | "nullstring": {String: "NULL", Valid: true}, 105 | "actuallynull": {String: "", Valid: false}, 106 | "NULL": {String: "NULL string key", Valid: true}, 107 | "withbracket": {String: "value>42", Valid: true}, 108 | "withequal": {String: "value=42", Valid: true}, 109 | `"withquotes1"`: {String: `this "should" be fine`, Valid: true}, 110 | `"withquotes"2"`: {String: `this "should\" also be fine`, Valid: true}, 111 | "embedded1": {String: "value1=>x1", Valid: true}, 112 | "embedded2": {String: `"value2"=>x2`, Valid: true}, 113 | "withnewlines": {String: "\n\nvalue\t=>2", Valid: true}, 114 | "<>": {String: `this, "should,\" also, => be fine`, Valid: true}, 115 | }, 116 | } 117 | 118 | // test encoding in query params, then decoding during Scan 119 | testBidirectional := func(h Hstore) { 120 | err = db.QueryRow("SELECT $1::hstore", h).Scan(&hs) 121 | if err != nil { 122 | t.Fatalf("re-query %d-pair map failed: %s", len(h.Map), err.Error()) 123 | } 124 | if hs.Map == nil { 125 | t.Fatalf("expected %d-pair map, got null map", len(h.Map)) 126 | } 127 | if len(hs.Map) != len(h.Map) { 128 | t.Fatalf("expected %d-pair map, got len(map)=%d", len(h.Map), len(hs.Map)) 129 | } 130 | 131 | for key, val := range hs.Map { 132 | otherval, found := h.Map[key] 133 | if !found { 134 | t.Fatalf(" key '%v' not found in %d-pair map", key, len(h.Map)) 135 | } 136 | if otherval.Valid != val.Valid { 137 | t.Fatalf(" value %v <> %v in %d-pair map", otherval, val, len(h.Map)) 138 | } 139 | if otherval.String != val.String { 140 | t.Fatalf(" value '%v' <> '%v' in %d-pair map", otherval.String, val.String, len(h.Map)) 141 | } 142 | } 143 | } 144 | 145 | testBidirectional(hsOnePair) 146 | testBidirectional(hsThreePairs) 147 | testBidirectional(hsSmorgasbord) 148 | } 149 | -------------------------------------------------------------------------------- /certs/server.crt: -------------------------------------------------------------------------------- 1 | Certificate: 2 | Data: 3 | Version: 3 (0x2) 4 | Serial Number: 1 (0x1) 5 | Signature Algorithm: sha256WithRSAEncryption 6 | Issuer: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=pq CA 7 | Validity 8 | Not Before: Oct 11 15:05:15 2014 GMT 9 | Not After : Oct 8 15:05:15 2024 GMT 10 | Subject: C=US, ST=Nevada, L=Las Vegas, O=github.com/lib/pq, CN=postgres 11 | Subject Public Key Info: 12 | Public Key Algorithm: rsaEncryption 13 | RSA Public Key: (2048 bit) 14 | Modulus (2048 bit): 15 | 00:d7:8a:4c:85:fb:17:a5:3c:8f:e0:72:11:29:ce: 16 | 3f:b0:1f:3f:7d:c6:ee:7f:a7:fc:02:2b:35:47:08: 17 | a6:3d:90:df:5c:56:14:94:00:c7:6d:d1:d2:e2:61: 18 | 95:77:b8:e3:a6:66:31:f9:1f:21:7d:62:e1:27:da: 19 | 94:37:61:4a:ea:63:53:a0:61:b8:9c:bb:a5:e2:e7: 20 | b7:a6:d8:0f:05:04:c7:29:e2:ea:49:2b:7f:de:15: 21 | 00:a6:18:70:50:c7:0c:de:9a:f9:5a:96:b0:e1:94: 22 | 06:c6:6d:4a:21:3b:b4:0f:a5:6d:92:86:34:b2:4e: 23 | d7:0e:a7:19:c0:77:0b:7b:87:c8:92:de:42:ff:86: 24 | d2:b7:9a:a4:d4:15:23:ca:ad:a5:69:21:b8:ce:7e: 25 | 66:cb:85:5d:b9:ed:8b:2d:09:8d:94:e4:04:1e:72: 26 | ec:ef:d0:76:90:15:5a:a4:f7:91:4b:e9:ce:4e:9d: 27 | 5d:9a:70:17:9c:d8:e9:73:83:ea:3d:61:99:a6:cd: 28 | ac:91:40:5a:88:77:e5:4e:2a:8e:3d:13:f3:f9:38: 29 | 6f:81:6b:8a:95:ca:0e:07:ab:6f:da:b4:8c:d9:ff: 30 | aa:78:03:aa:c7:c2:cf:6f:64:92:d3:d8:83:d5:af: 31 | f1:23:18:a7:2e:7b:17:0b:e7:7d:f1:fa:a8:41:a3: 32 | 04:57 33 | Exponent: 65537 (0x10001) 34 | X509v3 extensions: 35 | X509v3 Subject Key Identifier: 36 | EE:F0:B3:46:DC:C7:09:EB:0E:B6:2F:E5:FE:62:60:45:44:9F:59:CC 37 | X509v3 Authority Key Identifier: 38 | keyid:52:93:ED:1E:76:0A:9F:65:4F:DE:19:66:C1:D5:22:40:35:CB:A0:72 39 | 40 | X509v3 Basic Constraints: 41 | CA:FALSE 42 | X509v3 Key Usage: 43 | Digital Signature, Non Repudiation, Key Encipherment 44 | Signature Algorithm: sha256WithRSAEncryption 45 | 7e:5a:6e:be:bf:d2:6c:c1:d6:fa:b6:fb:3f:06:53:36:08:87: 46 | 9d:95:b1:39:af:9e:f6:47:38:17:39:da:25:7c:f2:ad:0c:e3: 47 | ab:74:19:ca:fb:8c:a0:50:c0:1d:19:8a:9c:21:ed:0f:3a:d1: 48 | 96:54:2e:10:09:4f:b8:70:f7:2b:99:43:d2:c6:15:bc:3f:24: 49 | 7d:28:39:32:3f:8d:a4:4f:40:75:7f:3e:0d:1c:d1:69:f2:4e: 50 | 98:83:47:97:d2:25:ac:c9:36:86:2f:04:a6:c4:86:c7:c4:00: 51 | 5f:7f:b9:ad:fc:bf:e9:f5:78:d7:82:1a:51:0d:fc:ab:9e:92: 52 | 1d:5f:0c:18:d1:82:e0:14:c9:ce:91:89:71:ff:49:49:ff:35: 53 | bf:7b:44:78:42:c1:d0:66:65:bb:28:2e:60:ca:9b:20:12:a9: 54 | 90:61:b1:96:ec:15:46:c9:37:f7:07:90:8a:89:45:2a:3f:37: 55 | ec:dc:e3:e5:8f:c3:3a:57:80:a5:54:60:0c:e1:b2:26:99:2b: 56 | 40:7e:36:d1:9a:70:02:ec:63:f4:3b:72:ae:81:fb:30:20:6d: 57 | cb:48:46:c6:b5:8f:39:b1:84:05:25:55:8d:f5:62:f6:1b:46: 58 | 2e:da:a3:4c:26:12:44:d7:56:b6:b8:a9:ca:d3:ab:71:45:7c: 59 | 9f:48:6d:1e 60 | -----BEGIN CERTIFICATE----- 61 | MIIDlDCCAnygAwIBAgIBATANBgkqhkiG9w0BAQsFADBeMQswCQYDVQQGEwJVUzEP 62 | MA0GA1UECBMGTmV2YWRhMRIwEAYDVQQHEwlMYXMgVmVnYXMxGjAYBgNVBAoTEWdp 63 | dGh1Yi5jb20vbGliL3BxMQ4wDAYDVQQDEwVwcSBDQTAeFw0xNDEwMTExNTA1MTVa 64 | Fw0yNDEwMDgxNTA1MTVaMGExCzAJBgNVBAYTAlVTMQ8wDQYDVQQIEwZOZXZhZGEx 65 | EjAQBgNVBAcTCUxhcyBWZWdhczEaMBgGA1UEChMRZ2l0aHViLmNvbS9saWIvcHEx 66 | ETAPBgNVBAMTCHBvc3RncmVzMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC 67 | AQEA14pMhfsXpTyP4HIRKc4/sB8/fcbuf6f8Ais1RwimPZDfXFYUlADHbdHS4mGV 68 | d7jjpmYx+R8hfWLhJ9qUN2FK6mNToGG4nLul4ue3ptgPBQTHKeLqSSt/3hUAphhw 69 | UMcM3pr5Wpaw4ZQGxm1KITu0D6VtkoY0sk7XDqcZwHcLe4fIkt5C/4bSt5qk1BUj 70 | yq2laSG4zn5my4Vdue2LLQmNlOQEHnLs79B2kBVapPeRS+nOTp1dmnAXnNjpc4Pq 71 | PWGZps2skUBaiHflTiqOPRPz+ThvgWuKlcoOB6tv2rSM2f+qeAOqx8LPb2SS09iD 72 | 1a/xIxinLnsXC+d98fqoQaMEVwIDAQABo1owWDAdBgNVHQ4EFgQU7vCzRtzHCesO 73 | ti/l/mJgRUSfWcwwHwYDVR0jBBgwFoAUUpPtHnYKn2VP3hlmwdUiQDXLoHIwCQYD 74 | VR0TBAIwADALBgNVHQ8EBAMCBeAwDQYJKoZIhvcNAQELBQADggEBAH5abr6/0mzB 75 | 1vq2+z8GUzYIh52VsTmvnvZHOBc52iV88q0M46t0Gcr7jKBQwB0Zipwh7Q860ZZU 76 | LhAJT7hw9yuZQ9LGFbw/JH0oOTI/jaRPQHV/Pg0c0WnyTpiDR5fSJazJNoYvBKbE 77 | hsfEAF9/ua38v+n1eNeCGlEN/Kuekh1fDBjRguAUyc6RiXH/SUn/Nb97RHhCwdBm 78 | ZbsoLmDKmyASqZBhsZbsFUbJN/cHkIqJRSo/N+zc4+WPwzpXgKVUYAzhsiaZK0B+ 79 | NtGacALsY/Q7cq6B+zAgbctIRsa1jzmxhAUlVY31YvYbRi7ao0wmEkTXVra4qcrT 80 | q3FFfJ9IbR4= 81 | -----END CERTIFICATE----- 82 | -------------------------------------------------------------------------------- /oid/types.go: -------------------------------------------------------------------------------- 1 | // generated by 'go run gen.go'; do not edit 2 | 3 | package oid 4 | 5 | const ( 6 | T_bool Oid = 16 7 | T_bytea Oid = 17 8 | T_char Oid = 18 9 | T_name Oid = 19 10 | T_int8 Oid = 20 11 | T_int2 Oid = 21 12 | T_int2vector Oid = 22 13 | T_int4 Oid = 23 14 | T_regproc Oid = 24 15 | T_text Oid = 25 16 | T_oid Oid = 26 17 | T_tid Oid = 27 18 | T_xid Oid = 28 19 | T_cid Oid = 29 20 | T_oidvector Oid = 30 21 | T_pg_type Oid = 71 22 | T_pg_attribute Oid = 75 23 | T_pg_proc Oid = 81 24 | T_pg_class Oid = 83 25 | T_json Oid = 114 26 | T_xml Oid = 142 27 | T__xml Oid = 143 28 | T_pg_node_tree Oid = 194 29 | T__json Oid = 199 30 | T_smgr Oid = 210 31 | T_point Oid = 600 32 | T_lseg Oid = 601 33 | T_path Oid = 602 34 | T_box Oid = 603 35 | T_polygon Oid = 604 36 | T_line Oid = 628 37 | T__line Oid = 629 38 | T_cidr Oid = 650 39 | T__cidr Oid = 651 40 | T_float4 Oid = 700 41 | T_float8 Oid = 701 42 | T_abstime Oid = 702 43 | T_reltime Oid = 703 44 | T_tinterval Oid = 704 45 | T_unknown Oid = 705 46 | T_circle Oid = 718 47 | T__circle Oid = 719 48 | T_money Oid = 790 49 | T__money Oid = 791 50 | T_macaddr Oid = 829 51 | T_inet Oid = 869 52 | T__bool Oid = 1000 53 | T__bytea Oid = 1001 54 | T__char Oid = 1002 55 | T__name Oid = 1003 56 | T__int2 Oid = 1005 57 | T__int2vector Oid = 1006 58 | T__int4 Oid = 1007 59 | T__regproc Oid = 1008 60 | T__text Oid = 1009 61 | T__tid Oid = 1010 62 | T__xid Oid = 1011 63 | T__cid Oid = 1012 64 | T__oidvector Oid = 1013 65 | T__bpchar Oid = 1014 66 | T__varchar Oid = 1015 67 | T__int8 Oid = 1016 68 | T__point Oid = 1017 69 | T__lseg Oid = 1018 70 | T__path Oid = 1019 71 | T__box Oid = 1020 72 | T__float4 Oid = 1021 73 | T__float8 Oid = 1022 74 | T__abstime Oid = 1023 75 | T__reltime Oid = 1024 76 | T__tinterval Oid = 1025 77 | T__polygon Oid = 1027 78 | T__oid Oid = 1028 79 | T_aclitem Oid = 1033 80 | T__aclitem Oid = 1034 81 | T__macaddr Oid = 1040 82 | T__inet Oid = 1041 83 | T_bpchar Oid = 1042 84 | T_varchar Oid = 1043 85 | T_date Oid = 1082 86 | T_time Oid = 1083 87 | T_timestamp Oid = 1114 88 | T__timestamp Oid = 1115 89 | T__date Oid = 1182 90 | T__time Oid = 1183 91 | T_timestamptz Oid = 1184 92 | T__timestamptz Oid = 1185 93 | T_interval Oid = 1186 94 | T__interval Oid = 1187 95 | T__numeric Oid = 1231 96 | T_pg_database Oid = 1248 97 | T__cstring Oid = 1263 98 | T_timetz Oid = 1266 99 | T__timetz Oid = 1270 100 | T_bit Oid = 1560 101 | T__bit Oid = 1561 102 | T_varbit Oid = 1562 103 | T__varbit Oid = 1563 104 | T_numeric Oid = 1700 105 | T_refcursor Oid = 1790 106 | T__refcursor Oid = 2201 107 | T_regprocedure Oid = 2202 108 | T_regoper Oid = 2203 109 | T_regoperator Oid = 2204 110 | T_regclass Oid = 2205 111 | T_regtype Oid = 2206 112 | T__regprocedure Oid = 2207 113 | T__regoper Oid = 2208 114 | T__regoperator Oid = 2209 115 | T__regclass Oid = 2210 116 | T__regtype Oid = 2211 117 | T_record Oid = 2249 118 | T_cstring Oid = 2275 119 | T_any Oid = 2276 120 | T_anyarray Oid = 2277 121 | T_void Oid = 2278 122 | T_trigger Oid = 2279 123 | T_language_handler Oid = 2280 124 | T_internal Oid = 2281 125 | T_opaque Oid = 2282 126 | T_anyelement Oid = 2283 127 | T__record Oid = 2287 128 | T_anynonarray Oid = 2776 129 | T_pg_authid Oid = 2842 130 | T_pg_auth_members Oid = 2843 131 | T__txid_snapshot Oid = 2949 132 | T_uuid Oid = 2950 133 | T__uuid Oid = 2951 134 | T_txid_snapshot Oid = 2970 135 | T_fdw_handler Oid = 3115 136 | T_anyenum Oid = 3500 137 | T_tsvector Oid = 3614 138 | T_tsquery Oid = 3615 139 | T_gtsvector Oid = 3642 140 | T__tsvector Oid = 3643 141 | T__gtsvector Oid = 3644 142 | T__tsquery Oid = 3645 143 | T_regconfig Oid = 3734 144 | T__regconfig Oid = 3735 145 | T_regdictionary Oid = 3769 146 | T__regdictionary Oid = 3770 147 | T_anyrange Oid = 3831 148 | T_event_trigger Oid = 3838 149 | T_int4range Oid = 3904 150 | T__int4range Oid = 3905 151 | T_numrange Oid = 3906 152 | T__numrange Oid = 3907 153 | T_tsrange Oid = 3908 154 | T__tsrange Oid = 3909 155 | T_tstzrange Oid = 3910 156 | T__tstzrange Oid = 3911 157 | T_daterange Oid = 3912 158 | T__daterange Oid = 3913 159 | T_int8range Oid = 3926 160 | T__int8range Oid = 3927 161 | ) 162 | -------------------------------------------------------------------------------- /copy.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "database/sql/driver" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "sync" 9 | ) 10 | 11 | var ( 12 | errCopyInClosed = errors.New("pq: copyin statement has already been closed") 13 | errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") 14 | errCopyToNotSupported = errors.New("pq: COPY TO is not supported") 15 | errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") 16 | errCopyInProgress = errors.New("pq: COPY in progress") 17 | ) 18 | 19 | // CopyIn creates a COPY FROM statement which can be prepared with 20 | // Tx.Prepare(). The target table should be visible in search_path. 21 | func CopyIn(table string, columns ...string) string { 22 | stmt := "COPY " + QuoteIdentifier(table) + " (" 23 | for i, col := range columns { 24 | if i != 0 { 25 | stmt += ", " 26 | } 27 | stmt += QuoteIdentifier(col) 28 | } 29 | stmt += ") FROM STDIN" 30 | return stmt 31 | } 32 | 33 | // CopyInSchema creates a COPY FROM statement which can be prepared with 34 | // Tx.Prepare(). 35 | func CopyInSchema(schema, table string, columns ...string) string { 36 | stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " (" 37 | for i, col := range columns { 38 | if i != 0 { 39 | stmt += ", " 40 | } 41 | stmt += QuoteIdentifier(col) 42 | } 43 | stmt += ") FROM STDIN" 44 | return stmt 45 | } 46 | 47 | type copyin struct { 48 | cn *conn 49 | buffer []byte 50 | rowData chan []byte 51 | done chan bool 52 | 53 | closed bool 54 | 55 | sync.Mutex // guards err 56 | err error 57 | } 58 | 59 | const ciBufferSize = 64 * 1024 60 | 61 | // flush buffer before the buffer is filled up and needs reallocation 62 | const ciBufferFlushSize = 63 * 1024 63 | 64 | func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) { 65 | if !cn.isInTransaction() { 66 | return nil, errCopyNotSupportedOutsideTxn 67 | } 68 | 69 | ci := ©in{ 70 | cn: cn, 71 | buffer: make([]byte, 0, ciBufferSize), 72 | rowData: make(chan []byte), 73 | done: make(chan bool, 1), 74 | } 75 | // add CopyData identifier + 4 bytes for message length 76 | ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0) 77 | 78 | b := cn.writeBuf('Q') 79 | b.string(q) 80 | cn.send(b) 81 | 82 | awaitCopyInResponse: 83 | for { 84 | t, r := cn.recv1() 85 | switch t { 86 | case 'G': 87 | if r.byte() != 0 { 88 | err = errBinaryCopyNotSupported 89 | break awaitCopyInResponse 90 | } 91 | go ci.resploop() 92 | return ci, nil 93 | case 'H': 94 | err = errCopyToNotSupported 95 | break awaitCopyInResponse 96 | case 'E': 97 | err = parseError(r) 98 | case 'Z': 99 | if err == nil { 100 | cn.bad = true 101 | errorf("unexpected ReadyForQuery in response to COPY") 102 | } 103 | cn.processReadyForQuery(r) 104 | return nil, err 105 | default: 106 | cn.bad = true 107 | errorf("unknown response for copy query: %q", t) 108 | } 109 | } 110 | 111 | // something went wrong, abort COPY before we return 112 | b = cn.writeBuf('f') 113 | b.string(err.Error()) 114 | cn.send(b) 115 | 116 | for { 117 | t, r := cn.recv1() 118 | switch t { 119 | case 'c', 'C', 'E': 120 | case 'Z': 121 | // correctly aborted, we're done 122 | cn.processReadyForQuery(r) 123 | return nil, err 124 | default: 125 | cn.bad = true 126 | errorf("unknown response for CopyFail: %q", t) 127 | } 128 | } 129 | } 130 | 131 | func (ci *copyin) flush(buf []byte) { 132 | // set message length (without message identifier) 133 | binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) 134 | 135 | _, err := ci.cn.c.Write(buf) 136 | if err != nil { 137 | panic(err) 138 | } 139 | } 140 | 141 | func (ci *copyin) resploop() { 142 | for { 143 | var r readBuf 144 | t, err := ci.cn.recvMessage(&r) 145 | if err != nil { 146 | ci.cn.bad = true 147 | ci.setError(err) 148 | ci.done <- true 149 | return 150 | } 151 | switch t { 152 | case 'C': 153 | // complete 154 | case 'N': 155 | // NoticeResponse 156 | case 'Z': 157 | ci.cn.processReadyForQuery(&r) 158 | ci.done <- true 159 | return 160 | case 'E': 161 | err := parseError(&r) 162 | ci.setError(err) 163 | default: 164 | ci.cn.bad = true 165 | ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) 166 | ci.done <- true 167 | return 168 | } 169 | } 170 | } 171 | 172 | func (ci *copyin) isErrorSet() bool { 173 | ci.Lock() 174 | isSet := (ci.err != nil) 175 | ci.Unlock() 176 | return isSet 177 | } 178 | 179 | // setError() sets ci.err if one has not been set already. Caller must not be 180 | // holding ci.Mutex. 181 | func (ci *copyin) setError(err error) { 182 | ci.Lock() 183 | if ci.err == nil { 184 | ci.err = err 185 | } 186 | ci.Unlock() 187 | } 188 | 189 | func (ci *copyin) NumInput() int { 190 | return -1 191 | } 192 | 193 | func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { 194 | return nil, ErrNotSupported 195 | } 196 | 197 | // Exec inserts values into the COPY stream. The insert is asynchronous 198 | // and Exec can return errors from previous Exec calls to the same 199 | // COPY stmt. 200 | // 201 | // You need to call Exec(nil) to sync the COPY stream and to get any 202 | // errors from pending data, since Stmt.Close() doesn't return errors 203 | // to the user. 204 | func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { 205 | if ci.closed { 206 | return nil, errCopyInClosed 207 | } 208 | 209 | if ci.cn.bad { 210 | return nil, driver.ErrBadConn 211 | } 212 | defer ci.cn.errRecover(&err) 213 | 214 | if ci.isErrorSet() { 215 | return nil, ci.err 216 | } 217 | 218 | if len(v) == 0 { 219 | return nil, ci.Close() 220 | } 221 | 222 | numValues := len(v) 223 | for i, value := range v { 224 | ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value) 225 | if i < numValues-1 { 226 | ci.buffer = append(ci.buffer, '\t') 227 | } 228 | } 229 | 230 | ci.buffer = append(ci.buffer, '\n') 231 | 232 | if len(ci.buffer) > ciBufferFlushSize { 233 | ci.flush(ci.buffer) 234 | // reset buffer, keep bytes for message identifier and length 235 | ci.buffer = ci.buffer[:5] 236 | } 237 | 238 | return driver.RowsAffected(0), nil 239 | } 240 | 241 | func (ci *copyin) Close() (err error) { 242 | if ci.closed { // Don't do anything, we're already closed 243 | return nil 244 | } 245 | ci.closed = true 246 | 247 | if ci.cn.bad { 248 | return driver.ErrBadConn 249 | } 250 | defer ci.cn.errRecover(&err) 251 | 252 | if len(ci.buffer) > 0 { 253 | ci.flush(ci.buffer) 254 | } 255 | // Avoid touching the scratch buffer as resploop could be using it. 256 | err = ci.cn.sendSimpleMessage('c') 257 | if err != nil { 258 | return err 259 | } 260 | 261 | <-ci.done 262 | ci.cn.inCopy = false 263 | 264 | if ci.isErrorSet() { 265 | err = ci.err 266 | return err 267 | } 268 | return nil 269 | } 270 | -------------------------------------------------------------------------------- /ssl_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | // This file contains SSL tests 4 | 5 | import ( 6 | _ "crypto/sha256" 7 | "crypto/x509" 8 | "database/sql" 9 | "fmt" 10 | "os" 11 | "path/filepath" 12 | "testing" 13 | ) 14 | 15 | func maybeSkipSSLTests(t *testing.T) { 16 | // Require some special variables for testing certificates 17 | if os.Getenv("PQSSLCERTTEST_PATH") == "" { 18 | t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests") 19 | } 20 | 21 | value := os.Getenv("PQGOSSLTESTS") 22 | if value == "" || value == "0" { 23 | t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests") 24 | } else if value != "1" { 25 | t.Fatalf("unexpected value %q for PQGOSSLTESTS", value) 26 | } 27 | } 28 | 29 | func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) { 30 | db, err := openTestConnConninfo(conninfo) 31 | if err != nil { 32 | // should never fail 33 | t.Fatal(err) 34 | } 35 | // Do something with the connection to see whether it's working or not. 36 | tx, err := db.Begin() 37 | if err == nil { 38 | return db, tx.Rollback() 39 | } 40 | _ = db.Close() 41 | return nil, err 42 | } 43 | 44 | func checkSSLSetup(t *testing.T, conninfo string) { 45 | db, err := openSSLConn(t, conninfo) 46 | if err == nil { 47 | db.Close() 48 | t.Fatalf("expected error with conninfo=%q", conninfo) 49 | } 50 | } 51 | 52 | // Connect over SSL and run a simple query to test the basics 53 | func TestSSLConnection(t *testing.T) { 54 | maybeSkipSSLTests(t) 55 | // Environment sanity check: should fail without SSL 56 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 57 | 58 | db, err := openSSLConn(t, "sslmode=require user=pqgossltest") 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | rows, err := db.Query("SELECT 1") 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | rows.Close() 67 | } 68 | 69 | // Test sslmode=verify-full 70 | func TestSSLVerifyFull(t *testing.T) { 71 | maybeSkipSSLTests(t) 72 | // Environment sanity check: should fail without SSL 73 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 74 | 75 | // Not OK according to the system CA 76 | _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest") 77 | if err == nil { 78 | t.Fatal("expected error") 79 | } 80 | _, ok := err.(x509.UnknownAuthorityError) 81 | if !ok { 82 | t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) 83 | } 84 | 85 | rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") 86 | rootCert := "sslrootcert=" + rootCertPath + " " 87 | // No match on Common Name 88 | _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest") 89 | if err == nil { 90 | t.Fatal("expected error") 91 | } 92 | _, ok = err.(x509.HostnameError) 93 | if !ok { 94 | t.Fatalf("expected x509.HostnameError, got %#+v", err) 95 | } 96 | // OK 97 | _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest") 98 | if err != nil { 99 | t.Fatal(err) 100 | } 101 | } 102 | 103 | // Test sslmode=require sslrootcert=rootCertPath 104 | func TestSSLRequireWithRootCert(t *testing.T) { 105 | maybeSkipSSLTests(t) 106 | // Environment sanity check: should fail without SSL 107 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 108 | 109 | bogusRootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "bogus_root.crt") 110 | bogusRootCert := "sslrootcert=" + bogusRootCertPath + " " 111 | 112 | // Not OK according to the bogus CA 113 | _, err := openSSLConn(t, bogusRootCert+"host=postgres sslmode=require user=pqgossltest") 114 | if err == nil { 115 | t.Fatal("expected error") 116 | } 117 | _, ok := err.(x509.UnknownAuthorityError) 118 | if !ok { 119 | t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err) 120 | } 121 | 122 | nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt") 123 | nonExistentCert := "sslrootcert=" + nonExistentCertPath + " " 124 | 125 | // No match on Common Name, but that's OK because we're not validating anything. 126 | _, err = openSSLConn(t, nonExistentCert+"host=127.0.0.1 sslmode=require user=pqgossltest") 127 | if err != nil { 128 | t.Fatal(err) 129 | } 130 | 131 | rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") 132 | rootCert := "sslrootcert=" + rootCertPath + " " 133 | 134 | // No match on Common Name, but that's OK because we're not validating the CN. 135 | _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=require user=pqgossltest") 136 | if err != nil { 137 | t.Fatal(err) 138 | } 139 | // Everything OK 140 | _, err = openSSLConn(t, rootCert+"host=postgres sslmode=require user=pqgossltest") 141 | if err != nil { 142 | t.Fatal(err) 143 | } 144 | } 145 | 146 | // Test sslmode=verify-ca 147 | func TestSSLVerifyCA(t *testing.T) { 148 | maybeSkipSSLTests(t) 149 | // Environment sanity check: should fail without SSL 150 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 151 | 152 | // Not OK according to the system CA 153 | _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") 154 | if err == nil { 155 | t.Fatal("expected error") 156 | } 157 | _, ok := err.(x509.UnknownAuthorityError) 158 | if !ok { 159 | t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) 160 | } 161 | 162 | rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") 163 | rootCert := "sslrootcert=" + rootCertPath + " " 164 | // No match on Common Name, but that's OK 165 | _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest") 166 | if err != nil { 167 | t.Fatal(err) 168 | } 169 | // Everything OK 170 | _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest") 171 | if err != nil { 172 | t.Fatal(err) 173 | } 174 | } 175 | 176 | func getCertConninfo(t *testing.T, source string) string { 177 | var sslkey string 178 | var sslcert string 179 | 180 | certpath := os.Getenv("PQSSLCERTTEST_PATH") 181 | 182 | switch source { 183 | case "missingkey": 184 | sslkey = "/tmp/filedoesnotexist" 185 | sslcert = filepath.Join(certpath, "postgresql.crt") 186 | case "missingcert": 187 | sslkey = filepath.Join(certpath, "postgresql.key") 188 | sslcert = "/tmp/filedoesnotexist" 189 | case "certtwice": 190 | sslkey = filepath.Join(certpath, "postgresql.crt") 191 | sslcert = filepath.Join(certpath, "postgresql.crt") 192 | case "valid": 193 | sslkey = filepath.Join(certpath, "postgresql.key") 194 | sslcert = filepath.Join(certpath, "postgresql.crt") 195 | default: 196 | t.Fatalf("invalid source %q", source) 197 | } 198 | return fmt.Sprintf("sslmode=require user=pqgosslcert sslkey=%s sslcert=%s", sslkey, sslcert) 199 | } 200 | 201 | // Authenticate over SSL using client certificates 202 | func TestSSLClientCertificates(t *testing.T) { 203 | maybeSkipSSLTests(t) 204 | // Environment sanity check: should fail without SSL 205 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 206 | 207 | // Should also fail without a valid certificate 208 | db, err := openSSLConn(t, "sslmode=require user=pqgosslcert") 209 | if err == nil { 210 | db.Close() 211 | t.Fatal("expected error") 212 | } 213 | pge, ok := err.(*Error) 214 | if !ok { 215 | t.Fatal("expected pq.Error") 216 | } 217 | if pge.Code.Name() != "invalid_authorization_specification" { 218 | t.Fatalf("unexpected error code %q", pge.Code.Name()) 219 | } 220 | 221 | // Should work 222 | db, err = openSSLConn(t, getCertConninfo(t, "valid")) 223 | if err != nil { 224 | t.Fatal(err) 225 | } 226 | rows, err := db.Query("SELECT 1") 227 | if err != nil { 228 | t.Fatal(err) 229 | } 230 | rows.Close() 231 | } 232 | 233 | // Test errors with ssl certificates 234 | func TestSSLClientCertificatesMissingFiles(t *testing.T) { 235 | maybeSkipSSLTests(t) 236 | // Environment sanity check: should fail without SSL 237 | checkSSLSetup(t, "sslmode=disable user=pqgossltest") 238 | 239 | // Key missing, should fail 240 | _, err := openSSLConn(t, getCertConninfo(t, "missingkey")) 241 | if err == nil { 242 | t.Fatal("expected error") 243 | } 244 | // should be a PathError 245 | _, ok := err.(*os.PathError) 246 | if !ok { 247 | t.Fatalf("expected PathError, got %#+v", err) 248 | } 249 | 250 | // Cert missing, should fail 251 | _, err = openSSLConn(t, getCertConninfo(t, "missingcert")) 252 | if err == nil { 253 | t.Fatal("expected error") 254 | } 255 | // should be a PathError 256 | _, ok = err.(*os.PathError) 257 | if !ok { 258 | t.Fatalf("expected PathError, got %#+v", err) 259 | } 260 | 261 | // Key has wrong permissions, should fail 262 | _, err = openSSLConn(t, getCertConninfo(t, "certtwice")) 263 | if err == nil { 264 | t.Fatal("expected error") 265 | } 266 | if err != ErrSSLKeyHasWorldPermissions { 267 | t.Fatalf("expected ErrSSLKeyHasWorldPermissions, got %#+v", err) 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package pq is a pure Go Postgres driver for the database/sql package. 3 | 4 | In most cases clients will use the database/sql package instead of 5 | using this package directly. For example: 6 | 7 | import ( 8 | "database/sql" 9 | 10 | _ "github.com/lib/pq" 11 | ) 12 | 13 | func main() { 14 | db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full") 15 | if err != nil { 16 | log.Fatal(err) 17 | } 18 | 19 | age := 21 20 | rows, err := db.Query("SELECT name FROM users WHERE age = $1", age) 21 | … 22 | } 23 | 24 | You can also connect to a database using a URL. For example: 25 | 26 | db, err := sql.Open("postgres", "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full") 27 | 28 | 29 | Connection String Parameters 30 | 31 | 32 | Similarly to libpq, when establishing a connection using pq you are expected to 33 | supply a connection string containing zero or more parameters. 34 | A subset of the connection parameters supported by libpq are also supported by pq. 35 | Additionally, pq also lets you specify run-time parameters (such as search_path or work_mem) 36 | directly in the connection string. This is different from libpq, which does not allow 37 | run-time parameters in the connection string, instead requiring you to supply 38 | them in the options parameter. 39 | 40 | For compatibility with libpq, the following special connection parameters are 41 | supported: 42 | 43 | * dbname - The name of the database to connect to 44 | * user - The user to sign in as 45 | * password - The user's password 46 | * host - The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) 47 | * port - The port to bind to. (default is 5432) 48 | * sslmode - Whether or not to use SSL (default is require, this is not the default for libpq) 49 | * fallback_application_name - An application_name to fall back to if one isn't provided. 50 | * connect_timeout - Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. 51 | * sslcert - Cert file location. The file must contain PEM encoded data. 52 | * sslkey - Key file location. The file must contain PEM encoded data. 53 | * sslrootcert - The location of the root certificate file. The file must contain PEM encoded data. 54 | 55 | Valid values for sslmode are: 56 | 57 | * disable - No SSL 58 | * require - Always SSL (skip verification) 59 | * verify-ca - Always SSL (verify that the certificate presented by the server was signed by a trusted CA) 60 | * verify-full - Always SSL (verify that the certification presented by the server was signed by a trusted CA and the server host name matches the one in the certificate) 61 | 62 | See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING 63 | for more information about connection string parameters. 64 | 65 | Use single quotes for values that contain whitespace: 66 | 67 | "user=pqgotest password='with spaces'" 68 | 69 | A backslash will escape the next character in values: 70 | 71 | "user=space\ man password='it\'s valid' 72 | 73 | Note that the connection parameter client_encoding (which sets the 74 | text encoding for the connection) may be set but must be "UTF8", 75 | matching with the same rules as Postgres. It is an error to provide 76 | any other value. 77 | 78 | In addition to the parameters listed above, any run-time parameter that can be 79 | set at backend start time can be set in the connection string. For more 80 | information, see 81 | http://www.postgresql.org/docs/current/static/runtime-config.html. 82 | 83 | Most environment variables as specified at http://www.postgresql.org/docs/current/static/libpq-envars.html 84 | supported by libpq are also supported by pq. If any of the environment 85 | variables not supported by pq are set, pq will panic during connection 86 | establishment. Environment variables have a lower precedence than explicitly 87 | provided connection parameters. 88 | 89 | The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html 90 | is supported, but on Windows PGPASSFILE must be specified explicitly. 91 | 92 | 93 | Queries 94 | 95 | 96 | database/sql does not dictate any specific format for parameter 97 | markers in query strings, and pq uses the Postgres-native ordinal markers, 98 | as shown above. The same marker can be reused for the same parameter: 99 | 100 | rows, err := db.Query(`SELECT name FROM users WHERE favorite_fruit = $1 101 | OR age BETWEEN $2 AND $2 + 3`, "orange", 64) 102 | 103 | pq does not support the LastInsertId() method of the Result type in database/sql. 104 | To return the identifier of an INSERT (or UPDATE or DELETE), use the Postgres 105 | RETURNING clause with a standard Query or QueryRow call: 106 | 107 | var userid int 108 | err := db.QueryRow(`INSERT INTO users(name, favorite_fruit, age) 109 | VALUES('beatrice', 'starfruit', 93) RETURNING id`).Scan(&userid) 110 | 111 | For more details on RETURNING, see the Postgres documentation: 112 | 113 | http://www.postgresql.org/docs/current/static/sql-insert.html 114 | http://www.postgresql.org/docs/current/static/sql-update.html 115 | http://www.postgresql.org/docs/current/static/sql-delete.html 116 | 117 | For additional instructions on querying see the documentation for the database/sql package. 118 | 119 | 120 | Data Types 121 | 122 | 123 | Parameters pass through driver.DefaultParameterConverter before they are handled 124 | by this package. When the binary_parameters connection option is enabled, 125 | []byte values are sent directly to the backend as data in binary format. 126 | 127 | This package returns the following types for values from the PostgreSQL backend: 128 | 129 | - integer types smallint, integer, and bigint are returned as int64 130 | - floating-point types real and double precision are returned as float64 131 | - character types char, varchar, and text are returned as string 132 | - temporal types date, time, timetz, timestamp, and timestamptz are returned as time.Time 133 | - the boolean type is returned as bool 134 | - the bytea type is returned as []byte 135 | 136 | All other types are returned directly from the backend as []byte values in text format. 137 | 138 | 139 | Errors 140 | 141 | 142 | pq may return errors of type *pq.Error which can be interrogated for error details: 143 | 144 | if err, ok := err.(*pq.Error); ok { 145 | fmt.Println("pq error:", err.Code.Name()) 146 | } 147 | 148 | See the pq.Error type for details. 149 | 150 | 151 | Bulk imports 152 | 153 | You can perform bulk imports by preparing a statement returned by pq.CopyIn (or 154 | pq.CopyInSchema) in an explicit transaction (sql.Tx). The returned statement 155 | handle can then be repeatedly "executed" to copy data into the target table. 156 | After all data has been processed you should call Exec() once with no arguments 157 | to flush all buffered data. Any call to Exec() might return an error which 158 | should be handled appropriately, but because of the internal buffering an error 159 | returned by Exec() might not be related to the data passed in the call that 160 | failed. 161 | 162 | CopyIn uses COPY FROM internally. It is not possible to COPY outside of an 163 | explicit transaction in pq. 164 | 165 | Usage example: 166 | 167 | txn, err := db.Begin() 168 | if err != nil { 169 | log.Fatal(err) 170 | } 171 | 172 | stmt, err := txn.Prepare(pq.CopyIn("users", "name", "age")) 173 | if err != nil { 174 | log.Fatal(err) 175 | } 176 | 177 | for _, user := range users { 178 | _, err = stmt.Exec(user.Name, int64(user.Age)) 179 | if err != nil { 180 | log.Fatal(err) 181 | } 182 | } 183 | 184 | _, err = stmt.Exec() 185 | if err != nil { 186 | log.Fatal(err) 187 | } 188 | 189 | err = stmt.Close() 190 | if err != nil { 191 | log.Fatal(err) 192 | } 193 | 194 | err = txn.Commit() 195 | if err != nil { 196 | log.Fatal(err) 197 | } 198 | 199 | 200 | Notifications 201 | 202 | 203 | PostgreSQL supports a simple publish/subscribe model over database 204 | connections. See http://www.postgresql.org/docs/current/static/sql-notify.html 205 | for more information about the general mechanism. 206 | 207 | To start listening for notifications, you first have to open a new connection 208 | to the database by calling NewListener. This connection can not be used for 209 | anything other than LISTEN / NOTIFY. Calling Listen will open a "notification 210 | channel"; once a notification channel is open, a notification generated on that 211 | channel will effect a send on the Listener.Notify channel. A notification 212 | channel will remain open until Unlisten is called, though connection loss might 213 | result in some notifications being lost. To solve this problem, Listener sends 214 | a nil pointer over the Notify channel any time the connection is re-established 215 | following a connection loss. The application can get information about the 216 | state of the underlying connection by setting an event callback in the call to 217 | NewListener. 218 | 219 | A single Listener can safely be used from concurrent goroutines, which means 220 | that there is often no need to create more than one Listener in your 221 | application. However, a Listener is always connected to a single database, so 222 | you will need to create a new Listener instance for every database you want to 223 | receive notifications in. 224 | 225 | The channel name in both Listen and Unlisten is case sensitive, and can contain 226 | any characters legal in an identifier (see 227 | http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS 228 | for more information). Note that the channel name will be truncated to 63 229 | bytes by the PostgreSQL server. 230 | 231 | You can find a complete, working example of Listener usage at 232 | http://godoc.org/github.com/lib/pq/listen_example. 233 | 234 | */ 235 | package pq 236 | -------------------------------------------------------------------------------- /copy_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "database/sql/driver" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestCopyInStmt(t *testing.T) { 12 | var stmt string 13 | stmt = CopyIn("table name") 14 | if stmt != `COPY "table name" () FROM STDIN` { 15 | t.Fatal(stmt) 16 | } 17 | 18 | stmt = CopyIn("table name", "column 1", "column 2") 19 | if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` { 20 | t.Fatal(stmt) 21 | } 22 | 23 | stmt = CopyIn(`table " name """`, `co"lumn""`) 24 | if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` { 25 | t.Fatal(stmt) 26 | } 27 | } 28 | 29 | func TestCopyInSchemaStmt(t *testing.T) { 30 | var stmt string 31 | stmt = CopyInSchema("schema name", "table name") 32 | if stmt != `COPY "schema name"."table name" () FROM STDIN` { 33 | t.Fatal(stmt) 34 | } 35 | 36 | stmt = CopyInSchema("schema name", "table name", "column 1", "column 2") 37 | if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` { 38 | t.Fatal(stmt) 39 | } 40 | 41 | stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`) 42 | if stmt != `COPY "schema "" name """"""".`+ 43 | `"table "" name """"""" ("co""lumn""""") FROM STDIN` { 44 | t.Fatal(stmt) 45 | } 46 | } 47 | 48 | func TestCopyInMultipleValues(t *testing.T) { 49 | db := openTestConn(t) 50 | defer db.Close() 51 | 52 | txn, err := db.Begin() 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | defer txn.Rollback() 57 | 58 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | 68 | longString := strings.Repeat("#", 500) 69 | 70 | for i := 0; i < 500; i++ { 71 | _, err = stmt.Exec(int64(i), longString) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | } 76 | 77 | _, err = stmt.Exec() 78 | if err != nil { 79 | t.Fatal(err) 80 | } 81 | 82 | err = stmt.Close() 83 | if err != nil { 84 | t.Fatal(err) 85 | } 86 | 87 | var num int 88 | err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 89 | if err != nil { 90 | t.Fatal(err) 91 | } 92 | 93 | if num != 500 { 94 | t.Fatalf("expected 500 items, not %d", num) 95 | } 96 | } 97 | 98 | func TestCopyInRaiseStmtTrigger(t *testing.T) { 99 | db := openTestConn(t) 100 | defer db.Close() 101 | 102 | if getServerVersion(t, db) < 90000 { 103 | var exists int 104 | err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists) 105 | if err == sql.ErrNoRows { 106 | t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger") 107 | } else if err != nil { 108 | t.Fatal(err) 109 | } 110 | } 111 | 112 | txn, err := db.Begin() 113 | if err != nil { 114 | t.Fatal(err) 115 | } 116 | defer txn.Rollback() 117 | 118 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 119 | if err != nil { 120 | t.Fatal(err) 121 | } 122 | 123 | _, err = txn.Exec(` 124 | CREATE OR REPLACE FUNCTION pg_temp.temptest() 125 | RETURNS trigger AS 126 | $BODY$ begin 127 | raise notice 'Hello world'; 128 | return new; 129 | end $BODY$ 130 | LANGUAGE plpgsql`) 131 | if err != nil { 132 | t.Fatal(err) 133 | } 134 | 135 | _, err = txn.Exec(` 136 | CREATE TRIGGER temptest_trigger 137 | BEFORE INSERT 138 | ON temp 139 | FOR EACH ROW 140 | EXECUTE PROCEDURE pg_temp.temptest()`) 141 | if err != nil { 142 | t.Fatal(err) 143 | } 144 | 145 | stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 146 | if err != nil { 147 | t.Fatal(err) 148 | } 149 | 150 | longString := strings.Repeat("#", 500) 151 | 152 | _, err = stmt.Exec(int64(1), longString) 153 | if err != nil { 154 | t.Fatal(err) 155 | } 156 | 157 | _, err = stmt.Exec() 158 | if err != nil { 159 | t.Fatal(err) 160 | } 161 | 162 | err = stmt.Close() 163 | if err != nil { 164 | t.Fatal(err) 165 | } 166 | 167 | var num int 168 | err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | if num != 1 { 174 | t.Fatalf("expected 1 items, not %d", num) 175 | } 176 | } 177 | 178 | func TestCopyInTypes(t *testing.T) { 179 | db := openTestConn(t) 180 | defer db.Close() 181 | 182 | txn, err := db.Begin() 183 | if err != nil { 184 | t.Fatal(err) 185 | } 186 | defer txn.Rollback() 187 | 188 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)") 189 | if err != nil { 190 | t.Fatal(err) 191 | } 192 | 193 | stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) 194 | if err != nil { 195 | t.Fatal(err) 196 | } 197 | 198 | _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil) 199 | if err != nil { 200 | t.Fatal(err) 201 | } 202 | 203 | _, err = stmt.Exec() 204 | if err != nil { 205 | t.Fatal(err) 206 | } 207 | 208 | err = stmt.Close() 209 | if err != nil { 210 | t.Fatal(err) 211 | } 212 | 213 | var num int 214 | var text string 215 | var blob []byte 216 | var nothing sql.NullString 217 | 218 | err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing) 219 | if err != nil { 220 | t.Fatal(err) 221 | } 222 | 223 | if num != 1234567890 { 224 | t.Fatal("unexpected result", num) 225 | } 226 | if text != "Héllö\n ☃!\r\t\\" { 227 | t.Fatal("unexpected result", text) 228 | } 229 | if bytes.Compare(blob, []byte{0, 255, 9, 10, 13}) != 0 { 230 | t.Fatal("unexpected result", blob) 231 | } 232 | if nothing.Valid { 233 | t.Fatal("unexpected result", nothing.String) 234 | } 235 | } 236 | 237 | func TestCopyInWrongType(t *testing.T) { 238 | db := openTestConn(t) 239 | defer db.Close() 240 | 241 | txn, err := db.Begin() 242 | if err != nil { 243 | t.Fatal(err) 244 | } 245 | defer txn.Rollback() 246 | 247 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 248 | if err != nil { 249 | t.Fatal(err) 250 | } 251 | 252 | stmt, err := txn.Prepare(CopyIn("temp", "num")) 253 | if err != nil { 254 | t.Fatal(err) 255 | } 256 | defer stmt.Close() 257 | 258 | _, err = stmt.Exec("Héllö\n ☃!\r\t\\") 259 | if err != nil { 260 | t.Fatal(err) 261 | } 262 | 263 | _, err = stmt.Exec() 264 | if err == nil { 265 | t.Fatal("expected error") 266 | } 267 | if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" { 268 | t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge) 269 | } 270 | } 271 | 272 | func TestCopyOutsideOfTxnError(t *testing.T) { 273 | db := openTestConn(t) 274 | defer db.Close() 275 | 276 | _, err := db.Prepare(CopyIn("temp", "num")) 277 | if err == nil { 278 | t.Fatal("COPY outside of transaction did not return an error") 279 | } 280 | if err != errCopyNotSupportedOutsideTxn { 281 | t.Fatalf("expected %s, got %s", err, err.Error()) 282 | } 283 | } 284 | 285 | func TestCopyInBinaryError(t *testing.T) { 286 | db := openTestConn(t) 287 | defer db.Close() 288 | 289 | txn, err := db.Begin() 290 | if err != nil { 291 | t.Fatal(err) 292 | } 293 | defer txn.Rollback() 294 | 295 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 296 | if err != nil { 297 | t.Fatal(err) 298 | } 299 | _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary") 300 | if err != errBinaryCopyNotSupported { 301 | t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err) 302 | } 303 | // check that the protocol is in a valid state 304 | err = txn.Rollback() 305 | if err != nil { 306 | t.Fatal(err) 307 | } 308 | } 309 | 310 | func TestCopyFromError(t *testing.T) { 311 | db := openTestConn(t) 312 | defer db.Close() 313 | 314 | txn, err := db.Begin() 315 | if err != nil { 316 | t.Fatal(err) 317 | } 318 | defer txn.Rollback() 319 | 320 | _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") 321 | if err != nil { 322 | t.Fatal(err) 323 | } 324 | _, err = txn.Prepare("COPY temp (num) TO STDOUT") 325 | if err != errCopyToNotSupported { 326 | t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err) 327 | } 328 | // check that the protocol is in a valid state 329 | err = txn.Rollback() 330 | if err != nil { 331 | t.Fatal(err) 332 | } 333 | } 334 | 335 | func TestCopySyntaxError(t *testing.T) { 336 | db := openTestConn(t) 337 | defer db.Close() 338 | 339 | txn, err := db.Begin() 340 | if err != nil { 341 | t.Fatal(err) 342 | } 343 | defer txn.Rollback() 344 | 345 | _, err = txn.Prepare("COPY ") 346 | if err == nil { 347 | t.Fatal("expected error") 348 | } 349 | if pge := err.(*Error); pge.Code.Name() != "syntax_error" { 350 | t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge) 351 | } 352 | // check that the protocol is in a valid state 353 | err = txn.Rollback() 354 | if err != nil { 355 | t.Fatal(err) 356 | } 357 | } 358 | 359 | // Tests for connection errors in copyin.resploop() 360 | func TestCopyRespLoopConnectionError(t *testing.T) { 361 | db := openTestConn(t) 362 | defer db.Close() 363 | 364 | txn, err := db.Begin() 365 | if err != nil { 366 | t.Fatal(err) 367 | } 368 | defer txn.Rollback() 369 | 370 | var pid int 371 | err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid) 372 | if err != nil { 373 | t.Fatal(err) 374 | } 375 | 376 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int)") 377 | if err != nil { 378 | t.Fatal(err) 379 | } 380 | 381 | stmt, err := txn.Prepare(CopyIn("temp", "a")) 382 | if err != nil { 383 | t.Fatal(err) 384 | } 385 | defer stmt.Close() 386 | 387 | _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) 388 | if err != nil { 389 | t.Fatal(err) 390 | } 391 | 392 | if getServerVersion(t, db) < 90500 { 393 | // We have to try and send something over, since postgres before 394 | // version 9.5 won't process SIGTERMs while it's waiting for 395 | // CopyData/CopyEnd messages; see tcop/postgres.c. 396 | _, err = stmt.Exec(1) 397 | if err != nil { 398 | t.Fatal(err) 399 | } 400 | } 401 | _, err = stmt.Exec() 402 | if err == nil { 403 | t.Fatalf("expected error") 404 | } 405 | pge, ok := err.(*Error) 406 | if !ok { 407 | if err == driver.ErrBadConn { 408 | // likely an EPIPE 409 | } else { 410 | t.Fatalf("expected *pq.Error or driver.ErrBadConn, got %+#v", err) 411 | } 412 | } else if pge.Code.Name() != "admin_shutdown" { 413 | t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) 414 | } 415 | 416 | _ = stmt.Close() 417 | } 418 | 419 | func BenchmarkCopyIn(b *testing.B) { 420 | db := openTestConn(b) 421 | defer db.Close() 422 | 423 | txn, err := db.Begin() 424 | if err != nil { 425 | b.Fatal(err) 426 | } 427 | defer txn.Rollback() 428 | 429 | _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") 430 | if err != nil { 431 | b.Fatal(err) 432 | } 433 | 434 | stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) 435 | if err != nil { 436 | b.Fatal(err) 437 | } 438 | 439 | for i := 0; i < b.N; i++ { 440 | _, err = stmt.Exec(int64(i), "hello world!") 441 | if err != nil { 442 | b.Fatal(err) 443 | } 444 | } 445 | 446 | _, err = stmt.Exec() 447 | if err != nil { 448 | b.Fatal(err) 449 | } 450 | 451 | err = stmt.Close() 452 | if err != nil { 453 | b.Fatal(err) 454 | } 455 | 456 | var num int 457 | err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) 458 | if err != nil { 459 | b.Fatal(err) 460 | } 461 | 462 | if num != b.N { 463 | b.Fatalf("expected %d items, not %d", b.N, num) 464 | } 465 | } 466 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.1 2 | 3 | package pq 4 | 5 | import ( 6 | "bufio" 7 | "bytes" 8 | "database/sql" 9 | "database/sql/driver" 10 | "io" 11 | "math/rand" 12 | "net" 13 | "runtime" 14 | "strconv" 15 | "strings" 16 | "sync" 17 | "testing" 18 | "time" 19 | 20 | "github.com/lib/pq/oid" 21 | ) 22 | 23 | var ( 24 | selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'" 25 | selectSeriesQuery = "SELECT generate_series(1, 100)" 26 | ) 27 | 28 | func BenchmarkSelectString(b *testing.B) { 29 | var result string 30 | benchQuery(b, selectStringQuery, &result) 31 | } 32 | 33 | func BenchmarkSelectSeries(b *testing.B) { 34 | var result int 35 | benchQuery(b, selectSeriesQuery, &result) 36 | } 37 | 38 | func benchQuery(b *testing.B, query string, result interface{}) { 39 | b.StopTimer() 40 | db := openTestConn(b) 41 | defer db.Close() 42 | b.StartTimer() 43 | 44 | for i := 0; i < b.N; i++ { 45 | benchQueryLoop(b, db, query, result) 46 | } 47 | } 48 | 49 | func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) { 50 | rows, err := db.Query(query) 51 | if err != nil { 52 | b.Fatal(err) 53 | } 54 | defer rows.Close() 55 | for rows.Next() { 56 | err = rows.Scan(result) 57 | if err != nil { 58 | b.Fatal("failed to scan", err) 59 | } 60 | } 61 | } 62 | 63 | // reading from circularConn yields content[:prefixLen] once, followed by 64 | // content[prefixLen:] over and over again. It never returns EOF. 65 | type circularConn struct { 66 | content string 67 | prefixLen int 68 | pos int 69 | net.Conn // for all other net.Conn methods that will never be called 70 | } 71 | 72 | func (r *circularConn) Read(b []byte) (n int, err error) { 73 | n = copy(b, r.content[r.pos:]) 74 | r.pos += n 75 | if r.pos >= len(r.content) { 76 | r.pos = r.prefixLen 77 | } 78 | return 79 | } 80 | 81 | func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil } 82 | 83 | func (r *circularConn) Close() error { return nil } 84 | 85 | func fakeConn(content string, prefixLen int) *conn { 86 | c := &circularConn{content: content, prefixLen: prefixLen} 87 | return &conn{buf: bufio.NewReader(c), c: c} 88 | } 89 | 90 | // This benchmark is meant to be the same as BenchmarkSelectString, but takes 91 | // out some of the factors this package can't control. The numbers are less noisy, 92 | // but also the costs of network communication aren't accurately represented. 93 | func BenchmarkMockSelectString(b *testing.B) { 94 | b.StopTimer() 95 | // taken from a recorded run of BenchmarkSelectString 96 | // See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html 97 | const response = "1\x00\x00\x00\x04" + 98 | "t\x00\x00\x00\x06\x00\x00" + 99 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 100 | "Z\x00\x00\x00\x05I" + 101 | "2\x00\x00\x00\x04" + 102 | "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + 103 | "C\x00\x00\x00\rSELECT 1\x00" + 104 | "Z\x00\x00\x00\x05I" + 105 | "3\x00\x00\x00\x04" + 106 | "Z\x00\x00\x00\x05I" 107 | c := fakeConn(response, 0) 108 | b.StartTimer() 109 | 110 | for i := 0; i < b.N; i++ { 111 | benchMockQuery(b, c, selectStringQuery) 112 | } 113 | } 114 | 115 | var seriesRowData = func() string { 116 | var buf bytes.Buffer 117 | for i := 1; i <= 100; i++ { 118 | digits := byte(2) 119 | if i >= 100 { 120 | digits = 3 121 | } else if i < 10 { 122 | digits = 1 123 | } 124 | buf.WriteString("D\x00\x00\x00") 125 | buf.WriteByte(10 + digits) 126 | buf.WriteString("\x00\x01\x00\x00\x00") 127 | buf.WriteByte(digits) 128 | buf.WriteString(strconv.Itoa(i)) 129 | } 130 | return buf.String() 131 | }() 132 | 133 | func BenchmarkMockSelectSeries(b *testing.B) { 134 | b.StopTimer() 135 | var response = "1\x00\x00\x00\x04" + 136 | "t\x00\x00\x00\x06\x00\x00" + 137 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 138 | "Z\x00\x00\x00\x05I" + 139 | "2\x00\x00\x00\x04" + 140 | seriesRowData + 141 | "C\x00\x00\x00\x0fSELECT 100\x00" + 142 | "Z\x00\x00\x00\x05I" + 143 | "3\x00\x00\x00\x04" + 144 | "Z\x00\x00\x00\x05I" 145 | c := fakeConn(response, 0) 146 | b.StartTimer() 147 | 148 | for i := 0; i < b.N; i++ { 149 | benchMockQuery(b, c, selectSeriesQuery) 150 | } 151 | } 152 | 153 | func benchMockQuery(b *testing.B, c *conn, query string) { 154 | stmt, err := c.Prepare(query) 155 | if err != nil { 156 | b.Fatal(err) 157 | } 158 | defer stmt.Close() 159 | rows, err := stmt.Query(nil) 160 | if err != nil { 161 | b.Fatal(err) 162 | } 163 | defer rows.Close() 164 | var dest [1]driver.Value 165 | for { 166 | if err := rows.Next(dest[:]); err != nil { 167 | if err == io.EOF { 168 | break 169 | } 170 | b.Fatal(err) 171 | } 172 | } 173 | } 174 | 175 | func BenchmarkPreparedSelectString(b *testing.B) { 176 | var result string 177 | benchPreparedQuery(b, selectStringQuery, &result) 178 | } 179 | 180 | func BenchmarkPreparedSelectSeries(b *testing.B) { 181 | var result int 182 | benchPreparedQuery(b, selectSeriesQuery, &result) 183 | } 184 | 185 | func benchPreparedQuery(b *testing.B, query string, result interface{}) { 186 | b.StopTimer() 187 | db := openTestConn(b) 188 | defer db.Close() 189 | stmt, err := db.Prepare(query) 190 | if err != nil { 191 | b.Fatal(err) 192 | } 193 | defer stmt.Close() 194 | b.StartTimer() 195 | 196 | for i := 0; i < b.N; i++ { 197 | benchPreparedQueryLoop(b, db, stmt, result) 198 | } 199 | } 200 | 201 | func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) { 202 | rows, err := stmt.Query() 203 | if err != nil { 204 | b.Fatal(err) 205 | } 206 | if !rows.Next() { 207 | rows.Close() 208 | b.Fatal("no rows") 209 | } 210 | defer rows.Close() 211 | for rows.Next() { 212 | err = rows.Scan(&result) 213 | if err != nil { 214 | b.Fatal("failed to scan") 215 | } 216 | } 217 | } 218 | 219 | // See the comment for BenchmarkMockSelectString. 220 | func BenchmarkMockPreparedSelectString(b *testing.B) { 221 | b.StopTimer() 222 | const parseResponse = "1\x00\x00\x00\x04" + 223 | "t\x00\x00\x00\x06\x00\x00" + 224 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 225 | "Z\x00\x00\x00\x05I" 226 | const responses = parseResponse + 227 | "2\x00\x00\x00\x04" + 228 | "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + 229 | "C\x00\x00\x00\rSELECT 1\x00" + 230 | "Z\x00\x00\x00\x05I" 231 | c := fakeConn(responses, len(parseResponse)) 232 | 233 | stmt, err := c.Prepare(selectStringQuery) 234 | if err != nil { 235 | b.Fatal(err) 236 | } 237 | b.StartTimer() 238 | 239 | for i := 0; i < b.N; i++ { 240 | benchPreparedMockQuery(b, c, stmt) 241 | } 242 | } 243 | 244 | func BenchmarkMockPreparedSelectSeries(b *testing.B) { 245 | b.StopTimer() 246 | const parseResponse = "1\x00\x00\x00\x04" + 247 | "t\x00\x00\x00\x06\x00\x00" + 248 | "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + 249 | "Z\x00\x00\x00\x05I" 250 | var responses = parseResponse + 251 | "2\x00\x00\x00\x04" + 252 | seriesRowData + 253 | "C\x00\x00\x00\x0fSELECT 100\x00" + 254 | "Z\x00\x00\x00\x05I" 255 | c := fakeConn(responses, len(parseResponse)) 256 | 257 | stmt, err := c.Prepare(selectSeriesQuery) 258 | if err != nil { 259 | b.Fatal(err) 260 | } 261 | b.StartTimer() 262 | 263 | for i := 0; i < b.N; i++ { 264 | benchPreparedMockQuery(b, c, stmt) 265 | } 266 | } 267 | 268 | func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) { 269 | rows, err := stmt.Query(nil) 270 | if err != nil { 271 | b.Fatal(err) 272 | } 273 | defer rows.Close() 274 | var dest [1]driver.Value 275 | for { 276 | if err := rows.Next(dest[:]); err != nil { 277 | if err == io.EOF { 278 | break 279 | } 280 | b.Fatal(err) 281 | } 282 | } 283 | } 284 | 285 | func BenchmarkEncodeInt64(b *testing.B) { 286 | for i := 0; i < b.N; i++ { 287 | encode(¶meterStatus{}, int64(1234), oid.T_int8) 288 | } 289 | } 290 | 291 | func BenchmarkEncodeFloat64(b *testing.B) { 292 | for i := 0; i < b.N; i++ { 293 | encode(¶meterStatus{}, 3.14159, oid.T_float8) 294 | } 295 | } 296 | 297 | var testByteString = []byte("abcdefghijklmnopqrstuvwxyz") 298 | 299 | func BenchmarkEncodeByteaHex(b *testing.B) { 300 | for i := 0; i < b.N; i++ { 301 | encode(¶meterStatus{serverVersion: 90000}, testByteString, oid.T_bytea) 302 | } 303 | } 304 | func BenchmarkEncodeByteaEscape(b *testing.B) { 305 | for i := 0; i < b.N; i++ { 306 | encode(¶meterStatus{serverVersion: 84000}, testByteString, oid.T_bytea) 307 | } 308 | } 309 | 310 | func BenchmarkEncodeBool(b *testing.B) { 311 | for i := 0; i < b.N; i++ { 312 | encode(¶meterStatus{}, true, oid.T_bool) 313 | } 314 | } 315 | 316 | var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local) 317 | 318 | func BenchmarkEncodeTimestamptz(b *testing.B) { 319 | for i := 0; i < b.N; i++ { 320 | encode(¶meterStatus{}, testTimestamptz, oid.T_timestamptz) 321 | } 322 | } 323 | 324 | var testIntBytes = []byte("1234") 325 | 326 | func BenchmarkDecodeInt64(b *testing.B) { 327 | for i := 0; i < b.N; i++ { 328 | decode(¶meterStatus{}, testIntBytes, oid.T_int8, formatText) 329 | } 330 | } 331 | 332 | var testFloatBytes = []byte("3.14159") 333 | 334 | func BenchmarkDecodeFloat64(b *testing.B) { 335 | for i := 0; i < b.N; i++ { 336 | decode(¶meterStatus{}, testFloatBytes, oid.T_float8, formatText) 337 | } 338 | } 339 | 340 | var testBoolBytes = []byte{'t'} 341 | 342 | func BenchmarkDecodeBool(b *testing.B) { 343 | for i := 0; i < b.N; i++ { 344 | decode(¶meterStatus{}, testBoolBytes, oid.T_bool, formatText) 345 | } 346 | } 347 | 348 | func TestDecodeBool(t *testing.T) { 349 | db := openTestConn(t) 350 | rows, err := db.Query("select true") 351 | if err != nil { 352 | t.Fatal(err) 353 | } 354 | rows.Close() 355 | } 356 | 357 | var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07") 358 | 359 | func BenchmarkDecodeTimestamptz(b *testing.B) { 360 | for i := 0; i < b.N; i++ { 361 | decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) 362 | } 363 | } 364 | 365 | func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) { 366 | oldProcs := runtime.GOMAXPROCS(0) 367 | defer runtime.GOMAXPROCS(oldProcs) 368 | runtime.GOMAXPROCS(runtime.NumCPU()) 369 | globalLocationCache = newLocationCache() 370 | 371 | f := func(wg *sync.WaitGroup, loops int) { 372 | defer wg.Done() 373 | for i := 0; i < loops; i++ { 374 | decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) 375 | } 376 | } 377 | 378 | wg := &sync.WaitGroup{} 379 | b.ResetTimer() 380 | for j := 0; j < 10; j++ { 381 | wg.Add(1) 382 | go f(wg, b.N/10) 383 | } 384 | wg.Wait() 385 | } 386 | 387 | func BenchmarkLocationCache(b *testing.B) { 388 | globalLocationCache = newLocationCache() 389 | for i := 0; i < b.N; i++ { 390 | globalLocationCache.getLocation(rand.Intn(10000)) 391 | } 392 | } 393 | 394 | func BenchmarkLocationCacheMultiThread(b *testing.B) { 395 | oldProcs := runtime.GOMAXPROCS(0) 396 | defer runtime.GOMAXPROCS(oldProcs) 397 | runtime.GOMAXPROCS(runtime.NumCPU()) 398 | globalLocationCache = newLocationCache() 399 | 400 | f := func(wg *sync.WaitGroup, loops int) { 401 | defer wg.Done() 402 | for i := 0; i < loops; i++ { 403 | globalLocationCache.getLocation(rand.Intn(10000)) 404 | } 405 | } 406 | 407 | wg := &sync.WaitGroup{} 408 | b.ResetTimer() 409 | for j := 0; j < 10; j++ { 410 | wg.Add(1) 411 | go f(wg, b.N/10) 412 | } 413 | wg.Wait() 414 | } 415 | 416 | // Stress test the performance of parsing results from the wire. 417 | func BenchmarkResultParsing(b *testing.B) { 418 | b.StopTimer() 419 | 420 | db := openTestConn(b) 421 | defer db.Close() 422 | _, err := db.Exec("BEGIN") 423 | if err != nil { 424 | b.Fatal(err) 425 | } 426 | 427 | b.StartTimer() 428 | for i := 0; i < b.N; i++ { 429 | res, err := db.Query("SELECT generate_series(1, 50000)") 430 | if err != nil { 431 | b.Fatal(err) 432 | } 433 | res.Close() 434 | } 435 | } 436 | -------------------------------------------------------------------------------- /notify_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "os" 8 | "runtime" 9 | "sync" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | var errNilNotification = errors.New("nil notification") 16 | 17 | func expectNotification(t *testing.T, ch <-chan *Notification, relname string, extra string) error { 18 | select { 19 | case n := <-ch: 20 | if n == nil { 21 | return errNilNotification 22 | } 23 | if n.Channel != relname || n.Extra != extra { 24 | return fmt.Errorf("unexpected notification %v", n) 25 | } 26 | return nil 27 | case <-time.After(1500 * time.Millisecond): 28 | return fmt.Errorf("timeout") 29 | } 30 | } 31 | 32 | func expectNoNotification(t *testing.T, ch <-chan *Notification) error { 33 | select { 34 | case n := <-ch: 35 | return fmt.Errorf("unexpected notification %v", n) 36 | case <-time.After(100 * time.Millisecond): 37 | return nil 38 | } 39 | } 40 | 41 | func expectEvent(t *testing.T, eventch <-chan ListenerEventType, et ListenerEventType) error { 42 | select { 43 | case e := <-eventch: 44 | if e != et { 45 | return fmt.Errorf("unexpected event %v", e) 46 | } 47 | return nil 48 | case <-time.After(1500 * time.Millisecond): 49 | panic("expectEvent timeout") 50 | } 51 | } 52 | 53 | func expectNoEvent(t *testing.T, eventch <-chan ListenerEventType) error { 54 | select { 55 | case e := <-eventch: 56 | return fmt.Errorf("unexpected event %v", e) 57 | case <-time.After(100 * time.Millisecond): 58 | return nil 59 | } 60 | } 61 | 62 | func newTestListenerConn(t *testing.T) (*ListenerConn, <-chan *Notification) { 63 | datname := os.Getenv("PGDATABASE") 64 | sslmode := os.Getenv("PGSSLMODE") 65 | 66 | if datname == "" { 67 | os.Setenv("PGDATABASE", "pqgotest") 68 | } 69 | 70 | if sslmode == "" { 71 | os.Setenv("PGSSLMODE", "disable") 72 | } 73 | 74 | notificationChan := make(chan *Notification) 75 | l, err := NewListenerConn("", notificationChan) 76 | if err != nil { 77 | t.Fatal(err) 78 | } 79 | 80 | return l, notificationChan 81 | } 82 | 83 | func TestNewListenerConn(t *testing.T) { 84 | l, _ := newTestListenerConn(t) 85 | 86 | defer l.Close() 87 | } 88 | 89 | func TestConnListen(t *testing.T) { 90 | l, channel := newTestListenerConn(t) 91 | 92 | defer l.Close() 93 | 94 | db := openTestConn(t) 95 | defer db.Close() 96 | 97 | ok, err := l.Listen("notify_test") 98 | if !ok || err != nil { 99 | t.Fatal(err) 100 | } 101 | 102 | _, err = db.Exec("NOTIFY notify_test") 103 | if err != nil { 104 | t.Fatal(err) 105 | } 106 | 107 | err = expectNotification(t, channel, "notify_test", "") 108 | if err != nil { 109 | t.Fatal(err) 110 | } 111 | } 112 | 113 | func TestConnUnlisten(t *testing.T) { 114 | l, channel := newTestListenerConn(t) 115 | 116 | defer l.Close() 117 | 118 | db := openTestConn(t) 119 | defer db.Close() 120 | 121 | ok, err := l.Listen("notify_test") 122 | if !ok || err != nil { 123 | t.Fatal(err) 124 | } 125 | 126 | _, err = db.Exec("NOTIFY notify_test") 127 | 128 | err = expectNotification(t, channel, "notify_test", "") 129 | if err != nil { 130 | t.Fatal(err) 131 | } 132 | 133 | ok, err = l.Unlisten("notify_test") 134 | if !ok || err != nil { 135 | t.Fatal(err) 136 | } 137 | 138 | _, err = db.Exec("NOTIFY notify_test") 139 | if err != nil { 140 | t.Fatal(err) 141 | } 142 | 143 | err = expectNoNotification(t, channel) 144 | if err != nil { 145 | t.Fatal(err) 146 | } 147 | } 148 | 149 | func TestConnUnlistenAll(t *testing.T) { 150 | l, channel := newTestListenerConn(t) 151 | 152 | defer l.Close() 153 | 154 | db := openTestConn(t) 155 | defer db.Close() 156 | 157 | ok, err := l.Listen("notify_test") 158 | if !ok || err != nil { 159 | t.Fatal(err) 160 | } 161 | 162 | _, err = db.Exec("NOTIFY notify_test") 163 | 164 | err = expectNotification(t, channel, "notify_test", "") 165 | if err != nil { 166 | t.Fatal(err) 167 | } 168 | 169 | ok, err = l.UnlistenAll() 170 | if !ok || err != nil { 171 | t.Fatal(err) 172 | } 173 | 174 | _, err = db.Exec("NOTIFY notify_test") 175 | if err != nil { 176 | t.Fatal(err) 177 | } 178 | 179 | err = expectNoNotification(t, channel) 180 | if err != nil { 181 | t.Fatal(err) 182 | } 183 | } 184 | 185 | func TestConnClose(t *testing.T) { 186 | l, _ := newTestListenerConn(t) 187 | defer l.Close() 188 | 189 | err := l.Close() 190 | if err != nil { 191 | t.Fatal(err) 192 | } 193 | err = l.Close() 194 | if err != errListenerConnClosed { 195 | t.Fatalf("expected errListenerConnClosed; got %v", err) 196 | } 197 | } 198 | 199 | func TestConnPing(t *testing.T) { 200 | l, _ := newTestListenerConn(t) 201 | defer l.Close() 202 | err := l.Ping() 203 | if err != nil { 204 | t.Fatal(err) 205 | } 206 | err = l.Close() 207 | if err != nil { 208 | t.Fatal(err) 209 | } 210 | err = l.Ping() 211 | if err != errListenerConnClosed { 212 | t.Fatalf("expected errListenerConnClosed; got %v", err) 213 | } 214 | } 215 | 216 | // Test for deadlock where a query fails while another one is queued 217 | func TestConnExecDeadlock(t *testing.T) { 218 | l, _ := newTestListenerConn(t) 219 | defer l.Close() 220 | 221 | var wg sync.WaitGroup 222 | wg.Add(2) 223 | 224 | go func() { 225 | l.ExecSimpleQuery("SELECT pg_sleep(60)") 226 | wg.Done() 227 | }() 228 | runtime.Gosched() 229 | go func() { 230 | l.ExecSimpleQuery("SELECT 1") 231 | wg.Done() 232 | }() 233 | // give the two goroutines some time to get into position 234 | runtime.Gosched() 235 | // calls Close on the net.Conn; equivalent to a network failure 236 | l.Close() 237 | 238 | var done int32 = 0 239 | go func() { 240 | time.Sleep(10 * time.Second) 241 | if atomic.LoadInt32(&done) != 1 { 242 | panic("timed out") 243 | } 244 | }() 245 | wg.Wait() 246 | atomic.StoreInt32(&done, 1) 247 | } 248 | 249 | // Test for ListenerConn being closed while a slow query is executing 250 | func TestListenerConnCloseWhileQueryIsExecuting(t *testing.T) { 251 | l, _ := newTestListenerConn(t) 252 | defer l.Close() 253 | 254 | var wg sync.WaitGroup 255 | wg.Add(1) 256 | 257 | go func() { 258 | sent, err := l.ExecSimpleQuery("SELECT pg_sleep(60)") 259 | if sent { 260 | panic("expected sent=false") 261 | } 262 | // could be any of a number of errors 263 | if err == nil { 264 | panic("expected error") 265 | } 266 | wg.Done() 267 | }() 268 | // give the above goroutine some time to get into position 269 | runtime.Gosched() 270 | err := l.Close() 271 | if err != nil { 272 | t.Fatal(err) 273 | } 274 | var done int32 = 0 275 | go func() { 276 | time.Sleep(10 * time.Second) 277 | if atomic.LoadInt32(&done) != 1 { 278 | panic("timed out") 279 | } 280 | }() 281 | wg.Wait() 282 | atomic.StoreInt32(&done, 1) 283 | } 284 | 285 | func TestNotifyExtra(t *testing.T) { 286 | db := openTestConn(t) 287 | defer db.Close() 288 | 289 | if getServerVersion(t, db) < 90000 { 290 | t.Skip("skipping NOTIFY payload test since the server does not appear to support it") 291 | } 292 | 293 | l, channel := newTestListenerConn(t) 294 | defer l.Close() 295 | 296 | ok, err := l.Listen("notify_test") 297 | if !ok || err != nil { 298 | t.Fatal(err) 299 | } 300 | 301 | _, err = db.Exec("NOTIFY notify_test, 'something'") 302 | if err != nil { 303 | t.Fatal(err) 304 | } 305 | 306 | err = expectNotification(t, channel, "notify_test", "something") 307 | if err != nil { 308 | t.Fatal(err) 309 | } 310 | } 311 | 312 | // create a new test listener and also set the timeouts 313 | func newTestListenerTimeout(t *testing.T, min time.Duration, max time.Duration) (*Listener, <-chan ListenerEventType) { 314 | datname := os.Getenv("PGDATABASE") 315 | sslmode := os.Getenv("PGSSLMODE") 316 | 317 | if datname == "" { 318 | os.Setenv("PGDATABASE", "pqgotest") 319 | } 320 | 321 | if sslmode == "" { 322 | os.Setenv("PGSSLMODE", "disable") 323 | } 324 | 325 | eventch := make(chan ListenerEventType, 16) 326 | l := NewListener("", min, max, func(t ListenerEventType, err error) { eventch <- t }) 327 | err := expectEvent(t, eventch, ListenerEventConnected) 328 | if err != nil { 329 | t.Fatal(err) 330 | } 331 | return l, eventch 332 | } 333 | 334 | func newTestListener(t *testing.T) (*Listener, <-chan ListenerEventType) { 335 | return newTestListenerTimeout(t, time.Hour, time.Hour) 336 | } 337 | 338 | func TestListenerListen(t *testing.T) { 339 | l, _ := newTestListener(t) 340 | defer l.Close() 341 | 342 | db := openTestConn(t) 343 | defer db.Close() 344 | 345 | err := l.Listen("notify_listen_test") 346 | if err != nil { 347 | t.Fatal(err) 348 | } 349 | 350 | _, err = db.Exec("NOTIFY notify_listen_test") 351 | if err != nil { 352 | t.Fatal(err) 353 | } 354 | 355 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 356 | if err != nil { 357 | t.Fatal(err) 358 | } 359 | } 360 | 361 | func TestListenerUnlisten(t *testing.T) { 362 | l, _ := newTestListener(t) 363 | defer l.Close() 364 | 365 | db := openTestConn(t) 366 | defer db.Close() 367 | 368 | err := l.Listen("notify_listen_test") 369 | if err != nil { 370 | t.Fatal(err) 371 | } 372 | 373 | _, err = db.Exec("NOTIFY notify_listen_test") 374 | if err != nil { 375 | t.Fatal(err) 376 | } 377 | 378 | err = l.Unlisten("notify_listen_test") 379 | if err != nil { 380 | t.Fatal(err) 381 | } 382 | 383 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 384 | if err != nil { 385 | t.Fatal(err) 386 | } 387 | 388 | _, err = db.Exec("NOTIFY notify_listen_test") 389 | if err != nil { 390 | t.Fatal(err) 391 | } 392 | 393 | err = expectNoNotification(t, l.Notify) 394 | if err != nil { 395 | t.Fatal(err) 396 | } 397 | } 398 | 399 | func TestListenerUnlistenAll(t *testing.T) { 400 | l, _ := newTestListener(t) 401 | defer l.Close() 402 | 403 | db := openTestConn(t) 404 | defer db.Close() 405 | 406 | err := l.Listen("notify_listen_test") 407 | if err != nil { 408 | t.Fatal(err) 409 | } 410 | 411 | _, err = db.Exec("NOTIFY notify_listen_test") 412 | if err != nil { 413 | t.Fatal(err) 414 | } 415 | 416 | err = l.UnlistenAll() 417 | if err != nil { 418 | t.Fatal(err) 419 | } 420 | 421 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 422 | if err != nil { 423 | t.Fatal(err) 424 | } 425 | 426 | _, err = db.Exec("NOTIFY notify_listen_test") 427 | if err != nil { 428 | t.Fatal(err) 429 | } 430 | 431 | err = expectNoNotification(t, l.Notify) 432 | if err != nil { 433 | t.Fatal(err) 434 | } 435 | } 436 | 437 | func TestListenerFailedQuery(t *testing.T) { 438 | l, eventch := newTestListener(t) 439 | defer l.Close() 440 | 441 | db := openTestConn(t) 442 | defer db.Close() 443 | 444 | err := l.Listen("notify_listen_test") 445 | if err != nil { 446 | t.Fatal(err) 447 | } 448 | 449 | _, err = db.Exec("NOTIFY notify_listen_test") 450 | if err != nil { 451 | t.Fatal(err) 452 | } 453 | 454 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 455 | if err != nil { 456 | t.Fatal(err) 457 | } 458 | 459 | // shouldn't cause a disconnect 460 | ok, err := l.cn.ExecSimpleQuery("SELECT error") 461 | if !ok { 462 | t.Fatalf("could not send query to server: %v", err) 463 | } 464 | _, ok = err.(PGError) 465 | if !ok { 466 | t.Fatalf("unexpected error %v", err) 467 | } 468 | err = expectNoEvent(t, eventch) 469 | if err != nil { 470 | t.Fatal(err) 471 | } 472 | 473 | // should still work 474 | _, err = db.Exec("NOTIFY notify_listen_test") 475 | if err != nil { 476 | t.Fatal(err) 477 | } 478 | 479 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 480 | if err != nil { 481 | t.Fatal(err) 482 | } 483 | } 484 | 485 | func TestListenerReconnect(t *testing.T) { 486 | l, eventch := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) 487 | defer l.Close() 488 | 489 | db := openTestConn(t) 490 | defer db.Close() 491 | 492 | err := l.Listen("notify_listen_test") 493 | if err != nil { 494 | t.Fatal(err) 495 | } 496 | 497 | _, err = db.Exec("NOTIFY notify_listen_test") 498 | if err != nil { 499 | t.Fatal(err) 500 | } 501 | 502 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 503 | if err != nil { 504 | t.Fatal(err) 505 | } 506 | 507 | // kill the connection and make sure it comes back up 508 | ok, err := l.cn.ExecSimpleQuery("SELECT pg_terminate_backend(pg_backend_pid())") 509 | if ok { 510 | t.Fatalf("could not kill the connection: %v", err) 511 | } 512 | if err != io.EOF { 513 | t.Fatalf("unexpected error %v", err) 514 | } 515 | err = expectEvent(t, eventch, ListenerEventDisconnected) 516 | if err != nil { 517 | t.Fatal(err) 518 | } 519 | err = expectEvent(t, eventch, ListenerEventReconnected) 520 | if err != nil { 521 | t.Fatal(err) 522 | } 523 | 524 | // should still work 525 | _, err = db.Exec("NOTIFY notify_listen_test") 526 | if err != nil { 527 | t.Fatal(err) 528 | } 529 | 530 | // should get nil after Reconnected 531 | err = expectNotification(t, l.Notify, "", "") 532 | if err != errNilNotification { 533 | t.Fatal(err) 534 | } 535 | 536 | err = expectNotification(t, l.Notify, "notify_listen_test", "") 537 | if err != nil { 538 | t.Fatal(err) 539 | } 540 | } 541 | 542 | func TestListenerClose(t *testing.T) { 543 | l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) 544 | defer l.Close() 545 | 546 | err := l.Close() 547 | if err != nil { 548 | t.Fatal(err) 549 | } 550 | err = l.Close() 551 | if err != errListenerClosed { 552 | t.Fatalf("expected errListenerClosed; got %v", err) 553 | } 554 | } 555 | 556 | func TestListenerPing(t *testing.T) { 557 | l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) 558 | defer l.Close() 559 | 560 | err := l.Ping() 561 | if err != nil { 562 | t.Fatal(err) 563 | } 564 | 565 | err = l.Close() 566 | if err != nil { 567 | t.Fatal(err) 568 | } 569 | 570 | err = l.Ping() 571 | if err != errListenerClosed { 572 | t.Fatalf("expected errListenerClosed; got %v", err) 573 | } 574 | } 575 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "database/sql/driver" 5 | "fmt" 6 | "io" 7 | "net" 8 | "runtime" 9 | ) 10 | 11 | // Error severities 12 | const ( 13 | Efatal = "FATAL" 14 | Epanic = "PANIC" 15 | Ewarning = "WARNING" 16 | Enotice = "NOTICE" 17 | Edebug = "DEBUG" 18 | Einfo = "INFO" 19 | Elog = "LOG" 20 | ) 21 | 22 | // Error represents an error communicating with the server. 23 | // 24 | // See http://www.postgresql.org/docs/current/static/protocol-error-fields.html for details of the fields 25 | type Error struct { 26 | Severity string 27 | Code ErrorCode 28 | Message string 29 | Detail string 30 | Hint string 31 | Position string 32 | InternalPosition string 33 | InternalQuery string 34 | Where string 35 | Schema string 36 | Table string 37 | Column string 38 | DataTypeName string 39 | Constraint string 40 | File string 41 | Line string 42 | Routine string 43 | } 44 | 45 | // ErrorCode is a five-character error code. 46 | type ErrorCode string 47 | 48 | // Name returns a more human friendly rendering of the error code, namely the 49 | // "condition name". 50 | // 51 | // See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for 52 | // details. 53 | func (ec ErrorCode) Name() string { 54 | return errorCodeNames[ec] 55 | } 56 | 57 | // ErrorClass is only the class part of an error code. 58 | type ErrorClass string 59 | 60 | // Name returns the condition name of an error class. It is equivalent to the 61 | // condition name of the "standard" error code (i.e. the one having the last 62 | // three characters "000"). 63 | func (ec ErrorClass) Name() string { 64 | return errorCodeNames[ErrorCode(ec+"000")] 65 | } 66 | 67 | // Class returns the error class, e.g. "28". 68 | // 69 | // See http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html for 70 | // details. 71 | func (ec ErrorCode) Class() ErrorClass { 72 | return ErrorClass(ec[0:2]) 73 | } 74 | 75 | // errorCodeNames is a mapping between the five-character error codes and the 76 | // human readable "condition names". It is derived from the list at 77 | // http://www.postgresql.org/docs/9.3/static/errcodes-appendix.html 78 | var errorCodeNames = map[ErrorCode]string{ 79 | // Class 00 - Successful Completion 80 | "00000": "successful_completion", 81 | // Class 01 - Warning 82 | "01000": "warning", 83 | "0100C": "dynamic_result_sets_returned", 84 | "01008": "implicit_zero_bit_padding", 85 | "01003": "null_value_eliminated_in_set_function", 86 | "01007": "privilege_not_granted", 87 | "01006": "privilege_not_revoked", 88 | "01004": "string_data_right_truncation", 89 | "01P01": "deprecated_feature", 90 | // Class 02 - No Data (this is also a warning class per the SQL standard) 91 | "02000": "no_data", 92 | "02001": "no_additional_dynamic_result_sets_returned", 93 | // Class 03 - SQL Statement Not Yet Complete 94 | "03000": "sql_statement_not_yet_complete", 95 | // Class 08 - Connection Exception 96 | "08000": "connection_exception", 97 | "08003": "connection_does_not_exist", 98 | "08006": "connection_failure", 99 | "08001": "sqlclient_unable_to_establish_sqlconnection", 100 | "08004": "sqlserver_rejected_establishment_of_sqlconnection", 101 | "08007": "transaction_resolution_unknown", 102 | "08P01": "protocol_violation", 103 | // Class 09 - Triggered Action Exception 104 | "09000": "triggered_action_exception", 105 | // Class 0A - Feature Not Supported 106 | "0A000": "feature_not_supported", 107 | // Class 0B - Invalid Transaction Initiation 108 | "0B000": "invalid_transaction_initiation", 109 | // Class 0F - Locator Exception 110 | "0F000": "locator_exception", 111 | "0F001": "invalid_locator_specification", 112 | // Class 0L - Invalid Grantor 113 | "0L000": "invalid_grantor", 114 | "0LP01": "invalid_grant_operation", 115 | // Class 0P - Invalid Role Specification 116 | "0P000": "invalid_role_specification", 117 | // Class 0Z - Diagnostics Exception 118 | "0Z000": "diagnostics_exception", 119 | "0Z002": "stacked_diagnostics_accessed_without_active_handler", 120 | // Class 20 - Case Not Found 121 | "20000": "case_not_found", 122 | // Class 21 - Cardinality Violation 123 | "21000": "cardinality_violation", 124 | // Class 22 - Data Exception 125 | "22000": "data_exception", 126 | "2202E": "array_subscript_error", 127 | "22021": "character_not_in_repertoire", 128 | "22008": "datetime_field_overflow", 129 | "22012": "division_by_zero", 130 | "22005": "error_in_assignment", 131 | "2200B": "escape_character_conflict", 132 | "22022": "indicator_overflow", 133 | "22015": "interval_field_overflow", 134 | "2201E": "invalid_argument_for_logarithm", 135 | "22014": "invalid_argument_for_ntile_function", 136 | "22016": "invalid_argument_for_nth_value_function", 137 | "2201F": "invalid_argument_for_power_function", 138 | "2201G": "invalid_argument_for_width_bucket_function", 139 | "22018": "invalid_character_value_for_cast", 140 | "22007": "invalid_datetime_format", 141 | "22019": "invalid_escape_character", 142 | "2200D": "invalid_escape_octet", 143 | "22025": "invalid_escape_sequence", 144 | "22P06": "nonstandard_use_of_escape_character", 145 | "22010": "invalid_indicator_parameter_value", 146 | "22023": "invalid_parameter_value", 147 | "2201B": "invalid_regular_expression", 148 | "2201W": "invalid_row_count_in_limit_clause", 149 | "2201X": "invalid_row_count_in_result_offset_clause", 150 | "22009": "invalid_time_zone_displacement_value", 151 | "2200C": "invalid_use_of_escape_character", 152 | "2200G": "most_specific_type_mismatch", 153 | "22004": "null_value_not_allowed", 154 | "22002": "null_value_no_indicator_parameter", 155 | "22003": "numeric_value_out_of_range", 156 | "22026": "string_data_length_mismatch", 157 | "22001": "string_data_right_truncation", 158 | "22011": "substring_error", 159 | "22027": "trim_error", 160 | "22024": "unterminated_c_string", 161 | "2200F": "zero_length_character_string", 162 | "22P01": "floating_point_exception", 163 | "22P02": "invalid_text_representation", 164 | "22P03": "invalid_binary_representation", 165 | "22P04": "bad_copy_file_format", 166 | "22P05": "untranslatable_character", 167 | "2200L": "not_an_xml_document", 168 | "2200M": "invalid_xml_document", 169 | "2200N": "invalid_xml_content", 170 | "2200S": "invalid_xml_comment", 171 | "2200T": "invalid_xml_processing_instruction", 172 | // Class 23 - Integrity Constraint Violation 173 | "23000": "integrity_constraint_violation", 174 | "23001": "restrict_violation", 175 | "23502": "not_null_violation", 176 | "23503": "foreign_key_violation", 177 | "23505": "unique_violation", 178 | "23514": "check_violation", 179 | "23P01": "exclusion_violation", 180 | // Class 24 - Invalid Cursor State 181 | "24000": "invalid_cursor_state", 182 | // Class 25 - Invalid Transaction State 183 | "25000": "invalid_transaction_state", 184 | "25001": "active_sql_transaction", 185 | "25002": "branch_transaction_already_active", 186 | "25008": "held_cursor_requires_same_isolation_level", 187 | "25003": "inappropriate_access_mode_for_branch_transaction", 188 | "25004": "inappropriate_isolation_level_for_branch_transaction", 189 | "25005": "no_active_sql_transaction_for_branch_transaction", 190 | "25006": "read_only_sql_transaction", 191 | "25007": "schema_and_data_statement_mixing_not_supported", 192 | "25P01": "no_active_sql_transaction", 193 | "25P02": "in_failed_sql_transaction", 194 | // Class 26 - Invalid SQL Statement Name 195 | "26000": "invalid_sql_statement_name", 196 | // Class 27 - Triggered Data Change Violation 197 | "27000": "triggered_data_change_violation", 198 | // Class 28 - Invalid Authorization Specification 199 | "28000": "invalid_authorization_specification", 200 | "28P01": "invalid_password", 201 | // Class 2B - Dependent Privilege Descriptors Still Exist 202 | "2B000": "dependent_privilege_descriptors_still_exist", 203 | "2BP01": "dependent_objects_still_exist", 204 | // Class 2D - Invalid Transaction Termination 205 | "2D000": "invalid_transaction_termination", 206 | // Class 2F - SQL Routine Exception 207 | "2F000": "sql_routine_exception", 208 | "2F005": "function_executed_no_return_statement", 209 | "2F002": "modifying_sql_data_not_permitted", 210 | "2F003": "prohibited_sql_statement_attempted", 211 | "2F004": "reading_sql_data_not_permitted", 212 | // Class 34 - Invalid Cursor Name 213 | "34000": "invalid_cursor_name", 214 | // Class 38 - External Routine Exception 215 | "38000": "external_routine_exception", 216 | "38001": "containing_sql_not_permitted", 217 | "38002": "modifying_sql_data_not_permitted", 218 | "38003": "prohibited_sql_statement_attempted", 219 | "38004": "reading_sql_data_not_permitted", 220 | // Class 39 - External Routine Invocation Exception 221 | "39000": "external_routine_invocation_exception", 222 | "39001": "invalid_sqlstate_returned", 223 | "39004": "null_value_not_allowed", 224 | "39P01": "trigger_protocol_violated", 225 | "39P02": "srf_protocol_violated", 226 | // Class 3B - Savepoint Exception 227 | "3B000": "savepoint_exception", 228 | "3B001": "invalid_savepoint_specification", 229 | // Class 3D - Invalid Catalog Name 230 | "3D000": "invalid_catalog_name", 231 | // Class 3F - Invalid Schema Name 232 | "3F000": "invalid_schema_name", 233 | // Class 40 - Transaction Rollback 234 | "40000": "transaction_rollback", 235 | "40002": "transaction_integrity_constraint_violation", 236 | "40001": "serialization_failure", 237 | "40003": "statement_completion_unknown", 238 | "40P01": "deadlock_detected", 239 | // Class 42 - Syntax Error or Access Rule Violation 240 | "42000": "syntax_error_or_access_rule_violation", 241 | "42601": "syntax_error", 242 | "42501": "insufficient_privilege", 243 | "42846": "cannot_coerce", 244 | "42803": "grouping_error", 245 | "42P20": "windowing_error", 246 | "42P19": "invalid_recursion", 247 | "42830": "invalid_foreign_key", 248 | "42602": "invalid_name", 249 | "42622": "name_too_long", 250 | "42939": "reserved_name", 251 | "42804": "datatype_mismatch", 252 | "42P18": "indeterminate_datatype", 253 | "42P21": "collation_mismatch", 254 | "42P22": "indeterminate_collation", 255 | "42809": "wrong_object_type", 256 | "42703": "undefined_column", 257 | "42883": "undefined_function", 258 | "42P01": "undefined_table", 259 | "42P02": "undefined_parameter", 260 | "42704": "undefined_object", 261 | "42701": "duplicate_column", 262 | "42P03": "duplicate_cursor", 263 | "42P04": "duplicate_database", 264 | "42723": "duplicate_function", 265 | "42P05": "duplicate_prepared_statement", 266 | "42P06": "duplicate_schema", 267 | "42P07": "duplicate_table", 268 | "42712": "duplicate_alias", 269 | "42710": "duplicate_object", 270 | "42702": "ambiguous_column", 271 | "42725": "ambiguous_function", 272 | "42P08": "ambiguous_parameter", 273 | "42P09": "ambiguous_alias", 274 | "42P10": "invalid_column_reference", 275 | "42611": "invalid_column_definition", 276 | "42P11": "invalid_cursor_definition", 277 | "42P12": "invalid_database_definition", 278 | "42P13": "invalid_function_definition", 279 | "42P14": "invalid_prepared_statement_definition", 280 | "42P15": "invalid_schema_definition", 281 | "42P16": "invalid_table_definition", 282 | "42P17": "invalid_object_definition", 283 | // Class 44 - WITH CHECK OPTION Violation 284 | "44000": "with_check_option_violation", 285 | // Class 53 - Insufficient Resources 286 | "53000": "insufficient_resources", 287 | "53100": "disk_full", 288 | "53200": "out_of_memory", 289 | "53300": "too_many_connections", 290 | "53400": "configuration_limit_exceeded", 291 | // Class 54 - Program Limit Exceeded 292 | "54000": "program_limit_exceeded", 293 | "54001": "statement_too_complex", 294 | "54011": "too_many_columns", 295 | "54023": "too_many_arguments", 296 | // Class 55 - Object Not In Prerequisite State 297 | "55000": "object_not_in_prerequisite_state", 298 | "55006": "object_in_use", 299 | "55P02": "cant_change_runtime_param", 300 | "55P03": "lock_not_available", 301 | // Class 57 - Operator Intervention 302 | "57000": "operator_intervention", 303 | "57014": "query_canceled", 304 | "57P01": "admin_shutdown", 305 | "57P02": "crash_shutdown", 306 | "57P03": "cannot_connect_now", 307 | "57P04": "database_dropped", 308 | // Class 58 - System Error (errors external to PostgreSQL itself) 309 | "58000": "system_error", 310 | "58030": "io_error", 311 | "58P01": "undefined_file", 312 | "58P02": "duplicate_file", 313 | // Class F0 - Configuration File Error 314 | "F0000": "config_file_error", 315 | "F0001": "lock_file_exists", 316 | // Class HV - Foreign Data Wrapper Error (SQL/MED) 317 | "HV000": "fdw_error", 318 | "HV005": "fdw_column_name_not_found", 319 | "HV002": "fdw_dynamic_parameter_value_needed", 320 | "HV010": "fdw_function_sequence_error", 321 | "HV021": "fdw_inconsistent_descriptor_information", 322 | "HV024": "fdw_invalid_attribute_value", 323 | "HV007": "fdw_invalid_column_name", 324 | "HV008": "fdw_invalid_column_number", 325 | "HV004": "fdw_invalid_data_type", 326 | "HV006": "fdw_invalid_data_type_descriptors", 327 | "HV091": "fdw_invalid_descriptor_field_identifier", 328 | "HV00B": "fdw_invalid_handle", 329 | "HV00C": "fdw_invalid_option_index", 330 | "HV00D": "fdw_invalid_option_name", 331 | "HV090": "fdw_invalid_string_length_or_buffer_length", 332 | "HV00A": "fdw_invalid_string_format", 333 | "HV009": "fdw_invalid_use_of_null_pointer", 334 | "HV014": "fdw_too_many_handles", 335 | "HV001": "fdw_out_of_memory", 336 | "HV00P": "fdw_no_schemas", 337 | "HV00J": "fdw_option_name_not_found", 338 | "HV00K": "fdw_reply_handle", 339 | "HV00Q": "fdw_schema_not_found", 340 | "HV00R": "fdw_table_not_found", 341 | "HV00L": "fdw_unable_to_create_execution", 342 | "HV00M": "fdw_unable_to_create_reply", 343 | "HV00N": "fdw_unable_to_establish_connection", 344 | // Class P0 - PL/pgSQL Error 345 | "P0000": "plpgsql_error", 346 | "P0001": "raise_exception", 347 | "P0002": "no_data_found", 348 | "P0003": "too_many_rows", 349 | // Class XX - Internal Error 350 | "XX000": "internal_error", 351 | "XX001": "data_corrupted", 352 | "XX002": "index_corrupted", 353 | } 354 | 355 | func parseError(r *readBuf) *Error { 356 | err := new(Error) 357 | for t := r.byte(); t != 0; t = r.byte() { 358 | msg := r.string() 359 | switch t { 360 | case 'S': 361 | err.Severity = msg 362 | case 'C': 363 | err.Code = ErrorCode(msg) 364 | case 'M': 365 | err.Message = msg 366 | case 'D': 367 | err.Detail = msg 368 | case 'H': 369 | err.Hint = msg 370 | case 'P': 371 | err.Position = msg 372 | case 'p': 373 | err.InternalPosition = msg 374 | case 'q': 375 | err.InternalQuery = msg 376 | case 'W': 377 | err.Where = msg 378 | case 's': 379 | err.Schema = msg 380 | case 't': 381 | err.Table = msg 382 | case 'c': 383 | err.Column = msg 384 | case 'd': 385 | err.DataTypeName = msg 386 | case 'n': 387 | err.Constraint = msg 388 | case 'F': 389 | err.File = msg 390 | case 'L': 391 | err.Line = msg 392 | case 'R': 393 | err.Routine = msg 394 | } 395 | } 396 | return err 397 | } 398 | 399 | // Fatal returns true if the Error Severity is fatal. 400 | func (err *Error) Fatal() bool { 401 | return err.Severity == Efatal 402 | } 403 | 404 | // Get implements the legacy PGError interface. New code should use the fields 405 | // of the Error struct directly. 406 | func (err *Error) Get(k byte) (v string) { 407 | switch k { 408 | case 'S': 409 | return err.Severity 410 | case 'C': 411 | return string(err.Code) 412 | case 'M': 413 | return err.Message 414 | case 'D': 415 | return err.Detail 416 | case 'H': 417 | return err.Hint 418 | case 'P': 419 | return err.Position 420 | case 'p': 421 | return err.InternalPosition 422 | case 'q': 423 | return err.InternalQuery 424 | case 'W': 425 | return err.Where 426 | case 's': 427 | return err.Schema 428 | case 't': 429 | return err.Table 430 | case 'c': 431 | return err.Column 432 | case 'd': 433 | return err.DataTypeName 434 | case 'n': 435 | return err.Constraint 436 | case 'F': 437 | return err.File 438 | case 'L': 439 | return err.Line 440 | case 'R': 441 | return err.Routine 442 | } 443 | return "" 444 | } 445 | 446 | func (err Error) Error() string { 447 | return "pq: " + err.Message 448 | } 449 | 450 | // PGError is an interface used by previous versions of pq. It is provided 451 | // only to support legacy code. New code should use the Error type. 452 | type PGError interface { 453 | Error() string 454 | Fatal() bool 455 | Get(k byte) (v string) 456 | } 457 | 458 | func errorf(s string, args ...interface{}) { 459 | panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) 460 | } 461 | 462 | func errRecoverNoErrBadConn(err *error) { 463 | e := recover() 464 | if e == nil { 465 | // Do nothing 466 | return 467 | } 468 | var ok bool 469 | *err, ok = e.(error) 470 | if !ok { 471 | *err = fmt.Errorf("pq: unexpected error: %#v", e) 472 | } 473 | } 474 | 475 | func (c *conn) errRecover(err *error) { 476 | e := recover() 477 | switch v := e.(type) { 478 | case nil: 479 | // Do nothing 480 | case runtime.Error: 481 | c.bad = true 482 | panic(v) 483 | case *Error: 484 | if v.Fatal() { 485 | *err = driver.ErrBadConn 486 | } else { 487 | *err = v 488 | } 489 | case *net.OpError: 490 | *err = driver.ErrBadConn 491 | case error: 492 | if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { 493 | *err = driver.ErrBadConn 494 | } else { 495 | *err = v 496 | } 497 | 498 | default: 499 | c.bad = true 500 | panic(fmt.Sprintf("unknown error: %#v", e)) 501 | } 502 | 503 | // Any time we return ErrBadConn, we need to remember it since *Tx doesn't 504 | // mark the connection bad in database/sql. 505 | if *err == driver.ErrBadConn { 506 | c.bad = true 507 | } 508 | } 509 | -------------------------------------------------------------------------------- /encode.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "database/sql/driver" 6 | "encoding/binary" 7 | "encoding/hex" 8 | "errors" 9 | "fmt" 10 | "math" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | "time" 15 | 16 | "github.com/lib/pq/oid" 17 | ) 18 | 19 | func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte { 20 | switch v := x.(type) { 21 | case []byte: 22 | return v 23 | default: 24 | return encode(parameterStatus, x, oid.T_unknown) 25 | } 26 | } 27 | 28 | func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { 29 | switch v := x.(type) { 30 | case int64: 31 | return strconv.AppendInt(nil, v, 10) 32 | case float64: 33 | return strconv.AppendFloat(nil, v, 'f', -1, 64) 34 | case []byte: 35 | if pgtypOid == oid.T_bytea { 36 | return encodeBytea(parameterStatus.serverVersion, v) 37 | } 38 | 39 | return v 40 | case string: 41 | if pgtypOid == oid.T_bytea { 42 | return encodeBytea(parameterStatus.serverVersion, []byte(v)) 43 | } 44 | 45 | return []byte(v) 46 | case bool: 47 | return strconv.AppendBool(nil, v) 48 | case time.Time: 49 | return formatTs(v) 50 | 51 | default: 52 | errorf("encode: unknown type for %T", v) 53 | } 54 | 55 | panic("not reached") 56 | } 57 | 58 | func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { 59 | switch f { 60 | case formatBinary: 61 | return binaryDecode(parameterStatus, s, typ) 62 | case formatText: 63 | return textDecode(parameterStatus, s, typ) 64 | default: 65 | panic("not reached") 66 | } 67 | } 68 | 69 | func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { 70 | switch typ { 71 | case oid.T_bytea: 72 | return s 73 | case oid.T_int8: 74 | return int64(binary.BigEndian.Uint64(s)) 75 | case oid.T_int4: 76 | return int64(int32(binary.BigEndian.Uint32(s))) 77 | case oid.T_int2: 78 | return int64(int16(binary.BigEndian.Uint16(s))) 79 | 80 | default: 81 | errorf("don't know how to decode binary parameter of type %d", uint32(typ)) 82 | } 83 | 84 | panic("not reached") 85 | } 86 | 87 | func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { 88 | switch typ { 89 | case oid.T_char, oid.T_varchar, oid.T_text: 90 | return string(s) 91 | case oid.T_bytea: 92 | b, err := parseBytea(s) 93 | if err != nil { 94 | errorf("%s", err) 95 | } 96 | return b 97 | case oid.T_timestamptz: 98 | return parseTs(parameterStatus.currentLocation, string(s)) 99 | case oid.T_timestamp, oid.T_date: 100 | return parseTs(nil, string(s)) 101 | case oid.T_time: 102 | return mustParse("15:04:05", typ, s) 103 | case oid.T_timetz: 104 | return mustParse("15:04:05-07", typ, s) 105 | case oid.T_bool: 106 | return s[0] == 't' 107 | case oid.T_int8, oid.T_int4, oid.T_int2: 108 | i, err := strconv.ParseInt(string(s), 10, 64) 109 | if err != nil { 110 | errorf("%s", err) 111 | } 112 | return i 113 | case oid.T_float4, oid.T_float8: 114 | bits := 64 115 | if typ == oid.T_float4 { 116 | bits = 32 117 | } 118 | f, err := strconv.ParseFloat(string(s), bits) 119 | if err != nil { 120 | errorf("%s", err) 121 | } 122 | return f 123 | } 124 | 125 | return s 126 | } 127 | 128 | // appendEncodedText encodes item in text format as required by COPY 129 | // and appends to buf 130 | func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte { 131 | switch v := x.(type) { 132 | case int64: 133 | return strconv.AppendInt(buf, v, 10) 134 | case float64: 135 | return strconv.AppendFloat(buf, v, 'f', -1, 64) 136 | case []byte: 137 | encodedBytea := encodeBytea(parameterStatus.serverVersion, v) 138 | return appendEscapedText(buf, string(encodedBytea)) 139 | case string: 140 | return appendEscapedText(buf, v) 141 | case bool: 142 | return strconv.AppendBool(buf, v) 143 | case time.Time: 144 | return append(buf, formatTs(v)...) 145 | case nil: 146 | return append(buf, "\\N"...) 147 | default: 148 | errorf("encode: unknown type for %T", v) 149 | } 150 | 151 | panic("not reached") 152 | } 153 | 154 | func appendEscapedText(buf []byte, text string) []byte { 155 | escapeNeeded := false 156 | startPos := 0 157 | var c byte 158 | 159 | // check if we need to escape 160 | for i := 0; i < len(text); i++ { 161 | c = text[i] 162 | if c == '\\' || c == '\n' || c == '\r' || c == '\t' { 163 | escapeNeeded = true 164 | startPos = i 165 | break 166 | } 167 | } 168 | if !escapeNeeded { 169 | return append(buf, text...) 170 | } 171 | 172 | // copy till first char to escape, iterate the rest 173 | result := append(buf, text[:startPos]...) 174 | for i := startPos; i < len(text); i++ { 175 | c = text[i] 176 | switch c { 177 | case '\\': 178 | result = append(result, '\\', '\\') 179 | case '\n': 180 | result = append(result, '\\', 'n') 181 | case '\r': 182 | result = append(result, '\\', 'r') 183 | case '\t': 184 | result = append(result, '\\', 't') 185 | default: 186 | result = append(result, c) 187 | } 188 | } 189 | return result 190 | } 191 | 192 | func mustParse(f string, typ oid.Oid, s []byte) time.Time { 193 | str := string(s) 194 | 195 | // check for a 30-minute-offset timezone 196 | if (typ == oid.T_timestamptz || typ == oid.T_timetz) && 197 | str[len(str)-3] == ':' { 198 | f += ":00" 199 | } 200 | t, err := time.Parse(f, str) 201 | if err != nil { 202 | errorf("decode: %s", err) 203 | } 204 | return t 205 | } 206 | 207 | var errInvalidTimestamp = errors.New("invalid timestamp") 208 | 209 | type timestampParser struct { 210 | err error 211 | } 212 | 213 | func (p *timestampParser) expect(str string, char byte, pos int) { 214 | if p.err != nil { 215 | return 216 | } 217 | if pos+1 > len(str) { 218 | p.err = errInvalidTimestamp 219 | return 220 | } 221 | if c := str[pos]; c != char && p.err == nil { 222 | p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) 223 | } 224 | } 225 | 226 | func (p *timestampParser) mustAtoi(str string, begin int, end int) int { 227 | if p.err != nil { 228 | return 0 229 | } 230 | if begin < 0 || end < 0 || begin > end || end > len(str) { 231 | p.err = errInvalidTimestamp 232 | return 0 233 | } 234 | result, err := strconv.Atoi(str[begin:end]) 235 | if err != nil { 236 | if p.err == nil { 237 | p.err = fmt.Errorf("expected number; got '%v'", str) 238 | } 239 | return 0 240 | } 241 | return result 242 | } 243 | 244 | // The location cache caches the time zones typically used by the client. 245 | type locationCache struct { 246 | cache map[int]*time.Location 247 | lock sync.Mutex 248 | } 249 | 250 | // All connections share the same list of timezones. Benchmarking shows that 251 | // about 5% speed could be gained by putting the cache in the connection and 252 | // losing the mutex, at the cost of a small amount of memory and a somewhat 253 | // significant increase in code complexity. 254 | var globalLocationCache = newLocationCache() 255 | 256 | func newLocationCache() *locationCache { 257 | return &locationCache{cache: make(map[int]*time.Location)} 258 | } 259 | 260 | // Returns the cached timezone for the specified offset, creating and caching 261 | // it if necessary. 262 | func (c *locationCache) getLocation(offset int) *time.Location { 263 | c.lock.Lock() 264 | defer c.lock.Unlock() 265 | 266 | location, ok := c.cache[offset] 267 | if !ok { 268 | location = time.FixedZone("", offset) 269 | c.cache[offset] = location 270 | } 271 | 272 | return location 273 | } 274 | 275 | var infinityTsEnabled = false 276 | var infinityTsNegative time.Time 277 | var infinityTsPositive time.Time 278 | 279 | const ( 280 | infinityTsEnabledAlready = "pq: infinity timestamp enabled already" 281 | infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" 282 | ) 283 | 284 | // EnableInfinityTs controls the handling of Postgres' "-infinity" and 285 | // "infinity" "timestamp"s. 286 | // 287 | // If EnableInfinityTs is not called, "-infinity" and "infinity" will return 288 | // []byte("-infinity") and []byte("infinity") respectively, and potentially 289 | // cause error "sql: Scan error on column index 0: unsupported driver -> Scan 290 | // pair: []uint8 -> *time.Time", when scanning into a time.Time value. 291 | // 292 | // Once EnableInfinityTs has been called, all connections created using this 293 | // driver will decode Postgres' "-infinity" and "infinity" for "timestamp", 294 | // "timestamp with time zone" and "date" types to the predefined minimum and 295 | // maximum times, respectively. When encoding time.Time values, any time which 296 | // equals or precedes the predefined minimum time will be encoded to 297 | // "-infinity". Any values at or past the maximum time will similarly be 298 | // encoded to "infinity". 299 | // 300 | // If EnableInfinityTs is called with negative >= positive, it will panic. 301 | // Calling EnableInfinityTs after a connection has been established results in 302 | // undefined behavior. If EnableInfinityTs is called more than once, it will 303 | // panic. 304 | func EnableInfinityTs(negative time.Time, positive time.Time) { 305 | if infinityTsEnabled { 306 | panic(infinityTsEnabledAlready) 307 | } 308 | if !negative.Before(positive) { 309 | panic(infinityTsNegativeMustBeSmaller) 310 | } 311 | infinityTsEnabled = true 312 | infinityTsNegative = negative 313 | infinityTsPositive = positive 314 | } 315 | 316 | /* 317 | * Testing might want to toggle infinityTsEnabled 318 | */ 319 | func disableInfinityTs() { 320 | infinityTsEnabled = false 321 | } 322 | 323 | // This is a time function specific to the Postgres default DateStyle 324 | // setting ("ISO, MDY"), the only one we currently support. This 325 | // accounts for the discrepancies between the parsing available with 326 | // time.Parse and the Postgres date formatting quirks. 327 | func parseTs(currentLocation *time.Location, str string) interface{} { 328 | switch str { 329 | case "-infinity": 330 | if infinityTsEnabled { 331 | return infinityTsNegative 332 | } 333 | return []byte(str) 334 | case "infinity": 335 | if infinityTsEnabled { 336 | return infinityTsPositive 337 | } 338 | return []byte(str) 339 | } 340 | t, err := ParseTimestamp(currentLocation, str) 341 | if err != nil { 342 | panic(err) 343 | } 344 | return t 345 | } 346 | 347 | // ParseTimestamp parses Postgres' text format. It returns a time.Time in 348 | // currentLocation iff that time's offset agrees with the offset sent from the 349 | // Postgres server. Otherwise, ParseTimestamp returns a time.Time with the 350 | // fixed offset offset provided by the Postgres server. 351 | func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { 352 | p := timestampParser{} 353 | 354 | monSep := strings.IndexRune(str, '-') 355 | // this is Gregorian year, not ISO Year 356 | // In Gregorian system, the year 1 BC is followed by AD 1 357 | year := p.mustAtoi(str, 0, monSep) 358 | daySep := monSep + 3 359 | month := p.mustAtoi(str, monSep+1, daySep) 360 | p.expect(str, '-', daySep) 361 | timeSep := daySep + 3 362 | day := p.mustAtoi(str, daySep+1, timeSep) 363 | 364 | var hour, minute, second int 365 | if len(str) > monSep+len("01-01")+1 { 366 | p.expect(str, ' ', timeSep) 367 | minSep := timeSep + 3 368 | p.expect(str, ':', minSep) 369 | hour = p.mustAtoi(str, timeSep+1, minSep) 370 | secSep := minSep + 3 371 | p.expect(str, ':', secSep) 372 | minute = p.mustAtoi(str, minSep+1, secSep) 373 | secEnd := secSep + 3 374 | second = p.mustAtoi(str, secSep+1, secEnd) 375 | } 376 | remainderIdx := monSep + len("01-01 00:00:00") + 1 377 | // Three optional (but ordered) sections follow: the 378 | // fractional seconds, the time zone offset, and the BC 379 | // designation. We set them up here and adjust the other 380 | // offsets if the preceding sections exist. 381 | 382 | nanoSec := 0 383 | tzOff := 0 384 | 385 | if remainderIdx < len(str) && str[remainderIdx] == '.' { 386 | fracStart := remainderIdx + 1 387 | fracOff := strings.IndexAny(str[fracStart:], "-+ ") 388 | if fracOff < 0 { 389 | fracOff = len(str) - fracStart 390 | } 391 | fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) 392 | nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) 393 | 394 | remainderIdx += fracOff + 1 395 | } 396 | if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { 397 | // time zone separator is always '-' or '+' (UTC is +00) 398 | var tzSign int 399 | switch c := str[tzStart]; c { 400 | case '-': 401 | tzSign = -1 402 | case '+': 403 | tzSign = +1 404 | default: 405 | return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) 406 | } 407 | tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) 408 | remainderIdx += 3 409 | var tzMin, tzSec int 410 | if remainderIdx < len(str) && str[remainderIdx] == ':' { 411 | tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) 412 | remainderIdx += 3 413 | } 414 | if remainderIdx < len(str) && str[remainderIdx] == ':' { 415 | tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) 416 | remainderIdx += 3 417 | } 418 | tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) 419 | } 420 | var isoYear int 421 | if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" { 422 | isoYear = 1 - year 423 | remainderIdx += 3 424 | } else { 425 | isoYear = year 426 | } 427 | if remainderIdx < len(str) { 428 | return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) 429 | } 430 | t := time.Date(isoYear, time.Month(month), day, 431 | hour, minute, second, nanoSec, 432 | globalLocationCache.getLocation(tzOff)) 433 | 434 | if currentLocation != nil { 435 | // Set the location of the returned Time based on the session's 436 | // TimeZone value, but only if the local time zone database agrees with 437 | // the remote database on the offset. 438 | lt := t.In(currentLocation) 439 | _, newOff := lt.Zone() 440 | if newOff == tzOff { 441 | t = lt 442 | } 443 | } 444 | 445 | return t, p.err 446 | } 447 | 448 | // formatTs formats t into a format postgres understands. 449 | func formatTs(t time.Time) []byte { 450 | if infinityTsEnabled { 451 | // t <= -infinity : ! (t > -infinity) 452 | if !t.After(infinityTsNegative) { 453 | return []byte("-infinity") 454 | } 455 | // t >= infinity : ! (!t < infinity) 456 | if !t.Before(infinityTsPositive) { 457 | return []byte("infinity") 458 | } 459 | } 460 | return FormatTimestamp(t) 461 | } 462 | 463 | // FormatTimestamp formats t into Postgres' text format for timestamps. 464 | func FormatTimestamp(t time.Time) []byte { 465 | // Need to send dates before 0001 A.D. with " BC" suffix, instead of the 466 | // minus sign preferred by Go. 467 | // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on 468 | bc := false 469 | if t.Year() <= 0 { 470 | // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11" 471 | t = t.AddDate((-t.Year())*2+1, 0, 0) 472 | bc = true 473 | } 474 | b := []byte(t.Format(time.RFC3339Nano)) 475 | 476 | _, offset := t.Zone() 477 | offset = offset % 60 478 | if offset != 0 { 479 | // RFC3339Nano already printed the minus sign 480 | if offset < 0 { 481 | offset = -offset 482 | } 483 | 484 | b = append(b, ':') 485 | if offset < 10 { 486 | b = append(b, '0') 487 | } 488 | b = strconv.AppendInt(b, int64(offset), 10) 489 | } 490 | 491 | if bc { 492 | b = append(b, " BC"...) 493 | } 494 | return b 495 | } 496 | 497 | // Parse a bytea value received from the server. Both "hex" and the legacy 498 | // "escape" format are supported. 499 | func parseBytea(s []byte) (result []byte, err error) { 500 | if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { 501 | // bytea_output = hex 502 | s = s[2:] // trim off leading "\\x" 503 | result = make([]byte, hex.DecodedLen(len(s))) 504 | _, err := hex.Decode(result, s) 505 | if err != nil { 506 | return nil, err 507 | } 508 | } else { 509 | // bytea_output = escape 510 | for len(s) > 0 { 511 | if s[0] == '\\' { 512 | // escaped '\\' 513 | if len(s) >= 2 && s[1] == '\\' { 514 | result = append(result, '\\') 515 | s = s[2:] 516 | continue 517 | } 518 | 519 | // '\\' followed by an octal number 520 | if len(s) < 4 { 521 | return nil, fmt.Errorf("invalid bytea sequence %v", s) 522 | } 523 | r, err := strconv.ParseInt(string(s[1:4]), 8, 9) 524 | if err != nil { 525 | return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) 526 | } 527 | result = append(result, byte(r)) 528 | s = s[4:] 529 | } else { 530 | // We hit an unescaped, raw byte. Try to read in as many as 531 | // possible in one go. 532 | i := bytes.IndexByte(s, '\\') 533 | if i == -1 { 534 | result = append(result, s...) 535 | break 536 | } 537 | result = append(result, s[:i]...) 538 | s = s[i:] 539 | } 540 | } 541 | } 542 | 543 | return result, nil 544 | } 545 | 546 | func encodeBytea(serverVersion int, v []byte) (result []byte) { 547 | if serverVersion >= 90000 { 548 | // Use the hex format if we know that the server supports it 549 | result = make([]byte, 2+hex.EncodedLen(len(v))) 550 | result[0] = '\\' 551 | result[1] = 'x' 552 | hex.Encode(result[2:], v) 553 | } else { 554 | // .. or resort to "escape" 555 | for _, b := range v { 556 | if b == '\\' { 557 | result = append(result, '\\', '\\') 558 | } else if b < 0x20 || b > 0x7e { 559 | result = append(result, []byte(fmt.Sprintf("\\%03o", b))...) 560 | } else { 561 | result = append(result, b) 562 | } 563 | } 564 | } 565 | 566 | return result 567 | } 568 | 569 | // NullTime represents a time.Time that may be null. NullTime implements the 570 | // sql.Scanner interface so it can be used as a scan destination, similar to 571 | // sql.NullString. 572 | type NullTime struct { 573 | Time time.Time 574 | Valid bool // Valid is true if Time is not NULL 575 | } 576 | 577 | // Scan implements the Scanner interface. 578 | func (nt *NullTime) Scan(value interface{}) error { 579 | nt.Time, nt.Valid = value.(time.Time) 580 | return nil 581 | } 582 | 583 | // Value implements the driver Valuer interface. 584 | func (nt NullTime) Value() (driver.Value, error) { 585 | if !nt.Valid { 586 | return nil, nil 587 | } 588 | return nt.Time, nil 589 | } 590 | -------------------------------------------------------------------------------- /array.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "database/sql/driver" 7 | "encoding/hex" 8 | "fmt" 9 | "reflect" 10 | "strconv" 11 | "strings" 12 | ) 13 | 14 | var typeByteSlice = reflect.TypeOf([]byte{}) 15 | var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() 16 | var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() 17 | 18 | // Array returns the optimal driver.Valuer and sql.Scanner for an array or 19 | // slice of any dimension. 20 | // 21 | // For example: 22 | // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) 23 | // 24 | // var x []sql.NullInt64 25 | // db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) 26 | // 27 | // Scanning multi-dimensional arrays is not supported. Arrays where the lower 28 | // bound is not one (such as `[0:0]={1}') are not supported. 29 | func Array(a interface{}) interface { 30 | driver.Valuer 31 | sql.Scanner 32 | } { 33 | switch a := a.(type) { 34 | case []bool: 35 | return (*BoolArray)(&a) 36 | case []float64: 37 | return (*Float64Array)(&a) 38 | case []int64: 39 | return (*Int64Array)(&a) 40 | case []string: 41 | return (*StringArray)(&a) 42 | 43 | case *[]bool: 44 | return (*BoolArray)(a) 45 | case *[]float64: 46 | return (*Float64Array)(a) 47 | case *[]int64: 48 | return (*Int64Array)(a) 49 | case *[]string: 50 | return (*StringArray)(a) 51 | } 52 | 53 | return GenericArray{a} 54 | } 55 | 56 | // ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner 57 | // to override the array delimiter used by GenericArray. 58 | type ArrayDelimiter interface { 59 | // ArrayDelimiter returns the delimiter character(s) for this element's type. 60 | ArrayDelimiter() string 61 | } 62 | 63 | // BoolArray represents a one-dimensional array of the PostgreSQL boolean type. 64 | type BoolArray []bool 65 | 66 | // Scan implements the sql.Scanner interface. 67 | func (a *BoolArray) Scan(src interface{}) error { 68 | switch src := src.(type) { 69 | case []byte: 70 | return a.scanBytes(src) 71 | case string: 72 | return a.scanBytes([]byte(src)) 73 | } 74 | 75 | return fmt.Errorf("pq: cannot convert %T to BoolArray", src) 76 | } 77 | 78 | func (a *BoolArray) scanBytes(src []byte) error { 79 | elems, err := scanLinearArray(src, []byte{','}, "BoolArray") 80 | if err != nil { 81 | return err 82 | } 83 | if len(elems) == 0 { 84 | *a = (*a)[:0] 85 | } else { 86 | b := make(BoolArray, len(elems)) 87 | for i, v := range elems { 88 | if len(v) != 1 { 89 | return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) 90 | } 91 | switch v[0] { 92 | case 't': 93 | b[i] = true 94 | case 'f': 95 | b[i] = false 96 | default: 97 | return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) 98 | } 99 | } 100 | *a = b 101 | } 102 | return nil 103 | } 104 | 105 | // Value implements the driver.Valuer interface. 106 | func (a BoolArray) Value() (driver.Value, error) { 107 | if a == nil { 108 | return nil, nil 109 | } 110 | 111 | if n := len(a); n > 0 { 112 | // There will be exactly two curly brackets, N bytes of values, 113 | // and N-1 bytes of delimiters. 114 | b := make([]byte, 1+2*n) 115 | 116 | for i := 0; i < n; i++ { 117 | b[2*i] = ',' 118 | if a[i] { 119 | b[1+2*i] = 't' 120 | } else { 121 | b[1+2*i] = 'f' 122 | } 123 | } 124 | 125 | b[0] = '{' 126 | b[2*n] = '}' 127 | 128 | return string(b), nil 129 | } 130 | 131 | return "{}", nil 132 | } 133 | 134 | // ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. 135 | type ByteaArray [][]byte 136 | 137 | // Scan implements the sql.Scanner interface. 138 | func (a *ByteaArray) Scan(src interface{}) error { 139 | switch src := src.(type) { 140 | case []byte: 141 | return a.scanBytes(src) 142 | case string: 143 | return a.scanBytes([]byte(src)) 144 | } 145 | 146 | return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) 147 | } 148 | 149 | func (a *ByteaArray) scanBytes(src []byte) error { 150 | elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") 151 | if err != nil { 152 | return err 153 | } 154 | if len(elems) == 0 { 155 | *a = (*a)[:0] 156 | } else { 157 | b := make(ByteaArray, len(elems)) 158 | for i, v := range elems { 159 | b[i], err = parseBytea(v) 160 | if err != nil { 161 | return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) 162 | } 163 | } 164 | *a = b 165 | } 166 | return nil 167 | } 168 | 169 | // Value implements the driver.Valuer interface. It uses the "hex" format which 170 | // is only supported on PostgreSQL 9.0 or newer. 171 | func (a ByteaArray) Value() (driver.Value, error) { 172 | if a == nil { 173 | return nil, nil 174 | } 175 | 176 | if n := len(a); n > 0 { 177 | // There will be at least two curly brackets, 2*N bytes of quotes, 178 | // 3*N bytes of hex formatting, and N-1 bytes of delimiters. 179 | size := 1 + 6*n 180 | for _, x := range a { 181 | size += hex.EncodedLen(len(x)) 182 | } 183 | 184 | b := make([]byte, size) 185 | 186 | for i, s := 0, b; i < n; i++ { 187 | o := copy(s, `,"\\x`) 188 | o += hex.Encode(s[o:], a[i]) 189 | s[o] = '"' 190 | s = s[o+1:] 191 | } 192 | 193 | b[0] = '{' 194 | b[size-1] = '}' 195 | 196 | return string(b), nil 197 | } 198 | 199 | return "{}", nil 200 | } 201 | 202 | // Float64Array represents a one-dimensional array of the PostgreSQL double 203 | // precision type. 204 | type Float64Array []float64 205 | 206 | // Scan implements the sql.Scanner interface. 207 | func (a *Float64Array) Scan(src interface{}) error { 208 | switch src := src.(type) { 209 | case []byte: 210 | return a.scanBytes(src) 211 | case string: 212 | return a.scanBytes([]byte(src)) 213 | } 214 | 215 | return fmt.Errorf("pq: cannot convert %T to Float64Array", src) 216 | } 217 | 218 | func (a *Float64Array) scanBytes(src []byte) error { 219 | elems, err := scanLinearArray(src, []byte{','}, "Float64Array") 220 | if err != nil { 221 | return err 222 | } 223 | if len(elems) == 0 { 224 | *a = (*a)[:0] 225 | } else { 226 | b := make(Float64Array, len(elems)) 227 | for i, v := range elems { 228 | if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { 229 | return fmt.Errorf("pq: parsing array element index %d: %v", i, err) 230 | } 231 | } 232 | *a = b 233 | } 234 | return nil 235 | } 236 | 237 | // Value implements the driver.Valuer interface. 238 | func (a Float64Array) Value() (driver.Value, error) { 239 | if a == nil { 240 | return nil, nil 241 | } 242 | 243 | if n := len(a); n > 0 { 244 | // There will be at least two curly brackets, N bytes of values, 245 | // and N-1 bytes of delimiters. 246 | b := make([]byte, 1, 1+2*n) 247 | b[0] = '{' 248 | 249 | b = strconv.AppendFloat(b, a[0], 'f', -1, 64) 250 | for i := 1; i < n; i++ { 251 | b = append(b, ',') 252 | b = strconv.AppendFloat(b, a[i], 'f', -1, 64) 253 | } 254 | 255 | return string(append(b, '}')), nil 256 | } 257 | 258 | return "{}", nil 259 | } 260 | 261 | // GenericArray implements the driver.Valuer and sql.Scanner interfaces for 262 | // an array or slice of any dimension. 263 | type GenericArray struct{ A interface{} } 264 | 265 | func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { 266 | var assign func([]byte, reflect.Value) error 267 | var del = "," 268 | 269 | // TODO calculate the assign function for other types 270 | // TODO repeat this section on the element type of arrays or slices (multidimensional) 271 | { 272 | if reflect.PtrTo(rt).Implements(typeSqlScanner) { 273 | // dest is always addressable because it is an element of a slice. 274 | assign = func(src []byte, dest reflect.Value) (err error) { 275 | ss := dest.Addr().Interface().(sql.Scanner) 276 | if src == nil { 277 | err = ss.Scan(nil) 278 | } else { 279 | err = ss.Scan(src) 280 | } 281 | return 282 | } 283 | goto FoundType 284 | } 285 | 286 | assign = func([]byte, reflect.Value) error { 287 | return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) 288 | } 289 | } 290 | 291 | FoundType: 292 | 293 | if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { 294 | del = ad.ArrayDelimiter() 295 | } 296 | 297 | return rt, assign, del 298 | } 299 | 300 | // Scan implements the sql.Scanner interface. 301 | func (a GenericArray) Scan(src interface{}) error { 302 | dpv := reflect.ValueOf(a.A) 303 | switch { 304 | case dpv.Kind() != reflect.Ptr: 305 | return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) 306 | case dpv.IsNil(): 307 | return fmt.Errorf("pq: destination %T is nil", a.A) 308 | } 309 | 310 | dv := dpv.Elem() 311 | switch dv.Kind() { 312 | case reflect.Slice: 313 | case reflect.Array: 314 | default: 315 | return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) 316 | } 317 | 318 | switch src := src.(type) { 319 | case []byte: 320 | return a.scanBytes(src, dv) 321 | case string: 322 | return a.scanBytes([]byte(src), dv) 323 | } 324 | 325 | return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) 326 | } 327 | 328 | func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { 329 | dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) 330 | dims, elems, err := parseArray(src, []byte(del)) 331 | if err != nil { 332 | return err 333 | } 334 | 335 | // TODO allow multidimensional 336 | 337 | if len(dims) > 1 { 338 | return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", 339 | strings.Replace(fmt.Sprint(dims), " ", "][", -1)) 340 | } 341 | 342 | // Treat a zero-dimensional array like an array with a single dimension of zero. 343 | if len(dims) == 0 { 344 | dims = append(dims, 0) 345 | } 346 | 347 | for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { 348 | switch rt.Kind() { 349 | case reflect.Slice: 350 | case reflect.Array: 351 | if rt.Len() != dims[i] { 352 | return fmt.Errorf("pq: cannot convert ARRAY%s to %s", 353 | strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) 354 | } 355 | default: 356 | // TODO handle multidimensional 357 | } 358 | } 359 | 360 | values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) 361 | for i, e := range elems { 362 | if err := assign(e, values.Index(i)); err != nil { 363 | return fmt.Errorf("pq: parsing array element index %d: %v", i, err) 364 | } 365 | } 366 | 367 | // TODO handle multidimensional 368 | 369 | switch dv.Kind() { 370 | case reflect.Slice: 371 | dv.Set(values.Slice(0, dims[0])) 372 | case reflect.Array: 373 | for i := 0; i < dims[0]; i++ { 374 | dv.Index(i).Set(values.Index(i)) 375 | } 376 | } 377 | 378 | return nil 379 | } 380 | 381 | // Value implements the driver.Valuer interface. 382 | func (a GenericArray) Value() (driver.Value, error) { 383 | if a.A == nil { 384 | return nil, nil 385 | } 386 | 387 | rv := reflect.ValueOf(a.A) 388 | 389 | if k := rv.Kind(); k != reflect.Array && k != reflect.Slice { 390 | return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) 391 | } 392 | 393 | if n := rv.Len(); n > 0 { 394 | // There will be at least two curly brackets, N bytes of values, 395 | // and N-1 bytes of delimiters. 396 | b := make([]byte, 0, 1+2*n) 397 | 398 | b, _, err := appendArray(b, rv, n) 399 | return string(b), err 400 | } 401 | 402 | return "{}", nil 403 | } 404 | 405 | // Int64Array represents a one-dimensional array of the PostgreSQL integer types. 406 | type Int64Array []int64 407 | 408 | // Scan implements the sql.Scanner interface. 409 | func (a *Int64Array) Scan(src interface{}) error { 410 | switch src := src.(type) { 411 | case []byte: 412 | return a.scanBytes(src) 413 | case string: 414 | return a.scanBytes([]byte(src)) 415 | } 416 | 417 | return fmt.Errorf("pq: cannot convert %T to Int64Array", src) 418 | } 419 | 420 | func (a *Int64Array) scanBytes(src []byte) error { 421 | elems, err := scanLinearArray(src, []byte{','}, "Int64Array") 422 | if err != nil { 423 | return err 424 | } 425 | if len(elems) == 0 { 426 | *a = (*a)[:0] 427 | } else { 428 | b := make(Int64Array, len(elems)) 429 | for i, v := range elems { 430 | if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { 431 | return fmt.Errorf("pq: parsing array element index %d: %v", i, err) 432 | } 433 | } 434 | *a = b 435 | } 436 | return nil 437 | } 438 | 439 | // Value implements the driver.Valuer interface. 440 | func (a Int64Array) Value() (driver.Value, error) { 441 | if a == nil { 442 | return nil, nil 443 | } 444 | 445 | if n := len(a); n > 0 { 446 | // There will be at least two curly brackets, N bytes of values, 447 | // and N-1 bytes of delimiters. 448 | b := make([]byte, 1, 1+2*n) 449 | b[0] = '{' 450 | 451 | b = strconv.AppendInt(b, a[0], 10) 452 | for i := 1; i < n; i++ { 453 | b = append(b, ',') 454 | b = strconv.AppendInt(b, a[i], 10) 455 | } 456 | 457 | return string(append(b, '}')), nil 458 | } 459 | 460 | return "{}", nil 461 | } 462 | 463 | // StringArray represents a one-dimensional array of the PostgreSQL character types. 464 | type StringArray []string 465 | 466 | // Scan implements the sql.Scanner interface. 467 | func (a *StringArray) Scan(src interface{}) error { 468 | switch src := src.(type) { 469 | case []byte: 470 | return a.scanBytes(src) 471 | case string: 472 | return a.scanBytes([]byte(src)) 473 | } 474 | 475 | return fmt.Errorf("pq: cannot convert %T to StringArray", src) 476 | } 477 | 478 | func (a *StringArray) scanBytes(src []byte) error { 479 | elems, err := scanLinearArray(src, []byte{','}, "StringArray") 480 | if err != nil { 481 | return err 482 | } 483 | if len(elems) == 0 { 484 | *a = (*a)[:0] 485 | } else { 486 | b := make(StringArray, len(elems)) 487 | for i, v := range elems { 488 | if b[i] = string(v); v == nil { 489 | return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) 490 | } 491 | } 492 | *a = b 493 | } 494 | return nil 495 | } 496 | 497 | // Value implements the driver.Valuer interface. 498 | func (a StringArray) Value() (driver.Value, error) { 499 | if a == nil { 500 | return nil, nil 501 | } 502 | 503 | if n := len(a); n > 0 { 504 | // There will be at least two curly brackets, 2*N bytes of quotes, 505 | // and N-1 bytes of delimiters. 506 | b := make([]byte, 1, 1+3*n) 507 | b[0] = '{' 508 | 509 | b = appendArrayQuotedBytes(b, []byte(a[0])) 510 | for i := 1; i < n; i++ { 511 | b = append(b, ',') 512 | b = appendArrayQuotedBytes(b, []byte(a[i])) 513 | } 514 | 515 | return string(append(b, '}')), nil 516 | } 517 | 518 | return "{}", nil 519 | } 520 | 521 | // appendArray appends rv to the buffer, returning the extended buffer and 522 | // the delimiter used between elements. 523 | // 524 | // It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. 525 | func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { 526 | var del string 527 | var err error 528 | 529 | b = append(b, '{') 530 | 531 | if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { 532 | return b, del, err 533 | } 534 | 535 | for i := 1; i < n; i++ { 536 | b = append(b, del...) 537 | if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { 538 | return b, del, err 539 | } 540 | } 541 | 542 | return append(b, '}'), del, nil 543 | } 544 | 545 | // appendArrayElement appends rv to the buffer, returning the extended buffer 546 | // and the delimiter to use before the next element. 547 | // 548 | // When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted 549 | // using driver.DefaultParameterConverter and the resulting []byte or string 550 | // is double-quoted. 551 | // 552 | // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO 553 | func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { 554 | if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { 555 | if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { 556 | if n := rv.Len(); n > 0 { 557 | return appendArray(b, rv, n) 558 | } 559 | 560 | return b, "", nil 561 | } 562 | } 563 | 564 | var del string = "," 565 | var err error 566 | var iv interface{} = rv.Interface() 567 | 568 | if ad, ok := iv.(ArrayDelimiter); ok { 569 | del = ad.ArrayDelimiter() 570 | } 571 | 572 | if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { 573 | return b, del, err 574 | } 575 | 576 | switch v := iv.(type) { 577 | case nil: 578 | return append(b, "NULL"...), del, nil 579 | case []byte: 580 | return appendArrayQuotedBytes(b, v), del, nil 581 | case string: 582 | return appendArrayQuotedBytes(b, []byte(v)), del, nil 583 | } 584 | 585 | b, err = appendValue(b, iv) 586 | return b, del, err 587 | } 588 | 589 | func appendArrayQuotedBytes(b, v []byte) []byte { 590 | b = append(b, '"') 591 | for { 592 | i := bytes.IndexAny(v, `"\`) 593 | if i < 0 { 594 | b = append(b, v...) 595 | break 596 | } 597 | if i > 0 { 598 | b = append(b, v[:i]...) 599 | } 600 | b = append(b, '\\', v[i]) 601 | v = v[i+1:] 602 | } 603 | return append(b, '"') 604 | } 605 | 606 | func appendValue(b []byte, v driver.Value) ([]byte, error) { 607 | return append(b, encode(nil, v, 0)...), nil 608 | } 609 | 610 | // parseArray extracts the dimensions and elements of an array represented in 611 | // text format. Only representations emitted by the backend are supported. 612 | // Notably, whitespace around brackets and delimiters is significant, and NULL 613 | // is case-sensitive. 614 | // 615 | // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO 616 | func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { 617 | var depth, i int 618 | 619 | if len(src) < 1 || src[0] != '{' { 620 | return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) 621 | } 622 | 623 | Open: 624 | for i < len(src) { 625 | switch src[i] { 626 | case '{': 627 | depth++ 628 | i++ 629 | case '}': 630 | elems = make([][]byte, 0) 631 | goto Close 632 | default: 633 | break Open 634 | } 635 | } 636 | dims = make([]int, i) 637 | 638 | Element: 639 | for i < len(src) { 640 | switch src[i] { 641 | case '{': 642 | depth++ 643 | dims[depth-1] = 0 644 | i++ 645 | case '"': 646 | var elem = []byte{} 647 | var escape bool 648 | for i++; i < len(src); i++ { 649 | if escape { 650 | elem = append(elem, src[i]) 651 | escape = false 652 | } else { 653 | switch src[i] { 654 | default: 655 | elem = append(elem, src[i]) 656 | case '\\': 657 | escape = true 658 | case '"': 659 | elems = append(elems, elem) 660 | i++ 661 | break Element 662 | } 663 | } 664 | } 665 | default: 666 | for start := i; i < len(src); i++ { 667 | if bytes.HasPrefix(src[i:], del) || src[i] == '}' { 668 | elem := src[start:i] 669 | if len(elem) == 0 { 670 | return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) 671 | } 672 | if bytes.Equal(elem, []byte("NULL")) { 673 | elem = nil 674 | } 675 | elems = append(elems, elem) 676 | break Element 677 | } 678 | } 679 | } 680 | } 681 | 682 | for i < len(src) { 683 | if bytes.HasPrefix(src[i:], del) { 684 | dims[depth-1]++ 685 | i += len(del) 686 | goto Element 687 | } else if src[i] == '}' { 688 | dims[depth-1]++ 689 | depth-- 690 | i++ 691 | } else { 692 | return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) 693 | } 694 | } 695 | 696 | Close: 697 | for i < len(src) { 698 | if src[i] == '}' && depth > 0 { 699 | depth-- 700 | i++ 701 | } else { 702 | return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) 703 | } 704 | } 705 | if depth > 0 { 706 | err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) 707 | } 708 | if err == nil { 709 | for _, d := range dims { 710 | if (len(elems) % d) != 0 { 711 | err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") 712 | } 713 | } 714 | } 715 | return 716 | } 717 | 718 | func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { 719 | dims, elems, err := parseArray(src, del) 720 | if err != nil { 721 | return nil, err 722 | } 723 | if len(dims) > 1 { 724 | return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) 725 | } 726 | return elems, err 727 | } 728 | -------------------------------------------------------------------------------- /encode_test.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | import ( 4 | "bytes" 5 | "database/sql" 6 | "fmt" 7 | "strings" 8 | "testing" 9 | "time" 10 | 11 | "github.com/lib/pq/oid" 12 | ) 13 | 14 | func TestScanTimestamp(t *testing.T) { 15 | var nt NullTime 16 | tn := time.Now() 17 | nt.Scan(tn) 18 | if !nt.Valid { 19 | t.Errorf("Expected Valid=false") 20 | } 21 | if nt.Time != tn { 22 | t.Errorf("Time value mismatch") 23 | } 24 | } 25 | 26 | func TestScanNilTimestamp(t *testing.T) { 27 | var nt NullTime 28 | nt.Scan(nil) 29 | if nt.Valid { 30 | t.Errorf("Expected Valid=false") 31 | } 32 | } 33 | 34 | var timeTests = []struct { 35 | str string 36 | timeval time.Time 37 | }{ 38 | {"22001-02-03", time.Date(22001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, 39 | {"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, 40 | {"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, 41 | {"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.FixedZone("", 0))}, 42 | {"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.FixedZone("", 0))}, 43 | {"2001-02-03 04:05:06.0001", time.Date(2001, time.February, 3, 4, 5, 6, 100000, time.FixedZone("", 0))}, 44 | {"2001-02-03 04:05:06.001", time.Date(2001, time.February, 3, 4, 5, 6, 1000000, time.FixedZone("", 0))}, 45 | {"2001-02-03 04:05:06.01", time.Date(2001, time.February, 3, 4, 5, 6, 10000000, time.FixedZone("", 0))}, 46 | {"2001-02-03 04:05:06.1", time.Date(2001, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, 47 | {"2001-02-03 04:05:06.12", time.Date(2001, time.February, 3, 4, 5, 6, 120000000, time.FixedZone("", 0))}, 48 | {"2001-02-03 04:05:06.123", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, 49 | {"2001-02-03 04:05:06.1234", time.Date(2001, time.February, 3, 4, 5, 6, 123400000, time.FixedZone("", 0))}, 50 | {"2001-02-03 04:05:06.12345", time.Date(2001, time.February, 3, 4, 5, 6, 123450000, time.FixedZone("", 0))}, 51 | {"2001-02-03 04:05:06.123456", time.Date(2001, time.February, 3, 4, 5, 6, 123456000, time.FixedZone("", 0))}, 52 | {"2001-02-03 04:05:06.123-07", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, 53 | time.FixedZone("", -7*60*60))}, 54 | {"2001-02-03 04:05:06-07", time.Date(2001, time.February, 3, 4, 5, 6, 0, 55 | time.FixedZone("", -7*60*60))}, 56 | {"2001-02-03 04:05:06-07:42", time.Date(2001, time.February, 3, 4, 5, 6, 0, 57 | time.FixedZone("", -(7*60*60+42*60)))}, 58 | {"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0, 59 | time.FixedZone("", -(7*60*60+30*60+9)))}, 60 | {"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0, 61 | time.FixedZone("", 7*60*60))}, 62 | {"0011-02-03 04:05:06 BC", time.Date(-10, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, 63 | {"0011-02-03 04:05:06.123 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, 64 | {"0011-02-03 04:05:06.123-07 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, 65 | time.FixedZone("", -7*60*60))}, 66 | {"0001-02-03 04:05:06.123", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, 67 | {"0001-02-03 04:05:06.123 BC", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, 68 | {"0001-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, 69 | {"0002-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, 70 | {"0002-02-03 04:05:06.123 BC", time.Date(-1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, 71 | {"12345-02-03 04:05:06.1", time.Date(12345, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, 72 | {"123456-02-03 04:05:06.1", time.Date(123456, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, 73 | } 74 | 75 | // Test that parsing the string results in the expected value. 76 | func TestParseTs(t *testing.T) { 77 | for i, tt := range timeTests { 78 | val, err := ParseTimestamp(nil, tt.str) 79 | if err != nil { 80 | t.Errorf("%d: got error: %v", i, err) 81 | } else if val.String() != tt.timeval.String() { 82 | t.Errorf("%d: expected to parse %q into %q; got %q", 83 | i, tt.str, tt.timeval, val) 84 | } 85 | } 86 | } 87 | 88 | var timeErrorTests = []string{ 89 | "2001", 90 | "2001-2-03", 91 | "2001-02-3", 92 | "2001-02-03 ", 93 | "2001-02-03 04", 94 | "2001-02-03 04:", 95 | "2001-02-03 04:05", 96 | "2001-02-03 04:05:", 97 | "2001-02-03 04:05:6", 98 | "2001-02-03 04:05:06.123 B", 99 | } 100 | 101 | // Test that parsing the string results in an error. 102 | func TestParseTsErrors(t *testing.T) { 103 | for i, tt := range timeErrorTests { 104 | _, err := ParseTimestamp(nil, tt) 105 | if err == nil { 106 | t.Errorf("%d: expected an error from parsing: %v", i, tt) 107 | } 108 | } 109 | } 110 | 111 | // Now test that sending the value into the database and parsing it back 112 | // returns the same time.Time value. 113 | func TestEncodeAndParseTs(t *testing.T) { 114 | db, err := openTestConnConninfo("timezone='Etc/UTC'") 115 | if err != nil { 116 | t.Fatal(err) 117 | } 118 | defer db.Close() 119 | 120 | for i, tt := range timeTests { 121 | var dbstr string 122 | err = db.QueryRow("SELECT ($1::timestamptz)::text", tt.timeval).Scan(&dbstr) 123 | if err != nil { 124 | t.Errorf("%d: could not send value %q to the database: %s", i, tt.timeval, err) 125 | continue 126 | } 127 | 128 | val, err := ParseTimestamp(nil, dbstr) 129 | if err != nil { 130 | t.Errorf("%d: could not parse value %q: %s", i, dbstr, err) 131 | continue 132 | } 133 | val = val.In(tt.timeval.Location()) 134 | if val.String() != tt.timeval.String() { 135 | t.Errorf("%d: expected to parse %q into %q; got %q", i, dbstr, tt.timeval, val) 136 | } 137 | } 138 | } 139 | 140 | var formatTimeTests = []struct { 141 | time time.Time 142 | expected string 143 | }{ 144 | {time.Time{}, "0001-01-01T00:00:00Z"}, 145 | {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03T04:05:06.123456789Z"}, 146 | {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03T04:05:06.123456789+02:00"}, 147 | {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03T04:05:06.123456789-06:00"}, 148 | {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03T04:05:06-07:30:09"}, 149 | 150 | {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03T04:05:06.123456789Z"}, 151 | {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03T04:05:06.123456789+02:00"}, 152 | {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03T04:05:06.123456789-06:00"}, 153 | 154 | {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03T04:05:06.123456789Z BC"}, 155 | {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03T04:05:06.123456789+02:00 BC"}, 156 | {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03T04:05:06.123456789-06:00 BC"}, 157 | 158 | {time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03T04:05:06-07:30:09"}, 159 | {time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03T04:05:06-07:30:09 BC"}, 160 | } 161 | 162 | func TestFormatTs(t *testing.T) { 163 | for i, tt := range formatTimeTests { 164 | val := string(formatTs(tt.time)) 165 | if val != tt.expected { 166 | t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected) 167 | } 168 | } 169 | } 170 | 171 | func TestTimestampWithTimeZone(t *testing.T) { 172 | db := openTestConn(t) 173 | defer db.Close() 174 | 175 | tx, err := db.Begin() 176 | if err != nil { 177 | t.Fatal(err) 178 | } 179 | defer tx.Rollback() 180 | 181 | // try several different locations, all included in Go's zoneinfo.zip 182 | for _, locName := range []string{ 183 | "UTC", 184 | "America/Chicago", 185 | "America/New_York", 186 | "Australia/Darwin", 187 | "Australia/Perth", 188 | } { 189 | loc, err := time.LoadLocation(locName) 190 | if err != nil { 191 | t.Logf("Could not load time zone %s - skipping", locName) 192 | continue 193 | } 194 | 195 | // Postgres timestamps have a resolution of 1 microsecond, so don't 196 | // use the full range of the Nanosecond argument 197 | refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc) 198 | 199 | for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} { 200 | // Switch Postgres's timezone to test different output timestamp formats 201 | _, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone)) 202 | if err != nil { 203 | t.Fatal(err) 204 | } 205 | 206 | var gotTime time.Time 207 | row := tx.QueryRow("select $1::timestamp with time zone", refTime) 208 | err = row.Scan(&gotTime) 209 | if err != nil { 210 | t.Fatal(err) 211 | } 212 | 213 | if !refTime.Equal(gotTime) { 214 | t.Errorf("timestamps not equal: %s != %s", refTime, gotTime) 215 | } 216 | 217 | // check that the time zone is set correctly based on TimeZone 218 | pgLoc, err := time.LoadLocation(pgTimeZone) 219 | if err != nil { 220 | t.Logf("Could not load time zone %s - skipping", pgLoc) 221 | continue 222 | } 223 | translated := refTime.In(pgLoc) 224 | if translated.String() != gotTime.String() { 225 | t.Errorf("timestamps not equal: %s != %s", translated, gotTime) 226 | } 227 | } 228 | } 229 | } 230 | 231 | func TestTimestampWithOutTimezone(t *testing.T) { 232 | db := openTestConn(t) 233 | defer db.Close() 234 | 235 | test := func(ts, pgts string) { 236 | r, err := db.Query("SELECT $1::timestamp", pgts) 237 | if err != nil { 238 | t.Fatalf("Could not run query: %v", err) 239 | } 240 | 241 | n := r.Next() 242 | 243 | if n != true { 244 | t.Fatal("Expected at least one row") 245 | } 246 | 247 | var result time.Time 248 | err = r.Scan(&result) 249 | if err != nil { 250 | t.Fatalf("Did not expect error scanning row: %v", err) 251 | } 252 | 253 | expected, err := time.Parse(time.RFC3339, ts) 254 | if err != nil { 255 | t.Fatalf("Could not parse test time literal: %v", err) 256 | } 257 | 258 | if !result.Equal(expected) { 259 | t.Fatalf("Expected time to match %v: got mismatch %v", 260 | expected, result) 261 | } 262 | 263 | n = r.Next() 264 | if n != false { 265 | t.Fatal("Expected only one row") 266 | } 267 | } 268 | 269 | test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00") 270 | 271 | // Test higher precision time 272 | test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033") 273 | } 274 | 275 | func TestInfinityTimestamp(t *testing.T) { 276 | db := openTestConn(t) 277 | defer db.Close() 278 | var err error 279 | var resultT time.Time 280 | 281 | expectedErrorStrPrefix := `sql: Scan error on column index 0: unsupported` 282 | type testCases []struct { 283 | Query string 284 | Param string 285 | ExpectedErrStrPrefix string 286 | ExpectedVal interface{} 287 | } 288 | tc := testCases{ 289 | {"SELECT $1::timestamp", "-infinity", expectedErrorStrPrefix, "-infinity"}, 290 | {"SELECT $1::timestamptz", "-infinity", expectedErrorStrPrefix, "-infinity"}, 291 | {"SELECT $1::timestamp", "infinity", expectedErrorStrPrefix, "infinity"}, 292 | {"SELECT $1::timestamptz", "infinity", expectedErrorStrPrefix, "infinity"}, 293 | } 294 | // try to assert []byte to time.Time 295 | for _, q := range tc { 296 | err = db.QueryRow(q.Query, q.Param).Scan(&resultT) 297 | if !strings.HasPrefix(err.Error(), q.ExpectedErrStrPrefix) { 298 | t.Errorf("Scanning -/+infinity, expected error to have prefix %q, got %q", q.ExpectedErrStrPrefix, err) 299 | } 300 | } 301 | // yield []byte 302 | for _, q := range tc { 303 | var resultI interface{} 304 | err = db.QueryRow(q.Query, q.Param).Scan(&resultI) 305 | if err != nil { 306 | t.Errorf("Scanning -/+infinity, expected no error, got %q", err) 307 | } 308 | result, ok := resultI.([]byte) 309 | if !ok { 310 | t.Errorf("Scanning -/+infinity, expected []byte, got %#v", resultI) 311 | } 312 | if string(result) != q.ExpectedVal { 313 | t.Errorf("Scanning -/+infinity, expected %q, got %q", q.ExpectedVal, result) 314 | } 315 | } 316 | 317 | y1500 := time.Date(1500, time.January, 1, 0, 0, 0, 0, time.UTC) 318 | y2500 := time.Date(2500, time.January, 1, 0, 0, 0, 0, time.UTC) 319 | EnableInfinityTs(y1500, y2500) 320 | 321 | err = db.QueryRow("SELECT $1::timestamp", "infinity").Scan(&resultT) 322 | if err != nil { 323 | t.Errorf("Scanning infinity, expected no error, got %q", err) 324 | } 325 | if !resultT.Equal(y2500) { 326 | t.Errorf("Scanning infinity, expected %q, got %q", y2500, resultT) 327 | } 328 | 329 | err = db.QueryRow("SELECT $1::timestamptz", "infinity").Scan(&resultT) 330 | if err != nil { 331 | t.Errorf("Scanning infinity, expected no error, got %q", err) 332 | } 333 | if !resultT.Equal(y2500) { 334 | t.Errorf("Scanning Infinity, expected time %q, got %q", y2500, resultT.String()) 335 | } 336 | 337 | err = db.QueryRow("SELECT $1::timestamp", "-infinity").Scan(&resultT) 338 | if err != nil { 339 | t.Errorf("Scanning -infinity, expected no error, got %q", err) 340 | } 341 | if !resultT.Equal(y1500) { 342 | t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String()) 343 | } 344 | 345 | err = db.QueryRow("SELECT $1::timestamptz", "-infinity").Scan(&resultT) 346 | if err != nil { 347 | t.Errorf("Scanning -infinity, expected no error, got %q", err) 348 | } 349 | if !resultT.Equal(y1500) { 350 | t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String()) 351 | } 352 | 353 | y_1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC) 354 | y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC) 355 | var s string 356 | err = db.QueryRow("SELECT $1::timestamp::text", y_1500).Scan(&s) 357 | if err != nil { 358 | t.Errorf("Encoding -infinity, expected no error, got %q", err) 359 | } 360 | if s != "-infinity" { 361 | t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s) 362 | } 363 | err = db.QueryRow("SELECT $1::timestamptz::text", y_1500).Scan(&s) 364 | if err != nil { 365 | t.Errorf("Encoding -infinity, expected no error, got %q", err) 366 | } 367 | if s != "-infinity" { 368 | t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s) 369 | } 370 | 371 | err = db.QueryRow("SELECT $1::timestamp::text", y11500).Scan(&s) 372 | if err != nil { 373 | t.Errorf("Encoding infinity, expected no error, got %q", err) 374 | } 375 | if s != "infinity" { 376 | t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s) 377 | } 378 | err = db.QueryRow("SELECT $1::timestamptz::text", y11500).Scan(&s) 379 | if err != nil { 380 | t.Errorf("Encoding infinity, expected no error, got %q", err) 381 | } 382 | if s != "infinity" { 383 | t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s) 384 | } 385 | 386 | disableInfinityTs() 387 | 388 | var panicErrorString string 389 | func() { 390 | defer func() { 391 | panicErrorString, _ = recover().(string) 392 | }() 393 | EnableInfinityTs(y2500, y1500) 394 | }() 395 | if panicErrorString != infinityTsNegativeMustBeSmaller { 396 | t.Errorf("Expected error, %q, got %q", infinityTsNegativeMustBeSmaller, panicErrorString) 397 | } 398 | } 399 | 400 | func TestStringWithNul(t *testing.T) { 401 | db := openTestConn(t) 402 | defer db.Close() 403 | 404 | hello0world := string("hello\x00world") 405 | _, err := db.Query("SELECT $1::text", &hello0world) 406 | if err == nil { 407 | t.Fatal("Postgres accepts a string with nul in it; " + 408 | "injection attacks may be plausible") 409 | } 410 | } 411 | 412 | func TestByteSliceToText(t *testing.T) { 413 | db := openTestConn(t) 414 | defer db.Close() 415 | 416 | b := []byte("hello world") 417 | row := db.QueryRow("SELECT $1::text", b) 418 | 419 | var result []byte 420 | err := row.Scan(&result) 421 | if err != nil { 422 | t.Fatal(err) 423 | } 424 | 425 | if string(result) != string(b) { 426 | t.Fatalf("expected %v but got %v", b, result) 427 | } 428 | } 429 | 430 | func TestStringToBytea(t *testing.T) { 431 | db := openTestConn(t) 432 | defer db.Close() 433 | 434 | b := "hello world" 435 | row := db.QueryRow("SELECT $1::bytea", b) 436 | 437 | var result []byte 438 | err := row.Scan(&result) 439 | if err != nil { 440 | t.Fatal(err) 441 | } 442 | 443 | if !bytes.Equal(result, []byte(b)) { 444 | t.Fatalf("expected %v but got %v", b, result) 445 | } 446 | } 447 | 448 | func TestTextByteSliceToUUID(t *testing.T) { 449 | db := openTestConn(t) 450 | defer db.Close() 451 | 452 | b := []byte("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") 453 | row := db.QueryRow("SELECT $1::uuid", b) 454 | 455 | var result string 456 | err := row.Scan(&result) 457 | if forceBinaryParameters() { 458 | pqErr := err.(*Error) 459 | if pqErr == nil { 460 | t.Errorf("Expected to get error") 461 | } else if pqErr.Code != "22P03" { 462 | t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code) 463 | } 464 | } else { 465 | if err != nil { 466 | t.Fatal(err) 467 | } 468 | 469 | if result != string(b) { 470 | t.Fatalf("expected %v but got %v", b, result) 471 | } 472 | } 473 | } 474 | 475 | func TestBinaryByteSlicetoUUID(t *testing.T) { 476 | db := openTestConn(t) 477 | defer db.Close() 478 | 479 | b := []byte{'\xa0', '\xee', '\xbc', '\x99', 480 | '\x9c', '\x0b', 481 | '\x4e', '\xf8', 482 | '\xbb', '\x00', '\x6b', 483 | '\xb9', '\xbd', '\x38', '\x0a', '\x11'} 484 | row := db.QueryRow("SELECT $1::uuid", b) 485 | 486 | var result string 487 | err := row.Scan(&result) 488 | if forceBinaryParameters() { 489 | if err != nil { 490 | t.Fatal(err) 491 | } 492 | 493 | if result != string("a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11") { 494 | t.Fatalf("expected %v but got %v", b, result) 495 | } 496 | } else { 497 | pqErr := err.(*Error) 498 | if pqErr == nil { 499 | t.Errorf("Expected to get error") 500 | } else if pqErr.Code != "22021" { 501 | t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code) 502 | } 503 | } 504 | } 505 | 506 | func TestStringToUUID(t *testing.T) { 507 | db := openTestConn(t) 508 | defer db.Close() 509 | 510 | s := "a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11" 511 | row := db.QueryRow("SELECT $1::uuid", s) 512 | 513 | var result string 514 | err := row.Scan(&result) 515 | if err != nil { 516 | t.Fatal(err) 517 | } 518 | 519 | if result != s { 520 | t.Fatalf("expected %v but got %v", s, result) 521 | } 522 | } 523 | 524 | func TestTextByteSliceToInt(t *testing.T) { 525 | db := openTestConn(t) 526 | defer db.Close() 527 | 528 | expected := 12345678 529 | b := []byte(fmt.Sprintf("%d", expected)) 530 | row := db.QueryRow("SELECT $1::int", b) 531 | 532 | var result int 533 | err := row.Scan(&result) 534 | if forceBinaryParameters() { 535 | pqErr := err.(*Error) 536 | if pqErr == nil { 537 | t.Errorf("Expected to get error") 538 | } else if pqErr.Code != "22P03" { 539 | t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code) 540 | } 541 | } else { 542 | if err != nil { 543 | t.Fatal(err) 544 | } 545 | if result != expected { 546 | t.Fatalf("expected %v but got %v", expected, result) 547 | } 548 | } 549 | } 550 | 551 | func TestBinaryByteSliceToInt(t *testing.T) { 552 | db := openTestConn(t) 553 | defer db.Close() 554 | 555 | expected := 12345678 556 | b := []byte{'\x00', '\xbc', '\x61', '\x4e'} 557 | row := db.QueryRow("SELECT $1::int", b) 558 | 559 | var result int 560 | err := row.Scan(&result) 561 | if forceBinaryParameters() { 562 | if err != nil { 563 | t.Fatal(err) 564 | } 565 | if result != expected { 566 | t.Fatalf("expected %v but got %v", expected, result) 567 | } 568 | } else { 569 | pqErr := err.(*Error) 570 | if pqErr == nil { 571 | t.Errorf("Expected to get error") 572 | } else if pqErr.Code != "22021" { 573 | t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code) 574 | } 575 | } 576 | } 577 | 578 | func TestTextDecodeIntoString(t *testing.T) { 579 | input := []byte("hello world") 580 | want := string(input) 581 | for _, typ := range []oid.Oid{oid.T_char, oid.T_varchar, oid.T_text} { 582 | got := decode(¶meterStatus{}, input, typ, formatText) 583 | if got != want { 584 | t.Errorf("invalid string decoding output for %T(%+v), got %v but expected %v", typ, typ, got, want) 585 | } 586 | } 587 | } 588 | 589 | func TestByteaOutputFormatEncoding(t *testing.T) { 590 | input := []byte("\\x\x00\x01\x02\xFF\xFEabcdefg0123") 591 | want := []byte("\\x5c78000102fffe6162636465666730313233") 592 | got := encode(¶meterStatus{serverVersion: 90000}, input, oid.T_bytea) 593 | if !bytes.Equal(want, got) { 594 | t.Errorf("invalid hex bytea output, got %v but expected %v", got, want) 595 | } 596 | 597 | want = []byte("\\\\x\\000\\001\\002\\377\\376abcdefg0123") 598 | got = encode(¶meterStatus{serverVersion: 84000}, input, oid.T_bytea) 599 | if !bytes.Equal(want, got) { 600 | t.Errorf("invalid escape bytea output, got %v but expected %v", got, want) 601 | } 602 | } 603 | 604 | func TestByteaOutputFormats(t *testing.T) { 605 | db := openTestConn(t) 606 | defer db.Close() 607 | 608 | if getServerVersion(t, db) < 90000 { 609 | // skip 610 | return 611 | } 612 | 613 | testByteaOutputFormat := func(f string, usePrepared bool) { 614 | expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08") 615 | sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')" 616 | 617 | var data []byte 618 | 619 | // use a txn to avoid relying on getting the same connection 620 | txn, err := db.Begin() 621 | if err != nil { 622 | t.Fatal(err) 623 | } 624 | defer txn.Rollback() 625 | 626 | _, err = txn.Exec("SET LOCAL bytea_output TO " + f) 627 | if err != nil { 628 | t.Fatal(err) 629 | } 630 | var rows *sql.Rows 631 | var stmt *sql.Stmt 632 | if usePrepared { 633 | stmt, err = txn.Prepare(sqlQuery) 634 | if err != nil { 635 | t.Fatal(err) 636 | } 637 | rows, err = stmt.Query() 638 | } else { 639 | // use Query; QueryRow would hide the actual error 640 | rows, err = txn.Query(sqlQuery) 641 | } 642 | if err != nil { 643 | t.Fatal(err) 644 | } 645 | if !rows.Next() { 646 | if rows.Err() != nil { 647 | t.Fatal(rows.Err()) 648 | } 649 | t.Fatal("shouldn't happen") 650 | } 651 | err = rows.Scan(&data) 652 | if err != nil { 653 | t.Fatal(err) 654 | } 655 | err = rows.Close() 656 | if err != nil { 657 | t.Fatal(err) 658 | } 659 | if stmt != nil { 660 | err = stmt.Close() 661 | if err != nil { 662 | t.Fatal(err) 663 | } 664 | } 665 | if !bytes.Equal(data, expectedData) { 666 | t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData) 667 | } 668 | } 669 | 670 | testByteaOutputFormat("hex", false) 671 | testByteaOutputFormat("escape", false) 672 | testByteaOutputFormat("hex", true) 673 | testByteaOutputFormat("escape", true) 674 | } 675 | 676 | func TestAppendEncodedText(t *testing.T) { 677 | var buf []byte 678 | 679 | buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, int64(10)) 680 | buf = append(buf, '\t') 681 | buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, 42.0000000001) 682 | buf = append(buf, '\t') 683 | buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, "hello\tworld") 684 | buf = append(buf, '\t') 685 | buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, []byte{0, 128, 255}) 686 | 687 | if string(buf) != "10\t42.0000000001\thello\\tworld\t\\\\x0080ff" { 688 | t.Fatal(string(buf)) 689 | } 690 | } 691 | 692 | func TestAppendEscapedText(t *testing.T) { 693 | if esc := appendEscapedText(nil, "hallo\tescape"); string(esc) != "hallo\\tescape" { 694 | t.Fatal(string(esc)) 695 | } 696 | if esc := appendEscapedText(nil, "hallo\\tescape\n"); string(esc) != "hallo\\\\tescape\\n" { 697 | t.Fatal(string(esc)) 698 | } 699 | if esc := appendEscapedText(nil, "\n\r\t\f"); string(esc) != "\\n\\r\\t\f" { 700 | t.Fatal(string(esc)) 701 | } 702 | } 703 | 704 | func TestAppendEscapedTextExistingBuffer(t *testing.T) { 705 | var buf []byte 706 | buf = []byte("123\t") 707 | if esc := appendEscapedText(buf, "hallo\tescape"); string(esc) != "123\thallo\\tescape" { 708 | t.Fatal(string(esc)) 709 | } 710 | buf = []byte("123\t") 711 | if esc := appendEscapedText(buf, "hallo\\tescape\n"); string(esc) != "123\thallo\\\\tescape\\n" { 712 | t.Fatal(string(esc)) 713 | } 714 | buf = []byte("123\t") 715 | if esc := appendEscapedText(buf, "\n\r\t\f"); string(esc) != "123\t\\n\\r\\t\f" { 716 | t.Fatal(string(esc)) 717 | } 718 | } 719 | 720 | func BenchmarkAppendEscapedText(b *testing.B) { 721 | longString := "" 722 | for i := 0; i < 100; i++ { 723 | longString += "123456789\n" 724 | } 725 | for i := 0; i < b.N; i++ { 726 | appendEscapedText(nil, longString) 727 | } 728 | } 729 | 730 | func BenchmarkAppendEscapedTextNoEscape(b *testing.B) { 731 | longString := "" 732 | for i := 0; i < 100; i++ { 733 | longString += "1234567890" 734 | } 735 | for i := 0; i < b.N; i++ { 736 | appendEscapedText(nil, longString) 737 | } 738 | } 739 | -------------------------------------------------------------------------------- /notify.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | // Package pq is a pure Go Postgres driver for the database/sql package. 4 | // This module contains support for Postgres LISTEN/NOTIFY. 5 | 6 | import ( 7 | "errors" 8 | "fmt" 9 | "sync" 10 | "sync/atomic" 11 | "time" 12 | ) 13 | 14 | // Notification represents a single notification from the database. 15 | type Notification struct { 16 | // Process ID (PID) of the notifying postgres backend. 17 | BePid int 18 | // Name of the channel the notification was sent on. 19 | Channel string 20 | // Payload, or the empty string if unspecified. 21 | Extra string 22 | } 23 | 24 | func recvNotification(r *readBuf) *Notification { 25 | bePid := r.int32() 26 | channel := r.string() 27 | extra := r.string() 28 | 29 | return &Notification{bePid, channel, extra} 30 | } 31 | 32 | const ( 33 | connStateIdle int32 = iota 34 | connStateExpectResponse 35 | connStateExpectReadyForQuery 36 | ) 37 | 38 | type message struct { 39 | typ byte 40 | err error 41 | } 42 | 43 | var errListenerConnClosed = errors.New("pq: ListenerConn has been closed") 44 | 45 | // ListenerConn is a low-level interface for waiting for notifications. You 46 | // should use Listener instead. 47 | type ListenerConn struct { 48 | // guards cn and err 49 | connectionLock sync.Mutex 50 | cn *conn 51 | err error 52 | 53 | connState int32 54 | 55 | // the sending goroutine will be holding this lock 56 | senderLock sync.Mutex 57 | 58 | notificationChan chan<- *Notification 59 | 60 | replyChan chan message 61 | } 62 | 63 | // Creates a new ListenerConn. Use NewListener instead. 64 | func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { 65 | return newDialListenerConn(defaultDialer{}, name, notificationChan) 66 | } 67 | 68 | func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) { 69 | cn, err := DialOpen(d, name) 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | l := &ListenerConn{ 75 | cn: cn.(*conn), 76 | notificationChan: c, 77 | connState: connStateIdle, 78 | replyChan: make(chan message, 2), 79 | } 80 | 81 | go l.listenerConnMain() 82 | 83 | return l, nil 84 | } 85 | 86 | // We can only allow one goroutine at a time to be running a query on the 87 | // connection for various reasons, so the goroutine sending on the connection 88 | // must be holding senderLock. 89 | // 90 | // Returns an error if an unrecoverable error has occurred and the ListenerConn 91 | // should be abandoned. 92 | func (l *ListenerConn) acquireSenderLock() error { 93 | // we must acquire senderLock first to avoid deadlocks; see ExecSimpleQuery 94 | l.senderLock.Lock() 95 | 96 | l.connectionLock.Lock() 97 | err := l.err 98 | l.connectionLock.Unlock() 99 | if err != nil { 100 | l.senderLock.Unlock() 101 | return err 102 | } 103 | return nil 104 | } 105 | 106 | func (l *ListenerConn) releaseSenderLock() { 107 | l.senderLock.Unlock() 108 | } 109 | 110 | // setState advances the protocol state to newState. Returns false if moving 111 | // to that state from the current state is not allowed. 112 | func (l *ListenerConn) setState(newState int32) bool { 113 | var expectedState int32 114 | 115 | switch newState { 116 | case connStateIdle: 117 | expectedState = connStateExpectReadyForQuery 118 | case connStateExpectResponse: 119 | expectedState = connStateIdle 120 | case connStateExpectReadyForQuery: 121 | expectedState = connStateExpectResponse 122 | default: 123 | panic(fmt.Sprintf("unexpected listenerConnState %d", newState)) 124 | } 125 | 126 | return atomic.CompareAndSwapInt32(&l.connState, expectedState, newState) 127 | } 128 | 129 | // Main logic is here: receive messages from the postgres backend, forward 130 | // notifications and query replies and keep the internal state in sync with the 131 | // protocol state. Returns when the connection has been lost, is about to go 132 | // away or should be discarded because we couldn't agree on the state with the 133 | // server backend. 134 | func (l *ListenerConn) listenerConnLoop() (err error) { 135 | defer errRecoverNoErrBadConn(&err) 136 | 137 | r := &readBuf{} 138 | for { 139 | t, err := l.cn.recvMessage(r) 140 | if err != nil { 141 | return err 142 | } 143 | 144 | switch t { 145 | case 'A': 146 | // recvNotification copies all the data so we don't need to worry 147 | // about the scratch buffer being overwritten. 148 | l.notificationChan <- recvNotification(r) 149 | 150 | case 'T', 'D': 151 | // only used by tests; ignore 152 | 153 | case 'E': 154 | // We might receive an ErrorResponse even when not in a query; it 155 | // is expected that the server will close the connection after 156 | // that, but we should make sure that the error we display is the 157 | // one from the stray ErrorResponse, not io.ErrUnexpectedEOF. 158 | if !l.setState(connStateExpectReadyForQuery) { 159 | return parseError(r) 160 | } 161 | l.replyChan <- message{t, parseError(r)} 162 | 163 | case 'C', 'I': 164 | if !l.setState(connStateExpectReadyForQuery) { 165 | // protocol out of sync 166 | return fmt.Errorf("unexpected CommandComplete") 167 | } 168 | // ExecSimpleQuery doesn't need to know about this message 169 | 170 | case 'Z': 171 | if !l.setState(connStateIdle) { 172 | // protocol out of sync 173 | return fmt.Errorf("unexpected ReadyForQuery") 174 | } 175 | l.replyChan <- message{t, nil} 176 | 177 | case 'N', 'S': 178 | // ignore 179 | default: 180 | return fmt.Errorf("unexpected message %q from server in listenerConnLoop", t) 181 | } 182 | } 183 | } 184 | 185 | // This is the main routine for the goroutine receiving on the database 186 | // connection. Most of the main logic is in listenerConnLoop. 187 | func (l *ListenerConn) listenerConnMain() { 188 | err := l.listenerConnLoop() 189 | 190 | // listenerConnLoop terminated; we're done, but we still have to clean up. 191 | // Make sure nobody tries to start any new queries by making sure the err 192 | // pointer is set. It is important that we do not overwrite its value; a 193 | // connection could be closed by either this goroutine or one sending on 194 | // the connection -- whoever closes the connection is assumed to have the 195 | // more meaningful error message (as the other one will probably get 196 | // net.errClosed), so that goroutine sets the error we expose while the 197 | // other error is discarded. If the connection is lost while two 198 | // goroutines are operating on the socket, it probably doesn't matter which 199 | // error we expose so we don't try to do anything more complex. 200 | l.connectionLock.Lock() 201 | if l.err == nil { 202 | l.err = err 203 | } 204 | l.cn.Close() 205 | l.connectionLock.Unlock() 206 | 207 | // There might be a query in-flight; make sure nobody's waiting for a 208 | // response to it, since there's not going to be one. 209 | close(l.replyChan) 210 | 211 | // let the listener know we're done 212 | close(l.notificationChan) 213 | 214 | // this ListenerConn is done 215 | } 216 | 217 | // Send a LISTEN query to the server. See ExecSimpleQuery. 218 | func (l *ListenerConn) Listen(channel string) (bool, error) { 219 | return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel)) 220 | } 221 | 222 | // Send an UNLISTEN query to the server. See ExecSimpleQuery. 223 | func (l *ListenerConn) Unlisten(channel string) (bool, error) { 224 | return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel)) 225 | } 226 | 227 | // Send `UNLISTEN *` to the server. See ExecSimpleQuery. 228 | func (l *ListenerConn) UnlistenAll() (bool, error) { 229 | return l.ExecSimpleQuery("UNLISTEN *") 230 | } 231 | 232 | // Ping the remote server to make sure it's alive. Non-nil error means the 233 | // connection has failed and should be abandoned. 234 | func (l *ListenerConn) Ping() error { 235 | sent, err := l.ExecSimpleQuery("") 236 | if !sent { 237 | return err 238 | } 239 | if err != nil { 240 | // shouldn't happen 241 | panic(err) 242 | } 243 | return nil 244 | } 245 | 246 | // Attempt to send a query on the connection. Returns an error if sending the 247 | // query failed, and the caller should initiate closure of this connection. 248 | // The caller must be holding senderLock (see acquireSenderLock and 249 | // releaseSenderLock). 250 | func (l *ListenerConn) sendSimpleQuery(q string) (err error) { 251 | defer errRecoverNoErrBadConn(&err) 252 | 253 | // must set connection state before sending the query 254 | if !l.setState(connStateExpectResponse) { 255 | panic("two queries running at the same time") 256 | } 257 | 258 | // Can't use l.cn.writeBuf here because it uses the scratch buffer which 259 | // might get overwritten by listenerConnLoop. 260 | b := &writeBuf{ 261 | buf: []byte("Q\x00\x00\x00\x00"), 262 | pos: 1, 263 | } 264 | b.string(q) 265 | l.cn.send(b) 266 | 267 | return nil 268 | } 269 | 270 | // Execute a "simple query" (i.e. one with no bindable parameters) on the 271 | // connection. The possible return values are: 272 | // 1) "executed" is true; the query was executed to completion on the 273 | // database server. If the query failed, err will be set to the error 274 | // returned by the database, otherwise err will be nil. 275 | // 2) If "executed" is false, the query could not be executed on the remote 276 | // server. err will be non-nil. 277 | // 278 | // After a call to ExecSimpleQuery has returned an executed=false value, the 279 | // connection has either been closed or will be closed shortly thereafter, and 280 | // all subsequently executed queries will return an error. 281 | func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { 282 | if err = l.acquireSenderLock(); err != nil { 283 | return false, err 284 | } 285 | defer l.releaseSenderLock() 286 | 287 | err = l.sendSimpleQuery(q) 288 | if err != nil { 289 | // We can't know what state the protocol is in, so we need to abandon 290 | // this connection. 291 | l.connectionLock.Lock() 292 | // Set the error pointer if it hasn't been set already; see 293 | // listenerConnMain. 294 | if l.err == nil { 295 | l.err = err 296 | } 297 | l.connectionLock.Unlock() 298 | l.cn.c.Close() 299 | return false, err 300 | } 301 | 302 | // now we just wait for a reply.. 303 | for { 304 | m, ok := <-l.replyChan 305 | if !ok { 306 | // We lost the connection to server, don't bother waiting for a 307 | // a response. err should have been set already. 308 | l.connectionLock.Lock() 309 | err := l.err 310 | l.connectionLock.Unlock() 311 | return false, err 312 | } 313 | switch m.typ { 314 | case 'Z': 315 | // sanity check 316 | if m.err != nil { 317 | panic("m.err != nil") 318 | } 319 | // done; err might or might not be set 320 | return true, err 321 | 322 | case 'E': 323 | // sanity check 324 | if m.err == nil { 325 | panic("m.err == nil") 326 | } 327 | // server responded with an error; ReadyForQuery to follow 328 | err = m.err 329 | 330 | default: 331 | return false, fmt.Errorf("unknown response for simple query: %q", m.typ) 332 | } 333 | } 334 | } 335 | 336 | func (l *ListenerConn) Close() error { 337 | l.connectionLock.Lock() 338 | if l.err != nil { 339 | l.connectionLock.Unlock() 340 | return errListenerConnClosed 341 | } 342 | l.err = errListenerConnClosed 343 | l.connectionLock.Unlock() 344 | // We can't send anything on the connection without holding senderLock. 345 | // Simply close the net.Conn to wake up everyone operating on it. 346 | return l.cn.c.Close() 347 | } 348 | 349 | // Err() returns the reason the connection was closed. It is not safe to call 350 | // this function until l.Notify has been closed. 351 | func (l *ListenerConn) Err() error { 352 | return l.err 353 | } 354 | 355 | var errListenerClosed = errors.New("pq: Listener has been closed") 356 | 357 | var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") 358 | var ErrChannelNotOpen = errors.New("pq: channel is not open") 359 | 360 | type ListenerEventType int 361 | 362 | const ( 363 | // Emitted only when the database connection has been initially 364 | // initialized. err will always be nil. 365 | ListenerEventConnected ListenerEventType = iota 366 | 367 | // Emitted after a database connection has been lost, either because of an 368 | // error or because Close has been called. err will be set to the reason 369 | // the database connection was lost. 370 | ListenerEventDisconnected 371 | 372 | // Emitted after a database connection has been re-established after 373 | // connection loss. err will always be nil. After this event has been 374 | // emitted, a nil pq.Notification is sent on the Listener.Notify channel. 375 | ListenerEventReconnected 376 | 377 | // Emitted after a connection to the database was attempted, but failed. 378 | // err will be set to an error describing why the connection attempt did 379 | // not succeed. 380 | ListenerEventConnectionAttemptFailed 381 | ) 382 | 383 | type EventCallbackType func(event ListenerEventType, err error) 384 | 385 | // Listener provides an interface for listening to notifications from a 386 | // PostgreSQL database. For general usage information, see section 387 | // "Notifications". 388 | // 389 | // Listener can safely be used from concurrently running goroutines. 390 | type Listener struct { 391 | // Channel for receiving notifications from the database. In some cases a 392 | // nil value will be sent. See section "Notifications" above. 393 | Notify chan *Notification 394 | 395 | name string 396 | minReconnectInterval time.Duration 397 | maxReconnectInterval time.Duration 398 | dialer Dialer 399 | eventCallback EventCallbackType 400 | 401 | lock sync.Mutex 402 | isClosed bool 403 | reconnectCond *sync.Cond 404 | cn *ListenerConn 405 | connNotificationChan <-chan *Notification 406 | channels map[string]struct{} 407 | } 408 | 409 | // NewListener creates a new database connection dedicated to LISTEN / NOTIFY. 410 | // 411 | // name should be set to a connection string to be used to establish the 412 | // database connection (see section "Connection String Parameters" above). 413 | // 414 | // minReconnectInterval controls the duration to wait before trying to 415 | // re-establish the database connection after connection loss. After each 416 | // consecutive failure this interval is doubled, until maxReconnectInterval is 417 | // reached. Successfully completing the connection establishment procedure 418 | // resets the interval back to minReconnectInterval. 419 | // 420 | // The last parameter eventCallback can be set to a function which will be 421 | // called by the Listener when the state of the underlying database connection 422 | // changes. This callback will be called by the goroutine which dispatches the 423 | // notifications over the Notify channel, so you should try to avoid doing 424 | // potentially time-consuming operations from the callback. 425 | func NewListener(name string, 426 | minReconnectInterval time.Duration, 427 | maxReconnectInterval time.Duration, 428 | eventCallback EventCallbackType) *Listener { 429 | return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) 430 | } 431 | 432 | // NewDialListener is like NewListener but it takes a Dialer. 433 | func NewDialListener(d Dialer, 434 | name string, 435 | minReconnectInterval time.Duration, 436 | maxReconnectInterval time.Duration, 437 | eventCallback EventCallbackType) *Listener { 438 | 439 | l := &Listener{ 440 | name: name, 441 | minReconnectInterval: minReconnectInterval, 442 | maxReconnectInterval: maxReconnectInterval, 443 | dialer: d, 444 | eventCallback: eventCallback, 445 | 446 | channels: make(map[string]struct{}), 447 | 448 | Notify: make(chan *Notification, 32), 449 | } 450 | l.reconnectCond = sync.NewCond(&l.lock) 451 | 452 | go l.listenerMain() 453 | 454 | return l 455 | } 456 | 457 | // Returns the notification channel for this listener. This is the same 458 | // channel as Notify, and will not be recreated during the life time of the 459 | // Listener. 460 | func (l *Listener) NotificationChannel() <-chan *Notification { 461 | return l.Notify 462 | } 463 | 464 | // Listen starts listening for notifications on a channel. Calls to this 465 | // function will block until an acknowledgement has been received from the 466 | // server. Note that Listener automatically re-establishes the connection 467 | // after connection loss, so this function may block indefinitely if the 468 | // connection can not be re-established. 469 | // 470 | // Listen will only fail in three conditions: 471 | // 1) The channel is already open. The returned error will be 472 | // ErrChannelAlreadyOpen. 473 | // 2) The query was executed on the remote server, but PostgreSQL returned an 474 | // error message in response to the query. The returned error will be a 475 | // pq.Error containing the information the server supplied. 476 | // 3) Close is called on the Listener before the request could be completed. 477 | // 478 | // The channel name is case-sensitive. 479 | func (l *Listener) Listen(channel string) error { 480 | l.lock.Lock() 481 | defer l.lock.Unlock() 482 | 483 | if l.isClosed { 484 | return errListenerClosed 485 | } 486 | 487 | // The server allows you to issue a LISTEN on a channel which is already 488 | // open, but it seems useful to be able to detect this case to spot for 489 | // mistakes in application logic. If the application genuinely does't 490 | // care, it can check the exported error and ignore it. 491 | _, exists := l.channels[channel] 492 | if exists { 493 | return ErrChannelAlreadyOpen 494 | } 495 | 496 | if l.cn != nil { 497 | // If gotResponse is true but error is set, the query was executed on 498 | // the remote server, but resulted in an error. This should be 499 | // relatively rare, so it's fine if we just pass the error to our 500 | // caller. However, if gotResponse is false, we could not complete the 501 | // query on the remote server and our underlying connection is about 502 | // to go away, so we only add relname to l.channels, and wait for 503 | // resync() to take care of the rest. 504 | gotResponse, err := l.cn.Listen(channel) 505 | if gotResponse && err != nil { 506 | return err 507 | } 508 | } 509 | 510 | l.channels[channel] = struct{}{} 511 | for l.cn == nil { 512 | l.reconnectCond.Wait() 513 | // we let go of the mutex for a while 514 | if l.isClosed { 515 | return errListenerClosed 516 | } 517 | } 518 | 519 | return nil 520 | } 521 | 522 | // Unlisten removes a channel from the Listener's channel list. Returns 523 | // ErrChannelNotOpen if the Listener is not listening on the specified channel. 524 | // Returns immediately with no error if there is no connection. Note that you 525 | // might still get notifications for this channel even after Unlisten has 526 | // returned. 527 | // 528 | // The channel name is case-sensitive. 529 | func (l *Listener) Unlisten(channel string) error { 530 | l.lock.Lock() 531 | defer l.lock.Unlock() 532 | 533 | if l.isClosed { 534 | return errListenerClosed 535 | } 536 | 537 | // Similarly to LISTEN, this is not an error in Postgres, but it seems 538 | // useful to distinguish from the normal conditions. 539 | _, exists := l.channels[channel] 540 | if !exists { 541 | return ErrChannelNotOpen 542 | } 543 | 544 | if l.cn != nil { 545 | // Similarly to Listen (see comment in that function), the caller 546 | // should only be bothered with an error if it came from the backend as 547 | // a response to our query. 548 | gotResponse, err := l.cn.Unlisten(channel) 549 | if gotResponse && err != nil { 550 | return err 551 | } 552 | } 553 | 554 | // Don't bother waiting for resync if there's no connection. 555 | delete(l.channels, channel) 556 | return nil 557 | } 558 | 559 | // UnlistenAll removes all channels from the Listener's channel list. Returns 560 | // immediately with no error if there is no connection. Note that you might 561 | // still get notifications for any of the deleted channels even after 562 | // UnlistenAll has returned. 563 | func (l *Listener) UnlistenAll() error { 564 | l.lock.Lock() 565 | defer l.lock.Unlock() 566 | 567 | if l.isClosed { 568 | return errListenerClosed 569 | } 570 | 571 | if l.cn != nil { 572 | // Similarly to Listen (see comment in that function), the caller 573 | // should only be bothered with an error if it came from the backend as 574 | // a response to our query. 575 | gotResponse, err := l.cn.UnlistenAll() 576 | if gotResponse && err != nil { 577 | return err 578 | } 579 | } 580 | 581 | // Don't bother waiting for resync if there's no connection. 582 | l.channels = make(map[string]struct{}) 583 | return nil 584 | } 585 | 586 | // Ping the remote server to make sure it's alive. Non-nil return value means 587 | // that there is no active connection. 588 | func (l *Listener) Ping() error { 589 | l.lock.Lock() 590 | defer l.lock.Unlock() 591 | 592 | if l.isClosed { 593 | return errListenerClosed 594 | } 595 | if l.cn == nil { 596 | return errors.New("no connection") 597 | } 598 | 599 | return l.cn.Ping() 600 | } 601 | 602 | // Clean up after losing the server connection. Returns l.cn.Err(), which 603 | // should have the reason the connection was lost. 604 | func (l *Listener) disconnectCleanup() error { 605 | l.lock.Lock() 606 | defer l.lock.Unlock() 607 | 608 | // sanity check; can't look at Err() until the channel has been closed 609 | select { 610 | case _, ok := <-l.connNotificationChan: 611 | if ok { 612 | panic("connNotificationChan not closed") 613 | } 614 | default: 615 | panic("connNotificationChan not closed") 616 | } 617 | 618 | err := l.cn.Err() 619 | l.cn.Close() 620 | l.cn = nil 621 | return err 622 | } 623 | 624 | // Synchronize the list of channels we want to be listening on with the server 625 | // after the connection has been established. 626 | func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error { 627 | doneChan := make(chan error) 628 | go func() { 629 | for channel := range l.channels { 630 | // If we got a response, return that error to our caller as it's 631 | // going to be more descriptive than cn.Err(). 632 | gotResponse, err := cn.Listen(channel) 633 | if gotResponse && err != nil { 634 | doneChan <- err 635 | return 636 | } 637 | 638 | // If we couldn't reach the server, wait for notificationChan to 639 | // close and then return the error message from the connection, as 640 | // per ListenerConn's interface. 641 | if err != nil { 642 | for _ = range notificationChan { 643 | } 644 | doneChan <- cn.Err() 645 | return 646 | } 647 | } 648 | doneChan <- nil 649 | }() 650 | 651 | // Ignore notifications while synchronization is going on to avoid 652 | // deadlocks. We have to send a nil notification over Notify anyway as 653 | // we can't possibly know which notifications (if any) were lost while 654 | // the connection was down, so there's no reason to try and process 655 | // these messages at all. 656 | for { 657 | select { 658 | case _, ok := <-notificationChan: 659 | if !ok { 660 | notificationChan = nil 661 | } 662 | 663 | case err := <-doneChan: 664 | return err 665 | } 666 | } 667 | } 668 | 669 | // caller should NOT be holding l.lock 670 | func (l *Listener) closed() bool { 671 | l.lock.Lock() 672 | defer l.lock.Unlock() 673 | 674 | return l.isClosed 675 | } 676 | 677 | func (l *Listener) connect() error { 678 | notificationChan := make(chan *Notification, 32) 679 | cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) 680 | if err != nil { 681 | return err 682 | } 683 | 684 | l.lock.Lock() 685 | defer l.lock.Unlock() 686 | 687 | err = l.resync(cn, notificationChan) 688 | if err != nil { 689 | cn.Close() 690 | return err 691 | } 692 | 693 | l.cn = cn 694 | l.connNotificationChan = notificationChan 695 | l.reconnectCond.Broadcast() 696 | 697 | return nil 698 | } 699 | 700 | // Close disconnects the Listener from the database and shuts it down. 701 | // Subsequent calls to its methods will return an error. Close returns an 702 | // error if the connection has already been closed. 703 | func (l *Listener) Close() error { 704 | l.lock.Lock() 705 | defer l.lock.Unlock() 706 | 707 | if l.isClosed { 708 | return errListenerClosed 709 | } 710 | 711 | if l.cn != nil { 712 | l.cn.Close() 713 | } 714 | l.isClosed = true 715 | 716 | return nil 717 | } 718 | 719 | func (l *Listener) emitEvent(event ListenerEventType, err error) { 720 | if l.eventCallback != nil { 721 | l.eventCallback(event, err) 722 | } 723 | } 724 | 725 | // Main logic here: maintain a connection to the server when possible, wait 726 | // for notifications and emit events. 727 | func (l *Listener) listenerConnLoop() { 728 | var nextReconnect time.Time 729 | 730 | reconnectInterval := l.minReconnectInterval 731 | for { 732 | for { 733 | err := l.connect() 734 | if err == nil { 735 | break 736 | } 737 | 738 | if l.closed() { 739 | return 740 | } 741 | l.emitEvent(ListenerEventConnectionAttemptFailed, err) 742 | 743 | time.Sleep(reconnectInterval) 744 | reconnectInterval *= 2 745 | if reconnectInterval > l.maxReconnectInterval { 746 | reconnectInterval = l.maxReconnectInterval 747 | } 748 | } 749 | 750 | if nextReconnect.IsZero() { 751 | l.emitEvent(ListenerEventConnected, nil) 752 | } else { 753 | l.emitEvent(ListenerEventReconnected, nil) 754 | l.Notify <- nil 755 | } 756 | 757 | reconnectInterval = l.minReconnectInterval 758 | nextReconnect = time.Now().Add(reconnectInterval) 759 | 760 | for { 761 | notification, ok := <-l.connNotificationChan 762 | if !ok { 763 | // lost connection, loop again 764 | break 765 | } 766 | l.Notify <- notification 767 | } 768 | 769 | err := l.disconnectCleanup() 770 | if l.closed() { 771 | return 772 | } 773 | l.emitEvent(ListenerEventDisconnected, err) 774 | 775 | time.Sleep(nextReconnect.Sub(time.Now())) 776 | } 777 | } 778 | 779 | func (l *Listener) listenerMain() { 780 | l.listenerConnLoop() 781 | close(l.Notify) 782 | } 783 | --------------------------------------------------------------------------------