├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ ├── sweep-bugfix.yml │ ├── sweep-feature.yml │ └── sweep-refactor.yml ├── dependabot.yml └── workflows │ ├── build.yml │ ├── codeql-analysis.yml │ └── stale.yml ├── .gitignore ├── .golangci.yml ├── LICENSE ├── README.md ├── argument.go ├── argument_test.go ├── batch.go ├── batch_test.go ├── driver.go ├── driver_test.go ├── examples ├── .vscode │ └── launch.json ├── basic │ ├── basic.go │ └── basic_test.go ├── blog │ ├── blog.go │ └── blog_test.go └── doc.go ├── expectations.go ├── expectations_test.go ├── go.mod ├── go.sum ├── options.go ├── pgxmock.go ├── pgxmock_test.go ├── query.go ├── query_test.go ├── result.go ├── result_test.go ├── rows.go ├── rows_test.go └── sql_test.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: pashagolub 4 | # patreon: # Replace with a single Patreon username 5 | # open_collective: # Replace with a single Open Collective username 6 | # ko_fi: # Replace with a single Ko-fi username 7 | # tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | # community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | # liberapay: # Replace with a single Liberapay username 10 | # issuehunt: # Replace with a single IssueHunt username 11 | # otechie: # Replace with a single Otechie username 12 | # lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: ['https://revolut.me/pavlogolub'] 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/sweep-bugfix.yml: -------------------------------------------------------------------------------- 1 | name: Bugfix 2 | title: 'Sweep: ' 3 | description: Write something like "We notice ... behavior when ... happens instead of ..."" 4 | labels: sweep 5 | body: 6 | - type: textarea 7 | id: description 8 | attributes: 9 | label: Details 10 | description: More details about the bug 11 | placeholder: The bug might be in ... file -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/sweep-feature.yml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | title: 'Sweep: ' 3 | description: Write something like "Write an api endpoint that does "..." in the "..." file" 4 | labels: sweep 5 | body: 6 | - type: textarea 7 | id: description 8 | attributes: 9 | label: Details 10 | description: More details for Sweep 11 | placeholder: The new endpoint should use the ... class from ... file because it contains ... logic -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/sweep-refactor.yml: -------------------------------------------------------------------------------- 1 | name: Refactor 2 | title: 'Sweep: ' 3 | description: Write something like "Modify the ... api endpoint to use ... version and ... framework" 4 | labels: sweep 5 | body: 6 | - type: textarea 7 | id: description 8 | attributes: 9 | label: Details 10 | description: More details for Sweep 11 | placeholder: We are migrating this function to ... version because ... -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | 4 | # Maintain dependencies for Go modules 5 | - package-ecosystem: gomod 6 | directory: "/" 7 | schedule: 8 | interval: daily 9 | time: "04:00" 10 | open-pull-requests-limit: 10 11 | 12 | # Maintain dependencies for GitHub Actions 13 | - package-ecosystem: "github-actions" 14 | directory: "/" 15 | schedule: 16 | interval: "daily" 17 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build & Test 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | 11 | build-and-test: 12 | if: true # false to skip job during debug 13 | name: Test and Build on Ubuntu 14 | runs-on: ubuntu-latest 15 | steps: 16 | 17 | - name: Check out code 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Golang 21 | uses: actions/setup-go@v5 22 | with: 23 | go-version: '1.21' 24 | 25 | - name: Get dependencies 26 | run: | 27 | go mod download 28 | go version 29 | 30 | - name: GolangCI-Lint 31 | uses: golangci/golangci-lint-action@v8 32 | with: 33 | version: latest 34 | 35 | - name: Test 36 | run: go test -v -coverprofile=profile.cov 37 | 38 | - name: Coveralls 39 | uses: shogo82148/actions-goveralls@v1 40 | with: 41 | path-to-profile: profile.cov 42 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "master" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "master" ] 20 | schedule: 21 | - cron: '30 20 * * 2' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'go' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v4 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v3 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v3 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v3 73 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Close Stale Issues and PRs 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' 6 | 7 | workflow_dispatch: 8 | 9 | 10 | jobs: 11 | stale: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/stale@v9 15 | with: 16 | repo-token: ${{ secrets.GITHUB_TOKEN }} 17 | stale-issue-label: 'stale' 18 | stale-pr-label: 'stale' 19 | stale-issue-message: | 20 | 📅 This issue has been automatically marked as stale because lack of recent activity. It will be closed if no further activity occurs. 21 | ♻️ If you think there is new information allowing us to address the issue, please reopen it and provide us with updated details. 22 | 🤝 Thank you for your contributions. 23 | stale-pr-message: | 24 | 📅 This PR has been automatically marked as stale because lack of recent activity. It will be closed if no further activity occurs. 25 | ♻️ If you think there is new information allowing us to address this PR, please reopen it and provide us with updated details. 26 | 🤝 Thank you for your contributions. 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Visual Studio Code internal folder 15 | .vscode 16 | 17 | # Packages ouput folder 18 | dist 19 | 20 | # delve debugger file 21 | debug -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | enable: 4 | - gocyclo 5 | - misspell 6 | - revive 7 | settings: 8 | gocyclo: 9 | min-complexity: 20 10 | exclusions: 11 | generated: lax 12 | presets: 13 | - comments 14 | - common-false-positives 15 | - legacy 16 | - std-error-handling 17 | paths: 18 | - third_party$ 19 | - builtin$ 20 | - examples$ 21 | formatters: 22 | exclusions: 23 | generated: lax 24 | paths: 25 | - third_party$ 26 | - builtin$ 27 | - examples$ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The three clause BSD license (http://en.wikipedia.org/wiki/BSD_licenses) 2 | 3 | Copyright (c) 2021-2024 Pavlo Golub 4 | Copyright (c) 2013-2020, DATA-DOG team 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * The names from the above copyright notices may not be used to endorse or 18 | promote products derived from this software without specific prior written 19 | permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL MICHAEL BOSTOCK BE LIABLE FOR ANY DIRECT, 25 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 26 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 28 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 29 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 30 | EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Go Reference](https://pkg.go.dev/badge/github.com/pashagolub/pgxmock.svg)](https://pkg.go.dev/github.com/pashagolub/pgxmock/v4) 2 | [![Go Report Card](https://goreportcard.com/badge/github.com/pashagolub/pgxmock)](https://goreportcard.com/report/github.com/pashagolub/pgxmock/v4) 3 | [![Coverage Status](https://coveralls.io/repos/github/pashagolub/pgxmock/badge.svg?branch=master)](https://coveralls.io/github/pashagolub/pgxmock?branch=master) 4 | [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) 5 | 6 | # pgx driver mock for Golang 7 | 8 | **pgxmock** is a mock library implementing [pgx - PostgreSQL Driver and Toolkit](https://github.com/jackc/pgx/). 9 | It's based on the well-known [sqlmock](https://github.com/DATA-DOG/go-sqlmock) library for `sql/driver`. 10 | 11 | **pgxmock** has one and only purpose - to simulate **pgx** behavior in tests, without needing a real database connection. It helps to maintain correct **TDD** workflow. 12 | 13 | - written based on **go1.21** version; 14 | - does not require any modifications to your source code; 15 | - has strict by default expectation order matching; 16 | - has no third party dependencies except **pgx** packages. 17 | 18 | ## Install 19 | 20 | go get github.com/pashagolub/pgxmock/v4 21 | 22 | ## Documentation and Examples 23 | 24 | Visit [godoc](http://pkg.go.dev/github.com/pashagolub/pgxmock/v4) for general examples and public api reference. 25 | 26 | See implementation examples: 27 | 28 | - [the simplest one](https://github.com/pashagolub/pgxmock/tree/master/examples/basic) 29 | - [blog API server](https://github.com/pashagolub/pgxmock/tree/master/examples/blog) 30 | 31 | 32 | ### Something you may want to test 33 | 34 | ``` go 35 | package main 36 | 37 | import ( 38 | "context" 39 | 40 | pgx "github.com/jackc/pgx/v5" 41 | ) 42 | 43 | type PgxIface interface { 44 | Begin(context.Context) (pgx.Tx, error) 45 | Close(context.Context) error 46 | } 47 | 48 | func recordStats(db PgxIface, userID, productID int) (err error) { 49 | if tx, err := db.Begin(context.Background()); err != nil { 50 | return 51 | } 52 | defer func() { 53 | switch err { 54 | case nil: 55 | err = tx.Commit(context.Background()) 56 | default: 57 | _ = tx.Rollback(context.Background()) 58 | } 59 | }() 60 | sql := "UPDATE products SET views = views + 1" 61 | if _, err = tx.Exec(context.Background(), sql); err != nil { 62 | return 63 | } 64 | sql = "INSERT INTO product_viewers (user_id, product_id) VALUES ($1, $2)" 65 | if _, err = tx.Exec(context.Background(), sql, userID, productID); err != nil { 66 | return 67 | } 68 | return 69 | } 70 | 71 | func main() { 72 | // @NOTE: the real connection is not required for tests 73 | db, err := pgx.Connect(context.Background(), "postgres://rolname@hostname/dbname") 74 | if err != nil { 75 | panic(err) 76 | } 77 | defer db.Close(context.Background()) 78 | 79 | if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil { 80 | panic(err) 81 | } 82 | } 83 | ``` 84 | 85 | ### Tests with pgxmock 86 | 87 | ``` go 88 | package main 89 | 90 | import ( 91 | "context" 92 | "fmt" 93 | "testing" 94 | 95 | "github.com/pashagolub/pgxmock/v4" 96 | ) 97 | 98 | // a successful case 99 | func TestShouldUpdateStats(t *testing.T) { 100 | mock, err := pgxmock.NewPool() 101 | if err != nil { 102 | t.Fatal(err) 103 | } 104 | defer mock.Close() 105 | 106 | mock.ExpectBegin() 107 | mock.ExpectExec("UPDATE products"). 108 | WillReturnResult(pgxmock.NewResult("UPDATE", 1)) 109 | mock.ExpectExec("INSERT INTO product_viewers"). 110 | WithArgs(2, 3). 111 | WillReturnResult(pgxmock.NewResult("INSERT", 1)) 112 | mock.ExpectCommit() 113 | 114 | // now we execute our method 115 | if err = recordStats(mock, 2, 3); err != nil { 116 | t.Errorf("error was not expected while updating: %s", err) 117 | } 118 | 119 | // we make sure that all expectations were met 120 | if err := mock.ExpectationsWereMet(); err != nil { 121 | t.Errorf("there were unfulfilled expectations: %s", err) 122 | } 123 | } 124 | 125 | // a failing test case 126 | func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) { 127 | mock, err := pgxmock.NewPool() 128 | if err != nil { 129 | t.Fatal(err) 130 | } 131 | defer mock.Close() 132 | 133 | mock.ExpectBegin() 134 | mock.ExpectExec("UPDATE products"). 135 | WillReturnResult(pgxmock.NewResult("UPDATE", 1)) 136 | mock.ExpectExec("INSERT INTO product_viewers"). 137 | WithArgs(2, 3). 138 | WillReturnError(fmt.Errorf("some error")) 139 | mock.ExpectRollback() 140 | 141 | // now we execute our method 142 | if err = recordStats(mock, 2, 3); err == nil { 143 | t.Errorf("was expecting an error, but there was none") 144 | } 145 | 146 | // we make sure that all expectations were met 147 | if err := mock.ExpectationsWereMet(); err != nil { 148 | t.Errorf("there were unfulfilled expectations: %s", err) 149 | } 150 | } 151 | ``` 152 | 153 | ## Customize SQL query matching 154 | 155 | There were plenty of requests from users regarding SQL query string validation or different matching option. 156 | We have now implemented the `QueryMatcher` interface, which can be passed through an option when calling 157 | `pgxmock.New` or `pgxmock.NewWithDSN`. 158 | 159 | This now allows to include some library, which would allow for example to parse and validate SQL AST. 160 | And create a custom QueryMatcher in order to validate SQL in sophisticated ways. 161 | 162 | By default, **pgxmock** is preserving backward compatibility and default query matcher is `pgxmock.QueryMatcherRegexp` 163 | which uses expected SQL string as a regular expression to match incoming query string. There is an equality matcher: 164 | `QueryMatcherEqual` which will do a full case sensitive match. 165 | 166 | In order to customize the QueryMatcher, use the following: 167 | 168 | ``` go 169 | mock, err := pgxmock.New(context.Background(), pgxmock.QueryMatcherOption(pgxmock.QueryMatcherEqual)) 170 | ``` 171 | 172 | The query matcher can be fully customized based on user needs. **pgxmock** will not 173 | provide a standard sql parsing matchers. 174 | 175 | ## Matching arguments like time.Time 176 | 177 | There may be arguments which are of `struct` type and cannot be compared easily by value like `time.Time`. In this case 178 | **pgxmock** provides an [Argument](https://pkg.go.dev/github.com/pashagolub/pgxmock/v4#Argument) interface which 179 | can be used in more sophisticated matching. Here is a simple example of time argument matching: 180 | 181 | ``` go 182 | type AnyTime struct{} 183 | 184 | // Match satisfies sqlmock.Argument interface 185 | func (a AnyTime) Match(v interface{}) bool { 186 | _, ok := v.(time.Time) 187 | return ok 188 | } 189 | 190 | func TestAnyTimeArgument(t *testing.T) { 191 | t.Parallel() 192 | db, mock, err := New() 193 | if err != nil { 194 | t.Errorf("an error '%s' was not expected when opening a stub database connection", err) 195 | } 196 | defer db.Close() 197 | 198 | mock.ExpectExec("INSERT INTO users"). 199 | WithArgs("john", AnyTime{}). 200 | WillReturnResult(NewResult(1, 1)) 201 | 202 | _, err = db.Exec("INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now()) 203 | if err != nil { 204 | t.Errorf("error '%s' was not expected, while inserting a row", err) 205 | } 206 | 207 | if err := mock.ExpectationsWereMet(); err != nil { 208 | t.Errorf("there were unfulfilled expectations: %s", err) 209 | } 210 | } 211 | ``` 212 | 213 | It only asserts that argument is of `time.Time` type. 214 | 215 | ## Run tests 216 | 217 | go test -race 218 | 219 | ## Contributions 220 | 221 | Feel free to open a pull request. Note, if you wish to contribute an extension to public (exported methods or types) - 222 | please open an issue before, to discuss whether these changes can be accepted. All backward incompatible changes are 223 | and will be treated cautiously 224 | 225 | ## License 226 | 227 | The [three clause BSD license](http://en.wikipedia.org/wiki/BSD_licenses) 228 | 229 | ## Star History 230 | 231 | [![Star History Chart](https://api.star-history.com/svg?repos=pashagolub/pgxmock&type=Date)](https://star-history.com/#pashagolub/pgxmock&Date) 232 | 233 | -------------------------------------------------------------------------------- /argument.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | // Argument interface allows to match 4 | // any argument in specific way when used with 5 | // ExpectedQuery and ExpectedExec expectations. 6 | type Argument interface { 7 | Match(interface{}) bool 8 | } 9 | 10 | // AnyArg will return an Argument which can 11 | // match any kind of arguments. 12 | // 13 | // Useful for time.Time or similar kinds of arguments. 14 | func AnyArg() Argument { 15 | return anyArgument{} 16 | } 17 | 18 | type anyArgument struct{} 19 | 20 | func (a anyArgument) Match(_ interface{}) bool { 21 | return true 22 | } 23 | 24 | -------------------------------------------------------------------------------- /argument_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | pgx "github.com/jackc/pgx/v5" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | type AnyTime struct{} 14 | 15 | // Match satisfies pgxmock.Argument interface 16 | func (a AnyTime) Match(v interface{}) bool { 17 | _, ok := v.(time.Time) 18 | return ok 19 | } 20 | 21 | func TestAnyTimeArgument(t *testing.T) { 22 | t.Parallel() 23 | mock, err := NewConn() 24 | if err != nil { 25 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 26 | } 27 | 28 | mock.ExpectExec("INSERT INTO users"). 29 | WithArgs("john", AnyTime{}). 30 | WillReturnResult(NewResult("INSERT", 1)) 31 | 32 | _, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now()) 33 | if err != nil { 34 | t.Errorf("error '%s' was not expected, while inserting a row", err) 35 | } 36 | 37 | if err := mock.ExpectationsWereMet(); err != nil { 38 | t.Errorf("there were unfulfilled expectations: %s", err) 39 | } 40 | } 41 | 42 | func TestAnyTimeNamedArgument(t *testing.T) { 43 | t.Parallel() 44 | mock, err := NewConn() 45 | if err != nil { 46 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 47 | } 48 | 49 | mock.ExpectExec("INSERT INTO users"). 50 | WithArgs(pgx.NamedArgs{"name": "john", "time": AnyTime{}}). 51 | WillReturnResult(NewResult("INSERT", 1)) 52 | 53 | _, err = mock.Exec(context.Background(), 54 | "INSERT INTO users(name, created_at) VALUES (@name, @time)", 55 | pgx.NamedArgs{"name": "john", "time": time.Now()}, 56 | ) 57 | if err != nil { 58 | t.Errorf("error '%s' was not expected, while inserting a row", err) 59 | } 60 | 61 | if err := mock.ExpectationsWereMet(); err != nil { 62 | t.Errorf("there were unfulfilled expectations: %s", err) 63 | } 64 | } 65 | 66 | func TestByteSliceArgument(t *testing.T) { 67 | t.Parallel() 68 | mock, err := NewConn() 69 | if err != nil { 70 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 71 | } 72 | 73 | username := []byte("user") 74 | mock.ExpectExec("INSERT INTO users").WithArgs(username).WillReturnResult(NewResult("INSERT", 1)) 75 | 76 | _, err = mock.Exec(context.Background(), "INSERT INTO users(username) VALUES (?)", username) 77 | if err != nil { 78 | t.Errorf("error '%s' was not expected, while inserting a row", err) 79 | } 80 | 81 | if err := mock.ExpectationsWereMet(); err != nil { 82 | t.Errorf("there were unfulfilled expectations: %s", err) 83 | } 84 | } 85 | 86 | type failQryRW struct { 87 | pgx.QueryRewriter 88 | } 89 | 90 | func (fqrw failQryRW) RewriteQuery(_ context.Context, _ *pgx.Conn, sql string, _ []any) (newSQL string, newArgs []any, err error) { 91 | return "", nil, errors.New("cannot rewrite query " + sql) 92 | } 93 | 94 | func TestExpectQueryRewriterFail(t *testing.T) { 95 | 96 | t.Parallel() 97 | mock, err := NewConn() 98 | if err != nil { 99 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 100 | } 101 | 102 | mock.ExpectQuery(`INSERT INTO users\(username\) VALUES \(\@user\)`). 103 | WithRewrittenSQL(`INSERT INTO users\(username\) VALUES \(\$1\)`). 104 | WithArgs(failQryRW{}) 105 | _, err = mock.Query(context.Background(), "INSERT INTO users(username) VALUES (@user)", "baz") 106 | assert.Error(t, err) 107 | } 108 | 109 | func TestQueryRewriterFail(t *testing.T) { 110 | 111 | t.Parallel() 112 | mock, err := NewConn() 113 | if err != nil { 114 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 115 | } 116 | mock.ExpectExec(`INSERT INTO .+`).WithArgs("foo") 117 | _, err = mock.Exec(context.Background(), "INSERT INTO users(username) VALUES (@user)", failQryRW{}) 118 | assert.Error(t, err) 119 | 120 | } 121 | 122 | func TestByteSliceNamedArgument(t *testing.T) { 123 | t.Parallel() 124 | mock, err := NewConn() 125 | if err != nil { 126 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 127 | } 128 | 129 | username := []byte("user") 130 | mock.ExpectExec(`INSERT INTO users\(username\) VALUES \(\@user\)`). 131 | WithArgs(pgx.NamedArgs{"user": username}). 132 | WithRewrittenSQL(`INSERT INTO users\(username\) VALUES \(\$1\)`). 133 | WillReturnResult(NewResult("INSERT", 1)) 134 | 135 | _, err = mock.Exec(context.Background(), 136 | "INSERT INTO users(username) VALUES (@user)", 137 | pgx.NamedArgs{"user": username}, 138 | ) 139 | if err != nil { 140 | t.Errorf("error '%s' was not expected, while inserting a row", err) 141 | } 142 | 143 | if err := mock.ExpectationsWereMet(); err != nil { 144 | t.Errorf("there were unfulfilled expectations: %s", err) 145 | } 146 | } 147 | 148 | func TestAnyArgument(t *testing.T) { 149 | t.Parallel() 150 | mock, err := NewConn() 151 | if err != nil { 152 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 153 | } 154 | 155 | mock.ExpectExec("INSERT INTO users"). 156 | WithArgs("john", AnyArg()). 157 | WillReturnResult(NewResult("INSERT", 1)) 158 | 159 | _, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (?, ?)", "john", time.Now()) 160 | if err != nil { 161 | t.Errorf("error '%s' was not expected, while inserting a row", err) 162 | } 163 | 164 | if err := mock.ExpectationsWereMet(); err != nil { 165 | t.Errorf("there were unfulfilled expectations: %s", err) 166 | } 167 | } 168 | 169 | func TestAnyNamedArgument(t *testing.T) { 170 | t.Parallel() 171 | mock, err := NewConn() 172 | if err != nil { 173 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 174 | } 175 | 176 | mock.ExpectExec("INSERT INTO users"). 177 | WithArgs("john", AnyArg()). 178 | WillReturnResult(NewResult("INSERT", 1)) 179 | 180 | _, err = mock.Exec(context.Background(), "INSERT INTO users(name, created_at) VALUES (@name, @created)", 181 | pgx.NamedArgs{"name": "john", "created": time.Now()}, 182 | ) 183 | if err != nil { 184 | t.Errorf("error '%s' was not expected, while inserting a row", err) 185 | } 186 | 187 | if err := mock.ExpectationsWereMet(); err != nil { 188 | t.Errorf("there were unfulfilled expectations: %s", err) 189 | } 190 | } 191 | 192 | type panicArg struct{} 193 | 194 | var errPanicArg = errors.New("this is a panic argument") 195 | 196 | func (p panicArg) Match(_ any) bool { 197 | // This will always panic when called 198 | panic(errPanicArg) 199 | } 200 | 201 | var _ Argument = panicArg{} 202 | 203 | func TestCloseAfterArgumentPanic(t *testing.T) { 204 | mock, err := NewConn() 205 | if err != nil { 206 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 207 | } 208 | 209 | defer func() { 210 | checkFinishedWithin(t, 1*time.Second, func(ctx context.Context) { 211 | _ = mock.Close(ctx) 212 | }) 213 | }() 214 | 215 | mock.ExpectExec("INSERT INTO users"). 216 | WithArgs(panicArg{}). 217 | WillReturnResult(NewResult("INSERT", 1)) 218 | 219 | assert.PanicsWithValue(t, errPanicArg, func() { 220 | _, _ = mock.Exec(context.Background(), "INSERT INTO users(name) VALUES (@name)", 221 | pgx.NamedArgs{"name": "john"}, 222 | ) 223 | }) 224 | } 225 | 226 | func checkFinishedWithin(t *testing.T, timeout time.Duration, fun func(ctx context.Context)) { 227 | t.Helper() 228 | closeCtx, cancel := context.WithTimeout(context.Background(), timeout) 229 | defer cancel() 230 | finishedChan := make(chan bool) 231 | go func() { 232 | defer func() { 233 | finishedChan <- true 234 | close(finishedChan) 235 | }() 236 | defer func() { 237 | _ = recover() 238 | }() 239 | fun(closeCtx) 240 | }() 241 | select { 242 | case <-finishedChan: 243 | return 244 | case <-closeCtx.Done(): 245 | t.Error("timed out waiting for function to finish") 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /batch.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | pgx "github.com/jackc/pgx/v5" 8 | pgconn "github.com/jackc/pgx/v5/pgconn" 9 | ) 10 | 11 | type batchResults struct { 12 | mock *pgxmock 13 | batch *pgx.Batch 14 | expectedBatch *ExpectedBatch 15 | qqIdx int 16 | err error 17 | } 18 | 19 | func (br *batchResults) nextQueryAndArgs() (query string, args []any, err error) { 20 | if br.err != nil { 21 | return "", nil, br.err 22 | } 23 | if br.batch == nil { 24 | return "", nil, errors.New("no batch expectations set") 25 | } 26 | if br.qqIdx >= len(br.batch.QueuedQueries) { 27 | return "", nil, errors.New("no more queries in batch") 28 | } 29 | bi := br.batch.QueuedQueries[br.qqIdx] 30 | query = bi.SQL 31 | args = bi.Arguments 32 | br.qqIdx++ 33 | return 34 | } 35 | 36 | func (br *batchResults) Exec() (pgconn.CommandTag, error) { 37 | query, arguments, err := br.nextQueryAndArgs() 38 | if err != nil { 39 | return pgconn.NewCommandTag(""), err 40 | } 41 | return br.mock.Exec(context.Background(), query, arguments...) 42 | } 43 | 44 | func (br *batchResults) Query() (pgx.Rows, error) { 45 | query, arguments, err := br.nextQueryAndArgs() 46 | if err != nil { 47 | return nil, err 48 | } 49 | return br.mock.Query(context.Background(), query, arguments...) 50 | } 51 | 52 | func (br *batchResults) QueryRow() pgx.Row { 53 | query, arguments, err := br.nextQueryAndArgs() 54 | if err != nil { 55 | return errRow{err: err} 56 | } 57 | return br.mock.QueryRow(context.Background(), query, arguments...) 58 | } 59 | 60 | func (br *batchResults) Close() error { 61 | if br.err != nil { 62 | return br.err 63 | } 64 | // Read and run fn for all remaining items 65 | for br.err == nil && br.expectedBatch != nil && !br.expectedBatch.closed && br.qqIdx < len(br.batch.QueuedQueries) { 66 | if qq := br.batch.QueuedQueries[br.qqIdx]; qq != nil { 67 | br.err = errors.Join(br.err, br.callQuedQueryFn(qq)) 68 | } 69 | } 70 | br.expectedBatch.closed = true 71 | return br.err 72 | } 73 | 74 | func (br *batchResults) callQuedQueryFn(qq *pgx.QueuedQuery) error { 75 | if qq.Fn != nil { 76 | return qq.Fn(br) 77 | } 78 | _, err := br.Exec() 79 | return err 80 | } 81 | -------------------------------------------------------------------------------- /batch_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | pgx "github.com/jackc/pgx/v5" 8 | pgconn "github.com/jackc/pgx/v5/pgconn" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestBatch(t *testing.T) { 13 | t.Parallel() 14 | mock, _ := NewConn() 15 | a := assert.New(t) 16 | 17 | // define our expectations 18 | eb := mock.ExpectBatch() 19 | eb.ExpectQuery("select").WillReturnRows(NewRows([]string{"sum"}).AddRow(2)) 20 | eb.ExpectExec("update").WithArgs(true, 1).WillReturnResult(NewResult("UPDATE", 1)) 21 | 22 | // run the test 23 | batch := &pgx.Batch{} 24 | batch.Queue("select 1 + 1").QueryRow(func(row pgx.Row) error { 25 | var n int 26 | return row.Scan(&n) 27 | }) 28 | batch.Queue("update users set active = $1 where id = $2", true, 1).Exec(func(ct pgconn.CommandTag) (err error) { 29 | if ct.RowsAffected() != 1 { 30 | err = errors.New("expected 1 row to be affected") 31 | } 32 | return 33 | }) 34 | 35 | err := mock.SendBatch(ctx, batch).Close() 36 | a.NoError(err) 37 | a.NoError(mock.ExpectationsWereMet()) 38 | } 39 | 40 | func TestExplicitBatch(t *testing.T) { 41 | t.Parallel() 42 | mock, _ := NewConn() 43 | a := assert.New(t) 44 | 45 | // define our expectations 46 | eb := mock.ExpectBatch() 47 | eb.ExpectQuery("select").WillReturnRows(NewRows([]string{"sum"}).AddRow(2)) 48 | eb.ExpectQuery("select").WillReturnRows(NewRows([]string{"answer"}).AddRow(42)) 49 | eb.ExpectExec("update").WithArgs(true, 1).WillReturnResult(NewResult("UPDATE", 1)) 50 | 51 | // run the test 52 | batch := &pgx.Batch{} 53 | batch.Queue("select 1 + 1") 54 | batch.Queue("select 42") 55 | batch.Queue("update users set active = $1 where id = $2", true, 1) 56 | 57 | var sum int 58 | br := mock.SendBatch(ctx, batch) 59 | err := br.QueryRow().Scan(&sum) 60 | a.NoError(err) 61 | a.Equal(2, sum) 62 | 63 | var answer int 64 | rows, err := br.Query() 65 | a.NoError(err) 66 | rows.Next() 67 | err = rows.Scan(&answer) 68 | a.NoError(err) 69 | a.Equal(42, answer) 70 | 71 | ct, err := br.Exec() 72 | a.NoError(err) 73 | a.True(ct.Update()) 74 | a.EqualValues(1, ct.RowsAffected()) 75 | 76 | // no more queries 77 | _, err = br.Exec() 78 | a.Error(err) 79 | _, err = br.Query() 80 | a.Error(err) 81 | err = br.QueryRow().Scan(&sum) 82 | a.Error(err) 83 | 84 | a.NoError(mock.ExpectationsWereMet()) 85 | } 86 | 87 | func processBatch(db PgxPoolIface) error { 88 | batch := &pgx.Batch{} 89 | // Random order 90 | batch.Queue("SELECT id FROM normalized_queries WHERE query = $1", "some query") 91 | batch.Queue("INSERT INTO normalized_queries (query) VALUES ($1) RETURNING id", "some query") 92 | 93 | results := db.SendBatch(ctx, batch) 94 | defer results.Close() 95 | 96 | for i := 0; i < batch.Len(); i++ { 97 | var id int 98 | err := results.QueryRow().Scan(&id) 99 | if err != nil { 100 | return err 101 | } 102 | } 103 | 104 | return nil 105 | } 106 | 107 | func TestUnorderedBatchExpectations(t *testing.T) { 108 | t.Parallel() 109 | a := assert.New(t) 110 | 111 | mock, err := NewPool() 112 | a.NoError(err) 113 | defer mock.Close() 114 | 115 | mock.MatchExpectationsInOrder(false) 116 | 117 | expectedBatch := mock.ExpectBatch() 118 | expectedBatch.ExpectQuery("INSERT INTO").WithArgs("some query"). 119 | WillReturnRows(NewRows([]string{"id"}).AddRow(10)) 120 | expectedBatch.ExpectQuery("SELECT id").WithArgs("some query"). 121 | WillReturnRows(NewRows([]string{"id"}).AddRow(20)) 122 | 123 | err = processBatch(mock) 124 | a.NoError(err) 125 | a.NoError(mock.ExpectationsWereMet()) 126 | } 127 | -------------------------------------------------------------------------------- /driver.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | 7 | pgx "github.com/jackc/pgx/v5" 8 | "github.com/jackc/pgx/v5/pgxpool" 9 | ) 10 | 11 | type pgxmockConn struct { 12 | pgxmock 13 | } 14 | 15 | // NewConn creates PgxConnIface database connection and a mock to manage expectations. 16 | // Accepts options, like QueryMatcherOption, to match SQL query strings in more sophisticated ways. 17 | func NewConn(options ...func(*pgxmock) error) (PgxConnIface, error) { 18 | smock := &pgxmockConn{} 19 | smock.ordered = true 20 | return smock, smock.open(options) 21 | } 22 | 23 | func (c *pgxmockConn) Config() *pgx.ConnConfig { 24 | return &pgx.ConnConfig{} 25 | } 26 | 27 | type pgxmockPool struct { 28 | pgxmock 29 | } 30 | 31 | // NewPool creates PgxPoolIface pool of database connections and a mock to manage expectations. 32 | // Accepts options, like QueryMatcherOption, to match SQL query strings in more sophisticated ways. 33 | func NewPool(options ...func(*pgxmock) error) (PgxPoolIface, error) { 34 | smock := &pgxmockPool{} 35 | smock.ordered = true 36 | return smock, smock.open(options) 37 | } 38 | 39 | func (p *pgxmockPool) Close() { 40 | p.pgxmock.Close(context.Background()) 41 | } 42 | 43 | func (p *pgxmockPool) Acquire(context.Context) (*pgxpool.Conn, error) { 44 | return nil, errors.New("pgpool.Acquire() method is not implemented") 45 | } 46 | 47 | func (p *pgxmockPool) Config() *pgxpool.Config { 48 | return &pgxpool.Config{ConnConfig: &pgx.ConnConfig{}} 49 | } 50 | 51 | // AsConn is similar to Acquire but returns proper mocking interface 52 | func (p *pgxmockPool) AsConn() PgxConnIface { 53 | return &pgxmockConn{pgxmock: p.pgxmock} 54 | } 55 | 56 | func (p *pgxmockPool) Stat() *pgxpool.Stat { 57 | return &pgxpool.Stat{} 58 | } 59 | -------------------------------------------------------------------------------- /driver_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | func TestTwoOpenConnectionsOnTheSameDSN(t *testing.T) { 9 | mock, err := NewConn() 10 | if err != nil { 11 | t.Fatalf("expected no error, but got: %s", err) 12 | } 13 | mock2, err := NewConn() 14 | if err != nil { 15 | t.Fatalf("expected no error, but got: %s", err) 16 | } 17 | if mock == mock2 { 18 | t.Errorf("expected not the same mock instance, but it is the same") 19 | } 20 | mock.Close(context.Background()) 21 | mock2.Close(context.Background()) 22 | } 23 | 24 | func TestPools(t *testing.T) { 25 | mock, err := NewPool() 26 | if err != nil { 27 | t.Fatalf("expected no error, but got: %s", err) 28 | } 29 | mock2, err := NewPool() 30 | if err != nil { 31 | t.Fatalf("expected no error, but got: %s", err) 32 | } 33 | if mock == mock2 { 34 | t.Errorf("expected not the same mock instance, but it is the same") 35 | } 36 | conn := mock.AsConn() 37 | if conn == nil { 38 | t.Error("expected connection strruct, but got nil") 39 | } 40 | mock.Close() 41 | mock2.Close() 42 | } 43 | 44 | func TestAcquire(t *testing.T) { 45 | mock, err := NewPool() 46 | if err != nil { 47 | t.Fatalf("expected no error, but got: %s", err) 48 | } 49 | _, err = mock.Acquire(context.Background()) 50 | if err == nil { 51 | t.Error("expected error, but got nil") 52 | } 53 | } 54 | 55 | func TestPoolStat(t *testing.T) { 56 | mock, err := NewPool() 57 | if err != nil { 58 | t.Fatalf("expected no error, but got: %s", err) 59 | } 60 | s := mock.Stat() 61 | if s == nil { 62 | t.Error("expected stat object, but got nil") 63 | } 64 | } 65 | 66 | func TestPoolConfig(t *testing.T) { 67 | mock, err := NewPool() 68 | if err != nil { 69 | t.Fatalf("expected no error, but got: %s", err) 70 | } 71 | c := mock.Config() 72 | if c == nil { 73 | t.Fatal("expected config object, but got nil") 74 | } 75 | if c.ConnConfig == nil { 76 | t.Fatal("expected conn config object, but got nil") 77 | } 78 | } 79 | 80 | func TestConnConfig(t *testing.T) { 81 | mock, err := NewConn() 82 | if err != nil { 83 | t.Fatalf("expected no error, but got: %s", err) 84 | } 85 | c := mock.Config() 86 | if c == nil { 87 | t.Fatal("expected config object, but got nil") 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /examples/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [] 7 | } -------------------------------------------------------------------------------- /examples/basic/basic.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | 6 | pgx "github.com/jackc/pgx/v5" 7 | pgxpool "github.com/jackc/pgx/v5/pgxpool" 8 | ) 9 | 10 | type PgxIface interface { 11 | Begin(context.Context) (pgx.Tx, error) 12 | Close() 13 | } 14 | 15 | func recordStats(db PgxIface, userID, productID int) (err error) { 16 | tx, err := db.Begin(context.Background()) 17 | if err != nil { 18 | return 19 | } 20 | defer func() { 21 | switch err { 22 | case nil: 23 | err = tx.Commit(context.Background()) 24 | default: 25 | _ = tx.Rollback(context.Background()) 26 | } 27 | }() 28 | sql := "UPDATE products SET views = views + 1" 29 | if _, err = tx.Exec(context.Background(), sql); err != nil { 30 | return 31 | } 32 | sql = "INSERT INTO product_viewers (user_id, product_id) VALUES ($1, $2)" 33 | if _, err = tx.Exec(context.Background(), sql, userID, productID); err != nil { 34 | return 35 | } 36 | return 37 | } 38 | 39 | func main() { 40 | // @NOTE: the real connection is not required for tests 41 | db, err := pgxpool.New(context.Background(), "postgres://rolname@hostname/dbname") 42 | if err != nil { 43 | panic(err) 44 | } 45 | defer db.Close() 46 | 47 | if err = recordStats(db, 1 /*some user id*/, 5 /*some product id*/); err != nil { 48 | panic(err) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /examples/basic/basic_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/pashagolub/pgxmock/v4" 8 | ) 9 | 10 | // a successful case 11 | func TestShouldUpdateStats(t *testing.T) { 12 | mock, err := pgxmock.NewPool() 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | defer mock.Close() 17 | 18 | mock.ExpectBegin() 19 | mock.ExpectExec("UPDATE products"). 20 | WillReturnResult(pgxmock.NewResult("UPDATE", 1)) 21 | mock.ExpectExec("INSERT INTO product_viewers"). 22 | WithArgs(2, 3). 23 | WillReturnResult(pgxmock.NewResult("INSERT", 1)) 24 | mock.ExpectCommit() 25 | 26 | // now we execute our method 27 | if err = recordStats(mock, 2, 3); err != nil { 28 | t.Errorf("error was not expected while updating: %s", err) 29 | } 30 | 31 | // we make sure that all expectations were met 32 | if err := mock.ExpectationsWereMet(); err != nil { 33 | t.Errorf("there were unfulfilled expectations: %s", err) 34 | } 35 | } 36 | 37 | // a failing test case 38 | func TestShouldRollbackStatUpdatesOnFailure(t *testing.T) { 39 | mock, err := pgxmock.NewPool() 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | defer mock.Close() 44 | 45 | mock.ExpectBegin() 46 | mock.ExpectExec("UPDATE products"). 47 | WillReturnResult(pgxmock.NewResult("UPDATE", 1)) 48 | mock.ExpectExec("INSERT INTO product_viewers"). 49 | WithArgs(2, 3). 50 | WillReturnError(fmt.Errorf("some error")) 51 | mock.ExpectRollback() 52 | 53 | // now we execute our method 54 | if err = recordStats(mock, 2, 3); err == nil { 55 | t.Errorf("was expecting an error, but there was none") 56 | } 57 | 58 | // we make sure that all expectations were met 59 | if err := mock.ExpectationsWereMet(); err != nil { 60 | t.Errorf("there were unfulfilled expectations: %s", err) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /examples/blog/blog.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "net/http" 7 | 8 | pgx "github.com/jackc/pgx/v5" 9 | pgconn "github.com/jackc/pgx/v5/pgconn" 10 | ) 11 | 12 | type PgxIface interface { 13 | Begin(context.Context) (pgx.Tx, error) 14 | Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) 15 | QueryRow(context.Context, string, ...interface{}) pgx.Row 16 | Query(context.Context, string, ...interface{}) (pgx.Rows, error) 17 | Ping(context.Context) error 18 | Prepare(context.Context, string, string) (*pgconn.StatementDescription, error) 19 | Close(context.Context) error 20 | } 21 | 22 | type api struct { 23 | db PgxIface 24 | } 25 | 26 | type post struct { 27 | ID int 28 | Title string 29 | Body string 30 | } 31 | 32 | func (a *api) posts(w http.ResponseWriter, _ *http.Request) { 33 | rows, err := a.db.Query(context.Background(), "SELECT id, title, body FROM posts") 34 | if err != nil { 35 | a.fail(w, "failed to fetch posts: "+err.Error(), 500) 36 | return 37 | } 38 | defer rows.Close() 39 | 40 | var posts []*post 41 | for rows.Next() { 42 | p := &post{} 43 | if err := rows.Scan(&p.ID, &p.Title, &p.Body); err != nil { 44 | a.fail(w, "failed to scan post: "+err.Error(), 500) 45 | return 46 | } 47 | posts = append(posts, p) 48 | } 49 | if rows.Err() != nil { 50 | a.fail(w, "failed to read all posts: "+rows.Err().Error(), 500) 51 | return 52 | } 53 | 54 | data := struct { 55 | Posts []*post 56 | }{posts} 57 | 58 | a.ok(w, data) 59 | } 60 | 61 | func main() { 62 | // @NOTE: the real connection is not required for tests 63 | db, err := pgx.Connect(context.Background(), "postgres://postgres@localhost/blog") 64 | if err != nil { 65 | panic(err) 66 | } 67 | app := &api{db: db} 68 | http.HandleFunc("/posts", app.posts) 69 | _ = http.ListenAndServe(":8080", nil) 70 | } 71 | 72 | func (a *api) fail(w http.ResponseWriter, msg string, status int) { 73 | w.Header().Set("Content-Type", "application/json") 74 | 75 | data := struct { 76 | Error string 77 | }{Error: msg} 78 | 79 | resp, _ := json.Marshal(data) 80 | w.WriteHeader(status) 81 | _, _ = w.Write(resp) 82 | } 83 | 84 | func (a *api) ok(w http.ResponseWriter, data interface{}) { 85 | w.Header().Set("Content-Type", "application/json") 86 | 87 | resp, err := json.Marshal(data) 88 | if err != nil { 89 | w.WriteHeader(http.StatusInternalServerError) 90 | a.fail(w, "oops something evil has happened", 500) 91 | return 92 | } 93 | _, _ = w.Write(resp) 94 | } 95 | -------------------------------------------------------------------------------- /examples/blog/blog_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "net/http" 9 | "net/http/httptest" 10 | "testing" 11 | 12 | "github.com/pashagolub/pgxmock/v4" 13 | ) 14 | 15 | func (a *api) assertJSON(actual []byte, data interface{}, t *testing.T) { 16 | expected, err := json.Marshal(data) 17 | if err != nil { 18 | t.Fatalf("an error '%s' was not expected when marshaling expected json data", err) 19 | } 20 | 21 | if !bytes.Equal(expected, actual) { 22 | t.Errorf("the expected json: %s is different from actual %s", expected, actual) 23 | } 24 | } 25 | 26 | func TestShouldGetPosts(t *testing.T) { 27 | mock, err := pgxmock.NewConn() 28 | if err != nil { 29 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 30 | } 31 | defer mock.Close(context.Background()) 32 | 33 | // create app with mocked db, request and response to test 34 | app := &api{mock} 35 | req, err := http.NewRequest("GET", "http://localhost/posts", nil) 36 | if err != nil { 37 | t.Fatalf("an error '%s' was not expected while creating request", err) 38 | } 39 | w := httptest.NewRecorder() 40 | 41 | // before we actually execute our api function, we need to expect required DB actions 42 | rows := mock.NewRows([]string{"id", "title", "body"}). 43 | AddRow(1, "post 1", "hello"). 44 | AddRow(2, "post 2", "world") 45 | 46 | mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnRows(rows) 47 | 48 | // now we execute our request 49 | app.posts(w, req) 50 | 51 | if w.Code != 200 { 52 | t.Fatalf("expected status code to be 200, but got: %d\nBody: %v", w.Code, w.Body) 53 | } 54 | 55 | data := struct { 56 | Posts []*post 57 | }{Posts: []*post{ 58 | {ID: 1, Title: "post 1", Body: "hello"}, 59 | {ID: 2, Title: "post 2", Body: "world"}, 60 | }} 61 | app.assertJSON(w.Body.Bytes(), data, t) 62 | 63 | // we make sure that all expectations were met 64 | if err := mock.ExpectationsWereMet(); err != nil { 65 | t.Errorf("there were unfulfilled expectations: %s", err) 66 | } 67 | } 68 | 69 | func TestShouldRespondWithErrorOnFailure(t *testing.T) { 70 | mock, err := pgxmock.NewConn() 71 | if err != nil { 72 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 73 | } 74 | defer mock.Close(context.Background()) 75 | 76 | // create app with mocked db, request and response to test 77 | app := &api{mock} 78 | req, err := http.NewRequest("GET", "http://localhost/posts", nil) 79 | if err != nil { 80 | t.Fatalf("an error '%s' was not expected while creating request", err) 81 | } 82 | w := httptest.NewRecorder() 83 | 84 | // before we actually execute our api function, we need to expect required DB actions 85 | mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnError(fmt.Errorf("some error")) 86 | 87 | // now we execute our request 88 | app.posts(w, req) 89 | 90 | if w.Code != 500 { 91 | t.Fatalf("expected status code to be 500, but got: %d", w.Code) 92 | } 93 | 94 | data := struct { 95 | Error string 96 | }{"failed to fetch posts: some error"} 97 | app.assertJSON(w.Body.Bytes(), data, t) 98 | 99 | // we make sure that all expectations were met 100 | if err := mock.ExpectationsWereMet(); err != nil { 101 | t.Errorf("there were unfulfilled expectations: %s", err) 102 | } 103 | } 104 | 105 | func TestNoPostsReturned(t *testing.T) { 106 | mock, err := pgxmock.NewConn() 107 | if err != nil { 108 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 109 | } 110 | defer mock.Close(context.Background()) 111 | 112 | // create app with mocked db, request and response to test 113 | app := &api{mock} 114 | req, err := http.NewRequest("GET", "http://localhost/posts", nil) 115 | if err != nil { 116 | t.Fatalf("an error '%s' was not expected while creating request", err) 117 | } 118 | w := httptest.NewRecorder() 119 | 120 | mock.ExpectQuery("^SELECT (.+) FROM posts$").WillReturnRows(mock.NewRows([]string{"id", "title", "body"})) 121 | // now we execute our request 122 | app.posts(w, req) 123 | if w.Code != 200 { 124 | t.Fatalf("expected status code to be 200, but got: %d\nBody: %v", w.Code, w.Body) 125 | } 126 | 127 | // we make sure that all expectations were met 128 | if err := mock.ExpectationsWereMet(); err != nil { 129 | t.Errorf("there were unfulfilled expectations: %s", err) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /examples/doc.go: -------------------------------------------------------------------------------- 1 | package examples 2 | -------------------------------------------------------------------------------- /expectations.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "strings" 9 | "sync" 10 | "time" 11 | 12 | pgx "github.com/jackc/pgx/v5" 13 | pgconn "github.com/jackc/pgx/v5/pgconn" 14 | ) 15 | 16 | // an expectation interface 17 | type expectation interface { 18 | error() error 19 | required() bool 20 | fulfilled() bool 21 | fulfill() 22 | sync.Locker 23 | fmt.Stringer 24 | } 25 | 26 | // CallModifier interface represents common interface for all expectations supported 27 | type CallModifier interface { 28 | // Maybe allows the expected method call to be optional. 29 | // Not calling an optional method will not cause an error while asserting expectations 30 | Maybe() CallModifier 31 | // Times indicates that that the expected method should only fire the indicated number of times. 32 | // Zero value is ignored and means the same as one. 33 | Times(n uint) CallModifier 34 | // WillDelayFor allows to specify duration for which it will delay 35 | // result. May be used together with Context 36 | WillDelayFor(duration time.Duration) CallModifier 37 | // WillReturnError allows to set an error for the expected method 38 | WillReturnError(err error) 39 | // WillPanic allows to force the expected method to panic 40 | WillPanic(v any) 41 | } 42 | 43 | // common expectation struct 44 | // satisfies the expectation interface 45 | type commonExpectation struct { 46 | sync.Mutex 47 | triggered uint // how many times method was called 48 | err error // should method return error 49 | optional bool // can method be skipped 50 | panicArgument any // panic value to return for recovery 51 | plannedDelay time.Duration // should method delay before return 52 | plannedCalls uint // how many sequentional calls should be made 53 | } 54 | 55 | func (e *commonExpectation) error() error { 56 | return e.err 57 | } 58 | 59 | func (e *commonExpectation) fulfill() { 60 | e.triggered++ 61 | } 62 | 63 | func (e *commonExpectation) fulfilled() bool { 64 | return e.triggered >= max(e.plannedCalls, 1) 65 | } 66 | 67 | func (e *commonExpectation) required() bool { 68 | return !e.optional 69 | } 70 | 71 | func (e *commonExpectation) waitForDelay(ctx context.Context) (err error) { 72 | select { 73 | case <-time.After(e.plannedDelay): 74 | err = e.error() 75 | case <-ctx.Done(): 76 | err = ctx.Err() 77 | } 78 | if e.panicArgument != nil { 79 | panic(e.panicArgument) 80 | } 81 | return err 82 | } 83 | 84 | func (e *commonExpectation) Maybe() CallModifier { 85 | e.optional = true 86 | return e 87 | } 88 | 89 | func (e *commonExpectation) Times(n uint) CallModifier { 90 | e.plannedCalls = n 91 | return e 92 | } 93 | 94 | func (e *commonExpectation) WillDelayFor(duration time.Duration) CallModifier { 95 | e.plannedDelay = duration 96 | return e 97 | } 98 | 99 | func (e *commonExpectation) WillReturnError(err error) { 100 | e.err = err 101 | } 102 | 103 | var errPanic = errors.New("pgxmock panic") 104 | 105 | func (e *commonExpectation) WillPanic(v any) { 106 | e.err = errPanic 107 | e.panicArgument = v 108 | } 109 | 110 | // String returns string representation 111 | func (e *commonExpectation) String() string { 112 | w := new(strings.Builder) 113 | if e.err != nil { 114 | if e.err != errPanic { 115 | fmt.Fprintf(w, "\t- returns error: %v\n", e.err) 116 | } else { 117 | fmt.Fprintf(w, "\t- panics with: %v\n", e.panicArgument) 118 | } 119 | } 120 | if e.plannedDelay > 0 { 121 | fmt.Fprintf(w, "\t- delayed execution for: %v\n", e.plannedDelay) 122 | } 123 | if e.optional { 124 | fmt.Fprint(w, "\t- execution is optional\n") 125 | } 126 | if e.plannedCalls > 0 { 127 | fmt.Fprintf(w, "\t- execution calls awaited: %d\n", e.plannedCalls) 128 | } 129 | return w.String() 130 | } 131 | 132 | // queryBasedExpectation is a base class that adds a query matching logic 133 | type queryBasedExpectation struct { 134 | expectSQL string 135 | expectRewrittenSQL string 136 | args []interface{} 137 | } 138 | 139 | func (e *queryBasedExpectation) argsMatches(sql string, args []interface{}) (rewrittenSQL string, err error) { 140 | eargs := e.args 141 | // check for any QueryRewriter arguments: only supported as the first argument 142 | if len(args) == 1 { 143 | if qrw, ok := args[0].(pgx.QueryRewriter); ok { 144 | // note: pgx.Conn is not currently used by the query rewriter 145 | if rewrittenSQL, args, err = qrw.RewriteQuery(context.Background(), nil, sql, args); err != nil { 146 | return rewrittenSQL, fmt.Errorf("error rewriting query: %w", err) 147 | } 148 | } 149 | // also do rewriting on the expected args if a QueryRewriter is present 150 | if len(eargs) == 1 { 151 | if qrw, ok := eargs[0].(pgx.QueryRewriter); ok { 152 | if _, eargs, err = qrw.RewriteQuery(context.Background(), nil, sql, eargs); err != nil { 153 | return "", fmt.Errorf("error rewriting query expectation: %w", err) 154 | } 155 | } 156 | } 157 | } 158 | if len(args) != len(eargs) { 159 | return rewrittenSQL, fmt.Errorf("expected %d, but got %d arguments", len(eargs), len(args)) 160 | } 161 | for k, v := range args { 162 | // custom argument matcher 163 | if matcher, ok := eargs[k].(Argument); ok { 164 | if !matcher.Match(v) { 165 | return rewrittenSQL, fmt.Errorf("matcher %T could not match %d argument %T - %+v", matcher, k, args[k], args[k]) 166 | } 167 | continue 168 | } 169 | if darg := eargs[k]; !reflect.DeepEqual(darg, v) { 170 | return rewrittenSQL, fmt.Errorf("argument %d expected [%T - %+v] does not match actual [%T - %+v]", k, darg, darg, v, v) 171 | } 172 | } 173 | return 174 | } 175 | 176 | // ExpectedClose is used to manage pgx.Close expectation 177 | // returned by pgxmock.ExpectClose 178 | type ExpectedClose struct { 179 | commonExpectation 180 | } 181 | 182 | // String returns string representation 183 | func (e *ExpectedClose) String() string { 184 | return "ExpectedClose => expecting call to Close()\n" + e.commonExpectation.String() 185 | } 186 | 187 | // ExpectedBegin is used to manage *pgx.Begin expectation 188 | // returned by pgxmock.ExpectBegin. 189 | type ExpectedBegin struct { 190 | commonExpectation 191 | opts pgx.TxOptions 192 | } 193 | 194 | // String returns string representation 195 | func (e *ExpectedBegin) String() string { 196 | msg := "ExpectedBegin => expecting call to Begin() or to BeginTx()\n" 197 | if e.opts != (pgx.TxOptions{}) { 198 | msg += fmt.Sprintf("\t- transaction options awaited: %+v\n", e.opts) 199 | } 200 | return msg + e.commonExpectation.String() 201 | } 202 | 203 | // ExpectedCommit is used to manage pgx.Tx.Commit expectation 204 | // returned by pgxmock.ExpectCommit. 205 | type ExpectedCommit struct { 206 | commonExpectation 207 | } 208 | 209 | // String returns string representation 210 | func (e *ExpectedCommit) String() string { 211 | return "ExpectedCommit => expecting call to Tx.Commit()\n" + e.commonExpectation.String() 212 | } 213 | 214 | // ExpectedExec is used to manage pgx.Exec, pgx.Tx.Exec or pgx.Stmt.Exec expectations. 215 | // Returned by pgxmock.ExpectExec. 216 | type ExpectedExec struct { 217 | commonExpectation 218 | queryBasedExpectation 219 | result pgconn.CommandTag 220 | } 221 | 222 | // WithArgs will match given expected args to actual database exec operation arguments. 223 | // if at least one argument does not match, it will return an error. For specific 224 | // arguments an pgxmock.Argument interface can be used to match an argument. 225 | func (e *ExpectedExec) WithArgs(args ...interface{}) *ExpectedExec { 226 | e.args = args 227 | return e 228 | } 229 | 230 | // WithRewrittenSQL will match given expected expression to a rewritten SQL statement by 231 | // an pgx.QueryRewriter argument 232 | func (e *ExpectedExec) WithRewrittenSQL(sql string) *ExpectedExec { 233 | e.expectRewrittenSQL = sql 234 | return e 235 | } 236 | 237 | // String returns string representation 238 | func (e *ExpectedExec) String() string { 239 | msg := "ExpectedExec => expecting call to Exec():\n" 240 | msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) 241 | 242 | if len(e.args) == 0 { 243 | msg += "\t- is without arguments\n" 244 | } else { 245 | msg += "\t- is with arguments:\n" 246 | for i, arg := range e.args { 247 | msg += fmt.Sprintf("\t\t%d - %+v\n", i, arg) 248 | } 249 | } 250 | if e.result.String() != "" { 251 | msg += fmt.Sprintf("\t- returns result: %s\n", e.result) 252 | } 253 | 254 | return msg + e.commonExpectation.String() 255 | } 256 | 257 | // WillReturnResult arranges for an expected Exec() to return a particular 258 | // result, there is pgxmock.NewResult(op string, rowsAffected int64) method 259 | // to build a corresponding result. 260 | func (e *ExpectedExec) WillReturnResult(result pgconn.CommandTag) *ExpectedExec { 261 | e.result = result 262 | return e 263 | } 264 | 265 | // ExpectedBatch is used to manage pgx.Batch expectations. 266 | // Returned by pgxmock.ExpectBatch. 267 | type ExpectedBatch struct { 268 | commonExpectation 269 | mock *pgxmock 270 | expectedQueries []*queryBasedExpectation 271 | closed bool 272 | mustBeClosed bool 273 | } 274 | 275 | // ExpectExec allows to expect Queue().Exec() on this batch. 276 | func (e *ExpectedBatch) ExpectExec(query string) *ExpectedExec { 277 | ee := &ExpectedExec{} 278 | ee.expectSQL = query 279 | e.expectedQueries = append(e.expectedQueries, &ee.queryBasedExpectation) 280 | e.mock.expectations = append(e.mock.expectations, ee) 281 | return ee 282 | } 283 | 284 | // ExpectQuery allows to expect Queue().Query() or Queue().QueryRow() on this batch. 285 | func (e *ExpectedBatch) ExpectQuery(query string) *ExpectedQuery { 286 | eq := &ExpectedQuery{} 287 | eq.expectSQL = query 288 | e.expectedQueries = append(e.expectedQueries, &eq.queryBasedExpectation) 289 | e.mock.expectations = append(e.mock.expectations, eq) 290 | return eq 291 | } 292 | 293 | // String returns string representation 294 | func (e *ExpectedBatch) String() string { 295 | msg := "ExpectedBatch => expecting call to SendBatch()\n" 296 | if e.mustBeClosed { 297 | msg += "\t- batch must be closed\n" 298 | } 299 | return msg + e.commonExpectation.String() 300 | } 301 | 302 | // ExpectedPrepare is used to manage pgx.Prepare or pgx.Tx.Prepare expectations. 303 | // Returned by pgxmock.ExpectPrepare. 304 | type ExpectedPrepare struct { 305 | commonExpectation 306 | expectStmtName string 307 | expectSQL string 308 | } 309 | 310 | // String returns string representation 311 | func (e *ExpectedPrepare) String() string { 312 | msg := "ExpectedPrepare => expecting call to Prepare():\n" 313 | msg += fmt.Sprintf("\t- matches statement name: '%s'\n", e.expectStmtName) 314 | msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) 315 | return msg + e.commonExpectation.String() 316 | } 317 | 318 | // ExpectedDeallocate is used to manage pgx.Deallocate and pgx.DeallocateAll expectations. 319 | // Returned by pgxmock.ExpectDeallocate(string) and pgxmock.ExpectDeallocateAll(). 320 | type ExpectedDeallocate struct { 321 | commonExpectation 322 | expectStmtName string 323 | expectAll bool 324 | } 325 | 326 | // String returns string representation 327 | func (e *ExpectedDeallocate) String() string { 328 | msg := "ExpectedDeallocate => expecting call to Deallocate():\n" 329 | if e.expectAll { 330 | msg += "\t- matches all statements\n" 331 | } else { 332 | msg += fmt.Sprintf("\t- matches statement name: '%s'\n", e.expectStmtName) 333 | } 334 | return msg + e.commonExpectation.String() 335 | } 336 | 337 | // ExpectedPing is used to manage Ping() expectations 338 | type ExpectedPing struct { 339 | commonExpectation 340 | } 341 | 342 | // String returns string representation 343 | func (e *ExpectedPing) String() string { 344 | msg := "ExpectedPing => expecting call to Ping()\n" 345 | return msg + e.commonExpectation.String() 346 | } 347 | 348 | // ExpectedQuery is used to manage *pgx.Conn.Query, *pgx.Conn.QueryRow, *pgx.Tx.Query, 349 | // *pgx.Tx.QueryRow, *pgx.Stmt.Query or *pgx.Stmt.QueryRow expectations 350 | type ExpectedQuery struct { 351 | commonExpectation 352 | queryBasedExpectation 353 | rows pgx.Rows 354 | rowsMustBeClosed bool 355 | rowsWereClosed bool 356 | } 357 | 358 | // WithArgs will match given expected args to actual database query arguments. 359 | // if at least one argument does not match, it will return an error. For specific 360 | // arguments an pgxmock.Argument interface can be used to match an argument. 361 | func (e *ExpectedQuery) WithArgs(args ...interface{}) *ExpectedQuery { 362 | e.args = args 363 | return e 364 | } 365 | 366 | // WithRewrittenSQL will match given expected expression to a rewritten SQL statement by 367 | // an pgx.QueryRewriter argument 368 | func (e *ExpectedQuery) WithRewrittenSQL(sql string) *ExpectedQuery { 369 | e.expectRewrittenSQL = sql 370 | return e 371 | } 372 | 373 | // RowsWillBeClosed expects this query rows to be closed. 374 | func (e *ExpectedQuery) RowsWillBeClosed() *ExpectedQuery { 375 | e.rowsMustBeClosed = true 376 | return e 377 | } 378 | 379 | // String returns string representation 380 | func (e *ExpectedQuery) String() string { 381 | msg := "ExpectedQuery => expecting call to Query() or to QueryRow():\n" 382 | msg += fmt.Sprintf("\t- matches sql: '%s'\n", e.expectSQL) 383 | 384 | if len(e.args) == 0 { 385 | msg += "\t- is without arguments\n" 386 | } else { 387 | msg += "\t- is with arguments:\n" 388 | for i, arg := range e.args { 389 | msg += fmt.Sprintf("\t\t%d - %+v\n", i, arg) 390 | } 391 | } 392 | if e.rows != nil { 393 | msg += fmt.Sprintf("%s\n", e.rows) 394 | } 395 | return msg + e.commonExpectation.String() 396 | } 397 | 398 | // WillReturnRows specifies the set of resulting rows that will be returned 399 | // by the triggered query 400 | func (e *ExpectedQuery) WillReturnRows(rows ...*Rows) *ExpectedQuery { 401 | e.rows = &rowSets{sets: rows, ex: e} 402 | return e 403 | } 404 | 405 | // ExpectedCopyFrom is used to manage *pgx.Conn.CopyFrom expectations. 406 | // Returned by *Pgxmock.ExpectCopyFrom. 407 | type ExpectedCopyFrom struct { 408 | commonExpectation 409 | expectedTableName pgx.Identifier 410 | expectedColumns []string 411 | rowsAffected int64 412 | } 413 | 414 | // String returns string representation 415 | func (e *ExpectedCopyFrom) String() string { 416 | msg := "ExpectedCopyFrom => expecting CopyFrom which:" 417 | msg += "\n - matches table name: '" + e.expectedTableName.Sanitize() + "'" 418 | msg += fmt.Sprintf("\n - matches column names: '%+v'", e.expectedColumns) 419 | 420 | if e.err != nil { 421 | msg += fmt.Sprintf("\n - should returns error: %s", e.err) 422 | } 423 | 424 | return msg 425 | } 426 | 427 | // WillReturnResult arranges for an expected CopyFrom() to return a number of rows affected 428 | func (e *ExpectedCopyFrom) WillReturnResult(result int64) *ExpectedCopyFrom { 429 | e.rowsAffected = result 430 | return e 431 | } 432 | 433 | // ExpectedReset is used to manage pgx.Reset expectation 434 | type ExpectedReset struct { 435 | commonExpectation 436 | } 437 | 438 | func (e *ExpectedReset) String() string { 439 | return "ExpectedReset => expecting database Reset" 440 | } 441 | 442 | // ExpectedRollback is used to manage pgx.Tx.Rollback expectation 443 | // returned by pgxmock.ExpectRollback. 444 | type ExpectedRollback struct { 445 | commonExpectation 446 | } 447 | 448 | // String returns string representation 449 | func (e *ExpectedRollback) String() string { 450 | msg := "ExpectedRollback => expecting transaction Rollback" 451 | if e.err != nil { 452 | msg += fmt.Sprintf(", which should return error: %s", e.err) 453 | } 454 | return msg 455 | } 456 | -------------------------------------------------------------------------------- /expectations_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "testing" 9 | "time" 10 | 11 | "github.com/jackc/pgx/v5" 12 | "github.com/jackc/pgx/v5/pgconn" 13 | "github.com/jackc/pgx/v5/pgtype" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | var ctx = context.Background() 18 | 19 | func TestTimes(t *testing.T) { 20 | t.Parallel() 21 | mock, _ := NewConn() 22 | a := assert.New(t) 23 | mock.ExpectPing().Times(2) 24 | err := mock.Ping(ctx) 25 | a.NoError(err) 26 | a.Error(mock.ExpectationsWereMet()) // must be two Ping() calls 27 | err = mock.Ping(ctx) 28 | a.NoError(err) 29 | a.NoError(mock.ExpectationsWereMet()) 30 | } 31 | 32 | func TestMaybe(t *testing.T) { 33 | t.Parallel() 34 | mock, _ := NewConn() 35 | a := assert.New(t) 36 | mock.ExpectPing().Maybe() 37 | mock.ExpectBegin().Maybe() 38 | mock.ExpectQuery("SET TIME ZONE 'Europe/Rome'").Maybe() //only if we're in Italy 39 | cmdtag := pgconn.NewCommandTag("SELECT 1") 40 | mock.ExpectExec("select").WillReturnResult(cmdtag) 41 | mock.ExpectCommit().Maybe() 42 | 43 | res, err := mock.Exec(ctx, "select version()") 44 | a.Equal(cmdtag, res) 45 | a.NoError(err) 46 | a.NoError(mock.ExpectationsWereMet()) 47 | } 48 | 49 | func TestPanic(t *testing.T) { 50 | t.Parallel() 51 | mock, _ := NewConn() 52 | a := assert.New(t) 53 | defer func() { 54 | a.NotNil(recover(), "The code did not panic") 55 | a.NoError(mock.ExpectationsWereMet()) 56 | }() 57 | 58 | ex := mock.ExpectPing() 59 | ex.WillPanic("i'm tired") 60 | fmt.Println(ex) 61 | a.NoError(mock.Ping(ctx)) 62 | } 63 | 64 | func TestCallModifier(t *testing.T) { 65 | t.Parallel() 66 | mock, _ := NewConn() 67 | a := assert.New(t) 68 | 69 | mock.ExpectPing().WillDelayFor(time.Second).Maybe().Times(4) 70 | 71 | c, f := context.WithCancel(ctx) 72 | f() 73 | a.Error(mock.Ping(c), "should raise error for cancelled context") 74 | 75 | a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() call is optional 76 | 77 | a.NoError(mock.Ping(ctx)) 78 | a.NoError(mock.ExpectationsWereMet()) //should produce no error since Ping() was called actually 79 | } 80 | 81 | func TestCopyFromBug(t *testing.T) { 82 | mock, _ := NewConn() 83 | a := assert.New(t) 84 | 85 | mock.ExpectCopyFrom(pgx.Identifier{"foo"}, []string{"bar"}).WillReturnResult(1) 86 | 87 | var rows [][]any 88 | rows = append(rows, []any{"baz"}) 89 | 90 | r, err := mock.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"bar"}, pgx.CopyFromRows(rows)) 91 | a.EqualValues(len(rows), r) 92 | a.NoError(err) 93 | a.NoError(mock.ExpectationsWereMet()) 94 | } 95 | 96 | func ExampleExpectedExec() { 97 | mock, _ := NewConn() 98 | ex := mock.ExpectExec("^INSERT (.+)").WillReturnResult(NewResult("INSERT", 15)) 99 | ex.WillDelayFor(time.Second).Maybe().Times(2) 100 | 101 | fmt.Print(ex) 102 | res, _ := mock.Exec(ctx, "INSERT something") 103 | fmt.Println(res) 104 | ex.WithArgs(42) 105 | fmt.Print(ex) 106 | res, _ = mock.Exec(ctx, "INSERT something", 42) 107 | fmt.Print(res) 108 | // Output: 109 | // ExpectedExec => expecting call to Exec(): 110 | // - matches sql: '^INSERT (.+)' 111 | // - is without arguments 112 | // - returns result: INSERT 15 113 | // - delayed execution for: 1s 114 | // - execution is optional 115 | // - execution calls awaited: 2 116 | // INSERT 15 117 | // ExpectedExec => expecting call to Exec(): 118 | // - matches sql: '^INSERT (.+)' 119 | // - is with arguments: 120 | // 0 - 42 121 | // - returns result: INSERT 15 122 | // - delayed execution for: 1s 123 | // - execution is optional 124 | // - execution calls awaited: 2 125 | // INSERT 15 126 | } 127 | 128 | func TestUnexpectedPing(t *testing.T) { 129 | mock, _ := NewConn() 130 | err := mock.Ping(ctx) 131 | if err == nil { 132 | t.Error("Ping should return error for unexpected call") 133 | } 134 | mock.ExpectExec("foo") 135 | err = mock.Ping(ctx) 136 | if err == nil { 137 | t.Error("Ping should return error for unexpected call") 138 | } 139 | } 140 | 141 | func TestUnexpectedPrepare(t *testing.T) { 142 | mock, _ := NewConn() 143 | _, err := mock.Prepare(ctx, "foo", "bar") 144 | if err == nil { 145 | t.Error("Prepare should return error for unexpected call") 146 | } 147 | mock.ExpectExec("foo") 148 | _, err = mock.Prepare(ctx, "foo", "bar") 149 | if err == nil { 150 | t.Error("Prepare should return error for unexpected call") 151 | } 152 | } 153 | 154 | func TestUnexpectedCopyFrom(t *testing.T) { 155 | mock, _ := NewConn() 156 | _, err := mock.CopyFrom(ctx, pgx.Identifier{"schema", "table"}, []string{"foo", "bar"}, nil) 157 | if err == nil { 158 | t.Error("CopyFrom should return error for unexpected call") 159 | } 160 | mock.ExpectExec("foo") 161 | _, err = mock.CopyFrom(ctx, pgx.Identifier{"schema", "table"}, []string{"foo", "bar"}, nil) 162 | if err == nil { 163 | t.Error("CopyFrom should return error for unexpected call") 164 | } 165 | } 166 | 167 | func TestBuildQuery(t *testing.T) { 168 | mock, _ := NewConn() 169 | a := assert.New(t) 170 | query := ` 171 | SELECT 172 | name, 173 | email, 174 | address, 175 | anotherfield 176 | FROM user 177 | where 178 | name = 'John' 179 | and 180 | address = 'Jakarta' 181 | 182 | ` 183 | 184 | mock.ExpectPing().WillDelayFor(1 * time.Second).WillReturnError(errors.New("no ping please")) 185 | mock.ExpectQuery(query).WillReturnError(errors.New("oops")) 186 | mock.ExpectExec(query).WillReturnResult(NewResult("SELECT", 1)) 187 | mock.ExpectPrepare("foo", query) 188 | 189 | err := mock.Ping(ctx) 190 | a.Error(err) 191 | mock.QueryRow(ctx, query) 192 | _, err = mock.Exec(ctx, query) 193 | a.NoError(err) 194 | _, err = mock.Prepare(ctx, "foo", query) 195 | a.NoError(err) 196 | 197 | a.NoError(mock.ExpectationsWereMet()) 198 | } 199 | 200 | func TestQueryRowScan(t *testing.T) { 201 | mock, _ := NewConn() //TODO New(ValueConverterOption(CustomConverter{})) 202 | query := ` 203 | SELECT 204 | name, 205 | email, 206 | address, 207 | anotherfield 208 | FROM user 209 | where 210 | name = 'John' 211 | and 212 | address = 'Jakarta' 213 | 214 | ` 215 | expectedStringValue := "ValueOne" 216 | expectedIntValue := 2 217 | expectedArrayValue := []string{"Three", "Four"} 218 | mock.ExpectQuery(query).WillReturnRows(mock.NewRows([]string{"One", "Two", "Three"}).AddRow(expectedStringValue, expectedIntValue, []string{"Three", "Four"})) 219 | row := mock.QueryRow(ctx, query) 220 | var stringValue string 221 | var intValue int 222 | var arrayValue []string 223 | if e := row.Scan(&stringValue, &intValue, &arrayValue); e != nil { 224 | t.Error(e) 225 | } 226 | if stringValue != expectedStringValue { 227 | t.Errorf("Expectation %s does not met: %s", expectedStringValue, stringValue) 228 | } 229 | if intValue != expectedIntValue { 230 | t.Errorf("Expectation %d does not met: %d", expectedIntValue, intValue) 231 | } 232 | if !reflect.DeepEqual(expectedArrayValue, arrayValue) { 233 | t.Errorf("Expectation %v does not met: %v", expectedArrayValue, arrayValue) 234 | } 235 | if err := mock.ExpectationsWereMet(); err != nil { 236 | t.Error(err) 237 | } 238 | } 239 | 240 | func TestMissingWithArgs(t *testing.T) { 241 | mock, _ := NewConn() 242 | // No arguments expected 243 | mock.ExpectExec("INSERT something") 244 | // Receiving argument 245 | _, err := mock.Exec(ctx, "INSERT something", "something") 246 | if err == nil { 247 | t.Error("arguments do not match error was expected") 248 | } 249 | if err := mock.ExpectationsWereMet(); err == nil { 250 | t.Error("expectation was not matched error was expected") 251 | } 252 | } 253 | 254 | type user struct { 255 | ID int64 256 | name string 257 | email pgtype.Text 258 | } 259 | 260 | func (u *user) RewriteQuery(_ context.Context, _ *pgx.Conn, sql string, _ []any) (newSQL string, newArgs []any, err error) { 261 | switch sql { 262 | case "INSERT": 263 | return `INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id`, []any{u.name, u.email}, nil 264 | case "UPDATE": 265 | return `UPDATE users SET username = $1, email = $2 WHERE id = $1`, []any{u.ID, u.name, u.email}, nil 266 | case "DELETE": 267 | return `DELETE FROM users WHERE id = $1`, []any{u.ID}, nil 268 | } 269 | return 270 | } 271 | 272 | func TestWithRewrittenSQL(t *testing.T) { 273 | t.Parallel() 274 | mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual)) 275 | a := assert.New(t) 276 | a.NoError(err) 277 | 278 | u := user{name: "John", email: pgtype.Text{String: "john@example.com", Valid: true}} 279 | mock.ExpectQuery(`INSERT`). 280 | WithArgs(&u). 281 | WithRewrittenSQL(`INSERT INTO users (username, email) VALUES ($1, $2) RETURNING id`). 282 | WillReturnRows() 283 | 284 | _, err = mock.Query(context.Background(), "INSERT", &u) 285 | a.NoError(err) 286 | a.NoError(mock.ExpectationsWereMet()) 287 | 288 | mock.ExpectQuery(`INSERT INTO users(username, password) VALUES (@user, @password)`). 289 | WithArgs(pgx.NamedArgs{"user": "John", "password": "strong"}). 290 | WithRewrittenSQL(`INSERT INTO users(username, password) VALUES ($1)`). 291 | WillReturnRows() 292 | 293 | _, err = mock.Query(context.Background(), 294 | "INSERT INTO users(username) VALUES (@user)", 295 | pgx.NamedArgs{"user": "John", "password": "strong"}, 296 | ) 297 | a.Error(err) 298 | a.Error(mock.ExpectationsWereMet()) 299 | } 300 | 301 | func TestQueryRewriter(t *testing.T) { 302 | t.Parallel() 303 | mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual)) 304 | a := assert.New(t) 305 | a.NoError(err) 306 | 307 | update := `UPDATE "user" SET email = @email, password = @password, updated_utc = @updated_utc WHERE id = @id` 308 | 309 | mock.ExpectExec(update).WithArgs(pgx.NamedArgs{ 310 | "id": "mockUser.ID", 311 | "email": "mockUser.Email", 312 | "password": "mockUser.Password", 313 | "updated_utc": AnyArg(), 314 | }).WillReturnError(errPanic) 315 | 316 | _, err = mock.Exec(context.Background(), update, pgx.NamedArgs{ 317 | "id": "mockUser.ID", 318 | "email": "mockUser.Email", 319 | "password": "mockUser.Password", 320 | "updated_utc": time.Now().UTC(), 321 | }) 322 | a.Error(err) 323 | a.NoError(mock.ExpectationsWereMet()) 324 | } 325 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pashagolub/pgxmock/v4 2 | 3 | go 1.23.0 4 | 5 | toolchain go1.24.2 6 | 7 | require ( 8 | github.com/jackc/pgx/v5 v5.7.4 9 | github.com/stretchr/testify v1.9.0 10 | ) 11 | 12 | require ( 13 | github.com/davecgh/go-spew v1.1.1 // indirect 14 | github.com/jackc/pgpassfile v1.0.0 // indirect 15 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 16 | github.com/jackc/puddle/v2 v2.2.2 // indirect 17 | github.com/kr/text v0.2.0 // indirect 18 | github.com/pmezard/go-difflib v1.0.0 // indirect 19 | github.com/rogpeppe/go-internal v1.12.0 // indirect 20 | golang.org/x/crypto v0.37.0 // indirect 21 | golang.org/x/sync v0.13.0 // indirect 22 | golang.org/x/text v0.24.0 // indirect 23 | gopkg.in/yaml.v3 v3.0.1 // indirect 24 | ) 25 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 6 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 7 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= 8 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 9 | github.com/jackc/pgx/v5 v5.7.0 h1:FG6VLIdzvAPhnYqP14sQ2xhFLkiUQHCs6ySqO91kF4g= 10 | github.com/jackc/pgx/v5 v5.7.0/go.mod h1:awP1KNnjylvpxHuHP63gzjhnGkI1iw+PMoIwvoleN/8= 11 | github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= 12 | github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= 13 | github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= 14 | github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 15 | github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= 16 | github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 17 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 18 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 19 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 20 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 21 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 22 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 23 | github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= 24 | github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= 25 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 26 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 27 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 28 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 29 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 30 | golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= 31 | golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= 32 | golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= 33 | golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= 34 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= 35 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 36 | golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= 37 | golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 38 | golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= 39 | golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 40 | golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= 41 | golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= 42 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 43 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 44 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 45 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 46 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 47 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 48 | -------------------------------------------------------------------------------- /options.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | // QueryMatcherOption allows to customize SQL query matcher 4 | // and match SQL query strings in more sophisticated ways. 5 | // The default QueryMatcher is QueryMatcherRegexp. 6 | func QueryMatcherOption(queryMatcher QueryMatcher) func(*pgxmock) error { 7 | return func(s *pgxmock) error { 8 | s.queryMatcher = queryMatcher 9 | return nil 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /pgxmock.go: -------------------------------------------------------------------------------- 1 | /* 2 | package pgxmock is a mock library implementing pgx connector. Which has one and only 3 | purpose - to simulate pgx driver behavior in tests, without needing a real 4 | database connection. It helps to maintain correct **TDD** workflow. 5 | 6 | It does not require (almost) any modifications to your source code in order to test 7 | and mock database operations. Supports concurrency and multiple database mocking. 8 | 9 | The driver allows to mock any pgx driver method behavior. 10 | */ 11 | package pgxmock 12 | 13 | import ( 14 | "context" 15 | "errors" 16 | "fmt" 17 | "reflect" 18 | 19 | pgx "github.com/jackc/pgx/v5" 20 | pgconn "github.com/jackc/pgx/v5/pgconn" 21 | pgxpool "github.com/jackc/pgx/v5/pgxpool" 22 | ) 23 | 24 | // Expecter interface serves to create expectations 25 | // for any kind of database action in order to mock 26 | // and test real database behavior. 27 | type Expecter interface { 28 | // ExpectationsWereMet checks whether all queued expectations 29 | // were met in order (unless MatchExpectationsInOrder set to false). 30 | // If any of them was not met - an error is returned. 31 | ExpectationsWereMet() error 32 | 33 | // ExpectBatch expects pgx.Batch to be called. The *ExpectedBatch 34 | // allows to mock database response 35 | ExpectBatch() *ExpectedBatch 36 | 37 | // ExpectClose queues an expectation for this database 38 | // action to be triggered. The *ExpectedClose allows 39 | // to mock database response 40 | ExpectClose() *ExpectedClose 41 | 42 | // ExpectPrepare expects Prepare() to be called with expectedSQL query. 43 | ExpectPrepare(expectedStmtName, expectedSQL string) *ExpectedPrepare 44 | 45 | // ExpectDeallocate expects Deallocate() to be called with expectedStmtName. 46 | // The *ExpectedDeallocate allows to mock database response 47 | ExpectDeallocate(expectedStmtName string) *ExpectedDeallocate 48 | ExpectDeallocateAll() *ExpectedDeallocate 49 | 50 | // ExpectQuery expects Query() or QueryRow() to be called with expectedSQL query. 51 | // the *ExpectedQuery allows to mock database response. 52 | ExpectQuery(expectedSQL string) *ExpectedQuery 53 | 54 | // ExpectExec expects Exec() to be called with expectedSQL query. 55 | // the *ExpectedExec allows to mock database response 56 | ExpectExec(expectedSQL string) *ExpectedExec 57 | 58 | // ExpectBegin expects pgx.Conn.Begin to be called. 59 | // the *ExpectedBegin allows to mock database response 60 | ExpectBegin() *ExpectedBegin 61 | 62 | // ExpectBeginTx expects expects BeginTx() to be called with expectedSQL 63 | // query. The *ExpectedBegin allows to mock database response. 64 | ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin 65 | 66 | // ExpectCommit expects pgx.Tx.Commit to be called. 67 | // the *ExpectedCommit allows to mock database response 68 | ExpectCommit() *ExpectedCommit 69 | 70 | // ExpectReset expects pgxpool.Reset() to be called. 71 | // The *ExpectedReset allows to mock database response 72 | ExpectReset() *ExpectedReset 73 | 74 | // ExpectRollback expects pgx.Tx.Rollback to be called. 75 | // the *ExpectedRollback allows to mock database response 76 | ExpectRollback() *ExpectedRollback 77 | 78 | // ExpectPing expected Ping() to be called. 79 | // The *ExpectedPing allows to mock database response 80 | ExpectPing() *ExpectedPing 81 | 82 | // ExpectCopyFrom expects pgx.CopyFrom to be called. 83 | // The *ExpectCopyFrom allows to mock database response 84 | ExpectCopyFrom(expectedTableName pgx.Identifier, expectedColumns []string) *ExpectedCopyFrom 85 | 86 | // MatchExpectationsInOrder gives an option whether to match all 87 | // expectations in the order they were set or not. 88 | // 89 | // By default it is set to - true. But if you use goroutines 90 | // to parallelize your query executation, that option may 91 | // be handy. 92 | // 93 | // This option may be turned on anytime during tests. As soon 94 | // as it is switched to false, expectations will be matched 95 | // in any order. Or otherwise if switched to true, any unmatched 96 | // expectations will be expected in order 97 | MatchExpectationsInOrder(bool) 98 | 99 | // NewRows allows Rows to be created from a []string slice. 100 | NewRows(columns []string) *Rows 101 | 102 | // NewRowsWithColumnDefinition allows Rows to be created from a 103 | // pgconn.FieldDescription slice with a definition of sql metadata 104 | NewRowsWithColumnDefinition(columns ...pgconn.FieldDescription) *Rows 105 | 106 | // New Column allows to create a Column 107 | NewColumn(name string) *pgconn.FieldDescription 108 | } 109 | 110 | // PgxCommonIface represents common interface for all pgx connection interfaces: 111 | // pgxpool.Pool, pgx.Conn and pgx.Tx 112 | type PgxCommonIface interface { 113 | Expecter 114 | pgx.Tx 115 | BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) 116 | Ping(context.Context) error 117 | } 118 | 119 | // PgxConnIface represents pgx.Conn specific interface 120 | type PgxConnIface interface { 121 | PgxCommonIface 122 | Close(ctx context.Context) error 123 | Deallocate(ctx context.Context, name string) error 124 | DeallocateAll(ctx context.Context) error 125 | Config() *pgx.ConnConfig 126 | PgConn() *pgconn.PgConn 127 | } 128 | 129 | // PgxPoolIface represents pgxpool.Pool specific interface 130 | type PgxPoolIface interface { 131 | PgxCommonIface 132 | Acquire(ctx context.Context) (*pgxpool.Conn, error) 133 | AcquireAllIdle(ctx context.Context) []*pgxpool.Conn 134 | AcquireFunc(ctx context.Context, f func(*pgxpool.Conn) error) error 135 | AsConn() PgxConnIface 136 | Close() 137 | Stat() *pgxpool.Stat 138 | Reset() 139 | Config() *pgxpool.Config 140 | } 141 | 142 | type pgxmock struct { 143 | ordered bool 144 | queryMatcher QueryMatcher 145 | expectations []expectation 146 | } 147 | 148 | func (c *pgxmock) AcquireAllIdle(_ context.Context) []*pgxpool.Conn { 149 | return []*pgxpool.Conn{} 150 | } 151 | 152 | func (c *pgxmock) AcquireFunc(_ context.Context, _ func(*pgxpool.Conn) error) error { 153 | return nil 154 | } 155 | 156 | // region Expectations 157 | func (c *pgxmock) ExpectBatch() *ExpectedBatch { 158 | e := &ExpectedBatch{mock: c} 159 | c.expectations = append(c.expectations, e) 160 | return e 161 | } 162 | 163 | func (c *pgxmock) ExpectClose() *ExpectedClose { 164 | e := &ExpectedClose{} 165 | c.expectations = append(c.expectations, e) 166 | return e 167 | } 168 | 169 | func (c *pgxmock) MatchExpectationsInOrder(b bool) { 170 | c.ordered = b 171 | } 172 | 173 | func (c *pgxmock) ExpectationsWereMet() error { 174 | for _, e := range c.expectations { 175 | e.Lock() 176 | fulfilled := e.fulfilled() || !e.required() 177 | e.Unlock() 178 | 179 | if !fulfilled { 180 | return fmt.Errorf("there is a remaining expectation which was not matched: %s", e) 181 | } 182 | 183 | // must check whether all expected queried rows are closed 184 | if query, ok := e.(*ExpectedQuery); ok { 185 | if query.rowsMustBeClosed && !query.rowsWereClosed { 186 | return fmt.Errorf("expected query rows to be closed, but it was not: %s", query) 187 | } 188 | } 189 | } 190 | return nil 191 | } 192 | 193 | func (c *pgxmock) ExpectQuery(expectedSQL string) *ExpectedQuery { 194 | e := &ExpectedQuery{} 195 | e.expectSQL = expectedSQL 196 | c.expectations = append(c.expectations, e) 197 | return e 198 | } 199 | 200 | func (c *pgxmock) ExpectCommit() *ExpectedCommit { 201 | e := &ExpectedCommit{} 202 | c.expectations = append(c.expectations, e) 203 | return e 204 | } 205 | 206 | func (c *pgxmock) ExpectRollback() *ExpectedRollback { 207 | e := &ExpectedRollback{} 208 | c.expectations = append(c.expectations, e) 209 | return e 210 | } 211 | 212 | func (c *pgxmock) ExpectBegin() *ExpectedBegin { 213 | e := &ExpectedBegin{} 214 | c.expectations = append(c.expectations, e) 215 | return e 216 | } 217 | 218 | func (c *pgxmock) ExpectBeginTx(txOptions pgx.TxOptions) *ExpectedBegin { 219 | e := &ExpectedBegin{opts: txOptions} 220 | c.expectations = append(c.expectations, e) 221 | return e 222 | } 223 | 224 | func (c *pgxmock) ExpectExec(expectedSQL string) *ExpectedExec { 225 | e := &ExpectedExec{} 226 | e.expectSQL = expectedSQL 227 | c.expectations = append(c.expectations, e) 228 | return e 229 | } 230 | 231 | func (c *pgxmock) ExpectCopyFrom(expectedTableName pgx.Identifier, expectedColumns []string) *ExpectedCopyFrom { 232 | e := &ExpectedCopyFrom{expectedTableName: expectedTableName, expectedColumns: expectedColumns} 233 | c.expectations = append(c.expectations, e) 234 | return e 235 | } 236 | 237 | // ExpectReset expects Reset to be called. 238 | func (c *pgxmock) ExpectReset() *ExpectedReset { 239 | e := &ExpectedReset{} 240 | c.expectations = append(c.expectations, e) 241 | return e 242 | } 243 | 244 | func (c *pgxmock) ExpectPing() *ExpectedPing { 245 | e := &ExpectedPing{} 246 | c.expectations = append(c.expectations, e) 247 | return e 248 | } 249 | 250 | func (c *pgxmock) ExpectPrepare(expectedStmtName, expectedSQL string) *ExpectedPrepare { 251 | e := &ExpectedPrepare{expectSQL: expectedSQL, expectStmtName: expectedStmtName} 252 | c.expectations = append(c.expectations, e) 253 | return e 254 | } 255 | 256 | func (c *pgxmock) ExpectDeallocate(expectedStmtName string) *ExpectedDeallocate { 257 | e := &ExpectedDeallocate{expectStmtName: expectedStmtName} 258 | c.expectations = append(c.expectations, e) 259 | return e 260 | } 261 | 262 | func (c *pgxmock) ExpectDeallocateAll() *ExpectedDeallocate { 263 | e := &ExpectedDeallocate{expectAll: true} 264 | c.expectations = append(c.expectations, e) 265 | return e 266 | } 267 | 268 | //endregion Expectations 269 | 270 | // NewRows allows Rows to be created from a 271 | // atring slice or from the CSV string and 272 | // to be used as sql driver.Rows. 273 | func (c *pgxmock) NewRows(columns []string) *Rows { 274 | r := NewRows(columns) 275 | return r 276 | } 277 | 278 | // PgConn exposes the underlying low level postgres connection 279 | // This is just here to support interfaces that use it. Here is just returns an empty PgConn 280 | func (c *pgxmock) PgConn() *pgconn.PgConn { 281 | p := pgconn.PgConn{} 282 | return &p 283 | } 284 | 285 | // NewRowsWithColumnDefinition allows Rows to be created from a 286 | // sql driver.Value slice with a definition of sql metadata 287 | func (c *pgxmock) NewRowsWithColumnDefinition(columns ...pgconn.FieldDescription) *Rows { 288 | r := NewRowsWithColumnDefinition(columns...) 289 | return r 290 | } 291 | 292 | // NewColumn allows to create a Column that can be enhanced with metadata 293 | // using OfType/Nullable/WithLength/WithPrecisionAndScale methods. 294 | func (c *pgxmock) NewColumn(name string) *pgconn.FieldDescription { 295 | return &pgconn.FieldDescription{Name: name} 296 | } 297 | 298 | // open a mock database driver connection 299 | func (c *pgxmock) open(options []func(*pgxmock) error) error { 300 | for _, option := range options { 301 | err := option(c) 302 | if err != nil { 303 | return err 304 | } 305 | } 306 | 307 | if c.queryMatcher == nil { 308 | c.queryMatcher = QueryMatcherRegexp 309 | } 310 | 311 | return nil 312 | } 313 | 314 | // Close a mock database driver connection. It may or may not 315 | // be called depending on the circumstances, but if it is called 316 | // there must be an *ExpectedClose expectation satisfied. 317 | func (c *pgxmock) Close(ctx context.Context) error { 318 | ex, err := findExpectation[*ExpectedClose](c, "Close()") 319 | if err != nil { 320 | return err 321 | } 322 | return ex.waitForDelay(ctx) 323 | } 324 | 325 | func (c *pgxmock) Conn() *pgx.Conn { 326 | panic("Conn() is not available in pgxmock") 327 | } 328 | 329 | func (c *pgxmock) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { 330 | ex, err := findExpectationFunc(c, "CopyFrom()", func(copyExp *ExpectedCopyFrom) error { 331 | if !reflect.DeepEqual(copyExp.expectedTableName, tableName) { 332 | return fmt.Errorf("CopyFrom: table name '%s' was not expected, expected table name is '%s'", tableName, copyExp.expectedTableName) 333 | } 334 | if !reflect.DeepEqual(copyExp.expectedColumns, columnNames) { 335 | return fmt.Errorf("CopyFrom: column names '%v' were not expected, expected column names are '%v'", columnNames, copyExp.expectedColumns) 336 | } 337 | return nil 338 | }) 339 | if err != nil { 340 | return -1, err 341 | } 342 | for rowSrc.Next() { 343 | if _, err := rowSrc.Values(); err != nil { 344 | return ex.rowsAffected, err 345 | } 346 | if rowSrc.Err() != nil { 347 | return ex.rowsAffected, rowSrc.Err() 348 | } 349 | } 350 | return ex.rowsAffected, ex.waitForDelay(ctx) 351 | } 352 | 353 | func (c *pgxmock) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { 354 | ex, err := findExpectationFunc(c, "Batch()", func(batchExp *ExpectedBatch) error { 355 | if len(batchExp.expectedQueries) != len(b.QueuedQueries) { 356 | return fmt.Errorf("SendBatch: number of queries in batch '%d' was not expected, expected number of queries is '%d'", 357 | len(b.QueuedQueries), len(batchExp.expectedQueries)) 358 | } 359 | if !c.ordered { // postpone the check of every query until/if it is called 360 | return nil 361 | } 362 | for i, query := range b.QueuedQueries { 363 | if err := c.queryMatcher.Match(batchExp.expectedQueries[i].expectSQL, query.SQL); err != nil { 364 | return err 365 | } 366 | if rewrittenSQL, err := batchExp.expectedQueries[i].argsMatches(query.SQL, query.Arguments); err != nil { 367 | return err 368 | } else if rewrittenSQL != "" && batchExp.expectedQueries[i].expectRewrittenSQL != "" { 369 | if err := c.queryMatcher.Match(batchExp.expectedQueries[i].expectRewrittenSQL, rewrittenSQL); err != nil { 370 | return err 371 | } 372 | } 373 | } 374 | return nil 375 | }) 376 | br := &batchResults{mock: c, batch: b, expectedBatch: ex, err: err} 377 | if err != nil { 378 | return br 379 | } 380 | br.err = ex.waitForDelay(ctx) 381 | return br 382 | } 383 | 384 | func (c *pgxmock) LargeObjects() pgx.LargeObjects { 385 | return pgx.LargeObjects{} 386 | } 387 | 388 | func (c *pgxmock) Begin(ctx context.Context) (pgx.Tx, error) { 389 | return c.BeginTx(ctx, pgx.TxOptions{}) 390 | } 391 | 392 | func (c *pgxmock) BeginTx(ctx context.Context, txOptions pgx.TxOptions) (pgx.Tx, error) { 393 | ex, err := findExpectationFunc(c, "BeginTx()", func(beginExp *ExpectedBegin) error { 394 | if beginExp.opts != txOptions { 395 | return fmt.Errorf("BeginTx: call with transaction options '%v' was not expected: %s", txOptions, beginExp) 396 | } 397 | return nil 398 | }) 399 | if err != nil { 400 | return nil, err 401 | } 402 | if err = ex.waitForDelay(ctx); err != nil { 403 | return nil, err 404 | } 405 | return c, nil 406 | } 407 | 408 | func (c *pgxmock) Prepare(ctx context.Context, name, query string) (*pgconn.StatementDescription, error) { 409 | ex, err := findExpectationFunc(c, "Prepare()", func(prepareExp *ExpectedPrepare) error { 410 | if err := c.queryMatcher.Match(prepareExp.expectSQL, query); err != nil { 411 | return err 412 | } 413 | if prepareExp.expectStmtName != name { 414 | return fmt.Errorf("Prepare: prepared statement name '%s' was not expected, expected name is '%s'", name, prepareExp.expectStmtName) 415 | } 416 | return nil 417 | }) 418 | if err != nil { 419 | return nil, err 420 | } 421 | if err = ex.waitForDelay(ctx); err != nil { 422 | return nil, err 423 | } 424 | return &pgconn.StatementDescription{Name: name, SQL: query}, nil 425 | } 426 | 427 | func (c *pgxmock) Deallocate(ctx context.Context, name string) error { 428 | ex, err := findExpectationFunc(c, "Deallocate()", func(deallocateExp *ExpectedDeallocate) error { 429 | if deallocateExp.expectAll { 430 | return fmt.Errorf("Deallocate: all prepared statements were expected to be deallocated, instead only '%s' specified", name) 431 | } 432 | if deallocateExp.expectStmtName != name { 433 | return fmt.Errorf("Deallocate: prepared statement name '%s' was not expected, expected name is '%s'", name, deallocateExp.expectStmtName) 434 | } 435 | return nil 436 | }) 437 | if err != nil { 438 | return err 439 | } 440 | return ex.waitForDelay(ctx) 441 | } 442 | 443 | func (c *pgxmock) DeallocateAll(ctx context.Context) error { 444 | ex, err := findExpectationFunc(c, "DeallocateAll()", func(deallocateExp *ExpectedDeallocate) error { 445 | if !deallocateExp.expectAll { 446 | return fmt.Errorf("Deallocate: deallocate all prepared statements was not expected, expected name is '%s'", deallocateExp.expectStmtName) 447 | } 448 | return nil 449 | }) 450 | if err != nil { 451 | return err 452 | } 453 | return ex.waitForDelay(ctx) 454 | } 455 | 456 | func (c *pgxmock) Commit(ctx context.Context) error { 457 | ex, err := findExpectation[*ExpectedCommit](c, "Commit()") 458 | if err != nil { 459 | return err 460 | } 461 | return ex.waitForDelay(ctx) 462 | } 463 | 464 | func (c *pgxmock) Rollback(ctx context.Context) error { 465 | ex, err := findExpectation[*ExpectedRollback](c, "Rollback()") 466 | if err != nil { 467 | return err 468 | } 469 | return ex.waitForDelay(ctx) 470 | } 471 | 472 | // Implement the "QueryerContext" interface 473 | func (c *pgxmock) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { 474 | ex, err := findExpectationFunc(c, "Query()", func(queryExp *ExpectedQuery) error { 475 | if err := c.queryMatcher.Match(queryExp.expectSQL, sql); err != nil { 476 | return err 477 | } 478 | if rewrittenSQL, err := queryExp.argsMatches(sql, args); err != nil { 479 | return err 480 | } else if rewrittenSQL != "" && queryExp.expectRewrittenSQL != "" { 481 | if err := c.queryMatcher.Match(queryExp.expectRewrittenSQL, rewrittenSQL); err != nil { 482 | return err 483 | } 484 | } 485 | if queryExp.err == nil && queryExp.rows == nil { 486 | return fmt.Errorf("Query must return a result rows or raise an error: %v", queryExp) 487 | } 488 | return nil 489 | }) 490 | if err != nil { 491 | return nil, err 492 | } 493 | return ex.rows, ex.waitForDelay(ctx) 494 | } 495 | 496 | type errRow struct { 497 | err error 498 | } 499 | 500 | func (er errRow) Scan(...interface{}) error { 501 | return er.err 502 | } 503 | 504 | func (c *pgxmock) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { 505 | rows, err := c.Query(ctx, sql, args...) 506 | if err != nil { 507 | return errRow{err: err} 508 | } 509 | return (*connRow)(rows.(*rowSets)) 510 | } 511 | 512 | func (c *pgxmock) Exec(ctx context.Context, query string, args ...interface{}) (pgconn.CommandTag, error) { 513 | ex, err := findExpectationFunc(c, "Exec()", func(execExp *ExpectedExec) error { 514 | if err := c.queryMatcher.Match(execExp.expectSQL, query); err != nil { 515 | return err 516 | } 517 | if rewrittenSQL, err := execExp.argsMatches(query, args); err != nil { 518 | return err 519 | } else if rewrittenSQL != "" && execExp.expectRewrittenSQL != "" { 520 | if err := c.queryMatcher.Match(execExp.expectRewrittenSQL, rewrittenSQL); err != nil { 521 | return err 522 | } 523 | } 524 | if execExp.result.String() == "" && execExp.err == nil { 525 | return fmt.Errorf("Exec must return a result or raise an error: %s", execExp) 526 | } 527 | return nil 528 | }) 529 | if err != nil { 530 | return pgconn.NewCommandTag(""), err 531 | } 532 | return ex.result, ex.waitForDelay(ctx) 533 | } 534 | 535 | func (c *pgxmock) Ping(ctx context.Context) (err error) { 536 | ex, err := findExpectation[*ExpectedPing](c, "Ping()") 537 | if err != nil { 538 | return err 539 | } 540 | return ex.waitForDelay(ctx) 541 | } 542 | 543 | func (c *pgxmock) Reset() { 544 | if ex, err := findExpectation[*ExpectedReset](c, "Reset()"); err == nil { 545 | _ = ex.waitForDelay(context.Background()) 546 | } 547 | } 548 | 549 | type expectationType[t any] interface { 550 | *t 551 | expectation 552 | } 553 | 554 | func findExpectationFunc[ET expectationType[t], t any](c *pgxmock, method string, cmp func(ET) error) (ET, error) { 555 | var expected ET 556 | var fulfilled int 557 | var ok bool 558 | var err error 559 | defer func() { 560 | if expected != nil { 561 | expected.Unlock() 562 | } 563 | }() 564 | for _, next := range c.expectations { 565 | next.Lock() 566 | if next.fulfilled() { 567 | next.Unlock() 568 | fulfilled++ 569 | continue 570 | } 571 | if expected, ok = next.(ET); ok { 572 | if err = cmp(expected); err == nil { 573 | break 574 | } 575 | } 576 | expected = nil 577 | next.Unlock() 578 | if c.ordered { 579 | if !next.required() { 580 | continue 581 | } 582 | if err != nil { 583 | return nil, err 584 | } 585 | return nil, fmt.Errorf("call to method %s, was not expected, next expectation is: %s", method, next) 586 | } 587 | } 588 | 589 | if expected == nil { 590 | msg := fmt.Sprintf("call to method %s was not expected", method) 591 | if fulfilled == len(c.expectations) { 592 | msg = "all expectations were already fulfilled, " + msg 593 | } 594 | return nil, errors.New(msg) 595 | } 596 | 597 | expected.fulfill() 598 | return expected, nil 599 | } 600 | 601 | func findExpectation[ET expectationType[t], t any](c *pgxmock, method string) (ET, error) { 602 | return findExpectationFunc(c, method, func(_ ET) error { return nil }) 603 | } 604 | -------------------------------------------------------------------------------- /pgxmock_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "strconv" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | pgx "github.com/jackc/pgx/v5" 13 | "github.com/jackc/pgx/v5/pgxpool" 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | func cancelOrder(db PgxCommonIface, orderID int) error { 18 | tx, _ := db.Begin(context.Background()) 19 | _, _ = tx.Query(context.Background(), "SELECT * FROM orders {0} FOR UPDATE", orderID) 20 | err := tx.Rollback(context.Background()) 21 | if err != nil { 22 | return err 23 | } 24 | return nil 25 | } 26 | 27 | func TestIssue14EscapeSQL(t *testing.T) { 28 | t.Parallel() 29 | mock, err := NewConn() 30 | if err != nil { 31 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 32 | } 33 | defer mock.Close(context.Background()) 34 | mock.ExpectExec("INSERT INTO mytable\\(a, b\\)"). 35 | WithArgs("A", "B"). 36 | WillReturnResult(NewResult("INSERT", 1)) 37 | 38 | _, err = mock.Exec(context.Background(), "INSERT INTO mytable(a, b) VALUES (?, ?)", "A", "B") 39 | if err != nil { 40 | t.Errorf("error '%s' was not expected, while inserting a row", err) 41 | } 42 | 43 | if err := mock.ExpectationsWereMet(); err != nil { 44 | t.Errorf("there were unfulfilled expectations: %s", err) 45 | } 46 | } 47 | 48 | // test the case when db is not triggered and expectations 49 | // are not asserted on close 50 | func TestIssue4(t *testing.T) { 51 | t.Parallel() 52 | mock, err := NewConn() 53 | if err != nil { 54 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 55 | } 56 | defer mock.Close(context.Background()) 57 | 58 | mock.ExpectQuery("some sql query which will not be called"). 59 | WillReturnRows(NewRows([]string{"id"})) 60 | 61 | if err := mock.ExpectationsWereMet(); err == nil { 62 | t.Errorf("was expecting an error since query was not triggered") 63 | } 64 | } 65 | 66 | func TestMockQuery(t *testing.T) { 67 | t.Parallel() 68 | mock, err := NewConn() 69 | if err != nil { 70 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 71 | } 72 | defer mock.Close(context.Background()) 73 | 74 | rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") 75 | 76 | mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). 77 | WithArgs(5). 78 | WillReturnRows(rs) 79 | 80 | rows, err := mock.Query(context.Background(), "SELECT (.+) FROM articles WHERE id = ?", 5) 81 | if err != nil { 82 | t.Fatalf("error '%s' was not expected while retrieving mock rows", err) 83 | } 84 | 85 | defer rows.Close() 86 | 87 | if !rows.Next() { 88 | t.Fatal("it must have had one row as result, but got empty result set instead") 89 | } 90 | 91 | var id int 92 | var title string 93 | 94 | err = rows.Scan(&id, &title) 95 | if err != nil { 96 | t.Errorf("error '%s' was not expected while trying to scan row", err) 97 | } 98 | 99 | if id != 5 { 100 | t.Errorf("expected mocked id to be 5, but got %d instead", id) 101 | } 102 | 103 | if title != "hello world" { 104 | t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title) 105 | } 106 | 107 | if err := mock.ExpectationsWereMet(); err != nil { 108 | t.Errorf("there were unfulfilled expectations: %s", err) 109 | } 110 | } 111 | 112 | func TestMockCopyFrom(t *testing.T) { 113 | t.Parallel() 114 | mock, _ := NewConn() 115 | a := assert.New(t) 116 | mock.ExpectCopyFrom(pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}). 117 | WillReturnResult(2).WillDelayFor(1 * time.Second) 118 | 119 | res, err := mock.CopyFrom(context.Background(), pgx.Identifier{"error", "error"}, []string{"error"}, nil) 120 | a.Error(err, "incorrect table should raise an error") 121 | a.EqualValues(res, -1) 122 | a.Error(mock.ExpectationsWereMet(), "there must be unfulfilled expectations") 123 | 124 | res, err = mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"error"}, nil) 125 | a.Error(err, "incorrect columns should raise an error") 126 | a.EqualValues(res, -1) 127 | a.Error(mock.ExpectationsWereMet(), "there must be unfulfilled expectations") 128 | 129 | cfs := pgx.CopyFromRows([][]any{{"foo"}, {"bar"}}) 130 | res, err = mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}, cfs) 131 | a.NoError(err) 132 | a.EqualValues(res, 2) 133 | 134 | mock.ExpectCopyFrom(pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}). 135 | WillReturnError(errors.New("error is here")) 136 | 137 | _, err = mock.CopyFrom(context.Background(), pgx.Identifier{"fooschema", "baztable"}, []string{"col1"}, cfs) 138 | a.Error(err) 139 | 140 | a.NoError(mock.ExpectationsWereMet()) 141 | } 142 | 143 | func TestMockQueryTypes(t *testing.T) { 144 | t.Parallel() 145 | mock, err := NewConn() 146 | if err != nil { 147 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 148 | } 149 | defer mock.Close(context.Background()) 150 | 151 | columns := []string{"id", "timestamp", "sold"} 152 | 153 | timestamp := time.Now() 154 | rs := NewRows(columns) 155 | rs.AddRow(5, timestamp, true) 156 | 157 | mock.ExpectQuery("SELECT (.+) FROM sales WHERE id = ?"). 158 | WithArgs(5). 159 | WillReturnRows(rs) 160 | 161 | rows, err := mock.Query(context.Background(), "SELECT (.+) FROM sales WHERE id = ?", 5) 162 | if err != nil { 163 | t.Fatalf("error '%s' was not expected while retrieving mock rows", err) 164 | } 165 | defer rows.Close() 166 | if !rows.Next() { 167 | t.Error("it must have had one row as result, but got empty result set instead") 168 | } 169 | 170 | var id int 171 | var time time.Time 172 | var sold bool 173 | 174 | err = rows.Scan(&id, &time, &sold) 175 | if err != nil { 176 | t.Errorf("error '%s' was not expected while trying to scan row", err) 177 | } 178 | 179 | if id != 5 { 180 | t.Errorf("expected mocked id to be 5, but got %d instead", id) 181 | } 182 | 183 | if time != timestamp { 184 | t.Errorf("expected mocked time to be %s, but got '%s' instead", timestamp, time) 185 | } 186 | 187 | if sold != true { 188 | t.Errorf("expected mocked boolean to be true, but got %v instead", sold) 189 | } 190 | 191 | if err := mock.ExpectationsWereMet(); err != nil { 192 | t.Errorf("there were unfulfilled expectations: %s", err) 193 | } 194 | } 195 | 196 | func TestTransactionExpectations(t *testing.T) { 197 | t.Parallel() 198 | mock, _ := NewConn() 199 | a := assert.New(t) 200 | 201 | // begin and commit 202 | mock.ExpectBegin() 203 | mock.ExpectCommit() 204 | 205 | tx, err := mock.Begin(ctx) 206 | a.NoError(err) 207 | err = tx.Commit(ctx) 208 | a.NoError(err) 209 | 210 | // beginTx and commit 211 | mock.ExpectBeginTx(pgx.TxOptions{AccessMode: pgx.ReadOnly}) 212 | mock.ExpectCommit() 213 | 214 | _, err = mock.BeginTx(ctx, pgx.TxOptions{}) 215 | a.Error(err, "wrong tx access mode should raise error") 216 | 217 | tx, err = mock.BeginTx(ctx, pgx.TxOptions{AccessMode: pgx.ReadOnly}) 218 | a.NoError(err) 219 | err = tx.Commit(ctx) 220 | a.NoError(err) 221 | 222 | // begin and rollback 223 | mock.ExpectBegin() 224 | mock.ExpectRollback() 225 | 226 | tx, err = mock.Begin(ctx) 227 | a.NoError(err) 228 | err = tx.Rollback(ctx) 229 | a.NoError(err) 230 | 231 | // begin with an error 232 | mock.ExpectBegin().WillReturnError(errors.New("some err")) 233 | 234 | _, err = mock.Begin(ctx) 235 | a.Error(err) 236 | 237 | a.NoError(mock.ExpectationsWereMet()) 238 | } 239 | 240 | func TestPrepareExpectations(t *testing.T) { 241 | t.Parallel() 242 | mock, _ := NewConn() 243 | a := assert.New(t) 244 | expErr := errors.New("invaders must die") 245 | mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?"). 246 | WillDelayFor(1 * time.Second) 247 | mock.ExpectDeallocate("foo").WillReturnError(expErr) 248 | 249 | stmt, err := mock.Prepare(context.Background(), "baz", "SELECT (.+) FROM articles WHERE id = ?") 250 | a.Error(err, "wrong prepare stmt name should raise an error") 251 | a.Nil(stmt) 252 | 253 | stmt, err = mock.Prepare(context.Background(), "foo", "SELECT (.+) FROM articles WHERE id = $1") 254 | a.NoError(err) 255 | a.NotNil(stmt) 256 | err = mock.Deallocate(context.Background(), "foo") 257 | a.EqualError(err, expErr.Error()) 258 | 259 | // expect something else, w/o ExpectPrepare() 260 | var id int 261 | var title string 262 | rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") 263 | 264 | mock.ExpectQuery("foo"). 265 | WithArgs(5). 266 | WillReturnRows(rs) 267 | 268 | err = mock.QueryRow(context.Background(), "foo", 5).Scan(&id, &title) 269 | a.NoError(err) 270 | 271 | mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?"). 272 | WillReturnError(fmt.Errorf("Some DB error occurred")) 273 | 274 | stmt, err = mock.Prepare(context.Background(), "foo", "SELECT id FROM articles WHERE id = $1") 275 | a.Error(err) 276 | a.Nil(stmt) 277 | a.NoError(mock.ExpectationsWereMet()) 278 | } 279 | 280 | func TestPreparedQueryExecutions(t *testing.T) { 281 | t.Parallel() 282 | mock, err := NewConn() 283 | if err != nil { 284 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 285 | } 286 | defer mock.Close(context.Background()) 287 | 288 | mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?") 289 | 290 | rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world") 291 | mock.ExpectQuery("foo"). 292 | WithArgs(5). 293 | WillReturnRows(rs1) 294 | 295 | rs2 := NewRows([]string{"id", "title"}).AddRow(2, "whoop") 296 | mock.ExpectQuery("foo"). 297 | WithArgs(2). 298 | WillReturnRows(rs2) 299 | 300 | _, err = mock.Prepare(context.Background(), "foo", "SELECT id, title FROM articles WHERE id = ?") 301 | if err != nil { 302 | t.Fatalf("error '%s' was not expected while creating a prepared statement", err) 303 | } 304 | 305 | var id int 306 | var title string 307 | err = mock.QueryRow(context.Background(), "foo", 5).Scan(&id, &title) 308 | if err != nil { 309 | t.Errorf("error '%s' was not expected querying row from statement and scanning", err) 310 | } 311 | 312 | if id != 5 { 313 | t.Errorf("expected mocked id to be 5, but got %d instead", id) 314 | } 315 | 316 | if title != "hello world" { 317 | t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title) 318 | } 319 | 320 | err = mock.QueryRow(context.Background(), "foo", 2).Scan(&id, &title) 321 | if err != nil { 322 | t.Errorf("error '%s' was not expected querying row from statement and scanning", err) 323 | } 324 | 325 | if id != 2 { 326 | t.Errorf("expected mocked id to be 2, but got %d instead", id) 327 | } 328 | 329 | if title != "whoop" { 330 | t.Errorf("expected mocked title to be 'whoop', but got '%s' instead", title) 331 | } 332 | 333 | if err := mock.ExpectationsWereMet(); err != nil { 334 | t.Errorf("there were unfulfilled expectations: %s", err) 335 | } 336 | } 337 | 338 | func TestUnorderedPreparedQueryExecutions(t *testing.T) { 339 | t.Parallel() 340 | mock, err := NewConn() 341 | if err != nil { 342 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 343 | } 344 | defer mock.Close(context.Background()) 345 | 346 | mock.MatchExpectationsInOrder(false) 347 | 348 | mock.ExpectPrepare("articles_stmt", "SELECT (.+) FROM articles WHERE id = ?") 349 | mock.ExpectQuery("articles_stmt"). 350 | WithArgs(5). 351 | WillReturnRows(NewRows([]string{"id", "title"}).AddRow(5, "The quick brown fox")) 352 | mock.ExpectPrepare("authors_stmt", "SELECT (.+) FROM authors WHERE id = ?") 353 | mock.ExpectQuery("authors_stmt"). 354 | WithArgs(1). 355 | WillReturnRows(NewRows([]string{"id", "title"}).AddRow(1, "Betty B.")) 356 | 357 | var id int 358 | var name string 359 | 360 | _, err = mock.Prepare(context.Background(), "authors_stmt", "SELECT id, name FROM authors WHERE id = ?") 361 | if err != nil { 362 | t.Fatalf("error '%s' was not expected while creating a prepared statement", err) 363 | } 364 | 365 | err = mock.QueryRow(context.Background(), "authors_stmt", 1).Scan(&id, &name) 366 | if err != nil { 367 | t.Errorf("error '%s' was not expected querying row from statement and scanning", err) 368 | } 369 | 370 | if name != "Betty B." { 371 | t.Errorf("expected mocked name to be 'Betty B.', but got '%s' instead", name) 372 | } 373 | } 374 | 375 | func TestUnexpectedOperations(t *testing.T) { 376 | t.Parallel() 377 | mock, err := NewConn() 378 | if err != nil { 379 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 380 | } 381 | defer mock.Close(context.Background()) 382 | 383 | mock.ExpectPrepare("foo", "SELECT (.+) FROM articles WHERE id = ?") 384 | _, err = mock.Prepare(context.Background(), "foo", "SELECT id, title FROM articles WHERE id = ?") 385 | if err != nil { 386 | t.Fatalf("error '%s' was not expected while creating a prepared statement", err) 387 | } 388 | 389 | var id int 390 | var title string 391 | 392 | err = mock.QueryRow(context.Background(), "foo", 5).Scan(&id, &title) 393 | if err == nil { 394 | t.Error("error was expected querying row, since there was no such expectation") 395 | } 396 | 397 | mock.ExpectRollback() 398 | 399 | if err := mock.ExpectationsWereMet(); err == nil { 400 | t.Errorf("was expecting an error since query was not triggered") 401 | } 402 | } 403 | 404 | func TestWrongExpectations(t *testing.T) { 405 | t.Parallel() 406 | mock, err := NewConn() 407 | if err != nil { 408 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 409 | } 410 | defer mock.Close(context.Background()) 411 | 412 | mock.ExpectBegin() 413 | 414 | rs1 := NewRows([]string{"id", "title"}).AddRow(5, "hello world") 415 | mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). 416 | WithArgs(5). 417 | WillReturnRows(rs1) 418 | 419 | mock.ExpectCommit().WillReturnError(fmt.Errorf("deadlock occurred")) 420 | mock.ExpectRollback() // won't be triggered 421 | 422 | var id int 423 | var title string 424 | 425 | err = mock.QueryRow(context.Background(), "SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title) 426 | if err == nil { 427 | t.Error("error was expected while querying row, since there begin transaction expectation is not fulfilled") 428 | } 429 | 430 | // lets go around and start transaction 431 | tx, err := mock.Begin(context.Background()) 432 | if err != nil { 433 | t.Errorf("an error '%s' was not expected when beginning a transaction", err) 434 | } 435 | 436 | err = mock.QueryRow(context.Background(), "SELECT id, title FROM articles WHERE id = ? FOR UPDATE", 5).Scan(&id, &title) 437 | if err != nil { 438 | t.Errorf("error '%s' was not expected while querying row, since transaction was started", err) 439 | } 440 | 441 | err = tx.Commit(context.Background()) 442 | if err == nil { 443 | t.Error("a deadlock error was expected when committing a transaction", err) 444 | } 445 | 446 | if err := mock.ExpectationsWereMet(); err == nil { 447 | t.Errorf("was expecting an error since query was not triggered") 448 | } 449 | } 450 | 451 | func TestExecExpectations(t *testing.T) { 452 | t.Parallel() 453 | mock, err := NewConn() 454 | if err != nil { 455 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 456 | } 457 | defer mock.Close(context.Background()) 458 | 459 | result := NewResult("INSERT", 1) 460 | mock.ExpectExec("^INSERT INTO articles"). 461 | WithArgs("hello"). 462 | WillReturnResult(result) 463 | 464 | res, err := mock.Exec(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") 465 | if err != nil { 466 | t.Fatalf("error '%s' was not expected, while inserting a row", err) 467 | } 468 | 469 | if res.RowsAffected() != 1 { 470 | t.Errorf("expected affected rows to be 1, but got %d instead", res.RowsAffected()) 471 | } 472 | 473 | if err := mock.ExpectationsWereMet(); err != nil { 474 | t.Errorf("there were unfulfilled expectations: %s", err) 475 | } 476 | } 477 | 478 | type NullTime struct { 479 | Time time.Time 480 | Valid bool // Valid is true if Time is not NULL 481 | } 482 | 483 | type NullInt struct { 484 | Integer int 485 | Valid bool 486 | } 487 | 488 | // Satisfy sql.Scanner interface 489 | func (ni *NullInt) Scan(value interface{}) error { 490 | switch v := value.(type) { 491 | case nil: 492 | ni.Integer, ni.Valid = 0, false 493 | case int64: 494 | const maxUint = ^uint(0) 495 | const maxInt = int(maxUint >> 1) 496 | const minInt = -maxInt - 1 497 | 498 | if v > int64(maxInt) || v < int64(minInt) { 499 | return errors.New("value out of int range") 500 | } 501 | ni.Integer, ni.Valid = int(v), true 502 | case []byte: 503 | n, err := strconv.Atoi(string(v)) 504 | if err != nil { 505 | return err 506 | } 507 | ni.Integer, ni.Valid = n, true 508 | case string: 509 | n, err := strconv.Atoi(v) 510 | if err != nil { 511 | return err 512 | } 513 | ni.Integer, ni.Valid = n, true 514 | default: 515 | return fmt.Errorf("can't convert %T to integer", value) 516 | } 517 | return nil 518 | } 519 | 520 | // Satisfy sql.Valuer interface. 521 | func (ni NullInt) Value() (interface{}, error) { 522 | if !ni.Valid { 523 | return nil, nil 524 | } 525 | return int64(ni.Integer), nil 526 | } 527 | 528 | // Satisfy sql.Scanner interface 529 | func (nt *NullTime) Scan(value interface{}) error { 530 | switch v := value.(type) { 531 | case nil: 532 | nt.Time, nt.Valid = time.Time{}, false 533 | case time.Time: 534 | nt.Time, nt.Valid = v, true 535 | default: 536 | return fmt.Errorf("can't convert %T to time.Time", value) 537 | } 538 | return nil 539 | } 540 | 541 | // Satisfy sql.Valuer interface. 542 | func (nt NullTime) Value() (interface{}, error) { 543 | if !nt.Valid { 544 | return nil, nil 545 | } 546 | return nt.Time, nil 547 | } 548 | 549 | func TestRowBuilderAndNilTypes(t *testing.T) { 550 | t.Parallel() 551 | mock, err := NewConn() 552 | if err != nil { 553 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 554 | } 555 | defer mock.Close(context.Background()) 556 | 557 | rs := NewRows([]string{"id", "active", "created", "status"}). 558 | AddRow(1, true, NullTime{time.Now(), true}, NullInt{5, true}). 559 | AddRow(2, false, NullTime{Valid: false}, NullInt{Valid: false}) 560 | 561 | mock.ExpectQuery("SELECT (.+) FROM sales").WillReturnRows(rs) 562 | 563 | rows, err := mock.Query(context.Background(), "SELECT * FROM sales") 564 | if err != nil { 565 | t.Fatalf("error '%s' was not expected while retrieving mock rows", err) 566 | } 567 | defer rows.Close() 568 | 569 | type boolAlias bool 570 | 571 | // NullTime and NullInt are used from stubs_test.go 572 | var ( 573 | id int 574 | active boolAlias 575 | created NullTime 576 | status NullInt 577 | ) 578 | 579 | if !rows.Next() { 580 | t.Fatal("it must have had row in rows, but got empty result set instead") 581 | } 582 | 583 | err = rows.Scan(&id, &active, &created, &status) 584 | if err != nil { 585 | t.Errorf("error '%s' was not expected while trying to scan row", err) 586 | } 587 | 588 | if id != 1 { 589 | t.Errorf("expected mocked id to be 1, but got %d instead", id) 590 | } 591 | 592 | if !active { 593 | t.Errorf("expected 'active' to be 'true', but got '%v' instead", active) 594 | } 595 | 596 | if !created.Valid { 597 | t.Errorf("expected 'created' to be valid, but it %+v is not", created) 598 | } 599 | 600 | if !status.Valid { 601 | t.Errorf("expected 'status' to be valid, but it %+v is not", status) 602 | } 603 | 604 | if status.Integer != 5 { 605 | t.Errorf("expected 'status' to be '5', but got '%d'", status.Integer) 606 | } 607 | 608 | // test second row 609 | if !rows.Next() { 610 | t.Fatal("it must have had row in rows, but got empty result set instead") 611 | } 612 | 613 | err = rows.Scan(&id, &active, &created, &status) 614 | if err != nil { 615 | t.Errorf("error '%s' was not expected while trying to scan row", err) 616 | } 617 | 618 | if id != 2 { 619 | t.Errorf("expected mocked id to be 2, but got %d instead", id) 620 | } 621 | 622 | if active { 623 | t.Errorf("expected 'active' to be 'false', but got '%v' instead", active) 624 | } 625 | 626 | if created.Valid { 627 | t.Errorf("expected 'created' to be invalid, but it %+v is not", created) 628 | } 629 | 630 | if status.Valid { 631 | t.Errorf("expected 'status' to be invalid, but it %+v is not", status) 632 | } 633 | 634 | if err := mock.ExpectationsWereMet(); err != nil { 635 | t.Errorf("there were unfulfilled expectations: %s", err) 636 | } 637 | } 638 | 639 | func TestArgumentReflectValueTypeError(t *testing.T) { 640 | t.Parallel() 641 | mock, err := NewConn() 642 | if err != nil { 643 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 644 | } 645 | defer mock.Close(context.Background()) 646 | 647 | rs := NewRows([]string{"id"}).AddRow(1) 648 | 649 | mock.ExpectQuery("SELECT (.+) FROM sales").WithArgs(5.5).WillReturnRows(rs) 650 | 651 | _, err = mock.Query(context.Background(), "SELECT * FROM sales WHERE x = ?", 5) 652 | if err == nil { 653 | t.Error("expected error, but got none") 654 | } 655 | } 656 | 657 | func TestGoroutineExecutionWithUnorderedExpectationMatching(t *testing.T) { 658 | t.Parallel() 659 | mock, err := NewConn() 660 | if err != nil { 661 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 662 | } 663 | defer mock.Close(context.Background()) 664 | 665 | // note this line is important for unordered expectation matching 666 | mock.MatchExpectationsInOrder(false) 667 | 668 | result := NewResult("UPDATE", 1) 669 | 670 | mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result) 671 | mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result) 672 | mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result) 673 | 674 | var wg sync.WaitGroup 675 | queries := map[string][]interface{}{ 676 | "one": {"one"}, 677 | "two": {"one", "two"}, 678 | "three": {"one", "two", "three"}, 679 | } 680 | 681 | wg.Add(len(queries)) 682 | for table, args := range queries { 683 | go func(tbl string, a []interface{}) { 684 | if _, err := mock.Exec(context.Background(), "UPDATE "+tbl, a...); err != nil { 685 | t.Errorf("error was not expected: %s", err) 686 | } 687 | wg.Done() 688 | }(table, args) 689 | } 690 | 691 | wg.Wait() 692 | 693 | if err := mock.ExpectationsWereMet(); err != nil { 694 | t.Errorf("there were unfulfilled expectations: %s", err) 695 | } 696 | } 697 | 698 | // func Test_goroutines() { 699 | // mock, err := NewConn() 700 | // if err != nil { 701 | // fmt.Println("failed to open pgxmock database:", err) 702 | // } 703 | // defer mock.Close(context.Background()) 704 | 705 | // // note this line is important for unordered expectation matching 706 | // mock.MatchExpectationsInOrder(false) 707 | 708 | // result := NewResult("UPDATE", 1) 709 | 710 | // mock.ExpectExec("^UPDATE one").WithArgs("one").WillReturnResult(result) 711 | // mock.ExpectExec("^UPDATE two").WithArgs("one", "two").WillReturnResult(result) 712 | // mock.ExpectExec("^UPDATE three").WithArgs("one", "two", "three").WillReturnResult(result) 713 | 714 | // var wg sync.WaitGroup 715 | // queries := map[string][]interface{}{ 716 | // "one": {"one"}, 717 | // "two": {"one", "two"}, 718 | // "three": {"one", "two", "three"}, 719 | // } 720 | 721 | // wg.Add(len(queries)) 722 | // for table, args := range queries { 723 | // go func(tbl string, a []interface{}) { 724 | // if _, err := mock.Exec(context.Background(), "UPDATE "+tbl, a...); err != nil { 725 | // fmt.Println("error was not expected:", err) 726 | // } 727 | // wg.Done() 728 | // }(table, args) 729 | // } 730 | 731 | // wg.Wait() 732 | 733 | // if err := mock.ExpectationsWereMet(); err != nil { 734 | // fmt.Println("there were unfulfilled expectations:", err) 735 | // } 736 | // // Output: 737 | // } 738 | 739 | // False Positive - passes despite mismatched Exec 740 | // see #37 issue 741 | func TestRunExecsWithOrderedShouldNotMeetAllExpectations(t *testing.T) { 742 | dbmock, _ := NewConn() 743 | dbmock.ExpectExec("THE FIRST EXEC") 744 | dbmock.ExpectExec("THE SECOND EXEC") 745 | 746 | _, _ = dbmock.Exec(context.Background(), "THE FIRST EXEC") 747 | _, _ = dbmock.Exec(context.Background(), "THE WRONG EXEC") 748 | 749 | err := dbmock.ExpectationsWereMet() 750 | if err == nil { 751 | t.Fatal("was expecting an error, but there wasn't any") 752 | } 753 | } 754 | 755 | // False Positive - passes despite mismatched Exec 756 | // see #37 issue 757 | func TestRunQueriesWithOrderedShouldNotMeetAllExpectations(t *testing.T) { 758 | dbmock, _ := NewConn() 759 | dbmock.ExpectQuery("THE FIRST QUERY") 760 | dbmock.ExpectQuery("THE SECOND QUERY") 761 | 762 | _, _ = dbmock.Query(context.Background(), "THE FIRST QUERY") 763 | _, _ = dbmock.Query(context.Background(), "THE WRONG QUERY") 764 | 765 | err := dbmock.ExpectationsWereMet() 766 | if err == nil { 767 | t.Fatal("was expecting an error, but there wasn't any") 768 | } 769 | } 770 | 771 | func TestRunExecsWithExpectedErrorMeetsExpectations(t *testing.T) { 772 | dbmock, _ := NewConn() 773 | dbmock.ExpectExec("THE FIRST EXEC").WillReturnError(fmt.Errorf("big bad bug")) 774 | dbmock.ExpectExec("THE SECOND EXEC").WillReturnResult(NewResult("UPDATE", 0)) 775 | 776 | _, _ = dbmock.Exec(context.Background(), "THE FIRST EXEC") 777 | _, _ = dbmock.Exec(context.Background(), "THE SECOND EXEC") 778 | 779 | err := dbmock.ExpectationsWereMet() 780 | if err != nil { 781 | t.Fatalf("all expectations should be met: %s", err) 782 | } 783 | } 784 | 785 | func TestRunQueryWithExpectedErrorMeetsExpectations(t *testing.T) { 786 | dbmock, _ := NewConn() 787 | dbmock.ExpectQuery("THE FIRST QUERY").WillReturnError(fmt.Errorf("big bad bug")) 788 | dbmock.ExpectQuery("THE SECOND QUERY").WillReturnRows(NewRows([]string{"col"}).AddRow(1)) 789 | 790 | _, _ = dbmock.Query(context.Background(), "THE FIRST QUERY") 791 | _, _ = dbmock.Query(context.Background(), "THE SECOND QUERY") 792 | 793 | err := dbmock.ExpectationsWereMet() 794 | if err != nil { 795 | t.Fatalf("all expectations should be met: %s", err) 796 | } 797 | } 798 | 799 | func TestEmptyRowSet(t *testing.T) { 800 | t.Parallel() 801 | mock, err := NewConn() 802 | if err != nil { 803 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 804 | } 805 | defer mock.Close(context.Background()) 806 | 807 | rs := NewRows([]string{"id", "title"}) 808 | 809 | mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). 810 | WithArgs(5). 811 | WillReturnRows(rs) 812 | 813 | rows, err := mock.Query(context.Background(), "SELECT (.+) FROM articles WHERE id = ?", 5) 814 | if err != nil { 815 | t.Fatalf("error '%s' was not expected while retrieving mock rows", err) 816 | } 817 | defer rows.Close() 818 | 819 | if rows.Next() { 820 | t.Error("expected no rows but got one") 821 | } 822 | 823 | err = mock.ExpectationsWereMet() 824 | if err != nil { 825 | t.Fatalf("all expectations should be met: %s", err) 826 | } 827 | } 828 | 829 | // Based on issue #50 830 | func TestPrepareExpectationNotFulfilled(t *testing.T) { 831 | t.Parallel() 832 | mock, err := NewConn() 833 | if err != nil { 834 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 835 | } 836 | defer mock.Close(context.Background()) 837 | 838 | mock.ExpectPrepare("foo", "^BADSELECT$") 839 | 840 | if _, err := mock.Prepare(context.Background(), "foo", "SELECT"); err == nil { 841 | t.Fatal("prepare should not match expected query string") 842 | } 843 | 844 | if err := mock.ExpectationsWereMet(); err == nil { 845 | t.Errorf("was expecting an error, since prepared statement query does not match, but there was none") 846 | } 847 | } 848 | 849 | func TestRollbackThrow(t *testing.T) { 850 | // Open new mock database 851 | mock, err := NewConn() 852 | if err != nil { 853 | fmt.Println("error creating mock database") 854 | return 855 | } 856 | // columns to be used for result 857 | columns := []string{"id", "status"} 858 | // expect transaction begin 859 | mock.ExpectBegin() 860 | // expect query to fetch order, match it with regexp 861 | mock.ExpectQuery("SELECT (.+) FROM orders (.+) FOR UPDATE"). 862 | WithArgs(1). 863 | WillReturnRows(NewRows(columns).AddRow(1, 1)) 864 | // expect transaction rollback, since order status is "cancelled" 865 | mock.ExpectRollback().WillReturnError(fmt.Errorf("rollback failed")) 866 | 867 | // run the cancel order function 868 | someOrderID := 1 869 | // call a function which executes expected database operations 870 | err = cancelOrder(mock, someOrderID) 871 | if err == nil { 872 | t.Error("an error was expected when rolling back transaction, but got none") 873 | } 874 | 875 | // ensure all expectations have been met 876 | if err = mock.ExpectationsWereMet(); err != nil { 877 | t.Errorf("unmet expectation error: %s", err) 878 | } 879 | // Output: 880 | } 881 | 882 | func TestUnexpectedBegin(t *testing.T) { 883 | // Open new mock database 884 | mock, err := NewConn() 885 | if err != nil { 886 | fmt.Println("error creating mock database") 887 | return 888 | } 889 | if _, err := mock.Begin(context.Background()); err == nil { 890 | t.Error("an error was expected when calling begin, but got none") 891 | } 892 | } 893 | 894 | func TestUnexpectedExec(t *testing.T) { 895 | // Open new mock database 896 | mock, err := NewConn() 897 | if err != nil { 898 | fmt.Println("error creating mock database") 899 | return 900 | } 901 | mock.ExpectBegin() 902 | _, _ = mock.Begin(context.Background()) 903 | if _, err := mock.Exec(context.Background(), "SELECT 1"); err == nil { 904 | t.Error("an error was expected when calling exec, but got none") 905 | } 906 | } 907 | 908 | func TestUnexpectedCommit(t *testing.T) { 909 | // Open new mock database 910 | mock, err := NewConn() 911 | if err != nil { 912 | fmt.Println("error creating mock database") 913 | return 914 | } 915 | mock.ExpectBegin() 916 | tx, _ := mock.Begin(context.Background()) 917 | if err := tx.Commit(context.Background()); err == nil { 918 | t.Error("an error was expected when calling commit, but got none") 919 | } 920 | } 921 | 922 | func TestUnexpectedCommitOrder(t *testing.T) { 923 | // Open new mock database 924 | mock, err := NewConn() 925 | if err != nil { 926 | fmt.Println("error creating mock database") 927 | return 928 | } 929 | mock.ExpectBegin() 930 | mock.ExpectRollback().WillReturnError(fmt.Errorf("Rollback failed")) 931 | tx, _ := mock.Begin(context.Background()) 932 | if err := tx.Commit(context.Background()); err == nil { 933 | t.Error("an error was expected when calling commit, but got none") 934 | } 935 | } 936 | 937 | func TestExpectedCommitOrder(t *testing.T) { 938 | // Open new mock database 939 | mock, err := NewConn() 940 | if err != nil { 941 | fmt.Println("error creating mock database") 942 | return 943 | } 944 | mock.ExpectCommit().WillReturnError(fmt.Errorf("Commit failed")) 945 | if _, err := mock.Begin(context.Background()); err == nil { 946 | t.Error("an error was expected when calling begin, but got none") 947 | } 948 | } 949 | 950 | func TestUnexpectedRollback(t *testing.T) { 951 | // Open new mock database 952 | mock, err := NewConn() 953 | if err != nil { 954 | fmt.Println("error creating mock database") 955 | return 956 | } 957 | mock.ExpectBegin() 958 | tx, _ := mock.Begin(context.Background()) 959 | if err := tx.Rollback(context.Background()); err == nil { 960 | t.Error("an error was expected when calling rollback, but got none") 961 | } 962 | } 963 | 964 | func TestUnexpectedRollbackOrder(t *testing.T) { 965 | // Open new mock database 966 | mock, err := NewConn() 967 | if err != nil { 968 | fmt.Println("error creating mock database") 969 | return 970 | } 971 | mock.ExpectBegin() 972 | 973 | tx, _ := mock.Begin(context.Background()) 974 | if err := tx.Rollback(context.Background()); err == nil { 975 | t.Error("an error was expected when calling rollback, but got none") 976 | } 977 | } 978 | 979 | func TestPrepareExec(t *testing.T) { 980 | // Open new mock database 981 | mock, err := NewConn() 982 | if err != nil { 983 | fmt.Println("error creating mock database") 984 | return 985 | } 986 | defer mock.Close(context.Background()) 987 | mock.ExpectBegin() 988 | mock.ExpectPrepare("foo", "INSERT INTO ORDERS\\(ID, STATUS\\) VALUES \\(\\?, \\?\\)") 989 | for i := 0; i < 3; i++ { 990 | mock.ExpectExec("foo").WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1)) 991 | } 992 | mock.ExpectCommit() 993 | tx, _ := mock.Begin(context.Background()) 994 | _, err = tx.Prepare(context.Background(), "foo", "INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") 995 | if err != nil { 996 | t.Fatal(err) 997 | } 998 | for i := 0; i < 3; i++ { 999 | _, err := mock.Exec(context.Background(), "foo", i, "Hello"+strconv.Itoa(i)) 1000 | if err != nil { 1001 | t.Fatal(err) 1002 | } 1003 | } 1004 | _ = tx.Commit(context.Background()) 1005 | if err := mock.ExpectationsWereMet(); err != nil { 1006 | t.Errorf("there were unfulfilled expectations: %s", err) 1007 | } 1008 | } 1009 | 1010 | func TestPrepareQuery(t *testing.T) { 1011 | // Open new mock database 1012 | mock, err := NewConn() 1013 | if err != nil { 1014 | fmt.Println("error creating mock database") 1015 | return 1016 | } 1017 | defer mock.Close(context.Background()) 1018 | mock.ExpectBegin() 1019 | mock.ExpectPrepare("foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = \\?") 1020 | mock.ExpectQuery("foo").WithArgs(101).WillReturnRows(NewRows([]string{"ID", "STATUS"}).AddRow(101, "Hello")) 1021 | mock.ExpectCommit() 1022 | tx, _ := mock.Begin(context.Background()) 1023 | _, err = tx.Prepare(context.Background(), "foo", "SELECT ID, STATUS FROM ORDERS WHERE ID = ?") 1024 | if err != nil { 1025 | t.Fatal(err) 1026 | } 1027 | rows, err := mock.Query(context.Background(), "foo", 101) 1028 | if err != nil { 1029 | t.Fatal(err) 1030 | } 1031 | defer rows.Close() 1032 | for rows.Next() { 1033 | var ( 1034 | id int 1035 | status string 1036 | ) 1037 | if _ = rows.Scan(&id, &status); id != 101 || status != "Hello" { 1038 | t.Fatal("wrong query results") 1039 | } 1040 | 1041 | } 1042 | _ = tx.Commit(context.Background()) 1043 | if err := mock.ExpectationsWereMet(); err != nil { 1044 | t.Errorf("there were unfulfilled expectations: %s", err) 1045 | } 1046 | } 1047 | 1048 | func TestExpectedCloseError(t *testing.T) { 1049 | // Open new mock database 1050 | mock, err := NewConn() 1051 | if err != nil { 1052 | fmt.Println("error creating mock database") 1053 | return 1054 | } 1055 | mock.ExpectClose().WillReturnError(fmt.Errorf("Close failed")) 1056 | if err := mock.Close(context.Background()); err == nil { 1057 | t.Error("an error was expected when calling close, but got none") 1058 | } 1059 | if err := mock.ExpectationsWereMet(); err != nil { 1060 | t.Errorf("there were unfulfilled expectations: %s", err) 1061 | } 1062 | } 1063 | 1064 | func TestExpectedCloseOrder(t *testing.T) { 1065 | // Open new mock database 1066 | mock, err := NewConn() 1067 | if err != nil { 1068 | fmt.Println("error creating mock database") 1069 | return 1070 | } 1071 | defer mock.Close(context.Background()) 1072 | mock.ExpectClose().WillReturnError(fmt.Errorf("Close failed")) 1073 | t.Log() 1074 | _, _ = mock.Begin(context.Background()) 1075 | if err := mock.ExpectationsWereMet(); err == nil { 1076 | t.Error("expected error on ExpectationsWereMet") 1077 | } 1078 | } 1079 | 1080 | func TestExpectedBeginOrder(t *testing.T) { 1081 | // Open new mock database 1082 | mock, err := NewConn() 1083 | if err != nil { 1084 | fmt.Println("error creating mock database") 1085 | return 1086 | } 1087 | mock.ExpectBegin().WillReturnError(fmt.Errorf("Begin failed")) 1088 | if err := mock.Close(context.Background()); err == nil { 1089 | t.Error("an error was expected when calling close, but got none") 1090 | } 1091 | } 1092 | 1093 | func TestPreparedStatementCloseExpectation(t *testing.T) { 1094 | t.Parallel() 1095 | mock, _ := NewConn() 1096 | a := assert.New(t) 1097 | 1098 | mock.ExpectPrepare("foo", "INSERT INTO ORDERS") 1099 | mock.ExpectExec("foo").WithArgs(AnyArg(), AnyArg()).WillReturnResult(NewResult("UPDATE", 1)) 1100 | mock.ExpectDeallocate("foo") 1101 | mock.ExpectDeallocateAll() 1102 | 1103 | stmt, err := mock.Prepare(context.Background(), "foo", "INSERT INTO ORDERS(ID, STATUS) VALUES (?, ?)") 1104 | a.NoError(err) 1105 | a.NotNil(stmt) 1106 | 1107 | _, err = mock.Exec(context.Background(), "foo", 1, "Hello") 1108 | a.NoError(err) 1109 | 1110 | err = mock.Deallocate(context.Background(), "baz") 1111 | a.Error(err, "wrong prepares stmt name should raise an error") 1112 | 1113 | err = mock.DeallocateAll(context.Background()) 1114 | a.Error(err, "we're expecting one statement deallocation, not all") 1115 | 1116 | err = mock.Ping(context.Background()) 1117 | a.Error(err, "ping should raise an error, we're expecting deallocate") 1118 | 1119 | err = mock.Deallocate(context.Background(), "foo") 1120 | a.NoError(err) 1121 | 1122 | err = mock.Ping(context.Background()) 1123 | a.Error(err, "ping should raise an error, we're expecting deallocate") 1124 | 1125 | err = mock.Deallocate(context.Background(), "baz") 1126 | a.Error(err, "wrong prepares stmt name should raise an error") 1127 | 1128 | err = mock.DeallocateAll(context.Background()) 1129 | a.NoError(err) 1130 | 1131 | if err := mock.ExpectationsWereMet(); err != nil { 1132 | t.Errorf("there were unfulfilled expectations: %s", err) 1133 | } 1134 | } 1135 | 1136 | func TestExecExpectationErrorDelay(t *testing.T) { 1137 | t.Parallel() 1138 | mock, err := NewConn() 1139 | if err != nil { 1140 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 1141 | } 1142 | defer mock.Close(context.Background()) 1143 | 1144 | // test that return of error is delayed 1145 | delay := time.Millisecond * 100 1146 | mock.ExpectExec("^INSERT INTO articles").WithArgs(AnyArg()). 1147 | WillDelayFor(delay). 1148 | WillReturnError(errors.New("slow fail")) 1149 | 1150 | start := time.Now() 1151 | res, err := mock.Exec(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") 1152 | stop := time.Now() 1153 | 1154 | if res.String() != "" { 1155 | t.Errorf("result was not expected, was expecting nil") 1156 | } 1157 | 1158 | if err == nil { 1159 | t.Errorf("error was expected, was not expecting nil") 1160 | } 1161 | 1162 | if err.Error() != "slow fail" { 1163 | t.Errorf("error '%s' was not expected, was expecting '%s'", err.Error(), "slow fail") 1164 | } 1165 | 1166 | elapsed := stop.Sub(start) 1167 | if elapsed < delay { 1168 | t.Errorf("expecting a delay of %v before error, actual delay was %v", delay, elapsed) 1169 | } 1170 | 1171 | // also test that return of error is not delayed 1172 | mock.ExpectExec("^INSERT INTO articles").WillReturnError(errors.New("fast fail")) 1173 | 1174 | start = time.Now() 1175 | _, _ = mock.Exec(context.Background(), "INSERT INTO articles (title) VALUES (?)", "hello") 1176 | stop = time.Now() 1177 | 1178 | elapsed = stop.Sub(start) 1179 | if elapsed > delay { 1180 | t.Errorf("expecting a delay of less than %v before error, actual delay was %v", delay, elapsed) 1181 | } 1182 | } 1183 | 1184 | func TestOptionsFail(t *testing.T) { 1185 | t.Parallel() 1186 | expected := errors.New("failing option") 1187 | option := func(*pgxmock) error { 1188 | return expected 1189 | } 1190 | mock, err := NewConn(option) 1191 | defer func() { _ = mock.Close(context.Background()) }() 1192 | if err == nil { 1193 | t.Errorf("missing expecting error '%s' when opening a stub database connection", expected) 1194 | } 1195 | } 1196 | 1197 | func TestNewRows(t *testing.T) { 1198 | t.Parallel() 1199 | mock, err := NewConn() 1200 | if err != nil { 1201 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 1202 | } 1203 | defer mock.Close(context.Background()) 1204 | columns := []string{"col1", "col2"} 1205 | 1206 | r := mock.NewRows(columns) 1207 | if len(r.defs) != len(columns) || string(r.defs[0].Name) != columns[0] || string(r.defs[1].Name) != columns[1] { 1208 | t.Errorf("expecting to create a row with columns %v, actual colmns are %v", r.defs, columns) 1209 | } 1210 | } 1211 | 1212 | // This is actually a test of ExpectationsWereMet. Without a lock around e.fulfilled() inside 1213 | // ExpectationWereMet, the race detector complains if e.triggered is being read while it is also 1214 | // being written by the query running in another goroutine. 1215 | func TestQueryWithTimeout(t *testing.T) { 1216 | mock, err := NewConn() 1217 | if err != nil { 1218 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 1219 | } 1220 | defer mock.Close(context.Background()) 1221 | 1222 | rs := NewRows([]string{"id", "title"}).FromCSVString("5,hello world") 1223 | 1224 | mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). 1225 | WithArgs(5). 1226 | WillReturnRows(rs). 1227 | WillDelayFor(50 * time.Millisecond) // Query will take longer than timeout 1228 | 1229 | _, err = queryWithTimeout(10*time.Millisecond, mock, "SELECT (.+) FROM articles WHERE id = ?", 5) 1230 | if err == nil { 1231 | t.Errorf("expecting query to time out") 1232 | } 1233 | 1234 | if err := mock.ExpectationsWereMet(); err != nil { 1235 | t.Errorf("there were unfulfilled expectations: %s", err) 1236 | } 1237 | } 1238 | 1239 | func queryWithTimeout(t time.Duration, db PgxCommonIface, query string, args ...interface{}) (pgx.Rows, error) { 1240 | rowsChan := make(chan pgx.Rows, 1) 1241 | errChan := make(chan error, 1) 1242 | 1243 | go func() { 1244 | rows, err := db.Query(context.Background(), query, args...) 1245 | if err != nil { 1246 | errChan <- err 1247 | return 1248 | } 1249 | rowsChan <- rows 1250 | }() 1251 | 1252 | select { 1253 | case rows := <-rowsChan: 1254 | return rows, nil 1255 | case err := <-errChan: 1256 | return nil, err 1257 | case <-time.After(t): 1258 | return nil, fmt.Errorf("query timed out after %v", t) 1259 | } 1260 | } 1261 | 1262 | func TestUnmockedMethods(t *testing.T) { 1263 | mock, _ := NewPool() 1264 | a := assert.New(t) 1265 | a.NotNil(mock.Config()) 1266 | a.NotNil(mock.AsConn().Config()) 1267 | a.NotNil(mock.AcquireAllIdle(ctx)) 1268 | a.Nil(mock.AcquireFunc(ctx, func(*pgxpool.Conn) error { return nil })) 1269 | a.Zero(mock.LargeObjects()) 1270 | a.Panics(func() { _ = mock.Conn() }) 1271 | } 1272 | 1273 | func TestNewRowsWithColumnDefinition(t *testing.T) { 1274 | mock, _ := NewConn() 1275 | a := assert.New(t) 1276 | a.NotNil(mock.PgConn()) 1277 | r := mock.NewRowsWithColumnDefinition(*mock.NewColumn("foo")) 1278 | a.Equal(1, len(r.defs)) 1279 | } 1280 | 1281 | func TestExpectReset(t *testing.T) { 1282 | mock, _ := NewPool() 1283 | a := assert.New(t) 1284 | // Successful scenario 1285 | mock.ExpectReset() 1286 | mock.Reset() 1287 | a.NoError(mock.ExpectationsWereMet()) 1288 | 1289 | // Unsuccessful scenario 1290 | mock.ExpectReset() 1291 | a.Error(mock.ExpectationsWereMet()) 1292 | } 1293 | 1294 | func TestDoubleUnlock(t *testing.T) { 1295 | mock, _ := NewConn() 1296 | mock.MatchExpectationsInOrder(false) 1297 | a := assert.New(t) 1298 | 1299 | mock.ExpectExec("insert").WillReturnResult(NewResult("ok", 1)) 1300 | mock.ExpectExec("update").WillReturnResult(NewResult("ok", 1)) 1301 | 1302 | _, err := mock.Exec(ctx, "foo") 1303 | a.Error(err) 1304 | a.NotPanics(func() { _ = mock.Ping(ctx) }) 1305 | } 1306 | -------------------------------------------------------------------------------- /query.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "fmt" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | var re = regexp.MustCompile(`\s+`) 10 | 11 | // strip out new lines and trim spaces 12 | func stripQuery(q string) (s string) { 13 | return strings.TrimSpace(re.ReplaceAllString(q, " ")) 14 | } 15 | 16 | // QueryMatcher is an SQL query string matcher interface, 17 | // which can be used to customize validation of SQL query strings. 18 | // As an example, external library could be used to build 19 | // and validate SQL ast, columns selected. 20 | // 21 | // pgxmock can be customized to implement a different QueryMatcher 22 | // configured through an option when pgxmock.New or pgxmock.NewWithDSN 23 | // is called, default QueryMatcher is QueryMatcherRegexp. 24 | type QueryMatcher interface { 25 | 26 | // Match expected SQL query string without whitespace to 27 | // actual SQL. 28 | Match(expectedSQL, actualSQL string) error 29 | } 30 | 31 | // QueryMatcherFunc type is an adapter to allow the use of 32 | // ordinary functions as QueryMatcher. If f is a function 33 | // with the appropriate signature, QueryMatcherFunc(f) is a 34 | // QueryMatcher that calls f. 35 | type QueryMatcherFunc func(expectedSQL, actualSQL string) error 36 | 37 | // Match implements the QueryMatcher 38 | func (f QueryMatcherFunc) Match(expectedSQL, actualSQL string) error { 39 | return f(expectedSQL, actualSQL) 40 | } 41 | 42 | // QueryMatcherRegexp is the default SQL query matcher 43 | // used by pgxmock. It parses expectedSQL to a regular 44 | // expression and attempts to match actualSQL. 45 | var QueryMatcherRegexp QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { 46 | expect := stripQuery(expectedSQL) 47 | actual := stripQuery(actualSQL) 48 | re, err := regexp.Compile(expect) 49 | if err != nil { 50 | return err 51 | } 52 | if !re.MatchString(actual) { 53 | return fmt.Errorf(`could not match actual sql: "%s" with expected regexp "%s"`, actual, re.String()) 54 | } 55 | return nil 56 | }) 57 | 58 | // QueryMatcherEqual is the SQL query matcher 59 | // which simply tries a case sensitive match of 60 | // expected and actual SQL strings without whitespace. 61 | var QueryMatcherEqual QueryMatcher = QueryMatcherFunc(func(expectedSQL, actualSQL string) error { 62 | expect := stripQuery(expectedSQL) 63 | actual := stripQuery(actualSQL) 64 | if actual != expect { 65 | return fmt.Errorf(`actual sql: "%s" does not equal to expected "%s"`, actual, expect) 66 | } 67 | return nil 68 | }) 69 | 70 | // QueryMatcherAny is the SQL query matcher 71 | // which always returns nil, used to disable 72 | // SQL query matching. 73 | var QueryMatcherAny QueryMatcher = QueryMatcherFunc(func(_, _ string) error { 74 | return nil 75 | }) 76 | -------------------------------------------------------------------------------- /query_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | func ExampleQueryMatcher() { 10 | // configure to use case sensitive SQL query matcher 11 | // instead of default regular expression matcher 12 | mock, err := NewConn(QueryMatcherOption(QueryMatcherEqual)) 13 | if err != nil { 14 | fmt.Println("failed to open pgxmock database:", err) 15 | } 16 | // defer db.Close() 17 | 18 | rows := NewRows([]string{"id", "title"}). 19 | AddRow(1, "one"). 20 | AddRow(2, "two") 21 | 22 | mock.ExpectQuery("SELECT * FROM users").WillReturnRows(rows) 23 | 24 | rs, err := mock.Query(context.Background(), "SELECT * FROM users") 25 | if err != nil { 26 | fmt.Println("failed to match expected query") 27 | return 28 | } 29 | defer rs.Close() 30 | 31 | for rs.Next() { 32 | var id int 33 | var title string 34 | _ = rs.Scan(&id, &title) 35 | fmt.Println("scanned id:", id, "and title:", title) 36 | } 37 | 38 | if rs.Err() != nil { 39 | fmt.Println("got rows error:", rs.Err()) 40 | } 41 | // Output: scanned id: 1 and title: one 42 | // scanned id: 2 and title: two 43 | } 44 | 45 | func ExampleQueryMatcherAny() { 46 | mock, err := NewConn(QueryMatcherOption(QueryMatcherAny)) 47 | if err != nil { 48 | fmt.Println("failed to open pgxmock database:", err) 49 | } 50 | // defer db.Close() 51 | 52 | rows := NewRows([]string{"id", "title"}). 53 | AddRow(1, "one"). 54 | AddRow(2, "two") 55 | 56 | mock.ExpectQuery("").WillReturnRows(rows) 57 | 58 | rs, err := mock.Query(context.Background(), "SELECT * FROM users") 59 | if err != nil { 60 | fmt.Println("failed to match expected query") 61 | return 62 | } 63 | defer rs.Close() 64 | 65 | for rs.Next() { 66 | var id int 67 | var title string 68 | _ = rs.Scan(&id, &title) 69 | fmt.Println("scanned id:", id, "and title:", title) 70 | } 71 | 72 | if rs.Err() != nil { 73 | fmt.Println("got rows error:", rs.Err()) 74 | } 75 | // Output: scanned id: 1 and title: one 76 | // scanned id: 2 and title: two 77 | } 78 | 79 | func TestQueryStringStripping(t *testing.T) { 80 | assert := func(actual, expected string) { 81 | if res := stripQuery(actual); res != expected { 82 | t.Errorf("Expected '%s' to be '%s', but got '%s'", actual, expected, res) 83 | } 84 | } 85 | 86 | assert(" SELECT 1", "SELECT 1") 87 | assert("SELECT 1 FROM d", "SELECT 1 FROM d") 88 | assert(` 89 | SELECT c 90 | FROM D 91 | `, "SELECT c FROM D") 92 | assert("UPDATE (.+) SET ", "UPDATE (.+) SET") 93 | } 94 | 95 | func TestQueryMatcherRegexp(t *testing.T) { 96 | type testCase struct { 97 | expected string 98 | actual string 99 | err error 100 | } 101 | 102 | cases := []testCase{ 103 | {"?\\l", "SEL", fmt.Errorf("error parsing regexp: missing argument to repetition operator: `?`")}, 104 | {"SELECT (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", nil}, 105 | {"Select (.+) FROM users", "SELECT name, email FROM users WHERE id = ?", fmt.Errorf(`could not match actual sql: "SELECT name, email FROM users WHERE id = ?" with expected regexp "Select (.+) FROM users"`)}, 106 | {"SELECT (.+) FROM\nusers", "SELECT name, email\n FROM users\n WHERE id = ?", nil}, 107 | } 108 | 109 | for i, c := range cases { 110 | err := QueryMatcherRegexp.Match(c.expected, c.actual) 111 | if err == nil && c.err != nil { 112 | t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i) 113 | continue 114 | } 115 | if err != nil && c.err == nil { 116 | t.Errorf(`got unexpected error "%v" at %d case`, err, i) 117 | continue 118 | } 119 | if err == nil { 120 | continue 121 | } 122 | if err.Error() != c.err.Error() { 123 | t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i) 124 | } 125 | } 126 | } 127 | 128 | func TestQueryMatcherEqual(t *testing.T) { 129 | type testCase struct { 130 | expected string 131 | actual string 132 | err error 133 | } 134 | 135 | cases := []testCase{ 136 | {"SELECT name, email FROM users WHERE id = ?", "SELECT name, email\n FROM users\n WHERE id = ?", nil}, 137 | {"SELECT", "Select", fmt.Errorf(`actual sql: "Select" does not equal to expected "SELECT"`)}, 138 | {"SELECT from users", "SELECT from table", fmt.Errorf(`actual sql: "SELECT from table" does not equal to expected "SELECT from users"`)}, 139 | } 140 | 141 | for i, c := range cases { 142 | err := QueryMatcherEqual.Match(c.expected, c.actual) 143 | if err == nil && c.err != nil { 144 | t.Errorf(`got no error, but expected "%v" at %d case`, c.err, i) 145 | continue 146 | } 147 | if err != nil && c.err == nil { 148 | t.Errorf(`got unexpected error "%v" at %d case`, err, i) 149 | continue 150 | } 151 | if err == nil { 152 | continue 153 | } 154 | if err.Error() != c.err.Error() { 155 | t.Errorf(`expected error "%v", but got "%v" at %d case`, c.err, err, i) 156 | } 157 | } 158 | } 159 | 160 | func TestQueryMatcherAny(t *testing.T) { 161 | err := QueryMatcherAny.Match("foo", "SELECT * FROM users") 162 | if err != nil { 163 | t.Errorf("expected no error, but got %v", err) 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /result.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/jackc/pgx/v5/pgconn" 7 | ) 8 | 9 | // NewResult creates a new pgconn.CommandTag result 10 | // for Exec based query mocks. 11 | func NewResult(op string, rowsAffected int64) pgconn.CommandTag { 12 | return pgconn.NewCommandTag(fmt.Sprintf("%s %d", op, rowsAffected)) 13 | } 14 | -------------------------------------------------------------------------------- /result_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestShouldReturnValidSqlDriverResult(t *testing.T) { 8 | result := NewResult("SELECT", 2) 9 | if !result.Select() { 10 | t.Errorf("expected SELECT operation result, but got: %v", result.String()) 11 | } 12 | affected := result.RowsAffected() 13 | if affected != 2 { 14 | t.Errorf("expected affected rows to be 2, but got: %d", affected) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /rows.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "encoding/csv" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "reflect" 9 | "strings" 10 | 11 | "github.com/jackc/pgx/v5" 12 | "github.com/jackc/pgx/v5/pgconn" 13 | "github.com/jackc/pgx/v5/pgtype" 14 | ) 15 | 16 | // CSVColumnParser is a function which converts trimmed csv 17 | // column string to a []byte representation. Currently 18 | // transforms NULL to nil 19 | var CSVColumnParser = func(s string) interface{} { 20 | switch { 21 | case strings.ToLower(s) == "null": 22 | return nil 23 | } 24 | return s 25 | } 26 | 27 | // connRow implements the Row interface for Conn.QueryRow. 28 | type connRow rowSets 29 | 30 | func (r *connRow) Scan(dest ...any) (err error) { 31 | rows := (*rowSets)(r) 32 | 33 | if rows.Err() != nil { 34 | return rows.Err() 35 | } 36 | 37 | for _, d := range dest { 38 | if _, ok := d.(*pgtype.DriverBytes); ok { 39 | rows.Close() 40 | return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") 41 | } 42 | } 43 | 44 | if !rows.Next() { 45 | if rows.Err() == nil { 46 | return pgx.ErrNoRows 47 | } 48 | return rows.Err() 49 | } 50 | defer rows.Close() 51 | return errors.Join(rows.Scan(dest...), rows.Err()) 52 | } 53 | 54 | type rowSets struct { 55 | sets []*Rows 56 | RowSetNo int 57 | ex *ExpectedQuery 58 | } 59 | 60 | func (rs *rowSets) Conn() *pgx.Conn { 61 | return nil 62 | } 63 | 64 | func (rs *rowSets) Err() error { 65 | r := rs.sets[rs.RowSetNo] 66 | return r.nextErr[r.recNo-1] 67 | } 68 | 69 | func (rs *rowSets) CommandTag() pgconn.CommandTag { 70 | return rs.sets[rs.RowSetNo].commandTag 71 | } 72 | 73 | func (rs *rowSets) FieldDescriptions() []pgconn.FieldDescription { 74 | return rs.sets[rs.RowSetNo].defs 75 | } 76 | 77 | // func (rs *rowSets) Columns() []string { 78 | // return rs.sets[rs.pos].cols 79 | // } 80 | 81 | func (rs *rowSets) Close() { 82 | if rs.ex != nil { 83 | rs.ex.rowsWereClosed = true 84 | } 85 | rs.close() 86 | } 87 | 88 | // close marks the current rows closed, jumps to the last row, and sets the 89 | // close error. 90 | func (rs *rowSets) close() { 91 | r := rs.sets[rs.RowSetNo] 92 | r.recNo = len(r.rows) 93 | r.nextErr[r.recNo-1] = r.closeErr 94 | r.closed = true 95 | } 96 | 97 | // advances to next row 98 | func (rs *rowSets) Next() bool { 99 | r := rs.sets[rs.RowSetNo] 100 | if r.recNo == len(r.rows) && r.nextErr[r.recNo] == nil { 101 | rs.close() 102 | return false 103 | } 104 | r.recNo++ 105 | return r.recNo <= len(r.rows) 106 | } 107 | 108 | // Values returns the decoded row values. As with Scan(), it is an error to 109 | // call Values without first calling Next() and checking that it returned 110 | // true. 111 | func (rs *rowSets) Values() ([]interface{}, error) { 112 | r := rs.sets[rs.RowSetNo] 113 | return r.rows[r.recNo-1], r.nextErr[r.recNo-1] 114 | } 115 | 116 | func (rs *rowSets) Scan(dest ...interface{}) error { 117 | r := rs.sets[rs.RowSetNo] 118 | if r.closed { 119 | // If there is no error, we should return one anyway. Weirdly, pgx returns 120 | // `number of field descriptions must equal number of values, got %d and %d`. 121 | return r.nextErr[r.recNo-1] 122 | } 123 | if len(dest) == 1 { 124 | if rc, ok := dest[0].(pgx.RowScanner); ok { 125 | return rc.ScanRow(rs) 126 | } 127 | } 128 | if len(dest) != len(r.defs) { 129 | return fmt.Errorf("incorrect argument number %d for columns %d", len(dest), len(r.defs)) 130 | } 131 | if len(r.rows) == 0 { 132 | return pgx.ErrNoRows 133 | } 134 | for i, col := range r.rows[r.recNo-1] { 135 | if dest[i] == nil { 136 | //behave compatible with pgx 137 | continue 138 | } 139 | destVal := reflect.ValueOf(dest[i]) 140 | if destVal.Kind() != reflect.Ptr { 141 | return fmt.Errorf("destination argument must be a pointer for column %s", r.defs[i].Name) 142 | } 143 | if col == nil { 144 | dest[i] = nil 145 | continue 146 | } 147 | val := reflect.ValueOf(col) 148 | if _, ok := dest[i].(*interface{}); ok || val.Type().AssignableTo(destVal.Elem().Type()) { 149 | if destElem := destVal.Elem(); destElem.CanSet() { 150 | destElem.Set(val) 151 | } else { 152 | return fmt.Errorf("cannot set destination value for column %s", r.defs[i].Name) 153 | } 154 | } else if scanner, ok := destVal.Interface().(interface{ Scan(interface{}) error }); ok { 155 | // Try to use Scanner interface 156 | if err := scanner.Scan(val.Interface()); err != nil { 157 | return fmt.Errorf("scanning value error for column '%s': %w", string(r.defs[i].Name), err) 158 | } 159 | } else if val.CanConvert(destVal.Elem().Type()) { 160 | if destElem := destVal.Elem(); destElem.CanSet() { 161 | destElem.Set(val.Convert(destElem.Type())) 162 | } else { 163 | return fmt.Errorf("cannot set destination value for column %s", r.defs[i].Name) 164 | } 165 | } else { 166 | return fmt.Errorf("destination kind '%v' not supported for value kind '%v' of column '%s'", 167 | destVal.Elem().Kind(), val.Kind(), string(r.defs[i].Name)) 168 | } 169 | } 170 | return r.nextErr[r.recNo-1] 171 | } 172 | 173 | func (rs *rowSets) RawValues() [][]byte { 174 | r := rs.sets[rs.RowSetNo] 175 | dest := make([][]byte, len(r.defs)) 176 | 177 | for i, col := range r.rows[r.recNo-1] { 178 | if b, ok := rawBytes(col); ok { 179 | dest[i] = b 180 | continue 181 | } 182 | dest[i] = []byte(fmt.Sprintf("%v", col)) 183 | } 184 | 185 | return dest 186 | } 187 | 188 | // transforms to debuggable printable string 189 | func (rs *rowSets) String() string { 190 | if rs.empty() { 191 | return "\t- returns no data" 192 | } 193 | 194 | msg := "\t- returns data:\n" 195 | if len(rs.sets) == 1 { 196 | for n, row := range rs.sets[0].rows { 197 | msg += fmt.Sprintf("\t\trow %d - %+v\n", n, row) 198 | } 199 | return msg 200 | } 201 | for i, set := range rs.sets { 202 | msg += fmt.Sprintf("\t\tresult set: %d\n", i) 203 | for n, row := range set.rows { 204 | msg += fmt.Sprintf("\t\t\trow %d: %+v\n", n, row) 205 | } 206 | } 207 | return msg 208 | } 209 | 210 | func (rs *rowSets) empty() bool { 211 | for _, set := range rs.sets { 212 | if len(set.rows) > 0 { 213 | return false 214 | } 215 | } 216 | return true 217 | } 218 | 219 | func rawBytes(col interface{}) (_ []byte, ok bool) { 220 | val, err := json.Marshal(col) 221 | if err != nil || len(val) == 0 { 222 | return nil, false 223 | } 224 | // Copy the bytes from the mocked row into a shared raw buffer, which we'll replace the content of later 225 | b := make([]byte, len(val)) 226 | copy(b, val) 227 | return b, true 228 | } 229 | 230 | // Rows is a mocked collection of rows to 231 | // return for Query result 232 | type Rows struct { 233 | commandTag pgconn.CommandTag 234 | defs []pgconn.FieldDescription 235 | rows [][]interface{} 236 | recNo int 237 | nextErr map[int]error 238 | closeErr error 239 | closed bool 240 | } 241 | 242 | // NewRows allows Rows to be created from a 243 | // sql interface{} slice or from the CSV string and 244 | // to be used as sql driver.Rows. 245 | // Use pgxmock.NewRows instead if using a custom converter 246 | func NewRows(columns []string) *Rows { 247 | var coldefs []pgconn.FieldDescription 248 | for _, column := range columns { 249 | coldefs = append(coldefs, pgconn.FieldDescription{Name: column}) 250 | } 251 | return &Rows{ 252 | defs: coldefs, 253 | nextErr: make(map[int]error), 254 | } 255 | } 256 | 257 | // CloseError sets an error which will be returned by [Rows.Err] after 258 | // [Rows.Close] has been called or [Rows.Next] returns false. 259 | func (r *Rows) CloseError(err error) *Rows { 260 | r.closeErr = err 261 | return r 262 | } 263 | 264 | // RowError allows to set an error 265 | // which will be returned when a given 266 | // row number is read 267 | func (r *Rows) RowError(row int, err error) *Rows { 268 | r.nextErr[row] = err 269 | return r 270 | } 271 | 272 | // AddRow composed from database interface{} slice 273 | // return the same instance to perform subsequent actions. 274 | // Note that the number of values must match the number 275 | // of columns 276 | func (r *Rows) AddRow(values ...any) *Rows { 277 | if len(values) != len(r.defs) { 278 | panic("Expected number of values to match number of columns") 279 | } 280 | 281 | row := make([]interface{}, len(r.defs)) 282 | copy(row, values) 283 | r.rows = append(r.rows, row) 284 | return r 285 | } 286 | 287 | // AddRows adds multiple rows composed from any slice and 288 | // returns the same instance to perform subsequent actions. 289 | func (r *Rows) AddRows(values ...[]any) *Rows { 290 | for _, value := range values { 291 | r.AddRow(value...) 292 | } 293 | return r 294 | } 295 | 296 | // AddCommandTag will add a command tag to the result set 297 | func (r *Rows) AddCommandTag(tag pgconn.CommandTag) *Rows { 298 | r.commandTag = tag 299 | return r 300 | } 301 | 302 | // FromCSVString build rows from csv string. 303 | // return the same instance to perform subsequent actions. 304 | // Note that the number of values must match the number 305 | // of columns 306 | func (r *Rows) FromCSVString(s string) *Rows { 307 | res := strings.NewReader(strings.TrimSpace(s)) 308 | csvReader := csv.NewReader(res) 309 | 310 | for { 311 | res, err := csvReader.Read() 312 | if err != nil || res == nil { 313 | break 314 | } 315 | 316 | row := make([]interface{}, len(r.defs)) 317 | for i, v := range res { 318 | row[i] = CSVColumnParser(strings.TrimSpace(v)) 319 | } 320 | r.rows = append(r.rows, row) 321 | } 322 | return r 323 | } 324 | 325 | // Kind returns rows corresponding to the interface pgx.Rows 326 | // useful for testing entities that implement an interface pgx.RowScanner 327 | func (r *Rows) Kind() pgx.Rows { 328 | return &rowSets{ 329 | sets: []*Rows{r}, 330 | } 331 | } 332 | 333 | // NewRowsWithColumnDefinition return rows with columns metadata 334 | func NewRowsWithColumnDefinition(columns ...pgconn.FieldDescription) *Rows { 335 | return &Rows{ 336 | defs: columns, 337 | nextErr: make(map[int]error), 338 | } 339 | } 340 | -------------------------------------------------------------------------------- /rows_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "testing" 9 | 10 | "github.com/jackc/pgx/v5" 11 | "github.com/jackc/pgx/v5/pgconn" 12 | "github.com/jackc/pgx/v5/pgtype" 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestPointerToInterfaceArgument(t *testing.T) { 17 | mock, err := NewPool() 18 | if err != nil { 19 | panic(err) 20 | } 21 | 22 | mock.ExpectQuery(`SELECT 123`). 23 | WillReturnRows( 24 | mock.NewRows([]string{"id"}). 25 | AddRow(int64(123))) // Value which should be scanned in *interface{} 26 | 27 | var value interface{} 28 | err = mock.QueryRow(context.Background(), `SELECT 123`).Scan(&value) 29 | if err != nil || value.(int64) != 123 { 30 | t.Error(err) 31 | } 32 | 33 | } 34 | 35 | func TestExplicitTypeCasting(t *testing.T) { 36 | mock, err := NewPool() 37 | if err != nil { 38 | panic(err) 39 | } 40 | 41 | mock.ExpectQuery("SELECT .+ FROM test WHERE .+"). 42 | WithArgs(uint64(1)). 43 | WillReturnRows(NewRows( 44 | []string{"id"}). 45 | AddRow(uint64(1)), 46 | ) 47 | 48 | rows := mock.QueryRow(context.Background(), "SELECT id FROM test WHERE id = $1", uint64(1)) 49 | 50 | var id uint64 51 | err = rows.Scan(&id) 52 | if err != nil { 53 | t.Error(err) 54 | } 55 | } 56 | 57 | func TestAddRows(t *testing.T) { 58 | t.Parallel() 59 | mock, err := NewConn() 60 | if err != nil { 61 | t.Fatal("failed to open sqlmock database:", err) 62 | } 63 | defer mock.Close(context.Background()) 64 | 65 | values := [][]any{ 66 | { 67 | 1, "John", 68 | }, 69 | { 70 | 2, "Jane", 71 | }, 72 | { 73 | 3, "Peter", 74 | }, 75 | { 76 | 4, "Emily", 77 | }, 78 | } 79 | 80 | rows := NewRows([]string{"id", "name"}).AddRows(values...) 81 | mock.ExpectQuery("SELECT").WillReturnRows(rows).RowsWillBeClosed() 82 | 83 | rs, _ := mock.Query(context.Background(), "SELECT") 84 | defer rs.Close() 85 | 86 | for rs.Next() { 87 | var id int 88 | var name string 89 | _ = rs.Scan(&id, &name) 90 | fmt.Println("scanned id:", id, "and name:", name) 91 | } 92 | 93 | if rs.Err() != nil { 94 | fmt.Println("got rows error:", rs.Err()) 95 | } 96 | // Output: scanned id: 1 and title: John 97 | // scanned id: 2 and title: Jane 98 | // scanned id: 3 and title: Peter 99 | // scanned id: 4 and title: Emily 100 | } 101 | 102 | func ExampleRows_AddRows() { 103 | mock, err := NewConn() 104 | if err != nil { 105 | fmt.Println("failed to open sqlmock database:", err) 106 | return 107 | } 108 | defer mock.Close(context.Background()) 109 | 110 | values := [][]any{ 111 | { 112 | 1, "one", 113 | }, 114 | { 115 | 2, "two", 116 | }, 117 | } 118 | 119 | rows := NewRows([]string{"id", "title"}).AddRows(values...) 120 | 121 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 122 | 123 | rs, _ := mock.Query(context.Background(), "SELECT") 124 | defer rs.Close() 125 | 126 | for rs.Next() { 127 | var id int 128 | var title string 129 | _ = rs.Scan(&id, &title) 130 | fmt.Println("scanned id:", id, "and title:", title) 131 | } 132 | 133 | if rs.Err() != nil { 134 | fmt.Println("got rows error:", rs.Err()) 135 | } 136 | // Output: scanned id: 1 and title: one 137 | // scanned id: 2 and title: two 138 | } 139 | 140 | func ExampleRows() { 141 | mock, err := NewConn() 142 | if err != nil { 143 | fmt.Println("failed to open pgxmock database:", err) 144 | return 145 | } 146 | defer mock.Close(context.Background()) 147 | 148 | rows := NewRows([]string{"id", "title"}). 149 | AddRow(1, "one"). 150 | AddRow(2, "two"). 151 | AddCommandTag(pgconn.NewCommandTag("SELECT 2")) 152 | 153 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 154 | 155 | rs, _ := mock.Query(context.Background(), "SELECT") 156 | defer rs.Close() 157 | 158 | fmt.Println("command tag:", rs.CommandTag()) 159 | if len(rs.FieldDescriptions()) != 2 { 160 | fmt.Println("got wrong number of fields") 161 | } 162 | 163 | for rs.Next() { 164 | var id int 165 | var title string 166 | _ = rs.Scan(&id, &title) 167 | fmt.Println("scanned id:", id, "and title:", title) 168 | } 169 | 170 | if rs.Err() != nil { 171 | fmt.Println("got rows error:", rs.Err()) 172 | } 173 | 174 | // Output: command tag: SELECT 2 175 | // scanned id: 1 and title: one 176 | // scanned id: 2 and title: two 177 | } 178 | 179 | func ExampleRows_rowError() { 180 | mock, err := NewConn() 181 | if err != nil { 182 | fmt.Println("failed to open pgxmock database:", err) 183 | return 184 | } 185 | // defer mock.Close(context.Background()) 186 | 187 | rows := NewRows([]string{"id", "title"}). 188 | AddRow(0, "one"). 189 | AddRow(1, "two"). 190 | RowError(1, fmt.Errorf("row error")) 191 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 192 | 193 | rs, _ := mock.Query(context.Background(), "SELECT") 194 | defer rs.Close() 195 | 196 | for rs.Next() { 197 | var id int 198 | var title string 199 | _ = rs.Scan(&id, &title) 200 | fmt.Println("scanned id:", id, "and title:", title) 201 | if rs.Err() != nil { 202 | fmt.Println("got rows error:", rs.Err()) 203 | } 204 | } 205 | 206 | // Output: scanned id: 0 and title: one 207 | // scanned id: 1 and title: two 208 | // got rows error: row error 209 | } 210 | 211 | func ExampleRows_expectToBeClosed() { 212 | mock, err := NewConn() 213 | if err != nil { 214 | fmt.Println("failed to open pgxmock database:", err) 215 | return 216 | } 217 | defer mock.Close(context.Background()) 218 | 219 | row := NewRows([]string{"id", "title"}).AddRow(1, "john") 220 | rows := NewRowsWithColumnDefinition( 221 | pgconn.FieldDescription{Name: "id"}, 222 | pgconn.FieldDescription{Name: "title"}). 223 | AddRow(1, "john").AddRow(2, "anna") 224 | mock.ExpectQuery("SELECT").WillReturnRows(row, rows).RowsWillBeClosed() 225 | 226 | _, _ = mock.Query(context.Background(), "SELECT") 227 | _, _ = mock.Query(context.Background(), "SELECT") 228 | 229 | if err := mock.ExpectationsWereMet(); err != nil { 230 | fmt.Println("got error:", err) 231 | } 232 | 233 | /*Output: got error: expected query rows to be closed, but it was not: ExpectedQuery => expecting call to Query() or to QueryRow(): 234 | - matches sql: 'SELECT' 235 | - is without arguments 236 | - returns data: 237 | result set: 0 238 | row 0: [1 john] 239 | result set: 1 240 | row 0: [1 john] 241 | row 1: [2 anna] 242 | */ 243 | } 244 | 245 | func ExampleRows_customDriverValue() { 246 | mock, err := NewConn() 247 | if err != nil { 248 | fmt.Println("failed to open pgxmock database:", err) 249 | return 250 | } 251 | defer mock.Close(context.Background()) 252 | 253 | rows := NewRows([]string{"id", "null_int"}). 254 | AddRow(5, pgtype.Int8{Int64: 5, Valid: true}). 255 | AddRow(2, pgtype.Int8{Valid: false}) 256 | 257 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 258 | 259 | rs, _ := mock.Query(context.Background(), "SELECT") 260 | defer rs.Close() 261 | 262 | for rs.Next() { 263 | var id int 264 | var num pgtype.Int8 265 | _ = rs.Scan(&id, &num) 266 | fmt.Println("scanned id:", id, "and null int64:", num) 267 | } 268 | 269 | if rs.Err() != nil { 270 | fmt.Println("got rows error:", rs.Err()) 271 | } 272 | // Output: scanned id: 5 and null int64: {5 true} 273 | // scanned id: 2 and null int64: {0 false} 274 | } 275 | 276 | func TestAllowsToSetRowsErrors(t *testing.T) { 277 | t.Parallel() 278 | mock, err := NewConn() 279 | if err != nil { 280 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 281 | } 282 | defer mock.Close(context.Background()) 283 | 284 | rows := NewRows([]string{"id", "title"}). 285 | AddRow(0, "one"). 286 | AddRow(1, "two"). 287 | RowError(1, fmt.Errorf("error")) 288 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 289 | 290 | rs, err := mock.Query(context.Background(), "SELECT") 291 | if err != nil { 292 | t.Fatalf("unexpected error: %s", err) 293 | } 294 | defer rs.Close() 295 | 296 | if !rs.Next() { 297 | t.Fatal("expected the first row to be available") 298 | } 299 | if rs.Err() != nil { 300 | t.Fatalf("unexpected error: %s", rs.Err()) 301 | } 302 | 303 | if !rs.Next() { 304 | t.Fatal("expected the second row to be available, even there should be an error") 305 | } 306 | if rs.Err() == nil { 307 | t.Fatal("expected an error, but got none") 308 | } 309 | 310 | if err := mock.ExpectationsWereMet(); err != nil { 311 | t.Fatal(err) 312 | } 313 | } 314 | 315 | func TestOneRowError(t *testing.T) { 316 | t.Parallel() 317 | mock, err := NewConn() 318 | if err != nil { 319 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 320 | } 321 | defer mock.Close(context.Background()) 322 | 323 | rows := NewRows([]string{"id", "title"}). 324 | RowError(0, fmt.Errorf("error")) 325 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 326 | 327 | rs, err := mock.Query(context.Background(), "SELECT") 328 | if err != nil { 329 | t.Fatalf("unexpected error: %s", err) 330 | } 331 | defer rs.Close() 332 | 333 | if rs.Next() { 334 | t.Fatal("expected the first row not to be available") 335 | } 336 | if rs.Err() == nil { 337 | t.Fatalf("unexpected success") 338 | } 339 | 340 | if err := mock.ExpectationsWereMet(); err != nil { 341 | t.Fatal(err) 342 | } 343 | } 344 | 345 | func TestNoRowsCloseError(t *testing.T) { 346 | t.Parallel() 347 | mock, err := NewConn() 348 | if err != nil { 349 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 350 | } 351 | defer mock.Close(context.Background()) 352 | 353 | rows := NewRows([]string{"id"}).CloseError(fmt.Errorf("close error")) 354 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 355 | 356 | rs, err := mock.Query(context.Background(), "SELECT") 357 | if err != nil { 358 | t.Fatalf("unexpected error: %s", err) 359 | } 360 | rs.Close() 361 | 362 | if err := mock.ExpectationsWereMet(); err != nil { 363 | t.Fatal(err) 364 | } 365 | } 366 | 367 | func TestRowsCloseError(t *testing.T) { 368 | t.Parallel() 369 | mock, err := NewConn() 370 | if err != nil { 371 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 372 | } 373 | defer mock.Close(context.Background()) 374 | 375 | rows := NewRows([]string{"id"}).AddRow(1).AddRow(2).CloseError(fmt.Errorf("close error")) 376 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 377 | 378 | rs, err := mock.Query(context.Background(), "SELECT") 379 | if err != nil { 380 | t.Fatalf("unexpected error: %s", err) 381 | } 382 | defer rs.Close() 383 | 384 | total := 0 385 | for rs.Next() { 386 | var id int 387 | if err := rs.Scan(&id); err != nil { 388 | t.Fatalf("unexpected error: %s", err) 389 | } 390 | total += id 391 | } 392 | 393 | if total != 3 { 394 | t.Fatalf("expected id sum of 3, got %d", total) 395 | } 396 | 397 | if rs.Err() == nil { 398 | t.Fatal("expected an error, but got none") 399 | } 400 | } 401 | 402 | func TestRowsCloseEarlyError(t *testing.T) { 403 | t.Parallel() 404 | mock, err := NewConn() 405 | if err != nil { 406 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 407 | } 408 | defer mock.Close(context.Background()) 409 | 410 | rows := NewRows([]string{"id"}).AddRow(1).AddRow(2).CloseError(fmt.Errorf("close error")) 411 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 412 | 413 | rs, err := mock.Query(context.Background(), "SELECT") 414 | if err != nil { 415 | t.Fatalf("unexpected error: %s", err) 416 | } 417 | 418 | // Fetch the first row. 419 | if !rs.Next() { 420 | t.Fatal("unexpected false Next()") 421 | } 422 | var id int 423 | if err := rs.Scan(&id); err != nil { 424 | t.Fatalf("unexpected error: %s", err) 425 | } 426 | if id != 1 { 427 | t.Fatalf("expected id to be 1, got %d", id) 428 | } 429 | 430 | // Close before fetching the next row. 431 | rs.Close() 432 | 433 | // The close error should be set. 434 | if rs.Err() == nil { 435 | t.Fatal("expected an error, but got none") 436 | } 437 | 438 | // Next should be false now. 439 | if rs.Next() { 440 | t.Fatal("unexpected true Next()") 441 | } 442 | 443 | // Scan should return the error. 444 | id = -1 445 | if err := rs.Scan(&id); err == nil { 446 | t.Fatal("expected an error, but got none") 447 | } 448 | if id != -1 { 449 | t.Fatalf("expected no id but got %v", id) 450 | } 451 | 452 | // The close error should be set. 453 | if rs.Err() == nil { 454 | t.Fatal("expected an error, but got none") 455 | } 456 | } 457 | 458 | func TestRowsClosed(t *testing.T) { 459 | t.Parallel() 460 | mock, err := NewConn() 461 | if err != nil { 462 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 463 | } 464 | defer mock.Close(context.Background()) 465 | 466 | rows := NewRows([]string{"id"}).AddRow(1) 467 | mock.ExpectQuery("SELECT").WillReturnRows(rows).RowsWillBeClosed() 468 | 469 | rs, err := mock.Query(context.Background(), "SELECT") 470 | if err != nil { 471 | t.Fatalf("unexpected error: %s", err) 472 | } 473 | rs.Close() 474 | 475 | if err := mock.ExpectationsWereMet(); err != nil { 476 | t.Fatal(err) 477 | } 478 | } 479 | 480 | func TestQuerySingleRow(t *testing.T) { 481 | t.Parallel() 482 | mock, err := NewConn() 483 | if err != nil { 484 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 485 | } 486 | defer mock.Close(context.Background()) 487 | 488 | rows := NewRows([]string{"id"}). 489 | AddRow(1). 490 | AddRow(2) 491 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 492 | 493 | var id int 494 | if err := mock.QueryRow(context.Background(), "SELECT").Scan(&id); err != nil { 495 | t.Fatalf("unexpected error: %s", err) 496 | } 497 | 498 | mock.ExpectQuery("SELECT").WillReturnRows(NewRows([]string{"id"})) 499 | if err := mock.QueryRow(context.Background(), "SELECT").Scan(&id); err != pgx.ErrNoRows { 500 | t.Fatal("expected sql no rows error") 501 | } 502 | 503 | if err := mock.ExpectationsWereMet(); err != nil { 504 | t.Fatal(err) 505 | } 506 | } 507 | 508 | func ExampleRows_values() { 509 | mock, err := NewConn() 510 | if err != nil { 511 | fmt.Println("failed to open pgxmock database:", err) 512 | return 513 | } 514 | defer mock.Close(context.Background()) 515 | 516 | rows := NewRows([]string{"raw"}). 517 | AddRow(`one string value with some text!`). 518 | AddRow(`two string value with even more text than the first one`). 519 | AddRow([]byte{}) 520 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 521 | 522 | rs, err := mock.Query(context.Background(), "SELECT") 523 | if err != nil { 524 | fmt.Print(err) 525 | return 526 | } 527 | defer rs.Close() 528 | 529 | for rs.Next() { 530 | v, e := rs.Values() 531 | fmt.Println(v[0], e) 532 | } 533 | // Output: one string value with some text! 534 | // two string value with even more text than the first one 535 | // [] 536 | } 537 | 538 | func ExampleRows_rawValues() { 539 | mock, err := NewConn() 540 | if err != nil { 541 | fmt.Println("failed to open pgxmock database:", err) 542 | return 543 | } 544 | defer mock.Close(context.Background()) 545 | 546 | rows := NewRows([]string{"raw"}). 547 | AddRow([]byte(`one binary value with some text!`)). 548 | AddRow([]byte(`two binary value with even more text than the first one`)). 549 | AddRow([]byte{}) 550 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 551 | 552 | rs, err := mock.Query(context.Background(), "SELECT") 553 | if err != nil { 554 | fmt.Print(err) 555 | return 556 | } 557 | defer rs.Close() 558 | 559 | for rs.Next() { 560 | var rawValue []byte 561 | if err := json.Unmarshal(rs.RawValues()[0], &rawValue); err != nil { 562 | fmt.Print(err) 563 | } 564 | fmt.Println(string(rawValue)) 565 | } 566 | // Output: one binary value with some text! 567 | // two binary value with even more text than the first one 568 | // 569 | } 570 | 571 | func TestRowsScanError(t *testing.T) { 572 | t.Parallel() 573 | mock, err := NewConn() 574 | if err != nil { 575 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 576 | } 577 | defer mock.Close(context.Background()) 578 | 579 | r := NewRows([]string{"col1", "col2"}).AddRow("one", "two").AddRow("one", nil) 580 | mock.ExpectQuery("SELECT").WillReturnRows(r) 581 | 582 | rs, err := mock.Query(context.Background(), "SELECT") 583 | if err != nil { 584 | t.Fatalf("unexpected error: %s", err) 585 | } 586 | defer rs.Close() 587 | 588 | var one, two string 589 | if !rs.Next() || rs.Err() != nil || rs.Scan(&one, &two) != nil { 590 | t.Fatal("unexpected error on first row scan") 591 | } 592 | 593 | if !rs.Next() || rs.Err() != nil { 594 | t.Fatal("unexpected error on second row read") 595 | } 596 | 597 | err = rs.Scan(&one, two) 598 | if err == nil { 599 | t.Fatal("expected an error for scan, but got none") 600 | } 601 | 602 | if err := mock.ExpectationsWereMet(); err != nil { 603 | t.Fatal(err) 604 | } 605 | } 606 | 607 | type testScanner struct { 608 | Value int64 609 | } 610 | 611 | func (s *testScanner) Scan(src interface{}) error { 612 | switch src := src.(type) { 613 | case int64: 614 | s.Value = src 615 | return nil 616 | default: 617 | return errors.New("a dummy scan error") 618 | } 619 | } 620 | 621 | func TestRowsScanWithScannerIface(t *testing.T) { 622 | mock, err := NewConn() 623 | if err != nil { 624 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 625 | } 626 | defer mock.Close(context.Background()) 627 | 628 | r := NewRows([]string{"col1"}).AddRow(int64(23)) 629 | mock.ExpectQuery("SELECT").WillReturnRows(r) 630 | 631 | rs, err := mock.Query(context.Background(), "SELECT") 632 | if err != nil { 633 | t.Fatalf("unexpected error: %s", err) 634 | } 635 | 636 | var result testScanner 637 | if !rs.Next() || rs.Err() != nil { 638 | t.Fatal("unexpected error on first row read") 639 | } 640 | if rs.Scan(&result) != nil { 641 | t.Fatal("unexpected error for scan") 642 | } 643 | 644 | if result.Value != int64(23) { 645 | t.Fatalf("expected Value to be 23 but got: %d", result.Value) 646 | } 647 | 648 | } 649 | 650 | func TestRowsScanErrorOnScannerIface(t *testing.T) { 651 | mock, err := NewConn() 652 | if err != nil { 653 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 654 | } 655 | defer mock.Close(context.Background()) 656 | 657 | r := NewRows([]string{"col1"}).AddRow("one").AddRow("two") 658 | mock.ExpectQuery("SELECT").WillReturnRows(r) 659 | 660 | rs, err := mock.Query(context.Background(), "SELECT") 661 | if err != nil { 662 | t.Fatalf("unexpected error: %s", err) 663 | } 664 | 665 | var one int64 // No scanner interface 666 | var two testScanner // scanner error 667 | if !rs.Next() || rs.Err() != nil { 668 | t.Fatal("unexpected error on first row read") 669 | } 670 | if rs.Scan(&one) == nil { 671 | t.Fatal("expected an error for first scan (no scanner interface), but got none") 672 | } 673 | 674 | if !rs.Next() || rs.Err() != nil { 675 | t.Fatal("unexpected error on second row read") 676 | } 677 | 678 | err = rs.Scan(&two) 679 | if err == nil { 680 | t.Fatal("expected an error for second scan (scanner error), but got none") 681 | } 682 | } 683 | 684 | func TestCSVRowParser(t *testing.T) { 685 | t.Parallel() 686 | rs := NewRows([]string{"col1", "col2"}).FromCSVString("a,NULL") 687 | mock, err := NewConn() 688 | if err != nil { 689 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 690 | } 691 | defer mock.Close(context.Background()) 692 | 693 | mock.ExpectQuery("SELECT").WillReturnRows(rs) 694 | 695 | rw, err := mock.Query(context.Background(), "SELECT") 696 | if err != nil { 697 | t.Fatalf("unexpected error: %s", err) 698 | } 699 | defer rw.Close() 700 | var col1 string 701 | var col2 []byte 702 | 703 | rw.Next() 704 | if err = rw.Scan(&col1, &col2); err != nil { 705 | t.Fatalf("unexpected error: %s", err) 706 | } 707 | if col1 != "a" { 708 | t.Fatalf("expected col1 to be 'a', but got [%T]:%+v", col1, col1) 709 | } 710 | if col2 != nil { 711 | t.Fatalf("expected col2 to be nil, but got [%T]:%+v", col2, col2) 712 | } 713 | } 714 | 715 | func TestWrongNumberOfValues(t *testing.T) { 716 | // Open new mock database 717 | mock, err := NewConn() 718 | if err != nil { 719 | fmt.Println("error creating mock database") 720 | return 721 | } 722 | defer mock.Close(context.Background()) 723 | defer func() { 724 | _ = recover() 725 | }() 726 | mock.ExpectQuery("SELECT ID FROM TABLE").WithArgs(101).WillReturnRows(NewRows([]string{"ID"}).AddRow(101, "Hello")) 727 | _, _ = mock.Query(context.Background(), "SELECT ID FROM TABLE", 101) 728 | // shouldn't reach here 729 | t.Error("expected panic from query") 730 | } 731 | 732 | func TestEmptyRowSets(t *testing.T) { 733 | rs1 := NewRows([]string{"a"}).AddRow("a") 734 | rs2 := NewRows([]string{"b"}) 735 | rs3 := NewRows([]string{"c"}) 736 | 737 | set1 := &rowSets{sets: []*Rows{rs1, rs2}} 738 | set2 := &rowSets{sets: []*Rows{rs3, rs2}} 739 | set3 := &rowSets{sets: []*Rows{rs2}} 740 | 741 | if set1.empty() { 742 | t.Fatalf("expected rowset 1, not to be empty, but it was") 743 | } 744 | if !set2.empty() { 745 | t.Fatalf("expected rowset 2, to be empty, but it was not") 746 | } 747 | if !set3.empty() { 748 | t.Fatalf("expected rowset 3, to be empty, but it was not") 749 | } 750 | } 751 | 752 | func TestMockQueryWithCollect(t *testing.T) { 753 | t.Parallel() 754 | mock, err := NewConn() 755 | if err != nil { 756 | t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) 757 | } 758 | defer mock.Close(context.Background()) 759 | type rowStructType struct { 760 | ID int 761 | Title string 762 | } 763 | rs := NewRows([]string{"id", "title"}).AddRow(5, "hello world") 764 | 765 | mock.ExpectQuery("SELECT (.+) FROM articles WHERE id = ?"). 766 | WithArgs(5). 767 | WillReturnRows(rs) 768 | 769 | rows, err := mock.Query(context.Background(), "SELECT (.+) FROM articles WHERE id = ?", 5) 770 | if err != nil { 771 | t.Fatalf("error '%s' was not expected while retrieving mock rows", err) 772 | } 773 | 774 | defer rows.Close() 775 | 776 | rawMap, err := pgx.CollectRows(rows, pgx.RowToAddrOfStructByPos[rowStructType]) 777 | if err != nil { 778 | t.Fatalf("error '%s' was not expected while trying to collect rows", err) 779 | } 780 | 781 | var id = rawMap[0].ID 782 | var title = rawMap[0].Title 783 | 784 | if err != nil { 785 | t.Fatalf("error '%s' was not expected while trying to scan row", err) 786 | } 787 | 788 | if id != 5 { 789 | t.Errorf("expected mocked id to be 5, but got %d instead", id) 790 | } 791 | 792 | if title != "hello world" { 793 | t.Errorf("expected mocked title to be 'hello world', but got '%s' instead", title) 794 | } 795 | 796 | if err := mock.ExpectationsWereMet(); err != nil { 797 | t.Errorf("there were unfulfilled expectations: %s", err) 798 | } 799 | } 800 | 801 | func TestRowsConn(t *testing.T) { 802 | assert.Nil(t, (&rowSets{}).Conn()) 803 | } 804 | 805 | func TestRowsKind(t *testing.T) { 806 | var alphabet = []string{"a", "b", "c", "d", "e", "f"} 807 | rows := NewRows([]string{"id", "alphabet"}) 808 | 809 | for id, b := range alphabet { 810 | rows.AddRow(id, b) 811 | } 812 | 813 | kindRows := rows.Kind() 814 | 815 | for i := 0; kindRows.Next(); i++ { 816 | var ( 817 | letter string 818 | index int 819 | ) 820 | if err := kindRows.Scan(&index, &letter); err != nil { 821 | t.Fatalf("unexpected error: %s", err) 822 | } 823 | 824 | if index != i { 825 | t.Fatalf("expected %d, but got %d", i, index) 826 | } 827 | 828 | if letter != alphabet[i] { 829 | t.Fatalf("expected %s, but got %s", alphabet[i], letter) 830 | } 831 | } 832 | 833 | // Test closing as this is called by the pgx library in pgx.CollectRows 834 | // Previously this caused a nil pointer exception when Close was called on kindRows 835 | kindRows.Close() 836 | } 837 | 838 | // TestConnRow tests the ConnRow interface implementation for Conn.QueryRow. 839 | func TestConnRow(t *testing.T) { 840 | t.Parallel() 841 | mock, _ := NewConn() 842 | a := assert.New(t) 843 | 844 | // check error case 845 | expectedErr := errors.New("error") 846 | mock.ExpectQuery("SELECT").WillReturnError(expectedErr) 847 | err := mock.QueryRow(context.Background(), "SELECT").Scan(nil) 848 | a.ErrorIs(err, expectedErr) 849 | 850 | // check no rows returned case 851 | var id int 852 | rows := NewRows([]string{"id"}) 853 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 854 | err = mock.QueryRow(context.Background(), "SELECT").Scan(&id) 855 | a.ErrorIs(err, pgx.ErrNoRows) 856 | 857 | // check single row returned case 858 | rows = NewRows([]string{"id"}).AddRow(1) 859 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 860 | err = mock.QueryRow(context.Background(), "SELECT").Scan(&id) 861 | a.NoError(err) 862 | a.Equal(1, id) 863 | 864 | // check multiple rows returned case 865 | rows = NewRows([]string{"id"}).AddRow(1).AddRow(42) 866 | mock.ExpectQuery("SELECT").WillReturnRows(rows) 867 | err = mock.QueryRow(context.Background(), "SELECT").Scan(&id) 868 | a.NoError(err) 869 | a.Equal(1, id) 870 | 871 | a.NoError(mock.ExpectationsWereMet()) 872 | } 873 | 874 | func TestInvalidsQueryRow(t *testing.T) { 875 | mock, _ := NewPool() 876 | a := assert.New(t) 877 | 878 | // check invalid argument type 879 | mock.ExpectQuery("SELECT").WillReturnRows(mock.NewRows([]string{"seq"}).AddRow("not-an-int")) 880 | var expectedInt int 881 | err := mock.QueryRow(ctx, "SELECT").Scan(&expectedInt) 882 | a.Error(err) 883 | 884 | // check BOF error 885 | rs := mock.NewRows([]string{"seq"}) 886 | rs.AddRow("not-an-int").RowError(-1, errors.New("error")) // emulate pre-Next() error 887 | mock.ExpectQuery("SELECT").WillReturnRows(rs) 888 | err = mock.QueryRow(ctx, "SELECT").Scan(&expectedInt) 889 | a.Error(err) 890 | 891 | // check no row error 892 | rs = mock.NewRows([]string{"seq"}) 893 | mock.ExpectQuery("SELECT").WillReturnRows(rs) 894 | err = mock.QueryRow(ctx, "SELECT").Scan(&expectedInt) 895 | a.Error(err) 896 | 897 | //check first row error 898 | rs = mock.NewRows([]string{"seq"}).RowError(0, errors.New("error")) 899 | mock.ExpectQuery("SELECT").WillReturnRows(rs) 900 | err = mock.QueryRow(ctx, "SELECT").Scan(&expectedInt) 901 | a.Error(err) 902 | 903 | // check pgtype.DriverBytes error 904 | mock.ExpectQuery("SELECT").WillReturnRows(mock.NewRows([]string{"seq"}).AddRow("not-an-int")) 905 | var d pgtype.DriverBytes 906 | err = mock.QueryRow(ctx, "SELECT").Scan(&d) 907 | a.Error(err) 908 | } 909 | -------------------------------------------------------------------------------- /sql_test.go: -------------------------------------------------------------------------------- 1 | package pgxmock 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestScanTime(t *testing.T) { 11 | mock, err := NewPool() 12 | if err != nil { 13 | panic(err) 14 | } 15 | 16 | now, _ := time.Parse(time.RFC3339, "2006-01-02T15:04:05Z07:00") 17 | 18 | mock.ExpectQuery(`SELECT now()`). 19 | WillReturnRows( 20 | mock.NewRows([]string{"stamp"}). 21 | AddRow(now)) 22 | 23 | var value sql.NullTime 24 | err = mock.QueryRow(context.Background(), `SELECT now()`).Scan(&value) 25 | if err != nil { 26 | t.Error(err) 27 | } 28 | if value.Time != now { 29 | t.Errorf("want %v, got %v", now, value.Time) 30 | } 31 | } 32 | --------------------------------------------------------------------------------