├── .github ├── dependabot.yml └── workflows │ └── test.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── Makefile ├── README.md ├── compose.yaml ├── db.go ├── db_test.go ├── go.mod ├── go.sum └── main_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "gomod" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: "github-actions" 8 | directory: "/" 9 | schedule: 10 | interval: "weekly" 11 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | services: 17 | postgres: 18 | image: postgres:12 19 | env: 20 | POSTGRES_HOST_AUTH_METHOD: trust 21 | POSTGRES_DB: pgtxdbtest 22 | POSTGRES_USER: pgtxdbtest 23 | ports: 24 | - 5432:5432 25 | options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 26 | steps: 27 | - uses: actions/checkout@v4 28 | - uses: actions/setup-go@v5 29 | with: 30 | go-version-file: go.mod 31 | - uses: golangci/golangci-lint-action@v6 32 | - run: make test 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | vendor 27 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | enable: 3 | - gofmt 4 | - gosimple 5 | - misspell 6 | run: 7 | timeout: 5m 8 | issues: 9 | exclude-rules: 10 | - path: _test\.go 11 | linters: 12 | - errcheck 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Akira Chiku 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of pgtxdb nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | test: vet 3 | go test -v -count 1 ./... 4 | 5 | .PHONY: vet 6 | vet: 7 | go vet ./... 8 | 9 | .PHONY: lint 10 | lint: 11 | golangci-lint run 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pgtxdb 2 | 3 | [![test](https://github.com/kanmu/pgtxdb/actions/workflows/test.yml/badge.svg)](https://github.com/kanmu/pgtxdb/actions/workflows/test.yml) 4 | [![GitHub license](https://img.shields.io/badge/license-MIT-blue.svg)](https://raw.githubusercontent.com/kanmu/pgtxdb/master/LICENSE) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/kanmu/pgtxdb)](https://goreportcard.com/report/github.com/kanmu/pgtxdb) 6 | 7 | ## Description 8 | 9 | Single transaction sql driver for Golang x PostgreSQL. This is almost clone of [go-txdb](https://github.com/DATA-DOG/go-txdb) with a bit of PostgreSQL tweeks. 10 | 11 | - When `conn.Begin()` is called, this library executes `SAVEPOINT pgtxdb_xxx;` instead of actually begins transaction. 12 | - `tx.Commit()` does nothing. 13 | - `ROLLBACK TO SAVEPOINT pgtxdb_xxx;` will be executed upon `tx.Rollback()` call so that it can emulate transaction rollback. 14 | - Above features enable us to emulate multiple transactions in one test case. 15 | 16 | 17 | ## Run test 18 | 19 | ``` 20 | docker compose up -d 21 | make test 22 | ``` 23 | -------------------------------------------------------------------------------- /compose.yaml: -------------------------------------------------------------------------------- 1 | services: 2 | db: 3 | image: postgres:12 4 | environment: 5 | - POSTGRES_HOST_AUTH_METHOD=trust 6 | - POSTGRES_DB=pgtxdbtest 7 | - POSTGRES_USER=pgtxdbtest 8 | ports: 9 | - 5432:5432 10 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | /* 2 | Package pgtxdb is a single transaction based database sql driver for PostgreSQL. 3 | When the connection is opened, it starts a transaction and all operations performed on this *sql.DB 4 | will be within that transaction. If concurrent actions are performed, the lock is 5 | acquired and connection is always released the statements and rows are not holding the 6 | connection. 7 | 8 | Why is it useful. A very basic use case would be if you want to make functional tests 9 | you can prepare a test database and within each test you do not have to reload a database. 10 | All tests are isolated within transaction and though, performs fast. And you do not have 11 | to interface your sql.DB reference in your code, txdb is like a standard sql.Driver. 12 | 13 | This driver supports any sql.Driver connection to be opened. You can register txdb 14 | for different sql drivers and have it under different driver names. Under the hood 15 | whenever a txdb driver is opened, it attempts to open a real connection and starts 16 | transaction. When close is called, it rollbacks transaction leaving your prepared 17 | test database in the same state as before. 18 | 19 | Given, you have a mysql database called txdb_test and a table users with a username 20 | column. 21 | 22 | Example: 23 | 24 | package main 25 | 26 | import ( 27 | "database/sql" 28 | "log" 29 | 30 | "github.com/DATA-DOG/go-txdb" 31 | _ "github.com/go-sql-driver/mysql" 32 | ) 33 | 34 | func init() { 35 | // we register an sql driver named "txdb" 36 | txdb.Register("txdb", "mysql", "root@/txdb_test") 37 | } 38 | 39 | func main() { 40 | // dsn serves as an unique identifier for connection pool 41 | db, err := sql.Open("txdb", "identifier") 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | defer db.Close() 46 | 47 | if _, err := db.Exec(`INSERT INTO users(username) VALUES("gopher")`); err != nil { 48 | log.Fatal(err) 49 | } 50 | } 51 | 52 | Every time you will run this application, it will remain in the same state as before. 53 | */ 54 | package pgtxdb 55 | 56 | import ( 57 | "context" 58 | "database/sql" 59 | "database/sql/driver" 60 | "fmt" 61 | "io" 62 | "sync" 63 | 64 | "github.com/pkg/errors" 65 | ) 66 | 67 | // Register a txdb sql driver under the given sql driver name 68 | // which can be used to open a single transaction based database 69 | // connection. 70 | // 71 | // When Open is called any number of times it returns 72 | // the same transaction connection. Any Begin, Commit calls 73 | // will not start or close the transaction. 74 | // 75 | // When Close is called, the transaction is rolled back. 76 | // 77 | // Use drv (Driver) and dsn (DataSourceName) as the standard sql properties for 78 | // your test database connection to be isolated within transaction. 79 | // 80 | // The drv and dsn are the same items passed into `sql.Open(drv, dsn)`. 81 | // 82 | // Note: if you open a secondary database, make sure to differianciate 83 | // the dsn string when opening the sql.DB. The transaction will be 84 | // isolated within that dsn 85 | func Register(name, drv, dsn string) { 86 | sql.Register(name, &txDriver{ 87 | dsn: dsn, 88 | drv: drv, 89 | conns: make(map[string]*conn), 90 | log: io.Discard, 91 | }) 92 | } 93 | 94 | // txDriver is an sql driver which runs on single transaction 95 | // when the Close is called, transaction is rolled back 96 | type conn struct { 97 | sync.Mutex 98 | tx *sql.Tx 99 | dsn string 100 | opened int 101 | drv *txDriver 102 | savepoints []int 103 | log io.Writer 104 | } 105 | 106 | type txDriver struct { 107 | sync.Mutex 108 | db *sql.DB 109 | conns map[string]*conn 110 | log io.Writer 111 | 112 | drv string 113 | dsn string 114 | } 115 | 116 | type txConnector struct { 117 | driver *txDriver 118 | dsn string 119 | } 120 | 121 | func (c *txConnector) Driver() driver.Driver { 122 | return c.driver 123 | } 124 | 125 | func (c *txConnector) Connect(ctx context.Context) (driver.Conn, error) { 126 | return c.driver.Open(c.dsn) 127 | } 128 | 129 | func NewConnector(dsn, srcDrv, srcDsn string, log io.Writer) driver.Connector { 130 | if log == nil { 131 | log = io.Discard 132 | } 133 | 134 | return &txConnector{ 135 | driver: &txDriver{ 136 | dsn: srcDsn, 137 | drv: srcDrv, 138 | conns: make(map[string]*conn), 139 | log: log, 140 | }, 141 | dsn: dsn, 142 | } 143 | } 144 | 145 | func (d *txDriver) Open(dsn string) (driver.Conn, error) { 146 | d.Lock() 147 | defer d.Unlock() 148 | // first open a real database connection 149 | var err error 150 | if d.db == nil { 151 | db, err := sql.Open(d.drv, d.dsn) 152 | if err != nil { 153 | return nil, err 154 | } 155 | d.db = db 156 | } 157 | c, ok := d.conns[dsn] 158 | if !ok { 159 | c = &conn{dsn: dsn, drv: d, savepoints: []int{0}, log: d.log} 160 | fmt.Fprintf(c.log, "%s: open\n", c.dsn) 161 | c.tx, err = d.db.Begin() 162 | d.conns[dsn] = c 163 | } 164 | c.opened++ 165 | return c, err 166 | } 167 | 168 | func (c *conn) Close() (err error) { 169 | c.drv.Lock() 170 | defer c.drv.Unlock() 171 | c.opened-- 172 | if c.opened == 0 { 173 | err = c.tx.Rollback() 174 | if err != nil { 175 | return 176 | } 177 | c.tx = nil 178 | delete(c.drv.conns, c.dsn) 179 | } 180 | fmt.Fprintf(c.log, "%s: close\n", c.dsn) 181 | return 182 | } 183 | 184 | func (c *conn) Begin() (driver.Tx, error) { 185 | savepointID := len(c.savepoints) 186 | c.savepoints = append(c.savepoints, savepointID) 187 | sql := fmt.Sprintf("SAVEPOINT pgtxdb_%d", savepointID) 188 | _, err := c.tx.Exec(sql) 189 | if err != nil { 190 | return nil, errors.Wrap(err, "failed to create savepoint") 191 | } 192 | fmt.Fprintf(c.log, "%s: begin\n", c.dsn) 193 | return c, nil 194 | } 195 | 196 | func (c *conn) Commit() error { 197 | fmt.Fprintf(c.log, "%s: commit\n", c.dsn) 198 | return nil 199 | } 200 | 201 | func (c *conn) Rollback() error { 202 | savepointID := c.savepoints[len(c.savepoints)-1] 203 | c.savepoints = c.savepoints[:len(c.savepoints)-1] 204 | sql := fmt.Sprintf("ROLLBACK TO SAVEPOINT pgtxdb_%d", savepointID) 205 | _, err := c.tx.Exec(sql) 206 | if err != nil { 207 | return errors.Wrap(err, "failed to rollback to savepoint") 208 | } 209 | fmt.Fprintf(c.log, "%s: rollback\n", c.dsn) 210 | return nil 211 | } 212 | 213 | func (c *conn) Prepare(query string) (driver.Stmt, error) { 214 | c.Lock() 215 | defer c.Unlock() 216 | 217 | st, err := c.tx.Prepare(query) 218 | if err != nil { 219 | return nil, err 220 | } 221 | return &stmt{st: st, dsn: c.dsn, log: c.log}, nil 222 | } 223 | 224 | func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) { 225 | c.Lock() 226 | defer c.Unlock() 227 | 228 | fmt.Fprintf(c.log, "%s: exec %s\n", c.dsn, query) 229 | return c.tx.Exec(query, mapArgs(args)...) 230 | } 231 | 232 | func mapArgs(args []driver.Value) (res []interface{}) { 233 | res = make([]interface{}, len(args)) 234 | for i := range args { 235 | res[i] = args[i] 236 | } 237 | return 238 | } 239 | 240 | func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 241 | c.Lock() 242 | defer c.Unlock() 243 | 244 | // query rows 245 | rs, err := c.tx.Query(query, mapArgs(args)...) 246 | if err != nil { 247 | return nil, err 248 | } 249 | defer rs.Close() 250 | 251 | return buildRows(rs) 252 | } 253 | 254 | type stmt struct { 255 | st *sql.Stmt 256 | dsn string 257 | log io.Writer 258 | } 259 | 260 | func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { 261 | fmt.Fprintf(s.log, "%s: exec prepared statement\n", s.dsn) 262 | return s.st.Exec(mapArgs(args)...) 263 | } 264 | 265 | func (s *stmt) NumInput() int { 266 | return -1 267 | } 268 | 269 | func (s *stmt) Close() error { 270 | return s.st.Close() 271 | } 272 | 273 | func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { 274 | rows, err := s.st.Query(mapArgs(args)...) 275 | if err != nil { 276 | return nil, err 277 | } 278 | return buildRows(rows) 279 | } 280 | 281 | type rows struct { 282 | rows [][]driver.Value 283 | pos int 284 | cols []string 285 | } 286 | 287 | func (r *rows) Columns() (cols []string) { 288 | return r.cols 289 | } 290 | 291 | func (r *rows) Next(dest []driver.Value) error { 292 | r.pos++ 293 | if r.pos > len(r.rows) { 294 | return io.EOF 295 | } 296 | 297 | for i, val := range r.rows[r.pos-1] { 298 | dest[i] = *(val.(*interface{})) 299 | } 300 | 301 | return nil 302 | } 303 | 304 | func (r *rows) Close() error { 305 | return nil 306 | } 307 | 308 | func (r *rows) read(rs *sql.Rows) error { 309 | var err error 310 | r.cols, err = rs.Columns() 311 | if err != nil { 312 | return err 313 | } 314 | for rs.Next() { 315 | values := make([]interface{}, len(r.cols)) 316 | for i := range values { 317 | values[i] = new(interface{}) 318 | } 319 | if err := rs.Scan(values...); err != nil { 320 | return err 321 | } 322 | row := make([]driver.Value, len(r.cols)) 323 | for i, v := range values { 324 | row[i] = driver.Value(v) 325 | } 326 | r.rows = append(r.rows, row) 327 | } 328 | return rs.Err() 329 | } 330 | 331 | type rowSets struct { 332 | sets []*rows 333 | pos int 334 | } 335 | 336 | func (rs *rowSets) Columns() []string { 337 | return rs.sets[rs.pos].cols 338 | } 339 | 340 | func (rs *rowSets) Close() error { 341 | return nil 342 | } 343 | 344 | // advances to next row 345 | func (rs *rowSets) Next(dest []driver.Value) error { 346 | return rs.sets[rs.pos].Next(dest) 347 | } 348 | 349 | func buildRows(r *sql.Rows) (driver.Rows, error) { 350 | set := &rowSets{} 351 | rs := &rows{} 352 | if err := rs.read(r); err != nil { 353 | return set, err 354 | } 355 | set.sets = append(set.sets, rs) 356 | for r.NextResultSet() { 357 | rss := &rows{} 358 | if err := rss.read(r); err != nil { 359 | return set, err 360 | } 361 | set.sets = append(set.sets, rss) 362 | } 363 | return set, nil 364 | } 365 | 366 | // Implement the "RowsNextResultSet" interface 367 | func (rs *rowSets) HasNextResultSet() bool { 368 | return rs.pos+1 < len(rs.sets) 369 | } 370 | 371 | // Implement the "RowsNextResultSet" interface 372 | func (rs *rowSets) NextResultSet() error { 373 | if !rs.HasNextResultSet() { 374 | return io.EOF 375 | } 376 | 377 | rs.pos++ 378 | return nil 379 | } 380 | 381 | // Implement the "QueryerContext" interface 382 | func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 383 | c.Lock() 384 | defer c.Unlock() 385 | 386 | rs, err := c.tx.QueryContext(ctx, query, mapNamedArgs(args)...) 387 | if err != nil { 388 | return nil, err 389 | } 390 | defer rs.Close() 391 | 392 | return buildRows(rs) 393 | } 394 | 395 | // Implement the "ExecerContext" interface 396 | func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 397 | c.Lock() 398 | defer c.Unlock() 399 | 400 | fmt.Fprintf(c.log, "%s: exec %s\n", c.dsn, query) 401 | return c.tx.ExecContext(ctx, query, mapNamedArgs(args)...) 402 | } 403 | 404 | // Implement the "ConnBeginTx" interface 405 | func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 406 | savepointID := len(c.savepoints) 407 | c.savepoints = append(c.savepoints, savepointID) 408 | sql := fmt.Sprintf("SAVEPOINT pgtxdb_%d", savepointID) 409 | _, err := c.tx.Exec(sql) 410 | if err != nil { 411 | return nil, errors.Wrap(err, "failed to create savepoint") 412 | } 413 | fmt.Fprintf(c.log, "%s: begin\n", c.dsn) 414 | return c, nil 415 | } 416 | 417 | // Implement the "ConnPrepareContext" interface 418 | func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 419 | c.Lock() 420 | defer c.Unlock() 421 | 422 | st, err := c.tx.PrepareContext(ctx, query) 423 | if err != nil { 424 | return nil, err 425 | } 426 | return &stmt{st: st, dsn: c.dsn, log: c.log}, nil 427 | } 428 | 429 | // Implement the "Pinger" interface 430 | func (c *conn) Ping(ctx context.Context) error { 431 | return c.drv.db.PingContext(ctx) 432 | } 433 | 434 | // Implement the "StmtExecContext" interface 435 | func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 436 | fmt.Fprintf(s.log, "%s: exec prepared statement\n", s.dsn) 437 | return s.st.ExecContext(ctx, mapNamedArgs(args)...) 438 | } 439 | 440 | // Implement the "StmtQueryContext" interface 441 | func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 442 | rows, err := s.st.QueryContext(ctx, mapNamedArgs(args)...) 443 | if err != nil { 444 | return nil, err 445 | } 446 | return buildRows(rows) 447 | } 448 | 449 | func mapNamedArgs(args []driver.NamedValue) (res []interface{}) { 450 | res = make([]interface{}, len(args)) 451 | for i := range args { 452 | name := args[i].Name 453 | if name != "" { 454 | res[i] = sql.Named(name, args[i].Value) 455 | } else { 456 | res[i] = args[i].Value 457 | } 458 | } 459 | return 460 | } 461 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | package pgtxdb_test 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "runtime" 7 | "strings" 8 | "sync" 9 | "testing" 10 | 11 | _ "github.com/jackc/pgx/v5/stdlib" // pgx 12 | "github.com/kanmu/pgtxdb" 13 | ) 14 | 15 | func init() { 16 | pgtxdb.Register("pgtxdb", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable") 17 | } 18 | 19 | func TestShouldRunWithinTransaction(t *testing.T) { 20 | t.Parallel() 21 | var count int 22 | db1, err := sql.Open("pgtxdb", "one") 23 | if err != nil { 24 | t.Fatalf("failed to open a postgres connection, have you run 'make test'? err: %s", err) 25 | } 26 | defer db1.Close() 27 | 28 | _, err = db1.Exec(`INSERT INTO app_user(username, email) VALUES('txdb', 'txdb@test.com')`) 29 | if err != nil { 30 | t.Fatalf("failed to insert an app_user: %s", err) 31 | } 32 | err = db1.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count) 33 | if err != nil { 34 | t.Fatalf("failed to count users: %s", err) 35 | } 36 | if count != 1 { 37 | t.Fatalf("expected 1 user to be in database, but got %d", count) 38 | } 39 | 40 | db2, err := sql.Open("pgtxdb", "two") 41 | if err != nil { 42 | t.Fatalf("failed to reopen a postgres connection: %s", err) 43 | } 44 | defer db2.Close() 45 | 46 | err = db2.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count) 47 | if err != nil { 48 | t.Fatalf("failed to count app_user: %s", err) 49 | } 50 | if count != 0 { 51 | t.Fatalf("expected 0 user to be in database, but got %d", count) 52 | } 53 | } 54 | 55 | func TestShouldRunWithinTransactionForOpenDB(t *testing.T) { 56 | t.Parallel() 57 | var count int 58 | var db1Log strings.Builder 59 | db1 := sql.OpenDB(pgtxdb.NewConnector("one", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable", &db1Log)) 60 | defer db1.Close() 61 | 62 | _, err := db1.Exec(`INSERT INTO app_user(username, email) VALUES('txdb', 'txdb@test.com')`) 63 | if err != nil { 64 | t.Fatalf("failed to insert an app_user: %s", err) 65 | } 66 | err = db1.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count) 67 | if err != nil { 68 | t.Fatalf("failed to count users: %s", err) 69 | } 70 | if count != 1 { 71 | t.Fatalf("expected 1 user to be in database, but got %d", count) 72 | } 73 | 74 | expectedDb1Log := `one: open 75 | one: exec INSERT INTO app_user(username, email) VALUES('txdb', 'txdb@test.com') 76 | ` 77 | 78 | if db1Log.String() != expectedDb1Log { 79 | t.Errorf("unexpected db1 log: %s", db1Log.String()) 80 | } 81 | 82 | var db2Log strings.Builder 83 | db2 := sql.OpenDB(pgtxdb.NewConnector("two", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable", &db2Log)) 84 | defer db2.Close() 85 | 86 | err = db2.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count) 87 | if err != nil { 88 | t.Fatalf("failed to count app_user: %s", err) 89 | } 90 | if count != 0 { 91 | t.Errorf("expected 0 user to be in database, but got %d", count) 92 | } 93 | 94 | expectedDb2Log := `two: open 95 | ` 96 | 97 | if db2Log.String() != expectedDb2Log { 98 | t.Fatalf("unexpected db2 log: %s", db2Log.String()) 99 | } 100 | } 101 | 102 | func TestShouldNotHoldConnectionForRows(t *testing.T) { 103 | t.Parallel() 104 | db, err := sql.Open("pgtxdb", "three") 105 | if err != nil { 106 | t.Fatalf("failed to open a postgres connection, have you run 'make test'? err: %s", err) 107 | } 108 | defer db.Close() 109 | 110 | rows, err := db.Query("SELECT username FROM app_user") 111 | if err != nil { 112 | t.Fatalf("failed to query users: %s", err) 113 | } 114 | defer rows.Close() 115 | 116 | _, err = db.Exec(`INSERT INTO app_user(username, email) VALUES('txdb', 'txdb@test.com')`) 117 | if err != nil { 118 | t.Fatalf("failed to insert an app_user: %s", err) 119 | } 120 | } 121 | 122 | func TestShouldPerformParallelActions(t *testing.T) { 123 | runtime.GOMAXPROCS(runtime.NumCPU()) 124 | t.Parallel() 125 | db, err := sql.Open("pgtxdb", "four") 126 | if err != nil { 127 | t.Fatalf("failed to open a postgres connection, have you run 'make test'? err: %s", err) 128 | } 129 | defer db.Close() 130 | 131 | wg := &sync.WaitGroup{} 132 | for i := 0; i < 4; i++ { 133 | wg.Add(1) 134 | go func(d *sql.DB, idx int) { 135 | defer wg.Done() 136 | rows, err := d.Query("SELECT username FROM app_user") 137 | if err != nil { 138 | t.Errorf("failed to query app_user: %s", err) 139 | } 140 | defer rows.Close() 141 | 142 | username := fmt.Sprintf("parallel%d", idx) 143 | email := fmt.Sprintf("parallel%d@test.com", idx) 144 | _, err = d.Exec(`INSERT INTO app_user(username, email) VALUES($1, $2)`, username, email) 145 | if err != nil { 146 | t.Errorf("failed to insert an app_user: %s", err) 147 | } 148 | }(db, i) 149 | } 150 | wg.Wait() 151 | var count int 152 | err = db.QueryRow("SELECT COUNT(id) FROM app_user").Scan(&count) 153 | if err != nil { 154 | t.Fatalf("failed to count users: %s", err) 155 | } 156 | if count != 4 { 157 | t.Fatalf("expected 4 users to be in database, but got %d", count) 158 | } 159 | } 160 | 161 | func TestShouldHandlePrepare(t *testing.T) { 162 | t.Parallel() 163 | db, err := sql.Open("pgtxdb", "five") 164 | if err != nil { 165 | t.Fatalf("failed to open a postgres connection, have you run 'make test'? err: %s", err) 166 | } 167 | defer db.Close() 168 | 169 | stmt1, err := db.Prepare("SELECT email FROM app_user WHERE username = $1") 170 | if err != nil { 171 | t.Fatalf("could not prepare - %s", err) 172 | } 173 | 174 | stmt2, err := db.Prepare("INSERT INTO app_user(username, email) VALUES($1, $2)") 175 | if err != nil { 176 | t.Fatalf("could not prepare - %s", err) 177 | } 178 | _, err = stmt2.Exec("jane", "jane@gmail.com") 179 | if err != nil { 180 | t.Fatalf("should have inserted user - %s", err) 181 | } 182 | 183 | var email string 184 | if err = stmt1.QueryRow("jane").Scan(&email); err != nil { 185 | t.Fatalf("could not scan email - %s", err) 186 | } 187 | 188 | _, err = stmt2.Exec("mark", "mark.spencer@gmail.com") 189 | if err != nil { 190 | t.Fatalf("should have inserted user - %s", err) 191 | } 192 | } 193 | 194 | func sequentialRollbackTest(t *testing.T, db *sql.DB) error { 195 | tx1, err := db.Begin() 196 | if err != nil { 197 | return err 198 | } 199 | defer tx1.Rollback() 200 | _, err = tx1.Exec(`INSERT INTO app_user(username, email) VALUES ('taro', 'taro@gmail.com')`) 201 | if err != nil { 202 | t.Logf("failed to insert the first taro record: %s", err) 203 | return err 204 | } 205 | tx1.Commit() 206 | 207 | tx2, err := db.Begin() 208 | if err != nil { 209 | return err 210 | } 211 | defer tx2.Rollback() 212 | _, err = tx2.Exec(`INSERT INTO app_user(username, email) VALUES ('taro', 'taro@gmail.com')`) 213 | if err != nil { 214 | t.Logf("successfully failed to insert the second taro record: %s", err) 215 | return err 216 | } 217 | tx2.Commit() 218 | return nil 219 | } 220 | 221 | func TestSavepointRollbackSequential(t *testing.T) { 222 | t.Parallel() 223 | db, err := sql.Open("pgtxdb", "six") 224 | if err != nil { 225 | t.Fatalf("failed to open a postgres connection, have you run 'make test'? err: %s", err) 226 | } 227 | defer db.Close() 228 | 229 | // rollbackTest has to return error since it trys to insert a duplicate record. 230 | // although it returns error, inside it's function the first record is committed. 231 | if err := sequentialRollbackTest(t, db); err == nil { 232 | t.Fatal(err) 233 | } 234 | // Thus, we can retrieve a record from db scope 235 | var count int 236 | err = db.QueryRow(`SELECT count(*) FROM app_user WHERE username = 'taro'`).Scan(&count) 237 | if err != nil { 238 | t.Fatal(err) 239 | } 240 | if count != 1 { 241 | t.Errorf("expected 1 user with username taro, but got %d", count) 242 | } 243 | } 244 | 245 | func TestSavepointRollbackSequentialForOpenDB(t *testing.T) { 246 | t.Parallel() 247 | var dbLog strings.Builder 248 | db := sql.OpenDB(pgtxdb.NewConnector("one", "pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable", &dbLog)) 249 | defer db.Close() 250 | 251 | // rollbackTest has to return error since it trys to insert a duplicate record. 252 | // although it returns error, inside it's function the first record is committed. 253 | if err := sequentialRollbackTest(t, db); err == nil { 254 | t.Fatal(err) 255 | } 256 | // Thus, we can retrieve a record from db scope 257 | var count int 258 | err := db.QueryRow(`SELECT count(*) FROM app_user WHERE username = 'taro'`).Scan(&count) 259 | if err != nil { 260 | t.Fatal(err) 261 | } 262 | if count != 1 { 263 | t.Errorf("expected 1 user with username taro, but got %d", count) 264 | } 265 | 266 | expectedDbLog := `one: open 267 | one: begin 268 | one: exec INSERT INTO app_user(username, email) VALUES ('taro', 'taro@gmail.com') 269 | one: commit 270 | one: begin 271 | one: exec INSERT INTO app_user(username, email) VALUES ('taro', 'taro@gmail.com') 272 | one: rollback 273 | ` 274 | 275 | if dbLog.String() != expectedDbLog { 276 | t.Fatalf("unexpected db log: %s", dbLog.String()) 277 | } 278 | } 279 | 280 | func nestedRollbackTest(t *testing.T, db *sql.DB) error { 281 | tx1, err := db.Begin() 282 | if err != nil { 283 | return err 284 | } 285 | defer tx1.Rollback() 286 | t.Log("tx1 started") 287 | _, err = tx1.Exec(`INSERT INTO app_user(username, email) VALUES ('taro', 'taro@gmail.com')`) 288 | if err != nil { 289 | t.Logf("failed to insert the first taro record: %s", err) 290 | return err 291 | } 292 | tx1.Commit() 293 | t.Log("tx1 committed") 294 | 295 | tx2, err := db.Begin() 296 | if err != nil { 297 | return err 298 | } 299 | defer tx2.Rollback() 300 | t.Log("tx2 started") 301 | 302 | _, err = tx2.Exec(`INSERT INTO app_user(username, email) VALUES ('taro', 'taro@gmail.com')`) 303 | if err != nil { 304 | if eventErr := createErrorEventWithTx(t, tx2, db); eventErr != nil { 305 | return fmt.Errorf("createErrorEvent failed %s", eventErr) 306 | } 307 | return err 308 | } 309 | tx2.Commit() 310 | return nil 311 | } 312 | 313 | func createErrorEventWithTx(t *testing.T, prevTx *sql.Tx, db *sql.DB) error { 314 | // need to rollback error tx before starting new tx 315 | prevTx.Rollback() 316 | 317 | tx, err := db.Begin() 318 | if err != nil { 319 | return err 320 | } 321 | defer tx.Rollback() 322 | 323 | t.Log("error event tx started") 324 | _, err = tx.Exec(`INSERT INTO error_event (message) values ('error creating app_user')`) 325 | if err != nil { 326 | return err 327 | } 328 | tx.Commit() 329 | t.Log("error event tx committed") 330 | return nil 331 | } 332 | 333 | func TestSavepointRollbackNested(t *testing.T) { 334 | t.Parallel() 335 | db, err := sql.Open("pgtxdb", "seven") 336 | if err != nil { 337 | t.Fatalf("failed to open a postgres connection, have you run 'make test'? err: %s", err) 338 | } 339 | defer db.Close() 340 | 341 | if err := nestedRollbackTest(t, db); err == nil { 342 | t.Fatal(err) 343 | } 344 | 345 | var count int 346 | err = db.QueryRow(`SELECT count(*) FROM app_user WHERE username = 'taro'`).Scan(&count) 347 | if err != nil { 348 | t.Fatal(err) 349 | } 350 | if count != 1 { 351 | t.Errorf("expected 1 user with username taro, but got %d", count) 352 | } 353 | var errCount int 354 | err = db.QueryRow(`SELECT count(*) FROM error_event`).Scan(&errCount) 355 | if err != nil { 356 | t.Fatal(err) 357 | } 358 | if errCount != 1 { 359 | t.Errorf("expected 1 error event, but got %d", errCount) 360 | } 361 | } 362 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kanmu/pgtxdb 2 | 3 | go 1.21 4 | 5 | toolchain go1.23.2 6 | 7 | require ( 8 | github.com/jackc/pgx/v5 v5.7.1 9 | github.com/pkg/errors v0.9.1 10 | ) 11 | 12 | require ( 13 | github.com/jackc/pgpassfile v1.0.0 // indirect 14 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 15 | github.com/jackc/puddle/v2 v2.2.2 // indirect 16 | golang.org/x/crypto v0.27.0 // indirect 17 | golang.org/x/sync v0.8.0 // indirect 18 | golang.org/x/text v0.18.0 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 2 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 3 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 5 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 6 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= 7 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 8 | github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= 9 | github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= 10 | github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= 11 | github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 12 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 13 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 14 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 15 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 16 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 17 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 18 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 19 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 20 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 21 | golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= 22 | golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= 23 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= 24 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 25 | golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= 26 | golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 28 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 30 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package pgtxdb_test 2 | 3 | import ( 4 | "database/sql" 5 | "log" 6 | "os" 7 | "testing" 8 | 9 | _ "github.com/jackc/pgx/v5/stdlib" // pgx 10 | ) 11 | 12 | // TestMain service package setup/teardonw 13 | func TestMain(m *testing.M) { 14 | db, err := sql.Open("pgx", "postgres://pgtxdbtest@localhost:5432/pgtxdbtest?sslmode=disable") 15 | if err != nil { 16 | log.Fatalf("failed to connect test db: %s", err.Error()) 17 | } 18 | _, err = db.Exec(` 19 | CREATE TABLE IF NOT EXISTS app_user ( 20 | id BIGSERIAL NOT NULL, 21 | username TEXT NOT NULL, 22 | email TEXT NOT NULL, 23 | PRIMARY KEY (id), 24 | UNIQUE (email) 25 | ); 26 | CREATE TABLE IF NOT EXISTS error_event ( 27 | id BIGSERIAL NOT NULL, 28 | message TEXT NOT NULL, 29 | UNIQUE (id) 30 | ); 31 | `) 32 | if err != nil { 33 | log.Fatalf("failed to create test table: %s", err.Error()) 34 | } 35 | code := m.Run() 36 | _, err = db.Exec(` 37 | DROP TABLE IF EXISTS app_user; 38 | DROP TABLE IF EXISTS error_event; 39 | `) 40 | if err != nil { 41 | log.Fatalf("failed to create test table: %s", err.Error()) 42 | } 43 | os.Exit(code) 44 | } 45 | --------------------------------------------------------------------------------