├── .travis.yml
├── LICENSE
├── README.md
├── astnorm.go
├── cmd
├── go-normalize
│ └── main.go
├── grepfunc
│ └── main.go
└── internal
│ └── loadfile
│ └── loadfile.go
├── example
├── demo.bash
├── filter.go
├── mylib1
│ └── mylib1.go
├── mylib1n
│ └── mylib1.go
├── mylib2
│ └── mylib2.go
└── mylib2n
│ └── mylib2.go
├── logo.jpg
├── normalizer.go
├── normalizer_test.go
├── testdata
├── normalize_expr.go
└── normalize_stmt.go
└── utils.go
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 | go:
3 | - 1.x
4 | install:
5 | - # Prevent default install action "go get -t -v ./...".
6 | script:
7 | - go get -t -v ./...
8 | - go vet .
9 | - go test -v -race ./...
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Iskander (Alex) Sharipov / Quasilyte
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://goreportcard.com/report/github.com/Quasilyte/astnorm)
2 | [](https://godoc.org/github.com/Quasilyte/astnorm)
3 | [](https://travis-ci.org/Quasilyte/astnorm)
4 |
5 | 
6 |
7 | # astnorm
8 |
9 | Go AST normalization experiment.
10 |
11 | > THIS IS NOT A PROPER LIBRARY (yet?).
12 | > DO NOT USE.
13 | > It will probably be completely re-written before it becomes usable.
14 |
15 | ## Normalized code examples
16 |
17 | 1. Swap values.
18 |
19 |
20 |
21 | Before |
22 | After |
23 |
24 |
25 |
26 |
27 | ```go
28 | tmp := xs[i]
29 | xs[i] = ys[i]
30 | ys[i] = tmp
31 | ```
32 |
33 | |
34 |
35 | ```go
36 | xs[i], ys[i] = ys[i], xs[i]
37 | ```
38 |
39 | |
40 |
41 |
42 | 2. Remove elements that are equal to `toRemove+1`.
43 |
44 |
45 |
46 | Before |
47 | After |
48 |
49 |
50 |
51 |
52 | ```go
53 | const toRemove = 10
54 | var filtered []int
55 | filtered = xs[0:0]
56 | for i := int(0); i < len(xs); i++ {
57 | x := xs[i]
58 | if toRemove+1 != x {
59 | filtered = append(filtered, x)
60 | }
61 | }
62 | return (filtered)
63 | ```
64 |
65 | |
66 |
67 | ```go
68 | filtered := []int(nil)
69 | filtered = xs[:0]
70 | for _, x := range xs {
71 | if x != 11 {
72 | filtered = append(filtered, x)
73 | }
74 | }
75 | return filtered
76 | ```
77 |
78 | |
79 |
80 |
81 | ## Usage examples
82 |
83 | * [cmd/go-normalize](/cmd/go-normalize): normalize given Go file
84 | * [cmd/grepfunc](/cmd/grepfunc): turn Go code into a pattern for `gogrep` and run it
85 |
86 | Potential workflow for code searching:
87 |
88 | ### 1. Code search
89 |
90 | * Normalize the entire Go stdlib
91 | * Then normalize your function
92 | * Run `grepfunc` against normalized stdlib
93 | * If function you implemented has implementation under the stdlib, you'll probably find it
94 |
95 | Basically, instead of stdlib you can use any kind of Go corpus.
96 |
97 | Another code search related tasks that can be simplified by `astnorm` are code similarity
98 | evaluation and code duplication detection of any kind.
99 |
100 | ### 2. Static analysis
101 |
102 | Suppose we have `badcode.go` file:
103 |
104 | ```go
105 | package badpkg
106 |
107 | func NotEqual(x1, x2 int) bool {
108 | return (x1) != x1
109 | }
110 | ```
111 |
112 | There is an obvious mistake there, `x1` used twice, but because of extra parenthesis, linters may not detect this issue:
113 |
114 | ```bash
115 | $ staticcheck badcode.go
116 | # No output
117 | ```
118 |
119 | Let's normalize the input first and then run `staticcheck`:
120 |
121 | ```bash
122 | go-normalize badcode.go > normalized_badcode.go
123 | staticcheck normalized_badcode.go
124 | normalized_badcode.go:4:9: identical expressions on the left and right side of the '!=' operator (SA4000)
125 | ```
126 |
127 | And we get the warning we deserve!
128 | No changes into `staticcheck` or any other linter are required.
129 |
130 | See also: [demo script](/example/demo.bash).
131 |
--------------------------------------------------------------------------------
/astnorm.go:
--------------------------------------------------------------------------------
1 | // package astnorm implements AST normalization routines.
2 | package astnorm
3 |
4 | import (
5 | "go/ast"
6 | "go/types"
7 | )
8 |
9 | // Config carries information needed to properly normalize
10 | // AST nodes as well as optional configuration values
11 | // to control different aspects of the process.
12 | type Config struct {
13 | Info *types.Info
14 | }
15 |
16 | // Expr returns normalized expression x.
17 | // x may be mutated.
18 | func Expr(cfg *Config, x ast.Expr) ast.Expr {
19 | return newNormalizer(cfg).normalizeExpr(x)
20 | }
21 |
22 | // Stmt returns normalized statement x.
23 | // x may be mutated.
24 | func Stmt(cfg *Config, x ast.Stmt) ast.Stmt {
25 | return newNormalizer(cfg).normalizeStmt(x)
26 | }
27 |
28 | // Block returns normalized block x.
29 | // x may be mutated.
30 | func Block(cfg *Config, x *ast.BlockStmt) *ast.BlockStmt {
31 | return newNormalizer(cfg).normalizeBlockStmt(x)
32 | }
33 |
--------------------------------------------------------------------------------
/cmd/go-normalize/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "go/ast"
6 | "go/printer"
7 | "go/token"
8 | "log"
9 | "os"
10 |
11 | "github.com/Quasilyte/astnorm"
12 | "github.com/Quasilyte/astnorm/cmd/internal/loadfile"
13 | )
14 |
15 | func main() {
16 | log.SetFlags(0)
17 |
18 | flag.Parse()
19 |
20 | targets := flag.Args()
21 | if len(targets) != 1 {
22 | log.Panicf("expected exactly 1 positional argument (input go file)")
23 | }
24 |
25 | // For now, handle only 1 input file case.
26 | // For simplicity reasons.
27 | f, info, err := loadfile.ByPath(targets[0])
28 | if err != nil {
29 | log.Panicf("loadfile: %v", err)
30 | }
31 | normalizationConfig := &astnorm.Config{
32 | Info: info,
33 | }
34 | f = normalizeFile(normalizationConfig, f)
35 | fset := token.NewFileSet()
36 | if err := printer.Fprint(os.Stdout, fset, f); err != nil {
37 | log.Panicf("print normalized file: %v", err)
38 | }
39 | }
40 |
41 | func normalizeFile(cfg *astnorm.Config, f *ast.File) *ast.File {
42 | // Strip comments.
43 | f.Doc = nil
44 | ast.Inspect(f, func(n ast.Node) bool {
45 | switch n := n.(type) {
46 | case *ast.FuncDecl:
47 | n.Doc = nil
48 | case *ast.GenDecl:
49 | n.Doc = nil
50 | case *ast.Field:
51 | n.Doc = nil
52 | case *ast.ImportSpec:
53 | n.Doc = nil
54 | case *ast.ValueSpec:
55 | n.Doc = nil
56 | case *ast.TypeSpec:
57 | n.Doc = nil
58 | default:
59 | }
60 | return true
61 | })
62 | f.Comments = nil
63 |
64 | for _, decl := range f.Decls {
65 | // TODO(quasilyte): could also normalize global vars,
66 | // consts and type defs, but funcs are OK for the POC.
67 | switch decl := decl.(type) {
68 | case *ast.FuncDecl:
69 | decl.Body = astnorm.Block(cfg, decl.Body)
70 | }
71 | }
72 | return f
73 | }
74 |
--------------------------------------------------------------------------------
/cmd/grepfunc/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "go/ast"
7 | "go/types"
8 | "log"
9 | "os/exec"
10 | "strings"
11 |
12 | "github.com/Quasilyte/astnorm/cmd/internal/loadfile"
13 | "github.com/go-toolsmith/astfmt"
14 | "github.com/go-toolsmith/typep"
15 | )
16 |
17 | func main() {
18 | log.SetFlags(0)
19 |
20 | input := flag.String("input", "",
21 | `input Go file with pattern function`)
22 | pattern := flag.String("pattern", "_pattern",
23 | `function to be interpreted as a pattern`)
24 | verbose := flag.Bool("v", false,
25 | `turn on debug output`)
26 | flag.Parse()
27 |
28 | if *input == "" {
29 | log.Panic("-input argument can't be empty")
30 | }
31 | targets := flag.Args()
32 |
33 | f, info, err := loadfile.ByPath(*input)
34 | if err != nil {
35 | log.Panicf("loadfile: %v", err)
36 | }
37 |
38 | var fndecl *ast.FuncDecl
39 | for _, decl := range f.Decls {
40 | decl, ok := decl.(*ast.FuncDecl)
41 | if ok && decl.Name.Name == *pattern {
42 | fndecl = decl
43 | break
44 | }
45 | }
46 | if fndecl == nil {
47 | log.Panicf("found no `%s` func in %q", *pattern, targets[0])
48 | }
49 | if fndecl.Body == nil {
50 | log.Panic("external funcs are not supported")
51 | }
52 | pat := makeGogrepPattern(info, fndecl.Body)
53 | s := astfmt.Sprint(pat)
54 | s = strings.TrimPrefix(s, "{")
55 | s = strings.TrimSuffix(s, "}")
56 |
57 | if *verbose {
58 | fmt.Println(s)
59 | }
60 |
61 | gogrepArgs := []string{"-x", s}
62 | gogrepArgs = append(gogrepArgs, targets...)
63 | out, err := exec.Command("gogrep", gogrepArgs...).CombinedOutput()
64 | if err != nil {
65 | log.Panicf("run gogrep: %v: %s", err, out)
66 | }
67 | fmt.Print(string(out))
68 | }
69 |
70 | type visitor struct {
71 | info *types.Info
72 | }
73 |
74 | func (v *visitor) visitNode(x ast.Node) bool {
75 | switch x := x.(type) {
76 | case *ast.Ident:
77 | // Don't replace type names.
78 | if typep.IsTypeExpr(v.info, x) {
79 | return true
80 | }
81 | x.Name = "$" + x.Name
82 | return true
83 | case *ast.CallExpr:
84 | // Don't want to replace function names.
85 | for _, arg := range x.Args {
86 | ast.Inspect(arg, v.visitNode)
87 | }
88 | return false
89 | default:
90 | return true
91 | }
92 | }
93 |
94 | func makeGogrepPattern(info *types.Info, body *ast.BlockStmt) *ast.BlockStmt {
95 | v := &visitor{info: info}
96 | ast.Inspect(body, v.visitNode)
97 | return body
98 | }
99 |
--------------------------------------------------------------------------------
/cmd/internal/loadfile/loadfile.go:
--------------------------------------------------------------------------------
1 | package loadfile
2 |
3 | import (
4 | "fmt"
5 | "go/ast"
6 | "go/types"
7 |
8 | "golang.org/x/tools/go/packages"
9 | )
10 |
11 | func ByPath(path string) (*ast.File, *types.Info, error) {
12 | cfg := &packages.Config{Mode: packages.LoadSyntax}
13 | pkgs, err := packages.Load(cfg, path)
14 | if err != nil {
15 | return nil, nil, fmt.Errorf("load: %v", err)
16 | }
17 | if errCount := packages.PrintErrors(pkgs); errCount != 0 {
18 | return nil, nil, fmt.Errorf("%d errors during package loading", errCount)
19 | }
20 | if len(pkgs) != 1 {
21 | return nil, nil, fmt.Errorf("loaded %d packages, expected only 1", len(pkgs))
22 | }
23 | pkg := pkgs[0]
24 | if len(pkg.Syntax) != 1 {
25 | err := fmt.Errorf("loaded package has %d files, expected only 1",
26 | len(pkg.Syntax))
27 | return nil, nil, err
28 | }
29 | return pkg.Syntax[0], pkg.TypesInfo, nil
30 | }
31 |
--------------------------------------------------------------------------------
/example/demo.bash:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Both packages have a func that implements strings.Repeat.
4 | # Suppose we found 1 such place in our codebase and calls to it
5 | # with strings.Repeat. But how maybe there are more such funcs?
6 |
7 | cat mylib1/mylib1.go
8 | go-normalize mylib1/mylib1.go > mylib1n/mylib1.go
9 | diff mylib1/mylib1.go mylib1n/mylib1.go
10 |
11 | cat mylib2/mylib2.go
12 | go-normalize mylib2/mylib2.go > mylib2n/mylib2.go
13 | diff mylib2/mylib2.go mylib2n/mylib2.go
14 |
15 | # Use grepfunc to create a pattern from Go code.
16 | # With syntax patterns, we can now ignore variable
17 | # names differences and find both functions
18 | # by either of them.
19 | grepfunc -input mylib1n/mylib1.go -pattern=makeString ./...
20 | grepfunc -v -input mylib2n/mylib2.go -pattern=repeatString ./...
21 |
22 | # For more examples see consult README file.
23 |
--------------------------------------------------------------------------------
/example/filter.go:
--------------------------------------------------------------------------------
1 | package example
2 |
3 | func _(xs []int) []int {
4 | const toRemove = 10
5 | var filtered []int
6 | filtered = xs[0:0]
7 | for i := int(0); i < len(xs); i++ {
8 | x := xs[i]
9 | if toRemove+1 != x {
10 | filtered = append(filtered, x)
11 | }
12 | }
13 | return (filtered)
14 | }
15 |
--------------------------------------------------------------------------------
/example/mylib1/mylib1.go:
--------------------------------------------------------------------------------
1 | package mylib1
2 |
3 | import "strings"
4 |
5 | func makeString(str string, num int) string {
6 | var parts = make([]string, num)
7 | for i := 0; i < len(parts); i++ {
8 | parts[i] = str
9 | }
10 | return strings.Join(parts, "")
11 | }
12 |
--------------------------------------------------------------------------------
/example/mylib1n/mylib1.go:
--------------------------------------------------------------------------------
1 | package mylib1
2 |
3 | import "strings"
4 |
5 | func makeString(str string, num int) string {
6 | parts := make([]string, num)
7 | for i := range parts {
8 | parts[i] = str
9 | }
10 | return strings.Join(parts, "")
11 | }
12 |
--------------------------------------------------------------------------------
/example/mylib2/mylib2.go:
--------------------------------------------------------------------------------
1 | package mylib2
2 |
3 | import "strings"
4 |
5 | func repeatString(s string, n int) string {
6 | var pieces = make([]string, n)
7 | for i := range pieces {
8 | pieces[i] = s
9 | }
10 | const sep = ""
11 | return strings.Join(pieces, sep)
12 | }
13 |
--------------------------------------------------------------------------------
/example/mylib2n/mylib2.go:
--------------------------------------------------------------------------------
1 | package mylib2
2 |
3 | import "strings"
4 |
5 | func repeatString(s string, n int) string {
6 | pieces := make([]string, n)
7 | for i := range pieces {
8 | pieces[i] = s
9 | }
10 | return strings.Join(pieces, "")
11 | }
12 |
--------------------------------------------------------------------------------
/logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/quasilyte/astnorm/544300fc3e48de9779bafaf42eb34bff76fde35e/logo.jpg
--------------------------------------------------------------------------------
/normalizer.go:
--------------------------------------------------------------------------------
1 | package astnorm
2 |
3 | import (
4 | "fmt"
5 | "go/ast"
6 | "go/constant"
7 | "go/token"
8 | "go/types"
9 |
10 | "github.com/go-toolsmith/astcast"
11 | "github.com/go-toolsmith/astequal"
12 | "github.com/go-toolsmith/astp"
13 | "github.com/go-toolsmith/typep"
14 | )
15 |
16 | type normalizer struct {
17 | cfg *Config
18 | }
19 |
20 | func newNormalizer(cfg *Config) *normalizer {
21 | return &normalizer{
22 | cfg: cfg,
23 | }
24 | }
25 |
26 | func (n *normalizer) foldConstexpr(x ast.Expr) ast.Expr {
27 | if astp.IsParenExpr(x) {
28 | return nil
29 | }
30 | tv := n.cfg.Info.Types[x]
31 | if tv.Value == nil {
32 | return nil
33 | }
34 |
35 | if lit, ok := x.(*ast.BasicLit); ok && lit.Kind == token.FLOAT {
36 | // Floats may require additional handling here.
37 | if tv.Value.Kind() == constant.Int {
38 | // Usually, this case means that value is 0.
39 | // But to be sure, keep this assertion here.
40 | v, exact := constant.Int64Val(tv.Value)
41 | if !exact || v != 0 {
42 | panic(fmt.Sprintf("unexpected value for float with kind=int"))
43 | }
44 | return &ast.BasicLit{
45 | Kind: token.FLOAT,
46 | Value: "0.0",
47 | }
48 | }
49 | }
50 |
51 | if call, ok := x.(*ast.CallExpr); ok && typep.IsTypeExpr(n.cfg.Info, call.Fun) {
52 | if !isDefaultLiteralType(n.cfg.Info.TypeOf(call).Underlying()) {
53 | call.Args[0] = constValueNode(tv.Value)
54 | return call
55 | }
56 | }
57 |
58 | return constValueNode(tv.Value)
59 | }
60 |
61 | func (n *normalizer) normalizeExpr(x ast.Expr) ast.Expr {
62 | if folded := n.foldConstexpr(x); folded != nil {
63 | return folded
64 | }
65 |
66 | switch x := x.(type) {
67 | case *ast.CallExpr:
68 | if typep.IsTypeExpr(n.cfg.Info, x.Fun) {
69 | return n.normalizeTypeConversion(x)
70 | }
71 | x.Fun = n.normalizeExpr(x.Fun)
72 | x.Args = n.normalizeExprList(x.Args)
73 | return x
74 | case *ast.SliceExpr:
75 | return n.normalizeSliceExpr(x)
76 | case *ast.ParenExpr:
77 | return n.normalizeExpr(x.X)
78 | case *ast.BinaryExpr:
79 | return n.normalizeBinaryExpr(x)
80 | default:
81 | return x
82 | }
83 | }
84 |
85 | func (n *normalizer) normalizeTypeConversion(x *ast.CallExpr) ast.Expr {
86 | typeTo := n.cfg.Info.TypeOf(x)
87 | typeFrom := n.cfg.Info.TypeOf(x.Args[0])
88 | if types.Identical(typeTo, typeFrom) {
89 | return x.Args[0]
90 | }
91 | return x
92 | }
93 |
94 | func (n *normalizer) normalizeSliceExpr(x *ast.SliceExpr) ast.Expr {
95 | x.Low = n.normalizeExpr(x.Low)
96 | x.High = n.normalizeExpr(x.High)
97 | x.Max = n.normalizeExpr(x.Max)
98 | x.X = n.normalizeExpr(x.X)
99 | // Omit default low boundary.
100 | if astcast.ToBasicLit(x.Low).Value == "0" {
101 | x.Low = nil
102 | }
103 | // Omit default high boundary, but only if 3rd index is abscent.
104 | if x.Max == nil {
105 | lenCall := astcast.ToCallExpr(x.High)
106 | if astcast.ToIdent(lenCall.Fun).Name == "len" && astequal.Expr(lenCall.Args[0], x.X) {
107 | x.High = nil
108 | }
109 | }
110 | // s[:] => s if s is a slice or a string.
111 | if x.Low == nil && x.High == nil && !typep.IsArray(n.cfg.Info.TypeOf(x.X)) {
112 | return x.X
113 | }
114 | return x
115 | }
116 |
117 | func (n *normalizer) normalizeBinaryExpr(x *ast.BinaryExpr) ast.Expr {
118 | x.X = n.normalizeExpr(x.X)
119 | x.Y = n.normalizeExpr(x.Y)
120 |
121 | // TODO(quasilyte): implement this check in a proper way.
122 | // Also handle empty strings.
123 | switch {
124 | case isCommutative(n.cfg.Info, x) && astcast.ToBasicLit(x.X).Value == "0":
125 | return x.Y
126 | case astcast.ToBasicLit(x.Y).Value == "0":
127 | return x.X
128 | }
129 |
130 | if isCommutative(n.cfg.Info, x) {
131 | lhs := astcast.ToBinaryExpr(x.X)
132 | cv1 := constValueOf(n.cfg.Info, lhs.Y)
133 | cv2 := constValueOf(n.cfg.Info, x.Y)
134 |
135 | if cv1 != nil && cv2 != nil {
136 | cv := constant.BinaryOp(cv1, x.Op, cv2)
137 | x.X = lhs.X
138 | x.Y = constValueNode(cv)
139 | return n.normalizeExpr(x)
140 | }
141 |
142 | // Turn yoda expressions into the more conventional notation.
143 | // Put constant inside the expression after the non-constant part.
144 | if isLiteralConst(n.cfg.Info, x.X) && !isLiteralConst(n.cfg.Info, x.Y) {
145 | x.X, x.Y = x.Y, x.X
146 | }
147 | }
148 |
149 | return x
150 | }
151 |
152 | func (n *normalizer) normalizeExprList(xs []ast.Expr) []ast.Expr {
153 | for i, x := range xs {
154 | xs[i] = n.normalizeExpr(x)
155 | }
156 | return xs
157 | }
158 |
159 | func (n *normalizer) normalizeStmt(x ast.Stmt) ast.Stmt {
160 | switch x := x.(type) {
161 | case *ast.LabeledStmt:
162 | return n.normalizeLabeledStmt(x)
163 | case *ast.ExprStmt:
164 | return n.normalizeExprStmt(x)
165 | case *ast.AssignStmt:
166 | return n.normalizeAssignStmt(x)
167 | case *ast.BlockStmt:
168 | return n.normalizeBlockStmt(x)
169 | case *ast.ReturnStmt:
170 | return n.normalizeReturnStmt(x)
171 | case *ast.DeclStmt:
172 | return n.normalizeDeclStmt(x)
173 | case *ast.ForStmt:
174 | return n.normalizeForStmt(x)
175 | case *ast.RangeStmt:
176 | return n.normalizeRangeStmt(x)
177 | case *ast.IfStmt:
178 | return n.normalizeIfStmt(x)
179 | case *ast.IncDecStmt:
180 | return n.normalizeIncDecStmt(x)
181 | default:
182 | return x
183 | }
184 | }
185 |
186 | func (n *normalizer) normalizeIncDecStmt(stmt *ast.IncDecStmt) *ast.IncDecStmt {
187 | stmt.X = n.normalizeExpr(stmt.X)
188 | return stmt
189 | }
190 |
191 | func (n *normalizer) normalizeReturnStmt(ret *ast.ReturnStmt) *ast.ReturnStmt {
192 | ret.Results = n.normalizeExprList(ret.Results)
193 | return ret
194 | }
195 |
196 | func (n *normalizer) normalizeBlockStmt(b *ast.BlockStmt) *ast.BlockStmt {
197 | list := b.List[:0]
198 | for _, x := range b.List {
199 | // Filter-out const decls.
200 | // We inline const values, so local const defs are
201 | // not needed to keep code valid.
202 | decl, ok := x.(*ast.DeclStmt)
203 | if ok && decl.Decl.(*ast.GenDecl).Tok == token.CONST {
204 | continue
205 | }
206 | list = append(list, n.normalizeStmt(x))
207 | }
208 | b.List = list
209 |
210 | n.normalizeValSwap(b)
211 |
212 | return b
213 | }
214 |
215 | func (n *normalizer) normalizeDeclStmt(stmt *ast.DeclStmt) ast.Stmt {
216 | decl := stmt.Decl.(*ast.GenDecl)
217 | if decl.Tok != token.VAR {
218 | return stmt
219 | }
220 | if len(decl.Specs) != 1 {
221 | return stmt
222 | }
223 | spec := decl.Specs[0].(*ast.ValueSpec)
224 | if len(spec.Names) != 1 {
225 | return stmt
226 | }
227 |
228 | switch {
229 | case len(spec.Values) == 1:
230 | // var x T = v
231 | return &ast.AssignStmt{
232 | Tok: token.DEFINE,
233 | Lhs: []ast.Expr{spec.Names[0]},
234 | Rhs: []ast.Expr{spec.Values[0]},
235 | }
236 | case len(spec.Values) == 0 && spec.Type != nil:
237 | // var x T
238 | zv := zeroValueOf(n.cfg.Info.TypeOf(spec.Type))
239 | if zv == nil {
240 | return stmt
241 | }
242 | return &ast.AssignStmt{
243 | Tok: token.DEFINE,
244 | Lhs: []ast.Expr{spec.Names[0]},
245 | Rhs: []ast.Expr{zv},
246 | }
247 | default:
248 | return stmt
249 | }
250 | }
251 |
252 | func (n *normalizer) normalizeRangeStmt(loop *ast.RangeStmt) *ast.RangeStmt {
253 | loop.Key = n.normalizeExpr(loop.Key) // Not needed?
254 | loop.Value = n.normalizeExpr(loop.Value) // Not needed?
255 | loop.X = n.normalizeExpr(loop.X)
256 | loop.Body = n.normalizeBlockStmt(loop.Body)
257 | return loop
258 | }
259 |
260 | func (n *normalizer) normalizeForStmt(loop *ast.ForStmt) ast.Stmt {
261 | loop.Init = n.normalizeStmt(loop.Init)
262 | loop.Cond = n.normalizeExpr(loop.Cond)
263 | loop.Post = n.normalizeStmt(loop.Post)
264 | loop.Body = n.normalizeBlockStmt(loop.Body)
265 |
266 | if len(loop.Body.List) == 0 {
267 | return loop // Don't care
268 | }
269 |
270 | // I want AST matchers.
271 | // A lot.
272 | // Why go-toolsmith doesn't have one yet?
273 | // Code below is a mess even with astcast.
274 |
275 | init := astcast.ToAssignStmt(loop.Init)
276 | cond := astcast.ToBinaryExpr(loop.Cond)
277 | post := astcast.ToIncDecStmt(loop.Post)
278 |
279 | if init.Tok != token.DEFINE || len(init.Lhs) != 1 || len(init.Rhs) != 1 {
280 | return loop
281 | }
282 | if astcast.ToBasicLit(init.Rhs[0]).Value != "0" {
283 | return loop
284 | }
285 | iter := astcast.ToIdent(init.Lhs[0])
286 |
287 | if cond.Op != token.LSS {
288 | return loop
289 | }
290 | if !astequal.Expr(iter, cond.X) {
291 | return loop
292 | }
293 | lenCall := astcast.ToCallExpr(cond.Y)
294 | if astcast.ToIdent(lenCall.Fun).Name != "len" {
295 | return loop
296 | }
297 | slice := lenCall.Args[0]
298 | if !typep.IsSlice(n.cfg.Info.TypeOf(slice)) {
299 | return loop
300 | }
301 |
302 | if post.Tok != token.INC || !astequal.Expr(post.X, iter) {
303 | return loop
304 | }
305 |
306 | // Loop header matched.
307 | // Now need to see whether we need a key and/or val.
308 | var val ast.Expr
309 | keyUsed := false
310 | skip := 0
311 |
312 | assign := astcast.ToAssignStmt(loop.Body.List[0])
313 | if assign.Tok == token.DEFINE && len(assign.Lhs) == 1 && len(assign.Rhs) == 1 {
314 | indexing := astcast.ToIndexExpr(assign.Rhs[0])
315 | if astequal.Expr(indexing.X, slice) && astequal.Expr(indexing.Index, iter) {
316 | val = assign.Lhs[0]
317 | skip = 1
318 | }
319 | }
320 |
321 | // Now check that iter is not modified inside loop.
322 | // And since we're lazy and this is only a POC,
323 | // give up on any usage of iter.
324 | for _, stmt := range loop.Body.List[skip:] {
325 | giveUp := false
326 | ast.Inspect(stmt, func(x ast.Node) bool {
327 | switch x := x.(type) {
328 | case *ast.IndexExpr:
329 | index := astcast.ToIdent(x.Index)
330 | if index.Name == iter.Name && astequal.Expr(x.X, slice) {
331 | keyUsed = true
332 | return false
333 | }
334 | return true
335 | case *ast.Ident:
336 | if x.Name == iter.Name {
337 | giveUp = true
338 | }
339 | return true
340 | default:
341 | return true
342 | }
343 | })
344 | if giveUp {
345 | return loop
346 | }
347 | }
348 |
349 | key := iter
350 | if !keyUsed {
351 | key = blankIdent
352 | }
353 | loop.Body.List = loop.Body.List[skip:]
354 | return &ast.RangeStmt{
355 | Key: key,
356 | Value: val,
357 | Tok: token.DEFINE,
358 | X: slice,
359 | Body: loop.Body,
360 | }
361 | }
362 |
363 | func (n *normalizer) normalizeValSwap(b *ast.BlockStmt) {
364 | // tmp := x
365 | // x = y
366 | // y = tmp
367 | //
368 | // =>
369 | //
370 | // x, y = y, x
371 | //
372 | // FIXME(quasilyte): if tmp is used somewhere outside of the value swap,
373 | // this transformation is illegal.
374 |
375 | for i := 0; i < len(b.List)-2; i++ {
376 | assignTmp := astcast.ToAssignStmt(b.List[i+0])
377 | assignX := astcast.ToAssignStmt(b.List[i+1])
378 | assignY := astcast.ToAssignStmt(b.List[i+2])
379 | if assignTmp.Tok != token.DEFINE {
380 | continue
381 | }
382 | if assignX.Tok != token.ASSIGN || assignY.Tok != token.ASSIGN {
383 | continue
384 | }
385 | if len(assignTmp.Lhs) != 1 || len(assignX.Lhs) != 1 || len(assignY.Lhs) != 1 {
386 | continue
387 | }
388 | tmp := astcast.ToIdent(assignTmp.Lhs[0])
389 | x := assignX.Lhs[0]
390 | y := assignY.Lhs[0]
391 | if !astequal.Expr(assignTmp.Rhs[0], x) {
392 | continue
393 | }
394 | if !astequal.Expr(assignX.Rhs[0], y) {
395 | continue
396 | }
397 | if !astequal.Expr(assignY.Rhs[0], tmp) {
398 | continue
399 | }
400 |
401 | b.List[i] = &ast.AssignStmt{
402 | Tok: token.ASSIGN,
403 | Lhs: []ast.Expr{x, y},
404 | Rhs: []ast.Expr{y, x},
405 | }
406 | b.List = append(b.List[:i+1], b.List[i+3:]...)
407 | }
408 | }
409 |
410 | func (n *normalizer) normalizeLabeledStmt(stmt *ast.LabeledStmt) *ast.LabeledStmt {
411 | stmt.Stmt = n.normalizeStmt(stmt.Stmt)
412 | return stmt
413 | }
414 |
415 | func (n *normalizer) normalizeExprStmt(stmt *ast.ExprStmt) *ast.ExprStmt {
416 | stmt.X = n.normalizeExpr(stmt.X)
417 | return stmt
418 | }
419 |
420 | func (n *normalizer) normalizeAssignStmt(assign *ast.AssignStmt) ast.Stmt {
421 | for i, lhs := range assign.Lhs {
422 | assign.Lhs[i] = n.normalizeExpr(lhs)
423 | }
424 | for i, rhs := range assign.Rhs {
425 | assign.Rhs[i] = n.normalizeExpr(rhs)
426 | }
427 | assign = n.normalizeAssignOp(assign)
428 | return assign
429 | }
430 |
431 | func (n *normalizer) normalizeAssignOp(assign *ast.AssignStmt) *ast.AssignStmt {
432 | if assign.Tok != token.ASSIGN || len(assign.Lhs) != 1 {
433 | return assign
434 | }
435 | rhs := astcast.ToBinaryExpr(assign.Rhs[0])
436 | if !astequal.Expr(assign.Lhs[0], rhs.X) {
437 | return assign
438 | }
439 | op, ok := assignOpTab[rhs.Op]
440 | if ok {
441 | assign.Tok = op
442 | assign.Rhs[0] = rhs.Y
443 | }
444 | return assign
445 | }
446 |
447 | func (n *normalizer) normalizeIfStmt(stmt *ast.IfStmt) *ast.IfStmt {
448 | stmt.Init = n.normalizeStmt(stmt.Init)
449 | stmt.Cond = n.normalizeExpr(stmt.Cond)
450 | stmt.Body = n.normalizeBlockStmt(stmt.Body)
451 | stmt.Else = n.normalizeStmt(stmt.Else)
452 | return stmt
453 | }
454 |
455 | var assignOpTab = map[token.Token]token.Token{
456 | token.ADD: token.ADD_ASSIGN,
457 | token.SUB: token.SUB_ASSIGN,
458 | token.MUL: token.MUL_ASSIGN,
459 | token.QUO: token.QUO_ASSIGN,
460 | token.REM: token.REM_ASSIGN,
461 |
462 | token.AND: token.AND_ASSIGN, // &=
463 | token.OR: token.OR_ASSIGN, // |=
464 | token.XOR: token.XOR_ASSIGN, // ^=
465 | token.SHL: token.SHL_ASSIGN, // <<=
466 | token.SHR: token.SHR_ASSIGN, // >>=
467 | token.AND_NOT: token.AND_NOT_ASSIGN, // &^=
468 | }
469 |
--------------------------------------------------------------------------------
/normalizer_test.go:
--------------------------------------------------------------------------------
1 | package astnorm
2 |
3 | import (
4 | "go/ast"
5 | "strings"
6 | "testing"
7 |
8 | "github.com/go-toolsmith/astcast"
9 | "github.com/go-toolsmith/astequal"
10 | "github.com/go-toolsmith/astfmt"
11 | "golang.org/x/tools/go/packages"
12 | )
13 |
14 | func isTestCase(assign *ast.AssignStmt) bool {
15 | return len(assign.Lhs) == 2 &&
16 | len(assign.Rhs) == 2 &&
17 | astcast.ToIdent(assign.Lhs[0]).Name == "_" &&
18 | astcast.ToIdent(assign.Lhs[1]).Name == "_"
19 | }
20 |
21 | func TestNormalizeExpr(t *testing.T) {
22 | pkg := loadPackage(t, "./testdata/normalize_expr.go")
23 | funcs := collectFuncDecls(pkg)
24 | cfg := &Config{Info: pkg.TypesInfo}
25 |
26 | for _, fn := range funcs {
27 | for _, stmt := range fn.Body.List {
28 | assign, ok := stmt.(*ast.AssignStmt)
29 | if !ok || !isTestCase(assign) {
30 | continue
31 | }
32 | input := assign.Rhs[0]
33 | want := assign.Rhs[1]
34 | have := Expr(cfg, input)
35 | if !astequal.Expr(have, want) {
36 | pos := pkg.Fset.Position(assign.Pos())
37 | t.Errorf("%s:\nhave: %s\nwant: %s",
38 | pos, astfmt.Sprint(have), astfmt.Sprint(want))
39 | }
40 | }
41 | }
42 | }
43 |
44 | func TestNormalizeStmt(t *testing.T) {
45 | pkg := loadPackage(t, "./testdata/normalize_stmt.go")
46 | funcs := collectFuncDecls(pkg)
47 | cfg := &Config{Info: pkg.TypesInfo}
48 |
49 | for _, fn := range funcs {
50 | for _, stmt := range fn.Body.List {
51 | assign, ok := stmt.(*ast.AssignStmt)
52 | if !ok || !isTestCase(assign) {
53 | continue
54 | }
55 | input := assign.Rhs[0].(*ast.FuncLit).Body
56 | want := assign.Rhs[1].(*ast.FuncLit).Body
57 | have := Stmt(cfg, input)
58 | if !astequal.Stmt(have, want) {
59 | pos := pkg.Fset.Position(assign.Pos())
60 | t.Errorf("%s:\nhave: %s\nwant: %s",
61 | pos, astfmt.Sprint(have), astfmt.Sprint(want))
62 | }
63 | }
64 | }
65 | }
66 |
67 | func collectFuncDecls(pkg *packages.Package) []*ast.FuncDecl {
68 | var funcs []*ast.FuncDecl
69 | for _, f := range pkg.Syntax {
70 | for _, decl := range f.Decls {
71 | decl, ok := decl.(*ast.FuncDecl)
72 | if !ok || decl.Body == nil {
73 | continue
74 | }
75 | if !strings.HasSuffix(decl.Name.Name, "Test") {
76 | continue
77 | }
78 | funcs = append(funcs, decl)
79 | }
80 | }
81 | return funcs
82 | }
83 |
84 | func loadPackage(t *testing.T, path string) *packages.Package {
85 | cfg := &packages.Config{Mode: packages.LoadSyntax}
86 | pkgs, err := packages.Load(cfg, path)
87 | if err != nil {
88 | t.Fatalf("load %q: %v", path, err)
89 | }
90 | if packages.PrintErrors(pkgs) > 0 {
91 | t.Fatalf("package %q loaded with errors", path)
92 | }
93 | if len(pkgs) != 1 {
94 | t.Fatalf("expected 1 package from %q path, got %d",
95 | path, len(pkgs))
96 | }
97 | return pkgs[0]
98 | }
99 |
--------------------------------------------------------------------------------
/testdata/normalize_expr.go:
--------------------------------------------------------------------------------
1 | package normalize_expr
2 |
3 | func addInts(x, y int) int { return x + y }
4 |
5 | func identityTest() {
6 | var x int
7 | type T int
8 |
9 | _, _ = x, x
10 | _, _ = 102, 102
11 | _, _ = x+1, x+1
12 | _, _ = 0-x, 0-x
13 | _, _ = 1.1, 1.1
14 | _, _ = 12412.312, 12412.312
15 | }
16 |
17 | func stringLiteralsTest() {
18 | _, _ = ``, ""
19 | _, _ = `\\`, "\\\\"
20 | _, _ = `\d+`, "\\d+"
21 | _, _ = `123`, "123"
22 | _, _ = "\n"+``+"\n", "\n\n"
23 | }
24 |
25 | func defaultSlicingBoundsTest() {
26 | var xs []int
27 | var s string
28 | var a [3]int
29 |
30 | _, _ = xs[0:], xs
31 | _, _ = (xs)[(0+0):], xs
32 | _, _ = xs[0:len(xs)], xs
33 | _, _ = (xs)[0:(len(xs))], xs
34 | _, _ = xs[:0:0], xs[:0:0]
35 |
36 | _, _ = s[0:len(s)], s
37 | _, _ = s[1:], s[1:]
38 |
39 | _, _ = a[:], a[:]
40 | }
41 |
42 | func literalsTest() {
43 | // Convert any int numerical base into 10.
44 | _, _ = 0x0, 0
45 | _, _ = 0x1, 1
46 | _, _ = 04, 4
47 | _, _ = 010, 8
48 |
49 | // Represent floats in a consistent way.
50 | _, _ = 1.0, 1.0
51 | _, _ = 5.0, 5.0
52 | _, _ = 0.0, 0.0
53 | _, _ = .0, 0.0
54 | _, _ = 0., 0.0
55 | _, _ = 0.1e4, 1000.0
56 | _, _ = 00.0, 0.0
57 | }
58 |
59 | func conversionTest() {
60 | var x int
61 |
62 | // These alredy have proper type even without conversion.
63 | _, _ = int(1), 1
64 | _, _ = float64(40.1), 40.1
65 | _, _ = int(x), x
66 | _, _ = int(x+1), x+1
67 |
68 | // Repetitive conversions.
69 | _, _ = int(int(int(1))), 1
70 |
71 | // These require conversion.
72 | _, _ = int32(x), int32(x)
73 |
74 | // Preserve type conversion for untyped literals.
75 | _, _ = int8(1), int8(1)
76 | _, _ = int8(int16(1)), int8(1)
77 | _, _ = int8(int16(int32(1))), int8(1)
78 | _, _ = int8(int16(int32(int64(1)))), int8(1)
79 | _, _ = int16(int8(int16(int32(int64(1+1+1))))), int16(3)
80 | _, _ = int32(int16(int8(int16(int32(int64(1+1)))))), int32(2)
81 | }
82 |
83 | func yodaTest() {
84 | var x int
85 | var s string
86 | var m map[int]int
87 |
88 | _, _ = 1+x, x+1
89 | _, _ = (nil != m), m != nil
90 |
91 | // Concat is not commutative.
92 | _, _ = "prefix"+s, "prefix"+s
93 | // Other non-commutative ops.
94 | _, _ = 1-x, 1-x
95 | _, _ = 1000/x, 1000/x
96 | }
97 |
98 | func foldBoolTest() {
99 | _, _ = false && false, false
100 | _, _ = false || false, false
101 | _, _ = true && true, true
102 | _, _ = true || false, true
103 | _, _ = false || false || true, true
104 |
105 | _, _ = 1 != 1, false
106 | _, _ = 1 == 1, true
107 |
108 | var x float64
109 | // This is NaN test.
110 | // Not something that should be replaced.
111 | _, _ = x != x, x != x
112 | }
113 |
114 | func foldArithTest() {
115 | var x int
116 |
117 | // Const-only expressions are folded entirely.
118 | _, _ = 1+2+3, 6
119 | _, _ = 6-2, 4
120 |
121 | // Zeroes can be removed completely as well.
122 | _, _ = x+0, x
123 | _, _ = x+(0)+0, x
124 | _, _ = 0+x, x
125 | _, _ = 0+0+x, x
126 | _, _ = 0+x+(0), x
127 | _, _ = (0+0)+x+0, x
128 | _, _ = 0+x+0+0, x
129 | _, _ = x-0-0, x
130 |
131 | // For commutative ops fold it into a single op.
132 | _, _ = x+1, x+1
133 | _, _ = x+1+1, x+2
134 | _, _ = 1+x+1, x+2
135 | _, _ = 1+2+x+2+1, x+6
136 | _, _ = (1+2)+x+2+1, x+6
137 | _, _ = ((1 + (2)) + (x + 2) + 1), x+6
138 | _, _ = 0.2+0.1, 0.3
139 |
140 | _, _ = "a"+"b"+"c", "abc"
141 | }
142 |
143 | func parenthesisRemovalTest() {
144 | var x int
145 | type T int
146 |
147 | _, _ = (x), x
148 | _, _ = ((*T)(&x)), (*T)(&x)
149 | _, _ = (addInts)(1, 2), addInts(1, 2)
150 | _, _ = addInts((1), (2)), addInts(1, 2)
151 | }
152 |
--------------------------------------------------------------------------------
/testdata/normalize_stmt.go:
--------------------------------------------------------------------------------
1 | package normalize_stmt
2 |
3 | func addInts(x, y int) int { return x + y }
4 |
5 | func identityTest() {
6 | var x int
7 |
8 | _, _ = func() {
9 | x += 1
10 | x -= 1
11 | }, func() {
12 | x += 1
13 | x -= 1
14 | }
15 | }
16 |
17 | func incdecStmtTest() {
18 | var x int
19 |
20 | _, _ = func() {
21 | x++
22 | (x)++
23 | }, func() {
24 | x++
25 | x++
26 | }
27 | }
28 |
29 | func rangeStmtTest() {
30 | var xs []int
31 |
32 | _, _ = func() {
33 | for i := range xs[0:len(xs)] {
34 | _ = (i)
35 | }
36 | }, func() {
37 | for i := range xs {
38 | _ = i
39 | }
40 | }
41 | }
42 |
43 | func assignOpTest() {
44 | var x int
45 |
46 | _, _ = func() {
47 | x = x + 5
48 | x = x - 2
49 | x = x * 4
50 | }, func() {
51 | x += 5
52 | x -= 2
53 | x *= 4
54 | }
55 | }
56 |
57 | func valueSwapTest() {
58 | var x, y int
59 |
60 | _, _ = func() {
61 | tmp := (x)
62 | x = y
63 | y = tmp
64 | }, func() {
65 | x, y = y, x
66 | }
67 |
68 | _, _ = func() {
69 | tmp1 := x
70 | x = y
71 | y = tmp1
72 |
73 | tmp2 := y
74 | y = x
75 | x = tmp2
76 | }, func() {
77 | x, y = y, x
78 | y, x = x, y
79 | }
80 |
81 | }
82 |
83 | func removeConstDeclsTest() {
84 | _, _ = func() {
85 | const n = 10
86 | _ = n + n
87 | }, func() {
88 | _ = 20
89 | }
90 |
91 | _, _ = func() {
92 | const n = 10
93 | x := 10
94 | _ = x != n+1
95 | }, func() {
96 | x := 10
97 | _ = x != 11
98 | }
99 | }
100 |
101 | func rewriteVarSpecTest() {
102 | _, _ = func() {
103 | var x = 10
104 | var y float32 = float32(x)
105 | _ = x
106 | _ = y
107 | }, func() {
108 | x := 10
109 | y := float32(x)
110 | _ = x
111 | _ = y
112 | }
113 |
114 | _, _ = func() {
115 | var x int
116 | _ = x
117 | }, func() {
118 | x := 0
119 | _ = x
120 | }
121 |
122 | _, _ = func() {
123 | var xs [][]int
124 | var s string
125 | _ = xs
126 | _ = s
127 | }, func() {
128 | xs := [][]int(nil)
129 | s := ""
130 | _ = xs
131 | _ = s
132 | }
133 |
134 | _, _ = func() {
135 | var xs [8]string
136 | _ = xs
137 | }, func() {
138 | xs := [8]string{}
139 | _ = xs
140 | }
141 |
142 | _, _ = func() (float64, float32) {
143 | var x float64
144 | var y float32
145 | return x, y
146 | }, func() (float64, float32) {
147 | x := 0.0
148 | y := float32(0.0)
149 | return x, y
150 | }
151 | }
152 |
153 | func rangeLoopTest() {
154 | _, _ = func() {
155 | var xs []int
156 | for i := 0; i < len(xs); i++ {
157 | x := xs[i]
158 | _ = x
159 | }
160 |
161 | // Uses i+1 index.
162 | for i := 0; i < len(xs); i++ {
163 | x := xs[i+1]
164 | _ = x
165 | }
166 |
167 | // Doesn't assign elem.
168 |
169 | for i := 0; i < len(xs); i++ {
170 | _ = i
171 | }
172 |
173 | // TODO(quasilyte): more negative tests.
174 | // (Hint: use coverage to guide you, Luke!)
175 | }, func() {
176 | xs := []int(nil)
177 | for _, x := range xs {
178 | _ = x
179 | }
180 |
181 | for i := 0; i < len(xs); i++ {
182 | x := xs[i+1]
183 | _ = x
184 | }
185 |
186 | for i := 0; i < len(xs); i++ {
187 | _ = i
188 | }
189 | }
190 |
191 | _, _ = func() {
192 | var xs []int
193 | const toRemove = 10
194 | var filtered []int
195 | filtered = xs[0:0]
196 | for i := int(0); i < len(xs); i++ {
197 | x := xs[i]
198 | if toRemove+1 != x {
199 | filtered = append(filtered, x)
200 | }
201 | }
202 | _ = (filtered)
203 | }, func() {
204 | xs := []int(nil)
205 | filtered := []int(nil)
206 | filtered = xs[:0]
207 | for _, x := range xs {
208 | if x != 11 {
209 | filtered = append(filtered, x)
210 | }
211 | }
212 | _ = filtered
213 | }
214 |
215 | _, _ = func(xs []int) {
216 | for i := 0; i < len(xs); i++ {
217 | _ = xs[i]
218 | }
219 | }, func(xs []int) {
220 | for i := range xs {
221 | _ = xs[i]
222 | }
223 | }
224 |
225 | _, _ = func(xs []int) {
226 | for i := 0; i < len(xs); i++ {
227 | _ = xs[i+1]
228 | }
229 | }, func(xs []int) {
230 | for i := 0; i < len(xs); i++ {
231 | _ = xs[i+1]
232 | }
233 | }
234 |
235 | _, _ = func(xs []int) {
236 | for i := 0; i < len(xs); i++ {
237 | v := xs[i]
238 | _ = v
239 | _ = xs[i]
240 | }
241 | }, func(xs []int) {
242 | for i, v := range xs {
243 | _ = v
244 | _ = xs[i]
245 | }
246 | }
247 |
248 | _, _ = func(xs []int) {
249 | for i := 0; i < len(xs); i++ {
250 | v := xs[i]
251 | _ = v
252 | _ = xs[i]
253 | i++
254 | }
255 | }, func(xs []int) {
256 | for i := 0; i < len(xs); i++ {
257 | v := xs[i]
258 | _ = v
259 | _ = xs[i]
260 | i++
261 | }
262 | }
263 | }
264 |
265 | func exprStmtTest() {
266 | _, _ = func() {
267 | addInts((1), 0+0+0)
268 | }, func() {
269 | addInts(1, 0)
270 | }
271 | }
272 |
273 | func labeledStmtTest() {
274 | _, _ = func() int {
275 | goto L
276 | L:
277 | return (0)
278 | }, func() int {
279 | goto L
280 | L:
281 | return 0
282 | }
283 | }
284 |
285 | func combinedTest() {
286 | var x int
287 |
288 | _, _ = func() {
289 | x = x + (2)
290 | }, func() {
291 | x += 2
292 | }
293 | }
294 |
--------------------------------------------------------------------------------
/utils.go:
--------------------------------------------------------------------------------
1 | package astnorm
2 |
3 | import (
4 | "fmt"
5 | "go/ast"
6 | "go/constant"
7 | "go/token"
8 | "go/types"
9 | "strconv"
10 | "strings"
11 |
12 | "github.com/go-toolsmith/astcast"
13 | "github.com/go-toolsmith/strparse"
14 | "github.com/go-toolsmith/typep"
15 | "golang.org/x/tools/go/ast/astutil"
16 | )
17 |
18 | var blankIdent = &ast.Ident{Name: "_"}
19 |
20 | func isLiteralConst(info *types.Info, x ast.Expr) bool {
21 | switch x := x.(type) {
22 | case *ast.Ident:
23 | // Not really literal consts, but they are
24 | // considered as such by many programmers.
25 | switch x.Name {
26 | case "nil", "true", "false":
27 | return true
28 | }
29 | return false
30 | case *ast.BasicLit:
31 | return true
32 | default:
33 | return false
34 | }
35 | }
36 |
37 | func isCommutative(info *types.Info, x *ast.BinaryExpr) bool {
38 | // TODO(quasilyte): make this list more or less complete.
39 | switch x.Op {
40 | case token.ADD:
41 | return !typep.HasStringProp(info.TypeOf(x))
42 | case token.EQL, token.NEQ:
43 | return true
44 | default:
45 | return false
46 | }
47 | }
48 |
49 | func constValueNode(cv constant.Value) ast.Expr {
50 | switch cv.Kind() {
51 | case constant.Bool:
52 | if constant.BoolVal(cv) {
53 | return &ast.Ident{Name: "true"}
54 | }
55 | return &ast.Ident{Name: "false"}
56 |
57 | case constant.String:
58 | v := constant.StringVal(cv)
59 | return &ast.BasicLit{
60 | Kind: token.STRING,
61 | Value: strconv.Quote(v),
62 | }
63 |
64 | case constant.Int:
65 | // For whatever reason, constant.Int can also
66 | // mean "float with 0 fractional part".
67 | v, exact := constant.Int64Val(cv)
68 | if !exact {
69 | return nil
70 | }
71 | return &ast.BasicLit{
72 | Kind: token.INT,
73 | Value: fmt.Sprint(v),
74 | }
75 |
76 | case constant.Float:
77 | v, exact := constant.Float64Val(cv)
78 | if !exact {
79 | return nil
80 | }
81 | s := fmt.Sprint(v)
82 | if !strings.Contains(s, ".") {
83 | s += ".0"
84 | }
85 | return &ast.BasicLit{
86 | Kind: token.FLOAT,
87 | Value: s,
88 | }
89 |
90 | default:
91 | panic("unimplemented")
92 | }
93 | }
94 |
95 | func constValueOf(info *types.Info, x ast.Expr) constant.Value {
96 | if cv := info.Types[x].Value; cv != nil {
97 | return cv
98 | }
99 | lit := astcast.ToBasicLit(x)
100 | switch lit.Kind {
101 | case token.INT:
102 | v, err := strconv.ParseInt(lit.Value, 10, 64)
103 | if err != nil {
104 | return nil
105 | }
106 | return constant.MakeInt64(v)
107 | default:
108 | return nil
109 | }
110 | }
111 |
112 | func zeroValueOf(typ types.Type) ast.Expr {
113 | switch typ := typ.(type) {
114 | case *types.Basic:
115 | info := typ.Info()
116 | var zv ast.Expr
117 | switch {
118 | case info&types.IsInteger != 0:
119 | zv = &ast.BasicLit{Kind: token.INT, Value: "0"}
120 | case info&types.IsFloat != 0:
121 | zv = &ast.BasicLit{Kind: token.FLOAT, Value: "0.0"}
122 | case info&types.IsString != 0:
123 | zv = &ast.BasicLit{Kind: token.STRING, Value: `""`}
124 | }
125 | if isDefaultLiteralType(typ) {
126 | return zv
127 | }
128 | return &ast.CallExpr{
129 | Fun: typeToExpr(typ),
130 | Args: []ast.Expr{zv},
131 | }
132 | case *types.Slice:
133 | return &ast.CallExpr{
134 | Fun: typeToExpr(typ),
135 | Args: []ast.Expr{&ast.Ident{Name: "nil"}},
136 | }
137 | case *types.Array:
138 | return &ast.CompositeLit{Type: typeToExpr(typ)}
139 | }
140 | return nil
141 | }
142 |
143 | func typeToExpr(typ types.Type) ast.Expr {
144 | // This is a very dirty and inefficient way,
145 | // but it's at the very same time so simple and tempting.
146 | return strparse.Expr(typ.String())
147 | }
148 |
149 | func findNode(root ast.Node, pred func(ast.Node) bool) ast.Node {
150 | var found ast.Node
151 | astutil.Apply(root, nil, func(cur *astutil.Cursor) bool {
152 | if pred(cur.Node()) {
153 | found = cur.Node()
154 | return false
155 | }
156 | return true
157 | })
158 | return found
159 | }
160 |
161 | func containsNode(root ast.Node, pred func(ast.Node) bool) bool {
162 | return findNode(root, pred) != nil
163 | }
164 |
165 | func isDefaultLiteralType(typ types.Type) bool {
166 | btyp, ok := typ.(*types.Basic)
167 | if !ok {
168 | return false
169 | }
170 | switch btyp.Kind() {
171 | case types.Bool, types.Int, types.Float64, types.String:
172 | return true
173 | default:
174 | return false
175 | }
176 | }
177 |
--------------------------------------------------------------------------------