├── .gitignore ├── _testdata ├── go.mod └── foo.go ├── error.go ├── go.mod ├── doc.go ├── errf_test.go ├── debug.go ├── types.go ├── debug_test.go ├── .github └── workflows │ └── go.yml ├── LICENSE ├── go.sum ├── cmd └── decouple │ ├── decouple_test.go │ └── main.go ├── types_test.go ├── Readme.md ├── decouple_test.go └── decouple.go /.gitignore: -------------------------------------------------------------------------------- 1 | /decouple 2 | /cover.out 3 | -------------------------------------------------------------------------------- /_testdata/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bobg/decouple/testdata 2 | 3 | go 1.23 4 | -------------------------------------------------------------------------------- /error.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import "fmt" 4 | 5 | type derr struct { 6 | error 7 | } 8 | 9 | func errf(format string, args ...any) error { 10 | return derr{error: fmt.Errorf(format, args...)} 11 | } 12 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bobg/decouple 2 | 3 | go 1.23 4 | 5 | toolchain go1.24.0 6 | 7 | require ( 8 | github.com/bobg/errors v1.1.0 9 | github.com/bobg/go-generics/v4 v4.2.0 10 | github.com/bobg/seqs v1.8.0 11 | golang.org/x/tools v0.30.0 12 | ) 13 | 14 | require ( 15 | golang.org/x/mod v0.23.0 // indirect 16 | golang.org/x/sync v0.11.0 // indirect 17 | ) 18 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Decouple analyzes Go packages to find overspecified function parameters. 2 | // If your function takes a *os.File for example, 3 | // but only ever calls Read on it, 4 | // the function can be rewritten to take an io.Reader. 5 | // This generalizes the function, 6 | // making it easier to test 7 | // and decoupling it from whatever the source of the *os.File is. 8 | package decouple 9 | -------------------------------------------------------------------------------- /errf_test.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestErrf(t *testing.T) { 9 | got := errf("What's a %d?", 412) 10 | 11 | var d derr 12 | if !errors.As(got, &d) { 13 | t.Errorf("got %v, want derr", got) 14 | } 15 | 16 | const want = "What's a 412?" 17 | if got.Error() != want { 18 | t.Errorf("got %s, want %s", got, want) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | ) 8 | 9 | func (a *analyzer) debugf(format string, args ...any) { 10 | if !a.debug { 11 | return 12 | } 13 | s := fmt.Sprintf(format, args...) 14 | s = strings.TrimRight(s, "\r\n") 15 | if a.level > 0 { 16 | fmt.Fprint(os.Stderr, strings.Repeat(" ", a.level)) 17 | } 18 | fmt.Fprintln(os.Stderr, s) 19 | } 20 | -------------------------------------------------------------------------------- /types.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import "go/types" 4 | 5 | func getType[T types.Type](typ types.Type) T { 6 | switch typ := typ.(type) { 7 | case T: 8 | return typ 9 | case *types.Alias: 10 | return getType[T](typ.Rhs()) 11 | case *types.Named: 12 | return getType[T](typ.Underlying()) 13 | default: 14 | return zero[T]() 15 | } 16 | } 17 | 18 | // Returns the zero value for any type. 19 | func zero[T any]() (res T) { 20 | return 21 | } 22 | -------------------------------------------------------------------------------- /debug_test.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "testing" 7 | ) 8 | 9 | func TestDebugf(t *testing.T) { 10 | a := analyzer{debug: true, level: 2} 11 | 12 | f, err := os.CreateTemp("", "decouple") 13 | if err != nil { 14 | t.Fatal(err) 15 | } 16 | tmpname := f.Name() 17 | defer os.Remove(tmpname) 18 | defer f.Close() 19 | 20 | oldStderr := os.Stderr 21 | os.Stderr = f 22 | defer func() { os.Stderr = oldStderr }() 23 | 24 | a.debugf("What's a %d?", 412) 25 | 26 | if err := f.Close(); err != nil { 27 | t.Fatal(err) 28 | } 29 | 30 | f, err = os.Open(tmpname) 31 | if err != nil { 32 | t.Fatal(err) 33 | } 34 | defer f.Close() 35 | 36 | got, err := io.ReadAll(f) 37 | if err != nil { 38 | t.Fatal(err) 39 | } 40 | 41 | const want = " What's a 412?\n" 42 | 43 | if string(got) != want { 44 | t.Errorf("got %s, want %s", string(got), want) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 17 | 18 | - name: Set up Go 19 | uses: actions/setup-go@v4 20 | with: 21 | go-version-file: go.mod 22 | 23 | - name: Unit tests 24 | run: go test -coverprofile=cover.out ./... 25 | 26 | - name: Send coverage 27 | uses: shogo82148/actions-goveralls@v1 28 | with: 29 | path-to-profile: cover.out 30 | continue-on-error: true 31 | 32 | - name: Modver 33 | if: ${{ github.event_name == 'pull_request' }} 34 | uses: bobg/modver@v2.12.2 35 | with: 36 | github_token: ${{ secrets.GITHUB_TOKEN }} 37 | pull_request_url: https://github.com/${{ github.repository }}/pull/${{ github.event.number }} 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bob Glickstein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bobg/errors v1.1.0 h1:gsVanPzJMpZQpwY+27/GQYElZez5CuMYwiIpk2A3RGw= 2 | github.com/bobg/errors v1.1.0/go.mod h1:Q4775qBZpnte7EGFJqmvnlB1U4pkI1XmU3qxqdp7Zcc= 3 | github.com/bobg/go-generics/v4 v4.2.0 h1:c3eX8rlFCRrxFnUepwQIA174JK7WuckbdRHf5ARCl7w= 4 | github.com/bobg/go-generics/v4 v4.2.0/go.mod h1:KVwpxEYErjvcqjJSJqVNZd/JEq3SsQzb9t01+82pZGw= 5 | github.com/bobg/seqs v1.8.0 h1:UfmjNlR3PIWfu+7ok4oVuekzlXWDL12g99fyMngNuT4= 6 | github.com/bobg/seqs v1.8.0/go.mod h1:Iw4ESqX24EovuZ+0UHrnPmHYK1UyO9jcAZpPIzlNMa0= 7 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 8 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 9 | github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= 10 | github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= 11 | golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= 12 | golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 13 | golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 14 | golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 15 | golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= 16 | golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= 17 | -------------------------------------------------------------------------------- /cmd/decouple/decouple_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "path/filepath" 7 | "reflect" 8 | "slices" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/bobg/seqs" 13 | ) 14 | 15 | func TestRunJSON(t *testing.T) { 16 | buf := new(bytes.Buffer) 17 | if err := run2(buf, false, true, false, []string{"../.."}); err != nil { 18 | t.Fatal(err) 19 | } 20 | 21 | var ( 22 | got []jtuple 23 | dec = json.NewDecoder(buf) 24 | ) 25 | for dec.More() { 26 | var val jtuple 27 | if err := dec.Decode(&val); err != nil { 28 | t.Fatal(err) 29 | } 30 | val.FileName = filepath.Base(val.FileName) 31 | got = append(got, val) 32 | } 33 | 34 | want := []jtuple{{ 35 | PackageName: "main", 36 | FileName: "main.go", 37 | Line: 113, 38 | Column: 6, 39 | FuncName: "showJSON", 40 | Params: []jparam{{ 41 | Name: "checker", 42 | Methods: []string{ 43 | "NameForMethods", 44 | }, 45 | }}, 46 | }} 47 | 48 | if !reflect.DeepEqual(got, want) { 49 | t.Errorf("got %v, want %v", got, want) 50 | } 51 | } 52 | 53 | func TestRunPlain(t *testing.T) { 54 | buf := new(bytes.Buffer) 55 | if err := run2(buf, false, false, false, []string{"../.."}); err != nil { 56 | t.Fatal(err) 57 | } 58 | 59 | linesSeq, errptr := seqs.Lines(buf) 60 | lines := slices.Collect(linesSeq) 61 | if err := *errptr; err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | if len(lines) != 2 { 66 | t.Fatalf("got %d lines, want 2", len(lines)) 67 | } 68 | if !strings.HasSuffix(lines[0], ": showJSON") { 69 | t.Fatalf(`line 1 is "%s", want something ending in ": showJSON"`, lines[0]) 70 | } 71 | 72 | lines[1] = strings.TrimSpace(lines[1]) 73 | const want = "checker: [NameForMethods]" 74 | if lines[1] != want { 75 | t.Fatalf(`line 2 is "%s", want "%s"`, lines[1], want) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /types_test.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import ( 4 | "fmt" 5 | "go/types" 6 | "testing" 7 | ) 8 | 9 | type typeConstraint interface { 10 | types.Type 11 | comparable 12 | } 13 | 14 | func TestGetType(t *testing.T) { 15 | cases := []struct { 16 | t types.Type 17 | isChan, isSig, isIntf, isMap bool 18 | }{{ 19 | t: types.NewStruct(nil, nil), 20 | }, { 21 | t: types.NewChan(types.SendRecv, types.NewStruct(nil, nil)), 22 | isChan: true, 23 | }, { 24 | t: types.NewSignatureType(nil, nil, nil, nil, nil, false), 25 | isSig: true, 26 | }, { 27 | t: types.NewInterfaceType(nil, nil), 28 | isIntf: true, 29 | }, { 30 | t: types.NewMap(types.NewStruct(nil, nil), types.NewStruct(nil, nil)), 31 | isMap: true, 32 | }} 33 | 34 | for i, tc := range cases { 35 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 36 | checkType[*types.Chan](t, tc.t, tc.isChan) 37 | checkType[*types.Chan](t, types.NewNamed(types.NewTypeName(0, nil, "foo", nil), tc.t, nil), tc.isChan) 38 | 39 | checkType[*types.Signature](t, tc.t, tc.isSig) 40 | checkType[*types.Signature](t, types.NewNamed(types.NewTypeName(0, nil, "foo", nil), tc.t, nil), tc.isSig) 41 | 42 | checkType[*types.Interface](t, tc.t, tc.isIntf) 43 | checkType[*types.Interface](t, types.NewNamed(types.NewTypeName(0, nil, "foo", nil), tc.t, nil), tc.isIntf) 44 | 45 | checkType[*types.Map](t, tc.t, tc.isMap) 46 | checkType[*types.Map](t, types.NewNamed(types.NewTypeName(0, nil, "foo", nil), tc.t, nil), tc.isMap) 47 | }) 48 | } 49 | } 50 | 51 | func checkType[T typeConstraint](t *testing.T, inp types.Type, isType bool) { 52 | t.Helper() 53 | 54 | var zero T 55 | 56 | got := getType[T](inp) 57 | if (got != zero) != isType { 58 | t.Errorf("is-type[%T] is %v, want %v", zero, got != zero, isType) 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /cmd/decouple/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "fmt" 7 | "io" 8 | "maps" 9 | "os" 10 | "slices" 11 | "sort" 12 | "strings" 13 | 14 | "github.com/bobg/errors" 15 | 16 | "github.com/bobg/decouple" 17 | ) 18 | 19 | func main() { 20 | if err := run(); err != nil { 21 | fmt.Fprintf(os.Stderr, "Error: %s\n", err) 22 | os.Exit(1) 23 | } 24 | } 25 | 26 | func run() error { 27 | var ( 28 | deprecated bool 29 | doJSON bool 30 | verbose bool 31 | ) 32 | flag.BoolVar(&deprecated, "deprecated", false, "include deprecated functions in the analysis") 33 | flag.BoolVar(&doJSON, "json", false, "output in JSON format") 34 | flag.BoolVar(&verbose, "v", false, "verbose") 35 | flag.Parse() 36 | 37 | return run2(os.Stdout, deprecated, doJSON, verbose, flag.Args()) 38 | } 39 | 40 | func run2(w io.Writer, deprecated, doJSON, verbose bool, args []string) error { 41 | var dir string 42 | switch len(args) { 43 | case 0: 44 | dir = "." 45 | case 1: 46 | dir = args[0] 47 | default: 48 | return fmt.Errorf("usage: %s [-v] [-json] [DIR]", os.Args[0]) 49 | } 50 | 51 | checker, err := decouple.NewCheckerFromDir(dir) 52 | if err != nil { 53 | return errors.Wrapf(err, "creating checker for %s", dir) 54 | } 55 | checker.Deprecated = deprecated 56 | checker.Verbose = verbose 57 | 58 | tuples, err := checker.Check() 59 | if err != nil { 60 | return errors.Wrapf(err, "checking %s", dir) 61 | } 62 | 63 | sort.Slice(tuples, func(i, j int) bool { 64 | iPos, jPos := tuples[i].Pos(), tuples[j].Pos() 65 | if iPos.Filename < jPos.Filename { 66 | return true 67 | } 68 | if iPos.Filename > jPos.Filename { 69 | return false 70 | } 71 | return iPos.Offset < jPos.Offset 72 | }) 73 | 74 | if doJSON { 75 | err := showJSON(w, checker, tuples) 76 | return errors.Wrap(err, "formatting JSON output") 77 | } 78 | 79 | for _, tuple := range tuples { 80 | var showedFuncName bool 81 | 82 | params := slices.Collect(maps.Keys(tuple.M)) 83 | sort.Strings(params) 84 | for _, param := range params { 85 | mm := tuple.M[param] 86 | if len(mm) == 0 { 87 | continue 88 | } 89 | 90 | if !showedFuncName { 91 | fmt.Fprintf(w, "%s: %s\n", tuple.Pos(), tuple.F.Name.Name) 92 | showedFuncName = true 93 | } 94 | 95 | if pkg, intfName := checker.NameForMethods(mm); intfName != "" { 96 | pkgpath := pkg.PkgPath 97 | if strings.ContainsAny(pkgpath, "./") { 98 | pkgpath = fmt.Sprintf(`"%s"`, pkgpath) 99 | } 100 | fmt.Fprintf(w, " %s: %s.%s\n", param, pkgpath, intfName) 101 | continue 102 | } 103 | 104 | methods := slices.Collect(maps.Keys(tuple.M[param])) 105 | sort.Strings(methods) 106 | fmt.Fprintf(w, " %s: %v\n", param, methods) 107 | } 108 | } 109 | 110 | return nil 111 | } 112 | 113 | func showJSON(w io.Writer, checker decouple.Checker, tuples []decouple.Tuple) error { 114 | enc := json.NewEncoder(w) 115 | enc.SetIndent("", " ") 116 | 117 | for _, tuple := range tuples { 118 | p := tuple.Pos() 119 | jt := jtuple{ 120 | PackageName: tuple.P.Name, 121 | FileName: p.Filename, 122 | Line: p.Line, 123 | Column: p.Column, 124 | FuncName: tuple.F.Name.Name, 125 | } 126 | for param, mm := range tuple.M { 127 | if len(mm) == 0 { 128 | continue 129 | } 130 | jp := jparam{ 131 | Name: param, 132 | Methods: slices.Collect(maps.Keys(mm)), 133 | } 134 | sort.Strings(jp.Methods) 135 | if pkg, intfName := checker.NameForMethods(mm); intfName != "" { 136 | jp.InterfacePkg = pkg.PkgPath 137 | jp.InterfaceName = intfName 138 | } 139 | jt.Params = append(jt.Params, jp) 140 | } 141 | if len(jt.Params) == 0 { 142 | continue 143 | } 144 | sort.Slice(jt.Params, func(i, j int) bool { 145 | return jt.Params[i].Name < jt.Params[j].Name 146 | }) 147 | if err := enc.Encode(jt); err != nil { 148 | return err 149 | } 150 | } 151 | 152 | return nil 153 | } 154 | 155 | type jtuple struct { 156 | PackageName string 157 | FileName string 158 | Line, Column int 159 | FuncName string 160 | Params []jparam 161 | } 162 | 163 | type jparam struct { 164 | Name string 165 | Methods []string `json:",omitempty"` 166 | InterfacePkg string `json:",omitempty"` 167 | InterfaceName string `json:",omitempty"` 168 | } 169 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Decouple - find overspecified function parameters in Go code 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/bobg/decouple.svg)](https://pkg.go.dev/github.com/bobg/decouple) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/bobg/decouple)](https://goreportcard.com/report/github.com/bobg/decouple) 5 | [![Tests](https://github.com/bobg/decouple/actions/workflows/go.yml/badge.svg)](https://github.com/bobg/decouple/actions/workflows/go.yml) 6 | [![Coverage Status](https://coveralls.io/repos/github/bobg/decouple/badge.svg?branch=main)](https://coveralls.io/github/bobg/decouple?branch=main) 7 | [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) 8 | 9 | This is decouple, 10 | a Go package and command that analyzes your Go code 11 | to find “overspecified” function parameters. 12 | 13 | A parameter is overspecified, 14 | and eligible for “decoupling,” 15 | if it has a more-specific type than it actually needs. 16 | 17 | For example, 18 | if your function takes a `*os.File` parameter, 19 | but it’s only ever used for its `Read` method, 20 | it could be specified as an abstract `io.Reader` instead. 21 | 22 | ## Why decouple? 23 | 24 | When you decouple a function parameter from its too-specific type, 25 | you broaden the set of values on which it can operate. 26 | 27 | You also make it easier to test. 28 | For a simple example, 29 | suppose you’re testing this function: 30 | 31 | ```go 32 | func CountLines(f *os.File) (int, error) { 33 | var result int 34 | sc := bufio.NewScanner(f) 35 | for sc.Scan() { 36 | result++ 37 | } 38 | return result, sc.Err() 39 | } 40 | ``` 41 | 42 | Your unit test will need to open a testdata file and pass it to this function to get a result. 43 | But as `decouple` can tell you, 44 | `f` is only ever used as an `io.Reader` 45 | (the type of the argument to [bufio.NewScanner](https://pkg.go.dev/bufio#NewScanner)). 46 | 47 | If you were testing `func CountLines(r io.Reader) (int, error)` instead, 48 | the unit test can simply pass it something like `strings.NewReader("a\nb\nc")`. 49 | 50 | ## Installation 51 | 52 | ```sh 53 | go install github.com/bobg/decouple/cmd/decouple@latest 54 | ``` 55 | 56 | ## Usage 57 | 58 | ```sh 59 | decouple [-v] [-json] [DIR] 60 | ``` 61 | 62 | This produces a report about the Go packages rooted at DIR 63 | (the current directory by default). 64 | With -v, 65 | very verbose debugging output is printed along the way. 66 | With -json, 67 | the output is in JSON format. 68 | 69 | The report will be empty if decouple has no findings. 70 | Otherwise, it will look something like this (without -json): 71 | 72 | ``` 73 | $ decouple 74 | /home/bobg/kodigcs/handle.go:105:18: handleDir 75 | req: [Context] 76 | w: io.Writer 77 | /home/bobg/kodigcs/handle.go:167:18: handleNFO 78 | req: [Context] 79 | w: [Header Write] 80 | /home/bobg/kodigcs/handle.go:428:6: isStale 81 | t: [Before] 82 | /home/bobg/kodigcs/imdb.go:59:6: parseIMDbPage 83 | cl: [Do] 84 | ``` 85 | 86 | This is the output when running decouple on [the current commit](https://github.com/bobg/kodigcs/commit/f4e8cf0e44de0ea98fa7ad4f88705324ff446444) 87 | of [kodigcs](https://github.com/bobg/kodigcs). 88 | It’s saying that: 89 | 90 | - In the function [handleDir](https://github.com/bobg/kodigcs/blob/f4e8cf0e44de0ea98fa7ad4f88705324ff446444/handle.go#L105), 91 | the `req` parameter is being used only for its `Context` method 92 | and so could be declared as `interface{ Context() context.Context }`, 93 | allowing objects other than `*http.Request` values to be passed in here 94 | (or, better still, the function could be rewritten to take a `context.Context` parameter instead); 95 | - Also in [handleDir](https://github.com/bobg/kodigcs/blob/f4e8cf0e44de0ea98fa7ad4f88705324ff446444/handle.go#L105), 96 | `w` could be an `io.Writer`, 97 | allowing more types to be used than just `http.ResponseWriter`; 98 | - Similarly in [handleNFO](https://github.com/bobg/kodigcs/blob/f4e8cf0e44de0ea98fa7ad4f88705324ff446444/handle.go#L167), 99 | `req` is used only for its `Context` method, 100 | and `w` for its `Write` and `Header` methods 101 | (more than `io.Writer`, but less than `http.ResponseWriter`); 102 | - Anything with a `Before(time.Time) bool` method 103 | could be used in [isStale](https://github.com/bobg/kodigcs/blob/f4e8cf0e44de0ea98fa7ad4f88705324ff446444/handle.go#L428), 104 | it does not need to be limited to `time.Time`; 105 | - The `*http.Client` argument of [parseIMDbPage](https://github.com/bobg/kodigcs/blob/f4e8cf0e44de0ea98fa7ad4f88705324ff446444/imdb.go#L59) 106 | is being used only for its `Do` method. 107 | 108 | Note that, 109 | in the report, 110 | the presence of square brackets means “this is a set of methods,” 111 | while the absence of them means “this is an existing type that already has the right method set” 112 | (as in the `io.Writer` line in the example above). 113 | Decouple can’t always find a suitable existing type even when one exists, 114 | and if two or more types match, 115 | it doesn’t always choose the best one. 116 | 117 | The same report with `-json` specified looks like this: 118 | 119 | ``` 120 | { 121 | "PackageName": "main", 122 | "FileName": "/home/bobg/kodigcs/handle.go", 123 | "Line": 105, 124 | "Column": 18, 125 | "FuncName": "handleDir", 126 | "Params": [ 127 | { 128 | "Name": "req", 129 | "Methods": [ 130 | "Context" 131 | ] 132 | }, 133 | { 134 | "Name": "w", 135 | "Methods": [ 136 | "Write" 137 | ], 138 | "InterfaceName": "Writer", 139 | "InterfacePkg": "io" 140 | } 141 | ] 142 | } 143 | { 144 | "PackageName": "main", 145 | "FileName": "/home/bobg/kodigcs/handle.go", 146 | "Line": 167, 147 | "Column": 18, 148 | "FuncName": "handleNFO", 149 | "Params": [ 150 | { 151 | "Name": "req", 152 | "Methods": [ 153 | "Context" 154 | ] 155 | }, 156 | { 157 | "Name": "w", 158 | "Methods": [ 159 | "Header", 160 | "Write" 161 | ] 162 | } 163 | ] 164 | } 165 | { 166 | "PackageName": "main", 167 | "FileName": "/home/bobg/kodigcs/handle.go", 168 | "Line": 428, 169 | "Column": 6, 170 | "FuncName": "isStale", 171 | "Params": [ 172 | { 173 | "Name": "t", 174 | "Methods": [ 175 | "Before" 176 | ] 177 | } 178 | ] 179 | } 180 | { 181 | "PackageName": "main", 182 | "FileName": "/home/bobg/kodigcs/imdb.go", 183 | "Line": 59, 184 | "Column": 6, 185 | "FuncName": "parseIMDbPage", 186 | "Params": [ 187 | { 188 | "Name": "cl", 189 | "Methods": [ 190 | "Do" 191 | ] 192 | } 193 | ] 194 | } 195 | ``` 196 | 197 | ## Performance note 198 | 199 | Replacing overspecified function parameters with more-abstract ones, 200 | which this tool helps you to do, 201 | is often but not always the right thing, 202 | and it should not be done blindly. 203 | 204 | Using Go interfaces can impose an _abstraction penalty_ compared to using concrete types. 205 | Function arguments that could have been on [the stack](https://en.wikipedia.org/wiki/Stack-based_memory_allocation) 206 | may end up in [the heap](https://en.wikipedia.org/wiki/Heap-based_memory_allocation), 207 | and method calls may involve [a virtual-dispatch step](https://en.wikipedia.org/wiki/Dynamic_dispatch). 208 | 209 | In many cases this penalty is small and can be ignored, 210 | especially since the Go compiler may optimize some or all of it away. 211 | But in tight inner loops 212 | and other performance-critical code 213 | it is often preferable to operate only on concrete types when possible. 214 | 215 | That said, 216 | avoid the fallacy of [premature optimization](https://wiki.c2.com/?PrematureOptimization). 217 | Write your code for clarity and utility first. 218 | Then sacrifice those for the sake of performance 219 | not in the places where you _think_ they’ll make a difference, 220 | but in the places where you’ve _measured_ that they’re needed. 221 | -------------------------------------------------------------------------------- /_testdata/foo.go: -------------------------------------------------------------------------------- 1 | package m 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | ) 9 | 10 | // This named type should never be suggested, 11 | // since it is superseded by an identical one in the stdlib. 12 | type JankyReader interface { 13 | Read([]byte) (int, error) 14 | } 15 | 16 | // {"r": {"Read": "func([]byte) (int, error)"}} 17 | // {"r": "io.Reader"} 18 | func F1(r *os.File, n int) ([]byte, error) { 19 | if true { // This exercises the *ast.BlockStmt typeswitch clause. 20 | buf := make([]byte, n) 21 | n, err := r.Read(buf) 22 | return buf[:n], err 23 | } 24 | return nil, nil 25 | } 26 | 27 | // {"r": {"Read": "func([]byte) (int, error)"}} 28 | func F2(r *os.File) ([]byte, error) { 29 | return io.ReadAll(r) 30 | } 31 | 32 | // {} 33 | func F3(lf *io.LimitedReader) ([]byte, int64, error) { 34 | b, err := io.ReadAll((lf)) // extra parens sic 35 | return b, lf.N, err 36 | } 37 | 38 | // {} 39 | func F4(f *os.File) ([]byte, error) { 40 | var f2 *os.File = f // Some day perhaps decouple will be clever enough to know that f and f2 can both be io.Readers. 41 | return io.ReadAll(f2) 42 | } 43 | 44 | // {"r": {"Read": "func([]byte) (int, error)"}} 45 | func F5(r *os.File) ([]byte, error) { 46 | var f2 io.Reader = r 47 | return io.ReadAll(f2) 48 | } 49 | 50 | // {} 51 | func F6(f *os.File) ([]byte, error) { 52 | return F7(f) 53 | } 54 | 55 | // {"rc": {"Close": "func() error", "Read": "func([]byte) (int, error)"}} 56 | // {"rc": "io.ReadCloser"} 57 | func F7(rc *os.File) ([]byte, error) { 58 | defer rc.Close() 59 | goto LABEL 60 | LABEL: 61 | return io.ReadAll(rc) 62 | } 63 | 64 | type intErface int 65 | 66 | // {} 67 | func (i intErface) Read([]byte) (int, error) { 68 | return 0, nil 69 | } 70 | 71 | // {"r": {"Read": "func([]byte) (int, error)"}} 72 | func F8(r intErface) ([]byte, error) { 73 | return io.ReadAll(r) 74 | } 75 | 76 | // {} 77 | func F9(i intErface) int { 78 | return int(i) + 1 79 | } 80 | 81 | // {"r": {"Read": "func([]byte) (int, error)"}} 82 | func F10(r *os.File) ([]byte, error) { 83 | var r2 io.Reader 84 | r2 = r // separate non-defining assignment line sic 85 | return io.ReadAll(r2) 86 | } 87 | 88 | // {"r": {"Read": "func([]byte) (int, error)"}} 89 | func F11(r *os.File) ([]byte, error) { 90 | switch r { 91 | case r: 92 | return io.ReadAll(r) 93 | default: 94 | return nil, nil 95 | } 96 | } 97 | 98 | // {} 99 | func F12(f *os.File) ([]byte, error) { 100 | var f2 os.File 101 | switch f2 { 102 | case *f: 103 | return io.ReadAll(f) 104 | default: 105 | return nil, nil 106 | } 107 | } 108 | 109 | // {"ctx": {"Done": "func() <-chan struct{}"}, 110 | // 111 | // "r": {"Read": "func([]byte) (int, error)"}} 112 | func F13(ctx context.Context, ch chan<- io.Reader, r *os.File) { 113 | for { 114 | select { 115 | case <-ctx.Done(): 116 | return 117 | case ch <- r: 118 | // do nothing 119 | } 120 | } 121 | } 122 | 123 | // {"r": {"Read": "func([]byte) (int, error)"}} 124 | func F14(r *os.File) []io.Reader { 125 | return []io.Reader{r} 126 | } 127 | 128 | type boolErface bool 129 | 130 | // {} 131 | func (b boolErface) Read([]byte) (int, error) { 132 | return 0, nil 133 | } 134 | 135 | // {} 136 | func F15(b boolErface) ([]byte, error) { 137 | switch { 138 | case bool(b): 139 | return io.ReadAll(b) 140 | default: 141 | return nil, nil 142 | } 143 | } 144 | 145 | // {} 146 | func F16(b boolErface) ([]byte, error) { 147 | switch { 148 | case true: 149 | if bool(b) { 150 | return io.ReadAll(b) 151 | } 152 | } 153 | return nil, nil 154 | } 155 | 156 | // {"r": {"Read": "func([]byte) (int, error)"}} 157 | func F17(r *os.File) ([]byte, error) { 158 | var x io.Reader 159 | if r == x { 160 | return nil, nil 161 | } 162 | return io.ReadAll(r) 163 | } 164 | 165 | // {"r": {"Read": "func([]byte) (int, error)"}} 166 | func F17b(r *os.File) ([]byte, error) { 167 | var x io.Reader 168 | if x == r { 169 | return nil, nil 170 | } 171 | return io.ReadAll(r) 172 | } 173 | 174 | // {} 175 | func F18(f *os.File) ([]byte, error) { 176 | if f == nil { 177 | return nil, nil 178 | } 179 | return io.ReadAll(f) 180 | } 181 | 182 | type funcErface func() 183 | 184 | // {} 185 | func (f funcErface) Read([]byte) (int, error) { 186 | return 0, nil 187 | } 188 | 189 | // {} 190 | func F19(f funcErface) ([]byte, error) { 191 | f() 192 | return io.ReadAll(f) 193 | } 194 | 195 | // {"r": {"Read": "func([]byte) (int, error)"}} 196 | func F20(r *os.File) func([]byte) (int, error) { 197 | return r.Read 198 | } 199 | 200 | // {} 201 | func F21(f *os.File) map[*os.File]int { 202 | return map[*os.File]int{f: 0} 203 | } 204 | 205 | // {"rc": {"Close": "func() error", "Read": "func([]byte) (int, error)"}} 206 | func F22(rc *os.File) map[io.ReadCloser]int { 207 | return map[io.ReadCloser]int{rc: 0} 208 | } 209 | 210 | // {} 211 | func F23(f *os.File) *os.File { 212 | return f 213 | } 214 | 215 | // {"rc": {"Close": "func() error", "Read": "func([]byte) (int, error)"}} 216 | func F24(rc *os.File) io.ReadCloser { 217 | return rc 218 | } 219 | 220 | // {"r": {"Read": "func([]byte) (int, error)"}} 221 | func F25(r *os.File) ([]byte, error) { 222 | return func() ([]byte, error) { 223 | return io.ReadAll(r) 224 | }() 225 | } 226 | 227 | // {} 228 | func F26(f *os.File) io.Reader { 229 | return func() *os.File { 230 | return f 231 | }() 232 | } 233 | 234 | // {"r": {"Read": "func([]byte) (int, error)"}} 235 | func F27(r *os.File) (data []byte, err error) { 236 | ch := make(chan struct{}) 237 | go func() { 238 | data, err = io.ReadAll(r) 239 | close(ch) 240 | }() 241 | <-ch 242 | return 243 | } 244 | 245 | // {"r": {"Read": "func([]byte) (int, error)"}} 246 | func F28(r *os.File) map[int]io.Reader { 247 | return map[int]io.Reader{7: r} 248 | } 249 | 250 | // {"r": {"Read": "func([]byte) (int, error)"}} 251 | func F29(r io.ReadCloser) ([]byte, error) { 252 | return io.ReadAll(r) 253 | } 254 | 255 | // {} 256 | func F30(x io.ReadCloser) ([]byte, error) { 257 | defer x.Close() 258 | return io.ReadAll(x) 259 | } 260 | 261 | // {"r": {"Read": "func([]byte) (int, error)"}} 262 | func F31(r *os.File) io.Reader { 263 | x := []io.Reader{r} 264 | return x[0] 265 | } 266 | 267 | // {} 268 | func F32(_ io.Reader) {} 269 | 270 | // {} 271 | func F33(ch <-chan *os.File) ([]byte, error) { 272 | r := <-ch 273 | return io.ReadAll(r) 274 | } 275 | 276 | // {} 277 | func F34(r *os.File, ch chan<- *os.File) ([]byte, error) { 278 | ch <- r 279 | return io.ReadAll(r) 280 | } 281 | 282 | // {"x": {"foo": "func()"}} 283 | func F35(x interface { 284 | foo() 285 | bar() 286 | }) { 287 | x.foo() 288 | } 289 | 290 | // {} 291 | func F36(w io.Writer, inps []*os.File) error { 292 | for _, inp := range inps { 293 | if _, err := io.Copy(w, inp); err != nil { 294 | return err 295 | } 296 | } 297 | return nil 298 | } 299 | 300 | // {} 301 | func F37(r io.Reader) ([]byte, error) { 302 | switch r := r.(type) { 303 | case *os.File: 304 | fmt.Println(r.Name()) 305 | } 306 | return io.ReadAll(r) 307 | } 308 | 309 | // {} 310 | func F38(x int) int { 311 | return x + 1 312 | } 313 | 314 | // {"r": {"Read": "func([]byte) (int, error)"}} 315 | func F39(r *os.File) ([]byte, error) { 316 | type mtype map[io.Reader]io.Reader 317 | m := mtype{r: r} 318 | return io.ReadAll(m[r]) 319 | } 320 | 321 | // {} 322 | func F40[W, R any](w W, r R) error { 323 | if w, ok := any(w).(io.Writer); ok { 324 | if r, ok := any(r).(io.Reader); ok { 325 | _, err := io.Copy(w, r) 326 | return err 327 | } 328 | } 329 | return nil 330 | } 331 | 332 | // {} 333 | func F41(w io.Writer, readers []io.Reader) error { 334 | f := func(w io.Writer, readers ...io.Reader) error { 335 | for _, r := range readers { 336 | if _, err := io.Copy(w, r); err != nil { 337 | return err 338 | } 339 | } 340 | return nil 341 | } 342 | return f(w, readers...) 343 | } 344 | 345 | // {"ctx": {"Done": "func() <-chan struct{}", "Err": "func() error"}, 346 | // 347 | // "f": {"Name": "func() string"}} 348 | func F42(ctx context.Context, f *os.File, ch <-chan struct{}) (string, error) { 349 | select { 350 | case <-ctx.Done(): 351 | return "", ctx.Err() 352 | case <-ch: 353 | return f.Name(), nil 354 | } 355 | } 356 | 357 | // {"f": {"Read": "func([]byte) (int, error)"}} 358 | func F43(w io.Writer, f *os.File) error { 359 | fn := func(readers ...io.Reader) error { 360 | for _, r := range readers { 361 | if _, err := io.Copy(w, r); err != nil { 362 | return err 363 | } 364 | } 365 | return nil 366 | } 367 | return fn(f) 368 | } 369 | 370 | // {} 371 | func F44(s []int, x, y, z int) []int { 372 | return s[(x+1)*2 : y : z] // exercises *ast.ParenExpr and *ast.SliceExpr 373 | } 374 | 375 | // {} 376 | func F45(m map[string]int, k string) int { 377 | return m[k] 378 | } 379 | 380 | // {"f": {"Read": "func([]byte) (int, error)"}} 381 | func F46(f *os.File) ([]byte, error) { 382 | fn := func(r io.Reader) ([]byte, error) { 383 | return io.ReadAll(r) 384 | } 385 | return fn(f) 386 | } 387 | 388 | type t47 struct { 389 | f *os.File 390 | r io.Reader 391 | } 392 | 393 | // {} 394 | func F47(f *os.File) t47 { 395 | return t47{f: f, r: f} 396 | } 397 | 398 | // {"f": {"Read": "func([]byte) (int, error)"}} 399 | func F48(f *os.File) t47 { 400 | return t47{r: f} 401 | } 402 | 403 | // {} 404 | func F49(n int) int { 405 | n++ 406 | return n 407 | } 408 | -------------------------------------------------------------------------------- /decouple_test.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "go/ast" 7 | "go/token" 8 | "go/types" 9 | "maps" 10 | "strings" 11 | "testing" 12 | 13 | "github.com/bobg/go-generics/v4/set" 14 | "golang.org/x/tools/go/packages" 15 | // "github.com/davecgh/go-spew/spew" 16 | ) 17 | 18 | func TestCheck(t *testing.T) { 19 | checker, err := NewCheckerFromDir("_testdata") 20 | if err != nil { 21 | t.Fatal(err) 22 | } 23 | 24 | // if testing.Verbose() { 25 | // checker.Verbose = true 26 | // } 27 | 28 | tuples, err := checker.Check() 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | 33 | for _, tuple := range tuples { 34 | t.Run(tuple.F.Name.Name, func(t *testing.T) { 35 | if tuple.F.Doc == nil { 36 | t.Fatal("no doc") 37 | } 38 | var docb bytes.Buffer 39 | for _, c := range tuple.F.Doc.List { 40 | docb.WriteString(strings.TrimLeft(c.Text, "/")) 41 | docb.WriteByte('\n') 42 | } 43 | 44 | var ( 45 | dec = json.NewDecoder(&docb) 46 | pre map[string]map[string]string 47 | ) 48 | if err := dec.Decode(&pre); err != nil { 49 | t.Fatalf("unmarshaling `%s`: %s", docb.String(), err) 50 | } 51 | 52 | var ( 53 | gotParamNames = set.Collect(maps.Keys(tuple.M)) 54 | wantParamNames = set.Collect(maps.Keys(pre)) 55 | ) 56 | if !gotParamNames.Equal(wantParamNames) { 57 | t.Fatalf("got param names %v, want %v", gotParamNames.Slice(), wantParamNames.Slice()) 58 | } 59 | 60 | for paramName, methods := range pre { 61 | t.Run(paramName, func(t *testing.T) { 62 | var ( 63 | gotMethodNames = set.Collect(maps.Keys(tuple.M[paramName])) 64 | wantMethodNames = set.Collect(maps.Keys(methods)) 65 | ) 66 | if !gotMethodNames.Equal(wantMethodNames) { 67 | t.Fatalf("got method names %v, want %v", gotMethodNames.Slice(), wantMethodNames.Slice()) 68 | } 69 | for methodName, sigstr := range methods { 70 | t.Run(methodName, func(t *testing.T) { 71 | typ, err := types.Eval(tuple.P.Fset, tuple.P.Types, tuple.F.Pos(), sigstr) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | if !types.Identical(tuple.M[paramName][methodName], typ.Type) { 76 | t.Errorf("got %s, want %s", tuple.M[paramName][methodName], typ.Type) 77 | } 78 | }) 79 | } 80 | }) 81 | } 82 | 83 | if !dec.More() { 84 | return 85 | } 86 | 87 | t.Run("intf", func(t *testing.T) { 88 | var intfnames map[string]string 89 | if err := dec.Decode(&intfnames); err != nil { 90 | t.Fatalf("unmarshaling interface names: %s", err) 91 | } 92 | 93 | for paramName, intfname := range intfnames { 94 | t.Run(paramName, func(t *testing.T) { 95 | gotPkg, gotName := checker.NameForMethods(tuple.M[paramName]) 96 | if gotName == "" { 97 | t.Fatalf("no named interface found for param %s", paramName) 98 | } 99 | got := gotPkg.PkgPath + "." + gotName 100 | if got != intfname { 101 | t.Errorf("got %s, want %s", got, intfname) 102 | } 103 | }) 104 | } 105 | }) 106 | }) 107 | } 108 | } 109 | 110 | func TestGetIdent(t *testing.T) { 111 | var expr ast.Expr = &ast.BasicLit{Kind: token.INT, Value: "42"} 112 | 113 | if ident := getIdent(expr); ident != nil { 114 | t.Errorf("got %v, want nil", ident) 115 | } 116 | 117 | expr = ast.NewIdent("foo") 118 | 119 | if ident := getIdent(expr); ident == nil || ident.Name != "foo" { 120 | t.Errorf("got %v, want foo", ident) 121 | } 122 | 123 | expr = &ast.ParenExpr{X: expr} 124 | 125 | if ident := getIdent(expr); ident == nil || ident.Name != "foo" { 126 | t.Errorf("got %v, want foo", ident) 127 | } 128 | } 129 | 130 | func TestBestChooser(t *testing.T) { 131 | cases := []struct { 132 | name string 133 | pkgpath1, pkgpath2 string 134 | modpath1, modpath2 string 135 | isAlias1, isAlias2 bool 136 | isMain1, isMain2 bool 137 | isDirect1, isDirect2 bool 138 | want1 bool 139 | }{{ 140 | name: "stdlib vs non-stdlib", 141 | pkgpath1: "fmt", 142 | pkgpath2: "github.com/bobg/decouple", 143 | modpath2: "github.com/bobg/decouple", 144 | isAlias1: false, 145 | isAlias2: false, 146 | isMain1: false, 147 | isMain2: true, 148 | isDirect1: true, 149 | isDirect2: true, 150 | want1: true, 151 | }, { 152 | name: "non-stdlib vs stdlib", 153 | pkgpath1: "github.com/bobg/decouple", 154 | modpath1: "github.com/bobg/decouple", 155 | pkgpath2: "fmt", 156 | isAlias1: false, 157 | isAlias2: false, 158 | isMain1: true, 159 | isMain2: false, 160 | isDirect1: true, 161 | isDirect2: true, 162 | want1: false, 163 | }, { 164 | name: "higher alias vs lower non-alias in main module", 165 | pkgpath1: "google.golang.org/protobuf/proto", 166 | modpath1: "google.golang.org/protobuf", 167 | pkgpath2: "google.golang.org/protobuf/reflect/protoreflect", 168 | modpath2: "google.golang.org/protobuf", 169 | isMain1: true, 170 | isMain2: true, 171 | isAlias1: true, 172 | isAlias2: false, 173 | want1: true, 174 | }, { 175 | name: "lower non-alias vs higher alias in main module", 176 | pkgpath1: "google.golang.org/protobuf/reflect/protoreflect", 177 | modpath1: "google.golang.org/protobuf", 178 | pkgpath2: "google.golang.org/protobuf/proto", 179 | modpath2: "google.golang.org/protobuf", 180 | isMain1: true, 181 | isMain2: true, 182 | isAlias1: false, 183 | isAlias2: true, 184 | want1: false, 185 | }, { 186 | name: "higher alias vs lower non-alias in direct dependency", 187 | pkgpath1: "google.golang.org/protobuf/proto", 188 | modpath1: "google.golang.org/protobuf", 189 | pkgpath2: "google.golang.org/protobuf/reflect/protoreflect", 190 | modpath2: "google.golang.org/protobuf", 191 | isDirect1: true, 192 | isDirect2: true, 193 | isAlias1: true, 194 | isAlias2: false, 195 | want1: true, 196 | }, { 197 | name: "lower non-alias vs higher alias in direct dependency", 198 | pkgpath1: "google.golang.org/protobuf/reflect/protoreflect", 199 | modpath1: "google.golang.org/protobuf", 200 | pkgpath2: "google.golang.org/protobuf/proto", 201 | modpath2: "google.golang.org/protobuf", 202 | isDirect1: true, 203 | isDirect2: true, 204 | isAlias1: false, 205 | isAlias2: true, 206 | want1: false, 207 | }, { 208 | name: "higher alias vs lower non-alias in other dependency", 209 | pkgpath1: "google.golang.org/protobuf/proto", 210 | modpath1: "google.golang.org/protobuf", 211 | pkgpath2: "google.golang.org/protobuf/reflect/protoreflect", 212 | modpath2: "google.golang.org/protobuf", 213 | isAlias1: true, 214 | isAlias2: false, 215 | want1: true, 216 | }, { 217 | name: "lower non-alias vs higher alias in other dependency", 218 | pkgpath1: "google.golang.org/protobuf/reflect/protoreflect", 219 | modpath1: "google.golang.org/protobuf", 220 | pkgpath2: "google.golang.org/protobuf/proto", 221 | modpath2: "google.golang.org/protobuf", 222 | isAlias1: false, 223 | isAlias2: true, 224 | want1: false, 225 | }, { 226 | name: "same height alias vs non-alias in same other-dependency module", 227 | pkgpath1: "google.golang.org/protobuf/proto", 228 | modpath1: "google.golang.org/protobuf", 229 | pkgpath2: "google.golang.org/protobuf/protoadapt", 230 | modpath2: "google.golang.org/protobuf", 231 | isAlias1: true, 232 | isAlias2: false, 233 | want1: true, 234 | }, { 235 | name: "same height non-alias vs alias in same other-dependency module", 236 | pkgpath1: "google.golang.org/protobuf/proto", 237 | modpath1: "google.golang.org/protobuf", 238 | pkgpath2: "google.golang.org/protobuf/protoadapt", 239 | modpath2: "google.golang.org/protobuf", 240 | isAlias1: false, 241 | isAlias2: true, 242 | want1: true, 243 | }, { 244 | name: "alias vs non-alias in different direct dependency modules", 245 | pkgpath1: "github.com/foo/pkg", 246 | modpath1: "github.com/foo/pkg", 247 | pkgpath2: "github.com/bar/pkg", 248 | modpath2: "github.com/bar/pkg", 249 | isDirect1: true, 250 | isDirect2: true, 251 | isAlias1: true, 252 | isAlias2: false, 253 | want1: false, 254 | }, { 255 | name: "non-alias vs alias in different direct dependency modules", 256 | pkgpath1: "github.com/foo/pkg", 257 | modpath1: "github.com/foo/pkg", 258 | pkgpath2: "github.com/bar/pkg", 259 | modpath2: "github.com/bar/pkg", 260 | isDirect1: true, 261 | isDirect2: true, 262 | isAlias1: false, 263 | isAlias2: true, 264 | want1: true, 265 | }, { 266 | name: "alias vs non-alias in different non-direct-dependency modules", 267 | pkgpath1: "github.com/foo/pkg", 268 | modpath1: "github.com/foo/pkg", 269 | pkgpath2: "github.com/bar/pkg", 270 | modpath2: "github.com/bar/pkg", 271 | isAlias1: true, 272 | isAlias2: false, 273 | want1: false, 274 | }, { 275 | name: "non-alias vs alias in different non-direct-dependency modules", 276 | pkgpath1: "github.com/foo/pkg", 277 | modpath1: "github.com/foo/pkg", 278 | pkgpath2: "github.com/bar/pkg", 279 | modpath2: "github.com/bar/pkg", 280 | isAlias1: false, 281 | isAlias2: true, 282 | want1: true, 283 | }, { 284 | name: "alias vs non-alias in main vs direct-dependency module", 285 | pkgpath1: "github.com/bobg/decouple", 286 | modpath1: "github.com/bobg/decouple", 287 | pkgpath2: "github.com/some/dependency", 288 | modpath2: "github.com/some/dependency", 289 | isMain1: true, 290 | isDirect2: true, 291 | isAlias1: true, 292 | isAlias2: false, 293 | want1: true, 294 | }, { 295 | name: "non-alias vs alias in main vs direct-dependency module", 296 | pkgpath1: "github.com/bobg/decouple", 297 | modpath1: "github.com/bobg/decouple", 298 | pkgpath2: "github.com/some/dependency", 299 | modpath2: "github.com/some/dependency", 300 | isMain1: true, 301 | isDirect2: true, 302 | isAlias1: false, 303 | isAlias2: true, 304 | want1: true, 305 | }, { 306 | name: "alias vs non-alias in direct-dependency vs main module", 307 | pkgpath1: "github.com/bobg/decouple", 308 | modpath1: "github.com/bobg/decouple", 309 | pkgpath2: "github.com/some/dependency", 310 | modpath2: "github.com/some/dependency", 311 | isDirect1: true, 312 | isMain2: true, 313 | isAlias1: true, 314 | isAlias2: false, 315 | want1: false, 316 | }, { 317 | name: "non-alias vs alias in direct-dependency vs main module", 318 | pkgpath1: "github.com/bobg/decouple", 319 | modpath1: "github.com/bobg/decouple", 320 | pkgpath2: "github.com/some/dependency", 321 | modpath2: "github.com/some/dependency", 322 | isDirect1: true, 323 | isMain2: true, 324 | isAlias1: false, 325 | isAlias2: true, 326 | want1: false, 327 | }, { 328 | name: "main module vs direct dependency", 329 | pkgpath1: "github.com/bobg/decouple", 330 | modpath1: "github.com/bobg/decouple", 331 | pkgpath2: "github.com/some/dependency", 332 | modpath2: "github.com/some/dependency", 333 | isMain1: true, 334 | isDirect2: true, 335 | want1: true, 336 | }, { 337 | name: "direct dependency vs main module", 338 | pkgpath1: "github.com/some/dependency", 339 | modpath1: "github.com/some/dependency", 340 | pkgpath2: "github.com/bobg/decouple", 341 | modpath2: "github.com/bobg/decouple", 342 | isDirect1: true, 343 | isMain2: true, 344 | want1: false, 345 | }} 346 | 347 | for _, tc := range cases { 348 | t.Run(tc.name, func(t *testing.T) { 349 | pkg1 := &packages.Package{ 350 | PkgPath: tc.pkgpath1, 351 | Module: &packages.Module{ 352 | Main: tc.isMain1, 353 | Indirect: !tc.isDirect1, 354 | }, 355 | } 356 | var chooser bestChooser 357 | chooser.choose(pkg1, "x", tc.isAlias1) 358 | 359 | pkg2 := &packages.Package{ 360 | PkgPath: tc.pkgpath2, 361 | Module: &packages.Module{ 362 | Main: tc.isMain2, 363 | Indirect: !tc.isDirect2, 364 | }, 365 | } 366 | 367 | chooser.choose(pkg2, "y", tc.isAlias2) 368 | 369 | if tc.want1 { 370 | if chooser.name != "x" { 371 | t.Errorf("got name %q, want x", chooser.name) 372 | } 373 | } else { 374 | if chooser.name != "y" { 375 | t.Errorf("got name %q, want y", chooser.name) 376 | } 377 | } 378 | }) 379 | } 380 | } 381 | -------------------------------------------------------------------------------- /decouple.go: -------------------------------------------------------------------------------- 1 | package decouple 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | "go/types" 8 | "slices" 9 | "strings" 10 | 11 | "github.com/bobg/errors" 12 | "golang.org/x/tools/go/packages" 13 | ) 14 | 15 | // PkgMode is the minimal set of bit flags needed for the Config.Mode field of golang.org/x/go/packages 16 | // for the result to be usable by a Checker. 17 | const PkgMode = packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo | packages.NeedModule 18 | 19 | // Checker is the object that can analyze a directory tree of Go code, 20 | // or a set of packages loaded with "golang.org/x/go/packages".Load, 21 | // or a single such package, 22 | // or a function or function parameter in one. 23 | // 24 | // Set Deprecated to include deprecated functions in the analysis, which are normally excluded. 25 | // Set Verbose to true to get (very) verbose debugging output. 26 | type Checker struct { 27 | Deprecated bool 28 | Verbose bool 29 | 30 | pkgs []*packages.Package 31 | namedInterfaces map[*packages.Package]map[string]namedInterfacePair // pkg -> named interface type -> (method map, is alias) 32 | } 33 | 34 | type namedInterfacePair struct { 35 | mmap MethodMap 36 | isAlias bool 37 | } 38 | 39 | // NewCheckerFromDir creates a new Checker containing packages loaded 40 | // (using "golang.org/x/go/packages".Load) 41 | // from the given directory tree. 42 | func NewCheckerFromDir(dir string) (Checker, error) { 43 | conf := &packages.Config{Dir: dir, Mode: PkgMode} 44 | pkgs, err := packages.Load(conf, "./...") 45 | if err != nil { 46 | return Checker{}, errors.Wrapf(err, "loading packages from %s", dir) 47 | } 48 | for _, pkg := range pkgs { 49 | for _, pkgerr := range pkg.Errors { 50 | err = errors.Join(err, errors.Wrapf(pkgerr, "in package %s", pkg.PkgPath)) 51 | } 52 | } 53 | if err != nil { 54 | return Checker{}, errors.Wrapf(err, "after loading packages from %s", dir) 55 | } 56 | return NewCheckerFromPackages(pkgs), nil 57 | } 58 | 59 | // NewCheckerFromPackages creates a new Checker containing the given packages, 60 | // which should be the result of calling "golang.org/x/go/packages".Load 61 | // with at least the bits in PkgMode set in the Config.Mode field. 62 | func NewCheckerFromPackages(pkgs []*packages.Package) Checker { 63 | namedInterfaces := make(map[*packages.Package]map[string]namedInterfacePair) 64 | for _, pkg := range pkgs { 65 | findNamedInterfaces(pkg, namedInterfaces) 66 | } 67 | return Checker{pkgs: pkgs, namedInterfaces: namedInterfaces} 68 | } 69 | 70 | func findNamedInterfaces(pkg *packages.Package, namedInterfaces map[*packages.Package]map[string]namedInterfacePair) { 71 | if _, ok := namedInterfaces[pkg]; ok { 72 | // Already visited this package. 73 | return 74 | } 75 | 76 | namedInterfacesForPkg := make(map[string]namedInterfacePair) 77 | namedInterfaces[pkg] = namedInterfacesForPkg 78 | 79 | for _, ipkg := range pkg.Imports { 80 | findNamedInterfaces(ipkg, namedInterfaces) 81 | } 82 | 83 | if isInternal(pkg.PkgPath) { 84 | return 85 | } 86 | 87 | for _, file := range pkg.Syntax { 88 | for _, decl := range file.Decls { 89 | gendecl, ok := decl.(*ast.GenDecl) 90 | if !ok { 91 | continue 92 | } 93 | if gendecl.Tok != token.TYPE { 94 | continue 95 | } 96 | for _, spec := range gendecl.Specs { 97 | typespec, ok := spec.(*ast.TypeSpec) 98 | if !ok { 99 | // Should be impossible. 100 | continue 101 | } 102 | if !ast.IsExported(typespec.Name.Name) { 103 | continue 104 | } 105 | obj := pkg.TypesInfo.Defs[typespec.Name] 106 | if obj == nil { 107 | // Should be impossible. 108 | continue 109 | } 110 | intf := getType[*types.Interface](obj.Type()) 111 | if intf == nil { 112 | continue 113 | } 114 | mm := make(MethodMap) 115 | addMethodsToMap(intf, mm) 116 | namedInterfacesForPkg[typespec.Name.Name] = namedInterfacePair{mmap: mm, isAlias: typespec.Assign != token.NoPos} 117 | } 118 | } 119 | } 120 | } 121 | 122 | // Check checks all the packages in the Checker. 123 | // It analyzes the functions in them, 124 | // looking for parameters with concrete types that could be interfaces instead. 125 | // The result is a list of Tuples, 126 | // one for each function checked that has parameters eligible for decoupling. 127 | func (ch Checker) Check() ([]Tuple, error) { 128 | var result []Tuple 129 | 130 | for _, pkg := range ch.pkgs { 131 | pkgResult, err := ch.CheckPackage(pkg) 132 | if err != nil { 133 | return nil, errors.Wrapf(err, "analyzing package %s", pkg.PkgPath) 134 | } 135 | result = append(result, pkgResult...) 136 | } 137 | 138 | return result, nil 139 | } 140 | 141 | // CheckPackage checks a single package. 142 | // It should be one of the packages contained in the Checker. 143 | // The result is a list of Tuples, 144 | // one for each function checked that has parameters eligible for decoupling. 145 | func (ch Checker) CheckPackage(pkg *packages.Package) ([]Tuple, error) { 146 | var result []Tuple 147 | 148 | for _, file := range pkg.Syntax { 149 | for _, decl := range file.Decls { 150 | fndecl, ok := decl.(*ast.FuncDecl) 151 | if !ok { 152 | continue 153 | } 154 | m, err := ch.CheckFunc(pkg, fndecl) 155 | if err != nil { 156 | return nil, errors.Wrapf(err, "analyzing function %s at %s", fndecl.Name.Name, pkg.Fset.Position(fndecl.Name.Pos())) 157 | } 158 | result = append(result, Tuple{ 159 | F: fndecl, 160 | P: pkg, 161 | M: m, 162 | }) 163 | } 164 | } 165 | 166 | return result, nil 167 | } 168 | 169 | // Tuple is the type of a result from Checker.Check and Checker.CheckPackage. 170 | type Tuple struct { 171 | // F is the function declaration that this result is about. 172 | F *ast.FuncDecl 173 | 174 | // P is the package in which the function declaration appears. 175 | P *packages.Package 176 | 177 | // M is a map from the names of function parameters eligible for decoupling 178 | // to MethodMaps for each such parameter. 179 | M map[string]MethodMap 180 | } 181 | 182 | // Pos computes the filename and offset 183 | // of the function name of the Tuple. 184 | func (t Tuple) Pos() token.Position { 185 | return t.P.Fset.Position(t.F.Name.Pos()) 186 | } 187 | 188 | // MethodMap maps a set of method names to their calling signatures. 189 | type MethodMap = map[string]*types.Signature 190 | 191 | // CheckFunc checks a single function declaration, 192 | // which should appear in the given package, 193 | // which should be one of the packages contained in the Checker. 194 | // The result is a map from parameter names eligible for decoupling to MethodMaps. 195 | func (ch Checker) CheckFunc(pkg *packages.Package, fndecl *ast.FuncDecl) (map[string]MethodMap, error) { 196 | if !ch.Deprecated && isDeprecated(fndecl) { 197 | return nil, nil 198 | } 199 | 200 | result := make(map[string]MethodMap) 201 | for _, field := range fndecl.Type.Params.List { 202 | for _, name := range field.Names { 203 | if name.Name == "_" { 204 | continue 205 | } 206 | 207 | nameResult, err := ch.CheckParam(pkg, fndecl, name) 208 | if err != nil { 209 | return nil, errors.Wrapf(err, "analyzing parameter %s of %s", name.Name, fndecl.Name.Name) 210 | } 211 | if len(nameResult) != 0 { 212 | result[name.Name] = nameResult 213 | } 214 | } 215 | } 216 | return result, nil 217 | } 218 | 219 | func isDeprecated(fndecl *ast.FuncDecl) bool { 220 | if fndecl.Doc == nil { 221 | return false 222 | } 223 | pars := strings.Split(fndecl.Doc.Text(), "\n\n") 224 | for _, par := range pars { 225 | if strings.HasPrefix(par, "Deprecated:") { 226 | return true 227 | } 228 | } 229 | return false 230 | } 231 | 232 | // CheckParam checks a single named parameter in a given function declaration, 233 | // which must apepar in the given package, 234 | // which should be one of the packages in the Checker. 235 | // The result is a MethodMap for the parameter, 236 | // and may be nil if the parameter is not eligible for decoupling. 237 | func (ch Checker) CheckParam(pkg *packages.Package, fndecl *ast.FuncDecl, name *ast.Ident) (_ MethodMap, err error) { 238 | defer func() { 239 | if r := recover(); r != nil { 240 | if e, ok := r.(error); ok { 241 | var d derr 242 | if errors.As(e, &d) { 243 | err = d 244 | return 245 | } 246 | } 247 | panic(r) 248 | } 249 | }() 250 | 251 | obj, ok := pkg.TypesInfo.Defs[name] 252 | if !ok { 253 | return nil, fmt.Errorf("no def found for %s", name.Name) 254 | } 255 | 256 | var ( 257 | intf = getType[*types.Interface](obj.Type()) 258 | mm MethodMap 259 | ) 260 | if intf != nil { 261 | mm = make(MethodMap) 262 | addMethodsToMap(intf, mm) 263 | } 264 | a := analyzer{ 265 | name: name, 266 | obj: obj, 267 | pkg: pkg, 268 | objmethods: mm, 269 | methods: make(MethodMap), 270 | enclosingFunc: &funcDeclOrLit{decl: fndecl}, 271 | debug: ch.Verbose, 272 | } 273 | a.debugf("fn %s param %s", fndecl.Name.Name, name.Name) 274 | for _, stmt := range fndecl.Body.List { 275 | if !a.stmt(stmt) { 276 | return nil, nil 277 | } 278 | } 279 | 280 | if len(a.objmethods) > 0 { 281 | if len(a.methods) < len(a.objmethods) { 282 | // A smaller interface will do. 283 | return a.methods, nil 284 | } 285 | return nil, nil 286 | } 287 | return a.methods, nil 288 | } 289 | 290 | // NameForMethods takes a MethodMap 291 | // and returns the name of an interface defining exactly the methods in it, 292 | // if it can find one among the packages in the Checker. 293 | // 294 | // If there are multiple such interfaces, 295 | // one is chosen arbitrarily from the Go standard library if possible, 296 | // otherwise from the "main module" if possible, 297 | // otherwise from any modules directly depended on by the main module, 298 | // and finally from any other module. 299 | // 300 | // Within each category, 301 | // if two packages are from the same module, 302 | // we prefer the one closer to the module root, 303 | // and we prefer a non-alias type over an alias type. 304 | func (ch Checker) NameForMethods(inp MethodMap) (*packages.Package, string) { 305 | var chooser bestChooser 306 | 307 | for pkg, namedInterfaces := range ch.namedInterfaces { 308 | for name, pair := range namedInterfaces { 309 | if !sameMethodMaps(pair.mmap, inp) { 310 | continue 311 | } 312 | chooser.choose(pkg, name, pair.isAlias) 313 | if chooser.isStdlib && !chooser.isAlias { 314 | // Instant winner. 315 | return chooser.pkg, chooser.name 316 | } 317 | } 318 | } 319 | 320 | return chooser.pkg, chooser.name 321 | } 322 | 323 | type bestChooser struct { 324 | pkg *packages.Package 325 | name string 326 | isAlias bool 327 | isStdlib bool 328 | isMainModule bool 329 | isDirectDependency bool 330 | } 331 | 332 | func (b *bestChooser) choose(pkg *packages.Package, name string, isAlias bool) { 333 | switch { 334 | case b.pkg == nil: 335 | b.pkg = pkg 336 | b.name = name 337 | b.isStdlib = isStdlib(pkg.PkgPath) 338 | b.isMainModule = isMainModulePackage(pkg) 339 | b.isDirectDependency = isDirectDependencyOfMainModule(pkg) 340 | b.isAlias = isAlias 341 | 342 | case isStdlib(pkg.PkgPath): 343 | if !b.isStdlib || (b.isAlias && !isAlias) { 344 | b.pkg = pkg 345 | b.name = name 346 | b.isStdlib = true 347 | b.isMainModule = false 348 | b.isDirectDependency = false 349 | b.isAlias = isAlias 350 | } 351 | 352 | case isMainModulePackage(pkg): 353 | if b.isStdlib { 354 | return 355 | } 356 | if b.isMainModule && sameModuleButHigher(b.pkg, pkg) { 357 | return 358 | } 359 | if !b.isMainModule || sameModuleButHigher(pkg, b.pkg) || (b.isAlias && !isAlias) { 360 | b.pkg = pkg 361 | b.name = name 362 | b.isStdlib = false 363 | b.isMainModule = true 364 | b.isDirectDependency = false 365 | b.isAlias = isAlias 366 | } 367 | 368 | case isDirectDependencyOfMainModule(pkg): 369 | if b.isStdlib || b.isMainModule { 370 | return 371 | } 372 | if b.isDirectDependency && sameModuleButHigher(b.pkg, pkg) { 373 | return 374 | } 375 | if !b.isDirectDependency || sameModuleButHigher(pkg, b.pkg) || (b.isAlias && !isAlias) { 376 | b.pkg = pkg 377 | b.name = name 378 | b.isStdlib = false 379 | b.isMainModule = false 380 | b.isDirectDependency = true 381 | b.isAlias = isAlias 382 | } 383 | 384 | case !b.isStdlib && !b.isMainModule && !b.isDirectDependency: 385 | if sameModuleButHigher(pkg, b.pkg) || (!sameModuleButHigher(b.pkg, pkg) && b.isAlias && !isAlias) { 386 | b.pkg = pkg 387 | b.name = name 388 | b.isStdlib = false 389 | b.isMainModule = false 390 | b.isDirectDependency = false 391 | b.isAlias = isAlias 392 | } 393 | } 394 | } 395 | 396 | func isStdlib(pkgPath string) bool { 397 | return !strings.Contains(pkgPath, ".") 398 | } 399 | 400 | // Return true if a and b are in the same module and a is closer to the module root than b. 401 | // If a and b are at the same height in the same module, 402 | // this function considers the length of the package paths. 403 | // Shorter one wins. 404 | func sameModuleButHigher(a, b *packages.Package) bool { 405 | modpath := a.Module.Path 406 | if b.Module.Path != modpath { 407 | return false 408 | } 409 | acount, bcount := strings.Count(a.PkgPath, "/"), strings.Count(b.PkgPath, "/") 410 | if acount != bcount { 411 | return acount < bcount 412 | } 413 | return len(a.PkgPath) < len(b.PkgPath) 414 | } 415 | 416 | func isMainModulePackage(pkg *packages.Package) bool { 417 | return pkg.Module != nil && pkg.Module.Main 418 | } 419 | 420 | func isDirectDependencyOfMainModule(pkg *packages.Package) bool { 421 | return pkg.Module != nil && !pkg.Module.Main && !pkg.Module.Indirect 422 | } 423 | 424 | type funcDeclOrLit struct { 425 | decl *ast.FuncDecl 426 | lit *ast.FuncLit 427 | } 428 | 429 | type analyzer struct { 430 | name *ast.Ident 431 | obj types.Object 432 | pkg *packages.Package 433 | 434 | // objmethods is input: the methodmap for obj's type, 435 | // if that's an interface type. 436 | // methods is output: the set of methods actually used. 437 | objmethods, methods MethodMap 438 | 439 | enclosingFunc *funcDeclOrLit 440 | enclosingSwitchStmt *ast.SwitchStmt 441 | 442 | level int 443 | debug bool 444 | } 445 | 446 | func (a *analyzer) enclosingFuncInfo() (types.Type, token.Position, bool) { 447 | if a.enclosingFunc == nil { 448 | return nil, token.Position{}, false 449 | } 450 | if decl := a.enclosingFunc.decl; decl != nil { 451 | obj, ok := a.pkg.TypesInfo.Defs[decl.Name] 452 | if !ok { 453 | return nil, token.Position{}, false 454 | } 455 | return obj.Type(), a.pos(obj), true 456 | } 457 | lit := a.enclosingFunc.lit 458 | tv, ok := a.pkg.TypesInfo.Types[lit] 459 | if !ok { 460 | return nil, token.Position{}, false 461 | } 462 | return tv.Type, a.pos(lit), true 463 | } 464 | 465 | func (a *analyzer) getSig(expr ast.Expr) *types.Signature { 466 | return getType[*types.Signature](a.pkg.TypesInfo.Types[expr].Type) 467 | } 468 | 469 | // Does expr denote the object in a? 470 | func (a *analyzer) isObj(expr ast.Expr) bool { 471 | switch expr := expr.(type) { 472 | case *ast.Ident: 473 | obj := a.pkg.TypesInfo.Uses[expr] 474 | return obj == a.obj 475 | 476 | case *ast.ParenExpr: 477 | return a.isObj(expr.X) 478 | 479 | default: 480 | return false 481 | } 482 | } 483 | 484 | func (a *analyzer) stmt(stmt ast.Stmt) (ok bool) { 485 | a.level++ 486 | a.debugf("> stmt %#v", stmt) 487 | defer func() { 488 | a.debugf("< stmt %#v %v", stmt, ok) 489 | a.level-- 490 | }() 491 | 492 | if stmt == nil { 493 | return true 494 | } 495 | 496 | switch stmt := stmt.(type) { 497 | case *ast.AssignStmt: 498 | for _, lhs := range stmt.Lhs { 499 | // I think we can ignore the rhs value if a.isObj(lhs). 500 | // What matters is only how our object is being used, 501 | // not what's being assigned to it. 502 | if !a.expr(lhs) { 503 | return false 504 | } 505 | } 506 | for i, rhs := range stmt.Rhs { 507 | // xxx do a recursive analysis of how this var is used! 508 | if a.isObj(rhs) && stmt.Tok != token.DEFINE { 509 | if stmt.Tok != token.ASSIGN { 510 | // Reject OP= 511 | return false 512 | } 513 | tv, ok := a.pkg.TypesInfo.Types[stmt.Lhs[i]] 514 | if !ok { 515 | panic(errf("no type info for lvalue %d in assignment at %s", i, a.pos(stmt))) 516 | } 517 | intf := getType[*types.Interface](tv.Type) 518 | if intf == nil { 519 | return false 520 | } 521 | a.addMethods(intf) 522 | continue 523 | } 524 | if !a.expr(rhs) { 525 | return false 526 | } 527 | } 528 | return true 529 | 530 | case *ast.BlockStmt: 531 | for _, s := range stmt.List { 532 | if !a.stmt(s) { 533 | return false 534 | } 535 | } 536 | return true 537 | 538 | case *ast.BranchStmt: 539 | return true 540 | 541 | case *ast.CaseClause: 542 | for _, expr := range stmt.List { 543 | if a.isObj(expr) { 544 | if a.enclosingSwitchStmt == nil { 545 | panic(errf("case clause with no enclosing switch statement at %s", a.pos(stmt))) 546 | } 547 | if a.enclosingSwitchStmt.Tag == nil { 548 | return false // would require our obj to evaluate as a boolean 549 | } 550 | tv, ok := a.pkg.TypesInfo.Types[a.enclosingSwitchStmt.Tag] 551 | if !ok { 552 | panic(errf("no type info for switch tag at %s", a.pos(a.enclosingSwitchStmt.Tag))) 553 | } 554 | t1, t2 := a.obj.Type(), tv.Type 555 | if !types.AssignableTo(t1, t2) && !types.AssignableTo(t2, t1) { 556 | // "In any comparison, the first operand must be assignable to the type of the second operand, or vice versa." 557 | // https://go.dev/ref/spec#Comparison_operators 558 | return false 559 | } 560 | continue 561 | } 562 | 563 | if !a.expr(expr) { 564 | return false 565 | } 566 | } 567 | for _, s := range stmt.Body { 568 | if !a.stmt(s) { 569 | return false 570 | } 571 | } 572 | return true 573 | 574 | case *ast.CommClause: 575 | if !a.stmt(stmt.Comm) { 576 | return false 577 | } 578 | for _, s := range stmt.Body { 579 | if !a.stmt(s) { 580 | return false 581 | } 582 | } 583 | return true 584 | 585 | case *ast.DeclStmt: 586 | return a.decl(stmt.Decl) 587 | 588 | case *ast.DeferStmt: 589 | return a.expr(stmt.Call) 590 | 591 | case *ast.ExprStmt: 592 | return !a.isObjOrNotExpr(stmt.X) // a.isObj(stmt.X) probably can't happen in a well-formed program. 593 | 594 | case *ast.ForStmt: 595 | if !a.stmt(stmt.Init) { 596 | return false 597 | } 598 | if a.isObjOrNotExpr(stmt.Cond) { 599 | return false 600 | } 601 | if !a.stmt(stmt.Post) { 602 | return false 603 | } 604 | return a.stmt(stmt.Body) 605 | 606 | case *ast.GoStmt: 607 | return a.expr(stmt.Call) 608 | 609 | case *ast.IfStmt: 610 | if !a.stmt(stmt.Init) { 611 | return false 612 | } 613 | if a.isObjOrNotExpr(stmt.Cond) { 614 | return false 615 | } 616 | if !a.stmt(stmt.Body) { 617 | return false 618 | } 619 | return a.stmt(stmt.Else) 620 | 621 | case *ast.IncDecStmt: 622 | return !a.isObjOrNotExpr(stmt.X) 623 | 624 | case *ast.LabeledStmt: 625 | return a.stmt(stmt.Stmt) 626 | 627 | case *ast.RangeStmt: 628 | // As with AssignStmt, 629 | // if our object appears on the lhs we don't care. 630 | if a.isObjOrNotExpr(stmt.X) { 631 | return false 632 | } 633 | return a.stmt(stmt.Body) 634 | 635 | case *ast.ReturnStmt: 636 | for i, expr := range stmt.Results { 637 | if a.isObj(expr) { 638 | typ, fpos, ok := a.enclosingFuncInfo() 639 | if !ok { 640 | panic(errf("no type info for function containing return statement at %s", a.pos(expr))) 641 | } 642 | sig, ok := typ.(*types.Signature) 643 | if !ok { 644 | panic(errf("got %T, want *types.Signature for type of function at %s", typ, fpos)) 645 | } 646 | if i >= sig.Results().Len() { 647 | panic(errf("cannot return %d value(s) from %d-value-returning function at %s", i+1, sig.Results().Len(), a.pos(stmt))) 648 | } 649 | resultvar := sig.Results().At(i) 650 | intf := getType[*types.Interface](resultvar.Type()) 651 | if intf == nil { 652 | return false 653 | } 654 | a.addMethods(intf) 655 | continue 656 | } 657 | if !a.expr(expr) { 658 | return false 659 | } 660 | } 661 | return true 662 | 663 | case *ast.SelectStmt: 664 | return a.stmt(stmt.Body) 665 | 666 | case *ast.SendStmt: 667 | if a.isObjOrNotExpr(stmt.Chan) { 668 | return false 669 | } 670 | if a.isObj(stmt.Value) { 671 | tv, ok := a.pkg.TypesInfo.Types[stmt.Chan] 672 | if !ok { 673 | panic(errf("no type info for channel in send statement at %s", a.pos(stmt))) 674 | } 675 | chtyp := getType[*types.Chan](tv.Type) 676 | if chtyp == nil { 677 | panic(errf("got %T, want channel for type of channel in send statement at %s", tv.Type, a.pos(stmt))) 678 | } 679 | intf := getType[*types.Interface](chtyp.Elem()) 680 | if intf == nil { 681 | return false 682 | } 683 | a.addMethods(intf) 684 | return true 685 | } 686 | return a.expr(stmt.Value) 687 | 688 | case *ast.SwitchStmt: 689 | return a.switchStmt(stmt) 690 | 691 | case *ast.TypeSwitchStmt: 692 | if !a.stmt(stmt.Init) { 693 | return false 694 | } 695 | // Can skip stmt.Assign. 696 | return a.stmt(stmt.Body) 697 | } 698 | 699 | return false 700 | } 701 | 702 | func (a *analyzer) pos(p interface{ Pos() token.Pos }) token.Position { 703 | return a.pkg.Fset.Position(p.Pos()) 704 | } 705 | 706 | type methoder interface { 707 | NumMethods() int 708 | Method(int) *types.Func 709 | } 710 | 711 | func (a *analyzer) addMethods(intf methoder) { 712 | addMethodsToMap(intf, a.methods) 713 | } 714 | 715 | func addMethodsToMap(intf methoder, mm MethodMap) { 716 | for i := 0; i < intf.NumMethods(); i++ { 717 | m := intf.Method(i) 718 | 719 | // m is a *types.Func, and the Type() of a *types.Func is always *types.Signature. 720 | mm[m.Name()] = m.Type().(*types.Signature) 721 | } 722 | } 723 | 724 | func (a *analyzer) expr(expr ast.Expr) (ok bool) { 725 | a.level++ 726 | a.debugf("> expr %#v", expr) 727 | defer func() { 728 | a.debugf("< expr %#v %v", expr, ok) 729 | a.level-- 730 | }() 731 | 732 | if expr == nil { 733 | return true 734 | } 735 | 736 | switch expr := expr.(type) { 737 | case *ast.BinaryExpr: 738 | var other ast.Expr 739 | if a.isObj(expr.X) { 740 | other = expr.Y 741 | } else if a.isObj(expr.Y) { 742 | other = expr.X 743 | } 744 | if other != nil { 745 | switch expr.Op { 746 | case token.EQL, token.NEQ: 747 | if a.isObj(other) { 748 | return true 749 | } 750 | tv, ok := a.pkg.TypesInfo.Types[other] 751 | if !ok { 752 | panic(errf("no type info for expr at %s", a.pos(other))) 753 | } 754 | intf := getType[*types.Interface](tv.Type) 755 | if intf == nil { 756 | return false 757 | } 758 | a.addMethods(intf) 759 | // Continue below. 760 | 761 | default: 762 | return false 763 | } 764 | } 765 | 766 | return a.expr(expr.X) && a.expr(expr.Y) 767 | 768 | case *ast.CallExpr: 769 | if a.isObjOrNotExpr(expr.Fun) { 770 | return false 771 | } 772 | for i, arg := range expr.Args { 773 | if a.isObj(arg) { 774 | if i == len(expr.Args)-1 && expr.Ellipsis != token.NoPos { 775 | // This is "obj..." using our object, requiring it to be a slice. 776 | return false 777 | } 778 | tv, ok := a.pkg.TypesInfo.Types[expr.Fun] 779 | if !ok { 780 | panic(errf("no type info for function in call expression at %s", a.pos(expr))) 781 | } 782 | sig := getType[*types.Signature](tv.Type) 783 | if sig == nil { 784 | // This could be a type conversion expression; e.g. int(x). 785 | if len(expr.Args) == 1 { 786 | return false 787 | } 788 | panic(errf("got %T, want *types.Signature for type of function in call expression at %s", tv.Type, a.pos(expr))) 789 | } 790 | var ( 791 | params = sig.Params() 792 | plen = params.Len() 793 | ptype types.Type 794 | ) 795 | if sig.Variadic() && i >= plen-1 { 796 | ptype = params.At(plen - 1).Type() 797 | slice, ok := ptype.(*types.Slice) 798 | if !ok { 799 | panic(errf("got %T, want slice for type of final parameter of variadic function in call expression at %s", ptype, a.pos(expr))) 800 | } 801 | ptype = slice.Elem() 802 | } else if i >= plen { 803 | panic(errf("cannot send %d argument(s) to %d-parameter function in call expression at %s", i+1, plen, a.pos(expr))) 804 | } else { 805 | ptype = params.At(i).Type() 806 | } 807 | intf := getType[*types.Interface](ptype) 808 | if intf == nil { 809 | return false 810 | } 811 | a.addMethods(intf) 812 | continue 813 | } 814 | if !a.expr(arg) { 815 | return false 816 | } 817 | } 818 | return true 819 | 820 | case *ast.CompositeLit: 821 | // Can skip expr.Type. 822 | for i, elt := range expr.Elts { 823 | if kv, ok := elt.(*ast.KeyValueExpr); ok { 824 | if a.isObj(kv.Key) { 825 | tv, ok := a.pkg.TypesInfo.Types[expr] 826 | if !ok { 827 | panic(errf("no type info for composite literal at %s", a.pos(expr))) 828 | } 829 | mapType := getType[*types.Map](tv.Type) 830 | if mapType == nil { 831 | return false 832 | } 833 | intf := getType[*types.Interface](mapType.Key()) 834 | if intf == nil { 835 | return false 836 | } 837 | a.addMethods(intf) 838 | } else if !a.expr(kv.Key) { 839 | return false 840 | } 841 | if a.isObj(kv.Value) { 842 | tv, ok := a.pkg.TypesInfo.Types[expr] 843 | if !ok { 844 | panic(errf("no type info for composite literal at %s", a.pos(expr))) 845 | } 846 | 847 | literalType := tv.Type 848 | if named, ok := literalType.(*types.Named); ok { // xxx should this be a loop? 849 | literalType = named.Underlying() 850 | } 851 | 852 | var elemType types.Type 853 | 854 | switch literalType := literalType.(type) { 855 | case *types.Map: 856 | elemType = literalType.Elem() 857 | 858 | case *types.Struct: 859 | id := getIdent(kv.Key) 860 | if id == nil { 861 | panic(errf("got %T, want *ast.Ident in key-value entry of struct-typed composite literal at %s", kv.Key, a.pos(kv))) 862 | } 863 | 864 | for j := 0; j < literalType.NumFields(); j++ { 865 | field := literalType.Field(j) 866 | if field.Name() == id.Name { 867 | elemType = field.Type() 868 | break 869 | } 870 | } 871 | if elemType == nil { 872 | panic(errf("assignment to unknown struct field %s at %s", id.Name, a.pos(kv))) 873 | } 874 | 875 | case *types.Slice: 876 | elemType = literalType.Elem() 877 | 878 | case *types.Array: 879 | elemType = literalType.Elem() 880 | 881 | default: 882 | return false 883 | } 884 | 885 | intf := getType[*types.Interface](elemType) 886 | if intf == nil { 887 | return false 888 | } 889 | a.addMethods(intf) 890 | 891 | } else if !a.expr(kv.Value) { 892 | return false 893 | } 894 | continue 895 | } 896 | if a.isObj(elt) { 897 | tv, ok := a.pkg.TypesInfo.Types[expr] 898 | if !ok { 899 | panic(errf("no type info for composite literal at %s", a.pos(expr))) 900 | } 901 | 902 | literalType := tv.Type 903 | if named, ok := literalType.(*types.Named); ok { // xxx should this be a loop? 904 | literalType = named.Underlying() 905 | } 906 | 907 | var elemType types.Type 908 | 909 | switch literalType := literalType.(type) { 910 | case *types.Struct: 911 | if i >= literalType.NumFields() { 912 | panic(errf("cannot assign field %d of %d-field struct at %s", i, literalType.NumFields(), a.pos(elt))) 913 | } 914 | elemType = literalType.Field(i).Type() 915 | 916 | case *types.Slice: 917 | elemType = literalType.Elem() 918 | 919 | case *types.Array: 920 | elemType = literalType.Elem() 921 | } 922 | 923 | intf := getType[*types.Interface](elemType) 924 | if intf == nil { 925 | return false 926 | } 927 | a.addMethods(intf) 928 | 929 | continue 930 | } 931 | if !a.expr(elt) { 932 | return false 933 | } 934 | } 935 | return true 936 | 937 | case *ast.Ellipsis: 938 | return !a.isObjOrNotExpr(expr.Elt) 939 | 940 | case *ast.FuncLit: 941 | return a.funcLit(expr) 942 | 943 | case *ast.Ident: 944 | return true 945 | 946 | case *ast.IndexExpr: 947 | if a.isObjOrNotExpr(expr.X) { 948 | return false 949 | } 950 | if a.isObj(expr.Index) { 951 | // In expression x[index], 952 | // index can be an interface 953 | // if x is a map. 954 | tv, ok := a.pkg.TypesInfo.Types[expr.X] 955 | if !ok { 956 | panic(errf("no type info for index expression at %s", a.pos(expr))) 957 | } 958 | mapType := getType[*types.Map](tv.Type) 959 | if mapType == nil { 960 | return false 961 | } 962 | intf := getType[*types.Interface](mapType.Key()) 963 | if intf == nil { 964 | return false 965 | } 966 | a.addMethods(intf) 967 | return true 968 | } 969 | return a.expr(expr.Index) 970 | 971 | case *ast.IndexListExpr: 972 | if a.isObjOrNotExpr(expr.X) { 973 | return false 974 | } 975 | for _, idx := range expr.Indices { 976 | if a.isObjOrNotExpr(idx) { 977 | return false 978 | } 979 | } 980 | return true 981 | 982 | case *ast.KeyValueExpr: 983 | panic("did not expect to reach the KeyValueExpr clause") 984 | 985 | case *ast.ParenExpr: 986 | return a.expr(expr.X) 987 | 988 | case *ast.SelectorExpr: 989 | if a.isObj(expr.X) { 990 | if sig := a.getSig(expr); sig != nil { 991 | a.methods[expr.Sel.Name] = sig 992 | return true 993 | } 994 | return false 995 | } 996 | return a.expr(expr.X) 997 | 998 | case *ast.SliceExpr: 999 | if a.isObjOrNotExpr(expr.X) { 1000 | return false 1001 | } 1002 | if a.isObjOrNotExpr(expr.Low) { 1003 | return false 1004 | } 1005 | if a.isObjOrNotExpr(expr.High) { 1006 | return false 1007 | } 1008 | return !a.isObjOrNotExpr(expr.Max) 1009 | 1010 | case *ast.StarExpr: 1011 | return !a.isObjOrNotExpr(expr.X) 1012 | 1013 | case *ast.TypeAssertExpr: 1014 | // Can skip expr.Type. 1015 | return a.expr(expr.X) 1016 | 1017 | case *ast.UnaryExpr: 1018 | if a.isObj(expr.X) { 1019 | return expr.Op == token.AND 1020 | } 1021 | return a.expr(expr.X) 1022 | } 1023 | 1024 | return true 1025 | } 1026 | 1027 | func (a *analyzer) isObjOrNotExpr(expr ast.Expr) bool { 1028 | if a.isObj(expr) { 1029 | return true 1030 | } 1031 | return !a.expr(expr) 1032 | } 1033 | 1034 | func (a *analyzer) decl(decl ast.Decl) bool { 1035 | switch decl := decl.(type) { 1036 | case *ast.GenDecl: 1037 | if decl.Tok != token.VAR { 1038 | return true 1039 | } 1040 | for _, spec := range decl.Specs { 1041 | valspec, ok := spec.(*ast.ValueSpec) 1042 | if !ok { 1043 | panic(errf("got %T, want *ast.ValueSpec in variable declaration at %s", spec, a.pos(decl))) 1044 | } 1045 | for _, val := range valspec.Values { 1046 | if a.isObj(val) { 1047 | if valspec.Type == nil { 1048 | continue 1049 | } 1050 | tv, ok := a.pkg.TypesInfo.Types[valspec.Type] 1051 | if !ok { 1052 | panic(errf("no type info for variable declaration at %s", a.pos(valspec))) 1053 | } 1054 | intf := getType[*types.Interface](tv.Type) 1055 | if intf == nil { 1056 | return false 1057 | } 1058 | a.addMethods(intf) 1059 | continue 1060 | } 1061 | if !a.expr(val) { 1062 | return false 1063 | } 1064 | } 1065 | } 1066 | return true 1067 | 1068 | case *ast.FuncDecl: 1069 | outer := a.enclosingFunc 1070 | a.enclosingFunc = &funcDeclOrLit{decl: decl} 1071 | defer func() { a.enclosingFunc = outer }() 1072 | 1073 | return a.stmt(decl.Body) 1074 | 1075 | default: 1076 | return true 1077 | } 1078 | } 1079 | 1080 | func (a *analyzer) funcLit(expr *ast.FuncLit) bool { 1081 | outer := a.enclosingFunc 1082 | a.enclosingFunc = &funcDeclOrLit{lit: expr} 1083 | defer func() { 1084 | a.enclosingFunc = outer 1085 | }() 1086 | 1087 | return a.stmt(expr.Body) 1088 | } 1089 | 1090 | func (a *analyzer) switchStmt(stmt *ast.SwitchStmt) bool { 1091 | outer := a.enclosingSwitchStmt 1092 | a.enclosingSwitchStmt = stmt 1093 | defer func() { 1094 | a.enclosingSwitchStmt = outer 1095 | }() 1096 | 1097 | if !a.stmt(stmt.Init) { 1098 | return false 1099 | } 1100 | // It's OK if stmt.Tag is our object. 1101 | if !a.expr(stmt.Tag) { 1102 | return false 1103 | } 1104 | return a.stmt(stmt.Body) 1105 | } 1106 | 1107 | func getIdent(expr ast.Expr) *ast.Ident { 1108 | switch expr := expr.(type) { 1109 | case *ast.Ident: 1110 | return expr 1111 | case *ast.ParenExpr: 1112 | return getIdent(expr.X) 1113 | default: 1114 | return nil 1115 | } 1116 | } 1117 | 1118 | func isInternal(path string) bool { 1119 | parts := strings.Split(path, "/") 1120 | return slices.Contains(parts, "internal") 1121 | } 1122 | 1123 | func sameMethodMaps(a, b MethodMap) bool { 1124 | if len(a) != len(b) { 1125 | return false 1126 | } 1127 | for name, asig := range a { 1128 | bsig, ok := b[name] 1129 | if !ok { 1130 | return false 1131 | } 1132 | if !types.Identical(asig, bsig) { 1133 | return false 1134 | } 1135 | } 1136 | return true 1137 | } 1138 | --------------------------------------------------------------------------------