├── .github └── workflows │ └── unittest.yml ├── .gitignore ├── LICENSE ├── README.md ├── example ├── duplicated_exports │ └── main.go ├── file │ └── file.go ├── for-loop │ └── for-loop.go ├── func-args │ └── func-args.go ├── goroutine │ └── main.go ├── goroutines │ └── goroutines.go ├── maps │ └── maps.go ├── math │ └── math.go ├── packagealiases │ └── main.go ├── readme │ └── readme.go └── vendoring │ ├── main.go │ └── vendor │ └── foo │ └── foo.go ├── generate ├── generate.go └── generate_test.go ├── go.mod ├── go.sum ├── main.go ├── main_test.go ├── playground ├── .gitignore ├── Dockerfile ├── Makefile ├── genmeta │ └── genmeta.go ├── index.html ├── now.json ├── package.json ├── packages.txt ├── server │ ├── server.go │ ├── server11_test.go │ └── server_test.go └── yarn.lock ├── pry-build-corpus └── main.go └── pry ├── autocomplete.go ├── fuzz.go ├── helpers.go ├── highlighter.go ├── highlighter_test.go ├── importer.go ├── importer_default.go ├── importer_js.go ├── interpreter.go ├── interpreter_test.go ├── io_default.go ├── io_default_test.go ├── io_js.go ├── package.go ├── pry.go ├── pry_test.go ├── pseudo_generics.go ├── safebuffer └── safebuffer.go ├── suggestions.go ├── tty_js.go ├── tty_unix.go ├── tty_windows.go ├── type.go └── type_test.go /.github/workflows/unittest.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | 7 | name: Test 8 | jobs: 9 | test: 10 | strategy: 11 | matrix: 12 | go-version: [1.16.x] 13 | os: [ubuntu-latest] 14 | runs-on: ${{ matrix.os }} 15 | steps: 16 | - name: Install Go 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ${{ matrix.go-version }} 20 | - name: Checkout code 21 | uses: actions/checkout@v2 22 | - name: Install 23 | run: | 24 | set -ex 25 | go get -t -v . ./pry ./generate ./playground/server 26 | - name: Test 27 | run: | 28 | set -ex 29 | go test -v -race . ./pry ./generate ./playground/server 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.go.go 2 | *.zip 3 | fuzz/ 4 | *.coverprofile 5 | go-pry-test* 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Tristan Rice 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # go-pry 2 | 3 | go-pry - an interactive REPL for Go that allows you to drop into your code at any point. 4 | 5 | ![Tests](https://github.com/d4l3k/go-pry/actions/workflows/unittest.yml/badge.svg) 6 | [![GoDoc](https://godoc.org/github.com/d4l3k/go-pry/pry?status.svg)](https://godoc.org/github.com/d4l3k/go-pry/pry) 7 | 8 | ![go-pry](https://i.imgur.com/yr1BEsK.png) 9 | 10 | Example 11 | 12 | ![go-pry Animated Example](https://i.imgur.com/H8hFzPV.gif) 13 | ![go-pry Example](https://i.imgur.com/0rmwVY7.png) 14 | 15 | 16 | 17 | ## Usage 18 | 19 | Install go-pry 20 | ```bash 21 | go get github.com/d4l3k/go-pry 22 | go install -i github.com/d4l3k/go-pry 23 | 24 | ``` 25 | 26 | Add the pry statement to the code 27 | ```go 28 | package main 29 | 30 | import "github.com/d4l3k/go-pry/pry" 31 | 32 | func main() { 33 | a := 1 34 | pry.Pry() 35 | } 36 | ``` 37 | 38 | Run the code as you would normally with the `go` command. go-pry is just a wrapper. 39 | ```bash 40 | # Run 41 | go-pry run readme.go 42 | ``` 43 | 44 | If you want completions to work properly, also install `gocode` if it 45 | is not installed in your system 46 | 47 | ```bash 48 | go get -u github.com/nsf/gocode 49 | ``` 50 | 51 | 52 | ## How does it work? 53 | go-pry is built using a combination of meta programming as well as a massive amount of reflection. When you invoke the go-pry command it looks at the Go files in the mentioned directories (or the current in cases such as `go-pry build`) and processes them. Since Go is a compiled language there's no way to dynamically get in scope variables, and even if there was, unused imports would be automatically removed for optimization purposes. Thus, go-pry has to find every instance of `pry.Pry()` and inject a large blob of code that contains references to all in scope variables and functions as well as those of the imported packages. When doing this it makes a copy of your file to `..gopry` and modifies the `.go` then passes the command arguments to the standard `go` command. Once the command exits, it restores the files. 54 | 55 | If the program unexpectedly fails there is a custom command `go-pry restore [files]` that will move the files back. An alternative is to just remove the `pry.Apply(...)` line. 56 | 57 | ## Inspiration 58 | 59 | go-pry is greatly inspired by [Pry REPL](http://pryrepl.org) for Ruby. 60 | 61 | ## License 62 | 63 | go-pry is licensed under the MIT license. 64 | 65 | Made by [Tristan Rice](https://fn.lc). 66 | -------------------------------------------------------------------------------- /example/duplicated_exports/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | 7 | "github.com/d4l3k/go-pry/pry" 8 | ) 9 | 10 | func main() { 11 | a := filepath.Base("/asdf/asdf") 12 | pry.Pry() 13 | os.Setenv("foo", "bar") 14 | _ = a 15 | } 16 | -------------------------------------------------------------------------------- /example/file/file.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/d4l3k/go-pry/pry" 5 | 6 | "log" 7 | ) 8 | 9 | // Note: This file has some gibberish to test for highlighting and other edge cases. 10 | 11 | /* 12 | Block Quote 13 | */ 14 | 15 | func X() bool { 16 | return true 17 | } 18 | 19 | type Banana struct { 20 | Name string 21 | Cake []int 22 | } 23 | 24 | func (b Banana) Ly() string { 25 | return b.Name + "ly" 26 | } 27 | 28 | func main() { 29 | a := 1 30 | b := Banana{"Jeoffry", []int{1, 2, 3}} 31 | m := []int{1234} 32 | _ = m 33 | 34 | testMake := make(chan int, 1) 35 | testMap := map[int]interface{}{ 36 | 1: 2, 37 | 3: "asdf", 38 | 5: []interface{}{ 39 | 1, "asdf", 40 | }, 41 | } 42 | _ = testMap 43 | go func() { 44 | _ = 1 + 1*1/1%1 45 | }() 46 | 47 | if d := X(); d { 48 | log.Println(d) 49 | for i, j := range []int{1} { 50 | k := 1 51 | log.Println(i, j, k) 52 | // Example comment 53 | pry.Pry() 54 | } 55 | } 56 | log.Println("Test", a, b, main, testMake) 57 | } 58 | -------------------------------------------------------------------------------- /example/for-loop/for-loop.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/d4l3k/go-pry/pry" 4 | import "fmt" 5 | 6 | func main() { 7 | for i := 0; i < 10; i++ { 8 | pry.Pry() 9 | } 10 | fmt.Println("DUCK") 11 | } 12 | -------------------------------------------------------------------------------- /example/func-args/func-args.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/d4l3k/go-pry/pry" 4 | 5 | func a(b int) { 6 | c := 5 7 | pry.Pry() 8 | _ = c 9 | } 10 | 11 | func main() { 12 | a(5) 13 | } 14 | -------------------------------------------------------------------------------- /example/goroutine/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "html" 6 | "log" 7 | "net/http" 8 | 9 | "github.com/d4l3k/go-pry/pry" 10 | ) 11 | 12 | func main() { 13 | w := 6 14 | 15 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 16 | b := "toast" 17 | fmt.Fprintf(w, "Hello, %q", html.EscapeString(r.URL.Path)) 18 | pry.Pry() 19 | _ = b 20 | }) 21 | 22 | pry.Pry() 23 | 24 | log.Fatal(http.ListenAndServe(":8080", nil)) 25 | _ = w 26 | } 27 | -------------------------------------------------------------------------------- /example/goroutines/goroutines.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/d4l3k/go-pry/pry" 6 | "time" 7 | ) 8 | 9 | func prying() { 10 | fmt.Println("PRYING!") 11 | } 12 | 13 | func main() { 14 | c := make(chan bool) 15 | go func() { 16 | prying() 17 | pry.Pry() 18 | c <- true 19 | }() 20 | <-c 21 | for { 22 | time.Sleep(time.Second) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /example/maps/maps.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/d4l3k/go-pry/pry" 7 | ) 8 | 9 | func main() { 10 | testMap := map[string]int{ 11 | "duck": 1, 12 | "blue": 2, 13 | "5": 0xDEAD, 14 | } 15 | for k, v := range testMap { 16 | fmt.Println(k, v) 17 | } 18 | pry.Pry() 19 | } 20 | -------------------------------------------------------------------------------- /example/math/math.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/d4l3k/go-pry/pry" 5 | "math" 6 | ) 7 | 8 | func main() { 9 | a := math.Sin(10.0) 10 | pry.Pry() 11 | _ = a 12 | } 13 | -------------------------------------------------------------------------------- /example/packagealiases/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | f "fmt" 5 | 6 | "github.com/d4l3k/go-pry/pry" 7 | ) 8 | 9 | func main() { 10 | f.Println("foo") 11 | pry.Pry() 12 | } 13 | -------------------------------------------------------------------------------- /example/readme/readme.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "github.com/d4l3k/go-pry/pry" 4 | 5 | func main() { 6 | a := 1 7 | pry.Pry() 8 | _ = a 9 | } 10 | -------------------------------------------------------------------------------- /example/vendoring/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "foo" 4 | 5 | func main() { 6 | foo.Foo() 7 | } 8 | -------------------------------------------------------------------------------- /example/vendoring/vendor/foo/foo.go: -------------------------------------------------------------------------------- 1 | package foo 2 | 3 | import "fmt" 4 | 5 | func Foo() { 6 | fmt.Println("foo!") 7 | } 8 | -------------------------------------------------------------------------------- /generate/generate.go: -------------------------------------------------------------------------------- 1 | package generate 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "go/ast" 7 | "go/parser" 8 | "go/token" 9 | "io/ioutil" 10 | "log" 11 | "os" 12 | "os/exec" 13 | "path/filepath" 14 | "reflect" 15 | "strings" 16 | 17 | "github.com/d4l3k/go-pry/pry" 18 | "github.com/pkg/errors" 19 | "golang.org/x/tools/go/packages" 20 | ) 21 | 22 | type Generator struct { 23 | contexts []pryContext 24 | debug bool 25 | Config packages.Config 26 | } 27 | 28 | func NewGenerator(debug bool) *Generator { 29 | return &Generator{ 30 | debug: debug, 31 | Config: packages.Config{ 32 | Mode: packages.NeedName | packages.NeedSyntax, 33 | }, 34 | } 35 | } 36 | 37 | // Debug prints debug statements if debug is true. 38 | func (g Generator) Debug(templ string, k ...interface{}) { 39 | if g.debug { 40 | log.Printf(templ, k...) 41 | } 42 | } 43 | 44 | // ExecuteGoCmd runs the 'go' command with certain parameters. 45 | func (g *Generator) ExecuteGoCmd(ctx context.Context, args []string, env []string) error { 46 | binary, err := exec.LookPath("go") 47 | if err != nil { 48 | return err 49 | } 50 | 51 | cmd := exec.CommandContext(ctx, binary, args...) 52 | cmd.Env = append(os.Environ(), env...) 53 | cmd.Stdout = os.Stdout 54 | cmd.Stderr = os.Stderr 55 | cmd.Stdin = os.Stdin 56 | return cmd.Run() 57 | } 58 | 59 | // InjectPry walks the scope and replaces pry.Pry with pry.Apply(pry.Scope{...}). 60 | func (g *Generator) InjectPry(filePath string) (string, error) { 61 | g.Debug("Prying into %s\n", filePath) 62 | filePath, err := filepath.Abs(filePath) 63 | if err != nil { 64 | return "", nil 65 | } 66 | 67 | g.contexts = make([]pryContext, 0) 68 | 69 | fset := token.NewFileSet() // positions are relative to fset 70 | 71 | // Parse the file containing this very example 72 | // but stop after processing the imports. 73 | f, err := parser.ParseFile(fset, filePath, nil, 0) 74 | if err != nil { 75 | return "", err 76 | } 77 | 78 | g.Config.Dir = filepath.Dir(filePath) 79 | 80 | packagePairs := []string{} 81 | for _, imp := range f.Imports { 82 | importStr := imp.Path.Value[1 : len(imp.Path.Value)-1] 83 | if importStr != "../pry" { 84 | pkgs, err := packages.Load(&g.Config, importStr) 85 | if err != nil { 86 | return "", err 87 | } 88 | pkg := pkgs[0] 89 | importName := pkg.Name 90 | if imp.Name != nil { 91 | importName = imp.Name.Name 92 | } 93 | pair := "\"" + importName + "\": pry.Package{Name: \"" + pkg.Name + "\", Functions: map[string]interface{}{" 94 | added := make(map[string]bool) 95 | exports, err := g.GetExports(importName, pkg.Syntax, added) 96 | if err != nil { 97 | return "", err 98 | } 99 | pair += exports 100 | pair += "}}, " 101 | packagePairs = append(packagePairs, pair) 102 | } 103 | } 104 | 105 | var funcs []*ast.FuncDecl 106 | var vars []string 107 | 108 | // Print the imports from the file's AST. 109 | for k, v := range f.Scope.Objects { 110 | switch decl := v.Decl.(type) { 111 | case *ast.FuncDecl: 112 | funcs = append(funcs, decl) 113 | case *ast.ValueSpec: 114 | vars = append(vars, k) 115 | } 116 | } 117 | 118 | for _, f := range funcs { 119 | vars = append(vars, f.Name.Name) 120 | if f.Recv != nil { 121 | vars = g.extractFields(vars, f.Recv.List) 122 | } 123 | if f.Type != nil { 124 | if f.Type.Params != nil { 125 | vars = g.extractFields(vars, f.Type.Params.List) 126 | } 127 | if f.Type.Results != nil { 128 | vars = g.extractFields(vars, f.Type.Results.List) 129 | } 130 | } 131 | vars = g.extractVariables(vars, f.Body.List) 132 | } 133 | 134 | fileTextBytes, err := ioutil.ReadFile(filePath) 135 | if err != nil { 136 | return "", nil 137 | } 138 | 139 | fileText := (string)(fileTextBytes) 140 | 141 | offset := 0 142 | 143 | if len(g.contexts) == 0 { 144 | return "", nil 145 | } 146 | 147 | g.Debug(" :: Found %d pry statements.\n", len(g.contexts)) 148 | 149 | for _, context := range g.contexts { 150 | filteredVars := filterVars(context.Vars) 151 | obj := "&pry.Scope{Vals:map[string]interface{}{ " 152 | for _, v := range filteredVars { 153 | obj += "\"" + v + "\": " + v + ", " 154 | } 155 | obj += strings.Join(packagePairs, "") 156 | obj += "}}" 157 | text := "pry.Apply(" + obj + ")" 158 | fileText = fileText[0:context.Start+offset] + text + fileText[context.End+offset:] 159 | offset = len(text) - (context.End - context.Start) 160 | } 161 | 162 | newPath := filepath.Dir(filePath) + "/." + filepath.Base(filePath) + "pry" 163 | 164 | err = os.Rename(filePath, newPath) 165 | if err != nil { 166 | return "", err 167 | } 168 | ioutil.WriteFile(filePath, ([]byte)(fileText), 0644) 169 | return filePath, nil 170 | } 171 | 172 | // GetExports returns a string of gocode that represents the exports (constants/functions) of an ast.Package. 173 | func (g *Generator) GetExports(importName string, files []*ast.File, added map[string]bool) (string, error) { 174 | vars := "" 175 | for _, file := range files { 176 | // Print the imports from the file's AST. 177 | scope := pry.NewScope() 178 | for k, obj := range file.Scope.Objects { 179 | if added[k] { 180 | continue 181 | } 182 | added[k] = true 183 | firstLetter := k[0:1] 184 | if firstLetter == strings.ToUpper(firstLetter) && firstLetter != "_" { 185 | 186 | isType := false 187 | 188 | switch stmt := obj.Decl.(type) { 189 | /* 190 | case *ast.ValueSpec: 191 | if len(stmt.Values) > 0 { 192 | out, err := pry.InterpretExpr(scope, stmt.Values[0]) 193 | if err != nil { 194 | fmt.Println("ERR", err) 195 | //continue 196 | } else { 197 | scope[obj.Name] = out 198 | } 199 | } 200 | */ 201 | case *ast.TypeSpec: 202 | switch typ := stmt.Type.(type) { 203 | case *ast.StructType: 204 | isType = true 205 | scope.Set(obj.Name, typ) 206 | 207 | default: 208 | out, err := scope.Interpret(stmt.Type) 209 | if err != nil { 210 | g.Debug("TypeSpec ERR %s\n", err.Error()) 211 | //continue 212 | } else { 213 | scope.Set(obj.Name, out) 214 | isType = true 215 | } 216 | } 217 | } 218 | 219 | if obj.Kind != ast.Typ || isType { 220 | path := importName + "." + k 221 | vars += "\"" + k + "\": " 222 | if isType { 223 | out, _ := scope.Get(obj.Name) 224 | switch v := out.(type) { 225 | case reflect.Type: 226 | zero := reflect.Zero(v).Interface() 227 | val := fmt.Sprintf("%#v", zero) 228 | if zero == nil { 229 | val = "nil" 230 | } 231 | vars += fmt.Sprintf("pry.Type(%s(%s))", path, val) 232 | case *ast.StructType: 233 | vars += fmt.Sprintf("pry.Type(%s{})", path) 234 | default: 235 | log.Fatalf("got unknown type: %T %+v", out, out) 236 | } 237 | 238 | // TODO Fix hack for very large constants 239 | } else if path == "math.MaxUint64" || path == "math.MaxUint" || path == "crc64.ISO" || path == "crc64.ECMA" { 240 | vars += fmt.Sprintf("uint64(%s)", path) 241 | } else { 242 | vars += path 243 | } 244 | vars += "," 245 | if g.debug { 246 | vars += "\n" 247 | } 248 | } 249 | } 250 | } 251 | } 252 | return vars, nil 253 | } 254 | 255 | // GenerateFile generates a injected file. 256 | func (g *Generator) GenerateFile(imports []string, extraStatements, path string) error { 257 | file := "package main\nimport (\n\t\"github.com/d4l3k/go-pry/pry\"\n\n" 258 | for _, imp := range imports { 259 | if len(imp) == 0 { 260 | continue 261 | } 262 | file += fmt.Sprintf("\t%#v\n", imp) 263 | } 264 | file += ")\nfunc main() {\n\t" + extraStatements + "\n\tpry.Pry()\n}\n" 265 | 266 | if err := ioutil.WriteFile(path, []byte(file), 0644); err != nil { 267 | return err 268 | } 269 | 270 | _, err := g.InjectPry(path) 271 | return err 272 | } 273 | 274 | // GenerateAndExecuteFile generates and executes a temp file with the given imports 275 | func (g *Generator) GenerateAndExecuteFile(ctx context.Context, imports []string, extraStatements string) error { 276 | dir, err := ioutil.TempDir("", "pry") 277 | if err != nil { 278 | return err 279 | } 280 | defer func() { 281 | if err := os.RemoveAll(dir); err != nil { 282 | log.Fatal(err) 283 | } 284 | }() 285 | newPath := dir + "/main.go" 286 | 287 | if err := g.GenerateFile(imports, extraStatements, newPath); err != nil { 288 | return err 289 | } 290 | 291 | if err := g.ExecuteGoCmd(ctx, []string{"run", newPath}, nil); err != nil { 292 | return err 293 | } 294 | return nil 295 | } 296 | 297 | // RevertPry reverts the changes made by InjectPry. 298 | func (g *Generator) RevertPry(modifiedFiles []string) error { 299 | fmt.Println("Reverting files") 300 | for _, file := range modifiedFiles { 301 | newPath := filepath.Dir(file) + "/." + filepath.Base(file) + "pry" 302 | if _, err := os.Stat(newPath); os.IsNotExist(err) { 303 | return errors.Errorf("no such file or directory: %s", newPath) 304 | } 305 | 306 | err := os.Remove(file) 307 | if err != nil { 308 | return err 309 | } 310 | err = os.Rename(newPath, file) 311 | if err != nil { 312 | return err 313 | } 314 | } 315 | return nil 316 | } 317 | 318 | func filterVars(vars []string) (fVars []string) { 319 | for _, v := range vars { 320 | if v != "_" { 321 | fVars = append(fVars, v) 322 | } 323 | } 324 | return 325 | } 326 | 327 | func (g *Generator) extractVariables(vars []string, l []ast.Stmt) []string { 328 | for _, s := range l { 329 | vars = g.handleStatement(vars, s) 330 | } 331 | return vars 332 | } 333 | 334 | func (g *Generator) extractFields(vars []string, l []*ast.Field) []string { 335 | for _, s := range l { 336 | vars = g.handleIdents(vars, s.Names) 337 | } 338 | return vars 339 | } 340 | 341 | func (g *Generator) handleStatement(vars []string, s ast.Stmt) []string { 342 | switch stmt := s.(type) { 343 | case *ast.ExprStmt: 344 | vars = g.handleExpr(vars, stmt.X) 345 | case *ast.AssignStmt: 346 | lhsStatements := (*stmt).Lhs 347 | for _, v := range lhsStatements { 348 | vars = g.handleExpr(vars, v) 349 | } 350 | case *ast.GoStmt: 351 | g.handleExpr(vars, stmt.Call) 352 | case *ast.IfStmt: 353 | g.handleIfStmt(vars, stmt) 354 | case *ast.DeclStmt: 355 | decl := stmt.Decl.(*ast.GenDecl) 356 | if decl.Tok == token.VAR { 357 | for _, spec := range decl.Specs { 358 | valSpec := spec.(*ast.ValueSpec) 359 | vars = g.handleIdents(vars, valSpec.Names) 360 | } 361 | } 362 | case *ast.BlockStmt: 363 | vars = g.handleBlockStmt(vars, stmt) 364 | case *ast.RangeStmt: 365 | g.handleRangeStmt(vars, stmt) 366 | case *ast.ForStmt: 367 | vars = g.handleForStmt(vars, stmt) 368 | default: 369 | g.Debug("Unknown %T\n", stmt) 370 | } 371 | return vars 372 | } 373 | 374 | func (g *Generator) handleIfStmt(vars []string, stmt *ast.IfStmt) []string { 375 | vars = g.handleStatement(vars, stmt.Init) 376 | vars = g.handleStatement(vars, stmt.Body) 377 | return vars 378 | } 379 | 380 | func (g *Generator) handleRangeStmt(vars []string, stmt *ast.RangeStmt) []string { 381 | vars = g.handleExpr(vars, stmt.Key) 382 | vars = g.handleExpr(vars, stmt.Value) 383 | vars = g.handleStatement(vars, stmt.Body) 384 | return vars 385 | } 386 | 387 | func (g *Generator) handleForStmt(vars []string, stmt *ast.ForStmt) []string { 388 | vars = g.handleStatement(vars, stmt.Init) 389 | vars = g.handleStatement(vars, stmt.Body) 390 | return vars 391 | } 392 | 393 | func (g *Generator) handleBlockStmt(vars []string, stmt *ast.BlockStmt) []string { 394 | vars = g.extractVariables(vars, stmt.List) 395 | return vars 396 | } 397 | 398 | func (g *Generator) handleIdents(vars []string, idents []*ast.Ident) []string { 399 | for _, i := range idents { 400 | vars = append(vars, i.Name) 401 | } 402 | return vars 403 | } 404 | 405 | func (g *Generator) handleExpr(vars []string, v ast.Expr) []string { 406 | switch expr := v.(type) { 407 | case *ast.Ident: 408 | varMap := make(map[string]bool) 409 | for _, v := range vars { 410 | varMap[v] = true 411 | } 412 | if !varMap[expr.Name] { 413 | vars = append(vars, expr.Name) 414 | } 415 | case *ast.CallExpr: 416 | switch fun := expr.Fun.(type) { 417 | case *ast.SelectorExpr: 418 | funcName := fun.Sel.Name 419 | if funcName == "Pry" || funcName == "Apply" { 420 | g.contexts = append(g.contexts, pryContext{(int)(expr.Pos() - 1), (int)(expr.End() - 1), vars}) 421 | } 422 | //handleExpr(vars, fun.X) 423 | case *ast.FuncLit: 424 | g.handleExpr(vars, fun) 425 | default: 426 | g.Debug("Unknown function type %T\n", fun) 427 | } 428 | for _, arg := range expr.Args { 429 | g.handleExpr(vars, arg) 430 | } 431 | case *ast.FuncLit: 432 | if expr.Type.Params != nil { 433 | for _, param := range expr.Type.Params.List { 434 | for _, name := range param.Names { 435 | vars = g.handleExpr(vars, name) 436 | } 437 | } 438 | } 439 | if expr.Type.Results != nil { 440 | for _, param := range expr.Type.Results.List { 441 | for _, name := range param.Names { 442 | vars = g.handleExpr(vars, name) 443 | } 444 | } 445 | } 446 | g.handleStatement(vars, expr.Body) 447 | default: 448 | g.Debug("Unknown %T\n", expr) 449 | } 450 | return vars 451 | } 452 | 453 | type pryContext struct { 454 | Start, End int 455 | Vars []string 456 | } 457 | -------------------------------------------------------------------------------- /generate/generate_test.go: -------------------------------------------------------------------------------- 1 | package generate 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "testing" 7 | ) 8 | 9 | func TestImportPry(t *testing.T) { 10 | g := NewGenerator(false) 11 | file := "../example/file/file.go" 12 | res, err := g.InjectPry(file) 13 | 14 | if err != nil { 15 | t.Errorf("Failed to inject pry %v", err) 16 | } 17 | 18 | if !fileExists(res) { 19 | t.Error("Source file not found") 20 | } 21 | 22 | pryFile := filepath.Join(filepath.Dir(res), ".file.gopry") 23 | if !fileExists(pryFile) { 24 | t.Error("Pry file not found") 25 | } 26 | 27 | // clean up 28 | g.RevertPry([]string{res}) 29 | 30 | if !fileExists(file) { 31 | t.Error("Source file not found") 32 | } 33 | 34 | res, err = g.InjectPry("nonexisting.go") 35 | if res != "" { 36 | t.Error("Non empty result received") 37 | } 38 | 39 | if fileExists(".nonexisting.gopry") { 40 | t.Error("Pry file should not exists") 41 | } 42 | } 43 | 44 | func fileExists(filePath string) bool { 45 | _, err := os.Stat(filePath) 46 | return !os.IsNotExist(err) 47 | } 48 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/d4l3k/go-pry 2 | 3 | go 1.16 4 | 5 | require ( 6 | github.com/cenkalti/backoff v2.2.1+incompatible 7 | github.com/davecgh/go-spew v1.1.1 8 | github.com/gorilla/handlers v1.5.1 9 | github.com/gorilla/mux v1.8.0 10 | github.com/mattn/go-colorable v0.1.8 11 | github.com/mattn/go-tty v0.0.3 12 | github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d 13 | github.com/mitchellh/go-homedir v1.1.0 14 | github.com/pkg/errors v0.9.1 15 | golang.org/x/tools v0.1.5 16 | ) 17 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= 2 | github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= 3 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 4 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 5 | github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= 6 | github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= 7 | github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4= 8 | github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= 9 | github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= 10 | github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= 11 | github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= 12 | github.com/mattn/go-colorable v0.1.8 h1:c1ghPdyEDarC70ftn0y+A/Ee++9zz8ljHG1b13eJ0s8= 13 | github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= 14 | github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= 15 | github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= 16 | github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= 17 | github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= 18 | github.com/mattn/go-runewidth v0.0.6/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= 19 | github.com/mattn/go-tty v0.0.3 h1:5OfyWorkyO7xP52Mq7tB36ajHDG5OHrmBGIS/DtakQI= 20 | github.com/mattn/go-tty v0.0.3/go.mod h1:ihxohKRERHTVzN+aSVRwACLCeqIoZAWpoICkkvrWyR0= 21 | github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= 22 | github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= 23 | github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= 24 | github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= 25 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 26 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 27 | github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 28 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 29 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 30 | golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo= 31 | golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 32 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 33 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 34 | golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= 35 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 36 | golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 37 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 38 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 39 | golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 40 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 41 | golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 42 | golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 43 | golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 44 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae h1:/WDfKMnPU+m5M4xB+6x4kaepxRw6jWvR5iDRdvjHgy8= 45 | golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 46 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 47 | golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 48 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007 h1:gG67DSER+11cZvqIMb8S8bt0vZtiN6xWYARwirrOSfE= 49 | golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 50 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 51 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 52 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 53 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 54 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 55 | golang.org/x/tools v0.1.5 h1:ouewzE6p+/VEB31YYnTbEJdi8pFqKp4P4n85vwo3DHA= 56 | golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 57 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 58 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 59 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 60 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 61 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "os" 9 | "os/signal" 10 | "path/filepath" 11 | "strings" 12 | 13 | "github.com/d4l3k/go-pry/generate" 14 | "github.com/pkg/errors" 15 | ) 16 | 17 | func main() { 18 | log.SetFlags(log.Flags() | log.Lshortfile) 19 | 20 | // Catch Ctrl-C 21 | c := make(chan os.Signal, 1) 22 | signal.Notify(c, os.Interrupt) 23 | go func() { 24 | for _ = range c { 25 | } 26 | }() 27 | 28 | if err := run(); err != nil { 29 | log.Fatalf("%+v", err) 30 | } 31 | } 32 | 33 | func run() error { 34 | ctx := context.Background() 35 | 36 | // FLAGS 37 | imports := flag.String("i", "fmt,math", "packages to import, comma seperated") 38 | revert := flag.Bool("r", true, "whether to revert changes on exit") 39 | execute := flag.String("e", "", "statements to execute") 40 | generatePath := flag.String("generate", "", "the path to generate a go-pry injected file - EXPERIMENTAL") 41 | debug := flag.Bool("d", false, "display debug statements") 42 | 43 | flag.CommandLine.Usage = func() { 44 | if err := generate.NewGenerator(*debug).ExecuteGoCmd(ctx, []string{}, nil); err != nil { 45 | log.Fatal(err) 46 | } 47 | fmt.Println("----") 48 | fmt.Println("go-pry is an interactive REPL and wrapper around the go command.") 49 | fmt.Println("You can execute go commands as normal and go-pry will take care of generating the pry code.") 50 | fmt.Println("Running go-pry with no arguments will drop you into an interactive REPL.") 51 | flag.PrintDefaults() 52 | fmt.Println(" revert: cleans up go-pry generated files if not automatically done") 53 | } 54 | flag.Parse() 55 | 56 | g := generate.NewGenerator(*debug) 57 | 58 | cmdArgs := flag.Args() 59 | if len(cmdArgs) == 0 { 60 | imports := strings.Split(*imports, ",") 61 | if len(*generatePath) > 0 { 62 | return g.GenerateFile(imports, *execute, *generatePath) 63 | } 64 | return g.GenerateAndExecuteFile(ctx, imports, *execute) 65 | } 66 | 67 | goDirs := []string{} 68 | for _, arg := range cmdArgs { 69 | if strings.HasSuffix(arg, ".go") { 70 | goDirs = append(goDirs, filepath.Dir(arg)) 71 | } 72 | } 73 | if len(goDirs) == 0 { 74 | dir, err := os.Getwd() 75 | if err != nil { 76 | panic(err) 77 | } 78 | goDirs = []string{dir} 79 | } 80 | 81 | processedFiles := []string{} 82 | modifiedFiles := []string{} 83 | 84 | if cmdArgs[0] == "revert" { 85 | fmt.Println("REVERTING PRY") 86 | for _, dir := range goDirs { 87 | filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { 88 | if strings.HasSuffix(path, ".gopry") { 89 | processed := false 90 | for _, file := range processedFiles { 91 | if file == path { 92 | processed = true 93 | } 94 | } 95 | if !processed { 96 | base := filepath.Base(path) 97 | newPath := filepath.Dir(path) + "/" + base[1:len(base)-3] 98 | modifiedFiles = append(modifiedFiles, newPath) 99 | } 100 | } 101 | return nil 102 | }) 103 | } 104 | return g.RevertPry(modifiedFiles) 105 | } 106 | 107 | testsRequired := cmdArgs[0] == "test" 108 | for _, dir := range goDirs { 109 | if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { 110 | if !testsRequired && strings.HasSuffix(path, "_test.go") || !strings.HasSuffix(path, ".go") || strings.Contains(path, "vendor/") { 111 | return nil 112 | } 113 | for _, file := range processedFiles { 114 | if file == path { 115 | return nil 116 | } 117 | } 118 | file, err := g.InjectPry(path) 119 | if err != nil { 120 | return errors.Wrap(err, "inject") 121 | } 122 | if file != "" { 123 | modifiedFiles = append(modifiedFiles, path) 124 | } 125 | return nil 126 | }); err != nil { 127 | return err 128 | } 129 | } 130 | 131 | if cmdArgs[0] == "apply" { 132 | return nil 133 | } 134 | 135 | if err := g.ExecuteGoCmd(ctx, cmdArgs, nil); err != nil { 136 | return err 137 | } 138 | 139 | if *revert { 140 | if err := g.RevertPry(modifiedFiles); err != nil { 141 | return err 142 | } 143 | } 144 | return nil 145 | } 146 | -------------------------------------------------------------------------------- /main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestBackupRestore(t *testing.T) { 8 | } 9 | -------------------------------------------------------------------------------- /playground/.gitignore: -------------------------------------------------------------------------------- 1 | wasm_exec.js 2 | node_modules/ 3 | *.wasm 4 | *.gopry 5 | bundles/ 6 | playground 7 | -------------------------------------------------------------------------------- /playground/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:alpine as base 2 | 3 | RUN apk add binutils git subversion mercurial 4 | RUN strip /usr/local/go/bin/go 5 | RUN rm /usr/local/go/bin/gofmt 6 | RUN rm -r /usr/local/go/src/cmd 7 | #RUN rm -r /usr/local/go/src/vendor 8 | RUN rm -r /usr/local/go/api 9 | RUN rm -r /usr/local/go/lib 10 | RUN rm -r /usr/local/go/misc 11 | RUN rm -r /usr/local/go/test 12 | RUN rm -r /usr/share 13 | RUN find /usr/local/go/src -name "testdata" -exec rm -r {} + 14 | RUN find /usr/lib/python2.7 -name "*.pyo" -exec rm -r {} + 15 | RUN find /usr/lib/python2.7 -name "*.pyc" -exec rm -r {} + 16 | RUN rm -r /usr/local/go/pkg/linux_amd64 17 | RUN strip /usr/local/go/pkg/tool/linux_amd64/* 18 | 19 | WORKDIR /go/src/app 20 | COPY . . 21 | 22 | RUN CGO_ENABLED=0 go get ./server 23 | RUN CGO_ENABLED=0 go build -ldflags "-s -w" -o playground ./server 24 | RUN rm -r /go/src/github.com /go/src/golang.org /go/bin 25 | RUN rm -r /root 26 | 27 | EXPOSE 8080 28 | CMD ["/go/src/app/playground"] 29 | 30 | #WORKDIR /go/src/app 31 | #COPY . . 32 | # 33 | #FROM scratch 34 | #ENV LD_LIBRARY_PATH /lib/:/usr/lib/ 35 | # 36 | #COPY --from=base /go/src/app/ / 37 | # 38 | #COPY --from=base /usr/local/go/bin/go /bin/ 39 | #COPY --from=base /usr/local/go/src /usr/local/go/ 40 | # 41 | #COPY --from=base /lib/*.so* /lib/ 42 | #COPY --from=base /usr/lib/*.so* /usr/lib/ 43 | # 44 | #COPY --from=base /usr/bin/git /bin/ 45 | #COPY --from=base /usr/bin/svn /bin/ 46 | #COPY --from=base /usr/bin/hg /bin/ 47 | # 48 | #COPY --from=base /tmp/ /tmp/ 49 | 50 | #EXPOSE 8080 51 | #CMD ["/playground"] 52 | -------------------------------------------------------------------------------- /playground/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build 2 | build: wasm_exec.js 3 | yarn install 4 | go install -v .. 5 | @rm -r bundles 6 | mkdir -p bundles 7 | go-pry -generate="bundles/main.go" -i="fmt,log,math" -e='log.Println("Hello world!")' 8 | GOOS=js GOARCH=wasm go build -v -ldflags "-s -w" -o bundles/main.wasm bundles/main.go 9 | go-pry -generate="bundles/stdlib.go" -i="$(shell tr '\n' ',' < packages.txt)" -e='log.Println("Hello world!")' 10 | GOOS=js GOARCH=wasm go build -v -ldflags "-s -w" -o bundles/stdlib.wasm bundles/stdlib.go 11 | 12 | .PHONY: run 13 | run: build 14 | CGO_ENABLED=0 go build -ldflags "-s -w" -o playground ./server 15 | ./playground 16 | 17 | wasm_exec.js: $(shell go env GOROOT)/misc/wasm/wasm_exec.js 18 | cp $< . 19 | 20 | .PHONY: deploy 21 | deploy: build 22 | now 23 | now alias 24 | -------------------------------------------------------------------------------- /playground/genmeta/genmeta.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "go/parser" 7 | "go/token" 8 | "io/ioutil" 9 | "log" 10 | "os" 11 | 12 | "github.com/d4l3k/go-pry/pry" 13 | "github.com/davecgh/go-spew/spew" 14 | ) 15 | 16 | func main() { 17 | if err := run(); err != nil { 18 | log.Fatal(err) 19 | } 20 | } 21 | 22 | func run() error { 23 | wd, err := os.Getwd() 24 | if err != nil { 25 | return err 26 | } 27 | dir, err := parser.ParseDir(token.NewFileSet(), wd, nil, 0) 28 | if err != nil { 29 | return err 30 | } 31 | imp := pry.JSImporter{ 32 | Dir: dir, 33 | } 34 | spew.Dump(imp) 35 | var buf bytes.Buffer 36 | if err := json.NewEncoder(&buf).Encode(imp); err != nil { 37 | return err 38 | } 39 | if err := ioutil.WriteFile("meta.go", []byte( 40 | `package main 41 | import "github.com/d4l3k/go-pry/pry" 42 | func init(){ 43 | pry.InternalSetImports(`+"`"+buf.String()+"`"+`) 44 | }`, 45 | ), 0644); err != nil { 46 | return err 47 | } 48 | return nil 49 | } 50 | -------------------------------------------------------------------------------- /playground/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 39 | go-pry: web 40 | 41 | 42 |

