├── .github └── workflows │ └── test.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── docker-compose.yml ├── go.mod ├── go.sum ├── sqalx.go └── sqalx_test.go /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | name: Test 3 | jobs: 4 | test: 5 | runs-on: ubuntu-latest 6 | strategy: 7 | matrix: 8 | go-version: ["1.21", "1.22"] 9 | 10 | services: 11 | postgres: 12 | image: postgres:13-alpine 13 | ports: 14 | - 5432:5432 15 | env: 16 | POSTGRES_USER: sqalx 17 | POSTGRES_PASSWORD: sqalx 18 | mysql: 19 | image: mysql:8.0 20 | ports: 21 | - 3306:3306 22 | env: 23 | MYSQL_ROOT_PASSWORD: sqalx 24 | MYSQL_USER: sqalx 25 | MYSQL_PASSWORD: sqalx 26 | MYSQL_DATABASE: sqalx 27 | 28 | steps: 29 | - name: Install Go 30 | uses: actions/setup-go@v5 31 | with: 32 | go-version: ${{ matrix.go-version }} 33 | - name: Checkout code 34 | uses: actions/checkout@v4 35 | - name: Test 36 | run: make test 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Heetch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | POSTGRESQL_DATASOURCE ?= postgresql://sqalx:sqalx@localhost:5432/sqalx?sslmode=disable 2 | MYSQL_DATASOURCE ?= sqalx:sqalx@tcp(localhost:3306)/sqalx 3 | SQLITE_DATASOURCE ?= :memory: 4 | 5 | .PHONY: test 6 | 7 | test: 8 | POSTGRESQL_DATASOURCE="$(POSTGRESQL_DATASOURCE)" \ 9 | MYSQL_DATASOURCE="$(MYSQL_DATASOURCE)" \ 10 | SQLITE_DATASOURCE="$(SQLITE_DATASOURCE)" \ 11 | go test -v -cover -race -timeout=1m ./... && echo OK || (echo FAIL && exit 1) 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # :warning: Warning: This repository is considered inactive and no change will be made to it except for security updates. 3 | 4 | # sqalx 5 | 6 | [![GoDoc](https://godoc.org/github.com/heetch/sqalx?status.svg)](https://godoc.org/github.com/heetch/sqalx) 7 | [![Go Report Card](https://goreportcard.com/badge/github.com/heetch/sqalx)](https://goreportcard.com/report/github.com/heetch/sqalx) 8 | 9 | sqalx (pronounced 'scale-x') is a library built on top of [sqlx](https://github.com/jmoiron/sqlx) that allows to seamlessly create nested transactions and to avoid thinking about whether or not a function is called within a transaction. 10 | With sqalx you can easily create reusable and composable functions that can be called within or out of transactions and that can create transactions themselves. 11 | 12 | ## Getting started 13 | 14 | ```sh 15 | $ go get github.com/heetch/sqalx 16 | ``` 17 | 18 | ### Import sqalx 19 | 20 | ```go 21 | import "github.com/heetch/sqalx" 22 | ``` 23 | 24 | ### Usage 25 | 26 | ```go 27 | package main 28 | 29 | import ( 30 | "log" 31 | 32 | "github.com/heetch/sqalx" 33 | "github.com/jmoiron/sqlx" 34 | _ "github.com/lib/pq" 35 | ) 36 | 37 | func main() { 38 | // Connect to PostgreSQL with sqlx. 39 | db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable") 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | 44 | defer db.Close() 45 | 46 | // Pass the db to sqalx. 47 | // It returns a sqalx.Node. A Node is a wrapper around sqlx.DB or sqlx.Tx. 48 | node, err := sqalx.New(db) 49 | if err != nil { 50 | log.Fatal(err) 51 | } 52 | 53 | err = createUser(node) 54 | if err != nil { 55 | log.Fatal(err) 56 | } 57 | } 58 | 59 | func createUser(node sqalx.Node) error { 60 | // Exec a query 61 | _, _ = node.Exec("INSERT INTO ....") // you can use a node as if it were a *sqlx.DB or a *sqlx.Tx 62 | 63 | // Let's create a transaction. 64 | // A transaction is also a sqalx.Node. 65 | tx, err := node.Beginx() 66 | if err != nil { 67 | return err 68 | } 69 | defer tx.Rollback() 70 | 71 | _, _ = tx.Exec("UPDATE ...") 72 | 73 | // Now we call another function and pass it the transaction. 74 | err = updateGroups(tx) 75 | if err != nil { 76 | return nil 77 | } 78 | 79 | return tx.Commit() 80 | } 81 | 82 | func updateGroups(node sqalx.Node) error { 83 | // Notice we are creating a new transaction. 84 | // This would normally cause a dead lock without sqalx. 85 | tx, err := node.Beginx() 86 | if err != nil { 87 | return err 88 | } 89 | defer tx.Rollback() 90 | 91 | _, _ = tx.Exec("INSERT ...") 92 | _, _ = tx.Exec("UPDATE ...") 93 | _, _ = tx.Exec("DELETE ...") 94 | 95 | return tx.Commit() 96 | } 97 | ``` 98 | 99 | ### PostgreSQL Savepoints 100 | 101 | When using the PostgreSQL driver, an option can be passed to `New` to enable the use of PostgreSQL [Savepoints](https://www.postgresql.org/docs/8.1/static/sql-savepoint.html) for nested transactions. 102 | 103 | ```go 104 | node, err := sqalx.New(db, sqalx.SavePoint(true)) 105 | ``` 106 | 107 | ## Issue 108 | Please open an issue if you encounter any problem. 109 | 110 | ## Development 111 | sqalx is covered by a go test suite. In order to test against specific databases we include a docker-compose file that runs Postgres and MySQL. 112 | 113 | ### Running all tests 114 | To run the tests, first run `docker-compose up` to run both Postgres and MySQL in locally-exposed docker images. Then run your tests via `make test` which sets up the above described data sources and runs all tests. 115 | 116 | ### Running specific tests 117 | To test against the Postgres instance be sure to export the following DSN: 118 | 119 | ```sh 120 | export POSTGRESQL_DATASOURCE="postgresql://sqalx:sqalx@localhost:5432/sqalx?sslmode=disable" 121 | ``` 122 | 123 | To test against the MySQL instance be sure to export the following DSN: 124 | 125 | ```sh 126 | export MYSQL_DATASOURCE="sqalx:sqalx@tcp(localhost:3306)/sqalx" 127 | ``` 128 | 129 | To test against SQlite export the following DSN: 130 | 131 | ```sh 132 | export SQLITE_DATASOURCE=":memory:" 133 | ``` 134 | 135 | _Note:_ If you are developing on an M1 Mac you will need to use the officially supported by Oracle image rather than the default `mysql:tag` image. It is commented out in `docker-compose.yml`. 136 | 137 | ## License 138 | The library is released under the MIT license. See [LICENSE](LICENSE) file. 139 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.6' 2 | 3 | services: 4 | postgres: 5 | image: postgres:9.6-alpine 6 | ports: 7 | - 5432:5432 8 | environment: 9 | - POSTGRES_USER=sqalx 10 | - POSTGRES_PASSWORD=sqalx 11 | 12 | mysql: 13 | image: mysql:8.0 # intel only 14 | # image: mysql/mysql-server:8.0 # mac M1 preview 15 | ports: 16 | - 3306:3306 17 | environment: 18 | - MYSQL_ROOT_PASSWORD=sqalx 19 | - MYSQL_USER=sqalx 20 | - MYSQL_PASSWORD=sqalx 21 | - MYSQL_DATABASE=sqalx 22 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/heetch/sqalx 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.2 7 | github.com/go-sql-driver/mysql v1.7.1 8 | github.com/google/uuid v1.6.0 9 | github.com/jackc/pgx/v5 v5.5.4 10 | github.com/jmoiron/sqlx v1.3.5 11 | github.com/lib/pq v1.10.9 12 | github.com/mattn/go-sqlite3 v1.14.22 13 | github.com/stretchr/testify v1.9.0 14 | ) 15 | 16 | require ( 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/jackc/pgpassfile v1.0.0 // indirect 19 | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect 20 | github.com/jackc/puddle/v2 v2.2.1 // indirect 21 | github.com/kr/text v0.2.0 // indirect 22 | github.com/pmezard/go-difflib v1.0.0 // indirect 23 | github.com/rogpeppe/go-internal v1.12.0 // indirect 24 | golang.org/x/crypto v0.21.0 // indirect 25 | golang.org/x/sync v0.1.0 // indirect 26 | golang.org/x/text v0.14.0 // indirect 27 | gopkg.in/yaml.v3 v3.0.1 // indirect 28 | ) 29 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= 2 | github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= 3 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 4 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 8 | github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= 9 | github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= 10 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 11 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 12 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 13 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 14 | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= 15 | github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 16 | github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= 17 | github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= 18 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 19 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 20 | github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= 21 | github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= 22 | github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= 23 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 24 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 25 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 26 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 27 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 28 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 29 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 30 | github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= 31 | github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= 32 | github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 33 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 34 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 35 | github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= 36 | github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= 37 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 38 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 39 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 40 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 41 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 42 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= 43 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= 44 | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= 45 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 46 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 47 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 48 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 49 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 50 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 51 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 52 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 53 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 54 | -------------------------------------------------------------------------------- /sqalx.go: -------------------------------------------------------------------------------- 1 | package sqalx 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "errors" 7 | "strings" 8 | 9 | "github.com/google/uuid" 10 | "github.com/jmoiron/sqlx" 11 | ) 12 | 13 | var ( 14 | // ErrNotInTransaction is returned when using Commit 15 | // outside of a transaction. 16 | ErrNotInTransaction = errors.New("not in transaction") 17 | 18 | // ErrIncompatibleOption is returned when using an option incompatible 19 | // with the selected driver. 20 | ErrIncompatibleOption = errors.New("incompatible option") 21 | ) 22 | 23 | // A Node is a database driver that can manage nested transactions. 24 | type Node interface { 25 | Driver 26 | 27 | // Close the underlying sqlx connection. 28 | Close() error 29 | // Begin a new transaction. 30 | Beginx() (Node, error) 31 | // Begin a new transaction using the provided context and options. 32 | // Note that the provided parameters are only used when opening a new transaction, 33 | // not on nested ones. 34 | BeginTxx(ctx context.Context, opts *sql.TxOptions) (Node, error) 35 | // Rollback the associated transaction. 36 | Rollback() error 37 | // Commit the assiociated transaction. 38 | Commit() error 39 | // Tx returns the underlying transaction. 40 | Tx() *sqlx.Tx 41 | } 42 | 43 | // A Driver can query the database. It can either be a *sqlx.DB or a *sqlx.Tx 44 | // and therefore is limited to the methods they have in common. 45 | type Driver interface { 46 | sqlx.Execer 47 | sqlx.ExecerContext 48 | sqlx.Queryer 49 | sqlx.QueryerContext 50 | sqlx.Preparer 51 | sqlx.PreparerContext 52 | BindNamed(query string, arg interface{}) (string, []interface{}, error) 53 | DriverName() string 54 | Get(dest interface{}, query string, args ...interface{}) error 55 | GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error 56 | MustExec(query string, args ...interface{}) sql.Result 57 | MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result 58 | NamedExec(query string, arg interface{}) (sql.Result, error) 59 | NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) 60 | NamedQuery(query string, arg interface{}) (*sqlx.Rows, error) 61 | PrepareNamed(query string) (*sqlx.NamedStmt, error) 62 | PrepareNamedContext(ctx context.Context, query string) (*sqlx.NamedStmt, error) 63 | Preparex(query string) (*sqlx.Stmt, error) 64 | PreparexContext(ctx context.Context, query string) (*sqlx.Stmt, error) 65 | QueryRow(string, ...interface{}) *sql.Row 66 | QueryRowContext(context.Context, string, ...interface{}) *sql.Row 67 | Rebind(query string) string 68 | Select(dest interface{}, query string, args ...interface{}) error 69 | SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error 70 | } 71 | 72 | // New creates a new Node with the given DB. 73 | func New(db *sqlx.DB, options ...Option) (Node, error) { 74 | n := node{ 75 | db: db, 76 | Driver: db, 77 | } 78 | 79 | for _, opt := range options { 80 | err := opt(&n) 81 | if err != nil { 82 | return nil, err 83 | } 84 | } 85 | 86 | return &n, nil 87 | } 88 | 89 | // NewFromTransaction creates a new Node from the given transaction. 90 | func NewFromTransaction(tx *sqlx.Tx, options ...Option) (Node, error) { 91 | n := node{ 92 | tx: tx, 93 | Driver: tx, 94 | } 95 | 96 | for _, opt := range options { 97 | err := opt(&n) 98 | if err != nil { 99 | return nil, err 100 | } 101 | } 102 | 103 | return &n, nil 104 | } 105 | 106 | // Connect to a database. 107 | func Connect(driverName, dataSourceName string, options ...Option) (Node, error) { 108 | db, err := sqlx.Connect(driverName, dataSourceName) 109 | if err != nil { 110 | return nil, err 111 | } 112 | 113 | node, err := New(db, options...) 114 | if err != nil { 115 | // the connection has been opened within this function, we must close it 116 | // on error. 117 | db.Close() 118 | return nil, err 119 | } 120 | 121 | return node, nil 122 | } 123 | 124 | type node struct { 125 | Driver 126 | db *sqlx.DB 127 | tx *sqlx.Tx 128 | savePointID string 129 | savePointEnabled bool 130 | nested bool 131 | } 132 | 133 | func (n *node) Close() error { 134 | return n.db.Close() 135 | } 136 | 137 | func (n node) Beginx() (Node, error) { 138 | return n.BeginTxx(context.Background(), nil) 139 | } 140 | 141 | func (n node) BeginTxx(ctx context.Context, opts *sql.TxOptions) (Node, error) { 142 | var err error 143 | 144 | switch { 145 | case n.tx == nil: 146 | // new actual transaction 147 | n.tx, err = n.db.BeginTxx(ctx, opts) 148 | n.Driver = n.tx 149 | case n.savePointEnabled: 150 | // already in a transaction: using savepoints 151 | n.nested = true 152 | // savepoints name must start with a char and cannot contain dashes (-) 153 | n.savePointID = "sp_" + strings.Replace(uuid.NewString(), "-", "_", -1) 154 | _, err = n.tx.Exec("SAVEPOINT " + n.savePointID) 155 | default: 156 | // already in a transaction: reusing current transaction 157 | n.nested = true 158 | } 159 | 160 | if err != nil { 161 | return nil, err 162 | } 163 | 164 | return &n, nil 165 | } 166 | 167 | func (n *node) Rollback() error { 168 | if n.tx == nil { 169 | return nil 170 | } 171 | 172 | var err error 173 | 174 | if n.savePointEnabled && n.savePointID != "" { 175 | _, err = n.tx.Exec("ROLLBACK TO SAVEPOINT " + n.savePointID) 176 | } else if !n.nested { 177 | err = n.tx.Rollback() 178 | } 179 | 180 | if err != nil { 181 | return err 182 | } 183 | 184 | n.tx = nil 185 | n.Driver = nil 186 | 187 | return nil 188 | } 189 | 190 | func (n *node) Commit() error { 191 | if n.tx == nil { 192 | return ErrNotInTransaction 193 | } 194 | 195 | var err error 196 | 197 | if n.savePointID != "" { 198 | _, err = n.tx.Exec("RELEASE SAVEPOINT " + n.savePointID) 199 | } else if !n.nested { 200 | err = n.tx.Commit() 201 | } 202 | 203 | if err != nil { 204 | return err 205 | } 206 | 207 | n.tx = nil 208 | n.Driver = nil 209 | 210 | return nil 211 | } 212 | 213 | // Tx returns the underlying transaction. 214 | func (n *node) Tx() *sqlx.Tx { 215 | return n.tx 216 | } 217 | 218 | // Option to configure sqalx 219 | type Option func(*node) error 220 | 221 | // SavePoint option enables PostgreSQL and SQLite Savepoints for nested 222 | // transactions. 223 | func SavePoint(enabled bool) Option { 224 | return func(n *node) error { 225 | driverName := n.Driver.DriverName() 226 | if enabled && driverName != "postgres" && driverName != "pgx" && driverName != "pgx/v5" && driverName != "sqlite3" && driverName != "mysql" { 227 | return ErrIncompatibleOption 228 | } 229 | n.savePointEnabled = enabled 230 | return nil 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /sqalx_test.go: -------------------------------------------------------------------------------- 1 | package sqalx_test 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "testing" 7 | 8 | sqlmock "github.com/DATA-DOG/go-sqlmock" 9 | _ "github.com/go-sql-driver/mysql" 10 | "github.com/heetch/sqalx" 11 | _ "github.com/jackc/pgx/v5/stdlib" 12 | "github.com/jmoiron/sqlx" 13 | _ "github.com/lib/pq" 14 | _ "github.com/mattn/go-sqlite3" 15 | "github.com/stretchr/testify/require" 16 | ) 17 | 18 | func prepareDB(t *testing.T, driverName string) (*sqlx.DB, sqlmock.Sqlmock, func()) { 19 | db, mock, err := sqlmock.New() 20 | require.NoError(t, err) 21 | 22 | return sqlx.NewDb(db, driverName), mock, func() { 23 | db.Close() 24 | } 25 | } 26 | 27 | func TestSqalxConnectPostgreSQL(t *testing.T) { 28 | dataSource := os.Getenv("POSTGRESQL_DATASOURCE") 29 | if dataSource == "" { 30 | t.Log("skipping due to blank POSTGRESQL_DATASOURCE") 31 | t.Skip() 32 | return 33 | } 34 | 35 | testSqalxConnect(t, "postgres", dataSource) 36 | testSqalxConnect(t, "postgres", dataSource, sqalx.SavePoint(true)) 37 | } 38 | func TestSqalxConnectPGX(t *testing.T) { 39 | dataSource := os.Getenv("POSTGRESQL_DATASOURCE") 40 | if dataSource == "" { 41 | t.Log("skipping due to blank POSTGRESQL_DATASOURCE") 42 | t.Skip() 43 | return 44 | } 45 | 46 | testSqalxConnect(t, "pgx", dataSource) 47 | testSqalxConnect(t, "pgx", dataSource, sqalx.SavePoint(true)) 48 | } 49 | 50 | func TestSqalxConnectSqlite(t *testing.T) { 51 | dataSource := os.Getenv("SQLITE_DATASOURCE") 52 | if dataSource == "" { 53 | t.Skip() 54 | return 55 | } 56 | 57 | testSqalxConnect(t, "sqlite3", dataSource) 58 | testSqalxConnect(t, "sqlite3", dataSource, sqalx.SavePoint(true)) 59 | } 60 | 61 | func TestSqalxConnectMySQL(t *testing.T) { 62 | dataSource := os.Getenv("MYSQL_DATASOURCE") 63 | if dataSource == "" { 64 | t.Log("skipping due to blank MYSQL_DATASOURCE") 65 | t.Skip() 66 | return 67 | } 68 | 69 | testSqalxConnect(t, "mysql", dataSource) 70 | testSqalxConnect(t, "mysql", dataSource, sqalx.SavePoint(true)) 71 | } 72 | 73 | func testSqalxConnect(t *testing.T, driverName, dataSource string, options ...sqalx.Option) { 74 | node, err := sqalx.Connect(driverName, dataSource, options...) 75 | require.NoError(t, err) 76 | 77 | err = node.Close() 78 | require.NoError(t, err) 79 | } 80 | 81 | func TestSqalxTransactionViolations(t *testing.T) { 82 | node, err := sqalx.New(nil) 83 | require.NoError(t, err) 84 | 85 | require.Panics(t, func() { 86 | //nolint:errcheck // the intended panic makes error checking irrelevant 87 | node.Exec("UPDATE products SET views = views + 1") 88 | }) 89 | 90 | require.Panics(t, func() { 91 | //nolint:errcheck // the intended panic makes error checking irrelevant 92 | node.Beginx() 93 | }) 94 | 95 | // calling Rollback after a transaction is closed does nothing 96 | err = node.Rollback() 97 | require.NoError(t, err) 98 | 99 | err = node.Commit() 100 | require.Equal(t, err, sqalx.ErrNotInTransaction) 101 | } 102 | 103 | func TestSqalxSimpleQuery(t *testing.T) { 104 | db, mock, cleanup := prepareDB(t, "mock") 105 | defer cleanup() 106 | 107 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 108 | 109 | node, err := sqalx.New(db) 110 | require.NoError(t, err) 111 | 112 | _, err = node.Exec("UPDATE products SET views = views + 1") 113 | require.NoError(t, err) 114 | } 115 | 116 | func TestSqalxTopLevelTransaction(t *testing.T) { 117 | db, mock, cleanup := prepareDB(t, "mock") 118 | defer cleanup() 119 | var err error 120 | 121 | mock.ExpectBegin() 122 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 123 | mock.ExpectCommit() 124 | 125 | node, err := sqalx.New(db) 126 | require.NoError(t, err) 127 | 128 | node, err = node.Beginx() 129 | require.NoError(t, err) 130 | require.NotNil(t, node) 131 | defer func() { 132 | err = node.Rollback() 133 | require.NoError(t, err) 134 | }() 135 | 136 | _, err = node.Exec("UPDATE products SET views = views + 1") 137 | require.NoError(t, err) 138 | 139 | err = node.Commit() 140 | require.NoError(t, err) 141 | } 142 | 143 | func TestSqalxNestedTransactions(t *testing.T) { 144 | testSqalxNestedTransactions(t, "mock", false) 145 | } 146 | 147 | func TestSqalxNestedTransactionsWithSavePoint(t *testing.T) { 148 | for _, driver := range []string{ 149 | "postgres", 150 | "pgx", 151 | "sqlite3", 152 | "mysql", 153 | } { 154 | t.Run(driver, func(t *testing.T) { 155 | testSqalxNestedTransactions(t, driver, true) 156 | }) 157 | } 158 | } 159 | 160 | func testSqalxNestedTransactions(t *testing.T, driverName string, testSavePoint bool) { 161 | db, mock, cleanup := prepareDB(t, driverName) 162 | defer cleanup() 163 | 164 | require.Equal(t, driverName, db.DriverName()) 165 | 166 | var err error 167 | const query = "UPDATE products SET views = views + 1" 168 | 169 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 170 | mock.ExpectBegin() 171 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 172 | if testSavePoint { 173 | mock.ExpectExec("SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1)) 174 | } 175 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 176 | if testSavePoint { 177 | mock.ExpectExec("ROLLBACK TO SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1)) 178 | } 179 | if testSavePoint { 180 | mock.ExpectExec("SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1)) 181 | } 182 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 183 | if testSavePoint { 184 | mock.ExpectExec("RELEASE SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1)) 185 | } 186 | mock.ExpectCommit() 187 | 188 | node, err := sqalx.New(db, sqalx.SavePoint(testSavePoint)) 189 | require.NoError(t, err) 190 | 191 | _, err = node.Exec(query) 192 | require.NoError(t, err) 193 | 194 | n1, err := node.Beginx() 195 | require.NoError(t, err) 196 | require.NotNil(t, n1) 197 | 198 | _, err = n1.Exec(query) 199 | require.NoError(t, err) 200 | 201 | n1_1, err := n1.Beginx() 202 | require.NoError(t, err) 203 | require.NotNil(t, n1_1) 204 | 205 | _, err = n1_1.Exec(query) 206 | require.NoError(t, err) 207 | 208 | err = n1_1.Rollback() 209 | require.NoError(t, err) 210 | 211 | err = n1_1.Commit() 212 | require.Equal(t, sqalx.ErrNotInTransaction, err) 213 | 214 | n1_1, err = n1.BeginTxx(context.Background(), nil) 215 | require.NoError(t, err) 216 | require.NotNil(t, n1_1) 217 | 218 | _, err = n1_1.Exec(query) 219 | require.NoError(t, err) 220 | 221 | err = n1_1.Commit() 222 | require.NoError(t, err) 223 | 224 | err = n1_1.Commit() 225 | require.Equal(t, sqalx.ErrNotInTransaction, err) 226 | 227 | err = n1_1.Rollback() 228 | require.NoError(t, err) 229 | 230 | err = n1.Commit() 231 | require.NoError(t, err) 232 | } 233 | 234 | func TestSqalxFromTransaction(t *testing.T) { 235 | db, mock, cleanup := prepareDB(t, "mock") 236 | defer cleanup() 237 | 238 | mock.ExpectBegin() 239 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 240 | mock.ExpectExec("UPDATE products").WillReturnResult(sqlmock.NewResult(1, 1)) 241 | mock.ExpectRollback() 242 | 243 | tx, err := db.Beginx() 244 | require.NoError(t, err) 245 | 246 | node, err := sqalx.NewFromTransaction(tx) 247 | require.NoError(t, err) 248 | 249 | _, err = node.Exec("UPDATE products SET views = views + 1") 250 | require.NoError(t, err) 251 | 252 | ntx, err := node.Beginx() 253 | require.NoError(t, err) 254 | _, err = ntx.Exec("UPDATE products SET views = views + 1") 255 | require.NoError(t, err) 256 | 257 | err = ntx.Rollback() 258 | require.NoError(t, err) 259 | 260 | err = node.Rollback() 261 | require.NoError(t, err) 262 | } 263 | --------------------------------------------------------------------------------