├── Makefile ├── util.go ├── AUTHORS ├── examples ├── hellosql │ └── hellosql.go ├── scan │ └── scan.go ├── multipleselects │ └── multipleselects.go ├── statements │ └── statements.go └── pool │ └── pool.go ├── LICENSE ├── README.mdown ├── conn_log.go ├── data └── testdatabase.sql ├── error.go ├── state.go ├── types.go ├── driver.go ├── parameter.go ├── pool.go ├── messagecodes.go ├── conn_write.go ├── statement.go ├── conn_read.go ├── conn.go ├── resultset.go └── pgsql_test.go /Makefile: -------------------------------------------------------------------------------- 1 | include $(GOROOT)/src/Make.inc 2 | 3 | TARG=pgsql 4 | GOFILES=\ 5 | conn.go\ 6 | conn_log.go\ 7 | conn_read.go\ 8 | conn_write.go\ 9 | error.go\ 10 | messagecodes.go\ 11 | parameter.go\ 12 | resultset.go\ 13 | state.go\ 14 | statement.go\ 15 | types.go\ 16 | util.go\ 17 | pool.go 18 | 19 | include $(GOROOT)/src/Make.pkg 20 | -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | func panicIfErr(err error) { 8 | if err != nil { 9 | panic(err) 10 | } 11 | } 12 | 13 | func panicNotImplemented() { 14 | panic("not implemented") 15 | } 16 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | # This is the official list of 'go-pgsql' authors for copyright purposes. 2 | 3 | # Names should be added to this file as 4 | # Name or Organization 5 | # The email address is not required for organizations. 6 | 7 | # Please keep the list sorted. 8 | 9 | # Contributors 10 | # ============ 11 | 12 | Alexander Neumann 13 | Andrew Zeneski 14 | Christopher Browne 15 | Martin Marcher 16 | Samuel Stauffer 17 | Sascha Peilicke 18 | Sergey Shepelev 19 | -------------------------------------------------------------------------------- /examples/hellosql/hellosql.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "database/sql" 9 | "fmt" 10 | "log" 11 | ) 12 | 13 | import ( 14 | _ "github.com/lxn/go-pgsql" 15 | ) 16 | 17 | func main() { 18 | db, err := sql.Open("postgres", "dbname=testdatabase user=testuser password=testpassword") 19 | if err != nil { 20 | log.Fatal(err) 21 | } 22 | 23 | defer db.Close() 24 | 25 | var msg string 26 | 27 | err = db.QueryRow("SELECT $1 || ' ' || $2;", "Hello", "SQL").Scan(&msg) 28 | if err != nil { 29 | log.Fatal(err) 30 | } 31 | 32 | fmt.Println(msg) 33 | } 34 | -------------------------------------------------------------------------------- /examples/scan/scan.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "os" 10 | ) 11 | 12 | import ( 13 | "github.com/lxn/go-pgsql" 14 | ) 15 | 16 | type item struct { 17 | id int 18 | name string 19 | price float32 20 | } 21 | 22 | func main() { 23 | // conn, err := pgsql.Connect("dbname=postgres user=cbbrowne port=7099", pgsql.LogError) 24 | // Can have a long connection string, if needed, but if values are set in environment, it's all optional 25 | conn, err := pgsql.Connect("", pgsql.LogError) 26 | if err != nil { 27 | os.Exit(1) 28 | } 29 | defer conn.Close() 30 | 31 | var x item 32 | 33 | _, err = conn.Scan("SELECT 123, 'abc', 14.99;", &x.id, &x.name, &x.price) 34 | if err != nil { 35 | os.Exit(1) 36 | } 37 | 38 | fmt.Printf("item x: '%+v'\n", x) 39 | } 40 | -------------------------------------------------------------------------------- /examples/multipleselects/multipleselects.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "os" 10 | ) 11 | 12 | import ( 13 | "github.com/lxn/go-pgsql" 14 | ) 15 | 16 | func main() { 17 | conn, err := pgsql.Connect("dbname=postgres user=cbbrowne port=7099", pgsql.LogError) 18 | 19 | if err != nil { 20 | os.Exit(1) 21 | } 22 | defer conn.Close() 23 | 24 | rs, err := conn.Query("SELECT 1 AS num; SELECT 2 AS num; SELECT 3 AS num;") 25 | if err != nil { 26 | os.Exit(1) 27 | } 28 | defer rs.Close() 29 | 30 | for { 31 | hasRow, err := rs.FetchNext() 32 | if err != nil { 33 | os.Exit(1) 34 | } 35 | if hasRow { 36 | num, _, _ := rs.Int(0) 37 | fmt.Println("num:", num) 38 | } else { 39 | hasResult, err := rs.NextResult() 40 | if err != nil { 41 | os.Exit(1) 42 | } 43 | if !hasResult { 44 | break 45 | } 46 | fmt.Println("next result") 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2010 The go-pgsql Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 1. Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | 2. Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | 3. The names of the authors may not be used to endorse or promote products 12 | derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR 15 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 16 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 17 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT, INDIRECT, 18 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 19 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 20 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 21 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 22 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 23 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /examples/statements/statements.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "os" 10 | ) 11 | 12 | import ( 13 | "github.com/lxn/go-pgsql" 14 | ) 15 | 16 | func queryAndPrintResults(stmt *pgsql.Statement) { 17 | rs, err := stmt.Query() 18 | if err != nil { 19 | os.Exit(1) 20 | } 21 | defer rs.Close() 22 | 23 | stroptOrd := rs.Ordinal("stropt") 24 | 25 | for { 26 | hasRow, err := rs.FetchNext() 27 | if err != nil { 28 | os.Exit(1) 29 | } 30 | if !hasRow { 31 | break 32 | } 33 | 34 | stropt, isNull, err := rs.String(stroptOrd) 35 | if err != nil { 36 | os.Exit(1) 37 | } 38 | if isNull { 39 | stropt = "(null)" 40 | } 41 | fmt.Println("stropt:", stropt) 42 | } 43 | } 44 | 45 | func main() { 46 | conn, err := pgsql.Connect("dbname=postgres user=cbbrowne port=7099", pgsql.LogError) 47 | 48 | if err != nil { 49 | os.Exit(1) 50 | } 51 | defer conn.Close() 52 | 53 | command := "SELECT * FROM table1 WHERE id = @id;" 54 | idParam := pgsql.NewParameter("@id", pgsql.Integer) 55 | 56 | stmt, err := conn.Prepare(command, idParam) 57 | if err != nil { 58 | os.Exit(1) 59 | } 60 | defer stmt.Close() 61 | 62 | for id := 1; id <= 3; id++ { 63 | err = idParam.SetValue(id) 64 | if err != nil { 65 | os.Exit(1) 66 | } 67 | queryAndPrintResults(stmt) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /README.mdown: -------------------------------------------------------------------------------- 1 | About go-pgsql 2 | ============== 3 | 4 | go-pgsql is a [PostgreSQL](http://www.postgresql.org) client library for the 5 | [Go](http://golang.org) programming language. 6 | 7 | It partially implements version 3.0 of the 8 | [PostgreSQL](http://www.postgresql.org) frontend/backend protocol, so it 9 | should work with servers of version 7.4 and later. 10 | 11 | It now supports database/sql in addition to its existing interface. 12 | 13 | Installing go-pgsql 14 | =================== 15 | 16 | First make sure you have a working [Go](http://golang.org) installation, see 17 | the installation guide at http://golang.org/doc/install.html 18 | 19 | Now you should be able to install go-pgsql by running 20 | `go get github.com/lxn/go-pgsql` 21 | 22 | Using go-pgsql 23 | ============== 24 | 25 | There are some examples in the 26 | [examples](examples) directory which 27 | should get you started. 28 | 29 | Please open an issue on the bug tracker if you encounter a bug. 30 | 31 | Missing Features 32 | ================ 33 | 34 | go-pgsql is currently missing support for some features, including: 35 | 36 | - authentication types other than MD5 37 | - SSL encrypted sessions 38 | - some data types like bytea, ... 39 | - canceling commands/queries 40 | - bulk copy 41 | - ... 42 | 43 | Connection Info 44 | ================ 45 | 46 | To connect, you must pass a connection string to pgsql.Connect(). 47 | Much as with the libpq conninfo parameter, values are optional, and 48 | may be overridden by environment variables at runtime, specifically: 49 | 50 | - PGPORT - port number 51 | - PGUSER - user name (defaults to postgres) 52 | - PGHOST - host name/address (defaults to localhost) 53 | - PGDATABASE - name of database (defaults to user name if not specified) 54 | -------------------------------------------------------------------------------- /conn_log.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "bytes" 9 | "errors" 10 | "fmt" 11 | "log" 12 | "runtime" 13 | ) 14 | 15 | func (conn *Conn) log(level LogLevel, v ...interface{}) { 16 | log.Print(v...) 17 | } 18 | 19 | func (conn *Conn) logf(level LogLevel, format string, v ...interface{}) { 20 | log.Printf(format, v...) 21 | } 22 | 23 | func (conn *Conn) logError(level LogLevel, err error) { 24 | if conn.LogLevel >= level { 25 | conn.log(level, err) 26 | } 27 | } 28 | 29 | func (conn *Conn) logEnter(funcName string) string { 30 | conn.log(LogDebug, "entering: ", "pgsql."+funcName) 31 | return funcName 32 | } 33 | 34 | func (conn *Conn) logExit(funcName string) { 35 | conn.log(LogDebug, "exiting: ", "pgsql."+funcName) 36 | } 37 | 38 | func (conn *Conn) logAndConvertPanic(x interface{}) (err error) { 39 | buf := bytes.NewBuffer(nil) 40 | 41 | buf.WriteString(fmt.Sprintf("Error: %v\nStack Trace:\n", x)) 42 | buf.WriteString("=======================================================\n") 43 | 44 | i := 0 45 | for { 46 | pc, file, line, ok := runtime.Caller(i + 3) 47 | if !ok { 48 | break 49 | } 50 | if i > 0 { 51 | buf.WriteString("-------------------------------------------------------\n") 52 | } 53 | 54 | fun := runtime.FuncForPC(pc) 55 | name := fun.Name() 56 | 57 | buf.WriteString(fmt.Sprintf("%s (%s, Line %d)\n", name, file, line)) 58 | 59 | i++ 60 | } 61 | buf.WriteString("=======================================================\n") 62 | 63 | if conn.LogLevel >= LogError { 64 | conn.log(LogError, buf) 65 | } 66 | 67 | err, ok := x.(error) 68 | if !ok { 69 | err = errors.New(buf.String()) 70 | } 71 | 72 | return 73 | } 74 | -------------------------------------------------------------------------------- /examples/pool/pool.go: -------------------------------------------------------------------------------- 1 | // Copyright 2011 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "log" 10 | "sync" 11 | ) 12 | 13 | import ( 14 | "github.com/lxn/go-pgsql" 15 | ) 16 | 17 | func main() { 18 | // Create a connection pool with up to 3 connections, automatically closing 19 | // idle connections after the default timeout period (5 minutes). 20 | pool, err := pgsql.NewPool("dbname=postgres user=postgres", 3, 3, pgsql.DEFAULT_IDLE_TIMEOUT) 21 | if err != nil { 22 | log.Fatalf("Error opening connection pool: %s\n", err) 23 | } 24 | pool.Debug = true 25 | 26 | // Create 10 worker goroutines each of which acquires and uses a 27 | // connection from the pool. 28 | var wg sync.WaitGroup 29 | nthreads := 10 30 | wg.Add(nthreads) 31 | for i := 0; i < nthreads; i++ { 32 | go worker(i+1, pool, &wg) 33 | } 34 | wg.Wait() // Wait for all the workers to finish. 35 | pool.Close() // Close all pool connections. 36 | } 37 | 38 | func worker(id int, pool *pgsql.Pool, wg *sync.WaitGroup) { 39 | conn, err := pool.Acquire() 40 | if err != nil { 41 | log.Printf("Error acquiring connection: %s\n", err) 42 | } else { 43 | res, err := conn.Query("SELECT now();") 44 | if err != nil { 45 | log.Printf("Error executing query: %s\n", err) 46 | } else { 47 | if hasRow, _ := res.FetchNext(); !hasRow { 48 | log.Println("Couldn't advance result cursor") 49 | } else { 50 | var now string 51 | if err := res.Scan(&now); err != nil { 52 | log.Printf("Error scanning result: %s\n", err) 53 | } else { 54 | fmt.Printf("Timestamp returned for worker %d: %s\n", id, now) 55 | } 56 | } 57 | } 58 | } 59 | // Return the connection back to the pool. 60 | pool.Release(conn) 61 | wg.Done() 62 | } 63 | -------------------------------------------------------------------------------- /data/testdatabase.sql: -------------------------------------------------------------------------------- 1 | -- 2 | -- PostgreSQL database dump 3 | -- 4 | 5 | SET statement_timeout = 0; 6 | SET client_encoding = 'UTF8'; 7 | SET standard_conforming_strings = off; 8 | SET check_function_bodies = false; 9 | SET client_min_messages = warning; 10 | SET escape_string_warning = off; 11 | 12 | SET search_path = public, pg_catalog; 13 | 14 | SET default_tablespace = ''; 15 | 16 | SET default_with_oids = false; 17 | 18 | -- 19 | -- Name: table1; Type: TABLE; Schema: public; Owner: testuser; Tablespace: 20 | -- 21 | 22 | CREATE TABLE table1 ( 23 | id integer NOT NULL, 24 | strreq character varying(20) NOT NULL, 25 | stropt character varying(20), 26 | blnreq boolean NOT NULL, 27 | i32req integer NOT NULL 28 | ); 29 | 30 | 31 | ALTER TABLE public.table1 OWNER TO testuser; 32 | 33 | -- 34 | -- Name: table1_id_seq; Type: SEQUENCE; Schema: public; Owner: testuser 35 | -- 36 | 37 | CREATE SEQUENCE table1_id_seq 38 | START WITH 1 39 | INCREMENT BY 1 40 | NO MAXVALUE 41 | NO MINVALUE 42 | CACHE 1; 43 | 44 | 45 | ALTER TABLE public.table1_id_seq OWNER TO testuser; 46 | 47 | -- 48 | -- Name: table1_id_seq; Type: SEQUENCE OWNED BY; Schema: public; Owner: testuser 49 | -- 50 | 51 | ALTER SEQUENCE table1_id_seq OWNED BY table1.id; 52 | 53 | 54 | -- 55 | -- Name: table1_id_seq; Type: SEQUENCE SET; Schema: public; Owner: testuser 56 | -- 57 | 58 | SELECT pg_catalog.setval('table1_id_seq', 3, true); 59 | 60 | 61 | -- 62 | -- Name: id; Type: DEFAULT; Schema: public; Owner: testuser 63 | -- 64 | 65 | ALTER TABLE table1 ALTER COLUMN id SET DEFAULT nextval('table1_id_seq'::regclass); 66 | 67 | 68 | -- 69 | -- Data for Name: table1; Type: TABLE DATA; Schema: public; Owner: testuser 70 | -- 71 | 72 | COPY table1 (id, strreq, stropt, blnreq, i32req) FROM stdin; 73 | 1 foo bar t 1234567890 74 | 2 baz t 5432 75 | 3 ※‣⁈ \N f -987654321 76 | \. 77 | 78 | 79 | -- 80 | -- Name: table1_pkey; Type: CONSTRAINT; Schema: public; Owner: testuser; Tablespace: 81 | -- 82 | 83 | ALTER TABLE ONLY table1 84 | ADD CONSTRAINT table1_pkey PRIMARY KEY (id); 85 | 86 | 87 | -- 88 | -- Name: public; Type: ACL; Schema: -; Owner: postgres 89 | -- 90 | 91 | REVOKE ALL ON SCHEMA public FROM PUBLIC; 92 | REVOKE ALL ON SCHEMA public FROM postgres; 93 | GRANT ALL ON SCHEMA public TO postgres; 94 | GRANT ALL ON SCHEMA public TO PUBLIC; 95 | 96 | 97 | -- 98 | -- PostgreSQL database dump complete 99 | -- 100 | 101 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "fmt" 9 | ) 10 | 11 | // Error contains detailed error information received from a PostgreSQL backend. 12 | // 13 | // Many go-pgsql functions return an os.Error value. In case of a backend error, 14 | // a type assertion as shown below gives you a *pgsql.Error with all details: 15 | // 16 | // ... 17 | // _, err := rs.FetchNext() 18 | // if err != nil { 19 | // if pgerr, ok := err.(*pgsql.Error); ok { 20 | // // Do something with pgerr 21 | // } 22 | // } 23 | // ... 24 | type Error struct { 25 | severity string 26 | code string 27 | message string 28 | detail string 29 | hint string 30 | position string 31 | internalPosition string 32 | internalQuery string 33 | where string 34 | file string 35 | line string 36 | routine string 37 | } 38 | 39 | func (e *Error) Severity() string { 40 | return e.severity 41 | } 42 | 43 | func (e *Error) Code() string { 44 | return e.code 45 | } 46 | 47 | func (e *Error) Message() string { 48 | return e.message 49 | } 50 | 51 | func (e *Error) Detail() string { 52 | return e.detail 53 | } 54 | 55 | func (e *Error) Hint() string { 56 | return e.hint 57 | } 58 | 59 | func (e *Error) Position() string { 60 | return e.position 61 | } 62 | 63 | func (e *Error) InternalPosition() string { 64 | return e.internalPosition 65 | } 66 | 67 | func (e *Error) InternalQuery() string { 68 | return e.internalQuery 69 | } 70 | 71 | func (e *Error) Where() string { 72 | return e.where 73 | } 74 | 75 | func (e *Error) File() string { 76 | return e.file 77 | } 78 | 79 | func (e *Error) Line() string { 80 | return e.line 81 | } 82 | 83 | func (e *Error) Routine() string { 84 | return e.routine 85 | } 86 | 87 | func (e *Error) Error() string { 88 | return fmt.Sprintf( 89 | `Severity: %s 90 | Code: %s 91 | Message: %s 92 | Detail: %s 93 | Hint: %s 94 | Position: %s 95 | Internal Position: %s 96 | Internal Query: %s 97 | Where: %s 98 | File: %s 99 | Line: %s 100 | Routine: %s`, 101 | e.severity, e.code, e.message, e.detail, e.hint, e.position, 102 | e.internalPosition, e.internalQuery, e.where, e.file, e.line, e.routine) 103 | } 104 | -------------------------------------------------------------------------------- /state.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | const invalidOpForStateMsg = "invalid operation for this state" 8 | 9 | // state is the interface that all states must implement. 10 | type state interface { 11 | // code returns the ConnStatus that matches the state. 12 | code() ConnStatus 13 | 14 | // execute sends Bind and Execute packets to the server. 15 | execute(stmt *Statement, rs *ResultSet) 16 | 17 | // flush sends a Flush packet to the server. 18 | flush(conn *Conn) 19 | 20 | // prepare sends a Parse packet to the server. 21 | prepare(stmt *Statement) 22 | 23 | // query sends a Query packet to the server. 24 | query(conn *Conn, rs *ResultSet, sql string) 25 | } 26 | 27 | // abstractState can be embedded in any real state struct, so it satisfies 28 | // the state interface without implementing all state methods itself. 29 | type abstractState struct{} 30 | 31 | func (abstractState) execute(stmt *Statement, rs *ResultSet) { 32 | panic(invalidOpForStateMsg) 33 | } 34 | 35 | func (abstractState) flush(conn *Conn) { 36 | panic(invalidOpForStateMsg) 37 | } 38 | 39 | func (abstractState) prepare(stmt *Statement) { 40 | panic(invalidOpForStateMsg) 41 | } 42 | 43 | func (abstractState) query(conn *Conn, rs *ResultSet, sql string) { 44 | panic(invalidOpForStateMsg) 45 | } 46 | 47 | // copyState is the state that is active when the connection is used 48 | // to exchange CopyData messages for bulk import/export. 49 | type copyState struct { 50 | abstractState 51 | } 52 | 53 | func (copyState) code() ConnStatus { 54 | return StatusCopy 55 | } 56 | 57 | // disconnectedState is the initial state before a connection is established. 58 | type disconnectedState struct { 59 | abstractState 60 | } 61 | 62 | func (disconnectedState) code() ConnStatus { 63 | return StatusDisconnected 64 | } 65 | 66 | // processingQueryState is the state that is active when 67 | // the results of a query are being processed. 68 | type processingQueryState struct { 69 | abstractState 70 | } 71 | 72 | func (processingQueryState) code() ConnStatus { 73 | return StatusProcessingQuery 74 | } 75 | 76 | // readyState is the state that is active when the connection to the 77 | // PostgreSQL server is ready for queries. 78 | type readyState struct { 79 | abstractState 80 | } 81 | 82 | func (readyState) code() ConnStatus { 83 | return StatusReady 84 | } 85 | 86 | func (readyState) execute(stmt *Statement, rs *ResultSet) { 87 | conn := stmt.conn 88 | 89 | if conn.LogLevel >= LogDebug { 90 | defer conn.logExit(conn.logEnter("readyState.execute")) 91 | } 92 | 93 | succeeded := false 94 | conn.onErrorDontRequireReadyForQuery = true 95 | defer func() { 96 | conn.onErrorDontRequireReadyForQuery = false 97 | 98 | if !succeeded { 99 | conn.writeSync() 100 | 101 | conn.readBackendMessages(nil) 102 | } 103 | }() 104 | 105 | conn.writeBind(stmt) 106 | 107 | conn.readBackendMessages(rs) 108 | 109 | conn.writeDescribe(stmt) 110 | 111 | conn.readBackendMessages(rs) 112 | 113 | conn.writeExecute(stmt) 114 | 115 | conn.writeSync() 116 | 117 | conn.state = processingQueryState{} 118 | 119 | succeeded = true 120 | } 121 | 122 | func (readyState) prepare(stmt *Statement) { 123 | conn := stmt.conn 124 | 125 | if conn.LogLevel >= LogDebug { 126 | defer conn.logExit(conn.logEnter("readyState.prepare")) 127 | } 128 | 129 | conn.writeParse(stmt) 130 | 131 | conn.onErrorDontRequireReadyForQuery = true 132 | defer func() { conn.onErrorDontRequireReadyForQuery = false }() 133 | 134 | conn.readBackendMessages(nil) 135 | } 136 | 137 | func (readyState) query(conn *Conn, rs *ResultSet, command string) { 138 | if conn.LogLevel >= LogDebug { 139 | defer conn.logExit(conn.logEnter("readyState.query")) 140 | } 141 | 142 | conn.writeQuery(command) 143 | 144 | conn.readBackendMessages(rs) 145 | 146 | conn.state = processingQueryState{} 147 | } 148 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | const ( 8 | _BOOLOID = 16 9 | _BYTEAOID = 17 10 | _CHAROID = 18 11 | _NAMEOID = 19 12 | _INT8OID = 20 13 | _INT2OID = 21 14 | _INT2VECTOROID = 22 15 | _INT4OID = 23 16 | _REGPROCOID = 24 17 | _TEXTOID = 25 18 | _OIDOID = 26 19 | _TIDOID = 27 20 | _XIDOID = 28 21 | _CIDOID = 29 22 | _OIDVECTOROID = 30 23 | _XMLOID = 142 24 | _POINTOID = 600 25 | _LSEGOID = 601 26 | _PATHOID = 602 27 | _BOXOID = 603 28 | _POLYGONOID = 604 29 | _LINEOID = 628 30 | _FLOAT4OID = 700 31 | _FLOAT8OID = 701 32 | _ABSTIMEOID = 702 33 | _RELTIMEOID = 703 34 | _TINTERVALOID = 704 35 | _UNKNOWNOID = 705 36 | _CIRCLEOID = 718 37 | _CASHOID = 790 38 | _MACADDROID = 829 39 | _INETOID = 869 40 | _CIDROID = 650 41 | _INT4ARRAYOID = 1007 42 | _TEXTARRAYOID = 1009 43 | _FLOAT4ARRAYOID = 1021 44 | _ACLITEMOID = 1033 45 | _CSTRINGARRAYOID = 1263 46 | _BPCHAROID = 1042 47 | _VARCHAROID = 1043 48 | _DATEOID = 1082 49 | _TIMEOID = 1083 50 | _TIMESTAMPOID = 1114 51 | _TIMESTAMPTZOID = 1184 52 | _INTERVALOID = 1186 53 | _TIMETZOID = 1266 54 | _BITOID = 1560 55 | _VARBITOID = 1562 56 | _NUMERICOID = 1700 57 | _REFCURSOROID = 1790 58 | _REGPROCEDUREOID = 2202 59 | _REGOPEROID = 2203 60 | _REGOPERATOROID = 2204 61 | _REGCLASSOID = 2205 62 | _REGTYPEOID = 2206 63 | _REGTYPEARRAYOID = 2211 64 | _TSVECTOROID = 3614 65 | _GTSVECTOROID = 3642 66 | _TSQUERYOID = 3615 67 | _REGCONFIGOID = 3734 68 | _REGDICTIONARYOID = 3769 69 | _RECORDOID = 2249 70 | _RECORDARRAYOID = 2287 71 | _CSTRINGOID = 2275 72 | _ANYOID = 2276 73 | _ANYARRAYOID = 2277 74 | _VOIDOID = 2278 75 | _TRIGGEROID = 2279 76 | _LANGUAGE_HANDLEROID = 2280 77 | _INTERNALOID = 2281 78 | _OPAQUEOID = 2282 79 | _ANYELEMENTOID = 2283 80 | _ANYNONARRAYOID = 2776 81 | _ANYENUMOID = 3500 82 | ) 83 | 84 | // Type represents the PostgreSQL data type of fields and parameters. 85 | type Type int32 86 | 87 | const ( 88 | Custom Type = 0 89 | Boolean Type = _BOOLOID 90 | Char Type = _CHAROID 91 | Date Type = _DATEOID 92 | Real Type = _FLOAT4OID 93 | Double Type = _FLOAT8OID 94 | Smallint Type = _INT2OID 95 | Integer Type = _INT4OID 96 | Bigint Type = _INT8OID 97 | Numeric Type = _NUMERICOID 98 | Text Type = _TEXTOID 99 | Time Type = _TIMEOID 100 | TimeTZ Type = _TIMETZOID 101 | Timestamp Type = _TIMESTAMPOID 102 | TimestampTZ Type = _TIMESTAMPTZOID 103 | Varchar Type = _VARCHAROID 104 | ) 105 | 106 | func (t Type) String() string { 107 | switch t { 108 | case Boolean: 109 | return "Boolean" 110 | 111 | case Char: 112 | return "Char" 113 | 114 | case Custom: 115 | return "Custom" 116 | 117 | case Date: 118 | return "Date" 119 | 120 | case Real: 121 | return "Real" 122 | 123 | case Double: 124 | return "Double" 125 | 126 | case Smallint: 127 | return "Smallint" 128 | 129 | case Integer: 130 | return "Integer" 131 | 132 | case Bigint: 133 | return "Bigint" 134 | 135 | case Numeric: 136 | return "Numeric" 137 | 138 | case Text: 139 | return "Text" 140 | 141 | case Time: 142 | return "Time" 143 | 144 | case TimeTZ: 145 | return "TimeTZ" 146 | 147 | case Timestamp: 148 | return "Timestamp" 149 | 150 | case TimestampTZ: 151 | return "TimestampTZ" 152 | 153 | case Varchar: 154 | return "Varchar" 155 | } 156 | 157 | return "Unknown" 158 | } 159 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "database/sql" 9 | "database/sql/driver" 10 | "fmt" 11 | "io" 12 | "time" 13 | ) 14 | 15 | func init() { 16 | sql.Register("postgres", sqlDriver{}) 17 | } 18 | 19 | type sqlDriver struct { 20 | } 21 | 22 | func (sqlDriver) Open(name string) (driver.Conn, error) { 23 | conn, err := Connect(name, LogNothing) 24 | if err != nil { 25 | return nil, err 26 | } 27 | 28 | return &sqlConn{conn}, nil 29 | } 30 | 31 | type sqlConn struct { 32 | conn *Conn 33 | } 34 | 35 | func (c *sqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { 36 | n, err := c.conn.Execute(query, paramsFromValues(nil, args)...) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return driver.RowsAffected(n), nil 42 | } 43 | 44 | func (c *sqlConn) Prepare(query string) (driver.Stmt, error) { 45 | stmt, err := c.conn.Prepare(query) 46 | if err != nil { 47 | return nil, err 48 | } 49 | 50 | return &sqlStmt{stmt}, nil 51 | } 52 | 53 | func (c *sqlConn) Close() error { 54 | return c.conn.Close() 55 | } 56 | 57 | func (c *sqlConn) Begin() (driver.Tx, error) { 58 | if _, err := c.conn.Execute("BEGIN;"); err != nil { 59 | return nil, err 60 | } 61 | 62 | return &sqlTx{c.conn}, nil 63 | } 64 | 65 | type sqlStmt struct { 66 | stmt *Statement 67 | } 68 | 69 | func (s *sqlStmt) Close() error { 70 | return s.stmt.Close() 71 | } 72 | 73 | func (s *sqlStmt) NumInput() int { 74 | return -1 75 | } 76 | 77 | func (s *sqlStmt) Exec(args []driver.Value) (driver.Result, error) { 78 | s.stmt.params = paramsFromValues(s.stmt.params, args) 79 | 80 | n, err := s.stmt.Execute() 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | return driver.RowsAffected(n), nil 86 | } 87 | 88 | func (s *sqlStmt) Query(args []driver.Value) (driver.Rows, error) { 89 | s.stmt.params = paramsFromValues(s.stmt.params, args) 90 | 91 | rs, err := s.stmt.Query() 92 | if err != nil { 93 | return nil, err 94 | } 95 | 96 | return &sqlRows{rs}, nil 97 | } 98 | 99 | type sqlTx struct { 100 | conn *Conn 101 | } 102 | 103 | func (t *sqlTx) Commit() error { 104 | _, err := t.conn.Execute("COMMIT;") 105 | 106 | return err 107 | } 108 | 109 | func (t *sqlTx) Rollback() error { 110 | _, err := t.conn.Execute("ROLLBACK;") 111 | 112 | return err 113 | } 114 | 115 | type sqlRows struct { 116 | rs *ResultSet 117 | } 118 | 119 | func (r *sqlRows) Columns() []string { 120 | names := make([]string, len(r.rs.fields)) 121 | 122 | for i, f := range r.rs.fields { 123 | names[i] = f.name 124 | } 125 | 126 | return names 127 | } 128 | 129 | func (r *sqlRows) Close() error { 130 | return r.rs.Close() 131 | } 132 | 133 | func (r *sqlRows) Next(dest []driver.Value) error { 134 | fetched, err := r.rs.FetchNext() 135 | if err != nil { 136 | return err 137 | } 138 | 139 | if !fetched { 140 | return io.EOF 141 | } 142 | 143 | for i := range dest { 144 | val, isNull, err := r.rs.Any(i) 145 | if err != nil { 146 | return err 147 | } 148 | 149 | if isNull { 150 | val = nil 151 | } else { 152 | switch v := val.(type) { 153 | case float32: 154 | val = float64(v) 155 | 156 | case int: 157 | val = int64(v) 158 | 159 | case int8: 160 | val = int64(v) 161 | 162 | case int16: 163 | val = int64(v) 164 | 165 | case int32: 166 | val = int64(v) 167 | 168 | case string: 169 | val = ([]byte)(v) 170 | } 171 | } 172 | 173 | dest[i] = val 174 | } 175 | 176 | return nil 177 | } 178 | 179 | func paramsFromValues(params []*Parameter, vals []driver.Value) []*Parameter { 180 | if len(params) < len(vals) { 181 | params = make([]*Parameter, len(vals)) 182 | } 183 | 184 | for i, val := range vals { 185 | p := params[i] 186 | 187 | if p == nil { 188 | var typ Type 189 | 190 | switch val.(type) { 191 | case nil: 192 | typ = Integer 193 | 194 | case bool: 195 | typ = Boolean 196 | 197 | case []byte, string: 198 | typ = Varchar 199 | 200 | case float64: 201 | typ = Double 202 | 203 | case int64: 204 | typ = Bigint 205 | 206 | case time.Time: 207 | typ = TimestampTZ 208 | 209 | default: 210 | panic("unexpected value type") 211 | } 212 | 213 | p = NewParameter(fmt.Sprintf("$%d", i), typ) 214 | params[i] = p 215 | } 216 | 217 | p.SetValue(val) 218 | } 219 | 220 | return params 221 | } 222 | -------------------------------------------------------------------------------- /parameter.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "math/big" 11 | "reflect" 12 | "time" 13 | ) 14 | 15 | // Parameter is used to set the value of a parameter in a Statement. 16 | type Parameter struct { 17 | name string 18 | stmt *Statement 19 | typ Type 20 | customTypeName string 21 | value interface{} 22 | } 23 | 24 | // NewParameter returns a new Parameter with the specified name and type. 25 | func NewParameter(name string, typ Type) *Parameter { 26 | return &Parameter{name: name, typ: typ} 27 | } 28 | 29 | // NewCustomTypeParameter returns a new Parameter with the specified name and 30 | // custom data type. 31 | // 32 | // The value of customTypeName will be used to insert a type cast into the 33 | // command text for each occurrence of the parameter. 34 | // 35 | // This constructor can be used for enum type parameters. In that case the value 36 | // provided to SetValue is expected to be a string. 37 | func NewCustomTypeParameter(name, customTypeName string) *Parameter { 38 | return &Parameter{name: name, customTypeName: customTypeName} 39 | } 40 | 41 | // CustomTypeName returns the custom type name of the Parameter. 42 | func (p *Parameter) CustomTypeName() string { 43 | return p.customTypeName 44 | } 45 | 46 | // Name returns the name of the Parameter. 47 | func (p *Parameter) Name() string { 48 | return p.name 49 | } 50 | 51 | // Statement returns the *Statement this Parameter is associated with. 52 | func (p *Parameter) Statement() *Statement { 53 | return p.stmt 54 | } 55 | 56 | // Type returns the PostgreSQL data type of the Parameter. 57 | func (p *Parameter) Type() Type { 58 | return p.typ 59 | } 60 | 61 | // Value returns the current value of the Parameter. 62 | func (p *Parameter) Value() interface{} { 63 | return p.value 64 | } 65 | 66 | func (p *Parameter) panicInvalidValue(v interface{}) { 67 | panic(errors.New(fmt.Sprintf("Parameter %s: Invalid value for PostgreSQL type %s: '%v' (Go type: %T)", 68 | p.name, p.typ, v, v))) 69 | } 70 | 71 | func isNilPtr(v interface{}) bool { 72 | ptr := reflect.ValueOf(v) 73 | 74 | return ptr.Kind() == reflect.Ptr && 75 | ptr.IsNil() 76 | } 77 | 78 | // SetValue sets the current value of the Parameter. 79 | func (p *Parameter) SetValue(v interface{}) (err error) { 80 | if p.stmt != nil && p.stmt.conn.LogLevel >= LogVerbose { 81 | defer p.stmt.conn.logExit(p.stmt.conn.logEnter("*Parameter.SetValue")) 82 | } 83 | 84 | defer func() { 85 | if x := recover(); x != nil { 86 | if p.stmt == nil { 87 | switch ex := x.(type) { 88 | case error: 89 | err = ex 90 | 91 | case string: 92 | err = errors.New(ex) 93 | 94 | default: 95 | err = errors.New("pgsql.*Parameter.SetValue: D'oh!") 96 | } 97 | } else { 98 | err = p.stmt.conn.logAndConvertPanic(x) 99 | } 100 | } 101 | }() 102 | 103 | if v == nil { 104 | p.value = nil 105 | return 106 | } 107 | 108 | switch p.typ { 109 | case Bigint: 110 | switch val := v.(type) { 111 | case byte: 112 | p.value = int64(val) 113 | 114 | case int: 115 | p.value = int64(val) 116 | 117 | case int16: 118 | p.value = int64(val) 119 | 120 | case int32: 121 | p.value = int64(val) 122 | 123 | case uint: 124 | p.value = int64(val) 125 | 126 | case uint16: 127 | p.value = int64(val) 128 | 129 | case uint32: 130 | p.value = int64(val) 131 | 132 | case uint64: 133 | p.value = int64(val) 134 | 135 | case int64: 136 | p.value = val 137 | 138 | default: 139 | p.panicInvalidValue(v) 140 | } 141 | 142 | case Boolean: 143 | val, ok := v.(bool) 144 | if !ok { 145 | p.panicInvalidValue(v) 146 | } 147 | p.value = val 148 | 149 | case Char, Text, Varchar: 150 | val, ok := v.(string) 151 | if !ok { 152 | p.panicInvalidValue(v) 153 | } 154 | p.value = val 155 | 156 | case Custom: 157 | p.value = v 158 | 159 | case Date, Time, TimeTZ, Timestamp, TimestampTZ: 160 | switch val := v.(type) { 161 | case int64: 162 | p.value = val 163 | 164 | case time.Time: 165 | if isNilPtr(v) { 166 | p.value = nil 167 | return 168 | } 169 | 170 | p.value = val 171 | 172 | case uint64: 173 | p.value = val 174 | 175 | default: 176 | p.panicInvalidValue(v) 177 | } 178 | 179 | case Double: 180 | switch val := v.(type) { 181 | case float32: 182 | p.value = float64(val) 183 | 184 | case float64: 185 | p.value = val 186 | 187 | default: 188 | p.panicInvalidValue(v) 189 | } 190 | 191 | case Integer: 192 | switch val := v.(type) { 193 | case byte: 194 | p.value = int32(val) 195 | 196 | case int: 197 | p.value = int32(val) 198 | 199 | case int16: 200 | p.value = int32(val) 201 | 202 | case uint: 203 | p.value = int32(val) 204 | 205 | case uint16: 206 | p.value = int32(val) 207 | 208 | case uint32: 209 | p.value = int32(val) 210 | 211 | case int32: 212 | p.value = val 213 | 214 | default: 215 | p.panicInvalidValue(v) 216 | } 217 | 218 | case Numeric: 219 | val, ok := v.(*big.Rat) 220 | if !ok { 221 | p.panicInvalidValue(v) 222 | } 223 | 224 | if isNilPtr(v) { 225 | p.value = nil 226 | return 227 | } 228 | 229 | p.value = val 230 | 231 | case Real: 232 | switch val := v.(type) { 233 | case float32: 234 | p.value = val 235 | 236 | default: 237 | p.panicInvalidValue(v) 238 | } 239 | 240 | case Smallint: 241 | switch val := v.(type) { 242 | case byte: 243 | p.value = int16(val) 244 | 245 | case uint16: 246 | p.value = int16(val) 247 | 248 | case int16: 249 | p.value = val 250 | 251 | default: 252 | p.panicInvalidValue(v) 253 | } 254 | } 255 | 256 | return 257 | } 258 | -------------------------------------------------------------------------------- /pool.go: -------------------------------------------------------------------------------- 1 | // Copyright 2011 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "bufio" 9 | "container/list" 10 | "errors" 11 | "fmt" 12 | "log" 13 | "runtime" 14 | "sync" 15 | "time" 16 | ) 17 | 18 | const DEFAULT_IDLE_TIMEOUT = 300 // Seconds 19 | 20 | type poolConn struct { 21 | *Conn 22 | atime time.Time // Time at which Conn is inserted into free list 23 | } 24 | 25 | type pool struct { 26 | params string // Params to create new Conn 27 | conns *list.List // List of available Conns 28 | max int // Maximum number of connections to create 29 | min int // min number of connections to create 30 | n int // Number of connections created 31 | cond *sync.Cond // Pool lock, and condition to signal when connection is released 32 | timeout time.Duration // Idle timeout period in seconds 33 | closed bool 34 | Debug bool // Set to true to print debug messages to stderr 35 | } 36 | 37 | func (p *pool) log(msg string) { 38 | if p.Debug { 39 | log.Println(time.Now().Format("2006-01-02 15:04:05"), msg) 40 | } 41 | } 42 | 43 | // A Pool manages a list of connections that can be safely used by multiple goroutines. 44 | type Pool struct { 45 | // Subtle: Embed *pool struct so that timeoutCloser can operate on *pool 46 | // without preventing *Pool being garbage collected (and properly finalized). 47 | // See http://groups.google.com/group/golang-nuts/browse_thread/thread/d48b4d38e8fcc96f for discussion. 48 | *pool 49 | } 50 | 51 | // Close connections that have been idle for > p.timeout seconds. 52 | func timeoutCloser(p *pool) { 53 | for p != nil && !p.closed { 54 | p.cond.L.Lock() 55 | now := time.Now() 56 | delay := p.timeout 57 | for p.conns.Len() > 0 { 58 | front := p.conns.Front() 59 | pc := front.Value.(poolConn) 60 | atime := pc.atime 61 | if (now.Sub(atime)) > p.timeout { 62 | pc.Conn.Close() 63 | p.conns.Remove(front) 64 | p.n-- 65 | p.log("idle connection closed") 66 | } else { 67 | // Wait until first connection would timeout if it isn't used. 68 | delay = p.timeout - now.Sub(atime) + 1 69 | break 70 | } 71 | } 72 | // don't let the pool fall below the min 73 | for i := p.n; i < p.min; i++ { 74 | c, err := Connect(p.params, LogError) 75 | if err != nil { 76 | p.log("can't create connection") 77 | } else { 78 | p.conns.PushFront(poolConn{c, time.Now()}) 79 | p.n++ 80 | } 81 | } 82 | p.cond.L.Unlock() 83 | time.Sleep(delay * time.Second) 84 | } 85 | p.log("timeoutCloser finished") 86 | } 87 | 88 | // NewPool returns a new Pool that will create new connections on demand 89 | // using connectParams, up to a maximum of maxConns outstanding connections. 90 | // An error is returned if an initial connection cannot be created. 91 | // Connections that have been idle for idleTimeout seconds will be automatically 92 | // closed. 93 | func NewPool(connectParams string, minConns, maxConns int, idleTimeout time.Duration) (p *Pool, err error) { 94 | if minConns < 1 { 95 | return nil, errors.New("minConns must be >= 1") 96 | } 97 | if maxConns < 1 { 98 | return nil, errors.New("maxConns must be >= 1") 99 | } 100 | if idleTimeout < 5 { 101 | return nil, errors.New("idleTimeout must be >= 5") 102 | } 103 | 104 | // Create initial connection to verify connectParams will work. 105 | c, err := Connect(connectParams, LogError) 106 | if err != nil { 107 | return 108 | } 109 | p = &Pool{ 110 | &pool{ 111 | params: connectParams, 112 | conns: list.New(), 113 | max: maxConns, 114 | min: minConns, 115 | n: 1, 116 | cond: sync.NewCond(new(sync.Mutex)), 117 | timeout: idleTimeout, 118 | }, 119 | } 120 | p.conns.PushFront(poolConn{c, time.Now()}) 121 | 122 | for i := 0; i < minConns-1; i++ { 123 | // pre-fill the pool 124 | _c, err := Connect(connectParams, LogError) 125 | if err != nil { 126 | return nil, err 127 | } 128 | p.conns.PushFront(poolConn{_c, time.Now()}) 129 | p.n++ 130 | } 131 | 132 | go timeoutCloser(p.pool) 133 | runtime.SetFinalizer(p, (*Pool).close) 134 | return 135 | } 136 | 137 | // Acquire returns the next available connection, or returns an error if it 138 | // failed to create a new connection. 139 | // When an Acquired connection has been finished with, it should be returned 140 | // to the pool via Release. 141 | func (p *Pool) Acquire() (c *Conn, err error) { 142 | p.cond.L.Lock() 143 | defer p.cond.L.Unlock() 144 | if p.closed { 145 | return nil, errors.New("pool is closed") 146 | } 147 | if p.conns.Len() > 0 { 148 | c = p.conns.Remove(p.conns.Front()).(poolConn).Conn 149 | } else if p.conns.Len() == 0 && p.n < p.max { 150 | c, err = Connect(p.params, LogError) 151 | if err != nil { 152 | return 153 | } 154 | p.n++ 155 | if p.Debug { 156 | p.log(fmt.Sprintf("connection %d created", p.n)) 157 | } 158 | } else { // p.conns.Len() == 0 && p.n == p.max 159 | for p.conns.Len() == 0 { 160 | p.cond.Wait() 161 | } 162 | c = p.conns.Remove(p.conns.Front()).(poolConn).Conn 163 | } 164 | if p.Debug { 165 | p.log(fmt.Sprintf("connection acquired: %d idle, %d unused", p.conns.Len(), p.max-p.n)) 166 | } 167 | return c, nil 168 | } 169 | 170 | // Release returns the previously Acquired connection to the list of available connections. 171 | func (p *Pool) Release(c *Conn) { 172 | p.cond.L.Lock() 173 | defer p.cond.L.Unlock() 174 | if !p.closed { 175 | // reset the connection 176 | c.reader = bufio.NewReader(c.tcpConn) 177 | c.writer = bufio.NewWriter(c.tcpConn) 178 | c.state = readyState{} 179 | 180 | // push back to the queue 181 | p.conns.PushBack(poolConn{c, time.Now()}) 182 | if p.Debug { 183 | p.log(fmt.Sprintf("connection released: %d idle, %d unused", p.conns.Len(), p.max-p.n)) 184 | } 185 | p.cond.Signal() 186 | } 187 | } 188 | 189 | func (p *Pool) close() { 190 | if p != nil { 191 | nConns := p.conns.Len() 192 | for p.conns.Len() > 0 { 193 | p.conns.Remove(p.conns.Front()).(poolConn).Close() 194 | } 195 | p.n -= nConns 196 | p.closed = true 197 | runtime.SetFinalizer(p, nil) 198 | p.log("close finished") 199 | } 200 | } 201 | 202 | // Close closes any available connections and prevents the Acquiring of any new connections. 203 | // It returns an error if there are any outstanding connections remaining. 204 | func (p *Pool) Close() error { 205 | p.cond.L.Lock() 206 | defer p.cond.L.Unlock() 207 | if !p.closed { 208 | p.close() 209 | if p.n > 0 { 210 | return errors.New(fmt.Sprintf("pool closed but %d connections in use", p.n)) 211 | } 212 | } 213 | return nil 214 | } 215 | -------------------------------------------------------------------------------- /messagecodes.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "fmt" 9 | ) 10 | 11 | //------------------------------------------------------------------------------ 12 | 13 | type backendMessageCode byte 14 | 15 | const ( 16 | _AuthenticationRequest backendMessageCode = 'R' 17 | _BackendKeyData backendMessageCode = 'K' 18 | _BindComplete backendMessageCode = '2' 19 | _CloseComplete backendMessageCode = '3' 20 | _CommandComplete backendMessageCode = 'C' 21 | _CopyData_BE backendMessageCode = 'd' 22 | _CopyDone_BE backendMessageCode = 'c' 23 | _CopyInResponse backendMessageCode = 'G' 24 | _CopyOutResponse backendMessageCode = 'H' 25 | _DataRow backendMessageCode = 'D' 26 | _EmptyQueryResponse backendMessageCode = 'I' 27 | _ErrorResponse backendMessageCode = 'E' 28 | _FunctionCallResponse backendMessageCode = 'V' 29 | _NoData backendMessageCode = 'n' 30 | _NoticeResponse backendMessageCode = 'N' 31 | _NotificationResponse backendMessageCode = 'A' 32 | _ParameterDescription backendMessageCode = 't' 33 | _ParameterStatus backendMessageCode = 'S' 34 | _ParseComplete backendMessageCode = '1' 35 | _PortalSuspended backendMessageCode = 's' 36 | _ReadyForQuery backendMessageCode = 'Z' 37 | _RowDescription backendMessageCode = 'T' 38 | ) 39 | 40 | var backendMsgCode2String map[backendMessageCode]string 41 | 42 | func (x backendMessageCode) String() string { 43 | s, ok := backendMsgCode2String[x] 44 | if !ok { 45 | return fmt.Sprintf("unknown backendMessageCode: %02x", x) 46 | } 47 | 48 | return s 49 | } 50 | 51 | //------------------------------------------------------------------------------ 52 | 53 | type frontendMessageCode byte 54 | 55 | const ( 56 | _Bind frontendMessageCode = 'B' 57 | _Close frontendMessageCode = 'C' 58 | _CopyData_FE frontendMessageCode = 'd' 59 | _CopyDone_FE frontendMessageCode = 'c' 60 | _CopyFail frontendMessageCode = 'f' 61 | _Describe frontendMessageCode = 'D' 62 | _Execute frontendMessageCode = 'E' 63 | _Flush frontendMessageCode = 'H' 64 | _FunctionCall frontendMessageCode = 'F' 65 | _Parse frontendMessageCode = 'P' 66 | _PasswordMessage frontendMessageCode = 'p' 67 | _Query frontendMessageCode = 'Q' 68 | _SSLRequest frontendMessageCode = '8' 69 | _Sync frontendMessageCode = 'S' 70 | _Terminate frontendMessageCode = 'X' 71 | ) 72 | 73 | var frontendMsgCode2String map[frontendMessageCode]string 74 | 75 | func (x frontendMessageCode) String() string { 76 | s, ok := frontendMsgCode2String[x] 77 | if !ok { 78 | return "unkown frontendMessageCode" 79 | } 80 | 81 | return s 82 | } 83 | 84 | //------------------------------------------------------------------------------ 85 | 86 | type authenticationType int32 87 | 88 | const ( 89 | _AuthenticationOk authenticationType = 0 90 | _AuthenticationKerberosV5 authenticationType = 2 91 | _AuthenticationCleartextPassword authenticationType = 3 92 | _AuthenticationMD5Password authenticationType = 5 93 | _AuthenticationSCMCredential authenticationType = 6 94 | _AuthenticationGSS authenticationType = 7 95 | _AuthenticationGSSContinue authenticationType = 8 96 | _AuthenticationSSPI authenticationType = 9 97 | ) 98 | 99 | var authType2String map[authenticationType]string 100 | 101 | func (x authenticationType) String() string { 102 | s, ok := authType2String[x] 103 | if !ok { 104 | return "unkown authenticationType" 105 | } 106 | 107 | return s 108 | } 109 | 110 | //------------------------------------------------------------------------------ 111 | 112 | func init() { 113 | 114 | backendMsgCode2String = make(map[backendMessageCode]string) 115 | 116 | backendMsgCode2String[_AuthenticationRequest] = "AuthenticationRequest" 117 | backendMsgCode2String[_BackendKeyData] = "BackendKeyData" 118 | backendMsgCode2String[_BindComplete] = "BindComplete" 119 | backendMsgCode2String[_CloseComplete] = "CloseComplete" 120 | backendMsgCode2String[_CommandComplete] = "CommandComplete" 121 | backendMsgCode2String[_CopyData_BE] = "CopyData" 122 | backendMsgCode2String[_CopyDone_BE] = "CopyDone" 123 | backendMsgCode2String[_CopyInResponse] = "CopyInResponse" 124 | backendMsgCode2String[_CopyOutResponse] = "CopyOutResponse" 125 | backendMsgCode2String[_DataRow] = "DataRow" 126 | backendMsgCode2String[_EmptyQueryResponse] = "EmptyQueryResponse" 127 | backendMsgCode2String[_ErrorResponse] = "ErrorResponse" 128 | backendMsgCode2String[_FunctionCallResponse] = "FunctionCallResponse" 129 | backendMsgCode2String[_NoData] = "NoData" 130 | backendMsgCode2String[_NoticeResponse] = "NoticeResponse" 131 | backendMsgCode2String[_NotificationResponse] = "NotificationResponse" 132 | backendMsgCode2String[_ParameterDescription] = "ParameterDescription" 133 | backendMsgCode2String[_ParameterStatus] = "ParameterStatus" 134 | backendMsgCode2String[_ParseComplete] = "ParseComplete" 135 | backendMsgCode2String[_PortalSuspended] = "PortalSuspended" 136 | backendMsgCode2String[_ReadyForQuery] = "ReadyForQuery" 137 | backendMsgCode2String[_RowDescription] = "RowDescription" 138 | 139 | //-------- 140 | 141 | frontendMsgCode2String = make(map[frontendMessageCode]string) 142 | 143 | frontendMsgCode2String[_Bind] = "Bind" 144 | frontendMsgCode2String[_Close] = "Close" 145 | frontendMsgCode2String[_CopyData_FE] = "CopyData" 146 | frontendMsgCode2String[_CopyDone_FE] = "CopyDone" 147 | frontendMsgCode2String[_CopyFail] = "CopyFail" 148 | frontendMsgCode2String[_Describe] = "Describe" 149 | frontendMsgCode2String[_Execute] = "Execute" 150 | frontendMsgCode2String[_Flush] = "Flush" 151 | frontendMsgCode2String[_FunctionCall] = "FunctionCall" 152 | frontendMsgCode2String[_Parse] = "Parse" 153 | frontendMsgCode2String[_PasswordMessage] = "PasswordMessage" 154 | frontendMsgCode2String[_Query] = "Query" 155 | frontendMsgCode2String[_SSLRequest] = "SSLRequest" 156 | frontendMsgCode2String[_Sync] = "Sync" 157 | frontendMsgCode2String[_Terminate] = "Terminate" 158 | 159 | //-------- 160 | 161 | authType2String = make(map[authenticationType]string) 162 | 163 | authType2String[_AuthenticationOk] = "AuthenticationOk" 164 | authType2String[_AuthenticationKerberosV5] = "AuthenticationKerberosV5" 165 | authType2String[_AuthenticationCleartextPassword] = "AuthenticationCleartextPassword" 166 | authType2String[_AuthenticationMD5Password] = "AuthenticationMD5Password" 167 | authType2String[_AuthenticationSCMCredential] = "AuthenticationSCMCredential" 168 | authType2String[_AuthenticationGSS] = "AuthenticationGSS" 169 | authType2String[_AuthenticationGSSContinue] = "AuthenticationGSSContinue" 170 | authType2String[_AuthenticationSSPI] = "AuthenticationSSPI" 171 | } 172 | -------------------------------------------------------------------------------- /conn_write.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "encoding/binary" 9 | "fmt" 10 | "math" 11 | "math/big" 12 | "strconv" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | func (conn *Conn) flush() { 18 | panicIfErr(conn.writer.Flush()) 19 | } 20 | 21 | func (conn *Conn) write(b []byte) { 22 | _, err := conn.writer.Write(b) 23 | panicIfErr(err) 24 | } 25 | 26 | func (conn *Conn) writeByte(b byte) { 27 | panicIfErr(conn.writer.WriteByte(b)) 28 | } 29 | 30 | func (conn *Conn) writeFloat32(f float32) { 31 | var buf [4]byte 32 | b := buf[:] 33 | 34 | binary.BigEndian.PutUint32(b, math.Float32bits(f)) 35 | conn.write(b) 36 | } 37 | 38 | func (conn *Conn) writeFloat64(f float64) { 39 | var buf [8]byte 40 | b := buf[:] 41 | 42 | binary.BigEndian.PutUint64(b, math.Float64bits(f)) 43 | conn.write(b) 44 | } 45 | 46 | func (conn *Conn) writeFrontendMessageCode(code frontendMessageCode) { 47 | panicIfErr(conn.writer.WriteByte(byte(code))) 48 | } 49 | 50 | func (conn *Conn) writeInt16(i int16) { 51 | var buf [2]byte 52 | b := buf[:] 53 | 54 | binary.BigEndian.PutUint16(b, uint16(i)) 55 | conn.write(b) 56 | } 57 | 58 | func (conn *Conn) writeInt32(i int32) { 59 | var buf [4]byte 60 | b := buf[:] 61 | 62 | binary.BigEndian.PutUint32(b, uint32(i)) 63 | conn.write(b) 64 | } 65 | 66 | func (conn *Conn) writeInt64(i int64) { 67 | var buf [8]byte 68 | b := buf[:] 69 | 70 | binary.BigEndian.PutUint64(b, uint64(i)) 71 | conn.write(b) 72 | } 73 | 74 | func (conn *Conn) writeString(s string) { 75 | _, err := conn.writer.WriteString(s) 76 | panicIfErr(err) 77 | } 78 | 79 | func (conn *Conn) writeString0(s string) { 80 | conn.writeString(s) 81 | conn.writeByte(0) 82 | } 83 | 84 | func (conn *Conn) writeFlush() { 85 | conn.writeFrontendMessageCode(_Flush) 86 | conn.writeInt32(4) 87 | 88 | conn.flush() 89 | } 90 | 91 | func (conn *Conn) writeBind(stmt *Statement) { 92 | values := make([]string, len(stmt.params)) 93 | 94 | var paramValuesLen int 95 | for i, param := range stmt.params { 96 | value := param.value 97 | if val, ok := value.(uint64); ok { 98 | value = int64(val) 99 | } 100 | 101 | switch val := value.(type) { 102 | case bool: 103 | if val { 104 | values[i] = "t" 105 | } else { 106 | values[i] = "f" 107 | } 108 | 109 | case byte: 110 | values[i] = string([]byte{val}) 111 | 112 | case float32: 113 | values[i] = strconv.FormatFloat(float64(val), 'f', -1, 32) 114 | 115 | case float64: 116 | values[i] = strconv.FormatFloat(val, 'f', -1, 64) 117 | 118 | case int: 119 | values[i] = strconv.Itoa(val) 120 | 121 | case int16: 122 | values[i] = strconv.Itoa(int(val)) 123 | 124 | case int32: 125 | values[i] = strconv.Itoa(int(val)) 126 | 127 | case int64: 128 | switch param.typ { 129 | case Date: 130 | values[i] = time.Unix(val, 0).UTC().Format("2006-01-02") 131 | 132 | case Time, TimeTZ: 133 | values[i] = time.Unix(val, 0).UTC().Format("15:04:05") 134 | 135 | case Timestamp, TimestampTZ: 136 | values[i] = time.Unix(val, 0).UTC().Format("2006-01-02 15:04:05") 137 | 138 | default: 139 | values[i] = strconv.FormatInt(val, 10) 140 | } 141 | 142 | case nil: 143 | 144 | case *big.Rat: 145 | if val.IsInt() { 146 | values[i] = val.Num().String() 147 | } else { 148 | // FIXME: Find a better way to do this. 149 | prec999 := val.FloatString(999) 150 | trimmed := strings.TrimRight(prec999, "0") 151 | sepIndex := strings.Index(trimmed, ".") 152 | prec := len(trimmed) - sepIndex - 1 153 | values[i] = val.FloatString(prec) 154 | } 155 | 156 | case string: 157 | values[i] = val 158 | 159 | case time.Time: 160 | switch param.typ { 161 | case Date: 162 | values[i] = val.Format("2006-01-02") 163 | 164 | case Time, TimeTZ: 165 | values[i] = val.Format("15:04:05") 166 | 167 | case Timestamp, TimestampTZ: 168 | values[i] = val.Format("2006-01-02 15:04:05") 169 | 170 | default: 171 | panic("invalid use of time.Time") 172 | } 173 | 174 | default: 175 | panic("unsupported parameter type") 176 | } 177 | 178 | paramValuesLen += len(values[i]) 179 | } 180 | 181 | msgLen := int32(4 + 182 | len(stmt.portalName) + 1 + 183 | len(stmt.name) + 1 + 184 | 2 + 2 + 185 | 2 + len(stmt.params)*4 + paramValuesLen + 186 | 2 + 2) 187 | 188 | conn.writeFrontendMessageCode(_Bind) 189 | conn.writeInt32(msgLen) 190 | conn.writeString0(stmt.portalName) 191 | conn.writeString0(stmt.name) 192 | conn.writeInt16(1) 193 | conn.writeInt16(int16(textFormat)) 194 | conn.writeInt16(int16(len(stmt.params))) 195 | 196 | for i, param := range stmt.params { 197 | if param.value == nil { 198 | conn.writeInt32(-1) 199 | } else { 200 | conn.writeInt32(int32(len(values[i]))) 201 | conn.writeString(values[i]) 202 | } 203 | } 204 | 205 | conn.writeInt16(1) 206 | conn.writeInt16(int16(textFormat)) 207 | 208 | conn.writeFlush() 209 | } 210 | 211 | func (conn *Conn) writeClose(itemType byte, itemName string) { 212 | msgLen := int32(4 + 1 + len(itemName) + 1) 213 | 214 | conn.writeFrontendMessageCode(_Close) 215 | conn.writeInt32(msgLen) 216 | conn.writeByte(itemType) 217 | conn.writeString0(itemName) 218 | 219 | conn.flush() 220 | } 221 | 222 | func (conn *Conn) writeDescribe(stmt *Statement) { 223 | msgLen := int32(4 + 1 + len(stmt.portalName) + 1) 224 | 225 | conn.writeFrontendMessageCode(_Describe) 226 | conn.writeInt32(msgLen) 227 | conn.writeByte('P') 228 | conn.writeString0(stmt.portalName) 229 | 230 | conn.writeFlush() 231 | } 232 | 233 | func (conn *Conn) writeExecute(stmt *Statement) { 234 | msgLen := int32(4 + len(stmt.portalName) + 1 + 4) 235 | 236 | conn.writeFrontendMessageCode(_Execute) 237 | conn.writeInt32(msgLen) 238 | conn.writeString0(stmt.portalName) 239 | conn.writeInt32(0) 240 | 241 | conn.writeFlush() 242 | } 243 | 244 | func (conn *Conn) writeParse(stmt *Statement) { 245 | if conn.LogLevel >= LogDebug { 246 | defer conn.logExit(conn.logEnter("*Conn.writeParse")) 247 | } 248 | 249 | if conn.LogLevel >= LogCommand { 250 | conn.log(LogCommand, fmt.Sprintf("stmt.ActualCommand: '%s'", stmt.ActualCommand())) 251 | } 252 | 253 | msgLen := int32(4 + 254 | len(stmt.name) + 1 + 255 | len(stmt.actualCommand) + 1 + 256 | 2 + len(stmt.params)*4) 257 | 258 | conn.writeFrontendMessageCode(_Parse) 259 | conn.writeInt32(msgLen) 260 | conn.writeString0(stmt.name) 261 | conn.writeString0(stmt.actualCommand) 262 | 263 | conn.writeInt16(int16(len(stmt.params))) 264 | for _, param := range stmt.params { 265 | typ := param.typ 266 | if typ == Char { 267 | // FIXME: There seems to be something wrong with CHAR parameters. 268 | // Had a query that correctly returned rows in psql, but didn't 269 | // via go-pgsql statement. Changed param type from Char to Varchar 270 | // and it worked. The corresponding field in the table was CHAR(32). 271 | typ = Varchar 272 | } 273 | conn.writeInt32(int32(typ)) 274 | } 275 | 276 | conn.writeFlush() 277 | } 278 | 279 | func (conn *Conn) writePasswordMessage(password string) { 280 | if conn.LogLevel >= LogDebug { 281 | defer conn.logExit(conn.logEnter("*Conn.writePasswordMessage")) 282 | } 283 | 284 | msgLen := int32(4 + len(password) + 1) 285 | 286 | conn.writeFrontendMessageCode(_PasswordMessage) 287 | conn.writeInt32(msgLen) 288 | conn.writeString0(password) 289 | 290 | conn.flush() 291 | } 292 | 293 | func (conn *Conn) writeQuery(command string) { 294 | if conn.LogLevel >= LogDebug { 295 | defer conn.logExit(conn.logEnter("*Conn.writeQuery")) 296 | } 297 | 298 | if conn.LogLevel >= LogCommand { 299 | conn.log(LogCommand, fmt.Sprintf("command: '%s'", command)) 300 | } 301 | 302 | conn.writeFrontendMessageCode(_Query) 303 | conn.writeInt32(int32(4 + len(command) + 1)) 304 | conn.writeString0(command) 305 | 306 | conn.flush() 307 | } 308 | 309 | func (conn *Conn) writeStartup() { 310 | if conn.LogLevel >= LogDebug { 311 | defer conn.logExit(conn.logEnter("*Conn.writeStartup")) 312 | } 313 | 314 | msglen := int32(4 + 4 + 315 | len("user") + 1 + len(conn.params.User) + 1 + 316 | len("database") + 1 + len(conn.params.Database) + 1 + 1) 317 | 318 | conn.writeInt32(msglen) 319 | 320 | // For now we only support protocol version 3.0. 321 | conn.writeInt32(3 << 16) 322 | 323 | conn.writeString0("user") 324 | conn.writeString0(conn.params.User) 325 | 326 | conn.writeString0("database") 327 | conn.writeString0(conn.params.Database) 328 | 329 | conn.writeByte(0) 330 | 331 | conn.flush() 332 | } 333 | 334 | func (conn *Conn) writeSync() { 335 | conn.writeFrontendMessageCode(_Sync) 336 | conn.writeInt32(4) 337 | 338 | conn.writeFlush() 339 | } 340 | 341 | func (conn *Conn) writeTerminate() { 342 | conn.writeFrontendMessageCode(_Terminate) 343 | conn.writeInt32(4) 344 | 345 | conn.flush() 346 | } 347 | -------------------------------------------------------------------------------- /statement.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "regexp" 11 | ) 12 | 13 | var quoteRegExp = regexp.MustCompile("['][^']*[']") 14 | 15 | // Statement is a means to efficiently execute a parameterized SQL command multiple times. 16 | // 17 | // Call *Conn.Prepare to create a new prepared Statement. 18 | type Statement struct { 19 | conn *Conn 20 | name string 21 | portalName string 22 | command string 23 | actualCommand string 24 | isClosed bool 25 | params []*Parameter 26 | name2param map[string]*Parameter 27 | } 28 | 29 | func replaceParameterNameInSubstring(s, old, new string, buf *bytes.Buffer, paramRegExp *regexp.Regexp) { 30 | matchIndexPairs := paramRegExp.FindAllStringIndex(s, -1) 31 | prevMatchEnd := 1 32 | 33 | for _, pair := range matchIndexPairs { 34 | matchStart := pair[0] 35 | matchEnd := pair[1] 36 | 37 | buf.WriteString(s[prevMatchEnd-1 : matchStart+1]) 38 | buf.WriteString(new) 39 | 40 | prevMatchEnd = matchEnd 41 | } 42 | 43 | if prevMatchEnd > 1 { 44 | buf.WriteString(s[prevMatchEnd-1:]) 45 | return 46 | } 47 | 48 | buf.WriteString(s) 49 | } 50 | 51 | func replaceParameterName(command, old, new string) string { 52 | paramRegExp := regexp.MustCompile("[\\- |\n\r\t,)(;=+/<>][:|@]" + old[1:] + "([\\- |\n\r\t,)(;=+/<>]|$)") 53 | 54 | buf := bytes.NewBuffer(nil) 55 | 56 | quoteIndexPairs := quoteRegExp.FindAllStringIndex(command, -1) 57 | prevQuoteEnd := 0 58 | 59 | for _, pair := range quoteIndexPairs { 60 | quoteStart := pair[0] 61 | quoteEnd := pair[1] 62 | 63 | replaceParameterNameInSubstring(command[prevQuoteEnd:quoteStart], old, new, buf, paramRegExp) 64 | buf.WriteString(command[quoteStart:quoteEnd]) 65 | 66 | prevQuoteEnd = quoteEnd 67 | } 68 | 69 | if buf.Len() > 0 { 70 | replaceParameterNameInSubstring(command[prevQuoteEnd:], old, new, buf, paramRegExp) 71 | 72 | return buf.String() 73 | } 74 | 75 | replaceParameterNameInSubstring(command, old, new, buf, paramRegExp) 76 | 77 | return buf.String() 78 | } 79 | 80 | func adjustCommand(command string, params []*Parameter) string { 81 | for i, p := range params { 82 | var cast string 83 | if p.customTypeName != "" { 84 | cast = fmt.Sprintf("::%s", p.customTypeName) 85 | } 86 | command = replaceParameterName(command, p.name, fmt.Sprintf("$%d%s", i+1, cast)) 87 | } 88 | 89 | return command 90 | } 91 | 92 | func newStatement(conn *Conn, command string, params []*Parameter) *Statement { 93 | if conn.LogLevel >= LogDebug { 94 | defer conn.logExit(conn.logEnter("newStatement")) 95 | } 96 | 97 | stmt := &Statement{} 98 | 99 | stmt.name2param = make(map[string]*Parameter) 100 | 101 | for _, param := range params { 102 | if param == nil { 103 | panic("received a nil parameter") 104 | } 105 | if param.stmt != nil { 106 | panic(fmt.Sprintf("parameter '%s' already used in another statement", param.name)) 107 | } 108 | param.stmt = stmt 109 | 110 | stmt.name2param[param.name] = param 111 | } 112 | 113 | stmt.conn = conn 114 | 115 | stmt.name = fmt.Sprint("stmt", conn.nextStatementId) 116 | conn.nextStatementId++ 117 | 118 | stmt.portalName = fmt.Sprint("prtl", conn.nextPortalId) 119 | conn.nextPortalId++ 120 | 121 | stmt.command = command 122 | stmt.actualCommand = adjustCommand(command, params) 123 | 124 | stmt.params = make([]*Parameter, len(params)) 125 | copy(stmt.params, params) 126 | 127 | return stmt 128 | } 129 | 130 | // Conn returns the *Conn this Statement is associated with. 131 | func (stmt *Statement) Conn() *Conn { 132 | return stmt.conn 133 | } 134 | 135 | // Parameter returns the Parameter with the specified name or nil, if the Statement has no Parameter with that name. 136 | func (stmt *Statement) Parameter(name string) *Parameter { 137 | conn := stmt.conn 138 | 139 | if conn.LogLevel >= LogVerbose { 140 | defer conn.logExit(conn.logEnter("*Statement.Parameter")) 141 | } 142 | 143 | param, ok := stmt.name2param[name] 144 | if !ok { 145 | return nil 146 | } 147 | 148 | return param 149 | } 150 | 151 | // Parameters returns a slice containing the parameters of the Statement. 152 | func (stmt *Statement) Parameters() []*Parameter { 153 | conn := stmt.conn 154 | 155 | if conn.LogLevel >= LogVerbose { 156 | defer conn.logExit(conn.logEnter("*Statement.Parameters")) 157 | } 158 | 159 | params := make([]*Parameter, len(stmt.params)) 160 | copy(params, stmt.params) 161 | return params 162 | } 163 | 164 | // IsClosed returns if the Statement has been closed. 165 | func (stmt *Statement) IsClosed() bool { 166 | conn := stmt.conn 167 | 168 | if conn.LogLevel >= LogVerbose { 169 | defer conn.logExit(conn.logEnter("*Statement.IsClosed")) 170 | } 171 | 172 | return stmt.isClosed 173 | } 174 | 175 | func (stmt *Statement) close() { 176 | conn := stmt.conn 177 | 178 | if conn.LogLevel >= LogDebug { 179 | defer conn.logExit(conn.logEnter("*Statement.close")) 180 | } 181 | 182 | stmt.conn.writeClose('S', stmt.name) 183 | 184 | stmt.isClosed = true 185 | return 186 | } 187 | 188 | // Close closes the Statement, releasing resources on the server. 189 | func (stmt *Statement) Close() (err error) { 190 | err = stmt.conn.withRecover("*Statement.Close", func() { 191 | stmt.close() 192 | }) 193 | 194 | return 195 | } 196 | 197 | // ActualCommand returns the actual command text that is sent to the server. 198 | // 199 | // The original command is automatically adjusted if it contains parameters so 200 | // it complies with what PostgreSQL expects. Refer to the return value of this 201 | // method to make sense of the position information contained in many error 202 | // messages. 203 | func (stmt *Statement) ActualCommand() string { 204 | conn := stmt.conn 205 | 206 | if conn.LogLevel >= LogVerbose { 207 | defer conn.logExit(conn.logEnter("*Statement.ActualCommand")) 208 | } 209 | 210 | return stmt.actualCommand 211 | } 212 | 213 | // Command is the original command text as given to *Conn.Prepare. 214 | func (stmt *Statement) Command() string { 215 | conn := stmt.conn 216 | 217 | if conn.LogLevel >= LogVerbose { 218 | defer conn.logExit(conn.logEnter("*Statement.Command")) 219 | } 220 | 221 | return stmt.command 222 | } 223 | 224 | func (stmt *Statement) query() (rs *ResultSet) { 225 | conn := stmt.conn 226 | 227 | if conn.LogLevel >= LogDebug { 228 | defer conn.logExit(conn.logEnter("*Statement.query")) 229 | } 230 | 231 | if conn.LogLevel >= LogCommand { 232 | buf := bytes.NewBuffer(nil) 233 | 234 | buf.WriteString("\n=================================================\n") 235 | 236 | buf.WriteString("ActualCommand:\n") 237 | buf.WriteString(stmt.actualCommand) 238 | buf.WriteString("\n-------------------------------------------------\n") 239 | buf.WriteString("Parameters:\n") 240 | 241 | for i, p := range stmt.params { 242 | buf.WriteString(fmt.Sprintf("$%d (%s) = '%v'\n", i+1, p.name, p.value)) 243 | } 244 | 245 | buf.WriteString("=================================================\n") 246 | 247 | conn.log(LogCommand, buf.String()) 248 | } 249 | 250 | r := newResultSet(conn) 251 | 252 | conn.state.execute(stmt, r) 253 | 254 | rs = r 255 | 256 | return 257 | } 258 | 259 | // Query executes the Statement and returns a 260 | // ResultSet for row-by-row retrieval of the results. 261 | // 262 | // The returned ResultSet must be closed before sending another 263 | // query or command to the server over the same connection. 264 | func (stmt *Statement) Query() (rs *ResultSet, err error) { 265 | err = stmt.conn.withRecover("*Statement.Query", func() { 266 | rs = stmt.query() 267 | }) 268 | 269 | return 270 | } 271 | 272 | func (stmt *Statement) execute() (rowsAffected int64) { 273 | conn := stmt.conn 274 | 275 | if conn.LogLevel >= LogDebug { 276 | defer conn.logExit(conn.logEnter("*Statement.Execute")) 277 | } 278 | 279 | rs := stmt.query() 280 | rs.close() 281 | 282 | return rs.rowsAffected 283 | } 284 | 285 | // Execute executes the Statement and returns the number 286 | // of rows affected. 287 | // 288 | // If the results of a query are needed, use the 289 | // Query method instead. 290 | func (stmt *Statement) Execute() (rowsAffected int64, err error) { 291 | err = stmt.conn.withRecover("*Statement.Execute", func() { 292 | rowsAffected = stmt.execute() 293 | }) 294 | 295 | return 296 | } 297 | 298 | func (stmt *Statement) scan(args ...interface{}) (*ResultSet, bool) { 299 | conn := stmt.conn 300 | 301 | if conn.LogLevel >= LogDebug { 302 | defer conn.logExit(conn.logEnter("*Statement.Scan")) 303 | } 304 | 305 | rs := stmt.query() 306 | 307 | return rs, rs.scanNext(args...) 308 | } 309 | 310 | // Scan executes the statement and scans the fields of the first row 311 | // in the ResultSet, trying to store field values into the specified 312 | // arguments. 313 | // 314 | // The arguments must be of pointer types. If a row has 315 | // been fetched, fetched will be true, otherwise false. 316 | func (stmt *Statement) Scan(args ...interface{}) (fetched bool, err error) { 317 | err = stmt.conn.withRecover("*Statement.Scan", func() { 318 | var rs *ResultSet 319 | rs, fetched = stmt.scan(args...) 320 | rs.close() 321 | }) 322 | 323 | return 324 | } 325 | -------------------------------------------------------------------------------- /conn_read.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "crypto/md5" 9 | "encoding/binary" 10 | "encoding/hex" 11 | "fmt" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | func (conn *Conn) read(b []byte) { 17 | readTotal := 0 18 | for { 19 | n, err := conn.reader.Read(b[readTotal:]) 20 | panicIfErr(err) 21 | 22 | readTotal += n 23 | if readTotal == len(b) { 24 | break 25 | } 26 | } 27 | } 28 | 29 | func (conn *Conn) readByte() byte { 30 | b, err := conn.reader.ReadByte() 31 | panicIfErr(err) 32 | 33 | return b 34 | } 35 | 36 | func (conn *Conn) readBytes(delim byte) []byte { 37 | b, err := conn.reader.ReadBytes(delim) 38 | panicIfErr(err) 39 | 40 | return b 41 | } 42 | 43 | func (conn *Conn) readInt16() int16 { 44 | var buf [2]byte 45 | b := buf[:] 46 | 47 | conn.read(b) 48 | return int16(binary.BigEndian.Uint16(b)) 49 | } 50 | 51 | func (conn *Conn) readInt32() int32 { 52 | var buf [4]byte 53 | b := buf[:] 54 | 55 | conn.read(b) 56 | return int32(binary.BigEndian.Uint32(b)) 57 | } 58 | 59 | func (conn *Conn) readString() string { 60 | b := conn.readBytes(0) 61 | return string(b[:len(b)-1]) 62 | } 63 | 64 | func (conn *Conn) readAuthenticationRequest() { 65 | if conn.LogLevel >= LogDebug { 66 | defer conn.logExit(conn.logEnter("*Conn.readAuthenticationRequest")) 67 | } 68 | 69 | // Just eat message length. 70 | conn.readInt32() 71 | 72 | authType := conn.readInt32() 73 | switch authenticationType(authType) { 74 | case _AuthenticationOk: 75 | // nop 76 | 77 | // case _AuthenticationKerberosV5 authenticationType: 78 | 79 | // case _AuthenticationCleartextPassword: 80 | 81 | case _AuthenticationMD5Password: 82 | salt := make([]byte, 4) 83 | 84 | conn.read(salt) 85 | 86 | md5Hasher := md5.New() 87 | 88 | _, err := md5Hasher.Write([]byte(conn.params.Password)) 89 | panicIfErr(err) 90 | 91 | _, err = md5Hasher.Write([]byte(conn.params.User)) 92 | panicIfErr(err) 93 | 94 | md5HashHex1 := hex.EncodeToString(md5Hasher.Sum(nil)) 95 | 96 | md5Hasher.Reset() 97 | 98 | _, err = md5Hasher.Write([]byte(md5HashHex1)) 99 | panicIfErr(err) 100 | 101 | _, err = md5Hasher.Write(salt) 102 | panicIfErr(err) 103 | 104 | md5HashHex2 := hex.EncodeToString(md5Hasher.Sum(nil)) 105 | 106 | password := "md5" + md5HashHex2 107 | 108 | conn.writePasswordMessage(password) 109 | 110 | // case _AuthenticationSCMCredential: 111 | 112 | // case _AuthenticationGSS: 113 | 114 | // case _AuthenticationGSSContinue: 115 | 116 | // case _AuthenticationSSPI: 117 | 118 | default: 119 | panic(fmt.Sprintf("unsupported authentication type: %d", authType)) 120 | } 121 | } 122 | 123 | func (conn *Conn) readBackendKeyData() { 124 | if conn.LogLevel >= LogDebug { 125 | defer conn.logExit(conn.logEnter("*Conn.readBackendKeyData")) 126 | } 127 | 128 | // Just eat message length. 129 | conn.readInt32() 130 | 131 | conn.backendPID = conn.readInt32() 132 | conn.backendSecretKey = conn.readInt32() 133 | } 134 | 135 | func (conn *Conn) readBindComplete() { 136 | if conn.LogLevel >= LogDebug { 137 | defer conn.logExit(conn.logEnter("*Conn.readBindComplete")) 138 | } 139 | 140 | // Just eat message length. 141 | conn.readInt32() 142 | } 143 | 144 | func (conn *Conn) readCloseComplete() { 145 | if conn.LogLevel >= LogDebug { 146 | defer conn.logExit(conn.logEnter("*Conn.readCloseComplete")) 147 | } 148 | 149 | // Just eat message length. 150 | conn.readInt32() 151 | } 152 | 153 | func (conn *Conn) readCommandComplete(rs *ResultSet) { 154 | if conn.LogLevel >= LogDebug { 155 | defer conn.logExit(conn.logEnter("*Conn.readCommandComplete")) 156 | } 157 | 158 | // Just eat message length. 159 | conn.readInt32() 160 | 161 | // Retrieve the number of affected rows from the command tag. 162 | tag := conn.readString() 163 | 164 | if rs != nil { 165 | parts := strings.Split(tag, " ") 166 | 167 | rs.rowsAffected, _ = strconv.ParseInt(parts[len(parts)-1], 10, 64) 168 | rs.currentResultComplete = true 169 | } 170 | } 171 | 172 | // As of PostgreSQL 9.2 (protocol 3.0), CopyOutResponse and CopyBothResponse 173 | // are exactly the same. 174 | func (conn *Conn) readCopyInResponse() { 175 | if conn.LogLevel >= LogDebug { 176 | defer conn.logExit(conn.logEnter("*Conn.readCopyInResponse")) 177 | } 178 | 179 | // Just eat message length. 180 | conn.readInt32() 181 | 182 | // Just eat overall COPY format. 0 - textual, 1 - binary. 183 | conn.readByte() 184 | 185 | numColumns := conn.readInt16() 186 | for i := int16(0); i < numColumns; i++ { 187 | // Just eat column formats. 188 | conn.readInt16() 189 | } 190 | 191 | conn.state = copyState{} 192 | } 193 | 194 | func (conn *Conn) readDataRow(rs *ResultSet) { 195 | // Just eat message length. 196 | conn.readInt32() 197 | 198 | fieldCount := conn.readInt16() 199 | 200 | var ord int16 201 | for ord = 0; ord < fieldCount; ord++ { 202 | valLen := conn.readInt32() 203 | 204 | var val []byte 205 | 206 | if valLen == -1 { 207 | val = nil 208 | } else { 209 | val = make([]byte, valLen) 210 | conn.read(val) 211 | } 212 | 213 | rs.values[ord] = val 214 | } 215 | } 216 | 217 | func (conn *Conn) readEmptyQueryResponse() { 218 | if conn.LogLevel >= LogDebug { 219 | defer conn.logExit(conn.logEnter("*Conn.readEmptyQueryResponse")) 220 | } 221 | 222 | // Just eat message length. 223 | conn.readInt32() 224 | } 225 | 226 | func (conn *Conn) readErrorOrNoticeResponse(isError bool) { 227 | if conn.LogLevel >= LogDebug { 228 | defer conn.logExit(conn.logEnter("*Conn.readErrorOrNoticeResponse")) 229 | } 230 | 231 | // Just eat message length. 232 | conn.readInt32() 233 | 234 | err := &Error{} 235 | 236 | // Read all fields, just ignore unknown ones. 237 | for { 238 | fieldType := conn.readByte() 239 | 240 | if fieldType == 0 { 241 | if isError { 242 | if !conn.onErrorDontRequireReadyForQuery { 243 | // Before panicking, we have to wait for a ReadyForQuery message. 244 | conn.readBackendMessages(nil) 245 | } 246 | 247 | // We panic with our error as parameter, so the right thing (TM) will happen. 248 | panic(err) 249 | } else { 250 | // For now, we just log notices. 251 | conn.logError(LogDebug, err) 252 | return 253 | } 254 | } 255 | 256 | str := conn.readString() 257 | 258 | switch fieldType { 259 | case 'S': 260 | err.severity = str 261 | 262 | case 'C': 263 | err.code = str 264 | 265 | case 'M': 266 | err.message = str 267 | 268 | case 'D': 269 | err.detail = str 270 | 271 | case 'H': 272 | err.hint = str 273 | 274 | case 'P': 275 | err.position = str 276 | 277 | case 'p': 278 | err.internalPosition = str 279 | 280 | case 'q': 281 | err.internalQuery = str 282 | 283 | case 'W': 284 | err.where = str 285 | 286 | case 'F': 287 | err.file = str 288 | 289 | case 'L': 290 | err.line = str 291 | 292 | case 'R': 293 | err.routine = str 294 | } 295 | } 296 | } 297 | 298 | func (conn *Conn) readNoData() { 299 | if conn.LogLevel >= LogDebug { 300 | defer conn.logExit(conn.logEnter("*Conn.readNoData")) 301 | } 302 | 303 | // Just eat message length. 304 | conn.readInt32() 305 | } 306 | 307 | func (conn *Conn) readParameterStatus() { 308 | if conn.LogLevel >= LogDebug { 309 | defer conn.logExit(conn.logEnter("*Conn.readParameterStatus")) 310 | } 311 | 312 | // Just eat message length. 313 | conn.readInt32() 314 | 315 | name := conn.readString() 316 | value := conn.readString() 317 | 318 | if conn.LogLevel >= LogDebug { 319 | conn.logf(LogDebug, "ParameterStatus: Name: '%s', Value: '%s'", name, value) 320 | } 321 | 322 | conn.runtimeParameters[name] = value 323 | 324 | if name == "DateStyle" { 325 | conn.updateTimeFormats() 326 | } 327 | } 328 | 329 | func (conn *Conn) readParseComplete() { 330 | if conn.LogLevel >= LogDebug { 331 | defer conn.logExit(conn.logEnter("*Conn.readParseComplete")) 332 | } 333 | 334 | // Just eat message length. 335 | conn.readInt32() 336 | } 337 | 338 | func (conn *Conn) readReadyForQuery(rs *ResultSet) { 339 | if conn.LogLevel >= LogDebug { 340 | defer conn.logExit(conn.logEnter("*Conn.readReadyForQuery")) 341 | } 342 | 343 | // Just eat message length. 344 | conn.readInt32() 345 | 346 | txStatus := conn.readByte() 347 | 348 | if conn.LogLevel >= LogDebug { 349 | conn.log(LogDebug, "Transaction Status: ", string([]byte{txStatus})) 350 | } 351 | 352 | conn.transactionStatus = TransactionStatus(txStatus) 353 | 354 | if rs != nil { 355 | rs.allResultsComplete = true 356 | } 357 | 358 | conn.state = readyState{} 359 | } 360 | 361 | func (conn *Conn) readRowDescription(rs *ResultSet) { 362 | // Just eat message length. 363 | conn.readInt32() 364 | 365 | fieldCount := conn.readInt16() 366 | 367 | rs.fields = make([]field, fieldCount) 368 | rs.values = make([][]byte, fieldCount) 369 | 370 | var ord int16 371 | for ord = 0; ord < fieldCount; ord++ { 372 | rs.fields[ord].name = conn.readString() 373 | 374 | // Just eat table OID. 375 | conn.readInt32() 376 | 377 | // Just eat field OID. 378 | conn.readInt16() 379 | 380 | rs.fields[ord].typeOID = conn.readInt32() 381 | 382 | // Just eat field size. 383 | conn.readInt16() 384 | 385 | // Just eat field type modifier. 386 | conn.readInt32() 387 | 388 | format := fieldFormat(conn.readInt16()) 389 | switch format { 390 | case textFormat, binaryFormat: 391 | // nop 392 | 393 | default: 394 | panic("unsupported field format") 395 | } 396 | rs.fields[ord].format = format 397 | } 398 | } 399 | 400 | func (conn *Conn) readBackendMessages(rs *ResultSet) { 401 | if conn.LogLevel >= LogDebug { 402 | defer conn.logExit(conn.logEnter("*Conn.readBackendMessages")) 403 | } 404 | 405 | for { 406 | msgCode := backendMessageCode(conn.readByte()) 407 | 408 | if conn.LogLevel >= LogDebug { 409 | conn.logf(LogDebug, "received '%s' backend message", msgCode) 410 | } 411 | 412 | switch msgCode { 413 | case _AuthenticationRequest: 414 | conn.readAuthenticationRequest() 415 | 416 | case _BackendKeyData: 417 | conn.readBackendKeyData() 418 | 419 | case _BindComplete: 420 | conn.readBindComplete() 421 | return 422 | 423 | case _CloseComplete: 424 | conn.readCloseComplete() 425 | 426 | case _CommandComplete: 427 | conn.readCommandComplete(rs) 428 | return 429 | 430 | case _CopyInResponse: 431 | conn.readCopyInResponse() 432 | return 433 | 434 | case _DataRow: 435 | rs.readRow() 436 | return 437 | 438 | case _EmptyQueryResponse: 439 | conn.readEmptyQueryResponse() 440 | 441 | case _ErrorResponse: 442 | conn.readErrorOrNoticeResponse(true) 443 | 444 | case _NoData: 445 | conn.readNoData() 446 | return 447 | 448 | case _NoticeResponse: 449 | conn.readErrorOrNoticeResponse(false) 450 | 451 | case _ParameterStatus: 452 | conn.readParameterStatus() 453 | 454 | case _ParseComplete: 455 | conn.readParseComplete() 456 | return 457 | 458 | case _ReadyForQuery: 459 | conn.readReadyForQuery(rs) 460 | return 461 | 462 | case _RowDescription: 463 | rs.initializeResult() 464 | return 465 | } 466 | } 467 | } 468 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // The pgsql package implements a PostgreSQL frontend library. 6 | // It is compatible with servers of version 7.4 and later. 7 | package pgsql 8 | 9 | import ( 10 | "bufio" 11 | "bytes" 12 | "errors" 13 | "fmt" 14 | "io" 15 | "net" 16 | "os" 17 | "strconv" 18 | "strings" 19 | "time" 20 | ) 21 | 22 | // LogLevel is used to control what is written to the log. 23 | type LogLevel int 24 | 25 | const ( 26 | // Log nothing. 27 | LogNothing LogLevel = iota 28 | 29 | // Log fatal errors. 30 | LogFatal 31 | 32 | // Log all errors. 33 | LogError 34 | 35 | // Log errors and warnings. 36 | LogWarning 37 | 38 | // Log errors, warnings and sent commands. 39 | LogCommand 40 | 41 | // Log errors, warnings, sent commands and additional debug info. 42 | LogDebug 43 | 44 | // Log everything. 45 | LogVerbose 46 | ) 47 | 48 | type connParams struct { 49 | Host string 50 | Port int 51 | User string 52 | Password string 53 | Database string 54 | TimeoutSeconds int 55 | } 56 | 57 | // ConnStatus represents the status of a connection. 58 | type ConnStatus int 59 | 60 | const ( 61 | StatusDisconnected ConnStatus = iota 62 | StatusReady 63 | StatusProcessingQuery 64 | StatusCopy 65 | ) 66 | 67 | func (s ConnStatus) String() string { 68 | switch s { 69 | case StatusDisconnected: 70 | return "Disconnected" 71 | 72 | case StatusReady: 73 | return "Ready" 74 | 75 | case StatusProcessingQuery: 76 | return "Processing Query" 77 | 78 | case StatusCopy: 79 | return "Bulk Copy" 80 | } 81 | 82 | return "Unknown" 83 | } 84 | 85 | // IsolationLevel represents the isolation level of a transaction. 86 | type IsolationLevel int 87 | 88 | const ( 89 | ReadCommittedIsolation IsolationLevel = iota 90 | SerializableIsolation 91 | ) 92 | 93 | func (il IsolationLevel) String() string { 94 | switch il { 95 | case ReadCommittedIsolation: 96 | return "Read Committed" 97 | 98 | case SerializableIsolation: 99 | return "Serializable" 100 | } 101 | 102 | return "Unknown" 103 | } 104 | 105 | // TransactionStatus represents the transaction status of a connection. 106 | type TransactionStatus byte 107 | 108 | const ( 109 | NotInTransaction TransactionStatus = 'I' 110 | InTransaction TransactionStatus = 'T' 111 | InFailedTransaction TransactionStatus = 'E' 112 | ) 113 | 114 | func (s TransactionStatus) String() string { 115 | switch s { 116 | case NotInTransaction: 117 | return "Not In Transaction" 118 | 119 | case InTransaction: 120 | return "In Transaction" 121 | 122 | case InFailedTransaction: 123 | return "In Failed Transaction" 124 | } 125 | 126 | return "Unknown" 127 | } 128 | 129 | // Conn represents a PostgreSQL database connection. 130 | type Conn struct { 131 | LogLevel LogLevel 132 | tcpConn net.Conn 133 | reader *bufio.Reader 134 | writer *bufio.Writer 135 | params *connParams 136 | state state 137 | backendPID int32 138 | backendSecretKey int32 139 | onErrorDontRequireReadyForQuery bool 140 | runtimeParameters map[string]string 141 | nextStatementId uint64 142 | nextPortalId uint64 143 | nextSavepointId uint64 144 | transactionStatus TransactionStatus 145 | dateFormat string 146 | timeFormat string 147 | timestampFormat string 148 | timestampTimezoneFormat string 149 | } 150 | 151 | func (conn *Conn) withRecover(funcName string, f func()) (err error) { 152 | if conn.LogLevel >= LogDebug { 153 | defer conn.logExit(conn.logEnter(funcName)) 154 | } 155 | 156 | defer func() { 157 | if x := recover(); x != nil { 158 | err = conn.logAndConvertPanic(x) 159 | } 160 | }() 161 | 162 | f() 163 | 164 | return 165 | } 166 | 167 | func parseParamsInUnquotedSubstring(s string, name2value map[string]string) (lastKeyword string) { 168 | var words []string 169 | 170 | for { 171 | index := strings.IndexAny(s, "= \n\r\t") 172 | if index == -1 { 173 | break 174 | } 175 | 176 | word := s[0:index] 177 | if word != "" { 178 | words = append(words, word) 179 | } 180 | s = s[index+1:] 181 | } 182 | if len(s) > 0 { 183 | words = append(words, s) 184 | } 185 | 186 | for i := 0; i < len(words)-1; i += 2 { 187 | name2value[words[i]] = words[i+1] 188 | } 189 | 190 | if len(words) > 0 && len(words)%2 == 1 { 191 | lastKeyword = words[len(words)-1] 192 | } 193 | 194 | return 195 | } 196 | 197 | func (conn *Conn) parseParams(s string) *connParams { 198 | name2value := make(map[string]string) 199 | 200 | quoteIndexPairs := quoteRegExp.FindAllStringIndex(s, -1) 201 | prevQuoteEnd := 0 202 | 203 | for _, pair := range quoteIndexPairs { 204 | quoteStart := pair[0] 205 | quoteEnd := pair[1] 206 | 207 | lastKeyword := parseParamsInUnquotedSubstring(s[prevQuoteEnd:quoteStart], name2value) 208 | if lastKeyword != "" { 209 | name2value[lastKeyword] = s[quoteStart+1 : quoteEnd-1] 210 | } 211 | 212 | prevQuoteEnd = quoteEnd 213 | } 214 | 215 | if prevQuoteEnd > 0 { 216 | parseParamsInUnquotedSubstring(s[prevQuoteEnd:], name2value) 217 | } else { 218 | parseParamsInUnquotedSubstring(s, name2value) 219 | } 220 | 221 | params := &connParams{} 222 | 223 | params.Host = name2value["host"] 224 | params.Port, _ = strconv.Atoi(name2value["port"]) 225 | params.Database = name2value["dbname"] 226 | params.User = name2value["user"] 227 | params.Password = name2value["password"] 228 | if params.Password == "" { 229 | params.Password, _ = passwordfromfile(params.Host, params.Port, params.Database, params.User) 230 | } 231 | params.TimeoutSeconds, _ = strconv.Atoi(name2value["timeout"]) 232 | 233 | if conn.LogLevel >= LogDebug { 234 | buf := bytes.NewBuffer(nil) 235 | 236 | for name, value := range name2value { 237 | buf.WriteString(fmt.Sprintf("%s = '%s'\n", name, value)) 238 | } 239 | 240 | conn.log(LogDebug, "Parsed connection parameter settings:\n", buf) 241 | } 242 | 243 | return params 244 | } 245 | 246 | // Connect establishes a database connection. 247 | // 248 | // Parameter settings in connStr have to be separated by whitespace and are 249 | // expected in keyword = value form. Spaces around equal signs are optional. 250 | // Use single quotes for empty values or values containing spaces. 251 | // 252 | // Currently these keywords are supported: 253 | // 254 | // host = Name of the host to connect to (default: localhost) 255 | // port = Integer port number the server listens on (default: 5432) 256 | // dbname = Database name (default: same as user) 257 | // user = User to connect as 258 | // password = Password for password based authentication methods 259 | // timeout = Timeout in seconds, 0 or not specified disables timeout (default: 0) 260 | func Connect(connStr string, logLevel LogLevel) (conn *Conn, err error) { 261 | newConn := &Conn{} 262 | 263 | newConn.LogLevel = logLevel 264 | newConn.state = disconnectedState{} 265 | 266 | if newConn.LogLevel >= LogDebug { 267 | defer newConn.logExit(newConn.logEnter("Connect")) 268 | } 269 | 270 | defer func() { 271 | if x := recover(); x != nil { 272 | err = newConn.logAndConvertPanic(x) 273 | } 274 | }() 275 | 276 | params := newConn.parseParams(connStr) 277 | newConn.params = params 278 | 279 | var env string // Reusable environment variable used to capture PG environment variables - PGHOST, PGPORT, PGDATABASE, PGUSER 280 | 281 | if params.Host == "" { 282 | params.Host = "localhost" 283 | } 284 | env = os.Getenv("PGHOST") 285 | if env != "" { 286 | params.Host = env 287 | } 288 | if params.Port == 0 { 289 | params.Port = 5432 290 | } 291 | env = os.Getenv("PGPORT") 292 | if env != "" { 293 | params.Port, _ = strconv.Atoi(env) 294 | } 295 | if params.Database == "" { 296 | params.Database = params.User 297 | } 298 | env = os.Getenv("PGDATABASE") 299 | if env != "" { 300 | params.Database = env 301 | } 302 | env = os.Getenv("PGUSER") 303 | if env != "" { 304 | params.User = env 305 | } 306 | 307 | tcpConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", params.Host, params.Port)) 308 | panicIfErr(err) 309 | 310 | panicIfErr(tcpConn.SetDeadline(time.Unix(int64(params.TimeoutSeconds*1000*1000*1000), 0))) 311 | 312 | newConn.tcpConn = tcpConn 313 | 314 | newConn.reader = bufio.NewReader(tcpConn) 315 | newConn.writer = bufio.NewWriter(tcpConn) 316 | 317 | newConn.runtimeParameters = make(map[string]string) 318 | 319 | newConn.onErrorDontRequireReadyForQuery = true 320 | defer func() { 321 | newConn.onErrorDontRequireReadyForQuery = false 322 | }() 323 | 324 | newConn.writeStartup() 325 | 326 | newConn.readBackendMessages(nil) 327 | 328 | newConn.state = readyState{} 329 | newConn.params = nil 330 | 331 | newConn.transactionStatus = NotInTransaction 332 | 333 | conn = newConn 334 | 335 | return 336 | } 337 | 338 | // Close closes the connection to the database. 339 | func (conn *Conn) Close() (err error) { 340 | return conn.withRecover("*Conn.Close", func() { 341 | if conn.Status() == StatusDisconnected { 342 | err = errors.New("connection already closed") 343 | conn.logError(LogWarning, err) 344 | return 345 | } 346 | 347 | conn.writeTerminate() 348 | 349 | panicIfErr(conn.tcpConn.Close()) 350 | 351 | conn.state = disconnectedState{} 352 | }) 353 | } 354 | 355 | func (conn *Conn) copyFrom(command string, r io.Reader) int64 { 356 | if conn.LogLevel >= LogDebug { 357 | defer conn.logExit(conn.logEnter("*Conn.copyFrom")) 358 | } 359 | 360 | conn.writeQuery(command) 361 | conn.readBackendMessages(nil) 362 | if stateCode := conn.state.code(); stateCode != StatusCopy { 363 | panic("wrong state, expected: StatusCopy, have: " + stateCode.String()) 364 | return 0 365 | } 366 | 367 | // FIXME: magic number; wild guess without any reason. 368 | const CopyBufferSize = 32 << 10 369 | buf := make([]byte, CopyBufferSize) 370 | var nr int 371 | var err error 372 | for { 373 | nr, err = r.Read(buf) 374 | if err != nil && err != io.EOF { 375 | message := err.Error() 376 | conn.writeFrontendMessageCode(_CopyFail) 377 | conn.writeInt32(int32(5 + len(message))) 378 | conn.writeString0(message) 379 | panic(err) 380 | } 381 | if nr > 0 { 382 | conn.writeFrontendMessageCode(_CopyData_FE) 383 | conn.writeInt32(int32(4 + nr)) 384 | conn.write(buf[:nr]) 385 | conn.flush() 386 | } 387 | // TODO: peek backend message. Maybe there was error in data 388 | // and we can stop sending early. 389 | if err == io.EOF { 390 | break 391 | } 392 | } 393 | conn.writeFrontendMessageCode(_CopyDone_FE) 394 | conn.writeInt32(4) 395 | conn.flush() 396 | 397 | rs := newResultSet(conn) 398 | conn.readBackendMessages(rs) 399 | rs.close() 400 | 401 | return rs.rowsAffected 402 | } 403 | 404 | // CopyFrom sends a `COPY table FROM STDIN` SQL command to the server and 405 | // returns the number of rows affected. 406 | func (conn *Conn) CopyFrom(command string, r io.Reader) (rowsAffected int64, err error) { 407 | err = conn.withRecover("*Conn.CopyFrom", func() { 408 | rowsAffected = conn.copyFrom(command, r) 409 | }) 410 | 411 | return 412 | } 413 | 414 | func getpgpassfilename() string { 415 | var env string 416 | env = os.Getenv("PGPASSFILE") 417 | if env != "" { 418 | return env 419 | } 420 | env = os.Getenv("HOME") 421 | return fmt.Sprintf("%s/.pgpass", env) 422 | } 423 | 424 | func passwordfromfile(hostname string, port int, dbname string, username string) (string, error) { 425 | var sport string 426 | var lhostname string 427 | if dbname == "" { 428 | return "", nil 429 | } 430 | if username == "" { 431 | return "", nil 432 | } 433 | if hostname == "" { 434 | lhostname = "localhost" 435 | } else { 436 | lhostname = hostname 437 | } 438 | if port == 0 { 439 | sport = "5432" 440 | } else { 441 | sport = fmt.Sprintf("%d", port) 442 | } 443 | pgfile := getpgpassfilename() 444 | fileinfo, err := os.Stat(pgfile) 445 | if err != nil { 446 | err := errors.New(fmt.Sprintf("WARNING: password file \"%s\" is not a plain file\n", pgfile)) 447 | return "", err 448 | } 449 | if (fileinfo.Mode() & 077) != 0 { 450 | err := errors.New(fmt.Sprintf("WARNING: password file \"%s\" has group or world access; permissions should be u=rw (0600) or less", pgfile)) 451 | return "", err 452 | } 453 | fp, err := os.Open(pgfile) 454 | if err != nil { 455 | err := errors.New(fmt.Sprintf("Problem opening pgpass file \"%s\"", pgfile)) 456 | return "", err 457 | } 458 | br := bufio.NewReader(fp) 459 | for { 460 | line, ok := br.ReadString('\n') 461 | if ok == io.EOF { 462 | return "", nil 463 | } 464 | // Now, split the line into pieces 465 | // hostname:port:database:username:password 466 | // and * matches anything 467 | pieces := strings.Split(line, ":") 468 | phost := pieces[0] 469 | pport := pieces[1] 470 | pdb := pieces[2] 471 | puser := pieces[3] 472 | ppass := pieces[4] 473 | 474 | if (phost == lhostname || phost == "*") && 475 | (pport == "*" || pport == sport) && 476 | (pdb == "*" || pdb == dbname) && 477 | (puser == "*" || puser == username) { 478 | 479 | return ppass, nil 480 | } 481 | } 482 | return "", nil 483 | } 484 | 485 | func (conn *Conn) execute(command string, params ...*Parameter) int64 { 486 | if conn.LogLevel >= LogDebug { 487 | defer conn.logExit(conn.logEnter("*Conn.execute")) 488 | } 489 | 490 | rs := conn.query(command, params...) 491 | rs.close() 492 | 493 | return rs.rowsAffected 494 | } 495 | 496 | // Execute sends a SQL command to the server and returns the number 497 | // of rows affected. 498 | // 499 | // If the results of a query are needed, use the 500 | // Query method instead. 501 | func (conn *Conn) Execute(command string, params ...*Parameter) (rowsAffected int64, err error) { 502 | err = conn.withRecover("*Conn.Execute", func() { 503 | rowsAffected = conn.execute(command, params...) 504 | }) 505 | 506 | return 507 | } 508 | 509 | func (conn *Conn) prepare(command string, params ...*Parameter) *Statement { 510 | if conn.LogLevel >= LogDebug { 511 | defer conn.logExit(conn.logEnter("*Conn.prepare")) 512 | } 513 | 514 | stmt := newStatement(conn, command, params) 515 | 516 | conn.state.prepare(stmt) 517 | 518 | return stmt 519 | } 520 | 521 | // Prepare returns a new prepared Statement, optimized to be executed multiple 522 | // times with different parameter values. 523 | func (conn *Conn) Prepare(command string, params ...*Parameter) (stmt *Statement, err error) { 524 | err = conn.withRecover("*Conn.Prepare", func() { 525 | stmt = conn.prepare(command, params...) 526 | }) 527 | 528 | return 529 | } 530 | 531 | func (conn *Conn) query(command string, params ...*Parameter) (rs *ResultSet) { 532 | if conn.LogLevel >= LogDebug { 533 | defer conn.logExit(conn.logEnter("*Conn.query")) 534 | } 535 | 536 | var stmt *Statement 537 | if len(params) == 0 { 538 | r := newResultSet(conn) 539 | 540 | conn.state.query(conn, r, command) 541 | 542 | rs = r 543 | } else { 544 | stmt = conn.prepare(command, params...) 545 | defer stmt.close() 546 | 547 | rs = stmt.query() 548 | } 549 | 550 | return 551 | } 552 | 553 | // Query sends a SQL query to the server and returns a 554 | // ResultSet for row-by-row retrieval of the results. 555 | // 556 | // The returned ResultSet must be closed before sending another 557 | // query or command to the server over the same connection. 558 | func (conn *Conn) Query(command string, params ...*Parameter) (rs *ResultSet, err error) { 559 | err = conn.withRecover("*Conn.Query", func() { 560 | rs = conn.query(command, params...) 561 | }) 562 | 563 | return 564 | } 565 | 566 | // RuntimeParameter returns the value of the specified runtime parameter. 567 | // 568 | // If the value was successfully retrieved, ok is true, otherwise false. 569 | func (conn *Conn) RuntimeParameter(name string) (value string, ok bool) { 570 | if conn.LogLevel >= LogVerbose { 571 | defer conn.logExit(conn.logEnter("*Conn.RuntimeParameter")) 572 | } 573 | 574 | value, ok = conn.runtimeParameters[name] 575 | return 576 | } 577 | 578 | func (conn *Conn) scan(command string, args ...interface{}) (*ResultSet, bool) { 579 | if conn.LogLevel >= LogDebug { 580 | defer conn.logExit(conn.logEnter("*Conn.scan")) 581 | } 582 | 583 | rs := conn.query(command) 584 | 585 | return rs, rs.scanNext(args...) 586 | } 587 | 588 | // Scan executes the command and scans the fields of the first row 589 | // in the ResultSet, trying to store field values into the specified 590 | // arguments. 591 | // 592 | // The arguments must be of pointer types. If a row has 593 | // been fetched, fetched will be true, otherwise false. 594 | func (conn *Conn) Scan(command string, args ...interface{}) (fetched bool, err error) { 595 | err = conn.withRecover("*Conn.Scan", func() { 596 | var rs *ResultSet 597 | rs, fetched = conn.scan(command, args...) 598 | rs.close() 599 | }) 600 | 601 | return 602 | } 603 | 604 | // Status returns the current connection status. 605 | func (conn *Conn) Status() ConnStatus { 606 | return conn.state.code() 607 | } 608 | 609 | // TransactionStatus returns the current transaction status of the connection. 610 | func (conn *Conn) TransactionStatus() TransactionStatus { 611 | return conn.transactionStatus 612 | } 613 | 614 | // WithTransaction starts a new transaction, if none is in progress, then 615 | // calls f. 616 | // 617 | // If f returns an error or panicks, the transaction is rolled back, 618 | // otherwise it is committed. If the connection is in a failed transaction when 619 | // calling WithTransaction, this function immediately returns with an error, 620 | // without calling f. In case of an active transaction without error, 621 | // WithTransaction just calls f. 622 | func (conn *Conn) WithTransaction(isolation IsolationLevel, f func() error) (err error) { 623 | if conn.LogLevel >= LogDebug { 624 | defer conn.logExit(conn.logEnter("*Conn.WithTransaction")) 625 | } 626 | 627 | oldStatus := conn.transactionStatus 628 | 629 | if oldStatus == InFailedTransaction { 630 | return conn.logAndConvertPanic("error in transaction") 631 | } 632 | 633 | defer func() { 634 | if x := recover(); x != nil { 635 | err = conn.logAndConvertPanic(x) 636 | } 637 | if err == nil && conn.transactionStatus == InFailedTransaction { 638 | err = conn.logAndConvertPanic("error in transaction") 639 | } 640 | if err != nil && oldStatus == NotInTransaction { 641 | conn.execute("ROLLBACK;") 642 | } 643 | }() 644 | 645 | if oldStatus == NotInTransaction { 646 | var isol string 647 | if isolation == SerializableIsolation { 648 | isol = "SERIALIZABLE" 649 | } else { 650 | isol = "READ COMMITTED" 651 | } 652 | cmd := fmt.Sprintf("BEGIN; SET TRANSACTION ISOLATION LEVEL %s;", isol) 653 | conn.execute(cmd) 654 | } 655 | 656 | panicIfErr(f()) 657 | 658 | if oldStatus == NotInTransaction && conn.transactionStatus == InTransaction { 659 | conn.execute("COMMIT;") 660 | } 661 | return 662 | } 663 | 664 | // WithSavepoint creates a transaction savepoint, if the connection is in an 665 | // active transaction without errors, then calls f. 666 | // 667 | // If f returns an error or 668 | // panicks, the transaction is rolled back to the savepoint. If the connection 669 | // is in a failed transaction when calling WithSavepoint, this function 670 | // immediately returns with an error, without calling f. If no transaction is in 671 | // progress, instead of creating a savepoint, a new transaction is started. 672 | func (conn *Conn) WithSavepoint(isolation IsolationLevel, f func() error) (err error) { 673 | if conn.LogLevel >= LogDebug { 674 | defer conn.logExit(conn.logEnter("*Conn.WithSavepoint")) 675 | } 676 | 677 | oldStatus := conn.transactionStatus 678 | 679 | switch oldStatus { 680 | case InFailedTransaction: 681 | return conn.logAndConvertPanic("error in transaction") 682 | 683 | case NotInTransaction: 684 | return conn.WithTransaction(isolation, f) 685 | } 686 | 687 | savepointName := fmt.Sprintf("sp%d", conn.nextSavepointId) 688 | conn.nextSavepointId++ 689 | 690 | defer func() { 691 | if x := recover(); x != nil { 692 | err = conn.logAndConvertPanic(x) 693 | } 694 | if err == nil && conn.transactionStatus == InFailedTransaction { 695 | err = conn.logAndConvertPanic("error in transaction") 696 | } 697 | if err != nil { 698 | conn.execute(fmt.Sprintf("ROLLBACK TO %s;", savepointName)) 699 | } 700 | }() 701 | 702 | conn.execute(fmt.Sprintf("SAVEPOINT %s;", savepointName)) 703 | 704 | panicIfErr(f()) 705 | 706 | return 707 | } 708 | 709 | func (conn *Conn) updateTimeFormats() { 710 | style := conn.runtimeParameters["DateStyle"] 711 | 712 | switch style { 713 | case "ISO", "ISO, DMY", "ISO, MDY", "ISO, YMD": 714 | conn.dateFormat = "2006-01-02" 715 | conn.timeFormat = "15:04:05" 716 | conn.timestampFormat = "2006-01-02 15:04:05" 717 | conn.timestampTimezoneFormat = "-07" 718 | 719 | case "SQL", "SQL, MDY": 720 | conn.dateFormat = "01/02/2006" 721 | conn.timeFormat = "15:04:05" 722 | conn.timestampFormat = "01/02/2006 15:04:05" 723 | conn.timestampTimezoneFormat = " MST" 724 | 725 | case "SQL, DMY": 726 | conn.dateFormat = "02/01/2006" 727 | conn.timeFormat = "15:04:05" 728 | conn.timestampFormat = "02/01/2006 15:04:05" 729 | conn.timestampTimezoneFormat = " MST" 730 | 731 | case "Postgres", "Postgres, DMY": 732 | conn.dateFormat = "02-01-2006" 733 | conn.timeFormat = "15:04:05" 734 | conn.timestampFormat = "Mon 02 Jan 15:04:05 2006" 735 | conn.timestampTimezoneFormat = " MST" 736 | 737 | case "Postgres, MDY": 738 | conn.dateFormat = "01-02-2006" 739 | conn.timeFormat = "15:04:05" 740 | conn.timestampFormat = "Mon Jan 02 15:04:05 2006" 741 | conn.timestampTimezoneFormat = " MST" 742 | 743 | case "German", "German, DMY", "German, MDY": 744 | conn.dateFormat = "02.01.2006" 745 | conn.timeFormat = "15:04:05" 746 | conn.timestampFormat = "02.01.2006 15:04:05" 747 | conn.timestampTimezoneFormat = " MST" 748 | 749 | default: 750 | if conn.LogLevel >= LogWarning { 751 | conn.log(LogWarning, "Unknown DateStyle: "+style) 752 | } 753 | conn.dateFormat = "" 754 | conn.timeFormat = "" 755 | conn.timestampFormat = "" 756 | conn.timestampTimezoneFormat = "" 757 | } 758 | } 759 | -------------------------------------------------------------------------------- /resultset.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "encoding/binary" 9 | "fmt" 10 | "math" 11 | "math/big" 12 | "strconv" 13 | "strings" 14 | "time" 15 | ) 16 | 17 | type fieldFormat int16 18 | 19 | const ( 20 | textFormat fieldFormat = 0 21 | binaryFormat fieldFormat = 1 22 | ) 23 | 24 | type field struct { 25 | name string 26 | format fieldFormat 27 | typeOID int32 28 | } 29 | 30 | // ResultSet reads the results of a query, row by row, and provides methods to 31 | // retrieve field values of the current row. 32 | // 33 | // Access is by 0-based field ordinal position. 34 | type ResultSet struct { 35 | conn *Conn 36 | stmt *Statement 37 | hasCurrentRow bool 38 | currentResultComplete bool 39 | allResultsComplete bool 40 | rowsAffected int64 41 | name2ord map[string]int 42 | fields []field 43 | values [][]byte 44 | } 45 | 46 | func newResultSet(conn *Conn) *ResultSet { 47 | if conn.LogLevel >= LogDebug { 48 | defer conn.logExit(conn.logEnter("newResultSet")) 49 | } 50 | 51 | return &ResultSet{conn: conn} 52 | } 53 | 54 | func (rs *ResultSet) initializeResult() { 55 | if rs.conn.LogLevel >= LogDebug { 56 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.initializeResult")) 57 | } 58 | 59 | rs.conn.readRowDescription(rs) 60 | 61 | rs.name2ord = make(map[string]int) 62 | 63 | for ord, field := range rs.fields { 64 | rs.name2ord[field.name] = ord 65 | } 66 | 67 | rs.currentResultComplete = false 68 | rs.hasCurrentRow = false 69 | } 70 | 71 | func (rs *ResultSet) readRow() { 72 | if rs.conn.LogLevel >= LogDebug { 73 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.readRow")) 74 | } 75 | 76 | rs.conn.readDataRow(rs) 77 | 78 | rs.hasCurrentRow = true 79 | } 80 | 81 | func (rs *ResultSet) eatCurrentResultRows() { 82 | for { 83 | hasRow := rs.fetchNext() 84 | if !hasRow { 85 | return 86 | } 87 | } 88 | } 89 | 90 | func (rs *ResultSet) eatAllResultRows() { 91 | for { 92 | hasResult := rs.nextResult() 93 | if !hasResult { 94 | return 95 | } 96 | } 97 | } 98 | 99 | // Conn returns the *Conn this ResultSet is associated with. 100 | func (rs *ResultSet) Conn() *Conn { 101 | return rs.conn 102 | } 103 | 104 | // Statement returns the *Statement this ResultSet is associated with. 105 | func (rs *ResultSet) Statement() *Statement { 106 | return rs.stmt 107 | } 108 | 109 | func (rs *ResultSet) nextResult() bool { 110 | if rs.conn.LogLevel >= LogDebug { 111 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.nextResult")) 112 | } 113 | 114 | rs.eatCurrentResultRows() 115 | 116 | if !rs.allResultsComplete { 117 | rs.conn.readBackendMessages(rs) 118 | } 119 | 120 | return !rs.allResultsComplete 121 | } 122 | 123 | // NextResult moves the ResultSet to the next result, if there is one. 124 | // 125 | // In this case true is returned, otherwise false. 126 | // Statements support a single result only, use *Conn.Query if you need 127 | // this functionality. 128 | func (rs *ResultSet) NextResult() (hasResult bool, err error) { 129 | err = rs.conn.withRecover("*ResultSet.NextResult", func() { 130 | hasResult = rs.nextResult() 131 | }) 132 | 133 | return 134 | } 135 | 136 | func (rs *ResultSet) fetchNext() bool { 137 | if rs.conn.LogLevel >= LogDebug { 138 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.fetchNext")) 139 | } 140 | 141 | if rs.currentResultComplete { 142 | return false 143 | } 144 | 145 | rs.conn.readBackendMessages(rs) 146 | 147 | return !rs.currentResultComplete 148 | } 149 | 150 | func (rs *ResultSet) setCompletedOnPgsqlError(err error) { 151 | if err != nil && !rs.hasCurrentRow { 152 | if _, ok := err.(*Error); ok { 153 | // This is likely an exception raised by a user defined PostgreSQL 154 | // function. 155 | // FIXME: Not sure if this handling is sane. 156 | rs.currentResultComplete = true 157 | rs.allResultsComplete = true 158 | } 159 | } 160 | } 161 | 162 | // FetchNext reads the next row, if there is one. 163 | // 164 | // In this case true is returned, otherwise false. 165 | func (rs *ResultSet) FetchNext() (hasRow bool, err error) { 166 | err = rs.conn.withRecover("*ResultSet.FetchNext", func() { 167 | hasRow = rs.fetchNext() 168 | }) 169 | 170 | rs.setCompletedOnPgsqlError(err) 171 | 172 | return 173 | } 174 | 175 | func (rs *ResultSet) close() { 176 | if rs.conn.LogLevel >= LogDebug { 177 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.close")) 178 | } 179 | 180 | if rs.stmt != nil { 181 | defer rs.conn.writeClose('P', rs.stmt.portalName) 182 | } 183 | 184 | // TODO: Instead of eating all records, try to cancel the query processing. 185 | // (The required message has to be sent through another connection though.) 186 | rs.eatAllResultRows() 187 | 188 | rs.conn.state = readyState{} 189 | } 190 | 191 | // Close closes the ResultSet, so another query or command can be sent to 192 | // the server over the same connection. 193 | func (rs *ResultSet) Close() (err error) { 194 | err = rs.conn.withRecover("*ResultSet.Close", func() { 195 | rs.close() 196 | }) 197 | 198 | return 199 | } 200 | 201 | func (rs *ResultSet) isNull(ord int) bool { 202 | if rs.conn.LogLevel >= LogVerbose { 203 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.isNull")) 204 | } 205 | 206 | // Since all field value retrieval methods call this method, 207 | // we only check for a valid current row here. 208 | if !rs.hasCurrentRow { 209 | panic("invalid row") 210 | } 211 | 212 | return rs.values[ord] == nil 213 | } 214 | 215 | // IsNull returns if the value of the field with the specified ordinal is null. 216 | func (rs *ResultSet) IsNull(ord int) (isNull bool, err error) { 217 | err = rs.conn.withRecover("*ResultSet.IsNull", func() { 218 | isNull = rs.isNull(ord) 219 | }) 220 | 221 | return 222 | } 223 | 224 | // FieldCount returns the number of fields in the current result of the ResultSet. 225 | func (rs *ResultSet) FieldCount() int { 226 | if rs.conn.LogLevel >= LogVerbose { 227 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.FieldCount")) 228 | } 229 | 230 | return len(rs.fields) 231 | } 232 | 233 | // Name returns the name of the field with the specified ordinal. 234 | func (rs *ResultSet) Name(ord int) (name string, err error) { 235 | err = rs.conn.withRecover("*ResultSet.Name", func() { 236 | name = rs.fields[ord].name 237 | }) 238 | 239 | return 240 | } 241 | 242 | // Type returns the PostgreSQL type of the field with the specified ordinal. 243 | func (rs *ResultSet) Type(ord int) (typ Type, err error) { 244 | err = rs.conn.withRecover("*ResultSet.Type", func() { 245 | switch t := rs.fields[ord].typeOID; t { 246 | case _BOOLOID, _CHAROID, _DATEOID, _FLOAT4OID, _FLOAT8OID, _INT2OID, 247 | _INT4OID, _INT8OID, _NUMERICOID, _TEXTOID, _TIMEOID, _TIMETZOID, 248 | _TIMESTAMPOID, _TIMESTAMPTZOID, _VARCHAROID: 249 | typ = Type(t) 250 | return 251 | } 252 | 253 | typ = Custom 254 | }) 255 | 256 | return 257 | } 258 | 259 | // Ordinal returns the 0-based ordinal position of the field with the 260 | // specified name, or -1 if the ResultSet has no field with such a name. 261 | func (rs *ResultSet) Ordinal(name string) int { 262 | if rs.conn.LogLevel >= LogVerbose { 263 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.Ordinal")) 264 | } 265 | 266 | ord, ok := rs.name2ord[name] 267 | if !ok { 268 | return -1 269 | } 270 | 271 | return ord 272 | } 273 | 274 | func (rs *ResultSet) bool(ord int) (value, isNull bool) { 275 | if rs.conn.LogLevel >= LogVerbose { 276 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.bool")) 277 | } 278 | 279 | isNull = rs.isNull(ord) 280 | if isNull { 281 | return 282 | } 283 | 284 | val := rs.values[ord] 285 | 286 | switch rs.fields[ord].format { 287 | case textFormat: 288 | value = val[0] == 't' 289 | 290 | case binaryFormat: 291 | value = val[0] != 0 292 | } 293 | 294 | return 295 | } 296 | 297 | // Bool returns the value of the field with the specified ordinal as bool. 298 | func (rs *ResultSet) Bool(ord int) (value, isNull bool, err error) { 299 | err = rs.conn.withRecover("*ResultSet.Bool", func() { 300 | value, isNull = rs.bool(ord) 301 | }) 302 | 303 | return 304 | } 305 | 306 | func (rs *ResultSet) float32(ord int) (value float32, isNull bool) { 307 | if rs.conn.LogLevel >= LogVerbose { 308 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.float32")) 309 | } 310 | 311 | isNull = rs.isNull(ord) 312 | if isNull { 313 | return 314 | } 315 | 316 | val := rs.values[ord] 317 | 318 | switch rs.fields[ord].format { 319 | case textFormat: 320 | // strconv.Atof32 does not handle "-Infinity" and "Infinity" 321 | valStr := string(val) 322 | switch valStr { 323 | case "-Infinity": 324 | value = float32(math.Inf(-1)) 325 | 326 | case "Infinity": 327 | value = float32(math.Inf(1)) 328 | 329 | default: 330 | val, err := strconv.ParseFloat(valStr, 32) 331 | panicIfErr(err) 332 | value = float32(val) 333 | } 334 | 335 | case binaryFormat: 336 | value = math.Float32frombits(binary.BigEndian.Uint32(val)) 337 | } 338 | 339 | return 340 | } 341 | 342 | // Float32 returns the value of the field with the specified ordinal as float32. 343 | func (rs *ResultSet) Float32(ord int) (value float32, isNull bool, err error) { 344 | err = rs.conn.withRecover("*ResultSet.Float32", func() { 345 | value, isNull = rs.float32(ord) 346 | }) 347 | 348 | return 349 | } 350 | 351 | func (rs *ResultSet) float64(ord int) (value float64, isNull bool) { 352 | if rs.conn.LogLevel >= LogVerbose { 353 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.float64")) 354 | } 355 | 356 | isNull = rs.isNull(ord) 357 | if isNull { 358 | return 359 | } 360 | 361 | val := rs.values[ord] 362 | 363 | switch rs.fields[ord].format { 364 | case textFormat: 365 | // strconv.Atof64 does not handle "-Infinity" and "Infinity" 366 | valStr := string(val) 367 | switch valStr { 368 | case "-Infinity": 369 | value = math.Inf(-1) 370 | 371 | case "Infinity": 372 | value = math.Inf(1) 373 | 374 | default: 375 | var err error 376 | value, err = strconv.ParseFloat(valStr, 64) 377 | panicIfErr(err) 378 | } 379 | 380 | case binaryFormat: 381 | value = math.Float64frombits(binary.BigEndian.Uint64(val)) 382 | } 383 | 384 | return 385 | } 386 | 387 | // Float64 returns the value of the field with the specified ordinal as float64. 388 | func (rs *ResultSet) Float64(ord int) (value float64, isNull bool, err error) { 389 | err = rs.conn.withRecover("*ResultSet.Float64", func() { 390 | value, isNull = rs.float64(ord) 391 | }) 392 | 393 | return 394 | } 395 | 396 | func (rs *ResultSet) int16(ord int) (value int16, isNull bool) { 397 | if rs.conn.LogLevel >= LogVerbose { 398 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.int16")) 399 | } 400 | 401 | isNull = rs.isNull(ord) 402 | if isNull { 403 | return 404 | } 405 | 406 | val := rs.values[ord] 407 | 408 | switch rs.fields[ord].format { 409 | case textFormat: 410 | x, err := strconv.Atoi(string(val)) 411 | panicIfErr(err) 412 | value = int16(x) 413 | 414 | case binaryFormat: 415 | value = int16(binary.BigEndian.Uint16(val)) 416 | } 417 | 418 | return 419 | } 420 | 421 | // Int16 returns the value of the field with the specified ordinal as int16. 422 | func (rs *ResultSet) Int16(ord int) (value int16, isNull bool, err error) { 423 | err = rs.conn.withRecover("*ResultSet.Int16", func() { 424 | value, isNull = rs.int16(ord) 425 | }) 426 | 427 | return 428 | } 429 | 430 | func (rs *ResultSet) int32(ord int) (value int32, isNull bool) { 431 | if rs.conn.LogLevel >= LogVerbose { 432 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.int32")) 433 | } 434 | 435 | isNull = rs.isNull(ord) 436 | if isNull { 437 | return 438 | } 439 | 440 | val := rs.values[ord] 441 | 442 | switch rs.fields[ord].format { 443 | case textFormat: 444 | x, err := strconv.Atoi(string(val)) 445 | panicIfErr(err) 446 | value = int32(x) 447 | 448 | case binaryFormat: 449 | value = int32(binary.BigEndian.Uint32(val)) 450 | } 451 | 452 | return 453 | } 454 | 455 | // Int32 returns the value of the field with the specified ordinal as int32. 456 | func (rs *ResultSet) Int32(ord int) (value int32, isNull bool, err error) { 457 | err = rs.conn.withRecover("*ResultSet.Int32", func() { 458 | value, isNull = rs.int32(ord) 459 | }) 460 | 461 | return 462 | } 463 | 464 | func (rs *ResultSet) int64(ord int) (value int64, isNull bool) { 465 | if rs.conn.LogLevel >= LogVerbose { 466 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.int64")) 467 | } 468 | 469 | isNull = rs.isNull(ord) 470 | if isNull { 471 | return 472 | } 473 | 474 | val := rs.values[ord] 475 | 476 | switch rs.fields[ord].format { 477 | case textFormat: 478 | x, err := strconv.ParseInt(string(val), 10, 64) 479 | panicIfErr(err) 480 | value = int64(x) 481 | 482 | case binaryFormat: 483 | value = int64(binary.BigEndian.Uint64(val)) 484 | } 485 | 486 | return 487 | } 488 | 489 | // Int64 returns the value of the field with the specified ordinal as int64. 490 | func (rs *ResultSet) Int64(ord int) (value int64, isNull bool, err error) { 491 | err = rs.conn.withRecover("*ResultSet.Int64", func() { 492 | value, isNull = rs.int64(ord) 493 | }) 494 | 495 | return 496 | } 497 | 498 | func (rs *ResultSet) int(ord int) (value int, isNull bool) { 499 | var val int32 500 | val, isNull = rs.int32(ord) 501 | value = int(val) 502 | 503 | return 504 | } 505 | 506 | // Int returns the value of the field with the specified ordinal as int. 507 | func (rs *ResultSet) Int(ord int) (value int, isNull bool, err error) { 508 | err = rs.conn.withRecover("*ResultSet.Int", func() { 509 | value, isNull = rs.int(ord) 510 | }) 511 | 512 | return 513 | } 514 | 515 | func (rs *ResultSet) rat(ord int) (value *big.Rat, isNull bool) { 516 | if rs.conn.LogLevel >= LogVerbose { 517 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.rat")) 518 | } 519 | 520 | isNull = rs.isNull(ord) 521 | if isNull { 522 | return 523 | } 524 | 525 | val := rs.values[ord] 526 | 527 | switch rs.fields[ord].format { 528 | case textFormat: 529 | x := big.NewRat(1, 1) 530 | if _, ok := x.SetString(string(val)); !ok { 531 | panic("*big.Rat.SetString failed") 532 | } 533 | value = x 534 | 535 | case binaryFormat: 536 | panicNotImplemented() 537 | } 538 | 539 | return 540 | } 541 | 542 | // Rat returns the value of the field with the specified ordinal as *big.Rat. 543 | func (rs *ResultSet) Rat(ord int) (value *big.Rat, isNull bool, err error) { 544 | err = rs.conn.withRecover("*ResultSet.Rat", func() { 545 | value, isNull = rs.rat(ord) 546 | }) 547 | 548 | return 549 | } 550 | 551 | func (rs *ResultSet) string(ord int) (value string, isNull bool) { 552 | if rs.conn.LogLevel >= LogVerbose { 553 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.string")) 554 | } 555 | 556 | isNull = rs.isNull(ord) 557 | if isNull { 558 | return 559 | } 560 | 561 | value = string(rs.values[ord]) 562 | 563 | return 564 | } 565 | 566 | // String returns the value of the field with the specified ordinal as string. 567 | func (rs *ResultSet) String(ord int) (value string, isNull bool, err error) { 568 | err = rs.conn.withRecover("*ResultSet.String", func() { 569 | value, isNull = rs.string(ord) 570 | }) 571 | 572 | return 573 | } 574 | 575 | func (rs *ResultSet) time(ord int) (value time.Time, isNull bool) { 576 | if rs.conn.LogLevel >= LogVerbose { 577 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.Time")) 578 | } 579 | 580 | // We need to convert the parsed *time.Time to seconds and back, 581 | // because otherwise the Weekday field will always equal 0 (Sunday). 582 | // See http://code.google.com/p/go/issues/detail?id=1025 583 | seconds, isNull := rs.timeSeconds(ord) 584 | if isNull { 585 | return 586 | } 587 | 588 | value = time.Unix(seconds, 0).UTC() 589 | 590 | return 591 | } 592 | 593 | // Time returns the value of the field with the specified ordinal as *time.Time. 594 | func (rs *ResultSet) Time(ord int) (value time.Time, isNull bool, err error) { 595 | err = rs.conn.withRecover("*ResultSet.Time", func() { 596 | value, isNull = rs.time(ord) 597 | }) 598 | 599 | return 600 | } 601 | 602 | func (rs *ResultSet) timeSeconds(ord int) (value int64, isNull bool) { 603 | if rs.conn.LogLevel >= LogVerbose { 604 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.timeSeconds")) 605 | } 606 | 607 | isNull = rs.isNull(ord) 608 | if isNull { 609 | return 610 | } 611 | 612 | val := rs.values[ord] 613 | 614 | var t time.Time 615 | 616 | switch rs.fields[ord].format { 617 | case textFormat: 618 | var format string 619 | switch rs.fields[ord].typeOID { 620 | case _DATEOID: 621 | format = rs.conn.dateFormat 622 | 623 | case _TIMEOID, _TIMETZOID: 624 | format = rs.conn.timeFormat 625 | 626 | case _TIMESTAMPOID, _TIMESTAMPTZOID: 627 | format = rs.conn.timestampFormat 628 | } 629 | 630 | switch rs.fields[ord].typeOID { 631 | case _TIMETZOID: 632 | format += "-07" 633 | 634 | case _TIMESTAMPTZOID: 635 | format += rs.conn.timestampTimezoneFormat 636 | } 637 | 638 | s := string(val) 639 | 640 | if rs.fields[ord].typeOID != _DATEOID { 641 | // The resolution of time.Time is seconds, so we will have to drop 642 | // fractions, if present. 643 | lastSemicolon := strings.LastIndex(s, ":") 644 | lastDot := strings.LastIndex(s, ".") 645 | if lastSemicolon < lastDot { 646 | // There are fractions 647 | plusOrMinus := strings.IndexAny(s[lastDot:], "+-") 648 | if -1 < plusOrMinus { 649 | // There is a time zone 650 | s = s[:lastDot] + s[lastDot+plusOrMinus:] 651 | } else { 652 | s = s[:lastDot] 653 | } 654 | } 655 | } 656 | 657 | var err error 658 | t, err = time.Parse(format, s) 659 | panicIfErr(err) 660 | 661 | case binaryFormat: 662 | panicNotImplemented() 663 | } 664 | 665 | value = t.Unix() 666 | 667 | return 668 | } 669 | 670 | // TimeSeconds returns the value of the field with the specified ordinal as int64. 671 | func (rs *ResultSet) TimeSeconds(ord int) (value int64, isNull bool, err error) { 672 | err = rs.conn.withRecover("*ResultSet.TimeSeconds", func() { 673 | value, isNull = rs.timeSeconds(ord) 674 | }) 675 | 676 | return 677 | } 678 | 679 | func (rs *ResultSet) uint(ord int) (value uint, isNull bool) { 680 | var val uint32 681 | val, isNull = rs.uint32(ord) 682 | value = uint(val) 683 | 684 | return 685 | } 686 | 687 | // Uint returns the value of the field with the specified ordinal as uint. 688 | func (rs *ResultSet) Uint(ord int) (value uint, isNull bool, err error) { 689 | err = rs.conn.withRecover("*ResultSet.Uint", func() { 690 | value, isNull = rs.uint(ord) 691 | }) 692 | 693 | return 694 | } 695 | 696 | func (rs *ResultSet) uint16(ord int) (value uint16, isNull bool) { 697 | var val int16 698 | val, isNull = rs.int16(ord) 699 | value = uint16(val) 700 | 701 | return 702 | } 703 | 704 | // Uint16 returns the value of the field with the specified ordinal as uint16. 705 | func (rs *ResultSet) Uint16(ord int) (value uint16, isNull bool, err error) { 706 | err = rs.conn.withRecover("*ResultSet.Uint16", func() { 707 | value, isNull = rs.uint16(ord) 708 | }) 709 | 710 | return 711 | } 712 | 713 | func (rs *ResultSet) uint32(ord int) (value uint32, isNull bool) { 714 | var val int32 715 | val, isNull = rs.int32(ord) 716 | value = uint32(val) 717 | 718 | return 719 | } 720 | 721 | // Uint32 returns the value of the field with the specified ordinal as uint32. 722 | func (rs *ResultSet) Uint32(ord int) (value uint32, isNull bool, err error) { 723 | err = rs.conn.withRecover("*ResultSet.Uint32", func() { 724 | value, isNull = rs.uint32(ord) 725 | }) 726 | 727 | return 728 | } 729 | 730 | func (rs *ResultSet) uint64(ord int) (value uint64, isNull bool) { 731 | var val int64 732 | val, isNull = rs.int64(ord) 733 | value = uint64(val) 734 | 735 | return 736 | } 737 | 738 | // Uint64 returns the value of the field with the specified ordinal as uint64. 739 | func (rs *ResultSet) Uint64(ord int) (value uint64, isNull bool, err error) { 740 | err = rs.conn.withRecover("*ResultSet.Uint64", func() { 741 | value, isNull = rs.uint64(ord) 742 | }) 743 | 744 | return 745 | } 746 | 747 | func (rs *ResultSet) any(ord int) (value interface{}, isNull bool) { 748 | if rs.values[ord] == nil { 749 | isNull = true 750 | return 751 | } 752 | 753 | switch rs.fields[ord].typeOID { 754 | case _BOOLOID: 755 | value, isNull = rs.bool(ord) 756 | 757 | case _BPCHAROID, _CHAROID, _VARCHAROID, _TEXTOID: 758 | value, isNull = rs.string(ord) 759 | 760 | case _DATEOID, _TIMEOID, _TIMETZOID, _TIMESTAMPOID, _TIMESTAMPTZOID: 761 | value, isNull = rs.time(ord) 762 | 763 | case _FLOAT4OID: 764 | value, isNull = rs.float32(ord) 765 | 766 | case _FLOAT8OID: 767 | value, isNull = rs.float64(ord) 768 | 769 | case _INT2OID: 770 | value, isNull = rs.int16(ord) 771 | 772 | case _INT4OID: 773 | value, isNull = rs.int(ord) 774 | 775 | case _INT8OID: 776 | value, isNull = rs.int64(ord) 777 | 778 | case _NUMERICOID: 779 | value, isNull = rs.rat(ord) 780 | 781 | default: 782 | panic(fmt.Sprintf("unexpected field type: field: '%s' OID: %d", rs.fields[ord].name, rs.fields[ord].typeOID)) 783 | } 784 | 785 | return 786 | } 787 | 788 | // Any returns the value of the field with the specified ordinal as interface{}. 789 | // 790 | // Types are mapped as follows: 791 | // 792 | // PostgreSQL Go 793 | // 794 | // Bigint int64 795 | // Boolean bool 796 | // Char string 797 | // Date int64 798 | // Double float64 799 | // Integer int 800 | // Numeric *big.Rat 801 | // Real float 802 | // Smallint int16 803 | // Text string 804 | // Time time.Time 805 | // TimeTZ time.Time 806 | // Timestamp time.Time 807 | // TimestampTZ time.Time 808 | // Varchar string 809 | func (rs *ResultSet) Any(ord int) (value interface{}, isNull bool, err error) { 810 | err = rs.conn.withRecover("*ResultSet.Any", func() { 811 | value, isNull = rs.any(ord) 812 | }) 813 | 814 | return 815 | } 816 | 817 | func (rs *ResultSet) scan(args ...interface{}) { 818 | if rs.conn.LogLevel >= LogVerbose { 819 | defer rs.conn.logExit(rs.conn.logEnter("*ResultSet.Scan")) 820 | } 821 | 822 | if len(args) != len(rs.fields) { 823 | panic("wrong argument count") 824 | } 825 | 826 | for i, arg := range args { 827 | switch a := arg.(type) { 828 | case *bool: 829 | *a, _ = rs.bool(i) 830 | 831 | case *float32: 832 | *a, _ = rs.float32(i) 833 | 834 | case *float64: 835 | *a, _ = rs.float64(i) 836 | 837 | case *int: 838 | *a, _ = rs.int(i) 839 | 840 | case *int16: 841 | *a, _ = rs.int16(i) 842 | 843 | case *int32: 844 | *a, _ = rs.int32(i) 845 | 846 | case *int64: 847 | switch rs.fields[i].typeOID { 848 | case _DATEOID, _TIMEOID, _TIMETZOID, _TIMESTAMPOID, _TIMESTAMPTZOID: 849 | *a, _ = rs.timeSeconds(i) 850 | 851 | default: 852 | *a, _ = rs.int64(i) 853 | } 854 | 855 | case *interface{}: 856 | *a, _ = rs.any(i) 857 | 858 | case **big.Rat: 859 | var r *big.Rat 860 | r, _ = rs.rat(i) 861 | *a = r 862 | 863 | case *string: 864 | *a, _ = rs.string(i) 865 | 866 | case *time.Time: 867 | var t time.Time 868 | t, _ = rs.time(i) 869 | *a = t 870 | 871 | case *uint: 872 | *a, _ = rs.uint(i) 873 | 874 | case *uint16: 875 | *a, _ = rs.uint16(i) 876 | 877 | case *uint32: 878 | *a, _ = rs.uint32(i) 879 | 880 | case *uint64: 881 | switch rs.fields[i].typeOID { 882 | case _DATEOID, _TIMEOID, _TIMETZOID, _TIMESTAMPOID, _TIMESTAMPTZOID: 883 | var seconds int64 884 | seconds, _ = rs.timeSeconds(i) 885 | *a = uint64(seconds) 886 | 887 | default: 888 | *a, _ = rs.uint64(i) 889 | } 890 | } 891 | } 892 | 893 | return 894 | } 895 | 896 | // Scan scans the fields of the current row in the ResultSet, trying 897 | // to store field values into the specified arguments. 898 | // 899 | // The arguments must be of pointer types. 900 | func (rs *ResultSet) Scan(args ...interface{}) (err error) { 901 | err = rs.conn.withRecover("*ResultSet.Scan", func() { 902 | rs.scan(args...) 903 | }) 904 | 905 | return 906 | } 907 | 908 | func (rs *ResultSet) scanNext(args ...interface{}) (fetched bool) { 909 | fetched = rs.fetchNext() 910 | if !fetched { 911 | return 912 | } 913 | 914 | rs.scan(args...) 915 | 916 | return 917 | } 918 | 919 | // ScanNext scans the fields of the next row in the ResultSet, trying 920 | // to store field values into the specified arguments. 921 | // 922 | // The arguments must be of pointer types. If a row has been fetched, fetched 923 | // will be true, otherwise false. 924 | func (rs *ResultSet) ScanNext(args ...interface{}) (fetched bool, err error) { 925 | err = rs.conn.withRecover("*ResultSet.ScanNext", func() { 926 | fetched = rs.scanNext(args...) 927 | }) 928 | 929 | rs.setCompletedOnPgsqlError(err) 930 | 931 | return 932 | } 933 | -------------------------------------------------------------------------------- /pgsql_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2010 The go-pgsql Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package pgsql 6 | 7 | import ( 8 | "bytes" 9 | "errors" 10 | "fmt" 11 | "math" 12 | "math/big" 13 | "strings" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | func withConnLog(t *testing.T, logLevel LogLevel, f func(conn *Conn)) { 19 | conn, err := Connect("dbname=testdatabase user=testuser password=testpassword", logLevel) 20 | if err != nil { 21 | t.Error("withConn: Connect:", err) 22 | return 23 | } 24 | if conn == nil { 25 | t.Error("withConn: Connect: conn == nil") 26 | return 27 | } 28 | defer conn.Close() 29 | 30 | f(conn) 31 | } 32 | 33 | func withConn(t *testing.T, f func(conn *Conn)) { 34 | withConnLog(t, LogNothing, f) 35 | } 36 | 37 | func withSimpleQueryResultSet(t *testing.T, command string, f func(rs *ResultSet)) { 38 | withConn(t, func(conn *Conn) { 39 | rs, err := conn.Query(command) 40 | if err != nil { 41 | t.Error("withSimpleQueryResultSet: conn.Query:", err) 42 | return 43 | } 44 | if rs == nil { 45 | t.Error("withSimpleQueryResultSet: conn.Query: rs == nil") 46 | return 47 | } 48 | defer rs.Close() 49 | 50 | f(rs) 51 | }) 52 | } 53 | 54 | func withStatement(t *testing.T, command string, params []*Parameter, f func(stmt *Statement)) { 55 | withConn(t, func(conn *Conn) { 56 | stmt, err := conn.Prepare(command, params...) 57 | if err != nil { 58 | t.Error("withStatement: conn.Prepare:", err) 59 | return 60 | } 61 | if stmt == nil { 62 | t.Error("withStatement: conn.Prepare: stmt == nil") 63 | return 64 | } 65 | defer stmt.Close() 66 | 67 | f(stmt) 68 | }) 69 | } 70 | 71 | func withStatementResultSet(t *testing.T, command string, params []*Parameter, f func(rs *ResultSet)) { 72 | withStatement(t, command, params, func(stmt *Statement) { 73 | rs, err := stmt.Query() 74 | if err != nil { 75 | t.Error("withStatementResultSet: stmt.Query:", err) 76 | return 77 | } 78 | if rs == nil { 79 | t.Error("withStatementResultSet: stmt.Query: rs == nil") 80 | return 81 | } 82 | defer rs.Close() 83 | 84 | f(rs) 85 | }) 86 | } 87 | 88 | func param(name string, typ Type, value interface{}) *Parameter { 89 | p := NewParameter(name, typ) 90 | err := p.SetValue(value) 91 | if err != nil { 92 | panic(err) 93 | } 94 | 95 | return p 96 | } 97 | 98 | func Test_Connect_UglyButValidParamsStyle_ExpectErrNil(t *testing.T) { 99 | conn, err := Connect( 100 | `dbname=testdatabase 101 | 102 | user ='testuser' password = 'testpassword' `, 103 | LogNothing) 104 | if err != nil { 105 | t.Fail() 106 | } 107 | if conn != nil { 108 | conn.Close() 109 | } 110 | } 111 | 112 | func Test_Connect_InvalidPassword_ExpectConnNil(t *testing.T) { 113 | conn, _ := Connect("dbname=testdatabase user=testuser password=wrongpassword", LogNothing) 114 | if conn != nil { 115 | t.Fail() 116 | conn.Close() 117 | } 118 | } 119 | 120 | func Test_Connect_InvalidPassword_ExpectErrorClass28(t *testing.T) { 121 | conn, err := Connect("dbname=testdatabase user=testuser password=wrongpassword", LogNothing) 122 | if err == nil { 123 | t.Error("expected err != nil") 124 | } 125 | // Class 28 == invalid authorization specification 126 | if pgerr, ok := err.(*Error); !ok || !strings.HasPrefix(pgerr.Code(), "28") { 127 | t.Error("expected *pgsql.Error of class 28") 128 | } 129 | if conn != nil { 130 | conn.Close() 131 | } 132 | } 133 | 134 | func Test_DoSimpleQueryResultSetTests(t *testing.T) { 135 | tests := []func(rs *ResultSet) (have, want interface{}, name string){ 136 | // Basic rs tests 137 | func(rs *ResultSet) (have, want interface{}, name string) { 138 | hasRow, _ := rs.FetchNext() 139 | return hasRow, true, "FetchNext" 140 | }, 141 | func(rs *ResultSet) (have, want interface{}, name string) { 142 | hasRow, _ := rs.FetchNext() 143 | hasRow, _ = rs.FetchNext() 144 | return hasRow, false, "FetchNext_RetValSecondCall" 145 | }, 146 | func(rs *ResultSet) (have, want interface{}, name string) { 147 | _, err := rs.FetchNext() 148 | return err == nil, true, "FetchNext_ErrNil" 149 | }, 150 | 151 | // Field info tests 152 | func(rs *ResultSet) (have, want interface{}, name string) { 153 | fieldCount := rs.FieldCount() 154 | return fieldCount, 5, "field count" 155 | }, 156 | func(rs *ResultSet) (have, want interface{}, name string) { 157 | fieldName, _ := rs.Name(1) 158 | return fieldName, "_two", "field #1 name" 159 | }, 160 | func(rs *ResultSet) (have, want interface{}, name string) { 161 | typ, _ := rs.Type(2) 162 | return typ, Boolean, "field #2 type" 163 | }, 164 | 165 | // Get value tests 166 | func(rs *ResultSet) (have, want interface{}, name string) { 167 | rs.FetchNext() 168 | val, _, _ := rs.Int32(0) 169 | return val, int32(1), "field #0" 170 | }, 171 | func(rs *ResultSet) (have, want interface{}, name string) { 172 | rs.FetchNext() 173 | val, _, _ := rs.String(1) 174 | return val, "two", "field #1" 175 | }, 176 | func(rs *ResultSet) (have, want interface{}, name string) { 177 | rs.FetchNext() 178 | val, _, _ := rs.Bool(2) 179 | return val, true, "field #2" 180 | }, 181 | func(rs *ResultSet) (have, want interface{}, name string) { 182 | rs.FetchNext() 183 | val, _ := rs.IsNull(3) 184 | return val, true, "field #3 is null" 185 | }, 186 | func(rs *ResultSet) (have, want interface{}, name string) { 187 | rs.FetchNext() 188 | val, _, _ := rs.Float64(4) 189 | return val, float64(4.5), "field #4" 190 | }, 191 | } 192 | 193 | for _, test := range tests { 194 | withSimpleQueryResultSet(t, "SELECT 1 AS _1, 'two' AS _two, true AS _true, null AS _null, 4.5 AS _4_5;", func(rs *ResultSet) { 195 | if have, want, name := test(rs); have != want { 196 | t.Errorf("%s failed - have: '%v', but want '%v'", name, have, want) 197 | } 198 | }) 199 | } 200 | } 201 | 202 | func Test_SimpleQuery_MultipleSelects(t *testing.T) { 203 | tests := []func(rs *ResultSet) (have, want interface{}, name string){ 204 | // First result 205 | func(rs *ResultSet) (have, want interface{}, name string) { 206 | hasRead, _ := rs.FetchNext() 207 | return hasRead, true, "hasRead on first FetchNext (first result)" 208 | }, 209 | func(rs *ResultSet) (have, want interface{}, name string) { 210 | _, err := rs.FetchNext() 211 | return err, nil, "err on first FetchNext (first result)" 212 | }, 213 | func(rs *ResultSet) (have, want interface{}, name string) { 214 | rs.FetchNext() 215 | hasRead, _ := rs.FetchNext() 216 | return hasRead, false, "hasRead on second FetchNext (first result)" 217 | }, 218 | func(rs *ResultSet) (have, want interface{}, name string) { 219 | rs.FetchNext() 220 | _, err := rs.FetchNext() 221 | return err, nil, "err on second FetchNext (first result)" 222 | }, 223 | func(rs *ResultSet) (have, want interface{}, name string) { 224 | rs.FetchNext() 225 | val, _, _ := rs.Int(0) 226 | return val, 1, "value Int(0) (first result)" 227 | }, 228 | func(rs *ResultSet) (have, want interface{}, name string) { 229 | rs.FetchNext() 230 | _, isNull, _ := rs.Int(0) 231 | return isNull, false, "isNull Int(0) (first result)" 232 | }, 233 | func(rs *ResultSet) (have, want interface{}, name string) { 234 | rs.FetchNext() 235 | _, _, err := rs.Int(0) 236 | return err, nil, "err Int(0) (first result)" 237 | }, 238 | func(rs *ResultSet) (have, want interface{}, name string) { 239 | hasResult, _ := rs.NextResult() 240 | return hasResult, true, "hasResult on NextResult (first result)" 241 | }, 242 | func(rs *ResultSet) (have, want interface{}, name string) { 243 | _, err := rs.NextResult() 244 | return err, nil, "err on NextResult (first result)" 245 | }, 246 | // Second result 247 | func(rs *ResultSet) (have, want interface{}, name string) { 248 | rs.NextResult() 249 | hasRead, _ := rs.FetchNext() 250 | return hasRead, true, "hasRead on first FetchNext (second result)" 251 | }, 252 | func(rs *ResultSet) (have, want interface{}, name string) { 253 | rs.NextResult() 254 | _, err := rs.FetchNext() 255 | return err, nil, "err on first FetchNext (second result)" 256 | }, 257 | func(rs *ResultSet) (have, want interface{}, name string) { 258 | rs.NextResult() 259 | rs.FetchNext() 260 | hasRead, _ := rs.FetchNext() 261 | return hasRead, false, "hasRead on second FetchNext (second result)" 262 | }, 263 | func(rs *ResultSet) (have, want interface{}, name string) { 264 | rs.NextResult() 265 | rs.FetchNext() 266 | _, err := rs.FetchNext() 267 | return err, nil, "err on second FetchNext (second result)" 268 | }, 269 | func(rs *ResultSet) (have, want interface{}, name string) { 270 | rs.NextResult() 271 | rs.FetchNext() 272 | val, _, _ := rs.String(0) 273 | return val, "two", "value String(0) (second result)" 274 | }, 275 | func(rs *ResultSet) (have, want interface{}, name string) { 276 | rs.NextResult() 277 | rs.FetchNext() 278 | _, isNull, _ := rs.String(0) 279 | return isNull, false, "isNull String(0) (second result)" 280 | }, 281 | func(rs *ResultSet) (have, want interface{}, name string) { 282 | rs.NextResult() 283 | rs.FetchNext() 284 | _, _, err := rs.String(0) 285 | return err, nil, "err String(0) (second result)" 286 | }, 287 | func(rs *ResultSet) (have, want interface{}, name string) { 288 | rs.NextResult() 289 | hasResult, _ := rs.NextResult() 290 | return hasResult, false, "hasResult on NextResult (second result)" 291 | }, 292 | func(rs *ResultSet) (have, want interface{}, name string) { 293 | rs.NextResult() 294 | _, err := rs.NextResult() 295 | return err, nil, "err on NextResult (second result)" 296 | }, 297 | } 298 | 299 | for _, test := range tests { 300 | withSimpleQueryResultSet(t, "SELECT 1 AS _1; SELECT 'two' AS _two;", func(rs *ResultSet) { 301 | if have, want, name := test(rs); have != want { 302 | t.Errorf("%s failed - have: '%v', but want '%v'", name, have, want) 303 | } 304 | }) 305 | } 306 | } 307 | 308 | func idParameter(value int) *Parameter { 309 | idParam := NewParameter("@id", Integer) 310 | idParam.SetValue(value) 311 | 312 | return idParam 313 | } 314 | 315 | func Test_Statement_ActualCommand(t *testing.T) { 316 | withStatement(t, "SELECT id FROM table1 WHERE strreq = '@id' OR id = @id;", []*Parameter{idParameter(3)}, func(stmt *Statement) { 317 | if stmt.ActualCommand() != "SELECT id FROM table1 WHERE strreq = '@id' OR id = $1;" { 318 | t.Fail() 319 | } 320 | }) 321 | } 322 | 323 | type statementResultSetTest struct { 324 | command string 325 | params []*Parameter 326 | fun func(rs *ResultSet) (have, want interface{}, name string) 327 | } 328 | 329 | func whereIdEquals2StatementResultSetTest(fun func(rs *ResultSet) (have, want interface{}, name string)) *statementResultSetTest { 330 | return &statementResultSetTest{ 331 | command: "SELECT id FROM table1 WHERE id = @id;", 332 | params: []*Parameter{idParameter(2)}, 333 | fun: fun, 334 | } 335 | } 336 | 337 | func Test_DoStatementResultSetTests(t *testing.T) { 338 | tests := []*statementResultSetTest{ 339 | whereIdEquals2StatementResultSetTest(func(rs *ResultSet) (have, want interface{}, name string) { 340 | hasRead, _ := rs.FetchNext() 341 | return hasRead, true, "WHERE id = 2 - 'hasRead, _ := rs.FetchNext()'" 342 | }), 343 | whereIdEquals2StatementResultSetTest(func(rs *ResultSet) (have, want interface{}, name string) { 344 | _, err := rs.FetchNext() 345 | return err, nil, "WHERE id = 2 - '_, err := rs.FetchNext()'" 346 | }), 347 | whereIdEquals2StatementResultSetTest(func(rs *ResultSet) (have, want interface{}, name string) { 348 | rs.FetchNext() 349 | val, _, _ := rs.Int32(0) 350 | return val, int32(2), "WHERE id = 2 - 'val, _, _ := rs.Int32(0)'" 351 | }), 352 | whereIdEquals2StatementResultSetTest(func(rs *ResultSet) (have, want interface{}, name string) { 353 | rs.FetchNext() 354 | _, isNull, _ := rs.Int32(0) 355 | return isNull, false, "WHERE id = 2 - '_, isNull, _ := rs.Int32(0)'" 356 | }), 357 | whereIdEquals2StatementResultSetTest(func(rs *ResultSet) (have, want interface{}, name string) { 358 | rs.FetchNext() 359 | _, _, err := rs.Int32(0) 360 | return err, nil, "WHERE id = 2 - '_, _, err := rs.Int32(0)'" 361 | }), 362 | } 363 | 364 | for _, test := range tests { 365 | withStatementResultSet(t, test.command, test.params, func(rs *ResultSet) { 366 | if have, want, name := test.fun(rs); have != want { 367 | t.Errorf("%s failed - have: '%v', but want '%v'", name, have, want) 368 | } 369 | }) 370 | } 371 | } 372 | 373 | type item struct { 374 | id int 375 | name string 376 | price float64 377 | packUnit uint 378 | onSale bool 379 | something interface{} 380 | } 381 | 382 | func Test_Conn_Scan(t *testing.T) { 383 | withConn(t, func(conn *Conn) { 384 | var x item 385 | command := "SELECT 123, 'abc', 14.99, 4, true, '2010-08-20'::DATE;" 386 | fetched, err := conn.Scan(command, &x.id, &x.name, &x.price, &x.packUnit, &x.onSale, &x.something) 387 | if err != nil { 388 | t.Error(err) 389 | return 390 | } 391 | if !fetched { 392 | t.Error("fetched == false") 393 | } 394 | if x.id != 123 { 395 | t.Errorf("id - have: %d, but want: 123", x.id) 396 | } 397 | if x.name != "abc" { 398 | t.Errorf("name - have: '%s', but want: 'abc'", x.name) 399 | } 400 | if math.Abs(float64(x.price)-14.99) > 0.000001 { 401 | t.Errorf("price - have: %f, but want: 14.99", x.price) 402 | } 403 | if x.packUnit != 4 { 404 | t.Errorf("packUnit - have: %d, but want: 4", x.packUnit) 405 | } 406 | if !x.onSale { 407 | t.Error("onSale - have: true, but want: false") 408 | } 409 | tm, ok := x.something.(time.Time) 410 | if !ok { 411 | t.Error("something should have type time.Time") 412 | } else { 413 | dateStr := tm.Format(dateFormat) 414 | if dateStr != "2010-08-20" { 415 | t.Errorf("something - have: '%s', but want: '2010-08-20'", dateStr) 416 | } 417 | } 418 | }) 419 | } 420 | 421 | type dateStyleTest struct { 422 | typ, format, want string 423 | } 424 | 425 | func Test_DateStyle(t *testing.T) { 426 | dateStyles := []string{ 427 | "ISO", "ISO, DMY", "ISO, MDY", "ISO, YMD", 428 | "SQL", "SQL, DMY", "SQL, MDY", 429 | "Postgres", "Postgres, DMY", "Postgres, MDY", 430 | "German", "German, DMY", "German, MDY", 431 | } 432 | 433 | tests := []*dateStyleTest{ 434 | &dateStyleTest{ 435 | typ: "DATE", 436 | format: dateFormat, 437 | want: "2010-08-16", 438 | }, 439 | &dateStyleTest{ 440 | typ: "TIME", 441 | format: timeFormat, 442 | want: "01:23:45", 443 | }, 444 | &dateStyleTest{ 445 | typ: "TIME WITH TIME ZONE", 446 | format: timeFormat, 447 | want: "01:23:45", 448 | }, 449 | &dateStyleTest{ 450 | typ: "TIMESTAMP", 451 | format: timestampFormat, 452 | want: "2010-08-16 01:23:45", 453 | }, 454 | &dateStyleTest{ 455 | typ: "TIMESTAMP WITH TIME ZONE", 456 | format: timestampFormat, 457 | want: "2010-08-16 01:23:45", 458 | }, 459 | } 460 | 461 | for _, style := range dateStyles { 462 | withConn(t, func(conn *Conn) { 463 | _, err := conn.Execute("SET TimeZone = UTC;") 464 | if err != nil { 465 | t.Errorf("failed to set time zone = UTC: %s", err) 466 | return 467 | } 468 | 469 | _, err = conn.Execute(fmt.Sprintf("SET DateStyle = %s;", style)) 470 | if err != nil { 471 | t.Errorf("failed to set DateStyle = %s: %s", style, err) 472 | return 473 | } 474 | 475 | var ts time.Time 476 | 477 | for _, test := range tests { 478 | _, err = conn.Scan(fmt.Sprintf("SELECT %s '%s';", test.typ, test.want), &ts) 479 | if err != nil { 480 | t.Errorf("failed to scan with DateStyle = %s: %s", style, err) 481 | return 482 | } 483 | 484 | have := ts.Format(test.format) 485 | 486 | if have != test.want { 487 | t.Errorf("DateStyle = %s, typ = %s: want: '%s', but have: '%s'", style, test.typ, test.want, have) 488 | } 489 | } 490 | }) 491 | } 492 | } 493 | 494 | type timeTest struct { 495 | command, timeString string 496 | seconds int64 497 | } 498 | 499 | func newTimeTest(commandTemplate, format, value string) *timeTest { 500 | test := &timeTest{} 501 | 502 | t, err := time.Parse(format, value) 503 | if err != nil { 504 | panic(err) 505 | } 506 | t = time.Unix(t.Unix(), 0).UTC() 507 | 508 | if strings.Index(commandTemplate, "%s") > -1 { 509 | test.command = fmt.Sprintf(commandTemplate, value) 510 | } else { 511 | test.command = commandTemplate 512 | } 513 | test.seconds = t.Unix() 514 | test.timeString = t.String() 515 | 516 | return test 517 | } 518 | 519 | const ( 520 | dateFormat = "2006-01-02" 521 | timeFormat = "15:04:05" 522 | timestampFormat = "2006-01-02 15:04:05" 523 | ) 524 | 525 | func Test_Conn_Scan_Time(t *testing.T) { 526 | tests := []*timeTest{ 527 | newTimeTest( 528 | "SELECT DATE '%s';", 529 | dateFormat, 530 | "2010-08-14"), 531 | newTimeTest( 532 | "SELECT TIME '%s';", 533 | timeFormat, 534 | "18:43:32"), 535 | newTimeTest( 536 | "SELECT TIME WITH TIME ZONE '%s';", 537 | timeFormat+"-07", 538 | "18:43:32+02"), 539 | newTimeTest( 540 | "SELECT TIMESTAMP '%s';", 541 | timestampFormat, 542 | "2010-08-14 18:43:32"), 543 | newTimeTest( 544 | "SELECT TIMESTAMP WITH TIME ZONE '%s';", 545 | timestampFormat+"-07", 546 | "2010-08-14 18:43:32+02"), 547 | } 548 | 549 | for _, test := range tests { 550 | withConn(t, func(conn *Conn) { 551 | _, err := conn.Execute("SET TimeZone = 02; SET DateStyle = ISO") 552 | if err != nil { 553 | t.Error("failed to set time zone or date style:", err) 554 | return 555 | } 556 | 557 | var seconds int64 558 | _, err = conn.Scan(test.command, &seconds) 559 | if err != nil { 560 | t.Error(err) 561 | return 562 | } 563 | if seconds != test.seconds { 564 | t.Errorf("'%s' failed - have: '%d', but want '%d'", test.command, seconds, test.seconds) 565 | } 566 | 567 | var tm time.Time 568 | _, err = conn.Scan(test.command, &tm) 569 | if err != nil { 570 | t.Error(err) 571 | return 572 | } 573 | timeString := tm.String() 574 | if timeString != test.timeString { 575 | t.Errorf("'%s' failed - have: '%s', but want '%s'", test.command, timeString, test.timeString) 576 | } 577 | }) 578 | } 579 | } 580 | 581 | func Test_Insert_Time(t *testing.T) { 582 | tests := []*timeTest{ 583 | newTimeTest( 584 | "SELECT _d FROM _gopgsql_test_time;", 585 | dateFormat, 586 | "2010-08-14"), 587 | newTimeTest( 588 | "SELECT _t FROM _gopgsql_test_time;", 589 | timeFormat, 590 | "20:03:38"), 591 | newTimeTest( 592 | "SELECT _ttz FROM _gopgsql_test_time;", 593 | timeFormat+"-07", 594 | "20:03:38+02"), 595 | newTimeTest( 596 | "SELECT _ts FROM _gopgsql_test_time;", 597 | timestampFormat, 598 | "2010-08-14 20:03:38"), 599 | newTimeTest( 600 | "SELECT _tstz FROM _gopgsql_test_time;", 601 | timestampFormat+"-07", 602 | "2010-08-14 20:03:38+02"), 603 | } 604 | 605 | for _, test := range tests { 606 | withConn(t, func(conn *Conn) { 607 | conn.Execute("DROP TABLE _gopgsql_test_time;") 608 | 609 | _, err := conn.Execute( 610 | `CREATE TABLE _gopgsql_test_time 611 | ( 612 | _d DATE, 613 | _t TIME, 614 | _ttz TIME WITH TIME ZONE, 615 | _ts TIMESTAMP, 616 | _tstz TIMESTAMP WITH TIME ZONE 617 | );`) 618 | if err != nil { 619 | t.Error("failed to create table:", err) 620 | return 621 | } 622 | defer func() { 623 | conn.Execute("DROP TABLE _gopgsql_test_time;") 624 | }() 625 | 626 | _, err = conn.Execute("SET TimeZone = 02; SET DateStyle = ISO") 627 | if err != nil { 628 | t.Error("failed to set time zone or date style:", err) 629 | return 630 | } 631 | 632 | _d, _ := time.Parse(dateFormat, "2010-08-14") 633 | _t, _ := time.Parse(timeFormat, "20:03:38") 634 | _ttz, _ := time.Parse(timeFormat, "20:03:38") 635 | _ts, _ := time.Parse(timestampFormat, "2010-08-14 20:03:38") 636 | _tstz, _ := time.Parse(timestampFormat, "2010-08-14 20:03:38") 637 | 638 | stmt, err := conn.Prepare( 639 | `INSERT INTO _gopgsql_test_time 640 | (_d, _t, _ttz, _ts, _tstz) 641 | VALUES 642 | (@d, @t, @ttz, @ts, @tstz);`, 643 | param("@d", Date, _d), 644 | param("@t", Time, _t.Unix()), 645 | param("@ttz", TimeTZ, _ttz), 646 | param("@ts", Timestamp, _ts), 647 | param("@tstz", TimestampTZ, uint64(_tstz.Unix()))) 648 | if err != nil { 649 | t.Error("failed to prepare insert statement:", err) 650 | return 651 | } 652 | defer stmt.Close() 653 | 654 | _, err = stmt.Execute() 655 | if err != nil { 656 | t.Error("failed to execute insert statement:", err) 657 | } 658 | 659 | var seconds uint64 660 | _, err = conn.Scan(test.command, &seconds) 661 | if err != nil { 662 | t.Error(err) 663 | return 664 | } 665 | if seconds != uint64(test.seconds) { 666 | t.Errorf("'%s' failed - have: '%d', but want '%d'", test.command, seconds, test.seconds) 667 | } 668 | 669 | var tm time.Time 670 | _, err = conn.Scan(test.command, &tm) 671 | if err != nil { 672 | t.Error(err) 673 | return 674 | } 675 | timeString := tm.String() 676 | if timeString != test.timeString { 677 | t.Errorf("'%s' failed - have: '%s', but want '%s'", test.command, timeString, test.timeString) 678 | } 679 | }) 680 | } 681 | } 682 | 683 | func Test_Conn_WithSavepoint(t *testing.T) { 684 | withConn(t, func(conn *Conn) { 685 | conn.Execute("DROP TABLE _gopgsql_test_account;") 686 | 687 | _, err := conn.Execute(` 688 | CREATE TABLE _gopgsql_test_account 689 | ( 690 | name VARCHAR(20) PRIMARY KEY, 691 | balance REAL NOT NULL 692 | ); 693 | INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Alice', 100.0); 694 | INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Bob', 0.0); 695 | INSERT INTO _gopgsql_test_account (name, balance) VALUES ('Wally', 0.0); 696 | `) 697 | if err != nil { 698 | t.Error("failed to create table:", err) 699 | return 700 | } 701 | defer func() { 702 | conn.Execute("DROP TABLE _gopgsql_test_account;") 703 | }() 704 | 705 | err = conn.WithTransaction(ReadCommittedIsolation, func() (err error) { 706 | _, err = conn.Execute(` 707 | UPDATE _gopgsql_test_account 708 | SET balance = balance - 100.0 709 | WHERE name = 'Alice';`) 710 | if err != nil { 711 | t.Error("failed to execute update:", err) 712 | return 713 | } 714 | 715 | err = conn.WithSavepoint(ReadCommittedIsolation, func() (err error) { 716 | _, err = conn.Execute(` 717 | UPDATE _gopgsql_test_account 718 | SET balance = balance + 100.0 719 | WHERE name = 'Bob';`) 720 | if err != nil { 721 | t.Error("failed to execute update:", err) 722 | return 723 | } 724 | 725 | err = errors.New("wrong credit account") 726 | 727 | return 728 | }) 729 | 730 | _, err = conn.Execute(` 731 | UPDATE _gopgsql_test_account 732 | SET balance = balance + 100.0 733 | WHERE name = 'Wally';`) 734 | if err != nil { 735 | t.Error("failed to execute update:", err) 736 | return 737 | } 738 | 739 | return 740 | }) 741 | 742 | var rs *ResultSet 743 | rs, err = conn.Query("SELECT name, balance FROM _gopgsql_test_account;") 744 | if err != nil { 745 | t.Error("failed to query:", err) 746 | return 747 | } 748 | defer rs.Close() 749 | 750 | have := make(map[string]float64) 751 | want := map[string]float64{ 752 | "Alice": 0, 753 | "Bob": 0, 754 | "Wally": 100, 755 | } 756 | var name string 757 | var balance float64 758 | var fetched bool 759 | 760 | for { 761 | fetched, err = rs.ScanNext(&name, &balance) 762 | if err != nil { 763 | t.Error("failed to scan next:", err) 764 | return 765 | } 766 | if !fetched { 767 | break 768 | } 769 | 770 | have[name] = balance 771 | } 772 | 773 | for name, haveBalance := range have { 774 | wantBalance := want[name] 775 | 776 | if math.Abs(haveBalance-wantBalance) > 0.000001 { 777 | t.Errorf("name: %s have: %f, but want: %f", name, haveBalance, wantBalance) 778 | } 779 | } 780 | }) 781 | } 782 | 783 | func Test_Numeric(t *testing.T) { 784 | strWant := "0." + strings.Repeat("0123456789", 100)[1:] 785 | numWant, _ := big.NewRat(1, 1).SetString(strWant) 786 | numParam := param("@num", Numeric, numWant) 787 | 788 | withStatementResultSet(t, "SELECT @num;", []*Parameter{numParam}, func(rs *ResultSet) { 789 | // Use interface{}, so *resultSet.Any will be tested as well. 790 | var numHaveInterface interface{} 791 | 792 | _, err := rs.ScanNext(&numHaveInterface) 793 | if err != nil { 794 | t.Error("failed to scan next:", err) 795 | } 796 | 797 | numHave, ok := numHaveInterface.(*big.Rat) 798 | if !ok { 799 | t.Errorf("unexpected type: %T", numHaveInterface) 800 | return 801 | } 802 | 803 | strHave := numHave.FloatString(999) 804 | if strHave != strWant { 805 | t.Errorf("have: %s, but want: %s", strHave, strWant) 806 | } 807 | }) 808 | } 809 | 810 | func Test_FloatInf(t *testing.T) { 811 | numParam := param("@num", Real, float32(math.Inf(-1))) 812 | 813 | withStatementResultSet(t, "SELECT @num;", []*Parameter{numParam}, func(rs *ResultSet) { 814 | var numHave float32 815 | 816 | _, err := rs.ScanNext(&numHave) 817 | if err != nil { 818 | t.Error("failed to scan next:", err) 819 | } 820 | 821 | if !math.IsInf(float64(numHave), -1) { 822 | t.Fail() 823 | } 824 | }) 825 | } 826 | 827 | func Test_FloatNaN(t *testing.T) { 828 | numParam := param("@num", Double, math.NaN()) 829 | 830 | withStatementResultSet(t, "SELECT @num;", []*Parameter{numParam}, func(rs *ResultSet) { 831 | var numHave float64 832 | 833 | _, err := rs.ScanNext(&numHave) 834 | if err != nil { 835 | t.Error("failed to scan next:", err) 836 | } 837 | 838 | if !math.IsNaN(numHave) { 839 | t.Fail() 840 | } 841 | }) 842 | } 843 | 844 | func Test_Parameter_SetValue_NilPtr_ValueReturnsNil(t *testing.T) { 845 | initialValue, _ := big.NewRat(1, 1).SetString("123.456") 846 | p := param("@num", Numeric, initialValue) 847 | 848 | p.SetValue(nil) 849 | 850 | if p.Value() != nil { 851 | t.Fail() 852 | } 853 | } 854 | 855 | // This test hung when using no timeout before *ResultSet.FetchNext was fixed. 856 | func Test_Query_Exception(t *testing.T) { 857 | withConn(t, func(conn *Conn) { 858 | conn.Execute("CREATE LANGUAGE plpgsql;") 859 | 860 | _, err := conn.Execute(` 861 | CREATE OR REPLACE FUNCTION one_or_fail(num int) RETURNS int AS $$ 862 | BEGIN 863 | IF num != 1 THEN 864 | RAISE EXCEPTION 'FAIL!'; 865 | END IF; 866 | 867 | RETURN 1; 868 | END; 869 | $$ LANGUAGE plpgsql; 870 | `) 871 | if err != nil { 872 | t.Error("create function failed:", err) 873 | return 874 | } 875 | defer func() { 876 | conn.Execute("DROP FUNCTION one_or_fail(int);") 877 | }() 878 | 879 | rs, err := conn.Query("SELECT one_or_fail(2);") 880 | if err != nil { 881 | t.Error("query failed") 882 | return 883 | } 884 | defer rs.Close() 885 | 886 | _, err = rs.FetchNext() 887 | if err == nil { 888 | t.Error("error expected") 889 | return 890 | } 891 | if _, ok := err.(*Error); !ok { 892 | t.Error("*pgsql.Error expected") 893 | return 894 | } 895 | rs.Close() 896 | 897 | rs, err = conn.Query("SELECT one_or_fail(2);") 898 | if err != nil { 899 | t.Error("query failed") 900 | return 901 | } 902 | defer rs.Close() 903 | 904 | var one int 905 | _, err = rs.ScanNext(&one) 906 | if err == nil { 907 | t.Error("error expected") 908 | return 909 | } 910 | if _, ok := err.(*Error); !ok { 911 | t.Error("*pgsql.Error expected") 912 | return 913 | } 914 | rs.Close() 915 | 916 | stmt, err := conn.Prepare("SELECT one_or_fail(2);") 917 | if err != nil { 918 | t.Error("prepare failed") 919 | return 920 | } 921 | defer stmt.Close() 922 | 923 | _, err = stmt.Execute() 924 | if err == nil { 925 | t.Error("error expected") 926 | return 927 | } 928 | if _, ok := err.(*Error); !ok { 929 | t.Error("*pgsql.Error expected") 930 | return 931 | } 932 | 933 | _, err = stmt.Scan(&one) 934 | if err == nil { 935 | t.Error("error expected") 936 | return 937 | } 938 | if _, ok := err.(*Error); !ok { 939 | t.Error("*pgsql.Error expected") 940 | return 941 | } 942 | 943 | _, err = conn.Execute("SELECT one_or_fail(2);") 944 | if err == nil { 945 | t.Error("error expected") 946 | return 947 | } 948 | if _, ok := err.(*Error); !ok { 949 | t.Error("*pgsql.Error expected") 950 | return 951 | } 952 | 953 | _, err = conn.Scan("SELECT one_or_fail(2);", &one) 954 | if err == nil { 955 | t.Error("error expected") 956 | return 957 | } 958 | if _, ok := err.(*Error); !ok { 959 | t.Error("*pgsql.Error expected") 960 | return 961 | } 962 | 963 | var abc string 964 | _, err = conn.Scan("SELECT 'abc';", &abc) 965 | if err != nil { 966 | t.Error("*Conn.Scan failed after previous expected *Conn.Scan error") 967 | return 968 | } 969 | }) 970 | } 971 | 972 | func Test_bufio_Reader_Read_release_2010_12_08(t *testing.T) { 973 | withConn(t, func(conn *Conn) { 974 | conn.Execute("DROP TABLE _gopgsql_test;") 975 | 976 | _, err := conn.Execute(` 977 | CREATE TABLE _gopgsql_test 978 | ( 979 | str text 980 | ); 981 | `) 982 | if err != nil { 983 | t.Error("failed to create table:", err) 984 | return 985 | } 986 | defer func() { 987 | conn.Execute("DROP TABLE _gopgsql_test;") 988 | }() 989 | 990 | in := strings.Repeat("x", 10000) 991 | 992 | stmt, err := conn.Prepare("INSERT INTO _gopgsql_test (str) VALUES (@str);", param("@str", Text, in)) 993 | if err != nil { 994 | t.Error("failed to prepare statement:", err) 995 | return 996 | } 997 | defer stmt.Close() 998 | 999 | _, err = stmt.Execute() 1000 | if err != nil { 1001 | t.Error("failed to execute statement:", err) 1002 | return 1003 | } 1004 | 1005 | var out string 1006 | 1007 | _, err = conn.Scan("SELECT str FROM _gopgsql_test;", &out) 1008 | if err != nil { 1009 | t.Error("failed to read str:", err) 1010 | return 1011 | } 1012 | 1013 | if out != in { 1014 | t.Error("out != in") 1015 | } 1016 | }) 1017 | } 1018 | 1019 | func Test_Issue2_Uint64_OutOfRange(t *testing.T) { 1020 | withConn(t, func(conn *Conn) { 1021 | want := uint64(9989608743) 1022 | query := fmt.Sprintf("SELECT %d::bigint;", want) 1023 | 1024 | var have uint64 1025 | if _, err := conn.Scan(query, &have); err != nil { 1026 | t.Error("failed to read uint64:", err) 1027 | return 1028 | } 1029 | 1030 | if have != want { 1031 | t.Errorf("have: %d, but want: %d", have, want) 1032 | } 1033 | }) 1034 | } 1035 | 1036 | func Test_Issue14_CopyFrom(t *testing.T) { 1037 | const data = "1\ts1\t\\N\ttrue\t2\n" 1038 | dataBuf := bytes.NewBufferString(data) 1039 | withConnLog(t, LogNothing, func(conn *Conn) { 1040 | if _, err := conn.Execute("TRUNCATE table1;"); err != nil { 1041 | t.Error("failed to truncate table1:", err) 1042 | return 1043 | } 1044 | 1045 | if n, err := conn.CopyFrom("COPY table1 FROM STDIN;", dataBuf); err != nil && n != 1 { 1046 | t.Error("COPY failed. err:", err, "n:", n) 1047 | } 1048 | 1049 | var b1, b2, b3, b4, b5 bool 1050 | if _, err := conn.Scan("SELECT id = 1, strreq = 's1', stropt IS NULL, blnreq, i32req = 2 FROM table1;", 1051 | &b1, &b2, &b3, &b4, &b5); err != nil { 1052 | t.Error("failed to SELECT table1:", err) 1053 | return 1054 | } else { 1055 | if !(b1 && b2 && b3 && b4 && b5) { 1056 | t.Error("some columns have incorrect data:", b1, b2, b3, b4, b5) 1057 | return 1058 | } 1059 | } 1060 | }) 1061 | } 1062 | --------------------------------------------------------------------------------