├── LICENSE ├── README.md ├── cmd └── add_argument │ ├── add_argument.go │ ├── main.go │ └── refactor.go ├── pos ├── pos.go └── query_pos.go └── test └── original ├── main.go ├── x.go ├── y.go └── z.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Travis Cline 2 | 3 | Permission to use, copy, modify, and/or distribute this software for any purpose 4 | with or without fee is hereby granted, provided that the above copyright notice 5 | and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH 8 | REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND 9 | FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, 10 | INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS 11 | OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER 12 | TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF 13 | THIS SOFTWARE. 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | srcutils 2 | ======== 3 | 4 | utilities to perform modifications on golang codebases 5 | 6 | * license: isc 7 | 8 | Installation 9 | ------------ 10 | 11 | ```sh 12 | $ go get github.com/tmc/srcutils/cmd/add_argument 13 | ``` 14 | 15 | Utilities 16 | --------- 17 | 18 | add_argument 19 | 20 | Adds a new argument to a codebase. 21 | 22 | Example: 23 | 24 | ```sh 25 | $ add_argument -w -arg="ctx context.Context" -pos=$GOPATH/src/github.com/tmc/srcutils/test/original/z.go:#26 github.com/tmc/srcutils/test/original 26 | ``` 27 | 28 | Produces the diff: 29 | https://github.com/tmc/srcutils/commit/e70de1db99149dcf51940d1abbba0beba9779506 30 | -------------------------------------------------------------------------------- /cmd/add_argument/add_argument.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/printer" 8 | "io/ioutil" 9 | "log" 10 | "os" 11 | "regexp" 12 | "strings" 13 | 14 | "golang.org/x/tools/imports" 15 | 16 | "github.com/tmc/srcutils/pos" 17 | ) 18 | 19 | func commandAddArgument(options Options) error { 20 | re, err := regexp.Compile(options.packageNameRe) 21 | if err != nil { 22 | return err 23 | } 24 | 25 | r, err := newRefactor(options.args, re) 26 | if err != nil { 27 | return err 28 | } 29 | parts := strings.SplitN(options.argument, " ", 2) 30 | argumentName, argumentType := parts[0], parts[1] 31 | 32 | return r.addArgument(argumentName, argumentType, options.position, options.skipExists, options.callgraphDepth) 33 | } 34 | 35 | func (r *refactor) addArgument(argumentName, argumentType, position string, skipExists bool, depth int) error { 36 | qpos, err := r.queryPos(position, false) 37 | if err != nil { 38 | return err 39 | } 40 | 41 | funcPositions, callSites, err := r.callersAndCallsites(qpos, depth) 42 | if err != nil { 43 | return err 44 | } 45 | 46 | for _, callPos := range funcPositions { 47 | if err := addArgument(argumentName, argumentType, callPos, skipExists); err != nil { 48 | log.Println(err) 49 | } 50 | } 51 | 52 | for _, callSite := range callSites { 53 | if err := addParameter(argumentName, callSite, skipExists); err != nil { 54 | log.Println(err) 55 | } 56 | } 57 | 58 | modifiedFiles := map[*ast.File]bool{} 59 | for _, pos := range append(funcPositions, callSites...) { 60 | fileNode := pos.Path[len(pos.Path)-1].(*ast.File) 61 | modifiedFiles[fileNode] = true 62 | } 63 | 64 | for file, _ := range modifiedFiles { 65 | fileName := r.iprog.Fset.Position(file.Pos()).Filename 66 | if !r.packageNameRe.MatchString(fileName) { 67 | fmt.Fprintln(os.Stderr, "File didn't match:", fileName) 68 | continue 69 | } 70 | 71 | var buf bytes.Buffer 72 | printer.Fprint(&buf, qpos.Fset, file) 73 | 74 | formatted, err := imports.Process(fileName, buf.Bytes(), nil) 75 | if err != nil { 76 | return err 77 | } 78 | 79 | if options.write { 80 | err := ioutil.WriteFile(fileName, formatted, 0) 81 | if err != nil { 82 | return err 83 | } 84 | log.Println("wrote", fileName) 85 | } else { 86 | fmt.Println(fileName) 87 | fmt.Println(string(formatted)) 88 | } 89 | } 90 | return nil 91 | } 92 | 93 | func addArgument(name, argType string, position *pos.QueryPos, skipExists bool) error { 94 | if len(position.Path) == 0 { 95 | return fmt.Errorf("got empty node path") 96 | } 97 | node := position.Path[0] 98 | 99 | fieldList, ok := node.(*ast.FieldList) 100 | if !ok { 101 | ast.Print(position.Fset, node) 102 | return fmt.Errorf("pos must be in a FieldList, got: %T instead", node) 103 | } 104 | 105 | newField := &ast.Field{ 106 | Names: []*ast.Ident{{Name: name}}, 107 | Type: &ast.Ident{Name: argType}, 108 | } 109 | if len(fieldList.List) > 0 { 110 | if fieldList.List[0].Names[0].Name == name && 111 | fieldList.List[0].Type.(*ast.Ident).Name == argType { 112 | return nil 113 | } 114 | } 115 | fieldList.List = append([]*ast.Field{newField}, fieldList.List...) 116 | return nil 117 | } 118 | 119 | func addParameter(name string, position *pos.QueryPos, skipExists bool) error { 120 | if len(position.Path) == 0 { 121 | return fmt.Errorf("got empty node path") 122 | } 123 | node := position.Path[0] 124 | 125 | fieldList, ok := node.(*ast.CallExpr) 126 | if !ok { 127 | return fmt.Errorf("pos must be in a CallExpr, got: %T instead", node) 128 | } 129 | newParam := &ast.Ident{Name: name} 130 | if len(fieldList.Args) > 0 { 131 | if field, ok := fieldList.Args[0].(*ast.Ident); ok { 132 | if field.Name == name { 133 | return nil 134 | } 135 | } 136 | } 137 | fieldList.Args = append([]ast.Expr{newParam}, fieldList.Args...) 138 | return nil 139 | } 140 | -------------------------------------------------------------------------------- /cmd/add_argument/main.go: -------------------------------------------------------------------------------- 1 | // Program add_argument inserts a new argument into a function and all of it's callers 2 | // 3 | // Example: 4 | // $ add_argument -arg="foo int" -pos=$GOPATH/src/github.com/tmc/srcutils/test/original/z.go:#20 github.com/tmc/srcutils/test/original 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "fmt" 10 | "os" 11 | ) 12 | 13 | type Options struct { 14 | position string // position 15 | argument string // argument to add 16 | args []string // ssa FromArgs 17 | write bool 18 | skipExists bool // skip if specified name and type are already present 19 | packageNameRe string // package name regexp 20 | callgraphDepth int // depth up the callgraph to make modifications 21 | } 22 | 23 | var options Options 24 | 25 | func init() { 26 | flag.StringVar(&options.argument, "arg", "", 27 | "argument to add to the specified function") 28 | flag.StringVar(&options.position, "pos", "", 29 | "Filename and byte offset or extent of a syntax element about which to "+ 30 | "query, e.g. foo.go:#123,#456, bar.go:#123.") 31 | flag.BoolVar(&options.write, "w", false, 32 | "write result to (source) file instead of stdout") 33 | flag.BoolVar(&options.skipExists, "skip-exists", true, 34 | "if an argument appears to exist already don't add it") 35 | flag.StringVar(&options.packageNameRe, "package-regexp", "", 36 | "regular expression that package names much match to be modified") 37 | flag.IntVar(&options.callgraphDepth, "depth", -1, 38 | "callgraph traversal limit (-1 for unlimited) ") 39 | } 40 | 41 | func main() { 42 | flag.Parse() 43 | options.args = flag.Args() 44 | if err := commandAddArgument(options); err != nil { 45 | fmt.Fprintf(os.Stderr, "Error: %s.\n", err) 46 | os.Exit(1) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /cmd/add_argument/refactor.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/build" 7 | "go/parser" 8 | "go/token" 9 | "regexp" 10 | 11 | "github.com/tmc/srcutils/pos" 12 | 13 | "golang.org/x/tools/go/loader" 14 | "golang.org/x/tools/go/pointer" 15 | "golang.org/x/tools/go/ssa" 16 | ) 17 | 18 | type refactor struct { 19 | iprog *loader.Program 20 | prog *ssa.Program 21 | ptraCfg *pointer.Config 22 | packageNameRe *regexp.Regexp 23 | } 24 | 25 | func newRefactor(args []string, packageNameRe *regexp.Regexp) (*refactor, error) { 26 | conf := loader.Config{ 27 | Build: &build.Default, 28 | ParserMode: parser.ParseComments, 29 | } 30 | args, err := conf.FromArgs(args, true) 31 | if err != nil { 32 | return nil, err 33 | } 34 | if len(args) > 0 { 35 | return nil, fmt.Errorf("surplus arguments: %q", args) 36 | } 37 | 38 | iprog, err := conf.Load() 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | var mode ssa.BuilderMode 44 | prog := ssa.NewProgram(iprog.Fset, mode) 45 | prog.Build() 46 | 47 | // For each initial package (specified on the command line), 48 | // if it has a main function, analyze that, 49 | // otherwise analyze its tests, if any. 50 | var testPkgs, mains []*ssa.Package 51 | for _, info := range iprog.InitialPackages() { 52 | initialPkg := prog.Package(info.Pkg) 53 | 54 | // Add package to the pointer analysis scope. 55 | if initialPkg.Func("main") != nil { 56 | mains = append(mains, initialPkg) 57 | } else { 58 | testPkgs = append(testPkgs, initialPkg) 59 | } 60 | } 61 | if testPkgs != nil { 62 | if p := prog.CreateTestMainPackage(testPkgs...); p != nil { 63 | mains = append(mains, p) 64 | } 65 | } 66 | if mains == nil { 67 | return nil, fmt.Errorf("analysis scope has no main and no tests") 68 | } 69 | 70 | return &refactor{ 71 | iprog, 72 | prog, 73 | &pointer.Config{Mains: mains, BuildCallGraph: true}, 74 | packageNameRe, 75 | }, nil 76 | } 77 | 78 | func (r *refactor) callers(qpos *pos.QueryPos) ([]*pos.QueryPos, error) { 79 | pkg := r.prog.Package(qpos.Info.Pkg) 80 | if pkg == nil { 81 | return nil, fmt.Errorf("no SSA package") 82 | } 83 | if !ssa.HasEnclosingFunction(pkg, qpos.Path) { 84 | return nil, fmt.Errorf("this position is not inside a function") 85 | } 86 | 87 | target := ssa.EnclosingFunction(pkg, qpos.Path) 88 | if target == nil { 89 | return nil, fmt.Errorf("no SSA function built for this location (dead code?)") 90 | } 91 | 92 | ptrAnalysis, err := pointer.Analyze(r.ptraCfg) 93 | if err != nil { 94 | return nil, err 95 | } 96 | 97 | cg := ptrAnalysis.CallGraph 98 | cg.DeleteSyntheticNodes() 99 | edges := cg.CreateNode(target).In 100 | 101 | callers := []*pos.QueryPos{} 102 | for _, edge := range edges { 103 | if edge.Caller.ID <= 1 { 104 | continue 105 | } 106 | caller, err := r.posToQueryPos(edge.Pos()) 107 | if err != nil { 108 | return callers, err 109 | } 110 | callers = append(callers, caller) 111 | } 112 | 113 | return callers, nil 114 | } 115 | 116 | func (r *refactor) callersAndCallsites(qpos *pos.QueryPos, depth int) ([]*pos.QueryPos, []*pos.QueryPos, error) { 117 | allCallers, allCallsites := map[token.Pos]*pos.QueryPos{}, map[token.Pos]*pos.QueryPos{} 118 | 119 | err := r.addCallersAndCallsites(qpos, depth, allCallers, allCallsites) 120 | if err != nil { 121 | return nil, nil, err 122 | } 123 | 124 | resultCallers, resultCallsites := []*pos.QueryPos{}, []*pos.QueryPos{} 125 | for _, caller := range allCallers { 126 | resultCallers = append(resultCallers, caller) 127 | } 128 | for _, caller := range allCallsites { 129 | resultCallsites = append(resultCallsites, caller) 130 | } 131 | return resultCallers, resultCallsites, nil 132 | } 133 | 134 | func (r *refactor) addCallersAndCallsites(qpos *pos.QueryPos, depth int, allCallers, allCallsites map[token.Pos]*pos.QueryPos) error { 135 | if _, present := allCallers[qpos.Start]; present { 136 | return nil 137 | } 138 | if depth == 0 { 139 | return nil 140 | } 141 | allCallers[qpos.Start] = qpos 142 | callers, err := r.callers(qpos) 143 | if err != nil { 144 | return err 145 | } 146 | for _, caller := range callers { 147 | allCallsites[caller.Start] = caller 148 | 149 | parent, err := r.parentFunc(caller.Path) 150 | if err != nil { 151 | return err 152 | } 153 | if parent != nil { 154 | if err := r.addCallersAndCallsites(parent, depth-1, allCallers, allCallsites); err != nil { 155 | return err 156 | } 157 | } 158 | } 159 | return nil 160 | } 161 | 162 | func (r *refactor) queryPos(position string, reflection bool) (*pos.QueryPos, error) { 163 | return pos.ParseQueryPos(r.iprog, position, reflection) 164 | } 165 | 166 | func (r *refactor) posToQueryPos(pos token.Pos) (*pos.QueryPos, error) { 167 | p := r.prog.Fset.Position(pos) 168 | return r.queryPos(fmt.Sprintf("%s:#%d", p.Filename, p.Offset), false) 169 | } 170 | 171 | func (r *refactor) parentFunc(path []ast.Node) (*pos.QueryPos, error) { 172 | for _, node := range path { 173 | // TODO consider function literals 174 | if fn, ok := node.(*ast.FuncDecl); ok { 175 | return r.posToQueryPos(fn.Type.Params.Pos()) 176 | } 177 | } 178 | return nil, fmt.Errorf("no parent found: %s", r.prog.Fset.Position(path[0].Pos())) 179 | } 180 | -------------------------------------------------------------------------------- /pos/pos.go: -------------------------------------------------------------------------------- 1 | package pos 2 | 3 | import ( 4 | "fmt" 5 | "go/token" 6 | "os" 7 | "path/filepath" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | // parseOctothorpDecimal returns the numeric value if s matches "#%d", 13 | // otherwise -1. 14 | func parseOctothorpDecimal(s string) int { 15 | if s != "" && s[0] == '#' { 16 | if s, err := strconv.ParseInt(s[1:], 10, 32); err == nil { 17 | return int(s) 18 | } 19 | } 20 | return -1 21 | } 22 | 23 | // ParsePosFlag parses a string of the form "file:pos" or 24 | // file:start,end" where pos, start, end match #%d and represent byte 25 | // offsets, and returns its components. 26 | // 27 | // (Numbers without a '#' prefix are reserved for future use, 28 | // e.g. to indicate line/column positions.) 29 | // 30 | func ParsePosFlag(posFlag string) (filename string, startOffset, endOffset int, err error) { 31 | if posFlag == "" { 32 | err = fmt.Errorf("no source position specified (-pos flag)") 33 | return 34 | } 35 | 36 | colon := strings.LastIndex(posFlag, ":") 37 | if colon < 0 { 38 | err = fmt.Errorf("invalid source position -pos=%q", posFlag) 39 | return 40 | } 41 | filename, offset := posFlag[:colon], posFlag[colon+1:] 42 | startOffset = -1 43 | endOffset = -1 44 | if hyphen := strings.Index(offset, ","); hyphen < 0 { 45 | // e.g. "foo.go:#123" 46 | startOffset = parseOctothorpDecimal(offset) 47 | endOffset = startOffset 48 | } else { 49 | // e.g. "foo.go:#123,#456" 50 | startOffset = parseOctothorpDecimal(offset[:hyphen]) 51 | endOffset = parseOctothorpDecimal(offset[hyphen+1:]) 52 | } 53 | if startOffset < 0 || endOffset < 0 { 54 | err = fmt.Errorf("invalid -pos offset %q", offset) 55 | return 56 | } 57 | return 58 | } 59 | 60 | // findQueryPos searches fset for filename and translates the 61 | // specified file-relative byte offsets into token.Pos form. It 62 | // returns an error if the file was not found or the offsets were out 63 | // of bounds. 64 | // 65 | func findQueryPos(fset *token.FileSet, filename string, startOffset, endOffset int) (start, end token.Pos, err error) { 66 | var file *token.File 67 | fset.Iterate(func(f *token.File) bool { 68 | if sameFile(filename, f.Name()) { 69 | // (f.Name() is absolute) 70 | file = f 71 | return false // done 72 | } 73 | return true // continue 74 | }) 75 | if file == nil { 76 | err = fmt.Errorf("couldn't find file containing position") 77 | return 78 | } 79 | 80 | // Range check [start..end], inclusive of both end-points. 81 | 82 | if 0 <= startOffset && startOffset <= file.Size() { 83 | start = file.Pos(int(startOffset)) 84 | } else { 85 | err = fmt.Errorf("start position is beyond end of file") 86 | return 87 | } 88 | 89 | if 0 <= endOffset && endOffset <= file.Size() { 90 | end = file.Pos(int(endOffset)) 91 | } else { 92 | err = fmt.Errorf("end position is beyond end of file") 93 | return 94 | } 95 | 96 | return 97 | } 98 | 99 | // sameFile returns true if x and y have the same basename and denote 100 | // the same file. 101 | // 102 | func sameFile(x, y string) bool { 103 | if filepath.Base(x) == filepath.Base(y) { // (optimisation) 104 | if xi, err := os.Stat(x); err == nil { 105 | if yi, err := os.Stat(y); err == nil { 106 | return os.SameFile(xi, yi) 107 | } 108 | } 109 | } 110 | return false 111 | } 112 | -------------------------------------------------------------------------------- /pos/query_pos.go: -------------------------------------------------------------------------------- 1 | package pos 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/token" 7 | 8 | "golang.org/x/tools/go/ast/astutil" 9 | "golang.org/x/tools/go/loader" 10 | ) 11 | 12 | // A QueryPos represents the position provided as input to a query: 13 | // a textual extent in the program's source code, the AST node it 14 | // corresponds to, and the package to which it belongs. 15 | // Instances are created by ParseQueryPos. 16 | // 17 | type QueryPos struct { 18 | Fset *token.FileSet 19 | Start, End token.Pos // source extent of query 20 | Path []ast.Node // AST path from query node to root of ast.File 21 | Exact bool // 2nd result of PathEnclosingInterval 22 | Info *loader.PackageInfo // type info for the queried package 23 | } 24 | 25 | // ParseQueryPos parses the source query position pos. 26 | // If needExact, it must identify a single AST subtree; 27 | // this is appropriate for queries that allow fairly arbitrary syntax, 28 | // e.g. "describe". 29 | // 30 | func ParseQueryPos(iprog *loader.Program, posFlag string, needExact bool) (*QueryPos, error) { 31 | filename, startOffset, endOffset, err := ParsePosFlag(posFlag) 32 | if err != nil { 33 | return nil, err 34 | } 35 | start, end, err := findQueryPos(iprog.Fset, filename, startOffset, endOffset) 36 | 37 | if err != nil { 38 | return nil, err 39 | } 40 | info, path, exact := iprog.PathEnclosingInterval(start, end) 41 | if path == nil { 42 | return nil, fmt.Errorf("no syntax here") 43 | } 44 | if needExact && !exact { 45 | return nil, fmt.Errorf("ambiguous selection within %s", astutil.NodeDescription(path[0])) 46 | } 47 | return &QueryPos{iprog.Fset, start, end, path, exact, info}, nil 48 | } 49 | -------------------------------------------------------------------------------- /test/original/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | x(42, "life") 5 | } 6 | 7 | func init() { 8 | z(31, "foobar") 9 | } 10 | 11 | func init() { 12 | alreadyPresent(42) 13 | } 14 | 15 | // shouldn't be modified 16 | func alreadyPresent(foo int) { 17 | x(31, "foobar") 18 | x(foo, "foobar") 19 | } 20 | 21 | func init() { 22 | new(dummy).test() 23 | } 24 | 25 | type dummy struct{} 26 | 27 | func (d *dummy) test() { 28 | baz := &struct{ foo func() }{ 29 | func() { 30 | x(7, "oi") 31 | }, 32 | } 33 | baz.foo() 34 | } 35 | -------------------------------------------------------------------------------- /test/original/x.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func x(a int, b string) { 4 | y(a, b) 5 | } 6 | -------------------------------------------------------------------------------- /test/original/y.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func y(a int, b string) { 4 | z(a, b) 5 | } 6 | -------------------------------------------------------------------------------- /test/original/z.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func z(a int, b string) { 4 | println("⚛") 5 | } 6 | --------------------------------------------------------------------------------