├── .gitignore ├── .license-header ├── .travis.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── doc.go ├── main.go ├── mocks ├── helheim_test.go ├── helpers_test.go ├── method.go ├── method_test.go ├── mock.go ├── mock_test.go ├── mocks.go ├── mocks_test.go ├── sugar.go └── test │ ├── withimports │ └── with_imports.go │ └── withoutimports │ └── without_imports.go ├── packages ├── packages.go └── packages_test.go ├── pers ├── consistent.go ├── consistent_test.go ├── havemethodexecuted.go ├── havemethodexecuted_test.go ├── helpers.go ├── localized_test.go ├── return.go └── return_test.go └── types ├── doc.go ├── helheim_test.go ├── helpers_test.go ├── types.go └── types_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | -------------------------------------------------------------------------------- /.license-header: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | os: 4 | - linux 5 | - osx 6 | 7 | go: 8 | - "1.12.x" 9 | - "1.13.x" 10 | 11 | install: 12 | - go get golang.org/x/lint/golint 13 | - go get -t -v -d ./... 14 | 15 | script: 16 | - go vet ./... 17 | - golint -set_exit_status ./... 18 | - go test -v -race -parallel 4 ./... 19 | - go build 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contributing to hel 2 | ------------------- 3 | 4 | See [the main repo's CONTRIBUTING.md](https://git.sr.ht/~nelsam/hel/tree/master/CONTRIBUTING.md) 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hel has moved to sr.ht! 2 | 3 | I don't trust or support Microsoft. I've weight the options and 4 | decided to [move hel to sr.ht](https://git.sr.ht/~nelsam/hel). 5 | Please submit issues and pull requests there. 6 | 7 | No more releases will be made on github. 8 | 9 | See [the main repo's README.md](https://git.sr.ht/~nelsam/hel/tree/master/README.md) 10 | for details about hel! 11 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | // Package main implements the hel command. 6 | package main 7 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "os" 10 | "os/exec" 11 | "path/filepath" 12 | "reflect" 13 | "strings" 14 | "time" 15 | 16 | "github.com/nelsam/hel/mocks" 17 | "github.com/nelsam/hel/packages" 18 | "github.com/nelsam/hel/types" 19 | "github.com/spf13/cobra" 20 | ) 21 | 22 | var ( 23 | cmd *cobra.Command 24 | goimportsPath string 25 | ) 26 | 27 | func init() { 28 | output, err := exec.Command("which", "goimports").Output() 29 | if err != nil { 30 | fmt.Println("Could not locate goimports: ", err.Error()) 31 | fmt.Println("If goimports is not installed, please install it somewhere in your path. " + 32 | "See https://godoc.org/golang.org/x/tools/cmd/goimports.") 33 | os.Exit(1) 34 | } 35 | goimportsPath = strings.TrimSpace(string(output)) 36 | 37 | cmd = &cobra.Command{ 38 | Use: "hel", 39 | Short: "A mock generator for Go", 40 | Long: "hel is a simple mock generator. The origin of the name is the Norse goddess, Hel, " + 41 | "who guards over the souls of those unworthy to enter Valhalla. You can probably " + 42 | "guess how much I like mocks.", 43 | Run: func(cmd *cobra.Command, args []string) { 44 | if len(args) > 0 { 45 | fmt.Print("Invalid usage. Help:\n\n") 46 | cmd.HelpFunc()(nil, nil) 47 | os.Exit(1) 48 | } 49 | packagePatterns, err := cmd.Flags().GetStringSlice("package") 50 | if err != nil { 51 | panic(err) 52 | } 53 | typePatterns, err := cmd.Flags().GetStringSlice("type") 54 | if err != nil { 55 | panic(err) 56 | } 57 | outputName, err := cmd.Flags().GetString("output") 58 | if err != nil { 59 | panic(err) 60 | } 61 | chanSize, err := cmd.Flags().GetInt("chan-size") 62 | if err != nil { 63 | panic(err) 64 | } 65 | blockingReturn, err := cmd.Flags().GetBool("blocking-return") 66 | if err != nil { 67 | panic(err) 68 | } 69 | noTestPkg, err := cmd.Flags().GetBool("no-test-package") 70 | if err != nil { 71 | panic(err) 72 | } 73 | fmt.Printf("Loading directories matching %s %v", pluralize(packagePatterns, "pattern", "patterns"), packagePatterns) 74 | var dirList []packages.Dir 75 | progress(func() { 76 | dirList = packages.Load(packagePatterns...) 77 | }) 78 | fmt.Print("\n") 79 | fmt.Println("Found directories:") 80 | for _, dir := range dirList { 81 | fmt.Println(" " + dir.Path()) 82 | } 83 | fmt.Print("\n") 84 | 85 | fmt.Printf("Loading interface types in matching directories") 86 | var typeDirs types.Dirs 87 | progress(func() { 88 | godirs := make([]types.GoDir, 0, len(dirList)) 89 | for _, dir := range dirList { 90 | godirs = append(godirs, dir) 91 | } 92 | typeDirs = types.Load(godirs...).Filter(typePatterns...) 93 | }) 94 | fmt.Print("\n\n") 95 | 96 | fmt.Printf("Generating mocks in output file %s", outputName) 97 | progress(func() { 98 | for _, typeDir := range typeDirs { 99 | mockPath, err := makeMocks(typeDir, outputName, chanSize, blockingReturn, !noTestPkg) 100 | if err != nil { 101 | panic(err) 102 | } 103 | if mockPath != "" { 104 | if err = exec.Command(goimportsPath, "-w", mockPath).Run(); err != nil { 105 | panic(err) 106 | } 107 | } 108 | } 109 | }) 110 | fmt.Print("\n") 111 | }, 112 | } 113 | cmd.Flags().StringSliceP("package", "p", []string{"."}, "The package(s) to generate mocks for.") 114 | cmd.Flags().StringSliceP("type", "t", []string{}, "The type(s) to generate mocks for. If no types "+ 115 | "are passed in, all exported interface types will be generated.") 116 | cmd.Flags().StringP("output", "o", "helheim_test.go", "The file to write generated mocks to. Since hel does "+ 117 | "not generate exported types, this file will be saved directly in all packages with generated mocks. "+ 118 | "Also note that, since the types are not exported, you will want the file to end in '_test.go'.") 119 | cmd.Flags().IntP("chan-size", "s", 100, "The size of channels used for method calls.") 120 | cmd.Flags().BoolP("blocking-return", "b", false, "Always block when returning from mock even if there is no return value.") 121 | cmd.Flags().Bool("no-test-package", false, "Generate mocks in the primary package rather than in {pkg}_test") 122 | } 123 | 124 | func makeMocks(types types.Dir, fileName string, chanSize int, blockingReturn, useTestPkg bool) (filePath string, err error) { 125 | mocks, err := mocks.Generate(types) 126 | if err != nil { 127 | return "", err 128 | } 129 | if len(mocks) == 0 { 130 | return "", nil 131 | } 132 | mocks.SetBlockingReturn(blockingReturn) 133 | if useTestPkg { 134 | mocks.PrependLocalPackage(types.Package()) 135 | } 136 | filePath = filepath.Join(types.Dir(), fileName) 137 | f, err := os.Create(filePath) 138 | if err != nil { 139 | return "", err 140 | } 141 | defer f.Close() 142 | testPkg := types.Package() 143 | if useTestPkg { 144 | testPkg += "_test" 145 | } 146 | return filePath, mocks.Output(testPkg, types.Dir(), chanSize, f) 147 | } 148 | 149 | func progress(f func()) { 150 | stop, done := make(chan struct{}), make(chan struct{}) 151 | defer func() { 152 | close(stop) 153 | <-done 154 | }() 155 | go showProgress(stop, done) 156 | f() 157 | } 158 | 159 | func showProgress(stop <-chan struct{}, done chan<- struct{}) { 160 | defer close(done) 161 | ticker := time.NewTicker(time.Second / 2) 162 | defer ticker.Stop() 163 | for { 164 | select { 165 | case <-ticker.C: 166 | fmt.Print(".") 167 | case <-stop: 168 | return 169 | } 170 | } 171 | } 172 | 173 | type lengther interface { 174 | Len() int 175 | } 176 | 177 | func pluralize(values interface{}, singular, plural string) string { 178 | length := findLength(values) 179 | if length == 1 { 180 | return singular 181 | } 182 | return plural 183 | } 184 | 185 | func findLength(values interface{}) int { 186 | if lengther, ok := values.(lengther); ok { 187 | return lengther.Len() 188 | } 189 | return reflect.ValueOf(values).Len() 190 | } 191 | 192 | func main() { 193 | cmd.Execute() 194 | } 195 | -------------------------------------------------------------------------------- /mocks/helheim_test.go: -------------------------------------------------------------------------------- 1 | // This file was generated by github.com/nelsam/hel. Do not 2 | // edit this code by hand unless you *really* know what you're 3 | // doing. Expect any changes made manually to be overwritten 4 | // the next time hel regenerates this file. 5 | 6 | package mocks_test 7 | 8 | import ( 9 | "go/ast" 10 | 11 | "github.com/nelsam/hel/types" 12 | ) 13 | 14 | type mockTypeFinder struct { 15 | ExportedTypesCalled chan bool 16 | ExportedTypesOutput struct { 17 | Types chan []*ast.TypeSpec 18 | } 19 | DependenciesCalled chan bool 20 | DependenciesInput struct { 21 | Inter chan *ast.InterfaceType 22 | } 23 | DependenciesOutput struct { 24 | Dependencies chan []types.Dependency 25 | } 26 | } 27 | 28 | func newMockTypeFinder() *mockTypeFinder { 29 | m := &mockTypeFinder{} 30 | m.ExportedTypesCalled = make(chan bool, 100) 31 | m.ExportedTypesOutput.Types = make(chan []*ast.TypeSpec, 100) 32 | m.DependenciesCalled = make(chan bool, 100) 33 | m.DependenciesInput.Inter = make(chan *ast.InterfaceType, 100) 34 | m.DependenciesOutput.Dependencies = make(chan []types.Dependency, 100) 35 | return m 36 | } 37 | func (m *mockTypeFinder) ExportedTypes() (types []*ast.TypeSpec) { 38 | m.ExportedTypesCalled <- true 39 | return <-m.ExportedTypesOutput.Types 40 | } 41 | func (m *mockTypeFinder) Dependencies(inter *ast.InterfaceType) (dependencies []types.Dependency) { 42 | m.DependenciesCalled <- true 43 | m.DependenciesInput.Inter <- inter 44 | return <-m.DependenciesOutput.Dependencies 45 | } 46 | -------------------------------------------------------------------------------- /mocks/helpers_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package mocks_test 6 | 7 | import ( 8 | "bytes" 9 | "go/ast" 10 | "go/format" 11 | "go/parser" 12 | "go/token" 13 | 14 | "github.com/a8m/expect" 15 | ) 16 | 17 | const packagePrefix = "package foo\n\n" 18 | 19 | func source(expect func(interface{}) *expect.Expect, pkg string, decls []ast.Decl, scope *ast.Scope) string { 20 | buf := bytes.Buffer{} 21 | f := &ast.File{ 22 | Name: &ast.Ident{Name: pkg}, 23 | Decls: decls, 24 | Scope: scope, 25 | } 26 | err := format.Node(&buf, token.NewFileSet(), f) 27 | expect(err).To.Be.Nil() 28 | return buf.String() 29 | } 30 | 31 | func parse(expect func(interface{}) *expect.Expect, code string) *ast.File { 32 | f, err := parser.ParseFile(token.NewFileSet(), "", packagePrefix+code, 0) 33 | expect(err).To.Be.Nil() 34 | expect(f).Not.To.Be.Nil() 35 | return f 36 | } 37 | 38 | func typeSpec(expect func(interface{}) *expect.Expect, spec string) *ast.TypeSpec { 39 | f := parse(expect, spec) 40 | expect(f.Scope.Objects).To.Have.Len(1) 41 | for _, obj := range f.Scope.Objects { 42 | spec, ok := obj.Decl.(*ast.TypeSpec) 43 | expect(ok).To.Be.Ok() 44 | return spec 45 | } 46 | return nil 47 | } 48 | 49 | func method(expect func(interface{}) *expect.Expect, spec *ast.TypeSpec) *ast.FuncType { 50 | inter, ok := spec.Type.(*ast.InterfaceType) 51 | expect(ok).To.Be.Ok() 52 | expect(inter.Methods.List).To.Have.Len(1) 53 | f, ok := inter.Methods.List[0].Type.(*ast.FuncType) 54 | expect(ok).To.Be.Ok() 55 | return f 56 | } 57 | -------------------------------------------------------------------------------- /mocks/method.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package mocks 6 | 7 | import ( 8 | "fmt" 9 | "go/ast" 10 | "go/token" 11 | "strconv" 12 | "strings" 13 | "unicode" 14 | ) 15 | 16 | const ( 17 | inputFmt = "arg%d" 18 | outputFmt = "ret%d" 19 | receiverName = "m" 20 | ) 21 | 22 | // Method represents a method that is being mocked. 23 | type Method struct { 24 | receiver Mock 25 | name string 26 | implements *ast.FuncType 27 | } 28 | 29 | // MethodFor returns a Method representing typ, using receiver as 30 | // the Method's receiver type and name as the method name. 31 | func MethodFor(receiver Mock, name string, typ *ast.FuncType) Method { 32 | return Method{ 33 | receiver: receiver, 34 | name: name, 35 | implements: typ, 36 | } 37 | } 38 | 39 | // Ast returns the ast representation of m. 40 | func (m Method) Ast() *ast.FuncDecl { 41 | f := &ast.FuncDecl{} 42 | f.Name = &ast.Ident{Name: m.name} 43 | f.Type = m.mockType() 44 | f.Recv = m.recv() 45 | f.Body = m.body() 46 | return f 47 | } 48 | 49 | // Fields returns the fields that need to be a part of the receiver 50 | // struct for this method. 51 | func (m Method) Fields() []*ast.Field { 52 | fields := []*ast.Field{ 53 | { 54 | Names: []*ast.Ident{{Name: m.name + "Called"}}, 55 | Type: &ast.ChanType{ 56 | Dir: ast.SEND | ast.RECV, 57 | Value: &ast.Ident{Name: "bool"}, 58 | }, 59 | }, 60 | } 61 | if len(m.params()) > 0 { 62 | fields = append(fields, &ast.Field{ 63 | Names: []*ast.Ident{{Name: m.name + "Input"}}, 64 | Type: m.chanStruct(m.implements.Params.List), 65 | }) 66 | } 67 | if len(m.results()) > 0 { 68 | fields = append(fields, &ast.Field{ 69 | Names: []*ast.Ident{{Name: m.name + "Output"}}, 70 | Type: m.chanStruct(m.results()), 71 | }) 72 | } 73 | return fields 74 | } 75 | 76 | func (m Method) chanStruct(list []*ast.Field) *ast.StructType { 77 | typ := &ast.StructType{Fields: &ast.FieldList{}} 78 | for _, f := range list { 79 | chanValType := f.Type 80 | switch src := chanValType.(type) { 81 | case *ast.ChanType: 82 | // Receive-only channels require parens, and it seems unfair to leave 83 | // out send-only channels. 84 | switch src.Dir { 85 | case ast.SEND, ast.RECV: 86 | chanValType = &ast.ParenExpr{X: src} 87 | } 88 | case *ast.Ellipsis: 89 | // The actual value of variadic types is a slice 90 | chanValType = &ast.ArrayType{Elt: src.Elt} 91 | } 92 | names := make([]*ast.Ident, 0, len(f.Names)) 93 | for _, name := range f.Names { 94 | newName := &ast.Ident{} 95 | *newName = *name 96 | names = append(names, newName) 97 | newName.Name = strings.Title(newName.Name) 98 | } 99 | typ.Fields.List = append(typ.Fields.List, &ast.Field{ 100 | Names: names, 101 | Type: &ast.ChanType{ 102 | Dir: ast.SEND | ast.RECV, 103 | Value: chanValType, 104 | }, 105 | }) 106 | } 107 | return typ 108 | } 109 | 110 | func (m Method) paramChanInit(chanSize int) []ast.Stmt { 111 | if len(m.params()) == 0 { 112 | return nil 113 | } 114 | return m.typeChanInit(m.name+"Input", m.implements.Params.List, chanSize) 115 | } 116 | 117 | func (m Method) returnChanInit(chanSize int) []ast.Stmt { 118 | return m.typeChanInit(m.name+"Output", m.results(), chanSize) 119 | } 120 | 121 | func (m Method) typeChanInit(fieldName string, fields []*ast.Field, chanSize int) (inputInits []ast.Stmt) { 122 | for _, field := range fields { 123 | for _, name := range field.Names { 124 | inputInits = append(inputInits, &ast.AssignStmt{ 125 | Lhs: []ast.Expr{selectors("m", fieldName, strings.Title(name.String()))}, 126 | Tok: token.ASSIGN, 127 | Rhs: []ast.Expr{m.makeChan(field.Type, chanSize)}, 128 | }) 129 | } 130 | } 131 | return inputInits 132 | } 133 | 134 | func (m Method) makeChan(typ ast.Expr, size int) *ast.CallExpr { 135 | switch src := typ.(type) { 136 | case *ast.ChanType: 137 | switch src.Dir { 138 | case ast.SEND, ast.RECV: 139 | typ = &ast.ParenExpr{X: src} 140 | } 141 | case *ast.Ellipsis: 142 | // The actual value of variadic types is a slice 143 | typ = &ast.ArrayType{Elt: src.Elt} 144 | } 145 | return &ast.CallExpr{ 146 | Fun: &ast.Ident{Name: "make"}, 147 | Args: []ast.Expr{ 148 | &ast.ChanType{Dir: ast.SEND | ast.RECV, Value: typ}, 149 | &ast.BasicLit{Kind: token.INT, Value: strconv.Itoa(size)}, 150 | }, 151 | } 152 | } 153 | 154 | func (m Method) chanInit(chanSize int) []ast.Stmt { 155 | stmts := []ast.Stmt{ 156 | &ast.AssignStmt{ 157 | Lhs: []ast.Expr{selectors("m", m.name+"Called")}, 158 | Tok: token.ASSIGN, 159 | Rhs: []ast.Expr{m.makeChan(&ast.Ident{Name: "bool"}, chanSize)}, 160 | }, 161 | } 162 | stmts = append(stmts, m.paramChanInit(chanSize)...) 163 | stmts = append(stmts, m.returnChanInit(chanSize)...) 164 | return stmts 165 | } 166 | 167 | func (m Method) recv() *ast.FieldList { 168 | return &ast.FieldList{ 169 | List: []*ast.Field{ 170 | { 171 | Names: []*ast.Ident{{Name: receiverName}}, 172 | Type: &ast.StarExpr{ 173 | X: &ast.Ident{Name: m.receiver.Name()}, 174 | }, 175 | }, 176 | }, 177 | } 178 | } 179 | 180 | func (m Method) mockType() *ast.FuncType { 181 | newTyp := &ast.FuncType{ 182 | Results: m.implements.Results, 183 | } 184 | if m.implements.Params != nil { 185 | newTyp.Params = &ast.FieldList{ 186 | List: m.params(), 187 | } 188 | } 189 | return newTyp 190 | } 191 | 192 | func (m Method) sendOn(receiver string, fields ...string) *ast.SendStmt { 193 | return &ast.SendStmt{Chan: selectors(receiver, fields...)} 194 | } 195 | 196 | func (m Method) called() ast.Stmt { 197 | stmt := m.sendOn(receiverName, m.name+"Called") 198 | stmt.Value = &ast.Ident{Name: "true"} 199 | return stmt 200 | } 201 | 202 | func mockField(idx int, f *ast.Field) *ast.Field { 203 | if f.Names == nil { 204 | if idx < 0 { 205 | return f 206 | } 207 | // Edit the field directly to ensure the same name is used in the mock 208 | // struct. 209 | f.Names = []*ast.Ident{{Name: fmt.Sprintf(inputFmt, idx)}} 210 | return f 211 | } 212 | 213 | // Here, we want a copy, so that we can use altered names without affecting 214 | // field names in the mock struct. 215 | newField := &ast.Field{Type: f.Type} 216 | for _, n := range f.Names { 217 | name := n.Name 218 | if name == receiverName { 219 | name += "_" 220 | } 221 | newField.Names = append(newField.Names, &ast.Ident{Name: name}) 222 | } 223 | return newField 224 | } 225 | 226 | func (m Method) params() []*ast.Field { 227 | var params []*ast.Field 228 | for idx, f := range m.implements.Params.List { 229 | params = append(params, mockField(idx, f)) 230 | } 231 | return params 232 | } 233 | 234 | func (m Method) results() []*ast.Field { 235 | if m.implements.Results == nil { 236 | if !*m.receiver.blockingReturn { 237 | return nil 238 | } 239 | return []*ast.Field{ 240 | { 241 | Names: []*ast.Ident{ 242 | {Name: "blockReturn"}, 243 | }, 244 | Type: &ast.Ident{Name: "bool"}, 245 | }, 246 | } 247 | } 248 | fields := make([]*ast.Field, 0, len(m.implements.Results.List)) 249 | for idx, f := range m.implements.Results.List { 250 | if f.Names == nil { 251 | // to avoid changing the method definition, make a copy 252 | copy := *f 253 | f = © 254 | f.Names = []*ast.Ident{{Name: fmt.Sprintf(outputFmt, idx)}} 255 | } 256 | fields = append(fields, f) 257 | } 258 | return fields 259 | } 260 | 261 | func (m Method) inputs() (stmts []ast.Stmt) { 262 | for _, input := range m.params() { 263 | for _, n := range input.Names { 264 | // Undo our hack to avoid name collisions with the receiver. 265 | name := n.Name 266 | if name == receiverName+"_" { 267 | name = receiverName 268 | } 269 | stmt := m.sendOn(receiverName, m.name+"Input", strings.Title(name)) 270 | stmt.Value = &ast.Ident{Name: n.Name} 271 | stmts = append(stmts, stmt) 272 | } 273 | } 274 | return stmts 275 | } 276 | 277 | // PrependLocalPackage prepends name as the package name for local types 278 | // in m's signature. This is most often used when mocking types that are 279 | // imported by the local package. 280 | func (m Method) PrependLocalPackage(name string) { 281 | m.prependPackage(name, m.implements.Results) 282 | m.prependPackage(name, m.implements.Params) 283 | } 284 | 285 | func (m Method) prependPackage(name string, fields *ast.FieldList) { 286 | if fields == nil { 287 | return 288 | } 289 | for _, field := range fields.List { 290 | field.Type = m.prependTypePackage(name, field.Type) 291 | } 292 | } 293 | 294 | func (m Method) prependTypePackage(name string, typ ast.Expr) ast.Expr { 295 | switch src := typ.(type) { 296 | case *ast.Ident: 297 | if !unicode.IsUpper(rune(src.String()[0])) { 298 | // Assume a built-in type, at least for now 299 | return src 300 | } 301 | return selectors(name, src.String()) 302 | case *ast.FuncType: 303 | m.prependPackage(name, src.Params) 304 | m.prependPackage(name, src.Results) 305 | return src 306 | case *ast.ArrayType: 307 | src.Elt = m.prependTypePackage(name, src.Elt) 308 | return src 309 | case *ast.MapType: 310 | src.Key = m.prependTypePackage(name, src.Key) 311 | src.Value = m.prependTypePackage(name, src.Value) 312 | return src 313 | case *ast.StarExpr: 314 | src.X = m.prependTypePackage(name, src.X) 315 | return src 316 | default: 317 | return typ 318 | } 319 | } 320 | 321 | func (m Method) recvFrom(receiver string, fields ...string) *ast.UnaryExpr { 322 | return &ast.UnaryExpr{Op: token.ARROW, X: selectors(receiver, fields...)} 323 | } 324 | 325 | func (m Method) returnsExprs() (exprs []ast.Expr) { 326 | for _, output := range m.results() { 327 | for _, name := range output.Names { 328 | exprs = append(exprs, m.recvFrom(receiverName, m.name+"Output", strings.Title(name.String()))) 329 | } 330 | } 331 | return exprs 332 | } 333 | 334 | func (m Method) returns() ast.Stmt { 335 | if m.implements.Results == nil { 336 | if !*m.receiver.blockingReturn { 337 | return nil 338 | } 339 | return &ast.ExprStmt{X: m.returnsExprs()[0]} 340 | } 341 | return &ast.ReturnStmt{Results: m.returnsExprs()} 342 | } 343 | 344 | func (m Method) body() *ast.BlockStmt { 345 | stmts := []ast.Stmt{m.called()} 346 | stmts = append(stmts, m.inputs()...) 347 | if returnStmt := m.returns(); returnStmt != nil { 348 | stmts = append(stmts, m.returns()) 349 | } 350 | return &ast.BlockStmt{ 351 | List: stmts, 352 | } 353 | } 354 | -------------------------------------------------------------------------------- /mocks/method_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package mocks_test 6 | 7 | import ( 8 | "go/ast" 9 | "go/format" 10 | "testing" 11 | 12 | "github.com/a8m/expect" 13 | "github.com/nelsam/hel/mocks" 14 | ) 15 | 16 | func TestMockSimpleMethod(t *testing.T) { 17 | expect := expect.New(t) 18 | 19 | spec := typeSpec(expect, ` 20 | type Foo interface { 21 | Foo() 22 | }`) 23 | mock, err := mocks.For(spec) 24 | expect(err).To.Be.Nil().Else.FailNow() 25 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 26 | 27 | expected, err := format.Source([]byte(` 28 | package foo 29 | 30 | func (m *mockFoo) Foo() { 31 | m.FooCalled <- true 32 | }`)) 33 | expect(err).To.Be.Nil().Else.FailNow() 34 | 35 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 36 | expect(src).To.Equal(string(expected)) 37 | 38 | fields := method.Fields() 39 | expect(fields).To.Have.Len(1).Else.FailNow() 40 | 41 | expect(fields[0].Names[0].Name).To.Equal("FooCalled") 42 | ch, ok := fields[0].Type.(*ast.ChanType) 43 | expect(ok).To.Be.Ok().Else.FailNow() 44 | expect(ch.Dir).To.Equal(ast.SEND | ast.RECV) 45 | ident, ok := ch.Value.(*ast.Ident) 46 | expect(ident.Name).To.Equal("bool") 47 | } 48 | 49 | func TestMockMethodParams(t *testing.T) { 50 | expect := expect.New(t) 51 | 52 | spec := typeSpec(expect, ` 53 | type Foo interface { 54 | Foo(foo, bar string, baz int) 55 | }`) 56 | mock, err := mocks.For(spec) 57 | expect(err).To.Be.Nil().Else.FailNow() 58 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 59 | 60 | expected, err := format.Source([]byte(` 61 | package foo 62 | 63 | func (m *mockFoo) Foo(foo, bar string, baz int) { 64 | m.FooCalled <- true 65 | m.FooInput.Foo <- foo 66 | m.FooInput.Bar <- bar 67 | m.FooInput.Baz <- baz 68 | }`)) 69 | expect(err).To.Be.Nil().Else.FailNow() 70 | 71 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 72 | expect(src).To.Equal(string(expected)) 73 | 74 | fields := method.Fields() 75 | expect(fields).To.Have.Len(2) 76 | 77 | expect(fields[0].Names[0].Name).To.Equal("FooCalled") 78 | ch, ok := fields[0].Type.(*ast.ChanType) 79 | expect(ok).To.Be.Ok().Else.FailNow() 80 | expect(ch.Dir).To.Equal(ast.SEND | ast.RECV) 81 | ident, ok := ch.Value.(*ast.Ident) 82 | expect(ident.Name).To.Equal("bool") 83 | 84 | expect(fields[1].Names[0].Name).To.Equal("FooInput") 85 | input, ok := fields[1].Type.(*ast.StructType) 86 | expect(ok).To.Be.Ok().Else.FailNow() 87 | expect(input.Fields.List).To.Have.Len(2).Else.FailNow() 88 | 89 | fooBar := input.Fields.List[0] 90 | expect(fooBar.Names).To.Have.Len(2).Else.FailNow() 91 | expect(fooBar.Names[0].Name).To.Equal("Foo") 92 | expect(fooBar.Names[1].Name).To.Equal("Bar") 93 | ch, ok = fooBar.Type.(*ast.ChanType) 94 | expect(ok).To.Be.Ok().Else.FailNow() 95 | expect(ch.Dir).To.Equal(ast.SEND | ast.RECV) 96 | ident, ok = ch.Value.(*ast.Ident) 97 | expect(ident.Name).To.Equal("string") 98 | 99 | baz := input.Fields.List[1] 100 | expect(baz.Names[0].Name).To.Equal("Baz") 101 | ch, ok = baz.Type.(*ast.ChanType) 102 | expect(ok).To.Be.Ok().Else.FailNow() 103 | expect(ch.Dir).To.Equal(ast.SEND | ast.RECV) 104 | ident, ok = ch.Value.(*ast.Ident) 105 | expect(ident.Name).To.Equal("int") 106 | } 107 | 108 | func TestMockMethodReturns(t *testing.T) { 109 | expect := expect.New(t) 110 | 111 | spec := typeSpec(expect, ` 112 | type Foo interface { 113 | Foo() (foo, bar string, baz int) 114 | }`) 115 | mock, err := mocks.For(spec) 116 | expect(err).To.Be.Nil().Else.FailNow() 117 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 118 | 119 | expected, err := format.Source([]byte(` 120 | package foo 121 | 122 | func (m *mockFoo) Foo() (foo, bar string, baz int) { 123 | m.FooCalled <- true 124 | return <-m.FooOutput.Foo, <-m.FooOutput.Bar, <-m.FooOutput.Baz 125 | }`)) 126 | expect(err).To.Be.Nil().Else.FailNow() 127 | 128 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 129 | expect(src).To.Equal(string(expected)) 130 | 131 | fields := method.Fields() 132 | expect(fields).To.Have.Len(2) 133 | 134 | expect(fields[0].Names[0].Name).To.Equal("FooCalled") 135 | ch, ok := fields[0].Type.(*ast.ChanType) 136 | expect(ok).To.Be.Ok().Else.FailNow() 137 | expect(ch.Dir).To.Equal(ast.SEND | ast.RECV) 138 | ident, ok := ch.Value.(*ast.Ident) 139 | expect(ident.Name).To.Equal("bool") 140 | 141 | expect(fields[1].Names[0].Name).To.Equal("FooOutput") 142 | input, ok := fields[1].Type.(*ast.StructType) 143 | expect(ok).To.Be.Ok().Else.FailNow() 144 | expect(input.Fields.List).To.Have.Len(2).Else.FailNow() 145 | 146 | fooBar := input.Fields.List[0] 147 | expect(fooBar.Names).To.Have.Len(2).Else.FailNow() 148 | expect(fooBar.Names[0].Name).To.Equal("Foo") 149 | expect(fooBar.Names[1].Name).To.Equal("Bar") 150 | ch, ok = fooBar.Type.(*ast.ChanType) 151 | expect(ok).To.Be.Ok().Else.FailNow() 152 | expect(ch.Dir).To.Equal(ast.SEND | ast.RECV) 153 | ident, ok = ch.Value.(*ast.Ident) 154 | expect(ident.Name).To.Equal("string") 155 | 156 | baz := input.Fields.List[1] 157 | expect(baz.Names[0].Name).To.Equal("Baz") 158 | ch, ok = baz.Type.(*ast.ChanType) 159 | expect(ok).To.Be.Ok().Else.FailNow() 160 | expect(ch.Dir).To.Equal(ast.SEND | ast.RECV) 161 | ident, ok = ch.Value.(*ast.Ident) 162 | expect(ident.Name).To.Equal("int") 163 | } 164 | 165 | func TestMockMethodWithBlockingReturn(t *testing.T) { 166 | expect := expect.New(t) 167 | 168 | spec := typeSpec(expect, ` 169 | type Foo interface { 170 | Foo() 171 | }`) 172 | mock, err := mocks.For(spec) 173 | expect(err).To.Be.Nil().Else.FailNow() 174 | mock.SetBlockingReturn(true) 175 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 176 | 177 | expected, err := format.Source([]byte(` 178 | package foo 179 | 180 | func (m *mockFoo) Foo() () { 181 | m.FooCalled <- true 182 | <-m.FooOutput.BlockReturn 183 | }`)) 184 | expect(err).To.Be.Nil().Else.FailNow() 185 | 186 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 187 | expect(src).To.Equal(string(expected)) 188 | } 189 | 190 | func TestMockMethodUnnamedValues(t *testing.T) { 191 | expect := expect.New(t) 192 | 193 | spec := typeSpec(expect, ` 194 | type Foo interface { 195 | Foo(int, string) (string, error) 196 | }`) 197 | mock, err := mocks.For(spec) 198 | expect(err).To.Be.Nil().Else.FailNow() 199 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 200 | 201 | expected, err := format.Source([]byte(` 202 | package foo 203 | 204 | func (m *mockFoo) Foo(arg0 int, arg1 string) (string, error) { 205 | m.FooCalled <- true 206 | m.FooInput.Arg0 <- arg0 207 | m.FooInput.Arg1 <- arg1 208 | return <-m.FooOutput.Ret0, <-m.FooOutput.Ret1 209 | }`)) 210 | expect(err).To.Be.Nil().Else.FailNow() 211 | 212 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 213 | expect(src).To.Equal(string(expected)) 214 | } 215 | 216 | func TestMockMethodLocalTypes(t *testing.T) { 217 | expect := expect.New(t) 218 | 219 | spec := typeSpec(expect, ` 220 | type Foo interface { 221 | Foo(bar bar.Bar, baz func(f Foo) error) (*Foo, func() Foo, error) 222 | }`) 223 | mock, err := mocks.For(spec) 224 | expect(err).To.Be.Nil().Else.FailNow() 225 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 226 | 227 | expected, err := format.Source([]byte(` 228 | package foo 229 | 230 | func (m *mockFoo) Foo(bar bar.Bar, baz func(f Foo) error) (*Foo, func() Foo, error) { 231 | m.FooCalled <- true 232 | m.FooInput.Bar <- bar 233 | m.FooInput.Baz <- baz 234 | return <-m.FooOutput.Ret0, <-m.FooOutput.Ret1, <-m.FooOutput.Ret2 235 | }`)) 236 | expect(err).To.Be.Nil().Else.FailNow() 237 | 238 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 239 | expect(src).To.Equal(string(expected)) 240 | 241 | method.PrependLocalPackage("foo") 242 | 243 | expected, err = format.Source([]byte(` 244 | package foo 245 | 246 | func (m *mockFoo) Foo(bar bar.Bar, baz func(f foo.Foo) error) (*foo.Foo, func() foo.Foo, error) { 247 | m.FooCalled <- true 248 | m.FooInput.Bar <- bar 249 | m.FooInput.Baz <- baz 250 | return <-m.FooOutput.Ret0, <-m.FooOutput.Ret1, <-m.FooOutput.Ret2 251 | }`)) 252 | expect(err).To.Be.Nil().Else.FailNow() 253 | 254 | src = source(expect, "foo", []ast.Decl{method.Ast()}, nil) 255 | expect(src).To.Equal(string(expected)) 256 | } 257 | 258 | func TestMockMethodLocalTypeNesting(t *testing.T) { 259 | expect := expect.New(t) 260 | 261 | spec := typeSpec(expect, ` 262 | type Foo interface { 263 | Foo(bar []Bar, bacon map[Foo]Bar) (baz []Baz, eggs map[Foo]Bar) 264 | }`) 265 | mock, err := mocks.For(spec) 266 | expect(err).To.Be.Nil().Else.FailNow() 267 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 268 | method.PrependLocalPackage("foo") 269 | 270 | expected, err := format.Source([]byte(` 271 | package foo 272 | 273 | func (m *mockFoo) Foo(bar []foo.Bar, bacon map[foo.Foo]foo.Bar) (baz []foo.Baz, eggs map[foo.Foo]foo.Bar) { 274 | m.FooCalled <- true 275 | m.FooInput.Bar <- bar 276 | m.FooInput.Bacon <- bacon 277 | return <-m.FooOutput.Baz, <-m.FooOutput.Eggs 278 | }`)) 279 | expect(err).To.Be.Nil().Else.FailNow() 280 | 281 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 282 | expect(src).To.Equal(string(expected)) 283 | } 284 | 285 | func TestMockMethodReceiverNameConflicts(t *testing.T) { 286 | expect := expect.New(t) 287 | 288 | spec := typeSpec(expect, ` 289 | type Foo interface { 290 | Foo(m string) 291 | }`) 292 | mock, err := mocks.For(spec) 293 | expect(err).To.Be.Nil().Else.FailNow() 294 | method := mocks.MethodFor(mock, "Foo", method(expect, spec)) 295 | 296 | expected, err := format.Source([]byte(` 297 | package foo 298 | 299 | func (m *mockFoo) Foo(m_ string) { 300 | m.FooCalled <- true 301 | m.FooInput.M <- m_ 302 | }`)) 303 | expect(err).To.Be.Nil().Else.FailNow() 304 | 305 | src := source(expect, "foo", []ast.Decl{method.Ast()}, nil) 306 | expect(src).To.Equal(string(expected)) 307 | } 308 | -------------------------------------------------------------------------------- /mocks/mock.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package mocks 6 | 7 | import ( 8 | "fmt" 9 | "go/ast" 10 | "go/token" 11 | "strings" 12 | "unicode" 13 | ) 14 | 15 | // Mock is a mock of an interface type. 16 | type Mock struct { 17 | typeName string 18 | implements *ast.InterfaceType 19 | blockingReturn *bool 20 | } 21 | 22 | // For returns a Mock representing typ. An error will be returned 23 | // if a mock cannot be created from typ. 24 | func For(typ *ast.TypeSpec) (Mock, error) { 25 | inter, ok := typ.Type.(*ast.InterfaceType) 26 | if !ok { 27 | return Mock{}, fmt.Errorf("TypeSpec.Type expected to be *ast.InterfaceType, was %T", typ.Type) 28 | } 29 | var blockingReturn bool 30 | m := Mock{ 31 | typeName: typ.Name.String(), 32 | implements: inter, 33 | blockingReturn: &blockingReturn, 34 | } 35 | return m, nil 36 | } 37 | 38 | // Name returns the type name for m. 39 | func (m Mock) Name() string { 40 | return "mock" + strings.ToUpper(m.typeName[0:1]) + m.typeName[1:] 41 | } 42 | 43 | // Methods returns the methods that need to be created with m 44 | // as a receiver. 45 | func (m Mock) Methods() (methods []Method) { 46 | for _, method := range m.implements.Methods.List { 47 | switch methodType := method.Type.(type) { 48 | case *ast.FuncType: 49 | methods = append(methods, MethodFor(m, method.Names[0].String(), methodType)) 50 | } 51 | } 52 | return 53 | } 54 | 55 | // PrependLocalPackage prepends name as the package name for local types 56 | // in m's signature. This is most often used when mocking types that are 57 | // imported by the local package. 58 | func (m Mock) PrependLocalPackage(name string) { 59 | for _, m := range m.Methods() { 60 | m.PrependLocalPackage(name) 61 | } 62 | } 63 | 64 | // SetBlockingReturn sets whether or not methods will include a blocking 65 | // return channel, most often used for testing data races. 66 | func (m Mock) SetBlockingReturn(blockingReturn bool) { 67 | *m.blockingReturn = blockingReturn 68 | } 69 | 70 | // Constructor returns a function AST to construct m. chanSize will be 71 | // the buffer size for all channels initialized in the constructor. 72 | func (m Mock) Constructor(chanSize int) *ast.FuncDecl { 73 | decl := &ast.FuncDecl{} 74 | typeRunes := []rune(m.Name()) 75 | typeRunes[0] = unicode.ToUpper(typeRunes[0]) 76 | decl.Name = &ast.Ident{Name: "new" + string(typeRunes)} 77 | decl.Type = &ast.FuncType{ 78 | Results: &ast.FieldList{List: []*ast.Field{{ 79 | Type: &ast.StarExpr{ 80 | X: &ast.Ident{Name: m.Name()}, 81 | }, 82 | }}}, 83 | } 84 | decl.Body = &ast.BlockStmt{List: m.constructorBody(chanSize)} 85 | return decl 86 | } 87 | 88 | // Decl returns the declaration AST for m. 89 | func (m Mock) Decl() *ast.GenDecl { 90 | spec := &ast.TypeSpec{} 91 | spec.Name = &ast.Ident{Name: m.Name()} 92 | spec.Type = m.structType() 93 | return &ast.GenDecl{ 94 | Tok: token.TYPE, 95 | Specs: []ast.Spec{spec}, 96 | } 97 | } 98 | 99 | // Ast returns all declaration AST for m. 100 | func (m Mock) Ast(chanSize int) []ast.Decl { 101 | decls := []ast.Decl{ 102 | m.Decl(), 103 | m.Constructor(chanSize), 104 | } 105 | for _, method := range m.Methods() { 106 | decls = append(decls, method.Ast()) 107 | } 108 | return decls 109 | } 110 | 111 | func (m Mock) constructorBody(chanSize int) []ast.Stmt { 112 | structAlloc := &ast.AssignStmt{ 113 | Lhs: []ast.Expr{&ast.Ident{Name: "m"}}, 114 | Tok: token.DEFINE, 115 | Rhs: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.CompositeLit{Type: &ast.Ident{Name: m.Name()}}}}, 116 | } 117 | stmts := []ast.Stmt{structAlloc} 118 | for _, method := range m.Methods() { 119 | stmts = append(stmts, method.chanInit(chanSize)...) 120 | } 121 | stmts = append(stmts, &ast.ReturnStmt{Results: []ast.Expr{&ast.Ident{Name: "m"}}}) 122 | return stmts 123 | } 124 | 125 | func (m Mock) structType() *ast.StructType { 126 | structType := &ast.StructType{Fields: &ast.FieldList{}} 127 | for _, method := range m.Methods() { 128 | structType.Fields.List = append(structType.Fields.List, method.Fields()...) 129 | } 130 | return structType 131 | } 132 | -------------------------------------------------------------------------------- /mocks/mock_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package mocks_test 6 | 7 | import ( 8 | "go/ast" 9 | "go/format" 10 | "testing" 11 | 12 | "github.com/a8m/expect" 13 | "github.com/nelsam/hel/mocks" 14 | ) 15 | 16 | func TestNewErrorsForNonInterfaceTypes(t *testing.T) { 17 | expect := expect.New(t) 18 | 19 | spec := typeSpec(expect, "type Foo func()") 20 | _, err := mocks.For(spec) 21 | expect(err).Not.To.Be.Nil().Else.FailNow() 22 | expect(err.Error()).To.Equal("TypeSpec.Type expected to be *ast.InterfaceType, was *ast.FuncType") 23 | } 24 | 25 | func TestMockName(t *testing.T) { 26 | expect := expect.New(t) 27 | 28 | spec := typeSpec(expect, "type Foo interface{}") 29 | m, err := mocks.For(spec) 30 | expect(err).To.Be.Nil().Else.FailNow() 31 | expect(m.Name()).To.Equal("mockFoo") 32 | } 33 | 34 | func TestMockTypeDecl(t *testing.T) { 35 | expect := expect.New(t) 36 | 37 | spec := typeSpec(expect, ` 38 | type Foo interface { 39 | Foo(foo string) int 40 | Bar(bar int) Foo 41 | Baz() 42 | Bacon(func(Eggs) Eggs) func(Eggs) Eggs 43 | } 44 | `) 45 | m, err := mocks.For(spec) 46 | expect(err).To.Be.Nil().Else.FailNow() 47 | 48 | expected, err := format.Source([]byte(` 49 | package foo 50 | 51 | type mockFoo struct { 52 | FooCalled chan bool 53 | FooInput struct { 54 | Foo chan string 55 | } 56 | FooOutput struct { 57 | Ret0 chan int 58 | } 59 | BarCalled chan bool 60 | BarInput struct { 61 | Bar chan int 62 | } 63 | BarOutput struct { 64 | Ret0 chan Foo 65 | } 66 | BazCalled chan bool 67 | BaconCalled chan bool 68 | BaconInput struct { 69 | Arg0 chan func(Eggs) Eggs 70 | } 71 | BaconOutput struct { 72 | Ret0 chan func(Eggs) Eggs 73 | } 74 | } 75 | `)) 76 | expect(err).To.Be.Nil().Else.FailNow() 77 | 78 | src := source(expect, "foo", []ast.Decl{m.Decl()}, nil) 79 | expect(src).To.Equal(string(expected)) 80 | 81 | m.PrependLocalPackage("foo") 82 | 83 | expected, err = format.Source([]byte(` 84 | package foo 85 | 86 | type mockFoo struct { 87 | FooCalled chan bool 88 | FooInput struct { 89 | Foo chan string 90 | } 91 | FooOutput struct { 92 | Ret0 chan int 93 | } 94 | BarCalled chan bool 95 | BarInput struct { 96 | Bar chan int 97 | } 98 | BarOutput struct { 99 | Ret0 chan foo.Foo 100 | } 101 | BazCalled chan bool 102 | BaconCalled chan bool 103 | BaconInput struct { 104 | Arg0 chan func(foo.Eggs) foo.Eggs 105 | } 106 | BaconOutput struct { 107 | Ret0 chan func(foo.Eggs) foo.Eggs 108 | } 109 | } 110 | `)) 111 | expect(err).To.Be.Nil().Else.FailNow() 112 | 113 | src = source(expect, "foo", []ast.Decl{m.Decl()}, nil) 114 | expect(src).To.Equal(string(expected)) 115 | 116 | m.SetBlockingReturn(true) 117 | 118 | expected, err = format.Source([]byte(` 119 | package foo 120 | 121 | type mockFoo struct { 122 | FooCalled chan bool 123 | FooInput struct { 124 | Foo chan string 125 | } 126 | FooOutput struct { 127 | Ret0 chan int 128 | } 129 | BarCalled chan bool 130 | BarInput struct { 131 | Bar chan int 132 | } 133 | BarOutput struct { 134 | Ret0 chan foo.Foo 135 | } 136 | BazCalled chan bool 137 | BazOutput struct { 138 | BlockReturn chan bool 139 | } 140 | BaconCalled chan bool 141 | BaconInput struct { 142 | Arg0 chan func(foo.Eggs) foo.Eggs 143 | } 144 | BaconOutput struct { 145 | Ret0 chan func(foo.Eggs) foo.Eggs 146 | } 147 | } 148 | `)) 149 | expect(err).To.Be.Nil().Else.FailNow() 150 | 151 | src = source(expect, "foo", []ast.Decl{m.Decl()}, nil) 152 | expect(src).To.Equal(string(expected)) 153 | } 154 | 155 | func TestMockTypeDecl_DirectionalChansGetParens(t *testing.T) { 156 | expect := expect.New(t) 157 | 158 | spec := typeSpec(expect, ` 159 | type Foo interface { 160 | Foo(foo chan<- int) <-chan int 161 | } 162 | `) 163 | m, err := mocks.For(spec) 164 | expect(err).To.Be.Nil().Else.FailNow() 165 | 166 | expected, err := format.Source([]byte(` 167 | package foo 168 | 169 | type mockFoo struct { 170 | FooCalled chan bool 171 | FooInput struct { 172 | Foo chan (chan<- int) 173 | } 174 | FooOutput struct { 175 | Ret0 chan (<-chan int) 176 | } 177 | } 178 | `)) 179 | expect(err).To.Be.Nil().Else.FailNow() 180 | 181 | src := source(expect, "foo", []ast.Decl{m.Decl()}, nil) 182 | expect(src).To.Equal(string(expected)) 183 | } 184 | 185 | func TestMockTypeDecl_VariadicMethods(t *testing.T) { 186 | expect := expect.New(t) 187 | 188 | spec := typeSpec(expect, ` 189 | type Foo interface { 190 | Foo(foo ...int) 191 | } 192 | `) 193 | m, err := mocks.For(spec) 194 | expect(err).To.Be.Nil().Else.FailNow() 195 | 196 | expected, err := format.Source([]byte(` 197 | package foo 198 | 199 | type mockFoo struct { 200 | FooCalled chan bool 201 | FooInput struct { 202 | Foo chan []int 203 | } 204 | } 205 | `)) 206 | expect(err).To.Be.Nil().Else.FailNow() 207 | 208 | src := source(expect, "foo", []ast.Decl{m.Decl()}, nil) 209 | expect(src).To.Equal(string(expected)) 210 | } 211 | 212 | func TestMockTypeDecl_ParamsWithoutTypes(t *testing.T) { 213 | expect := expect.New(t) 214 | 215 | spec := typeSpec(expect, ` 216 | type Foo interface { 217 | Foo(foo, bar string) 218 | } 219 | `) 220 | m, err := mocks.For(spec) 221 | expect(err).To.Be.Nil().Else.FailNow() 222 | 223 | expected, err := format.Source([]byte(` 224 | package foo 225 | 226 | type mockFoo struct { 227 | FooCalled chan bool 228 | FooInput struct { 229 | Foo, Bar chan string 230 | } 231 | } 232 | `)) 233 | expect(err).To.Be.Nil().Else.FailNow() 234 | 235 | src := source(expect, "foo", []ast.Decl{m.Decl()}, nil) 236 | expect(src).To.Equal(string(expected)) 237 | } 238 | 239 | func TestMockConstructor(t *testing.T) { 240 | expect := expect.New(t) 241 | 242 | spec := typeSpec(expect, ` 243 | type Foo interface { 244 | Foo(foo string) int 245 | Bar(bar int) string 246 | } 247 | `) 248 | m, err := mocks.For(spec) 249 | expect(err).To.Be.Nil().Else.FailNow() 250 | 251 | expected, err := format.Source([]byte(` 252 | package foo 253 | 254 | func newMockFoo() *mockFoo { 255 | m := &mockFoo{} 256 | m.FooCalled = make(chan bool, 300) 257 | m.FooInput.Foo = make(chan string, 300) 258 | m.FooOutput.Ret0 = make(chan int, 300) 259 | m.BarCalled = make(chan bool, 300) 260 | m.BarInput.Bar = make(chan int, 300) 261 | m.BarOutput.Ret0 = make(chan string, 300) 262 | return m 263 | }`)) 264 | expect(err).To.Be.Nil().Else.FailNow() 265 | 266 | src := source(expect, "foo", []ast.Decl{m.Constructor(300)}, nil) 267 | expect(src).To.Equal(string(expected)) 268 | } 269 | 270 | func TestMockConstructor_DirectionalChansGetParens(t *testing.T) { 271 | expect := expect.New(t) 272 | 273 | spec := typeSpec(expect, ` 274 | type Foo interface { 275 | Foo(foo chan<- int) <-chan int 276 | } 277 | `) 278 | m, err := mocks.For(spec) 279 | expect(err).To.Be.Nil().Else.FailNow() 280 | 281 | expected, err := format.Source([]byte(` 282 | package foo 283 | 284 | func newMockFoo() *mockFoo { 285 | m := &mockFoo{} 286 | m.FooCalled = make(chan bool, 200) 287 | m.FooInput.Foo = make(chan (chan<- int), 200) 288 | m.FooOutput.Ret0 = make(chan (<-chan int), 200) 289 | return m 290 | } 291 | `)) 292 | expect(err).To.Be.Nil().Else.FailNow() 293 | 294 | src := source(expect, "foo", []ast.Decl{m.Constructor(200)}, nil) 295 | expect(src).To.Equal(string(expected)) 296 | } 297 | 298 | func TestMockConstructor_VariadicParams(t *testing.T) { 299 | expect := expect.New(t) 300 | 301 | spec := typeSpec(expect, ` 302 | type Foo interface { 303 | Foo(foo ...int) 304 | } 305 | `) 306 | m, err := mocks.For(spec) 307 | expect(err).To.Be.Nil().Else.FailNow() 308 | 309 | expected, err := format.Source([]byte(` 310 | package foo 311 | 312 | func newMockFoo() *mockFoo { 313 | m := &mockFoo{} 314 | m.FooCalled = make(chan bool, 200) 315 | m.FooInput.Foo = make(chan []int, 200) 316 | return m 317 | } 318 | `)) 319 | expect(err).To.Be.Nil().Else.FailNow() 320 | 321 | src := source(expect, "foo", []ast.Decl{m.Constructor(200)}, nil) 322 | expect(src).To.Equal(string(expected)) 323 | } 324 | 325 | func TestMockAst(t *testing.T) { 326 | expect := expect.New(t) 327 | 328 | spec := typeSpec(expect, ` 329 | type Foo interface { 330 | Bar(bar string) 331 | Baz() (baz int) 332 | }`) 333 | m, err := mocks.For(spec) 334 | expect(err).To.Be.Nil().Else.FailNow() 335 | 336 | decls := m.Ast(300) 337 | expect(decls).To.Have.Len(4).Else.FailNow() 338 | expect(decls[0]).To.Equal(m.Decl()) 339 | expect(decls[1]).To.Equal(m.Constructor(300)) 340 | expect(m.Methods()).To.Have.Len(2).Else.FailNow() 341 | expect(decls[2]).To.Equal(m.Methods()[0].Ast()) 342 | expect(decls[3]).To.Equal(m.Methods()[1].Ast()) 343 | } 344 | -------------------------------------------------------------------------------- /mocks/mocks.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | //go:generate hel 6 | 7 | package mocks 8 | 9 | import ( 10 | "bytes" 11 | "go/ast" 12 | "go/format" 13 | "go/parser" 14 | "go/token" 15 | "io" 16 | "os" 17 | "strconv" 18 | "strings" 19 | 20 | "github.com/nelsam/hel/types" 21 | "golang.org/x/tools/go/ast/astutil" 22 | ) 23 | 24 | const commentHeader = `// This file was generated by github.com/nelsam/hel. Do not 25 | // edit this code by hand unless you *really* know what you're 26 | // doing. Expect any changes made manually to be overwritten 27 | // the next time hel regenerates this file. 28 | 29 | ` 30 | 31 | // TypeFinder represents a type which knows about types and dependencies. 32 | type TypeFinder interface { 33 | ExportedTypes() (types []*ast.TypeSpec) 34 | Dependencies(inter *ast.InterfaceType) (dependencies []types.Dependency) 35 | } 36 | 37 | // Mocks is a slice of Mock values. 38 | type Mocks []Mock 39 | 40 | // Output writes the go code representing m to dest. pkg will be the 41 | // package name; dir is the destination directory (needed for formatting 42 | // the file); chanSize is the buffer size of any channels created in 43 | // constructors. 44 | func (m Mocks) Output(pkg, dir string, chanSize int, dest io.Writer) error { 45 | if _, err := dest.Write([]byte(commentHeader)); err != nil { 46 | return err 47 | } 48 | 49 | fset := token.NewFileSet() 50 | 51 | f := &ast.File{ 52 | Name: &ast.Ident{Name: pkg}, 53 | Decls: m.decls(chanSize), 54 | } 55 | 56 | var b bytes.Buffer 57 | format.Node(&b, fset, f) 58 | 59 | // TODO: Determine why adding imports without creating a new ast file 60 | // will only allow one import to be printed to the file. 61 | fset = token.NewFileSet() 62 | file, err := parser.ParseFile(fset, pkg, &b, 0) 63 | if err != nil { 64 | return err 65 | } 66 | 67 | file, fset, err = addImports(file, fset, dir) 68 | if err != nil { 69 | return err 70 | } 71 | 72 | return format.Node(dest, fset, file) 73 | } 74 | 75 | // PrependLocalPackage prepends name as the package name for local types 76 | // in m's signature. This is most often used when mocking types that are 77 | // imported by the local package. 78 | func (m Mocks) PrependLocalPackage(name string) { 79 | for _, m := range m { 80 | m.PrependLocalPackage(name) 81 | } 82 | } 83 | 84 | // SetBlockingReturn sets whether or not methods will include a blocking 85 | // return channel, most often used for testing data races. 86 | func (m Mocks) SetBlockingReturn(blockingReturn bool) { 87 | for _, m := range m { 88 | m.SetBlockingReturn(blockingReturn) 89 | } 90 | } 91 | 92 | func (m Mocks) decls(chanSize int) (decls []ast.Decl) { 93 | for _, mock := range m { 94 | decls = append(decls, mock.Ast(chanSize)...) 95 | } 96 | return decls 97 | } 98 | 99 | func addImports(file *ast.File, fset *token.FileSet, dirPath string) (*ast.File, *token.FileSet, error) { 100 | imports, err := getImports(dirPath, fset) 101 | if err != nil { 102 | return nil, nil, err 103 | } 104 | 105 | for _, s := range imports { 106 | unquotedPath, err := strconv.Unquote(s.Path.Value) 107 | if err != nil { 108 | return nil, nil, err 109 | } 110 | 111 | if s.Name != nil { 112 | astutil.AddNamedImport(fset, file, s.Name.Name, unquotedPath) 113 | continue 114 | } 115 | 116 | astutil.AddImport(fset, file, unquotedPath) 117 | } 118 | 119 | return file, fset, nil 120 | } 121 | 122 | func getImports(dirPath string, fset *token.FileSet) ([]*ast.ImportSpec, error) { 123 | // Grab imports from all files except helheim_test 124 | pkgs, err := parser.ParseDir(fset, dirPath, func(info os.FileInfo) bool { 125 | return !strings.Contains(info.Name(), "_test.go") 126 | }, parser.ImportsOnly) 127 | if err != nil { 128 | return nil, err 129 | } 130 | 131 | var imports []*ast.ImportSpec 132 | for _, p := range pkgs { 133 | files := p.Files 134 | for _, f := range files { 135 | imports = append(imports, f.Imports...) 136 | } 137 | } 138 | return imports, nil 139 | } 140 | 141 | // Generate generates a Mocks value for all exported interface 142 | // types returned by finder. 143 | func Generate(finder TypeFinder) (Mocks, error) { 144 | base := finder.ExportedTypes() 145 | var ( 146 | typs []*ast.TypeSpec 147 | deps []types.Dependency 148 | ) 149 | for _, typ := range base { 150 | typs = append(typs, typ) 151 | if inter, ok := typ.Type.(*ast.InterfaceType); ok { 152 | deps = append(deps, finder.Dependencies(inter)...) 153 | } 154 | } 155 | deps = deDupe(typs, deps) 156 | m := make(Mocks, 0, len(typs)) 157 | for _, typ := range typs { 158 | newMock, err := For(typ) 159 | if err != nil { 160 | return nil, err 161 | } 162 | m = append(m, newMock) 163 | } 164 | for _, dep := range deps { 165 | newMock, err := For(dep.Type) 166 | if err != nil { 167 | return nil, err 168 | } 169 | newMock.PrependLocalPackage(dep.PkgName) 170 | m = append(m, newMock) 171 | } 172 | return m, nil 173 | } 174 | 175 | func deDupe(typs []*ast.TypeSpec, deps []types.Dependency) []types.Dependency { 176 | for _, typ := range typs { 177 | for i := 0; i < len(deps); i++ { 178 | if deps[i].Type.Name.Name != typ.Name.Name { 179 | continue 180 | } 181 | if deps[i].PkgName == "" { 182 | deps = append(deps[:i], deps[i+1:]...) 183 | i-- 184 | continue 185 | } 186 | deps[i] = separate(deps[i], typ) 187 | } 188 | } 189 | for i := 0; i < len(deps); i++ { 190 | for j := i + 1; j < len(deps); j++ { 191 | if equal(deps[i], deps[j]) { 192 | deps = append(deps[:j], deps[j+1:]...) 193 | j-- 194 | continue 195 | } 196 | deps[j] = separate(deps[j], deps[i].Type) 197 | } 198 | } 199 | return deps 200 | } 201 | 202 | func equal(a, b types.Dependency) bool { 203 | if a.PkgName != b.PkgName { 204 | return false 205 | } 206 | if a.Type.Name.Name != b.Type.Name.Name { 207 | return false 208 | } 209 | return true 210 | } 211 | 212 | func separate(dep types.Dependency, from *ast.TypeSpec) types.Dependency { 213 | if dep.Type.Name.Name != from.Name.Name { 214 | return dep 215 | } 216 | pkgTitle := strings.Title(dep.PkgName) 217 | if !strings.HasSuffix(dep.Type.Name.Name, pkgTitle) { 218 | dep.Type.Name.Name = pkgTitle + dep.Type.Name.Name 219 | } 220 | return dep 221 | } 222 | -------------------------------------------------------------------------------- /mocks/mocks_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package mocks_test 6 | 7 | import ( 8 | "bytes" 9 | "go/ast" 10 | "go/format" 11 | "testing" 12 | 13 | "github.com/a8m/expect" 14 | "github.com/nelsam/hel/mocks" 15 | "github.com/nelsam/hel/types" 16 | ) 17 | 18 | func TestGenerate(t *testing.T) { 19 | expect := expect.New(t) 20 | 21 | types := []*ast.TypeSpec{ 22 | typeSpec(expect, ` 23 | type Foo interface { 24 | Bar() int 25 | }`), 26 | typeSpec(expect, ` 27 | type Bar interface { 28 | Foo(foo string) 29 | Baz() 30 | }`), 31 | } 32 | 33 | mockFinder := newMockTypeFinder() 34 | close(mockFinder.DependenciesOutput.Dependencies) 35 | mockFinder.ExportedTypesOutput.Types <- types 36 | m, err := mocks.Generate(mockFinder) 37 | expect(err).To.Be.Nil().Else.FailNow() 38 | expect(m).To.Have.Len(2).Else.FailNow() 39 | expect(m[0]).To.Equal(mockFor(expect, types[0])) 40 | expect(m[1]).To.Equal(mockFor(expect, types[1])) 41 | } 42 | 43 | func TestOutput(t *testing.T) { 44 | expect := expect.New(t) 45 | 46 | types := []*ast.TypeSpec{ 47 | typeSpec(expect, ` 48 | type Foo interface { 49 | Bar() int 50 | }`), 51 | typeSpec(expect, ` 52 | type Bar interface { 53 | Foo(foo string) Foo 54 | Baz() 55 | Bacon(func(Eggs) Eggs) func(Eggs) Eggs 56 | }`), 57 | } 58 | 59 | mockFinder := newMockTypeFinder() 60 | close(mockFinder.DependenciesOutput.Dependencies) 61 | mockFinder.ExportedTypesOutput.Types <- types 62 | m, err := mocks.Generate(mockFinder) 63 | expect(err).To.Be.Nil().Else.FailNow() 64 | 65 | buf := bytes.Buffer{} 66 | m.Output("foo", "test/withoutimports", 100, &buf) 67 | 68 | expected, err := format.Source([]byte(` 69 | // This file was generated by github.com/nelsam/hel. Do not 70 | // edit this code by hand unless you *really* know what you're 71 | // doing. Expect any changes made manually to be overwritten 72 | // the next time hel regenerates this file. 73 | 74 | package foo 75 | 76 | type mockFoo struct { 77 | BarCalled chan bool 78 | BarOutput struct { 79 | Ret0 chan int 80 | } 81 | } 82 | 83 | func newMockFoo() *mockFoo { 84 | m := &mockFoo{} 85 | m.BarCalled = make(chan bool, 100) 86 | m.BarOutput.Ret0 = make(chan int, 100) 87 | return m 88 | } 89 | func (m *mockFoo) Bar() int { 90 | m.BarCalled <- true 91 | return <-m.BarOutput.Ret0 92 | } 93 | 94 | type mockBar struct { 95 | FooCalled chan bool 96 | FooInput struct { 97 | Foo chan string 98 | } 99 | FooOutput struct { 100 | Ret0 chan Foo 101 | } 102 | BazCalled chan bool 103 | BaconCalled chan bool 104 | BaconInput struct { 105 | Arg0 chan func(Eggs) Eggs 106 | } 107 | BaconOutput struct { 108 | Ret0 chan func(Eggs) Eggs 109 | } 110 | } 111 | 112 | func newMockBar() *mockBar { 113 | m := &mockBar{} 114 | m.FooCalled = make(chan bool, 100) 115 | m.FooInput.Foo = make(chan string, 100) 116 | m.FooOutput.Ret0 = make(chan Foo, 100) 117 | m.BazCalled = make(chan bool, 100) 118 | m.BaconCalled = make(chan bool, 100) 119 | m.BaconInput.Arg0 = make(chan func(Eggs) Eggs, 100) 120 | m.BaconOutput.Ret0 = make(chan func(Eggs) Eggs, 100) 121 | return m 122 | } 123 | func (m *mockBar) Foo(foo string) Foo { 124 | m.FooCalled <- true 125 | m.FooInput.Foo <- foo 126 | return <-m.FooOutput.Ret0 127 | } 128 | func (m *mockBar) Baz() { 129 | m.BazCalled <- true 130 | } 131 | func (m *mockBar) Bacon(arg0 func(Eggs) Eggs) func(Eggs) Eggs { 132 | m.BaconCalled <- true 133 | m.BaconInput.Arg0 <- arg0 134 | return <-m.BaconOutput.Ret0 135 | } 136 | `)) 137 | expect(err).To.Be.Nil().Else.FailNow() 138 | expect(buf.String()).To.Equal(string(expected)) 139 | 140 | m.PrependLocalPackage("foo") 141 | buf = bytes.Buffer{} 142 | m.Output("foo_test", "test/withoutimports", 100, &buf) 143 | 144 | expected, err = format.Source([]byte(` 145 | // This file was generated by github.com/nelsam/hel. Do not 146 | // edit this code by hand unless you *really* know what you're 147 | // doing. Expect any changes made manually to be overwritten 148 | // the next time hel regenerates this file. 149 | 150 | package foo_test 151 | 152 | type mockFoo struct { 153 | BarCalled chan bool 154 | BarOutput struct { 155 | Ret0 chan int 156 | } 157 | } 158 | 159 | func newMockFoo() *mockFoo { 160 | m := &mockFoo{} 161 | m.BarCalled = make(chan bool, 100) 162 | m.BarOutput.Ret0 = make(chan int, 100) 163 | return m 164 | } 165 | func (m *mockFoo) Bar() int { 166 | m.BarCalled <- true 167 | return <-m.BarOutput.Ret0 168 | } 169 | 170 | type mockBar struct { 171 | FooCalled chan bool 172 | FooInput struct { 173 | Foo chan string 174 | } 175 | FooOutput struct { 176 | Ret0 chan foo.Foo 177 | } 178 | BazCalled chan bool 179 | BaconCalled chan bool 180 | BaconInput struct { 181 | Arg0 chan func(foo.Eggs) foo.Eggs 182 | } 183 | BaconOutput struct { 184 | Ret0 chan func(foo.Eggs) foo.Eggs 185 | } 186 | } 187 | 188 | func newMockBar() *mockBar { 189 | m := &mockBar{} 190 | m.FooCalled = make(chan bool, 100) 191 | m.FooInput.Foo = make(chan string, 100) 192 | m.FooOutput.Ret0 = make(chan foo.Foo, 100) 193 | m.BazCalled = make(chan bool, 100) 194 | m.BaconCalled = make(chan bool, 100) 195 | m.BaconInput.Arg0 = make(chan func(foo.Eggs) foo.Eggs, 100) 196 | m.BaconOutput.Ret0 = make(chan func(foo.Eggs) foo.Eggs, 100) 197 | return m 198 | } 199 | func (m *mockBar) Foo(foo string) foo.Foo { 200 | m.FooCalled <- true 201 | m.FooInput.Foo <- foo 202 | return <-m.FooOutput.Ret0 203 | } 204 | func (m *mockBar) Baz() { 205 | m.BazCalled <- true 206 | } 207 | func (m *mockBar) Bacon(arg0 func(foo.Eggs) foo.Eggs) func(foo.Eggs) foo.Eggs { 208 | m.BaconCalled <- true 209 | m.BaconInput.Arg0 <- arg0 210 | return <-m.BaconOutput.Ret0 211 | } 212 | `)) 213 | expect(err).To.Be.Nil().Else.FailNow() 214 | expect(buf.String()).To.Equal(string(expected)) 215 | 216 | m.SetBlockingReturn(true) 217 | buf = bytes.Buffer{} 218 | m.Output("foo_test", "test/withoutimports", 100, &buf) 219 | 220 | expected, err = format.Source([]byte(` 221 | // This file was generated by github.com/nelsam/hel. Do not 222 | // edit this code by hand unless you *really* know what you're 223 | // doing. Expect any changes made manually to be overwritten 224 | // the next time hel regenerates this file. 225 | 226 | package foo_test 227 | 228 | type mockFoo struct { 229 | BarCalled chan bool 230 | BarOutput struct { 231 | Ret0 chan int 232 | } 233 | } 234 | 235 | func newMockFoo() *mockFoo { 236 | m := &mockFoo{} 237 | m.BarCalled = make(chan bool, 100) 238 | m.BarOutput.Ret0 = make(chan int, 100) 239 | return m 240 | } 241 | func (m *mockFoo) Bar() int { 242 | m.BarCalled <- true 243 | return <-m.BarOutput.Ret0 244 | } 245 | 246 | type mockBar struct { 247 | FooCalled chan bool 248 | FooInput struct { 249 | Foo chan string 250 | } 251 | FooOutput struct { 252 | Ret0 chan foo.Foo 253 | } 254 | BazCalled chan bool 255 | BazOutput struct { 256 | BlockReturn chan bool 257 | } 258 | BaconCalled chan bool 259 | BaconInput struct { 260 | Arg0 chan func(foo.Eggs) foo.Eggs 261 | } 262 | BaconOutput struct { 263 | Ret0 chan func(foo.Eggs) foo.Eggs 264 | } 265 | } 266 | 267 | func newMockBar() *mockBar { 268 | m := &mockBar{} 269 | m.FooCalled = make(chan bool, 100) 270 | m.FooInput.Foo = make(chan string, 100) 271 | m.FooOutput.Ret0 = make(chan foo.Foo, 100) 272 | m.BazCalled = make(chan bool, 100) 273 | m.BazOutput.BlockReturn = make(chan bool, 100) 274 | m.BaconCalled = make(chan bool, 100) 275 | m.BaconInput.Arg0 = make(chan func(foo.Eggs) foo.Eggs, 100) 276 | m.BaconOutput.Ret0 = make(chan func(foo.Eggs) foo.Eggs, 100) 277 | return m 278 | } 279 | func (m *mockBar) Foo(foo string) foo.Foo { 280 | m.FooCalled <- true 281 | m.FooInput.Foo <- foo 282 | return <-m.FooOutput.Ret0 283 | } 284 | func (m *mockBar) Baz() { 285 | m.BazCalled <- true 286 | <-m.BazOutput.BlockReturn 287 | } 288 | func (m *mockBar) Bacon(arg0 func(foo.Eggs) foo.Eggs) func(foo.Eggs) foo.Eggs { 289 | m.BaconCalled <- true 290 | m.BaconInput.Arg0 <- arg0 291 | return <-m.BaconOutput.Ret0 292 | } 293 | `)) 294 | expect(err).To.Be.Nil().Else.FailNow() 295 | expect(buf.String()).To.Equal(string(expected)) 296 | } 297 | 298 | func TestOutput_Dependencies(t *testing.T) { 299 | expect := expect.New(t) 300 | 301 | typs := []*ast.TypeSpec{ 302 | typeSpec(expect, ` 303 | type Bar interface { 304 | Foo(foo string) (Foo, b.Foo) 305 | }`), 306 | typeSpec(expect, ` 307 | type Foo interface { 308 | Foo() string 309 | }`), 310 | } 311 | 312 | barDeps := []types.Dependency{ 313 | // We shouldn't see duplicates of types we're 314 | // already mocking. 315 | { 316 | Type: typeSpec(expect, ` 317 | type Foo interface{ 318 | Foo() string 319 | }`), 320 | }, 321 | // Different package names should have the type 322 | // name altered. 323 | { 324 | PkgName: "b", 325 | Type: typeSpec(expect, ` 326 | type Foo interface{ 327 | Foo() string 328 | }`), 329 | }, 330 | // Different types entirely should be supported. 331 | { 332 | PkgPath: "some/path/to/foo", 333 | PkgName: "baz", 334 | Type: typeSpec(expect, ` 335 | type Baz interface { 336 | Baz() Baz 337 | } 338 | `), 339 | }, 340 | } 341 | fooDeps := []types.Dependency{ 342 | // We shouldn't see duplicate dependencies from 343 | // previous types, either. 344 | { 345 | PkgPath: "some/path/to/foo", 346 | PkgName: "baz", 347 | Type: typeSpec(expect, ` 348 | type Baz interface { 349 | Baz() Baz 350 | } 351 | `), 352 | }, 353 | } 354 | 355 | mockFinder := newMockTypeFinder() 356 | mockFinder.ExportedTypesOutput.Types <- typs 357 | mockFinder.DependenciesOutput.Dependencies <- barDeps 358 | mockFinder.DependenciesOutput.Dependencies <- fooDeps 359 | m, err := mocks.Generate(mockFinder) 360 | expect(err).To.Be.Nil().Else.FailNow() 361 | 362 | buf := bytes.Buffer{} 363 | m.Output("foo", "test/withoutimports", 100, &buf) 364 | 365 | // TODO: For some reason, functions are coming out without 366 | // whitespace between them. We need to figure that out. 367 | expected, err := format.Source([]byte(` 368 | // This file was generated by github.com/nelsam/hel. Do not 369 | // edit this code by hand unless you *really* know what you're 370 | // doing. Expect any changes made manually to be overwritten 371 | // the next time hel regenerates this file. 372 | 373 | package foo 374 | 375 | type mockBar struct { 376 | FooCalled chan bool 377 | FooInput struct { 378 | Foo chan string 379 | } 380 | FooOutput struct { 381 | Ret0 chan Foo 382 | Ret1 chan b.Foo 383 | } 384 | } 385 | 386 | func newMockBar() *mockBar { 387 | m := &mockBar{} 388 | m.FooCalled = make(chan bool, 100) 389 | m.FooInput.Foo = make(chan string, 100) 390 | m.FooOutput.Ret0 = make(chan Foo, 100) 391 | m.FooOutput.Ret1 = make(chan b.Foo, 100) 392 | return m 393 | } 394 | func (m *mockBar) Foo(foo string) (Foo, b.Foo) { 395 | m.FooCalled <- true 396 | m.FooInput.Foo <- foo 397 | return <-m.FooOutput.Ret0, <-m.FooOutput.Ret1 398 | } 399 | 400 | type mockFoo struct { 401 | FooCalled chan bool 402 | FooOutput struct { 403 | Ret0 chan string 404 | } 405 | } 406 | 407 | func newMockFoo() *mockFoo { 408 | m := &mockFoo{} 409 | m.FooCalled = make(chan bool, 100) 410 | m.FooOutput.Ret0 = make(chan string, 100) 411 | return m 412 | } 413 | func (m *mockFoo) Foo() string { 414 | m.FooCalled <- true 415 | return <-m.FooOutput.Ret0 416 | } 417 | 418 | type mockBFoo struct { 419 | FooCalled chan bool 420 | FooOutput struct { 421 | Ret0 chan string 422 | } 423 | } 424 | 425 | func newMockBFoo() *mockBFoo { 426 | m := &mockBFoo{} 427 | m.FooCalled = make(chan bool, 100) 428 | m.FooOutput.Ret0 = make(chan string, 100) 429 | return m 430 | } 431 | func (m *mockBFoo) Foo() string { 432 | m.FooCalled <- true 433 | return <-m.FooOutput.Ret0 434 | } 435 | 436 | type mockBaz struct { 437 | BazCalled chan bool 438 | BazOutput struct { 439 | Ret0 chan baz.Baz 440 | } 441 | } 442 | 443 | func newMockBaz() *mockBaz { 444 | m := &mockBaz{} 445 | m.BazCalled = make(chan bool, 100) 446 | m.BazOutput.Ret0 = make(chan baz.Baz, 100) 447 | return m 448 | } 449 | func (m *mockBaz) Baz() baz.Baz { 450 | m.BazCalled <- true 451 | return <-m.BazOutput.Ret0 452 | } 453 | `)) 454 | expect(err).To.Be.Nil().Else.FailNow() 455 | expect(buf.String()).To.Equal(string(expected)) 456 | } 457 | 458 | func TestOutputWithPackageInputs(t *testing.T) { 459 | expect := expect.New(t) 460 | 461 | types := []*ast.TypeSpec{ 462 | typeSpec(expect, ` 463 | type Foo interface { 464 | Bar() int 465 | }`), 466 | typeSpec(expect, ` 467 | type Bar interface { 468 | Foo(foo string) Foo 469 | Baz() 470 | Bacon(func(Eggs) Eggs) func(Eggs) Eggs 471 | }`), 472 | } 473 | 474 | mockFinder := newMockTypeFinder() 475 | close(mockFinder.DependenciesOutput.Dependencies) 476 | mockFinder.ExportedTypesOutput.Types <- types 477 | m, err := mocks.Generate(mockFinder) 478 | expect(err).To.Be.Nil().Else.FailNow() 479 | 480 | buf := bytes.Buffer{} 481 | m.Output("foo", "test/withimports", 100, &buf) 482 | 483 | expected, err := format.Source([]byte(` 484 | // This file was generated by github.com/nelsam/hel. Do not 485 | // edit this code by hand unless you *really* know what you're 486 | // doing. Expect any changes made manually to be overwritten 487 | // the next time hel regenerates this file. 488 | 489 | package foo 490 | 491 | import ( 492 | thisIsFmt "fmt" 493 | "strconv" 494 | ) 495 | 496 | type mockFoo struct { 497 | BarCalled chan bool 498 | BarOutput struct { 499 | Ret0 chan int 500 | } 501 | } 502 | 503 | func newMockFoo() *mockFoo { 504 | m := &mockFoo{} 505 | m.BarCalled = make(chan bool, 100) 506 | m.BarOutput.Ret0 = make(chan int, 100) 507 | return m 508 | } 509 | func (m *mockFoo) Bar() int { 510 | m.BarCalled <- true 511 | return <-m.BarOutput.Ret0 512 | } 513 | 514 | type mockBar struct { 515 | FooCalled chan bool 516 | FooInput struct { 517 | Foo chan string 518 | } 519 | FooOutput struct { 520 | Ret0 chan Foo 521 | } 522 | BazCalled chan bool 523 | BaconCalled chan bool 524 | BaconInput struct { 525 | Arg0 chan func(Eggs) Eggs 526 | } 527 | BaconOutput struct { 528 | Ret0 chan func(Eggs) Eggs 529 | } 530 | } 531 | 532 | func newMockBar() *mockBar { 533 | m := &mockBar{} 534 | m.FooCalled = make(chan bool, 100) 535 | m.FooInput.Foo = make(chan string, 100) 536 | m.FooOutput.Ret0 = make(chan Foo, 100) 537 | m.BazCalled = make(chan bool, 100) 538 | m.BaconCalled = make(chan bool, 100) 539 | m.BaconInput.Arg0 = make(chan func(Eggs) Eggs, 100) 540 | m.BaconOutput.Ret0 = make(chan func(Eggs) Eggs, 100) 541 | return m 542 | } 543 | func (m *mockBar) Foo(foo string) Foo { 544 | m.FooCalled <- true 545 | m.FooInput.Foo <- foo 546 | return <-m.FooOutput.Ret0 547 | } 548 | func (m *mockBar) Baz() { 549 | m.BazCalled <- true 550 | } 551 | func (m *mockBar) Bacon(arg0 func(Eggs) Eggs) func(Eggs) Eggs { 552 | m.BaconCalled <- true 553 | m.BaconInput.Arg0 <- arg0 554 | return <-m.BaconOutput.Ret0 555 | } 556 | `)) 557 | expect(err).To.Be.Nil().Else.FailNow() 558 | expect(buf.String()).To.Equal(string(expected)) 559 | } 560 | 561 | func TestOutput_ReceiverNameInArgs(t *testing.T) { 562 | expect := expect.New(t) 563 | 564 | types := []*ast.TypeSpec{ 565 | typeSpec(expect, ` 566 | type Foo interface { 567 | Foo(m string) 568 | }`), 569 | } 570 | 571 | mockFinder := newMockTypeFinder() 572 | close(mockFinder.DependenciesOutput.Dependencies) 573 | mockFinder.ExportedTypesOutput.Types <- types 574 | m, err := mocks.Generate(mockFinder) 575 | expect(err).To.Be.Nil().Else.FailNow() 576 | 577 | buf := bytes.Buffer{} 578 | m.Output("foo", "test/withoutimports", 100, &buf) 579 | 580 | expected, err := format.Source([]byte(` 581 | // This file was generated by github.com/nelsam/hel. Do not 582 | // edit this code by hand unless you *really* know what you're 583 | // doing. Expect any changes made manually to be overwritten 584 | // the next time hel regenerates this file. 585 | 586 | package foo 587 | 588 | type mockFoo struct { 589 | FooCalled chan bool 590 | FooInput struct { 591 | M chan string 592 | } 593 | } 594 | 595 | func newMockFoo() *mockFoo { 596 | m := &mockFoo{} 597 | m.FooCalled = make(chan bool, 100) 598 | m.FooInput.M = make(chan string, 100) 599 | return m 600 | } 601 | func (m *mockFoo) Foo(m_ string) { 602 | m.FooCalled <- true 603 | m.FooInput.M <- m_ 604 | } 605 | `)) 606 | expect(err).To.Be.Nil().Else.FailNow() 607 | expect(buf.String()).To.Equal(string(expected)) 608 | } 609 | 610 | func mockFor(expect func(interface{}) *expect.Expect, spec *ast.TypeSpec) mocks.Mock { 611 | m, err := mocks.For(spec) 612 | expect(err).To.Be.Nil() 613 | return m 614 | } 615 | -------------------------------------------------------------------------------- /mocks/sugar.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package mocks 6 | 7 | import "go/ast" 8 | 9 | func selectors(receiver string, fields ...string) *ast.SelectorExpr { 10 | if len(fields) == 0 { 11 | return nil 12 | } 13 | selector := &ast.SelectorExpr{ 14 | X: &ast.Ident{Name: receiver}, 15 | Sel: &ast.Ident{Name: fields[0]}, 16 | } 17 | for _, field := range fields[1:] { 18 | selector = &ast.SelectorExpr{ 19 | X: selector, 20 | Sel: &ast.Ident{Name: field}, 21 | } 22 | } 23 | return selector 24 | } 25 | -------------------------------------------------------------------------------- /mocks/test/withimports/with_imports.go: -------------------------------------------------------------------------------- 1 | package withimports 2 | 3 | import ( 4 | thisIsFmt "fmt" 5 | "strconv" 6 | ) 7 | 8 | func toMakeThisCompile() { 9 | thisIsFmt.Fprint(nil, strconv.Quote("lemons")) 10 | } 11 | -------------------------------------------------------------------------------- /mocks/test/withoutimports/without_imports.go: -------------------------------------------------------------------------------- 1 | package withoutimports 2 | -------------------------------------------------------------------------------- /packages/packages.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package packages 6 | 7 | import ( 8 | "fmt" 9 | "go/build" 10 | "os" 11 | "path/filepath" 12 | "strings" 13 | 14 | "golang.org/x/tools/go/packages" 15 | ) 16 | 17 | var ( 18 | cwd string 19 | gopathEnv = os.Getenv("GOPATH") 20 | gopath = strings.Split(gopathEnv, string(os.PathListSeparator)) 21 | ) 22 | 23 | func init() { 24 | var err error 25 | cwd, err = os.Getwd() 26 | if err != nil { 27 | panic(err) 28 | } 29 | } 30 | 31 | // Dir represents a directory containing go files. 32 | type Dir struct { 33 | pkg *packages.Package 34 | fsPath string 35 | } 36 | 37 | // Load looks for directories matching the passed in package patterns 38 | // and returns Dir values for each directory that can be successfully 39 | // imported and is found to match one of the patterns. 40 | func Load(pkgPatterns ...string) []Dir { 41 | return load(cwd, pkgPatterns...) 42 | } 43 | 44 | func load(fromDir string, pkgPatterns ...string) (dirs []Dir) { 45 | pkgs, err := packages.Load(&packages.Config{ 46 | Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedDeps | packages.NeedSyntax, 47 | }, pkgPatterns...) 48 | if err != nil { 49 | panic(err) 50 | } 51 | for _, pkg := range pkgs { 52 | fsPath := "" 53 | if len(pkg.GoFiles) > 0 { 54 | fsPath = filepath.Dir(pkg.GoFiles[0]) 55 | } 56 | dirs = append(dirs, Dir{pkg: pkg, fsPath: fsPath}) 57 | } 58 | return dirs 59 | } 60 | 61 | // Path returns the file path to d. 62 | func (d Dir) Path() string { 63 | return d.fsPath 64 | } 65 | 66 | // Package returns the *packages.Package for d 67 | func (d Dir) Package() *packages.Package { 68 | return d.pkg 69 | } 70 | 71 | // Import imports path from srcDir, then loads the ast for that package. 72 | // It ensures that the returned ast is for the package that would be 73 | // imported by an import clause. 74 | func (d Dir) Import(path string) (*packages.Package, error) { 75 | p, ok := nestedImport(d.pkg, path) 76 | if !ok { 77 | return nil, fmt.Errorf("Could not find import %s in package %s", path, d.Path()) 78 | } 79 | return p, nil 80 | } 81 | 82 | func nestedImport(pkg *packages.Package, path string) (*packages.Package, bool) { 83 | if p, ok := pkg.Imports[path]; ok { 84 | return p, true 85 | } 86 | for _, p := range pkg.Imports { 87 | if subp, ok := nestedImport(p, path); ok { 88 | return subp, true 89 | } 90 | } 91 | return nil, false 92 | } 93 | 94 | func parsePatterns(fromDir string, pkgPatterns ...string) (packages []string) { 95 | for _, pkgPattern := range pkgPatterns { 96 | if !strings.HasSuffix(pkgPattern, "...") { 97 | packages = append(packages, pkgPattern) 98 | continue 99 | } 100 | parent := strings.TrimSuffix(pkgPattern, "...") 101 | parentPkg, err := build.Import(parent, fromDir, build.AllowBinary) 102 | if err != nil { 103 | panic(err) 104 | } 105 | filepath.Walk(parentPkg.Dir, func(path string, info os.FileInfo, err error) error { 106 | if !info.IsDir() { 107 | return nil 108 | } 109 | path = strings.Replace(path, parentPkg.Dir, parent, 1) 110 | if _, err := build.Import(path, fromDir, build.AllowBinary); err != nil { 111 | // This directory doesn't appear to be a go package 112 | return nil 113 | } 114 | packages = append(packages, path) 115 | return nil 116 | }) 117 | } 118 | return 119 | } 120 | -------------------------------------------------------------------------------- /packages/packages_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package packages_test 6 | 7 | import ( 8 | "os" 9 | "path/filepath" 10 | "testing" 11 | 12 | "github.com/nelsam/hel/packages" 13 | "github.com/poy/onpar" 14 | "github.com/poy/onpar/expect" 15 | "github.com/poy/onpar/matchers" 16 | ) 17 | 18 | type expectation = expect.Expectation 19 | 20 | var ( 21 | not = matchers.Not 22 | beNil = matchers.BeNil 23 | equal = matchers.Equal 24 | haveLen = matchers.HaveLen 25 | haveOccurred = matchers.HaveOccurred 26 | ) 27 | 28 | func TestLoad(t *testing.T) { 29 | o := onpar.New() 30 | defer o.Run(t) 31 | 32 | o.BeforeEach(func(t *testing.T) expectation { 33 | return expect.New(t) 34 | }) 35 | 36 | o.Spec("All", func(expect expectation) { 37 | wd, err := os.Getwd() 38 | expect(err).To(not(haveOccurred())) 39 | 40 | dirs := packages.Load(".") 41 | expect(dirs).To(haveLen(1)) 42 | expect(dirs[0].Path()).To(equal(filepath.Join(wd))) 43 | expect(dirs[0].Package().Name).To(equal("packages")) 44 | 45 | dirs = packages.Load("github.com/nelsam/hel/mocks") 46 | expect(dirs).To(haveLen(1)) 47 | expect(dirs[0].Path()).To(equal(filepath.Join(filepath.Dir(wd), "mocks"))) 48 | 49 | dirs = packages.Load("github.com/nelsam/hel/...") 50 | expect(dirs).To(haveLen(7)) 51 | 52 | dirs = packages.Load("github.com/nelsam/hel") 53 | expect(dirs).To(haveLen(1)) 54 | 55 | _, err = dirs[0].Import("golang.org/x/tools/go/packages") 56 | expect(err).To(not(haveOccurred())) 57 | 58 | dir := dirs[0] 59 | 60 | pkg, err := dir.Import("path/filepath") 61 | expect(err).To(not(haveOccurred())) 62 | expect(pkg).To(not(beNil())) 63 | expect(pkg.Name).To(equal("filepath")) 64 | 65 | pkg, err = dir.Import("github.com/nelsam/hel/packages") 66 | expect(err).To(not(haveOccurred())) 67 | expect(pkg).To(not(beNil())) 68 | expect(pkg.Name).To(equal("packages")) 69 | 70 | _, err = dir.Import("../..") 71 | expect(err).To(haveOccurred()) 72 | }) 73 | } 74 | -------------------------------------------------------------------------------- /pers/consistent.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package pers 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "reflect" 11 | ) 12 | 13 | // ConsistentlyReturn will continue adding a given value to the channel 14 | // until the returned done function is called. You may pass in either 15 | // a channel (in which case you should pass in a single argument) or a 16 | // struct full of channels (in which case you should pass in arguments 17 | // in the order the fields appear in the struct). 18 | // 19 | // After the returned function is called, you will still need to drain 20 | // any remaining calls from the channel(s) before it will start blocking 21 | // again. 22 | func ConsistentlyReturn(mock interface{}, args ...interface{}) (func(), error) { 23 | cases, err := selectCases(mock, args...) 24 | if err != nil { 25 | return nil, err 26 | } 27 | done := make(chan struct{}) 28 | exited := make(chan struct{}) 29 | go consistentlyReturn(cases, done, exited, args...) 30 | return func() { 31 | close(done) 32 | <-exited 33 | }, nil 34 | } 35 | 36 | func selectCases(mock interface{}, args ...interface{}) ([]reflect.SelectCase, error) { 37 | v := reflect.ValueOf(mock) 38 | switch v.Kind() { 39 | case reflect.Chan: 40 | if len(args) != 1 { 41 | return nil, fmt.Errorf("expected 1 argument for %#v; got %d", mock, len(args)) 42 | } 43 | arg := args[0] 44 | argV := reflect.ValueOf(arg) 45 | if arg == nil { 46 | argV = reflect.Zero(v.Type().Elem()) 47 | } 48 | return []reflect.SelectCase{{Dir: reflect.SelectSend, Chan: v, Send: argV}}, nil 49 | case reflect.Struct: 50 | if v.NumField() == 0 { 51 | return nil, errors.New("cannot consistently return on unsupported type struct{}") 52 | } 53 | if len(args) != v.NumField() { 54 | argString := "argument" 55 | if v.NumField() != 1 { 56 | argString = "arguments" 57 | } 58 | return nil, fmt.Errorf("expected %d %s for %#v; got %d", v.NumField(), argString, mock, len(args)) 59 | } 60 | var cases []reflect.SelectCase 61 | for i := 0; i < v.NumField(); i++ { 62 | c, err := selectCases(v.Field(i).Interface(), args[i]) 63 | if err != nil { 64 | return nil, err 65 | } 66 | cases = append(cases, c...) 67 | } 68 | return cases, nil 69 | default: 70 | return nil, fmt.Errorf("cannot consistently return on unsupported type %T", mock) 71 | } 72 | } 73 | 74 | func consistentlyReturn(cases []reflect.SelectCase, done, exited chan struct{}, args ...interface{}) { 75 | defer close(exited) 76 | doneIdx := len(cases) 77 | cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(done)}) 78 | for { 79 | chosen, _, _ := reflect.Select(cases) 80 | if chosen == doneIdx { 81 | return 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /pers/consistent_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package pers_test 6 | 7 | import ( 8 | "fmt" 9 | "reflect" 10 | "testing" 11 | "time" 12 | 13 | "github.com/nelsam/hel/pers" 14 | "github.com/poy/onpar" 15 | "github.com/poy/onpar/expect" 16 | ) 17 | 18 | func TestConsistentlyReturn(t *testing.T) { 19 | o := onpar.New() 20 | defer o.Run(t) 21 | 22 | o.BeforeEach(func(t *testing.T) expectation { 23 | return expect.New(t) 24 | }) 25 | 26 | o.Spec("it errors if an unsupported type is passed in", func(expect expectation) { 27 | var f struct { 28 | Foo int 29 | } 30 | done, err := pers.ConsistentlyReturn(f, 1) 31 | expect(err).To(haveOccurred()) 32 | expect(err.Error()).To(containSubstring("unsupported type")) 33 | expect(done).To(beNil()) 34 | 35 | var e struct{} 36 | done, err = pers.ConsistentlyReturn(e) 37 | expect(err).To(haveOccurred()) 38 | expect(err.Error()).To(containSubstring("unsupported type")) 39 | expect(done).To(beNil()) 40 | }) 41 | 42 | o.Spec("it errors if there aren't enough arguments", func(expect expectation) { 43 | c := make(chan int) 44 | done, err := pers.ConsistentlyReturn(c) 45 | expect(err).To(haveOccurred()) 46 | expect(err.Error()).To(containSubstring("expected 1 argument")) 47 | expect(done).To(beNil()) 48 | 49 | var f struct { 50 | Foo chan int 51 | Bar chan string 52 | } 53 | done, err = pers.ConsistentlyReturn(f, 1) 54 | expect(err).To(haveOccurred()) 55 | expect(err.Error()).To(containSubstring("expected 2 arguments")) 56 | expect(done).To(beNil()) 57 | }) 58 | 59 | o.Spec("it errors if there are too many arguments", func(expect expectation) { 60 | c := make(chan int) 61 | done, err := pers.ConsistentlyReturn(c, 2, "foo") 62 | expect(err).To(haveOccurred()) 63 | expect(err.Error()).To(containSubstring("expected 1 argument")) 64 | expect(done).To(beNil()) 65 | 66 | var f struct { 67 | Foo chan int 68 | Bar chan string 69 | } 70 | done, err = pers.ConsistentlyReturn(f, 1, "foo", true) 71 | expect(err).To(haveOccurred()) 72 | expect(err.Error()).To(containSubstring("expected 2 arguments")) 73 | expect(done).To(beNil()) 74 | }) 75 | 76 | o.Spec("it handles nil values correctly", func(expect expectation) { 77 | c := make(chan error) 78 | done, err := pers.ConsistentlyReturn(c, nil) 79 | expect(err).To(not(haveOccurred())) 80 | defer done() 81 | for i := 0; i < 1000; i++ { 82 | expect(<-c).To(equal(nil)) 83 | } 84 | }) 85 | 86 | o.Spec("it consistently returns on a channel", func(expect expectation) { 87 | c := make(chan int) 88 | done, err := pers.ConsistentlyReturn(c, 1) 89 | expect(err).To(not(haveOccurred())) 90 | defer done() 91 | for i := 0; i < 1000; i++ { 92 | expect(<-c).To(equal(1)) 93 | } 94 | }) 95 | 96 | o.Spec("it consistently returns on a struct of channels", func(expect expectation) { 97 | type fooReturns struct { 98 | Foo chan string 99 | Bar chan bool 100 | } 101 | v := fooReturns{make(chan string), make(chan bool)} 102 | done, err := pers.ConsistentlyReturn(v, "foo", true) 103 | expect(err).To(not(haveOccurred())) 104 | defer done() 105 | for i := 0; i < 1000; i++ { 106 | expect(<-v.Foo).To(equal("foo")) 107 | expect(<-v.Bar).To(equal(true)) 108 | } 109 | }) 110 | 111 | o.Spec("it stops returning after done is called", func(expect expectation) { 112 | c := make(chan string) 113 | done, err := pers.ConsistentlyReturn(c, "foo") 114 | expect(err).To(not(haveOccurred())) 115 | done() 116 | expect(c).To(not(receiveMatcher{timeout: 100 * time.Millisecond})) 117 | }) 118 | } 119 | 120 | type receiveMatcher struct { 121 | timeout time.Duration 122 | } 123 | 124 | func (r receiveMatcher) Match(actual interface{}) (interface{}, error) { 125 | cases := []reflect.SelectCase{ 126 | {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(actual)}, 127 | {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(time.After(r.timeout))}, 128 | } 129 | i, v, _ := reflect.Select(cases) 130 | if i == 1 { 131 | return actual, fmt.Errorf("timed out after %s waiting for %#v to receive", r.timeout, actual) 132 | } 133 | return v.Interface(), nil 134 | } 135 | -------------------------------------------------------------------------------- /pers/havemethodexecuted.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package pers 6 | 7 | import ( 8 | "fmt" 9 | "reflect" 10 | "strings" 11 | "time" 12 | 13 | "github.com/poy/onpar/diff" 14 | "github.com/poy/onpar/expect" 15 | "github.com/poy/onpar/matchers" 16 | ) 17 | 18 | // Matcher is any type that can match values. Some code in this package supports 19 | // matching against child matchers, for example: 20 | // HaveBeenExecuted("Foo", WithArgs(matchers.HaveLen(12))) 21 | type Matcher interface { 22 | Match(actual interface{}) (interface{}, error) 23 | } 24 | 25 | type any int 26 | 27 | // Any is a special value to tell pers to allow any value at the position used. 28 | // For example, you can assert only on the second argument with: 29 | // HaveMethodExecuted("Foo", WithArgs(Any, 22)) 30 | const Any any = -1 31 | 32 | // HaveMethodExecutedOption is an option function for the HaveMethodExecutedMatcher. 33 | type HaveMethodExecutedOption func(HaveMethodExecutedMatcher) HaveMethodExecutedMatcher 34 | 35 | // Within returns a HaveMethodExecutedOption which sets the HaveMethodExecutedMatcher 36 | // to be executed within a given timeframe. 37 | func Within(d time.Duration) HaveMethodExecutedOption { 38 | return func(m HaveMethodExecutedMatcher) HaveMethodExecutedMatcher { 39 | m.within = d 40 | return m 41 | } 42 | } 43 | 44 | // WithArgs returns a HaveMethodExecutedOption which sets the HaveMethodExecutedMatcher 45 | // to only pass if the latest execution of the method called it with the passed in 46 | // arguments. 47 | func WithArgs(args ...interface{}) HaveMethodExecutedOption { 48 | return func(m HaveMethodExecutedMatcher) HaveMethodExecutedMatcher { 49 | m.args = args 50 | return m 51 | } 52 | } 53 | 54 | // StoreArgs returns a HaveMethodExecutedOption which stores the arguments passed to 55 | // the method in the addresses provided. 56 | // 57 | // StoreArgs will panic if the values provided are not pointers or cannot store data 58 | // of the same type as the method arguments. 59 | func StoreArgs(targets ...interface{}) HaveMethodExecutedOption { 60 | return func(m HaveMethodExecutedMatcher) HaveMethodExecutedMatcher { 61 | m.saveTo = targets 62 | return m 63 | } 64 | } 65 | 66 | // HaveMethodExecutedMatcher is a matcher to ensure that a method on a mock was 67 | // executed. 68 | type HaveMethodExecutedMatcher struct { 69 | MethodName string 70 | within time.Duration 71 | args []interface{} 72 | saveTo []interface{} 73 | 74 | differ matchers.Differ 75 | } 76 | 77 | // HaveMethodExecuted returns a matcher that asserts that the method referenced 78 | // by name was executed. Options can modify the behavior of the matcher. 79 | func HaveMethodExecuted(name string, opts ...HaveMethodExecutedOption) *HaveMethodExecutedMatcher { 80 | m := HaveMethodExecutedMatcher{MethodName: name, differ: diff.New()} 81 | for _, opt := range opts { 82 | m = opt(m) 83 | } 84 | return &m 85 | } 86 | 87 | // UseDiffer sets m to use d when showing a diff between actual and expected values. 88 | func (m *HaveMethodExecutedMatcher) UseDiffer(d matchers.Differ) { 89 | m.differ = d 90 | } 91 | 92 | // Match checks the mock value v to see if it has a method matching m.MethodName 93 | // which has been called. 94 | func (m HaveMethodExecutedMatcher) Match(v interface{}) (interface{}, error) { 95 | mv := reflect.ValueOf(v) 96 | method, exists := mv.Type().MethodByName(m.MethodName) 97 | if !exists { 98 | return v, fmt.Errorf("pers: could not find method '%s' on type %T", m.MethodName, v) 99 | } 100 | if mv.Kind() == reflect.Ptr { 101 | mv = mv.Elem() 102 | } 103 | calledField := mv.FieldByName(m.MethodName + "Called") 104 | cases := []reflect.SelectCase{ 105 | {Dir: reflect.SelectRecv, Chan: calledField}, 106 | } 107 | switch m.within { 108 | case 0: 109 | cases = append(cases, reflect.SelectCase{Dir: reflect.SelectDefault}) 110 | default: 111 | cases = append(cases, reflect.SelectCase{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(time.After(m.within))}) 112 | } 113 | 114 | chosen, _, _ := reflect.Select(cases) 115 | if chosen == 1 { 116 | return v, fmt.Errorf("pers: expected method %s to have been called, but it was not", m.MethodName) 117 | } 118 | inputField := mv.FieldByName(m.MethodName + "Input") 119 | if !inputField.IsValid() { 120 | return v, nil 121 | } 122 | 123 | var calledWith []interface{} 124 | for i := 0; i < inputField.NumField(); i++ { 125 | fv, ok := inputField.Field(i).Recv() 126 | if !ok { 127 | return v, fmt.Errorf("pers: field %s is closed; cannot perform matches against this mock", inputField.Type().Field(i).Name) 128 | } 129 | calledWith = append(calledWith, fv.Interface()) 130 | 131 | if m.saveTo != nil { 132 | reflect.ValueOf(m.saveTo[i]).Elem().Set(fv) 133 | } 134 | } 135 | if len(m.args) == 0 { 136 | return v, nil 137 | } 138 | 139 | args := append([]interface{}(nil), m.args...) 140 | if method.Type.IsVariadic() { 141 | lastTypeArg := method.Type.NumIn() - 1 142 | lastArg := lastTypeArg - 1 // lastTypeArg is including the receiver as an argument 143 | variadic := reflect.MakeSlice(method.Type.In(lastTypeArg), 0, 0) 144 | for i := lastArg; i < len(m.args); i++ { 145 | variadic = reflect.Append(variadic, reflect.ValueOf(m.args[i])) 146 | } 147 | args = append(args[:lastArg], variadic.Interface()) 148 | } 149 | if len(args) != len(calledWith) { 150 | return v, fmt.Errorf("pers: expected %d arguments, but got %d", len(calledWith), len(args)) 151 | } 152 | matched, diff := m.sliceDiff(reflect.ValueOf(calledWith), reflect.ValueOf(args)) 153 | if matched { 154 | return v, nil 155 | } 156 | const msg = "pers: %s was called with incorrect arguments: %s" 157 | return v, fmt.Errorf(msg, m.MethodName, diff) 158 | } 159 | 160 | func (m HaveMethodExecutedMatcher) sliceDiff(actual, expected reflect.Value) (bool, string) { 161 | if actual.Len() != expected.Len() { 162 | return false, m.differ.Diff(fmt.Sprintf("length %d", actual.Len()), fmt.Sprintf("length %d", expected.Len())) 163 | } 164 | var diffs []string 165 | matched := true 166 | for i := 0; i < actual.Len(); i++ { 167 | match, diff := m.valueDiff(actual.Index(i), expected.Index(i)) 168 | matched = matched && match 169 | diffs = append(diffs, diff) 170 | } 171 | return matched, fmt.Sprintf("[ %s ]", strings.Join(diffs, ", ")) 172 | } 173 | 174 | func (m HaveMethodExecutedMatcher) mapDiff(actual, expected reflect.Value) (bool, string) { 175 | matched := true 176 | var diffs []string 177 | for _, k := range expected.MapKeys() { 178 | eV := expected.MapIndex(k) 179 | aV := actual.MapIndex(k) 180 | if !aV.IsValid() { 181 | matched = false 182 | diffs = append(diffs, m.differ.Diff("missing key: %v", k.Interface())) 183 | continue 184 | } 185 | match, diff := m.valueDiff(aV, eV) 186 | matched = matched && match 187 | diffs = append(diffs, fmt.Sprintf(formatFor(k)+": %s", k.Interface(), diff)) 188 | } 189 | return matched, fmt.Sprintf("{ %s }", strings.Join(diffs, ", ")) 190 | } 191 | 192 | func (m HaveMethodExecutedMatcher) valueDiff(actual, expected reflect.Value) (bool, string) { 193 | for actual.Kind() == reflect.Interface { 194 | actual = actual.Elem() 195 | } 196 | for expected.Kind() == reflect.Interface { 197 | expected = expected.Elem() 198 | } 199 | if !actual.IsValid() || isNil(actual) { 200 | if !expected.IsValid() || isNil(expected) { 201 | return true, "" 202 | } 203 | } 204 | if !expected.IsValid() || isNil(expected) { 205 | return false, m.differ.Diff(actual.Interface(), nil) 206 | } 207 | 208 | format := formatFor(actual.Interface()) 209 | actualStr := fmt.Sprintf(format, actual.Interface()) 210 | switch src := expected.Interface().(type) { 211 | case any: 212 | return true, actualStr 213 | case Matcher: 214 | if dm, ok := src.(expect.DiffMatcher); ok { 215 | dm.UseDiffer(m.differ) 216 | } 217 | _, err := src.Match(actual.Interface()) 218 | if err != nil { 219 | return false, err.Error() 220 | } 221 | return true, actualStr 222 | default: 223 | if actual.Kind() != expected.Kind() { 224 | return false, m.differ.Diff(actual.Interface(), expected.Interface()) 225 | } 226 | switch actual.Kind() { 227 | case reflect.Slice, reflect.Array: 228 | return m.sliceDiff(actual, expected) 229 | case reflect.Map: 230 | return m.mapDiff(actual, expected) 231 | default: 232 | a, e := actual.Interface(), expected.Interface() 233 | if !reflect.DeepEqual(a, e) { 234 | return false, fmt.Sprintf(format, m.differ.Diff(a, e)) 235 | } 236 | return true, actualStr 237 | } 238 | } 239 | } 240 | 241 | func isNil(v reflect.Value) bool { 242 | switch v.Kind() { 243 | case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: 244 | return v.IsNil() 245 | default: 246 | return false 247 | } 248 | } 249 | 250 | // formatFor returns the format string that should be used for 251 | // the passed in actual type. 252 | func formatFor(actual interface{}) string { 253 | switch actual.(type) { 254 | case string: 255 | return `"%v"` 256 | default: 257 | return `%v` 258 | 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /pers/havemethodexecuted_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package pers_test 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "testing" 11 | "time" 12 | 13 | "github.com/nelsam/hel/pers" 14 | "github.com/poy/onpar" 15 | "github.com/poy/onpar/expect" 16 | "github.com/poy/onpar/matchers" 17 | ) 18 | 19 | type fakeMock struct { 20 | FooCalled chan struct{} 21 | FooInput struct { 22 | Arg0 chan int 23 | Arg1 chan string 24 | } 25 | FooOutput struct { 26 | Err chan error 27 | } 28 | BarCalled chan struct{} 29 | } 30 | 31 | func newFakeMock() *fakeMock { 32 | m := &fakeMock{} 33 | m.FooCalled = make(chan struct{}, 100) 34 | m.FooInput.Arg0 = make(chan int, 100) 35 | m.FooInput.Arg1 = make(chan string, 100) 36 | m.FooOutput.Err = make(chan error, 100) 37 | m.BarCalled = make(chan struct{}, 100) 38 | return m 39 | } 40 | 41 | func (m *fakeMock) Foo(arg0 int, arg1 string) error { 42 | m.FooCalled <- struct{}{} 43 | m.FooInput.Arg0 <- arg0 44 | m.FooInput.Arg1 <- arg1 45 | return <-m.FooOutput.Err 46 | } 47 | 48 | func (m *fakeMock) Bar() { 49 | m.BarCalled <- struct{}{} 50 | } 51 | 52 | type fakeVariadicMock struct { 53 | FooCalled chan struct{} 54 | FooInput struct { 55 | Args chan []string 56 | } 57 | } 58 | 59 | func newFakeVariadicMock() *fakeVariadicMock { 60 | m := &fakeVariadicMock{} 61 | m.FooCalled = make(chan struct{}, 100) 62 | m.FooInput.Args = make(chan []string, 100) 63 | return m 64 | } 65 | 66 | func (m *fakeVariadicMock) Foo(args ...string) { 67 | m.FooCalled <- struct{}{} 68 | m.FooInput.Args <- args 69 | } 70 | 71 | type fakeSliceMapMock struct { 72 | FooCalled chan struct{} 73 | FooInput struct { 74 | Arg0 chan []interface{} 75 | Arg1 chan map[string]interface{} 76 | } 77 | } 78 | 79 | func newFakeSliceMapMock() *fakeSliceMapMock { 80 | m := &fakeSliceMapMock{} 81 | m.FooCalled = make(chan struct{}, 100) 82 | m.FooInput.Arg0 = make(chan []interface{}, 100) 83 | m.FooInput.Arg1 = make(chan map[string]interface{}, 100) 84 | return m 85 | } 86 | 87 | func (m *fakeSliceMapMock) Foo(arg0 []interface{}, arg1 map[string]interface{}) { 88 | m.FooCalled <- struct{}{} 89 | m.FooInput.Arg0 <- arg0 90 | m.FooInput.Arg1 <- arg1 91 | } 92 | 93 | func TestHaveMethodExecuted(t *testing.T) { 94 | o := onpar.New() 95 | defer o.Run(t) 96 | 97 | o.BeforeEach(func(t *testing.T) (*testing.T, expectation) { 98 | return t, expect.New(t) 99 | }) 100 | 101 | o.Spec("it fails gracefully if the requested method isn't found", func(t *testing.T, expect expectation) { 102 | fm := newFakeMock() 103 | 104 | m := pers.HaveMethodExecuted("Invalid") 105 | _, err := m.Match(fm) 106 | expect(err).To(haveOccurred()) 107 | expect(err.Error()).To(equal("pers: could not find method 'Invalid' on type *pers_test.fakeMock")) 108 | }) 109 | 110 | o.Spec("it drains a value off of each relevant channel", func(t *testing.T, expect expectation) { 111 | fm := newFakeMock() 112 | fm.FooCalled <- struct{}{} 113 | fm.FooInput.Arg0 <- 0 114 | fm.FooInput.Arg1 <- "foo" 115 | 116 | m := pers.HaveMethodExecuted("Foo") 117 | m.Match(fm) 118 | 119 | select { 120 | case <-fm.FooCalled: 121 | t.Fatal("Expected HaveMethodExecuted to drain from the mock's FooCalled channel") 122 | case <-fm.FooInput.Arg0: 123 | t.Fatal("Expected HaveMethodExecuted to drain from the mock's first FooInput channel") 124 | case <-fm.FooInput.Arg1: 125 | t.Fatal("Expected HaveMethodExecuted to drain frim the mock's second FooInput channel") 126 | default: 127 | } 128 | }) 129 | 130 | o.Spec("it returns a success when the method has been called", func(t *testing.T, expect expectation) { 131 | fm := newFakeMock() 132 | fm.FooOutput.Err <- nil 133 | fm.Foo(1, "foo") 134 | 135 | m := pers.HaveMethodExecuted("Foo") 136 | _, err := m.Match(fm) 137 | expect(err).To(not(haveOccurred())) 138 | }) 139 | 140 | o.Spec("it returns a failure when the method has _not_ been called", func(t *testing.T, expect expectation) { 141 | m := pers.HaveMethodExecuted("Foo") 142 | _, err := m.Match(newFakeMock()) 143 | expect(err).To(haveOccurred()) 144 | expect(err.Error()).To(equal("pers: expected method Foo to have been called, but it was not")) 145 | }) 146 | 147 | o.Spec("it can handle methods with no input or output", func(t *testing.T, expect expectation) { 148 | fm := newFakeMock() 149 | fm.Bar() 150 | 151 | m := pers.HaveMethodExecuted("Bar") 152 | _, err := m.Match(fm) 153 | 154 | expect(err).To(not(haveOccurred())) 155 | }) 156 | 157 | o.Spec("it waits for a method to be called", func(t *testing.T, expect expectation) { 158 | fm := newFakeMock() 159 | fm.FooOutput.Err <- nil 160 | 161 | m := pers.HaveMethodExecuted("Foo", pers.Within(100*time.Millisecond)) 162 | errs := make(chan error) 163 | go func() { 164 | _, err := m.Match(fm) 165 | errs <- err 166 | }() 167 | 168 | fm.Foo(10, "bar") 169 | select { 170 | case err := <-errs: 171 | expect(err).To(not(haveOccurred())) 172 | case <-time.After(100 * time.Millisecond): 173 | t.Fatal("Expected Match to wait for Foo to be called") 174 | } 175 | }) 176 | 177 | for _, test := range []struct { 178 | name string 179 | arg0 interface{} 180 | arg1 interface{} 181 | err error 182 | }{ 183 | { 184 | name: "fails due to a mismatch on the first argument", 185 | arg0: 122, 186 | arg1: "this is a value", 187 | err: errors.New(`pers: Foo was called with incorrect arguments: [ >123!=122<, "this is a value" ]`), 188 | }, 189 | { 190 | name: "fails due to a mismatch on the second argument", 191 | arg0: 123, 192 | arg1: "this is a val", 193 | err: errors.New(`pers: Foo was called with incorrect arguments: [ 123, "this is a val>ue!=<" ]`), 194 | }, 195 | { 196 | name: "doesn't show Any in failing diff", 197 | arg0: 123, 198 | arg1: "this value", 199 | err: errors.New(`pers: Foo was called with incorrect arguments: [ 123, "this >is a != 310 | // 42 311 | // foobar 312 | } 313 | -------------------------------------------------------------------------------- /pers/helpers.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | // Package pers (hel/pers ... get it?) contains a bunch of helpers 6 | // for working with hel mocks. From making a mock consistently 7 | // return to matchers - they'll all be here. 8 | package pers 9 | -------------------------------------------------------------------------------- /pers/localized_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package pers_test 6 | 7 | import ( 8 | "github.com/poy/onpar/expect" 9 | "github.com/poy/onpar/matchers" 10 | ) 11 | 12 | // The types and variables in this file are mimicking dot imports, 13 | // without all of the disadvantages of dot imports. 14 | 15 | type expectation = expect.Expectation 16 | 17 | var ( 18 | equal = matchers.Equal 19 | not = matchers.Not 20 | haveOccurred = matchers.HaveOccurred 21 | beNil = matchers.BeNil 22 | containSubstring = matchers.ContainSubstring 23 | chain = matchers.Chain 24 | receive = matchers.Receive 25 | receiveWait = matchers.ReceiveWait 26 | ) 27 | -------------------------------------------------------------------------------- /pers/return.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package pers 6 | 7 | // Return will add a given value to the channel or struct of channels. 8 | // This isn't very useful with a single value, so it's intended more 9 | // to support structs full of channels, such as the ones that hel 10 | // generates for return values in its mocks. 11 | func Return(mock interface{}, args ...interface{}) error { 12 | cases, err := selectCases(mock, args...) 13 | if err != nil { 14 | return err 15 | } 16 | for _, c := range cases { 17 | c.Chan.Send(c.Send) 18 | } 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /pers/return_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package pers_test 6 | 7 | import ( 8 | "testing" 9 | "time" 10 | 11 | "github.com/nelsam/hel/pers" 12 | "github.com/poy/onpar" 13 | "github.com/poy/onpar/expect" 14 | ) 15 | 16 | func TestReturn(t *testing.T) { 17 | o := onpar.New() 18 | defer o.Run(t) 19 | 20 | o.BeforeEach(func(t *testing.T) expectation { 21 | return expect.New(t) 22 | }) 23 | 24 | o.Spec("it errors if an unexpected type is passed in", func(expect expectation) { 25 | var f struct { 26 | Foo int 27 | } 28 | err := pers.Return(f, 1) 29 | expect(err).To(haveOccurred()) 30 | expect(err.Error()).To(containSubstring("unsupported type")) 31 | 32 | var e struct{} 33 | err = pers.Return(e) 34 | expect(err).To(haveOccurred()) 35 | expect(err.Error()).To(containSubstring("unsupported type")) 36 | }) 37 | 38 | o.Spec("it errors if there aren't enough arguments", func(expect expectation) { 39 | c := make(chan int) 40 | err := pers.Return(c) 41 | expect(err).To(haveOccurred()) 42 | expect(err.Error()).To(containSubstring("expected 1 argument")) 43 | 44 | var f struct { 45 | Foo chan int 46 | Bar chan string 47 | } 48 | err = pers.Return(f, 1) 49 | expect(err).To(haveOccurred()) 50 | expect(err.Error()).To(containSubstring("expected 2 arguments")) 51 | }) 52 | 53 | o.Spec("it errors if there are too many arguments", func(expect expectation) { 54 | c := make(chan int) 55 | err := pers.Return(c, 2, "foo") 56 | expect(err).To(haveOccurred()) 57 | expect(err.Error()).To(containSubstring("expected 1 argument")) 58 | 59 | var f struct { 60 | Foo chan int 61 | Bar chan string 62 | } 63 | err = pers.Return(f, 1, "foo", true) 64 | expect(err).To(haveOccurred()) 65 | expect(err.Error()).To(containSubstring("expected 2 arguments")) 66 | }) 67 | 68 | wait := receiveWait(100 * time.Millisecond) 69 | 70 | o.Spec("it handles nil values correctly", func(expect expectation) { 71 | c := make(chan error) 72 | errs := make(chan error) 73 | go func() { 74 | errs <- pers.Return(c, nil) 75 | }() 76 | expect(c).To(chain(receive(wait), equal(nil))) 77 | expect(errs).To(chain(receive(wait), not(haveOccurred()))) 78 | }) 79 | 80 | o.Spec("it returns on a channel", func(expect expectation) { 81 | c := make(chan int) 82 | errs := make(chan error) 83 | go func() { 84 | errs <- pers.Return(c, 1) 85 | }() 86 | expect(c).To(chain(receive(wait), equal(1))) 87 | expect(errs).To(chain(receive(wait), not(haveOccurred()))) 88 | }) 89 | 90 | o.Spec("it returns on a struct of channels", func(expect expectation) { 91 | type fooReturns struct { 92 | Foo chan string 93 | Bar chan bool 94 | } 95 | v := fooReturns{make(chan string), make(chan bool)} 96 | errs := make(chan error) 97 | go func() { 98 | errs <- pers.Return(v, "foo", true) 99 | }() 100 | expect(v.Foo).To(chain(receive(wait), equal("foo"))) 101 | expect(v.Bar).To(chain(receive(wait), equal(true))) 102 | expect(errs).To(chain(receive(wait), not(haveOccurred()))) 103 | }) 104 | 105 | } 106 | -------------------------------------------------------------------------------- /types/doc.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | // Package types contains logic for parsing type definitions from ast 6 | // packages and filtering those types. 7 | package types 8 | 9 | //go:generate hel 10 | -------------------------------------------------------------------------------- /types/helheim_test.go: -------------------------------------------------------------------------------- 1 | // This file was generated by github.com/nelsam/hel. Do not 2 | // edit this code by hand unless you *really* know what you're 3 | // doing. Expect any changes made manually to be overwritten 4 | // the next time hel regenerates this file. 5 | 6 | package types_test 7 | 8 | import ( 9 | "golang.org/x/tools/go/packages" 10 | ) 11 | 12 | type mockGoDir struct { 13 | PathCalled chan bool 14 | PathOutput struct { 15 | Path chan string 16 | } 17 | PackageCalled chan bool 18 | PackageOutput struct { 19 | Pkg chan *packages.Package 20 | } 21 | ImportCalled chan bool 22 | ImportInput struct { 23 | Path chan string 24 | } 25 | ImportOutput struct { 26 | Pkg chan *packages.Package 27 | Err chan error 28 | } 29 | } 30 | 31 | func newMockGoDir() *mockGoDir { 32 | m := &mockGoDir{} 33 | m.PathCalled = make(chan bool, 100) 34 | m.PathOutput.Path = make(chan string, 100) 35 | m.PackageCalled = make(chan bool, 100) 36 | m.PackageOutput.Pkg = make(chan *packages.Package, 100) 37 | m.ImportCalled = make(chan bool, 100) 38 | m.ImportInput.Path = make(chan string, 100) 39 | m.ImportOutput.Pkg = make(chan *packages.Package, 100) 40 | m.ImportOutput.Err = make(chan error, 100) 41 | return m 42 | } 43 | func (m *mockGoDir) Path() (path string) { 44 | m.PathCalled <- true 45 | return <-m.PathOutput.Path 46 | } 47 | func (m *mockGoDir) Package() (pkg *packages.Package) { 48 | m.PackageCalled <- true 49 | return <-m.PackageOutput.Pkg 50 | } 51 | func (m *mockGoDir) Import(path string) (pkg *packages.Package, err error) { 52 | m.ImportCalled <- true 53 | m.ImportInput.Path <- path 54 | return <-m.ImportOutput.Pkg, <-m.ImportOutput.Err 55 | } 56 | -------------------------------------------------------------------------------- /types/helpers_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package types_test 6 | 7 | import ( 8 | "go/ast" 9 | "go/parser" 10 | "go/token" 11 | ) 12 | 13 | const packagePrefix = "package foo\n\n" 14 | 15 | func parse(expect expectation, code string) *ast.File { 16 | f, err := parser.ParseFile(token.NewFileSet(), "", packagePrefix+code, 0) 17 | expect(err).To(not(haveOccurred())) 18 | return f 19 | } 20 | -------------------------------------------------------------------------------- /types/types.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package types 6 | 7 | import ( 8 | "fmt" 9 | "go/ast" 10 | "log" 11 | "regexp" 12 | "strings" 13 | "unicode" 14 | 15 | "golang.org/x/tools/go/packages" 16 | ) 17 | 18 | var ( 19 | // errorMethod is the type of the Error method on error types. 20 | // It's defined here for any interface types that embed error. 21 | errorMethod = &ast.Field{ 22 | Names: []*ast.Ident{{Name: "Error"}}, 23 | Type: &ast.FuncType{ 24 | Params: &ast.FieldList{}, 25 | Results: &ast.FieldList{ 26 | List: []*ast.Field{{Type: &ast.Ident{Name: "string"}}}, 27 | }, 28 | }, 29 | } 30 | ) 31 | 32 | // A GoDir is a type that represents a directory of Go files. 33 | type GoDir interface { 34 | Path() (path string) 35 | Package() (pkg *packages.Package) 36 | Import(path string) (pkg *packages.Package, err error) 37 | } 38 | 39 | // A Dependency is a struct containing a package and a dependent 40 | // type spec. 41 | type Dependency struct { 42 | Type *ast.TypeSpec 43 | PkgName string 44 | PkgPath string 45 | } 46 | 47 | // A Dir is a type that represents a directory containing Go 48 | // packages. 49 | type Dir struct { 50 | dir string 51 | pkg string 52 | types []*ast.TypeSpec 53 | dependencies map[*ast.InterfaceType][]Dependency 54 | } 55 | 56 | // Dir returns the directory path that d represents. 57 | func (d Dir) Dir() string { 58 | return d.dir 59 | } 60 | 61 | // Len returns the number of types that will be returned by 62 | // d.ExportedTypes(). 63 | func (d Dir) Len() int { 64 | return len(d.types) 65 | } 66 | 67 | // Package returns the name of d's importable package. 68 | func (d Dir) Package() string { 69 | return d.pkg 70 | } 71 | 72 | // ExportedTypes returns all *ast.TypeSpecs found by d. Interface 73 | // types with anonymous interface types will be flattened, for ease of 74 | // mocking by other logic. 75 | func (d Dir) ExportedTypes() []*ast.TypeSpec { 76 | return d.types 77 | } 78 | 79 | // Dependencies returns all interface types that typ depends on for 80 | // method parameters or results. 81 | func (d Dir) Dependencies(typ *ast.InterfaceType) []Dependency { 82 | return d.dependencies[typ] 83 | } 84 | 85 | func (d Dir) addPkg(pkg *packages.Package, dir GoDir) Dir { 86 | newTypes, depMap := loadPkgTypeSpecs(pkg, dir) 87 | if d.pkg == "" { 88 | d.pkg = pkg.Name 89 | } 90 | for inter, deps := range depMap { 91 | d.dependencies[inter] = append(d.dependencies[inter], deps...) 92 | } 93 | d.types = append(d.types, newTypes...) 94 | return d 95 | } 96 | 97 | // Filter filters d's types, removing all types that don't match any 98 | // of the passed in matchers. 99 | func (d Dir) Filter(matchers ...*regexp.Regexp) Dir { 100 | oldTypes := d.ExportedTypes() 101 | d.types = make([]*ast.TypeSpec, 0, d.Len()) 102 | for _, typ := range oldTypes { 103 | for _, matcher := range matchers { 104 | if !matcher.MatchString(typ.Name.String()) { 105 | continue 106 | } 107 | d.types = append(d.types, typ) 108 | break 109 | } 110 | } 111 | return d 112 | } 113 | 114 | // Dirs is a slice of Dir values, to provide sugar for running some 115 | // methods against multiple Dir values. 116 | type Dirs []Dir 117 | 118 | // Load loads a Dirs value for goDirs. 119 | func Load(goDirs ...GoDir) Dirs { 120 | typeDirs := make(Dirs, 0, len(goDirs)) 121 | for _, dir := range goDirs { 122 | d := Dir{ 123 | pkg: dir.Package().Name, 124 | dir: dir.Path(), 125 | dependencies: make(map[*ast.InterfaceType][]Dependency), 126 | } 127 | d = d.addPkg(dir.Package(), dir) 128 | typeDirs = append(typeDirs, d) 129 | } 130 | return typeDirs 131 | } 132 | 133 | // Filter calls Dir.Filter for each Dir in d. 134 | func (d Dirs) Filter(patterns ...string) (dirs Dirs) { 135 | if len(patterns) == 0 { 136 | return d 137 | } 138 | matchers := make([]*regexp.Regexp, 0, len(patterns)) 139 | for _, pattern := range patterns { 140 | matchers = append(matchers, regexp.MustCompile("^"+pattern+"$")) 141 | } 142 | for _, dir := range d { 143 | dir = dir.Filter(matchers...) 144 | if dir.Len() > 0 { 145 | dirs = append(dirs, dir) 146 | } 147 | } 148 | return dirs 149 | } 150 | 151 | // dependencies returns all *ast.TypeSpec values with a Type of 152 | // *ast.InterfaceType. It assumes that typ is pre-flattened. 153 | func dependencies(typ *ast.InterfaceType, available []*ast.TypeSpec, withImports []*ast.ImportSpec, dir GoDir) []Dependency { 154 | if typ.Methods == nil { 155 | return nil 156 | } 157 | dependencies := make(map[*ast.TypeSpec]Dependency) 158 | for _, meth := range typ.Methods.List { 159 | f := meth.Type.(*ast.FuncType) 160 | addSpecs(dependencies, loadDependencies(f.Params, available, withImports, dir)...) 161 | addSpecs(dependencies, loadDependencies(f.Results, available, withImports, dir)...) 162 | } 163 | dependentSlice := make([]Dependency, 0, len(dependencies)) 164 | for _, dependent := range dependencies { 165 | dependentSlice = append(dependentSlice, dependent) 166 | } 167 | return dependentSlice 168 | } 169 | 170 | func addSpecs(set map[*ast.TypeSpec]Dependency, values ...Dependency) { 171 | for _, value := range values { 172 | set[value.Type] = value 173 | } 174 | } 175 | 176 | func names(specs []*ast.TypeSpec) (names []string) { 177 | for _, spec := range specs { 178 | names = append(names, spec.Name.String()) 179 | } 180 | return names 181 | } 182 | 183 | func loadDependencies(fields *ast.FieldList, available []*ast.TypeSpec, withImports []*ast.ImportSpec, dir GoDir) (dependencies []Dependency) { 184 | if fields == nil { 185 | return nil 186 | } 187 | for _, field := range fields.List { 188 | switch src := field.Type.(type) { 189 | case *ast.Ident: 190 | for _, spec := range available { 191 | if spec.Name.String() == src.String() { 192 | if _, ok := spec.Type.(*ast.InterfaceType); ok { 193 | dependencies = append(dependencies, Dependency{ 194 | Type: spec, 195 | }) 196 | } 197 | break 198 | } 199 | } 200 | case *ast.SelectorExpr: 201 | selectorName := src.X.(*ast.Ident).String() 202 | for _, imp := range withImports { 203 | importPath := strings.Trim(imp.Path.Value, `"`) 204 | importName := imp.Name.String() 205 | pkg, err := dir.Import(importPath) 206 | if err != nil { 207 | log.Printf("Error loading dependencies: %s", err) 208 | continue 209 | } 210 | if imp.Name == nil { 211 | importName = pkg.Name 212 | } 213 | if selectorName != importName { 214 | continue 215 | } 216 | d := Dir{ 217 | dependencies: make(map[*ast.InterfaceType][]Dependency), 218 | } 219 | d = d.addPkg(pkg, dir) 220 | for _, typ := range d.ExportedTypes() { 221 | if typ.Name.String() == src.Sel.String() { 222 | if _, ok := typ.Type.(*ast.InterfaceType); ok { 223 | dependencies = append(dependencies, Dependency{ 224 | Type: typ, 225 | PkgName: importName, 226 | PkgPath: importPath, 227 | }) 228 | } 229 | break 230 | } 231 | } 232 | } 233 | case *ast.FuncType: 234 | dependencies = append(dependencies, loadDependencies(src.Params, available, withImports, dir)...) 235 | dependencies = append(dependencies, loadDependencies(src.Results, available, withImports, dir)...) 236 | } 237 | } 238 | return dependencies 239 | } 240 | 241 | func loadPkgTypeSpecs(pkg *packages.Package, dir GoDir) (specs []*ast.TypeSpec, depMap map[*ast.InterfaceType][]Dependency) { 242 | depMap = make(map[*ast.InterfaceType][]Dependency) 243 | imports := make(map[*ast.TypeSpec][]*ast.ImportSpec) 244 | defer func() { 245 | for _, spec := range specs { 246 | inter, ok := spec.Type.(*ast.InterfaceType) 247 | if !ok { 248 | continue 249 | } 250 | depMap[inter] = dependencies(inter, specs, imports[spec], dir) 251 | } 252 | }() 253 | for _, f := range pkg.Syntax { 254 | fileImports := f.Imports 255 | fileSpecs := loadFileTypeSpecs(f) 256 | for _, spec := range fileSpecs { 257 | imports[spec] = fileImports 258 | } 259 | 260 | // flattenAnon needs to be called for each file, but the 261 | // withSpecs parameter needs *all* specs, from *all* files. 262 | // So we defer the flatten call until all files are processed. 263 | defer func() { 264 | flattenAnon(fileSpecs, specs, fileImports, dir) 265 | }() 266 | 267 | specs = append(specs, fileSpecs...) 268 | } 269 | return specs, depMap 270 | } 271 | 272 | func loadFileTypeSpecs(f *ast.File) (specs []*ast.TypeSpec) { 273 | for _, obj := range f.Scope.Objects { 274 | spec, ok := obj.Decl.(*ast.TypeSpec) 275 | if !ok { 276 | continue 277 | } 278 | if _, ok := spec.Type.(*ast.InterfaceType); !ok { 279 | continue 280 | } 281 | specs = append(specs, spec) 282 | } 283 | return specs 284 | } 285 | 286 | func flattenAnon(specs, withSpecs []*ast.TypeSpec, withImports []*ast.ImportSpec, dir GoDir) { 287 | for _, spec := range specs { 288 | inter := spec.Type.(*ast.InterfaceType) 289 | flatten(inter, withSpecs, withImports, dir) 290 | } 291 | } 292 | 293 | func flatten(inter *ast.InterfaceType, withSpecs []*ast.TypeSpec, withImports []*ast.ImportSpec, dir GoDir) { 294 | if inter.Methods == nil { 295 | return 296 | } 297 | methods := make([]*ast.Field, 0, len(inter.Methods.List)) 298 | for _, method := range inter.Methods.List { 299 | switch src := method.Type.(type) { 300 | case *ast.FuncType: 301 | methods = append(methods, method) 302 | case *ast.Ident: 303 | methods = append(methods, findAnonMethods(src, withSpecs, withImports, dir)...) 304 | case *ast.SelectorExpr: 305 | importedTypes, _ := findImportedTypes(src.X.(*ast.Ident), withImports, dir) 306 | methods = append(methods, findAnonMethods(src.Sel, importedTypes, nil, dir)...) 307 | } 308 | } 309 | inter.Methods.List = methods 310 | } 311 | 312 | func findImportedTypes(name *ast.Ident, withImports []*ast.ImportSpec, dir GoDir) ([]*ast.TypeSpec, map[*ast.InterfaceType][]Dependency) { 313 | importName := name.String() 314 | for _, imp := range withImports { 315 | path := strings.Trim(imp.Path.Value, `"`) 316 | pkg, err := dir.Import(path) 317 | if err != nil { 318 | log.Printf("Error loading import: %s", err) 319 | continue 320 | } 321 | name := pkg.Name 322 | if imp.Name != nil { 323 | name = imp.Name.String() 324 | } 325 | if name != importName { 326 | continue 327 | } 328 | typs, deps := loadPkgTypeSpecs(pkg, dir) 329 | addSelector(typs, importName) 330 | return typs, deps 331 | } 332 | return nil, nil 333 | } 334 | 335 | func addSelector(typs []*ast.TypeSpec, selector string) { 336 | for _, typ := range typs { 337 | inter := typ.Type.(*ast.InterfaceType) 338 | for _, meth := range inter.Methods.List { 339 | addFuncSelectors(meth.Type.(*ast.FuncType), selector) 340 | } 341 | } 342 | } 343 | 344 | func addFuncSelectors(method *ast.FuncType, selector string) { 345 | if method.Params != nil { 346 | addFieldSelectors(method.Params.List, selector) 347 | } 348 | if method.Results != nil { 349 | addFieldSelectors(method.Results.List, selector) 350 | } 351 | } 352 | 353 | func addFieldSelectors(fields []*ast.Field, selector string) { 354 | for idx, field := range fields { 355 | fields[idx] = addFieldSelector(field, selector) 356 | } 357 | } 358 | 359 | func addFieldSelector(field *ast.Field, selector string) *ast.Field { 360 | switch src := field.Type.(type) { 361 | case *ast.Ident: 362 | if !unicode.IsUpper(rune(src.String()[0])) { 363 | return field 364 | } 365 | return &ast.Field{ 366 | Type: &ast.SelectorExpr{ 367 | X: &ast.Ident{Name: selector}, 368 | Sel: src, 369 | }, 370 | } 371 | case *ast.FuncType: 372 | addFuncSelectors(src, selector) 373 | } 374 | return field 375 | } 376 | 377 | func findAnonMethods(ident *ast.Ident, withSpecs []*ast.TypeSpec, withImports []*ast.ImportSpec, dir GoDir) []*ast.Field { 378 | var spec *ast.TypeSpec 379 | for idx := range withSpecs { 380 | if withSpecs[idx].Name.String() == ident.Name { 381 | spec = withSpecs[idx] 382 | break 383 | } 384 | } 385 | if spec == nil { 386 | if ident.Name != "error" { 387 | // TODO: do something nicer with this error. 388 | panic(fmt.Errorf("Can't find anonymous type %s", ident.Name)) 389 | } 390 | return []*ast.Field{errorMethod} 391 | } 392 | anon := spec.Type.(*ast.InterfaceType) 393 | flatten(anon, withSpecs, withImports, dir) 394 | return anon.Methods.List 395 | } 396 | -------------------------------------------------------------------------------- /types/types_test.go: -------------------------------------------------------------------------------- 1 | // This is free and unencumbered software released into the public 2 | // domain. For more information, see or the 3 | // accompanying UNLICENSE file. 4 | 5 | package types_test 6 | 7 | import ( 8 | "go/ast" 9 | "testing" 10 | 11 | "github.com/nelsam/hel/pers" 12 | "github.com/nelsam/hel/types" 13 | "github.com/poy/onpar" 14 | "github.com/poy/onpar/expect" 15 | "github.com/poy/onpar/matchers" 16 | "golang.org/x/tools/go/packages" 17 | ) 18 | 19 | type expectation = expect.Expectation 20 | 21 | var ( 22 | equal = matchers.Equal 23 | not = matchers.Not 24 | haveOccurred = matchers.HaveOccurred 25 | haveLen = matchers.HaveLen 26 | beNil = matchers.BeNil 27 | beTrue = matchers.BeTrue 28 | ) 29 | 30 | func TestTypes(t *testing.T) { 31 | o := onpar.New() 32 | defer o.Run(t) 33 | 34 | o.BeforeEach(func(t *testing.T) (expectation, *mockGoDir) { 35 | return expect.New(t), newMockGoDir() 36 | }) 37 | 38 | o.Spec("Load_EmptyInterface", func(expect expectation, mockGoDir *mockGoDir) { 39 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 40 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 41 | Name: "foo", 42 | Syntax: []*ast.File{ 43 | parse(expect, "type Foo interface {}"), 44 | }, 45 | }) 46 | found := types.Load(mockGoDir) 47 | expect(found).To(haveLen(1)) 48 | expect(found[0].Len()).To(equal(1)) 49 | expect(found[0].Dir()).To(equal("/some/path")) 50 | expect(found[0].Package()).To(equal("foo")) 51 | }) 52 | 53 | o.Spec("Filter", func(expect expectation, mockGoDir *mockGoDir) { 54 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 55 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 56 | Name: "foo", 57 | Syntax: []*ast.File{ 58 | parse(expect, ` 59 | type Foo interface {} 60 | type Bar interface {} 61 | type FooBar interface {} 62 | type BarFoo interface {} 63 | `), 64 | }, 65 | }) 66 | found := types.Load(mockGoDir) 67 | expect(found).To(haveLen(1)) 68 | expect(found[0].Len()).To(equal(4)) 69 | 70 | notFiltered := found.Filter() 71 | expect(notFiltered).To(haveLen(1)) 72 | expect(notFiltered[0].Len()).To(equal(4)) 73 | 74 | foos := found.Filter("Foo") 75 | expect(foos).To(haveLen(1)) 76 | expect(foos[0].Len()).To(equal(1)) 77 | expect(foos[0].ExportedTypes()[0].Name.String()).To(equal("Foo")) 78 | 79 | fooPrefixes := found.Filter("Foo.*") 80 | expect(fooPrefixes).To(haveLen(1)) 81 | expect(fooPrefixes[0].Len()).To(equal(2)) 82 | expectNamesToMatch(expect, fooPrefixes[0].ExportedTypes(), "Foo", "FooBar") 83 | 84 | fooPostfixes := found.Filter(".*Foo") 85 | expect(fooPostfixes).To(haveLen(1)) 86 | expect(fooPostfixes[0].Len()).To(equal(2)) 87 | expectNamesToMatch(expect, fooPostfixes[0].ExportedTypes(), "Foo", "BarFoo") 88 | 89 | fooContainers := found.Filter("Foo.*", ".*Foo") 90 | expect(fooContainers).To(haveLen(1)) 91 | expect(fooContainers[0].Len()).To(equal(3)) 92 | expectNamesToMatch(expect, fooContainers[0].ExportedTypes(), "Foo", "FooBar", "BarFoo") 93 | }) 94 | 95 | o.Spec("LocalDependencies", func(expect expectation, mockGoDir *mockGoDir) { 96 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 97 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 98 | Name: "bar", 99 | Syntax: []*ast.File{ 100 | parse(expect, ` 101 | 102 | type Bar interface{ 103 | Bar(Foo) Foo 104 | }`), 105 | parse(expect, ` 106 | 107 | type Foo interface { 108 | Foo() 109 | }`), 110 | }, 111 | }) 112 | 113 | found := types.Load(mockGoDir) 114 | 115 | expect(found).To(haveLen(1)) 116 | mockables := found[0].ExportedTypes() 117 | expect(mockables).To(haveLen(2)) 118 | 119 | var foo, bar *ast.TypeSpec 120 | for _, mockable := range mockables { 121 | switch mockable.Name.String() { 122 | case "Bar": 123 | bar = mockable 124 | case "Foo": 125 | foo = mockable 126 | } 127 | } 128 | if bar == nil { 129 | t.Fatal("expected to find a Bar type") 130 | } 131 | 132 | expect(found).To(haveLen(1)) 133 | dependencies := found[0].Dependencies(bar.Type.(*ast.InterfaceType)) 134 | expect(dependencies).To(haveLen(1)) 135 | expect(dependencies[0].Type).To(equal(foo)) 136 | expect(dependencies[0].PkgName).To(equal("")) 137 | expect(dependencies[0].PkgPath).To(equal("")) 138 | }) 139 | 140 | o.Spec("ImportedDependencies", func(expect expectation, mockGoDir *mockGoDir) { 141 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 142 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 143 | Name: "bar", 144 | Syntax: []*ast.File{ 145 | parse(expect, ` 146 | 147 | import "some/path/to/foo" 148 | 149 | type Bar interface{ 150 | Bar(foo.Foo) foo.Bar 151 | }`), 152 | }, 153 | }) 154 | 155 | pkgName := "foo" 156 | pkg := &packages.Package{ 157 | Name: pkgName, 158 | Syntax: []*ast.File{ 159 | parse(expect, ` 160 | type Foo interface { 161 | Foo() 162 | } 163 | 164 | type Bar interface { 165 | Bar() 166 | }`), 167 | }, 168 | } 169 | done, err := pers.ConsistentlyReturn(mockGoDir.ImportOutput, pkg, nil) 170 | expect(err).To(not(haveOccurred())) 171 | defer done() 172 | 173 | found := types.Load(mockGoDir) 174 | expect(found).To(haveLen(1)) 175 | expect(mockGoDir.ImportCalled).To(haveLen(2)) 176 | 177 | expect(<-mockGoDir.ImportInput.Path).To(equal("some/path/to/foo")) 178 | 179 | mockables := found[0].ExportedTypes() 180 | expect(mockables).To(haveLen(1)) 181 | if mockables[0] == nil { 182 | t.Fatal("expected mockables[0] to be non-nil") 183 | } 184 | 185 | dependencies := found[0].Dependencies(mockables[0].Type.(*ast.InterfaceType)) 186 | expect(dependencies).To(haveLen(2)) 187 | 188 | names := make(map[string]bool) 189 | for _, dependent := range dependencies { 190 | expect(dependent.PkgName).To(equal("foo")) 191 | expect(dependent.PkgPath).To(equal("some/path/to/foo")) 192 | names[dependent.Type.Name.String()] = true 193 | } 194 | expect(names).To(equal(map[string]bool{"Foo": true, "Bar": true})) 195 | }) 196 | 197 | o.Spec("AliasedImportedDependencies", func(expect expectation, mockGoDir *mockGoDir) { 198 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 199 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 200 | Name: "bar", 201 | Syntax: []*ast.File{ 202 | parse(expect, ` 203 | 204 | import baz "some/path/to/foo" 205 | 206 | type Bar interface{ 207 | Bar(baz.Foo) baz.Bar 208 | }`), 209 | }, 210 | }) 211 | 212 | pkgName := "foo" 213 | pkg := &packages.Package{ 214 | Name: pkgName, 215 | Syntax: []*ast.File{ 216 | parse(expect, ` 217 | type Foo interface { 218 | Foo() 219 | } 220 | 221 | type Bar interface { 222 | Bar() 223 | }`), 224 | }, 225 | } 226 | done, err := pers.ConsistentlyReturn(mockGoDir.ImportOutput, pkg, nil) 227 | expect(err).To(not(haveOccurred())) 228 | defer done() 229 | 230 | found := types.Load(mockGoDir) 231 | expect(mockGoDir.ImportCalled).To(haveLen(2)) 232 | expect(<-mockGoDir.ImportInput.Path).To(equal("some/path/to/foo")) 233 | 234 | expect(found).To(haveLen(1)) 235 | mockables := found[0].ExportedTypes() 236 | expect(mockables).To(haveLen(1)) 237 | 238 | dependencies := found[0].Dependencies(mockables[0].Type.(*ast.InterfaceType)) 239 | expect(dependencies).To(haveLen(2)) 240 | 241 | names := make(map[string]bool) 242 | for _, dependent := range dependencies { 243 | expect(dependent.PkgName).To(equal("baz")) 244 | expect(dependent.PkgPath).To(equal("some/path/to/foo")) 245 | names[dependent.Type.Name.String()] = true 246 | } 247 | expect(names).To(equal(map[string]bool{"Foo": true, "Bar": true})) 248 | }) 249 | 250 | // TestAnonymousError is testing the only case (as of go 1.7) where 251 | // a builtin is an interface type. 252 | o.Spec("AnonymousError", func(expect expectation, mockGoDir *mockGoDir) { 253 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 254 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 255 | Name: "foo", 256 | Syntax: []*ast.File{ 257 | parse(expect, ` 258 | type Foo interface{ 259 | error 260 | }`), 261 | }, 262 | }) 263 | found := types.Load(mockGoDir) 264 | expect(found).To(haveLen(1)) 265 | 266 | typs := found[0].ExportedTypes() 267 | expect(typs).To(haveLen(1)) 268 | 269 | spec := typs[0] 270 | expect(spec).To(not(beNil())) 271 | 272 | inter := spec.Type.(*ast.InterfaceType) 273 | expect(inter.Methods.List).To(haveLen(1)) 274 | err := inter.Methods.List[0] 275 | expect(err.Names[0].String()).To(equal("Error")) 276 | _, isFunc := err.Type.(*ast.FuncType) 277 | expect(isFunc).To(beTrue()) 278 | }) 279 | 280 | o.Spec("AnonymousLocalTypes", func(expect expectation, mockGoDir *mockGoDir) { 281 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 282 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 283 | Name: "foo", 284 | Syntax: []*ast.File{ 285 | parse(expect, ` 286 | type Bar interface{ 287 | Foo 288 | Bar() 289 | }`), 290 | parse(expect, ` 291 | type Foo interface{ 292 | Foo() 293 | }`), 294 | }, 295 | }) 296 | found := types.Load(mockGoDir) 297 | expect(found).To(haveLen(1)) 298 | 299 | typs := found[0].ExportedTypes() 300 | expect(typs).To(haveLen(2)) 301 | 302 | spec := find(expect, typs, "Bar") 303 | expect(spec).To(not(beNil())) 304 | inter := spec.Type.(*ast.InterfaceType) 305 | expect(inter.Methods.List).To(haveLen(2)) 306 | foo := inter.Methods.List[0] 307 | expect(foo.Names[0].String()).To(equal("Foo")) 308 | _, isFunc := foo.Type.(*ast.FuncType) 309 | expect(isFunc).To(beTrue()) 310 | }) 311 | 312 | o.Spec("AnonymousImportedTypes", func(expect expectation, mockGoDir *mockGoDir) { 313 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 314 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 315 | Name: "bar", 316 | Syntax: []*ast.File{ 317 | parse(expect, ` 318 | 319 | import "some/path/to/foo" 320 | 321 | type Bar interface{ 322 | foo.Foo 323 | Bar() 324 | }`), 325 | }, 326 | }) 327 | 328 | pkgName := "foo" 329 | pkg := &packages.Package{ 330 | Name: pkgName, 331 | Syntax: []*ast.File{ 332 | parse(expect, ` 333 | type Foo interface { 334 | Foo(x X) Y 335 | } 336 | 337 | type X int 338 | type Y int`), 339 | }, 340 | } 341 | done, err := pers.ConsistentlyReturn(mockGoDir.ImportOutput, pkg, nil) 342 | expect(err).To(not(haveOccurred())) 343 | defer done() 344 | 345 | found := types.Load(mockGoDir) 346 | 347 | // 3 calls: 1 for the initial import, then deps imports for X and Y 348 | expect(mockGoDir.ImportCalled).To(haveLen(3)) 349 | expect(<-mockGoDir.ImportInput.Path).To(equal("some/path/to/foo")) 350 | 351 | expect(found).To(haveLen(1)) 352 | typs := found[0].ExportedTypes() 353 | expect(typs).To(haveLen(1)) 354 | 355 | spec := typs[0] 356 | expect(spec).To(not(beNil())) 357 | inter := spec.Type.(*ast.InterfaceType) 358 | expect(inter.Methods.List).To(haveLen(2)) 359 | 360 | foo := inter.Methods.List[0] 361 | expect(foo.Names[0].String()).To(equal("Foo")) 362 | f, isFunc := foo.Type.(*ast.FuncType) 363 | expect(isFunc).To(beTrue()) 364 | expect(f.Params.List).To(haveLen(1)) 365 | expect(f.Results.List).To(haveLen(1)) 366 | expr, isSelector := f.Params.List[0].Type.(*ast.SelectorExpr) 367 | expect(isSelector).To(beTrue()) 368 | expect(expr.X.(*ast.Ident).String()).To(equal("foo")) 369 | expect(expr.Sel.String()).To(equal("X")) 370 | expr, isSelector = f.Results.List[0].Type.(*ast.SelectorExpr) 371 | expect(isSelector).To(beTrue()) 372 | expect(expr.X.(*ast.Ident).String()).To(equal("foo")) 373 | expect(expr.Sel.String()).To(equal("Y")) 374 | }) 375 | 376 | o.Spec("AnonymousAliasedImportedTypes", func(expect expectation, mockGoDir *mockGoDir) { 377 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 378 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 379 | Name: "bar", 380 | Syntax: []*ast.File{ 381 | parse(expect, ` 382 | 383 | import baz "some/path/to/foo" 384 | 385 | type Bar interface{ 386 | baz.Foo 387 | Bar() 388 | }`), 389 | }, 390 | }) 391 | 392 | pkgName := "foo" 393 | pkg := &packages.Package{ 394 | Name: pkgName, 395 | Syntax: []*ast.File{ 396 | parse(expect, ` 397 | type Foo interface { 398 | Foo(x X) Y 399 | } 400 | 401 | type X int 402 | type Y int`), 403 | }, 404 | } 405 | done, err := pers.ConsistentlyReturn(mockGoDir.ImportOutput, pkg, nil) 406 | expect(err).To(not(haveOccurred())) 407 | defer done() 408 | 409 | found := types.Load(mockGoDir) 410 | 411 | // 3 calls: 1 for the initial import, then deps imports for X and Y 412 | expect(mockGoDir.ImportCalled).To(haveLen(3)) 413 | expect(<-mockGoDir.ImportInput.Path).To(equal("some/path/to/foo")) 414 | 415 | expect(found).To(haveLen(1)) 416 | typs := found[0].ExportedTypes() 417 | expect(typs).To(haveLen(1)) 418 | 419 | spec := typs[0] 420 | expect(spec).To(not(beNil())) 421 | inter := spec.Type.(*ast.InterfaceType) 422 | expect(inter.Methods.List).To(haveLen(2)) 423 | 424 | foo := inter.Methods.List[0] 425 | expect(foo.Names[0].String()).To(equal("Foo")) 426 | f, isFunc := foo.Type.(*ast.FuncType) 427 | expect(isFunc).To(beTrue()) 428 | expect(f.Params.List).To(haveLen(1)) 429 | expect(f.Results.List).To(haveLen(1)) 430 | expr, isSelector := f.Params.List[0].Type.(*ast.SelectorExpr) 431 | expect(isSelector).To(beTrue()) 432 | expect(expr.X.(*ast.Ident).String()).To(equal("baz")) 433 | expect(expr.Sel.String()).To(equal("X")) 434 | expr, isSelector = f.Results.List[0].Type.(*ast.SelectorExpr) 435 | expect(isSelector).To(beTrue()) 436 | expect(expr.X.(*ast.Ident).String()).To(equal("baz")) 437 | expect(expr.Sel.String()).To(equal("Y")) 438 | }) 439 | 440 | o.Spec("AnonymousImportedTypes_Recursion", func(expect expectation, mockGoDir *mockGoDir) { 441 | pers.ConsistentlyReturn(mockGoDir.PathOutput, "/some/path") 442 | pers.ConsistentlyReturn(mockGoDir.PackageOutput, &packages.Package{ 443 | Name: "bar", 444 | Syntax: []*ast.File{ 445 | parse(expect, ` 446 | 447 | import "some/path/to/foo" 448 | 449 | type Bar interface{ 450 | foo.Foo 451 | Bar() 452 | }`), 453 | }, 454 | }) 455 | 456 | pkgName := "foo" 457 | pkg := &packages.Package{ 458 | Name: pkgName, 459 | Syntax: []*ast.File{ 460 | parse(expect, ` 461 | type Foo interface { 462 | Foo(func(X) Y) func(Y) X 463 | }`), 464 | }, 465 | } 466 | done, err := pers.ConsistentlyReturn(mockGoDir.ImportOutput, pkg, nil) 467 | expect(err).To(not(haveOccurred())) 468 | defer done() 469 | 470 | found := types.Load(mockGoDir) 471 | 472 | // One call for the initial import, four more for dependency checking 473 | expect(mockGoDir.ImportCalled).To(haveLen(5)) 474 | expect(<-mockGoDir.ImportInput.Path).To(equal("some/path/to/foo")) 475 | 476 | expect(found).To(haveLen(1)) 477 | typs := found[0].ExportedTypes() 478 | expect(typs).To(haveLen(1)) 479 | 480 | spec := typs[0] 481 | expect(spec).To(not(beNil())) 482 | inter := spec.Type.(*ast.InterfaceType) 483 | expect(inter.Methods.List).To(haveLen(2)) 484 | 485 | foo := inter.Methods.List[0] 486 | expect(foo.Names[0].String()).To(equal("Foo")) 487 | f, isFunc := foo.Type.(*ast.FuncType) 488 | expect(isFunc).To(beTrue()) 489 | expect(f.Params.List).To(haveLen(1)) 490 | expect(f.Results.List).To(haveLen(1)) 491 | 492 | input := f.Params.List[0] 493 | in, isFunc := input.Type.(*ast.FuncType) 494 | expect(isFunc).To(beTrue()) 495 | 496 | expr, isSelector := in.Params.List[0].Type.(*ast.SelectorExpr) 497 | expect(isSelector).To(beTrue()) 498 | expect(expr.X.(*ast.Ident).String()).To(equal("foo")) 499 | expect(expr.Sel.String()).To(equal("X")) 500 | expr, isSelector = in.Results.List[0].Type.(*ast.SelectorExpr) 501 | expect(isSelector).To(beTrue()) 502 | expect(expr.X.(*ast.Ident).String()).To(equal("foo")) 503 | expect(expr.Sel.String()).To(equal("Y")) 504 | 505 | output := f.Params.List[0] 506 | out, isFunc := output.Type.(*ast.FuncType) 507 | expect(isFunc).To(beTrue()) 508 | 509 | expr, isSelector = out.Params.List[0].Type.(*ast.SelectorExpr) 510 | expect(isSelector).To(beTrue()) 511 | expect(expr.X.(*ast.Ident).String()).To(equal("foo")) 512 | expect(expr.Sel.String()).To(equal("X")) 513 | expr, isSelector = out.Results.List[0].Type.(*ast.SelectorExpr) 514 | expect(isSelector).To(beTrue()) 515 | expect(expr.X.(*ast.Ident).String()).To(equal("foo")) 516 | expect(expr.Sel.String()).To(equal("Y")) 517 | }) 518 | } 519 | 520 | func expectNamesToMatch(expect expectation, list []*ast.TypeSpec, names ...string) { 521 | listNames := make(map[string]struct{}, len(list)) 522 | for _, spec := range list { 523 | listNames[spec.Name.String()] = struct{}{} 524 | } 525 | expectedNames := make(map[string]struct{}, len(names)) 526 | for _, name := range names { 527 | expectedNames[name] = struct{}{} 528 | } 529 | expect(listNames).To(equal(expectedNames)) 530 | } 531 | 532 | func find(expect expectation, typs []*ast.TypeSpec, name string) *ast.TypeSpec { 533 | for _, typ := range typs { 534 | if typ.Name.String() == name { 535 | return typ 536 | } 537 | } 538 | return nil 539 | } 540 | --------------------------------------------------------------------------------