├── .travis.yml ├── README.md ├── ast_util.go ├── ast_util_test.go ├── decl.go ├── decl_test.go ├── lib └── diff │ ├── LICENSE │ ├── README.md │ ├── diff.go │ ├── diff_test.go │ └── example_test.go ├── main.go ├── main_test.go ├── put.go ├── put_test.go ├── put_tmpl.go ├── testdata ├── cases │ └── 1 │ │ ├── a │ │ ├── main.go │ │ └── main_test.go │ │ └── b │ │ ├── main.go │ │ └── main_test.go ├── m │ └── m.go ├── s │ └── s.go └── x │ ├── x.go │ ├── x_fail_test.go │ └── x_pass_test.go └── util_test.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.4 4 | - tip 5 | os: 6 | - linux 7 | - osx 8 | before_install: 9 | - go get github.com/axw/gocov/gocov 10 | - go get github.com/modocache/gover 11 | - go get github.com/mattn/goveralls 12 | - go get golang.org/x/tools/cmd/cover 13 | script: 14 | - go test -bench=. -benchmem -covermode=count -coverprofile=main.coverprofile github.com/emil2k/tab 15 | - $HOME/gopath/bin/gover 16 | - $HOME/gopath/bin/goveralls -coverprofile=gover.coverprofile -service travis-ci -repotoken gLzHgh214HbvIdIV23QpAj7Gz4LxWgEoq 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tab 2 | [![Travis 3 | branch](https://img.shields.io/travis/emil2k/tab.svg?style=flat)](https://travis-ci.org/emil2k/tab) 4 | [![Coverage 5 | Status](https://img.shields.io/coveralls/emil2k/tab.svg?style=flat)](https://coveralls.io/r/emil2k/tab) 6 | 7 | **WARNING: This is a work in progress, if you want to help jump in.** 8 | 9 | A tool for generating [table driven 10 | tests](https://github.com/golang/go/wiki/TableDrivenTests) in Go. 11 | 12 | ## Installation 13 | 14 | ``` 15 | go get github.com/emil2k/tab 16 | ``` 17 | 18 | ## Compatibility 19 | 20 | - Go 1.4+ 21 | - Should work on OSX and Linux, someone should test it on Windows. 22 | 23 | ## Usage 24 | 25 | Inititate a variable that holds your table driven tests using the following 26 | naming convention where `F` is a function and `M` is a method on type `T` : 27 | 28 | ```go 29 | // ttF generates a table driven test for function F. 30 | var ttF = []struct{ 31 | ... 32 | }{ 33 | ... 34 | } 35 | 36 | // ttT_M generates a table driven test for method M of type T. 37 | var ttT_M = []struct{ 38 | ... 39 | }{ 40 | ... 41 | } 42 | ``` 43 | 44 | All the types and functions specified by `T` and `F` must be located in the same 45 | package as the variable. 46 | 47 | The `struct`s representing the test must define fields of the same type as the 48 | inputs and expected outputs of the function or method. The fields should be 49 | ordered with the inputs first and the outputs afterwards mirroring the function 50 | signature. When testing a method of type `T` the first field must be an instance 51 | of the type `T` which will be used as a receiver for the test. 52 | 53 | If the function has a variadic input it must be represented as a slice. 54 | Additionally, all the fields can be represented by a function with no parameters 55 | that returns the necessary type, i.e `func() int` for `int`. 56 | 57 | Afterwards, add a `go generate` directive to the file for generating the tests : 58 | 59 | ```go 60 | //go:generate tab 61 | ``` 62 | 63 | To generate the tests in the package directory run : 64 | 65 | ``` 66 | go generate 67 | ``` 68 | 69 | The tool will place or update a test function underneath each table test 70 | variable, with the following naming convention : 71 | 72 | ```go 73 | // ttF generates a table driven test for function F. 74 | var ttF = []struct{ 75 | ... 76 | }{ 77 | ... 78 | } 79 | 80 | // TestTTF is an automatically generated table driven test for the function F 81 | // using the tests defined in ttF. 82 | func TestTTF(t *testing.T) { 83 | ... 84 | } 85 | 86 | // ttT_M generates a table driven test for method M of type T. 87 | var ttT_M = []struct{ 88 | ... 89 | }{ 90 | ... 91 | } 92 | 93 | // TestTTT_M is an automatically generate table driven test for the method T.M 94 | // using the tests defined in ttT_M. 95 | func TestTTT_M(t *testing.T) { 96 | ... 97 | } 98 | ``` 99 | 100 | The generated functions will test that the outputs match expections. 101 | 102 | ## Example 103 | 104 | ```go 105 | package main 106 | 107 | //go:generate tab 108 | 109 | func DummyFunction(a, b int) (c, d, e int, f float64, err error) { 110 | // dummy function to test 111 | return 112 | } 113 | 114 | var ttDummyFunction = []struct { 115 | // inputs 116 | a, b int 117 | // outputs 118 | c, d, e int 119 | f float64 120 | err error 121 | }{ 122 | {1, 2, 5, 6, 7, 5.4, nil}, 123 | {1, 2, 5, 6, 7, 5.4, nil}, 124 | {1, 2, 5, 6, 7, 5.4, nil}, // and on and on ... 125 | } 126 | ``` 127 | 128 | After running `go generate` it adds a table test underneath : 129 | 130 | ```go 131 | package main 132 | 133 | //go:generate tab 134 | 135 | func DummyFunction(a, b int) (c, d, e int, f float64, err error) { 136 | // dummy function to test 137 | return 138 | } 139 | 140 | var ttDummyFunction = []struct { 141 | // inputs 142 | a, b int 143 | // outputs 144 | c, d, e int 145 | f float64 146 | err error 147 | }{ 148 | {1, 2, 5, 6, 7, 5.4, nil}, 149 | {1, 2, 5, 6, 7, 5.4, nil}, 150 | {1, 2, 5, 6, 7, 5.4, nil}, // and on and on ... 151 | } 152 | 153 | // TestTTDummyFunction is an automatically generated table driven test for the 154 | // function DummyFunction using the tests defined in ttDummyFunction. 155 | func TestTTDummyFunction(t *testing.T) { 156 | for i, tt := range ttDummyFunction { 157 | c, d, e, f, err := DummyFunction(tt.a, tt.b) 158 | if c != tt.c { 159 | t.Errorf("%d : c : got %v, expected %v", i, c, tt.c) 160 | } 161 | if d != tt.d { 162 | t.Errorf("%d : d : got %v, expected %v", i, d, tt.d) 163 | } 164 | if e != tt.e { 165 | t.Errorf("%d : e : got %v, expected %v", i, e, tt.e) 166 | } 167 | if f != tt.f { 168 | t.Errorf("%d : f : got %v, expected %v", i, f, tt.f) 169 | } 170 | if err != tt.err { 171 | t.Errorf("%d : err : got %v, expected %v", i, err, tt.err) 172 | } 173 | } 174 | } 175 | ``` 176 | 177 | ## TODO 178 | 179 | Improve error messages of the generated tests, can base on the output type : 180 | 181 | - Allow naming of expected values with field tags. 182 | - If outputs a struct it could test that each field of the struct matches 183 | expectations and output errors for individual fields. 184 | - If outputs a map it could test that all the keys match expectations and output 185 | errors for individual keys. Similar test for arrays/slices. 186 | - If outputs a function it could be provided with another set of table tests 187 | for the expected value which the returned function must pass for the test to 188 | pass. 189 | 190 | ## TODO 191 | 192 | Provide flags to prevent replacement of existing function and the option to 193 | place each table test into a separate test function. 194 | 195 | ## TODO 196 | 197 | By default inequality is evaluated using `!=`. Equality can also be determined 198 | by defining a custom function using the following naming convention : 199 | 200 | ```go 201 | // tt_T determines equality for all fields of type T in this package. 202 | // Returns true if a & b are equal. 203 | func tt_T(a, b T) bool { 204 | ... 205 | } 206 | 207 | // ttT_M_X determines equality for the field X, which is of type T and is an 208 | // output of the tests generated by ttT_M. 209 | // Returns true if a & b are equal. 210 | func ttT_M_X(a, b T) bool { 211 | ... 212 | } 213 | ``` 214 | -------------------------------------------------------------------------------- /ast_util.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "go/ast" 7 | "go/build" 8 | "go/parser" 9 | "go/token" 10 | "strconv" 11 | ) 12 | 13 | // ErrPkgNotFound returned when a package with the provided name is not found in 14 | // a directory. 15 | var ErrPkgNotFound = errors.New("package not found") 16 | 17 | // importer is used to import package given the import path, using getPkg. 18 | // Returns ast.Obj of kind pkg with the scope in the Data field and the package 19 | // itself in the Decl field, adds the object to the passed imports map for 20 | // caching. 21 | func importer(imports map[string]*ast.Object, path string) (pkg *ast.Object, err error) { 22 | if obj, ok := imports[path]; ok { 23 | return obj, nil 24 | } 25 | pkgInfo, err := build.Import(path, "", 0) 26 | if err != nil { 27 | return nil, err 28 | } 29 | oPkg, err := getPkg(pkgInfo.Dir, pkgInfo.Name) 30 | if err != nil { 31 | return nil, err 32 | } 33 | oo := ast.NewObj(ast.Pkg, pkgInfo.Name) 34 | oo.Decl = oPkg 35 | oo.Data = oPkg.Scope 36 | imports[path] = oo 37 | return oo, nil 38 | } 39 | 40 | // getPkg parses the directory and returns an AST for the specified package 41 | // name, making sure that the Scope attribute is set for the package. 42 | // Is meant to overcome the issue that parser.ParseDir does not set the Scope 43 | // of the package. 44 | // Returns an error if a package with the given name cannot be found in the 45 | // directory or the source cannot be parsed. 46 | func getPkg(dir, pkgName string) (*ast.Package, error) { 47 | pkgs, err := parser.ParseDir(token.NewFileSet(), dir, nil, 48 | parser.AllErrors|parser.ParseComments) 49 | if err != nil { 50 | return nil, err 51 | } 52 | if oPkg, ok := pkgs[pkgName]; ok { 53 | // Must create a new Package instance, because ParseDir does not 54 | // set the Scope field. 55 | // Ignoring error because for some reason ParseDir marks regular 56 | // types such as int and string unresolved and the NewPackage 57 | // call attempts to resolve them compiling a list of "undeclared 58 | // name" errors. 59 | pkg, _ := ast.NewPackage(token.NewFileSet(), oPkg.Files, 60 | nil, nil) 61 | return pkg, nil 62 | } 63 | return nil, ErrPkgNotFound 64 | } 65 | 66 | // containsFunction checks the passed packages scope to determine if it 67 | // contains a function with the passed identifier. If so it returns the 68 | // ast.FuncDecl and true, otherwise returns nil and false. 69 | // Relies on the package's Scope to lookup identifiers will panic if it is nil. 70 | // Does not match methods. 71 | func containsFunction(pkg *ast.Package, ident string) (*ast.FuncDecl, bool) { 72 | if pkg.Scope == nil { 73 | panic(fmt.Sprintf("package %s scope is nil\n", pkg.Name)) 74 | } 75 | if obj, ok := pkg.Scope.Objects[ident]; ok && obj.Kind == ast.Fun { 76 | if fd, ok := obj.Decl.(*ast.FuncDecl); ok && fd.Recv == nil { 77 | return fd, true 78 | } 79 | } 80 | return nil, false 81 | } 82 | 83 | // containsMethod checks the passed packages to determine if it contains a 84 | // method with the passed identifier and passed type identifier. If so it 85 | // returns the ast.FuncDecl and true, otherwise returns nil and false. 86 | // Depending on the whether `pointer` is set determines whether to include 87 | // methods with the `*tIdent` receiver in the search, instead of just `tIdent` 88 | // receivers. 89 | // Does not match functions. 90 | func containsMethod(pkg *ast.Package, mIdent, tIdent string, pointer bool) (*ast.FuncDecl, bool) { 91 | ts, ok := containsType(pkg, tIdent) 92 | if !ok { 93 | return nil, false 94 | } 95 | // Methods don't show up in the package scope, so need to walk the AST 96 | // to find them. 97 | var fd *ast.FuncDecl 98 | walk := func(p *ast.Package, mid, tid string, wpointer bool) funcVisitor { 99 | return func(n ast.Node) { 100 | if xFd, ok := n.(*ast.FuncDecl); ok && xFd.Name.Name == mid && xFd.Recv != nil { 101 | // Check the receiver name & whether pointer 102 | // matches. 103 | _, xTs, xPointer, _ := resolveExpr(p, xFd.Recv.List[0].Type) 104 | if (wpointer || (!wpointer && !xPointer)) && 105 | xTs != nil && xTs.Name.Name == tid { 106 | fd = xFd 107 | } 108 | } 109 | } 110 | } 111 | ast.Walk(walk(pkg, mIdent, tIdent, pointer), pkg) 112 | if fd != nil { 113 | return fd, true 114 | } 115 | // Check embeded fields in a struct type, method could have them as a 116 | // receiver. 117 | if st, ok := ts.Type.(*ast.StructType); ok && st.Fields != nil && 118 | st.Fields.NumFields() > 0 { 119 | for _, f := range st.Fields.List { 120 | if len(f.Names) != 0 { 121 | continue 122 | } 123 | ePkg, eTs, ePointer, _ := resolveExpr(pkg, f.Type) 124 | // Check for method in embedded type's package with it 125 | // as a receiver. 126 | if eTs != nil { 127 | ast.Walk(walk(ePkg, mIdent, eTs.Name.Name, ePointer), ePkg) 128 | if fd != nil { 129 | return fd, true 130 | } 131 | } 132 | } 133 | } 134 | return nil, false 135 | } 136 | 137 | // containsType checks the passed packages scope to determine if it contains a 138 | // type with the passed identifier. If so it returns the ast.TypeSpec and 139 | // true, otherwise returns nil and false. 140 | // Relies on the package's Scope to lookup identifiers will panic if it is nil. 141 | func containsType(pkg *ast.Package, ident string) (*ast.TypeSpec, bool) { 142 | if pkg.Scope == nil { 143 | panic(fmt.Sprintf("package %s scope is nil\n", pkg.Name)) 144 | } 145 | if obj, ok := pkg.Scope.Objects[ident]; ok && obj.Kind == ast.Typ { 146 | if ts, ok := obj.Decl.(*ast.TypeSpec); ok { 147 | return ts, true 148 | } 149 | } 150 | return nil, false 151 | } 152 | 153 | // containsVar checks the passed packages scope to determine if it contains a 154 | // variable declaration with the passed identifier. If so it returns the 155 | // ast.ValueSpec and true, otherwise returns nil and false. 156 | // Relies on the package's Scope to lookup identifiers will panic if it is nil. 157 | func containsVar(pkg *ast.Package, ident string) (*ast.ValueSpec, bool) { 158 | if pkg.Scope == nil { 159 | panic(fmt.Sprintf("package %s scope is nil\n", pkg.Name)) 160 | } 161 | if obj, ok := pkg.Scope.Objects[ident]; ok && obj.Kind == ast.Var { 162 | if vs, ok := obj.Decl.(*ast.ValueSpec); ok { 163 | return vs, true 164 | } 165 | } 166 | return nil, false 167 | } 168 | 169 | // funcVisitor defines a simple AST Visitor that calls a function passing in 170 | // the node and returns itself. 171 | type funcVisitor func(n ast.Node) 172 | 173 | // Visit calls the receiver function and returns the Visitor back. 174 | func (v funcVisitor) Visit(n ast.Node) ast.Visitor { 175 | v(n) 176 | return v 177 | } 178 | 179 | // exprEqual returns true if the types of the two expressions match, checks 180 | // idents, function signatures, channels, and interface types. 181 | // If expression `b` is an interface then `a` must also be an interface, and if 182 | // expression `a` is an interface then interface `b` must meet its requirements. 183 | // Resolves expressions in the passed corresponding packages. 184 | // If expression `a` is an ast.StarExpr then `b` must also be an ast.StarExpr 185 | // pointing to an equivalent type, otherwise `b` can be either. 186 | func exprEqual(ap, bp *ast.Package, a, b ast.Expr) bool { 187 | if a == nil || b == nil { 188 | return a == b 189 | } 190 | // Attempt to resolve idents and selector expressions, update the 191 | // respective packages if necessary. 192 | ap, _, apoint, ao := resolveExpr(ap, a) 193 | bp, bts, bpoint, bo := resolveExpr(bp, b) 194 | // If a is `a` is a pointer expression, then `b` must also be. 195 | if apoint && !bpoint { 196 | return false 197 | } 198 | switch at := ao.(type) { 199 | case *ast.MapType: 200 | if bt, ok := bo.(*ast.MapType); ok { 201 | return exprEqual(ap, bp, at.Key, bt.Key) && 202 | exprEqual(ap, bp, at.Value, bt.Value) 203 | } 204 | case *ast.StructType: 205 | if bt, ok := bo.(*ast.StructType); ok { 206 | return fieldListEqual(ap, bp, at.Fields, bt.Fields) 207 | } 208 | case *ast.ArrayType: 209 | if bt, ok := bo.(*ast.ArrayType); ok { 210 | return exprEqual(ap, bp, at.Elt, bt.Elt) && 211 | exprEqual(ap, bp, at.Len, bt.Len) 212 | } 213 | case *ast.InterfaceType: 214 | return exprInterface(ap, bp, at, bts, false) 215 | case *ast.FuncType: 216 | if bt, ok := bo.(*ast.FuncType); ok { 217 | return fieldListEqual(ap, bp, at.Params, bt.Params) && 218 | fieldListEqual(ap, bp, at.Results, bt.Results) 219 | } 220 | case *ast.ChanType: 221 | if bt, ok := bo.(*ast.ChanType); ok { 222 | return at.Dir == bt.Dir && exprEqual(ap, bp, at.Value, bt.Value) 223 | } 224 | case *ast.Ident: 225 | // For matching all other first class types, i.e. intX, floatX, 226 | // complexX, byte, rune. 227 | if bt, ok := bo.(*ast.Ident); ok && 228 | at.Name == bt.Name { 229 | return true 230 | } 231 | default: 232 | panic(fmt.Sprintf("unhandled type %T (%v), compared to %T (%v)\n", 233 | ao, ao, bo, bo)) 234 | } 235 | return false 236 | } 237 | 238 | // exprInterface return true if the the passed type meets the requirements of 239 | // the interface within their respective packages. 240 | // If provided type is an interface it must include the methods in the passed 241 | // interface. 242 | // If `pointer` is set it includes methods with pointer receivers. 243 | func exprInterface(ifacePkg, tsPkg *ast.Package, iface *ast.InterfaceType, ts *ast.TypeSpec, pointer bool) bool { 244 | if iface == nil || iface.Methods.NumFields() == 0 { 245 | return true 246 | } 247 | for _, m := range iface.Methods.List { 248 | mPkg, _, _, mObj := resolveExpr(ifacePkg, m.Type) 249 | switch x := mObj.(type) { 250 | case *ast.InterfaceType: 251 | // Embedded interface found, should have its methods 252 | // aswell. 253 | if ok := exprInterface(mPkg, tsPkg, x, ts, pointer); !ok { 254 | return false 255 | } 256 | case *ast.FuncType: 257 | for _, n := range m.Names { 258 | // Find the potential method declaration, 259 | // depending on whether the passed type is an 260 | // interface or another type. 261 | var md *ast.FuncType 262 | if tsi, ok := ts.Type.(*ast.InterfaceType); ok { 263 | if md, ok = ifaceContainsMethod(tsPkg, tsi, n.Name); !ok { 264 | return false 265 | } 266 | } else { 267 | if mdecl, ok := containsMethod(tsPkg, n.Name, ts.Name.Name, pointer); ok { 268 | md = mdecl.Type 269 | } 270 | } 271 | if md == nil { 272 | // Corresponding method not found 273 | return false 274 | } 275 | // Check that the method signatures match 276 | if mm := exprEqual(ifacePkg, tsPkg, x, md); !mm { 277 | return false 278 | } 279 | } 280 | default: 281 | panic(fmt.Sprintf("unhandled interface field %T\n", 282 | m.Type)) 283 | } 284 | } 285 | return true 286 | } 287 | 288 | // ifaceContainsMethod checks if the interface contains a method with the passed 289 | // name, if not returns nil and false. 290 | // Takes into account embedded interfaces. 291 | func ifaceContainsMethod(pkg *ast.Package, iface *ast.InterfaceType, name string) (*ast.FuncType, bool) { 292 | for _, m := range iface.Methods.List { 293 | mPkg, _, _, mObj := resolveExpr(pkg, m.Type) 294 | switch x := mObj.(type) { 295 | case *ast.InterfaceType: 296 | // Embedded interface found, check if it contains the 297 | // method. 298 | if xf, ok := ifaceContainsMethod(mPkg, x, name); ok { 299 | return xf, true 300 | } 301 | case *ast.FuncType: 302 | for _, n := range m.Names { 303 | if n.Name == name { 304 | return x, true 305 | } 306 | } 307 | } 308 | } 309 | return nil, false 310 | } 311 | 312 | // resolveExpr attempts to resolve expressions such as an ident or selector 313 | // expression down to their underlying type. 314 | // Returns the type spec and the underlying object XXXType, returns in nil when 315 | // cannot resolve something or if there is no type spec for the type, i.e. int, 316 | // float32, etc. 317 | // Returns whether the type spec is a pointer type. 318 | // Returns the package where the type spec was found, may change if there is 319 | // a selector expression. 320 | func resolveExpr(pkg *ast.Package, in ast.Expr) (resolvedPackage *ast.Package, typeSpec *ast.TypeSpec, pointer bool, obj interface{}) { 321 | if x, ok := in.(*ast.StarExpr); ok { 322 | in = x.X 323 | pointer = true 324 | } 325 | switch x := in.(type) { 326 | case *ast.Ident: 327 | resolvedPackage, typeSpec, obj = resolveIdent(pkg, x) 328 | case *ast.SelectorExpr: 329 | resolvedPackage, typeSpec, obj = resolveSelectorExpr(pkg, x) 330 | default: 331 | obj = in 332 | } 333 | return 334 | } 335 | 336 | // resolveSelectorExpr attempts to resolve a selector expression to its 337 | // underlying type assuming the selector is a package identifier and that the 338 | // field ident is in its scope. Imports a package if necessary. 339 | // Returns the input and the passed package if cannot resolve. 340 | // Returns the package where the type spec was found. 341 | func resolveSelectorExpr(pkg *ast.Package, in *ast.SelectorExpr) (*ast.Package, *ast.TypeSpec, interface{}) { 342 | if xi, ok := in.X.(*ast.Ident); ok { 343 | // Attempt to lookup the idents obj in case it is a package, 344 | // otherwise will have to try to import it. 345 | selObj, ok := objPkgLookup(xi.Obj, in.Sel.Name) 346 | if ok { 347 | return resolveObjDecl(pkg, selObj) 348 | } 349 | // Find the file the selector is in, attempt to find the 350 | // package it refers to, import it, and lookup the 351 | // object in it. 352 | if f, ok := lookupFile(pkg, in); ok { 353 | if pkgObj, ok := lookupImport(pkg, f, xi.Name); ok { 354 | if selPkg, ok := pkgObj.Decl.(*ast.Package); ok { 355 | selObj, _ = objPkgLookup(pkgObj, in.Sel.Name) 356 | return resolveObjDecl(selPkg, selObj) 357 | } 358 | } 359 | } 360 | } 361 | return pkg, nil, in 362 | } 363 | 364 | // resolveIdent attempts to resolve an ident expression into it's underlying 365 | // type. 366 | // Returns the input and the passed package if cannot resolve. 367 | // Returns the package where the type spec was found, may change if there is 368 | // a selector expression. 369 | func resolveIdent(pkg *ast.Package, in *ast.Ident) (*ast.Package, *ast.TypeSpec, interface{}) { 370 | if npkg, ts, decl := resolveObjDecl(pkg, in.Obj); decl != nil { 371 | return npkg, ts, decl 372 | } 373 | return pkg, nil, in 374 | } 375 | 376 | // resolveObjDecl resolves the declaration of an object, returning the type 377 | // spec of the object and the declaration of the underlying object, i.e. 378 | // XXXTypes. Returns a nil declaration if the object has no declaration, i.e. 379 | // int, float32, etc., and returns a nil type spec and the passed package when 380 | // the object does not refer to an ast.TypeSpec. 381 | // Returns the package where the type spec was found, may change if there is 382 | // a selector expression. 383 | func resolveObjDecl(pkg *ast.Package, obj *ast.Object) (*ast.Package, *ast.TypeSpec, interface{}) { 384 | if obj != nil && obj.Decl != nil { 385 | if ts, ok := obj.Decl.(*ast.TypeSpec); ok { 386 | // Recurse to the underlying type, but return this type 387 | // spec and package. 388 | _, _, _, robj := resolveExpr(pkg, ts.Type) 389 | return pkg, ts, robj 390 | } 391 | return pkg, nil, obj.Decl // no type spec found 392 | } 393 | return nil, nil, nil 394 | } 395 | 396 | // lookupImport attempts to import the package refered to by the passed selector 397 | // depending on the file. If found it returns an ast.Object of the package kind 398 | // and true, otherwise nil and false. 399 | // It must import the package to determine the package name, if pkg is not nil 400 | // it uses its Imports field as a cache so it won't have to import repeatedly. 401 | func lookupImport(pkg *ast.Package, file *ast.File, sel string) (*ast.Object, bool) { 402 | var imports map[string]*ast.Object 403 | if pkg != nil { 404 | imports = pkg.Imports 405 | } 406 | for _, i := range file.Imports { 407 | if i.Name != nil && (i.Name.Name == "." || i.Name.Name == "_") { 408 | continue 409 | } 410 | ip, err := strconv.Unquote(i.Path.Value) 411 | if err != nil { 412 | continue 413 | } 414 | pkgObj, err := importer(imports, ip) 415 | if err != nil { 416 | continue 417 | } 418 | // Must match either package name or local package name. 419 | if (i.Name != nil && i.Name.Name == sel) || 420 | (i.Name == nil && pkgObj.Name == sel) { 421 | return pkgObj, true 422 | } 423 | } 424 | return nil, false 425 | } 426 | 427 | // lookupFile attempts to locate the file in the package where the AST node 428 | // is defined, by walking the passed package. 429 | func lookupFile(pkg *ast.Package, node ast.Node) (*ast.File, bool) { 430 | found := false 431 | var walk funcVisitor = func(n ast.Node) { 432 | if n == node { 433 | found = true 434 | } 435 | } 436 | for _, f := range pkg.Files { 437 | ast.Walk(walk, f) 438 | if found { 439 | return f, true 440 | } 441 | } 442 | return nil, false 443 | } 444 | 445 | // objPkgLookup attempts to retrieve the named object from a package scope, 446 | // if the passed object is not a pkg object then it will return nil and false. 447 | func objPkgLookup(pkg *ast.Object, name string) (*ast.Object, bool) { 448 | if pkg == nil || pkg.Kind != ast.Pkg || pkg.Data == nil { 449 | return nil, false 450 | } 451 | if pkgScope, ok := pkg.Data.(*ast.Scope); ok { 452 | if obj, ok := scopeLookup(pkgScope, name); ok { 453 | return obj, true 454 | } 455 | } 456 | return nil, false 457 | } 458 | 459 | // scopeLookup recursively lookups an object with the given name in the provided 460 | // scope, recurses to outer scopes when not found. 461 | // Returns nil and false when not found. 462 | func scopeLookup(scope *ast.Scope, name string) (*ast.Object, bool) { 463 | for ; scope != nil; scope = scope.Outer { 464 | if obj := scope.Lookup(name); obj != nil { 465 | return obj, true 466 | } 467 | } 468 | return nil, false 469 | } 470 | 471 | // fieldListEqual returns true if the two field list are equivalent. 472 | // Accounts for the fact that fields can declare types on an individual basis or 473 | // on multiple idents but the two field lists should be considered equivalent. 474 | // For example : `a, b int` and `a int, b int` should be considered the same. 475 | func fieldListEqual(ap, bp *ast.Package, a, b *ast.FieldList) bool { 476 | if a.NumFields() != b.NumFields() { 477 | return false 478 | } 479 | bes := fieldListExpr(b) 480 | for i, ae := range fieldListExpr(a) { 481 | if !exprEqual(ap, bp, ae, bes[i]) { 482 | return false 483 | } 484 | } 485 | return true 486 | } 487 | 488 | // isStructSlice checks if the value spec is an array of structs, if so returns 489 | // the struct type and true, otherwise returns nil and false. 490 | // Only checks the first value in var declaration should not be used with 491 | // multi assigning initilizations. 492 | func isStructSlice(vs *ast.ValueSpec) (*ast.StructType, bool) { 493 | if len(vs.Values) > 0 { 494 | if cl, ok := vs.Values[0].(*ast.CompositeLit); ok { 495 | if at, ok := cl.Type.(*ast.ArrayType); ok { 496 | if st, ok := at.Elt.(*ast.StructType); ok { 497 | return st, true 498 | } 499 | } 500 | } 501 | } 502 | return nil, false 503 | } 504 | 505 | // structExpr returns a list of expressions representing the fields of a struct 506 | // type. 507 | // If mutiple idents are provided for a type in specifying the fields, i.e. 508 | // a, b int, then provides an expression for each ident. 509 | func structExpr(s *ast.StructType) []ast.Expr { 510 | oe := make([]ast.Expr, 0) 511 | if s == nil || s.Fields == nil { 512 | return oe 513 | } 514 | for _, f := range s.Fields.List { 515 | for i := 0; i < len(f.Names); i++ { 516 | oe = append(oe, f.Type) 517 | } 518 | } 519 | return oe 520 | } 521 | 522 | // funcExpr compiles a list of expressions in the signature of a function or 523 | // method, in the following order receiver, inputs, outputs. 524 | // If multiple ident are provided for a type in the signature, i.e. a, b int, 525 | // then provides an expression for each ident. 526 | func funcExpr(f *ast.FuncDecl) []ast.Expr { 527 | fs := make([]*ast.Field, 0) 528 | if f.Recv != nil { 529 | fs = append(fs, f.Recv.List...) 530 | } 531 | fs = append(fs, f.Type.Params.List...) 532 | if f.Type.Results != nil { 533 | fs = append(fs, f.Type.Results.List...) 534 | } 535 | oe := make([]ast.Expr, 0) 536 | for _, f := range fs { 537 | if len(f.Names) == 0 { // anonymous fields 538 | oe = append(oe, f.Type) 539 | continue 540 | } 541 | for i := 0; i < len(f.Names); i++ { 542 | oe = append(oe, f.Type) 543 | } 544 | } 545 | return oe 546 | } 547 | 548 | // fieldListExpr compiles a list of expressions in a field list. 549 | // If multiple ident are provided for a type in the signature, i.e. a, b int, 550 | // then provides an expression for each ident. 551 | func fieldListExpr(f *ast.FieldList) []ast.Expr { 552 | oe := make([]ast.Expr, 0) 553 | if f == nil { 554 | return oe 555 | } 556 | for _, f := range f.List { 557 | if len(f.Names) == 0 { // anonymous field 558 | oe = append(oe, f.Type) 559 | continue 560 | } 561 | for _ = range f.Names { 562 | oe = append(oe, f.Type) 563 | } 564 | } 565 | return oe 566 | } 567 | -------------------------------------------------------------------------------- /ast_util_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "go/ast" 5 | "go/parser" 6 | "testing" 7 | ) 8 | 9 | // TestGetPkgNotFound tests that getPkg properly returns not found error. 10 | func TestGetPkgNotFound(t *testing.T) { 11 | _, err := getPkg("testdata/x", "doesnotexist") 12 | if err != ErrPkgNotFound { 13 | t.Error("package does not exist") 14 | } 15 | } 16 | 17 | // TestGetPkg tests that the retrieved package has the expected package name. 18 | func TestGetPkg(t *testing.T) { 19 | pkg := getTestPkg(t, "testdata/x", "x") 20 | if pkg.Name != "x" { 21 | t.Errorf("package name %s, expected testdata/x\n", pkg.Name) 22 | } 23 | if pkg.Scope == nil { 24 | t.Errorf("package scope is nil") 25 | } 26 | } 27 | 28 | // TestContainsFunction tests containsFunction to make sure that it matches 29 | // functions and excludes methods. 30 | // Specifically tests matching of exported and non exported functions. 31 | func TestContainsFunction(t *testing.T) { 32 | pkg := getTestPkg(t, "testdata/x", "x") 33 | if fd, ok := containsFunction(pkg, "ExportedFunction"); !ok { 34 | t.Error("should contain") 35 | } else if fd.Name.Name != "ExportedFunction" { 36 | t.Error("ident does not match") 37 | } 38 | if _, ok := containsFunction(pkg, "nonExportedFunction"); !ok { 39 | t.Error("should contain") 40 | } 41 | if _, ok := containsFunction(pkg, "doesnotexist"); ok { 42 | t.Error("should not contain") 43 | } 44 | if _, ok := containsFunction(pkg, "ExportedMethod"); ok { 45 | t.Error("should not match method") 46 | } 47 | } 48 | 49 | // TestContainsMethod tests containsMethod to make sure it does not match 50 | // functions, makes sure that receiver ident matched. 51 | // Specifically tests matching of exported and nonexported methods. 52 | func TestContainsMethod(t *testing.T) { 53 | pkg := getTestPkg(t, "testdata/x", "x") 54 | if fd, ok := containsMethod(pkg, "ExportedMethod", "ExportedType", false); !ok { 55 | t.Error("should contain") 56 | } else if fd.Name.Name != "ExportedMethod" { 57 | t.Error("ident does not match") 58 | } 59 | if _, ok := containsMethod(pkg, "nonExportedMethod", "ExportedType", false); !ok { 60 | t.Error("should contain") 61 | } 62 | if _, ok := containsMethod(pkg, "ExportedMethod", "wrongtype", false); ok { 63 | t.Error("method exists but not for this type") 64 | } 65 | if _, ok := containsMethod(pkg, "doesnotexist", "ExportedType", false); ok { 66 | t.Error("type exists but not the method") 67 | } 68 | } 69 | 70 | // TestContainsType tests containsType to make sure it matches existing types, 71 | // and does not match variables or functions. 72 | // Specifically tests matching of exported and nonexported types. 73 | func TestContainsType(t *testing.T) { 74 | pkg := getTestPkg(t, "testdata/x", "x") 75 | if ts, ok := containsType(pkg, "ExportedType"); !ok { 76 | t.Error("should contain") 77 | } else if ts.Name.Name != "ExportedType" { 78 | t.Error("ident does not match") 79 | } 80 | if _, ok := containsType(pkg, "nonExportedType"); !ok { 81 | t.Error("should contain") 82 | } 83 | if _, ok := containsType(pkg, "doesnotexist"); ok { 84 | t.Error("should not contain") 85 | } 86 | if _, ok := containsType(pkg, "A"); ok { 87 | t.Error("should not match vars") 88 | } 89 | if _, ok := containsType(pkg, "ExportedFunction"); ok { 90 | t.Error("should not match functions") 91 | } 92 | } 93 | 94 | // TestContainsVar tests containsVar to make sure it matches variable, and does 95 | // not match functions, methods, or types. 96 | // Specifically tests matching of exported and nonexported vars. 97 | func TestContainsVar(t *testing.T) { 98 | pkg := getTestPkg(t, "testdata/x", "x") 99 | if vs, ok := containsVar(pkg, "ExportedVar"); !ok { 100 | t.Error("should contain") 101 | } else if vs.Names[0].Name != "ExportedVar" { 102 | t.Error("ident does not match") 103 | } 104 | if _, ok := containsVar(pkg, "nonExportedVar"); !ok { 105 | t.Error("should contain") 106 | } 107 | if _, ok := containsVar(pkg, "doesnotexist"); ok { 108 | t.Error("should not contain") 109 | } 110 | if _, ok := containsVar(pkg, "ExportedFunction"); ok { 111 | t.Error("should not match function") 112 | } 113 | if _, ok := containsVar(pkg, "ExportedMethod"); ok { 114 | t.Error("should not match methods") 115 | } 116 | if _, ok := containsVar(pkg, "ExportedType"); ok { 117 | t.Error("should not match types") 118 | } 119 | } 120 | 121 | // testsStructSliceExpr are table tests for TestStrucSliceExpr. 122 | var testsStructSliceExpr = []struct { 123 | name string // var name defining struct 124 | isStruct bool // whether expected to be struct 125 | count int // count of expression expected 126 | }{ 127 | {"StructArray", true, 3}, 128 | {"emptyStructArray", true, 4}, 129 | {"NotStructArray", false, 0}, 130 | {"notArray", false, 0}, 131 | } 132 | 133 | // TestStructSliceExpr tests isStructSlice and strucExpr, checks struct arrays, 134 | // empty struct arrays, non struct arrays, and non arrays to make sure struct 135 | // slice checking works. When a struct type can be retrieved checks that the 136 | // expression count matches with structExpr. 137 | func TestStructSliceExpr(t *testing.T) { 138 | pkg := getTestPkg(t, "testdata/s", "s") 139 | for _, tt := range testsStructSliceExpr { 140 | vs, ok := containsVar(pkg, tt.name) 141 | if !ok { 142 | t.Error(tt.name, "should contain") 143 | } 144 | if st, ok := isStructSlice(vs); ok != tt.isStruct { 145 | t.Errorf("%s is struct %t, expected %t\n", 146 | tt.name, ok, tt.isStruct) 147 | } else if x := len(structExpr(st)); x != tt.count { 148 | t.Errorf("%s expr count %d, expected %d\n", 149 | tt.name, x, tt.count) 150 | } 151 | } 152 | } 153 | 154 | // TestFuncExpr tests funcExpr checks the count of expressions returned on a 155 | // function an a method, which use a combination of anonymous and named fields 156 | // and cases where there are multiple expr per field. 157 | func TestFuncExpr(t *testing.T) { 158 | pkg := getTestPkg(t, "testdata/x", "x") 159 | // Test a function 160 | f, ok := containsFunction(pkg, "ExportedFunction") 161 | if !ok { 162 | t.Error("does not contain") 163 | } 164 | test := func(xf *ast.FuncDecl, count int) { 165 | if x := len(funcExpr(xf)); x != count { 166 | t.Errorf("expr count does not match %d, expected %d\n", 167 | x, count) 168 | } 169 | } 170 | test(f, 3) 171 | // Test a method 172 | m, ok := containsMethod(pkg, "ExportedMethod", "ExportedType", false) 173 | if !ok { 174 | t.Error("does not contain") 175 | } 176 | test(m, 5) 177 | } 178 | 179 | // testFieldListExpr are table driven tests for fieldListExpr the expected 180 | // outputs are counts of the parameters and the results. 181 | var testsFieldListExpr = []struct { 182 | expr string 183 | param, results int 184 | }{ 185 | {"func(a int, b int) bool", 2, 1}, 186 | {"func(a, b int) bool", 2, 1}, 187 | {"func(a ...int)", 1, 0}, 188 | {"func() bool", 0, 1}, 189 | {"func() (a, b bool)", 0, 2}, 190 | {"func() func(a, b int) bool", 0, 1}, 191 | {"func(a <-chan int) func(a, b int) bool", 1, 1}, 192 | } 193 | 194 | // TestFieldListExpr test fieldListExpr by simply testing of the counts of expr 195 | // returned match expectations for several functions. 196 | // Tests cases with anonymous returns, variadic inputs, and channels. 197 | func TestFieldListExpr(t *testing.T) { 198 | for _, tt := range testsFieldListExpr { 199 | f, err := parser.ParseExpr(tt.expr) 200 | if err != nil { 201 | t.Error(tt.expr+" could not be parsed : ", err.Error()) 202 | } 203 | if ft, ok := f.(*ast.FuncType); !ok { 204 | t.Errorf(tt.expr+" tt.expr is %T not FuncType\n", f) 205 | } else { 206 | if p := fieldListExpr(ft.Params); len(p) != tt.param { 207 | t.Error(tt.expr + " params count doesn't match") 208 | } 209 | if r := fieldListExpr(ft.Results); len(r) != tt.results { 210 | t.Error(tt.expr + " results count doesn't match") 211 | } 212 | } 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /decl.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/parser" 7 | "go/token" 8 | "path/filepath" 9 | "strings" 10 | ) 11 | 12 | // fileTTDecls generates table driven tests inside the file specified by path 13 | // for the specified package name. 14 | func fileTTDecls(path, pkgName string) ([]*ttDecl, error) { 15 | // Parse file for potential tt identifiers. 16 | ttIdents, err := fileTTIdents(path) 17 | if err != nil { 18 | return nil, err 19 | } 20 | // Parse directory to find the functions, types, and methods associated 21 | // with the table test declarations. 22 | dir := filepath.Dir(path) 23 | return dirTTDecls(dir, ttIdents, pkgName) 24 | } 25 | 26 | // fileTTIdents returns a list of all the identifiers of the potential table 27 | // driven declarations found in the file specified by path. 28 | func fileTTIdents(path string) ([]string, error) { 29 | f, err := parser.ParseFile(token.NewFileSet(), path, nil, 30 | parser.AllErrors|parser.ParseComments) 31 | if err != nil { 32 | return nil, err 33 | } 34 | tts := make([]string, 0) 35 | for _, n := range f.Decls { 36 | if gd, ok := n.(*ast.GenDecl); !ok { 37 | continue 38 | } else if _, ident, ok := isTTVar(gd); ok { 39 | tts = append(tts, ident) 40 | } 41 | } 42 | return tts, nil 43 | } 44 | 45 | // ttDecl holds a table driven test declaration. 46 | type ttDecl struct { 47 | pkg *ast.Package // package where the declaration is made 48 | tt *ast.ValueSpec // variable declaration that contains tt declaration 49 | f *ast.FuncDecl // function or method declaration to test 50 | t *ast.TypeSpec // type declaration if testing a method 51 | 52 | ttIdent, fIdent, tIdent string 53 | } 54 | 55 | // isMethod returns whether the test is testing a method, otherwise testing a 56 | // function. 57 | func (td ttDecl) isMethod() bool { 58 | return len(td.tIdent) > 0 59 | } 60 | 61 | // testName returns the name for the test function. 62 | func (td ttDecl) testName() string { 63 | if td.isMethod() { 64 | return fmt.Sprintf("TestTT%s_%s", td.tIdent, td.fIdent) 65 | } else { 66 | return fmt.Sprintf("TestTT%s", td.fIdent) 67 | } 68 | } 69 | 70 | // testDoc returns the doc string for the test function. 71 | func (td ttDecl) testDoc() string { 72 | if td.isMethod() { 73 | return fmt.Sprintf("%s is an automatically generated table driven test for the method %s.%s using the tests defined in %s.", 74 | td.testName(), td.tIdent, td.fIdent, td.ttIdent) 75 | } else { 76 | return fmt.Sprintf("%s is an automatically generated table driven test for the function %s using the tests defined in %s.", 77 | td.testName(), td.fIdent, td.ttIdent) 78 | } 79 | } 80 | 81 | // dirTTDecls compiles a list of all the valid tt declarations found in the 82 | // specified diretory that are associated with the passed tt identifiers and 83 | // package name. 84 | func dirTTDecls(dir string, ttIdents []string, pkgName string) ([]*ttDecl, error) { 85 | pkg, err := getPkg(dir, pkgName) 86 | if err != nil { 87 | return nil, err 88 | } 89 | return pkgTTDecls(pkg, ttIdents) 90 | } 91 | 92 | // pkgTTDecls compiles a list of all the valid tt declarations found in the 93 | // passed package that are associated with the passed tt identifiers. 94 | // Returns an error if any of the found tt declarations are in valid, meaning 95 | // the struct field types don't match the receiver/input/output types. 96 | func pkgTTDecls(pkg *ast.Package, ttIdents []string) ([]*ttDecl, error) { 97 | ttDecls := make([]*ttDecl, 0) 98 | for _, ttIdent := range ttIdents { 99 | if ttDecl, ok := isTTDecl(pkg, ttIdent); ok { 100 | if err := isTTDeclValid(ttDecl); err != nil { 101 | return nil, err 102 | } else { 103 | ttDecls = append(ttDecls, ttDecl) 104 | } 105 | } 106 | } 107 | return ttDecls, nil 108 | } 109 | 110 | // isTTDecl checks if the identifier is a tt declaration in the provided 111 | // package, if so returns a ttDecl instance with all the necessary AST nodes, 112 | // otherwise returns nil and false. 113 | func isTTDecl(pkg *ast.Package, ttIdent string) (*ttDecl, bool) { 114 | vs, ok := containsVar(pkg, ttIdent) 115 | if !ok { 116 | return nil, false 117 | } 118 | ttD := &ttDecl{pkg: pkg, tt: vs, ttIdent: ttIdent} 119 | // First, attempt to find a function with the name. A function may 120 | // contain underscore also. 121 | // Otherwise attempt to find a method. 122 | ident := strings.TrimPrefix(ttIdent, "tt") 123 | if fd, ok := containsFunction(pkg, ident); ok { 124 | ttD.f = fd 125 | ttD.fIdent = ident 126 | return ttD, true 127 | } else { 128 | // Try to find a method, by trying all the various type and 129 | // method names that can be inferred from the original ident. 130 | parts := strings.Split(ident, "_") 131 | for i := 0; i < len(parts)-1; i++ { 132 | tIdent := strings.Join(parts[:i+1], "_") // type ident 133 | mIdent := strings.Join(parts[i+1:], "_") // method ident 134 | md, ok := containsMethod(pkg, mIdent, tIdent, true) 135 | if !ok { 136 | continue 137 | } 138 | td, ok := containsType(pkg, tIdent) 139 | if !ok { 140 | continue 141 | } 142 | ttD.f = md 143 | ttD.fIdent = mIdent 144 | ttD.t = td 145 | ttD.tIdent = tIdent 146 | return ttD, true 147 | } 148 | } 149 | return nil, false 150 | } 151 | 152 | // isTTVar checks if the node is a possible tt declaration, returns the 153 | // ValueSpec, matched identifier, and a bool specifying whether matched. Only 154 | // matches the first identifier in a var declaration, make sure to only declare 155 | // on tt per var. 156 | func isTTVar(gd *ast.GenDecl) (*ast.ValueSpec, string, bool) { 157 | if gd.Tok != token.VAR { 158 | return nil, "", false 159 | } 160 | for _, sp := range gd.Specs { 161 | if vs, ok := sp.(*ast.ValueSpec); ok { 162 | for _, n := range vs.Names { 163 | if strings.HasPrefix(n.Name, "tt") { 164 | return vs, n.Name, true 165 | } 166 | } 167 | } 168 | } 169 | return nil, "", false 170 | } 171 | 172 | // isTTDeclValid returns nil if the tt declaration is valid, otherwise an error. 173 | // Returns an error if the the fields of the test declaration don't match the 174 | // reciever/inputs/outputs of the function or method being tested. 175 | func isTTDeclValid(td *ttDecl) error { 176 | // Check that tt declaration is a list of structs 177 | st, ok := isStructSlice(td.tt) 178 | if !ok { 179 | return fmt.Errorf("%s should be an array of structs", 180 | td.ttIdent) 181 | } 182 | // Gather expressions 183 | fes, ses := funcExpr(td.f), structExpr(st) 184 | if len(fes) != len(ses) { 185 | return fmt.Errorf("expression count does not match in %s and %s", 186 | td.ttIdent, td.fIdent) 187 | } 188 | for i, fe := range fes { 189 | if !isTTExprValid(td.pkg, fe, ses[i]) { 190 | return fmt.Errorf("expressions %T (%v) and %T (%v) don't match in %s and %s", 191 | fe, fe, ses[i], ses[i], td.fIdent, td.ttIdent) 192 | } 193 | } 194 | return nil 195 | } 196 | 197 | // isTTExprValid returns true if the field of the struct declaring the tt test 198 | // properly match the field of the function or method it is testing, both 199 | // expressions must located in the passed package. 200 | // In case of a selector expression, it will import packages as necessary. 201 | // Returns true if the types are the same or if the function has a variadic 202 | // input the struct must have an array of the same type representing the input. 203 | // Returns true if the struct contains a function with no parameters but 204 | // returns the same type as function field, i.e. `func() int` for `int`. 205 | // If the struct expression is an interface then the function must also be an 206 | // interface, and if the function expression is an interface the struct 207 | // expression must meet its requirements. 208 | func isTTExprValid(pkg *ast.Package, funcExpr, structExpr ast.Expr) bool { 209 | // Function may have variadic input which must be represented by an 210 | // ast.ArrayType with the same type in the struct. 211 | if vi, ok := funcExpr.(*ast.Ellipsis); ok { 212 | if at, ok := structExpr.(*ast.ArrayType); ok { 213 | if exprEqual(pkg, pkg, vi.Elt, at.Elt) { 214 | return true 215 | } 216 | } 217 | return false 218 | } 219 | if exprEqual(pkg, pkg, funcExpr, structExpr) { 220 | return true 221 | } 222 | // Struct may have a function that returns the necessary type for the 223 | // function expr, without parameters. 224 | if ft, ok := structExpr.(*ast.FuncType); ok { 225 | return ft.Params.NumFields() == 0 && 226 | ft.Results.NumFields() == 1 && 227 | exprEqual(pkg, pkg, funcExpr, ft.Results.List[0].Type) 228 | } 229 | return false 230 | } 231 | -------------------------------------------------------------------------------- /decl_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | // testsIsTTDeclValid are table tests for isTTDeclValid. 12 | var testsIsTTDeclValid = []struct { 13 | tt, f, t string // idents 14 | hasErr bool // whether should return error 15 | }{ 16 | {"ttSimpleMatch", "SimpleMatch", "", false}, 17 | {"ttSimpleMatch", "SimpleMisMatch", "", true}, 18 | {"ttAdvancedMatch", "AdvancedMatch", "", false}, 19 | {"ttAdvancedMatch", "AdvancedMisMatch", "", true}, // flips the chan 20 | {"ttAdvancedMatch", "AdvancedMisMatch2", "", true}, 21 | {"ttReaderMatch", "ReaderMatch", "", false}, 22 | {"ttReaderMatch", "ReaderMisMatch", "", true}, 23 | {"ttReadWriterMatch", "ReadWriterMatch", "", false}, 24 | {"ttReadWriterMatch", "ReadWriterMisMatch", "", true}, 25 | {"ttReaderMatch", "ReaderInterfaceMatch", "", false}, // internal 26 | {"ttReaderMatch", "ReaderInterfaceMisMatch", "", true}, 27 | {"ttReaderExternalStructMatch", "ReaderInterfaceMatch", "", false}, 28 | {"ttReaderExternalStructMatch", "ReaderInterfaceMisMatch", "", true}, 29 | {"ttReaderExternalInterfaceMatch", "ReaderInterfaceMatch", "", false}, 30 | {"ttReaderExternalInterfaceMatch", "ReaderInterfaceMisMatch", "", true}, 31 | {"ttVariadicMatch", "VariadicMatch", "", false}, 32 | {"ttVariadicMatch", "VariadicMisMatch", "", true}, 33 | {"ttVariadicMatch", "VariadicMisMatchType", "", true}, 34 | {"ttStructFunctionMatch", "StructFunctionMatch", "", false}, 35 | {"ttStructFunctionMatch", "StructFunctionMisMatch", "", true}, 36 | {"ttMethodTypeMatch_MethodValueMatch", "MethodValueMatch", "MethodTypeMatch", false}, 37 | {"ttMethodTypeMatch_MethodValueMatch_Pointer", "MethodValueMatch", "MethodTypeMatch", false}, 38 | {"ttMethodTypeMatch_MethodPointerMatch", "MethodPointerMatch", "MethodTypeMatch", false}, 39 | {"ttMethodTypeMatch_MethodPointerMisMatch", "MethodPointerMatch", "MethodTypeMatch", true}, 40 | } 41 | 42 | // TestIsTTDeclValid tests the isTTDeclValid function. 43 | func TestIsTTDeclValid(t *testing.T) { 44 | pkg := getTestPkg(t, "testdata/m", "m") 45 | for _, td := range testsIsTTDeclValid { 46 | pre := fmt.Sprintf("tt : %s : f : %s : t : %s", td.tt, td.f, td.t) 47 | ttDecl := &ttDecl{pkg: pkg, ttIdent: td.tt, 48 | fIdent: td.f, tIdent: td.t} 49 | tt, ok := containsVar(pkg, td.tt) 50 | if !ok { 51 | t.Error(td.tt, "should contain") 52 | t.FailNow() 53 | } 54 | ttDecl.tt = tt 55 | if len(td.t) > 0 { 56 | m, ok := containsMethod(pkg, td.f, td.t, true) 57 | if !ok { 58 | t.Error(pre, td.f, "should contain") 59 | t.FailNow() 60 | } 61 | ttDecl.f = m 62 | xt, ok := containsType(pkg, td.t) 63 | if !ok { 64 | t.Error(pre, td.t, "should contain") 65 | t.FailNow() 66 | } 67 | ttDecl.t = xt 68 | } else { 69 | f, ok := containsFunction(pkg, td.f) 70 | if !ok { 71 | t.Error(pre, td.f, "should contain") 72 | t.FailNow() 73 | } 74 | ttDecl.f = f 75 | } 76 | if err := isTTDeclValid(ttDecl); td.hasErr && err == nil { 77 | t.Error(pre, "expected error, returned nil") 78 | t.FailNow() 79 | } else if !td.hasErr && err != nil { 80 | t.Errorf("%s : not expecting error, returned %v\n", 81 | pre, err) 82 | t.FailNow() 83 | } 84 | } 85 | } 86 | 87 | // genDeclValueWrap put ValueSpecs into the Specs splice in a dummy var GenDecl. 88 | func genDeclValueWrap(vss ...*ast.ValueSpec) *ast.GenDecl { 89 | ss := make([]ast.Spec, 0) 90 | for _, vs := range vss { 91 | ss = append(ss, ast.Spec(vs)) 92 | } 93 | return &ast.GenDecl{Tok: token.VAR, Specs: ss} 94 | } 95 | 96 | // TestIsTTVar tests isTTVar to make sure it matches variable declarations 97 | // starting with "tt", and the negative case. 98 | func TestIsTTVar(t *testing.T) { 99 | pkg := getTestPkg(t, "testdata/x", "x") 100 | if n, ok := containsVar(pkg, "ttExportedFunction"); !ok { 101 | t.Error("should contain") 102 | } else if vs, ident, ok := isTTVar(genDeclValueWrap(n)); !ok { 103 | t.Error("should be tt var") 104 | } else if ident != "ttExportedFunction" { 105 | t.Error("ident does not match") 106 | } else if !reflect.DeepEqual(n, vs) { 107 | t.Error("nodes should equal") 108 | } 109 | // Test a variablet that should not match 110 | if n, ok := containsVar(pkg, "ExportedVar"); !ok { 111 | t.Error("should contain") 112 | } else if _, _, ok := isTTVar(genDeclValueWrap(n)); ok { 113 | t.Error("should not be tt var") 114 | } 115 | // Test the false response for a GenDecl that is not a Var. 116 | if vs, ident, ok := isTTVar(&ast.GenDecl{Tok: token.IMPORT}); ok { 117 | t.Error("should not be tt var") 118 | } else if len(ident) > 0 { 119 | t.Error("should return empty string") 120 | } else if vs != nil { 121 | t.Error("should return nil") 122 | } 123 | } 124 | 125 | // TestIsTTDecl tests isTTDecl checking it returns a proper ttDecl for a 126 | // function and a method. Also checks with a tt decl that does not exist and one 127 | // where the declaration exists but not the method it intends to test. 128 | func TestIsTTDecl(t *testing.T) { 129 | pkg := getTestPkg(t, "testdata/x", "x") 130 | // Test tt decl for function. 131 | if tt, ok := isTTDecl(pkg, "ttExportedFunction"); !ok { 132 | t.Error("should be a tt decl") 133 | } else { 134 | xTT := &ttDecl{} 135 | xTT.pkg = pkg 136 | xTT.ttIdent = "ttExportedFunction" 137 | xTT.fIdent = "ExportedFunction" 138 | xTT.tt, _ = containsVar(pkg, "ttExportedFunction") 139 | xTT.f, _ = containsFunction(pkg, "ExportedFunction") 140 | if !reflect.DeepEqual(tt, xTT) { 141 | t.Error("tt decl not as expected") 142 | } 143 | } 144 | // Test tt decl for method. 145 | if tt, ok := isTTDecl(pkg, "ttExportedType_ExportedMethod"); !ok { 146 | t.Error("should be a tt decl") 147 | } else { 148 | xTT := &ttDecl{} 149 | xTT.pkg = pkg 150 | xTT.ttIdent = "ttExportedType_ExportedMethod" 151 | xTT.fIdent = "ExportedMethod" 152 | xTT.tIdent = "ExportedType" 153 | xTT.tt, _ = containsVar(pkg, "ttExportedType_ExportedMethod") 154 | xTT.f, _ = containsMethod(pkg, "ExportedMethod", "ExportedType", false) 155 | xTT.t, _ = containsType(pkg, "ExportedType") 156 | if !reflect.DeepEqual(tt, xTT) { 157 | t.Error("tt decl not as expected") 158 | } 159 | } 160 | // Test tt decl that does not exist. 161 | if _, ok := isTTDecl(pkg, "ttDoesNotExist"); ok { 162 | t.Error("should not exist") 163 | } 164 | // Test tt decl where the declaration exists but the method it intends 165 | // to test does not exist. 166 | if _, ok := isTTDecl(pkg, "ttExportedType_DoesNotExist"); ok { 167 | t.Error("method should not exist") 168 | } 169 | } 170 | 171 | // TestFileTTIdents tests fileTTIdents making sure it retrieves all identifiers 172 | // that may be tt declarations. 173 | func TestFileTTIdents(t *testing.T) { 174 | idents, err := fileTTIdents("testdata/x/x_pass_test.go") 175 | if err != nil { 176 | t.Errorf(err.Error()) 177 | } 178 | expected := []string{ 179 | "ttExportedFunction", 180 | "ttExportedType_ExportedMethod", 181 | } 182 | if !reflect.DeepEqual(idents, expected) { 183 | t.Errorf("identifiers %s, expected %s", idents, expected) 184 | } 185 | } 186 | 187 | // TestProcessFile attempts to process file, there should be no errors. 188 | func TestProcessFile(t *testing.T) { 189 | if _, err := fileTTDecls("testdata/x/x_pass_test.go", "x"); err != nil { 190 | t.Error("should not get error", err.Error()) 191 | } 192 | if _, err := fileTTDecls("testdata/x/x_fail_test.go", "x"); err == nil { 193 | t.Error("should get an error") 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /lib/diff/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012 Martin Schnabel. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 2. Redistributions in binary form must reproduce the above copyright notice, 9 | this list of conditions and the following disclaimer in the documentation 10 | and/or other materials provided with the distribution. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 13 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 14 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 15 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 16 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 17 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 18 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 19 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 20 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 22 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 23 | -------------------------------------------------------------------------------- /lib/diff/README.md: -------------------------------------------------------------------------------- 1 | diff 2 | ==== 3 | 4 | A difference algorithm package for go. 5 | 6 | The algorithm is described by Eugene Myers in 7 | ["An O(ND) Difference Algorithm and its Variations"](http://www.xmailserver.org/diff2.pdf). 8 | 9 | Example 10 | ------- 11 | You can use diff.Ints, diff.Runes, diff.ByteStrings, and diff.Bytes 12 | 13 | diff.Runes([]rune("sögen"), []rune("mögen")) // returns []Changes{{0,0,1,1}} 14 | 15 | or you can implement diff.Data 16 | 17 | type MixedInput struct { 18 | A []int 19 | B []string 20 | } 21 | func (m *MixedInput) Equal(i, j int) bool { 22 | return m.A[i] == len(m.B[j]) 23 | } 24 | 25 | and call 26 | 27 | m := &MixedInput{..} 28 | diff.Diff(len(m.A), len(m.B), m) 29 | 30 | Also has granularity functions to merge changes that are close by. 31 | 32 | diff.Granular(1, diff.ByteStrings("emtire", "umpire")) // returns []Changes{{0,0,3,3}} 33 | 34 | Documentation at http://godoc.org/github.com/mb0/diff 35 | -------------------------------------------------------------------------------- /lib/diff/diff.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 Martin Schnabel. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package diff implements a difference algorithm. 6 | // The algorithm is described in "An O(ND) Difference Algorithm and its Variations", Eugene Myers, Algorithmica Vol. 1 No. 2, 1986, pp. 251-266. 7 | package diff 8 | 9 | // A type that satisfies diff.Data can be diffed by this package. 10 | // It typically has two sequences A and B of comparable elements. 11 | type Data interface { 12 | // Equal returns whether the elements at i and j are considered equal. 13 | Equal(i, j int) bool 14 | } 15 | 16 | // ByteStrings returns the differences of two strings in bytes. 17 | func ByteStrings(a, b string) []Change { 18 | return Diff(len(a), len(b), &strings{a, b}) 19 | } 20 | 21 | type strings struct{ a, b string } 22 | 23 | func (d *strings) Equal(i, j int) bool { return d.a[i] == d.b[j] } 24 | 25 | // Bytes returns the difference of two byte slices 26 | func Bytes(a, b []byte) []Change { 27 | return Diff(len(a), len(b), &bytes{a, b}) 28 | } 29 | 30 | type bytes struct{ a, b []byte } 31 | 32 | func (d *bytes) Equal(i, j int) bool { return d.a[i] == d.b[j] } 33 | 34 | // Ints returns the difference of two int slices 35 | func Ints(a, b []int) []Change { 36 | return Diff(len(a), len(b), &ints{a, b}) 37 | } 38 | 39 | type ints struct{ a, b []int } 40 | 41 | func (d *ints) Equal(i, j int) bool { return d.a[i] == d.b[j] } 42 | 43 | // Runes returns the difference of two rune slices 44 | func Runes(a, b []rune) []Change { 45 | return Diff(len(a), len(b), &runes{a, b}) 46 | } 47 | 48 | type runes struct{ a, b []rune } 49 | 50 | func (d *runes) Equal(i, j int) bool { return d.a[i] == d.b[j] } 51 | 52 | // Granular merges neighboring changes smaller than the specified granularity. 53 | // The changes must be ordered by ascending positions as returned by this package. 54 | func Granular(granularity int, changes []Change) []Change { 55 | if len(changes) == 0 { 56 | return changes 57 | } 58 | gap := 0 59 | for i := 1; i < len(changes); i++ { 60 | curr := changes[i] 61 | prev := changes[i-gap-1] 62 | // same as curr.B-(prev.B+prev.Ins); consistency is key 63 | if curr.A-(prev.A+prev.Del) <= granularity { 64 | // merge changes: 65 | curr = Change{ 66 | A: prev.A, B: prev.B, // start at same spot 67 | Del: curr.A - prev.A + curr.Del, // from first to end of second 68 | Ins: curr.B - prev.B + curr.Ins, // from first to end of second 69 | } 70 | gap++ 71 | } 72 | changes[i-gap] = curr 73 | } 74 | return changes[:len(changes)-gap] 75 | } 76 | 77 | // Diff returns the differences of data. 78 | // data.Equal is called repeatedly with 0<=i m { 82 | c.flags = make([]byte, n) 83 | } else { 84 | c.flags = make([]byte, m) 85 | } 86 | c.max = n + m + 1 87 | c.compare(0, 0, n, m) 88 | return c.result(n, m) 89 | } 90 | 91 | // A Change contains one or more deletions or inserts 92 | // at one position in two sequences. 93 | type Change struct { 94 | A, B int // position in input a and b 95 | Del int // delete Del elements from input a 96 | Ins int // insert Ins elements from input b 97 | } 98 | 99 | type context struct { 100 | data Data 101 | flags []byte // element bits 1 delete, 2 insert 102 | max int 103 | // forward and reverse d-path endpoint x components 104 | forward, reverse []int 105 | } 106 | 107 | func (c *context) compare(aoffset, boffset, alimit, blimit int) { 108 | // eat common prefix 109 | for aoffset < alimit && boffset < blimit && c.data.Equal(aoffset, boffset) { 110 | aoffset++ 111 | boffset++ 112 | } 113 | // eat common suffix 114 | for alimit > aoffset && blimit > boffset && c.data.Equal(alimit-1, blimit-1) { 115 | alimit-- 116 | blimit-- 117 | } 118 | // both equal or b inserts 119 | if aoffset == alimit { 120 | for boffset < blimit { 121 | c.flags[boffset] |= 2 122 | boffset++ 123 | } 124 | return 125 | } 126 | // a deletes 127 | if boffset == blimit { 128 | for aoffset < alimit { 129 | c.flags[aoffset] |= 1 130 | aoffset++ 131 | } 132 | return 133 | } 134 | x, y := c.findMiddleSnake(aoffset, boffset, alimit, blimit) 135 | c.compare(aoffset, boffset, x, y) 136 | c.compare(x, y, alimit, blimit) 137 | } 138 | 139 | func (c *context) findMiddleSnake(aoffset, boffset, alimit, blimit int) (int, int) { 140 | // midpoints 141 | fmid := aoffset - boffset 142 | rmid := alimit - blimit 143 | // correct offset in d-path slices 144 | foff := c.max - fmid 145 | roff := c.max - rmid 146 | isodd := (rmid-fmid)&1 != 0 147 | maxd := (alimit - aoffset + blimit - boffset + 2) / 2 148 | // allocate when first used 149 | if c.forward == nil { 150 | c.forward = make([]int, 2*c.max) 151 | c.reverse = make([]int, 2*c.max) 152 | } 153 | c.forward[c.max+1] = aoffset 154 | c.reverse[c.max-1] = alimit 155 | var x, y int 156 | for d := 0; d <= maxd; d++ { 157 | // forward search 158 | for k := fmid - d; k <= fmid+d; k += 2 { 159 | if k == fmid-d || k != fmid+d && c.forward[foff+k+1] > c.forward[foff+k-1] { 160 | x = c.forward[foff+k+1] // down 161 | } else { 162 | x = c.forward[foff+k-1] + 1 // right 163 | } 164 | y = x - k 165 | for x < alimit && y < blimit && c.data.Equal(x, y) { 166 | x++ 167 | y++ 168 | } 169 | c.forward[foff+k] = x 170 | if isodd && k > rmid-d && k < rmid+d { 171 | if c.reverse[roff+k] <= c.forward[foff+k] { 172 | return x, x - k 173 | } 174 | } 175 | } 176 | // reverse search x,y correspond to u,v 177 | for k := rmid - d; k <= rmid+d; k += 2 { 178 | if k == rmid+d || k != rmid-d && c.reverse[roff+k-1] < c.reverse[roff+k+1] { 179 | x = c.reverse[roff+k-1] // up 180 | } else { 181 | x = c.reverse[roff+k+1] - 1 // left 182 | } 183 | y = x - k 184 | for x > aoffset && y > boffset && c.data.Equal(x-1, y-1) { 185 | x-- 186 | y-- 187 | } 188 | c.reverse[roff+k] = x 189 | if !isodd && k >= fmid-d && k <= fmid+d { 190 | if c.reverse[roff+k] <= c.forward[foff+k] { 191 | // lookup opposite end 192 | x = c.forward[foff+k] 193 | return x, x - k 194 | } 195 | } 196 | } 197 | } 198 | panic("should never be reached") 199 | } 200 | 201 | func (c *context) result(n, m int) (res []Change) { 202 | var x, y int 203 | for x < n || y < m { 204 | if x < n && y < m && c.flags[x]&1 == 0 && c.flags[y]&2 == 0 { 205 | x++ 206 | y++ 207 | } else { 208 | a := x 209 | b := y 210 | for x < n && (y >= m || c.flags[x]&1 != 0) { 211 | x++ 212 | } 213 | for y < m && (x >= n || c.flags[y]&2 != 0) { 214 | y++ 215 | } 216 | if a < x || b < y { 217 | res = append(res, Change{a, b, x - a, y - b}) 218 | } 219 | } 220 | } 221 | return 222 | } 223 | -------------------------------------------------------------------------------- /lib/diff/diff_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 Martin Schnabel. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package diff_test 6 | 7 | import ( 8 | "github.com/emil2k/tab/lib/diff" 9 | "testing" 10 | ) 11 | 12 | type testcase struct { 13 | name string 14 | a, b []int 15 | res []diff.Change 16 | } 17 | 18 | var tests = []testcase{ 19 | {"shift", 20 | []int{1, 2, 3}, 21 | []int{0, 1, 2, 3}, 22 | []diff.Change{{0, 0, 0, 1}}, 23 | }, 24 | {"push", 25 | []int{1, 2, 3}, 26 | []int{1, 2, 3, 4}, 27 | []diff.Change{{3, 3, 0, 1}}, 28 | }, 29 | {"unshift", 30 | []int{0, 1, 2, 3}, 31 | []int{1, 2, 3}, 32 | []diff.Change{{0, 0, 1, 0}}, 33 | }, 34 | {"pop", 35 | []int{1, 2, 3, 4}, 36 | []int{1, 2, 3}, 37 | []diff.Change{{3, 3, 1, 0}}, 38 | }, 39 | {"all changed", 40 | []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 41 | []int{10, 11, 12, 13, 14}, 42 | []diff.Change{ 43 | {0, 0, 10, 5}, 44 | }, 45 | }, 46 | {"all same", 47 | []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 48 | []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, 49 | []diff.Change{}, 50 | }, 51 | {"wrap", 52 | []int{1}, 53 | []int{0, 1, 2, 3}, 54 | []diff.Change{ 55 | {0, 0, 0, 1}, 56 | {1, 2, 0, 2}, 57 | }, 58 | }, 59 | {"snake", 60 | []int{0, 1, 2, 3, 4, 5}, 61 | []int{1, 2, 3, 4, 5, 6}, 62 | []diff.Change{ 63 | {0, 0, 1, 0}, 64 | {6, 5, 0, 1}, 65 | }, 66 | }, 67 | // note: input is ambiguous 68 | // first two traces differ from fig.1 69 | // it still is a lcs and ses path 70 | {"paper fig. 1", 71 | []int{1, 2, 3, 1, 2, 2, 1}, 72 | []int{3, 2, 1, 2, 1, 3}, 73 | []diff.Change{ 74 | {0, 0, 1, 1}, 75 | {2, 2, 1, 0}, 76 | {5, 4, 1, 0}, 77 | {7, 5, 0, 1}, 78 | }, 79 | }, 80 | } 81 | 82 | func TestDiffAB(t *testing.T) { 83 | for _, test := range tests { 84 | res := diff.Ints(test.a, test.b) 85 | if len(res) != len(test.res) { 86 | t.Error(test.name, "expected length", len(test.res), "for", res) 87 | continue 88 | } 89 | for i, c := range test.res { 90 | if c != res[i] { 91 | t.Error(test.name, "expected ", c, "got", res[i]) 92 | } 93 | } 94 | } 95 | } 96 | 97 | func TestDiffBA(t *testing.T) { 98 | // interesting: fig.1 Diff(b, a) results in the same path as `diff -d a b` 99 | tests[len(tests)-1].res = []diff.Change{ 100 | {0, 0, 2, 0}, 101 | {3, 1, 1, 0}, 102 | {5, 2, 0, 1}, 103 | {7, 5, 0, 1}, 104 | } 105 | for _, test := range tests { 106 | res := diff.Ints(test.b, test.a) 107 | if len(res) != len(test.res) { 108 | t.Error(test.name, "expected length", len(test.res), "for", res) 109 | continue 110 | } 111 | for i, c := range test.res { 112 | // flip change data also 113 | rc := diff.Change{c.B, c.A, c.Ins, c.Del} 114 | if rc != res[i] { 115 | t.Error(test.name, "expected ", rc, "got", res[i]) 116 | } 117 | } 118 | } 119 | } 120 | 121 | func diffsEqual(a, b []diff.Change) bool { 122 | if len(a) != len(b) { 123 | return false 124 | } 125 | for i := 0; i < len(a); i++ { 126 | if a[i] != b[i] { 127 | return false 128 | } 129 | } 130 | return true 131 | } 132 | 133 | func TestGranularStrings(t *testing.T) { 134 | a := "abcdefghijklmnopqrstuvwxyza" 135 | b := "AbCdeFghiJklmnOpqrstUvwxyzab" 136 | // each iteration of i increases granularity and will absorb one more lower-letter-followed-by-upper-letters sequence 137 | changesI := [][]diff.Change{ 138 | {{0, 0, 1, 1}, {2, 2, 1, 1}, {5, 5, 1, 1}, {9, 9, 1, 1}, {14, 14, 1, 1}, {20, 20, 1, 1}, {27, 27, 0, 1}}, 139 | {{0, 0, 3, 3}, {5, 5, 1, 1}, {9, 9, 1, 1}, {14, 14, 1, 1}, {20, 20, 1, 1}, {27, 27, 0, 1}}, 140 | {{0, 0, 6, 6}, {9, 9, 1, 1}, {14, 14, 1, 1}, {20, 20, 1, 1}, {27, 27, 0, 1}}, 141 | {{0, 0, 10, 10}, {14, 14, 1, 1}, {20, 20, 1, 1}, {27, 27, 0, 1}}, 142 | {{0, 0, 15, 15}, {20, 20, 1, 1}, {27, 27, 0, 1}}, 143 | {{0, 0, 21, 21}, {27, 27, 0, 1}}, 144 | {{0, 0, 27, 28}}, 145 | } 146 | for i := 0; i < len(changesI); i++ { 147 | diffs := diff.Granular(i, diff.ByteStrings(a, b)) 148 | if !diffsEqual(diffs, changesI[i]) { 149 | t.Errorf("expected %v, got %v", diffs, changesI[i]) 150 | } 151 | } 152 | } 153 | 154 | func TestDiffRunes(t *testing.T) { 155 | a := []rune("brown fox jumps over the lazy dog") 156 | b := []rune("brwn faax junps ovver the lay dago") 157 | res := diff.Runes(a, b) 158 | echange := []diff.Change{ 159 | {2, 2, 1, 0}, 160 | {7, 6, 1, 2}, 161 | {12, 12, 1, 1}, 162 | {18, 18, 0, 1}, 163 | {27, 28, 1, 0}, 164 | {31, 31, 0, 2}, 165 | {32, 34, 1, 0}, 166 | } 167 | for i, c := range res { 168 | t.Log(c) 169 | if c != echange[i] { 170 | t.Error("expected", echange[i], "got", c) 171 | } 172 | } 173 | } 174 | 175 | func TestDiffByteStrings(t *testing.T) { 176 | a := "brown fox jumps over the lazy dog" 177 | b := "brwn faax junps ovver the lay dago" 178 | res := diff.ByteStrings(a, b) 179 | echange := []diff.Change{ 180 | {2, 2, 1, 0}, 181 | {7, 6, 1, 2}, 182 | {12, 12, 1, 1}, 183 | {18, 18, 0, 1}, 184 | {27, 28, 1, 0}, 185 | {31, 31, 0, 2}, 186 | {32, 34, 1, 0}, 187 | } 188 | for i, c := range res { 189 | t.Log(c) 190 | if c != echange[i] { 191 | t.Error("expected", echange[i], "got", c) 192 | } 193 | } 194 | } 195 | 196 | type ints struct{ a, b []int } 197 | 198 | func (d *ints) Equal(i, j int) bool { return d.a[i] == d.b[j] } 199 | func BenchmarkDiff(b *testing.B) { 200 | t := tests[len(tests)-1] 201 | d := &ints{t.a, t.b} 202 | n, m := len(d.a), len(d.b) 203 | for i := 0; i < b.N; i++ { 204 | diff.Diff(n, m, d) 205 | } 206 | } 207 | 208 | func BenchmarkInts(b *testing.B) { 209 | t := tests[len(tests)-1] 210 | d1 := t.a 211 | d2 := t.b 212 | for i := 0; i < b.N; i++ { 213 | diff.Ints(d1, d2) 214 | } 215 | } 216 | 217 | func BenchmarkDiffRunes(b *testing.B) { 218 | d1 := []rune("1231221") 219 | d2 := []rune("321213") 220 | for i := 0; i < b.N; i++ { 221 | diff.Runes(d1, d2) 222 | } 223 | } 224 | 225 | func BenchmarkDiffBytes(b *testing.B) { 226 | d1 := []byte("lorem ipsum dolor sit amet consectetur") 227 | d2 := []byte("lorem lovesum daenerys targaryen ami consecteture") 228 | for i := 0; i < b.N; i++ { 229 | diff.Bytes(d1, d2) 230 | } 231 | } 232 | 233 | func BenchmarkDiffByteStrings(b *testing.B) { 234 | d1 := "lorem ipsum dolor sit amet consectetur" 235 | d2 := "lorem lovesum daenerys targaryen ami consecteture" 236 | for i := 0; i < b.N; i++ { 237 | diff.ByteStrings(d1, d2) 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /lib/diff/example_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2012 Martin Schnabel. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package diff_test 6 | 7 | import ( 8 | "fmt" 9 | "github.com/emil2k/tab/lib/diff" 10 | ) 11 | 12 | // Diff on inputs with different representations 13 | type MixedInput struct { 14 | A []int 15 | B []string 16 | } 17 | 18 | var names map[string]int 19 | 20 | func (m *MixedInput) Equal(a, b int) bool { 21 | return m.A[a] == names[m.B[b]] 22 | } 23 | 24 | func ExampleDiff() { 25 | names = map[string]int{ 26 | "one": 1, 27 | "two": 2, 28 | "three": 3, 29 | } 30 | 31 | m := &MixedInput{ 32 | []int{1, 2, 3, 1, 2, 2, 1}, 33 | []string{"three", "two", "one", "two", "one", "three"}, 34 | } 35 | changes := diff.Diff(len(m.A), len(m.B), m) 36 | for _, c := range changes { 37 | fmt.Println("change at", c.A, c.B) 38 | } 39 | // Output: 40 | // change at 0 0 41 | // change at 2 2 42 | // change at 5 4 43 | // change at 7 5 44 | } 45 | 46 | func ExampleGranular() { 47 | a := "hElLo!" 48 | b := "hello!" 49 | changes := diff.Granular(5, diff.ByteStrings(a, b)) // ignore small gaps in differences 50 | for l := len(changes) - 1; l >= 0; l-- { 51 | change := changes[l] 52 | b = b[:change.B] + "|" + b[change.B:change.B+change.Ins] + "|" + b[change.B+change.Ins:] 53 | } 54 | fmt.Println(b) 55 | // Output: 56 | // h|ell|o! 57 | } 58 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | ) 7 | 8 | // main gets the GOFILE and GOPACKAGE environment variables set by `go 9 | // generate` and passes them to process. 10 | func main() { 11 | goFile, goPkg := os.Getenv("GOFILE"), os.Getenv("GOPACKAGE") 12 | if len(goFile) == 0 || len(goPkg) == 0 { 13 | fmt.Fprintf(os.Stderr, "tab : command must be called using `go generate`\n") 14 | os.Exit(1) 15 | } 16 | fmt.Fprintf(os.Stdout, "tab : processing file %s in package %s\n", goFile, goPkg) 17 | n, err := process(goFile, goPkg) 18 | if err != nil { 19 | fmt.Fprintf(os.Stderr, "tab : %s\n", err.Error()) 20 | } 21 | // Success. 22 | fmt.Fprintf(os.Stdout, "tab : processed file %s, placed %d table driven test(s)\n", 23 | goFile, n) 24 | os.Exit(0) 25 | } 26 | 27 | // process processes a file in the given package and returns the number of 28 | // table test placed or an error if there is an issue. 29 | func process(file, pkg string) (int, error) { 30 | ttDecls, err := fileTTDecls(file, pkg) 31 | if err != nil { 32 | return 0, fmt.Errorf("error looking for table test declarations : %s", err.Error()) 33 | } 34 | // Put the found declarations in the file. 35 | for _, td := range ttDecls { 36 | if err := putTTDecl(file, *td); err != nil { 37 | return 0, fmt.Errorf("error putting table driven test : %s", err.Error()) 38 | } 39 | } 40 | return len(ttDecls), nil 41 | } 42 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "strconv" 7 | "testing" 8 | ) 9 | 10 | // testCase runs the test case identified by the passed number, it processes the 11 | // file main_test.go in the `a` folder and compares it with the same file in the 12 | // `b` folder. If they don't match it fails the test. 13 | func testCase(t *testing.T, n int) { 14 | casePath := filepath.Join("testdata", "cases", strconv.Itoa(n)) 15 | tmp := getTestDir(t, filepath.Join(casePath, "a")) 16 | defer os.RemoveAll(tmp) 17 | aFile := filepath.Join(tmp, "main_test.go") 18 | process(aFile, "main") 19 | bFile := filepath.Join(casePath, "b", "main_test.go") 20 | testFiles(t, aFile, bFile) 21 | } 22 | 23 | // TestExampleCase runs the example test case. 24 | func TestExampleCase(t *testing.T) { 25 | testCase(t, 1) 26 | } 27 | -------------------------------------------------------------------------------- /put.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/parser" 8 | "go/scanner" 9 | "go/token" 10 | "io/ioutil" 11 | "os" 12 | "strings" 13 | "text/template" 14 | ) 15 | 16 | // putTTDecl either updates or creates the test function described by the passed 17 | // tt declaration under the variable that declares it, in the file specified by 18 | // the path. 19 | func putTTDecl(path string, td ttDecl) error { 20 | // Slurp the file with ReadAll. 21 | content, err := slurpFile(path) 22 | if err != nil { 23 | return err 24 | } 25 | // Find old test function declaration and if necessary remove it. 26 | fs, f, err := parseBytes(content) 27 | if err != nil { 28 | return err 29 | } 30 | rmStart, rmEnd, ok := funcDeclRange(fs, f, td.testName()) 31 | if ok { 32 | content = replaceRange(content, []byte{}, rmStart, rmEnd) 33 | // Need to update the AST and fileset, because content has 34 | // changed and it needs to be used for determine append range. 35 | fs, f, err = parseBytes(content) 36 | if err != nil { 37 | return err 38 | } 39 | } 40 | // Find range where to place the new test declaration. 41 | appendStart, appendEnd, appendEOF, ok := appendRange(fs, f, content, td.ttIdent) 42 | if !ok { 43 | return fmt.Errorf("%s not found in file %s", td.ttIdent, path) 44 | } 45 | // Template out the test function from the declaration. 46 | tdh, err := newTTHolder(td, !appendEOF) 47 | if err != nil { 48 | return err 49 | } 50 | testContent := renderTTTestFunction(*tdh) 51 | content = replaceRange(content, testContent, appendStart, appendEnd) 52 | // Write the new file to disk. 53 | if err := writeFile(path, content); err != nil { 54 | return err 55 | } 56 | return nil 57 | } 58 | 59 | // writeFile writes out the file to the given path, truncates the file if 60 | // necessary. 61 | // Returns an error if there is an issue with opening or writing to the path. 62 | func writeFile(path string, content []byte) error { 63 | f, err := os.OpenFile(path, os.O_WRONLY, 0755) // TODO mirror permissions 64 | if err != nil { 65 | return err 66 | } 67 | defer f.Close() 68 | f.Write(content) 69 | return nil 70 | } 71 | 72 | // slurpFile opens up and read the full content of the specified path. 73 | func slurpFile(path string) ([]byte, error) { 74 | f, err := os.Open(path) 75 | if err != nil { 76 | return nil, err 77 | } 78 | defer f.Close() 79 | content, err := ioutil.ReadAll(f) 80 | if err != nil { 81 | return nil, err 82 | } 83 | return content, err 84 | } 85 | 86 | // parseBytes parses the AST of a file from a byte slice, includes and returns 87 | // all errors. 88 | func parseBytes(in []byte) (*token.FileSet, *ast.File, error) { 89 | fs := token.NewFileSet() 90 | f, err := parser.ParseFile(fs, "DOESNTMATTER", in, 91 | parser.AllErrors|parser.ParseComments) 92 | return fs, f, err 93 | } 94 | 95 | // replaceRange replaces the range specified by the start/end offset with the 96 | // sub slice in the input slice. 97 | func replaceRange(in []byte, sub []byte, start, end int) []byte { 98 | out := make([]byte, 0, len(in)+len(sub)+start-end) 99 | out = append(out, in[:start]...) 100 | out = append(out, sub...) 101 | out = append(out, in[end:]...) 102 | return out 103 | } 104 | 105 | // funcDeclRange if the a func declaration exists in the file with the specified 106 | // ident provides the offset range where it resides, including documentation 107 | // comments (adjacent to declaration). 108 | // If a func declaration is not found returns false for the third result. 109 | func funcDeclRange(fs *token.FileSet, f *ast.File, ident string) (start, end int, ok bool) { 110 | if obj := f.Scope.Lookup(ident); obj != nil { 111 | if fd, ok := obj.Decl.(*ast.FuncDecl); ok { 112 | sp, ep := fd.Pos(), fd.End() 113 | // Determine where the functions documentation begins. 114 | for _, c := range fd.Doc.List { 115 | if c.Pos() < sp { 116 | sp = c.Pos() 117 | } 118 | } 119 | return fs.PositionFor(sp, true).Offset, 120 | fs.PositionFor(ep, true).Offset, true 121 | } 122 | } 123 | return 0, 0, false 124 | } 125 | 126 | // appendRange finds the node with the specified ident in the file's scope and 127 | // returns the range of offsets that includes adjacent whitespace that would 128 | // need to be replaced to append something right after it. 129 | // Returns whether the range reaches the end of file, this is important when 130 | // deciding whether to add whitespace after the node. 131 | // The contents of the file should be passed via src, must be the same size, 132 | // used by Scanner to find the extent of the whitespace up to the next keyword 133 | // or comment. 134 | func appendRange(fs *token.FileSet, f *ast.File, src []byte, ident string) (start, end int, eof, ok bool) { 135 | if obj := f.Scope.Lookup(ident); obj != nil { 136 | if n, ok := obj.Decl.(ast.Node); ok { 137 | sp, ep := n.End(), n.End() 138 | // Scan file to find where the whitespace ends. 139 | s := new(scanner.Scanner) 140 | tf := fs.File(sp) 141 | s.Init(tf, src, nil, scanner.ScanComments) 142 | for { 143 | tp, tok, _ := s.Scan() 144 | if tp < ep { 145 | continue 146 | } else if tp > ep { 147 | if tok == token.EOF { 148 | eof = true 149 | ep = tp 150 | break 151 | } 152 | if tok == token.COMMENT || tok.IsKeyword() { 153 | ep = tp 154 | break 155 | } 156 | } 157 | } 158 | return fs.PositionFor(sp, true).Offset, 159 | fs.PositionFor(ep, true).Offset, eof, true 160 | } 161 | } 162 | return 0, 0, false, false 163 | } 164 | 165 | // ttTmpl holds the table test template used for generating tests. 166 | var ttTmpl = template.Must(template.New("tt").Parse(ttTmplString)) 167 | 168 | // renderTTTestFunction generates the code for the table test function from the 169 | // template holder. 170 | func renderTTTestFunction(tdh ttHolder) []byte { 171 | buf := new(bytes.Buffer) 172 | err := ttTmpl.Execute(buf, tdh) 173 | if err != nil { 174 | panic(fmt.Sprintf("rendering table test function : %v", err)) 175 | } 176 | testContent, _ := ioutil.ReadAll(buf) 177 | return testContent 178 | } 179 | 180 | // ttHolder is a holder to provide to the template engine variables necessary to 181 | // output a table test. 182 | type ttHolder struct { 183 | Name string 184 | CallExpr string // expression for calling function or method 185 | TTIdent string // identifier for the structs slice to range over 186 | Doc string // docstring for the test function. 187 | Params, Results string 188 | Checks []ttCheck 189 | AppendNewlines bool // whether reaches EOF 190 | } 191 | 192 | // ttCheck is a holder to provide to the template engine variables necessary to 193 | // output a check that a value received for a result matches the expected value. 194 | type ttCheck struct { 195 | Name, Expected, Got string 196 | } 197 | 198 | // newTTHolder initiates the variables necessary to render a table test, returns 199 | // a ttHolder. The appendNewLines is used to determine whether new lines need to 200 | // be attached after the test function. 201 | func newTTHolder(td ttDecl, appendNewLines bool) (*ttHolder, error) { 202 | name := td.testName() 203 | i := 0 204 | // Get the struct slide and compile a list of its fields. 205 | tds, ok := isStructSlice(td.tt) 206 | if !ok { 207 | return nil, fmt.Errorf("%s is not a struct slice", td.ttIdent) 208 | } 209 | var fields []string 210 | for _, s := range tds.Fields.List { 211 | for _, n := range s.Names { 212 | fields = append(fields, n.Name) 213 | } 214 | } 215 | // Determine the function or method expression. 216 | var ident string 217 | if len(td.tIdent) > 0 { 218 | ident = fmt.Sprintf("%s.%s", fields[0], td.fIdent) 219 | i++ 220 | } else { 221 | ident = td.fIdent 222 | } 223 | // Determine expressions for the function/method parameters, results, 224 | // and the equivalence checks. 225 | var params, results []string 226 | for _, p := range td.f.Type.Params.List { 227 | switch p.Type.(type) { 228 | case *ast.Ellipsis: 229 | for range p.Names { 230 | params = append(params, fmt.Sprintf("tt.%s...", fields[i])) 231 | i++ 232 | } 233 | default: 234 | for range p.Names { 235 | params = append(params, fmt.Sprintf("tt.%s", fields[i])) 236 | i++ 237 | } 238 | } 239 | } 240 | var checks []ttCheck 241 | addResult := func() { 242 | field := fields[i] 243 | checks = append(checks, 244 | ttCheck{field, fmt.Sprintf("tt.%s", field), field}) 245 | results = append(results, field) 246 | i++ 247 | } 248 | for _, r := range td.f.Type.Results.List { 249 | if len(r.Names) == 0 { // unnamed return 250 | addResult() 251 | continue 252 | } 253 | for range r.Names { 254 | addResult() 255 | } 256 | } 257 | return &ttHolder{ 258 | name, 259 | ident, 260 | td.ttIdent, 261 | renderComment(td.testDoc()), 262 | strings.Join(params, ", "), 263 | strings.Join(results, ", "), 264 | checks, 265 | appendNewLines, 266 | }, nil 267 | } 268 | 269 | // renderComment returns a comment string with a new line roughly every 80 270 | // characters, without splitting up words. At the start of each new line adds 271 | // a "//" to make it a comment. No newline is added at the end. 272 | func renderComment(str string) string { 273 | if len(str) == 0 { 274 | return "" 275 | } 276 | comment := "// " 277 | out := len(comment) 278 | words := strings.Split(str, " ") 279 | for i, w := range words { 280 | comment += w 281 | out += len(w) 282 | if len(words) > i+1 { 283 | if out+len(words[i+1]) > 77 { 284 | comment += "\n// " 285 | out = 0 286 | } else { 287 | // Avoid trailing white space. 288 | comment += " " 289 | out++ 290 | } 291 | } 292 | } 293 | return comment 294 | } 295 | -------------------------------------------------------------------------------- /put_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "go/build" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | ) 9 | 10 | // TestPutTTDecl tests putTTDecl function by copying a test package processing 11 | // a file and checking if the package still builds. 12 | // Processes the file twice to see if multiple operations properly replace a 13 | // rendered test function without breaking the code. 14 | func TestPutTTDecl(t *testing.T) { 15 | pkgPath := getTestDir(t, filepath.FromSlash("testdata/x/")) 16 | defer os.RemoveAll(pkgPath) 17 | file := filepath.Join(pkgPath, "x_pass_test.go") 18 | tds, err := fileTTDecls(file, "x") 19 | if err != nil { 20 | t.Error("error while processing file :", err.Error()) 21 | } 22 | for _, td := range tds { 23 | err := putTTDecl(file, *td) 24 | if err != nil { 25 | t.Error("error while putting tt decl :", err.Error()) 26 | } 27 | } 28 | // Test that the package still builds. 29 | if _, err := build.ImportDir(pkgPath, 0); err != nil { 30 | t.Error("error while building processed package :", err) 31 | } 32 | // Process one more time to test that nothing breaks. 33 | tds, err = fileTTDecls(file, "x") 34 | if err != nil { 35 | t.Error("error while processing file second time :", 36 | err.Error()) 37 | } 38 | for _, td := range tds { 39 | err := putTTDecl(file, *td) 40 | if err != nil { 41 | t.Error("error while putting tt decl second time :", 42 | err.Error()) 43 | } 44 | } 45 | // Test that the package still builds. 46 | if _, err := build.ImportDir(pkgPath, 0); err != nil { 47 | t.Error("error while building processed package second time :", 48 | err) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /put_tmpl.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | var ttTmplString string = ` 4 | 5 | {{ .Doc }} 6 | func {{ .Name }}(t *testing.T) { 7 | for i, tt := range {{ .TTIdent }} { 8 | {{ if .Results }}{{ .Results }} := {{ end }}{{ .CallExpr }}({{ .Params }}){{ range .Checks }} 9 | if {{ .Got }} != {{ .Expected }} { 10 | t.Errorf("%d : {{ .Name }} : got %v, expected %v", i, {{ .Got }}, {{ .Expected }}) 11 | }{{ end }} 12 | } 13 | }{{ if .AppendNewlines }} 14 | 15 | {{ end }} 16 | ` 17 | -------------------------------------------------------------------------------- /testdata/cases/1/a/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func DummyFunction(a, b int) (c, d, e int, f float64, err error) { 4 | // dummy function to test 5 | return 6 | } 7 | -------------------------------------------------------------------------------- /testdata/cases/1/a/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | //go:generate tab 4 | 5 | var ttDummyFunction = []struct { 6 | // inputs 7 | a, b int 8 | // outputs 9 | c, d, e int 10 | f float64 11 | err error 12 | }{ 13 | {1, 2, 5, 6, 7, 5.4, nil}, 14 | {1, 2, 5, 6, 7, 5.4, nil}, 15 | {1, 2, 5, 6, 7, 5.4, nil}, // and on and on ... 16 | } 17 | -------------------------------------------------------------------------------- /testdata/cases/1/b/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func DummyFunction(a, b int) (c, d, e int, f float64, err error) { 4 | // dummy function to test 5 | return 6 | } 7 | -------------------------------------------------------------------------------- /testdata/cases/1/b/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | //go:generate tab 4 | 5 | var ttDummyFunction = []struct { 6 | // inputs 7 | a, b int 8 | // outputs 9 | c, d, e int 10 | f float64 11 | err error 12 | }{ 13 | {1, 2, 5, 6, 7, 5.4, nil}, 14 | {1, 2, 5, 6, 7, 5.4, nil}, 15 | {1, 2, 5, 6, 7, 5.4, nil}, // and on and on ... 16 | } 17 | 18 | // TestTTDummyFunction is an automatically generated table driven test for the 19 | // function DummyFunction using the tests defined in ttDummyFunction. 20 | func TestTTDummyFunction(t *testing.T) { 21 | for i, tt := range ttDummyFunction { 22 | c, d, e, f, err := DummyFunction(tt.a, tt.b) 23 | if c != tt.c { 24 | t.Errorf("%d : c : got %v, expected %v", i, c, tt.c) 25 | } 26 | if d != tt.d { 27 | t.Errorf("%d : d : got %v, expected %v", i, d, tt.d) 28 | } 29 | if e != tt.e { 30 | t.Errorf("%d : e : got %v, expected %v", i, e, tt.e) 31 | } 32 | if f != tt.f { 33 | t.Errorf("%d : f : got %v, expected %v", i, f, tt.f) 34 | } 35 | if err != tt.err { 36 | t.Errorf("%d : err : got %v, expected %v", i, err, tt.err) 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /testdata/m/m.go: -------------------------------------------------------------------------------- 1 | // Package m is meant for contains several test cases for testing tt declaration 2 | // validation. 3 | package m 4 | 5 | import ( 6 | "bufio" 7 | "io" 8 | ) 9 | 10 | // Simple cases 11 | 12 | var ttSimpleMatch = []struct { 13 | n int 14 | out bool 15 | }{} 16 | 17 | func SimpleMatch(n int) bool { 18 | return false 19 | } 20 | 21 | func SimpleMisMatch() bool { 22 | return false 23 | } 24 | 25 | // Advanced cases 26 | 27 | var ttAdvancedMatch = []struct { 28 | a, b, c int 29 | f func(a int, b int) <-chan int 30 | }{} 31 | 32 | func AdvancedMatch(a int, b, c int) func(a, b int) <-chan int { 33 | return nil 34 | } 35 | 36 | func AdvancedMisMatch(a int, b, c int) func(a, b int) chan<- int { 37 | return nil 38 | } 39 | 40 | func AdvancedMisMatch2(a int, b, c int) func() <-chan int { 41 | return nil 42 | } 43 | 44 | // Interfaces cases 45 | 46 | type SimpleReader int 47 | 48 | func (_ SimpleReader) Read(p []byte) (n int, err error) { 49 | return 0, nil 50 | } 51 | 52 | var ttReaderMatch = []struct { 53 | r SimpleReader 54 | }{} 55 | 56 | func ReaderMatch(r io.Reader) {} 57 | 58 | func ReaderMisMatch(r io.Writer) {} 59 | 60 | // Embedded interfaces, also tests that idents don't have to match to meet 61 | // interface requirements and a type that is not based on a first class type. 62 | 63 | type SimpleReadWriter SimpleReader 64 | 65 | func (_ SimpleReadWriter) Read(input []byte) (read int, problem error) { 66 | return 0, nil 67 | } 68 | 69 | func (_ SimpleReadWriter) Write(x []byte) (xx int, xxx error) { 70 | return 0, nil 71 | } 72 | 73 | var ttReadWriterMatch = []struct { 74 | r SimpleReadWriter 75 | }{} 76 | 77 | func ReadWriterMatch(r io.ReadWriter) {} 78 | 79 | func ReadWriterMisMatch(r io.ReadCloser) {} 80 | 81 | // Matching two interfaces, you should be able to pass the types in the struct 82 | // through the function. 83 | // The bufio.ReadWriter and io.ReadWriter have embedded structs and interfaces 84 | // that the program must take into account when looking for methods. 85 | // Also using external package io and and bufio to test that expressions are 86 | // being properly resolved acrosss packages. 87 | 88 | var ttReaderExternalInterfaceMatch = []struct { 89 | r io.ReadWriter 90 | }{} 91 | 92 | var ttReaderExternalStructMatch = []struct { 93 | r *bufio.ReadWriter 94 | }{} 95 | 96 | func ReaderInterfaceMatch(r io.Reader) {} 97 | 98 | func ReaderInterfaceMisMatch(r io.WriteSeeker) {} 99 | 100 | // Variadic input test. 101 | 102 | var ttVariadicMatch = []struct { 103 | n []int 104 | }{} 105 | 106 | func VariadicMatch(n ...int) {} 107 | 108 | func VariadicMisMatch(n int) {} 109 | 110 | func VariadicMisMatchType(n ...string) {} 111 | 112 | // Struct contains functions that returns the type. 113 | 114 | var ttStructFunctionMatch = []struct { 115 | in func() bool 116 | }{} 117 | 118 | func StructFunctionMatch(f bool) {} 119 | 120 | func StructFunctionMisMatch(f int) {} 121 | 122 | // Method tests, with some map and struct inputs. 123 | 124 | type MethodTypeMatch int 125 | 126 | func (m MethodTypeMatch) MethodValueMatch(in map[string]string) {} 127 | 128 | func (m *MethodTypeMatch) MethodPointerMatch(in struct{}) {} 129 | 130 | var ttMethodTypeMatch_MethodValueMatch = []struct { 131 | m MethodTypeMatch 132 | in map[string]string 133 | }{} 134 | 135 | var ttMethodTypeMatch_MethodValueMatch_Pointer = []struct { 136 | m *MethodTypeMatch 137 | in map[string]string 138 | }{} 139 | 140 | var ttMethodTypeMatch_MethodPointerMatch = []struct { 141 | m *MethodTypeMatch 142 | in struct{} 143 | }{} 144 | 145 | // This should not match as the type MethodTypeMatch won't have the method 146 | // pointer match in its method list. 147 | var ttMethodTypeMatch_MethodPointerMisMatch = []struct { 148 | m MethodTypeMatch 149 | in struct{} 150 | }{} 151 | -------------------------------------------------------------------------------- /testdata/s/s.go: -------------------------------------------------------------------------------- 1 | // Package s contains various declartions of struct arrays for testing purposes. 2 | package s 3 | 4 | import ( 5 | "io" 6 | ) 7 | 8 | var StructArray = []struct { 9 | a, b int 10 | c string 11 | }{ 12 | {1, 2, "boo"}, 13 | {3, 4, "shoe"}, 14 | } 15 | 16 | var emptyStructArray = []struct { 17 | a, b int 18 | r io.Reader 19 | f func() error 20 | }{} 21 | 22 | var NotStructArray = []int{1, 2, 3} 23 | 24 | var notArray int = 1 25 | -------------------------------------------------------------------------------- /testdata/x/x.go: -------------------------------------------------------------------------------- 1 | package x 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | func ExportedFunction(a int, b string) error { 8 | return nil 9 | } 10 | 11 | func ExportedFailFunction(a int, b string) error { 12 | return nil 13 | } 14 | 15 | func nonExportedFunction(a int, b string) error { 16 | return nil 17 | } 18 | 19 | var ExportedVar int = 1 20 | 21 | var nonExportedVar int = 1 22 | 23 | type ExportedType struct{} 24 | 25 | type nonExportedType struct{} 26 | 27 | func (e ExportedType) ExportedMethod(a string, b func() bool, c ...int) error { 28 | return fmt.Errorf("nop") 29 | } 30 | 31 | func (e ExportedType) ExportedFailMethod(a string, b func() bool, c ...int) error { 32 | return fmt.Errorf("nop") 33 | } 34 | 35 | func (e ExportedType) nonExportedMethod(a, b int) error { 36 | return fmt.Errorf("nop") 37 | } 38 | -------------------------------------------------------------------------------- /testdata/x/x_fail_test.go: -------------------------------------------------------------------------------- 1 | // This file should contain a file that should FAIL declaration validation. 2 | package x 3 | 4 | var ttExportedFailFunction = []struct { 5 | a rune // should be int 6 | b string 7 | e error 8 | }{ 9 | {1, "beep", nil}, 10 | {2, "bop", nil}, 11 | } 12 | 13 | var ttExportedType_ExportedFailMethod = []struct { 14 | r ExportedType 15 | a string 16 | b func() int // should return a bool 17 | c []int 18 | e error 19 | }{} 20 | -------------------------------------------------------------------------------- /testdata/x/x_pass_test.go: -------------------------------------------------------------------------------- 1 | // This file should contain a file that should PASS declaration validation. 2 | package x 3 | 4 | import ( 5 | "testing" 6 | ) 7 | 8 | var ttExportedFunction = []struct { 9 | a int 10 | b string 11 | e error 12 | }{ 13 | {1, "beep", nil}, 14 | {2, "bop", nil}, 15 | } 16 | 17 | // TestTTExportedFunction OLD DOC should be replaced. 18 | func TestTTExportedFunction(t *testing.T) {} 19 | 20 | var ttExportedType_ExportedMethod = []struct { 21 | r *ExportedType // pointer is always allowed 22 | a string 23 | b func() bool 24 | c []int 25 | e error 26 | }{} 27 | -------------------------------------------------------------------------------- /util_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "go/ast" 7 | "io" 8 | "io/ioutil" 9 | "os" 10 | "path/filepath" 11 | "testing" 12 | 13 | "github.com/emil2k/tab/lib/diff" 14 | ) 15 | 16 | // testFiles compares the contents of two files specified by the passed paths, 17 | // the test fails immediateley in case of error if they are not equal. 18 | // The first file is considered the "got" value while the second is considered 19 | // the "expected" value. 20 | func testFiles(t *testing.T, got, expected string) { 21 | check := func(err error) { 22 | if err != nil { 23 | t.Error(err.Error()) 24 | t.FailNow() 25 | } 26 | } 27 | gotF, err := os.Open(got) 28 | check(err) 29 | defer gotF.Close() 30 | expectedF, err := os.Open(expected) 31 | check(err) 32 | defer expectedF.Close() 33 | gotContent, err := ioutil.ReadAll(gotF) 34 | check(err) 35 | expectedContent, err := ioutil.ReadAll(expectedF) 36 | check(err) 37 | if !bytes.Equal(gotContent, expectedContent) { 38 | changes := diff.Bytes(gotContent, expectedContent) 39 | for _, c := range changes { 40 | if c.Del > 0 { 41 | t.Errorf("diff found, removed :\n%s\n", 42 | gotContent[c.A:c.A+c.Del]) 43 | } 44 | if c.Ins > 0 { 45 | t.Errorf("diff found, inserted :\n%s\n", 46 | expectedContent[c.B:c.B+c.Ins]) 47 | } 48 | } 49 | } 50 | } 51 | 52 | // getTestPkg attempt to get a package, in case of errors it fails and 53 | // terminates the test. 54 | func getTestPkg(t *testing.T, dir, pkgName string) *ast.Package { 55 | pkg, err := getPkg(dir, pkgName) 56 | if err != nil { 57 | t.Errorf("error when getting test package %s from %s : %s\n", 58 | pkgName, dir, err.Error()) 59 | t.FailNow() 60 | } 61 | return pkg 62 | } 63 | 64 | // getTestDirCopy creates a temporary directory copies the src directory, 65 | // returns path to temporary directory. 66 | // In case of errors it immediately fails the test with a proper message. 67 | func getTestDir(t *testing.T, src string) string { 68 | temp, err := ioutil.TempDir("", "tabtest") 69 | if err != nil { 70 | t.Error("error while creating temp dir :", err.Error()) 71 | t.FailNow() 72 | } 73 | err = copyDir(src, temp) 74 | if err != nil { 75 | t.Error("error while copying test directory :", err.Error()) 76 | t.FailNow() 77 | } 78 | return temp 79 | } 80 | 81 | // copyFileJob holds a pending copyFile call. 82 | type copyFileJob struct { 83 | si os.FileInfo 84 | src, dst string 85 | } 86 | 87 | // copyDir recursively copies the src directory to the desination directory. 88 | // Creates directories as necessary. Attempts to chmod everything to the src 89 | // mode. 90 | func copyDir(src, dst string) error { 91 | // First compile a list of copies to execute then execute, otherwise 92 | // infinite copy situations could arise when copying a parent directory 93 | // into a child directory. 94 | cjs := make([]copyFileJob, 0) 95 | walk := func(path string, info os.FileInfo, err error) error { 96 | if err != nil { 97 | return err 98 | } 99 | rel, err := filepath.Rel(src, path) 100 | if err != nil { 101 | return err 102 | } 103 | fileDst := filepath.Join(dst, rel) 104 | cjs = append(cjs, copyFileJob{info, path, fileDst}) 105 | return nil 106 | } 107 | if err := filepath.Walk(src, walk); err != nil { 108 | return err 109 | } 110 | // Execute copies 111 | for _, cj := range cjs { 112 | if err := copyFile(cj.si, cj.src, cj.dst); err != nil { 113 | return err 114 | } 115 | } 116 | return nil 117 | } 118 | 119 | // ErrIrregularFile is returned when attempts are made to copy links, pipes, 120 | // devices, and etc. 121 | var ErrIrregularFile = errors.New("non regular file") 122 | 123 | // copyFile copies a file or directory from src to dst. Creates directories as 124 | // necessary. Attempts to chmod to the src mode. Returns an error if the file 125 | // is src file is irregular, i.e. link, pipe, or device. 126 | func copyFile(si os.FileInfo, src, dst string) (err error) { 127 | switch { 128 | case si.Mode().IsDir(): 129 | return os.MkdirAll(dst, si.Mode()) 130 | case si.Mode().IsRegular(): 131 | closeErr := func(f *os.File) { 132 | // Properly return a close error 133 | if cerr := f.Close(); err == nil { 134 | err = cerr 135 | } 136 | } 137 | sf, err := os.Open(src) 138 | if err != nil { 139 | return err 140 | } 141 | defer closeErr(sf) 142 | df, err := os.Create(dst) 143 | if err != nil { 144 | return err 145 | } 146 | defer closeErr(df) 147 | // Copy contents 148 | if _, err = io.Copy(df, sf); err != nil { 149 | return err 150 | } else if err = df.Sync(); err != nil { 151 | return err 152 | } else { 153 | return df.Chmod(si.Mode()) 154 | } 155 | default: 156 | return ErrIrregularFile 157 | } 158 | } 159 | --------------------------------------------------------------------------------