├── coverage ├── coverage.log ├── README.md ├── cov-diff.sh ├── coverage.svg └── coverage.sh ├── CHANGELOG.md ├── buf.gen.yaml ├── .copywrite.hcl ├── CODEOWNERS ├── buf.yaml ├── docker-compose.yml ├── create_unexported_test.go ├── update_unexported_test.go ├── docs ├── README_DELETE.md ├── README_OPTIONS.md ├── README_READ.md ├── README_HOOKS.md ├── README_DEBUG.md ├── README_LOCKS.md ├── README_RW.md ├── README_INITFIELDS.md ├── README_QUERY.md ├── README_OPEN.md ├── README_TX.md ├── README_USAGE.md ├── README_CREATE.md ├── README_MODELS.md └── README_UPDATE.md ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── pull_request_template.md └── workflows │ ├── make-gen-delta.yml │ └── go.yml ├── error.go ├── id.go ├── transactions.go ├── backoff.go ├── docs.go ├── tools └── tools.go ├── db_unexported_test.go ├── query.go ├── backoff_test.go ├── transactions_test.go ├── internal ├── dbtest │ ├── timestamp.go │ └── db.go └── proto │ └── local │ └── dbtest │ └── storage │ └── v1 │ └── dbtest.proto ├── .gitignore ├── do_tx.go ├── go.mod ├── Makefile ├── lookup.go ├── scripts └── protoc_gen_plugin.bash ├── reader.go ├── id_test.go ├── README.md ├── rw_unexported_test.go ├── query_test.go ├── clause.go ├── testing_test.go ├── common_unexported_test.go ├── writer.go ├── CONTRIBUTING.md ├── lookup_test.go ├── common.go ├── db_test.go ├── delete.go ├── update.go ├── do_tx_test.go ├── option_test.go ├── db.go ├── option.go ├── rw.go └── rw_test.go /coverage/coverage.log: -------------------------------------------------------------------------------- 1 | 1693014148,90.9 2 | 1720903770,90.8 3 | 1722112902,90.9 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## Enhancements 4 | * Add `Dialect()` to `Reader` and `Writer`, allowing the user to make 5 | decisions based on the underlying database. 6 | -------------------------------------------------------------------------------- /buf.gen.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: BUSL-1.1 3 | 4 | version: v1 5 | plugins: 6 | - name: go 7 | out: . 8 | opt: 9 | - paths=import 10 | -------------------------------------------------------------------------------- /.copywrite.hcl: -------------------------------------------------------------------------------- 1 | schema_version = 1 2 | 3 | project { 4 | license = "MPL-2.0" 5 | copyright_year = 2023 6 | 7 | header_ignore = [ 8 | ".github/**", 9 | "coverage/**" 10 | ] 11 | } -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | # More on CODEOWNERS files: https://help.github.com/en/github/creating-cloning-and-archiving-repositories/about-code-owners 3 | * @hashicorp/boundary 4 | -------------------------------------------------------------------------------- /buf.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | version: v1beta1 5 | 6 | build: 7 | roots: 8 | - internal/proto/local 9 | 10 | lint: 11 | use: 12 | - DEFAULT 13 | ignore: 14 | - google 15 | ignore_only: 16 | 17 | breaking: 18 | use: 19 | - WIRE_JSON 20 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) HashiCorp, Inc. 2 | # SPDX-License-Identifier: MPL-2.0 3 | 4 | version: '3' 5 | 6 | services: 7 | postgres: 8 | image: 'postgres:latest' 9 | ports: 10 | - 9920:5432 11 | environment: 12 | - POSTGRES_DB=go_db 13 | - POSTGRES_USER=go_db 14 | - POSTGRES_PASSWORD=go_db -------------------------------------------------------------------------------- /create_unexported_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "sync/atomic" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_NonCreatableFields(t *testing.T) { 14 | // do not run with t.Parallel() 15 | assert := assert.New(t) 16 | nonUpdateFields = atomic.Value{} 17 | got := NonCreatableFields() 18 | assert.Equal(got, []string{}) 19 | 20 | InitNonCreatableFields([]string{"Foo"}) 21 | got = NonCreatableFields() 22 | assert.Equal(got, []string{"Foo"}) 23 | } 24 | -------------------------------------------------------------------------------- /update_unexported_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "sync/atomic" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func Test_NonUpdatableFields(t *testing.T) { 14 | // do not run with t.Parallel() 15 | assert := assert.New(t) 16 | nonUpdateFields = atomic.Value{} 17 | got := NonUpdatableFields() 18 | assert.Equal(got, []string{}) 19 | 20 | InitNonUpdatableFields([]string{"Foo"}) 21 | got = NonUpdatableFields() 22 | assert.Equal(got, []string{"Foo"}) 23 | } 24 | -------------------------------------------------------------------------------- /docs/README_DELETE.md: -------------------------------------------------------------------------------- 1 | # Delete 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | ## [RW.Delete(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Delete) example with one item 5 | ```go 6 | err, rowsAffected = rw.Delete(ctx, 7 | &user, 8 | dbw.WithVersion(&user.Version), 9 | ) 10 | ``` 11 | ## [RW.DeleteItems(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.DeleteItems) example with multiple items 12 | ```go 13 | var rowsAffected int64 14 | err = rw.DeleteItems(ctx, 15 | []interface{}{&user1, &user2}, 16 | dbw.WithRowsAffected(&rowsAffected), 17 | ) 18 | ``` 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Additional context** 27 | Add any other context about the problem here. 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## PCI review checklist 2 | 3 | 4 | 5 | - [ ] If applicable, I've documented a plan to revert these changes if they require more than reverting the pull request. 6 | 7 | - [ ] If applicable, I've worked with GRC to document the impact of any changes to security controls. 8 | 9 | Examples of changes to controls include access controls, encryption, logging, etc. 10 | 11 | - [ ] If applicable, I've worked with GRC to ensure compliance due to a significant change to the in-scope PCI environment. 12 | 13 | Examples include changes to operating systems, ports, protocols, services, cryptography-related components, PII processing code, etc. 14 | -------------------------------------------------------------------------------- /docs/README_OPTIONS.md: -------------------------------------------------------------------------------- 1 | # Options 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | `dbw` supports variadic 6 | [Option](https://pkg.go.dev/github.com/hashicorp/go-dbw#Option) function 7 | parameters for the vast majority of its operations. See the [dbw package 8 | docs](https://pkg.go.dev/github.com/hashicorp/go-dbw) for more 9 | information about which 10 | [Option](https://pkg.go.dev/github.com/hashicorp/go-dbw#Option) functions are 11 | supported for each operation. 12 | 13 | 14 | ```go 15 | // just one example of variadic options: an update 16 | // using WithVersion and WithDebug options 17 | rw.Update(ctx, &user, dbw.WithVersion(10), dbw.WithDebug(true)) 18 | ``` -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import "errors" 7 | 8 | var ( 9 | // ErrUnknown is an unknown/undefined error 10 | ErrUnknown = errors.New("unknown") 11 | 12 | // ErrInvalidParameter is an invalid parameter error 13 | ErrInvalidParameter = errors.New("invalid parameter") 14 | 15 | // ErrInternal is an internal error 16 | ErrInternal = errors.New("internal error") 17 | 18 | // ErrRecordNotFound is a not found record error 19 | ErrRecordNotFound = errors.New("record not found") 20 | 21 | // ErrMaxRetries is a max retries error 22 | ErrMaxRetries = errors.New("too many retries") 23 | 24 | // ErrInvalidFieldMask is an invalid field mask error 25 | ErrInvalidFieldMask = errors.New("invalid field mask") 26 | ) 27 | -------------------------------------------------------------------------------- /coverage/README.md: -------------------------------------------------------------------------------- 1 | # Coverage 2 | 3 | This `coverage` directory contains the bits required to generate the code 4 | coverage report and badge which are published for this repo. 5 | 6 | After making changes to source, please run `make coverage` in the root directory 7 | of this repo and check-in any changes. 8 | 9 | - **cov-diff.sh** - generates a new coverage report and checks the previous 10 | entry in `coverage.log` for differences. It's used by the github action to 11 | ensure that the published coverage report and badge are up to date. 12 | - **coverage.sh** - generates `coverage.log`, `coverage.svg`, and 13 | `coverage.html`. 14 | - **coverage.log** - A log of coverage report runs. 15 | - **coverage.html** - The published coverage report. 16 | - **coverage.svg** - The published coverage badge. 17 | -------------------------------------------------------------------------------- /docs/README_READ.md: -------------------------------------------------------------------------------- 1 | # Reading 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | dbw provides a few ways to read data from the database. 6 | 7 | ```go 8 | // Get the user with either a PublicId, PrivateId or 9 | // primary keys matching the given users 10 | rw.LookupId(ctx, &user) 11 | 12 | // Get the user with a public_id matching the given user. 13 | rw.LookupByPublicId(ctx, &user) 14 | 15 | // Get the first user matching the where clause 16 | rw.LookupWhere(ctx, 17 | &user, 18 | "public_id = @id", 19 | sql.Named("id", "1"), 20 | ) 21 | 22 | // Get all the users matching the where clause 23 | rw.SearchWhere(ctx, 24 | &users, 25 | "public_id in(@ids)", 26 | sql.Named("ids", []string{"1", "2"}), 27 | ) 28 | ``` -------------------------------------------------------------------------------- /docs/README_HOOKS.md: -------------------------------------------------------------------------------- 1 | # Hooks 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | dbw provides two options for write operations which give callers hooks before 6 | and after the write operations: 7 | * [WithBeforeWrite(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithBeforeWrite) 8 | * [WithAfterWrite(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithAfterWrite) 9 | 10 | ```go 11 | beforeFn := func(_ interface{}) error { 12 | return nil // always succeed for this example 13 | } 14 | afterFn := func(_ interface{}, _ int) error { 15 | return nil // always succeed for this example 16 | } 17 | 18 | rw.Create(ctx, 19 | &user, 20 | dbw.WithBeforeWrite(beforeFn), 21 | dbw.WithAfterWrite(afterFn), 22 | ) 23 | 24 | ``` 25 | 26 | -------------------------------------------------------------------------------- /coverage/cov-diff.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) HashiCorp, Inc. 3 | # SPDX-License-Identifier: MPL-2.0 4 | 5 | # Check if the file exists 6 | if [ ! -f "$1" ]; then 7 | echo "File not found!" 8 | exit 1 9 | fi 10 | 11 | # Read the last two lines of the file and extract the last columns 12 | last_col_last_line=$(tail -n 1 "$1" | awk -F',' '{print $NF}') 13 | last_col_second_last_line=$(tail -n 2 "$1" | head -n 1 | awk -F',' '{print $NF}') 14 | 15 | # Compare the last columns 16 | if [ "$last_col_last_line" = "$last_col_second_last_line" ]; then 17 | exit 0 18 | else 19 | echo "coverage has changed." 20 | echo "generate a new report and badge using: make coverage" 21 | echo "and then check-in the new report and badge?" 22 | echo "coverage before: $last_col_second_last_line" 23 | echo "coverage now: $last_col_last_line" 24 | exit 1 25 | fi -------------------------------------------------------------------------------- /docs/README_DEBUG.md: -------------------------------------------------------------------------------- 1 | # Debug output 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | `dbw` provides a few ways to get debug output from the underlying database. 6 | 7 | ## [DB.Debug(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#DB.Debug) 8 | 9 | ```go 10 | // enable debug output for all database operations 11 | db, err := dbw.Open(dbw.Sqlite, "dbw.db") 12 | db.Debug(true) 13 | ``` 14 | 15 | ## [WithDebug(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithDebug) 16 | Operations may take the 17 | [WithDebug(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithDebug) 18 | option which will enable/disable debug output for the duration of that operation. 19 | 20 | ```go 21 | // enable debug output for a create operation 22 | rw.Create(ctx, &user, dbw.WithDebug(true)) 23 | ``` -------------------------------------------------------------------------------- /id.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "bytes" 8 | "fmt" 9 | "strings" 10 | 11 | "github.com/hashicorp/go-secure-stdlib/base62" 12 | "golang.org/x/crypto/blake2b" 13 | ) 14 | 15 | // NewId creates a new random base62 ID with the provided prefix with an 16 | // underscore delimiter 17 | func NewId(prefix string, opt ...Option) (string, error) { 18 | const op = "dbw.NewId" 19 | if prefix == "" { 20 | return "", fmt.Errorf("%s: missing prefix: %w", op, ErrInvalidParameter) 21 | } 22 | var publicId string 23 | var err error 24 | opts := GetOpts(opt...) 25 | if len(opts.WithPrngValues) > 0 { 26 | sum := blake2b.Sum256([]byte(strings.Join(opts.WithPrngValues, "|"))) 27 | reader := bytes.NewReader(sum[0:]) 28 | publicId, err = base62.RandomWithReader(10, reader) 29 | } else { 30 | publicId, err = base62.Random(10) 31 | } 32 | if err != nil { 33 | return "", fmt.Errorf("%s: unable to generate id: %w", op, ErrInternal) 34 | } 35 | return fmt.Sprintf("%s_%s", prefix, publicId), nil 36 | } 37 | -------------------------------------------------------------------------------- /docs/README_LOCKS.md: -------------------------------------------------------------------------------- 1 | # Optimistic locking for write operations 2 | [![Go Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 3 | 4 | `dbw` provides the [dbw.WithVersion(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithVersion) option for write operations to enable 5 | an optimistic locking pattern. Using this pattern, the caller must first read 6 | a resource from the database and get its version. Then the caller passes the version in 7 | with the write operation and the operation will fail if another caller has 8 | updated the resource's version in the meantime. 9 | 10 | ```go 11 | err := rw.LookupId(ctx, &user) 12 | 13 | user.Name = "Alice" 14 | rowsAffected, err = rw.Update(ctx, 15 | &user, 16 | []string{"Name"}, 17 | nil, 18 | dbw.WithVersion(&user.Version)) 19 | 20 | if err != nil && error.Is(err, dbw.ErrRecordNotFound) { 21 | // update failed because the row wasn't found because 22 | // either it was deleted, or updated by another caller 23 | // after it was read earlier in this example 24 | } 25 | ``` -------------------------------------------------------------------------------- /coverage/coverage.svg: -------------------------------------------------------------------------------- 1 | coverage: 90.9%coverage90.9% -------------------------------------------------------------------------------- /transactions.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | ) 10 | 11 | // Begin will start a transaction 12 | func (rw *RW) Begin(ctx context.Context) (*RW, error) { 13 | const op = "dbw.Begin" 14 | newTx := rw.underlying.wrapped.WithContext(ctx) 15 | newTx = newTx.Begin() 16 | if newTx.Error != nil { 17 | return nil, fmt.Errorf("%s: %w", op, newTx.Error) 18 | } 19 | return New( 20 | &DB{wrapped: newTx}, 21 | ), nil 22 | } 23 | 24 | // Rollback will rollback the current transaction 25 | func (rw *RW) Rollback(ctx context.Context) error { 26 | const op = "dbw.Rollback" 27 | db := rw.underlying.wrapped.WithContext(ctx) 28 | if err := db.Rollback().Error; err != nil { 29 | return fmt.Errorf("%s: %w", op, err) 30 | } 31 | return nil 32 | } 33 | 34 | // Commit will commit a transaction 35 | func (rw *RW) Commit(ctx context.Context) error { 36 | const op = "dbw.Commit" 37 | db := rw.underlying.wrapped.WithContext(ctx) 38 | if err := db.Commit().Error; err != nil { 39 | return fmt.Errorf("%s: %w", op, err) 40 | } 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /backoff.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "math" 8 | "math/rand" 9 | "time" 10 | ) 11 | 12 | // Backoff defines an interface for providing a back off for retrying 13 | // transactions. See DoTx(...) 14 | type Backoff interface { 15 | Duration(attemptNumber uint) time.Duration 16 | } 17 | 18 | // ConstBackoff defines a constant backoff for retrying transactions. See 19 | // DoTx(...) 20 | type ConstBackoff struct { 21 | DurationMs time.Duration 22 | } 23 | 24 | // Duration is the constant backoff duration based on the retry attempt 25 | func (b ConstBackoff) Duration(attempt uint) time.Duration { 26 | return time.Millisecond * time.Duration(b.DurationMs) 27 | } 28 | 29 | // ExpBackoff defines an exponential backoff for retrying transactions. See DoTx(...) 30 | type ExpBackoff struct { 31 | testRand float64 32 | } 33 | 34 | // Duration is the exponential backoff duration based on the retry attempt 35 | func (b ExpBackoff) Duration(attempt uint) time.Duration { 36 | var r float64 37 | switch { 38 | case b.testRand > 0: 39 | r = b.testRand 40 | default: 41 | r = rand.Float64() 42 | } 43 | return time.Millisecond * time.Duration(math.Exp2(float64(attempt))*5*(r+0.5)) 44 | } 45 | -------------------------------------------------------------------------------- /docs.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | /* 5 | Package dbw is a database wrapper that supports connecting and using any database with a 6 | gorm driver. It's intent is to completely encapsulate an application's access 7 | to it's database with the exception of migrations. 8 | 9 | dbw is intentionally not an ORM and it removes typical ORM abstractions like 10 | "advanced query building", associations and migrations. 11 | 12 | This is not to say you can't easily use dbw for complicated queries, it's just 13 | that dbw doesn't try to reinvent sql by providing some sort of pattern for 14 | building them with functions. Of course, dbw also provides lookup/search 15 | functions when you simply need to read resources from the database. 16 | 17 | dbw strives to make CRUD for database resources fairly trivial. Even supporting 18 | "on conflict" for its create function. dbw also allows you to opt out of its 19 | CRUD functions and use exec, query and scan rows directly. You may want to 20 | carefully weigh when it's appropriate to use exec and query directly, since 21 | it's likely that each time you use them you're leaking a bit of your 22 | database schema into your application's domain. 23 | 24 | For more information see README.md 25 | */ 26 | package dbw 27 | -------------------------------------------------------------------------------- /tools/tools.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | //go:build tools 5 | // +build tools 6 | 7 | // This file ensures tool dependencies are kept in sync. This is the 8 | // recommended way of doing this according to 9 | // https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module 10 | // To install the following tools at the version used by this repo run: 11 | // $ make tools 12 | // or 13 | // $ go generate -tags tools tools/tools.go 14 | 15 | package tools 16 | 17 | // NOTE: This must not be indented, so to stop goimports from trying to be 18 | // helpful, it's separated out from the import block below. Please try to keep 19 | // them in the same order. 20 | //go:generate go install mvdan.cc/gofumpt 21 | //go:generate go install github.com/favadi/protoc-go-inject-tag 22 | //go:generate go install golang.org/x/tools/cmd/goimports 23 | //go:generate go install github.com/oligot/go-mod-upgrade 24 | //go:generate go install google.golang.org/protobuf/cmd/protoc-gen-go 25 | 26 | import ( 27 | _ "mvdan.cc/gofumpt" 28 | 29 | _ "github.com/favadi/protoc-go-inject-tag" 30 | 31 | _ "golang.org/x/tools/cmd/goimports" 32 | 33 | _ "github.com/oligot/go-mod-upgrade" 34 | 35 | _ "google.golang.org/protobuf/cmd/protoc-gen-go" 36 | ) 37 | -------------------------------------------------------------------------------- /.github/workflows/make-gen-delta.yml: -------------------------------------------------------------------------------- 1 | name: "make-gen-delta" 2 | on: 3 | - workflow_dispatch 4 | - push 5 | - workflow_call 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | make-gen-delta: 12 | name: "Check for uncommitted changes from make gen" 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.3 16 | with: 17 | fetch-depth: '0' 18 | - name: Determine Go version 19 | id: get-go-version 20 | # We use .go-version as our source of truth for current Go 21 | # version, because "goenv" can react to it automatically. 22 | run: | 23 | echo "Building with Go $(cat .go-version)" 24 | echo "go-version=$(cat .go-version)" >> "$GITHUB_OUTPUT" 25 | - name: Set up Go 26 | uses: actions/setup-go@93397bea11091df50f3d7e59dc26a7711a8bcfbe # v4.1.0 27 | with: 28 | go-version: "${{ steps.get-go-version.outputs.go-version }}" 29 | - name: Running go mod tidy 30 | run: | 31 | go mod tidy 32 | - name: Install Dependencies 33 | run: | 34 | make tools 35 | - name: Running make gen 36 | run: | 37 | make gen 38 | - name: Check for changes 39 | run: | 40 | git diff --exit-code 41 | git status --porcelain 42 | test -z "$(git status --porcelain)" 43 | -------------------------------------------------------------------------------- /db_unexported_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "bytes" 8 | "errors" 9 | "testing" 10 | 11 | "github.com/hashicorp/go-hclog" 12 | "github.com/jackc/pgconn" 13 | "github.com/stretchr/testify/assert" 14 | "gorm.io/gorm/logger" 15 | ) 16 | 17 | func TestDB_Debug(t *testing.T) { 18 | tests := []struct { 19 | name string 20 | enable bool 21 | }{ 22 | {name: "enabled", enable: true}, 23 | {name: "disabled"}, 24 | } 25 | for _, tt := range tests { 26 | t.Run(tt.name, func(t *testing.T) { 27 | assert := assert.New(t) 28 | db, _ := TestSetup(t) 29 | db.Debug(tt.enable) 30 | if tt.enable { 31 | assert.Equal(db.wrapped.Logger, logger.Default.LogMode(logger.Info)) 32 | } else { 33 | assert.Equal(db.wrapped.Logger, logger.Default.LogMode(logger.Error)) 34 | } 35 | }) 36 | } 37 | } 38 | 39 | func TestDB_gormLogger(t *testing.T) { 40 | var buf bytes.Buffer 41 | l := getGormLogger( 42 | hclog.New(&hclog.LoggerOptions{ 43 | Level: hclog.Trace, 44 | Output: &buf, 45 | }), 46 | ) 47 | t.Run("no-output", func(t *testing.T) { 48 | l.Printf("not a pgerror", "value 0 placeholder", errors.New("test"), "values 2 placeholder") 49 | assert.Empty(t, buf.Bytes()) 50 | }) 51 | t.Run("output", func(t *testing.T) { 52 | l.Printf("is a pgerror", "value 0 placeholder", &pgconn.PgError{}, "values 2 placeholder") 53 | assert.NotEmpty(t, buf.Bytes()) 54 | }) 55 | } 56 | -------------------------------------------------------------------------------- /docs/README_RW.md: -------------------------------------------------------------------------------- 1 | # Readers and Writers 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | [RW](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW) provides a type which 6 | implements both the interfaces of 7 | [dbw.Reader](https://pkg.go.dev/github.com/hashicorp/go-dbw#Reader) and 8 | [dbw.Writer](https://pkg.go.dev/github.com/hashicorp/go-dbw#Writer). Many 9 | [RWs](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW) 10 | can (and likely should) share the same 11 | [dbw.DB](https://pkg.go.dev/github.com/hashicorp/go-dbw#DB), since the 12 | [dbw.DB](https://pkg.go.dev/github.com/hashicorp/go-dbw#DB) 13 | is responsible for connection pooling. 14 | 15 | ```go 16 | db, _ := dbw.Open(dbw.Sqlite, url) 17 | rw := dbw.New(conn) 18 | // now you can use the rw for read/write database operations 19 | ``` 20 | 21 | When required, you can create two `DB`s: one for reading from read replicas and 22 | another for writing to the primary database. In such a scenario, you'd need to 23 | create RWs with the correct DB for either reading or writing. 24 | 25 | ```go 26 | readReplicaDSN := "postgresql://go_db:go_db@reader.hostname:9920/go_db?sslmode=disable" 27 | rdb, err := dbw.Open(dbw.Postgres, readReplicaDSN) 28 | reader := dbw.New(rdb) 29 | 30 | 31 | primaryDSN := "postgresql://go_db:go_db@primary.hostname:9920/go_db?sslmode=disable" 32 | rdb, err := dbw.Open(dbw.Postgres, primaryDSN) 33 | writer := dbw.New(rdb) 34 | ``` 35 | -------------------------------------------------------------------------------- /docs/README_INITFIELDS.md: -------------------------------------------------------------------------------- 1 | # NonCreatable and NonUpdatable fields 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | `dbw` provides a set of functions which allows you to define sets of fields 6 | which cannot be set using 7 | [RW.Create(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Create) or 8 | updated via 9 | [RW.Update(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Update). To 10 | be clear, errors are not raised if you mistakenly try to set/update these 11 | fields, but rather `dbw` quietly removes the set/update of these fields before 12 | generating the sql to send along to the database. 13 | 14 | For more details see: 15 | * [InitNonCreatableFields](https://pkg.go.dev/github.com/hashicorp/go-dbw#InitNonCreatableFields) 16 | * [InitNonUpdatableFields](https://pkg.go.dev/github.com/hashicorp/go-dbw#InitNonUpdatableFields) 17 | * [NonCreatableFields](https://pkg.go.dev/github.com/hashicorp/go-dbw#NonCreatableFields) 18 | * [NonUpdatableFields](https://pkg.go.dev/github.com/hashicorp/go-dbw#NonUpdatableFields) 19 | 20 | ```go 21 | // initialize fields which cannot be set during creation 22 | dbw.InitNonCreatableFields([]string{"CreateTime", "UpdateTime"}) 23 | // read the current set of non-creatable fields 24 | fields := dbw.NonCreatableFields() 25 | 26 | // initialize fields which cannot be updated 27 | dbw.InitNonUpdatableFields([]string{"PublicId", "CreateTime", "UpdateTime"}) 28 | // read the current set of non-updatable fields 29 | fields = dbw.NonUpdatableFields() 30 | ``` -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "fmt" 10 | ) 11 | 12 | // Query will run the raw query and return the *sql.Rows results. Query will 13 | // operate within the context of any ongoing transaction for the Reader. The 14 | // caller must close the returned *sql.Rows. Query can/should be used in 15 | // combination with ScanRows. The WithDebug option is supported. 16 | func (rw *RW) Query(ctx context.Context, sql string, values []interface{}, opt ...Option) (*sql.Rows, error) { 17 | const op = "dbw.Query" 18 | if rw.underlying == nil { 19 | return nil, fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal) 20 | } 21 | if sql == "" { 22 | return nil, fmt.Errorf("%s: missing sql: %w", op, ErrInvalidParameter) 23 | } 24 | opts := GetOpts(opt...) 25 | db := rw.underlying.wrapped.WithContext(ctx) 26 | if opts.WithDebug { 27 | db = db.Debug() 28 | } 29 | db = db.Raw(sql, values...) 30 | if db.Error != nil { 31 | return nil, fmt.Errorf("%s: %w", op, db.Error) 32 | } 33 | return db.Rows() 34 | } 35 | 36 | // ScanRows will scan the rows into the interface 37 | func (rw *RW) ScanRows(rows *sql.Rows, result interface{}) error { 38 | const op = "dbw.ScanRows" 39 | if rw.underlying == nil { 40 | return fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal) 41 | } 42 | if rows == nil { 43 | return fmt.Errorf("%s: missing rows: %w", op, ErrInvalidParameter) 44 | } 45 | if isNil(result) { 46 | return fmt.Errorf("%s: missing result: %w", op, ErrInvalidParameter) 47 | } 48 | return rw.underlying.wrapped.ScanRows(rows, result) 49 | } 50 | -------------------------------------------------------------------------------- /docs/README_QUERY.md: -------------------------------------------------------------------------------- 1 | # Queries 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | `dbw` provides quite a few different ways to read resources from a database 6 | (see: [Read operations](./README_READ.md)) 7 | 8 | `dbw` intentionally doesn't support "associations" or try to reinvent sql by providing some sort of pattern for 9 | "building" a query. Instead, `dbw` provides a set of functions for directly issuing SQL to the database and scanning the results back into Go structs. 10 | 11 | 12 | ## [RW.Query](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Query) and [RW.ScanRows(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.ScanRows) example with a CTE 13 | ```go 14 | where := ` 15 | with user_rentals as ( 16 | select user_id, count(*) as rental_count 17 | from test_rentals 18 | group by user_id) 19 | select u.public_id, u.name, r.rental_count 20 | from test_users u 21 | join user_rental r 22 | on u.public_id = r.user_id 23 | where name in (@names)` 24 | 25 | rows, err := rw.Query( 26 | context.Background(), 27 | where, 28 | []interface{}{ sql.Named{"names", "alice", "bob"}}, 29 | ) 30 | defer rows.Close() 31 | for rows.Next() { 32 | user := db_test.NewTestUser() 33 | _ = rw.ScanRows(rows, &user) 34 | // Do something with the user struct 35 | } 36 | ``` 37 | 38 | ## [RW.Exec](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Exec) example 39 | 40 | ```go 41 | where := ` 42 | delete from test_rentals 43 | where user_id not in (select user_id from test_users)` 44 | 45 | err := rw.Exec( 46 | context.Background(), 47 | where, 48 | nil, 49 | ) 50 | ``` -------------------------------------------------------------------------------- /docs/README_OPEN.md: -------------------------------------------------------------------------------- 1 | # Connecting 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | dbw has tested official support for SQLite and Postgres. You can also use it to connect 6 | to any database that has a Gorm V2 driver which has official support for: SQLite, Postgres, 7 | MySQL and SQL Server. 8 | 9 | ## SQLite 10 | ```go 11 | import( 12 | "github.com/hashicorp/go-dbw" 13 | ) 14 | 15 | func main() { 16 | db, err := dbw.Open(dbw.Sqlite, "dbw.db") 17 | } 18 | ``` 19 | 20 | ## Postgres 21 | ```go 22 | import( 23 | "github.com/hashicorp/go-dbw" 24 | ) 25 | 26 | func main() { 27 | dsn := "postgresql://go_db:go_db@localhost:9920/go_db?sslmode=disable" 28 | db, err := dbw.Open(dbw.Postgres, dsn) 29 | } 30 | ``` 31 | 32 | ## Any gorm v2 driver or an existing connection 33 | ```go 34 | import( 35 | "database/sql" 36 | "github.com/hashicorp/go-dbw" 37 | "gorm.io/gorm" 38 | ) 39 | 40 | func main() { 41 | dsn := "postgresql://go_db:go_db@localhost:9920/go_db?sslmode=disable" 42 | sqlDB, err := sql.Open("mysql", dsn) 43 | db, err := dbw.OpenWith(mysql.New(mysql.Config{ 44 | Conn: sqlDB, 45 | })) 46 | } 47 | ``` 48 | 49 | ## Connection Pooling 50 | 51 | ```go 52 | import( 53 | "github.com/hashicorp/go-dbw" 54 | ) 55 | 56 | func main() { 57 | dsn := "postgresql://go_db:go_db@localhost:9920/go_db?sslmode=disable" 58 | db, err := dbw.Open(dbw.Postgres, dsn, 59 | dbw.WithMaxConnections(20), 60 | dbw.WithMinConnections(2), 61 | ) 62 | sqlDB, err = db.SqlDB() 63 | sqlDB.SetConnMaxLifetime(time.Hour) 64 | sqlDB.SetMaxIdleConns(10) 65 | } 66 | ``` -------------------------------------------------------------------------------- /backoff_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "testing" 8 | "time" 9 | 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestConstBackoff_Duration(t *testing.T) { 14 | tests := []struct { 15 | name string 16 | b ConstBackoff 17 | attempt int 18 | want time.Duration 19 | }{ 20 | { 21 | name: "one", 22 | b: ConstBackoff{DurationMs: 2}, 23 | attempt: 1, 24 | want: time.Millisecond * 2, 25 | }, 26 | { 27 | name: "two", 28 | b: ConstBackoff{DurationMs: 2}, 29 | attempt: 2, 30 | want: time.Millisecond * 2, 31 | }, 32 | } 33 | for _, tt := range tests { 34 | t.Run(tt.name, func(t *testing.T) { 35 | assert := assert.New(t) 36 | got := tt.b.Duration(uint(tt.attempt)) 37 | assert.Equal(tt.want, got) 38 | }) 39 | } 40 | } 41 | 42 | func TestExpBackoff_Duration(t *testing.T) { 43 | tests := []struct { 44 | name string 45 | b ExpBackoff 46 | attempt int 47 | want time.Duration 48 | wantRand bool 49 | }{ 50 | { 51 | name: "one", 52 | b: ExpBackoff{testRand: 1}, 53 | attempt: 1, 54 | want: time.Millisecond * 15, 55 | }, 56 | { 57 | name: "two", 58 | b: ExpBackoff{testRand: 2}, 59 | attempt: 1, 60 | want: time.Millisecond * 25, 61 | }, 62 | { 63 | name: "rand", 64 | b: ExpBackoff{}, 65 | attempt: 1, 66 | wantRand: true, 67 | }, 68 | } 69 | for _, tt := range tests { 70 | t.Run(tt.name, func(t *testing.T) { 71 | assert := assert.New(t) 72 | got := tt.b.Duration(uint(tt.attempt)) 73 | if tt.wantRand { 74 | assert.NotZero(got) 75 | return 76 | } 77 | assert.Equal(tt.want, got) 78 | }) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /transactions_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw_test 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | 10 | "github.com/hashicorp/go-dbw" 11 | "github.com/hashicorp/go-dbw/internal/dbtest" 12 | "github.com/stretchr/testify/assert" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestRW_Transactions(t *testing.T) { 17 | t.Parallel() 18 | testCtx := context.Background() 19 | conn, _ := dbw.TestSetup(t) 20 | 21 | t.Run("simple", func(t *testing.T) { 22 | require := require.New(t) 23 | id, err := dbw.NewId("u") 24 | require.NoError(err) 25 | w := dbw.New(conn) 26 | 27 | tx, err := w.Begin(testCtx) 28 | require.NoError(err) 29 | 30 | user, err := dbtest.NewTestUser() 31 | require.NoError(err) 32 | require.NoError(tx.Create(testCtx, &user)) 33 | 34 | user.Name = id 35 | rowsUpdated, err := tx.Update(testCtx, user, []string{"Name"}, nil) 36 | require.NoError(err) 37 | require.Equal(1, rowsUpdated) 38 | require.NoError(tx.Commit(testCtx)) 39 | }) 40 | t.Run("rollback-success", func(t *testing.T) { 41 | require := require.New(t) 42 | id, err := dbw.NewId("u") 43 | require.NoError(err) 44 | w := dbw.New(conn) 45 | 46 | tx, err := w.Begin(testCtx) 47 | require.NoError(err) 48 | 49 | user, err := dbtest.NewTestUser() 50 | require.NoError(err) 51 | require.NoError(tx.Create(testCtx, &user)) 52 | 53 | user.Name = id 54 | rowsUpdated, err := tx.Update(testCtx, user, []string{"Name"}, nil) 55 | require.NoError(err) 56 | require.Equal(1, rowsUpdated) 57 | require.NoError(tx.Rollback(testCtx)) 58 | }) 59 | t.Run("no-transaction", func(t *testing.T) { 60 | assert := assert.New(t) 61 | w := dbw.New(conn) 62 | assert.Error(w.Rollback(testCtx)) 63 | assert.Error(w.Commit(testCtx)) 64 | }) 65 | } 66 | -------------------------------------------------------------------------------- /internal/dbtest/timestamp.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbtest 5 | 6 | import ( 7 | "database/sql/driver" 8 | "errors" 9 | "math" 10 | "time" 11 | 12 | "google.golang.org/protobuf/types/known/timestamppb" 13 | ) 14 | 15 | // New constructs a new Timestamp from the provided time.Time. 16 | func New(t time.Time) *Timestamp { 17 | return &Timestamp{ 18 | Timestamp: timestamppb.New(t), 19 | } 20 | } 21 | 22 | // Now constructs a new Timestamp from the current time. 23 | func Now() *Timestamp { 24 | return &Timestamp{ 25 | Timestamp: timestamppb.Now(), 26 | } 27 | } 28 | 29 | // AsTime converts x to a time.Time. 30 | func (ts *Timestamp) AsTime() time.Time { 31 | return ts.Timestamp.AsTime() 32 | } 33 | 34 | var ( 35 | // NegativeInfinityTS defines a value for postgres -infinity 36 | NegativeInfinityTS = time.Date(math.MinInt32, time.January, 1, 0, 0, 0, 0, time.UTC) 37 | // PositiveInfinityTS defines a value for postgres infinity 38 | PositiveInfinityTS = time.Date(math.MaxInt32, time.December, 31, 23, 59, 59, 1e9-1, time.UTC) 39 | ) 40 | 41 | // Scan implements sql.Scanner for protobuf Timestamp. 42 | func (ts *Timestamp) Scan(value interface{}) error { 43 | switch t := value.(type) { 44 | case time.Time: 45 | ts.Timestamp = timestamppb.New(t) // google proto version 46 | case string: 47 | switch value { 48 | case "-infinity": 49 | ts.Timestamp = timestamppb.New(NegativeInfinityTS) 50 | case "infinity": 51 | ts.Timestamp = timestamppb.New(PositiveInfinityTS) 52 | } 53 | default: 54 | return errors.New("Not a protobuf Timestamp") 55 | } 56 | return nil 57 | } 58 | 59 | // Scan implements driver.Valuer for protobuf Timestamp. 60 | func (ts *Timestamp) Value() (driver.Value, error) { 61 | if ts == nil { 62 | return nil, nil 63 | } 64 | return ts.Timestamp.AsTime(), nil 65 | } 66 | 67 | // GormDataType gorm common data type (required) 68 | func (ts *Timestamp) GormDataType() string { 69 | return "timestamp" 70 | } 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Folders 2 | _obj 3 | _test 4 | .cover 5 | 6 | # IntelliJ IDEA project files 7 | .idea 8 | *.ipr 9 | *.iml 10 | *.iws 11 | 12 | ### Logs ### 13 | logs/ 14 | 15 | ### direnv ### 16 | .envrc 17 | .direnv/ 18 | 19 | ### Temp directories ### 20 | tmp/ 21 | temp/ 22 | 23 | ### Visual Studio ### 24 | .vscode/ 25 | 26 | ### macOS ### 27 | # General 28 | .DS_Store 29 | .AppleDouble 30 | .LSOverride 31 | 32 | # Icon must end with two \r 33 | Icon 34 | 35 | 36 | # Thumbnails 37 | ._* 38 | 39 | ### Git ### 40 | # Created by git for backups. To disable backups in Git: 41 | # $ git config --global mergetool.keepBackup false 42 | *.orig 43 | 44 | # Created by git when using merge tools for conflicts 45 | *.BACKUP.* 46 | *.BASE.* 47 | *.LOCAL.* 48 | *.REMOTE.* 49 | *_BACKUP_*.txt 50 | *_BASE_*.txt 51 | *_LOCAL_*.txt 52 | *_REMOTE_*.txt 53 | 54 | ### Go ### 55 | # Binaries for programs and plugins 56 | *.exe 57 | *.exe~ 58 | *.dll 59 | *.so 60 | *.dylib 61 | 62 | # Test binary, built with `go test -c` 63 | *.test 64 | 65 | # Output of the go coverage tool, specifically when used with LiteIDE 66 | *.out 67 | 68 | ### Tags ### 69 | # Ignore tags created by etags, ctags, gtags (GNU global) and cscope 70 | TAGS 71 | .TAGS 72 | !TAGS/ 73 | tags 74 | .tags 75 | !tags/ 76 | gtags.files 77 | GTAGS 78 | GRTAGS 79 | GPATH 80 | GSYMS 81 | cscope.files 82 | cscope.out 83 | cscope.in.out 84 | cscope.po.out 85 | 86 | ### Vagrant ### 87 | # General 88 | .vagrant/ 89 | 90 | # Log files (if you are creating logs in debug mode, uncomment this) 91 | # *.log 92 | 93 | ### Vagrant Patch ### 94 | *.box 95 | 96 | ### Vim ### 97 | # Swap 98 | [._]*.s[a-v][a-z] 99 | [._]*.sw[a-p] 100 | [._]s[a-rt-v][a-z] 101 | [._]ss[a-gi-z] 102 | [._]sw[a-p] 103 | 104 | # Session 105 | Session.vim 106 | Sessionx.vim 107 | 108 | # Temporary 109 | .netrwhist 110 | *~ 111 | 112 | # Auto-generated tag files 113 | # Persistent undo 114 | [._]*.un~ 115 | 116 | # Test config file 117 | test*.hcl 118 | 119 | # vim: set filetype=conf : -------------------------------------------------------------------------------- /docs/README_TX.md: -------------------------------------------------------------------------------- 1 | # Transactions 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | `dbw` supports transactions via 6 | [RW.DoTx(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.DoTx) which 7 | uses any backoff strategy that implements the 8 | [Backoff](https://pkg.go.dev/github.com/hashicorp/go-dbw#Backoff) interface. 9 | There are two backoffs 10 | provided by the package: 11 | [ConstBackOff](https://pkg.go.dev/github.com/hashicorp/go-dbw#ConstBackoff) and 12 | [ExpBackoff](https://pkg.go.dev/github.com/hashicorp/go-dbw#ExpBackoff). 13 | 14 | ```go 15 | // Example with ExpBackoff 16 | retryErrFn := func(_ error) bool { return true } 17 | _, err = rw.DoTx( 18 | context.Background(), 19 | func(_ error) bool { return true }, // retry all errors 20 | 3, // three retries 21 | ExpBackoff{}, // exponential backoff 22 | func(w Writer) error { 23 | // the TxHandler updates the user's name 24 | _, err := w.Update(context.Background(), 25 | user, 26 | []string{"Name"}, 27 | nil, 28 | dbw.WithVersion(&user.Version), 29 | ) 30 | if err != nil { 31 | return err 32 | } 33 | }, 34 | ) 35 | if err != nil { 36 | // handle errors from the transaction... 37 | } 38 | ``` 39 | 40 | You can also control the transaction yourself using: 41 | * [RW.Begin(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Begin), 42 | * [RW.Rollback(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Rollback) 43 | * [RW.Commit(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Commit) 44 | 45 | ```go 46 | // begin a transaction 47 | tx, err := rw.Begin(ctx) 48 | 49 | // do some database operations like creating a resource 50 | if err := tx.Create(...); err != nil { 51 | 52 | // rollback the transaction if you 53 | if err := tx.Rollback(ctx); err != nil { 54 | // you'll need to handle rollback errors... perhaps via retry. 55 | } 56 | } 57 | 58 | // commit the transaction if there are not errors 59 | if err := tx.Commit(ctx); err != nil { 60 | // handle commit errors 61 | } 62 | ``` -------------------------------------------------------------------------------- /docs/README_USAGE.md: -------------------------------------------------------------------------------- 1 | # Usage highlights 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | Just some high-level usage highlights to get you started. Read the [dbw package 6 | docs](https://pkg.go.dev/github.com/hashicorp/go-dbw) for 7 | a complete list of capabilities and their documentation. 8 | 9 | ```go 10 | // initialize fields which cannot be set during creation 11 | dbw.InitNonCreatableFields([]string{"CreateTime", "UpdateTime"}) 12 | 13 | // initialize fields which cannot be updated 14 | dbw.InitNonUpdatableFields([]string{"PublicId", "CreateTime", "UpdateTime"}) 15 | 16 | // errors are intentionally ignored for brevity 17 | db, _ := dbw.Open(dialect, url) 18 | rw := dbw.New(conn) 19 | 20 | id, _ := dbw.NewId("u") 21 | user, _ := dbtest.NewTestUser() 22 | _ = rw.Create(context.Background(), user) 23 | 24 | foundUser, _ := dbtest.NewTestUser() 25 | foundUser.PublicId = id 26 | _ = rw.LookupBy(context.Background(), foundUser) 27 | 28 | where := ` 29 | with avg_version as ( 30 | select public_id, avg(version) as avg_version_for_user 31 | from test_users 32 | group by version) 33 | select u.public_id, u.name, av.avg_version_for_user 34 | from test_users u 35 | join avg_version av 36 | on u.public_id = av.public_id 37 | where name in (@names)` 38 | rows, err := rw.Query( 39 | context.Background(), 40 | where, 41 | []interface{}{ sql.Named{"names", "alice", "bob"}}, 42 | ) 43 | defer rows.Close() 44 | for rows.Next() { 45 | user := db_test.NewTestUser() 46 | _ = rw.ScanRows(rows, &user) 47 | // Do something with the user struct 48 | } 49 | 50 | user.Name = "Alice" 51 | retryErrFn := func(_ error) bool { return true } 52 | _, err = w.DoTx( 53 | context.Background(), 54 | func(_ error) bool { return true }, // retry all errors 55 | 3, // three retries 56 | ExpBackoff{}, // exponential backoff 57 | func(w Writer) error { 58 | // the TxHandler updates the user's name 59 | _, err := w.Update(context.Background(), 60 | user, 61 | []string{"Name"}, 62 | nil, 63 | dbw.WithVersion(&user.Version), 64 | ) 65 | if err != nil { 66 | return err 67 | } 68 | }, 69 | ) 70 | ``` 71 | -------------------------------------------------------------------------------- /docs/README_CREATE.md: -------------------------------------------------------------------------------- 1 | # Create 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | ## [RW.Create(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Create) example with one item 6 | ```go 7 | id, err := dbw.NewId("u") 8 | 9 | user := TestUser{PublicId: id, Name: "Alice"} 10 | 11 | var rowsAffected int64 12 | err = rw.Create(ctx, &user, dbw.WithRowsAffected(&rowsAffected)) 13 | ``` 14 | ## [RW.CreateItems(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.CreateItems) example with multiple items 15 | ```go 16 | var rowsAffected int64 17 | err = rw.CreateItems(ctx, []*dbtest.TestUser{&user1, &user2}, dbw.WithRowsAffected(&rowsAffected)) 18 | ``` 19 | 20 | 21 | ## [OnConflict](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithOnConflict) upsert example 22 | 23 | Upserts via a variety of conflict targets and actions are supported. 24 | 25 | ```go 26 | // set columns 27 | onConflict := dbw.OnConflict{ 28 | Target: dbw.Columns{"public_id"}, 29 | Action: dbw.SetColumns([]string{"name"}), 30 | } 31 | rw.Create(ctx, &user, dbw.WithConflict(&onConflict)) 32 | ``` 33 | 34 | ```go 35 | // set columns and column values 36 | onConflict := dbw.OnConflict{ 37 | Target: dbw.Columns{"public_id"}, 38 | } 39 | cv := dbw.SetColumns([]string{"name"}) 40 | cv = append( 41 | cv, 42 | dbw.SetColumnValues(map[string]interface{}{ 43 | "email": "alice@gmail.com", 44 | "phone_number": dbw.Expr("NULL"), 45 | })...) 46 | onConflict.Action = cv 47 | rw.Create(ctx, &user, dbw.WithConflict(&onConflict)) 48 | ``` 49 | 50 | ```go 51 | // do nothing 52 | onConflict := dbw.OnConflict{ 53 | Target: dbw.Columns{"public_id"}, 54 | Action: dbw.DoNothing(true), 55 | } 56 | rw.Create(ctx, &user, dbw.WithConflict(&onConflict)) 57 | ``` 58 | 59 | ```go 60 | // on constraint 61 | onConflict := dbw.OnConflict{ 62 | Target: dbw.Constraint("db_test_user_pkey"), 63 | Action: dbw.SetColumns([]string{"name"}), 64 | } 65 | rw.Create(ctx, &user, dbw.WithConflict(&onConflict)) 66 | ``` 67 | 68 | ```go 69 | // set columns combined with WithVersion 70 | onConflict := dbw.OnConflict{ 71 | Target: dbw.Columns{"public_id"}, 72 | Action: dbw.SetColumns([]string{"name"}), 73 | } 74 | version := uint32(1) 75 | rw.Create(ctx, &user, dbw.WithConflict(&onConflict), dbw.WithVersion(&version)) 76 | ``` 77 | 78 | -------------------------------------------------------------------------------- /docs/README_MODELS.md: -------------------------------------------------------------------------------- 1 | # Declaring Models 2 | [![Go 3 | Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 4 | 5 | Models are structs with basic Go types, or custom types implementing [Scanner](https://pkg.go.dev/database/sql#Scanner) and 6 | [Valuer](https://pkg.go.dev/database/sql/driver#Valuer) interfaces. Currently, 7 | Gorm V2 [Field Tags](https://gorm.io/docs/models.html#Fields-Tags) are supported when declaring models. 8 | 9 | Simple example: 10 | 11 | ```go 12 | type TestUser struct { 13 | PublicId string `json:"public_id,omitempty" gorm:"primaryKey;default:null"` 14 | CreateTime *time.Time `json:"create_time,omitempty" gorm:"default:CURRENT_TIMESTAMP"` 15 | UpdateTime *time.Time `json:"create_time,omitempty" gorm:"default:CURRENT_TIMESTAMP"` 16 | Name string `json:"name,omitempty" gorm:"default:null"` 17 | Email string `json:"name,omitempty" gorm:"default:null"` 18 | PhoneNumber string `json:"name,omitempty" gorm:"default:null"` 19 | Version uint32 `json:"version,omitempty" gorm:"default:null"` 20 | } 21 | 22 | ``` 23 | A more complicated example that uses an embedded protobuf: 24 | ```go 25 | type TestUser struct { 26 | *StoreTestUser 27 | } 28 | 29 | // TestUser model 30 | type StoreTestUser struct { 31 | state protoimpl.MessageState 32 | sizeCache protoimpl.SizeCache 33 | unknownFields protoimpl.UnknownFields 34 | 35 | // @inject_tag: gorm:"primaryKey;default:null" 36 | PublicId string `protobuf:"bytes,4,opt,name=public_id,json=publicId,proto3" json:"public_id,omitempty" gorm:"primaryKey;default:null"` 37 | 38 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 39 | CreateTime *Timestamp `protobuf:"bytes,2,opt,name=create_time,json=createTime,proto3" json:"create_time,omitempty" gorm:"default:CURRENT_TIMESTAMP"` 40 | 41 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 42 | UpdateTime *Timestamp `protobuf:"bytes,3,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty" gorm:"default:CURRENT_TIMESTAMP"` 43 | 44 | // @inject_tag: `gorm:"default:null"` 45 | Name string `protobuf:"bytes,5,opt,name=name,proto3" json:"name,omitempty" gorm:"default:null"` 46 | 47 | // @inject_tag: `gorm:"default:null"` 48 | Version uint32 `protobuf:"varint,8,opt,name=version,proto3" json:"version,omitempty" gorm:"default:null"` 49 | } 50 | 51 | ``` 52 | -------------------------------------------------------------------------------- /do_tx.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "time" 10 | ) 11 | 12 | // DoTx will wrap the Handler func passed within a transaction with retries 13 | // you should ensure that any objects written to the db in your TxHandler are retryable, which 14 | // means that the object may be sent to the db several times (retried), so 15 | // things like the primary key may need to be reset before retry. 16 | func (rw *RW) DoTx(ctx context.Context, retryErrorsMatchingFn func(error) bool, retries uint, backOff Backoff, handler TxHandler) (RetryInfo, error) { 17 | const op = "dbw.DoTx" 18 | if rw.underlying == nil { 19 | return RetryInfo{}, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) 20 | } 21 | if backOff == nil { 22 | return RetryInfo{}, fmt.Errorf("%s: missing backoff: %w", op, ErrInvalidParameter) 23 | } 24 | if handler == nil { 25 | return RetryInfo{}, fmt.Errorf("%s: missing handler: %w", op, ErrInvalidParameter) 26 | } 27 | if retryErrorsMatchingFn == nil { 28 | return RetryInfo{}, fmt.Errorf("%s: missing retry errors matching function: %w", op, ErrInvalidParameter) 29 | } 30 | info := RetryInfo{} 31 | for attempts := uint(1); ; attempts++ { 32 | if attempts > retries+1 { 33 | return info, fmt.Errorf("%s: too many retries: %d of %d: %w", op, attempts-1, retries+1, ErrMaxRetries) 34 | } 35 | 36 | // step one of this, start a transaction... 37 | newTx := rw.underlying.wrapped.WithContext(ctx) 38 | newTx = newTx.Begin() 39 | 40 | newRW := &RW{underlying: &DB{newTx}} 41 | if err := handler(newRW, newRW); err != nil { 42 | if err := newTx.Rollback().Error; err != nil { 43 | return info, fmt.Errorf("%s: %w", op, err) 44 | } 45 | if retry := retryErrorsMatchingFn(err); retry { 46 | d := backOff.Duration(attempts) 47 | info.Retries++ 48 | info.Backoff = info.Backoff + d 49 | select { 50 | case <-ctx.Done(): 51 | return info, fmt.Errorf("%s: cancelled: %w", op, err) 52 | case <-time.After(d): 53 | continue 54 | } 55 | } 56 | return info, fmt.Errorf("%s: %w", op, err) 57 | } 58 | 59 | if err := newTx.Commit().Error; err != nil { 60 | if err := newTx.Rollback().Error; err != nil { 61 | return info, fmt.Errorf("%s: %w", op, err) 62 | } 63 | return info, fmt.Errorf("%s: %w", op, err) 64 | } 65 | return info, nil // it all worked!!! 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/hashicorp/go-dbw 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.2 7 | github.com/favadi/protoc-go-inject-tag v1.3.0 8 | github.com/google/go-cmp v0.6.0 9 | github.com/hashicorp/go-hclog v1.6.3 10 | github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 11 | github.com/jackc/pgconn v1.14.3 12 | github.com/jackc/pgx/v5 v5.7.4 13 | github.com/oligot/go-mod-upgrade v0.6.1 14 | github.com/stretchr/testify v1.10.0 15 | github.com/xo/dburl v0.23.7 16 | golang.org/x/crypto v0.37.0 17 | golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d 18 | google.golang.org/protobuf v1.34.2 19 | gorm.io/driver/postgres v1.5.11 20 | gorm.io/driver/sqlite v1.5.7 21 | gorm.io/gorm v1.25.12 22 | mvdan.cc/gofumpt v0.2.0 23 | ) 24 | 25 | require ( 26 | github.com/AlecAivazis/survey/v2 v2.2.9 // indirect 27 | github.com/Masterminds/semver/v3 v3.1.1 // indirect 28 | github.com/apex/log v1.9.0 // indirect 29 | github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect 30 | github.com/davecgh/go-spew v1.1.1 // indirect 31 | github.com/fatih/color v1.18.0 // indirect 32 | github.com/hashicorp/go-uuid v1.0.3 // indirect 33 | github.com/jackc/chunkreader/v2 v2.0.1 // indirect 34 | github.com/jackc/pgio v1.0.0 // indirect 35 | github.com/jackc/pgpassfile v1.0.0 // indirect 36 | github.com/jackc/pgproto3/v2 v2.3.3 // indirect 37 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 38 | github.com/jackc/puddle/v2 v2.2.2 // indirect 39 | github.com/jinzhu/inflection v1.0.0 // indirect 40 | github.com/jinzhu/now v1.1.5 // indirect 41 | github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect 42 | github.com/kr/pty v1.1.8 // indirect 43 | github.com/mattn/go-colorable v0.1.14 // indirect 44 | github.com/mattn/go-isatty v0.0.20 // indirect 45 | github.com/mattn/go-sqlite3 v1.14.28 // indirect 46 | github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b // indirect 47 | github.com/pkg/errors v0.9.1 // indirect 48 | github.com/pmezard/go-difflib v1.0.0 // indirect 49 | github.com/russross/blackfriday/v2 v2.0.1 // indirect 50 | github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect 51 | github.com/urfave/cli/v2 v2.3.0 // indirect 52 | golang.org/x/mod v0.17.0 // indirect 53 | golang.org/x/sync v0.13.0 // indirect 54 | golang.org/x/sys v0.32.0 // indirect 55 | golang.org/x/term v0.31.0 // indirect 56 | golang.org/x/text v0.24.0 // indirect 57 | gopkg.in/yaml.v3 v3.0.1 // indirect 58 | ) 59 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Determine this makefile's path. 2 | # Be sure to place this BEFORE `include` directives, if any. 3 | THIS_FILE := $(lastword $(MAKEFILE_LIST)) 4 | THIS_DIR := $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) 5 | 6 | TMP_DIR := $(shell mktemp -d) 7 | REPO_PATH := github.com/hashicorp/dbw 8 | 9 | .PHONY: tools 10 | tools: 11 | go generate -tags tools tools/tools.go 12 | go install github.com/bufbuild/buf/cmd/buf@v1.15.1 13 | go install github.com/hashicorp/copywrite@v0.15.0 14 | 15 | .PHONY: fmt 16 | fmt: 17 | gofumpt -w $$(find . -name '*.go' ! -name '*pb.go') 18 | buf format -w 19 | 20 | .PHONY: copywrite 21 | copywrite: 22 | copywrite headers 23 | 24 | .PHONY: gen 25 | gen: proto fmt copywrite 26 | 27 | .PHONY: test 28 | test: 29 | go test -race -count=1 ./... 30 | 31 | .PHONY: test-all 32 | test-all: test-sqlite test-postgres 33 | 34 | .PHONY: test-sqlite 35 | test-sqlite: 36 | DB_DIALECT=sqlite go test -race -count=1 ./... 37 | 38 | .PHONY: test-postgres 39 | test-postgres: 40 | ############################################################## 41 | # this test is dependent on first running: docker-compose up 42 | ############################################################## 43 | DB_DIALECT=postgres DB_DSN="postgresql://go_db:go_db@localhost:9920/go_db?sslmode=disable" go test -race -count=1 ./... 44 | 45 | ### db tags requires protoc-gen-go v1.20.0 or later 46 | # GO111MODULE=on go get -u github.com/golang/protobuf/protoc-gen-go@v1.40 47 | .PHONY: proto 48 | proto: protolint protobuild 49 | 50 | .PHONY: protobuild 51 | protobuild: 52 | buf generate 53 | @protoc-go-inject-tag -input=./internal/dbtest/dbtest.pb.go 54 | 55 | .PHONY: protolint 56 | protolint: 57 | @buf lint 58 | # if/when this becomes a public repo, we can add this check 59 | # @buf check breaking --against 60 | # 'https://github.com/hashicorp/go-dbw.git#branch=main' 61 | 62 | # coverage-diff will run a new coverage report and check coverage.log to see if 63 | # the coverage has changed. 64 | .PHONY: coverage-diff 65 | coverage-diff: 66 | cd coverage && \ 67 | ./coverage.sh && \ 68 | ./cov-diff.sh coverage.log && \ 69 | if ./cov-diff.sh ./coverage.log; then git restore coverage.log; fi 70 | 71 | # coverage will generate a report, badge and log. when you make changes, run 72 | # this and check in the changes to publish a new/latest coverage report and 73 | # badge. 74 | .PHONY: coverage 75 | coverage: 76 | cd coverage && \ 77 | ./coverage.sh && \ 78 | if ./cov-diff.sh ./coverage.log; then git restore coverage.log; fi 79 | -------------------------------------------------------------------------------- /lookup.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | 10 | "gorm.io/gorm" 11 | ) 12 | 13 | // LookupBy will lookup a resource by it's primary keys, which must be 14 | // unique. If the resource implements either ResourcePublicIder or 15 | // ResourcePrivateIder interface, then they are used as the resource's 16 | // primary key for lookup. Otherwise, the resource tags are used to 17 | // determine it's primary key(s) for lookup. The WithDebug and WithTable 18 | // options are supported. 19 | func (rw *RW) LookupBy(ctx context.Context, resourceWithIder interface{}, opt ...Option) error { 20 | const op = "dbw.LookupById" 21 | if rw.underlying == nil { 22 | return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) 23 | } 24 | if err := raiseErrorOnHooks(resourceWithIder); err != nil { 25 | return fmt.Errorf("%s: %w", op, err) 26 | } 27 | if err := validateResourcesInterface(resourceWithIder); err != nil { 28 | return fmt.Errorf("%s: %w", op, err) 29 | } 30 | where, keys, err := rw.primaryKeysWhere(ctx, resourceWithIder) 31 | if err != nil { 32 | return fmt.Errorf("%s: %w", op, err) 33 | } 34 | opts := GetOpts(opt...) 35 | db := rw.underlying.wrapped.WithContext(ctx) 36 | if opts.WithTable != "" { 37 | db = db.Table(opts.WithTable) 38 | } 39 | if opts.WithDebug { 40 | db = db.Debug() 41 | } 42 | rw.clearDefaultNullResourceFields(ctx, resourceWithIder) 43 | if err := db.Where(where, keys...).First(resourceWithIder).Error; err != nil { 44 | if err == gorm.ErrRecordNotFound { 45 | return fmt.Errorf("%s: %w", op, ErrRecordNotFound) 46 | } 47 | return fmt.Errorf("%s: %w", op, err) 48 | } 49 | return nil 50 | } 51 | 52 | // LookupByPublicId will lookup resource by its public_id, which must be unique. 53 | // The WithTable option is supported. 54 | func (rw *RW) LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error { 55 | return rw.LookupBy(ctx, resource, opt...) 56 | } 57 | 58 | func (rw *RW) lookupAfterWrite(ctx context.Context, i interface{}, opt ...Option) error { 59 | const op = "dbw.lookupAfterWrite" 60 | opts := GetOpts(opt...) 61 | withLookup := opts.WithLookup 62 | if err := raiseErrorOnHooks(i); err != nil { 63 | return fmt.Errorf("%s: %w", op, err) 64 | } 65 | if !withLookup { 66 | return nil 67 | } 68 | if err := rw.LookupBy(ctx, i, opt...); err != nil { 69 | return fmt.Errorf("%s: %w", op, err) 70 | } 71 | return nil 72 | } 73 | -------------------------------------------------------------------------------- /scripts/protoc_gen_plugin.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) HashiCorp, Inc. 3 | # SPDX-License-Identifier: MPL-2.0 4 | 5 | 6 | # Copied with minor changes from Makego at https://github.com/bufbuild/makego/blob/master/make/go/scripts/protoc_gen_plugin.bash 7 | 8 | set -eo pipefail 9 | 10 | fail() { 11 | echo "$@" >&2 12 | exit 1 13 | } 14 | 15 | usage() { 16 | echo "usage: ${0} \ 17 | --proto_path=path/to/one \ 18 | --proto_path=path/to/two \ 19 | --proto_include_path=path/to/one \ 20 | --proto_include_path=path/to/two \ 21 | --plugin_name=go \ 22 | --plugin_out=gen/proto/go \ 23 | --plugin_opt=plugins=grpc" 24 | } 25 | 26 | check_flag_value_set() { 27 | if [ -z "${1}" ]; then 28 | usage 29 | exit 1 30 | fi 31 | } 32 | 33 | PROTO_PATHS=() 34 | PROTO_INCLUDE_PATHS=() 35 | PLUGIN_NAME= 36 | PLUGIN_OUT= 37 | PLUGIN_OPT= 38 | while test $# -gt 0; do 39 | case "${1}" in 40 | -h|--help) 41 | usage 42 | exit 0 43 | ;; 44 | --proto_path*) 45 | PROTO_PATHS+=("$(echo ${1} | sed -e 's/^[^=]*=//g')") 46 | shift 47 | ;; 48 | --proto_include_path*) 49 | PROTO_INCLUDE_PATHS+=("$(echo ${1} | sed -e 's/^[^=]*=//g')") 50 | shift 51 | ;; 52 | --plugin_name*) 53 | PLUGIN_NAME="$(echo ${1} | sed -e 's/^[^=]*=//g')" 54 | shift 55 | ;; 56 | --plugin_out*) 57 | PLUGIN_OUT="$(echo ${1} | sed -e 's/^[^=]*=//g')" 58 | shift 59 | ;; 60 | --plugin_opt*) 61 | PLUGIN_OPT="$(echo ${1} | sed -e 's/^[^=]*=//g')" 62 | shift 63 | ;; 64 | *) 65 | usage 66 | exit 1 67 | ;; 68 | esac 69 | done 70 | 71 | check_flag_value_set "${PROTO_PATHS[@]}" 72 | check_flag_value_set "${PLUGIN_NAME}" 73 | check_flag_value_set "${PLUGIN_OUT}" 74 | 75 | PROTOC_FLAGS=() 76 | for proto_path in "${PROTO_PATHS[@]}"; do 77 | PROTOC_FLAGS+=("--proto_path=${proto_path}") 78 | done 79 | for proto_path in "${PROTO_INCLUDE_PATHS[@]}"; do 80 | PROTOC_FLAGS+=("--proto_path=${proto_path}") 81 | done 82 | PROTOC_FLAGS+=("--${PLUGIN_NAME}_out=${PLUGIN_OUT}") 83 | if [ -n "${PLUGIN_OPT}" ]; then 84 | PROTOC_FLAGS+=("--${PLUGIN_NAME}_opt=${PLUGIN_OPT}") 85 | fi 86 | 87 | for proto_path in "${PROTO_PATHS[@]}"; do 88 | for dir in $(find "${proto_path}" -name '*.proto' -print0 | xargs -0 -n1 dirname | sort | uniq); do 89 | echo protoc "${PROTOC_FLAGS[@]}" $(find "${dir}" -name '*.proto') 90 | protoc "${PROTOC_FLAGS[@]}" $(find "${dir}" -name '*.proto') 91 | done 92 | done 93 | -------------------------------------------------------------------------------- /reader.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | ) 10 | 11 | // Reader interface defines lookups/searching for resources 12 | type Reader interface { 13 | // LookupBy will lookup a resource by it's primary keys, which must be 14 | // unique. If the resource implements either ResourcePublicIder or 15 | // ResourcePrivateIder interface, then they are used as the resource's 16 | // primary key for lookup. Otherwise, the resource tags are used to 17 | // determine it's primary key(s) for lookup. 18 | LookupBy(ctx context.Context, resource interface{}, opt ...Option) error 19 | 20 | // LookupByPublicId will lookup resource by its public_id which must be unique. 21 | LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error 22 | 23 | // LookupWhere will lookup and return the first resource using a where clause with parameters 24 | LookupWhere(ctx context.Context, resource interface{}, where string, args []interface{}, opt ...Option) error 25 | 26 | // SearchWhere will search for all the resources it can find using a where 27 | // clause with parameters. Supports the WithLimit option. If 28 | // WithLimit < 0, then unlimited results are returned. If WithLimit == 0, then 29 | // default limits are used for results. 30 | SearchWhere(ctx context.Context, resources interface{}, where string, args []interface{}, opt ...Option) error 31 | 32 | // Query will run the raw query and return the *sql.Rows results. Query will 33 | // operate within the context of any ongoing transaction for the dbw.Reader. The 34 | // caller must close the returned *sql.Rows. Query can/should be used in 35 | // combination with ScanRows. 36 | Query(ctx context.Context, sql string, values []interface{}, opt ...Option) (*sql.Rows, error) 37 | 38 | // ScanRows will scan sql rows into the interface provided 39 | ScanRows(rows *sql.Rows, result interface{}) error 40 | 41 | // Dialect returns the dialect and raw connection name of the underlying database. 42 | Dialect() (_ DbType, rawName string, _ error) 43 | } 44 | 45 | // ResourcePublicIder defines an interface that LookupByPublicId() and 46 | // LookupBy() can use to get the resource's public id. 47 | type ResourcePublicIder interface { 48 | GetPublicId() string 49 | } 50 | 51 | // ResourcePrivateIder defines an interface that LookupBy() can use to get the 52 | // resource's private id. 53 | type ResourcePrivateIder interface { 54 | GetPrivateId() string 55 | } 56 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | # Label of the container job 11 | sqlite: 12 | strategy: 13 | matrix: 14 | go: ["1.24", "1.23", "1.22"] 15 | platform: [ubuntu-latest] # can not run in windows OS 16 | runs-on: ${{ matrix.platform }} 17 | 18 | steps: 19 | - name: Set up Go 1.x 20 | uses: actions/setup-go@4d34df0c2316fe8122ab82dc22947d607c0c91f9 # v4.0.0 21 | with: 22 | go-version: ${{ matrix.go }} 23 | 24 | - name: Check out code into the Go module directory 25 | uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 26 | 27 | - name: go mod package cache 28 | uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 29 | with: 30 | path: ~/go/pkg/mod 31 | key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} 32 | 33 | - name: Tests 34 | run: make test-sqlite 35 | 36 | - name: Coverage 37 | run: | 38 | make coverage-diff 39 | 40 | postgres: 41 | strategy: 42 | matrix: 43 | dbversion: ["postgres:latest"] 44 | go: ["1.24", "1.23", "1.22"] 45 | platform: [ubuntu-latest] # can not run in macOS and Windows 46 | runs-on: ${{ matrix.platform }} 47 | 48 | services: 49 | postgres: 50 | image: ${{ matrix.dbversion }} 51 | env: 52 | POSTGRES_PASSWORD: go_db 53 | POSTGRES_USER: go_db 54 | POSTGRES_DB: go_db 55 | ports: 56 | - 9920:5432 57 | # Set health checks to wait until postgres has started 58 | options: >- 59 | --health-cmd pg_isready 60 | --health-interval 10s 61 | --health-timeout 5s 62 | --health-retries 5 63 | 64 | steps: 65 | - name: Set up Go 1.x 66 | uses: actions/setup-go@4d34df0c2316fe8122ab82dc22947d607c0c91f9 # v4.0.0 67 | with: 68 | go-version: ${{ matrix.go }} 69 | 70 | - name: Check out code into the Go module directory 71 | uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 72 | 73 | - name: go mod package cache 74 | uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 75 | with: 76 | path: ~/go/pkg/mod 77 | key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} 78 | 79 | - name: Tests 80 | run: make test-postgres 81 | -------------------------------------------------------------------------------- /docs/README_UPDATE.md: -------------------------------------------------------------------------------- 1 | # Update 2 | [![Go Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 3 | 4 | [Update(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Update) requires 5 | the resource to be updated with its fields set that the caller wants updated. 6 | 7 | A `fieldMask` is optional and provides paths for fields that 8 | should be updated. 9 | 10 | A `setToNullPaths` is optional and provides paths for the fields that should be 11 | set to null. 12 | 13 | Either a `fieldMaskPaths` or `setToNullPaths` must be provided and they must not intersect. 14 | 15 | The caller is responsible for the transaction life cycle of the writer and if an 16 | error is returned the caller must decide what to do with the transaction, which 17 | almost always should be to rollback. Update returns the number of rows updated. 18 | 19 | There a lots of supported options: 20 | [WithBeforeWrite](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithBeforeWrite), 21 | [WithAfterWrite](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithAfterWrite), 22 | [WithWhere](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithWhere), 23 | [WithDebug](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithDebug), and 24 | [WithVersion](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithVersion). 25 | 26 | If [WithVersion](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithVersion) is 27 | used, then the update will include the version number in the 28 | update where clause, which basically makes the update use optimistic locking and 29 | the update will only succeed if the existing rows version matches the 30 | [WithVersion](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithVersion) 31 | option. Zero is not a valid value for the 32 | [WithVersion](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithVersion) option 33 | and will return an error. 34 | 35 | [WithWhere](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithWhere) allows 36 | specifying an additional constraint on the operation in 37 | addition to the PKs. 38 | 39 | [WithDebug](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithDebug) will turn 40 | on debugging for the update call. 41 | 42 | ### Simple update [WithVersion](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithVersion) example 43 | ```go 44 | user.Name = "Alice" 45 | rowsAffected, err = rw.Update(ctx, 46 | &user, 47 | []string{"Name"}, 48 | nil, 49 | dbw.WithVersion(&user.Version)) 50 | ``` 51 | 52 | ### Update with setToNullPaths and `WithVersion` example 53 | ```go 54 | user.Name = "Alice" 55 | rowsAffected, err = rw.Update(ctx, 56 | &user, 57 | nil, 58 | []string{"Name"}, 59 | dbw.WithVersion(&user.Version)) 60 | ``` -------------------------------------------------------------------------------- /id_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw_test 5 | 6 | import ( 7 | "strings" 8 | "testing" 9 | 10 | "github.com/hashicorp/go-dbw" 11 | "github.com/stretchr/testify/assert" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestNewId(t *testing.T) { 16 | type args struct { 17 | prefix string 18 | } 19 | tests := []struct { 20 | name string 21 | args args 22 | wantErr bool 23 | wantLen int 24 | }{ 25 | { 26 | name: "valid", 27 | args: args{ 28 | prefix: "id", 29 | }, 30 | wantErr: false, 31 | wantLen: 10 + len("id_"), 32 | }, 33 | { 34 | name: "bad-prefix", 35 | args: args{ 36 | prefix: "", 37 | }, 38 | wantErr: true, 39 | wantLen: 0, 40 | }, 41 | } 42 | for _, tt := range tests { 43 | t.Run(tt.name, func(t *testing.T) { 44 | got, err := dbw.NewId(tt.args.prefix) 45 | if (err != nil) != tt.wantErr { 46 | t.Errorf("NewPublicId() error = %v, wantErr %v", err, tt.wantErr) 47 | return 48 | } 49 | if !tt.wantErr && !strings.HasPrefix(got, tt.args.prefix+"_") { 50 | t.Errorf("NewPublicId() = %v, wanted it to start with %v", got, tt.args.prefix) 51 | } 52 | if len(got) != tt.wantLen { 53 | t.Errorf("NewPublicId() = %v, with len of %d and wanted len of %v", got, len(got), tt.wantLen) 54 | } 55 | }) 56 | } 57 | } 58 | 59 | func TestPseudoRandomId(t *testing.T) { 60 | type args struct { 61 | prngValues []string 62 | } 63 | tests := []struct { 64 | name string 65 | args args 66 | sameAsPrev bool 67 | }{ 68 | { 69 | name: "valid first", 70 | args: args{}, 71 | }, 72 | { 73 | name: "valid second", 74 | args: args{}, 75 | }, 76 | { 77 | name: "first prng", 78 | args: args{prngValues: []string{"foo", "bar"}}, 79 | }, 80 | { 81 | name: "first prng verify", 82 | args: args{prngValues: []string{"foo", "bar"}}, 83 | sameAsPrev: true, 84 | }, 85 | { 86 | name: "second prng", 87 | args: args{prngValues: []string{"bar", "foo"}}, 88 | }, 89 | { 90 | name: "second prng verify", 91 | args: args{prngValues: []string{"bar", "foo"}}, 92 | sameAsPrev: true, 93 | }, 94 | } 95 | var prevTestValue string 96 | for _, tt := range tests { 97 | t.Run(tt.name, func(t *testing.T) { 98 | assert, require := assert.New(t), require.New(t) 99 | got, err := dbw.NewId("id", dbw.WithPrngValues(tt.args.prngValues)) 100 | require.NoError(err) 101 | if tt.sameAsPrev { 102 | assert.Equal(prevTestValue, got) 103 | } 104 | prevTestValue = got 105 | }) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /coverage/coverage.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euo pipefail 3 | 4 | # Copyright (c) 2023 Nuno Cruces 5 | 6 | # This is a simple script to generate an HTML coverage report, 7 | # and SVG badge for your Go project. 8 | # 9 | # It's meant to be used manually or as a pre-commit hook. 10 | # 11 | # Place it some where in your code tree and execute it. 12 | # If your tests pass, next to the script you'll find 13 | # the coverage.html report and coverage.svg badge. 14 | # 15 | # You can add the badge to your README.md as such: 16 | # [![Go Coverage](PATH_TO/coverage.svg)](https://raw.githack.com/URL/coverage.html) 17 | # 18 | # Visit https://raw.githack.com/ to find the correct URL. 19 | # 20 | # To have the script run as a pre-commmit hook, 21 | # symlink the script to .git/hooks/pre-commit: 22 | # 23 | # ln -s PATH_TO/coverage.sh .git/hooks/pre-commit 24 | # 25 | # Or, if you have other pre-commit hooks, 26 | # call it from your main hook. 27 | 28 | # Get the script's directory after resolving a possible symlink. 29 | SCRIPT_DIR="$(dirname -- "$(readlink -f "${BASH_SOURCE[0]}")")" 30 | 31 | OUT_DIR="${1-$SCRIPT_DIR}" 32 | OUT_FILE="$(mktemp)" 33 | 34 | # Get coverage for all packages in the current directory; store next to script. 35 | cd .. && go test -coverpkg "$(go list)" -coverprofile "$OUT_FILE" 36 | 37 | if [[ "${INPUT_REPORT-true}" == "true" ]]; then 38 | # Create an HTML report; store next to script. 39 | go tool cover -html="$OUT_FILE" -o "$OUT_DIR/coverage.html" 40 | fi 41 | 42 | # Extract total coverage: the decimal number from the last line of the function report. 43 | COVERAGE=$(go tool cover -func="$OUT_FILE" | tail -1 | grep -Eo '[0-9]+\.[0-9]') 44 | 45 | echo "coverage: $COVERAGE% of statements" 46 | 47 | date "+%s,$COVERAGE" >> "$OUT_DIR/coverage.log" 48 | sort -u -o "$OUT_DIR/coverage.log" "$OUT_DIR/coverage.log" 49 | 50 | # Pick a color for the badge. 51 | if awk "BEGIN {exit !($COVERAGE >= 90)}"; then 52 | COLOR=brightgreen 53 | elif awk "BEGIN {exit !($COVERAGE >= 80)}"; then 54 | COLOR=green 55 | elif awk "BEGIN {exit !($COVERAGE >= 70)}"; then 56 | COLOR=yellowgreen 57 | elif awk "BEGIN {exit !($COVERAGE >= 60)}"; then 58 | COLOR=yellow 59 | elif awk "BEGIN {exit !($COVERAGE >= 50)}"; then 60 | COLOR=orange 61 | else 62 | COLOR=red 63 | fi 64 | 65 | # Download the badge; store next to script. 66 | curl -s "https://img.shields.io/badge/coverage-$COVERAGE%25-$COLOR" > "$OUT_DIR/coverage.svg" 67 | 68 | if [[ "${INPUT_CHART-false}" == "true" ]]; then 69 | # Download the chart; store next to script. 70 | curl -s -H "Content-Type: text/plain" --data-binary "@$OUT_DIR/coverage.log" \ 71 | https://go-coverage-report.nunocruces.workers.dev/chart/ > \ 72 | "$OUT_DIR/coverage-chart.svg" 73 | fi 74 | 75 | # When running as a pre-commit hook, add the report and badge to the commit. 76 | if [[ -n "${GIT_INDEX_FILE-}" ]]; then 77 | git add "$OUT_DIR/coverage.html" "$OUT_DIR/coverage.svg" 78 | fi -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dbw package 2 | [![Go Reference](https://pkg.go.dev/badge/github.com/hashicorp/go-dbw.svg)](https://pkg.go.dev/github.com/hashicorp/go-dbw) 3 | [![Go Report Card](https://goreportcard.com/badge/github.com/hashicorp/go-dbw)](https://goreportcard.com/report/github.com/hashicorp/go-dbw) 4 | [![Go Coverage](https://raw.githack.com/hashicorp/go-dbw/main/coverage/coverage.svg)](https://raw.githack.com/hashicorp/go-dbw/main/coverage/coverage.html) 5 | 6 | [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) is a database wrapper that 7 | supports connecting and using any database with a 8 | [GORM](https://github.com/go-gorm/gorm) driver. 9 | 10 | [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) is intended to completely 11 | encapsulate an application's access to its database with the exception of 12 | migrations. [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) is 13 | intentionally not an ORM and it removes typical ORM abstractions like "advanced 14 | query building", associations and migrations. 15 | 16 | Of course you can use [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) for 17 | complicated queries, it's just that 18 | [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) doesn't try to reinvent 19 | SQL by providing some sort of pattern for building them with functions. Of 20 | course, [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) also provides 21 | lookup/search functions when you simply need to read resources from the 22 | database. 23 | 24 | [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) strives to make CRUD for 25 | database resources fairly trivial for common use cases. It also supports an 26 | [WithOnConflict(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#WithOnConflict) 27 | option for its 28 | [RW.Create(...)](https://pkg.go.dev/github.com/hashicorp/go-dbw#RW.Create) 29 | function for complex scenarios. [dbw](https://pkg.go.dev/github.com/hashicorp/go-dbw) also allows you to opt out of its CRUD 30 | functions and use exec, query and scan rows directly. You may want to carefully 31 | weigh when it's appropriate to use exec and query directly, since it's likely that 32 | each time you use them you're leaking a bit of your database layer schema into 33 | your application's domain. 34 | 35 | * [Usage highlights](./docs/README_USAGE.md) 36 | * [Declaring Models](./docs/README_MODELS.md) 37 | * [Connecting to a Database](./docs/README_OPEN.md) 38 | * [Options](./docs/README_OPTIONS.md) 39 | * [NonCreatable and NonUpdatable](./docs/README_INITFIELDS.md) 40 | * [Readers and Writers](./docs/README_RW.md) 41 | * [Create](./docs/README_CREATE.md) 42 | * [Read](./docs/README_READ.md) 43 | * [Update](./docs/README_UPDATE.md) 44 | * [Delete](./docs/README_DELETE.md) 45 | * [Queries](./docs/README_QUERY.md) 46 | * [Transactions](./docs/README_TX.md) 47 | * [Hooks](./docs/README_HOOKS.md) 48 | * [Optimistic locking for write operations](./docs/README_LOCKS.md) 49 | * [Debug output](./docs/README_DEBUG.md) 50 | -------------------------------------------------------------------------------- /rw_unexported_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | func TestRW_whereClausesFromOpts(t *testing.T) { 15 | db, _ := TestSetup(t) 16 | testCtx := context.Background() 17 | type testUser struct { 18 | Version int 19 | } 20 | 21 | tests := []struct { 22 | name string 23 | rw *RW 24 | i interface{} 25 | opts Options 26 | wantWhere string 27 | wantArgs []interface{} 28 | wantErr bool 29 | }{ 30 | { 31 | name: "with-version-with-table-on-conflict", 32 | rw: New(db), 33 | i: &testUser{}, 34 | opts: Options{ 35 | WithVersion: func() *uint32 { i := uint32(1); return &i }(), 36 | WithTable: "test_table", 37 | WithOnConflict: &OnConflict{}, 38 | }, 39 | wantWhere: "test_table.version = ?", 40 | wantArgs: []interface{}{func() *uint32 { i := uint32(1); return &i }()}, 41 | }, 42 | { 43 | name: "with-version-with-table", 44 | rw: New(db), 45 | i: &testUser{}, 46 | opts: Options{ 47 | WithVersion: func() *uint32 { i := uint32(1); return &i }(), 48 | WithTable: "test_table", 49 | }, 50 | wantWhere: "version = ?", 51 | wantArgs: []interface{}{func() *uint32 { i := uint32(1); return &i }()}, 52 | }, 53 | } 54 | 55 | for _, tt := range tests { 56 | t.Run(tt.name, func(t *testing.T) { 57 | assert, require := assert.New(t), require.New(t) 58 | where, whereArgs, err := tt.rw.whereClausesFromOpts(testCtx, tt.i, tt.opts) 59 | if tt.wantErr { 60 | require.NoError(err) 61 | assert.Empty(where) 62 | assert.Empty(whereArgs) 63 | return 64 | } 65 | require.NoError(err) 66 | assert.Equal(tt.wantWhere, where) 67 | assert.Equal(tt.wantArgs, whereArgs) 68 | }) 69 | } 70 | } 71 | 72 | func Test_validateResourcesInterface(t *testing.T) { 73 | t.Parallel() 74 | tests := []struct { 75 | name string 76 | resources interface{} 77 | wantErr bool 78 | wantErrContains string 79 | }{ 80 | { 81 | name: "not-ptr-to-slice", 82 | resources: []*string{}, 83 | wantErrContains: "interface parameter must to be a pointer:", 84 | }, 85 | { 86 | name: "not-ptr", 87 | resources: "string", 88 | wantErrContains: "interface parameter must to be a pointer:", 89 | }, 90 | { 91 | name: "not-slice-of-ptrs", 92 | resources: &[]string{}, 93 | wantErrContains: "interface parameter is a slice, but the elements of the slice are not pointers", 94 | }, 95 | { 96 | name: "success-ptr-to-slice-of-ptrs", 97 | resources: &[]*string{}, 98 | }, 99 | { 100 | name: "success-ptr", 101 | resources: func() interface{} { 102 | s := "s" 103 | return &s 104 | }(), 105 | }, 106 | } 107 | for _, tc := range tests { 108 | t.Run(tc.name, func(t *testing.T) { 109 | assert := assert.New(t) 110 | err := validateResourcesInterface(tc.resources) 111 | if tc.wantErr { 112 | assert.Error(err) 113 | if tc.wantErrContains != "" { 114 | assert.Contains(err.Error(), tc.wantErrContains) 115 | } 116 | } 117 | }) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw_test 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "testing" 10 | 11 | "github.com/hashicorp/go-dbw" 12 | "github.com/hashicorp/go-dbw/internal/dbtest" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | func TestDb_Query(t *testing.T) { 18 | t.Parallel() 19 | const ( 20 | insert = "insert into db_test_user (public_id, name) values(@public_id, @name)" 21 | query = "select * from db_test_user where name in (?, ?)" 22 | ) 23 | testCtx := context.Background() 24 | conn, _ := dbw.TestSetup(t) 25 | t.Run("valid", func(t *testing.T) { 26 | assert, require := assert.New(t), require.New(t) 27 | rw := dbw.New(conn) 28 | publicId, err := dbw.NewId("u") 29 | require.NoError(err) 30 | rowsAffected, err := rw.Exec(testCtx, insert, []interface{}{ 31 | sql.Named("public_id", publicId), 32 | sql.Named("name", "alice"), 33 | }) 34 | require.NoError(err) 35 | require.Equal(1, rowsAffected) 36 | rows, err := rw.Query(testCtx, query, []interface{}{"alice", "bob"}, dbw.WithDebug(true)) 37 | require.NoError(err) 38 | defer func() { err := rows.Close(); assert.NoError(err) }() 39 | for rows.Next() { 40 | u, err := dbtest.NewTestUser() 41 | require.NoError(err) 42 | // scan the row into your struct 43 | err = rw.ScanRows(rows, &u) 44 | require.NoError(err) 45 | assert.Equal(publicId, u.PublicId) 46 | } 47 | }) 48 | t.Run("missing-sql", func(t *testing.T) { 49 | assert, require := assert.New(t), require.New(t) 50 | rw := dbw.New(conn) 51 | got, err := rw.Query(testCtx, "", nil) 52 | require.Error(err) 53 | assert.Zero(got) 54 | }) 55 | t.Run("missing-underlying-db", func(t *testing.T) { 56 | assert, require := assert.New(t), require.New(t) 57 | rw := dbw.RW{} 58 | got, err := rw.Query(testCtx, "", nil) 59 | require.Error(err) 60 | assert.Zero(got) 61 | }) 62 | t.Run("bad-sql", func(t *testing.T) { 63 | assert, require := assert.New(t), require.New(t) 64 | rw := dbw.New(conn) 65 | got, err := rw.Query(testCtx, "from", nil) 66 | require.Error(err) 67 | assert.Zero(got) 68 | }) 69 | } 70 | 71 | func TestDb_ScanRows(t *testing.T) { 72 | t.Parallel() 73 | testCtx := context.Background() 74 | conn, _ := dbw.TestSetup(t) 75 | rw := dbw.New(conn) 76 | t.Run("valid", func(t *testing.T) { 77 | assert, require := assert.New(t), require.New(t) 78 | user, err := dbtest.NewTestUser() 79 | require.NoError(err) 80 | err = rw.Create(testCtx, user) 81 | require.NoError(err) 82 | assert.NotEmpty(user.PublicId) 83 | where := "select * from db_test_user where name in (?, ?)" 84 | rows, err := rw.Query(context.Background(), where, []interface{}{"alice", "bob"}) 85 | require.NoError(err) 86 | defer func() { err := rows.Close(); assert.NoError(err) }() 87 | for rows.Next() { 88 | u := dbtest.AllocTestUser() 89 | // scan the row into your struct 90 | err = rw.ScanRows(rows, &u) 91 | require.NoError(err) 92 | assert.Equal(user.PublicId, u.PublicId) 93 | } 94 | }) 95 | t.Run("missing-underlying-db", func(t *testing.T) { 96 | assert, require := assert.New(t), require.New(t) 97 | rw := dbw.RW{} 98 | u := dbtest.AllocTestUser() 99 | err := rw.ScanRows(&sql.Rows{}, &u) 100 | require.Error(err) 101 | assert.Contains(err.Error(), "missing underlying db") 102 | }) 103 | t.Run("missing-result", func(t *testing.T) { 104 | assert, require := assert.New(t), require.New(t) 105 | err := rw.ScanRows(&sql.Rows{}, nil) 106 | require.Error(err) 107 | assert.Contains(err.Error(), "missing result") 108 | }) 109 | t.Run("missing-rows", func(t *testing.T) { 110 | assert, require := assert.New(t), require.New(t) 111 | u := dbtest.AllocTestUser() 112 | err := rw.ScanRows(nil, &u) 113 | require.Error(err) 114 | assert.Contains(err.Error(), "missing rows") 115 | }) 116 | } 117 | -------------------------------------------------------------------------------- /internal/proto/local/dbtest/storage/v1/dbtest.proto: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | syntax = "proto3"; 5 | 6 | // define a test proto package for the internal/db package. These protos 7 | // are only used for unit tests and are not part of the rest of the domain model 8 | 9 | package dbtest.storage.v1; 10 | 11 | import "google/protobuf/timestamp.proto"; 12 | 13 | option go_package = "internal/dbtest;dbtest"; 14 | 15 | // Timestamp for storage messages. We've defined a new local type wrapper 16 | // of google.protobuf.Timestamp so we can implement sql.Scanner and sql.Valuer 17 | // interfaces. See: 18 | // https://golang.org/pkg/database/sql/#Scanner 19 | // https://golang.org/pkg/database/sql/driver/#Valuer 20 | message Timestamp { 21 | google.protobuf.Timestamp timestamp = 1; 22 | } 23 | 24 | // TestUser model 25 | message StoreTestUser { 26 | // public_id is the used to access the user via an API 27 | // @inject_tag: gorm:"primaryKey;default:null" 28 | string public_id = 4; 29 | 30 | // create_time from the RDBMS 31 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 32 | Timestamp create_time = 2; 33 | 34 | // update_time from the RDBMS 35 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 36 | Timestamp update_time = 3; 37 | 38 | // name is the optional friendly name used to 39 | // access the user via an API 40 | // @inject_tag: `gorm:"default:null"` 41 | string name = 5; 42 | 43 | // @inject_tag: `gorm:"default:null"` 44 | string phone_number = 6; 45 | 46 | // @inject_tag: `gorm:"default:null"` 47 | string email = 7; 48 | 49 | // @inject_tag: `gorm:"default:null"` 50 | uint32 version = 8; 51 | } 52 | 53 | // TestCar car model 54 | message StoreTestCar { 55 | // public_id is the used to access the car via an API 56 | // @inject_tag: gorm:"primaryKey;default:null" 57 | string public_id = 4; 58 | 59 | // create_time from the RDBMS 60 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 61 | Timestamp create_time = 2; 62 | 63 | // update_time from the RDBMS 64 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 65 | Timestamp update_time = 3; 66 | 67 | // name is the optional friendly name used to 68 | // access the Scope via an API 69 | // @inject_tag: `gorm:"default:null"` 70 | string name = 5; 71 | 72 | // @inject_tag: `gorm:"default:null"` 73 | string model = 6; 74 | 75 | // @inject_tag: `gorm:"default:null"` 76 | int32 mpg = 7; 77 | 78 | // intentionally there is no version field 79 | } 80 | 81 | // TestRental for test rental model 82 | message StoreTestRental { 83 | // @inject_tag: `gorm:"primaryKey"` 84 | string user_id = 1; 85 | 86 | // @inject_tag: `gorm:"primaryKey"` 87 | string car_id = 2; 88 | 89 | // create_time from the RDBMS 90 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 91 | Timestamp create_time = 3; 92 | 93 | // update_time from the RDBMS 94 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 95 | Timestamp update_time = 4; 96 | 97 | // name is the optional friendly name used to 98 | // access the rental via an API 99 | // @inject_tag: `gorm:"default:null"` 100 | string name = 5; 101 | 102 | // @inject_tag: `gorm:"default:null"` 103 | uint32 version = 6; 104 | } 105 | 106 | // StoreTestScooter used in the db tests only and provides a resource with 107 | // a private id. 108 | message StoreTestScooter { 109 | // private_id is the used to access scooter, but not intended to be available 110 | // via the API 111 | // @inject_tag: `gorm:"primaryKey"` 112 | string private_id = 1; 113 | 114 | // create_time from the RDBMS 115 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 116 | Timestamp create_time = 2; 117 | 118 | // update_time from the RDBMS 119 | // @inject_tag: `gorm:"default:CURRENT_TIMESTAMP"` 120 | Timestamp update_time = 3; 121 | 122 | // @inject_tag: `gorm:"default:null"` 123 | string model = 4; 124 | 125 | // @inject_tag: `gorm:"default:null"` 126 | int32 mpg = 5; 127 | 128 | // @inject_tag: `gorm:"-"` 129 | string read_only_field = 6; 130 | } 131 | -------------------------------------------------------------------------------- /clause.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "sort" 8 | 9 | "gorm.io/gorm" 10 | "gorm.io/gorm/clause" 11 | ) 12 | 13 | // ColumnValue defines a column and it's assigned value for a database 14 | // operation. See: SetColumnValues(...) 15 | type ColumnValue struct { 16 | // Column name 17 | Column string 18 | // Value is the column's value 19 | Value interface{} 20 | } 21 | 22 | // Column represents a table Column 23 | type Column struct { 24 | // Name of the column 25 | Name string 26 | // Table name of the column 27 | Table string 28 | } 29 | 30 | func (c *Column) toAssignment(column string) clause.Assignment { 31 | return clause.Assignment{ 32 | Column: clause.Column{Name: column}, 33 | Value: clause.Column{Table: c.Table, Name: c.Name}, 34 | } 35 | } 36 | 37 | func rawAssignment(column string, value interface{}) clause.Assignment { 38 | return clause.Assignment{ 39 | Column: clause.Column{Name: column}, 40 | Value: value, 41 | } 42 | } 43 | 44 | // ExprValue encapsulates an expression value for a column assignment. See 45 | // Expr(...) to create these values. 46 | type ExprValue struct { 47 | Sql string 48 | Vars []interface{} 49 | } 50 | 51 | func (ev *ExprValue) toAssignment(column string) clause.Assignment { 52 | return clause.Assignment{ 53 | Column: clause.Column{Name: column}, 54 | Value: gorm.Expr(ev.Sql, ev.Vars...), 55 | } 56 | } 57 | 58 | // Expr creates an expression value (ExprValue) which can be used when setting 59 | // column values for database operations. See: Expr(...) 60 | // 61 | // Set name column to null example: 62 | // 63 | // SetColumnValues(map[string]interface{}{"name": Expr("NULL")}) 64 | // 65 | // Set exp_time column to N seconds from now: 66 | // 67 | // SetColumnValues(map[string]interface{}{"exp_time": Expr("wt_add_seconds_to_now(?)", 10)}) 68 | func Expr(expr string, args ...interface{}) ExprValue { 69 | return ExprValue{Sql: expr, Vars: args} 70 | } 71 | 72 | // SetColumnValues defines a map from column names to values for database 73 | // operations. 74 | func SetColumnValues(columnValues map[string]interface{}) []ColumnValue { 75 | keys := make([]string, 0, len(columnValues)) 76 | for key := range columnValues { 77 | keys = append(keys, key) 78 | } 79 | sort.Strings(keys) 80 | 81 | assignments := make([]ColumnValue, len(keys)) 82 | for idx, key := range keys { 83 | assignments[idx] = ColumnValue{Column: key, Value: columnValues[key]} 84 | } 85 | return assignments 86 | } 87 | 88 | // SetColumns defines a list of column (names) to update using the set of 89 | // proposed insert columns during an on conflict update. 90 | func SetColumns(names []string) []ColumnValue { 91 | assignments := make([]ColumnValue, len(names)) 92 | for idx, name := range names { 93 | assignments[idx] = ColumnValue{ 94 | Column: name, 95 | Value: Column{Name: name, Table: "excluded"}, 96 | } 97 | } 98 | return assignments 99 | } 100 | 101 | // OnConflict specifies how to handle alternative actions to take when an insert 102 | // results in a unique constraint or exclusion constraint error. 103 | type OnConflict struct { 104 | // Target specifies what conflict you want to define a policy for. This can 105 | // be any one of these: 106 | // Columns: the name of a specific column or columns 107 | // Constraint: the name of a unique constraint 108 | Target interface{} 109 | 110 | // Action specifies the action to take on conflict. This can be any one of 111 | // these: 112 | // DoNothing: leaves the conflicting record as-is 113 | // UpdateAll: updates all the columns of the conflicting record using the resource's data 114 | // []ColumnValue: update a set of columns of the conflicting record using the set of assignments 115 | Action interface{} 116 | } 117 | 118 | // Constraint defines database constraint name 119 | type Constraint string 120 | 121 | // Columns defines a set of column names 122 | type Columns []string 123 | 124 | // DoNothing defines an "on conflict" action of doing nothing 125 | type DoNothing bool 126 | 127 | // UpdateAll defines an "on conflict" action of updating all columns using the 128 | // proposed insert column values 129 | type UpdateAll bool 130 | -------------------------------------------------------------------------------- /testing_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "errors" 10 | "os" 11 | "strings" 12 | "testing" 13 | 14 | "github.com/hashicorp/go-secure-stdlib/base62" 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | func Test_getTestOpts(t *testing.T) { 20 | t.Parallel() 21 | assert := assert.New(t) 22 | t.Run("WithTestMigration", func(t *testing.T) { 23 | fn := func(context.Context, string, string) error { return nil } 24 | opts := getTestOpts(WithTestMigration(fn)) 25 | testOpts := getDefaultTestOptions() 26 | testOpts.withTestMigration = fn 27 | assert.NotNil(opts, testOpts.withTestMigration) 28 | }) 29 | t.Run("WithTestDatabaseUrl", func(t *testing.T) { 30 | opts := getTestOpts(WithTestDatabaseUrl("url")) 31 | testOpts := getDefaultTestOptions() 32 | testOpts.withTestDatabaseUrl = "url" 33 | assert.Equal(opts, testOpts) 34 | }) 35 | } 36 | 37 | func Test_TestSetup(t *testing.T) { 38 | testMigrationFn := func(context.Context, string, string) error { 39 | conn, err := Open(Sqlite, "file::memory:") 40 | require.NoError(t, err) 41 | rw := New(conn) 42 | _, err = rw.Exec(context.Background(), testQueryCreateTablesSqlite, nil) 43 | require.NoError(t, err) 44 | return nil 45 | } 46 | 47 | testMigrationUsingDbFn := func(_ context.Context, db *sql.DB) error { 48 | var sql string 49 | switch strings.ToLower(os.Getenv("DB_DIALECT")) { 50 | case "postgres": 51 | sql = testQueryCreateTablesPostgres 52 | default: 53 | sql = testQueryCreateTablesSqlite 54 | } 55 | _, err := db.Exec(sql) 56 | require.NoError(t, err) 57 | return nil 58 | } 59 | 60 | tests := []struct { 61 | name string 62 | opt []TestOption 63 | validate func(db *DB) bool 64 | }{ 65 | { 66 | name: "sqlite-with-migration", 67 | opt: []TestOption{WithTestDialect(Sqlite.String()), WithTestMigration(testMigrationFn)}, 68 | // we can't validate this, since WithTestMigration will open a new 69 | // sqlite connection which will result in a new in-memory db which 70 | // will only existing during the testMigrationFn... sort of silly, 71 | // but it does test that the fn is called properly at least. 72 | }, 73 | { 74 | name: "sqlite-with-migration-using-db", 75 | opt: []TestOption{WithTestDialect(Sqlite.String()), WithTestMigrationUsingDB(testMigrationUsingDbFn)}, 76 | validate: func(db *DB) bool { 77 | rw := New(db) 78 | publicId, err := base62.Random(20) 79 | require.NoError(t, err) 80 | user := &testUser{ 81 | PublicId: publicId, 82 | } 83 | require.NoError(t, err) 84 | user.Name = "foo-" + user.PublicId 85 | err = rw.Create(context.Background(), user) 86 | require.NoError(t, err) 87 | return true 88 | }, 89 | }, 90 | } 91 | for _, tt := range tests { 92 | t.Run(tt.name, func(t *testing.T) { 93 | assert := assert.New(t) 94 | db, url := TestSetup(t, tt.opt...) 95 | if tt.validate != nil { 96 | assert.True(tt.validate(db)) 97 | } 98 | assert.NotNil(db) 99 | assert.NotEmpty(url) 100 | }) 101 | } 102 | } 103 | 104 | func Test_TestSetupWithMock(t *testing.T) { 105 | assert := assert.New(t) 106 | testCtx := context.Background() 107 | 108 | publicId, err := base62.Random(20) 109 | require.NoError(t, err) 110 | user := &testUser{ 111 | PublicId: publicId, 112 | } 113 | 114 | db, mock := TestSetupWithMock(t) 115 | rw := New(db) 116 | mock.ExpectQuery(`SELECT`).WillReturnError(errors.New("failed-lookup")) 117 | 118 | err = rw.Create(testCtx, &user) 119 | assert.Error(err) 120 | } 121 | 122 | func Test_CreateDropTestTables(t *testing.T) { 123 | t.Run("execute", func(t *testing.T) { 124 | db, _ := TestSetup(t, WithTestDialect(Sqlite.String())) 125 | testDropTables(t, db) 126 | TestCreateTables(t, db) 127 | }) 128 | } 129 | 130 | // testUser is required since we can't import dbtest as it creates a circular dep 131 | type testUser struct { 132 | PublicId string `gorm:"primaryKey;default:null"` 133 | Name string `gorm:"default:null"` 134 | PhoneNumber string `gorm:"default:null"` 135 | Email string `gorm:"default:null"` 136 | Version uint32 `gorm:"default:null"` 137 | } 138 | 139 | func (u *testUser) TableName() string { return "db_test_user" } 140 | -------------------------------------------------------------------------------- /common_unexported_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func Test_intersection(t *testing.T) { 13 | type args struct { 14 | av []string 15 | bv []string 16 | } 17 | tests := []struct { 18 | name string 19 | args args 20 | want []string 21 | want1 map[string]string 22 | want2 map[string]string 23 | wantErr bool 24 | wantErrMsg string 25 | }{ 26 | { 27 | name: "intersect", 28 | args: args{ 29 | av: []string{"alice"}, 30 | bv: []string{"alice", "bob"}, 31 | }, 32 | want: []string{"alice"}, 33 | want1: map[string]string{ 34 | "ALICE": "alice", 35 | }, 36 | want2: map[string]string{ 37 | "ALICE": "alice", 38 | "BOB": "bob", 39 | }, 40 | }, 41 | { 42 | name: "intersect-2", 43 | args: args{ 44 | av: []string{"alice", "bob", "jane", "doe"}, 45 | bv: []string{"alice", "doe", "bert", "ernie", "bigbird"}, 46 | }, 47 | want: []string{"alice", "doe"}, 48 | want1: map[string]string{ 49 | "ALICE": "alice", 50 | "BOB": "bob", 51 | "JANE": "jane", 52 | "DOE": "doe", 53 | }, 54 | want2: map[string]string{ 55 | "ALICE": "alice", 56 | "DOE": "doe", 57 | "BERT": "bert", 58 | "ERNIE": "ernie", 59 | "BIGBIRD": "bigbird", 60 | }, 61 | }, 62 | { 63 | name: "intersect-mixed-case", 64 | args: args{ 65 | av: []string{"AlicE"}, 66 | bv: []string{"alICe", "Bob"}, 67 | }, 68 | want: []string{"alice"}, 69 | want1: map[string]string{ 70 | "ALICE": "AlicE", 71 | }, 72 | want2: map[string]string{ 73 | "ALICE": "alICe", 74 | "BOB": "Bob", 75 | }, 76 | }, 77 | { 78 | name: "no-intersect-mixed-case", 79 | args: args{ 80 | av: []string{"AliCe", "BOb", "jaNe", "DOE"}, 81 | bv: []string{"beRt", "ERnie", "bigBIRD"}, 82 | }, 83 | want: []string{}, 84 | want1: map[string]string{ 85 | "ALICE": "AliCe", 86 | "BOB": "BOb", 87 | "JANE": "jaNe", 88 | "DOE": "DOE", 89 | }, 90 | want2: map[string]string{ 91 | "BERT": "beRt", 92 | "ERNIE": "ERnie", 93 | "BIGBIRD": "bigBIRD", 94 | }, 95 | }, 96 | { 97 | name: "no-intersect-1", 98 | args: args{ 99 | av: []string{"alice", "bob", "jane", "doe"}, 100 | bv: []string{"bert", "ernie", "bigbird"}, 101 | }, 102 | want: []string{}, 103 | want1: map[string]string{ 104 | "ALICE": "alice", 105 | "BOB": "bob", 106 | "JANE": "jane", 107 | "DOE": "doe", 108 | }, 109 | want2: map[string]string{ 110 | "BERT": "bert", 111 | "ERNIE": "ernie", 112 | "BIGBIRD": "bigbird", 113 | }, 114 | }, 115 | { 116 | name: "empty-av", 117 | args: args{ 118 | av: []string{}, 119 | bv: []string{"bert", "ernie", "bigbird"}, 120 | }, 121 | want: []string{}, 122 | want1: map[string]string{}, 123 | want2: map[string]string{ 124 | "BERT": "bert", 125 | "ERNIE": "ernie", 126 | "BIGBIRD": "bigbird", 127 | }, 128 | }, 129 | { 130 | name: "empty-av-and-bv", 131 | args: args{ 132 | av: []string{}, 133 | bv: []string{}, 134 | }, 135 | want: []string{}, 136 | want1: map[string]string{}, 137 | want2: map[string]string{}, 138 | }, 139 | { 140 | name: "nil-av", 141 | args: args{ 142 | av: nil, 143 | bv: []string{"bert", "ernie", "bigbird"}, 144 | }, 145 | want: nil, 146 | want1: nil, 147 | want2: nil, 148 | wantErr: true, 149 | wantErrMsg: "dbw.Intersection: av is missing: invalid parameter", 150 | }, 151 | { 152 | name: "nil-bv", 153 | args: args{ 154 | av: []string{}, 155 | bv: nil, 156 | }, 157 | want: nil, 158 | want1: nil, 159 | want2: nil, 160 | wantErr: true, 161 | wantErrMsg: "dbw.Intersection: bv is missing: invalid parameter", 162 | }, 163 | } 164 | for _, tt := range tests { 165 | t.Run(tt.name, func(t *testing.T) { 166 | assert := assert.New(t) 167 | got, got1, got2, err := Intersection(tt.args.av, tt.args.bv) 168 | if err == nil && tt.wantErr { 169 | assert.Error(err) 170 | } 171 | if tt.wantErr { 172 | assert.Error(err) 173 | assert.Equal(tt.wantErrMsg, err.Error()) 174 | } 175 | assert.Equal(tt.want, got) 176 | assert.Equal(tt.want1, got1) 177 | assert.Equal(tt.want2, got2) 178 | }) 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /writer.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "time" 10 | ) 11 | 12 | // Writer interface defines create, update and retryable transaction handlers 13 | type Writer interface { 14 | // DoTx will wrap the TxHandler in a retryable transaction 15 | DoTx(ctx context.Context, retryErrorsMatchingFn func(error) bool, retries uint, backOff Backoff, Handler TxHandler) (RetryInfo, error) 16 | 17 | // Update an object in the db, fieldMask is required and provides 18 | // field_mask.proto paths for fields that should be updated. The i interface 19 | // parameter is the type the caller wants to update in the db and its 20 | // fields are set to the update values. setToNullPaths is optional and 21 | // provides field_mask.proto paths for the fields that should be set to 22 | // null. fieldMaskPaths and setToNullPaths must not intersect. The caller 23 | // is responsible for the transaction life cycle of the writer and if an 24 | // error is returned the caller must decide what to do with the transaction, 25 | // which almost always should be to rollback. Update returns the number of 26 | // rows updated or an error. 27 | Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error) 28 | 29 | // Create a resource in the database. The caller is responsible for the 30 | // transaction life cycle of the writer and if an error is returned the 31 | // caller must decide what to do with the transaction, which almost always 32 | // should be to rollback. 33 | Create(ctx context.Context, i interface{}, opt ...Option) error 34 | 35 | // CreateItems will create multiple items of the same type. The caller is 36 | // responsible for the transaction life cycle of the writer and if an error 37 | // is returned the caller must decide what to do with the transaction, which 38 | // almost always should be to rollback. 39 | // Supported options: WithBatchSize, WithDebug, WithBeforeWrite, 40 | // WithAfterWrite, WithReturnRowsAffected, OnConflict, WithVersion, 41 | // WithTable, and WithWhere. 42 | // WithLookup is not a supported option. 43 | CreateItems(ctx context.Context, createItems interface{}, opt ...Option) error 44 | 45 | // Delete a resource in the database. The caller is responsible for the 46 | // transaction life cycle of the writer and if an error is returned the 47 | // caller must decide what to do with the transaction, which almost always 48 | // should be to rollback. Delete returns the number of rows deleted or an 49 | // error. 50 | Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) 51 | 52 | // DeleteItems will delete multiple items of the same type. The caller is 53 | // responsible for the transaction life cycle of the writer and if an error 54 | // is returned the caller must decide what to do with the transaction, which 55 | // almost always should be to rollback. Delete returns the number of rows 56 | // deleted or an error. 57 | DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error) 58 | 59 | // Exec will execute the sql with the values as parameters. The int returned 60 | // is the number of rows affected by the sql. No options are currently 61 | // supported. 62 | Exec(ctx context.Context, sql string, values []interface{}, opt ...Option) (int, error) 63 | 64 | // Query will run the raw query and return the *sql.Rows results. The 65 | // caller must close the returned *sql.Rows. Query can/should be used in 66 | // combination with ScanRows. Query is included in the Writer interface 67 | // so callers can execute updates and inserts with returning values. 68 | Query(ctx context.Context, sql string, values []interface{}, opt ...Option) (*sql.Rows, error) 69 | 70 | // ScanRows will scan sql rows into the interface provided 71 | ScanRows(rows *sql.Rows, result interface{}) error 72 | 73 | // Begin will start a transaction. NOTE: consider using DoTx(...) with a 74 | // TxHandler since it supports a better interface for managing transactions 75 | // via a TxHandler. 76 | Begin(ctx context.Context) (*RW, error) 77 | 78 | // Rollback will rollback the current transaction. NOTE: consider using 79 | // DoTx(...) with a TxHandler since it supports a better interface for 80 | // managing transactions via a TxHandler. 81 | Rollback(ctx context.Context) error 82 | 83 | // Commit will commit a transaction. NOTE: consider using DoTx(...) with a 84 | // TxHandler since it supports a better interface for managing transactions 85 | // via a TxHandler. 86 | Commit(ctx context.Context) error 87 | 88 | // Dialect returns the dialect and raw connection name of the underlying database. 89 | Dialect() (_ DbType, rawName string, _ error) 90 | } 91 | 92 | // RetryInfo provides information on the retries of a transaction 93 | type RetryInfo struct { 94 | Retries int 95 | Backoff time.Duration 96 | } 97 | 98 | // TxHandler defines a handler for a func that writes a transaction for use with DoTx 99 | type TxHandler func(Reader, Writer) error 100 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to dbw 2 | 3 | Thank you for contributing! Here you can find common questions around reporting issues and opening 4 | pull requests to our project. 5 | 6 | When contributing in any way to the project (new issue, PR, etc), please be aware that our team identifies with many gender pronouns. Please remember to use nonbinary pronouns (they/them) and gender neutral language ("Hello folks") when addressing our team. For more reading on our code of conduct, please see the [HashiCorp community guidelines](https://www.hashicorp.com/community-guidelines). 7 | 8 | ## Issue Reporting 9 | ### Reporting Security Related Vulnerabilities 10 | 11 | We take security and our users' trust very seriously. If you believe you have found a security issue, please responsibly disclose by contacting us at security@hashicorp.com. Do not open an issue on 12 | our GitHub issue tracker if you believe you've found a security related issue, thank you! 13 | 14 | ### Bug Fixes 15 | 16 | If you believe you found a bug, please: 17 | 18 | 1. Build from the latest `main` HEAD commit to attempt to reproduce the issue. It's possible we've already fixed 19 | the bug, and this is a first good step to ensuring that's not the case. 20 | 1. Ensure a similar ticket is not already opened by searching our opened issues on GitHub. 21 | 22 | 23 | Once you've verified the above, feel free to open a bug fix issue template type from our [issue selector](https://github.com/hashicorp/dbw/issues/new/choose) 24 | and we'll do our best to triage it as quickly as possible. 25 | 26 | ## Pull Requests 27 | 28 | ### New Features & Improvements 29 | 30 | Before writing a line of code, please ask us about a potential improvement or feature that you want to write. We may already be working on it; even if we aren't, we need to ensure that both the feature and its proposed implementation is aligned with our road map, vision, and standards for the project. We're happy to help walk through that via a [feature request issue](https://github.com/hashicorp/dbw/issues/new/choose). 31 | 32 | ### Submitting a New Pull Request 33 | 34 | When submitting a pull request, please ensure: 35 | 36 | 1. You've added a changelog line clearly describing the new addition under the correct changelog sub-section. 37 | 2. You've followed the above guidelines for contributing. 38 | 39 | Once you open your PR, our auto-labeling will add labels to help us triage and prioritize your contribution. Please 40 | allow us a couple of days to comment, request changes, or approve your PR. Thank you for your contribution! 41 | 42 | ## Changelog 43 | 44 | The changelog is updated by PR contributors. Each contribution should include a changelog update at the contributor or reviewer discretion. 45 | The changelog should be updated when the contribution is large enough to warrant it being called out in the larger release cycle. Enhancements, bug fixes, 46 | and other contributions that practitioners might want to be aware of should exist in the changelog. 47 | 48 | When contributing to the changelog, follow existing patterns for referencing PR's, issues or other ancillary context. 49 | 50 | The changelog is broken down into sections: 51 | 52 | ### vNext 53 | 54 | The current release cycle. New contributions slated for the next release should go under this heading. If the contribution is being backported, 55 | the inclusion of the feature in the appropriate release during the backport process is handled on an as-needed basis. 56 | 57 | ### New and Improved 58 | 59 | Any enhancements, new features, etc fall into this section. 60 | 61 | ### Bug Fixes 62 | 63 | Any bug fixes fall into this section. 64 | 65 | ## Testing 66 | 67 | Most tests require a database connection obtained from `TestSetup(t)` to run. 68 | By default, `TestSetup(t)` returns a connection to a tmp sqlite database which 69 | it tears down for you when the test is done. 70 | 71 | For the vast majority of tests you don't need to specify the dialect or dsn when 72 | using `TestSetup`. Just run `go test` and let `TestSetup(t)` handle 73 | creating/initializing a test database for your test and tearing it down when the 74 | test is over. 75 | 76 | In most cases you shouldn't specify the dialect/dsn for `TestSetup(t)` because 77 | that allows dbw to run your tests while iterating across a set dialects in CI/CD 78 | (see `test_all.sh` and our GitHub actions for details). If you wish to run 79 | `test_all.sh` locally then simply `docker-compose up` before running 80 | `test_all.sh`. Or you can run the GitHub actions locally using 81 | [act](https://github.com/nektos/act). 82 | 83 | `TestSetup(t)` supports setting the database dialect via the env var `DB_DIALECT` or 84 | the option `WithTestDialect`. Sqlite is used when both the env var and 85 | option are not specified. 86 | 87 | `TestSetup(t)` supports setting the database DSN via the env var `DB_DSN` or the 88 | option `WithTestDatabaseUrl(...)`. A tmp sqlite database is used when both the env 89 | var and the option are not specified. 90 | 91 | Note that if `max_connections` is set too low, it may result in sporadic test 92 | failures if a connection cannot be established. In this case, reduce the number 93 | of concurrent tests via `GOMAXPROCS` or selectively run tests. 94 | 95 | -------------------------------------------------------------------------------- /lookup_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw_test 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | 10 | "github.com/google/go-cmp/cmp" 11 | "github.com/hashicorp/go-dbw" 12 | "github.com/hashicorp/go-dbw/internal/dbtest" 13 | "github.com/stretchr/testify/assert" 14 | "github.com/stretchr/testify/require" 15 | "google.golang.org/protobuf/proto" 16 | "google.golang.org/protobuf/testing/protocmp" 17 | ) 18 | 19 | func TestDb_LookupBy(t *testing.T) { 20 | t.Parallel() 21 | db, _ := dbw.TestSetup(t) 22 | testRw := dbw.New(db) 23 | scooter := testScooter(t, testRw, "", 0, "") 24 | user := testUser(t, testRw, "", "", "") 25 | car := testCar(t, testRw) 26 | rental := testRental(t, testRw, user.PublicId, car.PublicId) 27 | 28 | testScooterWithROField := testScooter(t, testRw, "", 0, "read-only") 29 | 30 | type args struct { 31 | resource interface{} 32 | opt []dbw.Option 33 | } 34 | tests := []struct { 35 | name string 36 | rw *dbw.RW 37 | args args 38 | wantErr bool 39 | want proto.Message 40 | wantIsErr error 41 | }{ 42 | { 43 | name: "simple-private-id", 44 | rw: testRw, 45 | args: args{ 46 | resource: scooter, 47 | }, 48 | wantErr: false, 49 | want: scooter, 50 | }, 51 | { 52 | name: "simple-public-id", 53 | rw: testRw, 54 | args: args{ 55 | resource: user, 56 | }, 57 | wantErr: false, 58 | want: user, 59 | }, 60 | { 61 | name: "with-null-values-set", 62 | rw: testRw, 63 | args: args{ 64 | resource: &dbtest.TestScooter{ 65 | StoreTestScooter: &dbtest.StoreTestScooter{ 66 | PrivateId: scooter.GetPrivateId(), 67 | Model: "model", 68 | Mpg: 10, 69 | }, 70 | }, 71 | }, 72 | want: scooter, 73 | }, 74 | { 75 | name: "with-read-only-field-set", 76 | rw: testRw, 77 | args: args{ 78 | resource: testScooterWithROField, 79 | }, 80 | want: testScooterWithROField, 81 | }, 82 | { 83 | name: "with-table", 84 | rw: testRw, 85 | args: args{ 86 | resource: user, 87 | opt: []dbw.Option{dbw.WithTable(user.TableName())}, 88 | }, 89 | wantErr: false, 90 | want: user, 91 | }, 92 | { 93 | name: "with-debug", 94 | rw: testRw, 95 | args: args{ 96 | resource: user, 97 | opt: []dbw.Option{dbw.WithDebug(true)}, 98 | }, 99 | wantErr: false, 100 | want: user, 101 | }, 102 | { 103 | name: "with-table-fail", 104 | rw: testRw, 105 | args: args{ 106 | resource: user, 107 | opt: []dbw.Option{dbw.WithTable("invalid-table-name")}, 108 | }, 109 | wantErr: true, 110 | }, 111 | { 112 | name: "compond", 113 | rw: testRw, 114 | args: args{ 115 | resource: rental, 116 | }, 117 | wantErr: false, 118 | want: rental, 119 | }, 120 | { 121 | name: "compond-with-zero-value-pk", 122 | rw: testRw, 123 | args: args{ 124 | resource: func() interface{} { 125 | cp := rental.Clone() 126 | cp.(*dbtest.TestRental).CarId = "" 127 | return cp 128 | }(), 129 | }, 130 | wantErr: true, 131 | wantIsErr: dbw.ErrInvalidParameter, 132 | }, 133 | { 134 | name: "missing-public-id", 135 | rw: testRw, 136 | args: args{ 137 | resource: &dbtest.TestUser{ 138 | StoreTestUser: &dbtest.StoreTestUser{}, 139 | }, 140 | }, 141 | wantErr: true, 142 | wantIsErr: dbw.ErrInvalidParameter, 143 | }, 144 | { 145 | name: "missing-private-id", 146 | rw: testRw, 147 | args: args{ 148 | resource: &dbtest.TestScooter{ 149 | StoreTestScooter: &dbtest.StoreTestScooter{}, 150 | }, 151 | }, 152 | wantErr: true, 153 | wantIsErr: dbw.ErrInvalidParameter, 154 | }, 155 | { 156 | name: "not-an-ider", 157 | rw: testRw, 158 | args: args{ 159 | resource: &dbtest.NotIder{}, 160 | }, 161 | wantErr: true, 162 | wantIsErr: dbw.ErrInvalidParameter, 163 | }, 164 | { 165 | name: "missing-underlying-db", 166 | rw: &dbw.RW{}, 167 | args: args{ 168 | resource: user, 169 | }, 170 | wantErr: true, 171 | wantIsErr: dbw.ErrInvalidParameter, 172 | }, 173 | } 174 | for _, tt := range tests { 175 | t.Run(tt.name, func(t *testing.T) { 176 | assert, require := assert.New(t), require.New(t) 177 | cloner, ok := tt.args.resource.(dbtest.Cloner) 178 | require.True(ok) 179 | cp := cloner.Clone() 180 | err := tt.rw.LookupBy(context.Background(), cp, tt.args.opt...) 181 | if tt.wantErr { 182 | require.Error(err) 183 | if tt.wantIsErr != nil { 184 | assert.ErrorIs(err, tt.wantIsErr) 185 | } 186 | return 187 | } 188 | require.NoError(err) 189 | assert.Empty(cmp.Diff(tt.want, cp.(proto.Message), protocmp.Transform())) 190 | }) 191 | } 192 | t.Run("not-ptr", func(t *testing.T) { 193 | assert, require := assert.New(t), require.New(t) 194 | u := testUser(t, testRw, "", "", "") 195 | err := testRw.LookupBy(context.Background(), *u) 196 | require.Error(err) 197 | assert.ErrorIs(err, dbw.ErrInvalidParameter) 198 | }) 199 | t.Run("hooks", func(t *testing.T) { 200 | hookTests := []struct { 201 | name string 202 | resource interface{} 203 | }{ 204 | {"after", &dbtest.TestWithAfterFind{}}, 205 | } 206 | for _, tt := range hookTests { 207 | t.Run(tt.name, func(t *testing.T) { 208 | assert, require := assert.New(t), require.New(t) 209 | w := dbw.New(db) 210 | err := w.LookupBy(context.Background(), tt.resource) 211 | require.Error(err) 212 | assert.ErrorIs(err, dbw.ErrInvalidParameter) 213 | assert.Contains(err.Error(), "gorm callback/hooks are not supported") 214 | }) 215 | } 216 | }) 217 | } 218 | -------------------------------------------------------------------------------- /common.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "fmt" 8 | "reflect" 9 | "strings" 10 | 11 | "gorm.io/gorm" 12 | ) 13 | 14 | // UpdateFields will create a map[string]interface of the update values to be 15 | // sent to the db. The map keys will be the field names for the fields to be 16 | // updated. The caller provided fieldMaskPaths and setToNullPaths must not 17 | // intersect. fieldMaskPaths and setToNullPaths cannot both be zero len. 18 | func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []string) (map[string]interface{}, error) { 19 | const op = "dbw.UpdateFields" 20 | if i == nil { 21 | return nil, fmt.Errorf("%s: interface is missing: %w", op, ErrInvalidParameter) 22 | } 23 | if fieldMaskPaths == nil { 24 | fieldMaskPaths = []string{} 25 | } 26 | if setToNullPaths == nil { 27 | setToNullPaths = []string{} 28 | } 29 | if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { 30 | return nil, fmt.Errorf("%s: both fieldMaskPaths and setToNullPaths are zero len: %w", op, ErrInvalidParameter) 31 | } 32 | 33 | inter, maskPaths, nullPaths, err := Intersection(fieldMaskPaths, setToNullPaths) 34 | if err != nil { 35 | return nil, fmt.Errorf("%s: %w", op, ErrInvalidParameter) 36 | } 37 | if len(inter) != 0 { 38 | return nil, fmt.Errorf("%s: fieldMashPaths and setToNullPaths cannot intersect: %w", op, ErrInvalidParameter) 39 | } 40 | 41 | updateFields := map[string]interface{}{} // case sensitive update fields to values 42 | 43 | found := map[string]struct{}{} // we need something to keep track of found fields (case insensitive) 44 | 45 | val := reflect.Indirect(reflect.ValueOf(i)) 46 | structTyp := val.Type() 47 | for i := 0; i < structTyp.NumField(); i++ { 48 | if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { 49 | updateFields[f] = val.Field(i).Interface() 50 | found[strings.ToUpper(f)] = struct{}{} 51 | continue 52 | } 53 | if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { 54 | updateFields[f] = gorm.Expr("NULL") 55 | found[strings.ToUpper(f)] = struct{}{} 56 | continue 57 | } 58 | kind := structTyp.Field(i).Type.Kind() 59 | if kind == reflect.Struct || kind == reflect.Ptr { 60 | embType := structTyp.Field(i).Type 61 | // check if the embedded field is exported via CanInterface() 62 | if val.Field(i).CanInterface() { 63 | embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface())) 64 | // if it's a ptr to a struct, then we need a few more bits before proceeding. 65 | if kind == reflect.Ptr { 66 | embVal = val.Field(i).Elem() 67 | if !embVal.IsValid() { 68 | continue 69 | } 70 | embType = embVal.Type() 71 | if embType.Kind() != reflect.Struct { 72 | continue 73 | } 74 | } 75 | for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ { 76 | if f, ok := maskPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok { 77 | updateFields[f] = embVal.Field(embFieldNum).Interface() 78 | found[strings.ToUpper(f)] = struct{}{} 79 | } 80 | if f, ok := nullPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok { 81 | updateFields[f] = gorm.Expr("NULL") 82 | found[strings.ToUpper(f)] = struct{}{} 83 | } 84 | } 85 | continue 86 | } 87 | } 88 | } 89 | 90 | if missing := findMissingPaths(setToNullPaths, found); len(missing) != 0 { 91 | return nil, fmt.Errorf("%s: null paths not found in resource: %s: %w", op, missing, ErrInvalidParameter) 92 | } 93 | 94 | if missing := findMissingPaths(fieldMaskPaths, found); len(missing) != 0 { 95 | return nil, fmt.Errorf("%s: field mask paths not found in resource: %s: %w", op, missing, ErrInvalidParameter) 96 | } 97 | 98 | return updateFields, nil 99 | } 100 | 101 | func findMissingPaths(paths []string, foundPaths map[string]struct{}) []string { 102 | notFound := []string{} 103 | for _, f := range paths { 104 | if _, ok := foundPaths[strings.ToUpper(f)]; !ok { 105 | notFound = append(notFound, f) 106 | } 107 | } 108 | return notFound 109 | } 110 | 111 | // Intersection is a case-insensitive search for intersecting values. Returns 112 | // []string of the Intersection with values in lowercase, and map[string]string 113 | // of the original av and bv, with the key set to uppercase and value set to the 114 | // original 115 | func Intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) { 116 | const op = "dbw.Intersection" 117 | if av == nil { 118 | return nil, nil, nil, fmt.Errorf("%s: av is missing: %w", op, ErrInvalidParameter) 119 | } 120 | if bv == nil { 121 | return nil, nil, nil, fmt.Errorf("%s: bv is missing: %w", op, ErrInvalidParameter) 122 | } 123 | if len(av) == 0 && len(bv) == 0 { 124 | return []string{}, map[string]string{}, map[string]string{}, nil 125 | } 126 | s := []string{} 127 | ah := map[string]string{} 128 | bh := map[string]string{} 129 | 130 | for i := 0; i < len(av); i++ { 131 | ah[strings.ToUpper(av[i])] = av[i] 132 | } 133 | for i := 0; i < len(bv); i++ { 134 | k := strings.ToUpper(bv[i]) 135 | bh[k] = bv[i] 136 | if _, found := ah[k]; found { 137 | s = append(s, strings.ToLower(bh[k])) 138 | } 139 | } 140 | return s, ah, bh, nil 141 | } 142 | 143 | // BuildUpdatePaths takes a map of field names to field values, field masks, 144 | // fields allowed to be zero value, and returns both a list of field names to 145 | // update and a list of field names that should be set to null. 146 | func BuildUpdatePaths(fieldValues map[string]interface{}, fieldMask []string, allowZeroFields []string) (masks []string, nulls []string) { 147 | for f, v := range fieldValues { 148 | if !contains(fieldMask, f) { 149 | continue 150 | } 151 | switch { 152 | case isZero(v) && !contains(allowZeroFields, f): 153 | nulls = append(nulls, f) 154 | default: 155 | masks = append(masks, f) 156 | } 157 | } 158 | return masks, nulls 159 | } 160 | 161 | func isZero(i interface{}) bool { 162 | return i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) 163 | } 164 | -------------------------------------------------------------------------------- /internal/dbtest/db.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | // Package db_test provides some helper funcs for testing db integrations 5 | package dbtest 6 | 7 | import ( 8 | "context" 9 | "fmt" 10 | 11 | "github.com/hashicorp/go-dbw" 12 | "gorm.io/gorm" 13 | 14 | "github.com/hashicorp/go-secure-stdlib/base62" 15 | "google.golang.org/protobuf/proto" 16 | ) 17 | 18 | const ( 19 | defaultUserTablename = "db_test_user" 20 | defaultCarTableName = "db_test_car" 21 | defaultRentalTableName = "db_test_rental" 22 | defaultScooterTableName = "db_test_scooter" 23 | ) 24 | 25 | type TestUser struct { 26 | *StoreTestUser 27 | table string `gorm:"-"` 28 | } 29 | 30 | func NewTestUser() (*TestUser, error) { 31 | publicId, err := base62.Random(20) 32 | if err != nil { 33 | return nil, err 34 | } 35 | return &TestUser{ 36 | StoreTestUser: &StoreTestUser{ 37 | PublicId: publicId, 38 | }, 39 | }, nil 40 | } 41 | 42 | func AllocTestUser() TestUser { 43 | return TestUser{ 44 | StoreTestUser: &StoreTestUser{}, 45 | } 46 | } 47 | 48 | // Clone is useful when you're retrying transactions and you need to send the user several times 49 | func (u *TestUser) Clone() interface{} { 50 | s := proto.Clone(u.StoreTestUser) 51 | return &TestUser{ 52 | StoreTestUser: s.(*StoreTestUser), 53 | } 54 | } 55 | 56 | func (u *TestUser) TableName() string { 57 | if u.table != "" { 58 | return u.table 59 | } 60 | return defaultUserTablename 61 | } 62 | 63 | func (u *TestUser) SetTableName(name string) { 64 | switch name { 65 | case "": 66 | u.table = defaultUserTablename 67 | default: 68 | u.table = name 69 | } 70 | } 71 | 72 | var _ dbw.VetForWriter = (*TestUser)(nil) 73 | 74 | func (u *TestUser) VetForWrite(ctx context.Context, r dbw.Reader, opType dbw.OpType, opt ...dbw.Option) error { 75 | const op = "dbtest.(TestUser).VetForWrite" 76 | if u.PublicId == "" { 77 | return fmt.Errorf("%s: missing public id: %w", op, dbw.ErrInvalidParameter) 78 | } 79 | if u.Name == "fail-VetForWrite" { 80 | return fmt.Errorf("%s: name was fail-VetForWrite: %w", op, dbw.ErrInvalidParameter) 81 | } 82 | switch opType { 83 | case dbw.UpdateOp: 84 | dbOptions := dbw.GetOpts(opt...) 85 | for _, path := range dbOptions.WithFieldMaskPaths { 86 | switch path { 87 | case "PublicId", "CreateTime", "UpdateTime": 88 | return fmt.Errorf("%s: %s is immutable: %w", op, path, dbw.ErrInvalidParameter) 89 | } 90 | } 91 | case dbw.CreateOp: 92 | if u.CreateTime != nil { 93 | return fmt.Errorf("%s: create time is set by the database: %w", op, dbw.ErrInvalidParameter) 94 | } 95 | } 96 | return nil 97 | } 98 | 99 | type TestCar struct { 100 | *StoreTestCar 101 | table string `gorm:"-"` 102 | } 103 | 104 | func NewTestCar() (*TestCar, error) { 105 | publicId, err := base62.Random(20) 106 | if err != nil { 107 | return nil, err 108 | } 109 | return &TestCar{ 110 | StoreTestCar: &StoreTestCar{ 111 | PublicId: publicId, 112 | }, 113 | }, nil 114 | } 115 | 116 | func (c *TestCar) TableName() string { 117 | if c.table != "" { 118 | return c.table 119 | } 120 | 121 | return defaultCarTableName 122 | } 123 | 124 | func (c *TestCar) SetTableName(name string) { 125 | c.table = name 126 | } 127 | 128 | type Cloner interface { 129 | Clone() interface{} 130 | } 131 | 132 | type NotIder struct{} 133 | 134 | func (i *NotIder) Clone() interface{} { 135 | return &NotIder{} 136 | } 137 | 138 | type TestRental struct { 139 | *StoreTestRental 140 | table string `gorm:"-"` 141 | } 142 | 143 | func NewTestRental(userId, carId string) (*TestRental, error) { 144 | return &TestRental{ 145 | StoreTestRental: &StoreTestRental{ 146 | UserId: userId, 147 | CarId: carId, 148 | }, 149 | }, nil 150 | } 151 | 152 | // Clone is useful when you're retrying transactions and you need to send the user several times 153 | func (t *TestRental) Clone() interface{} { 154 | s := proto.Clone(t.StoreTestRental) 155 | return &TestRental{ 156 | StoreTestRental: s.(*StoreTestRental), 157 | } 158 | } 159 | 160 | func (r *TestRental) TableName() string { 161 | if r.table != "" { 162 | return r.table 163 | } 164 | 165 | return defaultRentalTableName 166 | } 167 | 168 | func (r *TestRental) SetTableName(name string) { 169 | r.table = name 170 | } 171 | 172 | type TestScooter struct { 173 | *StoreTestScooter 174 | table string `gorm:"-"` 175 | } 176 | 177 | func NewTestScooter() (*TestScooter, error) { 178 | privateId, err := base62.Random(20) 179 | if err != nil { 180 | return nil, err 181 | } 182 | return &TestScooter{ 183 | StoreTestScooter: &StoreTestScooter{ 184 | PrivateId: privateId, 185 | }, 186 | }, nil 187 | } 188 | 189 | func (t *TestScooter) Clone() interface{} { 190 | s := proto.Clone(t.StoreTestScooter) 191 | return &TestScooter{ 192 | StoreTestScooter: s.(*StoreTestScooter), 193 | } 194 | } 195 | 196 | func (t *TestScooter) TableName() string { 197 | if t.table != "" { 198 | return t.table 199 | } 200 | return defaultScooterTableName 201 | } 202 | 203 | func (t *TestScooter) SetTableName(name string) { 204 | t.table = name 205 | } 206 | 207 | type TestWithBeforeCreate struct{} 208 | 209 | func (t *TestWithBeforeCreate) BeforeCreate(_ *gorm.DB) error { 210 | return nil 211 | } 212 | 213 | type TestWithAfterCreate struct{} 214 | 215 | func (t *TestWithAfterCreate) AfterCreate(_ *gorm.DB) error { 216 | return nil 217 | } 218 | 219 | type TestWithBeforeSave struct{} 220 | 221 | func (t *TestWithBeforeSave) BeforeSave(_ *gorm.DB) error { 222 | return nil 223 | } 224 | 225 | type TestWithAfterSave struct{} 226 | 227 | func (t *TestWithAfterSave) AfterSave(_ *gorm.DB) error { 228 | return nil 229 | } 230 | 231 | type TestWithBeforeUpdate struct{} 232 | 233 | func (t *TestWithBeforeUpdate) BeforeUpdate(_ *gorm.DB) error { 234 | return nil 235 | } 236 | 237 | type TestWithAfterUpdate struct{} 238 | 239 | func (t *TestWithAfterUpdate) AfterUpdate(_ *gorm.DB) error { 240 | return nil 241 | } 242 | 243 | type TestWithBeforeDelete struct{} 244 | 245 | func (t *TestWithBeforeDelete) BeforeDelete(_ *gorm.DB) error { 246 | return nil 247 | } 248 | 249 | type TestWithAfterDelete struct{} 250 | 251 | func (t *TestWithAfterDelete) AfterDelete(_ *gorm.DB) error { 252 | return nil 253 | } 254 | 255 | type TestWithAfterFind struct{} 256 | 257 | func (t *TestWithAfterFind) AfterFind(_ *gorm.DB) error { 258 | return nil 259 | } 260 | -------------------------------------------------------------------------------- /db_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw_test 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "io/ioutil" 10 | "os" 11 | "strings" 12 | "sync" 13 | "testing" 14 | 15 | "github.com/hashicorp/go-dbw" 16 | "github.com/hashicorp/go-hclog" 17 | "github.com/stretchr/testify/assert" 18 | "github.com/stretchr/testify/require" 19 | "gorm.io/driver/sqlite" 20 | ) 21 | 22 | func TestOpen(t *testing.T) { 23 | ctx := context.Background() 24 | _, url := dbw.TestSetup(t) 25 | 26 | type args struct { 27 | dbType dbw.DbType 28 | connectionUrl string 29 | opts []dbw.Option 30 | } 31 | tests := []struct { 32 | name string 33 | args args 34 | wantErr bool 35 | }{ 36 | { 37 | name: "valid-sqlite-with-opts", 38 | args: args{ 39 | dbType: dbw.Sqlite, 40 | connectionUrl: url, 41 | opts: []dbw.Option{ 42 | dbw.WithMinOpenConnections(1), 43 | dbw.WithMaxOpenConnections(2), 44 | dbw.WithLogger(hclog.New(hclog.DefaultOptions)), 45 | }, 46 | }, 47 | wantErr: false, 48 | }, 49 | { 50 | name: "valid-sqlite-no-opts", 51 | args: args{ 52 | dbType: dbw.Sqlite, 53 | connectionUrl: url, 54 | }, 55 | wantErr: false, 56 | }, 57 | { 58 | name: "invalid-connection-opts", 59 | args: args{ 60 | dbType: dbw.Sqlite, 61 | connectionUrl: url, 62 | opts: []dbw.Option{ 63 | dbw.WithMinOpenConnections(3), 64 | dbw.WithMaxOpenConnections(2), 65 | dbw.WithLogger(hclog.New(hclog.DefaultOptions)), 66 | }, 67 | }, 68 | wantErr: true, 69 | }, 70 | { 71 | name: "missing-url", 72 | args: args{ 73 | dbType: dbw.Sqlite, 74 | connectionUrl: "", 75 | }, 76 | wantErr: true, 77 | }, 78 | { 79 | name: "invalid-url", 80 | args: args{ 81 | dbType: dbw.Sqlite, 82 | connectionUrl: "file::memory:?cache=invalid-parameter", 83 | opts: []dbw.Option{dbw.WithLogger(hclog.New( 84 | &hclog.LoggerOptions{ 85 | Output: ioutil.Discard, 86 | }, 87 | ))}, 88 | }, 89 | wantErr: true, 90 | }, 91 | { 92 | name: "unknown-type", 93 | args: args{ 94 | dbType: dbw.UnknownDB, 95 | connectionUrl: url, 96 | }, 97 | wantErr: true, 98 | }, 99 | } 100 | for _, tt := range tests { 101 | t.Run(tt.name, func(t *testing.T) { 102 | require := require.New(t) 103 | t.Cleanup(func() { 104 | os.Remove(tt.args.connectionUrl + "-journal") 105 | os.Remove(tt.args.connectionUrl) 106 | }) 107 | 108 | got, err := dbw.Open(tt.args.dbType, tt.args.connectionUrl, tt.args.opts...) 109 | defer func() { 110 | if err == nil { 111 | err = got.Close(ctx) 112 | require.NoError(err) 113 | } 114 | }() 115 | if tt.wantErr { 116 | require.Error(err) 117 | return 118 | } 119 | require.NoError(err) 120 | rw := dbw.New(got) 121 | rows, err := rw.Query(context.Background(), "PRAGMA foreign_keys", nil) 122 | require.NoError(err) 123 | require.True(rows.Next()) 124 | type foo struct{} 125 | f := struct { 126 | ForeignKeys int 127 | }{} 128 | err = rw.ScanRows(rows, &f) 129 | require.NoError(err) 130 | require.Equal(1, f.ForeignKeys) 131 | fmt.Println(f) 132 | }) 133 | } 134 | } 135 | 136 | func TestDB_OpenWith(t *testing.T) { 137 | t.Run("simple-sqlite", func(t *testing.T) { 138 | assert := assert.New(t) 139 | _, err := dbw.OpenWith(sqlite.Open("file::memory:"), nil) 140 | assert.NoError(err) 141 | }) 142 | t.Run("sqlite-with-logger", func(t *testing.T) { 143 | assert, require := assert.New(t), require.New(t) 144 | buf := new(strings.Builder) 145 | testLock := &sync.Mutex{} 146 | testLogger := hclog.New(&hclog.LoggerOptions{ 147 | Mutex: testLock, 148 | Name: "test", 149 | JSONFormat: true, 150 | Output: buf, 151 | Level: hclog.Debug, 152 | }) 153 | d, err := dbw.OpenWith(sqlite.Open("file::memory:"), dbw.WithLogger(gormDebugLogger{Logger: testLogger}), dbw.WithDebug(true)) 154 | require.NoError(err) 155 | require.NotEmpty(d) 156 | rw := dbw.New(d) 157 | const sql = "select 'hello world'" 158 | testCtx := context.Background() 159 | rows, err := rw.Query(testCtx, sql, nil) 160 | require.NoError(err) 161 | defer rows.Close() 162 | assert.Contains(buf.String(), sql) 163 | t.Log(buf.String()) 164 | }) 165 | } 166 | 167 | type gormDebugLogger struct { 168 | hclog.Logger 169 | } 170 | 171 | func (g gormDebugLogger) Printf(msg string, values ...interface{}) { 172 | b := new(strings.Builder) 173 | fmt.Fprintf(b, msg, values...) 174 | g.Debug(b.String()) 175 | } 176 | 177 | func getGormLogger(log hclog.Logger) gormDebugLogger { 178 | return gormDebugLogger{Logger: log} 179 | } 180 | 181 | func TestDB_StringToDbType(t *testing.T) { 182 | tests := []struct { 183 | name string 184 | want dbw.DbType 185 | wantErr bool 186 | }{ 187 | {name: "postgres", want: dbw.Postgres}, 188 | {name: "sqlite", want: dbw.Sqlite}, 189 | {name: "unknown", want: dbw.UnknownDB, wantErr: true}, 190 | } 191 | for _, tt := range tests { 192 | t.Run(tt.name, func(t *testing.T) { 193 | assert, require := assert.New(t), require.New(t) 194 | got, err := dbw.StringToDbType(tt.name) 195 | if tt.wantErr { 196 | require.Error(err) 197 | return 198 | } 199 | require.NoError(err) 200 | assert.Equal(got, tt.want) 201 | }) 202 | } 203 | } 204 | 205 | func TestDB_SqlDB(t *testing.T) { 206 | testCtx := context.Background() 207 | t.Run("valid", func(t *testing.T) { 208 | assert, require := assert.New(t), require.New(t) 209 | db, err := dbw.Open(dbw.Sqlite, "file::memory:") 210 | require.NoError(err) 211 | got, err := db.SqlDB(testCtx) 212 | require.NoError(err) 213 | assert.NotNil(got) 214 | }) 215 | 216 | t.Run("invalid", func(t *testing.T) { 217 | assert, require := assert.New(t), require.New(t) 218 | db := &dbw.DB{} 219 | got, err := db.SqlDB(testCtx) 220 | require.Error(err) 221 | assert.Nil(got) 222 | }) 223 | } 224 | 225 | func TestDB_Close(t *testing.T) { 226 | testCtx := context.Background() 227 | t.Run("valid", func(t *testing.T) { 228 | assert, require := assert.New(t), require.New(t) 229 | db, err := dbw.Open(dbw.Sqlite, "file::memory:") 230 | require.NoError(err) 231 | got, err := db.SqlDB(testCtx) 232 | require.NoError(err) 233 | require.NotNil(got) 234 | assert.NoError(got.Close()) 235 | }) 236 | t.Run("invalid", func(t *testing.T) { 237 | assert := assert.New(t) 238 | db := &dbw.DB{} 239 | err := db.Close(testCtx) 240 | assert.Error(err) 241 | }) 242 | } 243 | 244 | func TestDB_LogLevel(t *testing.T) { 245 | tests := []struct { 246 | name string 247 | level dbw.LogLevel 248 | }{ 249 | {"default", dbw.Default}, 250 | {"silent", dbw.Silent}, 251 | {"error", dbw.Error}, 252 | {"warn", dbw.Warn}, 253 | {"info", dbw.Info}, 254 | } 255 | for _, tt := range tests { 256 | t.Run(tt.name, func(t *testing.T) { 257 | db, _ := dbw.TestSetup(t) 258 | db.LogLevel(tt.level) 259 | }) 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /delete.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "reflect" 10 | ) 11 | 12 | // Delete a resource in the db with options: WithWhere, WithDebug, WithTable, 13 | // and WithVersion. WithWhere and WithVersion allows specifying a additional 14 | // constraints on the operation in addition to the PKs. Delete returns the 15 | // number of rows deleted and any errors. 16 | func (rw *RW) Delete(ctx context.Context, i interface{}, opt ...Option) (int, error) { 17 | const op = "dbw.Delete" 18 | if rw.underlying == nil { 19 | return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) 20 | } 21 | if isNil(i) { 22 | return noRowsAffected, fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter) 23 | } 24 | if err := raiseErrorOnHooks(i); err != nil { 25 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 26 | } 27 | opts := GetOpts(opt...) 28 | 29 | mDb := rw.underlying.wrapped.Model(i) 30 | err := mDb.Statement.Parse(i) 31 | if err == nil && mDb.Statement.Schema == nil { 32 | return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) 33 | } 34 | reflectValue := reflect.Indirect(reflect.ValueOf(i)) 35 | for _, pf := range mDb.Statement.Schema.PrimaryFields { 36 | if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { 37 | return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) 38 | } 39 | } 40 | if opts.WithBeforeWrite != nil { 41 | if err := opts.WithBeforeWrite(i); err != nil { 42 | return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) 43 | } 44 | } 45 | db := rw.underlying.wrapped.WithContext(ctx) 46 | if opts.WithVersion != nil || opts.WithWhereClause != "" { 47 | where, args, err := rw.whereClausesFromOpts(ctx, i, opts) 48 | if err != nil { 49 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 50 | } 51 | db = db.Where(where, args...) 52 | } 53 | if opts.WithDebug { 54 | db = db.Debug() 55 | } 56 | if opts.WithTable != "" { 57 | db = db.Table(opts.WithTable) 58 | } 59 | db = db.Delete(i) 60 | if db.Error != nil { 61 | return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) 62 | } 63 | rowsDeleted := int(db.RowsAffected) 64 | if rowsDeleted > 0 && opts.WithAfterWrite != nil { 65 | if err := opts.WithAfterWrite(i, rowsDeleted); err != nil { 66 | return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err) 67 | } 68 | } 69 | return rowsDeleted, nil 70 | } 71 | 72 | // DeleteItems will delete multiple items of the same type. Options supported: 73 | // WithWhereClause, WithDebug, WithTable 74 | func (rw *RW) DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error) { 75 | const op = "dbw.DeleteItems" 76 | switch { 77 | case rw.underlying == nil: 78 | return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) 79 | case isNil(deleteItems): 80 | return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter) 81 | } 82 | valDeleteItems := reflect.ValueOf(deleteItems) 83 | switch { 84 | case valDeleteItems.Kind() != reflect.Slice: 85 | return noRowsAffected, fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter) 86 | case valDeleteItems.Len() == 0: 87 | return noRowsAffected, fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) 88 | 89 | } 90 | if err := raiseErrorOnHooks(deleteItems); err != nil { 91 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 92 | } 93 | 94 | opts := GetOpts(opt...) 95 | switch { 96 | case opts.WithLookup: 97 | return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) 98 | case opts.WithVersion != nil: 99 | return noRowsAffected, fmt.Errorf("%s: with version is not a supported option: %w", op, ErrInvalidParameter) 100 | } 101 | 102 | // we need to dig out the stmt so in just a sec we can make sure the PKs are 103 | // set for all the items, so we'll just use the first item to do so. 104 | mDb := rw.underlying.wrapped.Model(valDeleteItems.Index(0).Interface()) 105 | err := mDb.Statement.Parse(valDeleteItems.Index(0).Interface()) 106 | switch { 107 | case err != nil: 108 | return noRowsAffected, fmt.Errorf("%s: (internal error) error parsing stmt: %w", op, err) 109 | case err == nil && mDb.Statement.Schema == nil: 110 | return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) 111 | } 112 | 113 | // verify that deleteItems are all the same type, among a myriad of 114 | // other things on the set of items 115 | var foundType reflect.Type 116 | 117 | for i := 0; i < valDeleteItems.Len(); i++ { 118 | if i == 0 { 119 | foundType = reflect.TypeOf(valDeleteItems.Index(i).Interface()) 120 | } 121 | currentType := reflect.TypeOf(valDeleteItems.Index(i).Interface()) 122 | switch { 123 | case isNil(valDeleteItems.Index(i).Interface()) || currentType == nil: 124 | return noRowsAffected, fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter) 125 | case foundType != currentType: 126 | return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) 127 | } 128 | if opts.WithWhereClause == "" { 129 | // make sure the PK is set for the current item 130 | reflectValue := reflect.Indirect(reflect.ValueOf(valDeleteItems.Index(i).Interface())) 131 | for _, pf := range mDb.Statement.Schema.PrimaryFields { 132 | if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { 133 | return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) 134 | } 135 | } 136 | } 137 | } 138 | 139 | if opts.WithBeforeWrite != nil { 140 | if err := opts.WithBeforeWrite(deleteItems); err != nil { 141 | return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) 142 | } 143 | } 144 | 145 | db := rw.underlying.wrapped.WithContext(ctx) 146 | if opts.WithDebug { 147 | db = db.Debug() 148 | } 149 | 150 | if opts.WithWhereClause != "" { 151 | where, args, err := rw.whereClausesFromOpts(ctx, valDeleteItems.Index(0).Interface(), opts) 152 | if err != nil { 153 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 154 | } 155 | db = db.Where(where, args...) 156 | } 157 | 158 | switch { 159 | case opts.WithTable != "": 160 | db = db.Table(opts.WithTable) 161 | default: 162 | tabler, ok := valDeleteItems.Index(0).Interface().(tableNamer) 163 | if ok { 164 | db = db.Table(tabler.TableName()) 165 | } 166 | } 167 | 168 | db = db.Delete(deleteItems) 169 | if db.Error != nil { 170 | return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) 171 | } 172 | rowsDeleted := int(db.RowsAffected) 173 | if rowsDeleted > 0 && opts.WithAfterWrite != nil { 174 | if err := opts.WithAfterWrite(deleteItems, int(rowsDeleted)); err != nil { 175 | return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err) 176 | } 177 | } 178 | return rowsDeleted, nil 179 | } 180 | 181 | type tableNamer interface { 182 | TableName() string 183 | } 184 | -------------------------------------------------------------------------------- /update.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "reflect" 10 | "sync/atomic" 11 | 12 | "gorm.io/gorm" 13 | ) 14 | 15 | var nonUpdateFields atomic.Value 16 | 17 | // InitNonUpdatableFields sets the fields which are not updatable using 18 | // via RW.Update(...) 19 | func InitNonUpdatableFields(fields []string) { 20 | m := make(map[string]struct{}, len(fields)) 21 | for _, f := range fields { 22 | m[f] = struct{}{} 23 | } 24 | nonUpdateFields.Store(m) 25 | } 26 | 27 | // NonUpdatableFields returns the current set of fields which are not updatable using 28 | // via RW.Update(...) 29 | func NonUpdatableFields() []string { 30 | m := nonUpdateFields.Load() 31 | if m == nil { 32 | return []string{} 33 | } 34 | 35 | fields := make([]string, 0, len(m.(map[string]struct{}))) 36 | for f := range m.(map[string]struct{}) { 37 | fields = append(fields, f) 38 | } 39 | return fields 40 | } 41 | 42 | // Update a resource in the db, a fieldMask is required and provides 43 | // field_mask.proto paths for fields that should be updated. The i interface 44 | // parameter is the type the caller wants to update in the db and its fields are 45 | // set to the update values. setToNullPaths is optional and provides 46 | // field_mask.proto paths for the fields that should be set to null. 47 | // fieldMaskPaths and setToNullPaths must not intersect. The caller is 48 | // responsible for the transaction life cycle of the writer and if an error is 49 | // returned the caller must decide what to do with the transaction, which almost 50 | // always should be to rollback. Update returns the number of rows updated. 51 | // 52 | // Supported options: WithBeforeWrite, WithAfterWrite, WithWhere, WithDebug, 53 | // WithTable and WithVersion. If WithVersion is used, then the update will 54 | // include the version number in the update where clause, which basically makes 55 | // the update use optimistic locking and the update will only succeed if the 56 | // existing rows version matches the WithVersion option. Zero is not a valid 57 | // value for the WithVersion option and will return an error. WithWhere allows 58 | // specifying an additional constraint on the operation in addition to the PKs. 59 | // WithDebug will turn on debugging for the update call. 60 | func (rw *RW) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error) { 61 | const op = "dbw.Update" 62 | if rw.underlying == nil { 63 | return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) 64 | } 65 | if isNil(i) { 66 | return noRowsAffected, fmt.Errorf("%s: missing interface: %w", op, ErrInvalidParameter) 67 | } 68 | if err := raiseErrorOnHooks(i); err != nil { 69 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 70 | } 71 | if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { 72 | return noRowsAffected, fmt.Errorf("%s: both fieldMaskPaths and setToNullPaths are missing: %w", op, ErrInvalidParameter) 73 | } 74 | opts := GetOpts(opt...) 75 | 76 | // we need to filter out some non-updatable fields (like: CreateTime, etc) 77 | fieldMaskPaths = filterPaths(fieldMaskPaths) 78 | setToNullPaths = filterPaths(setToNullPaths) 79 | if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { 80 | return noRowsAffected, fmt.Errorf("%s: after filtering non-updated fields, there are no fields left in fieldMaskPaths or setToNullPaths: %w", op, ErrInvalidParameter) 81 | } 82 | 83 | updateFields, err := UpdateFields(i, fieldMaskPaths, setToNullPaths) 84 | if err != nil { 85 | return noRowsAffected, fmt.Errorf("%s: getting update fields failed: %w", op, err) 86 | } 87 | if len(updateFields) == 0 { 88 | return noRowsAffected, fmt.Errorf("%s: no fields matched using fieldMaskPaths %s: %w", op, fieldMaskPaths, ErrInvalidParameter) 89 | } 90 | 91 | names, isZero, err := rw.primaryFieldsAreZero(ctx, i) 92 | if err != nil { 93 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 94 | } 95 | if isZero { 96 | return noRowsAffected, fmt.Errorf("%s: primary key is not set for: %s: %w", op, names, ErrInvalidParameter) 97 | } 98 | 99 | mDb := rw.underlying.wrapped.Model(i) 100 | err = mDb.Statement.Parse(i) 101 | if err != nil || mDb.Statement.Schema == nil { 102 | return noRowsAffected, fmt.Errorf("%s: internal error: unable to parse stmt: %w", op, err) 103 | } 104 | reflectValue := reflect.Indirect(reflect.ValueOf(i)) 105 | for _, pf := range mDb.Statement.Schema.PrimaryFields { 106 | if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { 107 | return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) 108 | } 109 | if contains(fieldMaskPaths, pf.Name) { 110 | return noRowsAffected, fmt.Errorf("%s: not allowed on primary key field %s: %w", op, pf.Name, ErrInvalidFieldMask) 111 | } 112 | } 113 | 114 | if !opts.WithSkipVetForWrite { 115 | if vetter, ok := i.(VetForWriter); ok { 116 | if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths), WithNullPaths(setToNullPaths)); err != nil { 117 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 118 | } 119 | } 120 | } 121 | if opts.WithBeforeWrite != nil { 122 | if err := opts.WithBeforeWrite(i); err != nil { 123 | return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) 124 | } 125 | } 126 | underlying := rw.underlying.wrapped.Model(i) 127 | if opts.WithDebug { 128 | underlying = underlying.Debug() 129 | } 130 | if opts.WithTable != "" { 131 | underlying = underlying.Table(opts.WithTable) 132 | } 133 | switch { 134 | case opts.WithVersion != nil || opts.WithWhereClause != "": 135 | where, args, err := rw.whereClausesFromOpts(ctx, i, opts) 136 | if err != nil { 137 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 138 | } 139 | underlying = underlying.Where(where, args...).Updates(updateFields) 140 | default: 141 | underlying = underlying.Updates(updateFields) 142 | } 143 | if underlying.Error != nil { 144 | if underlying.Error == gorm.ErrRecordNotFound { 145 | return noRowsAffected, fmt.Errorf("%s: %w", op, gorm.ErrRecordNotFound) 146 | } 147 | return noRowsAffected, fmt.Errorf("%s: %w", op, underlying.Error) 148 | } 149 | rowsUpdated := int(underlying.RowsAffected) 150 | if rowsUpdated > 0 && (opts.WithAfterWrite != nil) { 151 | if err := opts.WithAfterWrite(i, rowsUpdated); err != nil { 152 | return rowsUpdated, fmt.Errorf("%s: error after write: %w", op, err) 153 | } 154 | } 155 | // we need to force a lookupAfterWrite so the resource returned is correctly initialized 156 | // from the db 157 | opt = append(opt, WithLookup(true)) 158 | if err := rw.lookupAfterWrite(ctx, i, opt...); err != nil { 159 | return noRowsAffected, fmt.Errorf("%s: %w", op, err) 160 | } 161 | return rowsUpdated, nil 162 | } 163 | 164 | // filterPaths will filter out non-updatable fields 165 | func filterPaths(paths []string) []string { 166 | if len(paths) == 0 { 167 | return nil 168 | } 169 | nonUpdatable := NonUpdatableFields() 170 | if len(nonUpdatable) == 0 { 171 | return paths 172 | } 173 | var filtered []string 174 | for _, p := range paths { 175 | switch { 176 | case contains(nonUpdatable, p): 177 | continue 178 | default: 179 | filtered = append(filtered, p) 180 | } 181 | } 182 | return filtered 183 | } 184 | -------------------------------------------------------------------------------- /do_tx_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw_test 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "strings" 11 | "testing" 12 | "time" 13 | 14 | "github.com/hashicorp/go-dbw" 15 | "github.com/hashicorp/go-dbw/internal/dbtest" 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | ) 19 | 20 | func TestDb_DoTx(t *testing.T) { 21 | t.Parallel() 22 | testCtx := context.TODO() 23 | conn, _ := dbw.TestSetup(t) 24 | retryErr := errors.New("retry error") 25 | retryOnFn := func(err error) bool { 26 | if errors.Is(err, retryErr) { 27 | return true 28 | } 29 | return false 30 | } 31 | t.Run("timed-out", func(t *testing.T) { 32 | assert, require := assert.New(t), require.New(t) 33 | timeoutCtx, timeoutCancel := context.WithTimeout(testCtx, 1*time.Microsecond) 34 | defer timeoutCancel() 35 | 36 | w := dbw.New(conn) 37 | attempts := 0 38 | _, err := w.DoTx(timeoutCtx, retryOnFn, 2, dbw.ConstBackoff{DurationMs: 1}, func(dbw.Reader, dbw.Writer) error { 39 | attempts += 1 40 | return retryErr 41 | }) 42 | require.Error(err) 43 | const ( 44 | cancelledMsg = "dbw.DoTx: cancelled" 45 | deadlineMsg = "dbw.DoTx: context deadline exceeded" 46 | ) 47 | switch { 48 | case strings.Contains(err.Error(), cancelledMsg), strings.Contains(err.Error(), deadlineMsg): 49 | default: 50 | assert.Failf("error does not contain %q or %q", cancelledMsg, deadlineMsg) 51 | } 52 | }) 53 | t.Run("valid-with-10-retries", func(t *testing.T) { 54 | assert, require := assert.New(t), require.New(t) 55 | w := dbw.New(conn) 56 | attempts := 0 57 | got, err := w.DoTx(testCtx, retryOnFn, 10, dbw.ExpBackoff{}, 58 | func(dbw.Reader, dbw.Writer) error { 59 | attempts += 1 60 | if attempts < 9 { 61 | return retryErr 62 | } 63 | return nil 64 | }) 65 | require.NoError(err) 66 | assert.Equal(8, got.Retries) 67 | assert.Equal(9, attempts) // attempted 1 + 8 retries 68 | }) 69 | t.Run("valid-with-1-retries", func(t *testing.T) { 70 | assert, require := assert.New(t), require.New(t) 71 | w := dbw.New(conn) 72 | attempts := 0 73 | got, err := w.DoTx(testCtx, retryOnFn, 1, dbw.ExpBackoff{}, 74 | func(dbw.Reader, dbw.Writer) error { 75 | attempts += 1 76 | if attempts < 2 { 77 | return retryErr 78 | } 79 | return nil 80 | }) 81 | require.NoError(err) 82 | assert.Equal(1, got.Retries) 83 | assert.Equal(2, attempts) // attempted 1 + 8 retries 84 | }) 85 | t.Run("valid-with-2-retries", func(t *testing.T) { 86 | assert, require := assert.New(t), require.New(t) 87 | w := dbw.New(conn) 88 | attempts := 0 89 | got, err := w.DoTx(testCtx, retryOnFn, 3, dbw.ExpBackoff{}, 90 | func(dbw.Reader, dbw.Writer) error { 91 | attempts += 1 92 | if attempts < 3 { 93 | return retryErr 94 | } 95 | return nil 96 | }) 97 | require.NoError(err) 98 | assert.Equal(2, got.Retries) 99 | assert.Equal(3, attempts) // attempted 1 + 8 retries 100 | }) 101 | t.Run("valid-with-4-retries", func(t *testing.T) { 102 | assert, require := assert.New(t), require.New(t) 103 | w := dbw.New(conn) 104 | attempts := 0 105 | got, err := w.DoTx(testCtx, retryOnFn, 4, dbw.ExpBackoff{}, 106 | func(dbw.Reader, dbw.Writer) error { 107 | attempts += 1 108 | if attempts < 4 { 109 | return retryErr 110 | } 111 | return nil 112 | }) 113 | require.NoError(err) 114 | assert.Equal(3, got.Retries) 115 | assert.Equal(4, attempts) // attempted 1 + 8 retries 116 | }) 117 | t.Run("zero-retries", func(t *testing.T) { 118 | assert, require := assert.New(t), require.New(t) 119 | w := dbw.New(conn) 120 | attempts := 0 121 | got, err := w.DoTx(testCtx, retryOnFn, 0, dbw.ExpBackoff{}, func(dbw.Reader, dbw.Writer) error { attempts += 1; return nil }) 122 | require.NoError(err) 123 | assert.Equal(dbw.RetryInfo{}, got) 124 | assert.Equal(1, attempts) 125 | }) 126 | t.Run("nil-tx", func(t *testing.T) { 127 | assert, require := assert.New(t), require.New(t) 128 | w := &dbw.RW{} 129 | attempts := 0 130 | got, err := w.DoTx(testCtx, retryOnFn, 1, dbw.ExpBackoff{}, func(dbw.Reader, dbw.Writer) error { attempts += 1; return nil }) 131 | require.Error(err) 132 | assert.Equal(dbw.RetryInfo{}, got) 133 | assert.Equal("dbw.DoTx: missing underlying db: invalid parameter", err.Error()) 134 | }) 135 | t.Run("nil-retryOnFn", func(t *testing.T) { 136 | assert, require := assert.New(t), require.New(t) 137 | w := dbw.New(conn) 138 | attempts := 0 139 | got, err := w.DoTx(testCtx, nil, 1, dbw.ExpBackoff{}, func(dbw.Reader, dbw.Writer) error { attempts += 1; return nil }) 140 | require.Error(err) 141 | assert.Equal(dbw.RetryInfo{}, got) 142 | assert.Equal("dbw.DoTx: missing retry errors matching function: invalid parameter", err.Error()) 143 | }) 144 | t.Run("nil-handler", func(t *testing.T) { 145 | assert, require := assert.New(t), require.New(t) 146 | w := dbw.New(conn) 147 | got, err := w.DoTx(testCtx, retryOnFn, 1, dbw.ExpBackoff{}, nil) 148 | require.Error(err) 149 | assert.Equal(dbw.RetryInfo{}, got) 150 | assert.Equal("dbw.DoTx: missing handler: invalid parameter", err.Error()) 151 | }) 152 | t.Run("nil-backoff", func(t *testing.T) { 153 | assert, require := assert.New(t), require.New(t) 154 | w := dbw.New(conn) 155 | attempts := 0 156 | got, err := w.DoTx(testCtx, retryOnFn, 1, nil, func(dbw.Reader, dbw.Writer) error { attempts += 1; return nil }) 157 | require.Error(err) 158 | assert.Equal(dbw.RetryInfo{}, got) 159 | assert.Equal("dbw.DoTx: missing backoff: invalid parameter", err.Error()) 160 | }) 161 | t.Run("not-a-retry-err", func(t *testing.T) { 162 | assert, require := assert.New(t), require.New(t) 163 | w := dbw.New(conn) 164 | got, err := w.DoTx(testCtx, retryOnFn, 1, dbw.ExpBackoff{}, func(dbw.Reader, dbw.Writer) error { return errors.New("not a retry error") }) 165 | require.Error(err) 166 | assert.Equal(dbw.RetryInfo{}, got) 167 | assert.False(errors.Is(err, retryErr)) 168 | }) 169 | t.Run("too-many-retries", func(t *testing.T) { 170 | assert, require := assert.New(t), require.New(t) 171 | w := dbw.New(conn) 172 | attempts := 0 173 | got, err := w.DoTx(testCtx, retryOnFn, 2, dbw.ConstBackoff{}, func(dbw.Reader, dbw.Writer) error { 174 | attempts += 1 175 | return retryErr 176 | }) 177 | require.Error(err) 178 | assert.Equal(3, got.Retries) 179 | assert.Contains(err.Error(), "dbw.DoTx: too many retries: 3 of 3") 180 | }) 181 | t.Run("updating-good-bad-good", func(t *testing.T) { 182 | assert, require := assert.New(t), require.New(t) 183 | rw := dbw.New(conn) 184 | id, err := dbw.NewId("i") 185 | require.NoError(err) 186 | user, err := dbtest.NewTestUser() 187 | require.NoError(err) 188 | user.Name = "foo-" + id 189 | err = rw.Create(context.Background(), user) 190 | require.NoError(err) 191 | 192 | _, err = rw.DoTx(testCtx, retryOnFn, 10, dbw.ExpBackoff{}, func(r dbw.Reader, w dbw.Writer) error { 193 | user.Name = "friendly-" + id 194 | rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) 195 | if err != nil { 196 | return err 197 | } 198 | if rowsUpdated != 1 { 199 | return fmt.Errorf("error in number of rows updated %d", rowsUpdated) 200 | } 201 | return nil 202 | }) 203 | require.NoError(err) 204 | 205 | foundUser := dbtest.AllocTestUser() 206 | assert.NoError(err) 207 | foundUser.PublicId = user.PublicId 208 | err = rw.LookupByPublicId(context.Background(), &foundUser) 209 | require.NoError(err) 210 | assert.Equal(foundUser.Name, user.Name) 211 | 212 | user2, err := dbtest.NewTestUser() 213 | require.NoError(err) 214 | _, err = rw.DoTx(testCtx, retryOnFn, 10, dbw.ExpBackoff{}, func(_ dbw.Reader, w dbw.Writer) error { 215 | user2.Name = "friendly2-" + id 216 | rowsUpdated, err := w.Update(context.Background(), user2, []string{"Name"}, nil) 217 | if err != nil { 218 | return err 219 | } 220 | if rowsUpdated != 1 { 221 | return fmt.Errorf("error in number of rows updated %d", rowsUpdated) 222 | } 223 | return nil 224 | }) 225 | require.Error(err) 226 | err = rw.LookupByPublicId(context.Background(), &foundUser) 227 | require.NoError(err) 228 | assert.NotEqual(foundUser.Name, user2.Name) 229 | 230 | _, err = rw.DoTx(testCtx, retryOnFn, 10, dbw.ExpBackoff{}, func(r dbw.Reader, w dbw.Writer) error { 231 | user.Name = "friendly2-" + id 232 | rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) 233 | if err != nil { 234 | return err 235 | } 236 | if rowsUpdated != 1 { 237 | return fmt.Errorf("error in number of rows updated %d", rowsUpdated) 238 | } 239 | return nil 240 | }) 241 | require.NoError(err) 242 | err = rw.LookupByPublicId(context.Background(), &foundUser) 243 | require.NoError(err) 244 | assert.Equal(foundUser.Name, user.Name) 245 | }) 246 | } 247 | -------------------------------------------------------------------------------- /option_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "testing" 8 | 9 | "github.com/hashicorp/go-hclog" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | // Test_getOpts provides unit tests for GetOpts and all the options 14 | func Test_getOpts(t *testing.T) { 15 | t.Parallel() 16 | t.Run("WithLookup", func(t *testing.T) { 17 | assert := assert.New(t) 18 | // test default of true 19 | opts := GetOpts() 20 | testOpts := getDefaultOptions() 21 | testOpts.WithLookup = false 22 | assert.Equal(opts, testOpts) 23 | 24 | // try setting to false 25 | opts = GetOpts(WithLookup(true)) 26 | testOpts = getDefaultOptions() 27 | testOpts.WithLookup = true 28 | assert.Equal(opts, testOpts) 29 | }) 30 | t.Run("WithFieldMaskPaths", func(t *testing.T) { 31 | assert := assert.New(t) 32 | // test default of []string{} 33 | opts := GetOpts() 34 | testOpts := getDefaultOptions() 35 | testOpts.WithFieldMaskPaths = []string{} 36 | assert.Equal(opts, testOpts) 37 | 38 | testPaths := []string{"alice", "bob"} 39 | opts = GetOpts(WithFieldMaskPaths(testPaths)) 40 | testOpts = getDefaultOptions() 41 | testOpts.WithFieldMaskPaths = testPaths 42 | assert.Equal(opts, testOpts) 43 | }) 44 | t.Run("WithNullPaths", func(t *testing.T) { 45 | assert := assert.New(t) 46 | // test default of []string{} 47 | opts := GetOpts() 48 | testOpts := getDefaultOptions() 49 | testOpts.WithNullPaths = []string{} 50 | assert.Equal(opts, testOpts) 51 | 52 | testPaths := []string{"alice", "bob"} 53 | opts = GetOpts(WithNullPaths(testPaths)) 54 | testOpts = getDefaultOptions() 55 | testOpts.WithNullPaths = testPaths 56 | assert.Equal(opts, testOpts) 57 | }) 58 | t.Run("WithLimit", func(t *testing.T) { 59 | assert := assert.New(t) 60 | // test default of 0 61 | opts := GetOpts() 62 | testOpts := getDefaultOptions() 63 | testOpts.WithLimit = 0 64 | assert.Equal(opts, testOpts) 65 | 66 | opts = GetOpts(WithLimit(-1)) 67 | testOpts = getDefaultOptions() 68 | testOpts.WithLimit = -1 69 | assert.Equal(opts, testOpts) 70 | 71 | opts = GetOpts(WithLimit(1)) 72 | testOpts = getDefaultOptions() 73 | testOpts.WithLimit = 1 74 | assert.Equal(opts, testOpts) 75 | }) 76 | t.Run("WithVersion", func(t *testing.T) { 77 | assert := assert.New(t) 78 | // test default of 0 79 | opts := GetOpts() 80 | testOpts := getDefaultOptions() 81 | testOpts.WithVersion = nil 82 | assert.Equal(opts, testOpts) 83 | versionTwo := uint32(2) 84 | opts = GetOpts(WithVersion(&versionTwo)) 85 | testOpts = getDefaultOptions() 86 | testOpts.WithVersion = &versionTwo 87 | assert.Equal(opts, testOpts) 88 | }) 89 | t.Run("WithSkipVetForWrite", func(t *testing.T) { 90 | assert := assert.New(t) 91 | // test default of false 92 | opts := GetOpts() 93 | testOpts := getDefaultOptions() 94 | testOpts.WithSkipVetForWrite = false 95 | assert.Equal(opts, testOpts) 96 | opts = GetOpts(WithSkipVetForWrite(true)) 97 | testOpts = getDefaultOptions() 98 | testOpts.WithSkipVetForWrite = true 99 | assert.Equal(opts, testOpts) 100 | }) 101 | t.Run("WithWhere", func(t *testing.T) { 102 | assert := assert.New(t) 103 | // test default of false 104 | opts := GetOpts() 105 | testOpts := getDefaultOptions() 106 | testOpts.WithWhereClause = "" 107 | testOpts.WithWhereClauseArgs = nil 108 | assert.Equal(opts, testOpts) 109 | opts = GetOpts(WithWhere("id = ? and foo = ?", 1234, "bar")) 110 | testOpts.WithWhereClause = "id = ? and foo = ?" 111 | testOpts.WithWhereClauseArgs = []interface{}{1234, "bar"} 112 | assert.Equal(opts, testOpts) 113 | }) 114 | t.Run("WithOrder", func(t *testing.T) { 115 | assert := assert.New(t) 116 | // test default of false 117 | opts := GetOpts() 118 | testOpts := getDefaultOptions() 119 | testOpts.WithOrder = "" 120 | assert.Equal(opts, testOpts) 121 | opts = GetOpts(WithOrder("version desc")) 122 | testOpts.WithOrder = "version desc" 123 | assert.Equal(opts, testOpts) 124 | }) 125 | t.Run("WithGormFormatter", func(t *testing.T) { 126 | assert := assert.New(t) 127 | // test default of false 128 | opts := GetOpts() 129 | testOpts := getDefaultOptions() 130 | assert.Equal(opts, testOpts) 131 | 132 | testLogger := hclog.New(&hclog.LoggerOptions{}) 133 | opts = GetOpts(WithLogger(testLogger)) 134 | testOpts.WithLogger = testLogger 135 | assert.Equal(opts, testOpts) 136 | }) 137 | t.Run("WithMaxOpenConnections", func(t *testing.T) { 138 | assert := assert.New(t) 139 | // test default of false 140 | opts := GetOpts() 141 | testOpts := getDefaultOptions() 142 | assert.Equal(opts, testOpts) 143 | opts = GetOpts(WithMaxOpenConnections(22)) 144 | testOpts.WithMaxOpenConnections = 22 145 | assert.Equal(opts, testOpts) 146 | }) 147 | t.Run("WithDebug", func(t *testing.T) { 148 | assert := assert.New(t) 149 | // test default of false 150 | opts := GetOpts() 151 | testOpts := getDefaultOptions() 152 | assert.Equal(opts, testOpts) 153 | // try setting to true 154 | opts = GetOpts(WithDebug(true)) 155 | testOpts.WithDebug = true 156 | assert.Equal(opts, testOpts) 157 | }) 158 | t.Run("WithOnConflict", func(t *testing.T) { 159 | assert := assert.New(t) 160 | // test default of false 161 | opts := GetOpts() 162 | testOpts := getDefaultOptions() 163 | assert.Equal(opts, testOpts) 164 | columns := SetColumns([]string{"name", "description"}) 165 | columnValues := SetColumnValues(map[string]interface{}{"expiration": "NULL"}) 166 | testOnConflict := OnConflict{ 167 | Target: Constraint("uniq-name"), 168 | Action: append(columns, columnValues...), 169 | } 170 | opts = GetOpts(WithOnConflict(&testOnConflict)) 171 | testOpts.WithOnConflict = &testOnConflict 172 | assert.Equal(opts, testOpts) 173 | }) 174 | t.Run("WithReturnRowsAffected", func(t *testing.T) { 175 | assert := assert.New(t) 176 | // test default of false 177 | opts := GetOpts() 178 | testOpts := getDefaultOptions() 179 | assert.Equal(opts, testOpts) 180 | 181 | var rowsAffected int64 182 | opts = GetOpts(WithReturnRowsAffected(&rowsAffected)) 183 | testOpts.WithRowsAffected = &rowsAffected 184 | assert.Equal(opts, testOpts) 185 | }) 186 | t.Run("WithBeforeWrite", func(t *testing.T) { 187 | assert := assert.New(t) 188 | // test defaults 189 | opts := GetOpts() 190 | assert.Nil(opts.WithBeforeWrite) 191 | 192 | fn := func(interface{}) error { return nil } 193 | opts = GetOpts(WithBeforeWrite(fn)) 194 | assert.NotNil(opts.WithBeforeWrite) 195 | }) 196 | t.Run("WithAfterWrite", func(t *testing.T) { 197 | assert := assert.New(t) 198 | // test defaults 199 | opts := GetOpts() 200 | assert.Nil(opts.WithAfterWrite) 201 | 202 | fn := func(interface{}, int) error { return nil } 203 | opts = GetOpts(WithAfterWrite(fn)) 204 | assert.NotNil(opts.WithAfterWrite) 205 | }) 206 | t.Run("WithMaxOpenConnections", func(t *testing.T) { 207 | assert := assert.New(t) 208 | // test default of 0 209 | opts := GetOpts() 210 | testOpts := getDefaultOptions() 211 | testOpts.WithMaxOpenConnections = 0 212 | assert.Equal(opts, testOpts) 213 | opts = GetOpts(WithMaxOpenConnections(1)) 214 | testOpts = getDefaultOptions() 215 | testOpts.WithMaxOpenConnections = 1 216 | assert.Equal(opts, testOpts) 217 | }) 218 | t.Run("WithMinOpenConnections", func(t *testing.T) { 219 | assert := assert.New(t) 220 | // test default of 0 221 | opts := GetOpts() 222 | testOpts := getDefaultOptions() 223 | testOpts.WithMinOpenConnections = 0 224 | assert.Equal(opts, testOpts) 225 | opts = GetOpts(WithMinOpenConnections(1)) 226 | testOpts = getDefaultOptions() 227 | testOpts.WithMinOpenConnections = 1 228 | assert.Equal(opts, testOpts) 229 | }) 230 | t.Run("WithTable", func(t *testing.T) { 231 | assert := assert.New(t) 232 | // test default 233 | opts := GetOpts() 234 | testOpts := getDefaultOptions() 235 | testOpts.WithTable = "" 236 | assert.Equal(opts, testOpts) 237 | 238 | opts = GetOpts(WithTable("tmp_table_name")) 239 | testOpts = getDefaultOptions() 240 | testOpts.WithTable = "tmp_table_name" 241 | assert.Equal(opts, testOpts) 242 | }) 243 | t.Run("WithLogLevel", func(t *testing.T) { 244 | assert := assert.New(t) 245 | // test default 246 | opts := GetOpts() 247 | testOpts := getDefaultOptions() 248 | testOpts.withLogLevel = Error 249 | assert.Equal(opts, testOpts) 250 | 251 | opts = GetOpts(WithLogLevel(Warn)) 252 | testOpts = getDefaultOptions() 253 | testOpts.withLogLevel = Warn 254 | assert.Equal(opts, testOpts) 255 | }) 256 | t.Run("WithBatchSize", func(t *testing.T) { 257 | assert := assert.New(t) 258 | // test default 259 | opts := GetOpts() 260 | testOpts := getDefaultOptions() 261 | testOpts.WithBatchSize = DefaultBatchSize 262 | assert.Equal(opts, testOpts) 263 | 264 | opts = GetOpts(WithBatchSize(100)) 265 | testOpts = getDefaultOptions() 266 | testOpts.WithBatchSize = 100 267 | assert.Equal(opts, testOpts) 268 | }) 269 | } 270 | -------------------------------------------------------------------------------- /db.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "fmt" 10 | "strings" 11 | 12 | "github.com/hashicorp/go-hclog" 13 | "github.com/jackc/pgconn" 14 | 15 | _ "github.com/jackc/pgx/v5" // required to load postgres drivers 16 | "gorm.io/driver/postgres" 17 | "gorm.io/driver/sqlite" 18 | 19 | "gorm.io/gorm" 20 | "gorm.io/gorm/logger" 21 | ) 22 | 23 | // DbType defines a database type. It's not an exhaustive list of database 24 | // types which can be used by the dbw package, since you can always use 25 | // OpenWith(...) to connect to KnownDB types. 26 | type DbType int 27 | 28 | const ( 29 | // UnknownDB is an unknown db type 30 | UnknownDB DbType = 0 31 | 32 | // Postgres is a postgre db type 33 | Postgres DbType = 1 34 | 35 | // Sqlite is a sqlite db type 36 | Sqlite DbType = 2 37 | ) 38 | 39 | // String provides a string rep of the DbType. 40 | func (db DbType) String() string { 41 | return [...]string{ 42 | "unknown", 43 | "postgres", 44 | "sqlite", 45 | }[db] 46 | } 47 | 48 | // StringToDbType provides a string to type conversion. If the type is known, 49 | // then UnknownDB with and error is returned. 50 | func StringToDbType(dialect string) (DbType, error) { 51 | switch dialect { 52 | case "postgres": 53 | return Postgres, nil 54 | case "sqlite": 55 | return Sqlite, nil 56 | default: 57 | return UnknownDB, fmt.Errorf("%s is an unknown dialect", dialect) 58 | } 59 | } 60 | 61 | // DB is a wrapper around whatever is providing the interface for database 62 | // operations (typically an ORM). DB uses database/sql to maintain connection 63 | // pool. 64 | type DB struct { 65 | wrapped *gorm.DB 66 | } 67 | 68 | // DbType will return the DbType and raw name of the connection type 69 | func (db *DB) DbType() (typ DbType, rawName string, e error) { 70 | rawName = db.wrapped.Dialector.Name() 71 | typ, _ = StringToDbType(rawName) 72 | return typ, rawName, nil 73 | } 74 | 75 | // Debug will enable/disable debug info for the connection 76 | func (db *DB) Debug(on bool) { 77 | if on { 78 | // info level in the Gorm domain which maps to a debug level in this domain 79 | db.LogLevel(Info) 80 | } else { 81 | // the default level in the gorm domain is: error level 82 | db.LogLevel(Error) 83 | } 84 | } 85 | 86 | // LogLevel defines a log level 87 | type LogLevel int 88 | 89 | const ( 90 | // Default specifies the default log level 91 | Default LogLevel = iota 92 | 93 | // Silent is the silent log level 94 | Silent 95 | 96 | // Error is the error log level 97 | Error 98 | 99 | // Warn is the warning log level 100 | Warn 101 | 102 | // Info is the info log level 103 | Info 104 | ) 105 | 106 | // LogLevel will set the logging level for the db 107 | func (db *DB) LogLevel(l LogLevel) { 108 | db.wrapped.Logger = db.wrapped.Logger.LogMode(logger.LogLevel(l)) 109 | } 110 | 111 | // SqlDB returns the underlying sql.DB Note: this makes it possible to do 112 | // things like set database/sql connection options like SetMaxIdleConns. If 113 | // you're simply setting max/min connections then you should use the 114 | // WithMinOpenConnections and WithMaxOpenConnections options when 115 | // "opening" the database. 116 | // 117 | // Care should be take when deciding to use this for basic database operations 118 | // like Exec, Query, etc since these functions are already provided by dbw.RW 119 | // which provides a layer of encapsulation of the underlying database. 120 | func (db *DB) SqlDB(_ context.Context) (*sql.DB, error) { 121 | const op = "dbw.(DB).SqlDB" 122 | if db.wrapped == nil { 123 | return nil, fmt.Errorf("%s: missing underlying database: %w", op, ErrInternal) 124 | } 125 | return db.wrapped.DB() 126 | } 127 | 128 | // Close the database 129 | // 130 | // Note: Consider if you need to call Close() on the returned DB. Typically the 131 | // answer is no, but there are occasions when it's necessary. See the sql.DB 132 | // docs for more information. 133 | func (db *DB) Close(ctx context.Context) error { 134 | const op = "dbw.(DB).Close" 135 | if db.wrapped == nil { 136 | return fmt.Errorf("%s: missing underlying database: %w", op, ErrInternal) 137 | } 138 | underlying, err := db.wrapped.DB() 139 | if err != nil { 140 | return fmt.Errorf("%s: %w", op, err) 141 | } 142 | return underlying.Close() 143 | } 144 | 145 | // Open a database connection which is long-lived. The options of 146 | // WithLogger, WithLogLevel and WithMaxOpenConnections are supported. 147 | // 148 | // Note: Consider if you need to call Close() on the returned DB. Typically the 149 | // answer is no, but there are occasions when it's necessary. See the sql.DB 150 | // docs for more information. 151 | func Open(dbType DbType, connectionUrl string, opt ...Option) (*DB, error) { 152 | const op = "dbw.Open" 153 | if connectionUrl == "" { 154 | return nil, fmt.Errorf("%s: missing connection url: %w", op, ErrInvalidParameter) 155 | } 156 | var dialect gorm.Dialector 157 | switch dbType { 158 | case Postgres: 159 | dialect = postgres.New(postgres.Config{ 160 | DSN: connectionUrl, 161 | }, 162 | ) 163 | case Sqlite: 164 | dialect = sqlite.Open(connectionUrl) 165 | 166 | default: 167 | return nil, fmt.Errorf("unable to open %s database type", dbType) 168 | } 169 | db, err := openDialector(dialect, opt...) 170 | if err != nil { 171 | return nil, fmt.Errorf("%s: %w", op, err) 172 | } 173 | if dbType == Sqlite { 174 | if _, err := New(db).Exec(context.Background(), "PRAGMA foreign_keys=ON", nil); err != nil { 175 | return nil, fmt.Errorf("%s: unable to enable sqlite foreign keys: %w", op, err) 176 | } 177 | } 178 | return db, nil 179 | } 180 | 181 | // Dialector provides a set of functions the database dialect must satisfy to 182 | // be used with OpenWith(...) 183 | // It's a simple wrapper of the gorm.Dialector and provides the ability to open 184 | // any support gorm dialect driver. 185 | type Dialector interface { 186 | gorm.Dialector 187 | } 188 | 189 | // OpenWith will open a database connection using a Dialector which is 190 | // long-lived. The options of WithLogger, WithLogLevel and 191 | // WithMaxOpenConnections are supported. 192 | // 193 | // Note: Consider if you need to call Close() on the returned DB. Typically the 194 | // answer is no, but there are occasions when it's necessary. See the sql.DB 195 | // docs for more information. 196 | func OpenWith(dialector Dialector, opt ...Option) (*DB, error) { 197 | return openDialector(dialector, opt...) 198 | } 199 | 200 | func openDialector(dialect gorm.Dialector, opt ...Option) (*DB, error) { 201 | db, err := gorm.Open(dialect, &gorm.Config{}) 202 | if err != nil { 203 | return nil, fmt.Errorf("unable to open database: %w", err) 204 | } 205 | if strings.ToLower(dialect.Name()) == "sqlite" { 206 | if err := db.Exec("PRAGMA foreign_keys=ON", nil).Error; err != nil { 207 | return nil, fmt.Errorf("unable to enable sqlite foreign keys: %w", err) 208 | } 209 | } 210 | opts := GetOpts(opt...) 211 | if opts.WithLogger != nil { 212 | var newLogger logger.Interface 213 | loggerConfig := logger.Config{ 214 | LogLevel: logger.LogLevel(opts.withLogLevel), // Log level 215 | Colorful: false, // Disable color 216 | } 217 | switch v := opts.WithLogger.(type) { 218 | case LogWriter: 219 | // it's already a gorm logger, so we just need to configure it 220 | newLogger = logger.New(v, loggerConfig) 221 | default: 222 | newLogger = logger.New( 223 | getGormLogger(opts.WithLogger), // wrap the hclog with a gorm logger that only logs errors 224 | loggerConfig, 225 | ) 226 | } 227 | db = db.Session(&gorm.Session{Logger: newLogger}) 228 | } 229 | if opts.WithMaxOpenConnections > 0 { 230 | if opts.WithMinOpenConnections > 0 && (opts.WithMaxOpenConnections < opts.WithMinOpenConnections) { 231 | return nil, fmt.Errorf("unable to create db object with dialect %s: %s", dialect, fmt.Sprintf("max_open_connections must be unlimited by setting 0 or at least %d", opts.WithMinOpenConnections)) 232 | } 233 | underlyingDB, err := db.DB() 234 | if err != nil { 235 | return nil, fmt.Errorf("unable retrieve db: %w", err) 236 | } 237 | underlyingDB.SetMaxOpenConns(opts.WithMaxOpenConnections) 238 | } 239 | 240 | ret := &DB{wrapped: db} 241 | ret.Debug(opts.WithDebug) 242 | return ret, nil 243 | } 244 | 245 | // LogWriter defines an interface which can be used when passing a logger via 246 | // WithLogger(...). This interface allows callers to override the default 247 | // behavior for a logger (the default only emits postgres errors) 248 | type LogWriter interface { 249 | Printf(string, ...any) 250 | } 251 | 252 | type gormLogger struct { 253 | logger hclog.Logger 254 | } 255 | 256 | func (g gormLogger) Printf(_ string, values ...interface{}) { 257 | if len(values) > 1 { 258 | switch values[1].(type) { 259 | case *pgconn.PgError: 260 | g.logger.Trace("error from database adapter", "location", values[0], "error", values[1]) 261 | } 262 | } 263 | } 264 | 265 | func getGormLogger(log hclog.Logger) gormLogger { 266 | return gormLogger{logger: log} 267 | } 268 | -------------------------------------------------------------------------------- /option.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "github.com/hashicorp/go-hclog" 8 | ) 9 | 10 | // GetOpts - iterate the inbound Options and return a struct. 11 | func GetOpts(opt ...Option) Options { 12 | opts := getDefaultOptions() 13 | for _, o := range opt { 14 | if o != nil { 15 | o(&opts) 16 | } 17 | } 18 | return opts 19 | } 20 | 21 | // Option - how Options are passed as arguments. 22 | type Option func(*Options) 23 | 24 | // Options - how Options are represented which have been set via an Option 25 | // function. Use GetOpts(...) to populated this struct with the options that 26 | // have been specified for an operation. All option fields are exported so 27 | // they're available for use by other packages. 28 | type Options struct { 29 | // WithBeforeWrite provides and option to provide a func to be called before a 30 | // write operation. The i interface{} passed at runtime will be the resource(s) 31 | // being written. 32 | WithBeforeWrite func(i interface{}) error 33 | 34 | // WithAfterWrite provides and option to provide a func to be called after a 35 | // write operation. The i interface{} passed at runtime will be the resource(s) 36 | // being written. 37 | WithAfterWrite func(i interface{}, rowsAffected int) error 38 | 39 | // WithLookup enables a lookup after a write operation. 40 | WithLookup bool 41 | 42 | // WithLimit provides an option to provide a limit. Intentionally allowing 43 | // negative integers. If WithLimit < 0, then unlimited results are returned. 44 | // If WithLimit == 0, then default limits are used for results (see DefaultLimit 45 | // const). 46 | WithLimit int 47 | 48 | // WithFieldMaskPaths provides an option to provide field mask paths for update 49 | // operations. 50 | WithFieldMaskPaths []string 51 | 52 | // WithNullPaths provides an option to provide null paths for update 53 | // operations. 54 | WithNullPaths []string 55 | 56 | // WithVersion provides an option version number for update operations. Using 57 | // this option requires that your resource has a version column that's 58 | // incremented for every successful update operation. Version provides an 59 | // optimistic locking mechanism for write operations. 60 | WithVersion *uint32 61 | 62 | WithSkipVetForWrite bool 63 | 64 | // WithWhereClause provides an option to provide a where clause for an 65 | // operation. 66 | WithWhereClause string 67 | 68 | // WithWhereClauseArgs provides an option to provide a where clause arguments for an 69 | // operation. 70 | WithWhereClauseArgs []interface{} 71 | 72 | // WithOrder provides an option to provide an order when searching and looking 73 | // up. 74 | WithOrder string 75 | 76 | // WithPrngValues provides an option to provide values to seed an PRNG when generating IDs 77 | WithPrngValues []string 78 | 79 | // WithLogger specifies an optional hclog to use for db operations. It's only 80 | // valid for Open(..) and OpenWith(...) The logger provided can optionally 81 | // implement the LogWriter interface as well which would override the default 82 | // behavior for a logger (the default only emits postgres errors) 83 | WithLogger hclog.Logger 84 | 85 | // WithMinOpenConnections specifies and optional min open connections for the 86 | // database. A value of zero means that there is no min. 87 | WithMaxOpenConnections int 88 | 89 | // WithMaxOpenConnections specifies and optional max open connections for the 90 | // database. A value of zero equals unlimited connections 91 | WithMinOpenConnections int 92 | 93 | // WithDebug indicates that the given operation should invoke debug output 94 | // mode 95 | WithDebug bool 96 | 97 | // WithOnConflict specifies an optional on conflict criteria which specify 98 | // alternative actions to take when an insert results in a unique constraint or 99 | // exclusion constraint error 100 | WithOnConflict *OnConflict 101 | 102 | // WithRowsAffected specifies an option for returning the rows affected 103 | // and typically used with "bulk" write operations. 104 | WithRowsAffected *int64 105 | 106 | // WithTable specifies an option for setting a table name to use for the 107 | // operation. 108 | WithTable string 109 | 110 | // WithBatchSize specifies an option for setting the batch size for bulk 111 | // operations. If WithBatchSize == 0, then the default batch size is used. 112 | WithBatchSize int 113 | 114 | withLogLevel LogLevel 115 | } 116 | 117 | func getDefaultOptions() Options { 118 | return Options{ 119 | WithFieldMaskPaths: []string{}, 120 | WithNullPaths: []string{}, 121 | WithBatchSize: DefaultBatchSize, 122 | withLogLevel: Error, 123 | } 124 | } 125 | 126 | // WithBeforeWrite provides and option to provide a func to be called before a 127 | // write operation. The i interface{} passed at runtime will be the resource(s) 128 | // being written. 129 | func WithBeforeWrite(fn func(i interface{}) error) Option { 130 | return func(o *Options) { 131 | o.WithBeforeWrite = fn 132 | } 133 | } 134 | 135 | // WithAfterWrite provides and option to provide a func to be called after a 136 | // write operation. The i interface{} passed at runtime will be the resource(s) 137 | // being written. 138 | func WithAfterWrite(fn func(i interface{}, rowsAffected int) error) Option { 139 | return func(o *Options) { 140 | o.WithAfterWrite = fn 141 | } 142 | } 143 | 144 | // WithLookup enables a lookup after a write operation. 145 | func WithLookup(enable bool) Option { 146 | return func(o *Options) { 147 | o.WithLookup = enable 148 | } 149 | } 150 | 151 | // WithFieldMaskPaths provides an option to provide field mask paths for update 152 | // operations. 153 | func WithFieldMaskPaths(paths []string) Option { 154 | return func(o *Options) { 155 | o.WithFieldMaskPaths = paths 156 | } 157 | } 158 | 159 | // WithNullPaths provides an option to provide null paths for update operations. 160 | func WithNullPaths(paths []string) Option { 161 | return func(o *Options) { 162 | o.WithNullPaths = paths 163 | } 164 | } 165 | 166 | // WithLimit provides an option to provide a limit. Intentionally allowing 167 | // negative integers. If WithLimit < 0, then unlimited results are returned. 168 | // If WithLimit == 0, then default limits are used for results (see DefaultLimit 169 | // const). 170 | func WithLimit(limit int) Option { 171 | return func(o *Options) { 172 | o.WithLimit = limit 173 | } 174 | } 175 | 176 | // WithVersion provides an option version number for update operations. Using 177 | // this option requires that your resource has a version column that's 178 | // incremented for every successful update operation. Version provides an 179 | // optimistic locking mechanism for write operations. 180 | func WithVersion(version *uint32) Option { 181 | return func(o *Options) { 182 | o.WithVersion = version 183 | } 184 | } 185 | 186 | // WithSkipVetForWrite provides an option to allow skipping vet checks to allow 187 | // testing lower-level SQL triggers and constraints 188 | func WithSkipVetForWrite(enable bool) Option { 189 | return func(o *Options) { 190 | o.WithSkipVetForWrite = enable 191 | } 192 | } 193 | 194 | // WithWhere provides an option to provide a where clause with arguments for an 195 | // operation. 196 | func WithWhere(whereClause string, args ...interface{}) Option { 197 | return func(o *Options) { 198 | o.WithWhereClause = whereClause 199 | o.WithWhereClauseArgs = append(o.WithWhereClauseArgs, args...) 200 | } 201 | } 202 | 203 | // WithOrder provides an option to provide an order when searching and looking 204 | // up. 205 | func WithOrder(withOrder string) Option { 206 | return func(o *Options) { 207 | o.WithOrder = withOrder 208 | } 209 | } 210 | 211 | // WithPrngValues provides an option to provide values to seed an PRNG when generating IDs 212 | func WithPrngValues(withPrngValues []string) Option { 213 | return func(o *Options) { 214 | o.WithPrngValues = withPrngValues 215 | } 216 | } 217 | 218 | // WithLogger specifies an optional hclog to use for db operations. It's only 219 | // valid for Open(..) and OpenWith(...). The logger provided can optionally 220 | // implement the LogWriter interface as well which would override the default 221 | // behavior for a logger (the default only emits postgres errors) 222 | func WithLogger(l hclog.Logger) Option { 223 | return func(o *Options) { 224 | o.WithLogger = l 225 | } 226 | } 227 | 228 | // WithMaxOpenConnections specifies and optional max open connections for the 229 | // database. A value of zero equals unlimited connections 230 | func WithMaxOpenConnections(max int) Option { 231 | return func(o *Options) { 232 | o.WithMaxOpenConnections = max 233 | } 234 | } 235 | 236 | // WithMinOpenConnections specifies and optional min open connections for the 237 | // database. A value of zero means that there is no min. 238 | func WithMinOpenConnections(max int) Option { 239 | return func(o *Options) { 240 | o.WithMinOpenConnections = max 241 | } 242 | } 243 | 244 | // WithDebug specifies the given operation should invoke debug mode for the 245 | // database output 246 | func WithDebug(with bool) Option { 247 | return func(o *Options) { 248 | o.WithDebug = with 249 | } 250 | } 251 | 252 | // WithOnConflict specifies an optional on conflict criteria which specify 253 | // alternative actions to take when an insert results in a unique constraint or 254 | // exclusion constraint error 255 | func WithOnConflict(onConflict *OnConflict) Option { 256 | return func(o *Options) { 257 | o.WithOnConflict = onConflict 258 | } 259 | } 260 | 261 | // WithReturnRowsAffected specifies an option for returning the rows affected 262 | // and typically used with "bulk" write operations. 263 | func WithReturnRowsAffected(rowsAffected *int64) Option { 264 | return func(o *Options) { 265 | o.WithRowsAffected = rowsAffected 266 | } 267 | } 268 | 269 | // WithTable specifies an option for setting a table name to use for the 270 | // operation. 271 | func WithTable(name string) Option { 272 | return func(o *Options) { 273 | o.WithTable = name 274 | } 275 | } 276 | 277 | // WithLogLevel specifies an option for setting the log level 278 | func WithLogLevel(l LogLevel) Option { 279 | return func(o *Options) { 280 | o.withLogLevel = l 281 | } 282 | } 283 | 284 | // WithBatchSize specifies an option for setting the batch size for bulk 285 | // operations like CreateItems. If WithBatchSize == 0, the default batch size is 286 | // used (see DefaultBatchSize const). 287 | func WithBatchSize(size int) Option { 288 | return func(o *Options) { 289 | o.WithBatchSize = size 290 | } 291 | } 292 | -------------------------------------------------------------------------------- /rw.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "reflect" 10 | "strings" 11 | 12 | "gorm.io/gorm" 13 | "gorm.io/gorm/callbacks" 14 | ) 15 | 16 | const ( 17 | noRowsAffected = 0 18 | 19 | // DefaultLimit is the default for search results when no limit is specified 20 | // via the WithLimit(...) option 21 | DefaultLimit = 10000 22 | ) 23 | 24 | // RW uses a DB as a connection for it's read/write operations. This is 25 | // basically the primary type for the package's operations. 26 | type RW struct { 27 | underlying *DB 28 | } 29 | 30 | // ensure that RW implements the interfaces of: Reader and Writer 31 | var ( 32 | _ Reader = (*RW)(nil) 33 | _ Writer = (*RW)(nil) 34 | ) 35 | 36 | // New creates a new RW using an open DB. Note: there can by many RWs that share 37 | // the same DB, since the DB manages the connection pool. 38 | func New(underlying *DB) *RW { 39 | return &RW{underlying: underlying} 40 | } 41 | 42 | // DB returns the underlying DB 43 | func (rw *RW) DB() *DB { 44 | return rw.underlying 45 | } 46 | 47 | // Exec will execute the sql with the values as parameters. The int returned 48 | // is the number of rows affected by the sql. The WithDebug option is supported. 49 | func (rw *RW) Exec(ctx context.Context, sql string, values []interface{}, opt ...Option) (int, error) { 50 | const op = "dbw.Exec" 51 | if rw.underlying == nil { 52 | return 0, fmt.Errorf("%s: missing underlying db: %w", op, ErrInternal) 53 | } 54 | if sql == "" { 55 | return noRowsAffected, fmt.Errorf("%s: missing sql: %w", op, ErrInvalidParameter) 56 | } 57 | opts := GetOpts(opt...) 58 | db := rw.underlying.wrapped.WithContext(ctx) 59 | if opts.WithDebug { 60 | db = db.Debug() 61 | } 62 | db = db.Exec(sql, values...) 63 | if db.Error != nil { 64 | return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) 65 | } 66 | return int(db.RowsAffected), nil 67 | } 68 | 69 | func (rw *RW) primaryFieldsAreZero(ctx context.Context, i interface{}) ([]string, bool, error) { 70 | const op = "dbw.primaryFieldsAreZero" 71 | var fieldNames []string 72 | tx := rw.underlying.wrapped.Model(i) 73 | if err := tx.Statement.Parse(i); err != nil { 74 | return nil, false, fmt.Errorf("%s: %w", op, ErrInvalidParameter) 75 | } 76 | for _, f := range tx.Statement.Schema.PrimaryFields { 77 | if f.PrimaryKey { 78 | if _, isZero := f.ValueOf(ctx, reflect.ValueOf(i)); isZero { 79 | fieldNames = append(fieldNames, f.Name) 80 | } 81 | } 82 | } 83 | return fieldNames, len(fieldNames) > 0, nil 84 | } 85 | 86 | func isNil(i interface{}) bool { 87 | if i == nil { 88 | return true 89 | } 90 | switch reflect.TypeOf(i).Kind() { 91 | case reflect.Ptr, reflect.Map, reflect.Chan, reflect.Slice: 92 | return reflect.ValueOf(i).IsNil() 93 | } 94 | return false 95 | } 96 | 97 | func contains(ss []string, t string) bool { 98 | for _, s := range ss { 99 | if strings.EqualFold(s, t) { 100 | return true 101 | } 102 | } 103 | return false 104 | } 105 | 106 | func validateResourcesInterface(resources interface{}) error { 107 | const op = "dbw.validateResourcesInterface" 108 | vo := reflect.ValueOf(resources) 109 | if vo.Kind() != reflect.Ptr { 110 | return fmt.Errorf("%s: interface parameter must to be a pointer: %w", op, ErrInvalidParameter) 111 | } 112 | e := vo.Elem() 113 | if e.Kind() == reflect.Slice { 114 | if e.Type().Elem().Kind() != reflect.Ptr { 115 | return fmt.Errorf("%s: interface parameter is a slice, but the elements of the slice are not pointers: %w", op, ErrInvalidParameter) 116 | } 117 | } 118 | return nil 119 | } 120 | 121 | func raiseErrorOnHooks(i interface{}) error { 122 | const op = "dbw.raiseErrorOnHooks" 123 | v := i 124 | valOf := reflect.ValueOf(i) 125 | if valOf.Kind() == reflect.Slice { 126 | if valOf.Len() == 0 { 127 | return nil 128 | } 129 | v = valOf.Index(0).Interface() 130 | } 131 | 132 | switch v.(type) { 133 | case 134 | // create hooks 135 | callbacks.BeforeCreateInterface, 136 | callbacks.AfterCreateInterface, 137 | callbacks.BeforeSaveInterface, 138 | callbacks.AfterSaveInterface, 139 | 140 | // update hooks 141 | callbacks.BeforeUpdateInterface, 142 | callbacks.AfterUpdateInterface, 143 | 144 | // delete hooks 145 | callbacks.BeforeDeleteInterface, 146 | callbacks.AfterDeleteInterface, 147 | 148 | // find hooks 149 | callbacks.AfterFindInterface: 150 | 151 | return fmt.Errorf("%s: gorm callback/hooks are not supported: %w", op, ErrInvalidParameter) 152 | } 153 | return nil 154 | } 155 | 156 | // IsTx returns true if there's an existing transaction in progress 157 | func (rw *RW) IsTx() bool { 158 | if rw.underlying == nil { 159 | return false 160 | } 161 | switch rw.underlying.wrapped.Statement.ConnPool.(type) { 162 | case gorm.TxBeginner, gorm.ConnPoolBeginner: 163 | return false 164 | default: 165 | return true 166 | } 167 | } 168 | 169 | func (rw *RW) whereClausesFromOpts(_ context.Context, i interface{}, opts Options) (string, []interface{}, error) { 170 | const op = "dbw.whereClausesFromOpts" 171 | var where []string 172 | var args []interface{} 173 | if opts.WithVersion != nil { 174 | if *opts.WithVersion == 0 { 175 | return "", nil, fmt.Errorf("%s: with version option is zero: %w", op, ErrInvalidParameter) 176 | } 177 | mDb := rw.underlying.wrapped.Model(i) 178 | err := mDb.Statement.Parse(i) 179 | if err != nil && mDb.Statement.Schema == nil { 180 | return "", nil, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) 181 | } 182 | if !contains(mDb.Statement.Schema.DBNames, "version") { 183 | return "", nil, fmt.Errorf("%s: %s does not have a version field: %w", op, mDb.Statement.Schema.Table, ErrInvalidParameter) 184 | } 185 | if opts.WithOnConflict != nil { 186 | // on conflict clauses requires the version to be qualified with a 187 | // table name 188 | var tableName string 189 | switch { 190 | case opts.WithTable != "": 191 | tableName = opts.WithTable 192 | default: 193 | tableName = mDb.Statement.Schema.Table 194 | } 195 | where = append(where, fmt.Sprintf("%s.version = ?", tableName)) // we need to include the table name because of "on conflict" use cases 196 | } else { 197 | where = append(where, "version = ?") 198 | } 199 | args = append(args, opts.WithVersion) 200 | } 201 | if opts.WithWhereClause != "" { 202 | where, args = append(where, opts.WithWhereClause), append(args, opts.WithWhereClauseArgs...) 203 | } 204 | return strings.Join(where, " and "), args, nil 205 | } 206 | 207 | // clearDefaultNullResourceFields will clear fields in the resource which are 208 | // defaulted to a null value. This addresses the unfixed issue in gorm: 209 | // https://github.com/go-gorm/gorm/issues/6351 210 | func (rw *RW) clearDefaultNullResourceFields(ctx context.Context, i interface{}) error { 211 | const op = "dbw.ClearResourceFields" 212 | stmt := rw.underlying.wrapped.Model(i).Statement 213 | if err := stmt.Parse(i); err != nil { 214 | return fmt.Errorf("%s: %w", op, err) 215 | } 216 | v := reflect.ValueOf(i) 217 | for _, f := range stmt.Schema.Fields { 218 | switch { 219 | case f.PrimaryKey: 220 | // seems a bit redundant, with the test for null, but it's very 221 | // important to not clear the primary fields, so we'll make an 222 | // explicit test 223 | continue 224 | case !f.Updatable: 225 | // well, based on the gorm tags it's a read-only field, so we're done. 226 | continue 227 | case !strings.EqualFold(f.DefaultValue, "null"): 228 | continue 229 | default: 230 | _, isZero := f.ValueOf(ctx, v) 231 | if isZero { 232 | continue 233 | } 234 | if err := f.Set(stmt.Context, v, f.DefaultValueInterface); err != nil { 235 | return fmt.Errorf("%s: unable to set value of non-zero field: %w", op, err) 236 | } 237 | } 238 | } 239 | return nil 240 | } 241 | 242 | func (rw *RW) primaryKeysWhere(ctx context.Context, i interface{}) (string, []interface{}, error) { 243 | const op = "dbw.primaryKeysWhere" 244 | var fieldNames []string 245 | var fieldValues []interface{} 246 | tx := rw.underlying.wrapped.Model(i) 247 | if err := tx.Statement.Parse(i); err != nil { 248 | return "", nil, fmt.Errorf("%s: %w", op, err) 249 | } 250 | switch resourceType := i.(type) { 251 | case ResourcePublicIder: 252 | if resourceType.GetPublicId() == "" { 253 | return "", nil, fmt.Errorf("%s: missing primary key: %w", op, ErrInvalidParameter) 254 | } 255 | fieldValues = []interface{}{resourceType.GetPublicId()} 256 | fieldNames = []string{"public_id"} 257 | case ResourcePrivateIder: 258 | if resourceType.GetPrivateId() == "" { 259 | return "", nil, fmt.Errorf("%s: missing primary key: %w", op, ErrInvalidParameter) 260 | } 261 | fieldValues = []interface{}{resourceType.GetPrivateId()} 262 | fieldNames = []string{"private_id"} 263 | default: 264 | v := reflect.ValueOf(i) 265 | for _, f := range tx.Statement.Schema.PrimaryFields { 266 | if f.PrimaryKey { 267 | val, isZero := f.ValueOf(ctx, v) 268 | if isZero { 269 | return "", nil, fmt.Errorf("%s: primary field %s is zero: %w", op, f.Name, ErrInvalidParameter) 270 | } 271 | fieldNames = append(fieldNames, f.DBName) 272 | fieldValues = append(fieldValues, val) 273 | } 274 | } 275 | } 276 | if len(fieldNames) == 0 { 277 | return "", nil, fmt.Errorf("%s: no primary key(s) for %t: %w", op, i, ErrInvalidParameter) 278 | } 279 | clauses := make([]string, 0, len(fieldNames)) 280 | for _, col := range fieldNames { 281 | clauses = append(clauses, fmt.Sprintf("%s = ?", col)) 282 | } 283 | return strings.Join(clauses, " and "), fieldValues, nil 284 | } 285 | 286 | // LookupWhere will lookup the first resource using a where clause with 287 | // parameters (it only returns the first one). Supports WithDebug, and 288 | // WithTable options. 289 | func (rw *RW) LookupWhere(ctx context.Context, resource interface{}, where string, args []interface{}, opt ...Option) error { 290 | const op = "dbw.LookupWhere" 291 | if rw.underlying == nil { 292 | return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) 293 | } 294 | if err := validateResourcesInterface(resource); err != nil { 295 | return fmt.Errorf("%s: %w", op, err) 296 | } 297 | if err := raiseErrorOnHooks(resource); err != nil { 298 | return fmt.Errorf("%s: %w", op, err) 299 | } 300 | opts := GetOpts(opt...) 301 | db := rw.underlying.wrapped.WithContext(ctx) 302 | if opts.WithTable != "" { 303 | db = db.Table(opts.WithTable) 304 | } 305 | if opts.WithDebug { 306 | db = db.Debug() 307 | } 308 | if err := db.Where(where, args...).First(resource).Error; err != nil { 309 | if err == gorm.ErrRecordNotFound { 310 | return fmt.Errorf("%s: %w", op, ErrRecordNotFound) 311 | } 312 | return fmt.Errorf("%s: %w", op, err) 313 | } 314 | return nil 315 | } 316 | 317 | // SearchWhere will search for all the resources it can find using a where 318 | // clause with parameters. An error will be returned if args are provided without a 319 | // where clause. 320 | // 321 | // Supports WithTable and WithLimit options. If WithLimit < 0, then unlimited results are returned. 322 | // If WithLimit == 0, then default limits are used for results. 323 | // Supports the WithOrder, WithTable, and WithDebug options. 324 | func (rw *RW) SearchWhere(ctx context.Context, resources interface{}, where string, args []interface{}, opt ...Option) error { 325 | const op = "dbw.SearchWhere" 326 | opts := GetOpts(opt...) 327 | if rw.underlying == nil { 328 | return fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) 329 | } 330 | if where == "" && len(args) > 0 { 331 | return fmt.Errorf("%s: args provided with empty where: %w", op, ErrInvalidParameter) 332 | } 333 | if err := raiseErrorOnHooks(resources); err != nil { 334 | return fmt.Errorf("%s: %w", op, err) 335 | } 336 | if err := validateResourcesInterface(resources); err != nil { 337 | return fmt.Errorf("%s: %w", op, err) 338 | } 339 | var err error 340 | db := rw.underlying.wrapped.WithContext(ctx) 341 | if opts.WithOrder != "" { 342 | db = db.Order(opts.WithOrder) 343 | } 344 | if opts.WithDebug { 345 | db = db.Debug() 346 | } 347 | if opts.WithTable != "" { 348 | db = db.Table(opts.WithTable) 349 | } 350 | // Perform limiting 351 | switch { 352 | case opts.WithLimit < 0: // any negative number signals unlimited results 353 | case opts.WithLimit == 0: // zero signals the default value and default limits 354 | db = db.Limit(DefaultLimit) 355 | default: 356 | db = db.Limit(opts.WithLimit) 357 | } 358 | 359 | if where != "" { 360 | db = db.Where(where, args...) 361 | } 362 | 363 | // Perform the query 364 | err = db.Find(resources).Error 365 | if err != nil { 366 | // searching with a slice parameter does not return a gorm.ErrRecordNotFound 367 | return fmt.Errorf("%s: %w", op, err) 368 | } 369 | return nil 370 | } 371 | 372 | func (rw *RW) Dialect() (_ DbType, rawName string, _ error) { 373 | return rw.underlying.DbType() 374 | } 375 | -------------------------------------------------------------------------------- /rw_test.go: -------------------------------------------------------------------------------- 1 | // Copyright (c) HashiCorp, Inc. 2 | // SPDX-License-Identifier: MPL-2.0 3 | 4 | package dbw_test 5 | 6 | import ( 7 | "context" 8 | "database/sql" 9 | "fmt" 10 | "strconv" 11 | "testing" 12 | 13 | "github.com/hashicorp/go-dbw" 14 | "github.com/hashicorp/go-dbw/internal/dbtest" 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | func TestDb_Exec(t *testing.T) { 20 | t.Parallel() 21 | testCtx := context.Background() 22 | conn, _ := dbw.TestSetup(t) 23 | t.Run("update", func(t *testing.T) { 24 | require := require.New(t) 25 | w := dbw.New(conn) 26 | id, err := dbw.NewId("i") 27 | require.NoError(err) 28 | _, err = w.Exec(testCtx, 29 | "insert into db_test_user(public_id, name) values(@public_id, @name)", 30 | []interface{}{ 31 | sql.Named("public_id", id), 32 | sql.Named("name", "alice"), 33 | }, 34 | dbw.WithDebug(true), 35 | ) 36 | 37 | require.NoError(err) 38 | rowsAffected, err := w.Exec(testCtx, 39 | "update db_test_user set name = @name where public_id = @public_id", 40 | []interface{}{ 41 | sql.Named("public_id", id), 42 | sql.Named("name", "alice-"+id), 43 | }) 44 | require.NoError(err) 45 | require.Equal(1, rowsAffected) 46 | }) 47 | t.Run("missing-sql", func(t *testing.T) { 48 | assert, require := assert.New(t), require.New(t) 49 | rw := dbw.New(conn) 50 | got, err := rw.Exec(testCtx, "", nil) 51 | require.Error(err) 52 | assert.Zero(got) 53 | }) 54 | t.Run("missing-underlying-db", func(t *testing.T) { 55 | assert, require := assert.New(t), require.New(t) 56 | rw := dbw.RW{} 57 | got, err := rw.Exec(testCtx, "", nil) 58 | require.Error(err) 59 | assert.Zero(got) 60 | }) 61 | t.Run("bad-sql", func(t *testing.T) { 62 | assert, require := assert.New(t), require.New(t) 63 | rw := dbw.New(conn) 64 | got, err := rw.Exec(testCtx, "insert from", nil) 65 | require.Error(err) 66 | assert.Zero(got) 67 | }) 68 | } 69 | 70 | func TestDb_LookupWhere(t *testing.T) { 71 | t.Parallel() 72 | conn, _ := dbw.TestSetup(t) 73 | t.Run("simple", func(t *testing.T) { 74 | assert, require := assert.New(t), require.New(t) 75 | w := dbw.New(conn) 76 | user, err := dbtest.NewTestUser() 77 | require.NoError(err) 78 | user.Name = "foo-" + user.PublicId 79 | err = w.Create(context.Background(), user) 80 | require.NoError(err) 81 | assert.NotEmpty(user.PublicId) 82 | 83 | var foundUser dbtest.TestUser 84 | err = w.LookupWhere(context.Background(), &foundUser, "public_id = ? and 1 = ?", []interface{}{user.PublicId, 1}, dbw.WithDebug(true)) 85 | require.NoError(err) 86 | assert.Equal(foundUser.PublicId, user.PublicId) 87 | }) 88 | t.Run("with-table", func(t *testing.T) { 89 | assert, require := assert.New(t), require.New(t) 90 | w := dbw.New(conn) 91 | user, err := dbtest.NewTestUser() 92 | require.NoError(err) 93 | user.Name = "foo-" + user.PublicId 94 | err = w.Create(context.Background(), user, dbw.WithTable(user.TableName())) 95 | require.NoError(err) 96 | assert.NotEmpty(user.PublicId) 97 | 98 | var foundUser dbtest.TestUser 99 | err = w.LookupWhere(context.Background(), &foundUser, "public_id = ?", []interface{}{user.PublicId}, dbw.WithTable(user.TableName())) 100 | require.NoError(err) 101 | assert.Equal(foundUser.PublicId, user.PublicId) 102 | 103 | err = w.LookupWhere(context.Background(), &foundUser, "public_id = ?", []interface{}{user.PublicId}, dbw.WithTable("invalid-table-name")) 104 | require.Error(err) 105 | }) 106 | t.Run("tx-nil,", func(t *testing.T) { 107 | assert, require := assert.New(t), require.New(t) 108 | w := dbw.RW{} 109 | var foundUser dbtest.TestUser 110 | err := w.LookupWhere(context.Background(), &foundUser, "public_id = ?", []interface{}{1}) 111 | require.Error(err) 112 | assert.Equal("dbw.LookupWhere: missing underlying db: invalid parameter", err.Error()) 113 | }) 114 | t.Run("not-found", func(t *testing.T) { 115 | assert, require := assert.New(t), require.New(t) 116 | w := dbw.New(conn) 117 | id, err := dbw.NewId("i") 118 | require.NoError(err) 119 | 120 | var foundUser dbtest.TestUser 121 | err = w.LookupWhere(context.Background(), &foundUser, "public_id = ?", []interface{}{id}) 122 | require.Error(err) 123 | assert.ErrorIs(err, dbw.ErrRecordNotFound) 124 | }) 125 | t.Run("bad-where", func(t *testing.T) { 126 | require := require.New(t) 127 | w := dbw.New(conn) 128 | id, err := dbw.NewId("i") 129 | require.NoError(err) 130 | 131 | var foundUser dbtest.TestUser 132 | err = w.LookupWhere(context.Background(), &foundUser, "? = ?", []interface{}{id}) 133 | require.Error(err) 134 | }) 135 | t.Run("not-ptr", func(t *testing.T) { 136 | require := require.New(t) 137 | w := dbw.New(conn) 138 | id, err := dbw.NewId("i") 139 | require.NoError(err) 140 | 141 | var foundUser dbtest.TestUser 142 | err = w.LookupWhere(context.Background(), foundUser, "public_id = ?", []interface{}{id}) 143 | require.Error(err) 144 | }) 145 | t.Run("hooks", func(t *testing.T) { 146 | hookTests := []struct { 147 | name string 148 | resource interface{} 149 | }{ 150 | {"after", &dbtest.TestWithAfterFind{}}, 151 | } 152 | for _, tt := range hookTests { 153 | t.Run(tt.name, func(t *testing.T) { 154 | assert, require := assert.New(t), require.New(t) 155 | w := dbw.New(conn) 156 | err := w.LookupWhere(context.Background(), tt.resource, "public_id = ?", []interface{}{"1"}) 157 | require.Error(err) 158 | assert.ErrorIs(err, dbw.ErrInvalidParameter) 159 | assert.Contains(err.Error(), "gorm callback/hooks are not supported") 160 | }) 161 | } 162 | }) 163 | } 164 | 165 | func TestDb_SearchWhere(t *testing.T) { 166 | t.Parallel() 167 | conn, _ := dbw.TestSetup(t) 168 | testRw := dbw.New(conn) 169 | knownUser := testUser(t, testRw, "zedUser", "", "") 170 | 171 | type args struct { 172 | where string 173 | arg []interface{} 174 | opt []dbw.Option 175 | } 176 | tests := []struct { 177 | name string 178 | rw *dbw.RW 179 | createCnt int 180 | args args 181 | wantCnt int 182 | wantErr bool 183 | wantNameOrder bool 184 | }{ 185 | { 186 | name: "no-limit", 187 | rw: testRw, 188 | createCnt: 10, 189 | args: args{ 190 | where: "1=1", 191 | opt: []dbw.Option{dbw.WithLimit(-1), dbw.WithOrder("name asc")}, 192 | }, 193 | wantCnt: 11, // there's an additional knownUser 194 | wantErr: false, 195 | wantNameOrder: true, 196 | }, 197 | { 198 | name: "no-where", 199 | rw: testRw, 200 | createCnt: 10, 201 | args: args{ 202 | opt: []dbw.Option{dbw.WithLimit(10)}, 203 | }, 204 | wantCnt: 10, 205 | wantErr: false, 206 | }, 207 | { 208 | name: "custom-limit", 209 | rw: testRw, 210 | createCnt: 10, 211 | args: args{ 212 | where: "1=1", 213 | opt: []dbw.Option{dbw.WithLimit(3)}, 214 | }, 215 | wantCnt: 3, 216 | wantErr: false, 217 | }, 218 | { 219 | name: "simple", 220 | rw: testRw, 221 | createCnt: 1, 222 | args: args{ 223 | where: "public_id = ?", 224 | arg: []interface{}{knownUser.PublicId}, 225 | opt: []dbw.Option{dbw.WithLimit(3)}, 226 | }, 227 | wantCnt: 1, 228 | wantErr: false, 229 | }, 230 | { 231 | name: "with-table", 232 | rw: testRw, 233 | createCnt: 1, 234 | args: args{ 235 | where: "public_id = ?", 236 | arg: []interface{}{knownUser.PublicId}, 237 | opt: []dbw.Option{dbw.WithLimit(3), dbw.WithTable(knownUser.TableName())}, 238 | }, 239 | wantCnt: 1, 240 | wantErr: false, 241 | }, 242 | { 243 | name: "with-table-fail", 244 | rw: testRw, 245 | createCnt: 1, 246 | args: args{ 247 | where: "public_id = ?", 248 | arg: []interface{}{knownUser.PublicId}, 249 | opt: []dbw.Option{dbw.WithLimit(3), dbw.WithTable("invalid-table-name")}, 250 | }, 251 | wantErr: true, 252 | }, 253 | { 254 | name: "no args", 255 | rw: testRw, 256 | createCnt: 1, 257 | args: args{ 258 | where: fmt.Sprintf("public_id = '%v'", knownUser.PublicId), 259 | opt: []dbw.Option{dbw.WithLimit(3)}, 260 | }, 261 | wantCnt: 1, 262 | wantErr: false, 263 | }, 264 | { 265 | name: "no where, but with args", 266 | rw: testRw, 267 | createCnt: 1, 268 | args: args{ 269 | arg: []interface{}{knownUser.PublicId}, 270 | opt: []dbw.Option{dbw.WithLimit(3)}, 271 | }, 272 | wantErr: true, 273 | }, 274 | { 275 | name: "not-found", 276 | rw: testRw, 277 | createCnt: 1, 278 | args: args{ 279 | where: "public_id = ?", 280 | arg: []interface{}{"bad-id"}, 281 | opt: []dbw.Option{dbw.WithLimit(3)}, 282 | }, 283 | wantCnt: 0, 284 | wantErr: false, 285 | }, 286 | { 287 | name: "bad-where", 288 | rw: testRw, 289 | createCnt: 1, 290 | args: args{ 291 | where: "bad_column_name = ?", 292 | arg: []interface{}{knownUser.PublicId}, 293 | opt: []dbw.Option{dbw.WithLimit(3)}, 294 | }, 295 | wantCnt: 0, 296 | wantErr: true, 297 | }, 298 | { 299 | name: "nil-underlying", 300 | rw: &dbw.RW{}, 301 | createCnt: 1, 302 | args: args{ 303 | where: "public_id = ?", 304 | arg: []interface{}{knownUser.PublicId}, 305 | opt: []dbw.Option{dbw.WithLimit(3)}, 306 | }, 307 | wantCnt: 0, 308 | wantErr: true, 309 | }, 310 | } 311 | for _, tt := range tests { 312 | t.Run(tt.name, func(t *testing.T) { 313 | assert, require := assert.New(t), require.New(t) 314 | testUsers := []*dbtest.TestUser{} 315 | for i := 0; i < tt.createCnt; i++ { 316 | testUsers = append(testUsers, testUser(t, testRw, tt.name+strconv.Itoa(i), "", "")) 317 | } 318 | assert.Equal(tt.createCnt, len(testUsers)) 319 | 320 | var foundUsers []*dbtest.TestUser 321 | err := tt.rw.SearchWhere(context.Background(), &foundUsers, tt.args.where, tt.args.arg, tt.args.opt...) 322 | if tt.wantErr { 323 | require.Error(err) 324 | return 325 | } 326 | require.NoError(err) 327 | assert.Equal(tt.wantCnt, len(foundUsers)) 328 | if tt.wantNameOrder { 329 | assert.Equal(tt.name+strconv.Itoa(0), foundUsers[0].Name) 330 | for i, u := range foundUsers { 331 | if u.Name != "zedUser" { 332 | assert.Equal(tt.name+strconv.Itoa(i), u.Name) 333 | } 334 | } 335 | } 336 | }) 337 | } 338 | t.Run("hooks", func(t *testing.T) { 339 | hookTests := []struct { 340 | name string 341 | resource interface{} 342 | }{ 343 | {"after", &dbtest.TestWithAfterFind{}}, 344 | } 345 | for _, tt := range hookTests { 346 | t.Run(tt.name, func(t *testing.T) { 347 | assert, require := assert.New(t), require.New(t) 348 | w := dbw.New(conn) 349 | err := w.SearchWhere(context.Background(), tt.resource, "public_id = 1", nil) 350 | require.Error(err) 351 | assert.ErrorIs(err, dbw.ErrInvalidParameter) 352 | assert.Contains(err.Error(), "gorm callback/hooks are not supported") 353 | }) 354 | } 355 | }) 356 | } 357 | 358 | func TestRW_IsTx(t *testing.T) { 359 | t.Parallel() 360 | testCtx := context.Background() 361 | conn, _ := dbw.TestSetup(t) 362 | testRw := dbw.New(conn) 363 | assert, require := assert.New(t), require.New(t) 364 | 365 | assert.False(testRw.IsTx()) 366 | 367 | tx, err := testRw.Begin(testCtx) 368 | require.NoError(err) 369 | assert.NotNil(tx) 370 | assert.True(tx.IsTx()) 371 | } 372 | 373 | func TestDialect(t *testing.T) { 374 | t.Parallel() 375 | conn, _ := dbw.TestSetup(t) 376 | testRw := dbw.New(conn) 377 | assert, require := assert.New(t), require.New(t) 378 | 379 | gotTyp, gotRawName, err := testRw.Dialect() 380 | require.NoError(err) 381 | typ, rawName, err := conn.DbType() 382 | require.NoError(err) 383 | assert.Equal(typ, gotTyp) 384 | assert.Equal(rawName, gotRawName) 385 | } 386 | 387 | func testUser(t *testing.T, rw *dbw.RW, name, email, phoneNumber string) *dbtest.TestUser { 388 | t.Helper() 389 | require := require.New(t) 390 | r, err := dbtest.NewTestUser() 391 | require.NoError(err) 392 | r.Name = name 393 | r.Email = email 394 | r.PhoneNumber = phoneNumber 395 | if rw != nil { 396 | err = rw.Create(context.Background(), r) 397 | require.NoError(err) 398 | } 399 | return r 400 | } 401 | 402 | func testCar(t *testing.T, rw *dbw.RW) *dbtest.TestCar { 403 | t.Helper() 404 | require := require.New(t) 405 | r, err := dbtest.NewTestCar() 406 | require.NoError(err) 407 | if rw != nil { 408 | err = rw.Create(context.Background(), r) 409 | require.NoError(err) 410 | } 411 | return r 412 | } 413 | 414 | func testScooter(t *testing.T, rw *dbw.RW, model string, mpg int32, readOnlyField string) *dbtest.TestScooter { 415 | t.Helper() 416 | require := require.New(t) 417 | r, err := dbtest.NewTestScooter() 418 | require.NoError(err) 419 | r.Model = model 420 | r.Mpg = mpg 421 | r.ReadOnlyField = readOnlyField 422 | if rw != nil { 423 | err = rw.Create(context.Background(), r) 424 | require.NoError(err) 425 | } 426 | return r 427 | } 428 | 429 | func testRental(t *testing.T, rw *dbw.RW, userId, carId string) *dbtest.TestRental { 430 | t.Helper() 431 | require := require.New(t) 432 | r, err := dbtest.NewTestRental(userId, carId) 433 | require.NoError(err) 434 | if rw != nil { 435 | err = rw.Create(context.Background(), r) 436 | require.NoError(err) 437 | } 438 | return r 439 | } 440 | --------------------------------------------------------------------------------