├── .codeclimate.yml ├── .github └── workflows │ └── test.yml ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── cmd └── go-mockgen │ ├── args.go │ └── main.go ├── go.mod ├── go.sum ├── internal ├── integration │ ├── gen.go │ ├── gomega_test.go │ ├── strict_test.go │ ├── testdata │ │ ├── .gitignore │ │ ├── complex_types.go │ │ ├── empty.go │ │ ├── generics.go │ │ ├── mocks │ │ │ └── mocks.go │ │ ├── reference.go │ │ ├── relation.go │ │ ├── unexported.go │ │ ├── variadic_args.go │ │ └── variadic_non_interface.go │ └── testify_test.go ├── mockgen │ ├── consts │ │ └── consts.go │ ├── generation │ │ ├── errors.go │ │ ├── generate.go │ │ ├── generate_comment.go │ │ ├── generate_constructors.go │ │ ├── generate_constructors_test.go │ │ ├── generate_mock_func_call_methods.go │ │ ├── generate_mock_func_call_methods_test.go │ │ ├── generate_mock_func_methods.go │ │ ├── generate_mock_func_methods_test.go │ │ ├── generate_mock_methods.go │ │ ├── generate_mock_methods_test.go │ │ ├── generate_structs.go │ │ ├── generate_structs_test.go │ │ ├── generate_test.go │ │ ├── generate_type.go │ │ ├── helpers_test.go │ │ ├── paths.go │ │ ├── util.go │ │ ├── wrapped_interface.go │ │ └── wrapped_method.go │ ├── paths │ │ ├── exist.go │ │ ├── project.go │ │ └── relative.go │ └── types │ │ ├── extract.go │ │ ├── interface.go │ │ ├── method.go │ │ └── visitor.go └── testutil │ ├── helpers_test.go │ ├── reflect_helpers.go │ └── reflect_helpers_test.go └── testutil ├── assert ├── asserter.go ├── asserter_test.go └── assertions.go ├── gomega ├── anything_matcher.go ├── anything_matcher_test.go ├── called_matcher.go ├── called_matcher_test.go ├── called_with_matcher.go ├── called_with_matcher_test.go └── helpers_test.go └── require ├── asserter.go └── require.go /.codeclimate.yml: -------------------------------------------------------------------------------- 1 | --- 2 | engines: 3 | golint: 4 | enabled: true 5 | govet: 6 | enabled: true 7 | gofmt: 8 | enabled: true 9 | fixme: 10 | enabled: true 11 | exclude_patterns: 12 | - "examples/" 13 | - "**/*_test.go" 14 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v3 11 | 12 | - name: Setup Go 13 | uses: actions/setup-go@v3 14 | with: 15 | go-version: 'stable' 16 | 17 | - name: Generate 18 | run: go generate ./... 19 | - name: Test 20 | run: go test -race -v ./... 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /go-mockgen 2 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [Unreleased] 4 | 5 | Nothing yet! 6 | 7 | ## [v2.0.2] - 2025-05-01 8 | 9 | - Version bump. 10 | 11 | ## [v2.0.1] - 2024-03-04 12 | 13 | - Changed name of module to `github.com/derision-test/go-mockgen/v2`. 14 | - Updated module to Go 1.22. [#52](https://github.com/derision-test/go-mockgen/pull/52) 15 | 16 | ## v2.0.0 - 2024-03-04 17 | 18 | **PULLED**, as module name was not updated properly. 19 | 20 | ## [v1.3.7] - 2022-11-11 21 | 22 | - Fixed import cycle generated by generic param referencing local type. [#39](https://github.com/derision-test/go-mockgen/pull/39) 23 | 24 | ## [v1.3.6] - 2022-10-26 25 | 26 | - Config file loading now disambiguates interfaces with the same names in multiple search packages. [#36](https://github.com/derision-test/go-mockgen/pull/36) 27 | 28 | ## [v1.3.5] - 2022-10-26 29 | 30 | - No updates. 31 | 32 | ## [v1.3.4] - 2022-08-11 33 | 34 | - Added `NeedDeps` mode to package loader (prevents fatal log from within `golang.org/x/tools/go/packages` in some circumstances). [3ae60a2](https://github.com/derision-test/go-mockgen/commit/3ae60a20c75f7eb1ae85fc6af66f237f5ee1a04d) 35 | 36 | ## [v1.3.3] - 2022-06-09 37 | 38 | - Added support for `include-config-paths` in config file. [#35](https://github.com/derision-test/go-mockgen/pull/35) 39 | 40 | ## [v1.3.2] - 2022-06-09 41 | 42 | ### Added 43 | 44 | - Added support for `sources` in config file. [#33](https://github.com/derision-test/go-mockgen/pull/33) 45 | 46 | ### Fixed 47 | 48 | - Fixed broken `import-path` flag. [#34](https://github.com/derision-test/go-mockgen/pull/34) 49 | 50 | ## [v1.3.1] - 2022-06-06 51 | 52 | - Added `--file-prefix` flag. [#32](https://github.com/derision-test/go-mockgen/pull/32) 53 | 54 | ## [v1.3.0] - 2022-06-06 55 | 56 | ### Added 57 | 58 | - Added support for configuration files. [#31](https://github.com/derision-test/go-mockgen/pull/31) 59 | - Added `--constructor-prefix` flag. [#28](https://github.com/derision-test/go-mockgen/pull/28) 60 | 61 | ## [v1.2.0] - 2022-03-28 62 | 63 | ### Changed 64 | 65 | - Fixed generation of code with inline interface definitions. [#23](https://github.com/derision-test/go-mockgen/pull/23) 66 | - Added basic support for generic interfaces - now requires Go 1.18 or above. [#20](https://github.com/derision-test/go-mockgen/pull/20) 67 | 68 | ## [v1.1.5] - 2022-04-08 69 | 70 | ### Changed 71 | 72 | - Updated x/tools for Go 1.18 support. [#22](https://github.com/derision-test/go-mockgen/pull/22) 73 | 74 | ## [v1.1.4] - 2022-02-01 75 | 76 | ### Changed 77 | 78 | - Fixed generation for nested package on Windows. [#19](https://github.com/derision-test/go-mockgen/pull/19) 79 | - Fixed support for array types in method signatures. [#21](https://github.com/derision-test/go-mockgen/pull/21) 80 | 81 | ## [v1.1.3] - 2022-02-21 82 | 83 | ### Added 84 | 85 | - Added `--exclude`/`-e` flag to support exclusion of target interfaces. [#13](https://github.com/derision-test/go-mockgen/pull/13) 86 | - Added `--for-test` flag. [#14](https://github.com/derision-test/go-mockgen/pull/14) 87 | - Added `NewStrictMockX` constructor. [#16](https://github.com/derision-test/go-mockgen/pull/16) 88 | 89 | ## [v1.1.2] - 2021-06-14 90 | 91 | No significant changes (only corrected version output). 92 | 93 | ## [v1.1.1] - 2021-06-14 94 | 95 | ### Added 96 | 97 | - Added `--goimports` flag. [0f4ed82](https://github.com/derision-test/go-mockgen/commit/0f4ed82247eff5446b885c3ea48f48b870a9ee4a) 98 | 99 | ## [v1.0.0] - 2021-06-14 100 | 101 | ### Added 102 | 103 | - Added support for testify assertions. [#3](https://github.com/derision-test/go-mockgen/pull/3), [#8](https://github.com/derision-test/go-mockgen/pull/8) 104 | 105 | ### Changed 106 | 107 | - Migrated from [efritz/go-mockgen](https://github.com/efritz/go-mockgen). [#1](https://github.com/derision-test/go-mockgen/pull/1) 108 | - We now run `goimports` over rendered files. [096f848](https://github.com/derision-test/go-mockgen/commit/096f848333579e185c8018ff2d17688e4b5f6f27) 109 | - Fixed output paths when directories are generated. [#10](https://github.com/derision-test/go-mockgen/pull/10) 110 | 111 | [Unreleased]: https://github.com/derision-test/go-mockgen/compare/v1.3.7...HEAD 112 | [v1.0.0]: https://github.com/derision-test/go-mockgen/releases/tag/v1.0.0 113 | [v1.1.1]: https://github.com/derision-test/go-mockgen/compare/v1.0.0...v1.1.1 114 | [v1.1.2]: https://github.com/derision-test/go-mockgen/compare/v1.1.1...v1.1.2 115 | [v1.1.3]: https://github.com/derision-test/go-mockgen/compare/v1.1.2...v1.1.3 116 | [v1.1.4]: https://github.com/derision-test/go-mockgen/compare/v1.1.3...v1.1.4 117 | [v1.2.0]: https://github.com/derision-test/go-mockgen/compare/v1.1.4...v1.2.0 118 | [v1.3.0]: https://github.com/derision-test/go-mockgen/compare/v1.2.0...v1.3.0 119 | [v1.3.1]: https://github.com/derision-test/go-mockgen/compare/v1.3.0...v1.3.1 120 | [v1.3.2]: https://github.com/derision-test/go-mockgen/compare/v1.3.1...v1.3.2 121 | [v1.3.3]: https://github.com/derision-test/go-mockgen/compare/v1.3.2...v1.3.3 122 | [v1.3.4]: https://github.com/derision-test/go-mockgen/compare/v1.3.3...v1.3.4 123 | [v1.3.5]: https://github.com/derision-test/go-mockgen/compare/v1.3.4...v1.3.5 124 | [v1.3.6]: https://github.com/derision-test/go-mockgen/compare/v1.3.5...v1.3.6 125 | [v1.3.7]: https://github.com/derision-test/go-mockgen/compare/v1.3.6...v1.3.7 126 | [v2.0.1]: https://github.com/derision-test/go-mockgen/compare/v1.3.7...v2.0.1 127 | [v2.0.2]: https://github.com/derision-test/go-mockgen/compare/v2.0.1...v2.0.2 128 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Eric Fritz 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-mockgen 2 | 3 | [![PkgGoDev](https://pkg.go.dev/badge/badge/github.com/derision-test/go-mockgen.svg)](https://pkg.go.dev/github.com/derision-test/go-mockgen) 4 | [![Build status](https://github.com/derision-test/go-mockgen/actions/workflows/test.yml/badge.svg)](https://github.com/derision-test/go-mockgen/actions/workflows/test.yml) 5 | [![Latest release](https://img.shields.io/github/release/derision-test/go-mockgen.svg)](https://github.com/derision-test/go-mockgen/releases/) 6 | 7 | A mock interface code generator (supports generics as of [v1.2.0](https://github.com/derision-test/go-mockgen/releases/tag/v1.2.0) 🎉). 8 | 9 | ## Generating Mocks 10 | 11 | Install with `go get -u github.com/derision-test/go-mockgen/...`. 12 | 13 | Mocks should be generated via `go generate` and should be regenerated on each update to the target interface. For example, in `gen.go`: 14 | 15 | ```go 16 | package mocks 17 | 18 | //go:generate go-mockgen -f github.com/cache/user/pkg -i Cache -o mock_cache_test.go 19 | ``` 20 | 21 | Depending on how you prefer to structure your code, you can either 22 | 23 | 1. generate mocks next to the implementation (as a sibling or in a sibling `mocks` package), or 24 | 2. generate mocks as needed in test code (generating them into a `_test.go` file). 25 | 26 | ### Flags 27 | 28 | The following flags are defined by the binary. 29 | 30 | | Name | Short Flag | Description | 31 | | ------------------ | ---------- | ----------- | 32 | | package | p | The name of the generated package. Is the name of target directory if dirname or filename is supplied by default. | 33 | | prefix | | A prefix used in the name of each mock struct. Should be TitleCase by convention. | 34 | | constructor-prefix | | A prefix used in the name of each mock constructor function (after the initial `New`/`NewStrict` prefixes). Should be TitleCase by convention. | 35 | | interfaces | i | A list of interfaces to generate given the import paths. | 36 | | exclude | e | A list of interfaces to exclude from generation. | 37 | | filename | o | The target output file. All mocks are written to this file. | 38 | | dirname | d | The target output directory. Each mock will be written to a unique file. | 39 | | force | f | Do not abort if a write to disk would overwrite an existing file. | 40 | | disable-formatting | | Do not run goimports over the rendered files (enabled by default). | 41 | | goimports | | Path to the goimports binary (uses goimports on your PATH by default). | 42 | | for-test | | Append _test suffix to generated package names and file names. | 43 | | file-prefix | | Content that is written at the top of each generated file. | 44 | | build-constraints | | [Build constraints](https://pkg.go.dev/cmd/go#hdr-Build_constraints) that are added to each generated file. | 45 | 46 | ### Configuration file 47 | 48 | A configuration file is also supported. If no command line arguments are supplied, then the file `mockgen.yaml` in the current directory is used for input. The structure of the configuration file is as follows (where each entry in the `mocks` list can supply a value for each flag described above): 49 | 50 | ```yaml 51 | force: true 52 | mocks: 53 | - filename: foo/bar/mock_cache_test.go 54 | path: github.com/usr/pkg/cache 55 | interfaces: 56 | - Cache 57 | - filename: foo/baz/mocks_test.go 58 | # Supports multiple package sources in a single file 59 | sources: 60 | - path: github.com/usr/pkg/timer 61 | interfaces: 62 | - Timer 63 | - path: github.com/usr/pkg/stopwatch 64 | interfaces: 65 | - LapTimer 66 | - Stopwatch 67 | ``` 68 | 69 | The top level of the configuration file may also set the keys `exclude`, `prefix`, `constructor-prefix`, `goimports`, `file-prefix`, `force`, `disable-formatting`, and `for-tests`. Top-level excludes will also be applied to each mock generator entry. The values for interface and constructor prefixes, goimports, generated packag names, and file content prefixes will apply to each mock generator entry source(s) if a value is not set. The remaining boolean values will be true for each mock generator entry if set at the top level (regardless of the setting of each entry). 70 | 71 | To organize long lists of mocks, multiple files can be used, as follows. 72 | 73 | ```yaml 74 | include-config-paths: 75 | - foo.mockgen.yaml 76 | - bar.mockgen.yaml 77 | - baz.mockgen.yaml 78 | mocks: 79 | - filename: foo/bar/mock_cache_test.go 80 | path: github.com/usr/pkg/cache 81 | interfaces: 82 | - Cache 83 | ``` 84 | 85 | This file results in the mocks defined in the `mockgen.yaml` file, concatenated with the mocks defined in `{foo,bar,baz}.mockgen.yaml`. The included config paths do not have global-level configuration and should encode a top-level mocks array, e.g., 86 | 87 | ```yaml 88 | - filename: mock_cache_test.go 89 | path: github.com/usr/pkg/cache 90 | interfaces: 91 | - Cache 92 | - filename: mock_timer_test.go 93 | path: github.com/usr/pkg/timer 94 | interfaces: 95 | - Timer 96 | - filename: mock_stopwatch_test.go 97 | path: github.com/usr/pkg/stopwatch 98 | interfaces: 99 | - LapTimer 100 | - Stopwatch 101 | ``` 102 | 103 | ## Testing with Mocks 104 | 105 | A mock value fulfills all of the methods of the target interface from which it was generated. Unless overridden, all methods of the mock will return zero values for everything. To override a specific method, you can set its `hook` or its `return values`. 106 | 107 | A hook is a method that is called on each invocation and allows the test to specify complex behaviors in the mocked interface (conditionally returning values, synchronizing on external state, etc,). The default hook for a method is set with the `SetDefaultHook` method. 108 | 109 | ```go 110 | func TestCache(t *testing.T) { 111 | cache := mocks.NewMockCache[string, int]() 112 | cache.GetFunc.SetDefaultHook(func (key string) (int, bool) { 113 | if key == "expected" { 114 | return 42, true 115 | } 116 | return nil, false 117 | }) 118 | 119 | testSubject := NewThingThatNeedsCache(cache) 120 | // ... 121 | } 122 | ``` 123 | 124 | In the cases where you don't need specific behaviors but just need to return some data, the setup gets a bit easier with `SetDefaultReturn`. 125 | 126 | ```go 127 | func TestCache(t *testing.T) { 128 | cache := mocks.NewMockCache[string, int]() 129 | cache.GetFunc.SetDefaultReturn(42, true) 130 | 131 | testSubject := NewThingThatNeedsCache(cache) 132 | // ... 133 | } 134 | ``` 135 | 136 | Hook and return values can also be _stacked_ when your test can anticipate multiple calls to the same function. Pushing a hook or a return value will set the hook or return value for _one_ invocation of the mocked method. Once this hook or return value has been spent, it will be removed from the queue. Hooks and return values can be interleaved. If the queue is empty, the default hook will be invoked (or the default return values returned). 137 | 138 | The following example will test a cache that returns values 50, 51, and 52 in sequence, then panic if there is an unexpected fourth call. 139 | 140 | ```go 141 | func TestCache(t *testing.T) { 142 | cache := mocks.NewMockCache[string, int]() 143 | cache.GetFunc.SetDefaultHook(func (key string) (int, bool) { 144 | panic("unexpected call") 145 | }) 146 | cache.GetFunc.PushReturn(50, true) 147 | cache.GetFunc.PushReturn(51, true) 148 | cache.GetFunc.PushReturn(52, true) 149 | 150 | testSubject := NewThingThatNeedsCache(cache) 151 | // ... 152 | } 153 | ``` 154 | 155 | Note that this "panic by default" behavior is given automatically when using the `NewStrictMockCache` constructor, also automatically generated for all mocks. 156 | 157 | ### Assertions 158 | 159 | Mocks track their invocations and can be retrieved via the `History` method. Structs are generated for each method type containing fields for each argument and result type. Raw assertions can be performed on these values. 160 | 161 | ```go 162 | allCalls := cache.GetFunc.History() 163 | allCalls[0].Arg0 // key (type string) 164 | allCalls[0].Result0 // value (type int) 165 | allCalls[0].Result1 // exists flag (type bool) 166 | ``` 167 | 168 | ### Testify integration 169 | 170 | This library also contains an API that integrates with the style of [Testify](https://github.com/stretchr/testify) assertions. 171 | 172 | To use the assertions, import the assert and require packages by name. 173 | 174 | ```go 175 | import ( 176 | mockassert "github.com/derision-test/go-mockgen/v2/testutil/assert" 177 | mockrequire "github.com/derision-test/go-mockgen/v2/testutil/require" 178 | ) 179 | ``` 180 | 181 | The following methods are defined in both packages. 182 | 183 | - `Called(t, mockFn, msgAndArgs...)` 184 | - `NotCalled(t, mockFn, msgAndArgs...)` 185 | - `CalledOnce(t, mockFn, msgAndArgs...)` 186 | - `CalledN(t, mockFn, n, msgAndArgs...)` 187 | - `CalledWith(t, mockFn, msgAndArgs...)` 188 | - `NotCalledWith(t, mockFn, msgAndArgs...)` 189 | - `CalledOnceWith(t, mockFn, msgAndArgs...)` 190 | - `CalledNWith(t, mockFn, n, msgAndArgs...)` 191 | - `CalledAtNWith(t, mockFn, n, msgAndArgs...)` 192 | 193 | These methods can be used as follows. 194 | 195 | ```go 196 | // cache.Get called 3 times 197 | mockassert.CalledN(t, cache.GetFunc, 3) 198 | 199 | // Ensure cache.Set("foo", 42) was called 200 | mockassert.CalledWith(cache.SetFunc, mockassert.Values("foo", 42)) 201 | 202 | // Ensure cache.Set("foo", _) was called 203 | mockassert.CalledWith(cache.SetFunc, mockassert.Values("foo", mockassert.Skip)) 204 | ``` 205 | 206 | ### Gomega integration 207 | 208 | This library also contains a set of [Gomega](https://onsi.github.io/gomega/) matchers which simplify assertions over a mocked method's call history. 209 | 210 | To use the matchers, import the matchers package anonymously. 211 | 212 | ```go 213 | import . "github.com/derision-test/go-mockgen/v2/testutil/gomega" 214 | ``` 215 | 216 | The following matchers are defined. 217 | 218 | - `BeCalled()` 219 | - `BeCalledN(n)` 220 | - `BeCalledOnce()` 221 | - `BeCalledWith(args...)` 222 | - `BeCalledNWith(args...)` 223 | - `BeCalledOnceWith(args...)` 224 | - `BeAnything()` 225 | 226 | These matchers can be used as follows. 227 | 228 | ```go 229 | // cache.Get called 3 times 230 | Expect(cache.GetFunc).To(BeCalledN(3)) 231 | 232 | // Ensure cache.Set("foo", "bar") was called 233 | Expect(cache.SetFunc).To(BeCalledWith("foo", "bar")) 234 | 235 | // Ensure cache.Set("foo", _) was called 236 | Expect(cache.SetFunc).To(BeCalledWith("foo", BeAnything())) 237 | ``` 238 | -------------------------------------------------------------------------------- /cmd/go-mockgen/args.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | "path/filepath" 8 | "regexp" 9 | "strings" 10 | 11 | "github.com/alecthomas/kingpin" 12 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/consts" 13 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/generation" 14 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/paths" 15 | "gopkg.in/yaml.v3" 16 | ) 17 | 18 | func parseAndValidateOptions() ([]*generation.Options, error) { 19 | allOptions, err := parseOptions() 20 | if err != nil { 21 | return nil, err 22 | } 23 | 24 | validators := []func(opts *generation.Options) (bool, error){ 25 | validateOutputPaths, 26 | validateOptions, 27 | } 28 | 29 | for _, opts := range allOptions { 30 | for _, f := range validators { 31 | if fatal, err := f(opts); err != nil { 32 | if !fatal { 33 | kingpin.Fatalf("%s, try --help", err.Error()) 34 | } 35 | 36 | return nil, err 37 | } 38 | } 39 | } 40 | 41 | return allOptions, nil 42 | } 43 | 44 | func parseOptions() ([]*generation.Options, error) { 45 | if len(os.Args) == 1 { 46 | return parseManifest() 47 | } 48 | 49 | opts, err := parseFlags() 50 | if err != nil { 51 | return nil, err 52 | } 53 | 54 | return []*generation.Options{opts}, nil 55 | } 56 | 57 | func parseFlags() (*generation.Options, error) { 58 | opts := &generation.Options{ 59 | PackageOptions: []generation.PackageOptions{ 60 | { 61 | ImportPaths: []string{}, 62 | Interfaces: []string{}, 63 | }, 64 | }, 65 | } 66 | 67 | app := kingpin.New(consts.Name, consts.Description).Version(consts.Version) 68 | app.UsageWriter(os.Stdout) 69 | 70 | app.Arg("path", "The import paths used to search for eligible interfaces").Required().StringsVar(&opts.PackageOptions[0].ImportPaths) 71 | app.Flag("package", "The name of the generated package. It will be inferred from the output options by default.").Short('p').StringVar(&opts.ContentOptions.PkgName) 72 | app.Flag("interfaces", "A list of target interfaces to generate defined in the given the import paths.").Short('i').StringsVar(&opts.PackageOptions[0].Interfaces) 73 | app.Flag("exclude", "A list of interfaces to exclude from generation. Mocks for all other exported interfaces defined in the given import paths are generated.").Short('e').StringsVar(&opts.PackageOptions[0].Exclude) 74 | app.Flag("dirname", "The target output directory. Each mock will be written to a unique file.").Short('d').StringVar(&opts.OutputOptions.OutputDir) 75 | app.Flag("filename", "The target output file. All mocks are written to this file.").Short('o').StringVar(&opts.OutputOptions.OutputFilename) 76 | app.Flag("import-path", "The import path of the generated package. It will be inferred from the target directory by default.").StringVar(&opts.ContentOptions.OutputImportPath) 77 | app.Flag("prefix", "A prefix used in the name of each mock struct. Should be TitleCase by convention.").StringVar(&opts.ContentOptions.Prefix) 78 | app.Flag("constructor-prefix", "A prefix used in the name of each mock constructor function (after the initial `New`/`NewStrict` prefixes). Should be TitleCase by convention.").StringVar(&opts.ContentOptions.ConstructorPrefix) 79 | app.Flag("force", "Do not abort if a write to disk would overwrite an existing file.").Short('f').BoolVar(&opts.OutputOptions.Force) 80 | app.Flag("disable-formatting", "Do not run goimports over the rendered files.").BoolVar(&opts.OutputOptions.DisableFormatting) 81 | app.Flag("goimports", "Path to the goimports binary.").Default("goimports").StringVar(&opts.OutputOptions.GoImportsBinary) 82 | app.Flag("for-test", "Append _test suffix to generated package names and file names.").Default("false").BoolVar(&opts.OutputOptions.ForTest) 83 | app.Flag("file-prefix", "Content that is written at the top of each generated file.").StringVar(&opts.ContentOptions.FilePrefix) 84 | app.Flag("build-constraints", "Build constraints that are added to each generated file.").StringVar(&opts.ContentOptions.BuildConstraints) 85 | 86 | if _, err := app.Parse(os.Args[1:]); err != nil { 87 | return nil, err 88 | } 89 | 90 | return opts, nil 91 | } 92 | 93 | func parseManifest() ([]*generation.Options, error) { 94 | payload, err := readManifest() 95 | if err != nil { 96 | return nil, err 97 | } 98 | 99 | allOptions := make([]*generation.Options, 0, len(payload.Mocks)) 100 | for _, opts := range payload.Mocks { 101 | // Mix 102 | opts.Exclude = append(opts.Exclude, payload.Exclude...) 103 | 104 | // Set if not overwritten in this entry 105 | if opts.Prefix == "" { 106 | opts.Prefix = payload.Prefix 107 | } 108 | if opts.ConstructorPrefix == "" { 109 | opts.ConstructorPrefix = payload.ConstructorPrefix 110 | } 111 | if opts.Goimports == "" { 112 | opts.Goimports = payload.Goimports 113 | } 114 | if opts.FilePrefix == "" { 115 | opts.FilePrefix = payload.FilePrefix 116 | } 117 | 118 | // Overwrite 119 | if payload.Force { 120 | opts.Force = true 121 | } 122 | if payload.DisableFormatting { 123 | opts.DisableFormatting = true 124 | } 125 | if payload.ForTest { 126 | opts.ForTest = true 127 | } 128 | 129 | // Canonicalization 130 | paths := opts.Paths 131 | if opts.Path != "" { 132 | paths = append(paths, opts.Path) 133 | } 134 | 135 | // Defaults 136 | if opts.Goimports == "" { 137 | opts.Goimports = "goimports" 138 | } 139 | 140 | var packageOptions []generation.PackageOptions 141 | if len(opts.Sources) > 0 { 142 | if len(opts.Paths) > 0 || len(opts.Interfaces) > 0 { 143 | return nil, fmt.Errorf("sources and path/paths/interfaces are mutually exclusive") 144 | } 145 | 146 | for _, source := range opts.Sources { 147 | // Canonicalization 148 | paths := source.Paths 149 | if source.Path != "" { 150 | paths = append(paths, source.Path) 151 | } 152 | 153 | packageOptions = append(packageOptions, generation.PackageOptions{ 154 | ImportPaths: paths, 155 | Interfaces: source.Interfaces, 156 | Exclude: source.Exclude, 157 | Prefix: source.Prefix, 158 | }) 159 | } 160 | } else { 161 | packageOptions = append(packageOptions, generation.PackageOptions{ 162 | ImportPaths: paths, 163 | Interfaces: opts.Interfaces, 164 | Exclude: opts.Exclude, 165 | Prefix: opts.Prefix, 166 | }) 167 | } 168 | 169 | allOptions = append(allOptions, &generation.Options{ 170 | PackageOptions: packageOptions, 171 | OutputOptions: generation.OutputOptions{ 172 | OutputDir: opts.Dirname, 173 | OutputFilename: opts.Filename, 174 | Force: opts.Force, 175 | DisableFormatting: opts.DisableFormatting, 176 | GoImportsBinary: opts.Goimports, 177 | ForTest: opts.ForTest, 178 | }, 179 | ContentOptions: generation.ContentOptions{ 180 | PkgName: opts.Package, 181 | OutputImportPath: opts.ImportPath, 182 | Prefix: opts.Prefix, 183 | ConstructorPrefix: opts.ConstructorPrefix, 184 | FilePrefix: opts.FilePrefix, 185 | }, 186 | }) 187 | } 188 | 189 | return allOptions, nil 190 | } 191 | 192 | type yamlPayload struct { 193 | // Meta options 194 | IncludeConfigPaths []string `yaml:"include-config-paths"` 195 | 196 | // Global options 197 | Exclude []string `yaml:"exclude"` 198 | Prefix string `yaml:"prefix"` 199 | ConstructorPrefix string `yaml:"constructor-prefix"` 200 | Force bool `yaml:"force"` 201 | DisableFormatting bool `yaml:"disable-formatting"` 202 | Goimports string `yaml:"goimports"` 203 | ForTest bool `yaml:"for-test"` 204 | FilePrefix string `yaml:"file-prefix"` 205 | 206 | Mocks []yamlMock `yaml:"mocks"` 207 | } 208 | 209 | type yamlMock struct { 210 | Path string `yaml:"path"` 211 | Paths []string `yaml:"paths"` 212 | Sources []yamlSource `yaml:"sources"` 213 | Package string `yaml:"package"` 214 | Interfaces []string `yaml:"interfaces"` 215 | Exclude []string `yaml:"exclude"` 216 | Dirname string `yaml:"dirname"` 217 | Filename string `yaml:"filename"` 218 | ImportPath string `yaml:"import-path"` 219 | Prefix string `yaml:"prefix"` 220 | ConstructorPrefix string `yaml:"constructor-prefix"` 221 | Force bool `yaml:"force"` 222 | DisableFormatting bool `yaml:"disable-formatting"` 223 | Goimports string `yaml:"goimports"` 224 | ForTest bool `yaml:"for-test"` 225 | FilePrefix string `yaml:"file-prefix"` 226 | } 227 | 228 | type yamlSource struct { 229 | Path string `yaml:"path"` 230 | Paths []string `yaml:"paths"` 231 | Interfaces []string `yaml:"interfaces"` 232 | Exclude []string `yaml:"exclude"` 233 | Prefix string `yaml:"prefix"` 234 | } 235 | 236 | func readManifest() (yamlPayload, error) { 237 | contents, err := os.ReadFile("mockgen.yaml") 238 | if err != nil { 239 | return yamlPayload{}, err 240 | } 241 | 242 | var payload yamlPayload 243 | if err := yaml.Unmarshal(contents, &payload); err != nil { 244 | return yamlPayload{}, err 245 | } 246 | 247 | for _, path := range payload.IncludeConfigPaths { 248 | payload, err = readIncludeConfig(payload, path) 249 | if err != nil { 250 | return yamlPayload{}, err 251 | } 252 | } 253 | 254 | return payload, nil 255 | } 256 | 257 | func readIncludeConfig(payload yamlPayload, path string) (yamlPayload, error) { 258 | contents, err := os.ReadFile(path) 259 | if err != nil { 260 | return yamlPayload{}, err 261 | } 262 | 263 | var mocks []yamlMock 264 | if err := yaml.Unmarshal(contents, &mocks); err != nil { 265 | return yamlPayload{}, err 266 | } 267 | 268 | payload.Mocks = append(payload.Mocks, mocks...) 269 | return payload, nil 270 | } 271 | 272 | func validateOutputPaths(opts *generation.Options) (bool, error) { 273 | wd, err := os.Getwd() 274 | if err != nil { 275 | return true, fmt.Errorf("failed to get current directory") 276 | } 277 | 278 | if opts.OutputOptions.OutputFilename == "" && opts.OutputOptions.OutputDir == "" { 279 | opts.OutputOptions.OutputDir = wd 280 | } 281 | 282 | if opts.OutputOptions.OutputFilename != "" && opts.OutputOptions.OutputDir != "" { 283 | return false, fmt.Errorf("dirname and filename are mutually exclusive") 284 | } 285 | 286 | if opts.OutputOptions.OutputFilename != "" { 287 | opts.OutputOptions.OutputDir = path.Dir(opts.OutputOptions.OutputFilename) 288 | opts.OutputOptions.OutputFilename = path.Base(opts.OutputOptions.OutputFilename) 289 | } 290 | 291 | if err := paths.EnsureDirExists(opts.OutputOptions.OutputDir); err != nil { 292 | return true, fmt.Errorf( 293 | "failed to make output directory %s: %s", 294 | opts.OutputOptions.OutputDir, 295 | err.Error(), 296 | ) 297 | } 298 | 299 | if opts.OutputOptions.OutputDir, err = cleanPath(opts.OutputOptions.OutputDir); err != nil { 300 | return true, err 301 | } 302 | 303 | return false, nil 304 | } 305 | 306 | var goIdentifierPattern = regexp.MustCompile("^[A-Za-z]([A-Za-z0-9_]*)?$") 307 | 308 | func validateOptions(opts *generation.Options) (bool, error) { 309 | for _, packageOpts := range opts.PackageOptions { 310 | if len(packageOpts.Interfaces) != 0 && len(packageOpts.Exclude) != 0 { 311 | return false, fmt.Errorf("interface lists and exclude lists are mutually exclusive") 312 | } 313 | 314 | if packageOpts.Prefix != "" && !goIdentifierPattern.Match([]byte(packageOpts.Prefix)) { 315 | return false, fmt.Errorf("prefix `%s` is illegal", packageOpts.Prefix) 316 | } 317 | } 318 | 319 | if opts.ContentOptions.OutputImportPath == "" { 320 | path, ok := paths.InferImportPath(opts.OutputOptions.OutputDir) 321 | if !ok { 322 | return false, fmt.Errorf("could not infer output import path") 323 | } 324 | 325 | opts.ContentOptions.OutputImportPath = path 326 | } 327 | 328 | if opts.ContentOptions.PkgName == "" { 329 | opts.ContentOptions.PkgName = opts.ContentOptions.OutputImportPath[strings.LastIndex(opts.ContentOptions.OutputImportPath, string(os.PathSeparator))+1:] 330 | } 331 | 332 | if !goIdentifierPattern.Match([]byte(opts.ContentOptions.PkgName)) { 333 | return false, fmt.Errorf("package name `%s` is illegal", opts.ContentOptions.PkgName) 334 | } 335 | 336 | if opts.ContentOptions.Prefix != "" && !goIdentifierPattern.Match([]byte(opts.ContentOptions.Prefix)) { 337 | return false, fmt.Errorf("prefix `%s` is illegal", opts.ContentOptions.Prefix) 338 | } 339 | 340 | if opts.ContentOptions.ConstructorPrefix != "" && !goIdentifierPattern.Match([]byte(opts.ContentOptions.ConstructorPrefix)) { 341 | return false, fmt.Errorf("constructor-`prefix `%s` is illegal", opts.ContentOptions.ConstructorPrefix) 342 | } 343 | 344 | return false, nil 345 | } 346 | 347 | func cleanPath(path string) (cleaned string, err error) { 348 | if path, err = filepath.Abs(path); err != nil { 349 | return "", err 350 | } 351 | 352 | if path, err = filepath.EvalSymlinks(path); err != nil { 353 | return "", err 354 | } 355 | 356 | return path, nil 357 | } 358 | -------------------------------------------------------------------------------- /cmd/go-mockgen/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strings" 7 | 8 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/generation" 9 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 10 | "golang.org/x/tools/go/packages" 11 | ) 12 | 13 | func init() { 14 | log.SetFlags(0) 15 | log.SetPrefix("go-mockgen: ") 16 | } 17 | 18 | func main() { 19 | if err := mainErr(); err != nil { 20 | message := fmt.Sprintf("error: %s\n", err.Error()) 21 | 22 | if solvableError, ok := err.(solvableError); ok { 23 | message += "\nPossible solutions:\n" 24 | 25 | for _, hint := range solvableError.Solutions() { 26 | message += fmt.Sprintf(" - %s\n", hint) 27 | } 28 | 29 | message += "\n" 30 | } 31 | 32 | log.Fatalf(message) 33 | } 34 | } 35 | 36 | type solvableError interface { 37 | Solutions() []string 38 | } 39 | 40 | func mainErr() error { 41 | allOptions, err := parseAndValidateOptions() 42 | if err != nil { 43 | return err 44 | } 45 | 46 | var importPaths []string 47 | for _, opts := range allOptions { 48 | for _, packageOpts := range opts.PackageOptions { 49 | importPaths = append(importPaths, packageOpts.ImportPaths...) 50 | } 51 | } 52 | 53 | log.Printf("loading data for %d packages\n", len(importPaths)) 54 | 55 | pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName | packages.NeedImports | packages.NeedSyntax | packages.NeedTypes | packages.NeedDeps}, importPaths...) 56 | if err != nil { 57 | return fmt.Errorf("could not load packages %s (%s)", strings.Join(importPaths, ","), err.Error()) 58 | } 59 | 60 | for _, opts := range allOptions { 61 | typePackageOpts := make([]types.PackageOptions, 0, len(opts.PackageOptions)) 62 | for _, packageOpts := range opts.PackageOptions { 63 | typePackageOpts = append(typePackageOpts, types.PackageOptions(packageOpts)) 64 | } 65 | 66 | ifaces, err := types.Extract(pkgs, typePackageOpts) 67 | if err != nil { 68 | return err 69 | } 70 | 71 | nameMap := make(map[string]struct{}, len(ifaces)) 72 | for _, t := range ifaces { 73 | nameMap[strings.ToLower(t.Name)] = struct{}{} 74 | } 75 | 76 | for _, packageOpts := range opts.PackageOptions { 77 | for _, name := range packageOpts.Interfaces { 78 | if _, ok := nameMap[strings.ToLower(name)]; !ok { 79 | return fmt.Errorf("type '%s' not found in supplied import paths", name) 80 | } 81 | } 82 | } 83 | 84 | if err := generation.Generate(ifaces, opts); err != nil { 85 | return err 86 | } 87 | } 88 | 89 | return nil 90 | } 91 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/derision-test/go-mockgen/v2 2 | 3 | go 1.22 4 | 5 | require ( 6 | github.com/alecthomas/kingpin v2.2.6+incompatible 7 | github.com/dave/jennifer v1.5.0 8 | github.com/dustin/go-humanize v1.0.0 9 | github.com/mitchellh/go-wordwrap v1.0.1 10 | github.com/onsi/gomega v1.19.0 11 | github.com/stretchr/testify v1.7.1 12 | golang.org/x/tools v0.18.0 13 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b 14 | ) 15 | 16 | require ( 17 | github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect 18 | github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect 19 | github.com/davecgh/go-spew v1.1.1 // indirect 20 | github.com/kr/pretty v0.1.0 // indirect 21 | github.com/pmezard/go-difflib v1.0.0 // indirect 22 | golang.org/x/mod v0.15.0 // indirect 23 | golang.org/x/net v0.21.0 // indirect 24 | golang.org/x/text v0.14.0 // indirect 25 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect 26 | gopkg.in/yaml.v2 v2.4.0 // indirect 27 | ) 28 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/alecthomas/kingpin v2.2.6+incompatible h1:5svnBTFgJjZvGKyYBtMB0+m5wvrbUHiqye8wRJMlnYI= 2 | github.com/alecthomas/kingpin v2.2.6+incompatible/go.mod h1:59OFYbFVLKQKq+mqrL6Rw5bR0c3ACQaawgXx0QYndlE= 3 | github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= 4 | github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= 5 | github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAuRjVTiNNhvNRfY2Wxp9nhfyel4rklc= 6 | github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= 7 | github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14/go.mod h1:Sth2QfxfATb/nW4EsrSi2KyJmbcniZ8TgTaji17D6ms= 8 | github.com/dave/brenda v1.1.0/go.mod h1:4wCUr6gSlu5/1Tk7akE5X7UorwiQ8Rij0SKH3/BGMOM= 9 | github.com/dave/courtney v0.3.0/go.mod h1:BAv3hA06AYfNUjfjQr+5gc6vxeBVOupLqrColj+QSD8= 10 | github.com/dave/gopackages v0.0.0-20170318123100-46e7023ec56e/go.mod h1:i00+b/gKdIDIxuLDFob7ustLAVqhsZRk2qVZrArELGQ= 11 | github.com/dave/jennifer v1.5.0 h1:HmgPN93bVDpkQyYbqhCHj5QlgvUkvEOzMyEvKLgCRrg= 12 | github.com/dave/jennifer v1.5.0/go.mod h1:4MnyiFIlZS3l5tSDn8VnzE6ffAhYBMB2SZntBsZGUok= 13 | github.com/dave/kerr v0.0.0-20170318121727-bc25dd6abe8e/go.mod h1:qZqlPyPvfsDJt+3wHJ1EvSXDuVjFTK0j2p/ca+gtsb8= 14 | github.com/dave/patsy v0.0.0-20210517141501-957256f50cba/go.mod h1:qfR88CgEGLoiqDaE+xxDCi5QA5v4vUoW0UCX2Nd5Tlc= 15 | github.com/dave/rebecca v0.9.1/go.mod h1:N6XYdMD/OKw3lkF3ywh8Z6wPGuwNFDNtWYEMFWEmXBA= 16 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 17 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 18 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 19 | github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= 20 | github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= 21 | github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= 22 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 23 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 24 | github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= 25 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 26 | github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= 27 | github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= 28 | github.com/onsi/ginkgo/v2 v2.1.3 h1:e/3Cwtogj0HA+25nMP1jCMDIf8RtRYbGwGGuBIFztkc= 29 | github.com/onsi/ginkgo/v2 v2.1.3/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c= 30 | github.com/onsi/gomega v1.19.0 h1:4ieX6qQjPP/BfC3mpsAtIGGlxTWPeA3Inl/7DtXw1tw= 31 | github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= 32 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 33 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 34 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 35 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 36 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 37 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 38 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 39 | github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 40 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 41 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 42 | golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= 43 | golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8= 44 | golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 45 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 46 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 47 | golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 48 | golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= 49 | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= 50 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 51 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 52 | golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= 53 | golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 54 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 55 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 56 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 57 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 58 | golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 59 | golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= 60 | golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 61 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 62 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 63 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 64 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 65 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 66 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 67 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 68 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 69 | golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= 70 | golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ= 71 | golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg= 72 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 73 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 74 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 75 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 76 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= 77 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 78 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 79 | gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= 80 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 81 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 82 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= 83 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 84 | -------------------------------------------------------------------------------- /internal/integration/gen.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | //go:generate go run ../../cmd/go-mockgen ./testdata -f -d ./testdata --disable-formatting 4 | //go:generate go run ../../cmd/go-mockgen ./testdata -f -d ./testdata/mocks --disable-formatting 5 | -------------------------------------------------------------------------------- /internal/integration/gomega_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/derision-test/go-mockgen/v2/internal/integration/testdata" 8 | "github.com/derision-test/go-mockgen/v2/internal/integration/testdata/mocks" 9 | . "github.com/derision-test/go-mockgen/v2/testutil/gomega" 10 | . "github.com/onsi/gomega" 11 | ) 12 | 13 | func TestGomegaCalls(t *testing.T) { 14 | RegisterTestingT(t) 15 | 16 | mock := mocks.NewMockClient() 17 | Expect(mock.CloseFunc).NotTo(BeCalled()) 18 | Expect(mock.Close()).To(BeNil()) 19 | Expect(mock.CloseFunc).To(BeCalled()) 20 | Expect(mock.CloseFunc).To(BeCalledOnce()) 21 | } 22 | 23 | func TestGomegaCallsWithArgs(t *testing.T) { 24 | RegisterTestingT(t) 25 | 26 | mock := mocks.NewMockClient() 27 | mock.Do("foo") 28 | Expect(mock.DoFunc).To(BeCalled()) 29 | Expect(mock.DoFunc).To(BeCalledOnce()) 30 | Expect(mock.DoFunc).To(BeCalledWith("foo")) 31 | Expect(mock.DoFunc).NotTo(BeCalledWith("bar")) 32 | } 33 | 34 | func TestGomegaCallsWithVariadicArgs(t *testing.T) { 35 | RegisterTestingT(t) 36 | 37 | mock := mocks.NewMockClient() 38 | mock.DoArgs("foo", 1, 2, 3) 39 | Expect(mock.DoArgsFunc).To(BeCalledWith("foo", 1, 2, 3)) 40 | Expect(mock.DoArgsFunc).To(BeCalledWith(Equal("foo"), Equal(1), Equal(2), Equal(3))) 41 | 42 | mock.DoArgs("bar", 42) 43 | mock.DoArgs("baz") 44 | Expect(mock.DoArgsFunc).To(BeCalledN(3)) 45 | Expect(mock.DoArgsFunc).To(BeCalledNWith(2, ContainSubstring("a"))) 46 | 47 | // Mismatched variadic arg 48 | Expect(mock.DoArgsFunc).NotTo(BeCalledWith("baz", BeAnything())) 49 | } 50 | 51 | func TestGomegaPushHook(t *testing.T) { 52 | RegisterTestingT(t) 53 | 54 | child1 := mocks.NewMockChild() 55 | child2 := mocks.NewMockChild() 56 | child3 := mocks.NewMockChild() 57 | parent := mocks.NewMockParent() 58 | 59 | parent.GetChildFunc.PushHook(func(i int) (testdata.Child, error) { return child1, nil }) 60 | parent.GetChildFunc.PushHook(func(i int) (testdata.Child, error) { return child2, nil }) 61 | parent.GetChildFunc.PushHook(func(i int) (testdata.Child, error) { return child3, nil }) 62 | parent.GetChildFunc.SetDefaultHook(func(i int) (testdata.Child, error) { 63 | return nil, fmt.Errorf("uh-oh") 64 | }) 65 | 66 | for _, expected := range []interface{}{child1, child2, child3} { 67 | Expect(parent.GetChild(0)).To(Equal(expected)) 68 | } 69 | 70 | _, err := parent.GetChild(0) 71 | Expect(err).To(MatchError("uh-oh")) 72 | } 73 | 74 | func TestGomegaSetDefaultReturn(t *testing.T) { 75 | RegisterTestingT(t) 76 | 77 | parent := mocks.NewMockParent() 78 | parent.GetChildFunc.SetDefaultReturn(nil, fmt.Errorf("uh-oh")) 79 | _, err := parent.GetChild(0) 80 | Expect(err).To(MatchError("uh-oh")) 81 | } 82 | 83 | func TestGomegaPushReturn(t *testing.T) { 84 | RegisterTestingT(t) 85 | 86 | parent := mocks.NewMockParent() 87 | parent.GetChildrenFunc.PushReturn([]testdata.Child{nil}) 88 | parent.GetChildrenFunc.PushReturn([]testdata.Child{nil, nil}) 89 | parent.GetChildrenFunc.PushReturn([]testdata.Child{nil, nil, nil}) 90 | 91 | Expect(parent.GetChildren()).To(HaveLen(1)) 92 | Expect(parent.GetChildren()).To(HaveLen(2)) 93 | Expect(parent.GetChildren()).To(HaveLen(3)) 94 | Expect(parent.GetChildren()).To(HaveLen(0)) 95 | } 96 | 97 | func TestGomegaGenerics(t *testing.T) { 98 | RegisterTestingT(t) 99 | 100 | mock := mocks.NewMockI2[string, int]() 101 | mock.M2Func.SetDefaultReturn(42) 102 | Expect(mock.M2("foo")).To(Equal(42)) 103 | Expect(mock.M2Func).To(BeCalledOnceWith("foo")) 104 | } 105 | -------------------------------------------------------------------------------- /internal/integration/strict_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/derision-test/go-mockgen/v2/internal/integration/testdata/mocks" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestStrictConstructor(t *testing.T) { 12 | // All invocations panic by default 13 | mock := mocks.NewStrictMockRetrier() 14 | 15 | assert.Panics(t, func() { 16 | _ = mock.Retry(context.Background(), func() error { 17 | return nil 18 | }) 19 | }) 20 | 21 | // Should not panic if overwritten 22 | mock.RetryFunc.SetDefaultReturn(nil) 23 | 24 | assert.NotPanics(t, func() { 25 | _ = mock.Retry(context.Background(), func() error { 26 | return nil 27 | }) 28 | }) 29 | } 30 | -------------------------------------------------------------------------------- /internal/integration/testdata/.gitignore: -------------------------------------------------------------------------------- 1 | *_mock.go 2 | -------------------------------------------------------------------------------- /internal/integration/testdata/complex_types.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | type InterfaceType interface { 4 | ComplexParam(interface{ M(int) bool }) 5 | CopmlexResult() (interface{ M(int) bool }, error) 6 | } 7 | -------------------------------------------------------------------------------- /internal/integration/testdata/empty.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | type Empty interface{} 4 | -------------------------------------------------------------------------------- /internal/integration/testdata/generics.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | type I1[T any] interface { 4 | M1(T) 5 | } 6 | 7 | type I2[T1, T2 any] interface { 8 | I1[T1] 9 | M2(T1) T2 10 | } 11 | 12 | type SignedIntegerConstraint interface { 13 | ~int | ~int8 | ~int16 | ~int32 | ~int64 14 | } 15 | 16 | type I3[T SignedIntegerConstraint] interface { 17 | M3(I1[T]) int 18 | } 19 | 20 | type I4[T interface{ ~string }] interface { 21 | M4(I1[T]) int 22 | } 23 | 24 | type fooer[T any] interface { 25 | Foo() T 26 | } 27 | -------------------------------------------------------------------------------- /internal/integration/testdata/mocks/mocks.go: -------------------------------------------------------------------------------- 1 | package mocks 2 | -------------------------------------------------------------------------------- /internal/integration/testdata/reference.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | import "context" 4 | 5 | type Retrier interface { 6 | Retry(ctx context.Context, command Command) error 7 | } 8 | 9 | type Command func() error 10 | -------------------------------------------------------------------------------- /internal/integration/testdata/relation.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | type Parent interface { 4 | AddChild(c Child) 5 | GetChildren() []Child 6 | GetChild(i int) (Child, error) 7 | } 8 | 9 | type Child interface { 10 | Parent() Parent 11 | } 12 | -------------------------------------------------------------------------------- /internal/integration/testdata/unexported.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | type unexported interface { 4 | String() string 5 | } 6 | -------------------------------------------------------------------------------- /internal/integration/testdata/variadic_args.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | type Client interface { 4 | Close() error 5 | Do(command string) (interface{}, error) 6 | DoArgs(command string, args ...interface{}) (interface{}, error) 7 | } 8 | -------------------------------------------------------------------------------- /internal/integration/testdata/variadic_non_interface.go: -------------------------------------------------------------------------------- 1 | package testdata 2 | 3 | type OptionValidator interface { 4 | Validate(options ...Option) error 5 | } 6 | 7 | type Option struct{} 8 | -------------------------------------------------------------------------------- /internal/integration/testify_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/derision-test/go-mockgen/v2/internal/integration/testdata" 9 | "github.com/derision-test/go-mockgen/v2/internal/integration/testdata/mocks" 10 | mockassert "github.com/derision-test/go-mockgen/v2/testutil/assert" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestTestifyCalls(t *testing.T) { 15 | mock := mocks.NewMockClient() 16 | mockassert.NotCalled(t, mock.CloseFunc) 17 | assert.Nil(t, mock.Close()) 18 | mockassert.Called(t, mock.CloseFunc) 19 | mockassert.CalledOnce(t, mock.CloseFunc) 20 | } 21 | 22 | func TestTestifyCallsWithArgs(t *testing.T) { 23 | mock := mocks.NewMockClient() 24 | mock.Do("foo") 25 | mockassert.Called(t, mock.DoFunc) 26 | mockassert.CalledOnce(t, mock.DoFunc) 27 | mockassert.CalledWith(t, mock.DoFunc, mockassert.Values("foo")) 28 | mockassert.NotCalledWith(t, mock.DoFunc, mockassert.Values("bar")) 29 | } 30 | 31 | func TestTestifyCallsWithVariadicArgs(t *testing.T) { 32 | mock := mocks.NewMockClient() 33 | mock.DoArgs("foo", 1, 2, 3) 34 | mockassert.CalledWith(t, mock.DoArgsFunc, mockassert.Values("foo", 1, 2, 3)) 35 | 36 | mock.DoArgs("bar", 42) 37 | mock.DoArgs("baz") 38 | mockassert.CalledN(t, mock.DoArgsFunc, 3) 39 | mockassert.CalledNWith(t, mock.DoArgsFunc, 2, mockassert.Values( 40 | func(v string) bool { return strings.Contains(v, "a") }, 41 | )) 42 | mockassert.CalledAtNWith(t, mock.DoArgsFunc, 1, mockassert.Values( 43 | func(v string) bool { return strings.Contains(v, "a") }, 44 | )) 45 | mockassert.CalledAtNWith(t, mock.DoArgsFunc, 2, mockassert.Values( 46 | func(v string) bool { return strings.Contains(v, "a") }, 47 | )) 48 | 49 | // Mismatched variadic arg 50 | mockassert.NotCalledWith(t, mock.DoArgsFunc, mockassert.Values( 51 | mockassert.Skip, 52 | func(v []interface{}) bool { return len(v) > 0 }, 53 | )) 54 | } 55 | 56 | func TestTestifyPushHook(t *testing.T) { 57 | child1 := mocks.NewMockChild() 58 | child2 := mocks.NewMockChild() 59 | child3 := mocks.NewMockChild() 60 | parent := mocks.NewMockParent() 61 | 62 | parent.GetChildFunc.PushHook(func(i int) (testdata.Child, error) { return child1, nil }) 63 | parent.GetChildFunc.PushHook(func(i int) (testdata.Child, error) { return child2, nil }) 64 | parent.GetChildFunc.PushHook(func(i int) (testdata.Child, error) { return child3, nil }) 65 | parent.GetChildFunc.SetDefaultHook(func(i int) (testdata.Child, error) { 66 | return nil, fmt.Errorf("uh-oh") 67 | }) 68 | 69 | for _, expected := range []interface{}{child1, child2, child3} { 70 | child, _ := parent.GetChild(0) 71 | assert.Equal(t, expected, child) 72 | } 73 | 74 | _, err := parent.GetChild(0) 75 | assert.EqualError(t, err, "uh-oh") 76 | } 77 | 78 | func TestTestifySetDefaultReturn(t *testing.T) { 79 | parent := mocks.NewMockParent() 80 | parent.GetChildFunc.SetDefaultReturn(nil, fmt.Errorf("uh-oh")) 81 | _, err := parent.GetChild(0) 82 | assert.EqualError(t, err, "uh-oh") 83 | } 84 | 85 | func TestTestifyPushReturn(t *testing.T) { 86 | parent := mocks.NewMockParent() 87 | parent.GetChildrenFunc.PushReturn([]testdata.Child{nil}) 88 | parent.GetChildrenFunc.PushReturn([]testdata.Child{nil, nil}) 89 | parent.GetChildrenFunc.PushReturn([]testdata.Child{nil, nil, nil}) 90 | 91 | assert.Len(t, parent.GetChildren(), 1) 92 | assert.Len(t, parent.GetChildren(), 2) 93 | assert.Len(t, parent.GetChildren(), 3) 94 | assert.Len(t, parent.GetChildren(), 0) 95 | } 96 | 97 | func TestTestifyGenerics(t *testing.T) { 98 | mock := mocks.NewMockI2[string, int]() 99 | mock.M2Func.SetDefaultReturn(42) 100 | assert.Equal(t, 42, mock.M2("foo")) 101 | mockassert.CalledOnceWith(t, mock.M2Func, mockassert.Values("foo")) 102 | } 103 | -------------------------------------------------------------------------------- /internal/mockgen/consts/consts.go: -------------------------------------------------------------------------------- 1 | package consts 2 | 3 | const ( 4 | Name = "go-mockgen" 5 | PackageName = "github.com/derision-test/go-mockgen/v2" 6 | Description = "go-mockgen generates mock implementations from interface definitions." 7 | Version = "2.0.2" 8 | ) 9 | -------------------------------------------------------------------------------- /internal/mockgen/generation/errors.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | type errorWithSolutions struct { 4 | err error 5 | solutions []string 6 | } 7 | 8 | func (e errorWithSolutions) Error() string { return e.err.Error() } 9 | func (e errorWithSolutions) Solutions() []string { return e.solutions } 10 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io/ioutil" 7 | "log" 8 | "os/exec" 9 | "path" 10 | "path/filepath" 11 | "strings" 12 | 13 | "github.com/dave/jennifer/jen" 14 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/consts" 15 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/paths" 16 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 17 | ) 18 | 19 | type Options struct { 20 | PackageOptions []PackageOptions 21 | OutputOptions OutputOptions 22 | ContentOptions ContentOptions 23 | } 24 | 25 | type PackageOptions struct { 26 | ImportPaths []string 27 | Interfaces []string 28 | Exclude []string 29 | Prefix string 30 | } 31 | 32 | type OutputOptions struct { 33 | OutputFilename string 34 | OutputDir string 35 | Force bool 36 | DisableFormatting bool 37 | GoImportsBinary string 38 | ForTest bool 39 | } 40 | 41 | type ContentOptions struct { 42 | PkgName string 43 | OutputImportPath string 44 | Prefix string 45 | ConstructorPrefix string 46 | FilePrefix string 47 | BuildConstraints string 48 | } 49 | 50 | func Generate(ifaces []*types.Interface, opts *Options) error { 51 | if opts.OutputOptions.OutputFilename != "" { 52 | return generateFile(ifaces, opts) 53 | } 54 | 55 | return generateDirectory(ifaces, opts) 56 | } 57 | 58 | func generateFile(ifaces []*types.Interface, opts *Options) error { 59 | basename := opts.OutputOptions.OutputFilename 60 | if opts.OutputOptions.ForTest { 61 | ext := filepath.Ext(basename) 62 | basename = strings.TrimSuffix(basename, ext) + "_test" + ext 63 | } 64 | 65 | filename := filepath.Join(opts.OutputOptions.OutputDir, basename) 66 | 67 | exists, err := paths.Exists(filename) 68 | if err != nil { 69 | return err 70 | } 71 | if exists && !opts.OutputOptions.Force { 72 | return fmt.Errorf("filename %s already exists, overwrite with --force", paths.GetRelativePath(filename)) 73 | } 74 | 75 | return generateAndRender(ifaces, filename, opts) 76 | } 77 | 78 | func generateDirectory(ifaces []*types.Interface, opts *Options) error { 79 | suffix := "_mock" 80 | if opts.OutputOptions.ForTest { 81 | suffix += "_test" 82 | } 83 | 84 | makeFilename := func(iface *types.Interface) string { 85 | prefix := opts.ContentOptions.Prefix 86 | if iface.Prefix != "" { 87 | prefix = iface.Prefix 88 | } 89 | if prefix != "" { 90 | prefix += "_" 91 | } 92 | 93 | filename := fmt.Sprintf("%s%s%s.go", prefix, iface.Name, suffix) 94 | return path.Join(opts.OutputOptions.OutputDir, strings.Replace(strings.ToLower(filename), "-", "_", -1)) 95 | } 96 | 97 | if !opts.OutputOptions.Force { 98 | allPaths := make([]string, 0, len(ifaces)) 99 | for _, iface := range ifaces { 100 | 101 | allPaths = append(allPaths, makeFilename(iface)) 102 | } 103 | 104 | conflict, err := paths.AnyExists(allPaths) 105 | if err != nil { 106 | return err 107 | } 108 | if conflict != "" { 109 | return fmt.Errorf("filename %s already exists, overwrite with --force", paths.GetRelativePath(conflict)) 110 | } 111 | } 112 | 113 | for _, iface := range ifaces { 114 | if err := generateAndRender([]*types.Interface{iface}, makeFilename(iface), opts); err != nil { 115 | return err 116 | } 117 | } 118 | 119 | return nil 120 | } 121 | 122 | func generateAndRender(ifaces []*types.Interface, filename string, opts *Options) error { 123 | pkgName := opts.ContentOptions.PkgName 124 | if opts.OutputOptions.ForTest { 125 | pkgName += "_test" 126 | } 127 | 128 | content, err := generateContent(ifaces, pkgName, opts.ContentOptions) 129 | if err != nil { 130 | return err 131 | } 132 | 133 | log.Printf("writing to '%s'\n", paths.GetRelativePath(filename)) 134 | if err := ioutil.WriteFile(filename, []byte(content), 0644); err != nil { 135 | return err 136 | } 137 | 138 | if !opts.OutputOptions.DisableFormatting { 139 | if err := exec.Command(opts.OutputOptions.GoImportsBinary, "-w", filename).Run(); err != nil { 140 | return errorWithSolutions{ 141 | err: fmt.Errorf("failed to format file: %s", err), 142 | solutions: []string{ 143 | "install goimports on your PATH", 144 | "specify a non-standard path to a goimports binary via --goimports", 145 | "disable post-render formatting via --disable-formatting", 146 | }, 147 | } 148 | } 149 | } 150 | 151 | return nil 152 | } 153 | 154 | func generateContent(ifaces []*types.Interface, pkgName string, opts ContentOptions) (string, error) { 155 | prefix := opts.Prefix 156 | constructorPrefix := opts.ConstructorPrefix 157 | fileContentPrefix := opts.FilePrefix 158 | outputImportPath := opts.OutputImportPath 159 | 160 | if fileContentPrefix != "" { 161 | separator := "\n// " 162 | fileContentPrefix = "\n//" + separator + strings.Join(strings.Split(strings.TrimSpace(fileContentPrefix), "\n"), separator) 163 | } 164 | 165 | file := jen.NewFile(pkgName) 166 | file.HeaderComment(fmt.Sprintf("// Code generated by %s %s; DO NOT EDIT.%s", consts.Name, consts.Version, fileContentPrefix)) 167 | 168 | if opts.BuildConstraints != "" { 169 | file.HeaderComment("//go:build " + opts.BuildConstraints) 170 | } 171 | 172 | for _, iface := range ifaces { 173 | log.Printf("generating code for interface '%s'\n", iface.Name) 174 | generateInterface(file, iface, prefix, constructorPrefix, outputImportPath) 175 | } 176 | 177 | buffer := &bytes.Buffer{} 178 | if err := file.Render(buffer); err != nil { 179 | return "", err 180 | } 181 | 182 | return buffer.String(), nil 183 | } 184 | 185 | func generateInterface(file *jen.File, iface *types.Interface, prefix, constructorPrefix, outputImportPath string) { 186 | if iface.Prefix != "" { 187 | // Override parent prefix if one is set on the iface 188 | prefix = iface.Prefix 189 | } 190 | 191 | withConstructorPrefix := func(f func(*wrappedInterface, string, string) jen.Code) func(*wrappedInterface, string) jen.Code { 192 | return func(iface *wrappedInterface, outputImportPath string) jen.Code { 193 | return f(iface, constructorPrefix, outputImportPath) 194 | } 195 | } 196 | 197 | topLevelGenerators := []func(*wrappedInterface, string) jen.Code{ 198 | generateMockStruct, 199 | withConstructorPrefix(generateMockStructConstructor), 200 | withConstructorPrefix(generateMockStructStrictConstructor), 201 | withConstructorPrefix(generateMockStructFromConstructor), 202 | } 203 | 204 | methodGenerators := []func(*wrappedInterface, *wrappedMethod, string) jen.Code{ 205 | generateMockFuncStruct, 206 | generateMockInterfaceMethod, 207 | generateMockFuncSetHookMethod, 208 | generateMockFuncPushHookMethod, 209 | generateMockFuncSetReturnMethod, 210 | generateMockFuncPushReturnMethod, 211 | generateMockFuncNextHookMethod, 212 | generateMockFuncAppendCallMethod, 213 | generateMockFuncHistoryMethod, 214 | generateMockFuncCallStruct, 215 | generateMockFuncCallArgsMethod, 216 | generateMockFuncCallResultsMethod, 217 | } 218 | 219 | titleName := strings.ToUpper(string(iface.Name[0])) + iface.Name[1:] 220 | mockStructName := fmt.Sprintf("Mock%s%s", prefix, titleName) 221 | wrappedInterface := wrapInterface(iface, prefix, titleName, mockStructName, outputImportPath) 222 | 223 | for _, generator := range topLevelGenerators { 224 | file.Add(generator(wrappedInterface, outputImportPath)) 225 | file.Line() 226 | } 227 | 228 | for _, method := range wrappedInterface.wrappedMethods { 229 | for _, generator := range methodGenerators { 230 | file.Add(generator(wrappedInterface, method, outputImportPath)) 231 | file.Line() 232 | } 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_comment.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/dave/jennifer/jen" 8 | "github.com/mitchellh/go-wordwrap" 9 | ) 10 | 11 | var ( 12 | maxAllowance = 80 13 | minAllowance = maxAllowance - indent*maxLevels 14 | indent = 4 15 | maxLevels = 3 16 | ) 17 | 18 | func generateComment(level int, format string, args ...interface{}) *jen.Statement { 19 | allowance := maxAllowance - indent*level - 3 20 | if allowance < minAllowance { 21 | allowance = minAllowance 22 | } 23 | 24 | commentText := fmt.Sprintf(format, args...) 25 | wrapped := wordwrap.WrapString(commentText, uint(allowance)) 26 | lines := strings.Split(wrapped, "\n") 27 | commentBlock := jen.Comment(lines[0]).Line() 28 | 29 | for i := 1; i < len(lines); i++ { 30 | commentBlock = commentBlock.Comment(lines[i]).Line() 31 | } 32 | 33 | return commentBlock 34 | } 35 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_constructors.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "unicode" 7 | 8 | "github.com/dave/jennifer/jen" 9 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 10 | ) 11 | 12 | func generateMockStructConstructor(iface *wrappedInterface, constructorPrefix, outputImportPath string) jen.Code { 13 | makeField := func(method *wrappedMethod) jen.Code { 14 | return makeDefaultHookField(iface, method, outputImportPath, generateNoopFunction(iface, method, outputImportPath)) 15 | } 16 | 17 | name := fmt.Sprintf("New%s%s", constructorPrefix, iface.mockStructName) 18 | commentText := []string{ 19 | fmt.Sprintf(`%s creates a new mock of the %s interface.`, name, iface.Name), 20 | `All methods return zero values for all results, unless overwritten.`, 21 | } 22 | return generateConstructor(iface, strings.Join(commentText, " "), name, nil, outputImportPath, makeField) 23 | } 24 | 25 | func generateMockStructStrictConstructor(iface *wrappedInterface, constructorPrefix, outputImportPath string) jen.Code { 26 | makeField := func(method *wrappedMethod) jen.Code { 27 | return makeDefaultHookField(iface, method, outputImportPath, generatePanickingFunction(iface, method, outputImportPath)) 28 | } 29 | 30 | name := fmt.Sprintf("NewStrict%s%s", constructorPrefix, iface.mockStructName) 31 | commentText := []string{ 32 | fmt.Sprintf(`%s creates a new mock of the %s interface.`, name, iface.Name), 33 | `All methods panic on invocation, unless overwritten.`, 34 | } 35 | return generateConstructor(iface, strings.Join(commentText, " "), name, nil, outputImportPath, makeField) 36 | } 37 | 38 | func generateMockStructFromConstructor(iface *wrappedInterface, constructorPrefix, outputImportPath string) jen.Code { 39 | if !unicode.IsUpper([]rune(iface.Name)[0]) { 40 | surrogateStructName := fmt.Sprintf("surrogateMock%s", iface.titleName) 41 | surrogateDefinition := generateSurrogateInterface(iface, surrogateStructName, outputImportPath) 42 | name := jen.Id(surrogateStructName) 43 | constructor := generateMockStructFromConstructorCommon(iface, name, constructorPrefix, outputImportPath) 44 | return compose(surrogateDefinition, constructor) 45 | } 46 | 47 | importPath := sanitizeImportPath(iface.ImportPath, outputImportPath) 48 | name := jen.Qual(importPath, iface.Name) 49 | return generateMockStructFromConstructorCommon(iface, name, constructorPrefix, outputImportPath) 50 | } 51 | 52 | func generateMockStructFromConstructorCommon(iface *wrappedInterface, ifaceName *jen.Statement, constructorPrefix, outputImportPath string) jen.Code { 53 | makeField := func(method *wrappedMethod) jen.Code { 54 | // i. 55 | return makeDefaultHookField(iface, method, outputImportPath, jen.Id("i").Dot(method.Name)) 56 | } 57 | 58 | name := fmt.Sprintf("New%s%sFrom", constructorPrefix, iface.mockStructName) 59 | commentText := []string{ 60 | fmt.Sprintf(`%s creates a new mock of the %s interface.`, name, iface.mockStructName), 61 | `All methods delegate to the given implementation, unless overwritten.`, 62 | } 63 | 64 | // (i ) 65 | params := []jen.Code{compose(jen.Id("i"), addTypes(ifaceName, iface.TypeParams, outputImportPath, false))} 66 | return generateConstructor(iface, strings.Join(commentText, " "), name, params, outputImportPath, makeField) 67 | } 68 | 69 | func generateConstructor( 70 | iface *wrappedInterface, 71 | commentText string, 72 | methodName string, 73 | params []jen.Code, 74 | outputImportPath string, 75 | makeField func(method *wrappedMethod) jen.Code, 76 | ) jen.Code { 77 | constructorFields := make([]jen.Code, 0, len(iface.Methods)) 78 | for _, method := range iface.wrappedMethods { 79 | constructorFields = append(constructorFields, makeField(method)) 80 | } 81 | 82 | // return &Mock{ , ... } 83 | returnStatement := compose(jen.Return(), generateStructInitializer(iface.mockStructName, outputImportPath, iface.TypeParams, constructorFields...)) 84 | results := []jen.Code{addTypes(jen.Op("*").Id(iface.mockStructName), iface.TypeParams, outputImportPath, false)} 85 | functionDeclaration := compose(addTypes(jen.Func().Id(methodName), iface.TypeParams, outputImportPath, true), jen.Params(params...).Params(results...).Block(returnStatement)) 86 | return addComment(functionDeclaration, 1, commentText) 87 | } 88 | 89 | func generateNoopFunction(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 90 | rt := make([]jen.Code, 0, len(method.resultTypes)) 91 | for i, resultType := range method.resultTypes { 92 | // (r0 , r1 , ...) 93 | rt = append(rt, compose(jen.Id(fmt.Sprintf("r%d", i)), resultType)) 94 | } 95 | 96 | // Note: an empty return here returns the zero valued variables r0, r1, ... 97 | return jen.Func().Params(method.paramTypes...).Params(rt...).Block(jen.Return()) 98 | } 99 | 100 | func generatePanickingFunction(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 101 | // panic("unexpected invocation of .") 102 | panicStatement := jen.Panic(jen.Lit(fmt.Sprintf("unexpected invocation of %s.%s", iface.mockStructName, method.Method.Name))) 103 | return jen.Func().Params(method.paramTypes...).Params(method.resultTypes...).Block(panicStatement) 104 | } 105 | 106 | func generateSurrogateInterface(iface *wrappedInterface, surrogateName, outputImportPath string) *jen.Statement { 107 | surrogateCommentText := strings.Join([]string{ 108 | fmt.Sprintf(`%s is a copy of the %s interface (from the package %s).`, surrogateName, iface.Name, iface.ImportPath), 109 | `It is redefined here as it is unexported in the source package.`, 110 | }, " ") 111 | 112 | signatures := make([]jen.Code, 0, len(iface.wrappedMethods)) 113 | for _, method := range iface.wrappedMethods { 114 | signatures = append(signatures, jen.Id(method.Name).Params(method.paramTypes...).Params(method.resultTypes...)) 115 | } 116 | 117 | // type interface { (, ...) (, ...), ... } 118 | typeDeclaration := addTypes(jen.Type().Id(surrogateName), iface.Interface.TypeParams, outputImportPath, true).Interface(signatures...).Line() 119 | return addComment(typeDeclaration, 1, surrogateCommentText) 120 | } 121 | 122 | func makeDefaultHookField(iface *wrappedInterface, method *wrappedMethod, outputImportPath string, function jen.Code) jen.Code { 123 | fieldName := fmt.Sprintf("%sFunc", method.Name) 124 | structName := fmt.Sprintf("%s%s%sFunc", iface.prefix, iface.titleName, method.Name) 125 | 126 | initializer := generateStructInitializer(structName, outputImportPath, iface.TypeParams, compose( 127 | jen.Id("defaultHook").Op(":"), 128 | function, 129 | )) 130 | 131 | // : &StructName{ defaultHook: } 132 | return compose(jen.Id(fieldName), jen.Op(":"), initializer) 133 | } 134 | 135 | func generateStructInitializer(structName string, outputImportPath string, typeParams []types.TypeParam, fields ...jen.Code) jen.Code { 136 | // &{ fields, ... } 137 | return compose(addTypes(jen.Op("&").Id(structName), typeParams, outputImportPath, false), jen.Values(padFields(fields)...)) 138 | } 139 | 140 | func padFields(fields []jen.Code) []jen.Code { 141 | paddedFields := make([]jen.Code, 0, len(fields)+1) 142 | for _, field := range fields { 143 | paddedFields = append(paddedFields, compose(jen.Line(), field)) 144 | } 145 | 146 | return append(paddedFields, jen.Line()) 147 | } 148 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_constructors_test.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGenerateMockStructConstructor(t *testing.T) { 11 | code := generateMockStructConstructor(makeInterface(TestMethodStatus, TestMethodDo, TestMethodDof), "", "") 12 | expected := strip(` 13 | // NewMockTestClient creates a new mock of the Client interface. All methods 14 | // return zero values for all results, unless overwritten. 15 | func NewMockTestClient() *MockTestClient { 16 | return &MockTestClient{ 17 | StatusFunc: &TestClientStatusFunc{ 18 | defaultHook: func() (r0 string, r1 bool) { 19 | return 20 | }, 21 | }, 22 | DoFunc: &TestClientDoFunc{ 23 | defaultHook: func(string) (r0 bool) { 24 | return 25 | }, 26 | }, 27 | DofFunc: &TestClientDofFunc{ 28 | defaultHook: func(string, ...string) (r0 bool) { 29 | return 30 | }, 31 | }, 32 | } 33 | } 34 | `) 35 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 36 | } 37 | 38 | func TestGenerateMockStructStrictConstructor(t *testing.T) { 39 | code := generateMockStructStrictConstructor(makeInterface(TestMethodStatus, TestMethodDo, TestMethodDof), "", "") 40 | expected := strip(` 41 | // NewStrictMockTestClient creates a new mock of the Client interface. All 42 | // methods panic on invocation, unless overwritten. 43 | func NewStrictMockTestClient() *MockTestClient { 44 | return &MockTestClient{ 45 | StatusFunc: &TestClientStatusFunc{ 46 | defaultHook: func() (string, bool) { 47 | panic("unexpected invocation of MockTestClient.Status") 48 | }, 49 | }, 50 | DoFunc: &TestClientDoFunc{ 51 | defaultHook: func(string) bool { 52 | panic("unexpected invocation of MockTestClient.Do") 53 | }, 54 | }, 55 | DofFunc: &TestClientDofFunc{ 56 | defaultHook: func(string, ...string) bool { 57 | panic("unexpected invocation of MockTestClient.Dof") 58 | }, 59 | }, 60 | } 61 | } 62 | `) 63 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 64 | } 65 | 66 | func TestGenerateMockStructFromConstructor(t *testing.T) { 67 | code := generateMockStructFromConstructor(makeInterface(TestMethodStatus, TestMethodDo, TestMethodDof), "", "") 68 | expected := strip(` 69 | // NewMockTestClientFrom creates a new mock of the MockTestClient interface. 70 | // All methods delegate to the given implementation, unless overwritten. 71 | func NewMockTestClientFrom(i test.Client) *MockTestClient { 72 | return &MockTestClient{ 73 | StatusFunc: &TestClientStatusFunc{ 74 | defaultHook: i.Status, 75 | }, 76 | DoFunc: &TestClientDoFunc{ 77 | defaultHook: i.Do, 78 | }, 79 | DofFunc: &TestClientDofFunc{ 80 | defaultHook: i.Dof, 81 | }, 82 | } 83 | } 84 | `) 85 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 86 | } 87 | 88 | func TestGenerateMockStructFromConstructorUnexported(t *testing.T) { 89 | iface := makeBareInterface(TestMethodStatus, TestMethodDo, TestMethodDof) 90 | iface.Name = "client" 91 | code := generateMockStructFromConstructor(wrapInterface(iface, TestPrefix, TestTitleName, TestMockStructName, ""), "", "") 92 | 93 | expected := strip(` 94 | // surrogateMockClient is a copy of the client interface (from the package 95 | // github.com/derision-test/go-mockgen/v2/test). It is redefined here as it 96 | // is unexported in the source package. 97 | type surrogateMockClient interface { 98 | Status() (string, bool) 99 | Do(string) bool 100 | Dof(string, ...string) bool 101 | } 102 | 103 | // NewMockTestClientFrom creates a new mock of the MockTestClient interface. 104 | // All methods delegate to the given implementation, unless overwritten. 105 | func NewMockTestClientFrom(i surrogateMockClient) *MockTestClient { 106 | return &MockTestClient{ 107 | StatusFunc: &TestClientStatusFunc{ 108 | defaultHook: i.Status, 109 | }, 110 | DoFunc: &TestClientDoFunc{ 111 | defaultHook: i.Do, 112 | }, 113 | DofFunc: &TestClientDofFunc{ 114 | defaultHook: i.Dof, 115 | }, 116 | } 117 | } 118 | `) 119 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 120 | } 121 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_mock_func_call_methods.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/dave/jennifer/jen" 8 | ) 9 | 10 | func generateMockFuncCallArgsMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 11 | if method.Variadic { 12 | return generateMockFuncCallArgsMethodVariadic(iface, method, outputImportPath) 13 | } 14 | 15 | return generateMockFuncCallArgsMethodNonVariadic(iface, method, outputImportPath) 16 | } 17 | 18 | func generateMockFuncCallArgsMethodNonVariadic(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 19 | commentText := `Args returns an interface slice containing the arguments of this invocation.` 20 | 21 | valueExpressions := make([]jen.Code, 0, len(method.Params)) 22 | for i := range method.Params { 23 | valueExpressions = append(valueExpressions, jen.Id("c").Dot(fmt.Sprintf("Arg%d", i))) 24 | } 25 | returnStatement := jen.Return().Index().Interface().Values(valueExpressions...) 26 | 27 | results := []jen.Code{jen.Index().Interface()} 28 | return generateMockFuncCallMethod(iface, outputImportPath, method, "Args", commentText, nil, results, 29 | returnStatement, // return []interface{ c.Arg, ... } 30 | ) 31 | } 32 | 33 | func generateMockFuncCallArgsMethodVariadic(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 34 | commentText := strings.Join([]string{ 35 | `Args returns an interface slice containing the arguments of this invocation.`, 36 | `The variadic slice argument is flattened in this array such that one positional argument and three variadic arguments would result in a slice of four, not two.`, 37 | }, " ") 38 | 39 | valueExpressions := make([]jen.Code, 0, len(method.Params)) 40 | for i := range method.Params { 41 | valueExpressions = append(valueExpressions, jen.Id("c").Dot(fmt.Sprintf("Arg%d", i))) 42 | } 43 | 44 | lastIndex := len(valueExpressions) - 1 45 | trailingDeclaration := jen.Id("trailing").Op(":=").Index().Interface().Values() 46 | loopCondition := compose(jen.Id("_").Op(",").Id("val").Op(":=").Range(), valueExpressions[lastIndex]) 47 | loopBody := selfAppend(jen.Id("trailing"), jen.Id("val")) 48 | loopStatement := jen.For(loopCondition).Block(loopBody) 49 | simpleValuesExpression := jen.Index().Interface().Values(valueExpressions[:lastIndex]...) 50 | returnStatement := jen.Return().Append(simpleValuesExpression, jen.Id("trailing").Op("...")) 51 | 52 | results := []jen.Code{jen.Index().Interface()} 53 | return generateMockFuncCallMethod(iface, outputImportPath, method, "Args", commentText, nil, results, 54 | trailingDeclaration, // trailingDeclaration := []interface{} 55 | loopStatement, jen.Line(), jen.Line(), // for _, val := range Arg { trailing = append(trailing, val) } 56 | returnStatement, // return append([]interface{ , ... }, trailing...) 57 | ) 58 | } 59 | 60 | func generateMockFuncCallResultsMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 61 | commentText := `Results returns an interface slice containing the results of this invocation.` 62 | 63 | values := make([]jen.Code, 0, len(method.Results)) 64 | for i := range method.Results { 65 | values = append(values, jen.Id("c").Dot(fmt.Sprintf("Result%d", i))) 66 | } 67 | 68 | returnStatement := jen.Return().Index().Interface().Values(values...) 69 | 70 | results := []jen.Code{jen.Index().Interface()} 71 | return generateMockFuncCallMethod(iface, outputImportPath, method, "Results", commentText, nil, results, 72 | returnStatement, // return []interface{ c.Result, ... } 73 | ) 74 | } 75 | 76 | func generateMockFuncCallMethod( 77 | iface *wrappedInterface, 78 | outputImportPath string, 79 | method *wrappedMethod, 80 | methodName string, 81 | commentText string, 82 | params, results []jen.Code, 83 | body ...jen.Code, 84 | ) jen.Code { 85 | mockFuncCallStructName := fmt.Sprintf("%s%s%sFuncCall", iface.prefix, iface.titleName, method.Name) 86 | receiver := compose(jen.Id("c"), addTypes(jen.Id(mockFuncCallStructName), iface.TypeParams, outputImportPath, false)) 87 | methodDeclaration := jen.Func().Params(receiver).Id(methodName).Params(params...).Params(results...).Block(body...) 88 | return addComment(methodDeclaration, 1, commentText) 89 | } 90 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_mock_func_call_methods_test.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGenerateMockFuncCallArgsMethod(t *testing.T) { 11 | wrappedInterface := makeInterface(TestMethodDo) 12 | code := generateMockFuncCallArgsMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 13 | expected := strip(` 14 | // Args returns an interface slice containing the arguments of this 15 | // invocation. 16 | func (c TestClientDoFuncCall) Args() []interface{} { 17 | return []interface{}{c.Arg0} 18 | } 19 | `) 20 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 21 | } 22 | 23 | func TestGenerateMockFuncCallArgsMethodVariadic(t *testing.T) { 24 | wrappedInterface := makeInterface(TestMethodDof) 25 | code := generateMockFuncCallArgsMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 26 | expected := strip(` 27 | // Args returns an interface slice containing the arguments of this 28 | // invocation. The variadic slice argument is flattened in this array such 29 | // that one positional argument and three variadic arguments would result in 30 | // a slice of four, not two. 31 | func (c TestClientDofFuncCall) Args() []interface{} { 32 | trailing := []interface{}{} 33 | for _, val := range c.Arg1 { 34 | trailing = append(trailing, val) 35 | } 36 | 37 | return append([]interface{}{c.Arg0}, trailing...) 38 | } 39 | `) 40 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 41 | } 42 | 43 | func TestGenerateMockFuncCallResultsMethod(t *testing.T) { 44 | wrappedInterface := makeInterface(TestMethodDo) 45 | code := generateMockFuncCallResultsMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 46 | expected := strip(` 47 | // Results returns an interface slice containing the results of this 48 | // invocation. 49 | func (c TestClientDoFuncCall) Results() []interface{} { 50 | return []interface{}{c.Result0} 51 | } 52 | `) 53 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 54 | } 55 | 56 | func TestGenerateMockFuncCallResultsMethodMultiple(t *testing.T) { 57 | wrappedInterface := makeInterface(TestMethodStatus) 58 | code := generateMockFuncCallResultsMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 59 | expected := strip(` 60 | // Results returns an interface slice containing the results of this 61 | // invocation. 62 | func (c TestClientStatusFuncCall) Results() []interface{} { 63 | return []interface{}{c.Result0, c.Result1} 64 | } 65 | `) 66 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 67 | } 68 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_mock_func_methods.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/dave/jennifer/jen" 8 | ) 9 | 10 | func generateMockFuncSetHookMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 11 | commentText := fmt.Sprintf( 12 | `SetDefaultHook sets function that is called when the %s method of the parent %s instance is invoked and the hook queue is empty.`, 13 | method.Name, 14 | iface.mockStructName, 15 | ) 16 | 17 | assignStatement := jen.Id("f").Dot("defaultHook").Op("=").Id("hook") 18 | 19 | params := []jen.Code{compose(jen.Id("hook"), method.signature)} 20 | return generateMockFuncMethod(iface, outputImportPath, method, "SetDefaultHook", commentText, params, nil, 21 | assignStatement, // f.defaultHook = hook 22 | ) 23 | } 24 | 25 | func generateMockFuncPushHookMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 26 | commentText := strings.Join([]string{ 27 | `PushHook adds a function to the end of hook queue.`, 28 | fmt.Sprintf(`Each invocation of the %s method of the parent %s instance invokes the hook at the front of the queue and discards it.`, method.Name, iface.mockStructName), 29 | `After the queue is empty, the default hook function is invoked for any future action.`, 30 | }, " ") 31 | 32 | lockStatement := jen.Id("f").Dot("mutex").Dot("Lock").Call() 33 | unlockStatement := jen.Id("f").Dot("mutex").Dot("Unlock").Call() 34 | appendStatement := selfAppend(jen.Id("f").Dot("hooks"), jen.Id("hook")) 35 | 36 | params := []jen.Code{compose(jen.Id("hook"), method.signature)} 37 | return generateMockFuncMethod(iface, outputImportPath, method, "PushHook", commentText, params, nil, 38 | lockStatement, // f.mutex.Lock() 39 | appendStatement, // f.mutex.Unlock() 40 | unlockStatement, // f.hooks = append(f.hooks, hook) 41 | ) 42 | } 43 | 44 | func generateMockFuncSetReturnMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 45 | return generateMockReturnMethod(iface, method, "SetDefault", outputImportPath) 46 | } 47 | 48 | func generateMockFuncPushReturnMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 49 | return generateMockReturnMethod(iface, method, "Push", outputImportPath) 50 | } 51 | 52 | func generateMockReturnMethod(iface *wrappedInterface, method *wrappedMethod, methodPrefix, outputImportPath string) jen.Code { 53 | commentText := fmt.Sprintf( 54 | `%sReturn calls %sHook with a function that returns the given values.`, 55 | methodPrefix, 56 | methodPrefix, 57 | ) 58 | 59 | names := make([]jen.Code, 0, len(method.resultTypes)) 60 | params := make([]jen.Code, 0, len(method.resultTypes)) 61 | for i, typ := range method.resultTypes { 62 | name := jen.Id(fmt.Sprintf("r%d", i)) 63 | names = append(names, name) 64 | params = append(params, compose(name, typ)) 65 | } 66 | 67 | returnStatement := jen.Return().List(names...) 68 | functionExpression := jen.Func().Params(method.paramTypes...).Params(method.resultTypes...).Block(returnStatement) 69 | callStatement := jen.Id("f").Dot(fmt.Sprintf("%sHook", methodPrefix)).Call(functionExpression) 70 | 71 | return generateMockFuncMethod(iface, outputImportPath, method, fmt.Sprintf("%sReturn", methodPrefix), commentText, params, nil, 72 | callStatement, // f.Hook(func( T, ... ) { return r, ... }) 73 | ) 74 | } 75 | 76 | func generateMockFuncNextHookMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 77 | lockStatement := jen.Id("f").Dot("mutex").Dot("Lock").Call() 78 | deferUnlockStatement := jen.Defer().Id("f").Dot("mutex").Dot("Unlock").Call() 79 | lenHooksExpression := jen.Len(jen.Id("f").Dot("hooks")) 80 | earlyReturnStatement := jen.Return(jen.Id("f").Dot("defaultHook")) 81 | returnDefaultIfEmptyCondition := jen.If(lenHooksExpression.Op("==").Lit(0)).Block(earlyReturnStatement) 82 | firstHookStatement := jen.Id("hook").Op(":=").Id("f").Dot("hooks").Index(jen.Lit(0)) 83 | popHookStatement := jen.Id("f").Dot("hooks").Op("=").Id("f").Dot("hooks").Index(jen.Lit(1).Op(":")) 84 | returnStatement := jen.Return(jen.Id("hook")) 85 | 86 | results := []jen.Code{method.signature} 87 | return generateMockFuncMethod(iface, outputImportPath, method, "nextHook", "", nil, results, 88 | lockStatement, // f.mutex.Lock() 89 | deferUnlockStatement, jen.Line(), // defer f.mutex.Unlock() 90 | returnDefaultIfEmptyCondition, jen.Line(), // if len(f.hooks) == 0 { return f.defaultHook } 91 | firstHookStatement, // hook := f.hooks[0] 92 | popHookStatement, // f.hooks = f.hooks[1:] 93 | returnStatement, // return hook 94 | ) 95 | } 96 | 97 | func generateMockFuncAppendCallMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 98 | mockFuncCallStructName := fmt.Sprintf("%s%s%sFuncCall", iface.prefix, iface.titleName, method.Name) 99 | 100 | lockStatement := jen.Id("f").Dot("mutex").Dot("Lock").Call() 101 | unlockStatement := jen.Id("f").Dot("mutex").Dot("Unlock").Call() 102 | appendStatement := selfAppend(jen.Id("f").Dot("history"), jen.Id("r0")) 103 | 104 | params := []jen.Code{compose(jen.Id("r0"), addTypes(jen.Id(mockFuncCallStructName), iface.TypeParams, outputImportPath, false))} 105 | return generateMockFuncMethod(iface, outputImportPath, method, "appendCall", "", params, nil, 106 | lockStatement, // f.mutex.Lock() 107 | appendStatement, // f.history = append(f.history, r0) 108 | unlockStatement, // f.mutex.Unlock() 109 | ) 110 | } 111 | 112 | func generateMockFuncHistoryMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 113 | mockFuncCallStructName := fmt.Sprintf("%s%s%sFuncCall", iface.prefix, iface.titleName, method.Name) 114 | commentText := fmt.Sprintf( 115 | `History returns a sequence of %s objects describing the invocations of this function.`, 116 | mockFuncCallStructName, 117 | ) 118 | 119 | lockStatement := jen.Id("f").Dot("mutex").Dot("Lock").Call() 120 | unlockStatement := jen.Id("f").Dot("mutex").Dot("Unlock").Call() 121 | callStructSliceType := compose(jen.Index(), addTypes(jen.Id(mockFuncCallStructName), iface.TypeParams, outputImportPath, false)) 122 | lenHistoryExpression := jen.Len(jen.Id("f").Dot("history")) 123 | makeSliceStatement := jen.Id("history").Op(":=").Make(callStructSliceType, lenHistoryExpression) 124 | copyStatement := jen.Copy(jen.Id("history"), jen.Id("f").Dot("history")) 125 | returnStatement := jen.Return().Id("history") 126 | 127 | results := []jen.Code{compose(jen.Index(), addTypes(jen.Id(mockFuncCallStructName), iface.TypeParams, outputImportPath, false))} 128 | return generateMockFuncMethod(iface, outputImportPath, method, "History", commentText, nil, results, 129 | lockStatement, // f.mutex.Lock() 130 | makeSliceStatement, // history := make([], len(f.history)) 131 | copyStatement, // copy(history, f.history) 132 | unlockStatement, jen.Line(), // f.mutex.Unlock() 133 | returnStatement, // return history 134 | ) 135 | } 136 | 137 | func generateMockFuncMethod( 138 | iface *wrappedInterface, 139 | outputImportPath string, 140 | method *wrappedMethod, 141 | methodName string, 142 | commentText string, 143 | params, results []jen.Code, 144 | body ...jen.Code, 145 | ) jen.Code { 146 | mockFuncStructName := fmt.Sprintf("%s%s%sFunc", iface.prefix, iface.titleName, method.Name) 147 | receiver := compose(jen.Id("f").Op("*"), addTypes(jen.Id(mockFuncStructName), iface.TypeParams, outputImportPath, false)) 148 | methodDeclaration := jen.Func().Params(receiver).Id(methodName).Params(params...).Params(results...).Block(body...) 149 | return addComment(methodDeclaration, 1, commentText) 150 | } 151 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_mock_func_methods_test.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGenerateMockFuncSetHookMethod(t *testing.T) { 11 | wrappedInterface := makeInterface(TestMethodDo) 12 | code := generateMockFuncSetHookMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 13 | expected := strip(` 14 | // SetDefaultHook sets function that is called when the Do method of the 15 | // parent MockTestClient instance is invoked and the hook queue is empty. 16 | func (f *TestClientDoFunc) SetDefaultHook(hook func(string) bool) { 17 | f.defaultHook = hook 18 | } 19 | `) 20 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 21 | } 22 | 23 | func TestGenerateMockFuncSetHookMethodVariadic(t *testing.T) { 24 | wrappedInterface := makeInterface(TestMethodDof) 25 | code := generateMockFuncSetHookMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 26 | expected := strip(` 27 | // SetDefaultHook sets function that is called when the Dof method of the 28 | // parent MockTestClient instance is invoked and the hook queue is empty. 29 | func (f *TestClientDofFunc) SetDefaultHook(hook func(string, ...string) bool) { 30 | f.defaultHook = hook 31 | } 32 | `) 33 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 34 | } 35 | 36 | func TestGenerateMockFuncPushHookMethod(t *testing.T) { 37 | wrappedInterface := makeInterface(TestMethodDo) 38 | code := generateMockFuncPushHookMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 39 | expected := strip(` 40 | // PushHook adds a function to the end of hook queue. Each invocation of the 41 | // Do method of the parent MockTestClient instance invokes the hook at the 42 | // front of the queue and discards it. After the queue is empty, the default 43 | // hook function is invoked for any future action. 44 | func (f *TestClientDoFunc) PushHook(hook func(string) bool) { 45 | f.mutex.Lock() 46 | f.hooks = append(f.hooks, hook) 47 | f.mutex.Unlock() 48 | } 49 | `) 50 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 51 | } 52 | 53 | func TestGenerateMockFuncPushHookMethodVariadic(t *testing.T) { 54 | wrappedInterface := makeInterface(TestMethodDof) 55 | code := generateMockFuncPushHookMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 56 | expected := strip(` 57 | // PushHook adds a function to the end of hook queue. Each invocation of the 58 | // Dof method of the parent MockTestClient instance invokes the hook at the 59 | // front of the queue and discards it. After the queue is empty, the default 60 | // hook function is invoked for any future action. 61 | func (f *TestClientDofFunc) PushHook(hook func(string, ...string) bool) { 62 | f.mutex.Lock() 63 | f.hooks = append(f.hooks, hook) 64 | f.mutex.Unlock() 65 | } 66 | `) 67 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 68 | } 69 | 70 | func TestGenerateMockFuncSetReturnMethod(t *testing.T) { 71 | wrappedInterface := makeInterface(TestMethodDo) 72 | code := generateMockFuncSetReturnMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 73 | expected := strip(` 74 | // SetDefaultReturn calls SetDefaultHook with a function that returns the 75 | // given values. 76 | func (f *TestClientDoFunc) SetDefaultReturn(r0 bool) { 77 | f.SetDefaultHook(func(string) bool { 78 | return r0 79 | }) 80 | } 81 | `) 82 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 83 | } 84 | 85 | func TestGenerateMockFuncPushReturnMethod(t *testing.T) { 86 | wrappedInterface := makeInterface(TestMethodDo) 87 | code := generateMockFuncPushReturnMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 88 | expected := strip(` 89 | // PushReturn calls PushHook with a function that returns the given values. 90 | func (f *TestClientDoFunc) PushReturn(r0 bool) { 91 | f.PushHook(func(string) bool { 92 | return r0 93 | }) 94 | } 95 | `) 96 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 97 | } 98 | 99 | func TestGenerateMockFuncNextHookMethod(t *testing.T) { 100 | wrappedInterface := makeInterface(TestMethodDo) 101 | code := generateMockFuncNextHookMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 102 | expected := strip(` 103 | func (f *TestClientDoFunc) nextHook() func(string) bool { 104 | f.mutex.Lock() 105 | defer f.mutex.Unlock() 106 | 107 | if len(f.hooks) == 0 { 108 | return f.defaultHook 109 | } 110 | 111 | hook := f.hooks[0] 112 | f.hooks = f.hooks[1:] 113 | return hook 114 | } 115 | `) 116 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 117 | } 118 | 119 | func TestGenerateMockFuncAppendCallMethod(t *testing.T) { 120 | wrappedInterface := makeInterface(TestMethodDo) 121 | code := generateMockFuncAppendCallMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 122 | expected := strip(` 123 | func (f *TestClientDoFunc) appendCall(r0 TestClientDoFuncCall) { 124 | f.mutex.Lock() 125 | f.history = append(f.history, r0) 126 | f.mutex.Unlock() 127 | } 128 | `) 129 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 130 | } 131 | 132 | func TestGenerateMockFuncHistoryMethod(t *testing.T) { 133 | wrappedInterface := makeInterface(TestMethodDo) 134 | code := generateMockFuncHistoryMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 135 | expected := strip(` 136 | // History returns a sequence of TestClientDoFuncCall objects describing the 137 | // invocations of this function. 138 | func (f *TestClientDoFunc) History() []TestClientDoFuncCall { 139 | f.mutex.Lock() 140 | history := make([]TestClientDoFuncCall, len(f.history)) 141 | copy(history, f.history) 142 | f.mutex.Unlock() 143 | 144 | return history 145 | } 146 | `) 147 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 148 | } 149 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_mock_methods.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/dave/jennifer/jen" 7 | ) 8 | 9 | func generateMockInterfaceMethod(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 10 | mockFuncFieldName := fmt.Sprintf("%sFunc", method.Name) 11 | mockFuncCallStructName := fmt.Sprintf("%s%s%sFuncCall", iface.prefix, iface.titleName, method.Name) 12 | commentText := fmt.Sprintf( 13 | `%s delegates to the next hook function in the queue and stores the parameter and result values of this invocation.`, 14 | method.Name, 15 | ) 16 | 17 | paramNames := make([]jen.Code, 0, len(method.Params)) 18 | argumentExpressions := make([]jen.Code, 0, len(method.Params)) 19 | for i := 0; i < len(method.Params); i++ { 20 | name := fmt.Sprintf("v%d", i) 21 | 22 | nameExpression := jen.Id(name) 23 | if method.Variadic && i == len(method.Params)-1 { 24 | nameExpression = compose(nameExpression, jen.Op("...")) 25 | } 26 | 27 | paramNames = append(paramNames, jen.Id(name)) 28 | argumentExpressions = append(argumentExpressions, nameExpression) 29 | } 30 | 31 | resultNames := make([]jen.Code, 0, len(method.Results)) 32 | for i := 0; i < len(method.Results); i++ { 33 | resultNames = append(resultNames, jen.Id(fmt.Sprintf("r%d", i))) 34 | } 35 | 36 | functionExpression := jen.Id("m").Dot(mockFuncFieldName).Dot("nextHook").Call() 37 | callStatement := functionExpression.Call(argumentExpressions...) 38 | callInstanceExpression := compose(addTypes(jen.Id(mockFuncCallStructName), iface.TypeParams, outputImportPath, false), jen.Values(append(paramNames, resultNames...)...)) 39 | appendFuncCall := jen.Id("m").Dot(mockFuncFieldName).Dot("appendCall").Call(callInstanceExpression) 40 | returnStatement := jen.Return() 41 | 42 | if len(method.Results) != 0 { 43 | assignmentTarget := jen.Id("r0") 44 | returnStatement = returnStatement.Id("r0") 45 | 46 | for i := 1; i < len(method.Results); i++ { 47 | assignmentTarget = assignmentTarget.Op(",").Id(fmt.Sprintf("r%d", i)) 48 | returnStatement = returnStatement.Op(",").Id(fmt.Sprintf("r%d", i)) 49 | } 50 | 51 | callStatement = compose(assignmentTarget.Op(":="), callStatement) 52 | } 53 | 54 | return generateMockMethod(iface, method, commentText, outputImportPath, 55 | callStatement, // r, ... := m.Func.nextHook()(Param, ...) 56 | appendFuncCall, // m.Func.appendCall(FuncCall{Param, ..., r, ...}) 57 | returnStatement, // return r, ... 58 | ) 59 | } 60 | 61 | func generateMockMethod( 62 | iface *wrappedInterface, 63 | method *wrappedMethod, 64 | commentText string, 65 | outputImportPath string, 66 | body ...jen.Code, 67 | ) jen.Code { 68 | params := make([]jen.Code, 0, len(method.paramTypes)) 69 | for i, param := range method.paramTypes { 70 | params = append(params, compose(jen.Id(fmt.Sprintf("v%d", i)), param)) 71 | } 72 | 73 | receiver := compose(jen.Id("m").Op("*"), addTypes(jen.Id(iface.mockStructName), iface.TypeParams, outputImportPath, false)) 74 | methodDeclaration := jen.Func().Params(receiver).Id(method.Name).Params(params...).Params(method.resultTypes...).Block(body...) 75 | return addComment(methodDeclaration, 1, commentText) 76 | } 77 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_mock_methods_test.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGenerateMockInterfaceMethod(t *testing.T) { 11 | wrappedInterface := makeInterface(TestMethodDo) 12 | code := generateMockInterfaceMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 13 | expected := strip(` 14 | // Do delegates to the next hook function in the queue and stores the 15 | // parameter and result values of this invocation. 16 | func (m *MockTestClient) Do(v0 string) bool { 17 | r0 := m.DoFunc.nextHook()(v0) 18 | m.DoFunc.appendCall(TestClientDoFuncCall{v0, r0}) 19 | return r0 20 | } 21 | `) 22 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 23 | } 24 | 25 | func TestGenerateMockInterfaceMethodVariadic(t *testing.T) { 26 | wrappedInterface := makeInterface(TestMethodDof) 27 | code := generateMockInterfaceMethod(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 28 | expected := strip(` 29 | // Dof delegates to the next hook function in the queue and stores the 30 | // parameter and result values of this invocation. 31 | func (m *MockTestClient) Dof(v0 string, v1 ...string) bool { 32 | r0 := m.DofFunc.nextHook()(v0, v1...) 33 | m.DofFunc.appendCall(TestClientDofFuncCall{v0, v1, r0}) 34 | return r0 35 | } 36 | `) 37 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 38 | } 39 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_structs.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | 7 | "github.com/dave/jennifer/jen" 8 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 9 | "github.com/dustin/go-humanize" 10 | ) 11 | 12 | func generateMockStruct(iface *wrappedInterface, outputImportPath string) jen.Code { 13 | mockStructName := iface.mockStructName 14 | commentText := fmt.Sprintf( 15 | `%s is a mock implementation of the %s interface (from the package %s) used for unit testing.`, 16 | mockStructName, 17 | iface.Name, 18 | iface.ImportPath, 19 | ) 20 | 21 | structFields := make([]jen.Code, 0, len(iface.Methods)) 22 | for _, method := range iface.wrappedMethods { 23 | mockFuncFieldName := fmt.Sprintf("%sFunc", method.Name) 24 | mockFuncStructName := fmt.Sprintf("%s%s%sFunc", iface.prefix, iface.titleName, method.Name) 25 | commentText := fmt.Sprintf( 26 | `%s is an instance of a mock function object controlling the behavior of the method %s.`, 27 | mockFuncFieldName, 28 | method.Name, 29 | ) 30 | 31 | hook := compose(jen.Id(mockFuncFieldName).Op("*"), addTypes(jen.Id(mockFuncStructName), iface.TypeParams, outputImportPath, false)) 32 | structFields = append(structFields, addComment(hook, 2, commentText)) 33 | } 34 | 35 | // Func *Func, ... 36 | return generateStruct(mockStructName, iface.TypeParams, commentText, outputImportPath, structFields) 37 | } 38 | 39 | func generateMockFuncStruct(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 40 | mockStructName := iface.mockStructName 41 | mockFuncStructName := fmt.Sprintf("%s%s%sFunc", iface.prefix, iface.titleName, method.Name) 42 | mockFuncCallStructName := fmt.Sprintf("%s%s%sFuncCall", iface.prefix, iface.titleName, method.Name) 43 | commentText := fmt.Sprintf( 44 | `%s describes the behavior when the %s method of the parent %s instance is invoked.`, 45 | mockFuncStructName, 46 | method.Name, 47 | mockStructName, 48 | ) 49 | 50 | return generateStruct(mockFuncStructName, iface.TypeParams, commentText, outputImportPath, []jen.Code{ 51 | compose(jen.Id("defaultHook"), method.signature), // defaultHook 52 | compose(jen.Id("hooks").Index(), method.signature), // hooks [] 53 | compose(jen.Id("history").Index(), addTypes(jen.Id(mockFuncCallStructName), iface.TypeParams, outputImportPath, false)), // history []FuncCall 54 | jen.Id("mutex").Qual("sync", "Mutex"), // mutex sync.Mutex 55 | }) 56 | } 57 | 58 | func generateMockFuncCallStruct(iface *wrappedInterface, method *wrappedMethod, outputImportPath string) jen.Code { 59 | mockStructName := iface.mockStructName 60 | mockFuncCallStructName := fmt.Sprintf("%s%s%sFuncCall", iface.prefix, iface.titleName, method.Name) 61 | commentText := fmt.Sprintf( 62 | `%s is an object that describes an invocation of method %s on an instance of %s.`, 63 | mockFuncCallStructName, 64 | method.Name, 65 | mockStructName, 66 | ) 67 | 68 | makeFields := func(prefix string, params []jen.Code, makeComment commentFactory) []jen.Code { 69 | fields := make([]jen.Code, 0, len(params)) 70 | for i, param := range params { 71 | name := prefix + strconv.Itoa(i) 72 | field := jen.Id(name).Add(param) 73 | fields = append(fields, addComment(field, 2, makeComment(method, name, i))) 74 | } 75 | 76 | return fields 77 | } 78 | 79 | argFields := makeFields("Arg", method.dotlessParamTypes, argFieldComment) // Arg , ... 80 | resultFields := makeFields("Result", method.resultTypes, resultFieldComment) // Result , ... 81 | return generateStruct(mockFuncCallStructName, iface.TypeParams, commentText, outputImportPath, append(argFields, resultFields...)) 82 | } 83 | 84 | func generateStruct(name string, typeParams []types.TypeParam, commentText, outputImportPath string, structFields []jen.Code) jen.Code { 85 | typeDeclaration := compose(addTypes(jen.Type().Id(name), typeParams, outputImportPath, true), jen.Struct(structFields...)) 86 | return addComment(typeDeclaration, 1, commentText) 87 | } 88 | 89 | type commentFactory func(method *wrappedMethod, name string, i int) string 90 | 91 | var ( 92 | _ commentFactory = argFieldComment 93 | _ commentFactory = resultFieldComment 94 | ) 95 | 96 | func argFieldComment(method *wrappedMethod, name string, i int) string { 97 | if i == len(method.dotlessParamTypes)-1 && method.Variadic { 98 | return fmt.Sprintf( 99 | `%s is a slice containing the values of the variadic arguments passed to this method invocation.`, 100 | name, 101 | ) 102 | } 103 | 104 | return fmt.Sprintf( 105 | `%s is the value of the %s argument passed to this method invocation.`, 106 | name, 107 | humanize.Ordinal(i+1), 108 | ) 109 | } 110 | 111 | func resultFieldComment(method *wrappedMethod, name string, i int) string { 112 | return fmt.Sprintf( 113 | `%s is the value of the %s result returned from this method invocation.`, 114 | name, 115 | humanize.Ordinal(i+1), 116 | ) 117 | } 118 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_structs_test.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestGenerateMockStruct(t *testing.T) { 11 | code := generateMockStruct(makeInterface(TestMethodStatus, TestMethodDo, TestMethodDof), "") 12 | expected := strip(` 13 | // MockTestClient is a mock implementation of the Client interface (from the 14 | // package github.com/derision-test/go-mockgen/v2/test) used for unit 15 | // testing. 16 | type MockTestClient struct { 17 | // StatusFunc is an instance of a mock function object controlling the 18 | // behavior of the method Status. 19 | StatusFunc *TestClientStatusFunc 20 | // DoFunc is an instance of a mock function object controlling the 21 | // behavior of the method Do. 22 | DoFunc *TestClientDoFunc 23 | // DofFunc is an instance of a mock function object controlling the 24 | // behavior of the method Dof. 25 | DofFunc *TestClientDofFunc 26 | } 27 | `) 28 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 29 | } 30 | 31 | func TestGenerateFuncStruct(t *testing.T) { 32 | wrappedInterface := makeInterface(TestMethodDo) 33 | code := generateMockFuncStruct(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 34 | expected := strip(` 35 | // TestClientDoFunc describes the behavior when the Do method of the parent 36 | // MockTestClient instance is invoked. 37 | type TestClientDoFunc struct { 38 | defaultHook func(string) bool 39 | hooks []func(string) bool 40 | history []TestClientDoFuncCall 41 | mutex sync.Mutex 42 | } 43 | `) 44 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 45 | } 46 | 47 | func TestGenerateFuncStructVariadic(t *testing.T) { 48 | wrappedInterface := makeInterface(TestMethodDof) 49 | code := generateMockFuncStruct(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 50 | expected := strip(` 51 | // TestClientDofFunc describes the behavior when the Dof method of the 52 | // parent MockTestClient instance is invoked. 53 | type TestClientDofFunc struct { 54 | defaultHook func(string, ...string) bool 55 | hooks []func(string, ...string) bool 56 | history []TestClientDofFuncCall 57 | mutex sync.Mutex 58 | } 59 | `) 60 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 61 | } 62 | 63 | func TestGenerateMockFuncCallStruct(t *testing.T) { 64 | wrappedInterface := makeInterface(TestMethodDo) 65 | code := generateMockFuncCallStruct(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 66 | expected := strip(` 67 | // TestClientDoFuncCall is an object that describes an invocation of method 68 | // Do on an instance of MockTestClient. 69 | type TestClientDoFuncCall struct { 70 | // Arg0 is the value of the 1st argument passed to this method 71 | // invocation. 72 | Arg0 string 73 | // Result0 is the value of the 1st result returned from this method 74 | // invocation. 75 | Result0 bool 76 | } 77 | `) 78 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 79 | } 80 | 81 | func TestGenerateMockFuncCallStructVariadic(t *testing.T) { 82 | wrappedInterface := makeInterface(TestMethodDof) 83 | code := generateMockFuncCallStruct(wrappedInterface, wrappedInterface.wrappedMethods[0], "") 84 | expected := strip(` 85 | // TestClientDofFuncCall is an object that describes an invocation of method 86 | // Dof on an instance of MockTestClient. 87 | type TestClientDofFuncCall struct { 88 | // Arg0 is the value of the 1st argument passed to this method 89 | // invocation. 90 | Arg0 string 91 | // Arg1 is a slice containing the values of the variadic arguments 92 | // passed to this method invocation. 93 | Arg1 []string 94 | // Result0 is the value of the 1st result returned from this method 95 | // invocation. 96 | Result0 bool 97 | } 98 | `) 99 | 100 | assert.Equal(t, expected, fmt.Sprintf("%#v", code)) 101 | } 102 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_test.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/dave/jennifer/jen" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestGenerateInterface(t *testing.T) { 12 | expectedDecls := []string{ 13 | // Structs 14 | "type MockTestClient struct", 15 | "type TestClientDoFunc struct", 16 | "type TestClientDoFuncCall struct", 17 | "type TestClientDofFunc struct", 18 | "type TestClientDofFuncCall struct", 19 | "func NewMockTestClient() *MockTestClient", 20 | // Overrides 21 | "func (m *MockTestClient) Do(v0 string) bool", 22 | "func (m *MockTestClient) Dof(v0 string, v1 ...string) bool", 23 | // DoFunc Methods 24 | "func (f *TestClientDoFunc) SetDefaultHook(hook func(string) bool)", 25 | "func (f *TestClientDoFunc) PushHook(hook func(string) bool)", 26 | "func (f *TestClientDoFunc) SetDefaultReturn(r0 bool)", 27 | "func (f *TestClientDoFunc) PushReturn(r0 bool)", 28 | "func (f *TestClientDoFunc) History() []TestClientDoFuncCall", 29 | // DoFuncCall methods 30 | "func (c TestClientDoFuncCall) Args() []interface{}", 31 | "func (c TestClientDoFuncCall) Results() []interface{}", 32 | // DofFunc Methods 33 | "func (f *TestClientDofFunc) SetDefaultHook(hook func(string, ...string) bool)", 34 | "func (f *TestClientDofFunc) PushHook(hook func(string, ...string) bool)", 35 | "func (f *TestClientDofFunc) SetDefaultReturn(r0 bool)", 36 | "func (f *TestClientDofFunc) PushReturn(r0 bool)", 37 | "func (f *TestClientDofFunc) History() []TestClientDofFuncCall", 38 | // DofFuncCall methods 39 | "func (c TestClientDofFuncCall) Args() []interface{}", 40 | "func (c TestClientDofFuncCall) Results() []interface{}", 41 | } 42 | 43 | file := jen.NewFile("test") 44 | 45 | generateInterface(file, makeBareInterface(TestMethodDo, TestMethodDof), TestPrefix, "", "") 46 | rendered := fmt.Sprintf("%#v\n", file) 47 | 48 | for _, decl := range expectedDecls { 49 | assert.Contains(t, rendered, decl) 50 | } 51 | } 52 | 53 | func TestGenerateContent(t *testing.T) { 54 | t.Run("with generated by header only", func(t *testing.T) { 55 | pkg := "testpkg" 56 | got, err := generateContent(nil, pkg, ContentOptions{}) 57 | if !assert.NoError(t, err) { 58 | return 59 | } 60 | assert.Regexp(t, `(?m)^// Code generated by .+; DO NOT EDIT\.$`, got) 61 | }) 62 | t.Run("with file prefix", func(t *testing.T) { 63 | pkg := "testpkg" 64 | wantPrefix := "Example file prefix" 65 | got, err := generateContent(nil, pkg, ContentOptions{ 66 | FilePrefix: wantPrefix, 67 | }) 68 | if !assert.NoError(t, err) { 69 | return 70 | } 71 | assert.Regexp(t, `(?m)^// Code generated by .+; DO NOT EDIT\.\n//\n// Example file prefix$`, got) 72 | }) 73 | t.Run("with build constraints", func(t *testing.T) { 74 | pkg := "testpkg" 75 | wantConstraints := "(linux && 386) || (darwin && !cgo)" 76 | got, err := generateContent(nil, pkg, ContentOptions{ 77 | BuildConstraints: wantConstraints, 78 | }) 79 | if !assert.NoError(t, err) { 80 | return 81 | } 82 | assert.Regexp(t, `(?m)^// Code generated by .+; DO NOT EDIT\.\n//go:build .+$`, got) 83 | assert.Contains(t, got, "//go:build "+wantConstraints) 84 | }) 85 | t.Run("with build constraints and file prefix", func(t *testing.T) { 86 | pkg := "testpkg" 87 | wantConstraints := "(linux && 386) || (darwin && !cgo)" 88 | wantPrefix := "Example file prefix" 89 | got, err := generateContent(nil, pkg, ContentOptions{ 90 | BuildConstraints: wantConstraints, 91 | FilePrefix: wantPrefix, 92 | }) 93 | if !assert.NoError(t, err) { 94 | return 95 | } 96 | assert.Regexp(t, 97 | `(?m)^// Code generated by .+; DO NOT EDIT\.\n//\n// Example file prefix\n//go:build .+$`, 98 | got, 99 | ) 100 | assert.Contains(t, got, "//go:build "+wantConstraints) 101 | }) 102 | } 103 | -------------------------------------------------------------------------------- /internal/mockgen/generation/generate_type.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "fmt" 5 | "go/types" 6 | 7 | "github.com/dave/jennifer/jen" 8 | ) 9 | 10 | type typeGenerator func(typ types.Type) *jen.Statement 11 | 12 | func generateType(typ types.Type, importPath, outputImportPath string, variadic bool) (out *jen.Statement) { 13 | recur := func(typ types.Type) *jen.Statement { 14 | return generateType(typ, importPath, outputImportPath, false) 15 | } 16 | 17 | switch t := typ.(type) { 18 | case *types.Array: 19 | return generateArrayType(t, recur) 20 | case *types.Basic: 21 | return generateBasicType(t, recur) 22 | case *types.Chan: 23 | return generateChanType(t, recur) 24 | case *types.Interface: 25 | return generateInterfaceType(t, recur) 26 | case *types.Map: 27 | return generateMapType(t, recur) 28 | case *types.Named: 29 | return generateNamedType(t, importPath, outputImportPath, recur) 30 | case *types.Pointer: 31 | return generatePointerType(t, recur) 32 | case *types.Signature: 33 | return generateSignatureType(t, recur) 34 | case *types.Slice: 35 | return generateSliceType(t, variadic, recur) 36 | case *types.Struct: 37 | return generateStructType(t, recur) 38 | case *types.TypeParam: 39 | return generateTypeParamType(t) 40 | case *types.Union: 41 | return generateUnionType(t, recur) 42 | 43 | default: 44 | panic(fmt.Sprintf("unsupported case: %#v\n", typ)) 45 | } 46 | } 47 | 48 | func generateArrayType(t *types.Array, generate typeGenerator) *jen.Statement { 49 | return compose(jen.Index(jen.Lit(int(t.Len()))), generate(t.Elem())) 50 | } 51 | 52 | func generateBasicType(t *types.Basic, _ typeGenerator) *jen.Statement { 53 | return jen.Id(t.String()) 54 | } 55 | 56 | func generateChanType(t *types.Chan, generate typeGenerator) *jen.Statement { 57 | c := jen.Chan() 58 | 59 | if t.Dir() == types.RecvOnly { 60 | c = compose(jen.Op("<-"), c) 61 | } else if t.Dir() == types.SendOnly { 62 | c = compose(c, jen.Op("<-")) 63 | } 64 | 65 | return compose(c, generate(t.Elem())) 66 | } 67 | 68 | func generateInterfaceType(t *types.Interface, generate typeGenerator) *jen.Statement { 69 | embeds := make([]jen.Code, 0, t.NumEmbeddeds()) 70 | for i := 0; i < t.NumEmbeddeds(); i++ { 71 | if typ := t.EmbeddedType(i); typ != nil { 72 | embeds = append(embeds, compose(jen.Op("~"), generate(typ))) 73 | } 74 | } 75 | 76 | methods := make([]jen.Code, 0, t.NumMethods()) 77 | for i := 0; i < t.NumMethods(); i++ { 78 | params, results := generatePartialSignature(t.Method(i).Type().(*types.Signature), generate) 79 | methods = append(methods, jen.Id(t.Method(i).Name()).Params(params...).Params(results...)) 80 | } 81 | 82 | return jen.Interface(append(embeds, methods...)...) 83 | } 84 | 85 | func generateMapType(t *types.Map, generate typeGenerator) *jen.Statement { 86 | return compose(jen.Map(generate(t.Key())), generate(t.Elem())) 87 | } 88 | 89 | func generateNamedType(t *types.Named, importPath, outputImportPath string, generate typeGenerator) *jen.Statement { 90 | name := generateQualifiedName(t, importPath, outputImportPath) 91 | 92 | if typeArgs := t.TypeArgs(); typeArgs != nil { 93 | typeArguments := make([]jen.Code, 0, typeArgs.Len()) 94 | for i := 0; i < typeArgs.Len(); i++ { 95 | typeArguments = append(typeArguments, generate(typeArgs.At(i))) 96 | } 97 | 98 | name = name.Types(typeArguments...) 99 | } 100 | 101 | return name 102 | } 103 | 104 | func generatePointerType(t *types.Pointer, generate typeGenerator) *jen.Statement { 105 | return compose(jen.Op("*"), generate(t.Elem())) 106 | } 107 | 108 | func generateSignatureType(t *types.Signature, generate typeGenerator) *jen.Statement { 109 | params, results := generatePartialSignature(t, generate) 110 | return jen.Func().Params(params...).Params(results...) 111 | } 112 | 113 | func generatePartialSignature(t *types.Signature, generate typeGenerator) (params, results []jen.Code) { 114 | params = make([]jen.Code, 0, t.Params().Len()) 115 | for i := 0; i < t.Params().Len(); i++ { 116 | params = append(params, compose(jen.Id(t.Params().At(i).Name()), generate(t.Params().At(i).Type()))) 117 | } 118 | 119 | results = make([]jen.Code, 0, t.Results().Len()) 120 | for i := 0; i < t.Results().Len(); i++ { 121 | results = append(results, generate(t.Results().At(i).Type())) 122 | } 123 | 124 | return params, results 125 | } 126 | 127 | func generateSliceType(t *types.Slice, variadic bool, generate typeGenerator) *jen.Statement { 128 | if variadic { 129 | return compose(jen.Op("..."), generate(t.Elem())) 130 | } 131 | 132 | return compose(jen.Index(), generate(t.Elem())) 133 | } 134 | 135 | func generateStructType(t *types.Struct, generate typeGenerator) *jen.Statement { 136 | fields := make([]jen.Code, 0, t.NumFields()) 137 | for i := 0; i < t.NumFields(); i++ { 138 | fields = append(fields, compose(jen.Id(t.Field(i).Name()), generate(t.Field(i).Type()))) 139 | } 140 | 141 | return jen.Struct(fields...) 142 | } 143 | 144 | func generateTypeParamType(t *types.TypeParam) *jen.Statement { 145 | return jen.Id(t.String()) 146 | } 147 | 148 | func generateUnionType(t *types.Union, generate typeGenerator) *jen.Statement { 149 | types := make([]jen.Code, 0, t.Len()) 150 | for i := 0; i < t.Len(); i++ { 151 | types = append(types, generate(t.Term(i).Type())) 152 | } 153 | 154 | return jen.Union(types...) 155 | } 156 | -------------------------------------------------------------------------------- /internal/mockgen/generation/helpers_test.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | gotypes "go/types" 5 | "strings" 6 | 7 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 8 | ) 9 | 10 | const ( 11 | TestPrefix = "Test" 12 | TestTitleName = "Client" 13 | TestMockStructName = "MockTestClient" 14 | TestImportPath = "github.com/derision-test/go-mockgen/v2/test" 15 | ) 16 | 17 | var ( 18 | boolType = getType(gotypes.Bool) 19 | stringType = getType(gotypes.String) 20 | stringSliceType = gotypes.NewSlice(getType(gotypes.String)) 21 | 22 | TestMethodStatus = &types.Method{ 23 | Name: "Status", 24 | Params: []gotypes.Type{}, 25 | Results: []gotypes.Type{stringType, boolType}, 26 | } 27 | 28 | TestMethodDo = &types.Method{ 29 | Name: "Do", 30 | Params: []gotypes.Type{stringType}, 31 | Results: []gotypes.Type{boolType}, 32 | } 33 | 34 | TestMethodDof = &types.Method{ 35 | Name: "Dof", 36 | Params: []gotypes.Type{stringType, stringSliceType}, 37 | Results: []gotypes.Type{boolType}, 38 | Variadic: true, 39 | } 40 | ) 41 | 42 | func getType(kind gotypes.BasicKind) gotypes.Type { 43 | return gotypes.Typ[kind].Underlying() 44 | } 45 | 46 | func makeBareInterface(methods ...*types.Method) *types.Interface { 47 | return &types.Interface{ 48 | Name: TestTitleName, 49 | ImportPath: TestImportPath, 50 | Methods: methods, 51 | } 52 | } 53 | 54 | func makeInterface(methods ...*types.Method) *wrappedInterface { 55 | return wrapInterface(makeBareInterface(methods...), TestPrefix, TestTitleName, TestMockStructName, "") 56 | } 57 | 58 | func makeMethod(methods ...*types.Method) (*wrappedInterface, *wrappedMethod) { 59 | wrapped := makeInterface(methods...) 60 | return wrapped, wrapped.wrappedMethods[0] 61 | } 62 | 63 | func strip(block string) string { 64 | lines := strings.Split(block, "\n") 65 | for i, line := range lines { 66 | if strings.HasPrefix(line, "\t\t") { 67 | lines[i] = line[2:] 68 | } 69 | } 70 | 71 | return strings.TrimSpace(strings.Join(lines, "\n")) 72 | } 73 | -------------------------------------------------------------------------------- /internal/mockgen/generation/paths.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | gotypes "go/types" 5 | "strings" 6 | 7 | "github.com/dave/jennifer/jen" 8 | ) 9 | 10 | func generateQualifiedName(t *gotypes.Named, importPath, outputImportPath string) *jen.Statement { 11 | name := t.Obj().Name() 12 | 13 | if t.Obj().Pkg() == nil { 14 | return jen.Id(name) 15 | } 16 | 17 | if path := t.Obj().Pkg().Path(); path != "" { 18 | return jen.Qual(sanitizeImportPath(path, outputImportPath), name) 19 | } 20 | 21 | return jen.Qual(sanitizeImportPath(importPath, outputImportPath), name) 22 | } 23 | 24 | func sanitizeImportPath(path, outputImportPath string) string { 25 | path = stripVendor(path) 26 | if path == outputImportPath { 27 | return "" 28 | } 29 | 30 | return path 31 | } 32 | 33 | func stripVendor(path string) string { 34 | parts := strings.Split(path, "/vendor/") 35 | return parts[len(parts)-1] 36 | } 37 | -------------------------------------------------------------------------------- /internal/mockgen/generation/util.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "github.com/dave/jennifer/jen" 5 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 6 | ) 7 | 8 | func compose(stmt *jen.Statement, tail ...jen.Code) *jen.Statement { 9 | head := *stmt 10 | for _, value := range tail { 11 | head = append(head, value) 12 | } 13 | 14 | return &head 15 | } 16 | 17 | func addComment(code *jen.Statement, level int, commentText string) *jen.Statement { 18 | if commentText == "" { 19 | return code 20 | } 21 | 22 | comment := generateComment(level, commentText) 23 | return compose(comment, code) 24 | } 25 | 26 | func selfAppend(sliceRef *jen.Statement, value jen.Code) jen.Code { 27 | return compose(sliceRef, jen.Op("=").Id("append").Call(sliceRef, value)) 28 | } 29 | 30 | func addTypes(code *jen.Statement, typeParams []types.TypeParam, outputImportPath string, includeTypes bool) *jen.Statement { 31 | if len(typeParams) == 0 { 32 | return code 33 | } 34 | 35 | types := make([]jen.Code, 0, len(typeParams)) 36 | for _, typeParam := range typeParams { 37 | if includeTypes { 38 | types = append(types, compose(jen.Id(typeParam.Name), generateType(typeParam.Type, "", outputImportPath, false))) 39 | } else { 40 | types = append(types, jen.Id(typeParam.Name)) 41 | } 42 | } 43 | 44 | return compose(code, jen.Types(types...)) 45 | } 46 | -------------------------------------------------------------------------------- /internal/mockgen/generation/wrapped_interface.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 4 | 5 | type wrappedInterface struct { 6 | *types.Interface 7 | prefix string 8 | titleName string 9 | mockStructName string 10 | wrappedMethods []*wrappedMethod 11 | } 12 | 13 | func wrapInterface(iface *types.Interface, prefix, titleName, mockStructName, outputImportPath string) *wrappedInterface { 14 | wrapped := &wrappedInterface{ 15 | Interface: iface, 16 | prefix: prefix, 17 | titleName: titleName, 18 | mockStructName: mockStructName, 19 | } 20 | 21 | for _, method := range iface.Methods { 22 | wrapped.wrappedMethods = append(wrapped.wrappedMethods, wrapMethod(iface, method, outputImportPath)) 23 | } 24 | 25 | return wrapped 26 | } 27 | -------------------------------------------------------------------------------- /internal/mockgen/generation/wrapped_method.go: -------------------------------------------------------------------------------- 1 | package generation 2 | 3 | import ( 4 | "github.com/dave/jennifer/jen" 5 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/types" 6 | ) 7 | 8 | type wrappedMethod struct { 9 | *types.Method 10 | iface *types.Interface 11 | dotlessParamTypes []jen.Code 12 | paramTypes []jen.Code 13 | resultTypes []jen.Code 14 | signature jen.Code 15 | } 16 | 17 | func wrapMethod(iface *types.Interface, method *types.Method, outputImportPath string) *wrappedMethod { 18 | m := &wrappedMethod{ 19 | Method: method, 20 | iface: iface, 21 | dotlessParamTypes: generateParamTypes(method, iface.ImportPath, outputImportPath, true), 22 | paramTypes: generateParamTypes(method, iface.ImportPath, outputImportPath, false), 23 | resultTypes: generateResultTypes(method, iface.ImportPath, outputImportPath), 24 | } 25 | 26 | m.signature = jen.Func().Params(m.paramTypes...).Params(m.resultTypes...) 27 | return m 28 | } 29 | 30 | func generateParamTypes(method *types.Method, importPath, outputImportPath string, omitDots bool) []jen.Code { 31 | params := make([]jen.Code, 0, len(method.Params)) 32 | for i, typ := range method.Params { 33 | params = append(params, generateType( 34 | typ, 35 | importPath, 36 | outputImportPath, 37 | method.Variadic && i == len(method.Params)-1 && !omitDots, 38 | )) 39 | } 40 | 41 | return params 42 | } 43 | 44 | func generateResultTypes(method *types.Method, importPath, outputImportPath string) []jen.Code { 45 | results := make([]jen.Code, 0, len(method.Results)) 46 | for _, typ := range method.Results { 47 | results = append(results, generateType( 48 | typ, 49 | importPath, 50 | outputImportPath, 51 | false, 52 | )) 53 | } 54 | 55 | return results 56 | } 57 | -------------------------------------------------------------------------------- /internal/mockgen/paths/exist.go: -------------------------------------------------------------------------------- 1 | package paths 2 | 3 | import "os" 4 | 5 | func DirExists(path string) bool { 6 | if info, err := os.Stat(path); err == nil { 7 | return info.IsDir() 8 | } 9 | 10 | return false 11 | } 12 | 13 | func Exists(path string) (bool, error) { 14 | if _, err := os.Stat(path); err != nil { 15 | if os.IsNotExist(err) { 16 | err = nil 17 | } 18 | 19 | return false, err 20 | } 21 | 22 | return true, nil 23 | } 24 | 25 | func AnyExists(paths []string) (string, error) { 26 | for _, path := range paths { 27 | exists, err := Exists(path) 28 | if err != nil { 29 | return "", err 30 | } 31 | 32 | if exists { 33 | return path, nil 34 | } 35 | } 36 | 37 | return "", nil 38 | } 39 | 40 | func EnsureDirExists(dirname string) error { 41 | exists, err := Exists(dirname) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | if exists { 47 | return nil 48 | } 49 | 50 | return os.MkdirAll(dirname, os.ModeDir|os.ModePerm) 51 | } 52 | -------------------------------------------------------------------------------- /internal/mockgen/paths/project.go: -------------------------------------------------------------------------------- 1 | package paths 2 | 3 | import ( 4 | "go/build" 5 | "io/ioutil" 6 | "os" 7 | "path/filepath" 8 | "regexp" 9 | "strings" 10 | ) 11 | 12 | var srcpath = filepath.Join(gopath(), "src") 13 | var modulePattern = regexp.MustCompile(`^module\s+(.+)$`) 14 | 15 | func InferImportPath(dirname string) (string, bool) { 16 | if module, wd, ok := module(dirname); ok { 17 | return filepath.Join(module, dirname[len(wd):]), true 18 | } 19 | 20 | if strings.HasPrefix(dirname, srcpath) { 21 | return dirname[len(srcpath):], true 22 | } 23 | 24 | return "", false 25 | } 26 | 27 | func ResolveImportPath(wd, importPath string) (string, string) { 28 | // See if we're in a module and generating for our own package 29 | if module, baseDir, ok := module(wd); ok && strings.HasPrefix(importPath, module) { 30 | return importPath, filepath.Join(baseDir, importPath[len(module):]) 31 | } 32 | 33 | // See if it's a relative path to working directory 34 | if dir := filepath.Join(wd, importPath); DirExists(dir) { 35 | if path, ok := InferImportPath(dir); ok { 36 | return path, dir 37 | } 38 | } 39 | 40 | if strings.HasPrefix(wd, srcpath) { 41 | for wd != srcpath { 42 | // See if it's vendored on any path up to the GOPATH root 43 | if dir := filepath.Join(wd, "vendor", importPath); DirExists(dir) { 44 | return importPath, dir 45 | } 46 | 47 | wd = filepath.Dir(wd) 48 | } 49 | } 50 | 51 | // See if it's in the GOPATH 52 | if dir := filepath.Join(srcpath, importPath); DirExists(dir) { 53 | return importPath, dir 54 | } 55 | 56 | // It's installed as a module 57 | return importPath, importPath 58 | } 59 | 60 | func module(dirname string) (string, string, bool) { 61 | wd := dirname 62 | for wd != srcpath && wd != string(os.PathSeparator) { 63 | if module, ok := gomod(wd); ok { 64 | return module, wd, true 65 | } 66 | 67 | wd = filepath.Dir(wd) 68 | } 69 | 70 | return "", "", false 71 | } 72 | 73 | func gomod(dirname string) (string, bool) { 74 | content, err := ioutil.ReadFile(filepath.Join(dirname, "go.mod")) 75 | if err != nil { 76 | return "", false 77 | } 78 | 79 | for _, line := range strings.Split(string(content), "\n") { 80 | if matches := modulePattern.FindStringSubmatch(line); len(matches) > 0 { 81 | return matches[1], true 82 | } 83 | } 84 | 85 | return "", false 86 | } 87 | 88 | func gopath() string { 89 | if gopath := os.Getenv("GOPATH"); gopath != "" { 90 | return gopath 91 | } 92 | 93 | return build.Default.GOPATH 94 | } 95 | -------------------------------------------------------------------------------- /internal/mockgen/paths/relative.go: -------------------------------------------------------------------------------- 1 | package paths 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | ) 9 | 10 | func GetRelativePath(path string) string { 11 | wd, err := os.Getwd() 12 | if err != nil { 13 | return path 14 | } 15 | 16 | wd, err = filepath.EvalSymlinks(wd) 17 | if err != nil { 18 | return path 19 | } 20 | 21 | if strings.HasPrefix(path, wd) { 22 | return fmt.Sprintf(".%s", path[len(wd):]) 23 | } 24 | 25 | return path 26 | } 27 | -------------------------------------------------------------------------------- /internal/mockgen/types/extract.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "log" 7 | "os" 8 | "sort" 9 | "strings" 10 | "unicode" 11 | 12 | "github.com/derision-test/go-mockgen/v2/internal/mockgen/paths" 13 | "golang.org/x/tools/go/packages" 14 | ) 15 | 16 | type PackageOptions struct { 17 | ImportPaths []string 18 | Interfaces []string 19 | Exclude []string 20 | Prefix string 21 | } 22 | 23 | func Extract(pkgs []*packages.Package, packageOptions []PackageOptions) (ifaces []*Interface, _ error) { 24 | workingDirectory, err := os.Getwd() 25 | if err != nil { 26 | return nil, fmt.Errorf("failed to get working directory (%s)", err.Error()) 27 | } 28 | 29 | for _, packageOpts := range packageOptions { 30 | packageTypes, err := gatherAllPackageTypes(pkgs, workingDirectory, packageOpts.ImportPaths) 31 | if err != nil { 32 | return nil, err 33 | } 34 | 35 | for _, name := range gatherAllPackageTypeNames(packageTypes) { 36 | iface, err := extractInterface(packageTypes, name, packageOpts.Interfaces, packageOpts.Exclude) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | if iface != nil { 42 | iface.Prefix = packageOpts.Prefix 43 | ifaces = append(ifaces, iface) 44 | } 45 | } 46 | } 47 | 48 | return ifaces, nil 49 | } 50 | 51 | func gatherAllPackageTypes(pkgs []*packages.Package, workingDirectory string, importPaths []string) (map[string]map[string]*Interface, error) { 52 | packageTypes := make(map[string]map[string]*Interface, len(importPaths)) 53 | for _, importPath := range importPaths { 54 | path, dir := paths.ResolveImportPath(workingDirectory, importPath) 55 | log.Printf("parsing package '%s'\n", paths.GetRelativePath(dir)) 56 | 57 | types, err := gatherTypesForPackage(pkgs, importPath, path) 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | packageTypes[path] = types 63 | } 64 | 65 | return packageTypes, nil 66 | } 67 | 68 | func gatherTypesForPackage(pkgs []*packages.Package, importPath, path string) (map[string]*Interface, error) { 69 | for _, pkg := range pkgs { 70 | if pkg.PkgPath != path { 71 | continue 72 | } 73 | 74 | for _, err := range pkg.Errors { 75 | switch err.Kind { 76 | case packages.TypeError: 77 | log.Printf("package %s failed to type check, but attempting to continue: %s", importPath, err.Msg) 78 | default: 79 | return nil, fmt.Errorf("malformed package %s (%s)", importPath, err.Msg) 80 | } 81 | } 82 | 83 | visitor := newVisitor(path, pkg.Types) 84 | for _, file := range pkg.Syntax { 85 | ast.Walk(visitor, file) 86 | } 87 | 88 | return visitor.types, nil 89 | } 90 | 91 | return nil, fmt.Errorf("malformed package %s (not found)", importPath) 92 | } 93 | 94 | func gatherAllPackageTypeNames(packageTypes map[string]map[string]*Interface) []string { 95 | nameMap := map[string]struct{}{} 96 | for _, pkg := range packageTypes { 97 | for name := range pkg { 98 | nameMap[name] = struct{}{} 99 | } 100 | } 101 | 102 | names := make([]string, 0, len(nameMap)) 103 | for name := range nameMap { 104 | names = append(names, name) 105 | } 106 | sort.Strings(names) 107 | 108 | return names 109 | } 110 | 111 | func extractInterface(packageTypes map[string]map[string]*Interface, name string, targetNames, excludeNames []string) (*Interface, error) { 112 | if !shouldInclude(name, targetNames, excludeNames) { 113 | return nil, nil 114 | } 115 | 116 | candidates := make([]*Interface, 0, 1) 117 | for _, pkg := range packageTypes { 118 | if t, ok := pkg[name]; ok { 119 | candidates = append(candidates, t) 120 | 121 | if len(candidates) > 1 { 122 | return nil, fmt.Errorf("type '%s' is multiply-defined in supplied import paths", name) 123 | } 124 | } 125 | } 126 | if len(candidates) == 0 { 127 | return nil, nil 128 | } 129 | 130 | iface := candidates[0] 131 | 132 | for _, method := range iface.Methods { 133 | if !unicode.IsUpper([]rune(method.Name)[0]) { 134 | return nil, fmt.Errorf( 135 | "type '%s' has unexported an method '%s'", 136 | name, 137 | method.Name, 138 | ) 139 | } 140 | } 141 | 142 | return iface, nil 143 | } 144 | 145 | func shouldInclude(name string, targetNames, excludeNames []string) bool { 146 | for _, v := range excludeNames { 147 | if strings.ToLower(v) == strings.ToLower(name) { 148 | return false 149 | } 150 | } 151 | 152 | for _, v := range targetNames { 153 | if strings.ToLower(v) == strings.ToLower(name) { 154 | return true 155 | } 156 | } 157 | 158 | return len(targetNames) == 0 159 | } 160 | -------------------------------------------------------------------------------- /internal/mockgen/types/interface.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "go/ast" 5 | "go/types" 6 | "sort" 7 | ) 8 | 9 | type Interface struct { 10 | Name string 11 | ImportPath string 12 | TypeParams []TypeParam 13 | Methods []*Method 14 | 15 | // Prefix is set on extraction based on the current PackageOptions 16 | Prefix string 17 | } 18 | 19 | type TypeParam struct { 20 | Name string 21 | Type types.Type 22 | } 23 | 24 | func newInterfaceFromTypeSpec(name, importPath string, typeSpec *ast.TypeSpec, underlyingType *types.Interface, ps *types.TypeParamList) *Interface { 25 | methodMap := make(map[string]*Method, underlyingType.NumMethods()) 26 | for i := 0; i < underlyingType.NumMethods(); i++ { 27 | method := underlyingType.Method(i) 28 | name := method.Name() 29 | methodMap[name] = newMethodFromSignature(name, method.Type().(*types.Signature)) 30 | } 31 | 32 | methodNames := make([]string, 0, len(methodMap)) 33 | for k := range methodMap { 34 | methodNames = append(methodNames, k) 35 | } 36 | sort.Strings(methodNames) 37 | 38 | methods := make([]*Method, 0, len(methodNames)) 39 | for _, name := range methodNames { 40 | methods = append(methods, methodMap[name]) 41 | } 42 | 43 | var typeParams []TypeParam 44 | if typeSpec.TypeParams != nil && ps != nil { 45 | for i, field := range typeSpec.TypeParams.List { 46 | for _, name := range field.Names { 47 | typeParams = append(typeParams, TypeParam{Name: name.Name, Type: ps.At(i).Constraint()}) 48 | } 49 | } 50 | } 51 | 52 | return &Interface{ 53 | Name: name, 54 | ImportPath: importPath, 55 | TypeParams: typeParams, 56 | Methods: methods, 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /internal/mockgen/types/method.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import "go/types" 4 | 5 | type Method struct { 6 | Name string 7 | Params []types.Type 8 | Results []types.Type 9 | Variadic bool 10 | } 11 | 12 | func newMethodFromSignature(name string, signature *types.Signature) *Method { 13 | ps := signature.Params() 14 | pn := ps.Len() 15 | params := make([]types.Type, 0, pn) 16 | for i := 0; i < pn; i++ { 17 | params = append(params, ps.At(i).Type()) 18 | } 19 | 20 | rs := signature.Results() 21 | rn := rs.Len() 22 | results := make([]types.Type, 0, rn) 23 | for i := 0; i < rn; i++ { 24 | results = append(results, rs.At(i).Type()) 25 | } 26 | 27 | return &Method{ 28 | Name: name, 29 | Params: params, 30 | Results: results, 31 | Variadic: signature.Variadic(), 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /internal/mockgen/types/visitor.go: -------------------------------------------------------------------------------- 1 | package types 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/types" 7 | ) 8 | 9 | type visitor struct { 10 | importPath string 11 | pkgType *types.Package 12 | types map[string]*Interface 13 | } 14 | 15 | func newVisitor(importPath string, pkgType *types.Package) *visitor { 16 | return &visitor{ 17 | importPath: importPath, 18 | pkgType: pkgType, 19 | types: map[string]*Interface{}, 20 | } 21 | } 22 | 23 | func (v *visitor) Visit(node ast.Node) ast.Visitor { 24 | switch n := node.(type) { 25 | case *ast.File: 26 | return v 27 | 28 | case *ast.GenDecl: 29 | for _, spec := range n.Specs { 30 | if typeSpec, ok := spec.(*ast.TypeSpec); ok { 31 | name := typeSpec.Name.Name 32 | _, obj := v.pkgType.Scope().Innermost(typeSpec.Pos()).LookupParent(name, 0) 33 | 34 | switch t := obj.Type().Underlying().(type) { 35 | case *types.Interface: 36 | namedType, ok := obj.Type().(*types.Named) 37 | if !ok { 38 | panic(fmt.Sprintf("Unexpected type %T: expected *types.Named", obj.Type())) 39 | } 40 | 41 | if !t.IsMethodSet() { 42 | // Contains type constraints - we generate illegal code in this circumstance. 43 | // I'm not sure it makes sense to support this case, but we can revisit if we 44 | // get a feature request in the future or run into a case in the wild. 45 | continue 46 | } 47 | 48 | v.types[name] = newInterfaceFromTypeSpec(name, v.importPath, typeSpec, t, namedType.TypeParams()) 49 | } 50 | } 51 | } 52 | } 53 | 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /internal/testutil/helpers_test.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | type mockFunc struct { 4 | history []mockCall 5 | } 6 | 7 | type mockCall struct { 8 | args []interface{} 9 | results []interface{} 10 | } 11 | 12 | func newHistory(calls ...mockCall) *mockFunc { 13 | return &mockFunc{ 14 | history: calls, 15 | } 16 | } 17 | 18 | func (m mockFunc) History() []mockCall { return m.history } 19 | func (m mockCall) Args() []interface{} { return m.args } 20 | func (m mockCall) Results() []interface{} { return m.results } 21 | -------------------------------------------------------------------------------- /internal/testutil/reflect_helpers.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import "reflect" 4 | 5 | // CallInstance holds the arguments and results of a single mock function call. 6 | type CallInstance interface { 7 | Args() []interface{} 8 | Results() []interface{} 9 | } 10 | 11 | // GetCallHistory extracts the history from the given mock function and returns the 12 | // set of call instances. 13 | func GetCallHistory(v interface{}) ([]CallInstance, bool) { 14 | value := reflect.ValueOf(v) 15 | if !value.IsValid() { 16 | return nil, false 17 | } 18 | 19 | // Get reflect value of method 20 | method := value.MethodByName("History") 21 | if !method.IsValid() { 22 | return nil, false 23 | } 24 | 25 | // Check method arity 26 | if method.Type().NumIn() != 0 || method.Type().NumOut() != 1 { 27 | return nil, false 28 | } 29 | 30 | // Invoke the function with no arguments and get the reflect.Value result 31 | history := method.Call(nil)[0] 32 | 33 | // Ensure the returned type is []interface{} 34 | if history.Kind() != reflect.Slice || !history.Type().Elem().Implements(reflect.TypeOf((*CallInstance)(nil)).Elem()) { 35 | return nil, false 36 | } 37 | 38 | calls := make([]CallInstance, 0, history.Len()) 39 | for i := 0; i < history.Len(); i++ { 40 | calls = append(calls, history.Index(i).Interface().(CallInstance)) 41 | } 42 | 43 | return calls, true 44 | } 45 | 46 | // GetCallHistoryWith extracts the history from the given mock function and returns the 47 | // set of call instances that match the given function. If the given parameter is not of 48 | // the required type, a false-valued flag is returned. 49 | func GetCallHistoryWith(v interface{}, matcher func(v CallInstance) bool) (matching []CallInstance, _ bool) { 50 | history, ok := GetCallHistory(v) 51 | if !ok { 52 | return nil, false 53 | } 54 | 55 | for _, call := range history { 56 | if matcher(call) { 57 | matching = append(matching, call) 58 | } 59 | } 60 | 61 | return matching, true 62 | } 63 | 64 | // GetArgs returns the arguments from teh given mock function invocation. If the given 65 | // parameter is not of the required type, a false-valued flag is returned. 66 | func GetArgs(v interface{}) ([]interface{}, bool) { 67 | value := reflect.ValueOf(v) 68 | if !value.IsValid() { 69 | return nil, false 70 | } 71 | 72 | // Get reflect value of method 73 | method := value.MethodByName("Args") 74 | if !method.IsValid() { 75 | return nil, false 76 | } 77 | 78 | // Check method arity 79 | if method.Type().NumIn() != 0 || method.Type().NumOut() != 1 { 80 | return nil, false 81 | } 82 | 83 | // Invoke the function with no arguments and get the reflect.Value result 84 | args := method.Call(nil)[0] 85 | 86 | // Ensure the returned type is []interface{} 87 | if args.Kind() != reflect.Slice || args.Type().Elem().Kind() != reflect.Interface { 88 | return nil, false 89 | } 90 | 91 | // Return result unchanged 92 | return args.Interface().([]interface{}), true 93 | } 94 | -------------------------------------------------------------------------------- /internal/testutil/reflect_helpers_test.go: -------------------------------------------------------------------------------- 1 | package testutil 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGetCallHistory(t *testing.T) { 10 | value := newHistory( 11 | mockCall{args: []interface{}{"foo", "bar"}}, 12 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 13 | mockCall{args: []interface{}{"foo", "bar", "baz", "bonk"}}, 14 | mockCall{args: []interface{}{"foo", "bar", "bonk"}}, 15 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 16 | ) 17 | 18 | history, ok := GetCallHistory(value) 19 | assert.True(t, ok) 20 | assert.Len(t, history, 5) 21 | } 22 | 23 | func TestGetCallHistoryWith(t *testing.T) { 24 | value := newHistory( 25 | mockCall{args: []interface{}{"foo", "bar"}}, 26 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 27 | mockCall{args: []interface{}{"foo", "bar", "baz", "bonk"}}, 28 | mockCall{args: []interface{}{"foo", "bar", "bonk"}}, 29 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 30 | ) 31 | 32 | matchingHistory, ok := GetCallHistoryWith(value, func(v CallInstance) bool { return len(v.Args()) == 3 }) 33 | assert.True(t, ok) 34 | assert.Len(t, matchingHistory, 3) 35 | } 36 | 37 | func TestGetCallHistoryNil(t *testing.T) { 38 | _, ok := GetCallHistory(nil) 39 | assert.False(t, ok) 40 | } 41 | 42 | func TestGetCallHistoryNoHistoryMethod(t *testing.T) { 43 | _, ok := GetCallHistory(struct{}{}) 44 | assert.False(t, ok) 45 | } 46 | 47 | func TestGetCallHistoryBadParamArity(t *testing.T) { 48 | _, ok := GetCallHistory(&historyFuncBadParamArity{}) 49 | assert.False(t, ok) 50 | } 51 | 52 | type historyFuncBadParamArity struct{} 53 | 54 | func (h *historyFuncBadParamArity) History(n int) []CallInstance { 55 | return nil 56 | } 57 | 58 | func TestGetCallHistoryBadResultArity(t *testing.T) { 59 | _, ok := GetCallHistory(&historyFuncBadResultArity{}) 60 | assert.False(t, ok) 61 | } 62 | 63 | type historyFuncBadResultArity struct{} 64 | 65 | func (h *historyFuncBadResultArity) History() ([]CallInstance, error) { 66 | return nil, nil 67 | } 68 | 69 | func TestGetCallHistoryNonSliceResult(t *testing.T) { 70 | _, ok := GetCallHistory(&historyFuncNonSliceResult{}) 71 | assert.False(t, ok) 72 | } 73 | 74 | type historyFuncNonSliceResult struct{} 75 | 76 | func (h *historyFuncNonSliceResult) History() string { 77 | return "" 78 | } 79 | 80 | func TestGetCallHistoryBadSliceTypes(t *testing.T) { 81 | _, ok := GetCallHistory(&historyFuncBadSliceTypes{}) 82 | assert.False(t, ok) 83 | } 84 | 85 | type historyFuncBadSliceTypes struct{} 86 | 87 | func (h *historyFuncBadSliceTypes) History() []string { 88 | return nil 89 | } 90 | 91 | func TestGetArgs(t *testing.T) { 92 | expectedArgs := []interface{}{"foo", "bar", "baz", "bonk"} 93 | args, ok := GetArgs(mockCall{args: expectedArgs}) 94 | assert.True(t, ok) 95 | assert.Equal(t, expectedArgs, args) 96 | } 97 | 98 | func TestGetArgsNil(t *testing.T) { 99 | _, ok := GetArgs(nil) 100 | assert.False(t, ok) 101 | } 102 | 103 | func TestGetArgsNoArgsMethod(t *testing.T) { 104 | _, ok := GetArgs(struct{}{}) 105 | assert.False(t, ok) 106 | } 107 | 108 | func TestGetArgsBadParamArity(t *testing.T) { 109 | _, ok := GetArgs(&argsFuncBadParamArity{}) 110 | assert.False(t, ok) 111 | } 112 | 113 | type argsFuncBadParamArity struct{} 114 | 115 | func (m *argsFuncBadParamArity) Args(n int) []interface{} { 116 | return nil 117 | } 118 | 119 | func TestGetArgsBadResultArity(t *testing.T) { 120 | _, ok := GetArgs(&argsFuncBadResultArity{}) 121 | assert.False(t, ok) 122 | } 123 | 124 | type argsFuncBadResultArity struct{} 125 | 126 | func (m *argsFuncBadResultArity) Args() ([]interface{}, error) { 127 | return nil, nil 128 | } 129 | 130 | func TestGetArgsNonSliceResult(t *testing.T) { 131 | _, ok := GetArgs(&argsFuncNonSliceResult{}) 132 | assert.False(t, ok) 133 | } 134 | 135 | type argsFuncNonSliceResult struct{} 136 | 137 | func (m *argsFuncNonSliceResult) Args() string { 138 | return "" 139 | } 140 | 141 | func TestGetArgsBadSliceTypes(t *testing.T) { 142 | _, ok := GetArgs(&argsFuncBadSliceTypes{}) 143 | assert.False(t, ok) 144 | } 145 | 146 | type argsFuncBadSliceTypes struct{} 147 | 148 | func (m *argsFuncBadSliceTypes) Args() []string { 149 | return nil 150 | } 151 | -------------------------------------------------------------------------------- /testutil/assert/asserter.go: -------------------------------------------------------------------------------- 1 | package mockassert 2 | 3 | import ( 4 | "reflect" 5 | 6 | "github.com/derision-test/go-mockgen/v2/internal/testutil" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | // CallInstanceAsserter determines whether or not a set of argument values from a call 11 | // of a mock function match the test constraints of a particular function call. See the 12 | // assertions `CalledWith`, `NotCalledWith`, `CalledOnceWith`, `CalledNWith`, and 13 | // `CalledAtNWith` for further usage. 14 | type CallInstanceAsserter interface { 15 | // Assert determines if the given argument values matches the expected 16 | // function call. 17 | Assert(interface{}) bool 18 | } 19 | 20 | type CallInstanceAsserterFunc func(v interface{}) bool 21 | 22 | func (f CallInstanceAsserterFunc) Assert(v interface{}) bool { 23 | return f(v) 24 | } 25 | 26 | type valueAsserter struct { 27 | expectedValues []interface{} 28 | } 29 | 30 | type skip struct{} 31 | 32 | // Skip is a sentinel value which is skipped in a call instance asserter. This is useful 33 | // when used to skip the leading "don't care" values such as leading context parameters. 34 | var Skip = &skip{} 35 | 36 | // Values returns a new call instance asserter that will match the arguments of each 37 | // function call positionally with each of the expected values. The assertion behavior 38 | // in each position can be tuned: 39 | // 40 | // Use the value `mockassert.Skip` to skip validation for values in that parameter 41 | // position. 42 | // 43 | // Use a function with the type `func(v T) bool` (for any `T`) to override validation for 44 | // values in that parameter position. 45 | func Values(expectedValues ...interface{}) CallInstanceAsserter { 46 | return &valueAsserter{ 47 | expectedValues: expectedValues, 48 | } 49 | } 50 | 51 | func (a *valueAsserter) Assert(v interface{}) bool { 52 | args, ok := testutil.GetArgs(v) 53 | if !ok { 54 | return false 55 | } 56 | 57 | if len(a.expectedValues) > len(args) { 58 | return false 59 | } 60 | 61 | for i, expectedValue := range a.expectedValues { 62 | if expectedValue == Skip { 63 | continue 64 | } 65 | 66 | // First check to see if it's a hook function we should invoke 67 | if ret, ok := callTesterFunc(expectedValue, args[i]); ok { 68 | if ret { 69 | continue 70 | } 71 | 72 | return false 73 | } 74 | 75 | // Fall back to value equality checks 76 | if assert.ObjectsAreEqual(expectedValue, args[i]) { 77 | continue 78 | } 79 | 80 | return false 81 | } 82 | 83 | return true 84 | } 85 | 86 | // callTesterFunc attempts to invoke the given value `v` of type func(T) bool 87 | // with the given argument `arg` of type T. 88 | // 89 | // If the runtime types match these assumptions, then teh function is invoked 90 | // and the result is returned along with a true-valued flag. If the runtime 91 | // values break these assumptions, a false-valued flag is returned. 92 | func callTesterFunc(v interface{}, arg interface{}) (result bool, ok bool) { 93 | value := reflect.ValueOf(v) 94 | if !value.IsValid() { 95 | return false, false 96 | } 97 | 98 | // Ensure value is a function (Type will panic otherwise) 99 | if value.Kind() != reflect.Func { 100 | return false, false 101 | } 102 | 103 | // Check function arity 104 | if value.Type().NumIn() != 1 || value.Type().NumOut() != 1 { 105 | return false, false 106 | } 107 | 108 | argValue := reflect.ValueOf(arg) 109 | 110 | // Ensure argument and parameter types match. (Call will panic otherwise) 111 | if value.Type().In(0).Kind() != argValue.Kind() { 112 | return false, false 113 | } 114 | 115 | // Invoke the function with a single argument and get the reflect.Value result 116 | resultValue := value.Call([]reflect.Value{argValue})[0] 117 | 118 | // Ensure the returned type is bool 119 | if resultValue.Kind() != reflect.Bool { 120 | return false, false 121 | } 122 | 123 | return resultValue.Interface().(bool), true 124 | } 125 | -------------------------------------------------------------------------------- /testutil/assert/asserter_test.go: -------------------------------------------------------------------------------- 1 | package mockassert 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestValues(t *testing.T) { 11 | asserter := Values( 12 | Skip, 13 | 123, 14 | func(a string) bool { return len(a) == 2 }, 15 | ) 16 | 17 | assert.True(t, asserter.Assert(newMockArgs(context.Background(), 123, "xx"))) 18 | assert.True(t, asserter.Assert(newMockArgs(context.Background(), 123, "yy"))) 19 | assert.True(t, asserter.Assert(newMockArgs(context.Background(), 123, "zz"))) 20 | assert.False(t, asserter.Assert(newMockArgs(context.Background(), 789, "w"))) 21 | assert.False(t, asserter.Assert(newMockArgs(context.Background(), 123, 123))) 22 | assert.False(t, asserter.Assert(newMockArgs(context.Background(), 123, nil))) 23 | } 24 | 25 | func TestCallTesterFunc(t *testing.T) { 26 | v1, ok := callTesterFunc(func(v int) bool { return v%2 == 0 }, 4) 27 | assert.True(t, ok) 28 | assert.True(t, v1) 29 | 30 | v2, ok := callTesterFunc(func(v int) bool { return v%2 == 0 }, 3) 31 | assert.True(t, ok) 32 | assert.False(t, v2) 33 | } 34 | 35 | func TestCallTesterFuncNil(t *testing.T) { 36 | _, ok := callTesterFunc(nil, nil) 37 | assert.False(t, ok) 38 | } 39 | 40 | func TestCallTesterFuncNonFunc(t *testing.T) { 41 | _, ok := callTesterFunc(123, nil) 42 | assert.False(t, ok) 43 | } 44 | 45 | func TestCallTesterFuncBadParamArity(t *testing.T) { 46 | _, ok := callTesterFunc(func(a, b string) bool { return false }, nil) 47 | assert.False(t, ok) 48 | } 49 | 50 | func TestCallTesterFuncBadResultArity(t *testing.T) { 51 | _, ok := callTesterFunc(func(a string) (bool, error) { return false, nil }, nil) 52 | assert.False(t, ok) 53 | } 54 | 55 | func TestCallTesterFuncBadResultType(t *testing.T) { 56 | _, ok := callTesterFunc(func(a string) string { return a }, nil) 57 | assert.False(t, ok) 58 | } 59 | 60 | func TestCallTesterFuncMismatchedTypes(t *testing.T) { 61 | _, ok := callTesterFunc(func(a string) bool { return true }, 123) 62 | assert.False(t, ok) 63 | } 64 | 65 | type testArgs interface { 66 | Args() []interface{} 67 | } 68 | 69 | type mockArgs struct { 70 | args []interface{} 71 | } 72 | 73 | func newMockArgs(args ...interface{}) testArgs { 74 | return &mockArgs{ 75 | args: args, 76 | } 77 | } 78 | 79 | func (m *mockArgs) Args() []interface{} { 80 | return m.args 81 | } 82 | -------------------------------------------------------------------------------- /testutil/assert/assertions.go: -------------------------------------------------------------------------------- 1 | package mockassert 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/derision-test/go-mockgen/v2/internal/testutil" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | // Called asserts that the mock function object was called at least once. 11 | func Called(t assert.TestingT, mockFn interface{}, msgAndArgs ...interface{}) bool { 12 | callCount, ok := callCount(t, mockFn, msgAndArgs...) 13 | if !ok { 14 | return false 15 | } 16 | if callCount == 0 { 17 | return assert.Fail(t, fmt.Sprintf("Expected %T to be called at least once", mockFn), msgAndArgs...) 18 | } 19 | 20 | return true 21 | } 22 | 23 | // NotCalled asserts that the mock function object was not called. 24 | func NotCalled(t assert.TestingT, mockFn interface{}, msgAndArgs ...interface{}) bool { 25 | callCount, ok := callCount(t, mockFn, msgAndArgs...) 26 | if !ok { 27 | return false 28 | } 29 | if callCount != 0 { 30 | return assert.Fail(t, fmt.Sprintf("Did not expect %T to be called", mockFn), msgAndArgs...) 31 | } 32 | 33 | return true 34 | } 35 | 36 | // CalledOnce asserts that the mock function object was called exactly once. 37 | func CalledOnce(t assert.TestingT, mockFn interface{}, msgAndArgs ...interface{}) bool { 38 | return CalledN(t, mockFn, 1, msgAndArgs...) 39 | } 40 | 41 | // CalledN asserts that the mock function object was called exactly n times. 42 | func CalledN(t assert.TestingT, mockFn interface{}, n int, msgAndArgs ...interface{}) bool { 43 | callCount, ok := callCount(t, mockFn, msgAndArgs...) 44 | if !ok { 45 | return false 46 | } 47 | if callCount != n { 48 | return assert.Fail(t, fmt.Sprintf("Expected %T to be called exactly %d times, called %d times", mockFn, n, callCount), msgAndArgs...) 49 | } 50 | 51 | return true 52 | } 53 | 54 | // CalledWith asserts that the mock function object was called at least once with a set of 55 | // arguments matching the given call instance asserter. 56 | func CalledWith(t assert.TestingT, mockFn interface{}, asserter CallInstanceAsserter, msgAndArgs ...interface{}) bool { 57 | matchingCallCount, ok := callCountWith(t, mockFn, asserter, msgAndArgs...) 58 | if !ok { 59 | return false 60 | } 61 | if matchingCallCount == 0 { 62 | return assert.Fail(t, fmt.Sprintf("Expected %T to be called with given arguments at least once", mockFn), msgAndArgs...) 63 | } 64 | return true 65 | } 66 | 67 | // NotCalledWith asserts that the mock function object was not called with a set of arguments 68 | // matching the given call instance asserter. 69 | func NotCalledWith(t assert.TestingT, mockFn interface{}, asserter CallInstanceAsserter, msgAndArgs ...interface{}) bool { 70 | matchingCallCount, ok := callCountWith(t, mockFn, asserter, msgAndArgs...) 71 | if !ok { 72 | return false 73 | } 74 | if matchingCallCount != 0 { 75 | return assert.Fail(t, fmt.Sprintf("Did not expect %T to be called with given arguments", mockFn), msgAndArgs...) 76 | } 77 | return true 78 | } 79 | 80 | // CalledOnceWith asserts that the mock function object was called exactly once with a set of 81 | // arguments matching the given call instance asserter. 82 | func CalledOnceWith(t assert.TestingT, mockFn interface{}, asserter CallInstanceAsserter, msgAndArgs ...interface{}) bool { 83 | return CalledNWith(t, mockFn, 1, asserter, msgAndArgs...) 84 | } 85 | 86 | // CalledNWith asserts that the mock function object was called exactly n times with a set of 87 | // arguments matching the given call instance asserter. 88 | func CalledNWith(t assert.TestingT, mockFn interface{}, n int, asserter CallInstanceAsserter, msgAndArgs ...interface{}) bool { 89 | matchingCallCount, ok := callCountWith(t, mockFn, asserter, msgAndArgs...) 90 | if !ok { 91 | return false 92 | } 93 | if matchingCallCount != n { 94 | return assert.Fail(t, fmt.Sprintf("Expected %T to be called with given arguments exactly %d times, called %d times", mockFn, n, matchingCallCount), msgAndArgs...) 95 | } 96 | return true 97 | } 98 | 99 | // CalledAtNWith asserts that the mock function objects nth call was with a set of 100 | // arguments matching the given call instance asserter. 101 | func CalledAtNWith(t assert.TestingT, mockFn interface{}, n int, asserter CallInstanceAsserter, msgAndArgs ...interface{}) bool { 102 | hist, ok := testutil.GetCallHistory(mockFn) 103 | if !ok { 104 | return false 105 | } 106 | if len(hist) < n { 107 | return assert.Fail(t, fmt.Sprintf("Expected %T to be called at least %d times, called %d times", mockFn, n, len(hist)), msgAndArgs...) 108 | } 109 | 110 | if !asserter.Assert(hist[n]) { 111 | return assert.Fail(t, fmt.Sprintf("Expected call %d of %T to be with given arguments", n, mockFn), msgAndArgs...) 112 | } 113 | 114 | return true 115 | } 116 | 117 | // callCount returns the number of times the given mock function was called. 118 | func callCount(t assert.TestingT, mockFn interface{}, msgAndArgs ...interface{}) (int, bool) { 119 | return callCountWith(t, mockFn, CallInstanceAsserterFunc(func(call interface{}) bool { return true }), msgAndArgs...) 120 | } 121 | 122 | // callCount returns the number of times the given mock function was called with a set of 123 | // arguments matching the given call instance asserter. 124 | func callCountWith(t assert.TestingT, mockFn interface{}, asserter CallInstanceAsserter, msgAndArgs ...interface{}) (int, bool) { 125 | matchingHistory, ok := testutil.GetCallHistoryWith(mockFn, func(call testutil.CallInstance) bool { 126 | // Pass in a dummy non-erroring TestingT so that any assertions done inside 127 | // of the asserter will not fail the enclosing test. 128 | return asserter.Assert(call) 129 | }) 130 | if !ok { 131 | return 0, assert.Fail(t, fmt.Sprintf("Parameters must be a mock function description, got %T", mockFn), msgAndArgs...) 132 | } 133 | 134 | return len(matchingHistory), true 135 | } 136 | -------------------------------------------------------------------------------- /testutil/gomega/anything_matcher.go: -------------------------------------------------------------------------------- 1 | package matchers 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/onsi/gomega/format" 7 | "github.com/onsi/gomega/types" 8 | ) 9 | 10 | type anythingMatcher struct { 11 | n int 12 | } 13 | 14 | // BeAnything returns a matcher that never fails. 15 | func BeAnything() types.GomegaMatcher { 16 | return &anythingMatcher{} 17 | } 18 | 19 | func (m *anythingMatcher) Match(actual interface{}) (bool, error) { 20 | return true, nil 21 | } 22 | 23 | func (m *anythingMatcher) FailureMessage(actual interface{}) string { 24 | return fmt.Sprintf("Expected\n%s\nto be anything", format.Object(actual, 1)) 25 | } 26 | 27 | func (m *anythingMatcher) NegatedFailureMessage(actual interface{}) string { 28 | return fmt.Sprintf("Expected\n%s\nnot to be anything", format.Object(actual, 1)) 29 | } 30 | -------------------------------------------------------------------------------- /testutil/gomega/anything_matcher_test.go: -------------------------------------------------------------------------------- 1 | package matchers 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestAnythingMatch(t *testing.T) { 10 | ok, err := BeAnything().Match(nil) 11 | assert.Nil(t, err) 12 | assert.True(t, ok) 13 | } 14 | -------------------------------------------------------------------------------- /testutil/gomega/called_matcher.go: -------------------------------------------------------------------------------- 1 | package matchers 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/derision-test/go-mockgen/v2/internal/testutil" 7 | "github.com/onsi/gomega/format" 8 | "github.com/onsi/gomega/types" 9 | ) 10 | 11 | type calledMatcher struct { 12 | name string 13 | } 14 | 15 | var _ types.GomegaMatcher = &calledMatcher{} 16 | 17 | // BeCalled constructs a matcher that asserts the mock function object was called at least once. 18 | func BeCalled() types.GomegaMatcher { 19 | return &calledMatcher{ 20 | name: "BeCalled", 21 | } 22 | } 23 | 24 | func (m *calledMatcher) Match(actual interface{}) (bool, error) { 25 | history, ok := testutil.GetCallHistory(actual) 26 | if !ok { 27 | return false, fmt.Errorf("%s expects a mock function description. Got:\n%s", m.name, format.Object(actual, 1)) 28 | } 29 | 30 | return len(history) > 0, nil 31 | } 32 | 33 | func (m *calledMatcher) FailureMessage(actual interface{}) string { 34 | return fmt.Sprintf("Expected\n%s\nto be called at least once", format.Object(actual, 1)) 35 | } 36 | 37 | func (m *calledMatcher) NegatedFailureMessage(actual interface{}) string { 38 | return fmt.Sprintf("Expected\n%s\nnot to be called at least once", format.Object(actual, 1)) 39 | } 40 | 41 | // BeCalledOnce constructs a matcher that asserts the mock function object was called exactly once. 42 | func BeCalledOnce() types.GomegaMatcher { 43 | return &calledNMatcher{ 44 | name: "BeCalledOnce", 45 | n: 1, 46 | } 47 | } 48 | 49 | type calledNMatcher struct { 50 | name string 51 | n int 52 | } 53 | 54 | var _ types.GomegaMatcher = &calledNMatcher{} 55 | 56 | // BeCalledN constructs a matcher that asserts the mock function object was called exactly n times. 57 | func BeCalledN(n int) types.GomegaMatcher { 58 | return &calledNMatcher{ 59 | name: "BeCalledN", 60 | n: n, 61 | } 62 | } 63 | 64 | func (m *calledNMatcher) Match(actual interface{}) (bool, error) { 65 | history, ok := testutil.GetCallHistory(actual) 66 | if !ok { 67 | return false, fmt.Errorf("%s expects a mock function description. Got:\n%s", m.name, format.Object(actual, 1)) 68 | } 69 | 70 | return len(history) == m.n, nil 71 | } 72 | 73 | func (m *calledNMatcher) FailureMessage(actual interface{}) string { 74 | return fmt.Sprintf("Expected\n%s\nto be called %d times", format.Object(actual, 1), m.n) 75 | } 76 | 77 | func (m *calledNMatcher) NegatedFailureMessage(actual interface{}) string { 78 | return fmt.Sprintf("Expected\n%s\nnot to be called %d times", format.Object(actual, 1), m.n) 79 | } 80 | -------------------------------------------------------------------------------- /testutil/gomega/called_matcher_test.go: -------------------------------------------------------------------------------- 1 | package matchers 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestCalledMatch(t *testing.T) { 10 | ok, err := BeCalled().Match(newHistory(mockCall{})) 11 | assert.Nil(t, err) 12 | assert.True(t, ok) 13 | } 14 | 15 | func TestCalledMatchEmptyHistory(t *testing.T) { 16 | ok, err := BeCalled().Match(newHistory()) 17 | assert.Nil(t, err) 18 | assert.False(t, ok) 19 | } 20 | 21 | func TestCalledMatchError(t *testing.T) { 22 | _, err := BeCalled().Match(nil) 23 | assert.NotNil(t, err) 24 | assert.Contains(t, err.Error(), "BeCalled expects a mock function") 25 | } 26 | 27 | func TestCalledNMatch(t *testing.T) { 28 | ok, err := BeCalledN(2).Match(newHistory(mockCall{}, mockCall{})) 29 | assert.Nil(t, err) 30 | assert.True(t, ok) 31 | } 32 | 33 | func TestCalledNMatchEmptyHistory(t *testing.T) { 34 | ok, err := BeCalledN(1).Match(newHistory()) 35 | assert.Nil(t, err) 36 | assert.False(t, ok) 37 | } 38 | 39 | func TestCalledNMatchMismatchedHistory(t *testing.T) { 40 | ok, err := BeCalledN(1).Match(newHistory(mockCall{}, mockCall{})) 41 | assert.Nil(t, err) 42 | assert.False(t, ok) 43 | } 44 | 45 | func TestCalledNMatchError(t *testing.T) { 46 | _, err := BeCalledN(1).Match(nil) 47 | assert.NotNil(t, err) 48 | assert.Contains(t, err.Error(), "BeCalledN expects a mock function") 49 | } 50 | 51 | func TestCalledOnceMatch(t *testing.T) { 52 | ok, err := BeCalledOnce().Match(newHistory(mockCall{})) 53 | assert.Nil(t, err) 54 | assert.True(t, ok) 55 | } 56 | 57 | func TestCalledOnceMatchEmptyHistory(t *testing.T) { 58 | ok, err := BeCalledOnce().Match(newHistory()) 59 | assert.Nil(t, err) 60 | assert.False(t, ok) 61 | } 62 | 63 | func TestCalledOnceMatchMismatchedHistory(t *testing.T) { 64 | ok, err := BeCalledOnce().Match(newHistory(mockCall{}, mockCall{})) 65 | assert.Nil(t, err) 66 | assert.False(t, ok) 67 | } 68 | 69 | func TestCalledOnceMatchError(t *testing.T) { 70 | _, err := BeCalledOnce().Match(nil) 71 | assert.NotNil(t, err) 72 | assert.Contains(t, err.Error(), "BeCalledOnce expects a mock function") 73 | } 74 | -------------------------------------------------------------------------------- /testutil/gomega/called_with_matcher.go: -------------------------------------------------------------------------------- 1 | package matchers 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/derision-test/go-mockgen/v2/internal/testutil" 7 | "github.com/onsi/gomega/format" 8 | "github.com/onsi/gomega/matchers" 9 | "github.com/onsi/gomega/types" 10 | ) 11 | 12 | type calledWithMatcher struct { 13 | name string 14 | args []interface{} 15 | } 16 | 17 | var _ types.GomegaMatcher = &calledWithMatcher{} 18 | 19 | // BeCalledWith constructs a matcher that asserts the mock function object was called at least once 20 | // with a set of arguments matching the given values. The values can be another matcher or a literal 21 | // value. In the latter case, the values will be checked for equality. 22 | func BeCalledWith(args ...interface{}) types.GomegaMatcher { 23 | return &calledWithMatcher{ 24 | name: "BeCalledWith", 25 | args: args, 26 | } 27 | } 28 | 29 | func (m *calledWithMatcher) Match(actual interface{}) (bool, error) { 30 | matchingHistory, ok := getCallHistoryWith(actual, m.args...) 31 | if !ok { 32 | return false, fmt.Errorf("%s expects a mock function description. Got:\n%s", m.name, format.Object(actual, 1)) 33 | } 34 | 35 | return len(matchingHistory) > 0, nil 36 | } 37 | 38 | func (m *calledWithMatcher) FailureMessage(actual interface{}) string { 39 | return format.Message(actual, "to contain at least one call with argument list matching", m.args) 40 | } 41 | 42 | func (m *calledWithMatcher) NegatedFailureMessage(actual interface{}) string { 43 | return format.Message(actual, "not to contain at least one call with argument list matching", m.args) 44 | } 45 | 46 | // BeCalledOnceWith constructs a matcher that asserts the mock function object was called exactly once 47 | // with a set of arguments matching the given values. The values can be another matcher or a literal 48 | // value. In the latter case, the values will be checked for equality. 49 | func BeCalledOnceWith(args ...interface{}) types.GomegaMatcher { 50 | return &calledNWithMatcher{ 51 | name: "BeCalledOnceWith", 52 | n: 1, 53 | args: args, 54 | } 55 | } 56 | 57 | type calledNWithMatcher struct { 58 | name string 59 | n int 60 | args []interface{} 61 | } 62 | 63 | var _ types.GomegaMatcher = &calledNWithMatcher{} 64 | 65 | // BeCalledNWith constructs a matcher that asserts the mock function object was called exactly n times 66 | // with a set of arguments matching the given values. The values can be another matcher or a literal 67 | // value. In the latter case, the values will be checked for equality. 68 | func BeCalledNWith(n int, args ...interface{}) types.GomegaMatcher { 69 | return &calledNWithMatcher{ 70 | name: "BeCalledNWith", 71 | n: n, 72 | args: args, 73 | } 74 | } 75 | 76 | func (m *calledNWithMatcher) Match(actual interface{}) (bool, error) { 77 | matchingHistory, ok := getCallHistoryWith(actual, m.args...) 78 | if !ok { 79 | return false, fmt.Errorf("%s expects a mock function description. Got:\n%s", m.name, format.Object(actual, 1)) 80 | } 81 | 82 | return len(matchingHistory) == m.n, nil 83 | } 84 | 85 | func (m *calledNWithMatcher) FailureMessage(actual interface{}) string { 86 | return format.Message(actual, "to contain one call with argument list matching", m.args) 87 | } 88 | 89 | func (m *calledNWithMatcher) NegatedFailureMessage(actual interface{}) string { 90 | return format.Message(actual, "not to contain one call with argument list matching", m.args) 91 | } 92 | 93 | // getCallHistoryWith returns the set of call instances matching the given values. The values can 94 | // be another matcher or a literal value. In the latter case, the values will be checked for equality. 95 | func getCallHistoryWith(actual interface{}, args ...interface{}) ([]testutil.CallInstance, bool) { 96 | return testutil.GetCallHistoryWith(actual, func(v testutil.CallInstance) bool { 97 | if len(args) > len(v.Args()) { 98 | return false 99 | } 100 | 101 | for i, expectedArg := range args { 102 | matcher, ok := expectedArg.(types.GomegaMatcher) 103 | if !ok { 104 | matcher = &matchers.EqualMatcher{Expected: expectedArg} 105 | } 106 | 107 | success, err := matcher.Match(v.Args()[i]) 108 | if err != nil || !success { 109 | return false 110 | } 111 | } 112 | 113 | return true 114 | }) 115 | } 116 | -------------------------------------------------------------------------------- /testutil/gomega/called_with_matcher_test.go: -------------------------------------------------------------------------------- 1 | package matchers 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/onsi/gomega" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestCalledWithMatch(t *testing.T) { 11 | ok, err := BeCalledWith(gomega.ContainSubstring("foo"), 1, gomega.Not(gomega.Equal(2)), 3).Match(newHistory( 12 | mockCall{[]interface{}{"foobar", 1, 2, 3}, nil}, 13 | mockCall{[]interface{}{"foobar", 1, 4, 3}, nil}, 14 | mockCall{[]interface{}{"barbaz", 1, 2, 3}, nil}, 15 | )) 16 | 17 | assert.Nil(t, err) 18 | assert.True(t, ok) 19 | } 20 | 21 | func TestCalledWithNoMatch(t *testing.T) { 22 | ok, err := BeCalledWith(gomega.ContainSubstring("foo"), 1, gomega.Not(gomega.Equal(2)), 3).Match(newHistory( 23 | mockCall{[]interface{}{"foobar", 1, 2, 3}, nil}, 24 | mockCall{[]interface{}{"barbaz", 1, 4, 3}, nil}, 25 | mockCall{[]interface{}{"foobaz", 1, 2, 3}, nil}, 26 | )) 27 | 28 | assert.Nil(t, err) 29 | assert.False(t, ok) 30 | } 31 | 32 | func TestCalledWithMatchError(t *testing.T) { 33 | _, err := BeCalledWith("foo", 1, 2, 3).Match(nil) 34 | assert.NotNil(t, err) 35 | assert.Contains(t, err.Error(), "BeCalledWith expects a mock function") 36 | } 37 | 38 | func TestCalledOnceWithMatch(t *testing.T) { 39 | ok, err := BeCalledOnceWith(gomega.ContainSubstring("foo"), 1, gomega.Not(gomega.Equal(2)), 3).Match(newHistory( 40 | mockCall{[]interface{}{"foobar", 1, 2, 3}, nil}, 41 | mockCall{[]interface{}{"foobar", 1, 4, 3}, nil}, 42 | mockCall{[]interface{}{"barbaz", 1, 2, 3}, nil}, 43 | )) 44 | 45 | assert.Nil(t, err) 46 | assert.True(t, ok) 47 | } 48 | 49 | func TestCalledOnceWithNoMatch(t *testing.T) { 50 | ok, err := BeCalledOnceWith(gomega.ContainSubstring("foo"), 1, gomega.Not(gomega.Equal(2)), 3).Match(newHistory( 51 | mockCall{[]interface{}{"foobar", 1, 2, 3}, nil}, 52 | mockCall{[]interface{}{"barbaz", 1, 4, 3}, nil}, 53 | mockCall{[]interface{}{"foobaz", 1, 2, 3}, nil}, 54 | )) 55 | 56 | assert.Nil(t, err) 57 | assert.False(t, ok) 58 | } 59 | 60 | func TestCalledOnceWithMultipleMatches(t *testing.T) { 61 | ok, err := BeCalledOnceWith(gomega.ContainSubstring("foo"), 1, gomega.Not(gomega.Equal(2)), 3).Match(newHistory( 62 | mockCall{[]interface{}{"foobar", 1, 2, 3}, nil}, 63 | mockCall{[]interface{}{"foobar", 1, 4, 3}, nil}, 64 | mockCall{[]interface{}{"foobar", 1, 4, 3}, nil}, 65 | mockCall{[]interface{}{"foobaz", 1, 2, 3}, nil}, 66 | )) 67 | 68 | assert.Nil(t, err) 69 | assert.False(t, ok) 70 | } 71 | 72 | func TestCalledOnceWithMatchError(t *testing.T) { 73 | _, err := BeCalledOnceWith("foo", 1, 2, 3).Match(nil) 74 | assert.NotNil(t, err) 75 | assert.Contains(t, err.Error(), "BeCalledOnceWith expects a mock function") 76 | } 77 | 78 | func TestGetMatchingCallCountsLiterals(t *testing.T) { 79 | history := newHistory( 80 | mockCall{args: []interface{}{"foo", "bar"}}, 81 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 82 | mockCall{args: []interface{}{"foo", "bar", "baz", "bonk"}}, 83 | mockCall{args: []interface{}{"foo", "bar", "bonk"}}, 84 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 85 | ) 86 | 87 | matchingHistory, ok := getCallHistoryWith(history, "foo", "bar", "baz") 88 | assert.True(t, ok) 89 | assert.Len(t, matchingHistory, 3) 90 | } 91 | 92 | func TestGetMatchingCallCountsMatchers(t *testing.T) { 93 | history := newHistory( 94 | mockCall{args: []interface{}{"foo", "bar"}}, 95 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 96 | mockCall{args: []interface{}{"foo", "bar", "baz", "bonk"}}, 97 | mockCall{args: []interface{}{"foo", "bar", "bonk"}}, 98 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 99 | ) 100 | 101 | matchingHistory, ok := getCallHistoryWith(history, gomega.HaveLen(3), gomega.HaveLen(3), gomega.HaveLen(3)) 102 | assert.True(t, ok) 103 | assert.Len(t, matchingHistory, 3) 104 | } 105 | 106 | func TestGetMatchingCallCountsMixed(t *testing.T) { 107 | history := newHistory( 108 | mockCall{args: []interface{}{"foo", "bar"}}, 109 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 110 | mockCall{args: []interface{}{"foo", "bar", "baz", "bonk"}}, 111 | mockCall{args: []interface{}{"foo", "bar", "bonk"}}, 112 | mockCall{args: []interface{}{"foo", "bar", "baz"}}, 113 | ) 114 | 115 | matchingHistory, ok := getCallHistoryWith(history, "foo", "bar", gomega.ContainSubstring("bo")) 116 | assert.True(t, ok) 117 | assert.Len(t, matchingHistory, 1) 118 | } 119 | -------------------------------------------------------------------------------- /testutil/gomega/helpers_test.go: -------------------------------------------------------------------------------- 1 | package matchers 2 | 3 | type mockFunc struct { 4 | history []mockCall 5 | } 6 | 7 | type mockCall struct { 8 | args []interface{} 9 | results []interface{} 10 | } 11 | 12 | func newHistory(calls ...mockCall) *mockFunc { 13 | return &mockFunc{ 14 | history: calls, 15 | } 16 | } 17 | 18 | func (m mockFunc) History() []mockCall { return m.history } 19 | func (m mockCall) Args() []interface{} { return m.args } 20 | func (m mockCall) Results() []interface{} { return m.results } 21 | -------------------------------------------------------------------------------- /testutil/require/asserter.go: -------------------------------------------------------------------------------- 1 | package mockrequire 2 | 3 | import mockassert "github.com/derision-test/go-mockgen/v2/testutil/assert" 4 | 5 | type CallInstanceAsserter = mockassert.CallInstanceAsserter 6 | type CallInstanceAsserterFunc = mockassert.CallInstanceAsserterFunc 7 | 8 | var Values = mockassert.Values 9 | var Skip = mockassert.Skip 10 | -------------------------------------------------------------------------------- /testutil/require/require.go: -------------------------------------------------------------------------------- 1 | package mockrequire 2 | 3 | import ( 4 | mockassert "github.com/derision-test/go-mockgen/v2/testutil/assert" 5 | "github.com/stretchr/testify/require" 6 | ) 7 | 8 | // Called asserts that the mock function object was called at least once. 9 | func Called(t require.TestingT, mockFn interface{}, msgAndArgs ...interface{}) { 10 | if !mockassert.Called(t, mockFn, msgAndArgs...) { 11 | t.FailNow() 12 | } 13 | } 14 | 15 | // NotCalled asserts that the mock function object was not called. 16 | func NotCalled(t require.TestingT, mockFn interface{}, msgAndArgs ...interface{}) { 17 | if !mockassert.NotCalled(t, mockFn, msgAndArgs...) { 18 | t.FailNow() 19 | } 20 | } 21 | 22 | // CalledOnce asserts that the mock function object was called exactly once. 23 | func CalledOnce(t require.TestingT, mockFn interface{}, msgAndArgs ...interface{}) { 24 | if !mockassert.CalledOnce(t, mockFn, msgAndArgs...) { 25 | t.FailNow() 26 | } 27 | } 28 | 29 | // CalledN asserts that the mock function object was called exactly n times. 30 | func CalledN(t require.TestingT, mockFn interface{}, n int, msgAndArgs ...interface{}) { 31 | if !mockassert.CalledN(t, mockFn, n, msgAndArgs...) { 32 | t.FailNow() 33 | } 34 | } 35 | 36 | // CalledWith asserts that the mock function object was called at least once with a set of 37 | // arguments matching the given mockassertion function. 38 | func CalledWith(t require.TestingT, mockFn interface{}, asserter CallInstanceAsserter, msgAndArgs ...interface{}) { 39 | if !mockassert.CalledWith(t, mockFn, asserter, msgAndArgs...) { 40 | t.FailNow() 41 | } 42 | } 43 | 44 | // NotCalledWith asserts that the mock function object was not called with a set of arguments 45 | // matching the given mockassertion function. 46 | func NotCalledWith(t require.TestingT, mockFn interface{}, asserter CallInstanceAsserter, msgAndArgs ...interface{}) { 47 | if !mockassert.NotCalledWith(t, mockFn, asserter, msgAndArgs...) { 48 | t.FailNow() 49 | } 50 | } 51 | 52 | // CalledOnceWith asserts that the mock function object was called exactly once with a set of 53 | // arguments matching the given mockassertion function. 54 | func CalledOnceWith(t require.TestingT, mockFn interface{}, asserter CallInstanceAsserter, msgAndArgs ...interface{}) { 55 | if !mockassert.CalledOnceWith(t, mockFn, asserter, msgAndArgs...) { 56 | t.FailNow() 57 | } 58 | } 59 | 60 | // CalledNWith asserts that the mock function object was called exactly n times with a set of 61 | // arguments matching the given mockassertion function. 62 | func CalledNWith(t require.TestingT, mockFn interface{}, n int, asserter CallInstanceAsserter, msgAndArgs ...interface{}) { 63 | if !mockassert.CalledNWith(t, mockFn, n, asserter, msgAndArgs...) { 64 | t.FailNow() 65 | } 66 | } 67 | 68 | // CalledAtNWith asserts that the mock function objects nth call was with a set of 69 | // arguments matching the given call instance asserter. 70 | func CalledAtNWith(t require.TestingT, mockFn interface{}, n int, asserter CallInstanceAsserter, msgAndArgs ...interface{}) { 71 | if !mockassert.CalledAtNWith(t, mockFn, n, asserter, msgAndArgs...) { 72 | t.FailNow() 73 | } 74 | } 75 | --------------------------------------------------------------------------------