├── .travis.yml ├── LICENSE ├── README.md ├── astnorm.go ├── cmd ├── go-normalize │ └── main.go ├── grepfunc │ └── main.go └── internal │ └── loadfile │ └── loadfile.go ├── example ├── demo.bash ├── filter.go ├── mylib1 │ └── mylib1.go ├── mylib1n │ └── mylib1.go ├── mylib2 │ └── mylib2.go └── mylib2n │ └── mylib2.go ├── logo.jpg ├── normalizer.go ├── normalizer_test.go ├── testdata ├── normalize_expr.go └── normalize_stmt.go └── utils.go /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.x 4 | install: 5 | - # Prevent default install action "go get -t -v ./...". 6 | script: 7 | - go get -t -v ./... 8 | - go vet . 9 | - go test -v -race ./... -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Iskander (Alex) Sharipov / Quasilyte 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Go Report Card](https://goreportcard.com/badge/github.com/Quasilyte/astnorm)](https://goreportcard.com/report/github.com/Quasilyte/astnorm) 2 | [![GoDoc](https://godoc.org/github.com/Quasilyte/astnorm?status.svg)](https://godoc.org/github.com/Quasilyte/astnorm) 3 | [![Build Status](https://travis-ci.org/Quasilyte/astnorm.svg?branch=master)](https://travis-ci.org/Quasilyte/astnorm) 4 | 5 | ![logo](/logo.jpg) 6 | 7 | # astnorm 8 | 9 | Go AST normalization experiment. 10 | 11 | > THIS IS NOT A PROPER LIBRARY (yet?).
12 | > DO NOT USE.
13 | > It will probably be completely re-written before it becomes usable. 14 | 15 | ## Normalized code examples 16 | 17 | 1. Swap values. 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 40 |
BeforeAfter
26 | 27 | ```go 28 | tmp := xs[i] 29 | xs[i] = ys[i] 30 | ys[i] = tmp 31 | ``` 32 | 33 | 34 | 35 | ```go 36 | xs[i], ys[i] = ys[i], xs[i] 37 | ``` 38 | 39 |
41 | 42 | 2. Remove elements that are equal to `toRemove+1`. 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 79 |
BeforeAfter
51 | 52 | ```go 53 | const toRemove = 10 54 | var filtered []int 55 | filtered = xs[0:0] 56 | for i := int(0); i < len(xs); i++ { 57 | x := xs[i] 58 | if toRemove+1 != x { 59 | filtered = append(filtered, x) 60 | } 61 | } 62 | return (filtered) 63 | ``` 64 | 65 | 66 | 67 | ```go 68 | filtered := []int(nil) 69 | filtered = xs[:0] 70 | for _, x := range xs { 71 | if x != 11 { 72 | filtered = append(filtered, x) 73 | } 74 | } 75 | return filtered 76 | ``` 77 | 78 |
80 | 81 | ## Usage examples 82 | 83 | * [cmd/go-normalize](/cmd/go-normalize): normalize given Go file 84 | * [cmd/grepfunc](/cmd/grepfunc): turn Go code into a pattern for `gogrep` and run it 85 | 86 | Potential workflow for code searching: 87 | 88 | ### 1. Code search 89 | 90 | * Normalize the entire Go stdlib 91 | * Then normalize your function 92 | * Run `grepfunc` against normalized stdlib 93 | * If function you implemented has implementation under the stdlib, you'll probably find it 94 | 95 | Basically, instead of stdlib you can use any kind of Go corpus. 96 | 97 | Another code search related tasks that can be simplified by `astnorm` are code similarity 98 | evaluation and code duplication detection of any kind. 99 | 100 | ### 2. Static analysis 101 | 102 | Suppose we have `badcode.go` file: 103 | 104 | ```go 105 | package badpkg 106 | 107 | func NotEqual(x1, x2 int) bool { 108 | return (x1) != x1 109 | } 110 | ``` 111 | 112 | There is an obvious mistake there, `x1` used twice, but because of extra parenthesis, linters may not detect this issue: 113 | 114 | ```bash 115 | $ staticcheck badcode.go 116 | # No output 117 | ``` 118 | 119 | Let's normalize the input first and then run `staticcheck`: 120 | 121 | ```bash 122 | go-normalize badcode.go > normalized_badcode.go 123 | staticcheck normalized_badcode.go 124 | normalized_badcode.go:4:9: identical expressions on the left and right side of the '!=' operator (SA4000) 125 | ``` 126 | 127 | And we get the warning we deserve! 128 | No changes into `staticcheck` or any other linter are required. 129 | 130 | See also: [demo script](/example/demo.bash). 131 | -------------------------------------------------------------------------------- /astnorm.go: -------------------------------------------------------------------------------- 1 | // package astnorm implements AST normalization routines. 2 | package astnorm 3 | 4 | import ( 5 | "go/ast" 6 | "go/types" 7 | ) 8 | 9 | // Config carries information needed to properly normalize 10 | // AST nodes as well as optional configuration values 11 | // to control different aspects of the process. 12 | type Config struct { 13 | Info *types.Info 14 | } 15 | 16 | // Expr returns normalized expression x. 17 | // x may be mutated. 18 | func Expr(cfg *Config, x ast.Expr) ast.Expr { 19 | return newNormalizer(cfg).normalizeExpr(x) 20 | } 21 | 22 | // Stmt returns normalized statement x. 23 | // x may be mutated. 24 | func Stmt(cfg *Config, x ast.Stmt) ast.Stmt { 25 | return newNormalizer(cfg).normalizeStmt(x) 26 | } 27 | 28 | // Block returns normalized block x. 29 | // x may be mutated. 30 | func Block(cfg *Config, x *ast.BlockStmt) *ast.BlockStmt { 31 | return newNormalizer(cfg).normalizeBlockStmt(x) 32 | } 33 | -------------------------------------------------------------------------------- /cmd/go-normalize/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "go/ast" 6 | "go/printer" 7 | "go/token" 8 | "log" 9 | "os" 10 | 11 | "github.com/Quasilyte/astnorm" 12 | "github.com/Quasilyte/astnorm/cmd/internal/loadfile" 13 | ) 14 | 15 | func main() { 16 | log.SetFlags(0) 17 | 18 | flag.Parse() 19 | 20 | targets := flag.Args() 21 | if len(targets) != 1 { 22 | log.Panicf("expected exactly 1 positional argument (input go file)") 23 | } 24 | 25 | // For now, handle only 1 input file case. 26 | // For simplicity reasons. 27 | f, info, err := loadfile.ByPath(targets[0]) 28 | if err != nil { 29 | log.Panicf("loadfile: %v", err) 30 | } 31 | normalizationConfig := &astnorm.Config{ 32 | Info: info, 33 | } 34 | f = normalizeFile(normalizationConfig, f) 35 | fset := token.NewFileSet() 36 | if err := printer.Fprint(os.Stdout, fset, f); err != nil { 37 | log.Panicf("print normalized file: %v", err) 38 | } 39 | } 40 | 41 | func normalizeFile(cfg *astnorm.Config, f *ast.File) *ast.File { 42 | // Strip comments. 43 | f.Doc = nil 44 | ast.Inspect(f, func(n ast.Node) bool { 45 | switch n := n.(type) { 46 | case *ast.FuncDecl: 47 | n.Doc = nil 48 | case *ast.GenDecl: 49 | n.Doc = nil 50 | case *ast.Field: 51 | n.Doc = nil 52 | case *ast.ImportSpec: 53 | n.Doc = nil 54 | case *ast.ValueSpec: 55 | n.Doc = nil 56 | case *ast.TypeSpec: 57 | n.Doc = nil 58 | default: 59 | } 60 | return true 61 | }) 62 | f.Comments = nil 63 | 64 | for _, decl := range f.Decls { 65 | // TODO(quasilyte): could also normalize global vars, 66 | // consts and type defs, but funcs are OK for the POC. 67 | switch decl := decl.(type) { 68 | case *ast.FuncDecl: 69 | decl.Body = astnorm.Block(cfg, decl.Body) 70 | } 71 | } 72 | return f 73 | } 74 | -------------------------------------------------------------------------------- /cmd/grepfunc/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "go/ast" 7 | "go/types" 8 | "log" 9 | "os/exec" 10 | "strings" 11 | 12 | "github.com/Quasilyte/astnorm/cmd/internal/loadfile" 13 | "github.com/go-toolsmith/astfmt" 14 | "github.com/go-toolsmith/typep" 15 | ) 16 | 17 | func main() { 18 | log.SetFlags(0) 19 | 20 | input := flag.String("input", "", 21 | `input Go file with pattern function`) 22 | pattern := flag.String("pattern", "_pattern", 23 | `function to be interpreted as a pattern`) 24 | verbose := flag.Bool("v", false, 25 | `turn on debug output`) 26 | flag.Parse() 27 | 28 | if *input == "" { 29 | log.Panic("-input argument can't be empty") 30 | } 31 | targets := flag.Args() 32 | 33 | f, info, err := loadfile.ByPath(*input) 34 | if err != nil { 35 | log.Panicf("loadfile: %v", err) 36 | } 37 | 38 | var fndecl *ast.FuncDecl 39 | for _, decl := range f.Decls { 40 | decl, ok := decl.(*ast.FuncDecl) 41 | if ok && decl.Name.Name == *pattern { 42 | fndecl = decl 43 | break 44 | } 45 | } 46 | if fndecl == nil { 47 | log.Panicf("found no `%s` func in %q", *pattern, targets[0]) 48 | } 49 | if fndecl.Body == nil { 50 | log.Panic("external funcs are not supported") 51 | } 52 | pat := makeGogrepPattern(info, fndecl.Body) 53 | s := astfmt.Sprint(pat) 54 | s = strings.TrimPrefix(s, "{") 55 | s = strings.TrimSuffix(s, "}") 56 | 57 | if *verbose { 58 | fmt.Println(s) 59 | } 60 | 61 | gogrepArgs := []string{"-x", s} 62 | gogrepArgs = append(gogrepArgs, targets...) 63 | out, err := exec.Command("gogrep", gogrepArgs...).CombinedOutput() 64 | if err != nil { 65 | log.Panicf("run gogrep: %v: %s", err, out) 66 | } 67 | fmt.Print(string(out)) 68 | } 69 | 70 | type visitor struct { 71 | info *types.Info 72 | } 73 | 74 | func (v *visitor) visitNode(x ast.Node) bool { 75 | switch x := x.(type) { 76 | case *ast.Ident: 77 | // Don't replace type names. 78 | if typep.IsTypeExpr(v.info, x) { 79 | return true 80 | } 81 | x.Name = "$" + x.Name 82 | return true 83 | case *ast.CallExpr: 84 | // Don't want to replace function names. 85 | for _, arg := range x.Args { 86 | ast.Inspect(arg, v.visitNode) 87 | } 88 | return false 89 | default: 90 | return true 91 | } 92 | } 93 | 94 | func makeGogrepPattern(info *types.Info, body *ast.BlockStmt) *ast.BlockStmt { 95 | v := &visitor{info: info} 96 | ast.Inspect(body, v.visitNode) 97 | return body 98 | } 99 | -------------------------------------------------------------------------------- /cmd/internal/loadfile/loadfile.go: -------------------------------------------------------------------------------- 1 | package loadfile 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/types" 7 | 8 | "golang.org/x/tools/go/packages" 9 | ) 10 | 11 | func ByPath(path string) (*ast.File, *types.Info, error) { 12 | cfg := &packages.Config{Mode: packages.LoadSyntax} 13 | pkgs, err := packages.Load(cfg, path) 14 | if err != nil { 15 | return nil, nil, fmt.Errorf("load: %v", err) 16 | } 17 | if errCount := packages.PrintErrors(pkgs); errCount != 0 { 18 | return nil, nil, fmt.Errorf("%d errors during package loading", errCount) 19 | } 20 | if len(pkgs) != 1 { 21 | return nil, nil, fmt.Errorf("loaded %d packages, expected only 1", len(pkgs)) 22 | } 23 | pkg := pkgs[0] 24 | if len(pkg.Syntax) != 1 { 25 | err := fmt.Errorf("loaded package has %d files, expected only 1", 26 | len(pkg.Syntax)) 27 | return nil, nil, err 28 | } 29 | return pkg.Syntax[0], pkg.TypesInfo, nil 30 | } 31 | -------------------------------------------------------------------------------- /example/demo.bash: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Both packages have a func that implements strings.Repeat. 4 | # Suppose we found 1 such place in our codebase and calls to it 5 | # with strings.Repeat. But how maybe there are more such funcs? 6 | 7 | cat mylib1/mylib1.go 8 | go-normalize mylib1/mylib1.go > mylib1n/mylib1.go 9 | diff mylib1/mylib1.go mylib1n/mylib1.go 10 | 11 | cat mylib2/mylib2.go 12 | go-normalize mylib2/mylib2.go > mylib2n/mylib2.go 13 | diff mylib2/mylib2.go mylib2n/mylib2.go 14 | 15 | # Use grepfunc to create a pattern from Go code. 16 | # With syntax patterns, we can now ignore variable 17 | # names differences and find both functions 18 | # by either of them. 19 | grepfunc -input mylib1n/mylib1.go -pattern=makeString ./... 20 | grepfunc -v -input mylib2n/mylib2.go -pattern=repeatString ./... 21 | 22 | # For more examples see consult README file. 23 | -------------------------------------------------------------------------------- /example/filter.go: -------------------------------------------------------------------------------- 1 | package example 2 | 3 | func _(xs []int) []int { 4 | const toRemove = 10 5 | var filtered []int 6 | filtered = xs[0:0] 7 | for i := int(0); i < len(xs); i++ { 8 | x := xs[i] 9 | if toRemove+1 != x { 10 | filtered = append(filtered, x) 11 | } 12 | } 13 | return (filtered) 14 | } 15 | -------------------------------------------------------------------------------- /example/mylib1/mylib1.go: -------------------------------------------------------------------------------- 1 | package mylib1 2 | 3 | import "strings" 4 | 5 | func makeString(str string, num int) string { 6 | var parts = make([]string, num) 7 | for i := 0; i < len(parts); i++ { 8 | parts[i] = str 9 | } 10 | return strings.Join(parts, "") 11 | } 12 | -------------------------------------------------------------------------------- /example/mylib1n/mylib1.go: -------------------------------------------------------------------------------- 1 | package mylib1 2 | 3 | import "strings" 4 | 5 | func makeString(str string, num int) string { 6 | parts := make([]string, num) 7 | for i := range parts { 8 | parts[i] = str 9 | } 10 | return strings.Join(parts, "") 11 | } 12 | -------------------------------------------------------------------------------- /example/mylib2/mylib2.go: -------------------------------------------------------------------------------- 1 | package mylib2 2 | 3 | import "strings" 4 | 5 | func repeatString(s string, n int) string { 6 | var pieces = make([]string, n) 7 | for i := range pieces { 8 | pieces[i] = s 9 | } 10 | const sep = "" 11 | return strings.Join(pieces, sep) 12 | } 13 | -------------------------------------------------------------------------------- /example/mylib2n/mylib2.go: -------------------------------------------------------------------------------- 1 | package mylib2 2 | 3 | import "strings" 4 | 5 | func repeatString(s string, n int) string { 6 | pieces := make([]string, n) 7 | for i := range pieces { 8 | pieces[i] = s 9 | } 10 | return strings.Join(pieces, "") 11 | } 12 | -------------------------------------------------------------------------------- /logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quasilyte/astnorm/544300fc3e48de9779bafaf42eb34bff76fde35e/logo.jpg -------------------------------------------------------------------------------- /normalizer.go: -------------------------------------------------------------------------------- 1 | package astnorm 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/constant" 7 | "go/token" 8 | "go/types" 9 | 10 | "github.com/go-toolsmith/astcast" 11 | "github.com/go-toolsmith/astequal" 12 | "github.com/go-toolsmith/astp" 13 | "github.com/go-toolsmith/typep" 14 | ) 15 | 16 | type normalizer struct { 17 | cfg *Config 18 | } 19 | 20 | func newNormalizer(cfg *Config) *normalizer { 21 | return &normalizer{ 22 | cfg: cfg, 23 | } 24 | } 25 | 26 | func (n *normalizer) foldConstexpr(x ast.Expr) ast.Expr { 27 | if astp.IsParenExpr(x) { 28 | return nil 29 | } 30 | tv := n.cfg.Info.Types[x] 31 | if tv.Value == nil { 32 | return nil 33 | } 34 | 35 | if lit, ok := x.(*ast.BasicLit); ok && lit.Kind == token.FLOAT { 36 | // Floats may require additional handling here. 37 | if tv.Value.Kind() == constant.Int { 38 | // Usually, this case means that value is 0. 39 | // But to be sure, keep this assertion here. 40 | v, exact := constant.Int64Val(tv.Value) 41 | if !exact || v != 0 { 42 | panic(fmt.Sprintf("unexpected value for float with kind=int")) 43 | } 44 | return &ast.BasicLit{ 45 | Kind: token.FLOAT, 46 | Value: "0.0", 47 | } 48 | } 49 | } 50 | 51 | if call, ok := x.(*ast.CallExpr); ok && typep.IsTypeExpr(n.cfg.Info, call.Fun) { 52 | if !isDefaultLiteralType(n.cfg.Info.TypeOf(call).Underlying()) { 53 | call.Args[0] = constValueNode(tv.Value) 54 | return call 55 | } 56 | } 57 | 58 | return constValueNode(tv.Value) 59 | } 60 | 61 | func (n *normalizer) normalizeExpr(x ast.Expr) ast.Expr { 62 | if folded := n.foldConstexpr(x); folded != nil { 63 | return folded 64 | } 65 | 66 | switch x := x.(type) { 67 | case *ast.CallExpr: 68 | if typep.IsTypeExpr(n.cfg.Info, x.Fun) { 69 | return n.normalizeTypeConversion(x) 70 | } 71 | x.Fun = n.normalizeExpr(x.Fun) 72 | x.Args = n.normalizeExprList(x.Args) 73 | return x 74 | case *ast.SliceExpr: 75 | return n.normalizeSliceExpr(x) 76 | case *ast.ParenExpr: 77 | return n.normalizeExpr(x.X) 78 | case *ast.BinaryExpr: 79 | return n.normalizeBinaryExpr(x) 80 | default: 81 | return x 82 | } 83 | } 84 | 85 | func (n *normalizer) normalizeTypeConversion(x *ast.CallExpr) ast.Expr { 86 | typeTo := n.cfg.Info.TypeOf(x) 87 | typeFrom := n.cfg.Info.TypeOf(x.Args[0]) 88 | if types.Identical(typeTo, typeFrom) { 89 | return x.Args[0] 90 | } 91 | return x 92 | } 93 | 94 | func (n *normalizer) normalizeSliceExpr(x *ast.SliceExpr) ast.Expr { 95 | x.Low = n.normalizeExpr(x.Low) 96 | x.High = n.normalizeExpr(x.High) 97 | x.Max = n.normalizeExpr(x.Max) 98 | x.X = n.normalizeExpr(x.X) 99 | // Omit default low boundary. 100 | if astcast.ToBasicLit(x.Low).Value == "0" { 101 | x.Low = nil 102 | } 103 | // Omit default high boundary, but only if 3rd index is abscent. 104 | if x.Max == nil { 105 | lenCall := astcast.ToCallExpr(x.High) 106 | if astcast.ToIdent(lenCall.Fun).Name == "len" && astequal.Expr(lenCall.Args[0], x.X) { 107 | x.High = nil 108 | } 109 | } 110 | // s[:] => s if s is a slice or a string. 111 | if x.Low == nil && x.High == nil && !typep.IsArray(n.cfg.Info.TypeOf(x.X)) { 112 | return x.X 113 | } 114 | return x 115 | } 116 | 117 | func (n *normalizer) normalizeBinaryExpr(x *ast.BinaryExpr) ast.Expr { 118 | x.X = n.normalizeExpr(x.X) 119 | x.Y = n.normalizeExpr(x.Y) 120 | 121 | // TODO(quasilyte): implement this check in a proper way. 122 | // Also handle empty strings. 123 | switch { 124 | case isCommutative(n.cfg.Info, x) && astcast.ToBasicLit(x.X).Value == "0": 125 | return x.Y 126 | case astcast.ToBasicLit(x.Y).Value == "0": 127 | return x.X 128 | } 129 | 130 | if isCommutative(n.cfg.Info, x) { 131 | lhs := astcast.ToBinaryExpr(x.X) 132 | cv1 := constValueOf(n.cfg.Info, lhs.Y) 133 | cv2 := constValueOf(n.cfg.Info, x.Y) 134 | 135 | if cv1 != nil && cv2 != nil { 136 | cv := constant.BinaryOp(cv1, x.Op, cv2) 137 | x.X = lhs.X 138 | x.Y = constValueNode(cv) 139 | return n.normalizeExpr(x) 140 | } 141 | 142 | // Turn yoda expressions into the more conventional notation. 143 | // Put constant inside the expression after the non-constant part. 144 | if isLiteralConst(n.cfg.Info, x.X) && !isLiteralConst(n.cfg.Info, x.Y) { 145 | x.X, x.Y = x.Y, x.X 146 | } 147 | } 148 | 149 | return x 150 | } 151 | 152 | func (n *normalizer) normalizeExprList(xs []ast.Expr) []ast.Expr { 153 | for i, x := range xs { 154 | xs[i] = n.normalizeExpr(x) 155 | } 156 | return xs 157 | } 158 | 159 | func (n *normalizer) normalizeStmt(x ast.Stmt) ast.Stmt { 160 | switch x := x.(type) { 161 | case *ast.LabeledStmt: 162 | return n.normalizeLabeledStmt(x) 163 | case *ast.ExprStmt: 164 | return n.normalizeExprStmt(x) 165 | case *ast.AssignStmt: 166 | return n.normalizeAssignStmt(x) 167 | case *ast.BlockStmt: 168 | return n.normalizeBlockStmt(x) 169 | case *ast.ReturnStmt: 170 | return n.normalizeReturnStmt(x) 171 | case *ast.DeclStmt: 172 | return n.normalizeDeclStmt(x) 173 | case *ast.ForStmt: 174 | return n.normalizeForStmt(x) 175 | case *ast.RangeStmt: 176 | return n.normalizeRangeStmt(x) 177 | case *ast.IfStmt: 178 | return n.normalizeIfStmt(x) 179 | case *ast.IncDecStmt: 180 | return n.normalizeIncDecStmt(x) 181 | default: 182 | return x 183 | } 184 | } 185 | 186 | func (n *normalizer) normalizeIncDecStmt(stmt *ast.IncDecStmt) *ast.IncDecStmt { 187 | stmt.X = n.normalizeExpr(stmt.X) 188 | return stmt 189 | } 190 | 191 | func (n *normalizer) normalizeReturnStmt(ret *ast.ReturnStmt) *ast.ReturnStmt { 192 | ret.Results = n.normalizeExprList(ret.Results) 193 | return ret 194 | } 195 | 196 | func (n *normalizer) normalizeBlockStmt(b *ast.BlockStmt) *ast.BlockStmt { 197 | list := b.List[:0] 198 | for _, x := range b.List { 199 | // Filter-out const decls. 200 | // We inline const values, so local const defs are 201 | // not needed to keep code valid. 202 | decl, ok := x.(*ast.DeclStmt) 203 | if ok && decl.Decl.(*ast.GenDecl).Tok == token.CONST { 204 | continue 205 | } 206 | list = append(list, n.normalizeStmt(x)) 207 | } 208 | b.List = list 209 | 210 | n.normalizeValSwap(b) 211 | 212 | return b 213 | } 214 | 215 | func (n *normalizer) normalizeDeclStmt(stmt *ast.DeclStmt) ast.Stmt { 216 | decl := stmt.Decl.(*ast.GenDecl) 217 | if decl.Tok != token.VAR { 218 | return stmt 219 | } 220 | if len(decl.Specs) != 1 { 221 | return stmt 222 | } 223 | spec := decl.Specs[0].(*ast.ValueSpec) 224 | if len(spec.Names) != 1 { 225 | return stmt 226 | } 227 | 228 | switch { 229 | case len(spec.Values) == 1: 230 | // var x T = v 231 | return &ast.AssignStmt{ 232 | Tok: token.DEFINE, 233 | Lhs: []ast.Expr{spec.Names[0]}, 234 | Rhs: []ast.Expr{spec.Values[0]}, 235 | } 236 | case len(spec.Values) == 0 && spec.Type != nil: 237 | // var x T 238 | zv := zeroValueOf(n.cfg.Info.TypeOf(spec.Type)) 239 | if zv == nil { 240 | return stmt 241 | } 242 | return &ast.AssignStmt{ 243 | Tok: token.DEFINE, 244 | Lhs: []ast.Expr{spec.Names[0]}, 245 | Rhs: []ast.Expr{zv}, 246 | } 247 | default: 248 | return stmt 249 | } 250 | } 251 | 252 | func (n *normalizer) normalizeRangeStmt(loop *ast.RangeStmt) *ast.RangeStmt { 253 | loop.Key = n.normalizeExpr(loop.Key) // Not needed? 254 | loop.Value = n.normalizeExpr(loop.Value) // Not needed? 255 | loop.X = n.normalizeExpr(loop.X) 256 | loop.Body = n.normalizeBlockStmt(loop.Body) 257 | return loop 258 | } 259 | 260 | func (n *normalizer) normalizeForStmt(loop *ast.ForStmt) ast.Stmt { 261 | loop.Init = n.normalizeStmt(loop.Init) 262 | loop.Cond = n.normalizeExpr(loop.Cond) 263 | loop.Post = n.normalizeStmt(loop.Post) 264 | loop.Body = n.normalizeBlockStmt(loop.Body) 265 | 266 | if len(loop.Body.List) == 0 { 267 | return loop // Don't care 268 | } 269 | 270 | // I want AST matchers. 271 | // A lot. 272 | // Why go-toolsmith doesn't have one yet? 273 | // Code below is a mess even with astcast. 274 | 275 | init := astcast.ToAssignStmt(loop.Init) 276 | cond := astcast.ToBinaryExpr(loop.Cond) 277 | post := astcast.ToIncDecStmt(loop.Post) 278 | 279 | if init.Tok != token.DEFINE || len(init.Lhs) != 1 || len(init.Rhs) != 1 { 280 | return loop 281 | } 282 | if astcast.ToBasicLit(init.Rhs[0]).Value != "0" { 283 | return loop 284 | } 285 | iter := astcast.ToIdent(init.Lhs[0]) 286 | 287 | if cond.Op != token.LSS { 288 | return loop 289 | } 290 | if !astequal.Expr(iter, cond.X) { 291 | return loop 292 | } 293 | lenCall := astcast.ToCallExpr(cond.Y) 294 | if astcast.ToIdent(lenCall.Fun).Name != "len" { 295 | return loop 296 | } 297 | slice := lenCall.Args[0] 298 | if !typep.IsSlice(n.cfg.Info.TypeOf(slice)) { 299 | return loop 300 | } 301 | 302 | if post.Tok != token.INC || !astequal.Expr(post.X, iter) { 303 | return loop 304 | } 305 | 306 | // Loop header matched. 307 | // Now need to see whether we need a key and/or val. 308 | var val ast.Expr 309 | keyUsed := false 310 | skip := 0 311 | 312 | assign := astcast.ToAssignStmt(loop.Body.List[0]) 313 | if assign.Tok == token.DEFINE && len(assign.Lhs) == 1 && len(assign.Rhs) == 1 { 314 | indexing := astcast.ToIndexExpr(assign.Rhs[0]) 315 | if astequal.Expr(indexing.X, slice) && astequal.Expr(indexing.Index, iter) { 316 | val = assign.Lhs[0] 317 | skip = 1 318 | } 319 | } 320 | 321 | // Now check that iter is not modified inside loop. 322 | // And since we're lazy and this is only a POC, 323 | // give up on any usage of iter. 324 | for _, stmt := range loop.Body.List[skip:] { 325 | giveUp := false 326 | ast.Inspect(stmt, func(x ast.Node) bool { 327 | switch x := x.(type) { 328 | case *ast.IndexExpr: 329 | index := astcast.ToIdent(x.Index) 330 | if index.Name == iter.Name && astequal.Expr(x.X, slice) { 331 | keyUsed = true 332 | return false 333 | } 334 | return true 335 | case *ast.Ident: 336 | if x.Name == iter.Name { 337 | giveUp = true 338 | } 339 | return true 340 | default: 341 | return true 342 | } 343 | }) 344 | if giveUp { 345 | return loop 346 | } 347 | } 348 | 349 | key := iter 350 | if !keyUsed { 351 | key = blankIdent 352 | } 353 | loop.Body.List = loop.Body.List[skip:] 354 | return &ast.RangeStmt{ 355 | Key: key, 356 | Value: val, 357 | Tok: token.DEFINE, 358 | X: slice, 359 | Body: loop.Body, 360 | } 361 | } 362 | 363 | func (n *normalizer) normalizeValSwap(b *ast.BlockStmt) { 364 | // tmp := x 365 | // x = y 366 | // y = tmp 367 | // 368 | // => 369 | // 370 | // x, y = y, x 371 | // 372 | // FIXME(quasilyte): if tmp is used somewhere outside of the value swap, 373 | // this transformation is illegal. 374 | 375 | for i := 0; i < len(b.List)-2; i++ { 376 | assignTmp := astcast.ToAssignStmt(b.List[i+0]) 377 | assignX := astcast.ToAssignStmt(b.List[i+1]) 378 | assignY := astcast.ToAssignStmt(b.List[i+2]) 379 | if assignTmp.Tok != token.DEFINE { 380 | continue 381 | } 382 | if assignX.Tok != token.ASSIGN || assignY.Tok != token.ASSIGN { 383 | continue 384 | } 385 | if len(assignTmp.Lhs) != 1 || len(assignX.Lhs) != 1 || len(assignY.Lhs) != 1 { 386 | continue 387 | } 388 | tmp := astcast.ToIdent(assignTmp.Lhs[0]) 389 | x := assignX.Lhs[0] 390 | y := assignY.Lhs[0] 391 | if !astequal.Expr(assignTmp.Rhs[0], x) { 392 | continue 393 | } 394 | if !astequal.Expr(assignX.Rhs[0], y) { 395 | continue 396 | } 397 | if !astequal.Expr(assignY.Rhs[0], tmp) { 398 | continue 399 | } 400 | 401 | b.List[i] = &ast.AssignStmt{ 402 | Tok: token.ASSIGN, 403 | Lhs: []ast.Expr{x, y}, 404 | Rhs: []ast.Expr{y, x}, 405 | } 406 | b.List = append(b.List[:i+1], b.List[i+3:]...) 407 | } 408 | } 409 | 410 | func (n *normalizer) normalizeLabeledStmt(stmt *ast.LabeledStmt) *ast.LabeledStmt { 411 | stmt.Stmt = n.normalizeStmt(stmt.Stmt) 412 | return stmt 413 | } 414 | 415 | func (n *normalizer) normalizeExprStmt(stmt *ast.ExprStmt) *ast.ExprStmt { 416 | stmt.X = n.normalizeExpr(stmt.X) 417 | return stmt 418 | } 419 | 420 | func (n *normalizer) normalizeAssignStmt(assign *ast.AssignStmt) ast.Stmt { 421 | for i, lhs := range assign.Lhs { 422 | assign.Lhs[i] = n.normalizeExpr(lhs) 423 | } 424 | for i, rhs := range assign.Rhs { 425 | assign.Rhs[i] = n.normalizeExpr(rhs) 426 | } 427 | assign = n.normalizeAssignOp(assign) 428 | return assign 429 | } 430 | 431 | func (n *normalizer) normalizeAssignOp(assign *ast.AssignStmt) *ast.AssignStmt { 432 | if assign.Tok != token.ASSIGN || len(assign.Lhs) != 1 { 433 | return assign 434 | } 435 | rhs := astcast.ToBinaryExpr(assign.Rhs[0]) 436 | if !astequal.Expr(assign.Lhs[0], rhs.X) { 437 | return assign 438 | } 439 | op, ok := assignOpTab[rhs.Op] 440 | if ok { 441 | assign.Tok = op 442 | assign.Rhs[0] = rhs.Y 443 | } 444 | return assign 445 | } 446 | 447 | func (n *normalizer) normalizeIfStmt(stmt *ast.IfStmt) *ast.IfStmt { 448 | stmt.Init = n.normalizeStmt(stmt.Init) 449 | stmt.Cond = n.normalizeExpr(stmt.Cond) 450 | stmt.Body = n.normalizeBlockStmt(stmt.Body) 451 | stmt.Else = n.normalizeStmt(stmt.Else) 452 | return stmt 453 | } 454 | 455 | var assignOpTab = map[token.Token]token.Token{ 456 | token.ADD: token.ADD_ASSIGN, 457 | token.SUB: token.SUB_ASSIGN, 458 | token.MUL: token.MUL_ASSIGN, 459 | token.QUO: token.QUO_ASSIGN, 460 | token.REM: token.REM_ASSIGN, 461 | 462 | token.AND: token.AND_ASSIGN, // &= 463 | token.OR: token.OR_ASSIGN, // |= 464 | token.XOR: token.XOR_ASSIGN, // ^= 465 | token.SHL: token.SHL_ASSIGN, // <<= 466 | token.SHR: token.SHR_ASSIGN, // >>= 467 | token.AND_NOT: token.AND_NOT_ASSIGN, // &^= 468 | } 469 | -------------------------------------------------------------------------------- /normalizer_test.go: -------------------------------------------------------------------------------- 1 | package astnorm 2 | 3 | import ( 4 | "go/ast" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/go-toolsmith/astcast" 9 | "github.com/go-toolsmith/astequal" 10 | "github.com/go-toolsmith/astfmt" 11 | "golang.org/x/tools/go/packages" 12 | ) 13 | 14 | func isTestCase(assign *ast.AssignStmt) bool { 15 | return len(assign.Lhs) == 2 && 16 | len(assign.Rhs) == 2 && 17 | astcast.ToIdent(assign.Lhs[0]).Name == "_" && 18 | astcast.ToIdent(assign.Lhs[1]).Name == "_" 19 | } 20 | 21 | func TestNormalizeExpr(t *testing.T) { 22 | pkg := loadPackage(t, "./testdata/normalize_expr.go") 23 | funcs := collectFuncDecls(pkg) 24 | cfg := &Config{Info: pkg.TypesInfo} 25 | 26 | for _, fn := range funcs { 27 | for _, stmt := range fn.Body.List { 28 | assign, ok := stmt.(*ast.AssignStmt) 29 | if !ok || !isTestCase(assign) { 30 | continue 31 | } 32 | input := assign.Rhs[0] 33 | want := assign.Rhs[1] 34 | have := Expr(cfg, input) 35 | if !astequal.Expr(have, want) { 36 | pos := pkg.Fset.Position(assign.Pos()) 37 | t.Errorf("%s:\nhave: %s\nwant: %s", 38 | pos, astfmt.Sprint(have), astfmt.Sprint(want)) 39 | } 40 | } 41 | } 42 | } 43 | 44 | func TestNormalizeStmt(t *testing.T) { 45 | pkg := loadPackage(t, "./testdata/normalize_stmt.go") 46 | funcs := collectFuncDecls(pkg) 47 | cfg := &Config{Info: pkg.TypesInfo} 48 | 49 | for _, fn := range funcs { 50 | for _, stmt := range fn.Body.List { 51 | assign, ok := stmt.(*ast.AssignStmt) 52 | if !ok || !isTestCase(assign) { 53 | continue 54 | } 55 | input := assign.Rhs[0].(*ast.FuncLit).Body 56 | want := assign.Rhs[1].(*ast.FuncLit).Body 57 | have := Stmt(cfg, input) 58 | if !astequal.Stmt(have, want) { 59 | pos := pkg.Fset.Position(assign.Pos()) 60 | t.Errorf("%s:\nhave: %s\nwant: %s", 61 | pos, astfmt.Sprint(have), astfmt.Sprint(want)) 62 | } 63 | } 64 | } 65 | } 66 | 67 | func collectFuncDecls(pkg *packages.Package) []*ast.FuncDecl { 68 | var funcs []*ast.FuncDecl 69 | for _, f := range pkg.Syntax { 70 | for _, decl := range f.Decls { 71 | decl, ok := decl.(*ast.FuncDecl) 72 | if !ok || decl.Body == nil { 73 | continue 74 | } 75 | if !strings.HasSuffix(decl.Name.Name, "Test") { 76 | continue 77 | } 78 | funcs = append(funcs, decl) 79 | } 80 | } 81 | return funcs 82 | } 83 | 84 | func loadPackage(t *testing.T, path string) *packages.Package { 85 | cfg := &packages.Config{Mode: packages.LoadSyntax} 86 | pkgs, err := packages.Load(cfg, path) 87 | if err != nil { 88 | t.Fatalf("load %q: %v", path, err) 89 | } 90 | if packages.PrintErrors(pkgs) > 0 { 91 | t.Fatalf("package %q loaded with errors", path) 92 | } 93 | if len(pkgs) != 1 { 94 | t.Fatalf("expected 1 package from %q path, got %d", 95 | path, len(pkgs)) 96 | } 97 | return pkgs[0] 98 | } 99 | -------------------------------------------------------------------------------- /testdata/normalize_expr.go: -------------------------------------------------------------------------------- 1 | package normalize_expr 2 | 3 | func addInts(x, y int) int { return x + y } 4 | 5 | func identityTest() { 6 | var x int 7 | type T int 8 | 9 | _, _ = x, x 10 | _, _ = 102, 102 11 | _, _ = x+1, x+1 12 | _, _ = 0-x, 0-x 13 | _, _ = 1.1, 1.1 14 | _, _ = 12412.312, 12412.312 15 | } 16 | 17 | func stringLiteralsTest() { 18 | _, _ = ``, "" 19 | _, _ = `\\`, "\\\\" 20 | _, _ = `\d+`, "\\d+" 21 | _, _ = `123`, "123" 22 | _, _ = "\n"+``+"\n", "\n\n" 23 | } 24 | 25 | func defaultSlicingBoundsTest() { 26 | var xs []int 27 | var s string 28 | var a [3]int 29 | 30 | _, _ = xs[0:], xs 31 | _, _ = (xs)[(0+0):], xs 32 | _, _ = xs[0:len(xs)], xs 33 | _, _ = (xs)[0:(len(xs))], xs 34 | _, _ = xs[:0:0], xs[:0:0] 35 | 36 | _, _ = s[0:len(s)], s 37 | _, _ = s[1:], s[1:] 38 | 39 | _, _ = a[:], a[:] 40 | } 41 | 42 | func literalsTest() { 43 | // Convert any int numerical base into 10. 44 | _, _ = 0x0, 0 45 | _, _ = 0x1, 1 46 | _, _ = 04, 4 47 | _, _ = 010, 8 48 | 49 | // Represent floats in a consistent way. 50 | _, _ = 1.0, 1.0 51 | _, _ = 5.0, 5.0 52 | _, _ = 0.0, 0.0 53 | _, _ = .0, 0.0 54 | _, _ = 0., 0.0 55 | _, _ = 0.1e4, 1000.0 56 | _, _ = 00.0, 0.0 57 | } 58 | 59 | func conversionTest() { 60 | var x int 61 | 62 | // These alredy have proper type even without conversion. 63 | _, _ = int(1), 1 64 | _, _ = float64(40.1), 40.1 65 | _, _ = int(x), x 66 | _, _ = int(x+1), x+1 67 | 68 | // Repetitive conversions. 69 | _, _ = int(int(int(1))), 1 70 | 71 | // These require conversion. 72 | _, _ = int32(x), int32(x) 73 | 74 | // Preserve type conversion for untyped literals. 75 | _, _ = int8(1), int8(1) 76 | _, _ = int8(int16(1)), int8(1) 77 | _, _ = int8(int16(int32(1))), int8(1) 78 | _, _ = int8(int16(int32(int64(1)))), int8(1) 79 | _, _ = int16(int8(int16(int32(int64(1+1+1))))), int16(3) 80 | _, _ = int32(int16(int8(int16(int32(int64(1+1)))))), int32(2) 81 | } 82 | 83 | func yodaTest() { 84 | var x int 85 | var s string 86 | var m map[int]int 87 | 88 | _, _ = 1+x, x+1 89 | _, _ = (nil != m), m != nil 90 | 91 | // Concat is not commutative. 92 | _, _ = "prefix"+s, "prefix"+s 93 | // Other non-commutative ops. 94 | _, _ = 1-x, 1-x 95 | _, _ = 1000/x, 1000/x 96 | } 97 | 98 | func foldBoolTest() { 99 | _, _ = false && false, false 100 | _, _ = false || false, false 101 | _, _ = true && true, true 102 | _, _ = true || false, true 103 | _, _ = false || false || true, true 104 | 105 | _, _ = 1 != 1, false 106 | _, _ = 1 == 1, true 107 | 108 | var x float64 109 | // This is NaN test. 110 | // Not something that should be replaced. 111 | _, _ = x != x, x != x 112 | } 113 | 114 | func foldArithTest() { 115 | var x int 116 | 117 | // Const-only expressions are folded entirely. 118 | _, _ = 1+2+3, 6 119 | _, _ = 6-2, 4 120 | 121 | // Zeroes can be removed completely as well. 122 | _, _ = x+0, x 123 | _, _ = x+(0)+0, x 124 | _, _ = 0+x, x 125 | _, _ = 0+0+x, x 126 | _, _ = 0+x+(0), x 127 | _, _ = (0+0)+x+0, x 128 | _, _ = 0+x+0+0, x 129 | _, _ = x-0-0, x 130 | 131 | // For commutative ops fold it into a single op. 132 | _, _ = x+1, x+1 133 | _, _ = x+1+1, x+2 134 | _, _ = 1+x+1, x+2 135 | _, _ = 1+2+x+2+1, x+6 136 | _, _ = (1+2)+x+2+1, x+6 137 | _, _ = ((1 + (2)) + (x + 2) + 1), x+6 138 | _, _ = 0.2+0.1, 0.3 139 | 140 | _, _ = "a"+"b"+"c", "abc" 141 | } 142 | 143 | func parenthesisRemovalTest() { 144 | var x int 145 | type T int 146 | 147 | _, _ = (x), x 148 | _, _ = ((*T)(&x)), (*T)(&x) 149 | _, _ = (addInts)(1, 2), addInts(1, 2) 150 | _, _ = addInts((1), (2)), addInts(1, 2) 151 | } 152 | -------------------------------------------------------------------------------- /testdata/normalize_stmt.go: -------------------------------------------------------------------------------- 1 | package normalize_stmt 2 | 3 | func addInts(x, y int) int { return x + y } 4 | 5 | func identityTest() { 6 | var x int 7 | 8 | _, _ = func() { 9 | x += 1 10 | x -= 1 11 | }, func() { 12 | x += 1 13 | x -= 1 14 | } 15 | } 16 | 17 | func incdecStmtTest() { 18 | var x int 19 | 20 | _, _ = func() { 21 | x++ 22 | (x)++ 23 | }, func() { 24 | x++ 25 | x++ 26 | } 27 | } 28 | 29 | func rangeStmtTest() { 30 | var xs []int 31 | 32 | _, _ = func() { 33 | for i := range xs[0:len(xs)] { 34 | _ = (i) 35 | } 36 | }, func() { 37 | for i := range xs { 38 | _ = i 39 | } 40 | } 41 | } 42 | 43 | func assignOpTest() { 44 | var x int 45 | 46 | _, _ = func() { 47 | x = x + 5 48 | x = x - 2 49 | x = x * 4 50 | }, func() { 51 | x += 5 52 | x -= 2 53 | x *= 4 54 | } 55 | } 56 | 57 | func valueSwapTest() { 58 | var x, y int 59 | 60 | _, _ = func() { 61 | tmp := (x) 62 | x = y 63 | y = tmp 64 | }, func() { 65 | x, y = y, x 66 | } 67 | 68 | _, _ = func() { 69 | tmp1 := x 70 | x = y 71 | y = tmp1 72 | 73 | tmp2 := y 74 | y = x 75 | x = tmp2 76 | }, func() { 77 | x, y = y, x 78 | y, x = x, y 79 | } 80 | 81 | } 82 | 83 | func removeConstDeclsTest() { 84 | _, _ = func() { 85 | const n = 10 86 | _ = n + n 87 | }, func() { 88 | _ = 20 89 | } 90 | 91 | _, _ = func() { 92 | const n = 10 93 | x := 10 94 | _ = x != n+1 95 | }, func() { 96 | x := 10 97 | _ = x != 11 98 | } 99 | } 100 | 101 | func rewriteVarSpecTest() { 102 | _, _ = func() { 103 | var x = 10 104 | var y float32 = float32(x) 105 | _ = x 106 | _ = y 107 | }, func() { 108 | x := 10 109 | y := float32(x) 110 | _ = x 111 | _ = y 112 | } 113 | 114 | _, _ = func() { 115 | var x int 116 | _ = x 117 | }, func() { 118 | x := 0 119 | _ = x 120 | } 121 | 122 | _, _ = func() { 123 | var xs [][]int 124 | var s string 125 | _ = xs 126 | _ = s 127 | }, func() { 128 | xs := [][]int(nil) 129 | s := "" 130 | _ = xs 131 | _ = s 132 | } 133 | 134 | _, _ = func() { 135 | var xs [8]string 136 | _ = xs 137 | }, func() { 138 | xs := [8]string{} 139 | _ = xs 140 | } 141 | 142 | _, _ = func() (float64, float32) { 143 | var x float64 144 | var y float32 145 | return x, y 146 | }, func() (float64, float32) { 147 | x := 0.0 148 | y := float32(0.0) 149 | return x, y 150 | } 151 | } 152 | 153 | func rangeLoopTest() { 154 | _, _ = func() { 155 | var xs []int 156 | for i := 0; i < len(xs); i++ { 157 | x := xs[i] 158 | _ = x 159 | } 160 | 161 | // Uses i+1 index. 162 | for i := 0; i < len(xs); i++ { 163 | x := xs[i+1] 164 | _ = x 165 | } 166 | 167 | // Doesn't assign elem. 168 | 169 | for i := 0; i < len(xs); i++ { 170 | _ = i 171 | } 172 | 173 | // TODO(quasilyte): more negative tests. 174 | // (Hint: use coverage to guide you, Luke!) 175 | }, func() { 176 | xs := []int(nil) 177 | for _, x := range xs { 178 | _ = x 179 | } 180 | 181 | for i := 0; i < len(xs); i++ { 182 | x := xs[i+1] 183 | _ = x 184 | } 185 | 186 | for i := 0; i < len(xs); i++ { 187 | _ = i 188 | } 189 | } 190 | 191 | _, _ = func() { 192 | var xs []int 193 | const toRemove = 10 194 | var filtered []int 195 | filtered = xs[0:0] 196 | for i := int(0); i < len(xs); i++ { 197 | x := xs[i] 198 | if toRemove+1 != x { 199 | filtered = append(filtered, x) 200 | } 201 | } 202 | _ = (filtered) 203 | }, func() { 204 | xs := []int(nil) 205 | filtered := []int(nil) 206 | filtered = xs[:0] 207 | for _, x := range xs { 208 | if x != 11 { 209 | filtered = append(filtered, x) 210 | } 211 | } 212 | _ = filtered 213 | } 214 | 215 | _, _ = func(xs []int) { 216 | for i := 0; i < len(xs); i++ { 217 | _ = xs[i] 218 | } 219 | }, func(xs []int) { 220 | for i := range xs { 221 | _ = xs[i] 222 | } 223 | } 224 | 225 | _, _ = func(xs []int) { 226 | for i := 0; i < len(xs); i++ { 227 | _ = xs[i+1] 228 | } 229 | }, func(xs []int) { 230 | for i := 0; i < len(xs); i++ { 231 | _ = xs[i+1] 232 | } 233 | } 234 | 235 | _, _ = func(xs []int) { 236 | for i := 0; i < len(xs); i++ { 237 | v := xs[i] 238 | _ = v 239 | _ = xs[i] 240 | } 241 | }, func(xs []int) { 242 | for i, v := range xs { 243 | _ = v 244 | _ = xs[i] 245 | } 246 | } 247 | 248 | _, _ = func(xs []int) { 249 | for i := 0; i < len(xs); i++ { 250 | v := xs[i] 251 | _ = v 252 | _ = xs[i] 253 | i++ 254 | } 255 | }, func(xs []int) { 256 | for i := 0; i < len(xs); i++ { 257 | v := xs[i] 258 | _ = v 259 | _ = xs[i] 260 | i++ 261 | } 262 | } 263 | } 264 | 265 | func exprStmtTest() { 266 | _, _ = func() { 267 | addInts((1), 0+0+0) 268 | }, func() { 269 | addInts(1, 0) 270 | } 271 | } 272 | 273 | func labeledStmtTest() { 274 | _, _ = func() int { 275 | goto L 276 | L: 277 | return (0) 278 | }, func() int { 279 | goto L 280 | L: 281 | return 0 282 | } 283 | } 284 | 285 | func combinedTest() { 286 | var x int 287 | 288 | _, _ = func() { 289 | x = x + (2) 290 | }, func() { 291 | x += 2 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package astnorm 2 | 3 | import ( 4 | "fmt" 5 | "go/ast" 6 | "go/constant" 7 | "go/token" 8 | "go/types" 9 | "strconv" 10 | "strings" 11 | 12 | "github.com/go-toolsmith/astcast" 13 | "github.com/go-toolsmith/strparse" 14 | "github.com/go-toolsmith/typep" 15 | "golang.org/x/tools/go/ast/astutil" 16 | ) 17 | 18 | var blankIdent = &ast.Ident{Name: "_"} 19 | 20 | func isLiteralConst(info *types.Info, x ast.Expr) bool { 21 | switch x := x.(type) { 22 | case *ast.Ident: 23 | // Not really literal consts, but they are 24 | // considered as such by many programmers. 25 | switch x.Name { 26 | case "nil", "true", "false": 27 | return true 28 | } 29 | return false 30 | case *ast.BasicLit: 31 | return true 32 | default: 33 | return false 34 | } 35 | } 36 | 37 | func isCommutative(info *types.Info, x *ast.BinaryExpr) bool { 38 | // TODO(quasilyte): make this list more or less complete. 39 | switch x.Op { 40 | case token.ADD: 41 | return !typep.HasStringProp(info.TypeOf(x)) 42 | case token.EQL, token.NEQ: 43 | return true 44 | default: 45 | return false 46 | } 47 | } 48 | 49 | func constValueNode(cv constant.Value) ast.Expr { 50 | switch cv.Kind() { 51 | case constant.Bool: 52 | if constant.BoolVal(cv) { 53 | return &ast.Ident{Name: "true"} 54 | } 55 | return &ast.Ident{Name: "false"} 56 | 57 | case constant.String: 58 | v := constant.StringVal(cv) 59 | return &ast.BasicLit{ 60 | Kind: token.STRING, 61 | Value: strconv.Quote(v), 62 | } 63 | 64 | case constant.Int: 65 | // For whatever reason, constant.Int can also 66 | // mean "float with 0 fractional part". 67 | v, exact := constant.Int64Val(cv) 68 | if !exact { 69 | return nil 70 | } 71 | return &ast.BasicLit{ 72 | Kind: token.INT, 73 | Value: fmt.Sprint(v), 74 | } 75 | 76 | case constant.Float: 77 | v, exact := constant.Float64Val(cv) 78 | if !exact { 79 | return nil 80 | } 81 | s := fmt.Sprint(v) 82 | if !strings.Contains(s, ".") { 83 | s += ".0" 84 | } 85 | return &ast.BasicLit{ 86 | Kind: token.FLOAT, 87 | Value: s, 88 | } 89 | 90 | default: 91 | panic("unimplemented") 92 | } 93 | } 94 | 95 | func constValueOf(info *types.Info, x ast.Expr) constant.Value { 96 | if cv := info.Types[x].Value; cv != nil { 97 | return cv 98 | } 99 | lit := astcast.ToBasicLit(x) 100 | switch lit.Kind { 101 | case token.INT: 102 | v, err := strconv.ParseInt(lit.Value, 10, 64) 103 | if err != nil { 104 | return nil 105 | } 106 | return constant.MakeInt64(v) 107 | default: 108 | return nil 109 | } 110 | } 111 | 112 | func zeroValueOf(typ types.Type) ast.Expr { 113 | switch typ := typ.(type) { 114 | case *types.Basic: 115 | info := typ.Info() 116 | var zv ast.Expr 117 | switch { 118 | case info&types.IsInteger != 0: 119 | zv = &ast.BasicLit{Kind: token.INT, Value: "0"} 120 | case info&types.IsFloat != 0: 121 | zv = &ast.BasicLit{Kind: token.FLOAT, Value: "0.0"} 122 | case info&types.IsString != 0: 123 | zv = &ast.BasicLit{Kind: token.STRING, Value: `""`} 124 | } 125 | if isDefaultLiteralType(typ) { 126 | return zv 127 | } 128 | return &ast.CallExpr{ 129 | Fun: typeToExpr(typ), 130 | Args: []ast.Expr{zv}, 131 | } 132 | case *types.Slice: 133 | return &ast.CallExpr{ 134 | Fun: typeToExpr(typ), 135 | Args: []ast.Expr{&ast.Ident{Name: "nil"}}, 136 | } 137 | case *types.Array: 138 | return &ast.CompositeLit{Type: typeToExpr(typ)} 139 | } 140 | return nil 141 | } 142 | 143 | func typeToExpr(typ types.Type) ast.Expr { 144 | // This is a very dirty and inefficient way, 145 | // but it's at the very same time so simple and tempting. 146 | return strparse.Expr(typ.String()) 147 | } 148 | 149 | func findNode(root ast.Node, pred func(ast.Node) bool) ast.Node { 150 | var found ast.Node 151 | astutil.Apply(root, nil, func(cur *astutil.Cursor) bool { 152 | if pred(cur.Node()) { 153 | found = cur.Node() 154 | return false 155 | } 156 | return true 157 | }) 158 | return found 159 | } 160 | 161 | func containsNode(root ast.Node, pred func(ast.Node) bool) bool { 162 | return findNode(root, pred) != nil 163 | } 164 | 165 | func isDefaultLiteralType(typ types.Type) bool { 166 | btyp, ok := typ.(*types.Basic) 167 | if !ok { 168 | return false 169 | } 170 | switch btyp.Kind() { 171 | case types.Bool, types.Int, types.Float64, types.String: 172 | return true 173 | default: 174 | return false 175 | } 176 | } 177 | --------------------------------------------------------------------------------