go-pry: web

43 |
44 | 45 | 46 | 47 |
48 |

hello! 👋

49 |

This is a version of go-pry 50 | that runs in your browser using WebAssembly! It's missing a few features 51 | but works pretty well for the most part.

52 | 53 |

This sometimes takes quite a while to load in Chrome, but it's almost 54 | instant in Firefox.

55 | 56 |

I'm pretty excited to try using this to demo popular client 57 | libraries.

58 | 59 |

go-pry import bundles

60 | 67 |
68 | 69 |

Created by Tristan Rice.

70 | 71 | 72 | 73 | 74 | 75 | 76 | 106 | 107 | 108 | 109 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /playground/now.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "go-pry", 3 | "type": "docker", 4 | "alias": "gopry.rice.sh", 5 | "files": [ 6 | "index.html", 7 | "node_modules", 8 | "wasm_exec.js", 9 | "bundles", 10 | "Dockerfile", 11 | "server" 12 | ], 13 | "public": true 14 | } 15 | 16 | -------------------------------------------------------------------------------- /playground/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "playground", 3 | "version": "1.0.0", 4 | "main": "index.js", 5 | "license": "MIT", 6 | "dependencies": { 7 | "nprogress": "^0.2.0", 8 | "xterm": "^3.8.1" 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /playground/packages.txt: -------------------------------------------------------------------------------- 1 | sort 2 | mime 3 | mime/quotedprintable 4 | mime/multipart 5 | context 6 | html 7 | html/template 8 | io 9 | io/ioutil 10 | fmt 11 | encoding/gob 12 | encoding/binary 13 | encoding/pem 14 | encoding/base32 15 | encoding/ascii85 16 | encoding/base64 17 | encoding/hex 18 | encoding/json 19 | encoding/xml 20 | encoding/csv 21 | encoding/asn1 22 | crypto 23 | crypto/md5 24 | crypto/dsa 25 | crypto/x509 26 | crypto/sha512 27 | crypto/des 28 | crypto/aes 29 | crypto/rc4 30 | crypto/elliptic 31 | crypto/subtle 32 | crypto/tls 33 | crypto/hmac 34 | crypto/sha256 35 | crypto/sha1 36 | crypto/rsa 37 | crypto/ecdsa 38 | crypto/cipher 39 | image 40 | image/gif 41 | image/color 42 | image/color/palette 43 | image/draw 44 | image/jpeg 45 | image/png 46 | sync 47 | strings 48 | bytes 49 | strconv 50 | math 51 | math/cmplx 52 | math/big 53 | math/bits 54 | math/rand 55 | text/template/parse 56 | text/tabwriter 57 | text/scanner 58 | hash/crc64 59 | hash/crc32 60 | hash/adler32 61 | hash/fnv 62 | path 63 | path/filepath 64 | reflect 65 | archive/zip 66 | archive/tar 67 | log 68 | log/syslog 69 | compress/bzip2 70 | compress/zlib 71 | compress/gzip 72 | compress/flate 73 | compress/lzw 74 | plugin 75 | errors 76 | container/heap 77 | container/list 78 | container/ring 79 | time 80 | expvar 81 | unicode 82 | unicode/utf8 83 | unicode/utf16 84 | bufio 85 | regexp 86 | regexp/syntax 87 | net 88 | net/rpc 89 | net/rpc/jsonrpc 90 | net/mail 91 | net/url 92 | net/textproto 93 | net/smtp 94 | net/http 95 | -------------------------------------------------------------------------------- /playground/server/server.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "flag" 7 | "fmt" 8 | "log" 9 | "net/http" 10 | "os" 11 | "path/filepath" 12 | "sort" 13 | "strings" 14 | 15 | "github.com/d4l3k/go-pry/generate" 16 | "github.com/gorilla/handlers" 17 | "github.com/gorilla/mux" 18 | "github.com/pkg/errors" 19 | ) 20 | 21 | const bundlesDir = "bundles" 22 | 23 | var bind = flag.String("bind", ":8080", "address to bind to") 24 | 25 | func main() { 26 | if err := run(); err != nil { 27 | log.Fatal(err) 28 | } 29 | } 30 | 31 | func pkgHash(pkgs []string) string { 32 | hash := sha256.Sum256([]byte(strings.Join(pkgs, ","))) 33 | return hex.EncodeToString(hash[:]) 34 | } 35 | 36 | func normalizePackages(packages string) []string { 37 | var pkgs []string 38 | for _, pkg := range strings.Split(strings.ToLower(packages), ",") { 39 | pkg = strings.TrimSpace(pkg) 40 | if len(pkg) == 0 { 41 | continue 42 | } 43 | pkgs = append(pkgs, pkg) 44 | } 45 | sort.Strings(pkgs) 46 | return pkgs 47 | } 48 | 49 | func generateBundle(w http.ResponseWriter, r *http.Request, packages string) (retErr error) { 50 | pkgs := normalizePackages(packages) 51 | hash := pkgHash(pkgs) 52 | path := filepath.Join(bundlesDir, hash+".wasm") 53 | goPath := filepath.Join(bundlesDir, hash+".go") 54 | _, err := os.Stat(path) 55 | if err == nil { 56 | http.ServeFile(w, r, path) 57 | return nil 58 | } else if !os.IsNotExist(err) { 59 | return err 60 | } 61 | 62 | g := generate.NewGenerator(false) 63 | g.Config.Env = append(os.Environ(), "CGO_ENABLED=0") 64 | 65 | env := os.Environ() 66 | 67 | for _, pkg := range append([]string{"github.com/d4l3k/go-pry/pry"}, pkgs...) { 68 | if err := g.ExecuteGoCmd(r.Context(), []string{ 69 | "get", 70 | pkg, 71 | }, env); err != nil { 72 | return errors.Wrapf(err, "error go get %q", pkg) 73 | } 74 | } 75 | 76 | if err := g.GenerateFile(pkgs, "", goPath); err != nil { 77 | return errors.Wrap(err, "GenerateFile") 78 | } 79 | 80 | if err := g.ExecuteGoCmd(r.Context(), []string{ 81 | "build", 82 | "-ldflags", 83 | "-s -w", 84 | "-o", 85 | path, 86 | goPath, 87 | }, append([]string{ 88 | "GOOS=js", 89 | "GOARCH=wasm", 90 | }, env...)); err != nil { 91 | return errors.Wrapf(err, "go build") 92 | } 93 | 94 | http.ServeFile(w, r, path) 95 | return nil 96 | } 97 | 98 | func run() error { 99 | log.SetFlags(log.Flags() | log.Lshortfile) 100 | 101 | router := mux.NewRouter() 102 | router.PathPrefix("/wasm/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 103 | pkgs := strings.Join(strings.Split(r.URL.Path, "/")[2:], "/") 104 | if err := generateBundle(w, r, pkgs); err != nil { 105 | http.Error(w, fmt.Sprintf("%+v", err), http.StatusInternalServerError) 106 | } 107 | }) 108 | router.NotFoundHandler = http.FileServer(http.Dir(".")) 109 | 110 | log.Printf("Listening %s...", *bind) 111 | r := handlers.CombinedLoggingHandler(os.Stderr, router) 112 | return http.ListenAndServe(*bind, r) 113 | } 114 | -------------------------------------------------------------------------------- /playground/server/server11_test.go: -------------------------------------------------------------------------------- 1 | // +build go1.11 2 | 3 | package main 4 | 5 | import ( 6 | "net/http" 7 | "net/http/httptest" 8 | "os" 9 | "testing" 10 | ) 11 | 12 | func TestGenerateBundle(t *testing.T) { 13 | t.Parallel() 14 | 15 | remove := func() { 16 | if err := os.RemoveAll(bundlesDir); err != nil { 17 | t.Fatal(err) 18 | } 19 | } 20 | 21 | remove() 22 | 23 | if err := os.MkdirAll(bundlesDir, 0755); err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | r, err := http.NewRequest(http.MethodGet, "/wasm/math,fmt", nil) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | resp := httptest.NewRecorder() 32 | if err := generateBundle(resp, r, "math,fmt"); err != nil { 33 | t.Fatal(err) 34 | } 35 | if resp.Code != http.StatusOK { 36 | t.Fatalf("expected StatusOK got %+v", resp.Code) 37 | } 38 | 39 | remove() 40 | } 41 | -------------------------------------------------------------------------------- /playground/server/server_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestNormalizePackages(t *testing.T) { 9 | t.Parallel() 10 | 11 | cases := []struct { 12 | in string 13 | want []string 14 | }{ 15 | { 16 | "", 17 | nil, 18 | }, 19 | { 20 | "foo, bar,github.com/d4l3k/go-pry ", 21 | []string{ 22 | "bar", 23 | "foo", 24 | "github.com/d4l3k/go-pry", 25 | }, 26 | }, 27 | } 28 | 29 | for i, c := range cases { 30 | out := normalizePackages(c.in) 31 | if !reflect.DeepEqual(out, c.want) { 32 | t.Errorf("%d. normalizePackages(%q) = %+v; not %+v", i, c.in, out, c.want) 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /playground/yarn.lock: -------------------------------------------------------------------------------- 1 | # THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. 2 | # yarn lockfile v1 3 | 4 | 5 | nprogress@^0.2.0: 6 | version "0.2.0" 7 | resolved "https://registry.yarnpkg.com/nprogress/-/nprogress-0.2.0.tgz#cb8f34c53213d895723fcbab907e9422adbcafb1" 8 | 9 | xterm@^3.8.1: 10 | version "3.8.1" 11 | resolved "https://registry.yarnpkg.com/xterm/-/xterm-3.8.1.tgz#0beabaccdc23bd3ab2397c5129ed9b06b0abb167" 12 | -------------------------------------------------------------------------------- /pry-build-corpus/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "crypto/sha1" 6 | "encoding/hex" 7 | "io/ioutil" 8 | "log" 9 | "os" 10 | "path/filepath" 11 | "regexp" 12 | ) 13 | 14 | const out = "fuzz/corpus/" 15 | 16 | var ( 17 | exampleRegexpQuotes = regexp.MustCompile("(?s)InterpretString\\(`(.*?)`\\)") 18 | ) 19 | 20 | func main() { 21 | if err := run(); err != nil { 22 | log.Fatalf("%+v", err) 23 | } 24 | } 25 | 26 | func run() error { 27 | files, err := filepath.Glob("**/*.go") 28 | if err != nil { 29 | return err 30 | } 31 | if err := os.MkdirAll(out, 0755); err != nil { 32 | return err 33 | } 34 | for _, fpath := range files { 35 | body, err := ioutil.ReadFile(fpath) 36 | if err != nil { 37 | return err 38 | } 39 | 40 | for { 41 | match := exampleRegexpQuotes.FindSubmatchIndex(body) 42 | if match == nil { 43 | break 44 | } 45 | 46 | expr := bytes.TrimSpace(body[match[2]:match[3]]) 47 | hash := sha1.Sum(expr) 48 | file := hex.EncodeToString(hash[:]) 49 | if err := ioutil.WriteFile(filepath.Join(out, file), expr, 0644); err != nil { 50 | return err 51 | } 52 | 53 | body = body[match[1]:] 54 | } 55 | } 56 | return nil 57 | } 58 | -------------------------------------------------------------------------------- /pry/autocomplete.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "bufio" 5 | "go/ast" 6 | "io" 7 | "io/ioutil" 8 | "os" 9 | "os/exec" 10 | "path/filepath" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | const placeholder = "pryPlaceholderAutoComplete" 16 | 17 | // SuggestionsGoCode is a suggestion engine that uses gocode for autocomplete. 18 | func (scope *Scope) SuggestionsGoCode(line string, index int) ([]string, error) { 19 | var suggestions []string 20 | var code string 21 | for name, file := range scope.Files { 22 | moddedName := filepath.Dir(name) + "/." + filepath.Base(name) + "pry" 23 | if scope.path == moddedName { 24 | name = moddedName 25 | } 26 | if name == scope.path { 27 | ast.Walk(walker(func(n ast.Node) bool { 28 | switch s := n.(type) { 29 | case *ast.BlockStmt: 30 | for i, stmt := range s.List { 31 | pos := scope.fset.Position(stmt.Pos()) 32 | if pos.Line == scope.line { 33 | r := scope.Render(stmt) 34 | if strings.HasPrefix(r, "pry.Apply") { 35 | var iStmt []ast.Stmt 36 | iStmt = append(iStmt, ast.Stmt(&ast.ExprStmt{X: ast.NewIdent(placeholder)})) 37 | oldList := make([]ast.Stmt, len(s.List)) 38 | copy(oldList, s.List) 39 | 40 | s.List = append(s.List, make([]ast.Stmt, len(iStmt))...) 41 | 42 | copy(s.List[i+len(iStmt):], s.List[i:]) 43 | copy(s.List[i:], iStmt) 44 | 45 | code = scope.Render(file) 46 | s.List = oldList 47 | return false 48 | } 49 | } 50 | } 51 | } 52 | return true 53 | }), file) 54 | 55 | i := strings.Index(code, placeholder) + index 56 | code = strings.Replace(code, placeholder, line, 1) 57 | 58 | subProcess := exec.Command("gocode", "autocomplete", filepath.Dir(name), strconv.Itoa(i)) 59 | 60 | stdin, err := subProcess.StdinPipe() 61 | if err != nil { 62 | return nil, err 63 | } 64 | 65 | stdout, err := subProcess.StdoutPipe() 66 | if err != nil { 67 | return nil, err 68 | } 69 | defer stdout.Close() 70 | 71 | subProcess.Stderr = os.Stderr 72 | 73 | if err = subProcess.Start(); err != nil { 74 | return nil, err 75 | } 76 | 77 | io.WriteString(stdin, code) 78 | stdin.Close() 79 | 80 | output, err := ioutil.ReadAll(bufio.NewReader(stdout)) 81 | if err != nil { 82 | return nil, err 83 | } 84 | rawSuggestions := strings.Split(string(output), "\n")[1:] 85 | for _, suggestion := range rawSuggestions { 86 | trimmed := strings.TrimSpace(suggestion) 87 | if len(trimmed) > 0 { 88 | suggestions = append(suggestions, trimmed) 89 | } 90 | } 91 | subProcess.Wait() 92 | 93 | break 94 | } 95 | } 96 | return suggestions, nil 97 | } 98 | -------------------------------------------------------------------------------- /pry/fuzz.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import "fmt" 4 | 5 | // Fuzz is used for go-fuzz testing. 6 | func Fuzz(data []byte) int { 7 | s := NewScope() 8 | val, err := s.InterpretString(string(data)) 9 | if err != nil { 10 | if val != nil { 11 | panic(fmt.Sprintf("%#v != nil on error: %+v", val, err)) 12 | } 13 | return 0 14 | } 15 | return 1 16 | } 17 | -------------------------------------------------------------------------------- /pry/helpers.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | // InterpretError is an error returned by the interpreter and shouldn't be 11 | // passed to the user or running code. 12 | type InterpretError struct { 13 | err error 14 | } 15 | 16 | func (a *InterpretError) Error() error { 17 | if a == nil { 18 | return nil 19 | } 20 | return a.err 21 | } 22 | 23 | // Append is a runtime replacement for the append function 24 | func Append(arr interface{}, elems ...interface{}) (interface{}, *InterpretError) { 25 | arrVal := reflect.ValueOf(arr) 26 | valArr := make([]reflect.Value, len(elems)) 27 | for i, elem := range elems { 28 | if reflect.TypeOf(arr) != reflect.SliceOf(reflect.TypeOf(elem)) { 29 | return nil, &InterpretError{fmt.Errorf("%T cannot append to %T", elem, arr)} 30 | } 31 | valArr[i] = reflect.ValueOf(elem) 32 | } 33 | return reflect.Append(arrVal, valArr...).Interface(), nil 34 | } 35 | 36 | // Make is a runtime replacement for the make function 37 | func Make(t interface{}, args ...interface{}) (interface{}, *InterpretError) { 38 | typ, isType := t.(reflect.Type) 39 | if !isType { 40 | return nil, &InterpretError{fmt.Errorf("invalid type %#v", t)} 41 | } 42 | switch typ.Kind() { 43 | case reflect.Slice: 44 | if len(args) < 1 || len(args) > 2 { 45 | return nil, &InterpretError{errors.New("invalid number of arguments. Missing len or extra?")} 46 | } 47 | length, isInt := args[0].(int) 48 | if !isInt { 49 | return nil, &InterpretError{errors.New("len is not int")} 50 | } 51 | capacity := length 52 | if len(args) == 2 { 53 | capacity, isInt = args[0].(int) 54 | if !isInt { 55 | return nil, &InterpretError{errors.New("len is not int")} 56 | } 57 | } 58 | if length < 0 || capacity < 0 { 59 | return nil, &InterpretError{errors.Errorf("negative length or capacity")} 60 | } 61 | slice := reflect.MakeSlice(typ, length, capacity) 62 | return slice.Interface(), nil 63 | 64 | case reflect.Chan: 65 | if len(args) > 1 { 66 | fmt.Printf("CHAN ARGS %#v", args) 67 | return nil, &InterpretError{errors.New("too many arguments")} 68 | } 69 | size := 0 70 | if len(args) == 1 { 71 | var isInt bool 72 | size, isInt = args[0].(int) 73 | if !isInt { 74 | return nil, &InterpretError{errors.New("size is not int")} 75 | } 76 | } 77 | if size < 0 { 78 | return nil, &InterpretError{errors.Errorf("negative buffer size")} 79 | } 80 | buffer := reflect.MakeChan(typ, size) 81 | return buffer.Interface(), nil 82 | 83 | default: 84 | return nil, &InterpretError{fmt.Errorf("unknown kind type %T", t)} 85 | } 86 | } 87 | 88 | // Close is a runtime replacement for the "close" function. 89 | func Close(t interface{}) (interface{}, *InterpretError) { 90 | reflect.ValueOf(t).Close() 91 | return nil, nil 92 | } 93 | 94 | // Len is a runtime replacement for the len function 95 | func Len(t interface{}) (interface{}, *InterpretError) { 96 | return reflect.ValueOf(t).Len(), nil 97 | } 98 | -------------------------------------------------------------------------------- /pry/highlighter.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "github.com/mgutz/ansi" 5 | 6 | "regexp" 7 | "strings" 8 | ) 9 | 10 | const highlightColor1 = "white+b" 11 | const highlightColor2 = "green+b" 12 | const highlightColor3 = "red" 13 | const highlightColor4 = "blue+b" 14 | const highlightColor5 = "red+b" 15 | 16 | // Highlight highlights a string of go code for outputting to bash. 17 | func Highlight(s string) string { 18 | highlightSymbols := []string{"==", "!=", ":=", "="} 19 | highlightKeywords := []string{ 20 | "for", "defer", "func", "struct", "switch", "case", 21 | "interface", "if", "range", "bool", "type", "package", "import", 22 | "make", "append", 23 | } 24 | highlightKeywordsSpaced := []string{"go"} 25 | highlightTypes := []string{ 26 | "byte", 27 | "complex128", 28 | "complex64", 29 | "error", 30 | "float", 31 | "float32", 32 | "float64", 33 | "int", 34 | "int16", 35 | "int32", 36 | "int64", 37 | "int8", 38 | "rune", 39 | "string", 40 | "uint", 41 | "uint16", 42 | "uint32", 43 | "uint64", 44 | "uint8", 45 | "uintptr", 46 | } 47 | s = highlightWords(s, []string{"-?(0[xX])?\\d+((\\.|e-?)\\d+)*", "nil", "true", "false"}, highlightColor4, "\\W") 48 | s = highlightWords(s, highlightKeywords, highlightColor1, "\\W") 49 | s = highlightWords(s, highlightKeywordsSpaced, highlightColor1, "\\s") 50 | s = highlightWords(s, highlightTypes, highlightColor2, "\\W") 51 | s = highlightWords(s, highlightSymbols, highlightColor1, "") 52 | s = highlightWords(s, []string{".+"}, highlightColor3, "\"") 53 | s = highlightWords(s, []string{"\""}, highlightColor5, "") 54 | s = highlightWords(s, []string{"//.+"}, highlightColor4, "") 55 | return s 56 | } 57 | 58 | func highlightWords(s string, words []string, color, edges string) string { 59 | lE := len(edges) - strings.Count(edges, "\\") 60 | s = " " + s + " " 61 | for _, word := range words { 62 | r, _ := regexp.Compile(edges + word + edges) 63 | s = (string)(r.ReplaceAllFunc(([]byte)(s), func(b []byte) []byte { 64 | bStr := string(b) 65 | return []byte(bStr[0:lE] + ansi.Color(bStr[lE:len(bStr)-lE], color) + bStr[len(bStr)-lE:]) 66 | })) 67 | } 68 | if s[0] == ' ' { 69 | s = s[1:] 70 | } 71 | if s[len(s)-1] == ' ' { 72 | s = s[:len(s)-1] 73 | } 74 | return s 75 | } 76 | -------------------------------------------------------------------------------- /pry/highlighter_test.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "io/ioutil" 5 | "regexp" 6 | "testing" 7 | ) 8 | 9 | // Make sure the highlighter doesn't change the code. 10 | func TestHighlightSafe(t *testing.T) { 11 | t.Parallel() 12 | 13 | fileBytes, err := ioutil.ReadFile("../example/file/file.go") 14 | if err != nil { 15 | t.Error(err) 16 | } 17 | fileStr := (string)(fileBytes) 18 | highlight := Highlight(fileStr) 19 | 20 | r, err := regexp.Compile("\\x1b\\[(.*?)m") 21 | if err != nil { 22 | t.Error(err) 23 | } 24 | 25 | // Strip Bash control sequences 26 | s := r.ReplaceAllLiteralString(highlight, "") 27 | 28 | if s != fileStr { 29 | t.Error("Highlighting has changed the code!") 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /pry/importer.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "go/ast" 5 | "go/types" 6 | 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | // JSImporter contains all the information needed to implement a types.Importer 11 | // in a javascript environment. 12 | type JSImporter struct { 13 | packages map[string]*types.Package 14 | Dir map[string]*ast.Package 15 | } 16 | 17 | func (i *JSImporter) Import(path string) (*types.Package, error) { 18 | p, ok := i.packages[path] 19 | if !ok { 20 | return nil, errors.Errorf("package %q not found", path) 21 | } 22 | return p, nil 23 | } 24 | -------------------------------------------------------------------------------- /pry/importer_default.go: -------------------------------------------------------------------------------- 1 | // +build !js 2 | 3 | package pry 4 | 5 | import ( 6 | "go/ast" 7 | "go/types" 8 | "path/filepath" 9 | 10 | "github.com/pkg/errors" 11 | "golang.org/x/tools/go/packages" 12 | ) 13 | 14 | type packagesImporter struct { 15 | } 16 | 17 | func (i packagesImporter) Import(path string) (*types.Package, error) { 18 | return i.ImportFrom(path, "", 0) 19 | } 20 | func (packagesImporter) ImportFrom(path, dir string, mode types.ImportMode) (*types.Package, error) { 21 | conf := packages.Config{ 22 | Mode: packages.NeedImports | packages.NeedTypes, 23 | Dir: dir, 24 | } 25 | pkgs, err := packages.Load(&conf, path) 26 | if err != nil { 27 | return nil, errors.Wrapf(err, "importing %q", path) 28 | } 29 | pkg := pkgs[0] 30 | return pkg.Types, nil 31 | } 32 | 33 | func getImporter() types.ImporterFrom { 34 | return packagesImporter{} 35 | } 36 | 37 | func (s *Scope) parseDir() (map[string]*ast.File, error) { 38 | conf := packages.Config{ 39 | Fset: s.fset, 40 | Mode: packages.NeedCompiledGoFiles | packages.NeedSyntax, 41 | Dir: filepath.Dir(s.path), 42 | } 43 | pkgs, err := packages.Load(&conf, ".") 44 | if err != nil { 45 | return nil, errors.Wrapf(err, "parsing dir") 46 | } 47 | pkg := pkgs[0] 48 | files := map[string]*ast.File{} 49 | for i, name := range pkg.CompiledGoFiles { 50 | files[name] = pkg.Syntax[i] 51 | } 52 | return files, nil 53 | } 54 | -------------------------------------------------------------------------------- /pry/importer_js.go: -------------------------------------------------------------------------------- 1 | // +build js 2 | 3 | package pry 4 | 5 | import ( 6 | "encoding/json" 7 | "go/ast" 8 | "go/types" 9 | ) 10 | 11 | func (s *Scope) parseDir() (map[string]*ast.File, error) { 12 | files := map[string]*ast.File{} 13 | for _, p := range defaultImporter.Dir { 14 | for name, file := range p.Files { 15 | files[name] = file 16 | } 17 | } 18 | return files, nil 19 | } 20 | 21 | func getImporter() types.Importer { 22 | return defaultImporter 23 | } 24 | 25 | var defaultImporter = &JSImporter{ 26 | packages: map[string]*types.Package{}, 27 | Dir: map[string]*ast.Package{}, 28 | } 29 | 30 | func InternalSetImports(raw string) { 31 | if err := json.Unmarshal([]byte(raw), defaultImporter); err != nil { 32 | panic(err) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /pry/interpreter.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/ast" 7 | "go/parser" 8 | "go/printer" 9 | "go/token" 10 | "path/filepath" 11 | "reflect" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | "time" 16 | 17 | "github.com/pkg/errors" 18 | 19 | "go/types" 20 | // Used by types for import determination 21 | ) 22 | 23 | var ( 24 | // ErrChanSendFailed occurs when a channel is full or there are no receivers 25 | // available. 26 | ErrChanSendFailed = errors.New("failed to send, channel full or no receivers") 27 | 28 | // ErrBranchBreak is an internal error thrown when a for loop breaks. 29 | ErrBranchBreak = errors.New("branch break") 30 | // ErrBranchContinue is an internal error thrown when a for loop continues. 31 | ErrBranchContinue = errors.New("branch continue") 32 | ) 33 | 34 | // Scope is a string-interface key-value pair that represents variables/functions in scope. 35 | type Scope struct { 36 | Vals map[string]interface{} 37 | Parent *Scope 38 | Files map[string]*ast.File 39 | config *types.Config 40 | path string 41 | line int 42 | fset *token.FileSet 43 | 44 | isSelect bool 45 | typeAssert reflect.Type 46 | isFunction bool 47 | defers []*Defer 48 | 49 | sync.Mutex 50 | } 51 | 52 | type Defer struct { 53 | fun ast.Expr 54 | scope *Scope 55 | arguments []interface{} 56 | } 57 | 58 | func (scope *Scope) Defer(d *Defer) error { 59 | for ; scope != nil; scope = scope.Parent { 60 | if scope.isFunction { 61 | scope.defers = append(scope.defers, d) 62 | return nil 63 | } 64 | } 65 | return errors.New("defer: can't find function scope") 66 | } 67 | 68 | // NewScope creates a new initialized scope 69 | func NewScope() *Scope { 70 | s := &Scope{ 71 | Vals: map[string]interface{}{}, 72 | Files: map[string]*ast.File{}, 73 | } 74 | s.Set("_pryScope", s) 75 | return s 76 | } 77 | 78 | // GetPointer walks the scope and finds the pointer to the value of interest 79 | func (scope *Scope) GetPointer(name string) (val interface{}, exists bool) { 80 | currentScope := scope 81 | for !exists && currentScope != nil { 82 | currentScope.Lock() 83 | val, exists = currentScope.Vals[name] 84 | currentScope.Unlock() 85 | currentScope = currentScope.Parent 86 | } 87 | return 88 | } 89 | 90 | // Get walks the scope and finds the value of interest 91 | func (scope *Scope) Get(name string) (interface{}, bool) { 92 | val, exists := scope.GetPointer(name) 93 | if !exists || val == nil { 94 | return val, exists 95 | } 96 | v := reflect.ValueOf(val) 97 | if v.Kind() == reflect.Ptr { 98 | return v.Elem().Interface(), exists 99 | } 100 | return v.Interface(), exists 101 | } 102 | 103 | // Set walks the scope and sets a value in a parent scope if it exists, else current. 104 | func (scope *Scope) Set(name string, val interface{}) { 105 | if val != nil { 106 | value := reflect.ValueOf(val) 107 | if !value.CanAddr() { 108 | nv := reflect.New(value.Type()) 109 | nv.Elem().Set(value) 110 | val = nv.Interface() 111 | } else { 112 | val = value.Addr().Interface() 113 | } 114 | } 115 | 116 | exists := false 117 | currentScope := scope 118 | for !exists && currentScope != nil { 119 | currentScope.Lock() 120 | _, exists = currentScope.Vals[name] 121 | if exists { 122 | currentScope.Vals[name] = val 123 | } 124 | currentScope.Unlock() 125 | currentScope = currentScope.Parent 126 | } 127 | if !exists { 128 | scope.Lock() 129 | scope.Vals[name] = val 130 | scope.Unlock() 131 | } 132 | } 133 | 134 | // Keys returns all keys in scope 135 | func (scope *Scope) Keys() (keys []string) { 136 | currentScope := scope 137 | for currentScope != nil { 138 | currentScope.Lock() 139 | for k := range currentScope.Vals { 140 | keys = append(keys, k) 141 | } 142 | currentScope.Unlock() 143 | currentScope = scope.Parent 144 | } 145 | return 146 | } 147 | 148 | // NewChild creates a scope under the existing scope. 149 | func (scope *Scope) NewChild() *Scope { 150 | s := NewScope() 151 | s.Parent = scope 152 | return s 153 | } 154 | 155 | // Func represents an interpreted function definition. 156 | type Func struct { 157 | Def *ast.FuncLit 158 | } 159 | 160 | // ParseString parses go code into the ast nodes. 161 | func (scope *Scope) ParseString(exprStr string) (ast.Node, int, error) { 162 | exprStr = strings.Trim(exprStr, " \n\t") 163 | wrappedExpr := "func(){" + exprStr + "}()" 164 | shifted := 7 165 | expr, err := parser.ParseExpr(wrappedExpr) 166 | if err != nil && strings.HasPrefix(err.Error(), "1:8: expected statement, found '") { 167 | expr, err = parser.ParseExpr(exprStr) 168 | shifted = 0 169 | if err != nil { 170 | return expr, shifted, err 171 | } 172 | node, ok := expr.(ast.Node) 173 | if !ok { 174 | return nil, 0, errors.Errorf("expected ast.Node; got %#v", expr) 175 | } 176 | return node, shifted, nil 177 | } else if err != nil { 178 | return expr, shifted, err 179 | } 180 | if expr == nil { 181 | return nil, 0, errors.Errorf("expression is empty") 182 | } 183 | callExpr, ok := expr.(*ast.CallExpr) 184 | if !ok { 185 | return nil, 0, errors.Errorf("expected CallExpr; got %#v", callExpr) 186 | } 187 | return callExpr.Fun.(*ast.FuncLit).Body, shifted, nil 188 | } 189 | 190 | // InterpretString interprets a string of go code and returns the result. 191 | func (scope *Scope) InterpretString(exprStr string) (v interface{}, err error) { 192 | defer func() { 193 | if r := recover(); r != nil { 194 | err = errors.Errorf("interpreting %q: %s", exprStr, fmt.Sprint(r)) 195 | } 196 | }() 197 | 198 | node, _, err := scope.ParseString(exprStr) 199 | if err != nil { 200 | return node, err 201 | } 202 | errs := scope.CheckStatement(node) 203 | if len(errs) > 0 { 204 | return node, errs[0] 205 | } 206 | return scope.Interpret(node) 207 | } 208 | 209 | // Interpret interprets an ast.Node and returns the value. 210 | func (scope *Scope) Interpret(expr ast.Node) (interface{}, error) { 211 | builtinScope := map[string]interface{}{ 212 | "nil": nil, 213 | "true": true, 214 | "false": false, 215 | "append": Append, 216 | "make": Make, 217 | "len": Len, 218 | "close": Close, 219 | } 220 | 221 | switch e := expr.(type) { 222 | case *ast.Ident: 223 | 224 | typ, err := StringToType(e.Name) 225 | if err == nil { 226 | return typ, err 227 | } 228 | 229 | obj, exists := scope.Get(e.Name) 230 | if !exists { 231 | // TODO make builtinScope root of other scopes 232 | obj, exists = builtinScope[e.Name] 233 | if !exists { 234 | return nil, fmt.Errorf("can't find EXPR %s", e.Name) 235 | } 236 | } 237 | return obj, nil 238 | 239 | case *ast.SelectorExpr: 240 | X, err := scope.Interpret(e.X) 241 | if err != nil { 242 | return nil, err 243 | } 244 | sel := e.Sel 245 | 246 | rVal := reflect.ValueOf(X) 247 | if rVal.Kind() != reflect.Struct && rVal.Kind() != reflect.Ptr { 248 | return nil, fmt.Errorf("%#v is not a struct and thus has no field %#v", X, sel.Name) 249 | } 250 | 251 | pkg, isPackage := X.(Package) 252 | if isPackage { 253 | obj, isPresent := pkg.Functions[sel.Name] 254 | if isPresent { 255 | return obj, nil 256 | } 257 | return nil, fmt.Errorf("unknown field %#v", sel.Name) 258 | } 259 | 260 | if method := rVal.MethodByName(sel.Name); method.IsValid() { 261 | return method.Interface(), nil 262 | } 263 | if rVal.Kind() == reflect.Ptr { 264 | rVal = rVal.Elem() 265 | } 266 | if field := rVal.FieldByName(sel.Name); field.IsValid() { 267 | return field.Interface(), nil 268 | } 269 | return nil, fmt.Errorf("unknown field %#v", sel.Name) 270 | 271 | case *ast.CallExpr: 272 | args := make([]interface{}, len(e.Args)) 273 | for i, arg := range e.Args { 274 | interpretedArg, err := scope.Interpret(arg) 275 | if err != nil { 276 | return nil, err 277 | } 278 | args[i] = interpretedArg 279 | } 280 | 281 | return scope.ExecuteFunc(e.Fun, args) 282 | 283 | case *ast.GoStmt: 284 | go func() { 285 | _, err := scope.NewChild().Interpret(e.Call) 286 | if err != nil { 287 | fmt.Printf("goroutine failed: %s\n", err) 288 | } 289 | }() 290 | return nil, nil 291 | 292 | case *ast.BasicLit: 293 | switch e.Kind { 294 | case token.INT: 295 | n, err := strconv.ParseInt(e.Value, 0, 64) 296 | if err != nil { 297 | return nil, err 298 | } 299 | return int(n), nil 300 | case token.FLOAT, token.IMAG: 301 | v, err := strconv.ParseFloat(e.Value, 64) 302 | if err != nil { 303 | return nil, err 304 | } 305 | return v, nil 306 | case token.CHAR: 307 | return (rune)(e.Value[1]), nil 308 | case token.STRING: 309 | return e.Value[1 : len(e.Value)-1], nil 310 | default: 311 | return nil, fmt.Errorf("unknown basic literal %d", e.Kind) 312 | } 313 | 314 | case *ast.CompositeLit: 315 | typ, err := scope.Interpret(e.Type) 316 | if err != nil { 317 | return nil, err 318 | } 319 | 320 | switch t := e.Type.(type) { 321 | case *ast.ArrayType: 322 | l := len(e.Elts) 323 | aType := typ.(reflect.Type) 324 | var slice reflect.Value 325 | switch aType.Kind() { 326 | case reflect.Slice: 327 | slice = reflect.MakeSlice(aType, l, l) 328 | case reflect.Array: 329 | slice = reflect.New(aType).Elem() 330 | default: 331 | return nil, errors.Errorf("unknown array type %#v", typ) 332 | } 333 | 334 | if len(e.Elts) > slice.Len() { 335 | return nil, errors.Errorf("array index %d out of bounds [0:%d]", slice.Len(), slice.Len()) 336 | } 337 | 338 | for i, elem := range e.Elts { 339 | elemValue, err := scope.Interpret(elem) 340 | if err != nil { 341 | return nil, err 342 | } 343 | slice.Index(i).Set(reflect.ValueOf(elemValue)) 344 | } 345 | return slice.Interface(), nil 346 | 347 | case *ast.MapType: 348 | nMap := reflect.MakeMap(typ.(reflect.Type)) 349 | for _, elem := range e.Elts { 350 | switch eT := elem.(type) { 351 | case *ast.KeyValueExpr: 352 | key, err := scope.Interpret(eT.Key) 353 | if err != nil { 354 | return nil, err 355 | } 356 | val, err := scope.Interpret(eT.Value) 357 | if err != nil { 358 | return nil, err 359 | } 360 | nMap.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val)) 361 | 362 | default: 363 | return nil, fmt.Errorf("invalid element type %#v to map. Expecting key value pair", eT) 364 | } 365 | } 366 | return nMap.Interface(), nil 367 | 368 | case *ast.Ident, *ast.SelectorExpr: 369 | objPtr := reflect.New(typ.(reflect.Type)) 370 | obj := objPtr.Elem() 371 | for i, elem := range e.Elts { 372 | switch eT := elem.(type) { 373 | case *ast.BasicLit: 374 | val, err := scope.Interpret(eT) 375 | if err != nil { 376 | return nil, err 377 | } 378 | obj.Field(i).Set(reflect.ValueOf(val)) 379 | 380 | case *ast.KeyValueExpr: 381 | key := eT.Key.(*ast.Ident).Name 382 | val, err := scope.Interpret(eT.Value) 383 | if err != nil { 384 | return nil, err 385 | } 386 | obj.FieldByName(key).Set(reflect.ValueOf(val)) 387 | 388 | default: 389 | return nil, fmt.Errorf("invalid element type %T %#v to struct literal", eT, eT) 390 | } 391 | } 392 | return obj.Interface(), nil 393 | 394 | default: 395 | return nil, fmt.Errorf("unknown composite literal %#v", t) 396 | } 397 | 398 | case *ast.BinaryExpr: 399 | x, err := scope.Interpret(e.X) 400 | if err != nil { 401 | return nil, err 402 | } 403 | y, err := scope.Interpret(e.Y) 404 | if err != nil { 405 | return nil, err 406 | } 407 | return ComputeBinaryOp(x, y, e.Op) 408 | 409 | case *ast.UnaryExpr: 410 | // Handle indirection cases. 411 | if e.Op == token.AND { 412 | ident, isIdent := e.X.(*ast.Ident) 413 | if !isIdent { 414 | return nil, errors.Errorf("expected identifier; got %#v", e.X) 415 | } 416 | val, exists := scope.GetPointer(ident.Name) 417 | if !exists { 418 | return nil, errors.Errorf("unknown identifier %#v", ident) 419 | } 420 | return val, nil 421 | } 422 | 423 | x, err := scope.Interpret(e.X) 424 | if err != nil { 425 | return nil, err 426 | } 427 | return scope.ComputeUnaryOp(x, e.Op) 428 | 429 | case *ast.ArrayType: 430 | typ, err := scope.Interpret(e.Elt) 431 | if err != nil { 432 | return nil, err 433 | } 434 | rType, ok := typ.(reflect.Type) 435 | if !ok { 436 | return nil, errors.Errorf("invalid type %#v", typ) 437 | } 438 | if e.Len == nil { 439 | return reflect.SliceOf(rType), nil 440 | } 441 | 442 | len, err := scope.Interpret(e.Len) 443 | if err != nil { 444 | return nil, err 445 | } 446 | lenI, ok := len.(int) 447 | if !ok { 448 | return nil, errors.Errorf("expected int; got %#v", len) 449 | } 450 | if lenI < 0 { 451 | return nil, errors.Errorf("negative array size") 452 | } 453 | return reflect.ArrayOf(lenI, rType), nil 454 | 455 | case *ast.MapType: 456 | keyType, err := scope.Interpret(e.Key) 457 | if err != nil { 458 | return nil, err 459 | } 460 | valType, err := scope.Interpret(e.Value) 461 | if err != nil { 462 | return nil, err 463 | } 464 | mapType := reflect.MapOf(keyType.(reflect.Type), valType.(reflect.Type)) 465 | return mapType, nil 466 | 467 | case *ast.ChanType: 468 | typeI, err := scope.Interpret(e.Value) 469 | if err != nil { 470 | return nil, err 471 | } 472 | typ, isType := typeI.(reflect.Type) 473 | if !isType { 474 | return nil, fmt.Errorf("chan needs to be passed a type not %T", typ) 475 | } 476 | return reflect.ChanOf(reflect.BothDir, typ), nil 477 | 478 | case *ast.IndexExpr: 479 | X, err := scope.Interpret(e.X) 480 | if err != nil { 481 | return nil, err 482 | } 483 | i, err := scope.Interpret(e.Index) 484 | if err != nil { 485 | return nil, err 486 | } 487 | xVal := reflect.ValueOf(X) 488 | for xVal.Type().Kind() == reflect.Ptr { 489 | xVal = xVal.Elem() 490 | } 491 | switch xVal.Type().Kind() { 492 | case reflect.Map: 493 | val := xVal.MapIndex(reflect.ValueOf(i)) 494 | if !val.IsValid() { 495 | // If not valid key, return the "zero" type. Eg for int 0, string "" 496 | return reflect.Zero(xVal.Type().Elem()).Interface(), nil 497 | } 498 | return val.Interface(), nil 499 | 500 | case reflect.Slice, reflect.Array: 501 | iVal, isInt := i.(int) 502 | if !isInt { 503 | return nil, fmt.Errorf("index has to be an int not %T", i) 504 | } 505 | if iVal >= xVal.Len() || iVal < 0 { 506 | return nil, errors.New("slice index out of range") 507 | } 508 | 509 | return xVal.Index(iVal).Interface(), nil 510 | 511 | default: 512 | return nil, errors.Errorf("invalid X for IndexExpr: %#v", X) 513 | } 514 | 515 | case *ast.SliceExpr: 516 | low, err := scope.Interpret(e.Low) 517 | if err != nil { 518 | return nil, err 519 | } 520 | high, err := scope.Interpret(e.High) 521 | if err != nil { 522 | return nil, err 523 | } 524 | X, err := scope.Interpret(e.X) 525 | if err != nil { 526 | return nil, err 527 | } 528 | xVal := reflect.ValueOf(X) 529 | if low == nil { 530 | low = 0 531 | } 532 | kind := xVal.Kind() 533 | if kind != reflect.Array && kind != reflect.Slice { 534 | return nil, errors.Errorf("invalid X for SliceExpr: %#v", X) 535 | } 536 | if high == nil { 537 | high = xVal.Len() 538 | } 539 | lowVal, isLowInt := low.(int) 540 | highVal, isHighInt := high.(int) 541 | if !isLowInt || !isHighInt { 542 | return nil, fmt.Errorf("slice: indexes have to be an ints not %T and %T", low, high) 543 | } 544 | if lowVal < 0 || highVal >= xVal.Len() || highVal < lowVal { 545 | return nil, errors.New("slice: index out of bounds") 546 | } 547 | return xVal.Slice(lowVal, highVal).Interface(), nil 548 | 549 | case *ast.ParenExpr: 550 | return scope.Interpret(e.X) 551 | 552 | case *ast.FuncLit: 553 | return &Func{e}, nil 554 | case *ast.BlockStmt: 555 | var outFinal interface{} 556 | for _, stmts := range e.List { 557 | out, err := scope.Interpret(stmts) 558 | if err != nil { 559 | return out, err 560 | } 561 | outFinal = out 562 | } 563 | return outFinal, nil 564 | 565 | case *ast.ReturnStmt: 566 | results := make([]interface{}, len(e.Results)) 567 | for i, result := range e.Results { 568 | out, err := scope.Interpret(result) 569 | if err != nil { 570 | return out, err 571 | } 572 | results[i] = out 573 | } 574 | 575 | if len(results) == 0 { 576 | return nil, nil 577 | } else if len(results) == 1 { 578 | return results[0], nil 579 | } 580 | return results, nil 581 | 582 | case *ast.AssignStmt: 583 | // TODO implement type checking 584 | //define := e.Tok == token.DEFINE 585 | rhs := make([]interface{}, len(e.Rhs)) 586 | for i, expr := range e.Rhs { 587 | val, err := scope.Interpret(expr) 588 | if err != nil { 589 | return nil, err 590 | } 591 | rhs[i] = val 592 | } 593 | 594 | if len(rhs) == 1 && len(e.Lhs) > 1 && reflect.TypeOf(rhs[0]).Kind() == reflect.Slice { 595 | rhsV := reflect.ValueOf(rhs[0]) 596 | rhsLen := rhsV.Len() 597 | if rhsLen != len(e.Lhs) { 598 | return nil, fmt.Errorf("assignment count mismatch: %d = %d", len(e.Lhs), rhsLen) 599 | } 600 | 601 | rhs = rhs[:0] 602 | 603 | for i := 0; i < rhsLen; i++ { 604 | rhs = append(rhs, rhsV.Index(i).Interface()) 605 | } 606 | } 607 | 608 | if len(rhs) != len(e.Lhs) { 609 | return nil, fmt.Errorf("assignment count mismatch: %d = %d (%+v)", len(e.Lhs), len(rhs), rhs) 610 | } 611 | 612 | for i, id := range e.Lhs { 613 | getR := func(val interface{}) (interface{}, error) { 614 | r := rhs[i] 615 | isModAssign := e.Tok != token.ASSIGN && e.Tok != token.DEFINE 616 | if isModAssign { 617 | var err error 618 | r, err = ComputeBinaryOp(val, r, DeAssign(e.Tok)) 619 | if err != nil { 620 | return nil, err 621 | } 622 | } 623 | return r, nil 624 | } 625 | 626 | if ident, ok := id.(*ast.Ident); ok { 627 | val, exists := scope.Get(ident.Name) 628 | if !exists && (e.Tok != token.DEFINE) { 629 | return nil, errors.Errorf("undefined %s", ident.Name) 630 | } 631 | 632 | r, err := getR(val) 633 | if err != nil { 634 | return nil, err 635 | } 636 | scope.Set(ident.Name, r) 637 | continue 638 | } else if idx, ok := id.(*ast.IndexExpr); ok { 639 | left, err := scope.getValue(idx.X) 640 | if err != nil { 641 | return nil, err 642 | } 643 | if left.Type().Kind() == reflect.Map { 644 | index, err := scope.Interpret(idx.Index) 645 | if err != nil { 646 | return nil, err 647 | } 648 | var val interface{} 649 | leftV := left.MapIndex(reflect.ValueOf(index)) 650 | if leftV.IsValid() { 651 | val = leftV.Interface() 652 | } else { 653 | val = reflect.Zero(left.Type().Elem()).Interface() 654 | } 655 | r, err := getR(val) 656 | if err != nil { 657 | return nil, err 658 | } 659 | left.SetMapIndex(reflect.ValueOf(index), reflect.ValueOf(r)) 660 | continue 661 | } 662 | } 663 | 664 | val, err := scope.getValue(id) 665 | if err != nil { 666 | return nil, err 667 | } 668 | 669 | r, err := getR(val.Interface()) 670 | if err != nil { 671 | return nil, err 672 | } 673 | val.Set(reflect.ValueOf(r)) 674 | } 675 | 676 | if len(rhs) > 1 { 677 | return rhs, nil 678 | } 679 | return rhs[0], nil 680 | 681 | case *ast.IncDecStmt: 682 | var dir string 683 | switch e.Tok { 684 | case token.INC: 685 | dir = "1" 686 | case token.DEC: 687 | dir = "-1" 688 | } 689 | ass := &ast.AssignStmt{ 690 | Tok: token.ASSIGN, 691 | Lhs: []ast.Expr{e.X}, 692 | Rhs: []ast.Expr{&ast.BinaryExpr{ 693 | X: e.X, 694 | Op: token.ADD, 695 | Y: &ast.BasicLit{ 696 | Kind: token.INT, 697 | Value: dir, 698 | }, 699 | }}, 700 | } 701 | return scope.Interpret(ass) 702 | case *ast.RangeStmt: 703 | s := scope.NewChild() 704 | ranger, err := s.Interpret(e.X) 705 | if err != nil { 706 | return nil, err 707 | } 708 | var key, value string 709 | if e.Key != nil { 710 | key = e.Key.(*ast.Ident).Name 711 | } 712 | if e.Value != nil { 713 | value = e.Value.(*ast.Ident).Name 714 | } 715 | rv := reflect.ValueOf(ranger) 716 | switch rv.Type().Kind() { 717 | case reflect.Array, reflect.Slice: 718 | for i := 0; i < rv.Len(); i++ { 719 | if len(key) > 0 { 720 | s.Set(key, i) 721 | } 722 | if len(value) > 0 { 723 | s.Set(value, rv.Index(i).Interface()) 724 | } 725 | s.Interpret(e.Body) 726 | } 727 | case reflect.Map: 728 | keys := rv.MapKeys() 729 | for _, keyV := range keys { 730 | if len(key) > 0 { 731 | s.Set(key, keyV.Interface()) 732 | } 733 | if len(value) > 0 { 734 | s.Set(value, rv.MapIndex(keyV).Interface()) 735 | } 736 | s.Interpret(e.Body) 737 | } 738 | default: 739 | return nil, fmt.Errorf("ranging on %s is unsupported", rv.Type().Kind().String()) 740 | } 741 | return nil, nil 742 | case *ast.ExprStmt: 743 | return scope.Interpret(e.X) 744 | case *ast.DeclStmt: 745 | return scope.Interpret(e.Decl) 746 | case *ast.GenDecl: 747 | for _, spec := range e.Specs { 748 | if _, err := scope.Interpret(spec); err != nil { 749 | return nil, err 750 | } 751 | } 752 | return nil, nil 753 | case *ast.ValueSpec: 754 | typ, err := scope.Interpret(e.Type) 755 | if err != nil { 756 | return nil, err 757 | } 758 | zero := reflect.Zero(typ.(reflect.Type)).Interface() 759 | for i, name := range e.Names { 760 | if len(e.Values) > i { 761 | v, err := scope.Interpret(e.Values[i]) 762 | if err != nil { 763 | return nil, err 764 | } 765 | scope.Set(name.Name, v) 766 | } else { 767 | scope.Set(name.Name, zero) 768 | } 769 | } 770 | return nil, nil 771 | case *ast.ForStmt: 772 | s := scope.NewChild() 773 | if e.Init != nil { 774 | if _, err := s.Interpret(e.Init); err != nil { 775 | return nil, err 776 | } 777 | } 778 | var err error 779 | var last interface{} 780 | for { 781 | if e.Cond != nil { 782 | cond, err := s.Interpret(e.Cond) 783 | if err != nil { 784 | return nil, err 785 | } 786 | if cont, ok := cond.(bool); !ok { 787 | return nil, fmt.Errorf("for loop requires a boolean condition not %#v", cond) 788 | } else if !cont { 789 | return last, nil 790 | } 791 | } 792 | 793 | last, err = s.Interpret(e.Body) 794 | if err == ErrBranchBreak { 795 | break 796 | } else if err != nil && err != ErrBranchContinue { 797 | return nil, err 798 | } 799 | 800 | if e.Post != nil { 801 | if _, err := s.Interpret(e.Post); err != nil { 802 | return nil, err 803 | } 804 | } 805 | } 806 | return last, nil 807 | 808 | case *ast.BranchStmt: 809 | switch e.Tok { 810 | case token.BREAK: 811 | return nil, ErrBranchBreak 812 | case token.CONTINUE: 813 | return nil, ErrBranchContinue 814 | default: 815 | return nil, fmt.Errorf("unsupported BranchStmt %#v", e) 816 | } 817 | 818 | case *ast.SendStmt: 819 | val, err := scope.Interpret(e.Value) 820 | if err != nil { 821 | return nil, err 822 | } 823 | channel, err := scope.Interpret(e.Chan) 824 | if err != nil { 825 | return nil, err 826 | } 827 | chanV := reflect.ValueOf(channel) 828 | if chanV.Kind() != reflect.Chan { 829 | return nil, errors.Errorf("expected chan; got %#v", channel) 830 | } 831 | succeeded := chanV.TrySend(reflect.ValueOf(val)) 832 | if !succeeded { 833 | return nil, ErrChanSendFailed 834 | } 835 | return nil, nil 836 | 837 | case *ast.SelectStmt: 838 | list := e.Body.List 839 | var defaultCase *ast.CommClause 840 | 841 | // We're using a map here since we want iteration on clauses to be 842 | // pseudo-random. 843 | clauses := map[int]*ast.CommClause{} 844 | for i, stmt := range list { 845 | cc := stmt.(*ast.CommClause) 846 | if cc.Comm == nil { 847 | defaultCase = cc 848 | } else { 849 | clauses[i] = cc 850 | } 851 | } 852 | 853 | for { 854 | for _, cc := range clauses { 855 | child := scope.NewChild() 856 | child.isSelect = true 857 | _, err := child.Interpret(cc.Comm) 858 | child.isSelect = false 859 | if err == ErrChanSendFailed || err == ErrBranchContinue || err == ErrChanRecvInSelect { 860 | continue 861 | } else if err != nil { 862 | return nil, err 863 | } 864 | return child.Interpret(cc) 865 | } 866 | if defaultCase != nil { 867 | child := scope.NewChild() 868 | return child.Interpret(defaultCase) 869 | } 870 | time.Sleep(10 * time.Millisecond) 871 | } 872 | 873 | case *ast.SwitchStmt: 874 | list := e.Body.List 875 | var defaultCase *ast.CaseClause 876 | var clauses []*ast.CaseClause 877 | for _, stmt := range list { 878 | cc := stmt.(*ast.CaseClause) 879 | if cc.List == nil { 880 | defaultCase = cc 881 | } else { 882 | clauses = append(clauses, cc) 883 | } 884 | } 885 | 886 | currentScope := scope.NewChild() 887 | if e.Init != nil { 888 | if _, err := currentScope.Interpret(e.Init); err != nil { 889 | return nil, err 890 | } 891 | } 892 | 893 | var err error 894 | var want interface{} 895 | if e.Tag != nil { 896 | want, err = currentScope.Interpret(e.Tag) 897 | } else { 898 | want = true 899 | } 900 | if err != nil { 901 | return nil, err 902 | } 903 | 904 | for _, cc := range clauses { 905 | for _, c := range cc.List { 906 | child := currentScope.NewChild() 907 | out, err := child.Interpret(c) 908 | if err != nil { 909 | return nil, err 910 | } 911 | if reflect.DeepEqual(out, want) { 912 | return child.Interpret(cc) 913 | } 914 | } 915 | } 916 | if defaultCase != nil { 917 | child := scope.NewChild() 918 | return child.Interpret(defaultCase) 919 | } 920 | return nil, nil 921 | 922 | case *ast.TypeSwitchStmt: 923 | list := e.Body.List 924 | var defaultCase *ast.CaseClause 925 | var clauses []*ast.CaseClause 926 | for _, stmt := range list { 927 | cc := stmt.(*ast.CaseClause) 928 | if cc.List == nil { 929 | defaultCase = cc 930 | } else { 931 | clauses = append(clauses, cc) 932 | } 933 | } 934 | 935 | currentScope := scope.NewChild() 936 | if e.Init != nil { 937 | if _, err := currentScope.Interpret(e.Init); err != nil { 938 | return nil, err 939 | } 940 | } 941 | 942 | var want reflect.Type 943 | if e.Assign != nil { 944 | _, err := currentScope.Interpret(e.Assign) 945 | if err != nil { 946 | return nil, err 947 | } 948 | want = currentScope.typeAssert 949 | } 950 | 951 | for _, cc := range clauses { 952 | for _, c := range cc.List { 953 | child := currentScope.NewChild() 954 | out, err := child.Interpret(c) 955 | if err != nil { 956 | return nil, err 957 | } 958 | if out == want { 959 | return child.Interpret(cc) 960 | } 961 | } 962 | } 963 | if defaultCase != nil { 964 | child := scope.NewChild() 965 | return child.Interpret(defaultCase) 966 | } 967 | return nil, nil 968 | 969 | case *ast.CommClause: 970 | return scope.Interpret(&ast.BlockStmt{List: e.Body}) 971 | 972 | case *ast.CaseClause: 973 | return scope.Interpret(&ast.BlockStmt{List: e.Body}) 974 | 975 | case *ast.InterfaceType: 976 | if len(e.Methods.List) > 0 { 977 | return nil, fmt.Errorf("don't support non-anonymous interfaces yet") 978 | } 979 | return reflect.TypeOf((*interface{})(nil)).Elem(), nil 980 | 981 | case *ast.TypeAssertExpr: 982 | out, err := scope.Interpret(e.X) 983 | if err != nil { 984 | return nil, err 985 | } 986 | outType := reflect.TypeOf(out) 987 | if e.Type == nil { 988 | scope.typeAssert = outType 989 | return out, nil 990 | } 991 | typ, err := scope.Interpret(e.Type) 992 | if err != nil { 993 | return nil, err 994 | } 995 | if typ != outType { 996 | return nil, fmt.Errorf("%#v is not of type %#v, is %T", out, typ, out) 997 | } 998 | return out, nil 999 | 1000 | case *ast.IfStmt: 1001 | currentScope := scope.NewChild() 1002 | if e.Init != nil { 1003 | if _, err := currentScope.Interpret(e.Init); err != nil { 1004 | return nil, err 1005 | } 1006 | } 1007 | cond, err := currentScope.Interpret(e.Cond) 1008 | if err != nil { 1009 | return nil, err 1010 | } 1011 | if cond == true { 1012 | return currentScope.Interpret(e.Body) 1013 | } 1014 | return currentScope.Interpret(e.Else) 1015 | 1016 | case *ast.DeferStmt: 1017 | var args []interface{} 1018 | for _, arg := range e.Call.Args { 1019 | v, err := scope.Interpret(arg) 1020 | if err != nil { 1021 | return nil, err 1022 | } 1023 | args = append(args, v) 1024 | } 1025 | scope.Defer(&Defer{ 1026 | fun: e.Call.Fun, 1027 | scope: scope, 1028 | arguments: args, 1029 | }) 1030 | return nil, nil 1031 | 1032 | case *ast.StructType: 1033 | if len(e.Fields.List) > 0 { 1034 | return nil, errors.New("don't support non-empty structs yet") 1035 | } 1036 | return reflect.TypeOf(struct{}{}), nil 1037 | 1038 | default: 1039 | return nil, fmt.Errorf("unknown node %#v", e) 1040 | } 1041 | } 1042 | 1043 | func (scope *Scope) getValue(id ast.Expr) (reflect.Value, error) { 1044 | switch id := id.(type) { 1045 | case *ast.Ident: 1046 | variable := id.Name 1047 | current, exists := scope.GetPointer(variable) 1048 | if !exists { 1049 | return reflect.Value{}, fmt.Errorf("variable %#v is not defined", variable) 1050 | } 1051 | return reflect.ValueOf(current).Elem(), nil 1052 | 1053 | case *ast.IndexExpr: 1054 | index, err := scope.Interpret(id.Index) 1055 | if err != nil { 1056 | return reflect.Value{}, err 1057 | } 1058 | elem, err := scope.getValue(id.X) 1059 | if err != nil { 1060 | return reflect.Value{}, err 1061 | } 1062 | 1063 | switch elem.Kind() { 1064 | case reflect.Slice, reflect.Array: 1065 | indexInt, ok := index.(int) 1066 | if !ok { 1067 | return reflect.Value{}, errors.Errorf("expected index to be int, got %#v", index) 1068 | } 1069 | if indexInt >= elem.Len() { 1070 | return reflect.Value{}, errors.Errorf("index out of range") 1071 | } 1072 | return elem.Index(indexInt), nil 1073 | 1074 | case reflect.Map: 1075 | return elem.MapIndex(reflect.ValueOf(index)), nil 1076 | 1077 | default: 1078 | return reflect.Value{}, errors.Errorf("unknown type of X %#v", id) 1079 | } 1080 | 1081 | case *ast.SelectorExpr: 1082 | elem, err := scope.getValue(id.X) 1083 | if err != nil { 1084 | return reflect.Value{}, err 1085 | } 1086 | return elem.FieldByName(id.Sel.Name), nil 1087 | 1088 | default: 1089 | return reflect.Value{}, errors.Errorf("unknown assignment expr %#v", id) 1090 | } 1091 | } 1092 | 1093 | func (scope *Scope) ExecuteFunc(funExpr ast.Expr, args []interface{}) (interface{}, error) { 1094 | fun, err := scope.Interpret(funExpr) 1095 | if err != nil { 1096 | return nil, err 1097 | } 1098 | 1099 | switch funV := fun.(type) { 1100 | case reflect.Type: 1101 | if len(args) != 1 { 1102 | return nil, errors.Errorf("expected args len = 1; args %#v", args) 1103 | } 1104 | return reflect.ValueOf(args[0]).Convert(funV).Interface(), nil 1105 | 1106 | case *Func: 1107 | // TODO enforce func return values 1108 | currentScope := scope.NewChild() 1109 | i := 0 1110 | for _, arg := range funV.Def.Type.Params.List { 1111 | for _, name := range arg.Names { 1112 | currentScope.Set(name.Name, args[i]) 1113 | i++ 1114 | } 1115 | } 1116 | currentScope.isFunction = true 1117 | ret, err := currentScope.Interpret(funV.Def.Body) 1118 | if err != nil { 1119 | return nil, err 1120 | } 1121 | for i := len(currentScope.defers) - 1; i >= 0; i-- { 1122 | d := currentScope.defers[i] 1123 | if _, err := d.scope.ExecuteFunc(d.fun, d.arguments); err != nil { 1124 | return nil, err 1125 | } 1126 | } 1127 | return ret, nil 1128 | } 1129 | 1130 | funVal := reflect.ValueOf(fun) 1131 | 1132 | if funVal.Kind() != reflect.Func { 1133 | return nil, errors.Errorf("expected func; got %#v", fun) 1134 | } 1135 | 1136 | var valueArgs []reflect.Value 1137 | for _, v := range args { 1138 | valueArgs = append(valueArgs, reflect.ValueOf(v)) 1139 | } 1140 | funType := funVal.Type() 1141 | if (funType.NumIn() != len(valueArgs) && !funType.IsVariadic()) || (funType.IsVariadic() && len(valueArgs) < funType.NumIn()-1) { 1142 | return nil, errors.Errorf("number of arguments doesn't match function; expected %d; got %+v", funVal.Type().NumIn(), args) 1143 | } 1144 | values := ValuesToInterfaces(funVal.Call(valueArgs)) 1145 | if len(values) > 0 { 1146 | if last, ok := values[len(values)-1].(*InterpretError); ok { 1147 | values = values[:len(values)-1] 1148 | if err := last.Error(); err != nil { 1149 | return nil, err 1150 | } 1151 | } 1152 | } 1153 | 1154 | if len(values) == 0 { 1155 | return nil, nil 1156 | } else if len(values) == 1 { 1157 | return values[0], nil 1158 | } 1159 | return values, nil 1160 | } 1161 | 1162 | // ConfigureTypes configures the scope type checker 1163 | func (scope *Scope) ConfigureTypes(path string, line int) error { 1164 | scope.path = path 1165 | scope.line = line 1166 | scope.fset = token.NewFileSet() // positions are relative to fset 1167 | scope.config = &types.Config{ 1168 | FakeImportC: true, 1169 | Importer: getImporter(), 1170 | } 1171 | 1172 | // Parse the file containing this very example 1173 | // but stop after processing the imports. 1174 | f, err := scope.parseDir() 1175 | if err != nil { 1176 | return errors.Wrapf(err, "parser.ParseDir %q", scope.path) 1177 | } 1178 | 1179 | for name, file := range f { 1180 | scope.Files[name] = file 1181 | } 1182 | 1183 | _, errs := scope.TypeCheck() 1184 | if len(errs) > 0 { 1185 | return errors.Wrap(errs[0], "failed to TypeCheck") 1186 | } 1187 | 1188 | return nil 1189 | } 1190 | 1191 | // walker adapts a function to satisfy the ast.Visitor interface. 1192 | // The function return whether the walk should proceed into the node's children. 1193 | type walker func(ast.Node) bool 1194 | 1195 | func (w walker) Visit(node ast.Node) ast.Visitor { 1196 | if w(node) { 1197 | return w 1198 | } 1199 | return nil 1200 | } 1201 | 1202 | // CheckStatement checks if a statement is type safe 1203 | func (scope *Scope) CheckStatement(node ast.Node) (errs []error) { 1204 | for name, file := range scope.Files { 1205 | name = filepath.Dir(name) + "/." + filepath.Base(name) + "pry" 1206 | if name == scope.path { 1207 | ast.Walk(walker(func(n ast.Node) bool { 1208 | switch s := n.(type) { 1209 | case *ast.BlockStmt: 1210 | for i, stmt := range s.List { 1211 | pos := scope.fset.Position(stmt.Pos()) 1212 | if pos.Line == scope.line { 1213 | r := scope.Render(stmt) 1214 | if strings.HasPrefix(r, "pry.Apply") { 1215 | var iStmt []ast.Stmt 1216 | switch s2 := node.(type) { 1217 | case *ast.BlockStmt: 1218 | iStmt = append(iStmt, s2.List...) 1219 | case ast.Stmt: 1220 | iStmt = append(iStmt, s2) 1221 | case ast.Expr: 1222 | iStmt = append(iStmt, ast.Stmt(&ast.ExprStmt{X: s2})) 1223 | default: 1224 | errs = append(errs, errors.New("not a statement")) 1225 | return false 1226 | } 1227 | oldList := make([]ast.Stmt, len(s.List)) 1228 | copy(oldList, s.List) 1229 | 1230 | s.List = append(s.List, make([]ast.Stmt, len(iStmt))...) 1231 | 1232 | copy(s.List[i+len(iStmt):], s.List[i:]) 1233 | copy(s.List[i:], iStmt) 1234 | 1235 | _, errs = scope.TypeCheck() 1236 | if len(errs) > 0 { 1237 | s.List = oldList 1238 | return false 1239 | } 1240 | return false 1241 | } 1242 | } 1243 | } 1244 | } 1245 | return true 1246 | }), file) 1247 | } 1248 | } 1249 | return 1250 | } 1251 | 1252 | // Render renders an ast node 1253 | func (scope *Scope) Render(x ast.Node) string { 1254 | var buf bytes.Buffer 1255 | if err := printer.Fprint(&buf, scope.fset, x); err != nil { 1256 | panic(err) 1257 | } 1258 | return buf.String() 1259 | } 1260 | 1261 | // TypeCheck does type checking and returns the info object 1262 | func (scope *Scope) TypeCheck() (*types.Info, []error) { 1263 | var errs []error 1264 | scope.config.Error = func(err error) { 1265 | if !strings.HasSuffix(err.Error(), "not used") { 1266 | err := errors.New(strings.TrimPrefix(err.Error(), scope.path)) 1267 | errs = append(errs, errors.Wrapf(err, "path %q", scope.path)) 1268 | } 1269 | } 1270 | info := &types.Info{} 1271 | var files []*ast.File 1272 | for _, f := range scope.Files { 1273 | files = append(files, f) 1274 | } 1275 | // these errors should be reported via the error reporter above 1276 | if _, err := scope.config.Check(filepath.Dir(scope.path), scope.fset, files, info); errs == nil && err != nil { 1277 | return nil, []error{err} 1278 | } 1279 | return info, errs 1280 | } 1281 | 1282 | // StringToType returns the reflect.Type corresponding to the type string provided. Ex: StringToType("int") 1283 | func StringToType(str string) (reflect.Type, error) { 1284 | builtinTypes := map[string]reflect.Type{ 1285 | "bool": reflect.TypeOf(true), 1286 | "byte": reflect.TypeOf(byte(0)), 1287 | "rune": reflect.TypeOf(rune(0)), 1288 | "string": reflect.TypeOf(""), 1289 | "int": reflect.TypeOf(int(0)), 1290 | "int8": reflect.TypeOf(int8(0)), 1291 | "int16": reflect.TypeOf(int16(0)), 1292 | "int32": reflect.TypeOf(int32(0)), 1293 | "int64": reflect.TypeOf(int64(0)), 1294 | "uint": reflect.TypeOf(uint(0)), 1295 | "uint8": reflect.TypeOf(uint8(0)), 1296 | "uint16": reflect.TypeOf(uint16(0)), 1297 | "uint32": reflect.TypeOf(uint32(0)), 1298 | "uint64": reflect.TypeOf(uint64(0)), 1299 | "uintptr": reflect.TypeOf(uintptr(0)), 1300 | "float32": reflect.TypeOf(float32(0)), 1301 | "float64": reflect.TypeOf(float64(0)), 1302 | "complex64": reflect.TypeOf(complex64(0)), 1303 | "complex128": reflect.TypeOf(complex128(0)), 1304 | "error": reflect.TypeOf(errors.New("")), 1305 | } 1306 | val, present := builtinTypes[str] 1307 | if !present { 1308 | return nil, fmt.Errorf("type %#v is not in table", str) 1309 | } 1310 | return val, nil 1311 | } 1312 | 1313 | // ValuesToInterfaces converts a slice of []reflect.Value to []interface{} 1314 | func ValuesToInterfaces(vals []reflect.Value) []interface{} { 1315 | inters := make([]interface{}, len(vals)) 1316 | for i, val := range vals { 1317 | inters[i] = val.Interface() 1318 | } 1319 | return inters 1320 | } 1321 | -------------------------------------------------------------------------------- /pry/interpreter_test.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestEmptyString(t *testing.T) { 10 | t.Parallel() 11 | 12 | scope := NewScope() 13 | out, err := scope.InterpretString(``) 14 | if err != nil { 15 | t.Error(err) 16 | } 17 | if out != nil { 18 | t.Error("expected nil") 19 | } 20 | } 21 | 22 | // Literals 23 | func TestStringLiteral(t *testing.T) { 24 | t.Parallel() 25 | 26 | scope := NewScope() 27 | out, err := scope.InterpretString(`"Hello!"`) 28 | if err != nil { 29 | t.Error(err) 30 | } 31 | if out != "Hello!" { 32 | t.Error("Expected Hello!") 33 | } 34 | } 35 | 36 | func TestIntLiteral(t *testing.T) { 37 | t.Parallel() 38 | 39 | scope := NewScope() 40 | out, err := scope.InterpretString(`-1234`) 41 | if err != nil { 42 | t.Error(err) 43 | } 44 | if out != -1234 { 45 | t.Error("Expected -1234") 46 | } 47 | } 48 | func TestHexIntLiteral(t *testing.T) { 49 | t.Parallel() 50 | 51 | scope := NewScope() 52 | out, err := scope.InterpretString(`0xC123`) 53 | if err != nil { 54 | t.Error(err) 55 | } 56 | expected := 0xC123 57 | if !reflect.DeepEqual(expected, out) { 58 | t.Errorf("Expected %#v got %#v.", expected, out) 59 | } 60 | } 61 | func TestOctalIntLiteral(t *testing.T) { 62 | t.Parallel() 63 | 64 | scope := NewScope() 65 | out, err := scope.InterpretString(`03272`) 66 | if err != nil { 67 | t.Error(err) 68 | } 69 | expected := 03272 70 | if !reflect.DeepEqual(expected, out) { 71 | t.Errorf("Expected %#v got %#v.", expected, out) 72 | } 73 | } 74 | func TestCharLiteral(t *testing.T) { 75 | t.Parallel() 76 | 77 | scope := NewScope() 78 | out, err := scope.InterpretString(`'a'`) 79 | if err != nil { 80 | t.Error(err) 81 | } 82 | if out != 'a' { 83 | t.Errorf("Expected 'a' got %#v.", out) 84 | } 85 | } 86 | 87 | func TestArrayLiteral(t *testing.T) { 88 | t.Parallel() 89 | 90 | scope := NewScope() 91 | out, err := scope.InterpretString(`[]int{1,2,3,4}`) 92 | if err != nil { 93 | t.Error(err) 94 | } 95 | expected := []int{1, 2, 3, 4} 96 | if !reflect.DeepEqual(expected, out) { 97 | t.Errorf("Expected %#v got %#v.", expected, out) 98 | } 99 | } 100 | 101 | func TestFixedArrayLiteral(t *testing.T) { 102 | t.Parallel() 103 | 104 | scope := NewScope() 105 | out, err := scope.InterpretString(`[4]int{1,2,3,4}`) 106 | if err != nil { 107 | t.Error(err) 108 | } 109 | expected := [4]int{1, 2, 3, 4} 110 | if !reflect.DeepEqual(expected, out) { 111 | t.Errorf("Expected %#v got %#v.", expected, out) 112 | } 113 | } 114 | 115 | func TestFixedArray(t *testing.T) { 116 | t.Parallel() 117 | 118 | scope := NewScope() 119 | out, err := scope.InterpretString(` 120 | var a [3]int 121 | a[2] 122 | `) 123 | if err != nil { 124 | t.Error(err) 125 | } 126 | expected := 0 127 | if !reflect.DeepEqual(expected, out) { 128 | t.Errorf("Expected %#v got %#v.", expected, out) 129 | } 130 | } 131 | 132 | func TestFixedArraySet(t *testing.T) { 133 | t.Parallel() 134 | 135 | scope := NewScope() 136 | out, err := scope.InterpretString(` 137 | var a [3]int 138 | b := &a 139 | a[2] = 1 140 | b[2] 141 | `) 142 | if err != nil { 143 | t.Errorf("%+v", err) 144 | } 145 | expected := 1 146 | if !reflect.DeepEqual(expected, out) { 147 | t.Errorf("Expected %#v got %#v.", expected, out) 148 | } 149 | } 150 | 151 | func TestArraySet(t *testing.T) { 152 | t.Parallel() 153 | 154 | scope := NewScope() 155 | out, err := scope.InterpretString(` 156 | a := []int{1,2,3,4} 157 | a[2] = 1 158 | a[2] 159 | `) 160 | if err != nil { 161 | t.Errorf("%+v", err) 162 | } 163 | expected := 1 164 | if !reflect.DeepEqual(expected, out) { 165 | t.Errorf("Expected %#v got %#v.", expected, out) 166 | } 167 | } 168 | 169 | func TestMapLiteral(t *testing.T) { 170 | t.Parallel() 171 | 172 | scope := NewScope() 173 | out, err := scope.InterpretString(` 174 | map[string]int{ 175 | "duck": 5, 176 | "banana": -123, 177 | } 178 | `) 179 | if err != nil { 180 | t.Error(err) 181 | } 182 | expected := map[string]int{ 183 | "duck": 5, 184 | "banana": -123, 185 | } 186 | if !reflect.DeepEqual(expected, out) { 187 | t.Errorf("Expected %#v got %#v.", expected, out) 188 | } 189 | } 190 | 191 | func TestMapSet(t *testing.T) { 192 | t.Parallel() 193 | 194 | scope := NewScope() 195 | out, err := scope.InterpretString(` 196 | a := map[string]int{} 197 | a["blah"] = 1 198 | a["blah"] 199 | `) 200 | if err != nil { 201 | t.Errorf("%+v", err) 202 | } 203 | expected := 1 204 | if !reflect.DeepEqual(expected, out) { 205 | t.Errorf("Expected %#v got %#v.", expected, out) 206 | } 207 | } 208 | 209 | func TestMapLiteralInterface(t *testing.T) { 210 | t.Parallel() 211 | 212 | scope := NewScope() 213 | out, err := scope.InterpretString(` 214 | map[string]interface{}{ 215 | "duck": 5, 216 | "banana": -123, 217 | } 218 | `) 219 | if err != nil { 220 | t.Error(err) 221 | } 222 | expected := map[string]interface{}{ 223 | "duck": 5, 224 | "banana": -123, 225 | } 226 | if !reflect.DeepEqual(expected, out) { 227 | t.Errorf("Expected %#v got %#v.", expected, out) 228 | } 229 | } 230 | 231 | func TestTypeCast(t *testing.T) { 232 | t.Parallel() 233 | 234 | scope := NewScope() 235 | scope.Set("a", -1234.0) 236 | out, err := scope.InterpretString(`int(a)`) 237 | if err != nil { 238 | t.Error(err) 239 | } 240 | expected := -1234 241 | if !reflect.DeepEqual(expected, out) { 242 | t.Errorf("Expected %#v got %#v.", expected, out) 243 | } 244 | } 245 | 246 | // Selectors and Ident 247 | func TestBasicIdent(t *testing.T) { 248 | t.Parallel() 249 | 250 | scope := NewScope() 251 | scope.Set("a", 5) 252 | out, err := scope.InterpretString(`a`) 253 | if err != nil { 254 | t.Error(err) 255 | } 256 | expected := 5 257 | if !reflect.DeepEqual(expected, out) { 258 | t.Errorf("Expected %#v got %#v.", expected, out) 259 | } 260 | } 261 | func TestMissingBasicIdent(t *testing.T) { 262 | t.Parallel() 263 | 264 | scope := NewScope() 265 | out, err := scope.InterpretString(`a`) 266 | if err == nil || out != nil { 267 | t.Error("Found non-existant ident.") 268 | } 269 | } 270 | func TestMapIdent(t *testing.T) { 271 | t.Parallel() 272 | 273 | scope := NewScope() 274 | scope.Set("a", map[string]int{ 275 | "B": 10, 276 | }) 277 | out, err := scope.InterpretString(`a["B"]`) 278 | if err != nil { 279 | t.Error(err) 280 | } 281 | expected := 10 282 | if !reflect.DeepEqual(expected, out) { 283 | t.Errorf("Expected %#v got %#v.", expected, out) 284 | } 285 | } 286 | func TestMissingMapIdent(t *testing.T) { 287 | t.Parallel() 288 | 289 | scope := NewScope() 290 | scope.Set("a", map[string]int{}) 291 | 292 | out, err := scope.InterpretString(`a["b"]`) 293 | if err != nil { 294 | t.Error(err) 295 | } 296 | if out != 0 { 297 | t.Error("Found non-existant ident.") 298 | } 299 | } 300 | func TestArrIdent(t *testing.T) { 301 | t.Parallel() 302 | 303 | scope := NewScope() 304 | scope.Set("a", []int{1, 2, 3}) 305 | 306 | out, err := scope.InterpretString(`a[1]`) 307 | if err != nil { 308 | t.Error(err) 309 | } 310 | expected := 2 311 | if !reflect.DeepEqual(expected, out) { 312 | t.Errorf("Expected %#v got %#v.", expected, out) 313 | } 314 | } 315 | 316 | func TestMissingArrIdent(t *testing.T) { 317 | t.Parallel() 318 | 319 | scope := NewScope() 320 | scope.Set("a", []int{1}) 321 | 322 | out, err := scope.InterpretString(`a[1]`) 323 | if err == nil || out != nil { 324 | t.Error("Should have thrown out of range error") 325 | } 326 | } 327 | 328 | func TestSlice(t *testing.T) { 329 | t.Parallel() 330 | 331 | scope := NewScope() 332 | scope.Set("a", []int{1, 2, 3, 4}) 333 | 334 | out, err := scope.InterpretString(`a[1:3]`) 335 | if err != nil { 336 | t.Error(err) 337 | } 338 | expected := []int{2, 3} 339 | if !reflect.DeepEqual(expected, out) { 340 | t.Errorf("Expected %#v got %#v.", expected, out) 341 | } 342 | } 343 | 344 | // Structs 345 | type testStruct struct { 346 | A int 347 | C, D string 348 | } 349 | 350 | func (a testStruct) B() int { 351 | return a.A 352 | } 353 | 354 | func TestSelector(t *testing.T) { 355 | t.Parallel() 356 | 357 | scope := NewScope() 358 | scope.Set("a", testStruct{A: 1}) 359 | 360 | out, err := scope.InterpretString(`a.A`) 361 | if err != nil { 362 | t.Error(err) 363 | } 364 | expected := 1 365 | if !reflect.DeepEqual(expected, out) { 366 | t.Errorf("Expected %#v got %#v.", expected, out) 367 | } 368 | } 369 | 370 | func TestStructLiteral(t *testing.T) { 371 | scope := NewScope() 372 | scope.Set("a", Type(testStruct{})) 373 | 374 | out, err := scope.InterpretString(`a{0, "a", "b"}`) 375 | if err != nil { 376 | t.Error(err) 377 | } 378 | expected := testStruct{0, "a", "b"} 379 | if !reflect.DeepEqual(expected, out) { 380 | t.Errorf("Expected %#v got %#v.", expected, out) 381 | } 382 | } 383 | 384 | func TestStructLiteralNamed(t *testing.T) { 385 | scope := NewScope() 386 | scope.Set("a", Type(testStruct{})) 387 | 388 | out, err := scope.InterpretString(`a{C: "c", A: 0}`) 389 | if err != nil { 390 | t.Error(err) 391 | } 392 | expected := testStruct{C: "c", A: 0} 393 | if !reflect.DeepEqual(expected, out) { 394 | t.Errorf("Expected %#v got %#v.", expected, out) 395 | } 396 | } 397 | 398 | func TestStructLiteralEmpty(t *testing.T) { 399 | scope := NewScope() 400 | scope.Set("a", Type(testStruct{})) 401 | 402 | out, err := scope.InterpretString(`a{}`) 403 | if err != nil { 404 | t.Error(err) 405 | } 406 | expected := testStruct{} 407 | if !reflect.DeepEqual(expected, out) { 408 | t.Errorf("Expected %#v got %#v.", expected, out) 409 | } 410 | } 411 | 412 | func TestStructSelectorAssignment(t *testing.T) { 413 | scope := NewScope() 414 | scope.Set("a", testStruct{}) 415 | 416 | out, err := scope.InterpretString(`a.A = 10; a`) 417 | if err != nil { 418 | t.Error(err) 419 | } 420 | expected := testStruct{A: 10} 421 | if !reflect.DeepEqual(expected, out) { 422 | t.Errorf("Expected %#v got %#v.", expected, out) 423 | } 424 | } 425 | 426 | func TestSelectorFunc(t *testing.T) { 427 | t.Parallel() 428 | 429 | scope := NewScope() 430 | scope.Set("a", testStruct{A: 1}) 431 | 432 | out, err := scope.InterpretString(`a.B()`) 433 | if err != nil { 434 | t.Error(err) 435 | } 436 | expected := 1 437 | if !reflect.DeepEqual(expected, out) { 438 | t.Errorf("Expected %#v got %#v.", expected, out) 439 | } 440 | } 441 | 442 | // Basic Math 443 | func TestBasicMath(t *testing.T) { 444 | t.Parallel() 445 | 446 | scope := NewScope() 447 | pairs := map[string]interface{}{ 448 | "2*3": 6, 449 | "2.0 * 3.0": 6.0, 450 | "10 / 2": 5, 451 | "10.0 / 2.0": 5.0, 452 | "1 + 2": 3, 453 | "1.0 + 2.0": 3.0, 454 | } 455 | for k, expected := range pairs { 456 | out, err := scope.InterpretString(k) 457 | if err != nil { 458 | t.Error(err) 459 | } 460 | if !reflect.DeepEqual(expected, out) { 461 | t.Errorf("Expected %#v got %#v.", expected, out) 462 | } 463 | } 464 | } 465 | 466 | func TestMathShifting(t *testing.T) { 467 | t.Parallel() 468 | 469 | types := []string{ 470 | "int", "int8", "int16", "int32", "int64", 471 | "uint", "uint8", "uint16", "uint32", "uint64", 472 | "uintptr", 473 | } 474 | cases := []struct { 475 | l int 476 | op string 477 | r, out int 478 | }{ 479 | {3, "%", 2, 1}, 480 | {7, "&", 2, 2}, 481 | {6, "|", 2, 6}, 482 | {6, "^", 2, 4}, 483 | {2, "<<", 2, 8}, 484 | {8, ">>", 2, 2}, 485 | {6, "&^", 4, 2}, 486 | } 487 | scope := NewScope() 488 | for _, typ := range types { 489 | for _, td := range cases { 490 | query := fmt.Sprintf("%s(%d) %s %s(%d)", typ, td.l, td.op, typ, td.r) 491 | outI, err := scope.InterpretString(query) 492 | if err != nil { 493 | t.Error(err) 494 | } 495 | out := interfaceToInt(outI) 496 | if !reflect.DeepEqual(td.out, out) { 497 | t.Errorf("Expected %#v = %#v got %#v.", query, td.out, out) 498 | } 499 | } 500 | } 501 | } 502 | 503 | func TestMathBasic(t *testing.T) { 504 | t.Parallel() 505 | 506 | types := []string{ 507 | "int", "int8", "int16", "int32", "int64", 508 | "uint", "uint8", "uint16", "uint32", "uint64", 509 | "uintptr", 510 | "float32", "float64", 511 | } 512 | cases := []struct { 513 | l int 514 | op string 515 | r, out int 516 | }{ 517 | {3, "+", 2, 5}, 518 | {3, "-", 2, 1}, 519 | {3, "*", 2, 6}, 520 | {4, "/", 2, 2}, 521 | {4, ">", 3, 1}, 522 | {3, ">", 4, -1}, 523 | {4, ">=", 3, 1}, 524 | {3, ">=", 4, -1}, 525 | {4, "<", 3, -1}, 526 | {3, "<", 4, 1}, 527 | {4, "<=", 3, -1}, 528 | {3, "<=", 4, 1}, 529 | {3, "==", 3, 1}, 530 | {3, "==", 4, -1}, 531 | {3, "!=", 3, -1}, 532 | {3, "!=", 4, 1}, 533 | } 534 | scope := NewScope() 535 | for _, typ := range types { 536 | for _, td := range cases { 537 | query := fmt.Sprintf("%s(%d) %s %s(%d)", typ, td.l, td.op, typ, td.r) 538 | outI, err := scope.InterpretString(query) 539 | if err != nil { 540 | t.Error(err) 541 | } 542 | out := interfaceToInt(outI) 543 | if !reflect.DeepEqual(td.out, out) { 544 | t.Errorf("Expected %#v = %#v got %#v.", query, td.out, out) 545 | } 546 | } 547 | } 548 | } 549 | 550 | func TestBoolConds(t *testing.T) { 551 | t.Parallel() 552 | 553 | cases := []struct { 554 | l bool 555 | op string 556 | r, out bool 557 | }{ 558 | {true, "&&", true, true}, 559 | {true, "&&", false, false}, 560 | {false, "&&", true, false}, 561 | {false, "&&", false, false}, 562 | {true, "||", true, true}, 563 | {true, "||", false, true}, 564 | {false, "||", true, true}, 565 | {false, "||", false, false}, 566 | } 567 | scope := NewScope() 568 | for _, td := range cases { 569 | query := fmt.Sprintf("%#v %s %#v", td.l, td.op, td.r) 570 | outI, err := scope.InterpretString(query) 571 | if err != nil { 572 | t.Error(err) 573 | } 574 | out := outI.(bool) 575 | if !reflect.DeepEqual(td.out, out) { 576 | t.Errorf("Expected %#v = %#v got %#v.", query, td.out, out) 577 | } 578 | } 579 | } 580 | 581 | func interfaceToInt(i interface{}) int { 582 | switch v := i.(type) { 583 | case int: 584 | return int(v) 585 | case int8: 586 | return int(v) 587 | case int16: 588 | return int(v) 589 | case int32: 590 | return int(v) 591 | case int64: 592 | return int(v) 593 | case uint: 594 | return int(v) 595 | case uint8: 596 | return int(v) 597 | case uint16: 598 | return int(v) 599 | case uint32: 600 | return int(v) 601 | case uint64: 602 | return int(v) 603 | case uintptr: 604 | return int(v) 605 | case float32: 606 | return int(v) 607 | case float64: 608 | return int(v) 609 | case bool: 610 | if v { 611 | return 1 612 | } 613 | return -1 614 | } 615 | return 0 616 | } 617 | 618 | func TestStringConcat(t *testing.T) { 619 | t.Parallel() 620 | 621 | scope := NewScope() 622 | scope.Set("a", 5) 623 | 624 | out, err := scope.InterpretString(`"hello" + "foo"`) 625 | if err != nil { 626 | t.Error(err) 627 | } 628 | expected := "hellofoo" 629 | if !reflect.DeepEqual(expected, out) { 630 | t.Errorf("Expected %#v got %#v.", expected, out) 631 | } 632 | } 633 | 634 | func TestParens(t *testing.T) { 635 | t.Parallel() 636 | 637 | scope := NewScope() 638 | scope.Set("a", 5) 639 | 640 | out, err := scope.InterpretString(`((10) * (a))`) 641 | if err != nil { 642 | t.Error(err) 643 | } 644 | expected := 50 645 | if !reflect.DeepEqual(expected, out) { 646 | t.Errorf("Expected %#v got %#v.", expected, out) 647 | } 648 | } 649 | 650 | // Test Make 651 | func TestMakeSlice(t *testing.T) { 652 | t.Parallel() 653 | 654 | scope := NewScope() 655 | out, err := scope.InterpretString(`make([]int, 1, 10)`) 656 | if err != nil { 657 | t.Error(err) 658 | } 659 | expected := make([]int, 1, 10) 660 | if !reflect.DeepEqual(expected, out) { 661 | t.Errorf("Expected %#v got %#v.", expected, out) 662 | } 663 | } 664 | 665 | func TestMakeChan(t *testing.T) { 666 | t.Parallel() 667 | 668 | scope := NewScope() 669 | out, err := scope.InterpretString(`make(chan int, 10)`) 670 | if err != nil { 671 | t.Error(err) 672 | } 673 | expected := make(chan int, 10) 674 | if reflect.TypeOf(expected) != reflect.TypeOf(out) { 675 | t.Errorf("Expected %#v got %#v.", expected, out) 676 | } 677 | } 678 | 679 | func TestMakeChanInterface(t *testing.T) { 680 | t.Parallel() 681 | 682 | scope := NewScope() 683 | out, err := scope.InterpretString(`make(chan interface{}, 10)`) 684 | if err != nil { 685 | t.Error(err) 686 | } 687 | expected := make(chan interface{}, 10) 688 | if reflect.TypeOf(expected) != reflect.TypeOf(out) { 689 | t.Errorf("Expected %#v got %#v.", expected, out) 690 | } 691 | } 692 | 693 | func TestMakeUnknown(t *testing.T) { 694 | t.Parallel() 695 | 696 | scope := NewScope() 697 | out, err := scope.InterpretString(`make(int)`) 698 | if err == nil || out != nil { 699 | t.Error("Should have thrown error.") 700 | } 701 | } 702 | 703 | func TestAppend(t *testing.T) { 704 | t.Parallel() 705 | 706 | scope := NewScope() 707 | scope.Set("a", []int{1}) 708 | 709 | _, err := scope.InterpretString(`a = append(a, 2, 3)`) 710 | if err != nil { 711 | t.Error(err) 712 | } 713 | expected := []int{1, 2, 3} 714 | outV, found := scope.Get("a") 715 | if !found { 716 | t.Errorf("failed to find \"a\"") 717 | } 718 | out := outV.([]int) 719 | if !reflect.DeepEqual(expected, out) { 720 | t.Errorf("Expected %#v got %#v.", expected, out) 721 | } 722 | } 723 | 724 | func TestMultiReturn(t *testing.T) { 725 | t.Parallel() 726 | 727 | scope := NewScope() 728 | scope.Set("f", func() (int, error) { 729 | return 0, nil 730 | }) 731 | 732 | _, err := scope.InterpretString(`a, err := f()`) 733 | if err != nil { 734 | t.Error(err) 735 | } 736 | expected := 0 737 | outV, found := scope.Get("a") 738 | if !found { 739 | t.Errorf("failed to find \"a\"") 740 | } 741 | out := outV.(int) 742 | if !reflect.DeepEqual(expected, out) { 743 | t.Errorf("Expected %#v got %#v.", expected, out) 744 | } 745 | } 746 | 747 | func TestDeclareAssignVar(t *testing.T) { 748 | t.Parallel() 749 | 750 | scope := NewScope() 751 | scope.Set("a", []int{1}) 752 | 753 | out, err := scope.InterpretString(`var a, b int = 2, 3`) 754 | if err != nil { 755 | t.Error(err) 756 | } 757 | testData := []struct { 758 | v string 759 | want int 760 | }{ 761 | {"a", 2}, 762 | {"b", 3}, 763 | } 764 | for _, td := range testData { 765 | out, _ = scope.Get(td.v) 766 | if !reflect.DeepEqual(td.want, out) { 767 | t.Errorf("Expected %#v got %#v.", td.want, out) 768 | } 769 | } 770 | } 771 | 772 | func TestDeclareAssign(t *testing.T) { 773 | t.Parallel() 774 | 775 | scope := NewScope() 776 | scope.Set("a", []int{1}) 777 | 778 | out, err := scope.InterpretString(`b := 2`) 779 | if err != nil { 780 | t.Error(err) 781 | } 782 | expected := 2 783 | out, _ = scope.Get("b") 784 | if !reflect.DeepEqual(expected, out) { 785 | t.Errorf("Expected %#v got %#v.", expected, out) 786 | } 787 | } 788 | 789 | func TestAssign(t *testing.T) { 790 | t.Parallel() 791 | 792 | scope := NewScope() 793 | scope.Set("a", 1) 794 | 795 | if _, err := scope.InterpretString(`b = 1`); err == nil { 796 | t.Fatal("expected error") 797 | } 798 | 799 | out, err := scope.InterpretString(`a = 2`) 800 | if err != nil { 801 | t.Error(err) 802 | } 803 | expected := 2 804 | out, _ = scope.Get("a") 805 | if !reflect.DeepEqual(expected, out) { 806 | t.Errorf("Expected %#v got %#v.", expected, out) 807 | } 808 | } 809 | 810 | // Statements 811 | 812 | func TestFuncDeclAndCall(t *testing.T) { 813 | t.Parallel() 814 | 815 | scope := NewScope() 816 | 817 | out, err := scope.InterpretString(` 818 | a := func(){ return 5 } 819 | a() 820 | `) 821 | if err != nil { 822 | t.Error(err) 823 | } 824 | expected := 5 825 | if !reflect.DeepEqual(expected, out) { 826 | t.Errorf("Expected %#v got %#v.", expected, out) 827 | } 828 | } 829 | 830 | // Channels 831 | 832 | func TestChannel(t *testing.T) { 833 | t.Parallel() 834 | 835 | scope := NewScope() 836 | 837 | out, err := scope.InterpretString(` 838 | a := make(chan int, 10) 839 | a <- 1 840 | a <- 2 841 | []int{<-a, <-a} 842 | `) 843 | if err != nil { 844 | t.Error(err) 845 | } 846 | expected := []int{1, 2} 847 | if !reflect.DeepEqual(expected, out) { 848 | t.Errorf("Expected %#v got %#v.", expected, out) 849 | } 850 | } 851 | 852 | func TestChannelSendFail(t *testing.T) { 853 | t.Parallel() 854 | 855 | scope := NewScope() 856 | 857 | _, out := scope.InterpretString(` 858 | a := make(chan int) 859 | a <- 1 860 | `) 861 | expected := ErrChanSendFailed 862 | if !reflect.DeepEqual(expected, out) { 863 | t.Errorf("Expected err %#v got %#v.", expected, out) 864 | } 865 | } 866 | 867 | func TestChannelRecvFail(t *testing.T) { 868 | t.Parallel() 869 | 870 | scope := NewScope() 871 | 872 | _, out := scope.InterpretString(` 873 | a := make(chan int) 874 | close(a) 875 | <-a 876 | `) 877 | expected := ErrChanRecvFailed 878 | if !reflect.DeepEqual(expected, out) { 879 | t.Errorf("Expected err %#v got %#v.", expected, out) 880 | } 881 | } 882 | 883 | // Control structures 884 | 885 | func TestFor(t *testing.T) { 886 | t.Parallel() 887 | 888 | scope := NewScope() 889 | 890 | out, err := scope.InterpretString(` 891 | a := 1 892 | for i := 0; i < 5; i++ { 893 | a++ 894 | } 895 | a 896 | `) 897 | if err != nil { 898 | t.Error(err) 899 | } 900 | expected := 6 901 | if !reflect.DeepEqual(expected, out) { 902 | t.Errorf("Expected %#v got %#v.", expected, out) 903 | } 904 | 905 | out, err = scope.InterpretString(` 906 | a := 1 907 | for i := 5; i > 0; i-- { 908 | a++ 909 | } 910 | a 911 | `) 912 | if err != nil { 913 | t.Error(err) 914 | } 915 | if !reflect.DeepEqual(expected, out) { 916 | t.Errorf("Expected %#v got %#v.", expected, out) 917 | } 918 | } 919 | 920 | func TestForBreak(t *testing.T) { 921 | t.Parallel() 922 | 923 | scope := NewScope() 924 | 925 | _, err := scope.InterpretString(`for { break }`) 926 | if err != nil { 927 | t.Error(err) 928 | } 929 | } 930 | 931 | func TestForContinue(t *testing.T) { 932 | t.Parallel() 933 | 934 | scope := NewScope() 935 | 936 | out, err := scope.InterpretString(` 937 | a := 0 938 | for i:=0; i < 1; i++ { 939 | a = 1 940 | continue 941 | a = 2 942 | } 943 | a 944 | `) 945 | if err != nil { 946 | t.Error(err) 947 | } 948 | expected := 1 949 | if !reflect.DeepEqual(expected, out) { 950 | t.Errorf("Expected %#v got %#v.", expected, out) 951 | } 952 | } 953 | 954 | func TestForRangeArray(t *testing.T) { 955 | t.Parallel() 956 | 957 | scope := NewScope() 958 | 959 | out, err := scope.InterpretString(` 960 | a := 1 961 | for i, c := range []int{1,2,3} { 962 | a = a + i + c 963 | } 964 | a 965 | `) 966 | if err != nil { 967 | t.Error(err) 968 | } 969 | expected := 1 + 0 + 1 + 2 + 1 + 2 + 3 970 | if !reflect.DeepEqual(expected, out) { 971 | t.Errorf("Expected %#v got %#v.", expected, out) 972 | } 973 | } 974 | 975 | func TestForRangeMap(t *testing.T) { 976 | t.Parallel() 977 | 978 | scope := NewScope() 979 | 980 | out, err := scope.InterpretString(` 981 | a := 1 982 | for i, c := range map[int]int{0: 1, 1: 2, 2: 3} { 983 | a=a+i+c 984 | } 985 | a 986 | `) 987 | if err != nil { 988 | t.Error(err) 989 | } 990 | expected := 1 + 0 + 1 + 2 + 1 + 2 + 3 991 | if !reflect.DeepEqual(expected, out) { 992 | t.Errorf("Expected %#v got %#v.", expected, out) 993 | } 994 | } 995 | 996 | func TestSelectDefault(t *testing.T) { 997 | t.Parallel() 998 | 999 | scope := NewScope() 1000 | 1001 | out, err := scope.InterpretString(` 1002 | a := 0 1003 | c := make(chan int) 1004 | select { 1005 | case b := <-c: 1006 | a = b 1007 | default: 1008 | a = 1 1009 | } 1010 | a 1011 | `) 1012 | if err != nil { 1013 | t.Error(err) 1014 | } 1015 | expected := 1 1016 | if !reflect.DeepEqual(expected, out) { 1017 | t.Errorf("Expected %#v got %#v.", expected, out) 1018 | } 1019 | } 1020 | 1021 | func TestSelect(t *testing.T) { 1022 | t.Parallel() 1023 | 1024 | scope := NewScope() 1025 | 1026 | out, err := scope.InterpretString(` 1027 | a := 0 1028 | c := make(chan int, 10) 1029 | c <- 2 1030 | select { 1031 | case b := <-c: 1032 | a = b 1033 | default: 1034 | a = 1 1035 | } 1036 | a 1037 | `) 1038 | if err != nil { 1039 | t.Error(err) 1040 | } 1041 | expected := 2 1042 | if !reflect.DeepEqual(expected, out) { 1043 | t.Errorf("Expected %#v got %#v.", expected, out) 1044 | } 1045 | } 1046 | 1047 | func TestSelectMultiCase(t *testing.T) { 1048 | t.Parallel() 1049 | 1050 | scope := NewScope() 1051 | 1052 | out, err := scope.InterpretString(` 1053 | c := make(chan int, 10) 1054 | e := make(chan int, 10) 1055 | c <- 2 1056 | a := 0 1057 | select { 1058 | case d := <-e: 1059 | a = d 1060 | case b := <-c: 1061 | a = b 1062 | } 1063 | a 1064 | `) 1065 | if err != nil { 1066 | t.Error(err) 1067 | } 1068 | expected := 2 1069 | if !reflect.DeepEqual(expected, out) { 1070 | t.Errorf("Expected %#v got %#v.", expected, out) 1071 | } 1072 | } 1073 | 1074 | func TestSwitch(t *testing.T) { 1075 | t.Parallel() 1076 | 1077 | scope := NewScope() 1078 | 1079 | out, err := scope.InterpretString(` 1080 | a := 10 1081 | out := 0 1082 | switch a { 1083 | case 10: 1084 | out = 1 1085 | default: 1086 | out = 2 1087 | } 1088 | out 1089 | `) 1090 | if err != nil { 1091 | t.Error(err) 1092 | } 1093 | expected := 1 1094 | if !reflect.DeepEqual(expected, out) { 1095 | t.Errorf("Expected %#v got %#v.", expected, out) 1096 | } 1097 | } 1098 | 1099 | func TestSwitchDefault(t *testing.T) { 1100 | t.Parallel() 1101 | 1102 | scope := NewScope() 1103 | 1104 | out, err := scope.InterpretString(` 1105 | a := 0 1106 | out := 0 1107 | switch a { 1108 | case 10: 1109 | out = 1 1110 | default: 1111 | out = 2 1112 | } 1113 | out 1114 | `) 1115 | if err != nil { 1116 | t.Error(err) 1117 | } 1118 | expected := 2 1119 | if !reflect.DeepEqual(expected, out) { 1120 | t.Errorf("Expected %#v got %#v.", expected, out) 1121 | } 1122 | } 1123 | 1124 | func TestSwitchBool(t *testing.T) { 1125 | t.Parallel() 1126 | 1127 | scope := NewScope() 1128 | 1129 | out, err := scope.InterpretString(` 1130 | out := 0 1131 | switch { 1132 | case true: 1133 | out = 1 1134 | } 1135 | out 1136 | `) 1137 | if err != nil { 1138 | t.Error(err) 1139 | } 1140 | expected := 1 1141 | if !reflect.DeepEqual(expected, out) { 1142 | t.Errorf("Expected %#v got %#v.", expected, out) 1143 | } 1144 | } 1145 | 1146 | func TestSwitchType(t *testing.T) { 1147 | t.Parallel() 1148 | 1149 | scope := NewScope() 1150 | 1151 | out, err := scope.InterpretString(` 1152 | out := 0 1153 | var t interface{} 1154 | t = 10 1155 | switch t.(type){ 1156 | case int: 1157 | out = 1 1158 | case bool: 1159 | out = 2 1160 | } 1161 | out 1162 | `) 1163 | if err != nil { 1164 | t.Error(err) 1165 | } 1166 | expected := 1 1167 | if !reflect.DeepEqual(expected, out) { 1168 | t.Errorf("Expected %#v got %#v.", expected, out) 1169 | } 1170 | } 1171 | 1172 | func TestSwitchTypeUse(t *testing.T) { 1173 | t.Parallel() 1174 | 1175 | scope := NewScope() 1176 | 1177 | out, err := scope.InterpretString(` 1178 | out := 0 1179 | var t interface{} 1180 | t = 10 1181 | switch t := t.(type){ 1182 | case int: 1183 | out = t 1184 | case bool: 1185 | out = 2 1186 | } 1187 | out 1188 | `) 1189 | if err != nil { 1190 | t.Error(err) 1191 | } 1192 | expected := 10 1193 | if !reflect.DeepEqual(expected, out) { 1194 | t.Errorf("Expected %#v got %#v.", expected, out) 1195 | } 1196 | } 1197 | 1198 | func TestSwitchNone(t *testing.T) { 1199 | t.Parallel() 1200 | 1201 | scope := NewScope() 1202 | 1203 | out, err := scope.InterpretString(` 1204 | out := 0 1205 | switch { 1206 | case false: 1207 | out = 1 1208 | } 1209 | out 1210 | `) 1211 | if err != nil { 1212 | t.Error(err) 1213 | } 1214 | expected := 0 1215 | if !reflect.DeepEqual(expected, out) { 1216 | t.Errorf("Expected %#v got %#v.", expected, out) 1217 | } 1218 | } 1219 | 1220 | func TestIf(t *testing.T) { 1221 | t.Parallel() 1222 | 1223 | scope := NewScope() 1224 | 1225 | out, err := scope.InterpretString(` 1226 | a := 0 1227 | if true { 1228 | a = 1 1229 | } else { 1230 | a = 2 1231 | } 1232 | a 1233 | `) 1234 | if err != nil { 1235 | t.Error(err) 1236 | } 1237 | expected := 1 1238 | if !reflect.DeepEqual(expected, out) { 1239 | t.Errorf("Expected %#v got %#v.", expected, out) 1240 | } 1241 | } 1242 | 1243 | func TestIfElse(t *testing.T) { 1244 | t.Parallel() 1245 | 1246 | scope := NewScope() 1247 | 1248 | out, err := scope.InterpretString(` 1249 | a := 0 1250 | if false { 1251 | a = 1 1252 | } else { 1253 | a = 2 1254 | } 1255 | a 1256 | `) 1257 | if err != nil { 1258 | t.Error(err) 1259 | } 1260 | expected := 2 1261 | if !reflect.DeepEqual(expected, out) { 1262 | t.Errorf("Expected %#v got %#v.", expected, out) 1263 | } 1264 | } 1265 | 1266 | func TestIfIfElse(t *testing.T) { 1267 | t.Parallel() 1268 | 1269 | scope := NewScope() 1270 | 1271 | out, err := scope.InterpretString(` 1272 | a := 0 1273 | if false { 1274 | a = 1 1275 | } else if true { 1276 | a = 2 1277 | } 1278 | a 1279 | `) 1280 | if err != nil { 1281 | t.Error(err) 1282 | } 1283 | expected := 2 1284 | if !reflect.DeepEqual(expected, out) { 1285 | t.Errorf("Expected %#v got %#v.", expected, out) 1286 | } 1287 | } 1288 | 1289 | func TestFunctionArgs(t *testing.T) { 1290 | t.Parallel() 1291 | 1292 | scope := NewScope() 1293 | 1294 | out, err := scope.InterpretString(` 1295 | f := func(b, c int) int { 1296 | return b + c 1297 | } 1298 | f(10, 5) 1299 | `) 1300 | if err != nil { 1301 | t.Error(err) 1302 | } 1303 | expected := 15 1304 | if !reflect.DeepEqual(expected, out) { 1305 | t.Errorf("Expected %#v got %#v.", expected, out) 1306 | } 1307 | } 1308 | 1309 | func TestFunctionArgsBad(t *testing.T) { 1310 | t.Parallel() 1311 | 1312 | scope := NewScope() 1313 | 1314 | scope.Set("f", func(b, c int) int { 1315 | return b + c 1316 | }) 1317 | 1318 | _, err := scope.InterpretString(` 1319 | f(10.0, "foo") 1320 | `) 1321 | if err == nil { 1322 | t.Fatalf("expected error") 1323 | } 1324 | } 1325 | 1326 | func TestDefer(t *testing.T) { 1327 | t.Parallel() 1328 | 1329 | scope := NewScope() 1330 | 1331 | out, err := scope.InterpretString(` 1332 | a := 0 1333 | f := func() { 1334 | defer func() { 1335 | a = 2 1336 | }() 1337 | defer func() { 1338 | a = 3 1339 | }() 1340 | a = 1 1341 | } 1342 | f() 1343 | a 1344 | `) 1345 | if err != nil { 1346 | t.Error(err) 1347 | } 1348 | expected := 2 1349 | if !reflect.DeepEqual(expected, out) { 1350 | t.Errorf("Expected %#v got %#v.", expected, out) 1351 | } 1352 | } 1353 | 1354 | func TestStringAppend(t *testing.T) { 1355 | t.Parallel() 1356 | 1357 | scope := NewScope() 1358 | 1359 | out, err := scope.InterpretString(` 1360 | a := "foo" 1361 | a += "bar" 1362 | a 1363 | `) 1364 | if err != nil { 1365 | t.Error(err) 1366 | } 1367 | expected := "foobar" 1368 | if !reflect.DeepEqual(expected, out) { 1369 | t.Errorf("Expected %#v got %#v.", expected, out) 1370 | } 1371 | } 1372 | 1373 | func TestIntMod(t *testing.T) { 1374 | t.Parallel() 1375 | 1376 | scope := NewScope() 1377 | 1378 | out, err := scope.InterpretString(` 1379 | a := 10 1380 | a += 6 1381 | a -= 1 1382 | a /= 3 1383 | a *= 4 1384 | a 1385 | `) 1386 | if err != nil { 1387 | t.Error(err) 1388 | } 1389 | expected := 20 1390 | if !reflect.DeepEqual(expected, out) { 1391 | t.Errorf("Expected %#v got %#v.", expected, out) 1392 | } 1393 | } 1394 | 1395 | // TODO Packages 1396 | 1397 | // TODO References 1398 | -------------------------------------------------------------------------------- /pry/io_default.go: -------------------------------------------------------------------------------- 1 | // +build !js 2 | 3 | package pry 4 | 5 | import ( 6 | "encoding/json" 7 | "io/ioutil" 8 | "log" 9 | "path" 10 | 11 | homedir "github.com/mitchellh/go-homedir" 12 | "github.com/pkg/errors" 13 | ) 14 | 15 | var readFile = ioutil.ReadFile 16 | 17 | var historyFile = ".go-pry_history" 18 | 19 | type ioHistory struct { 20 | FileName string 21 | FilePath string 22 | Records []string 23 | } 24 | 25 | // NewHistory constructs ioHistory instance 26 | func NewHistory() (*ioHistory, error) { 27 | h := ioHistory{} 28 | h.FileName = historyFile 29 | 30 | dir, err := homedir.Dir() 31 | if err != nil { 32 | log.Printf("Error finding user home dir: %s", err) 33 | return nil, err 34 | } 35 | h.FilePath = path.Join(dir, h.FileName) 36 | 37 | return &h, nil 38 | } 39 | 40 | // Load unmarshal history file into history's records 41 | func (h *ioHistory) Load() error { 42 | body, err := ioutil.ReadFile(h.FilePath) 43 | if err != nil { 44 | return errors.Wrapf(err, "History file not found") 45 | } 46 | var records []string 47 | if err := json.Unmarshal(body, &records); err != nil { 48 | return errors.Wrapf(err, "Error reading history file") 49 | } 50 | 51 | h.Records = records 52 | return nil 53 | } 54 | 55 | // Save saves marshaled history's records into file 56 | func (h ioHistory) Save() error { 57 | body, err := json.Marshal(h.Records) 58 | if err != nil { 59 | return errors.Wrapf(err, "error marshaling history") 60 | } 61 | if err := ioutil.WriteFile(h.FilePath, body, 0755); err != nil { 62 | return errors.Wrapf(err, "error writing history to the file") 63 | } 64 | 65 | return nil 66 | } 67 | 68 | // Len returns amount of records in history 69 | func (h ioHistory) Len() int { return len(h.Records) } 70 | 71 | // Add appends record into history's records 72 | func (h *ioHistory) Add(record string) { 73 | h.Records = append(h.Records, record) 74 | } 75 | -------------------------------------------------------------------------------- /pry/io_default_test.go: -------------------------------------------------------------------------------- 1 | // +build !js 2 | 3 | package pry 4 | 5 | import ( 6 | "fmt" 7 | "math/rand" 8 | "os" 9 | "path/filepath" 10 | "reflect" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | func TestHistory(t *testing.T) { 16 | t.Parallel() 17 | 18 | rand.Seed(time.Now().UnixNano()) 19 | 20 | history := &ioHistory{} 21 | history.FileName = ".go-pry_history_test" 22 | history.FilePath = filepath.Join(os.TempDir(), history.FileName) 23 | 24 | expected := []string{ 25 | "test", 26 | fmt.Sprintf("rand: %d", rand.Int63()), 27 | } 28 | history.Add(expected[0]) 29 | history.Add(expected[1]) 30 | 31 | if err := history.Save(); err != nil { 32 | t.Error("Failed to save history") 33 | } 34 | 35 | if err := history.Load(); err != nil { 36 | t.Error("Failed to load history") 37 | } 38 | 39 | if !reflect.DeepEqual(expected, history.Records) { 40 | t.Errorf("history.Load() = %+v; expected %+v", history.Records, expected) 41 | } 42 | 43 | // delete test history file 44 | err := os.Remove(history.FilePath) 45 | if err != nil { 46 | t.Error(err) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /pry/io_js.go: -------------------------------------------------------------------------------- 1 | // +build js 2 | 3 | package pry 4 | 5 | import ( 6 | "encoding/json" 7 | "io" 8 | "io/ioutil" 9 | "path/filepath" 10 | "syscall/js" 11 | ) 12 | 13 | func readFile(path string) ([]byte, error) { 14 | path = filepath.Join("bundles", filepath.Base(path)) 15 | 16 | r, w := io.Pipe() 17 | var respCB js.Func 18 | respCB = js.FuncOf(func(this js.Value, args []js.Value) interface{} { 19 | defer respCB.Release() 20 | 21 | var textCB js.Func 22 | textCB = js.FuncOf(func(this js.Value, args []js.Value) interface{} { 23 | defer textCB.Release() 24 | 25 | w.Write([]byte(args[0].String())) 26 | w.Close() 27 | 28 | return nil 29 | }) 30 | args[0].Call("text").Call("then", textCB) 31 | 32 | return nil 33 | }) 34 | js.Global().Call("fetch", path).Call("then", respCB) 35 | return ioutil.ReadAll(r) 36 | } 37 | 38 | type browserHistory struct { 39 | Records []string 40 | } 41 | 42 | // NewHistory constructs browserHistory instance 43 | func NewHistory() (*browserHistory, error) { 44 | 45 | // FIXME: 46 | // when localStorage is full, can be return an error 47 | 48 | return &browserHistory{}, nil 49 | } 50 | 51 | // Load unmarshal localStorage data into history's records 52 | func (bh *browserHistory) Load() error { 53 | hist := js.Global().Get("localStorage").Get("history") 54 | if hist.Type() == js.TypeUndefined { 55 | return nil // nothing to unmarashal 56 | } 57 | var records []string 58 | if err := json.Unmarshal([]byte(hist.String()), &records); err != nil { 59 | return err 60 | } 61 | bh.Records = records 62 | 63 | return nil 64 | } 65 | 66 | // Save saves marshaled history's records into localStorage 67 | func (bh browserHistory) Save() error { 68 | bytes, err := json.Marshal(bh.Records) 69 | if err != nil { 70 | return err 71 | } 72 | js.Global().Get("localStorage").Set("history", string(bytes)) 73 | 74 | return nil 75 | } 76 | 77 | // Len returns amount of records in history 78 | func (bh browserHistory) Len() int { return len(bh.Records) } 79 | 80 | // Add appends record into history's records 81 | func (bh *browserHistory) Add(record string) { 82 | bh.Records = append(bh.Records, record) 83 | } 84 | -------------------------------------------------------------------------------- /pry/package.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | // Package represents a Go package for use with pry 4 | type Package struct { 5 | Name string 6 | Functions map[string]interface{} 7 | } 8 | 9 | func (p Package) Keys() []string { 10 | var keys []string 11 | for k := range p.Functions { 12 | keys = append(keys, k) 13 | } 14 | return keys 15 | } 16 | 17 | func (p Package) Get(key string) (interface{}, bool) { 18 | v, ok := p.Functions[key] 19 | return v, ok 20 | } 21 | -------------------------------------------------------------------------------- /pry/pry.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "log" 7 | "path/filepath" 8 | "runtime" 9 | "strings" 10 | 11 | "go/ast" 12 | 13 | "github.com/mgutz/ansi" 14 | ) 15 | 16 | // Pry does nothing. It only exists so running code without go-pry doesn't throw an error. 17 | func Pry(v ...interface{}) { 18 | } 19 | 20 | // Apply drops into a pry shell in the location required. 21 | func Apply(scope *Scope) { 22 | out, tty := openTTY() 23 | defer tty.Close() 24 | 25 | _, filePathRaw, lineNum, _ := runtime.Caller(1) 26 | filePath := filepath.Dir(filePathRaw) + "/." + filepath.Base(filePathRaw) + "pry" 27 | 28 | if err := apply(scope, out, tty, filePath, filePathRaw, lineNum); err != nil { 29 | log.Fatalf("%+v", err) 30 | } 31 | } 32 | 33 | type genericTTY interface { 34 | ReadRune() (rune, error) 35 | Size() (int, int, error) 36 | Close() error 37 | } 38 | 39 | func apply( 40 | scope *Scope, 41 | out io.Writer, 42 | tty genericTTY, 43 | filePath, filePathRaw string, 44 | lineNum int, 45 | ) error { 46 | if scope.Files == nil { 47 | scope.Files = map[string]*ast.File{} 48 | } 49 | 50 | if err := scope.ConfigureTypes(filePath, lineNum); err != nil { 51 | return err 52 | } 53 | 54 | displayFilePosition(out, filePathRaw, filePath, lineNum) 55 | 56 | history, err := NewHistory() 57 | if err != nil { 58 | fmt.Errorf("Failed to initiliaze history %+v", err) 59 | } 60 | if err := history.Load(); err != nil { 61 | fmt.Errorf("Failed to load the history %+v", err) 62 | } 63 | 64 | currentPos := history.Len() 65 | 66 | line := "" 67 | count := history.Len() 68 | index := 0 69 | r := rune(0) 70 | for { 71 | prompt := fmt.Sprintf("[%d] go-pry> ", currentPos) 72 | fmt.Fprintf(out, "\r\033[K%s%s \033[0J\033[%dD", prompt, Highlight(line), len(line)-index+1) 73 | 74 | promptWidth := len(prompt) + index 75 | displaySuggestions(scope, out, tty, line, index, promptWidth) 76 | 77 | bPrev := r 78 | 79 | r = 0 80 | for r == 0 { 81 | var err error 82 | r, err = tty.ReadRune() 83 | if err != nil { 84 | return err 85 | } 86 | } 87 | switch r { 88 | default: 89 | if bPrev == 27 && r == 91 { 90 | continue 91 | } else if bPrev == 91 { 92 | switch r { 93 | case 66: // Down 94 | currentPos++ 95 | if history.Len() < currentPos { 96 | currentPos = history.Len() 97 | } 98 | if history.Len() == currentPos { 99 | line = "" 100 | } else { 101 | line = history.Records[currentPos] 102 | } 103 | index = len(line) 104 | case 65: // Up 105 | currentPos-- 106 | if currentPos < 0 { 107 | currentPos = 0 108 | } 109 | if history.Len() > 0 { 110 | line = history.Records[currentPos] 111 | } 112 | index = len(line) 113 | case 67: // Right 114 | index++ 115 | if index > len(line) { 116 | index = len(line) 117 | } 118 | case 68: // Left 119 | index-- 120 | if index < 0 { 121 | index = 0 122 | } 123 | } 124 | continue 125 | } else if bPrev == 51 && r == 126 { // DELETE 126 | if len(line) > 0 && index < len(line) { 127 | line = line[:index] + line[index+1:] 128 | } 129 | if index > len(line) { 130 | index = len(line) 131 | } 132 | continue 133 | } 134 | line = line[:index] + string(r) + line[index:] 135 | index++ 136 | case 127, '\b': // Backspace 137 | if len(line) > 0 && index > 0 { 138 | line = line[:index-1] + line[index:] 139 | index-- 140 | } 141 | if index > len(line) { 142 | index = len(line) 143 | } 144 | case 27: // ? This happens on key press 145 | case 9: //TAB 146 | case 10, 13: //ENTER 147 | fmt.Fprintln(out, "\033[100000C\033[0J") 148 | if len(line) == 0 { 149 | continue 150 | } 151 | if line == "continue" || line == "exit" { 152 | return nil 153 | } 154 | resp, err := scope.InterpretString(line) 155 | if err != nil { 156 | fmt.Fprintln(out, "Error: ", err, resp) 157 | } else { 158 | respStr := Highlight(fmt.Sprintf("%#v", resp)) 159 | fmt.Fprintf(out, "=> %s\n", respStr) 160 | } 161 | history.Add(line) 162 | err = history.Save() 163 | if err != nil { 164 | fmt.Fprintln(out, "Error: ", err) 165 | } 166 | 167 | count++ 168 | currentPos = count 169 | line = "" 170 | index = 0 171 | case 4: // Ctrl-D 172 | fmt.Fprintln(out) 173 | return nil 174 | } 175 | } 176 | } 177 | 178 | func displayFilePosition( 179 | out io.Writer, filePathRaw, filePath string, lineNum int, 180 | ) { 181 | fmt.Fprintf(out, "\nFrom %s @ line %d :\n\n", filePathRaw, lineNum) 182 | file, err := readFile(filePath) 183 | if err != nil { 184 | fmt.Fprintln(out, err) 185 | } 186 | lines := strings.Split((string)(file), "\n") 187 | lineNum-- 188 | start := lineNum - 5 189 | if start < 0 { 190 | start = 0 191 | } 192 | end := lineNum + 6 193 | if end > len(lines) { 194 | end = len(lines) 195 | } 196 | maxLen := len(fmt.Sprint(end)) 197 | for i := start; i < end; i++ { 198 | caret := " " 199 | if i == lineNum { 200 | caret = "=>" 201 | } 202 | numStr := fmt.Sprint(i + 1) 203 | if len(numStr) < maxLen { 204 | numStr = " " + numStr 205 | } 206 | num := ansi.Color(numStr, "blue+b") 207 | highlightedLine := Highlight(strings.Replace(lines[i], "\t", " ", -1)) 208 | fmt.Fprintf(out, " %s %s: %s\n", caret, num, highlightedLine) 209 | } 210 | fmt.Fprintln(out) 211 | } 212 | 213 | // displaySuggestions renders the live autocomplete from GoCode. 214 | func displaySuggestions( 215 | scope *Scope, 216 | out io.Writer, 217 | tty genericTTY, 218 | line string, 219 | index, promptWidth int, 220 | ) { 221 | var err error 222 | var suggestions []string 223 | if runtime.GOOS == "js" { 224 | suggestions, err = scope.SuggestionsPry(line, index) 225 | } else { 226 | suggestions, err = scope.SuggestionsGoCode(line, index) 227 | } 228 | if err != nil { 229 | suggestions = []string{"ERR", err.Error()} 230 | } 231 | 232 | maxLength := 0 233 | if len(suggestions) > 10 { 234 | suggestions = suggestions[:10] 235 | } 236 | for _, term := range suggestions { 237 | if len(term) > maxLength { 238 | maxLength = len(term) 239 | } 240 | } 241 | termWidth, _, _ := tty.Size() 242 | for _, term := range suggestions { 243 | paddedTerm := term 244 | for len(paddedTerm) < maxLength { 245 | paddedTerm += " " 246 | } 247 | var leftPadding string 248 | for i := 0; i < promptWidth; i++ { 249 | leftPadding += " " 250 | } 251 | if promptWidth > termWidth { 252 | return 253 | } else if len(paddedTerm)+promptWidth > termWidth { 254 | paddedTerm = paddedTerm[:termWidth-promptWidth] 255 | } 256 | fmt.Fprintf(out, "\n%s%s\033[%dD", leftPadding, ansi.Color(paddedTerm, "white+b:magenta"), len(paddedTerm)) 257 | } 258 | if len(suggestions) > 0 { 259 | fmt.Fprintf(out, "\033[%dA", len(suggestions)) 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /pry/pry_test.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "io" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | "path" 9 | "reflect" 10 | "testing" 11 | "time" 12 | 13 | "github.com/cenkalti/backoff" 14 | "github.com/d4l3k/go-pry/pry/safebuffer" 15 | "github.com/pkg/errors" 16 | ) 17 | 18 | func TestCLIBasicStatement(t *testing.T) { 19 | t.Parallel() 20 | 21 | env := testPryApply(t) 22 | defer env.Close() 23 | 24 | env.Write([]byte("a := 10\n")) 25 | 26 | succeedsSoon(t, func() error { 27 | out, _ := env.Get("a") 28 | want := 10 29 | if !reflect.DeepEqual(out, want) { 30 | return errors.Errorf( 31 | "expected a = %d; got %d\nOutput:\n%s\n", want, out, env.Output()) 32 | } 33 | return nil 34 | }) 35 | } 36 | 37 | func TestCLIHistory(t *testing.T) { 38 | t.Parallel() 39 | 40 | env := testPryApply(t) 41 | defer env.Close() 42 | 43 | env.Write([]byte("var a int\na = 1\na = 2\na = 3\n")) 44 | // down down up up up down enter 45 | env.Write([]byte("\x1b\x5b\x42\x1b\x5b\x42\x1b\x5b\x41\x1b\x5b\x41\x1b\x5b\x41\x1b\x5b\x42\n")) 46 | 47 | succeedsSoon(t, func() error { 48 | out, _ := env.Get("a") 49 | want := 2 50 | if !reflect.DeepEqual(out, want) { 51 | return errors.Errorf( 52 | "expected a = %d; got %d\nOutput:\n%s\n", want, out, env.Output()) 53 | } 54 | return nil 55 | }) 56 | } 57 | 58 | func TestCLIEditingArrows(t *testing.T) { 59 | t.Parallel() 60 | 61 | env := testPryApply(t) 62 | defer env.Close() 63 | 64 | env.Write([]byte("a := 100")) 65 | // left left backspace 2 right right 5 enter 66 | env.Write([]byte("\x1b\x5b\x44\x1b\x5b\x44\b2\x1b\x5b\x43\x1b\x5b\x435\n")) 67 | 68 | succeedsSoon(t, func() error { 69 | out, _ := env.Get("a") 70 | want := 2005 71 | if !reflect.DeepEqual(out, want) { 72 | return errors.Errorf( 73 | "expected a = %d; got %d\nOutput:\n%s\n", want, out, env.Output()) 74 | } 75 | return nil 76 | }) 77 | } 78 | 79 | type testTTY struct { 80 | *io.PipeReader 81 | *io.PipeWriter 82 | } 83 | 84 | func makeTestTTY() *testTTY { 85 | r, w := io.Pipe() 86 | return &testTTY{r, w} 87 | } 88 | 89 | func (t *testTTY) ReadRune() (rune, error) { 90 | buf := make([]byte, 1) 91 | _, err := t.PipeReader.Read(buf) 92 | return rune(buf[0]), err 93 | } 94 | 95 | func (t *testTTY) Size() (int, int, error) { 96 | return 10000, 100, nil 97 | } 98 | 99 | func (t *testTTY) Close() error { 100 | t.PipeReader.Close() 101 | return t.PipeWriter.Close() 102 | } 103 | 104 | type testPryEnv struct { 105 | stdout *safebuffer.Buffer 106 | *testTTY 107 | *Scope 108 | dir, file string 109 | } 110 | 111 | func testPryApply(t testing.TB) *testPryEnv { 112 | var stdout safebuffer.Buffer 113 | tty := makeTestTTY() 114 | scope := NewScope() 115 | 116 | wd, err := os.Getwd() 117 | if err != nil { 118 | t.Fatal(err) 119 | } 120 | log.Printf("cwd %+v", wd) 121 | 122 | dir, err := ioutil.TempDir(wd, "go-pry-test") 123 | if err != nil { 124 | t.Fatal(err) 125 | } 126 | file, err := os.OpenFile( 127 | path.Join(dir, "main.go"), 128 | os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 129 | 0755, 130 | ) 131 | if err != nil { 132 | t.Fatal(err) 133 | } 134 | if _, err := file.Write([]byte( 135 | `package main 136 | 137 | import "github.com/d4l3k/go-pry/pry" 138 | 139 | func main() { 140 | pry.Pry() 141 | } 142 | `, 143 | )); err != nil { 144 | t.Fatal(err) 145 | } 146 | file.Close() 147 | 148 | filePath := file.Name() 149 | lineNum := 2 150 | 151 | go func() { 152 | if err := apply(scope, &stdout, tty, filePath, filePath, lineNum); err != nil { 153 | log.Fatalf("%+v", err) 154 | } 155 | }() 156 | 157 | return &testPryEnv{ 158 | stdout: &stdout, 159 | testTTY: tty, 160 | Scope: scope, 161 | dir: dir, 162 | file: filePath, 163 | } 164 | } 165 | 166 | func (env *testPryEnv) Output() string { 167 | return env.stdout.String() 168 | } 169 | 170 | func (env *testPryEnv) Close() { 171 | env.Write([]byte("\nexit\n")) 172 | env.testTTY.Close() 173 | os.RemoveAll(env.file) 174 | os.RemoveAll(env.dir) 175 | } 176 | 177 | func succeedsSoon(t testing.TB, f func() error) { 178 | b := backoff.NewExponentialBackOff() 179 | b.MaxElapsedTime = 5 * time.Second 180 | if err := backoff.Retry(f, b); err != nil { 181 | t.Fatal(errors.Wrapf(err, "failed after 5 seconds")) 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /pry/pseudo_generics.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "fmt" 5 | "go/token" 6 | "reflect" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | // ErrChanRecvFailed occurs when a channel is closed. 12 | var ErrChanRecvFailed = errors.New("receive failed: channel closed") 13 | 14 | // ErrChanRecvInSelect is an internal error that is used to indicate it's in a 15 | // select statement. 16 | var ErrChanRecvInSelect = errors.New("receive failed: in select") 17 | 18 | var ErrDivisionByZero = errors.New("division by zero") 19 | 20 | // DeAssign takes a *_ASSIGN token and returns the corresponding * token. 21 | func DeAssign(tok token.Token) token.Token { 22 | switch tok { 23 | case token.ADD_ASSIGN: 24 | return token.ADD 25 | case token.SUB_ASSIGN: 26 | return token.SUB 27 | case token.MUL_ASSIGN: 28 | return token.MUL 29 | case token.QUO_ASSIGN: 30 | return token.QUO 31 | case token.REM_ASSIGN: 32 | return token.REM 33 | case token.AND_ASSIGN: 34 | return token.AND 35 | case token.OR_ASSIGN: 36 | return token.OR 37 | case token.XOR_ASSIGN: 38 | return token.XOR 39 | case token.SHL_ASSIGN: 40 | return token.SHL 41 | case token.SHR_ASSIGN: 42 | return token.SHR 43 | case token.AND_NOT_ASSIGN: 44 | return token.AND_NOT 45 | } 46 | 47 | return tok 48 | } 49 | 50 | // ComputeBinaryOp executes the corresponding binary operation (+, -, etc) on two interfaces. 51 | func ComputeBinaryOp(xI, yI interface{}, op token.Token) (interface{}, error) { 52 | typeX := reflect.TypeOf(xI) 53 | typeY := reflect.TypeOf(yI) 54 | if typeX == typeY { 55 | switch xI.(type) { 56 | case string: 57 | x := xI.(string) 58 | y := yI.(string) 59 | switch op { 60 | case token.ADD: 61 | return x + y, nil 62 | } 63 | case int: 64 | x := xI.(int) 65 | y := yI.(int) 66 | switch op { 67 | case token.ADD: 68 | return x + y, nil 69 | case token.SUB: 70 | return x - y, nil 71 | case token.MUL: 72 | return x * y, nil 73 | case token.QUO: 74 | if y == 0 { 75 | return nil, ErrDivisionByZero 76 | } 77 | 78 | return x / y, nil 79 | case token.REM: 80 | if y == 0 { 81 | return nil, ErrDivisionByZero 82 | } 83 | 84 | return x % y, nil 85 | case token.AND: 86 | return x & y, nil 87 | case token.OR: 88 | return x | y, nil 89 | case token.XOR: 90 | return x ^ y, nil 91 | case token.AND_NOT: 92 | return x &^ y, nil 93 | case token.LSS: 94 | return x < y, nil 95 | case token.GTR: 96 | return x > y, nil 97 | case token.LEQ: 98 | return x <= y, nil 99 | case token.GEQ: 100 | return x >= y, nil 101 | } 102 | case int8: 103 | x := xI.(int8) 104 | y := yI.(int8) 105 | switch op { 106 | case token.ADD: 107 | return x + y, nil 108 | case token.SUB: 109 | return x - y, nil 110 | case token.MUL: 111 | return x * y, nil 112 | case token.QUO: 113 | if y == 0 { 114 | return nil, ErrDivisionByZero 115 | } 116 | 117 | return x / y, nil 118 | case token.REM: 119 | if y == 0 { 120 | return nil, ErrDivisionByZero 121 | } 122 | 123 | return x % y, nil 124 | case token.AND: 125 | return x & y, nil 126 | case token.OR: 127 | return x | y, nil 128 | case token.XOR: 129 | return x ^ y, nil 130 | case token.AND_NOT: 131 | return x &^ y, nil 132 | case token.LSS: 133 | return x < y, nil 134 | case token.GTR: 135 | return x > y, nil 136 | case token.LEQ: 137 | return x <= y, nil 138 | case token.GEQ: 139 | return x >= y, nil 140 | } 141 | case int16: 142 | x := xI.(int16) 143 | y := yI.(int16) 144 | switch op { 145 | case token.ADD: 146 | return x + y, nil 147 | case token.SUB: 148 | return x - y, nil 149 | case token.MUL: 150 | return x * y, nil 151 | case token.QUO: 152 | if y == 0 { 153 | return nil, ErrDivisionByZero 154 | } 155 | 156 | return x / y, nil 157 | case token.REM: 158 | if y == 0 { 159 | return nil, ErrDivisionByZero 160 | } 161 | 162 | return x % y, nil 163 | case token.AND: 164 | return x & y, nil 165 | case token.OR: 166 | return x | y, nil 167 | case token.XOR: 168 | return x ^ y, nil 169 | case token.AND_NOT: 170 | return x &^ y, nil 171 | case token.LSS: 172 | return x < y, nil 173 | case token.GTR: 174 | return x > y, nil 175 | case token.LEQ: 176 | return x <= y, nil 177 | case token.GEQ: 178 | return x >= y, nil 179 | } 180 | case int32: 181 | x := xI.(int32) 182 | y := yI.(int32) 183 | switch op { 184 | case token.ADD: 185 | return x + y, nil 186 | case token.SUB: 187 | return x - y, nil 188 | case token.MUL: 189 | return x * y, nil 190 | case token.QUO: 191 | if y == 0 { 192 | return nil, ErrDivisionByZero 193 | } 194 | 195 | return x / y, nil 196 | case token.REM: 197 | if y == 0 { 198 | return nil, ErrDivisionByZero 199 | } 200 | 201 | return x % y, nil 202 | case token.AND: 203 | return x & y, nil 204 | case token.OR: 205 | return x | y, nil 206 | case token.XOR: 207 | return x ^ y, nil 208 | case token.AND_NOT: 209 | return x &^ y, nil 210 | case token.LSS: 211 | return x < y, nil 212 | case token.GTR: 213 | return x > y, nil 214 | case token.LEQ: 215 | return x <= y, nil 216 | case token.GEQ: 217 | return x >= y, nil 218 | } 219 | case int64: 220 | x := xI.(int64) 221 | y := yI.(int64) 222 | switch op { 223 | case token.ADD: 224 | return x + y, nil 225 | case token.SUB: 226 | return x - y, nil 227 | case token.MUL: 228 | return x * y, nil 229 | case token.QUO: 230 | if y == 0 { 231 | return nil, ErrDivisionByZero 232 | } 233 | 234 | return x / y, nil 235 | case token.REM: 236 | if y == 0 { 237 | return nil, ErrDivisionByZero 238 | } 239 | 240 | return x % y, nil 241 | case token.AND: 242 | return x & y, nil 243 | case token.OR: 244 | return x | y, nil 245 | case token.XOR: 246 | return x ^ y, nil 247 | case token.AND_NOT: 248 | return x &^ y, nil 249 | case token.LSS: 250 | return x < y, nil 251 | case token.GTR: 252 | return x > y, nil 253 | case token.LEQ: 254 | return x <= y, nil 255 | case token.GEQ: 256 | return x >= y, nil 257 | } 258 | case uint: 259 | x := xI.(uint) 260 | y := yI.(uint) 261 | switch op { 262 | case token.ADD: 263 | return x + y, nil 264 | case token.SUB: 265 | return x - y, nil 266 | case token.MUL: 267 | return x * y, nil 268 | case token.QUO: 269 | if y == 0 { 270 | return nil, ErrDivisionByZero 271 | } 272 | 273 | return x / y, nil 274 | case token.REM: 275 | if y == 0 { 276 | return nil, ErrDivisionByZero 277 | } 278 | 279 | return x % y, nil 280 | case token.AND: 281 | return x & y, nil 282 | case token.OR: 283 | return x | y, nil 284 | case token.XOR: 285 | return x ^ y, nil 286 | case token.AND_NOT: 287 | return x &^ y, nil 288 | case token.LSS: 289 | return x < y, nil 290 | case token.GTR: 291 | return x > y, nil 292 | case token.LEQ: 293 | return x <= y, nil 294 | case token.GEQ: 295 | return x >= y, nil 296 | } 297 | case uint8: 298 | x := xI.(uint8) 299 | y := yI.(uint8) 300 | switch op { 301 | case token.ADD: 302 | return x + y, nil 303 | case token.SUB: 304 | return x - y, nil 305 | case token.MUL: 306 | return x * y, nil 307 | case token.QUO: 308 | if y == 0 { 309 | return nil, ErrDivisionByZero 310 | } 311 | 312 | return x / y, nil 313 | case token.REM: 314 | if y == 0 { 315 | return nil, ErrDivisionByZero 316 | } 317 | 318 | return x % y, nil 319 | case token.AND: 320 | return x & y, nil 321 | case token.OR: 322 | return x | y, nil 323 | case token.XOR: 324 | return x ^ y, nil 325 | case token.AND_NOT: 326 | return x &^ y, nil 327 | case token.LSS: 328 | return x < y, nil 329 | case token.GTR: 330 | return x > y, nil 331 | case token.LEQ: 332 | return x <= y, nil 333 | case token.GEQ: 334 | return x >= y, nil 335 | } 336 | case uint16: 337 | x := xI.(uint16) 338 | y := yI.(uint16) 339 | switch op { 340 | case token.ADD: 341 | return x + y, nil 342 | case token.SUB: 343 | return x - y, nil 344 | case token.MUL: 345 | return x * y, nil 346 | case token.QUO: 347 | if y == 0 { 348 | return nil, ErrDivisionByZero 349 | } 350 | 351 | return x / y, nil 352 | case token.REM: 353 | if y == 0 { 354 | return nil, ErrDivisionByZero 355 | } 356 | 357 | return x % y, nil 358 | case token.AND: 359 | return x & y, nil 360 | case token.OR: 361 | return x | y, nil 362 | case token.XOR: 363 | return x ^ y, nil 364 | case token.AND_NOT: 365 | return x &^ y, nil 366 | case token.LSS: 367 | return x < y, nil 368 | case token.GTR: 369 | return x > y, nil 370 | case token.LEQ: 371 | return x <= y, nil 372 | case token.GEQ: 373 | return x >= y, nil 374 | } 375 | case uint32: 376 | x := xI.(uint32) 377 | y := yI.(uint32) 378 | switch op { 379 | case token.ADD: 380 | return x + y, nil 381 | case token.SUB: 382 | return x - y, nil 383 | case token.MUL: 384 | return x * y, nil 385 | case token.QUO: 386 | if y == 0 { 387 | return nil, ErrDivisionByZero 388 | } 389 | 390 | return x / y, nil 391 | case token.REM: 392 | if y == 0 { 393 | return nil, ErrDivisionByZero 394 | } 395 | 396 | return x % y, nil 397 | case token.AND: 398 | return x & y, nil 399 | case token.OR: 400 | return x | y, nil 401 | case token.XOR: 402 | return x ^ y, nil 403 | case token.AND_NOT: 404 | return x &^ y, nil 405 | case token.LSS: 406 | return x < y, nil 407 | case token.GTR: 408 | return x > y, nil 409 | case token.LEQ: 410 | return x <= y, nil 411 | case token.GEQ: 412 | return x >= y, nil 413 | } 414 | case uint64: 415 | x := xI.(uint64) 416 | y := yI.(uint64) 417 | switch op { 418 | case token.ADD: 419 | return x + y, nil 420 | case token.SUB: 421 | return x - y, nil 422 | case token.MUL: 423 | return x * y, nil 424 | case token.QUO: 425 | if y == 0 { 426 | return nil, ErrDivisionByZero 427 | } 428 | 429 | return x / y, nil 430 | case token.REM: 431 | if y == 0 { 432 | return nil, ErrDivisionByZero 433 | } 434 | 435 | return x % y, nil 436 | case token.AND: 437 | return x & y, nil 438 | case token.OR: 439 | return x | y, nil 440 | case token.XOR: 441 | return x ^ y, nil 442 | case token.AND_NOT: 443 | return x &^ y, nil 444 | case token.LSS: 445 | return x < y, nil 446 | case token.GTR: 447 | return x > y, nil 448 | case token.LEQ: 449 | return x <= y, nil 450 | case token.GEQ: 451 | return x >= y, nil 452 | } 453 | case uintptr: 454 | x := xI.(uintptr) 455 | y := yI.(uintptr) 456 | switch op { 457 | case token.ADD: 458 | return x + y, nil 459 | case token.SUB: 460 | return x - y, nil 461 | case token.MUL: 462 | return x * y, nil 463 | case token.QUO: 464 | if y == 0 { 465 | return nil, ErrDivisionByZero 466 | } 467 | 468 | return x / y, nil 469 | case token.REM: 470 | if y == 0 { 471 | return nil, ErrDivisionByZero 472 | } 473 | 474 | return x % y, nil 475 | case token.AND: 476 | return x & y, nil 477 | case token.OR: 478 | return x | y, nil 479 | case token.XOR: 480 | return x ^ y, nil 481 | case token.AND_NOT: 482 | return x &^ y, nil 483 | case token.LSS: 484 | return x < y, nil 485 | case token.GTR: 486 | return x > y, nil 487 | case token.LEQ: 488 | return x <= y, nil 489 | case token.GEQ: 490 | return x >= y, nil 491 | } 492 | case complex64: 493 | x := xI.(complex64) 494 | y := yI.(complex64) 495 | switch op { 496 | case token.ADD: 497 | return x + y, nil 498 | case token.SUB: 499 | return x - y, nil 500 | case token.MUL: 501 | return x * y, nil 502 | case token.QUO: 503 | if y == 0 { 504 | return nil, ErrDivisionByZero 505 | } 506 | 507 | return x / y, nil 508 | } 509 | case complex128: 510 | x := xI.(complex128) 511 | y := yI.(complex128) 512 | switch op { 513 | case token.ADD: 514 | return x + y, nil 515 | case token.SUB: 516 | return x - y, nil 517 | case token.MUL: 518 | return x * y, nil 519 | case token.QUO: 520 | if y == 0 { 521 | return nil, ErrDivisionByZero 522 | } 523 | 524 | return x / y, nil 525 | } 526 | case float32: 527 | x := xI.(float32) 528 | y := yI.(float32) 529 | switch op { 530 | case token.ADD: 531 | return x + y, nil 532 | case token.SUB: 533 | return x - y, nil 534 | case token.MUL: 535 | return x * y, nil 536 | case token.QUO: 537 | if y == 0 { 538 | return nil, ErrDivisionByZero 539 | } 540 | 541 | return x / y, nil 542 | case token.LSS: 543 | return x < y, nil 544 | case token.GTR: 545 | return x > y, nil 546 | case token.LEQ: 547 | return x <= y, nil 548 | case token.GEQ: 549 | return x >= y, nil 550 | } 551 | case float64: 552 | x := xI.(float64) 553 | y := yI.(float64) 554 | switch op { 555 | case token.ADD: 556 | return x + y, nil 557 | case token.SUB: 558 | return x - y, nil 559 | case token.MUL: 560 | return x * y, nil 561 | case token.QUO: 562 | if y == 0 { 563 | return nil, ErrDivisionByZero 564 | } 565 | 566 | return x / y, nil 567 | case token.LSS: 568 | return x < y, nil 569 | case token.GTR: 570 | return x > y, nil 571 | case token.LEQ: 572 | return x <= y, nil 573 | case token.GEQ: 574 | return x >= y, nil 575 | } 576 | case bool: 577 | x := xI.(bool) 578 | y := yI.(bool) 579 | switch op { 580 | // Bool 581 | case token.LAND: 582 | return x && y, nil 583 | case token.LOR: 584 | return x || y, nil 585 | } 586 | } 587 | } 588 | yUint, isUint := yI.(uint64) 589 | if !isUint { 590 | isUint = true 591 | switch yV := yI.(type) { 592 | case int: 593 | yUint = uint64(yV) 594 | case int8: 595 | yUint = uint64(yV) 596 | case int16: 597 | yUint = uint64(yV) 598 | case int32: 599 | yUint = uint64(yV) 600 | case int64: 601 | yUint = uint64(yV) 602 | case uint: 603 | yUint = uint64(yV) 604 | case uintptr: 605 | yUint = uint64(yV) 606 | case uint8: 607 | yUint = uint64(yV) 608 | case uint16: 609 | yUint = uint64(yV) 610 | case uint32: 611 | yUint = uint64(yV) 612 | case float32: 613 | yUint = uint64(yV) 614 | case float64: 615 | yUint = uint64(yV) 616 | default: 617 | isUint = false 618 | } 619 | } 620 | if isUint { 621 | switch xI.(type) { 622 | case int: 623 | x := xI.(int) 624 | switch op { 625 | // Num, uint 626 | case token.SHL: 627 | return x << yUint, nil 628 | case token.SHR: 629 | return x >> yUint, nil 630 | } 631 | case int8: 632 | x := xI.(int8) 633 | switch op { 634 | // Num, uint 635 | case token.SHL: 636 | return x << yUint, nil 637 | case token.SHR: 638 | return x >> yUint, nil 639 | } 640 | case int16: 641 | x := xI.(int16) 642 | switch op { 643 | // Num, uint 644 | case token.SHL: 645 | return x << yUint, nil 646 | case token.SHR: 647 | return x >> yUint, nil 648 | } 649 | case int32: 650 | x := xI.(int32) 651 | switch op { 652 | // Num, uint 653 | case token.SHL: 654 | return x << yUint, nil 655 | case token.SHR: 656 | return x >> yUint, nil 657 | } 658 | case int64: 659 | x := xI.(int64) 660 | switch op { 661 | // Num, uint 662 | case token.SHL: 663 | return x << yUint, nil 664 | case token.SHR: 665 | return x >> yUint, nil 666 | } 667 | case uint: 668 | x := xI.(uint) 669 | switch op { 670 | // Num, uint 671 | case token.SHL: 672 | return x << yUint, nil 673 | case token.SHR: 674 | return x >> yUint, nil 675 | } 676 | case uint8: 677 | x := xI.(uint8) 678 | switch op { 679 | // Num, uint 680 | case token.SHL: 681 | return x << yUint, nil 682 | case token.SHR: 683 | return x >> yUint, nil 684 | } 685 | case uint16: 686 | x := xI.(uint16) 687 | switch op { 688 | // Num, uint 689 | case token.SHL: 690 | return x << yUint, nil 691 | case token.SHR: 692 | return x >> yUint, nil 693 | } 694 | case uint32: 695 | x := xI.(uint32) 696 | switch op { 697 | // Num, uint 698 | case token.SHL: 699 | return x << yUint, nil 700 | case token.SHR: 701 | return x >> yUint, nil 702 | } 703 | case uint64: 704 | x := xI.(uint64) 705 | switch op { 706 | // Num, uint 707 | case token.SHL: 708 | return x << yUint, nil 709 | case token.SHR: 710 | return x >> yUint, nil 711 | } 712 | case uintptr: 713 | x := xI.(uintptr) 714 | switch op { 715 | // Num, uint 716 | case token.SHL: 717 | return x << yUint, nil 718 | case token.SHR: 719 | return x >> yUint, nil 720 | } 721 | } 722 | } 723 | // Anything 724 | switch op { 725 | case token.EQL: 726 | return xI == yI, nil 727 | case token.NEQ: 728 | return xI != yI, nil 729 | } 730 | return nil, fmt.Errorf("unknown operation %#v between %#v and %#v", op, xI, yI) 731 | } 732 | 733 | // ComputeUnaryOp computes the corresponding unary (+x, -x) operation on an interface. 734 | func (scope *Scope) ComputeUnaryOp(xI interface{}, op token.Token) (interface{}, error) { 735 | if xI == nil { 736 | return nil, errors.Errorf("can't run unary ops on nil value") 737 | } 738 | 739 | switch op { 740 | case token.MUL: 741 | return reflect.ValueOf(xI).Elem().Interface(), nil 742 | } 743 | 744 | switch xI.(type) { 745 | case bool: 746 | x := xI.(bool) 747 | switch op { 748 | case token.NOT: 749 | return !x, nil 750 | } 751 | case int: 752 | x := xI.(int) 753 | switch op { 754 | case token.ADD: 755 | return +x, nil 756 | case token.SUB: 757 | return -x, nil 758 | } 759 | case int8: 760 | x := xI.(int8) 761 | switch op { 762 | case token.ADD: 763 | return +x, nil 764 | case token.SUB: 765 | return -x, nil 766 | } 767 | case int16: 768 | x := xI.(int16) 769 | switch op { 770 | case token.ADD: 771 | return +x, nil 772 | case token.SUB: 773 | return -x, nil 774 | } 775 | case int32: 776 | x := xI.(int32) 777 | switch op { 778 | case token.ADD: 779 | return +x, nil 780 | case token.SUB: 781 | return -x, nil 782 | } 783 | case int64: 784 | x := xI.(int64) 785 | switch op { 786 | case token.ADD: 787 | return +x, nil 788 | case token.SUB: 789 | return -x, nil 790 | } 791 | case uint: 792 | x := xI.(uint) 793 | switch op { 794 | case token.ADD: 795 | return +x, nil 796 | case token.SUB: 797 | return -x, nil 798 | } 799 | case uint8: 800 | x := xI.(uint8) 801 | switch op { 802 | case token.ADD: 803 | return +x, nil 804 | case token.SUB: 805 | return -x, nil 806 | } 807 | case uint16: 808 | x := xI.(uint16) 809 | switch op { 810 | case token.ADD: 811 | return +x, nil 812 | case token.SUB: 813 | return -x, nil 814 | } 815 | case uint32: 816 | x := xI.(uint32) 817 | switch op { 818 | case token.ADD: 819 | return +x, nil 820 | case token.SUB: 821 | return -x, nil 822 | } 823 | case uint64: 824 | x := xI.(uint64) 825 | switch op { 826 | case token.ADD: 827 | return +x, nil 828 | case token.SUB: 829 | return -x, nil 830 | } 831 | case uintptr: 832 | x := xI.(uintptr) 833 | switch op { 834 | case token.ADD: 835 | return +x, nil 836 | case token.SUB: 837 | return -x, nil 838 | } 839 | case float32: 840 | x := xI.(float32) 841 | switch op { 842 | case token.ADD: 843 | return +x, nil 844 | case token.SUB: 845 | return -x, nil 846 | } 847 | case float64: 848 | x := xI.(float64) 849 | switch op { 850 | case token.ADD: 851 | return +x, nil 852 | case token.SUB: 853 | return -x, nil 854 | } 855 | case complex64: 856 | x := xI.(complex64) 857 | switch op { 858 | case token.ADD: 859 | return +x, nil 860 | case token.SUB: 861 | return -x, nil 862 | } 863 | case complex128: 864 | x := xI.(complex128) 865 | switch op { 866 | case token.ADD: 867 | return +x, nil 868 | case token.SUB: 869 | return -x, nil 870 | } 871 | } 872 | 873 | switch reflect.TypeOf(xI).Kind() { 874 | case reflect.Chan: 875 | switch op { 876 | case token.ARROW: 877 | var v reflect.Value 878 | var ok bool 879 | if scope.isSelect { 880 | v, ok = reflect.ValueOf(xI).TryRecv() 881 | if !ok && !v.IsValid() { 882 | return nil, ErrChanRecvInSelect 883 | } 884 | } else { 885 | v, ok = reflect.ValueOf(xI).Recv() 886 | } 887 | if !ok { 888 | return nil, ErrChanRecvFailed 889 | } 890 | return v.Interface(), nil 891 | } 892 | } 893 | 894 | return nil, fmt.Errorf("unknown unary operation %#v on %#v", op, xI) 895 | } 896 | -------------------------------------------------------------------------------- /pry/safebuffer/safebuffer.go: -------------------------------------------------------------------------------- 1 | // safebuffer is a goroutine safe bytes.Buffer. 2 | // From https://gist.github.com/arkan/5924e155dbb4254b64614069ba0afd81 3 | package safebuffer 4 | 5 | import ( 6 | "bytes" 7 | "sync" 8 | ) 9 | 10 | // Buffer is a goroutine safe bytes.Buffer 11 | type Buffer struct { 12 | buffer bytes.Buffer 13 | mutex sync.Mutex 14 | } 15 | 16 | // Write appends the contents of p to the buffer, growing the buffer as needed. 17 | // It returns 18 | // the number of bytes written. 19 | func (s *Buffer) Write(p []byte) (n int, err error) { 20 | s.mutex.Lock() 21 | defer s.mutex.Unlock() 22 | return s.buffer.Write(p) 23 | } 24 | 25 | // String returns the contents of the unread portion of the buffer 26 | // as a string. If the Buffer is a nil pointer, it returns "". 27 | func (s *Buffer) String() string { 28 | s.mutex.Lock() 29 | defer s.mutex.Unlock() 30 | return s.buffer.String() 31 | } 32 | -------------------------------------------------------------------------------- /pry/suggestions.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "reflect" 5 | "regexp" 6 | "sort" 7 | "strings" 8 | ) 9 | 10 | var suggestionsRegexp = regexp.MustCompile("[.0-9a-zA-Z]+$") 11 | 12 | func (s *Scope) SuggestionsPry(line string, index int) ([]string, error) { 13 | text := line[:index] 14 | wip := suggestionsRegexp.FindString(text) 15 | 16 | if len(wip) == 0 { 17 | return nil, nil 18 | } 19 | 20 | var ok bool 21 | v := interface{}(s) 22 | parts := strings.Split(wip, ".") 23 | for _, k := range parts[:len(parts)-1] { 24 | v, ok = get(v, k) 25 | if !ok { 26 | return nil, nil 27 | } 28 | } 29 | 30 | partial := parts[len(parts)-1] 31 | 32 | var matchingKeys []string 33 | for _, key := range keys(v) { 34 | if strings.HasPrefix(key, partial) { 35 | matchingKeys = append(matchingKeys, key) 36 | } 37 | } 38 | 39 | sort.Strings(matchingKeys) 40 | 41 | return matchingKeys, nil 42 | } 43 | 44 | type keyser interface { 45 | Keys() []string 46 | } 47 | 48 | type getter interface { 49 | Get(string) (interface{}, bool) 50 | } 51 | 52 | func get(v interface{}, key string) (interface{}, bool) { 53 | if v == nil { 54 | return nil, false 55 | } 56 | 57 | g, ok := v.(getter) 58 | if ok { 59 | return g.Get(key) 60 | } 61 | 62 | val := reflect.ValueOf(v) 63 | 64 | typ := val.Type() 65 | switch typ.Kind() { 66 | case reflect.Ptr: 67 | return get(val.Elem().Addr(), key) 68 | 69 | case reflect.Struct: 70 | if _, ok := typ.FieldByName(key); !ok { 71 | return nil, false 72 | } 73 | return val.FieldByName(key).Interface(), true 74 | } 75 | 76 | return nil, false 77 | } 78 | 79 | func keys(v interface{}) []string { 80 | if v == nil { 81 | return nil 82 | } 83 | 84 | g, ok := v.(keyser) 85 | if ok { 86 | return g.Keys() 87 | } 88 | 89 | val := reflect.ValueOf(v) 90 | 91 | typ := val.Type() 92 | switch typ.Kind() { 93 | case reflect.Ptr: 94 | return keys(val.Elem().Addr()) 95 | 96 | case reflect.Struct: 97 | var keys []string 98 | for i := 0; i < typ.NumField(); i++ { 99 | keys = append(keys, typ.Field(i).Name) 100 | } 101 | for i := 0; i < typ.NumMethod(); i++ { 102 | keys = append(keys, typ.Method(i).Name) 103 | } 104 | return keys 105 | } 106 | 107 | return nil 108 | } 109 | -------------------------------------------------------------------------------- /pry/tty_js.go: -------------------------------------------------------------------------------- 1 | // +build js 2 | 3 | package pry 4 | 5 | import ( 6 | "io" 7 | "log" 8 | "syscall/js" 9 | ) 10 | 11 | var tty = newWASMTTY() 12 | 13 | func newWASMTTY() *wasmTTY { 14 | 15 | r, w := io.Pipe() 16 | t := &wasmTTY{ 17 | term: js.Global().Get("term"), 18 | r: r, 19 | } 20 | cb := js.FuncOf(func(this js.Value, args []js.Value) interface{} { 21 | data := args[0].String() 22 | w.Write([]byte(data)) 23 | return nil 24 | }) 25 | t.term.Call("on", "data", cb) 26 | 27 | return t 28 | } 29 | 30 | func init() { 31 | log.SetFlags(log.Flags() | log.Lshortfile) 32 | log.SetOutput(tty) 33 | } 34 | 35 | func openTTY() (io.Writer, genericTTY) { 36 | return tty, tty 37 | } 38 | 39 | type wasmTTY struct { 40 | term js.Value 41 | r io.Reader 42 | } 43 | 44 | func (t *wasmTTY) Write(buf []byte) (int, error) { 45 | t.term.Call("write", string(buf)) 46 | return len(buf), nil 47 | } 48 | 49 | func (t *wasmTTY) ReadRune() (rune, error) { 50 | var buf [1]byte 51 | if _, err := t.r.Read(buf[:]); err != nil { 52 | return 0, err 53 | } 54 | return rune(buf[0]), nil 55 | } 56 | 57 | func (t *wasmTTY) Size() (int, int, error) { 58 | return t.term.Get("cols").Int(), t.term.Get("rows").Int(), nil 59 | } 60 | 61 | func (t *wasmTTY) Close() error { 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /pry/tty_unix.go: -------------------------------------------------------------------------------- 1 | // +build linux darwin 2 | 3 | package pry 4 | 5 | import ( 6 | "io" 7 | "os" 8 | 9 | gotty "github.com/mattn/go-tty" 10 | ) 11 | 12 | func openTTY() (io.Writer, genericTTY) { 13 | tty, err := gotty.Open() 14 | if err != nil { 15 | panic(err) 16 | } 17 | return os.Stdout, tty 18 | } 19 | -------------------------------------------------------------------------------- /pry/tty_windows.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | package pry 4 | 5 | import ( 6 | "io" 7 | 8 | colorable "github.com/mattn/go-colorable" 9 | gotty "github.com/mattn/go-tty" 10 | ) 11 | 12 | func openTTY() (io.Writer, genericTTY) { 13 | tty, err := gotty.Open() 14 | if err != nil { 15 | panic(err) 16 | } 17 | return colorable.NewColorableStdout(), tty 18 | } 19 | -------------------------------------------------------------------------------- /pry/type.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import "reflect" 4 | 5 | // Type returns the reflect type of the passed object. 6 | func Type(t interface{}) reflect.Type { 7 | return reflect.TypeOf(t) 8 | } 9 | -------------------------------------------------------------------------------- /pry/type_test.go: -------------------------------------------------------------------------------- 1 | package pry 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestType(t *testing.T) { 9 | t.Parallel() 10 | 11 | a := 0 12 | out := Type(a) 13 | want := reflect.TypeOf(a) 14 | if !reflect.DeepEqual(want, out) { 15 | t.Errorf("Expected %#v got %#v.", want, out) 16 | } 17 | } 18 | --------------------------------------------------------------------------------