├── .github └── workflows │ └── test.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── doc.go ├── example └── example.go ├── go.mod ├── go.sum ├── gormtx ├── gormtx.go └── gormtx_test.go ├── option.go ├── pgxtxv5 ├── option.go ├── pgx.go └── pgx_test.go ├── sqltx ├── sql.go └── sql_test.go ├── testtx ├── testtx.go └── testtx_test.go ├── testutil ├── db.go ├── db │ └── schema.sql ├── int.go └── mocks │ ├── db.go │ ├── option.go │ ├── transaction.go │ └── transactor.go ├── tx.go └── tx_test.go /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Go Test 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: [trunk] 7 | pull_request: 8 | branches: [trunk] 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v4 18 | with: 19 | go-version: 1.22 20 | 21 | - name: Test 22 | run: make test 23 | 24 | - name: Install goveralls 25 | run: go install github.com/mattn/goveralls@latest 26 | 27 | - name: Coveralls 28 | env: 29 | COVERALLS_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} 30 | run: goveralls -coverprofile=profile.cov -service=github -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .DS_Store 3 | *.out 4 | *.cov -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Anes Hasicic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | @go test -tags="integration" -race -coverprofile=profile.cov -v $(shell go list ./... | grep -vE 'example|mocks|testdata|testutil') 3 | @go tool cover -func=profile.cov | grep total 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tx 2 | [![Go Test](https://github.com/aneshas/tx/actions/workflows/test.yml/badge.svg)](https://github.com/aneshas/tx/actions/workflows/test.yml) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/aneshas/tx)](https://goreportcard.com/report/github.com/aneshas/tx) 4 | [![Coverage Status](https://coveralls.io/repos/github/aneshas/tx/badge.svg)](https://coveralls.io/github/aneshas/tx) 5 | [![Go Reference](https://pkg.go.dev/badge/github.com/aneshas/tx.svg)](https://pkg.go.dev/github.com/aneshas/tx) 6 | 7 | `go get github.com/aneshas/tx/v2@latest` 8 | 9 | Package tx provides a simple abstraction which leverages `context.Context` in order to provide a transactional behavior 10 | which one could use in their use case orchestrator (eg. application service, command handler, etc...). You might think of it 11 | as closest thing in `Go` to `@Transactional` annotation in Java or the way you could scope a transaction in `C#`. 12 | 13 | Many people tend to implement this pattern in one way or another (I have seen it and did it quite a few times), and 14 | more often then not, the implementations still tend to couple your use case orchestrator with your database adapters (eg. repositories) or 15 | on the other hand, rely to heavily on `context.Context` and end up using it as a dependency injection mechanism. 16 | 17 | This package relies on `context.Context` in order to simply pass the database transaction down the stack in a safe and clean way which 18 | still does not violate the reasoning behind context package - which is to carry `request scoped` data across api boundaries - which is 19 | a database transaction in this case. 20 | 21 | ## Drivers 22 | Library currently supports `pgx`, `gorm` and stdlib `sql` out of the box although it is very easy to implement any additional ones 23 | you might need. 24 | 25 | ## Example 26 | Let's assume we have the following very common setup of an example account service which has a dependency to account repository. 27 | 28 | ```go 29 | type Repo interface { 30 | Save(ctx context.Context, account Account) error 31 | Find(ctx context.Context, id int) (*Account, error) 32 | } 33 | 34 | func NewAccountService(transactor tx.Transactor, repo Repo) *AccountService { 35 | return &AccountService{ 36 | Transactor: transactor, 37 | repo: repo, 38 | } 39 | } 40 | 41 | type AccountService struct { 42 | // Embedding Transactor interface in order to decorate the service with transactional behavior, 43 | // although you can choose how and when you use it freely 44 | tx.Transactor 45 | 46 | repo Repo 47 | } 48 | 49 | type ProvisionAccountReq struct { 50 | // ... 51 | } 52 | 53 | func (s *AccountService) ProvisionAccount(ctx context.Context, r ProvisionAccountReq) error { 54 | return s.WithTransaction(ctx, func (ctx context.Context) error { 55 | // ctx contains an embedded transaction and as long as 56 | // we pass it to our repo methods, they will be able to unwrap it and use it 57 | 58 | // eg. multiple calls to the same or different repos 59 | 60 | return s.repo.Save(ctx, Account{ 61 | // ... 62 | }) 63 | }) 64 | } 65 | ``` 66 | 67 | You will notice that the service looks mostly the same as it would normally apart from embedding `Transactor` interface 68 | and wrapping the use case execution using `WithTransaction`, both of which say nothing of the way the mechanism is implemented (no infrastructure dependencies). 69 | 70 | If the function wrapped via `WithTransaction` errors out or panics the transaction itself will be rolled back and if nil error is 71 | returned the transaction will be committed. (this behavior can be changed by providing `WithIgnoredErrors(...)` option to `tx.New`) 72 | 73 | ### Repo implementation 74 | Then, your repo might use postgres with pgx and have the following example implementation: 75 | 76 | ```go 77 | func NewAccountRepo(pool *pgxpool.Pool) *AccountRepo { 78 | return &AccountRepo{ 79 | pool: pool, 80 | } 81 | } 82 | 83 | type AccountRepo struct { 84 | pool *pgxpool.Pool 85 | } 86 | 87 | func (r *AccountRepo) Save(ctx context.Context, account Account) error { 88 | _, err := r.conn(ctx).Exec(ctx, "...") 89 | 90 | return err 91 | } 92 | 93 | func (r *AccountRepo) Find(ctx context.Context, id int) (*Account, error) { 94 | rows, err := r.conn(ctx).Query(ctx, "...") 95 | if err != nil { 96 | return nil, err 97 | } 98 | 99 | _ = rows 100 | 101 | return nil, nil 102 | } 103 | 104 | type Conn interface { 105 | Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) 106 | Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) 107 | } 108 | 109 | func (r *AccountRepo) conn(ctx context.Context) Conn { 110 | if tx, ok := pgxtxv5.From(ctx); ok { 111 | return tx 112 | } 113 | 114 | return r.pool 115 | } 116 | ``` 117 | 118 | Again, you may freely choose how you implement this and whether or not you actually do use the wrapped 119 | transaction or not. 120 | 121 | ### main 122 | Then your main function would simply tie everything together like this for example: 123 | 124 | ```go 125 | func main() { 126 | var pool *pgxpool.Pool 127 | 128 | svc := NewAccountService( 129 | tx.New(pgxtxv5.NewDBFromPool(pool)), 130 | NewAccountRepo(pool), 131 | ) 132 | 133 | _ = svc 134 | } 135 | ``` 136 | 137 | This way, your infrastructural concerns stay in the infrastructure layer where they really belong. 138 | 139 | *Please note that this is only one way of using the abstraction* 140 | 141 | ## Testing 142 | You can use `testtx.New()` as a convenient test helper. It creates a test transactor which only calls f without 143 | setting any sort of transaction in the `ctx` and preserves any errors raised by f in `.Err` field. 144 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package tx provides a simple transaction abstraction in order to enable decoupling / abstraction of persistence from 2 | // application / domain logic while still leaving transaction control to the application service / use case coordinator 3 | // (Something like @Transactional annotation in Java, without the annotation) 4 | package tx 5 | 6 | //go:generate mockery --all --with-expecter --case=underscore --output ./testutil/mocks 7 | -------------------------------------------------------------------------------- /example/example.go: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | import ( 4 | "context" 5 | "github.com/aneshas/tx/v2" 6 | "github.com/aneshas/tx/v2/pgxtxv5" 7 | "github.com/jackc/pgx/v5" 8 | "github.com/jackc/pgx/v5/pgconn" 9 | "github.com/jackc/pgx/v5/pgxpool" 10 | ) 11 | 12 | func main() { 13 | var pool *pgxpool.Pool 14 | 15 | svc := NewAccountService( 16 | tx.New(pgxtxv5.NewDBFromPool(pool)), 17 | NewAccountRepo(pool), 18 | ) 19 | 20 | _ = svc 21 | } 22 | 23 | type Account struct { 24 | // ... 25 | } 26 | 27 | type Repo interface { 28 | Save(ctx context.Context, account Account) error 29 | Find(ctx context.Context, id int) (*Account, error) 30 | } 31 | 32 | func NewAccountService(transactor tx.Transactor, repo Repo) *AccountService { 33 | return &AccountService{Transactor: transactor, repo: repo} 34 | } 35 | 36 | type AccountService struct { 37 | // Embedding transactional behavior in your service 38 | tx.Transactor 39 | 40 | repo Repo 41 | } 42 | 43 | type ProvisionAccountReq struct { 44 | // ... 45 | } 46 | 47 | func (s *AccountService) ProvisionAccount(ctx context.Context, r ProvisionAccountReq) error { 48 | return s.WithTransaction(ctx, func(ctx context.Context) error { 49 | // ctx contains an embedded transaction and as long as 50 | // we pass it to our repo methods, they will be able to unwrap it and use it 51 | 52 | // eg. multiple calls to different repos 53 | 54 | return s.repo.Save(ctx, Account{ 55 | // ... 56 | }) 57 | }) 58 | } 59 | 60 | func NewAccountRepo(pool *pgxpool.Pool) *AccountRepo { 61 | return &AccountRepo{ 62 | pool: pool, 63 | } 64 | } 65 | 66 | type AccountRepo struct { 67 | pool *pgxpool.Pool 68 | } 69 | 70 | func (r *AccountRepo) Save(ctx context.Context, account Account) error { 71 | _, err := r.conn(ctx).Exec(ctx, "...") 72 | 73 | return err 74 | } 75 | 76 | func (r *AccountRepo) Find(ctx context.Context, id int) (*Account, error) { 77 | rows, err := r.conn(ctx).Query(ctx, "...") 78 | if err != nil { 79 | return nil, err 80 | } 81 | 82 | _ = rows 83 | 84 | return nil, nil 85 | } 86 | 87 | type Conn interface { 88 | Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) 89 | Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) 90 | } 91 | 92 | func (r *AccountRepo) conn(ctx context.Context) Conn { 93 | if tx, ok := pgxtxv5.From(ctx); ok { 94 | return tx 95 | } 96 | 97 | return r.pool 98 | } 99 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/aneshas/tx/v2 2 | 3 | go 1.21.0 4 | 5 | require ( 6 | github.com/jackc/pgx/v5 v5.5.5 7 | github.com/orlangure/gnomock v0.30.0 8 | github.com/stretchr/testify v1.8.4 9 | ) 10 | 11 | require ( 12 | github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect 13 | github.com/Microsoft/go-winio v0.5.2 // indirect 14 | github.com/davecgh/go-spew v1.1.1 // indirect 15 | github.com/docker/distribution v2.8.2+incompatible // indirect 16 | github.com/docker/docker v24.0.5+incompatible // indirect 17 | github.com/docker/go-connections v0.4.0 // indirect 18 | github.com/docker/go-units v0.4.0 // indirect 19 | github.com/gogo/protobuf v1.3.2 // indirect 20 | github.com/google/uuid v1.3.1 // indirect 21 | github.com/jackc/pgpassfile v1.0.0 // indirect 22 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect 23 | github.com/jackc/puddle/v2 v2.2.1 // indirect 24 | github.com/jinzhu/inflection v1.0.0 // indirect 25 | github.com/jinzhu/now v1.1.5 // indirect 26 | github.com/lib/pq v1.10.9 // indirect 27 | github.com/opencontainers/go-digest v1.0.0 // indirect 28 | github.com/opencontainers/image-spec v1.0.2 // indirect 29 | github.com/pkg/errors v0.9.1 // indirect 30 | github.com/pmezard/go-difflib v1.0.0 // indirect 31 | github.com/rogpeppe/go-internal v1.6.1 // indirect 32 | github.com/stretchr/objx v0.5.0 // indirect 33 | go.uber.org/multierr v1.10.0 // indirect 34 | go.uber.org/zap v1.25.0 // indirect 35 | golang.org/x/crypto v0.21.0 // indirect 36 | golang.org/x/net v0.21.0 // indirect 37 | golang.org/x/sync v0.10.0 // indirect 38 | golang.org/x/sys v0.18.0 // indirect 39 | golang.org/x/text v0.21.0 // indirect 40 | gopkg.in/yaml.v3 v3.0.1 // indirect 41 | gorm.io/driver/postgres v1.5.11 // indirect 42 | gorm.io/gorm v1.25.12 // indirect 43 | ) 44 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= 2 | github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= 3 | github.com/Microsoft/go-winio v0.5.2 h1:a9IhgEQBCUEk6QCdml9CiJGhAws+YwffDHEMp1VMrpA= 4 | github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= 5 | github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= 6 | github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= 7 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 9 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 10 | github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= 11 | github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= 12 | github.com/docker/docker v24.0.5+incompatible h1:WmgcE4fxyI6EEXxBRxsHnZXrO1pQ3smi0k/jho4HLeY= 13 | github.com/docker/docker v24.0.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= 14 | github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= 15 | github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= 16 | github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw= 17 | github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= 18 | github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= 19 | github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= 20 | github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= 21 | github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 22 | github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= 23 | github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 24 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 25 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 26 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= 27 | github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 28 | github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= 29 | github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= 30 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 31 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 32 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 33 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 34 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= 35 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 36 | github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= 37 | github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= 38 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 39 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 40 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 41 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 42 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 43 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 44 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 45 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 46 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 47 | github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 h1:dcztxKSvZ4Id8iPpHERQBbIJfabdt4wUm5qy3wOL2Zc= 48 | github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6/go.mod h1:E2VnQOmVuvZB6UYnnDB0qG5Nq/1tD9acaOpo6xmt0Kw= 49 | github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= 50 | github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= 51 | github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= 52 | github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= 53 | github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrBg0D7ufOcFM= 54 | github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= 55 | github.com/orlangure/gnomock v0.30.0 h1:WXq/3KTKRVYe9a3BXa5JMZCCrg2RwNAPB2bZHMxEntE= 56 | github.com/orlangure/gnomock v0.30.0/go.mod h1:vDur9icFVsecjDQrHn06SbUs0BXjJaNJRDexBsPh5f4= 57 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 58 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 59 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 60 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 61 | github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= 62 | github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= 63 | github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= 64 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 65 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 66 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 67 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 68 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 69 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 70 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 71 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 72 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 73 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 74 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 75 | github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 76 | github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= 77 | go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= 78 | go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= 79 | go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= 80 | go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= 81 | go.uber.org/zap v1.25.0 h1:4Hvk6GtkucQ790dqmj7l1eEnRdKm3k3ZUrUMS2d5+5c= 82 | go.uber.org/zap v1.25.0/go.mod h1:JIAUzQIH94IC4fOJQm7gMmBJP5k7wQfdcnYdPoEXJYk= 83 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 84 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 85 | golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 86 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= 87 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= 88 | golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 89 | golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 90 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 91 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 92 | golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 93 | golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= 94 | golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= 95 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 96 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 97 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 98 | golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 99 | golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= 100 | golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= 101 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 102 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 103 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 104 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 105 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 106 | golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 107 | golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 108 | golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 109 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= 110 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 111 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 112 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 113 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 114 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 115 | golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= 116 | golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= 117 | golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= 118 | golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 119 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 120 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 121 | golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= 122 | golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= 123 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 124 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 125 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 126 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 127 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 128 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 129 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 130 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 131 | gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= 132 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 133 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 134 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 135 | gorm.io/driver/postgres v1.5.11 h1:ubBVAfbKEUld/twyKZ0IYn9rSQh448EdelLYk9Mv314= 136 | gorm.io/driver/postgres v1.5.11/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= 137 | gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= 138 | gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= 139 | gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= 140 | gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= 141 | -------------------------------------------------------------------------------- /gormtx/gormtx.go: -------------------------------------------------------------------------------- 1 | package gormtx 2 | 3 | import ( 4 | "context" 5 | "github.com/aneshas/tx/v2" 6 | "gorm.io/gorm" 7 | ) 8 | 9 | var ( 10 | _ tx.DB = &DB{} 11 | _ tx.Transaction = &Tx{} 12 | ) 13 | 14 | // NewDB instantiates new tx.DB *gorm.DB wrapper 15 | func NewDB(db *gorm.DB) tx.DB { 16 | return &DB{DB: db} 17 | } 18 | 19 | // DB implements tx.DB 20 | type DB struct { 21 | *gorm.DB 22 | } 23 | 24 | // Begin begins gorm transaction 25 | func (db *DB) Begin(ctx context.Context) (tx.Transaction, error) { 26 | txx := db.WithContext(ctx).Begin() 27 | if txx.Error != nil { 28 | return nil, txx.Error 29 | } 30 | 31 | return &Tx{txx}, nil 32 | } 33 | 34 | // Tx wraps *gorm.DB in order top implement tx.Transaction 35 | type Tx struct { 36 | *gorm.DB 37 | } 38 | 39 | // Commit commits the transaction 40 | func (t Tx) Commit(_ context.Context) error { 41 | return t.DB.Commit().Error 42 | } 43 | 44 | // Rollback rolls back the transaction 45 | func (t Tx) Rollback(_ context.Context) error { 46 | return t.DB.Rollback().Error 47 | } 48 | 49 | // From returns underlying *gorm.DB (wrapped in *Tx) 50 | func From(ctx context.Context) (*Tx, bool) { 51 | return tx.From[*Tx](ctx) 52 | } 53 | -------------------------------------------------------------------------------- /gormtx/gormtx_test.go: -------------------------------------------------------------------------------- 1 | //go:build integration 2 | // +build integration 3 | 4 | package gormtx_test 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "github.com/aneshas/tx/v2" 10 | "github.com/aneshas/tx/v2/gormtx" 11 | "github.com/aneshas/tx/v2/testutil" 12 | "github.com/jackc/pgx/v5/pgxpool" 13 | "github.com/stretchr/testify/assert" 14 | "gorm.io/driver/postgres" 15 | "gorm.io/gorm" 16 | "testing" 17 | ) 18 | 19 | var ( 20 | pool *pgxpool.Pool 21 | db *gorm.DB 22 | ) 23 | 24 | func TestMain(m *testing.M) { 25 | t := new(testing.T) 26 | 27 | p, sqlDB := testutil.SetupDB(t) 28 | 29 | gormDB, err := gorm.Open(postgres.New(postgres.Config{ 30 | Conn: sqlDB, 31 | }), &gorm.Config{}) 32 | 33 | assert.NoError(t, err) 34 | 35 | pool = p 36 | db = gormDB 37 | 38 | m.Run() 39 | } 40 | 41 | func TestShould_Commit_Sql_Transaction(t *testing.T) { 42 | name := "success_sql" 43 | 44 | doSql(t, tx.New(gormtx.NewDB(db)), name, false) 45 | testutil.AssertSuccess(t, pool, name) 46 | } 47 | 48 | func TestShould_Rollback_Sql_Transaction(t *testing.T) { 49 | name := "failure_sql" 50 | 51 | doSql(t, tx.New(gormtx.NewDB(db)), name, true) 52 | testutil.AssertFailure(t, pool, name) 53 | } 54 | 55 | func doSql(t *testing.T, transactor *tx.TX, name string, fail bool) { 56 | t.Helper() 57 | 58 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 59 | ttx, _ := gormtx.From(ctx) 60 | 61 | db := ttx.Exec(`insert into cats (name) values(?)`, name) 62 | if db.Error != nil { 63 | return db.Error 64 | } 65 | 66 | if fail { 67 | return fmt.Errorf("db error") 68 | } 69 | 70 | return db.Error 71 | }) 72 | 73 | if !fail { 74 | assert.NoError(t, err) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /option.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | type Option func(tx *TX) 4 | 5 | // WithIgnoredErrors offers a way to provide a list of errors which will 6 | // not cause the transaction to be rolled back. 7 | // 8 | // The transaction will still be committed but the actual error will be returned 9 | // by the WithTransaction method. 10 | func WithIgnoredErrors(errs ...error) Option { 11 | return func(tx *TX) { 12 | tx.ignoreErrs = append(tx.ignoreErrs, errs...) 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /pgxtxv5/option.go: -------------------------------------------------------------------------------- 1 | package pgxtxv5 2 | 3 | import "github.com/jackc/pgx/v5" 4 | 5 | // PgxTxOption represents pgx driver transaction option 6 | type PgxTxOption func(pool *Pool) 7 | 8 | // WithTxOptions allows us to set transaction options (eg. isolation level) 9 | func WithTxOptions(txOptions pgx.TxOptions) PgxTxOption { 10 | return func(pool *Pool) { 11 | pool.txOpts = txOptions 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /pgxtxv5/pgx.go: -------------------------------------------------------------------------------- 1 | package pgxtxv5 2 | 3 | import ( 4 | "context" 5 | "github.com/aneshas/tx/v2" 6 | "github.com/jackc/pgx/v5" 7 | "github.com/jackc/pgx/v5/pgxpool" 8 | ) 9 | 10 | var _ tx.DB = &Pool{} 11 | 12 | // NewDBFromPool instantiates new tx.DB *pgxpool.Pool wrapper 13 | func NewDBFromPool(pool *pgxpool.Pool, opts ...PgxTxOption) tx.DB { 14 | p := Pool{ 15 | Pool: pool, 16 | } 17 | 18 | for _, opt := range opts { 19 | opt(&p) 20 | } 21 | 22 | return &p 23 | } 24 | 25 | // Pool implements tx.DB 26 | type Pool struct { 27 | *pgxpool.Pool 28 | 29 | txOpts pgx.TxOptions 30 | } 31 | 32 | // Begin begins pgx transaction 33 | func (p *Pool) Begin(ctx context.Context) (tx.Transaction, error) { 34 | return p.Pool.BeginTx(ctx, p.txOpts) 35 | } 36 | 37 | // From returns underlying pgx.Tx from the context. 38 | // If you need to obtain a different interface back see tx.From 39 | func From(ctx context.Context) (pgx.Tx, bool) { 40 | return tx.From[pgx.Tx](ctx) 41 | } 42 | -------------------------------------------------------------------------------- /pgxtxv5/pgx_test.go: -------------------------------------------------------------------------------- 1 | //go:build integration 2 | // +build integration 3 | 4 | package pgxtxv5_test 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "github.com/aneshas/tx/v2" 10 | "github.com/aneshas/tx/v2/pgxtxv5" 11 | "github.com/aneshas/tx/v2/testutil" 12 | "github.com/jackc/pgx/v5" 13 | "github.com/jackc/pgx/v5/pgxpool" 14 | "github.com/stretchr/testify/assert" 15 | "testing" 16 | ) 17 | 18 | var db *pgxpool.Pool 19 | 20 | func TestMain(m *testing.M) { 21 | t := new(testing.T) 22 | 23 | db, _ = testutil.SetupDB(t) 24 | 25 | m.Run() 26 | } 27 | 28 | func TestShould_Commit_Pgx_Transaction(t *testing.T) { 29 | name := "success_pgx" 30 | 31 | doPgx( 32 | t, 33 | tx.New( 34 | pgxtxv5.NewDBFromPool( 35 | db, 36 | pgxtxv5.WithTxOptions(pgx.TxOptions{}), 37 | ), 38 | ), 39 | name, 40 | false, 41 | ) 42 | 43 | testutil.AssertSuccess(t, db, name) 44 | } 45 | 46 | func TestShould_Rollback_Pgx_Transaction(t *testing.T) { 47 | name := "failure_pgx" 48 | 49 | doPgx(t, tx.New(pgxtxv5.NewDBFromPool(db)), name, true) 50 | testutil.AssertFailure(t, db, name) 51 | } 52 | 53 | func doPgx(t *testing.T, transactor *tx.TX, name string, fail bool) { 54 | t.Helper() 55 | 56 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 57 | ttx, _ := pgxtxv5.From(ctx) 58 | 59 | _, err := ttx.Exec(ctx, `insert into cats (name) values($1)`, name) 60 | if err != nil { 61 | return err 62 | } 63 | 64 | if fail { 65 | return fmt.Errorf("db error") 66 | } 67 | 68 | return err 69 | }) 70 | 71 | if !fail { 72 | assert.NoError(t, err) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /sqltx/sql.go: -------------------------------------------------------------------------------- 1 | package sqltx 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "github.com/aneshas/tx/v2" 7 | ) 8 | 9 | var ( 10 | _ tx.DB = &DB{} 11 | _ tx.Transaction = &Tx{} 12 | ) 13 | 14 | // NewDB instantiates new tx.DB *sql.DB wrapper 15 | func NewDB(db *sql.DB) tx.DB { 16 | return &DB{db} 17 | } 18 | 19 | // DB implements tx.DB 20 | type DB struct { 21 | *sql.DB 22 | } 23 | 24 | // Begin begins sql transaction 25 | func (p *DB) Begin(_ context.Context) (tx.Transaction, error) { 26 | txx, err := p.DB.Begin() 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | return &Tx{txx}, nil 32 | } 33 | 34 | // Tx wraps *sql.TX in order top implement tx.Transaction 35 | type Tx struct { 36 | *sql.Tx 37 | } 38 | 39 | // Commit commits the transaction 40 | func (p *Tx) Commit(_ context.Context) error { 41 | return p.Tx.Commit() 42 | } 43 | 44 | // Rollback rolls back the transaction 45 | func (p *Tx) Rollback(_ context.Context) error { 46 | return p.Tx.Rollback() 47 | } 48 | 49 | // From returns underlying *sql.Tx (wrapped in *Tx) 50 | func From(ctx context.Context) (*Tx, bool) { 51 | return tx.From[*Tx](ctx) 52 | } 53 | -------------------------------------------------------------------------------- /sqltx/sql_test.go: -------------------------------------------------------------------------------- 1 | //go:build integration 2 | // +build integration 3 | 4 | package sqltx_test 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "fmt" 10 | "github.com/aneshas/tx/v2" 11 | "github.com/aneshas/tx/v2/sqltx" 12 | "github.com/aneshas/tx/v2/testutil" 13 | "github.com/jackc/pgx/v5/pgxpool" 14 | "github.com/stretchr/testify/assert" 15 | "testing" 16 | ) 17 | 18 | var ( 19 | pool *pgxpool.Pool 20 | db *sql.DB 21 | ) 22 | 23 | func TestMain(m *testing.M) { 24 | t := new(testing.T) 25 | 26 | pool, db = testutil.SetupDB(t) 27 | 28 | m.Run() 29 | } 30 | 31 | func TestShould_Commit_Sql_Transaction(t *testing.T) { 32 | name := "success_sql" 33 | 34 | doSql(t, tx.New(sqltx.NewDB(db)), name, false) 35 | testutil.AssertSuccess(t, pool, name) 36 | } 37 | 38 | func TestShould_Rollback_Sql_Transaction(t *testing.T) { 39 | name := "failure_sql" 40 | 41 | doSql(t, tx.New(sqltx.NewDB(db)), name, true) 42 | testutil.AssertFailure(t, pool, name) 43 | } 44 | 45 | func doSql(t *testing.T, transactor *tx.TX, name string, fail bool) { 46 | t.Helper() 47 | 48 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 49 | ttx, _ := sqltx.From(ctx) 50 | 51 | _, err := ttx.Exec(`insert into cats (name) values($1)`, name) 52 | if err != nil { 53 | return err 54 | } 55 | 56 | if fail { 57 | return fmt.Errorf("db error") 58 | } 59 | 60 | return err 61 | }) 62 | 63 | if !fail { 64 | assert.NoError(t, err) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /testtx/testtx.go: -------------------------------------------------------------------------------- 1 | package testtx 2 | 3 | import "context" 4 | 5 | // New creates a new TX 6 | func New() *TX { 7 | return &TX{} 8 | } 9 | 10 | // TX is a noop test implementation of tx.DB 11 | type TX struct { 12 | Err error 13 | } 14 | 15 | // WithTransaction is a noop test implementation of tx.DB 16 | func (t *TX) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error { 17 | t.Err = f(ctx) 18 | 19 | return t.Err 20 | } 21 | -------------------------------------------------------------------------------- /testtx/testtx_test.go: -------------------------------------------------------------------------------- 1 | package testtx_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/aneshas/tx/v2" 7 | "github.com/aneshas/tx/v2/testtx" 8 | "github.com/stretchr/testify/assert" 9 | "testing" 10 | ) 11 | 12 | type svc struct { 13 | tx.Transactor 14 | 15 | DidSomething bool 16 | } 17 | 18 | func (s *svc) doSomething(ctx context.Context, err error) error { 19 | return s.WithTransaction(ctx, func(ctx context.Context) error { 20 | s.DidSomething = true 21 | 22 | return err 23 | }) 24 | } 25 | 26 | func TestShould_Delegate_Call(t *testing.T) { 27 | transactor := testtx.New() 28 | 29 | s := &svc{ 30 | Transactor: transactor, 31 | } 32 | 33 | err := s.doSomething(context.TODO(), nil) 34 | 35 | assert.NoError(t, err) 36 | assert.True(t, s.DidSomething) 37 | } 38 | 39 | func TestShould_Save_Err(t *testing.T) { 40 | transactor := testtx.New() 41 | 42 | s := &svc{ 43 | Transactor: transactor, 44 | } 45 | 46 | wantErr := fmt.Errorf("something bad ocurred") 47 | 48 | err := s.doSomething(context.TODO(), wantErr) 49 | 50 | assert.ErrorIs(t, err, wantErr) 51 | assert.ErrorIs(t, transactor.Err, wantErr) 52 | } 53 | -------------------------------------------------------------------------------- /testutil/db.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import ( 4 | "github.com/aneshas/tx/v2/testutil/mocks" 5 | "github.com/stretchr/testify/mock" 6 | "testing" 7 | ) 8 | 9 | func NewDB(t *testing.T, opts ...Option) *DB { 10 | db := DB{ 11 | t: t, 12 | DB: mocks.NewDB(t), 13 | } 14 | 15 | for _, opt := range opts { 16 | opt(&db) 17 | } 18 | 19 | return &db 20 | } 21 | 22 | type Option func(db *DB) 23 | 24 | func WithUnsuccessfulTransactionStart(with error) Option { 25 | return func(db *DB) { 26 | db.EXPECT().Begin(mock.Anything).Return(nil, with).Once() 27 | } 28 | } 29 | 30 | func WithSuccessfulCommit() Option { 31 | return func(db *DB) { 32 | tx := mocks.NewTransaction(db.t) 33 | tx.EXPECT().Commit(mock.Anything).Return(nil).Once() 34 | db.EXPECT().Begin(mock.Anything).Return(tx, nil).Once() 35 | } 36 | } 37 | 38 | func WithSuccessfulRollback() Option { 39 | return func(db *DB) { 40 | tx := mocks.NewTransaction(db.t) 41 | tx.EXPECT().Rollback(mock.Anything).Return(nil).Once() 42 | db.EXPECT().Begin(mock.Anything).Return(tx, nil).Once() 43 | } 44 | } 45 | 46 | func WithUnsuccessfulRollback(with error) Option { 47 | return func(db *DB) { 48 | tx := mocks.NewTransaction(db.t) 49 | tx.EXPECT().Rollback(mock.Anything).Return(with).Once() 50 | db.EXPECT().Begin(mock.Anything).Return(tx, nil).Once() 51 | } 52 | } 53 | 54 | type DB struct { 55 | t *testing.T 56 | *mocks.DB 57 | } 58 | -------------------------------------------------------------------------------- /testutil/db/schema.sql: -------------------------------------------------------------------------------- 1 | create table cats ( 2 | id bigserial, 3 | name text not null 4 | ); 5 | 6 | insert into cats (name) values ('foo'), ('bar'); 7 | -------------------------------------------------------------------------------- /testutil/int.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "github.com/jackc/pgx/v5" 8 | "github.com/jackc/pgx/v5/pgxpool" 9 | "github.com/orlangure/gnomock" 10 | "github.com/orlangure/gnomock/preset/postgres" 11 | "github.com/stretchr/testify/assert" 12 | "testing" 13 | ) 14 | 15 | func SetupDB(t *testing.T) (*pgxpool.Pool, *sql.DB) { 16 | t.Helper() 17 | 18 | p := postgres.Preset( 19 | postgres.WithUser("gnomock", "gnomick"), 20 | postgres.WithDatabase("mydb"), 21 | postgres.WithQueriesFile("../testutil/db/schema.sql"), 22 | ) 23 | 24 | container, err := gnomock.Start(p) 25 | assert.NoError(t, err) 26 | 27 | t.Cleanup(func() { _ = gnomock.Stop(container) }) 28 | 29 | connStr := fmt.Sprintf( 30 | "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", 31 | container.Host, container.DefaultPort(), 32 | "gnomock", "gnomick", "mydb", 33 | ) 34 | 35 | pgConfig, err := pgxpool.ParseConfig(connStr) 36 | assert.NoError(t, err) 37 | 38 | pool, err := pgxpool.NewWithConfig(context.Background(), pgConfig) 39 | assert.NoError(t, err) 40 | 41 | db, err := sql.Open("postgres", connStr) 42 | assert.NoError(t, err) 43 | 44 | return pool, db 45 | } 46 | 47 | func AssertSuccess(t *testing.T, pool *pgxpool.Pool, name string) { 48 | t.Helper() 49 | 50 | row := pool.QueryRow(context.TODO(), `select name from cats where name=$1`, name) 51 | 52 | var n string 53 | 54 | err := row.Scan(&n) 55 | assert.NoError(t, err) 56 | 57 | assert.Equal(t, name, n) 58 | } 59 | 60 | func AssertFailure(t *testing.T, pool *pgxpool.Pool, name string) { 61 | t.Helper() 62 | 63 | row := pool.QueryRow(context.TODO(), `select name from cats where name=$1`, name) 64 | 65 | var n string 66 | 67 | err := row.Scan(&n) 68 | 69 | assert.ErrorIs(t, err, pgx.ErrNoRows) 70 | } 71 | -------------------------------------------------------------------------------- /testutil/mocks/db.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.42.1. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | context "context" 7 | 8 | tx "github.com/aneshas/tx/v2" 9 | mock "github.com/stretchr/testify/mock" 10 | ) 11 | 12 | // DB is an autogenerated mock type for the DB type 13 | type DB struct { 14 | mock.Mock 15 | } 16 | 17 | type DB_Expecter struct { 18 | mock *mock.Mock 19 | } 20 | 21 | func (_m *DB) EXPECT() *DB_Expecter { 22 | return &DB_Expecter{mock: &_m.Mock} 23 | } 24 | 25 | // Begin provides a mock function with given fields: ctx 26 | func (_m *DB) Begin(ctx context.Context) (tx.Transaction, error) { 27 | ret := _m.Called(ctx) 28 | 29 | if len(ret) == 0 { 30 | panic("no return value specified for Begin") 31 | } 32 | 33 | var r0 tx.Transaction 34 | var r1 error 35 | if rf, ok := ret.Get(0).(func(context.Context) (tx.Transaction, error)); ok { 36 | return rf(ctx) 37 | } 38 | if rf, ok := ret.Get(0).(func(context.Context) tx.Transaction); ok { 39 | r0 = rf(ctx) 40 | } else { 41 | if ret.Get(0) != nil { 42 | r0 = ret.Get(0).(tx.Transaction) 43 | } 44 | } 45 | 46 | if rf, ok := ret.Get(1).(func(context.Context) error); ok { 47 | r1 = rf(ctx) 48 | } else { 49 | r1 = ret.Error(1) 50 | } 51 | 52 | return r0, r1 53 | } 54 | 55 | // DB_Begin_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Begin' 56 | type DB_Begin_Call struct { 57 | *mock.Call 58 | } 59 | 60 | // Begin is a helper method to define mock.On call 61 | // - ctx context.Context 62 | func (_e *DB_Expecter) Begin(ctx interface{}) *DB_Begin_Call { 63 | return &DB_Begin_Call{Call: _e.mock.On("Begin", ctx)} 64 | } 65 | 66 | func (_c *DB_Begin_Call) Run(run func(ctx context.Context)) *DB_Begin_Call { 67 | _c.Call.Run(func(args mock.Arguments) { 68 | run(args[0].(context.Context)) 69 | }) 70 | return _c 71 | } 72 | 73 | func (_c *DB_Begin_Call) Return(_a0 tx.Transaction, _a1 error) *DB_Begin_Call { 74 | _c.Call.Return(_a0, _a1) 75 | return _c 76 | } 77 | 78 | func (_c *DB_Begin_Call) RunAndReturn(run func(context.Context) (tx.Transaction, error)) *DB_Begin_Call { 79 | _c.Call.Return(run) 80 | return _c 81 | } 82 | 83 | // NewDB creates a new instance of DB. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. 84 | // The first argument is typically a *testing.T value. 85 | func NewDB(t interface { 86 | mock.TestingT 87 | Cleanup(func()) 88 | }) *DB { 89 | mock := &DB{} 90 | mock.Mock.Test(t) 91 | 92 | t.Cleanup(func() { mock.AssertExpectations(t) }) 93 | 94 | return mock 95 | } 96 | -------------------------------------------------------------------------------- /testutil/mocks/option.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.42.1. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | tx "github.com/aneshas/tx/v2" 7 | mock "github.com/stretchr/testify/mock" 8 | ) 9 | 10 | // Option is an autogenerated mock type for the Option type 11 | type Option struct { 12 | mock.Mock 13 | } 14 | 15 | type Option_Expecter struct { 16 | mock *mock.Mock 17 | } 18 | 19 | func (_m *Option) EXPECT() *Option_Expecter { 20 | return &Option_Expecter{mock: &_m.Mock} 21 | } 22 | 23 | // Execute provides a mock function with given fields: _a0 24 | func (_m *Option) Execute(_a0 *tx.TX) { 25 | _m.Called(_a0) 26 | } 27 | 28 | // Option_Execute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Execute' 29 | type Option_Execute_Call struct { 30 | *mock.Call 31 | } 32 | 33 | // Execute is a helper method to define mock.On call 34 | // - _a0 *tx.TX 35 | func (_e *Option_Expecter) Execute(_a0 interface{}) *Option_Execute_Call { 36 | return &Option_Execute_Call{Call: _e.mock.On("Execute", _a0)} 37 | } 38 | 39 | func (_c *Option_Execute_Call) Run(run func(_a0 *tx.TX)) *Option_Execute_Call { 40 | _c.Call.Run(func(args mock.Arguments) { 41 | run(args[0].(*tx.TX)) 42 | }) 43 | return _c 44 | } 45 | 46 | func (_c *Option_Execute_Call) Return() *Option_Execute_Call { 47 | _c.Call.Return() 48 | return _c 49 | } 50 | 51 | func (_c *Option_Execute_Call) RunAndReturn(run func(*tx.TX)) *Option_Execute_Call { 52 | _c.Call.Return(run) 53 | return _c 54 | } 55 | 56 | // NewOption creates a new instance of Option. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. 57 | // The first argument is typically a *testing.T value. 58 | func NewOption(t interface { 59 | mock.TestingT 60 | Cleanup(func()) 61 | }) *Option { 62 | mock := &Option{} 63 | mock.Mock.Test(t) 64 | 65 | t.Cleanup(func() { mock.AssertExpectations(t) }) 66 | 67 | return mock 68 | } 69 | -------------------------------------------------------------------------------- /testutil/mocks/transaction.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.42.1. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | context "context" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | ) 10 | 11 | // Transaction is an autogenerated mock type for the Transaction type 12 | type Transaction struct { 13 | mock.Mock 14 | } 15 | 16 | type Transaction_Expecter struct { 17 | mock *mock.Mock 18 | } 19 | 20 | func (_m *Transaction) EXPECT() *Transaction_Expecter { 21 | return &Transaction_Expecter{mock: &_m.Mock} 22 | } 23 | 24 | // Commit provides a mock function with given fields: ctx 25 | func (_m *Transaction) Commit(ctx context.Context) error { 26 | ret := _m.Called(ctx) 27 | 28 | if len(ret) == 0 { 29 | panic("no return value specified for Commit") 30 | } 31 | 32 | var r0 error 33 | if rf, ok := ret.Get(0).(func(context.Context) error); ok { 34 | r0 = rf(ctx) 35 | } else { 36 | r0 = ret.Error(0) 37 | } 38 | 39 | return r0 40 | } 41 | 42 | // Transaction_Commit_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Commit' 43 | type Transaction_Commit_Call struct { 44 | *mock.Call 45 | } 46 | 47 | // Commit is a helper method to define mock.On call 48 | // - ctx context.Context 49 | func (_e *Transaction_Expecter) Commit(ctx interface{}) *Transaction_Commit_Call { 50 | return &Transaction_Commit_Call{Call: _e.mock.On("Commit", ctx)} 51 | } 52 | 53 | func (_c *Transaction_Commit_Call) Run(run func(ctx context.Context)) *Transaction_Commit_Call { 54 | _c.Call.Run(func(args mock.Arguments) { 55 | run(args[0].(context.Context)) 56 | }) 57 | return _c 58 | } 59 | 60 | func (_c *Transaction_Commit_Call) Return(_a0 error) *Transaction_Commit_Call { 61 | _c.Call.Return(_a0) 62 | return _c 63 | } 64 | 65 | func (_c *Transaction_Commit_Call) RunAndReturn(run func(context.Context) error) *Transaction_Commit_Call { 66 | _c.Call.Return(run) 67 | return _c 68 | } 69 | 70 | // Rollback provides a mock function with given fields: ctx 71 | func (_m *Transaction) Rollback(ctx context.Context) error { 72 | ret := _m.Called(ctx) 73 | 74 | if len(ret) == 0 { 75 | panic("no return value specified for Rollback") 76 | } 77 | 78 | var r0 error 79 | if rf, ok := ret.Get(0).(func(context.Context) error); ok { 80 | r0 = rf(ctx) 81 | } else { 82 | r0 = ret.Error(0) 83 | } 84 | 85 | return r0 86 | } 87 | 88 | // Transaction_Rollback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Rollback' 89 | type Transaction_Rollback_Call struct { 90 | *mock.Call 91 | } 92 | 93 | // Rollback is a helper method to define mock.On call 94 | // - ctx context.Context 95 | func (_e *Transaction_Expecter) Rollback(ctx interface{}) *Transaction_Rollback_Call { 96 | return &Transaction_Rollback_Call{Call: _e.mock.On("Rollback", ctx)} 97 | } 98 | 99 | func (_c *Transaction_Rollback_Call) Run(run func(ctx context.Context)) *Transaction_Rollback_Call { 100 | _c.Call.Run(func(args mock.Arguments) { 101 | run(args[0].(context.Context)) 102 | }) 103 | return _c 104 | } 105 | 106 | func (_c *Transaction_Rollback_Call) Return(_a0 error) *Transaction_Rollback_Call { 107 | _c.Call.Return(_a0) 108 | return _c 109 | } 110 | 111 | func (_c *Transaction_Rollback_Call) RunAndReturn(run func(context.Context) error) *Transaction_Rollback_Call { 112 | _c.Call.Return(run) 113 | return _c 114 | } 115 | 116 | // NewTransaction creates a new instance of Transaction. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. 117 | // The first argument is typically a *testing.T value. 118 | func NewTransaction(t interface { 119 | mock.TestingT 120 | Cleanup(func()) 121 | }) *Transaction { 122 | mock := &Transaction{} 123 | mock.Mock.Test(t) 124 | 125 | t.Cleanup(func() { mock.AssertExpectations(t) }) 126 | 127 | return mock 128 | } 129 | -------------------------------------------------------------------------------- /testutil/mocks/transactor.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.42.1. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | context "context" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | ) 10 | 11 | // Transactor is an autogenerated mock type for the Transactor type 12 | type Transactor struct { 13 | mock.Mock 14 | } 15 | 16 | type Transactor_Expecter struct { 17 | mock *mock.Mock 18 | } 19 | 20 | func (_m *Transactor) EXPECT() *Transactor_Expecter { 21 | return &Transactor_Expecter{mock: &_m.Mock} 22 | } 23 | 24 | // WithTransaction provides a mock function with given fields: ctx, f 25 | func (_m *Transactor) WithTransaction(ctx context.Context, f func(context.Context) error) error { 26 | ret := _m.Called(ctx, f) 27 | 28 | if len(ret) == 0 { 29 | panic("no return value specified for WithTransaction") 30 | } 31 | 32 | var r0 error 33 | if rf, ok := ret.Get(0).(func(context.Context, func(context.Context) error) error); ok { 34 | r0 = rf(ctx, f) 35 | } else { 36 | r0 = ret.Error(0) 37 | } 38 | 39 | return r0 40 | } 41 | 42 | // Transactor_WithTransaction_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithTransaction' 43 | type Transactor_WithTransaction_Call struct { 44 | *mock.Call 45 | } 46 | 47 | // WithTransaction is a helper method to define mock.On call 48 | // - ctx context.Context 49 | // - f func(context.Context) error 50 | func (_e *Transactor_Expecter) WithTransaction(ctx interface{}, f interface{}) *Transactor_WithTransaction_Call { 51 | return &Transactor_WithTransaction_Call{Call: _e.mock.On("WithTransaction", ctx, f)} 52 | } 53 | 54 | func (_c *Transactor_WithTransaction_Call) Run(run func(ctx context.Context, f func(context.Context) error)) *Transactor_WithTransaction_Call { 55 | _c.Call.Run(func(args mock.Arguments) { 56 | run(args[0].(context.Context), args[1].(func(context.Context) error)) 57 | }) 58 | return _c 59 | } 60 | 61 | func (_c *Transactor_WithTransaction_Call) Return(_a0 error) *Transactor_WithTransaction_Call { 62 | _c.Call.Return(_a0) 63 | return _c 64 | } 65 | 66 | func (_c *Transactor_WithTransaction_Call) RunAndReturn(run func(context.Context, func(context.Context) error) error) *Transactor_WithTransaction_Call { 67 | _c.Call.Return(run) 68 | return _c 69 | } 70 | 71 | // NewTransactor creates a new instance of Transactor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. 72 | // The first argument is typically a *testing.T value. 73 | func NewTransactor(t interface { 74 | mock.TestingT 75 | Cleanup(func()) 76 | }) *Transactor { 77 | mock := &Transactor{} 78 | mock.Mock.Test(t) 79 | 80 | t.Cleanup(func() { mock.AssertExpectations(t) }) 81 | 82 | return mock 83 | } 84 | -------------------------------------------------------------------------------- /tx.go: -------------------------------------------------------------------------------- 1 | package tx 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | ) 8 | 9 | type key struct{} 10 | 11 | // Transactor is a helper transactor interface added for brevity purposes, so you don't have to define your own 12 | // See TX 13 | type Transactor interface { 14 | // WithTransaction will wrap f in a sql transaction depending on the DB provider. 15 | // This is mostly useful for when we want to control the transaction scope from 16 | // application layer, for example application service/command handler. 17 | // If f fails with an error, transactor will automatically try to roll the transaction back and report back any errors, 18 | // otherwise, the implicit transaction will be committed. 19 | WithTransaction(ctx context.Context, f func(ctx context.Context) error) error 20 | } 21 | 22 | // DB represents an interface to a db capable of starting a transaction 23 | type DB interface { 24 | Begin(ctx context.Context) (Transaction, error) 25 | } 26 | 27 | // Transaction represents db transaction 28 | type Transaction interface { 29 | Commit(ctx context.Context) error 30 | Rollback(ctx context.Context) error 31 | } 32 | 33 | // New constructs new transactor which will use provided db to handle the transaction 34 | func New(db DB, opts ...Option) *TX { 35 | ttx := TX{db: db} 36 | 37 | for _, opt := range opts { 38 | opt(&ttx) 39 | } 40 | 41 | return &ttx 42 | } 43 | 44 | // TX represents sql transactor 45 | type TX struct { 46 | db DB 47 | ignoreErrs []error 48 | } 49 | 50 | // WithTransaction will wrap f in a sql transaction depending on the DB provider. 51 | // This is mostly useful for when we want to control the transaction scope from 52 | // application layer, for example application service/command handler. 53 | // If f fails with an error, transactor will automatically try to roll the transaction back and report back any errors, 54 | // otherwise, the implicit transaction will be committed. 55 | func (t *TX) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error { 56 | tx, err := t.db.Begin(ctx) // add tx options if different isolation levels are needed 57 | if err != nil { 58 | return fmt.Errorf("tx: could not start transaction: %w", err) 59 | } 60 | 61 | defer func() { 62 | if r := recover(); r != nil { 63 | _ = tx.Rollback(ctx) 64 | panic(r) 65 | } 66 | }() 67 | 68 | ctx = context.WithValue(ctx, key{}, tx) 69 | 70 | err = f(ctx) 71 | if err != nil && !t.shouldIgnore(err) { 72 | e := tx.Rollback(ctx) 73 | if e != nil { 74 | return errors.Join(e, err) 75 | } 76 | 77 | return err 78 | } 79 | 80 | return errors.Join(err, tx.Commit(ctx)) 81 | } 82 | 83 | func (t *TX) shouldIgnore(err error) bool { 84 | for _, e := range t.ignoreErrs { 85 | if errors.Is(err, e) { 86 | return true 87 | } 88 | } 89 | 90 | return false 91 | } 92 | 93 | // From returns underlying tx value from context if it can be type-casted to T 94 | // Otherwise it returns default T, false. 95 | // From returns underlying T from the context which in most cases should probably be pgx.Tx 96 | // T will mostly be your Tx type (pgx.Tx, *sql.Tx, etc...) but is left as a generic type in order 97 | // to accommodate cases where people tend to abstract the whole connection/transaction 98 | // away behind an interface for example, something like Executor (see example). 99 | // 100 | // Example: 101 | // 102 | // type Executor interface { 103 | // Exec(ctx context.Context, sql string, args ...interface{}) (pgconn.CommandTag, error) 104 | // // ... other stuff 105 | // } 106 | // 107 | // tx, err := tx.From[Executor](ctx, pool) 108 | // 109 | // # Or 110 | // 111 | // tx, err := tx.From[pgx.Tx](ctx, pool) 112 | func From[T any](ctx context.Context) (T, bool) { 113 | val := ctx.Value(key{}) 114 | if val == nil { 115 | var t T 116 | 117 | return t, false 118 | } 119 | 120 | tx, ok := val.(T) 121 | if !ok { 122 | var t T 123 | 124 | return t, false 125 | } 126 | 127 | return tx, true 128 | } 129 | -------------------------------------------------------------------------------- /tx_test.go: -------------------------------------------------------------------------------- 1 | package tx_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/aneshas/tx/v2" 7 | "github.com/aneshas/tx/v2/testutil" 8 | "github.com/aneshas/tx/v2/testutil/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "testing" 11 | ) 12 | 13 | func TestShould_Report_Transaction_Begin_Error(t *testing.T) { 14 | wantErr := fmt.Errorf("something bad occurred") 15 | 16 | db := testutil.NewDB( 17 | t, 18 | testutil.WithUnsuccessfulTransactionStart(wantErr), 19 | ) 20 | transactor := tx.New(db) 21 | 22 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 23 | return nil 24 | }) 25 | 26 | assert.ErrorIs(t, err, wantErr) 27 | } 28 | 29 | func TestShould_Commit_Transaction_On_No_Error(t *testing.T) { 30 | db := testutil.NewDB( 31 | t, 32 | testutil.WithSuccessfulCommit(), 33 | ) 34 | transactor := tx.New(db) 35 | 36 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 37 | return nil 38 | }) 39 | 40 | assert.NoError(t, err) 41 | } 42 | 43 | func TestShould_Rollback_Transaction_On_Error(t *testing.T) { 44 | db := testutil.NewDB( 45 | t, 46 | testutil.WithSuccessfulRollback(), 47 | ) 48 | transactor := tx.New(db) 49 | 50 | wantErr := fmt.Errorf("something bad occurred") 51 | 52 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 53 | return wantErr 54 | }) 55 | 56 | assert.ErrorIs(t, err, wantErr) 57 | } 58 | 59 | func TestShould_Report_Unsuccessful_Rollback(t *testing.T) { 60 | wantTxErr := fmt.Errorf("something bad occurred") 61 | 62 | db := testutil.NewDB( 63 | t, 64 | testutil.WithUnsuccessfulRollback(wantTxErr), 65 | ) 66 | transactor := tx.New(db) 67 | 68 | wantErr := fmt.Errorf("process error") 69 | 70 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 71 | return wantErr 72 | }) 73 | 74 | assert.ErrorIs(t, err, wantErr) 75 | assert.ErrorIs(t, err, wantTxErr) 76 | } 77 | 78 | func TestShould_Rollback_Transaction_On_Panic_And_RePanic(t *testing.T) { 79 | db := testutil.NewDB( 80 | t, 81 | testutil.WithSuccessfulRollback(), 82 | ) 83 | transactor := tx.New(db) 84 | 85 | defer func() { 86 | if r := recover(); r == nil { 87 | t.Fatalf("expected panic to be propagated") 88 | } 89 | }() 90 | 91 | _ = transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 92 | panic("something very bad occurred") 93 | }) 94 | } 95 | 96 | func TestShould_Still_Commit_On_Ignored_Error_And_Propagate_Error(t *testing.T) { 97 | wantErr := fmt.Errorf("something bad occurred") 98 | 99 | db := testutil.NewDB( 100 | t, 101 | testutil.WithSuccessfulCommit(), 102 | ) 103 | transactor := tx.New(db, tx.WithIgnoredErrors(wantErr)) 104 | 105 | err := transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 106 | return wantErr 107 | }) 108 | 109 | assert.ErrorIs(t, err, wantErr) 110 | } 111 | 112 | func TestShould_Retrieve_Tx_From_Context(t *testing.T) { 113 | db := testutil.NewDB( 114 | t, 115 | testutil.WithSuccessfulCommit(), 116 | ) 117 | transactor := tx.New(db) 118 | 119 | _ = transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 120 | ttx, ok := tx.From[tx.Transaction](ctx) 121 | 122 | assert.True(t, ok) 123 | assert.IsType(t, &mocks.Transaction{}, ttx) 124 | 125 | return nil 126 | }) 127 | } 128 | 129 | func TestShould_Not_Retrieve_Conn_From_Context_On_Mismatched_Type(t *testing.T) { 130 | db := testutil.NewDB( 131 | t, 132 | testutil.WithSuccessfulCommit(), 133 | ) 134 | transactor := tx.New(db) 135 | 136 | _ = transactor.WithTransaction(context.TODO(), func(ctx context.Context) error { 137 | _, ok := tx.From[mocks.Transaction](ctx) 138 | 139 | assert.False(t, ok) 140 | 141 | return nil 142 | }) 143 | } 144 | 145 | func TestShould_Not_Retrieve_Conn_From_Context_Without_Transaction(t *testing.T) { 146 | ttx, ok := tx.From[*mocks.Transaction](context.TODO()) 147 | 148 | assert.False(t, ok) 149 | assert.Nil(t, ttx) 150 | } 151 | --------------------------------------------------------------------------------