├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── diff └── diff.go ├── format.go ├── go.mod ├── gotopy.go ├── gotopy_test.go ├── parse.go ├── pyedits.go ├── pyprint ├── nodes.go └── printer.go ├── rewrite.go ├── simplify.go ├── testdata ├── basic.golden ├── basic.input ├── ra25.golden ├── ra25.input └── ra25.py └── version.go /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, GoKi 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Basic Go makefile 2 | 3 | GOCMD=go 4 | GOBUILD=$(GOCMD) build 5 | GOCLEAN=$(GOCMD) clean 6 | GOTEST=$(GOCMD) test 7 | GOGET=$(GOCMD) get 8 | 9 | DIRS=`go list ./...` 10 | 11 | all: build 12 | 13 | build: 14 | @echo "GO111MODULE = $(value GO111MODULE)" 15 | $(GOBUILD) -v $(DIRS) 16 | 17 | test: 18 | @echo "GO111MODULE = $(value GO111MODULE)" 19 | $(GOTEST) -v $(DIRS) 20 | 21 | clean: 22 | @echo "GO111MODULE = $(value GO111MODULE)" 23 | $(GOCLEAN) ./... 24 | 25 | fmts: 26 | gofmt -s -w . 27 | 28 | vet: 29 | @echo "GO111MODULE = $(value GO111MODULE)" 30 | $(GOCMD) vet $(DIRS) | grep -v unkeyed 31 | 32 | tidy: export GO111MODULE = on 33 | tidy: 34 | @echo "GO111MODULE = $(value GO111MODULE)" 35 | go mod tidy 36 | 37 | mod-update: export GO111MODULE = on 38 | mod-update: 39 | @echo "GO111MODULE = $(value GO111MODULE)" 40 | go get -u ./... 41 | go mod tidy 42 | 43 | # gopath-update is for GOPATH to get most things updated. 44 | # need to call it in a target executable directory 45 | gopath-update: export GO111MODULE = off 46 | gopath-update: 47 | @echo "GO111MODULE = $(value GO111MODULE)" 48 | go get -u ./... 49 | 50 | # NOTE: MUST update version number here prior to running 'make release' and edit this file! 51 | VERS=v0.4.1 52 | PACKAGE=main 53 | GIT_COMMIT=`git rev-parse --short HEAD` 54 | VERS_DATE=`date -u +%Y-%m-%d\ %H:%M` 55 | VERS_FILE=version.go 56 | 57 | release: 58 | /bin/rm -f $(VERS_FILE) 59 | @echo "// WARNING: auto-generated by Makefile release target -- run 'make release' to update" > $(VERS_FILE) 60 | @echo "" >> $(VERS_FILE) 61 | @echo "package $(PACKAGE)" >> $(VERS_FILE) 62 | @echo "" >> $(VERS_FILE) 63 | @echo "const (" >> $(VERS_FILE) 64 | @echo " Version = \"$(VERS)\"" >> $(VERS_FILE) 65 | @echo " GitCommit = \"$(GIT_COMMIT)\" // the commit JUST BEFORE the release" >> $(VERS_FILE) 66 | @echo " VersionDate = \"$(VERS_DATE)\" // UTC" >> $(VERS_FILE) 67 | @echo ")" >> $(VERS_FILE) 68 | @echo "" >> $(VERS_FILE) 69 | goimports -w $(VERS_FILE) 70 | /bin/cat $(VERS_FILE) 71 | git commit -am "$(VERS) release" 72 | git tag -a $(VERS) -m "$(VERS) release" 73 | git push 74 | git push origin --tags 75 | 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoToPy 2 | 3 | GoToPy is a Go to Python converter -- translates Go code into Python code. 4 | 5 | To install, do standard: 6 | 7 | ```Go 8 | $ go install github.com/go-python/gotopy@latest 9 | ``` 10 | 11 | It is based on the Go `gofmt` command source code and the go `printer` package, which parses Go files and writes them out according to standard go formatting. 12 | 13 | We have modified the `printer` code in the `pyprint` package to instead print out Python code. 14 | 15 | The `-gopy` flag generates [GoPy](https:://github.com/go-python/gopy) specific Python code, including: 16 | 17 | * `nil` -> `go.nil` 18 | * `[]string{...}` -> `go.Slice_string([...])` etc for int, float64, float32 19 | 20 | The `-gogi` flag generates [GoGi](https:://github.com/goki/gi) specific Python code, including: 21 | 22 | * struct tags generate: `self.SetTags()` call, for the `pygiv.ClassViewObj` class, which then provides an automatic GUI view with tag-based formatting of struct fields. 23 | 24 | # TODO 25 | 26 | * switch -> ifs.. -- grab switch expr and put into each if 27 | 28 | * string .contains -> "el" in str 29 | 30 | * map access with 2 vars = if el in map: mv = map[el] 31 | 32 | * for range with 2 vars = enumerate(slice) 33 | 34 | -------------------------------------------------------------------------------- /diff/diff.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package diff implements a Diff function that compare two inputs 6 | // using the 'diff' tool. 7 | package diff 8 | 9 | import ( 10 | "io/ioutil" 11 | "os" 12 | "os/exec" 13 | "runtime" 14 | ) 15 | 16 | // Returns diff of two arrays of bytes in diff tool format. 17 | func Diff(prefix string, b1, b2 []byte) ([]byte, error) { 18 | f1, err := writeTempFile(prefix, b1) 19 | if err != nil { 20 | return nil, err 21 | } 22 | defer os.Remove(f1) 23 | 24 | f2, err := writeTempFile(prefix, b2) 25 | if err != nil { 26 | return nil, err 27 | } 28 | defer os.Remove(f2) 29 | 30 | cmd := "diff" 31 | if runtime.GOOS == "plan9" { 32 | cmd = "/bin/ape/diff" 33 | } 34 | 35 | data, err := exec.Command(cmd, "-u", f1, f2).CombinedOutput() 36 | if len(data) > 0 { 37 | // diff exits with a non-zero status when the files don't match. 38 | // Ignore that failure as long as we get output. 39 | err = nil 40 | } 41 | return data, err 42 | } 43 | 44 | func writeTempFile(prefix string, data []byte) (string, error) { 45 | file, err := ioutil.TempFile("", prefix) 46 | if err != nil { 47 | return "", err 48 | } 49 | _, err = file.Write(data) 50 | if err1 := file.Close(); err == nil { 51 | err = err1 52 | } 53 | if err != nil { 54 | os.Remove(file.Name()) 55 | return "", err 56 | } 57 | return file.Name(), nil 58 | } 59 | -------------------------------------------------------------------------------- /format.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go-Python Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // This is based on gofmt source code: 6 | 7 | // Copyright 2015 The Go Authors. All rights reserved. 8 | // Use of this source code is governed by a BSD-style 9 | // license that can be found in the LICENSE file. 10 | 11 | package main 12 | 13 | import ( 14 | "bytes" 15 | "go/ast" 16 | "go/token" 17 | 18 | "github.com/go-python/gotopy/pyprint" 19 | ) 20 | 21 | // format formats the given package file originally obtained from src 22 | // and adjusts the result based on the original source via sourceAdj 23 | // and indentAdj. 24 | func format( 25 | fset *token.FileSet, 26 | file *ast.File, 27 | sourceAdj func(src []byte, indent int) []byte, 28 | indentAdj int, 29 | src []byte, 30 | cfg pyprint.Config, 31 | ) ([]byte, error) { 32 | if sourceAdj == nil { 33 | // Complete source file. 34 | var buf bytes.Buffer 35 | err := cfg.Fprint(&buf, fset, file) 36 | if err != nil { 37 | return nil, err 38 | } 39 | pyfix := pyEdits(buf.Bytes()) 40 | return pyfix, nil 41 | // return buf.Bytes(), nil 42 | } 43 | 44 | // Partial source file. 45 | // Determine and prepend leading space. 46 | i, j := 0, 0 47 | for j < len(src) && isSpace(src[j]) { 48 | if src[j] == '\n' { 49 | i = j + 1 // byte offset of last line in leading space 50 | } 51 | j++ 52 | } 53 | var res []byte 54 | res = append(res, src[:i]...) 55 | 56 | // Determine and prepend indentation of first code line. 57 | // Spaces are ignored unless there are no tabs, 58 | // in which case spaces count as one tab. 59 | indent := 0 60 | hasSpace := false 61 | for _, b := range src[i:j] { 62 | switch b { 63 | case ' ': 64 | hasSpace = true 65 | case '\t': 66 | indent++ 67 | } 68 | } 69 | if indent == 0 && hasSpace { 70 | indent = 1 71 | } 72 | for i := 0; i < indent; i++ { 73 | res = append(res, '\t') 74 | } 75 | 76 | // Format the source. 77 | // Write it without any leading and trailing space. 78 | cfg.Indent = indent + indentAdj 79 | var buf bytes.Buffer 80 | err := cfg.Fprint(&buf, fset, file) 81 | if err != nil { 82 | return nil, err 83 | } 84 | 85 | pyfix := pyEdits(buf.Bytes()) 86 | 87 | out := sourceAdj(pyfix, cfg.Indent) 88 | 89 | // If the adjusted output is empty, the source 90 | // was empty but (possibly) for white space. 91 | // The result is the incoming source. 92 | if len(out) == 0 { 93 | return src, nil 94 | } 95 | 96 | // Otherwise, append output to leading space. 97 | res = append(res, out...) 98 | 99 | // Determine and append trailing space. 100 | i = len(src) 101 | for i > 0 && isSpace(src[i-1]) { 102 | i-- 103 | } 104 | return append(res, src[i:]...), nil 105 | } 106 | 107 | // isSpace reports whether the byte is a space character. 108 | // isSpace defines a space as being among the following bytes: ' ', '\t', '\n' and '\r'. 109 | func isSpace(b byte) bool { 110 | return b == ' ' || b == '\t' || b == '\n' || b == '\r' 111 | } 112 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/go-python/gotopy 2 | 3 | go 1.15 4 | -------------------------------------------------------------------------------- /gotopy.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go-Python Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // This is based on gofmt source code: 6 | 7 | // Copyright 2009 The Go Authors. All rights reserved. 8 | // Use of this source code is governed by a BSD-style 9 | // license that can be found in the LICENSE file. 10 | 11 | package main 12 | 13 | import ( 14 | "bytes" 15 | "flag" 16 | "fmt" 17 | "go/ast" 18 | "go/parser" 19 | "go/scanner" 20 | "go/token" 21 | "io" 22 | "io/ioutil" 23 | "os" 24 | "path/filepath" 25 | "runtime" 26 | "runtime/pprof" 27 | "strings" 28 | 29 | "github.com/go-python/gotopy/diff" 30 | "github.com/go-python/gotopy/pyprint" 31 | ) 32 | 33 | var ( 34 | // main operation modes 35 | list = flag.Bool("l", false, "list files whose formatting differs from gofmt's") 36 | write = flag.Bool("w", false, "write result to (source) file instead of stdout") 37 | rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'a[b:len(a)] -> a[b:]')") 38 | simplifyAST = flag.Bool("s", false, "simplify code") 39 | doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") 40 | allErrors = flag.Bool("e", false, "report all errors (not just the first 10 on different lines)") 41 | gopyMode = flag.Bool("gopy", false, "support GoPy-specific Python code generation") 42 | gogiMode = flag.Bool("gogi", false, "support GoGi-specific Python code generation (implies gopy)") 43 | 44 | // debugging 45 | cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file") 46 | ) 47 | 48 | // Keep these in sync with go/format/format.go. 49 | const ( 50 | tabWidth = 8 51 | printerModeDef = pyprint.UseSpaces | pyprint.TabIndent | printerNormalizeNumbers 52 | 53 | // printerNormalizeNumbers means to canonicalize number literal prefixes 54 | // and exponents while printing. See https://golang.org/doc/go1.13#gofmt. 55 | // 56 | // This value is defined in go/printer specifically for go/format and cmd/gofmt. 57 | printerNormalizeNumbers = 1 << 30 58 | ) 59 | 60 | var ( 61 | fileSet = token.NewFileSet() // per process FileSet 62 | exitCode = 0 63 | rewrite func(*ast.File) *ast.File 64 | parserMode parser.Mode 65 | printerMode = printerModeDef 66 | ) 67 | 68 | func report(err error) { 69 | scanner.PrintError(os.Stderr, err) 70 | exitCode = 2 71 | } 72 | 73 | func usage() { 74 | fmt.Fprintf(os.Stderr, "usage: gofmt [flags] [path ...]\n") 75 | flag.PrintDefaults() 76 | } 77 | 78 | func initParserMode() { 79 | parserMode = parser.ParseComments 80 | if *allErrors { 81 | parserMode |= parser.AllErrors 82 | } 83 | if *gopyMode { 84 | printerMode |= pyprint.GoPy 85 | } 86 | 87 | if *gogiMode { 88 | printerMode |= pyprint.GoGi | pyprint.GoPy 89 | } 90 | } 91 | 92 | func isGoFile(f os.FileInfo) bool { 93 | // ignore non-Go files 94 | name := f.Name() 95 | return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") 96 | } 97 | 98 | // If in == nil, the source is the contents of the file with the given filename. 99 | func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error { 100 | var perm os.FileMode = 0644 101 | if in == nil { 102 | f, err := os.Open(filename) 103 | if err != nil { 104 | return err 105 | } 106 | defer f.Close() 107 | fi, err := f.Stat() 108 | if err != nil { 109 | return err 110 | } 111 | in = f 112 | perm = fi.Mode().Perm() 113 | } 114 | 115 | src, err := ioutil.ReadAll(in) 116 | if err != nil { 117 | return err 118 | } 119 | 120 | file, sourceAdj, indentAdj, err := parse(fileSet, filename, src, stdin) 121 | if err != nil { 122 | return err 123 | } 124 | 125 | if rewrite != nil { 126 | if sourceAdj == nil { 127 | file = rewrite(file) 128 | } else { 129 | fmt.Fprintf(os.Stderr, "warning: rewrite ignored for incomplete programs\n") 130 | } 131 | } 132 | 133 | ast.SortImports(fileSet, file) 134 | 135 | if *simplifyAST { 136 | simplify(file) 137 | } 138 | 139 | res, err := format(fileSet, file, sourceAdj, indentAdj, src, pyprint.Config{Mode: printerMode, Tabwidth: tabWidth}) 140 | if err != nil { 141 | return err 142 | } 143 | 144 | if !bytes.Equal(src, res) { 145 | // formatting has changed 146 | if *list { 147 | fmt.Fprintln(out, filename) 148 | } 149 | if *write { 150 | // make a temporary backup before overwriting original 151 | bakname, err := backupFile(filename+".", src, perm) 152 | if err != nil { 153 | return err 154 | } 155 | err = ioutil.WriteFile(filename, res, perm) 156 | if err != nil { 157 | os.Rename(bakname, filename) 158 | return err 159 | } 160 | err = os.Remove(bakname) 161 | if err != nil { 162 | return err 163 | } 164 | } 165 | if *doDiff { 166 | data, err := diffWithReplaceTempFile(src, res, filename) 167 | if err != nil { 168 | return fmt.Errorf("computing diff: %s", err) 169 | } 170 | fmt.Printf("diff -u %s %s\n", filepath.ToSlash(filename+".orig"), filepath.ToSlash(filename)) 171 | out.Write(data) 172 | } 173 | } 174 | 175 | if !*list && !*write && !*doDiff { 176 | _, err = out.Write(res) 177 | } 178 | 179 | return err 180 | } 181 | 182 | func visitFile(path string, f os.FileInfo, err error) error { 183 | if err == nil && isGoFile(f) { 184 | err = processFile(path, nil, os.Stdout, false) 185 | } 186 | // Don't complain if a file was deleted in the meantime (i.e. 187 | // the directory changed concurrently while running gofmt). 188 | if err != nil && !os.IsNotExist(err) { 189 | report(err) 190 | } 191 | return nil 192 | } 193 | 194 | func walkDir(path string) { 195 | filepath.Walk(path, visitFile) 196 | } 197 | 198 | func main() { 199 | // call gofmtMain in a separate function 200 | // so that it can use defer and have them 201 | // run before the exit. 202 | gofmtMain() 203 | os.Exit(exitCode) 204 | } 205 | 206 | func gofmtMain() { 207 | flag.Usage = usage 208 | flag.Parse() 209 | 210 | if *cpuprofile != "" { 211 | f, err := os.Create(*cpuprofile) 212 | if err != nil { 213 | fmt.Fprintf(os.Stderr, "creating cpu profile: %s\n", err) 214 | exitCode = 2 215 | return 216 | } 217 | defer f.Close() 218 | pprof.StartCPUProfile(f) 219 | defer pprof.StopCPUProfile() 220 | } 221 | 222 | initParserMode() 223 | initRewrite() 224 | 225 | if flag.NArg() == 0 { 226 | if *write { 227 | fmt.Fprintln(os.Stderr, "error: cannot use -w with standard input") 228 | exitCode = 2 229 | return 230 | } 231 | if err := processFile("", os.Stdin, os.Stdout, true); err != nil { 232 | report(err) 233 | } 234 | return 235 | } 236 | 237 | for i := 0; i < flag.NArg(); i++ { 238 | path := flag.Arg(i) 239 | switch dir, err := os.Stat(path); { 240 | case err != nil: 241 | report(err) 242 | case dir.IsDir(): 243 | walkDir(path) 244 | default: 245 | if err := processFile(path, nil, os.Stdout, false); err != nil { 246 | report(err) 247 | } 248 | } 249 | } 250 | } 251 | 252 | func diffWithReplaceTempFile(b1, b2 []byte, filename string) ([]byte, error) { 253 | data, err := diff.Diff("gofmt", b1, b2) 254 | if len(data) > 0 { 255 | return replaceTempFilename(data, filename) 256 | } 257 | return data, err 258 | } 259 | 260 | // replaceTempFilename replaces temporary filenames in diff with actual one. 261 | // 262 | // --- /tmp/gofmt316145376 2017-02-03 19:13:00.280468375 -0500 263 | // +++ /tmp/gofmt617882815 2017-02-03 19:13:00.280468375 -0500 264 | // ... 265 | // -> 266 | // --- path/to/file.go.orig 2017-02-03 19:13:00.280468375 -0500 267 | // +++ path/to/file.go 2017-02-03 19:13:00.280468375 -0500 268 | // ... 269 | func replaceTempFilename(diff []byte, filename string) ([]byte, error) { 270 | bs := bytes.SplitN(diff, []byte{'\n'}, 3) 271 | if len(bs) < 3 { 272 | return nil, fmt.Errorf("got unexpected diff for %s", filename) 273 | } 274 | // Preserve timestamps. 275 | var t0, t1 []byte 276 | if i := bytes.LastIndexByte(bs[0], '\t'); i != -1 { 277 | t0 = bs[0][i:] 278 | } 279 | if i := bytes.LastIndexByte(bs[1], '\t'); i != -1 { 280 | t1 = bs[1][i:] 281 | } 282 | // Always print filepath with slash separator. 283 | f := filepath.ToSlash(filename) 284 | bs[0] = []byte(fmt.Sprintf("--- %s%s", f+".orig", t0)) 285 | bs[1] = []byte(fmt.Sprintf("+++ %s%s", f, t1)) 286 | return bytes.Join(bs, []byte{'\n'}), nil 287 | } 288 | 289 | const chmodSupported = runtime.GOOS != "windows" 290 | 291 | // backupFile writes data to a new file named filename with permissions perm, 292 | // with >>>") 52 | 53 | endclass := "EndClass: " 54 | method := "Method: " 55 | endmethod := "EndMethod" 56 | 57 | lastMethSt := -1 58 | var lastMeth string 59 | curComSt := -1 60 | lastComSt := -1 61 | lastComEd := -1 62 | 63 | li := 0 64 | for { 65 | if li >= len(lines) { 66 | break 67 | } 68 | ln := lines[li] 69 | if len(ln) > 0 && ln[0] == '#' { 70 | if curComSt >= 0 { 71 | lastComEd = li 72 | } else { 73 | curComSt = li 74 | lastComSt = li 75 | lastComEd = li 76 | } 77 | } else { 78 | curComSt = -1 79 | } 80 | 81 | switch { 82 | case bytes.Equal(ln, []byte(" :")) || bytes.Equal(ln, []byte(":")): 83 | lines = append(lines[:li], lines[li+1:]...) // delete marker 84 | li-- 85 | case bytes.HasPrefix(ln, class): 86 | cl := string(ln[len(class):]) 87 | if idx := strings.Index(cl, "("); idx > 0 { 88 | cl = cl[:idx] 89 | } else if idx := strings.Index(cl, ":"); idx > 0 { // should have 90 | cl = cl[:idx] 91 | } 92 | cl = strings.TrimSpace(cl) 93 | classes[cl] = sted{st: li} 94 | // fmt.Printf("cl: %s at %d\n", cl, li) 95 | case bytes.HasPrefix(ln, pymark) && bytes.HasSuffix(ln, pyend): 96 | tag := string(ln[4 : len(ln)-4]) 97 | // fmt.Printf("tag: %s at: %d\n", tag, li) 98 | switch { 99 | case strings.HasPrefix(tag, endclass): 100 | cl := tag[len(endclass):] 101 | st := classes[cl] 102 | classes[cl] = sted{st: st.st, ed: li} 103 | lines = append(lines[:li], lines[li+1:]...) // delete marker 104 | // fmt.Printf("cl: %s at %v\n", cl, classes[cl]) 105 | li-- 106 | case strings.HasPrefix(tag, method): 107 | cl := tag[len(method):] 108 | lines = append(lines[:li], lines[li+1:]...) // delete marker 109 | li-- 110 | lastMeth = cl 111 | if lastComEd == li { 112 | lines = append(lines[:lastComSt], lines[lastComEd+1:]...) // delete comments 113 | lastMethSt = lastComSt 114 | li = lastComSt - 1 115 | } else { 116 | lastMethSt = li + 1 117 | } 118 | case tag == endmethod: 119 | se, ok := classes[lastMeth] 120 | if ok { 121 | lines = append(lines[:li], lines[li+1:]...) // delete marker 122 | moveLines(&lines, se.ed, lastMethSt, li+1) // extra blank 123 | classes[lastMeth] = sted{st: se.st, ed: se.ed + ((li + 1) - lastMethSt)} 124 | li -= 2 125 | } 126 | } 127 | } 128 | li++ 129 | } 130 | return lines 131 | } 132 | 133 | // pyEditsReplace replaces Go with equivalent Python code 134 | func pyEditsReplace(lines [][]byte) { 135 | fmtPrintf := []byte("fmt.Printf") 136 | fmtSprintf := []byte("fmt.Sprintf(") 137 | prints := []byte("print") 138 | eqappend := []byte("= append(") 139 | elseif := []byte("else if") 140 | elif := []byte("elif") 141 | forblank := []byte("for _, ") 142 | fornoblank := []byte("for ") 143 | itoa := []byte("strconv.Itoa") 144 | float64p := []byte("float64(") 145 | float32p := []byte("float32(") 146 | floatp := []byte("float(") 147 | stringp := []byte("string(") 148 | strp := []byte("str(") 149 | slicestr := []byte("[]str(") 150 | sliceint := []byte("[]int(") 151 | slicefloat64 := []byte("[]float64(") 152 | slicefloat32 := []byte("[]float32(") 153 | goslicestr := []byte("go.Slice_string([") 154 | gosliceint := []byte("go.Slice_int([") 155 | goslicefloat64 := []byte("go.Slice_float64([") 156 | goslicefloat32 := []byte("go.Slice_float32([") 157 | stringsdot := []byte("strings.") 158 | copyp := []byte("copy(") 159 | eqgonil := []byte(" == go.nil") 160 | eqgonil0 := []byte(" == 0") 161 | negonil := []byte(" != go.nil") 162 | negonil0 := []byte(" != 0") 163 | 164 | gopy := (printerMode&pyprint.GoPy != 0) 165 | // gogi := (printerMode&pyprint.GoGi != 0) 166 | 167 | for li, ln := range lines { 168 | ln = bytes.Replace(ln, float64p, floatp, -1) 169 | ln = bytes.Replace(ln, float32p, floatp, -1) 170 | ln = bytes.Replace(ln, stringp, strp, -1) 171 | ln = bytes.Replace(ln, forblank, fornoblank, -1) 172 | ln = bytes.Replace(ln, eqgonil, eqgonil0, -1) 173 | ln = bytes.Replace(ln, negonil, negonil0, -1) 174 | 175 | if bytes.Contains(ln, fmtSprintf) { 176 | if bytes.Contains(ln, []byte("%")) { 177 | ln = bytes.Replace(ln, []byte(`", `), []byte(`" % (`), -1) 178 | } 179 | ln = bytes.Replace(ln, fmtSprintf, []byte{}, -1) 180 | } 181 | 182 | if bytes.Contains(ln, fmtPrintf) { 183 | if bytes.Contains(ln, []byte("%")) { 184 | ln = bytes.Replace(ln, []byte(`", `), []byte(`" % `), -1) 185 | } 186 | ln = bytes.Replace(ln, fmtPrintf, prints, -1) 187 | } 188 | 189 | if bytes.Contains(ln, eqappend) { 190 | idx := bytes.Index(ln, eqappend) 191 | comi := bytes.Index(ln[idx+len(eqappend):], []byte(",")) 192 | nln := make([]byte, idx-1) 193 | copy(nln, ln[:idx-1]) 194 | nln = append(nln, []byte(".append(")...) 195 | nln = append(nln, ln[idx+len(eqappend)+comi+1:]...) 196 | ln = nln 197 | } 198 | 199 | for { 200 | if bytes.Contains(ln, stringsdot) { 201 | idx := bytes.Index(ln, stringsdot) 202 | pi := idx + len(stringsdot) + bytes.Index(ln[idx+len(stringsdot):], []byte("(")) 203 | comi := bytes.Index(ln[pi:], []byte(",")) 204 | nln := make([]byte, idx) 205 | copy(nln, ln[:idx]) 206 | if comi < 0 { 207 | comi = bytes.Index(ln[pi:], []byte(")")) 208 | nln = append(nln, ln[pi+1:pi+comi]...) 209 | nln = append(nln, '.') 210 | meth := bytes.ToLower(ln[idx+len(stringsdot) : pi+1]) 211 | if bytes.Equal(meth, []byte("fields(")) { 212 | meth = []byte("split(") 213 | } 214 | nln = append(nln, meth...) 215 | nln = append(nln, ln[pi+comi:]...) 216 | } else { 217 | nln = append(nln, ln[pi+1:pi+comi]...) 218 | nln = append(nln, '.') 219 | meth := bytes.ToLower(ln[idx+len(stringsdot) : pi+1]) 220 | nln = append(nln, meth...) 221 | nln = append(nln, ln[pi+comi+1:]...) 222 | } 223 | ln = nln 224 | } else { 225 | break 226 | } 227 | } 228 | 229 | if bytes.Contains(ln, copyp) { 230 | idx := bytes.Index(ln, copyp) 231 | pi := idx + len(copyp) + bytes.Index(ln[idx+len(stringsdot):], []byte("(")) 232 | comi := bytes.Index(ln[pi:], []byte(",")) 233 | nln := make([]byte, idx) 234 | copy(nln, ln[:idx]) 235 | nln = append(nln, ln[pi+1:pi+comi]...) 236 | nln = append(nln, '.') 237 | nln = append(nln, copyp...) 238 | nln = append(nln, ln[pi+comi+1:]...) 239 | ln = nln 240 | } 241 | 242 | if bytes.Contains(ln, itoa) { 243 | ln = bytes.Replace(ln, itoa, []byte(`str`), -1) 244 | } 245 | 246 | if bytes.Contains(ln, elseif) { 247 | ln = bytes.Replace(ln, elseif, elif, -1) 248 | } 249 | 250 | if gopy && bytes.Contains(ln, slicestr) { 251 | ln = bytes.Replace(ln, slicestr, goslicestr, -1) 252 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1) 253 | } 254 | 255 | if gopy && bytes.Contains(ln, sliceint) { 256 | ln = bytes.Replace(ln, sliceint, gosliceint, -1) 257 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1) 258 | } 259 | 260 | if gopy && bytes.Contains(ln, slicefloat64) { 261 | ln = bytes.Replace(ln, slicefloat64, goslicefloat64, -1) 262 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1) 263 | } 264 | 265 | if gopy && bytes.Contains(ln, slicefloat32) { 266 | ln = bytes.Replace(ln, slicefloat32, goslicefloat32, -1) 267 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1) 268 | } 269 | 270 | ln = bytes.Replace(ln, []byte("\t"), []byte(" "), -1) 271 | 272 | lines[li] = ln 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /pyprint/printer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go-Python Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // This is based on gofmt source code: 6 | 7 | // Copyright 2009 The Go Authors. All rights reserved. 8 | // Use of this source code is governed by a BSD-style 9 | // license that can be found in the LICENSE file. 10 | 11 | // Package printer implements printing of AST nodes. 12 | package pyprint 13 | 14 | import ( 15 | "fmt" 16 | "go/ast" 17 | "go/token" 18 | "io" 19 | "os" 20 | "strings" 21 | "text/tabwriter" 22 | "unicode" 23 | ) 24 | 25 | const ( 26 | maxNewlines = 2 // max. number of newlines between source text 27 | debug = false // enable for debugging 28 | infinity = 1 << 30 29 | ) 30 | 31 | type whiteSpace byte 32 | 33 | const ( 34 | ignore = whiteSpace(0) 35 | blank = whiteSpace(' ') 36 | vtab = whiteSpace('\v') 37 | newline = whiteSpace('\n') 38 | formfeed = whiteSpace('\f') 39 | indent = whiteSpace('>') 40 | unindent = whiteSpace('<') 41 | ) 42 | 43 | // A pmode value represents the current printer mode. 44 | type pmode int 45 | 46 | const ( 47 | noExtraBlank pmode = 1 << iota // disables extra blank after /*-style comment 48 | noExtraLinebreak // disables extra line break after /*-style comment 49 | ) 50 | 51 | type commentInfo struct { 52 | cindex int // current comment index 53 | comment *ast.CommentGroup // = printer.comments[cindex]; or nil 54 | commentOffset int // = printer.posFor(printer.comments[cindex].List[0].Pos()).Offset; or infinity 55 | commentNewline bool // true if the comment group contains newlines 56 | } 57 | 58 | type printer struct { 59 | // Configuration (does not change after initialization) 60 | Config 61 | fset *token.FileSet 62 | 63 | // Current state 64 | output []byte // raw printer result 65 | indent int // current indentation 66 | level int // level == 0: outside composite literal; level > 0: inside composite literal 67 | mode pmode // current printer mode 68 | endAlignment bool // if set, terminate alignment immediately 69 | impliedSemi bool // if set, a linebreak implies a semicolon 70 | lastTok token.Token // last token printed (token.ILLEGAL if it's whitespace) 71 | prevOpen token.Token // previous non-brace "open" token (, [, or token.ILLEGAL 72 | wsbuf []whiteSpace // delayed white space 73 | 74 | // Positions 75 | // The out position differs from the pos position when the result 76 | // formatting differs from the source formatting (in the amount of 77 | // white space). If there's a difference and SourcePos is set in 78 | // ConfigMode, //line directives are used in the output to restore 79 | // original source positions for a reader. 80 | pos token.Position // current position in AST (source) space 81 | out token.Position // current position in output space 82 | last token.Position // value of pos after calling writeString 83 | linePtr *int // if set, record out.Line for the next token in *linePtr 84 | 85 | // The list of all source comments, in order of appearance. 86 | comments []*ast.CommentGroup // may be nil 87 | useNodeComments bool // if not set, ignore lead and line comments of nodes 88 | 89 | // Information about p.comments[p.cindex]; set up by nextComment. 90 | commentInfo 91 | 92 | // Cache of already computed node sizes. 93 | nodeSizes map[ast.Node]int 94 | 95 | // Cache of most recently computed line position. 96 | cachedPos token.Pos 97 | cachedLine int // line corresponding to cachedPos 98 | } 99 | 100 | func (p *printer) init(cfg *Config, fset *token.FileSet, nodeSizes map[ast.Node]int) { 101 | p.Config = *cfg 102 | p.fset = fset 103 | p.pos = token.Position{Line: 1, Column: 1} 104 | p.out = token.Position{Line: 1, Column: 1} 105 | p.wsbuf = make([]whiteSpace, 0, 16) // whitespace sequences are short 106 | p.nodeSizes = nodeSizes 107 | p.cachedPos = -1 108 | } 109 | 110 | func (p *printer) internalError(msg ...interface{}) { 111 | if debug { 112 | fmt.Print(p.pos.String() + ": ") 113 | fmt.Println(msg...) 114 | panic("go/printer") 115 | } 116 | } 117 | 118 | // commentsHaveNewline reports whether a list of comments belonging to 119 | // an *ast.CommentGroup contains newlines. Because the position information 120 | // may only be partially correct, we also have to read the comment text. 121 | func (p *printer) commentsHaveNewline(list []*ast.Comment) bool { 122 | // len(list) > 0 123 | line := p.lineFor(list[0].Pos()) 124 | for i, c := range list { 125 | if i > 0 && p.lineFor(list[i].Pos()) != line { 126 | // not all comments on the same line 127 | return true 128 | } 129 | if t := c.Text; len(t) >= 2 && (t[1] == '/' || strings.Contains(t, "\n")) { 130 | return true 131 | } 132 | } 133 | _ = line 134 | return false 135 | } 136 | 137 | func (p *printer) nextComment() { 138 | for p.cindex < len(p.comments) { 139 | c := p.comments[p.cindex] 140 | p.cindex++ 141 | if list := c.List; len(list) > 0 { 142 | p.comment = c 143 | p.commentOffset = p.posFor(list[0].Pos()).Offset 144 | p.commentNewline = p.commentsHaveNewline(list) 145 | return 146 | } 147 | // we should not reach here (correct ASTs don't have empty 148 | // ast.CommentGroup nodes), but be conservative and try again 149 | } 150 | // no more comments 151 | p.commentOffset = infinity 152 | } 153 | 154 | // commentBefore reports whether the current comment group occurs 155 | // before the next position in the source code and printing it does 156 | // not introduce implicit semicolons. 157 | // 158 | func (p *printer) commentBefore(next token.Position) bool { 159 | return p.commentOffset < next.Offset && (!p.impliedSemi || !p.commentNewline) 160 | } 161 | 162 | // commentSizeBefore returns the estimated size of the 163 | // comments on the same line before the next position. 164 | // 165 | func (p *printer) commentSizeBefore(next token.Position) int { 166 | // save/restore current p.commentInfo (p.nextComment() modifies it) 167 | defer func(info commentInfo) { 168 | p.commentInfo = info 169 | }(p.commentInfo) 170 | 171 | size := 0 172 | for p.commentBefore(next) { 173 | for _, c := range p.comment.List { 174 | size += len(c.Text) 175 | } 176 | p.nextComment() 177 | } 178 | return size 179 | } 180 | 181 | // recordLine records the output line number for the next non-whitespace 182 | // token in *linePtr. It is used to compute an accurate line number for a 183 | // formatted construct, independent of pending (not yet emitted) whitespace 184 | // or comments. 185 | // 186 | func (p *printer) recordLine(linePtr *int) { 187 | p.linePtr = linePtr 188 | } 189 | 190 | // linesFrom returns the number of output lines between the current 191 | // output line and the line argument, ignoring any pending (not yet 192 | // emitted) whitespace or comments. It is used to compute an accurate 193 | // size (in number of lines) for a formatted construct. 194 | // 195 | func (p *printer) linesFrom(line int) int { 196 | return p.out.Line - line 197 | } 198 | 199 | func (p *printer) posFor(pos token.Pos) token.Position { 200 | // not used frequently enough to cache entire token.Position 201 | return p.fset.PositionFor(pos, false /* absolute position */) 202 | } 203 | 204 | func (p *printer) lineFor(pos token.Pos) int { 205 | if pos != p.cachedPos { 206 | p.cachedPos = pos 207 | p.cachedLine = p.fset.PositionFor(pos, false /* absolute position */).Line 208 | } 209 | return p.cachedLine 210 | } 211 | 212 | // writeLineDirective writes a //line directive if necessary. 213 | func (p *printer) writeLineDirective(pos token.Position) { 214 | if pos.IsValid() && (p.out.Line != pos.Line || p.out.Filename != pos.Filename) { 215 | p.output = append(p.output, tabwriter.Escape) // protect '\n' in //line from tabwriter interpretation 216 | p.output = append(p.output, fmt.Sprintf("//line %s:%d\n", pos.Filename, pos.Line)...) 217 | p.output = append(p.output, tabwriter.Escape) 218 | // p.out must match the //line directive 219 | p.out.Filename = pos.Filename 220 | p.out.Line = pos.Line 221 | } 222 | } 223 | 224 | // writeIndent writes indentation. 225 | func (p *printer) writeIndent() { 226 | // use "hard" htabs - indentation columns 227 | // must not be discarded by the tabwriter 228 | n := p.Config.Indent + p.indent // include base indentation 229 | for i := 0; i < n; i++ { 230 | p.output = append(p.output, '\t') 231 | } 232 | 233 | // update positions 234 | p.pos.Offset += n 235 | p.pos.Column += n 236 | p.out.Column += n 237 | } 238 | 239 | // writeByte writes ch n times to p.output and updates p.pos. 240 | // Only used to write formatting (white space) characters. 241 | func (p *printer) writeByte(ch byte, n int) { 242 | if p.endAlignment { 243 | // Ignore any alignment control character; 244 | // and at the end of the line, break with 245 | // a formfeed to indicate termination of 246 | // existing columns. 247 | switch ch { 248 | case '\t', '\v': 249 | ch = ' ' 250 | case '\n', '\f': 251 | ch = '\f' 252 | p.endAlignment = false 253 | } 254 | } 255 | 256 | if p.out.Column == 1 { 257 | // no need to write line directives before white space 258 | p.writeIndent() 259 | } 260 | 261 | for i := 0; i < n; i++ { 262 | p.output = append(p.output, ch) 263 | } 264 | 265 | // update positions 266 | p.pos.Offset += n 267 | if ch == '\n' || ch == '\f' { 268 | p.pos.Line += n 269 | p.out.Line += n 270 | p.pos.Column = 1 271 | p.out.Column = 1 272 | return 273 | } 274 | p.pos.Column += n 275 | p.out.Column += n 276 | } 277 | 278 | // writeString writes the string s to p.output and updates p.pos, p.out, 279 | // and p.last. If isLit is set, s is escaped w/ tabwriter.Escape characters 280 | // to protect s from being interpreted by the tabwriter. 281 | // 282 | // Note: writeString is only used to write Go tokens, literals, and 283 | // comments, all of which must be written literally. Thus, it is correct 284 | // to always set isLit = true. However, setting it explicitly only when 285 | // needed (i.e., when we don't know that s contains no tabs or line breaks) 286 | // avoids processing extra escape characters and reduces run time of the 287 | // printer benchmark by up to 10%. 288 | // 289 | func (p *printer) writeString(pos token.Position, s string, isLit bool) { 290 | if p.out.Column == 1 { 291 | if p.Config.Mode&SourcePos != 0 { 292 | p.writeLineDirective(pos) 293 | } 294 | p.writeIndent() 295 | } 296 | 297 | if pos.IsValid() { 298 | // update p.pos (if pos is invalid, continue with existing p.pos) 299 | // Note: Must do this after handling line beginnings because 300 | // writeIndent updates p.pos if there's indentation, but p.pos 301 | // is the position of s. 302 | p.pos = pos 303 | } 304 | 305 | if isLit { 306 | // Protect s such that is passes through the tabwriter 307 | // unchanged. Note that valid Go programs cannot contain 308 | // tabwriter.Escape bytes since they do not appear in legal 309 | // UTF-8 sequences. 310 | p.output = append(p.output, tabwriter.Escape) 311 | } 312 | 313 | if debug { 314 | p.output = append(p.output, fmt.Sprintf("/*%s*/", pos)...) // do not update p.pos! 315 | } 316 | p.output = append(p.output, s...) 317 | 318 | // update positions 319 | nlines := 0 320 | var li int // index of last newline; valid if nlines > 0 321 | for i := 0; i < len(s); i++ { 322 | // Raw string literals may contain any character except back quote (`). 323 | if ch := s[i]; ch == '\n' || ch == '\f' { 324 | // account for line break 325 | nlines++ 326 | li = i 327 | // A line break inside a literal will break whatever column 328 | // formatting is in place; ignore any further alignment through 329 | // the end of the line. 330 | p.endAlignment = true 331 | } 332 | } 333 | p.pos.Offset += len(s) 334 | if nlines > 0 { 335 | p.pos.Line += nlines 336 | p.out.Line += nlines 337 | c := len(s) - li 338 | p.pos.Column = c 339 | p.out.Column = c 340 | } else { 341 | p.pos.Column += len(s) 342 | p.out.Column += len(s) 343 | } 344 | 345 | if isLit { 346 | p.output = append(p.output, tabwriter.Escape) 347 | } 348 | 349 | p.last = p.pos 350 | } 351 | 352 | // writeCommentPrefix writes the whitespace before a comment. 353 | // If there is any pending whitespace, it consumes as much of 354 | // it as is likely to help position the comment nicely. 355 | // pos is the comment position, next the position of the item 356 | // after all pending comments, prev is the previous comment in 357 | // a group of comments (or nil), and tok is the next token. 358 | // 359 | func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, tok token.Token) { 360 | if len(p.output) == 0 { 361 | // the comment is the first item to be printed - don't write any whitespace 362 | return 363 | } 364 | 365 | if pos.IsValid() && pos.Filename != p.last.Filename { 366 | // comment in a different file - separate with newlines 367 | p.writeByte('\f', maxNewlines) 368 | return 369 | } 370 | 371 | if pos.Line == p.last.Line && (prev == nil || prev.Text[1] != '/') { 372 | // comment on the same line as last item: 373 | // separate with at least one separator 374 | hasSep := false 375 | if prev == nil { 376 | // first comment of a comment group 377 | j := 0 378 | for i, ch := range p.wsbuf { 379 | switch ch { 380 | case blank: 381 | // ignore any blanks before a comment 382 | p.wsbuf[i] = ignore 383 | continue 384 | case vtab: 385 | // respect existing tabs - important 386 | // for proper formatting of commented structs 387 | hasSep = true 388 | continue 389 | case indent: 390 | // apply pending indentation 391 | continue 392 | } 393 | j = i 394 | break 395 | } 396 | p.writeWhitespace(j) 397 | } 398 | // make sure there is at least one separator 399 | if !hasSep { 400 | sep := byte('\t') 401 | if pos.Line == next.Line { 402 | // next item is on the same line as the comment 403 | // (which must be a /*-style comment): separate 404 | // with a blank instead of a tab 405 | sep = ' ' 406 | } 407 | p.writeByte(sep, 1) 408 | } 409 | 410 | } else { 411 | // comment on a different line: 412 | // separate with at least one line break 413 | droppedLinebreak := false 414 | j := 0 415 | for i, ch := range p.wsbuf { 416 | switch ch { 417 | case blank, vtab: 418 | // ignore any horizontal whitespace before line breaks 419 | p.wsbuf[i] = ignore 420 | continue 421 | case indent: 422 | // apply pending indentation 423 | continue 424 | case unindent: 425 | // if this is not the last unindent, apply it 426 | // as it is (likely) belonging to the last 427 | // construct (e.g., a multi-line expression list) 428 | // and is not part of closing a block 429 | if i+1 < len(p.wsbuf) && p.wsbuf[i+1] == unindent { 430 | continue 431 | } 432 | // if the next token is not a closing }, apply the unindent 433 | // if it appears that the comment is aligned with the 434 | // token; otherwise assume the unindent is part of a 435 | // closing block and stop (this scenario appears with 436 | // comments before a case label where the comments 437 | // apply to the next case instead of the current one) 438 | if tok != token.RBRACE && pos.Column == next.Column { 439 | continue 440 | } 441 | case newline, formfeed: 442 | p.wsbuf[i] = ignore 443 | droppedLinebreak = prev == nil // record only if first comment of a group 444 | } 445 | j = i 446 | break 447 | } 448 | p.writeWhitespace(j) 449 | 450 | // determine number of linebreaks before the comment 451 | n := 0 452 | if pos.IsValid() && p.last.IsValid() { 453 | n = pos.Line - p.last.Line 454 | if n < 0 { // should never happen 455 | n = 0 456 | } 457 | } 458 | 459 | // at the package scope level only (p.indent == 0), 460 | // add an extra newline if we dropped one before: 461 | // this preserves a blank line before documentation 462 | // comments at the package scope level (issue 2570) 463 | if p.indent == 0 && droppedLinebreak { 464 | n++ 465 | } 466 | 467 | // make sure there is at least one line break 468 | // if the previous comment was a line comment 469 | if n == 0 && prev != nil && prev.Text[1] == '/' { 470 | n = 1 471 | } 472 | 473 | if n > 0 { 474 | // use formfeeds to break columns before a comment; 475 | // this is analogous to using formfeeds to separate 476 | // individual lines of /*-style comments 477 | p.writeByte('\f', nlimit(n)) 478 | } 479 | } 480 | } 481 | 482 | // Returns true if s contains only white space 483 | // (only tabs and blanks can appear in the printer's context). 484 | // 485 | func isBlank(s string) bool { 486 | for i := 0; i < len(s); i++ { 487 | if s[i] > ' ' { 488 | return false 489 | } 490 | } 491 | return true 492 | } 493 | 494 | // commonPrefix returns the common prefix of a and b. 495 | func commonPrefix(a, b string) string { 496 | i := 0 497 | for i < len(a) && i < len(b) && a[i] == b[i] && (a[i] <= ' ' || a[i] == '*') { 498 | i++ 499 | } 500 | return a[0:i] 501 | } 502 | 503 | // trimRight returns s with trailing whitespace removed. 504 | func trimRight(s string) string { 505 | return strings.TrimRightFunc(s, unicode.IsSpace) 506 | } 507 | 508 | // stripCommonPrefix removes a common prefix from /*-style comment lines (unless no 509 | // comment line is indented, all but the first line have some form of space prefix). 510 | // The prefix is computed using heuristics such that is likely that the comment 511 | // contents are nicely laid out after re-printing each line using the printer's 512 | // current indentation. 513 | // 514 | func stripCommonPrefix(lines []string) { 515 | if len(lines) <= 1 { 516 | return // at most one line - nothing to do 517 | } 518 | // len(lines) > 1 519 | 520 | // The heuristic in this function tries to handle a few 521 | // common patterns of /*-style comments: Comments where 522 | // the opening /* and closing */ are aligned and the 523 | // rest of the comment text is aligned and indented with 524 | // blanks or tabs, cases with a vertical "line of stars" 525 | // on the left, and cases where the closing */ is on the 526 | // same line as the last comment text. 527 | 528 | // Compute maximum common white prefix of all but the first, 529 | // last, and blank lines, and replace blank lines with empty 530 | // lines (the first line starts with /* and has no prefix). 531 | // In cases where only the first and last lines are not blank, 532 | // such as two-line comments, or comments where all inner lines 533 | // are blank, consider the last line for the prefix computation 534 | // since otherwise the prefix would be empty. 535 | // 536 | // Note that the first and last line are never empty (they 537 | // contain the opening /* and closing */ respectively) and 538 | // thus they can be ignored by the blank line check. 539 | prefix := "" 540 | prefixSet := false 541 | if len(lines) > 2 { 542 | for i, line := range lines[1 : len(lines)-1] { 543 | if isBlank(line) { 544 | lines[1+i] = "" // range starts with lines[1] 545 | } else { 546 | if !prefixSet { 547 | prefix = line 548 | prefixSet = true 549 | } 550 | prefix = commonPrefix(prefix, line) 551 | } 552 | 553 | } 554 | } 555 | // If we don't have a prefix yet, consider the last line. 556 | if !prefixSet { 557 | line := lines[len(lines)-1] 558 | prefix = commonPrefix(line, line) 559 | } 560 | 561 | /* 562 | * Check for vertical "line of stars" and correct prefix accordingly. 563 | */ 564 | lineOfStars := false 565 | if i := strings.Index(prefix, "*"); i >= 0 { 566 | // Line of stars present. 567 | if i > 0 && prefix[i-1] == ' ' { 568 | i-- // remove trailing blank from prefix so stars remain aligned 569 | } 570 | prefix = prefix[0:i] 571 | lineOfStars = true 572 | } else { 573 | // No line of stars present. 574 | // Determine the white space on the first line after the /* 575 | // and before the beginning of the comment text, assume two 576 | // blanks instead of the /* unless the first character after 577 | // the /* is a tab. If the first comment line is empty but 578 | // for the opening /*, assume up to 3 blanks or a tab. This 579 | // whitespace may be found as suffix in the common prefix. 580 | first := lines[0] 581 | if isBlank(first[2:]) { 582 | // no comment text on the first line: 583 | // reduce prefix by up to 3 blanks or a tab 584 | // if present - this keeps comment text indented 585 | // relative to the /* and */'s if it was indented 586 | // in the first place 587 | i := len(prefix) 588 | for n := 0; n < 3 && i > 0 && prefix[i-1] == ' '; n++ { 589 | i-- 590 | } 591 | if i == len(prefix) && i > 0 && prefix[i-1] == '\t' { 592 | i-- 593 | } 594 | prefix = prefix[0:i] 595 | } else { 596 | // comment text on the first line 597 | suffix := make([]byte, len(first)) 598 | n := 2 // start after opening /* 599 | for n < len(first) && first[n] <= ' ' { 600 | suffix[n] = first[n] 601 | n++ 602 | } 603 | if n > 2 && suffix[2] == '\t' { 604 | // assume the '\t' compensates for the /* 605 | suffix = suffix[2:n] 606 | } else { 607 | // otherwise assume two blanks 608 | suffix[0], suffix[1] = ' ', ' ' 609 | suffix = suffix[0:n] 610 | } 611 | // Shorten the computed common prefix by the length of 612 | // suffix, if it is found as suffix of the prefix. 613 | prefix = strings.TrimSuffix(prefix, string(suffix)) 614 | } 615 | } 616 | 617 | // Handle last line: If it only contains a closing */, align it 618 | // with the opening /*, otherwise align the text with the other 619 | // lines. 620 | last := lines[len(lines)-1] 621 | closing := "*/" 622 | i := strings.Index(last, closing) // i >= 0 (closing is always present) 623 | if isBlank(last[0:i]) { 624 | // last line only contains closing */ 625 | if lineOfStars { 626 | closing = " */" // add blank to align final star 627 | } 628 | lines[len(lines)-1] = prefix + closing 629 | } else { 630 | // last line contains more comment text - assume 631 | // it is aligned like the other lines and include 632 | // in prefix computation 633 | prefix = commonPrefix(prefix, last) 634 | } 635 | 636 | // Remove the common prefix from all but the first and empty lines. 637 | for i, line := range lines { 638 | if i > 0 && line != "" { 639 | lines[i] = line[len(prefix):] 640 | } 641 | } 642 | } 643 | 644 | func (p *printer) writeComment(comment *ast.Comment) { 645 | text := comment.Text 646 | pos := p.posFor(comment.Pos()) 647 | 648 | const linePrefix = "//line " 649 | if strings.HasPrefix(text, linePrefix) && (!pos.IsValid() || pos.Column == 1) { 650 | // Possibly a //-style line directive. 651 | // Suspend indentation temporarily to keep line directive valid. 652 | defer func(indent int) { p.indent = indent }(p.indent) 653 | p.indent = 0 654 | } 655 | 656 | // shortcut common case of //-style comments 657 | if text[1] == '/' { 658 | cstr := "#" + trimRight(text[2:]) 659 | p.writeString(pos, cstr, true) 660 | return 661 | } 662 | 663 | // for /*-style comments, print line by line and let the 664 | // write function take care of the proper indentation 665 | lines := strings.Split(text, "\n") 666 | 667 | // The comment started in the first column but is going 668 | // to be indented. For an idempotent result, add indentation 669 | // to all lines such that they look like they were indented 670 | // before - this will make sure the common prefix computation 671 | // is the same independent of how many times formatting is 672 | // applied (was issue 1835). 673 | if pos.IsValid() && pos.Column == 1 && p.indent > 0 { 674 | for i, line := range lines[1:] { 675 | lines[1+i] = " " + line 676 | } 677 | } 678 | 679 | stripCommonPrefix(lines) 680 | 681 | // write comment lines, separated by formfeed, 682 | // without a line break after the last line 683 | for i, line := range lines { 684 | if i > 0 { 685 | p.writeByte('\f', 1) 686 | pos = p.pos 687 | } 688 | if len(line) > 0 { 689 | p.writeString(pos, trimRight(line), true) 690 | } 691 | } 692 | } 693 | 694 | // writeCommentSuffix writes a line break after a comment if indicated 695 | // and processes any leftover indentation information. If a line break 696 | // is needed, the kind of break (newline vs formfeed) depends on the 697 | // pending whitespace. The writeCommentSuffix result indicates if a 698 | // newline was written or if a formfeed was dropped from the whitespace 699 | // buffer. 700 | // 701 | func (p *printer) writeCommentSuffix(needsLinebreak bool) (wroteNewline, droppedFF bool) { 702 | for i, ch := range p.wsbuf { 703 | switch ch { 704 | case blank, vtab: 705 | // ignore trailing whitespace 706 | p.wsbuf[i] = ignore 707 | case indent, unindent: 708 | // don't lose indentation information 709 | case newline, formfeed: 710 | // if we need a line break, keep exactly one 711 | // but remember if we dropped any formfeeds 712 | if needsLinebreak { 713 | needsLinebreak = false 714 | wroteNewline = true 715 | } else { 716 | if ch == formfeed { 717 | droppedFF = true 718 | } 719 | p.wsbuf[i] = ignore 720 | } 721 | } 722 | } 723 | p.writeWhitespace(len(p.wsbuf)) 724 | 725 | // make sure we have a line break 726 | if needsLinebreak { 727 | p.writeByte('\n', 1) 728 | wroteNewline = true 729 | } 730 | 731 | return 732 | } 733 | 734 | // containsLinebreak reports whether the whitespace buffer contains any line breaks. 735 | func (p *printer) containsLinebreak() bool { 736 | for _, ch := range p.wsbuf { 737 | if ch == newline || ch == formfeed { 738 | return true 739 | } 740 | } 741 | return false 742 | } 743 | 744 | // intersperseComments consumes all comments that appear before the next token 745 | // tok and prints it together with the buffered whitespace (i.e., the whitespace 746 | // that needs to be written before the next token). A heuristic is used to mix 747 | // the comments and whitespace. The intersperseComments result indicates if a 748 | // newline was written or if a formfeed was dropped from the whitespace buffer. 749 | // 750 | func (p *printer) intersperseComments(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) { 751 | var last *ast.Comment 752 | for p.commentBefore(next) { 753 | for _, c := range p.comment.List { 754 | p.writeCommentPrefix(p.posFor(c.Pos()), next, last, tok) 755 | p.writeComment(c) 756 | last = c 757 | } 758 | p.nextComment() 759 | } 760 | 761 | if last != nil { 762 | // If the last comment is a /*-style comment and the next item 763 | // follows on the same line but is not a comma, and not a "closing" 764 | // token immediately following its corresponding "opening" token, 765 | // add an extra separator unless explicitly disabled. Use a blank 766 | // as separator unless we have pending linebreaks, they are not 767 | // disabled, and we are outside a composite literal, in which case 768 | // we want a linebreak (issue 15137). 769 | // TODO(gri) This has become overly complicated. We should be able 770 | // to track whether we're inside an expression or statement and 771 | // use that information to decide more directly. 772 | needsLinebreak := false 773 | if p.mode&noExtraBlank == 0 && 774 | last.Text[1] == '*' && p.lineFor(last.Pos()) == next.Line && 775 | tok != token.COMMA && 776 | (tok != token.RPAREN || p.prevOpen == token.LPAREN) && 777 | (tok != token.RBRACK || p.prevOpen == token.LBRACK) { 778 | if p.containsLinebreak() && p.mode&noExtraLinebreak == 0 && p.level == 0 { 779 | needsLinebreak = true 780 | } else { 781 | p.writeByte(' ', 1) 782 | } 783 | } 784 | // Ensure that there is a line break after a //-style comment, 785 | // before EOF, and before a closing '}' unless explicitly disabled. 786 | if last.Text[1] == '/' || 787 | tok == token.EOF || 788 | tok == token.RBRACE && p.mode&noExtraLinebreak == 0 { 789 | needsLinebreak = true 790 | } 791 | return p.writeCommentSuffix(needsLinebreak) 792 | } 793 | 794 | // no comment was written - we should never reach here since 795 | // intersperseComments should not be called in that case 796 | p.internalError("intersperseComments called without pending comments") 797 | return 798 | } 799 | 800 | // whiteWhitespace writes the first n whitespace entries. 801 | func (p *printer) writeWhitespace(n int) { 802 | // write entries 803 | for i := 0; i < n; i++ { 804 | switch ch := p.wsbuf[i]; ch { 805 | case ignore: 806 | // ignore! 807 | case indent: 808 | p.indent++ 809 | case unindent: 810 | p.indent-- 811 | if p.indent < 0 { 812 | p.internalError("negative indentation:", p.indent) 813 | p.indent = 0 814 | } 815 | case newline, formfeed: 816 | // A line break immediately followed by a "correcting" 817 | // unindent is swapped with the unindent - this permits 818 | // proper label positioning. If a comment is between 819 | // the line break and the label, the unindent is not 820 | // part of the comment whitespace prefix and the comment 821 | // will be positioned correctly indented. 822 | if i+1 < n && p.wsbuf[i+1] == unindent { 823 | // Use a formfeed to terminate the current section. 824 | // Otherwise, a long label name on the next line leading 825 | // to a wide column may increase the indentation column 826 | // of lines before the label; effectively leading to wrong 827 | // indentation. 828 | p.wsbuf[i], p.wsbuf[i+1] = unindent, formfeed 829 | i-- // do it again 830 | continue 831 | } 832 | fallthrough 833 | default: 834 | p.writeByte(byte(ch), 1) 835 | } 836 | } 837 | 838 | // shift remaining entries down 839 | l := copy(p.wsbuf, p.wsbuf[n:]) 840 | p.wsbuf = p.wsbuf[:l] 841 | } 842 | 843 | // ---------------------------------------------------------------------------- 844 | // Printing interface 845 | 846 | // nlines limits n to maxNewlines. 847 | func nlimit(n int) int { 848 | if n > maxNewlines { 849 | n = maxNewlines 850 | } 851 | return n 852 | } 853 | 854 | func mayCombine(prev token.Token, next byte) (b bool) { 855 | switch prev { 856 | case token.INT: 857 | b = next == '.' // 1. 858 | case token.ADD: 859 | b = next == '+' // ++ 860 | case token.SUB: 861 | b = next == '-' // -- 862 | case token.QUO: 863 | b = next == '*' // /* 864 | case token.LSS: 865 | b = next == '-' || next == '<' // <- or << 866 | case token.AND: 867 | b = next == '&' || next == '^' // && or &^ 868 | } 869 | return 870 | } 871 | 872 | // print prints a list of "items" (roughly corresponding to syntactic 873 | // tokens, but also including whitespace and formatting information). 874 | // It is the only print function that should be called directly from 875 | // any of the AST printing functions in nodes.go. 876 | // 877 | // Whitespace is accumulated until a non-whitespace token appears. Any 878 | // comments that need to appear before that token are printed first, 879 | // taking into account the amount and structure of any pending white- 880 | // space for best comment placement. Then, any leftover whitespace is 881 | // printed, followed by the actual token. 882 | // 883 | func (p *printer) print(args ...interface{}) { 884 | for _, arg := range args { 885 | // information about the current arg 886 | var data string 887 | var isLit bool 888 | var impliedSemi bool // value for p.impliedSemi after this arg 889 | 890 | // record previous opening token, if any 891 | switch p.lastTok { 892 | case token.ILLEGAL: 893 | // ignore (white space) 894 | case token.LPAREN, token.LBRACK: 895 | p.prevOpen = p.lastTok 896 | default: 897 | // other tokens followed any opening token 898 | p.prevOpen = token.ILLEGAL 899 | } 900 | 901 | switch x := arg.(type) { 902 | case pmode: 903 | // toggle printer mode 904 | p.mode ^= x 905 | continue 906 | 907 | case whiteSpace: 908 | if x == ignore { 909 | // don't add ignore's to the buffer; they 910 | // may screw up "correcting" unindents (see 911 | // LabeledStmt) 912 | continue 913 | } 914 | i := len(p.wsbuf) 915 | if i == cap(p.wsbuf) { 916 | // Whitespace sequences are very short so this should 917 | // never happen. Handle gracefully (but possibly with 918 | // bad comment placement) if it does happen. 919 | p.writeWhitespace(i) 920 | i = 0 921 | } 922 | p.wsbuf = p.wsbuf[0 : i+1] 923 | p.wsbuf[i] = x 924 | if x == newline || x == formfeed { 925 | // newlines affect the current state (p.impliedSemi) 926 | // and not the state after printing arg (impliedSemi) 927 | // because comments can be interspersed before the arg 928 | // in this case 929 | p.impliedSemi = false 930 | } 931 | p.lastTok = token.ILLEGAL 932 | continue 933 | 934 | case *ast.Ident: 935 | data = x.Name 936 | impliedSemi = true 937 | p.lastTok = token.IDENT 938 | 939 | case *ast.BasicLit: 940 | data = x.Value 941 | isLit = true 942 | impliedSemi = true 943 | p.lastTok = x.Kind 944 | 945 | case token.Token: 946 | s := x.String() 947 | if mayCombine(p.lastTok, s[0]) { 948 | // the previous and the current token must be 949 | // separated by a blank otherwise they combine 950 | // into a different incorrect token sequence 951 | // (except for token.INT followed by a '.' this 952 | // should never happen because it is taken care 953 | // of via binary expression formatting) 954 | if len(p.wsbuf) != 0 { 955 | p.internalError("whitespace buffer not empty") 956 | } 957 | p.wsbuf = p.wsbuf[0:1] 958 | p.wsbuf[0] = ' ' 959 | } 960 | data = s 961 | // some keywords followed by a newline imply a semicolon 962 | switch x { 963 | case token.BREAK, token.CONTINUE, token.FALLTHROUGH, token.RETURN, 964 | token.INC, token.DEC, token.RPAREN, token.RBRACK, token.RBRACE: 965 | impliedSemi = true 966 | } 967 | p.lastTok = x 968 | 969 | case token.Pos: 970 | if x.IsValid() { 971 | p.pos = p.posFor(x) // accurate position of next item 972 | } 973 | continue 974 | 975 | case string: 976 | // incorrect AST - print error message 977 | data = x 978 | isLit = true 979 | impliedSemi = true 980 | p.lastTok = token.STRING 981 | 982 | default: 983 | fmt.Fprintf(os.Stderr, "print: unsupported argument %v (%T)\n", arg, arg) 984 | panic("go/printer type") 985 | } 986 | // data != "" 987 | 988 | next := p.pos // estimated/accurate position of next item 989 | wroteNewline, droppedFF := p.flush(next, p.lastTok) 990 | 991 | // intersperse extra newlines if present in the source and 992 | // if they don't cause extra semicolons (don't do this in 993 | // flush as it will cause extra newlines at the end of a file) 994 | if !p.impliedSemi { 995 | n := nlimit(next.Line - p.pos.Line) 996 | // don't exceed maxNewlines if we already wrote one 997 | if wroteNewline && n == maxNewlines { 998 | n = maxNewlines - 1 999 | } 1000 | if n > 0 { 1001 | ch := byte('\n') 1002 | if droppedFF { 1003 | ch = '\f' // use formfeed since we dropped one before 1004 | } 1005 | p.writeByte(ch, n) 1006 | impliedSemi = false 1007 | } 1008 | } 1009 | 1010 | // the next token starts now - record its line number if requested 1011 | if p.linePtr != nil { 1012 | *p.linePtr = p.out.Line 1013 | p.linePtr = nil 1014 | } 1015 | 1016 | switch data { 1017 | case "true": 1018 | data = "True" 1019 | case "false": 1020 | data = "False" 1021 | case "nil": 1022 | if p.Config.Mode&GoPy != 0 { 1023 | data = "go.nil" 1024 | } 1025 | } 1026 | 1027 | p.writeString(next, data, isLit) 1028 | p.impliedSemi = impliedSemi 1029 | } 1030 | } 1031 | 1032 | // flush prints any pending comments and whitespace occurring textually 1033 | // before the position of the next token tok. The flush result indicates 1034 | // if a newline was written or if a formfeed was dropped from the whitespace 1035 | // buffer. 1036 | // 1037 | func (p *printer) flush(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) { 1038 | if p.commentBefore(next) { 1039 | // if there are comments before the next item, intersperse them 1040 | wroteNewline, droppedFF = p.intersperseComments(next, tok) 1041 | } else { 1042 | // otherwise, write any leftover whitespace 1043 | p.writeWhitespace(len(p.wsbuf)) 1044 | } 1045 | return 1046 | } 1047 | 1048 | // getNode returns the ast.CommentGroup associated with n, if any. 1049 | func getDoc(n ast.Node) *ast.CommentGroup { 1050 | switch n := n.(type) { 1051 | case *ast.Field: 1052 | return n.Doc 1053 | case *ast.ImportSpec: 1054 | return n.Doc 1055 | case *ast.ValueSpec: 1056 | return n.Doc 1057 | case *ast.TypeSpec: 1058 | return n.Doc 1059 | case *ast.GenDecl: 1060 | return n.Doc 1061 | // case *ast.FuncDecl: // handled separately 1062 | // return n.Doc 1063 | case *ast.File: 1064 | return n.Doc 1065 | } 1066 | return nil 1067 | } 1068 | 1069 | func getLastComment(n ast.Node) *ast.CommentGroup { 1070 | switch n := n.(type) { 1071 | case *ast.Field: 1072 | return n.Comment 1073 | case *ast.ImportSpec: 1074 | return n.Comment 1075 | case *ast.ValueSpec: 1076 | return n.Comment 1077 | case *ast.TypeSpec: 1078 | return n.Comment 1079 | case *ast.GenDecl: 1080 | if len(n.Specs) > 0 { 1081 | return getLastComment(n.Specs[len(n.Specs)-1]) 1082 | } 1083 | case *ast.File: 1084 | if len(n.Comments) > 0 { 1085 | return n.Comments[len(n.Comments)-1] 1086 | } 1087 | } 1088 | return nil 1089 | } 1090 | 1091 | func (p *printer) printNode(node interface{}) error { 1092 | // unpack *CommentedNode, if any 1093 | var comments []*ast.CommentGroup 1094 | if cnode, ok := node.(*CommentedNode); ok { 1095 | node = cnode.Node 1096 | comments = cnode.Comments 1097 | } 1098 | 1099 | if comments != nil { 1100 | // commented node - restrict comment list to relevant range 1101 | n, ok := node.(ast.Node) 1102 | if !ok { 1103 | goto unsupported 1104 | } 1105 | beg := n.Pos() 1106 | end := n.End() 1107 | // if the node has associated documentation, 1108 | // include that commentgroup in the range 1109 | // (the comment list is sorted in the order 1110 | // of the comment appearance in the source code) 1111 | if doc := getDoc(n); doc != nil { 1112 | beg = doc.Pos() 1113 | } 1114 | if com := getLastComment(n); com != nil { 1115 | if e := com.End(); e > end { 1116 | end = e 1117 | } 1118 | } 1119 | // token.Pos values are global offsets, we can 1120 | // compare them directly 1121 | i := 0 1122 | for i < len(comments) && comments[i].End() < beg { 1123 | i++ 1124 | } 1125 | j := i 1126 | for j < len(comments) && comments[j].Pos() < end { 1127 | j++ 1128 | } 1129 | if i < j { 1130 | p.comments = comments[i:j] 1131 | } 1132 | } else if n, ok := node.(*ast.File); ok { 1133 | // use ast.File comments, if any 1134 | // note: this seems to be typical source for a parsed .go file 1135 | p.comments = n.Comments 1136 | } 1137 | 1138 | // if there are no comments, use node comments 1139 | p.useNodeComments = p.comments == nil 1140 | 1141 | // get comments ready for use 1142 | p.nextComment() 1143 | 1144 | // format node 1145 | switch n := node.(type) { 1146 | case ast.Expr: 1147 | p.expr(n) 1148 | case ast.Stmt: 1149 | // A labeled statement will un-indent to position the label. 1150 | // Set p.indent to 1 so we don't get indent "underflow". 1151 | if _, ok := n.(*ast.LabeledStmt); ok { 1152 | p.indent = 1 1153 | } 1154 | p.stmt(n, false) 1155 | case ast.Decl: 1156 | p.decl(n) 1157 | case ast.Spec: 1158 | p.spec(n, 1, false) 1159 | case []ast.Stmt: 1160 | // A labeled statement will un-indent to position the label. 1161 | // Set p.indent to 1 so we don't get indent "underflow". 1162 | for _, s := range n { 1163 | if _, ok := s.(*ast.LabeledStmt); ok { 1164 | p.indent = 1 1165 | } 1166 | } 1167 | p.stmtList(n, 0, false) 1168 | case []ast.Decl: 1169 | p.declList(n) 1170 | case *ast.File: 1171 | p.file(n) 1172 | default: 1173 | goto unsupported 1174 | } 1175 | 1176 | return nil 1177 | 1178 | unsupported: 1179 | return fmt.Errorf("go/printer: unsupported node type %T", node) 1180 | } 1181 | 1182 | // ---------------------------------------------------------------------------- 1183 | // Trimmer 1184 | 1185 | // A trimmer is an io.Writer filter for stripping tabwriter.Escape 1186 | // characters, trailing blanks and tabs, and for converting formfeed 1187 | // and vtab characters into newlines and htabs (in case no tabwriter 1188 | // is used). Text bracketed by tabwriter.Escape characters is passed 1189 | // through unchanged. 1190 | // 1191 | type trimmer struct { 1192 | output io.Writer 1193 | state int 1194 | space []byte 1195 | } 1196 | 1197 | // trimmer is implemented as a state machine. 1198 | // It can be in one of the following states: 1199 | const ( 1200 | inSpace = iota // inside space 1201 | inEscape // inside text bracketed by tabwriter.Escapes 1202 | inText // inside text 1203 | ) 1204 | 1205 | func (p *trimmer) resetSpace() { 1206 | p.state = inSpace 1207 | p.space = p.space[0:0] 1208 | } 1209 | 1210 | // Design note: It is tempting to eliminate extra blanks occurring in 1211 | // whitespace in this function as it could simplify some 1212 | // of the blanks logic in the node printing functions. 1213 | // However, this would mess up any formatting done by 1214 | // the tabwriter. 1215 | 1216 | var aNewline = []byte("\n") 1217 | 1218 | func (p *trimmer) Write(data []byte) (n int, err error) { 1219 | // invariants: 1220 | // p.state == inSpace: 1221 | // p.space is unwritten 1222 | // p.state == inEscape, inText: 1223 | // data[m:n] is unwritten 1224 | m := 0 1225 | var b byte 1226 | for n, b = range data { 1227 | if b == '\v' { 1228 | b = '\t' // convert to htab 1229 | } 1230 | switch p.state { 1231 | case inSpace: 1232 | switch b { 1233 | case '\t', ' ': 1234 | p.space = append(p.space, b) 1235 | case '\n', '\f': 1236 | p.resetSpace() // discard trailing space 1237 | _, err = p.output.Write(aNewline) 1238 | case tabwriter.Escape: 1239 | _, err = p.output.Write(p.space) 1240 | p.state = inEscape 1241 | m = n + 1 // +1: skip tabwriter.Escape 1242 | default: 1243 | _, err = p.output.Write(p.space) 1244 | p.state = inText 1245 | m = n 1246 | } 1247 | case inEscape: 1248 | if b == tabwriter.Escape { 1249 | _, err = p.output.Write(data[m:n]) 1250 | p.resetSpace() 1251 | } 1252 | case inText: 1253 | switch b { 1254 | case '\t', ' ': 1255 | _, err = p.output.Write(data[m:n]) 1256 | p.resetSpace() 1257 | p.space = append(p.space, b) 1258 | case '\n', '\f': 1259 | _, err = p.output.Write(data[m:n]) 1260 | p.resetSpace() 1261 | if err == nil { 1262 | _, err = p.output.Write(aNewline) 1263 | } 1264 | case tabwriter.Escape: 1265 | _, err = p.output.Write(data[m:n]) 1266 | p.state = inEscape 1267 | m = n + 1 // +1: skip tabwriter.Escape 1268 | } 1269 | default: 1270 | panic("unreachable") 1271 | } 1272 | if err != nil { 1273 | return 1274 | } 1275 | } 1276 | n = len(data) 1277 | 1278 | switch p.state { 1279 | case inEscape, inText: 1280 | _, err = p.output.Write(data[m:n]) 1281 | p.resetSpace() 1282 | } 1283 | 1284 | return 1285 | } 1286 | 1287 | // ---------------------------------------------------------------------------- 1288 | // Public interface 1289 | 1290 | // A Mode value is a set of flags (or 0). They control printing. 1291 | type Mode uint 1292 | 1293 | const ( 1294 | RawFormat Mode = 1 << iota // do not use a tabwriter; if set, UseSpaces is ignored 1295 | TabIndent // use tabs for indentation independent of UseSpaces 1296 | UseSpaces // use spaces instead of tabs for alignment 1297 | SourcePos // emit //line directives to preserve original source positions 1298 | GoPy // support GoPy-specific Python code generation 1299 | GoGi // support GoGi-specific Python code generation 1300 | ) 1301 | 1302 | // The mode below is not included in printer's public API because 1303 | // editing code text is deemed out of scope. Because this mode is 1304 | // unexported, it's also possible to modify or remove it based on 1305 | // the evolving needs of go/format and cmd/gofmt without breaking 1306 | // users. See discussion in CL 240683. 1307 | const ( 1308 | // normalizeNumbers means to canonicalize number 1309 | // literal prefixes and exponents while printing. 1310 | // 1311 | // This value is known in and used by go/format and cmd/gofmt. 1312 | // It is currently more convenient and performant for those 1313 | // packages to apply number normalization during printing, 1314 | // rather than by modifying the AST in advance. 1315 | normalizeNumbers Mode = 1 << 30 1316 | ) 1317 | 1318 | // A Config node controls the output of Fprint. 1319 | type Config struct { 1320 | Mode Mode // default: 0 1321 | Tabwidth int // default: 8 1322 | Indent int // default: 0 (all code is indented at least by this much) 1323 | } 1324 | 1325 | // fprint implements Fprint and takes a nodesSizes map for setting up the printer state. 1326 | func (cfg *Config) fprint(output io.Writer, fset *token.FileSet, node interface{}, nodeSizes map[ast.Node]int) (err error) { 1327 | // print node 1328 | var p printer 1329 | p.init(cfg, fset, nodeSizes) 1330 | if err = p.printNode(node); err != nil { 1331 | return 1332 | } 1333 | // print outstanding comments 1334 | p.impliedSemi = false // EOF acts like a newline 1335 | p.flush(token.Position{Offset: infinity, Line: infinity}, token.EOF) 1336 | 1337 | // redirect output through a trimmer to eliminate trailing whitespace 1338 | // (Input to a tabwriter must be untrimmed since trailing tabs provide 1339 | // formatting information. The tabwriter could provide trimming 1340 | // functionality but no tabwriter is used when RawFormat is set.) 1341 | output = &trimmer{output: output} 1342 | 1343 | // redirect output through a tabwriter if necessary 1344 | if cfg.Mode&RawFormat == 0 { 1345 | minwidth := cfg.Tabwidth 1346 | 1347 | padchar := byte('\t') 1348 | if cfg.Mode&UseSpaces != 0 { 1349 | padchar = ' ' 1350 | } 1351 | 1352 | twmode := tabwriter.DiscardEmptyColumns 1353 | if cfg.Mode&TabIndent != 0 { 1354 | minwidth = 0 1355 | twmode |= tabwriter.TabIndent 1356 | } 1357 | 1358 | output = tabwriter.NewWriter(output, minwidth, cfg.Tabwidth, 1, padchar, twmode) 1359 | } 1360 | 1361 | // write printer result via tabwriter/trimmer to output 1362 | if _, err = output.Write(p.output); err != nil { 1363 | return 1364 | } 1365 | 1366 | // flush tabwriter, if any 1367 | if tw, _ := output.(*tabwriter.Writer); tw != nil { 1368 | err = tw.Flush() 1369 | } 1370 | 1371 | return 1372 | } 1373 | 1374 | // A CommentedNode bundles an AST node and corresponding comments. 1375 | // It may be provided as argument to any of the Fprint functions. 1376 | // 1377 | type CommentedNode struct { 1378 | Node interface{} // *ast.File, or ast.Expr, ast.Decl, ast.Spec, or ast.Stmt 1379 | Comments []*ast.CommentGroup 1380 | } 1381 | 1382 | // Fprint "pretty-prints" an AST node to output for a given configuration cfg. 1383 | // Position information is interpreted relative to the file set fset. 1384 | // The node type must be *ast.File, *CommentedNode, []ast.Decl, []ast.Stmt, 1385 | // or assignment-compatible to ast.Expr, ast.Decl, ast.Spec, or ast.Stmt. 1386 | // 1387 | func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) error { 1388 | return cfg.fprint(output, fset, node, make(map[ast.Node]int)) 1389 | } 1390 | 1391 | // Fprint "pretty-prints" an AST node to output. 1392 | // It calls Config.Fprint with default settings. 1393 | // Note that gofmt uses tabs for indentation but spaces for alignment; 1394 | // use format.Node (package go/format) for output that matches gofmt. 1395 | // 1396 | func Fprint(output io.Writer, fset *token.FileSet, node interface{}) error { 1397 | return (&Config{Tabwidth: 8}).Fprint(output, fset, node) 1398 | } 1399 | -------------------------------------------------------------------------------- /rewrite.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go-Python Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // This is based on gofmt source code: 6 | 7 | // Copyright 2009 The Go Authors. All rights reserved. 8 | // Use of this source code is governed by a BSD-style 9 | // license that can be found in the LICENSE file. 10 | 11 | package main 12 | 13 | import ( 14 | "fmt" 15 | "go/ast" 16 | "go/parser" 17 | "go/token" 18 | "os" 19 | "reflect" 20 | "strings" 21 | "unicode" 22 | "unicode/utf8" 23 | ) 24 | 25 | func initRewrite() { 26 | if *rewriteRule == "" { 27 | rewrite = nil // disable any previous rewrite 28 | return 29 | } 30 | f := strings.Split(*rewriteRule, "->") 31 | if len(f) != 2 { 32 | fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n") 33 | os.Exit(2) 34 | } 35 | pattern := parseExpr(f[0], "pattern") 36 | replace := parseExpr(f[1], "replacement") 37 | rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) } 38 | } 39 | 40 | // parseExpr parses s as an expression. 41 | // It might make sense to expand this to allow statement patterns, 42 | // but there are problems with preserving formatting and also 43 | // with what a wildcard for a statement looks like. 44 | func parseExpr(s, what string) ast.Expr { 45 | x, err := parser.ParseExpr(s) 46 | if err != nil { 47 | fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err) 48 | os.Exit(2) 49 | } 50 | return x 51 | } 52 | 53 | // Keep this function for debugging. 54 | /* 55 | func dump(msg string, val reflect.Value) { 56 | fmt.Printf("%s:\n", msg) 57 | ast.Print(fileSet, val.Interface()) 58 | fmt.Println() 59 | } 60 | */ 61 | 62 | // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file. 63 | func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { 64 | cmap := ast.NewCommentMap(fileSet, p, p.Comments) 65 | m := make(map[string]reflect.Value) 66 | pat := reflect.ValueOf(pattern) 67 | repl := reflect.ValueOf(replace) 68 | 69 | var rewriteVal func(val reflect.Value) reflect.Value 70 | rewriteVal = func(val reflect.Value) reflect.Value { 71 | // don't bother if val is invalid to start with 72 | if !val.IsValid() { 73 | return reflect.Value{} 74 | } 75 | val = apply(rewriteVal, val) 76 | for k := range m { 77 | delete(m, k) 78 | } 79 | if match(m, pat, val) { 80 | val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos())) 81 | } 82 | return val 83 | } 84 | 85 | r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File) 86 | r.Comments = cmap.Filter(r).Comments() // recreate comments list 87 | return r 88 | } 89 | 90 | // set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y. 91 | func set(x, y reflect.Value) { 92 | // don't bother if x cannot be set or y is invalid 93 | if !x.CanSet() || !y.IsValid() { 94 | return 95 | } 96 | defer func() { 97 | if x := recover(); x != nil { 98 | if s, ok := x.(string); ok && 99 | (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) { 100 | // x cannot be set to y - ignore this rewrite 101 | return 102 | } 103 | panic(x) 104 | } 105 | }() 106 | x.Set(y) 107 | } 108 | 109 | // Values/types for special cases. 110 | var ( 111 | objectPtrNil = reflect.ValueOf((*ast.Object)(nil)) 112 | scopePtrNil = reflect.ValueOf((*ast.Scope)(nil)) 113 | 114 | identType = reflect.TypeOf((*ast.Ident)(nil)) 115 | objectPtrType = reflect.TypeOf((*ast.Object)(nil)) 116 | positionType = reflect.TypeOf(token.NoPos) 117 | callExprType = reflect.TypeOf((*ast.CallExpr)(nil)) 118 | scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) 119 | ) 120 | 121 | // apply replaces each AST field x in val with f(x), returning val. 122 | // To avoid extra conversions, f operates on the reflect.Value form. 123 | func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { 124 | if !val.IsValid() { 125 | return reflect.Value{} 126 | } 127 | 128 | // *ast.Objects introduce cycles and are likely incorrect after 129 | // rewrite; don't follow them but replace with nil instead 130 | if val.Type() == objectPtrType { 131 | return objectPtrNil 132 | } 133 | 134 | // similarly for scopes: they are likely incorrect after a rewrite; 135 | // replace them with nil 136 | if val.Type() == scopePtrType { 137 | return scopePtrNil 138 | } 139 | 140 | switch v := reflect.Indirect(val); v.Kind() { 141 | case reflect.Slice: 142 | for i := 0; i < v.Len(); i++ { 143 | e := v.Index(i) 144 | set(e, f(e)) 145 | } 146 | case reflect.Struct: 147 | for i := 0; i < v.NumField(); i++ { 148 | e := v.Field(i) 149 | set(e, f(e)) 150 | } 151 | case reflect.Interface: 152 | e := v.Elem() 153 | set(v, f(e)) 154 | } 155 | return val 156 | } 157 | 158 | func isWildcard(s string) bool { 159 | rune, size := utf8.DecodeRuneInString(s) 160 | return size == len(s) && unicode.IsLower(rune) 161 | } 162 | 163 | // match reports whether pattern matches val, 164 | // recording wildcard submatches in m. 165 | // If m == nil, match checks whether pattern == val. 166 | func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { 167 | // Wildcard matches any expression. If it appears multiple 168 | // times in the pattern, it must match the same expression 169 | // each time. 170 | if m != nil && pattern.IsValid() && pattern.Type() == identType { 171 | name := pattern.Interface().(*ast.Ident).Name 172 | if isWildcard(name) && val.IsValid() { 173 | // wildcards only match valid (non-nil) expressions. 174 | if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() { 175 | if old, ok := m[name]; ok { 176 | return match(nil, old, val) 177 | } 178 | m[name] = val 179 | return true 180 | } 181 | } 182 | } 183 | 184 | // Otherwise, pattern and val must match recursively. 185 | if !pattern.IsValid() || !val.IsValid() { 186 | return !pattern.IsValid() && !val.IsValid() 187 | } 188 | if pattern.Type() != val.Type() { 189 | return false 190 | } 191 | 192 | // Special cases. 193 | switch pattern.Type() { 194 | case identType: 195 | // For identifiers, only the names need to match 196 | // (and none of the other *ast.Object information). 197 | // This is a common case, handle it all here instead 198 | // of recursing down any further via reflection. 199 | p := pattern.Interface().(*ast.Ident) 200 | v := val.Interface().(*ast.Ident) 201 | return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name 202 | case objectPtrType, positionType: 203 | // object pointers and token positions always match 204 | return true 205 | case callExprType: 206 | // For calls, the Ellipsis fields (token.Position) must 207 | // match since that is how f(x) and f(x...) are different. 208 | // Check them here but fall through for the remaining fields. 209 | p := pattern.Interface().(*ast.CallExpr) 210 | v := val.Interface().(*ast.CallExpr) 211 | if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() { 212 | return false 213 | } 214 | } 215 | 216 | p := reflect.Indirect(pattern) 217 | v := reflect.Indirect(val) 218 | if !p.IsValid() || !v.IsValid() { 219 | return !p.IsValid() && !v.IsValid() 220 | } 221 | 222 | switch p.Kind() { 223 | case reflect.Slice: 224 | if p.Len() != v.Len() { 225 | return false 226 | } 227 | for i := 0; i < p.Len(); i++ { 228 | if !match(m, p.Index(i), v.Index(i)) { 229 | return false 230 | } 231 | } 232 | return true 233 | 234 | case reflect.Struct: 235 | for i := 0; i < p.NumField(); i++ { 236 | if !match(m, p.Field(i), v.Field(i)) { 237 | return false 238 | } 239 | } 240 | return true 241 | 242 | case reflect.Interface: 243 | return match(m, p.Elem(), v.Elem()) 244 | } 245 | 246 | // Handle token integers, etc. 247 | return p.Interface() == v.Interface() 248 | } 249 | 250 | // subst returns a copy of pattern with values from m substituted in place 251 | // of wildcards and pos used as the position of tokens from the pattern. 252 | // if m == nil, subst returns a copy of pattern and doesn't change the line 253 | // number information. 254 | func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value { 255 | if !pattern.IsValid() { 256 | return reflect.Value{} 257 | } 258 | 259 | // Wildcard gets replaced with map value. 260 | if m != nil && pattern.Type() == identType { 261 | name := pattern.Interface().(*ast.Ident).Name 262 | if isWildcard(name) { 263 | if old, ok := m[name]; ok { 264 | return subst(nil, old, reflect.Value{}) 265 | } 266 | } 267 | } 268 | 269 | if pos.IsValid() && pattern.Type() == positionType { 270 | // use new position only if old position was valid in the first place 271 | if old := pattern.Interface().(token.Pos); !old.IsValid() { 272 | return pattern 273 | } 274 | return pos 275 | } 276 | 277 | // Otherwise copy. 278 | switch p := pattern; p.Kind() { 279 | case reflect.Slice: 280 | if p.IsNil() { 281 | // Do not turn nil slices into empty slices. go/ast 282 | // guarantees that certain lists will be nil if not 283 | // populated. 284 | return reflect.Zero(p.Type()) 285 | } 286 | v := reflect.MakeSlice(p.Type(), p.Len(), p.Len()) 287 | for i := 0; i < p.Len(); i++ { 288 | v.Index(i).Set(subst(m, p.Index(i), pos)) 289 | } 290 | return v 291 | 292 | case reflect.Struct: 293 | v := reflect.New(p.Type()).Elem() 294 | for i := 0; i < p.NumField(); i++ { 295 | v.Field(i).Set(subst(m, p.Field(i), pos)) 296 | } 297 | return v 298 | 299 | case reflect.Ptr: 300 | v := reflect.New(p.Type()).Elem() 301 | if elem := p.Elem(); elem.IsValid() { 302 | v.Set(subst(m, elem, pos).Addr()) 303 | } 304 | return v 305 | 306 | case reflect.Interface: 307 | v := reflect.New(p.Type()).Elem() 308 | if elem := p.Elem(); elem.IsValid() { 309 | v.Set(subst(m, elem, pos)) 310 | } 311 | return v 312 | } 313 | 314 | return pattern 315 | } 316 | -------------------------------------------------------------------------------- /simplify.go: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Go-Python Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // This is based on gofmt source code: 6 | 7 | // Copyright 2010 The Go Authors. All rights reserved. 8 | // Use of this source code is governed by a BSD-style 9 | // license that can be found in the LICENSE file. 10 | 11 | package main 12 | 13 | import ( 14 | "go/ast" 15 | "go/token" 16 | "reflect" 17 | ) 18 | 19 | type simplifier struct{} 20 | 21 | func (s simplifier) Visit(node ast.Node) ast.Visitor { 22 | switch n := node.(type) { 23 | case *ast.CompositeLit: 24 | // array, slice, and map composite literals may be simplified 25 | outer := n 26 | var keyType, eltType ast.Expr 27 | switch typ := outer.Type.(type) { 28 | case *ast.ArrayType: 29 | eltType = typ.Elt 30 | case *ast.MapType: 31 | keyType = typ.Key 32 | eltType = typ.Value 33 | } 34 | 35 | if eltType != nil { 36 | var ktyp reflect.Value 37 | if keyType != nil { 38 | ktyp = reflect.ValueOf(keyType) 39 | } 40 | typ := reflect.ValueOf(eltType) 41 | for i, x := range outer.Elts { 42 | px := &outer.Elts[i] 43 | // look at value of indexed/named elements 44 | if t, ok := x.(*ast.KeyValueExpr); ok { 45 | if keyType != nil { 46 | s.simplifyLiteral(ktyp, keyType, t.Key, &t.Key) 47 | } 48 | x = t.Value 49 | px = &t.Value 50 | } 51 | s.simplifyLiteral(typ, eltType, x, px) 52 | } 53 | // node was simplified - stop walk (there are no subnodes to simplify) 54 | return nil 55 | } 56 | 57 | case *ast.SliceExpr: 58 | // a slice expression of the form: s[a:len(s)] 59 | // can be simplified to: s[a:] 60 | // if s is "simple enough" (for now we only accept identifiers) 61 | // 62 | // Note: This may not be correct because len may have been redeclared in another 63 | // file belonging to the same package. However, this is extremely unlikely 64 | // and so far (April 2016, after years of supporting this rewrite feature) 65 | // has never come up, so let's keep it working as is (see also #15153). 66 | if n.Max != nil { 67 | // - 3-index slices always require the 2nd and 3rd index 68 | break 69 | } 70 | if s, _ := n.X.(*ast.Ident); s != nil && s.Obj != nil { 71 | // the array/slice object is a single, resolved identifier 72 | if call, _ := n.High.(*ast.CallExpr); call != nil && len(call.Args) == 1 && !call.Ellipsis.IsValid() { 73 | // the high expression is a function call with a single argument 74 | if fun, _ := call.Fun.(*ast.Ident); fun != nil && fun.Name == "len" && fun.Obj == nil { 75 | // the function called is "len" and it is not locally defined; and 76 | // because we don't have dot imports, it must be the predefined len() 77 | if arg, _ := call.Args[0].(*ast.Ident); arg != nil && arg.Obj == s.Obj { 78 | // the len argument is the array/slice object 79 | n.High = nil 80 | } 81 | } 82 | } 83 | } 84 | // Note: We could also simplify slice expressions of the form s[0:b] to s[:b] 85 | // but we leave them as is since sometimes we want to be very explicit 86 | // about the lower bound. 87 | // An example where the 0 helps: 88 | // x, y, z := b[0:2], b[2:4], b[4:6] 89 | // An example where it does not: 90 | // x, y := b[:n], b[n:] 91 | 92 | case *ast.RangeStmt: 93 | // - a range of the form: for x, _ = range v {...} 94 | // can be simplified to: for x = range v {...} 95 | // - a range of the form: for _ = range v {...} 96 | // can be simplified to: for range v {...} 97 | if isBlank(n.Value) { 98 | n.Value = nil 99 | } 100 | if isBlank(n.Key) && n.Value == nil { 101 | n.Key = nil 102 | } 103 | } 104 | 105 | return s 106 | } 107 | 108 | func (s simplifier) simplifyLiteral(typ reflect.Value, astType, x ast.Expr, px *ast.Expr) { 109 | ast.Walk(s, x) // simplify x 110 | 111 | // if the element is a composite literal and its literal type 112 | // matches the outer literal's element type exactly, the inner 113 | // literal type may be omitted 114 | if inner, ok := x.(*ast.CompositeLit); ok { 115 | if match(nil, typ, reflect.ValueOf(inner.Type)) { 116 | inner.Type = nil 117 | } 118 | } 119 | // if the outer literal's element type is a pointer type *T 120 | // and the element is & of a composite literal of type T, 121 | // the inner &T may be omitted. 122 | if ptr, ok := astType.(*ast.StarExpr); ok { 123 | if addr, ok := x.(*ast.UnaryExpr); ok && addr.Op == token.AND { 124 | if inner, ok := addr.X.(*ast.CompositeLit); ok { 125 | if match(nil, reflect.ValueOf(ptr.X), reflect.ValueOf(inner.Type)) { 126 | inner.Type = nil // drop T 127 | *px = inner // drop & 128 | } 129 | } 130 | } 131 | } 132 | } 133 | 134 | func isBlank(x ast.Expr) bool { 135 | ident, ok := x.(*ast.Ident) 136 | return ok && ident.Name == "_" 137 | } 138 | 139 | func simplify(f *ast.File) { 140 | // remove empty declarations such as "const ()", etc 141 | removeEmptyDeclGroups(f) 142 | 143 | var s simplifier 144 | ast.Walk(s, f) 145 | } 146 | 147 | func removeEmptyDeclGroups(f *ast.File) { 148 | i := 0 149 | for _, d := range f.Decls { 150 | if g, ok := d.(*ast.GenDecl); !ok || !isEmpty(f, g) { 151 | f.Decls[i] = d 152 | i++ 153 | } 154 | } 155 | f.Decls = f.Decls[:i] 156 | } 157 | 158 | func isEmpty(f *ast.File, g *ast.GenDecl) bool { 159 | if g.Doc != nil || g.Specs != nil { 160 | return false 161 | } 162 | 163 | for _, c := range f.Comments { 164 | // if there is a comment in the declaration, it is not considered empty 165 | if g.Pos() <= c.Pos() && c.End() <= g.End() { 166 | return false 167 | } 168 | } 169 | 170 | return true 171 | } 172 | -------------------------------------------------------------------------------- /testdata/basic.golden: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | GlobInt int = 2 4 | GlobStr = "a string" 5 | GlobBool = False 6 | 7 | class MyStru: 8 | # A struct definition 9 | """ 10 | A struct definition 11 | 12 | # field desc 13 | """ 14 | 15 | def __init__(self): 16 | self.A = int() 17 | self.B = float() # `desc:"field tag"` 18 | self.C = str() # `desc:"more tags"` 19 | 20 | def MethOne(st, arg1): 21 | """ 22 | MethOne does something 23 | """ 24 | rv = st.A 25 | for a in SomeList : 26 | rv += a 27 | st.A = True 28 | 29 | ano = MyStru(A= 22, B= 44.2, C= "happy") 30 | 31 | return rv 32 | 33 | def MethTwo(st, arg1, arg2, arg3): 34 | """ 35 | MethTwo does something 36 | it is pretty cool 37 | not really sure about that 38 | """ 39 | rv = st.A 40 | for a in range(100): 41 | rv += a 42 | switch rv: 43 | if 100: 44 | rv *= 2 45 | if 500: 46 | rv /= 5 47 | return rv 48 | 49 | 50 | # A global function 51 | def GlobFun(a, b): 52 | """ 53 | A global function 54 | """ 55 | if a > b and a == 0 or b == 0: 56 | return a + b 57 | elif a == b: 58 | return a * b 59 | else: 60 | return a - b 61 | 62 | -------------------------------------------------------------------------------- /testdata/basic.input: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | var ( 4 | GlobInt int = 2 5 | GlobStr = "a string" 6 | GlobBool = false 7 | ) 8 | 9 | // A struct definition 10 | type MyStru struct { 11 | A int // field desc 12 | B float32 `desc:"field tag"` 13 | C string `desc:"more tags"` 14 | } 15 | 16 | // A global function 17 | func GlobFun(a, b float32) float32 { 18 | if a > b && a == 0 || b == 0 { 19 | return a + b 20 | } else if a == b { 21 | return a * b 22 | } else { 23 | return a - b 24 | } 25 | } 26 | 27 | // MethOne does something 28 | func (st *MyStru) MethOne(arg1 float32) int { 29 | rv := st.A 30 | for _, a := range SomeList { 31 | rv += a 32 | } 33 | st.A = true 34 | 35 | ano := MyStru{A: 22, B: 44.2, C: "happy"} 36 | 37 | return rv 38 | } 39 | 40 | // MethTwo does something 41 | // it is pretty cool 42 | // not really sure about that 43 | func (st *MyStru) MethTwo(arg1, arg2 float32, arg3 int) int { 44 | rv := st.A 45 | for a := 0; a < 100; a++ { 46 | rv += a 47 | } 48 | switch rv { 49 | case 100: 50 | rv *= 2 51 | case 500: 52 | rv /= 5 53 | } 54 | return rv 55 | } 56 | -------------------------------------------------------------------------------- /testdata/ra25.input: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, The Emergent Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //gofmt -gogi 6 | 7 | // ra25 runs a simple random-associator four-layer leabra network 8 | // that uses the standard supervised learning paradigm to learn 9 | // mappings between 25 random input / output patterns 10 | // defined over 5x5 input / output layers (i.e., 25 units) 11 | package main 12 | 13 | import ( 14 | "flag" 15 | "fmt" 16 | "log" 17 | "math/rand" 18 | "os" 19 | "strconv" 20 | "strings" 21 | "time" 22 | 23 | "github.com/emer/emergent/emer" 24 | "github.com/emer/emergent/env" 25 | "github.com/emer/emergent/netview" 26 | "github.com/emer/emergent/params" 27 | "github.com/emer/emergent/patgen" 28 | "github.com/emer/emergent/prjn" 29 | "github.com/emer/emergent/relpos" 30 | "github.com/emer/etable/agg" 31 | "github.com/emer/etable/eplot" 32 | "github.com/emer/etable/etable" 33 | "github.com/emer/etable/etensor" 34 | _ "github.com/emer/etable/etview" // include to get gui views 35 | "github.com/emer/etable/split" 36 | "github.com/emer/leabra/leabra" 37 | "github.com/goki/gi/gi" 38 | "github.com/goki/gi/gimain" 39 | "github.com/goki/gi/giv" 40 | "github.com/goki/ki/ki" 41 | "github.com/goki/ki/kit" 42 | "github.com/goki/mat32" 43 | ) 44 | 45 | func main() { 46 | TheSim.New() 47 | TheSim.Config() 48 | if len(os.Args) > 1 { 49 | TheSim.CmdArgs() // simple assumption is that any args = no gui -- could add explicit arg if you want 50 | } else { 51 | gimain.Main(func() { // this starts gui -- requires valid OpenGL display connection (e.g., X11) 52 | guirun() 53 | }) 54 | } 55 | } 56 | 57 | func guirun() { 58 | TheSim.Init() 59 | win := TheSim.ConfigGui() 60 | win.StartEventLoop() 61 | } 62 | 63 | // LogPrec is precision for saving float values in logs 64 | const LogPrec = 4 65 | 66 | // ParamSets is the default set of parameters -- Base is always applied, and others can be optionally 67 | // selected to apply on top of that 68 | var ParamSets = params.Sets{ 69 | {Name: "Base", Desc: "these are the best params", Sheets: params.Sheets{ 70 | "Network": ¶ms.Sheet{ 71 | {Sel: "Prjn", Desc: "norm and momentum on works better, but wt bal is not better for smaller nets", 72 | Params: params.Params{ 73 | "Prjn.Learn.Norm.On": "true", 74 | "Prjn.Learn.Momentum.On": "true", 75 | "Prjn.Learn.WtBal.On": "false", 76 | }}, 77 | {Sel: "Layer", Desc: "using default 1.8 inhib for all of network -- can explore", 78 | Params: params.Params{ 79 | "Layer.Inhib.Layer.Gi": "1.8", 80 | "Layer.Act.Gbar.L": "0.1", // set explictly, new default, a bit better vs 0.2 81 | }}, 82 | {Sel: ".Back", Desc: "top-down back-projections MUST have lower relative weight scale, otherwise network hallucinates", 83 | Params: params.Params{ 84 | "Prjn.WtScale.Rel": "0.2", 85 | }}, 86 | {Sel: "#Output", Desc: "output definitely needs lower inhib -- true for smaller layers in general", 87 | Params: params.Params{ 88 | "Layer.Inhib.Layer.Gi": "1.4", 89 | }}, 90 | }, 91 | "Sim": ¶ms.Sheet{ // sim params apply to sim object 92 | {Sel: "Sim", Desc: "best params always finish in this time", 93 | Params: params.Params{ 94 | "Sim.MaxEpcs": "50", 95 | }}, 96 | }, 97 | }}, 98 | {Name: "DefaultInhib", Desc: "output uses default inhib instead of lower", Sheets: params.Sheets{ 99 | "Network": ¶ms.Sheet{ 100 | {Sel: "#Output", Desc: "go back to default", 101 | Params: params.Params{ 102 | "Layer.Inhib.Layer.Gi": "1.8", 103 | }}, 104 | }, 105 | "Sim": ¶ms.Sheet{ // sim params apply to sim object 106 | {Sel: "Sim", Desc: "takes longer -- generally doesn't finish..", 107 | Params: params.Params{ 108 | "Sim.MaxEpcs": "100", 109 | }}, 110 | }, 111 | }}, 112 | {Name: "NoMomentum", Desc: "no momentum or normalization", Sheets: params.Sheets{ 113 | "Network": ¶ms.Sheet{ 114 | {Sel: "Prjn", Desc: "no norm or momentum", 115 | Params: params.Params{ 116 | "Prjn.Learn.Norm.On": "false", 117 | "Prjn.Learn.Momentum.On": "false", 118 | }}, 119 | }, 120 | }}, 121 | {Name: "WtBalOn", Desc: "try with weight bal on", Sheets: params.Sheets{ 122 | "Network": ¶ms.Sheet{ 123 | {Sel: "Prjn", Desc: "weight bal on", 124 | Params: params.Params{ 125 | "Prjn.Learn.WtBal.On": "true", 126 | }}, 127 | }, 128 | }}, 129 | } 130 | 131 | // Sim encapsulates the entire simulation model, and we define all the 132 | // functionality as methods on this struct. This structure keeps all relevant 133 | // state information organized and available without having to pass everything around 134 | // as arguments to methods, and provides the core GUI interface (note the view tags 135 | // for the fields which provide hints to how things should be displayed). 136 | type Sim struct { 137 | Net *leabra.Network `view:"no-inline" desc:"the network -- click to view / edit parameters for layers, prjns, etc"` 138 | Pats *etable.Table `view:"no-inline" desc:"the training patterns to use"` 139 | TrnEpcLog *etable.Table `view:"no-inline" desc:"training epoch-level log data"` 140 | TstEpcLog *etable.Table `view:"no-inline" desc:"testing epoch-level log data"` 141 | TstTrlLog *etable.Table `view:"no-inline" desc:"testing trial-level log data"` 142 | TstErrLog *etable.Table `view:"no-inline" desc:"log of all test trials where errors were made"` 143 | TstErrStats *etable.Table `view:"no-inline" desc:"stats on test trials where errors were made"` 144 | TstCycLog *etable.Table `view:"no-inline" desc:"testing cycle-level log data"` 145 | RunLog *etable.Table `view:"no-inline" desc:"summary log of each run"` 146 | RunStats *etable.Table `view:"no-inline" desc:"aggregate stats on all runs"` 147 | Params params.Sets `view:"no-inline" desc:"full collection of param sets"` 148 | ParamSet string `desc:"which set of *additional* parameters to use -- always applies Base and optionaly this next if set -- can use multiple names separated by spaces (don't put spaces in ParamSet names!)"` 149 | Tag string `desc:"extra tag string to add to any file names output from sim (e.g., weights files, log files, params for run)"` 150 | MaxRuns int `desc:"maximum number of model runs to perform"` 151 | MaxEpcs int `desc:"maximum number of epochs to run per model run"` 152 | NZeroStop int `desc:"if a positive number, training will stop after this many epochs with zero SSE"` 153 | TrainEnv env.FixedTable `desc:"Training environment -- contains everything about iterating over input / output patterns over training"` 154 | TestEnv env.FixedTable `desc:"Testing environment -- manages iterating over testing"` 155 | Time leabra.Time `desc:"leabra timing parameters and state"` 156 | ViewOn bool `desc:"whether to update the network view while running"` 157 | TrainUpdt leabra.TimeScales `desc:"at what time scale to update the display during training? Anything longer than Epoch updates at Epoch in this model"` 158 | TestUpdt leabra.TimeScales `desc:"at what time scale to update the display during testing? Anything longer than Epoch updates at Epoch in this model"` 159 | TestInterval int `desc:"how often to run through all the test patterns, in terms of training epochs -- can use 0 or -1 for no testing"` 160 | LayStatNms []string `desc:"names of layers to collect more detailed stats on (avg act, etc)"` 161 | 162 | // statistics: note use float64 as that is best for etable.Table 163 | TrlErr float64 `inactive:"+" desc:"1 if trial was error, 0 if correct -- based on SSE = 0 (subject to .5 unit-wise tolerance)"` 164 | TrlSSE float64 `inactive:"+" desc:"current trial's sum squared error"` 165 | TrlAvgSSE float64 `inactive:"+" desc:"current trial's average sum squared error"` 166 | TrlCosDiff float64 `inactive:"+" desc:"current trial's cosine difference"` 167 | EpcSSE float64 `inactive:"+" desc:"last epoch's total sum squared error"` 168 | EpcAvgSSE float64 `inactive:"+" desc:"last epoch's average sum squared error (average over trials, and over units within layer)"` 169 | EpcPctErr float64 `inactive:"+" desc:"last epoch's average TrlErr"` 170 | EpcPctCor float64 `inactive:"+" desc:"1 - last epoch's average TrlErr"` 171 | EpcCosDiff float64 `inactive:"+" desc:"last epoch's average cosine difference for output layer (a normalized error measure, maximum of 1 when the minus phase exactly matches the plus)"` 172 | EpcPerTrlMSec float64 `inactive:"+" desc:"how long did the epoch take per trial in wall-clock milliseconds"` 173 | FirstZero int `inactive:"+" desc:"epoch at when SSE first went to zero"` 174 | NZero int `inactive:"+" desc:"number of epochs in a row with zero SSE"` 175 | 176 | // internal state - view:"-" 177 | SumErr float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"` 178 | SumSSE float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"` 179 | SumAvgSSE float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"` 180 | SumCosDiff float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"` 181 | Win *gi.Window `view:"-" desc:"main GUI window"` 182 | NetView *netview.NetView `view:"-" desc:"the network viewer"` 183 | ToolBar *gi.ToolBar `view:"-" desc:"the master toolbar"` 184 | TrnEpcPlot *eplot.Plot2D `view:"-" desc:"the training epoch plot"` 185 | TstEpcPlot *eplot.Plot2D `view:"-" desc:"the testing epoch plot"` 186 | TstTrlPlot *eplot.Plot2D `view:"-" desc:"the test-trial plot"` 187 | TstCycPlot *eplot.Plot2D `view:"-" desc:"the test-cycle plot"` 188 | RunPlot *eplot.Plot2D `view:"-" desc:"the run plot"` 189 | TrnEpcFile *os.File `view:"-" desc:"log file"` 190 | RunFile *os.File `view:"-" desc:"log file"` 191 | ValsTsrs map[string]*etensor.Float32 `view:"-" desc:"for holding layer values"` 192 | SaveWts bool `view:"-" desc:"for command-line run only, auto-save final weights after each run"` 193 | NoGui bool `view:"-" desc:"if true, runing in no GUI mode"` 194 | LogSetParams bool `view:"-" desc:"if true, print message for all params that are set"` 195 | IsRunning bool `view:"-" desc:"true if sim is running"` 196 | StopNow bool `view:"-" desc:"flag to stop running"` 197 | NeedsNewRun bool `view:"-" desc:"flag to initialize NewRun if last one finished"` 198 | RndSeed int64 `view:"-" desc:"the current random seed"` 199 | LastEpcTime time.Time `view:"-" desc:"timer for last epoch"` 200 | } 201 | 202 | // this registers this Sim Type and gives it properties that e.g., 203 | // prompt for filename for save methods. 204 | var KiT_Sim = kit.Types.AddType(&Sim{}, SimProps) 205 | 206 | // TheSim is the overall state for this simulation 207 | var TheSim Sim 208 | 209 | // New creates new blank elements and initializes defaults 210 | func (ss *Sim) New() { 211 | ss.Net = &leabra.Network{} 212 | ss.Pats = &etable.Table{} 213 | ss.TrnEpcLog = &etable.Table{} 214 | ss.TstEpcLog = &etable.Table{} 215 | ss.TstTrlLog = &etable.Table{} 216 | ss.TstCycLog = &etable.Table{} 217 | ss.RunLog = &etable.Table{} 218 | ss.RunStats = &etable.Table{} 219 | ss.Params = ParamSets 220 | ss.RndSeed = 1 221 | ss.ViewOn = true 222 | ss.TrainUpdt = leabra.AlphaCycle 223 | ss.TestUpdt = leabra.Cycle 224 | ss.TestInterval = 5 225 | ss.LayStatNms = []string{"Hidden1", "Hidden2", "Output"} 226 | } 227 | 228 | //////////////////////////////////////////////////////////////////////////////////////////// 229 | // Configs 230 | 231 | // Config configures all the elements using the standard functions 232 | func (ss *Sim) Config() { 233 | //ss.ConfigPats() 234 | ss.OpenPats() 235 | ss.ConfigEnv() 236 | ss.ConfigNet(ss.Net) 237 | ss.ConfigTrnEpcLog(ss.TrnEpcLog) 238 | ss.ConfigTstEpcLog(ss.TstEpcLog) 239 | ss.ConfigTstTrlLog(ss.TstTrlLog) 240 | ss.ConfigTstCycLog(ss.TstCycLog) 241 | ss.ConfigRunLog(ss.RunLog) 242 | } 243 | 244 | func (ss *Sim) ConfigEnv() { 245 | if ss.MaxRuns == 0 { // allow user override 246 | ss.MaxRuns = 10 247 | } 248 | if ss.MaxEpcs == 0 { // allow user override 249 | ss.MaxEpcs = 50 250 | ss.NZeroStop = 5 251 | } 252 | 253 | ss.TrainEnv.Nm = "TrainEnv" 254 | ss.TrainEnv.Dsc = "training params and state" 255 | ss.TrainEnv.Table = etable.NewIdxView(ss.Pats) 256 | ss.TrainEnv.Validate() 257 | ss.TrainEnv.Run.Max = ss.MaxRuns // note: we are not setting epoch max -- do that manually 258 | 259 | ss.TestEnv.Nm = "TestEnv" 260 | ss.TestEnv.Dsc = "testing params and state" 261 | ss.TestEnv.Table = etable.NewIdxView(ss.Pats) 262 | ss.TestEnv.Sequential = true 263 | ss.TestEnv.Validate() 264 | 265 | // note: to create a train / test split of pats, do this: 266 | // all := etable.NewIdxView(ss.Pats) 267 | // splits, _ := split.Permuted(all, []float64{.8, .2}, []string{"Train", "Test"}) 268 | // ss.TrainEnv.Table = splits.Splits[0] 269 | // ss.TestEnv.Table = splits.Splits[1] 270 | 271 | ss.TrainEnv.Init(0) 272 | ss.TestEnv.Init(0) 273 | } 274 | 275 | func (ss *Sim) ConfigNet(net *leabra.Network) { 276 | net.InitName(net, "RA25") 277 | inp := net.AddLayer2D("Input", 5, 5, emer.Input) 278 | hid1 := net.AddLayer2D("Hidden1", 7, 7, emer.Hidden) 279 | hid2 := net.AddLayer4D("Hidden2", 2, 4, 3, 2, emer.Hidden) 280 | out := net.AddLayer2D("Output", 5, 5, emer.Target) 281 | 282 | // use this to position layers relative to each other 283 | // default is Above, YAlign = Front, XAlign = Center 284 | hid2.SetRelPos(relpos.Rel{Rel: relpos.RightOf, Other: "Hidden1", YAlign: relpos.Front, Space: 2}) 285 | 286 | // note: see emergent/prjn module for all the options on how to connect 287 | // NewFull returns a new prjn.Full connectivity pattern 288 | full := prjn.NewFull() 289 | 290 | net.ConnectLayers(inp, hid1, full, emer.Forward) 291 | net.BidirConnectLayers(hid1, hid2, full) 292 | net.BidirConnectLayers(hid2, out, full) 293 | 294 | // note: can set these to do parallel threaded computation across multiple cpus 295 | // not worth it for this small of a model, but definitely helps for larger ones 296 | // if Thread { 297 | // hid2.SetThread(1) 298 | // out.SetThread(1) 299 | // } 300 | 301 | // note: if you wanted to change a layer type from e.g., Target to Compare, do this: 302 | // out.SetType(emer.Compare) 303 | // that would mean that the output layer doesn't reflect target values in plus phase 304 | // and thus removes error-driven learning -- but stats are still computed. 305 | 306 | net.Defaults() 307 | ss.SetParams("Network", ss.LogSetParams) // only set Network params 308 | err := net.Build() 309 | if err != nil { 310 | log.Println(err) 311 | return 312 | } 313 | net.InitWts() 314 | } 315 | 316 | //////////////////////////////////////////////////////////////////////////////// 317 | // Init, utils 318 | 319 | // Init restarts the run, and initializes everything, including network weights 320 | // and resets the epoch log table 321 | func (ss *Sim) Init() { 322 | rand.Seed(ss.RndSeed) 323 | ss.ConfigEnv() // re-config env just in case a different set of patterns was 324 | // selected or patterns have been modified etc 325 | ss.StopNow = false 326 | ss.SetParams("", ss.LogSetParams) // all sheets 327 | ss.NewRun() 328 | ss.UpdateView(true) 329 | } 330 | 331 | // NewRndSeed gets a new random seed based on current time -- otherwise uses 332 | // the same random seed for every run 333 | func (ss *Sim) NewRndSeed() { 334 | ss.RndSeed = time.Now().UnixNano() 335 | } 336 | 337 | // Counters returns a string of the current counter state 338 | // use tabs to achieve a reasonable formatting overall 339 | // and add a few tabs at the end to allow for expansion.. 340 | func (ss *Sim) Counters(train bool) string { 341 | if train { 342 | return fmt.Sprintf("Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t", ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TrainEnv.Trial.Cur, ss.Time.Cycle, ss.TrainEnv.TrialName.Cur) 343 | } else { 344 | return fmt.Sprintf("Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t", ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TestEnv.Trial.Cur, ss.Time.Cycle, ss.TestEnv.TrialName.Cur) 345 | } 346 | } 347 | 348 | func (ss *Sim) UpdateView(train bool) { 349 | if ss.NetView != nil && ss.NetView.IsVisible() { 350 | ss.NetView.Record(ss.Counters(train)) 351 | // note: essential to use Go version of update when called from another goroutine 352 | ss.NetView.GoUpdate() // note: using counters is significantly slower.. 353 | } 354 | } 355 | 356 | //////////////////////////////////////////////////////////////////////////////// 357 | // Running the Network, starting bottom-up.. 358 | 359 | // AlphaCyc runs one alpha-cycle (100 msec, 4 quarters) of processing. 360 | // External inputs must have already been applied prior to calling, 361 | // using ApplyExt method on relevant layers (see TrainTrial, TestTrial). 362 | // If train is true, then learning DWt or WtFmDWt calls are made. 363 | // Handles netview updating within scope of AlphaCycle 364 | func (ss *Sim) AlphaCyc(train bool) { 365 | // ss.Win.PollEvents() // this can be used instead of running in a separate goroutine 366 | viewUpdt := ss.TrainUpdt 367 | if !train { 368 | viewUpdt = ss.TestUpdt 369 | } 370 | 371 | // update prior weight changes at start, so any DWt values remain visible at end 372 | // you might want to do this less frequently to achieve a mini-batch update 373 | // in which case, move it out to the TrainTrial method where the relevant 374 | // counters are being dealt with. 375 | if train { 376 | ss.Net.WtFmDWt() 377 | } 378 | 379 | ss.Net.AlphaCycInit() 380 | ss.Time.AlphaCycStart() 381 | for qtr := 0; qtr < 4; qtr++ { 382 | for cyc := 0; cyc < ss.Time.CycPerQtr; cyc++ { 383 | ss.Net.Cycle(&ss.Time) 384 | if !train { 385 | ss.LogTstCyc(ss.TstCycLog, ss.Time.Cycle) 386 | } 387 | ss.Time.CycleInc() 388 | if ss.ViewOn { 389 | switch viewUpdt { 390 | case leabra.Cycle: 391 | if cyc != ss.Time.CycPerQtr-1 { // will be updated by quarter 392 | ss.UpdateView(train) 393 | } 394 | case leabra.FastSpike: 395 | if (cyc+1)%10 == 0 { 396 | ss.UpdateView(train) 397 | } 398 | } 399 | } 400 | } 401 | ss.Net.QuarterFinal(&ss.Time) 402 | ss.Time.QuarterInc() 403 | if ss.ViewOn { 404 | switch { 405 | case viewUpdt <= leabra.Quarter: 406 | ss.UpdateView(train) 407 | case viewUpdt == leabra.Phase: 408 | if qtr >= 2 { 409 | ss.UpdateView(train) 410 | } 411 | } 412 | } 413 | } 414 | 415 | if train { 416 | ss.Net.DWt() 417 | } 418 | if ss.ViewOn && viewUpdt == leabra.AlphaCycle { 419 | ss.UpdateView(train) 420 | } 421 | if !train { 422 | ss.TstCycPlot.GoUpdate() // make sure up-to-date at end 423 | } 424 | } 425 | 426 | // ApplyInputs applies input patterns from given envirbonment. 427 | // It is good practice to have this be a separate method with appropriate 428 | // args so that it can be used for various different contexts 429 | // (training, testing, etc). 430 | func (ss *Sim) ApplyInputs(en env.Env) { 431 | // ss.Net.InitExt() // clear any existing inputs -- not strictly necessary if always 432 | // going to the same layers, but good practice and cheap anyway 433 | 434 | lays := []string{"Input", "Output"} 435 | for _, lnm := range lays { 436 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra() 437 | pats := en.State(ly.Nm) 438 | if pats != nil { 439 | ly.ApplyExt(pats) 440 | } 441 | } 442 | } 443 | 444 | // TrainTrial runs one trial of training using TrainEnv 445 | func (ss *Sim) TrainTrial() { 446 | if ss.NeedsNewRun { 447 | ss.NewRun() 448 | } 449 | 450 | ss.TrainEnv.Step() // the Env encapsulates and manages all counter state 451 | 452 | // Key to query counters FIRST because current state is in NEXT epoch 453 | // if epoch counter has changed 454 | epc, _, chg := ss.TrainEnv.Counter(env.Epoch) 455 | if chg { 456 | ss.LogTrnEpc(ss.TrnEpcLog) 457 | if ss.ViewOn && ss.TrainUpdt > leabra.AlphaCycle { 458 | ss.UpdateView(true) 459 | } 460 | if ss.TestInterval > 0 && epc%ss.TestInterval == 0 { // note: epc is *next* so won't trigger first time 461 | ss.TestAll() 462 | } 463 | if epc >= ss.MaxEpcs || (ss.NZeroStop > 0 && ss.NZero >= ss.NZeroStop) { 464 | // done with training.. 465 | ss.RunEnd() 466 | if ss.TrainEnv.Run.Incr() { // we are done! 467 | ss.StopNow = true 468 | return 469 | } else { 470 | ss.NeedsNewRun = true 471 | return 472 | } 473 | } 474 | } 475 | 476 | ss.ApplyInputs(&ss.TrainEnv) 477 | ss.AlphaCyc(true) // train 478 | ss.TrialStats(true) // accumulate 479 | } 480 | 481 | // RunEnd is called at the end of a run -- save weights, record final log, etc here 482 | func (ss *Sim) RunEnd() { 483 | ss.LogRun(ss.RunLog) 484 | if ss.SaveWts { 485 | fnm := ss.WeightsFileName() 486 | fmt.Printf("Saving Weights to: %s\n", fnm) 487 | ss.Net.SaveWtsJSON(gi.FileName(fnm)) 488 | } 489 | } 490 | 491 | // NewRun intializes a new run of the model, using the TrainEnv.Run counter 492 | // for the new run value 493 | func (ss *Sim) NewRun() { 494 | run := ss.TrainEnv.Run.Cur 495 | ss.TrainEnv.Init(run) 496 | ss.TestEnv.Init(run) 497 | ss.Time.Reset() 498 | ss.Net.InitWts() 499 | ss.InitStats() 500 | ss.TrnEpcLog.SetNumRows(0) 501 | ss.TstEpcLog.SetNumRows(0) 502 | ss.NeedsNewRun = false 503 | } 504 | 505 | // InitStats initializes all the statistics, especially important for the 506 | // cumulative epoch stats -- called at start of new run 507 | func (ss *Sim) InitStats() { 508 | // accumulators 509 | ss.SumErr = 0 510 | ss.SumSSE = 0 511 | ss.SumAvgSSE = 0 512 | ss.SumCosDiff = 0 513 | ss.FirstZero = -1 514 | ss.NZero = 0 515 | // clear rest just to make Sim look initialized 516 | ss.TrlErr = 0 517 | ss.TrlSSE = 0 518 | ss.TrlAvgSSE = 0 519 | ss.EpcSSE = 0 520 | ss.EpcAvgSSE = 0 521 | ss.EpcPctErr = 0 522 | ss.EpcCosDiff = 0 523 | } 524 | 525 | // TrialStats computes the trial-level statistics and adds them to the epoch accumulators if 526 | // accum is true. Note that we're accumulating stats here on the Sim side so the 527 | // core algorithm side remains as simple as possible, and doesn't need to worry about 528 | // different time-scales over which stats could be accumulated etc. 529 | // You can also aggregate directly from log data, as is done for testing stats 530 | func (ss *Sim) TrialStats(accum bool) { 531 | out := ss.Net.LayerByName("Output").(leabra.LeabraLayer).AsLeabra() 532 | ss.TrlCosDiff = float64(out.CosDiff.Cos) 533 | ss.TrlSSE, ss.TrlAvgSSE = out.MSE(0.5) // 0.5 = per-unit tolerance -- right side of .5 534 | if ss.TrlSSE > 0 { 535 | ss.TrlErr = 1 536 | } else { 537 | ss.TrlErr = 0 538 | } 539 | if accum { 540 | ss.SumErr += ss.TrlErr 541 | ss.SumSSE += ss.TrlSSE 542 | ss.SumAvgSSE += ss.TrlAvgSSE 543 | ss.SumCosDiff += ss.TrlCosDiff 544 | } 545 | } 546 | 547 | // TrainEpoch runs training trials for remainder of this epoch 548 | func (ss *Sim) TrainEpoch() { 549 | ss.StopNow = false 550 | curEpc := ss.TrainEnv.Epoch.Cur 551 | for { 552 | ss.TrainTrial() 553 | if ss.StopNow || ss.TrainEnv.Epoch.Cur != curEpc { 554 | break 555 | } 556 | } 557 | ss.Stopped() 558 | } 559 | 560 | // TrainRun runs training trials for remainder of run 561 | func (ss *Sim) TrainRun() { 562 | ss.StopNow = false 563 | curRun := ss.TrainEnv.Run.Cur 564 | for { 565 | ss.TrainTrial() 566 | if ss.StopNow || ss.TrainEnv.Run.Cur != curRun { 567 | break 568 | } 569 | } 570 | ss.Stopped() 571 | } 572 | 573 | // Train runs the full training from this point onward 574 | func (ss *Sim) Train() { 575 | ss.StopNow = false 576 | for { 577 | ss.TrainTrial() 578 | if ss.StopNow { 579 | break 580 | } 581 | } 582 | ss.Stopped() 583 | } 584 | 585 | // Stop tells the sim to stop running 586 | func (ss *Sim) Stop() { 587 | ss.StopNow = true 588 | } 589 | 590 | // Stopped is called when a run method stops running -- updates the IsRunning flag and toolbar 591 | func (ss *Sim) Stopped() { 592 | ss.IsRunning = false 593 | if ss.Win != nil { 594 | vp := ss.Win.WinViewport2D() 595 | if ss.ToolBar != nil { 596 | ss.ToolBar.UpdateActions() 597 | } 598 | vp.SetNeedsFullRender() 599 | } 600 | } 601 | 602 | // SaveWeights saves the network weights -- when called with giv.CallMethod 603 | // it will auto-prompt for filename 604 | func (ss *Sim) SaveWeights(filename gi.FileName) { 605 | ss.Net.SaveWtsJSON(filename) 606 | } 607 | 608 | //////////////////////////////////////////////////////////////////////////////////////////// 609 | // Testing 610 | 611 | // TestTrial runs one trial of testing -- always sequentially presented inputs 612 | func (ss *Sim) TestTrial(returnOnChg bool) { 613 | ss.TestEnv.Step() 614 | 615 | // Query counters FIRST 616 | _, _, chg := ss.TestEnv.Counter(env.Epoch) 617 | if chg { 618 | if ss.ViewOn && ss.TestUpdt > leabra.AlphaCycle { 619 | ss.UpdateView(false) 620 | } 621 | ss.LogTstEpc(ss.TstEpcLog) 622 | if returnOnChg { 623 | return 624 | } 625 | } 626 | 627 | ss.ApplyInputs(&ss.TestEnv) 628 | ss.AlphaCyc(false) // !train 629 | ss.TrialStats(false) // !accumulate 630 | ss.LogTstTrl(ss.TstTrlLog) 631 | } 632 | 633 | // TestItem tests given item which is at given index in test item list 634 | func (ss *Sim) TestItem(idx int) { 635 | cur := ss.TestEnv.Trial.Cur 636 | ss.TestEnv.Trial.Cur = idx 637 | ss.TestEnv.SetTrialName() 638 | ss.ApplyInputs(&ss.TestEnv) 639 | ss.AlphaCyc(false) // !train 640 | ss.TrialStats(false) // !accumulate 641 | ss.TestEnv.Trial.Cur = cur 642 | } 643 | 644 | // TestAll runs through the full set of testing items 645 | func (ss *Sim) TestAll() { 646 | ss.TestEnv.Init(ss.TrainEnv.Run.Cur) 647 | for { 648 | ss.TestTrial(true) // return on change -- don't wrap 649 | _, _, chg := ss.TestEnv.Counter(env.Epoch) 650 | if chg || ss.StopNow { 651 | break 652 | } 653 | } 654 | } 655 | 656 | // RunTestAll runs through the full set of testing items, has stop running = false at end -- for gui 657 | func (ss *Sim) RunTestAll() { 658 | ss.StopNow = false 659 | ss.TestAll() 660 | ss.Stopped() 661 | } 662 | 663 | ///////////////////////////////////////////////////////////////////////// 664 | // Params setting 665 | 666 | // ParamsName returns name of current set of parameters 667 | func (ss *Sim) ParamsName() string { 668 | if ss.ParamSet == "" { 669 | return "Base" 670 | } 671 | return ss.ParamSet 672 | } 673 | 674 | // SetParams sets the params for "Base" and then current ParamSet. 675 | // If sheet is empty, then it applies all avail sheets (e.g., Network, Sim) 676 | // otherwise just the named sheet 677 | // if setMsg = true then we output a message for each param that was set. 678 | func (ss *Sim) SetParams(sheet string, setMsg bool) error { 679 | if sheet == "" { 680 | // this is important for catching typos and ensuring that all sheets can be used 681 | ss.Params.ValidateSheets([]string{"Network", "Sim"}) 682 | } 683 | err := ss.SetParamsSet("Base", sheet, setMsg) 684 | if ss.ParamSet != "" && ss.ParamSet != "Base" { 685 | sps := strings.Fields(ss.ParamSet) 686 | for _, ps := range sps { 687 | err = ss.SetParamsSet(ps, sheet, setMsg) 688 | } 689 | } 690 | return err 691 | } 692 | 693 | // SetParamsSet sets the params for given params.Set name. 694 | // If sheet is empty, then it applies all avail sheets (e.g., Network, Sim) 695 | // otherwise just the named sheet 696 | // if setMsg = true then we output a message for each param that was set. 697 | func (ss *Sim) SetParamsSet(setNm string, sheet string, setMsg bool) error { 698 | pset, err := ss.Params.SetByNameTry(setNm) 699 | if err != nil { 700 | return err 701 | } 702 | if sheet == "" || sheet == "Network" { 703 | netp, ok := pset.Sheets["Network"] 704 | if ok { 705 | ss.Net.ApplyParams(netp, setMsg) 706 | } 707 | } 708 | 709 | if sheet == "" || sheet == "Sim" { 710 | simp, ok := pset.Sheets["Sim"] 711 | if ok { 712 | simp.Apply(ss, setMsg) 713 | } 714 | } 715 | // note: if you have more complex environments with parameters, definitely add 716 | // sheets for them, e.g., "TrainEnv", "TestEnv" etc 717 | return err 718 | } 719 | 720 | func (ss *Sim) ConfigPats() { 721 | dt := ss.Pats 722 | dt.SetMetaData("name", "TrainPats") 723 | dt.SetMetaData("desc", "Training patterns") 724 | dt.SetFromSchema(etable.Schema{ 725 | {"Name", etensor.STRING, nil, nil}, 726 | {"Input", etensor.FLOAT32, []int{5, 5}, []string{"Y", "X"}}, 727 | {"Output", etensor.FLOAT32, []int{5, 5}, []string{"Y", "X"}}, 728 | }, 25) 729 | 730 | patgen.PermutedBinaryRows(dt.Cols[1], 6, 1, 0) 731 | patgen.PermutedBinaryRows(dt.Cols[2], 6, 1, 0) 732 | dt.SaveCSV("random_5x5_25_gen.csv", etable.Comma, etable.Headers) 733 | } 734 | 735 | func (ss *Sim) OpenPats() { 736 | dt := ss.Pats 737 | dt.SetMetaData("name", "TrainPats") 738 | dt.SetMetaData("desc", "Training patterns") 739 | err := dt.OpenCSV("random_5x5_25.tsv", etable.Tab) 740 | if err != nil { 741 | log.Println(err) 742 | } 743 | } 744 | 745 | //////////////////////////////////////////////////////////////////////////////////////////// 746 | // Logging 747 | 748 | // ValsTsr gets value tensor of given name, creating if not yet made 749 | func (ss *Sim) ValsTsr(name string) *etensor.Float32 { 750 | if ss.ValsTsrs == nil { 751 | ss.ValsTsrs = make(map[string]*etensor.Float32) 752 | } 753 | tsr, ok := ss.ValsTsrs[name] 754 | if !ok { 755 | tsr = &etensor.Float32{} 756 | ss.ValsTsrs[name] = tsr 757 | } 758 | return tsr 759 | } 760 | 761 | // RunName returns a name for this run that combines Tag and Params -- add this to 762 | // any file names that are saved. 763 | func (ss *Sim) RunName() string { 764 | if ss.Tag != "" { 765 | return ss.Tag + "_" + ss.ParamsName() 766 | } else { 767 | return ss.ParamsName() 768 | } 769 | } 770 | 771 | // RunEpochName returns a string with the run and epoch numbers with leading zeros, suitable 772 | // for using in weights file names. Uses 3, 5 digits for each. 773 | func (ss *Sim) RunEpochName(run, epc int) string { 774 | return fmt.Sprintf("%03d_%05d", run, epc) 775 | } 776 | 777 | // WeightsFileName returns default current weights file name 778 | func (ss *Sim) WeightsFileName() string { 779 | return ss.Net.Nm + "_" + ss.RunName() + "_" + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur) + ".wts" 780 | } 781 | 782 | // LogFileName returns default log file name 783 | func (ss *Sim) LogFileName(lognm string) string { 784 | return ss.Net.Nm + "_" + ss.RunName() + "_" + lognm + ".tsv" 785 | } 786 | 787 | ////////////////////////////////////////////// 788 | // TrnEpcLog 789 | 790 | // LogTrnEpc adds data from current epoch to the TrnEpcLog table. 791 | // computes epoch averages prior to logging. 792 | func (ss *Sim) LogTrnEpc(dt *etable.Table) { 793 | row := dt.Rows 794 | dt.SetNumRows(row + 1) 795 | 796 | epc := ss.TrainEnv.Epoch.Prv // this is triggered by increment so use previous value 797 | nt := float64(len(ss.TrainEnv.Order)) // number of trials in view 798 | 799 | ss.EpcSSE = ss.SumSSE / nt 800 | ss.SumSSE = 0 801 | ss.EpcAvgSSE = ss.SumAvgSSE / nt 802 | ss.SumAvgSSE = 0 803 | ss.EpcPctErr = float64(ss.SumErr) / nt 804 | ss.SumErr = 0 805 | ss.EpcPctCor = 1 - ss.EpcPctErr 806 | ss.EpcCosDiff = ss.SumCosDiff / nt 807 | ss.SumCosDiff = 0 808 | if ss.FirstZero < 0 && ss.EpcPctErr == 0 { 809 | ss.FirstZero = epc 810 | } 811 | if ss.EpcPctErr == 0 { 812 | ss.NZero++ 813 | } else { 814 | ss.NZero = 0 815 | } 816 | 817 | if ss.LastEpcTime.IsZero() { 818 | ss.EpcPerTrlMSec = 0 819 | } else { 820 | iv := time.Now().Sub(ss.LastEpcTime) 821 | ss.EpcPerTrlMSec = float64(iv) / (nt * float64(time.Millisecond)) 822 | } 823 | ss.LastEpcTime = time.Now() 824 | 825 | dt.SetCellFloat("Run", row, float64(ss.TrainEnv.Run.Cur)) 826 | dt.SetCellFloat("Epoch", row, float64(epc)) 827 | dt.SetCellFloat("SSE", row, ss.EpcSSE) 828 | dt.SetCellFloat("AvgSSE", row, ss.EpcAvgSSE) 829 | dt.SetCellFloat("PctErr", row, ss.EpcPctErr) 830 | dt.SetCellFloat("PctCor", row, ss.EpcPctCor) 831 | dt.SetCellFloat("CosDiff", row, ss.EpcCosDiff) 832 | dt.SetCellFloat("PerTrlMSec", row, ss.EpcPerTrlMSec) 833 | 834 | for _, lnm := range ss.LayStatNms { 835 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra() 836 | dt.SetCellFloat(ly.Nm+"_ActAvg", row, float64(ly.Pools[0].ActAvg.ActPAvgEff)) 837 | } 838 | 839 | // note: essential to use Go version of update when called from another goroutine 840 | ss.TrnEpcPlot.GoUpdate() 841 | if ss.TrnEpcFile != nil { 842 | if ss.TrainEnv.Run.Cur == 0 && epc == 0 { 843 | dt.WriteCSVHeaders(ss.TrnEpcFile, etable.Tab) 844 | } 845 | dt.WriteCSVRow(ss.TrnEpcFile, row, etable.Tab) 846 | } 847 | } 848 | 849 | func (ss *Sim) ConfigTrnEpcLog(dt *etable.Table) { 850 | dt.SetMetaData("name", "TrnEpcLog") 851 | dt.SetMetaData("desc", "Record of performance over epochs of training") 852 | dt.SetMetaData("read-only", "true") 853 | dt.SetMetaData("precision", strconv.Itoa(LogPrec)) 854 | 855 | sch := etable.Schema{ 856 | {"Run", etensor.INT64, nil, nil}, 857 | {"Epoch", etensor.INT64, nil, nil}, 858 | {"SSE", etensor.FLOAT64, nil, nil}, 859 | {"AvgSSE", etensor.FLOAT64, nil, nil}, 860 | {"PctErr", etensor.FLOAT64, nil, nil}, 861 | {"PctCor", etensor.FLOAT64, nil, nil}, 862 | {"CosDiff", etensor.FLOAT64, nil, nil}, 863 | {"PerTrlMSec", etensor.FLOAT64, nil, nil}, 864 | } 865 | for _, lnm := range ss.LayStatNms { 866 | sch = append(sch, etable.Column{lnm + "_ActAvg", etensor.FLOAT64, nil, nil}) 867 | } 868 | dt.SetFromSchema(sch, 0) 869 | } 870 | 871 | func (ss *Sim) ConfigTrnEpcPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D { 872 | plt.Params.Title = "Leabra Random Associator 25 Epoch Plot" 873 | plt.Params.XAxisCol = "Epoch" 874 | plt.SetTable(dt) 875 | // order of params: on, fixMin, min, fixMax, max 876 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 877 | plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 878 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 879 | plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 880 | plt.SetColParams("PctErr", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot 881 | plt.SetColParams("PctCor", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot 882 | plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 883 | plt.SetColParams("PerTrlMSec", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 884 | 885 | for _, lnm := range ss.LayStatNms { 886 | plt.SetColParams(lnm+"_ActAvg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, .5) 887 | } 888 | return plt 889 | } 890 | 891 | ////////////////////////////////////////////// 892 | // TstTrlLog 893 | 894 | // LogTstTrl adds data from current trial to the TstTrlLog table. 895 | // log always contains number of testing items 896 | func (ss *Sim) LogTstTrl(dt *etable.Table) { 897 | epc := ss.TrainEnv.Epoch.Prv // this is triggered by increment so use previous value 898 | inp := ss.Net.LayerByName("Input").(leabra.LeabraLayer).AsLeabra() 899 | out := ss.Net.LayerByName("Output").(leabra.LeabraLayer).AsLeabra() 900 | 901 | trl := ss.TestEnv.Trial.Cur 902 | row := trl 903 | 904 | if dt.Rows <= row { 905 | dt.SetNumRows(row + 1) 906 | } 907 | 908 | dt.SetCellFloat("Run", row, float64(ss.TrainEnv.Run.Cur)) 909 | dt.SetCellFloat("Epoch", row, float64(epc)) 910 | dt.SetCellFloat("Trial", row, float64(trl)) 911 | dt.SetCellString("TrialName", row, ss.TestEnv.TrialName.Cur) 912 | dt.SetCellFloat("Err", row, ss.TrlErr) 913 | dt.SetCellFloat("SSE", row, ss.TrlSSE) 914 | dt.SetCellFloat("AvgSSE", row, ss.TrlAvgSSE) 915 | dt.SetCellFloat("CosDiff", row, ss.TrlCosDiff) 916 | 917 | for _, lnm := range ss.LayStatNms { 918 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra() 919 | dt.SetCellFloat(ly.Nm+" ActM.Avg", row, float64(ly.Pools[0].ActM.Avg)) 920 | } 921 | ivt := ss.ValsTsr("Input") 922 | ovt := ss.ValsTsr("Output") 923 | inp.UnitValsTensor(ivt, "Act") 924 | dt.SetCellTensor("InAct", row, ivt) 925 | out.UnitValsTensor(ovt, "ActM") 926 | dt.SetCellTensor("OutActM", row, ovt) 927 | out.UnitValsTensor(ovt, "ActP") 928 | dt.SetCellTensor("OutActP", row, ovt) 929 | 930 | // note: essential to use Go version of update when called from another goroutine 931 | ss.TstTrlPlot.GoUpdate() 932 | } 933 | 934 | func (ss *Sim) ConfigTstTrlLog(dt *etable.Table) { 935 | inp := ss.Net.LayerByName("Input").(leabra.LeabraLayer).AsLeabra() 936 | out := ss.Net.LayerByName("Output").(leabra.LeabraLayer).AsLeabra() 937 | 938 | dt.SetMetaData("name", "TstTrlLog") 939 | dt.SetMetaData("desc", "Record of testing per input pattern") 940 | dt.SetMetaData("read-only", "true") 941 | dt.SetMetaData("precision", strconv.Itoa(LogPrec)) 942 | 943 | nt := ss.TestEnv.Table.Len() // number in view 944 | sch := etable.Schema{ 945 | {"Run", etensor.INT64, nil, nil}, 946 | {"Epoch", etensor.INT64, nil, nil}, 947 | {"Trial", etensor.INT64, nil, nil}, 948 | {"TrialName", etensor.STRING, nil, nil}, 949 | {"Err", etensor.FLOAT64, nil, nil}, 950 | {"SSE", etensor.FLOAT64, nil, nil}, 951 | {"AvgSSE", etensor.FLOAT64, nil, nil}, 952 | {"CosDiff", etensor.FLOAT64, nil, nil}, 953 | } 954 | for _, lnm := range ss.LayStatNms { 955 | sch = append(sch, etable.Column{lnm + " ActM.Avg", etensor.FLOAT64, nil, nil}) 956 | } 957 | sch = append(sch, etable.Schema{ 958 | {"InAct", etensor.FLOAT64, inp.Shp.Shp, nil}, 959 | {"OutActM", etensor.FLOAT64, out.Shp.Shp, nil}, 960 | {"OutActP", etensor.FLOAT64, out.Shp.Shp, nil}, 961 | }...) 962 | dt.SetFromSchema(sch, nt) 963 | } 964 | 965 | func (ss *Sim) ConfigTstTrlPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D { 966 | plt.Params.Title = "Leabra Random Associator 25 Test Trial Plot" 967 | plt.Params.XAxisCol = "Trial" 968 | plt.SetTable(dt) 969 | // order of params: on, fixMin, min, fixMax, max 970 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 971 | plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 972 | plt.SetColParams("Trial", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 973 | plt.SetColParams("TrialName", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 974 | plt.SetColParams("Err", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 975 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 976 | plt.SetColParams("AvgSSE", eplot.On, eplot.FixMin, 0, eplot.FloatMax, 0) 977 | plt.SetColParams("CosDiff", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) 978 | 979 | for _, lnm := range ss.LayStatNms { 980 | plt.SetColParams(lnm+" ActM.Avg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, .5) 981 | } 982 | 983 | plt.SetColParams("InAct", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 984 | plt.SetColParams("OutActM", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 985 | plt.SetColParams("OutActP", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 986 | return plt 987 | } 988 | 989 | ////////////////////////////////////////////// 990 | // TstEpcLog 991 | 992 | func (ss *Sim) LogTstEpc(dt *etable.Table) { 993 | row := dt.Rows 994 | dt.SetNumRows(row + 1) 995 | 996 | trl := ss.TstTrlLog 997 | tix := etable.NewIdxView(trl) 998 | epc := ss.TrainEnv.Epoch.Prv // ? 999 | 1000 | // note: this shows how to use agg methods to compute summary data from another 1001 | // data table, instead of incrementing on the Sim 1002 | dt.SetCellFloat("Run", row, float64(ss.TrainEnv.Run.Cur)) 1003 | dt.SetCellFloat("Epoch", row, float64(epc)) 1004 | dt.SetCellFloat("SSE", row, agg.Sum(tix, "SSE")[0]) 1005 | dt.SetCellFloat("AvgSSE", row, agg.Mean(tix, "AvgSSE")[0]) 1006 | dt.SetCellFloat("PctErr", row, agg.Mean(tix, "Err")[0]) 1007 | dt.SetCellFloat("PctCor", row, 1-agg.Mean(tix, "Err")[0]) 1008 | dt.SetCellFloat("CosDiff", row, agg.Mean(tix, "CosDiff")[0]) 1009 | 1010 | trlix := etable.NewIdxView(trl) 1011 | trlix.Filter(func(et *etable.Table, row int) bool { 1012 | return et.CellFloat("SSE", row) > 0 // include error trials 1013 | }) 1014 | ss.TstErrLog = trlix.NewTable() 1015 | 1016 | allsp := split.All(trlix) 1017 | split.Agg(allsp, "SSE", agg.AggSum) 1018 | split.Agg(allsp, "AvgSSE", agg.AggMean) 1019 | split.Agg(allsp, "InAct", agg.AggMean) 1020 | split.Agg(allsp, "OutActM", agg.AggMean) 1021 | split.Agg(allsp, "OutActP", agg.AggMean) 1022 | 1023 | ss.TstErrStats = allsp.AggsToTable(etable.AddAggName) 1024 | 1025 | // note: essential to use Go version of update when called from another goroutine 1026 | ss.TstEpcPlot.GoUpdate() 1027 | } 1028 | 1029 | func (ss *Sim) ConfigTstEpcLog(dt *etable.Table) { 1030 | dt.SetMetaData("name", "TstEpcLog") 1031 | dt.SetMetaData("desc", "Summary stats for testing trials") 1032 | dt.SetMetaData("read-only", "true") 1033 | dt.SetMetaData("precision", strconv.Itoa(LogPrec)) 1034 | 1035 | dt.SetFromSchema(etable.Schema{ 1036 | {"Run", etensor.INT64, nil, nil}, 1037 | {"Epoch", etensor.INT64, nil, nil}, 1038 | {"SSE", etensor.FLOAT64, nil, nil}, 1039 | {"AvgSSE", etensor.FLOAT64, nil, nil}, 1040 | {"PctErr", etensor.FLOAT64, nil, nil}, 1041 | {"PctCor", etensor.FLOAT64, nil, nil}, 1042 | {"CosDiff", etensor.FLOAT64, nil, nil}, 1043 | }, 0) 1044 | } 1045 | 1046 | func (ss *Sim) ConfigTstEpcPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D { 1047 | plt.Params.Title = "Leabra Random Associator 25 Testing Epoch Plot" 1048 | plt.Params.XAxisCol = "Epoch" 1049 | plt.SetTable(dt) 1050 | // order of params: on, fixMin, min, fixMax, max 1051 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1052 | plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1053 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1054 | plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1055 | plt.SetColParams("PctErr", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot 1056 | plt.SetColParams("PctCor", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot 1057 | plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 1058 | return plt 1059 | } 1060 | 1061 | ////////////////////////////////////////////// 1062 | // TstCycLog 1063 | 1064 | // LogTstCyc adds data from current trial to the TstCycLog table. 1065 | // log just has 100 cycles, is overwritten 1066 | func (ss *Sim) LogTstCyc(dt *etable.Table, cyc int) { 1067 | if dt.Rows <= cyc { 1068 | dt.SetNumRows(cyc + 1) 1069 | } 1070 | 1071 | dt.SetCellFloat("Cycle", cyc, float64(cyc)) 1072 | for _, lnm := range ss.LayStatNms { 1073 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra() 1074 | dt.SetCellFloat(ly.Nm+" Ge.Avg", cyc, float64(ly.Pools[0].Inhib.Ge.Avg)) 1075 | dt.SetCellFloat(ly.Nm+" Act.Avg", cyc, float64(ly.Pools[0].Inhib.Act.Avg)) 1076 | } 1077 | 1078 | if cyc%10 == 0 { // too slow to do every cyc 1079 | // note: essential to use Go version of update when called from another goroutine 1080 | ss.TstCycPlot.GoUpdate() 1081 | } 1082 | } 1083 | 1084 | func (ss *Sim) ConfigTstCycLog(dt *etable.Table) { 1085 | dt.SetMetaData("name", "TstCycLog") 1086 | dt.SetMetaData("desc", "Record of activity etc over one trial by cycle") 1087 | dt.SetMetaData("read-only", "true") 1088 | dt.SetMetaData("precision", strconv.Itoa(LogPrec)) 1089 | 1090 | np := 100 // max cycles 1091 | sch := etable.Schema{ 1092 | {"Cycle", etensor.INT64, nil, nil}, 1093 | } 1094 | for _, lnm := range ss.LayStatNms { 1095 | sch = append(sch, etable.Column{lnm + " Ge.Avg", etensor.FLOAT64, nil, nil}) 1096 | sch = append(sch, etable.Column{lnm + " Act.Avg", etensor.FLOAT64, nil, nil}) 1097 | } 1098 | dt.SetFromSchema(sch, np) 1099 | } 1100 | 1101 | func (ss *Sim) ConfigTstCycPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D { 1102 | plt.Params.Title = "Leabra Random Associator 25 Test Cycle Plot" 1103 | plt.Params.XAxisCol = "Cycle" 1104 | plt.SetTable(dt) 1105 | // order of params: on, fixMin, min, fixMax, max 1106 | plt.SetColParams("Cycle", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1107 | for _, lnm := range ss.LayStatNms { 1108 | plt.SetColParams(lnm+" Ge.Avg", true, true, 0, true, .5) 1109 | plt.SetColParams(lnm+" Act.Avg", true, true, 0, true, .5) 1110 | } 1111 | return plt 1112 | } 1113 | 1114 | ////////////////////////////////////////////// 1115 | // RunLog 1116 | 1117 | // LogRun adds data from current run to the RunLog table. 1118 | func (ss *Sim) LogRun(dt *etable.Table) { 1119 | run := ss.TrainEnv.Run.Cur // this is NOT triggered by increment yet -- use Cur 1120 | row := dt.Rows 1121 | dt.SetNumRows(row + 1) 1122 | 1123 | epclog := ss.TrnEpcLog 1124 | epcix := etable.NewIdxView(epclog) 1125 | // compute mean over last N epochs for run level 1126 | nlast := 5 1127 | if nlast > epcix.Len()-1 { 1128 | nlast = epcix.Len() - 1 1129 | } 1130 | epcix.Idxs = epcix.Idxs[epcix.Len()-nlast:] 1131 | 1132 | params := ss.RunName() // includes tag 1133 | 1134 | dt.SetCellFloat("Run", row, float64(run)) 1135 | dt.SetCellString("Params", row, params) 1136 | dt.SetCellFloat("FirstZero", row, float64(ss.FirstZero)) 1137 | dt.SetCellFloat("SSE", row, agg.Mean(epcix, "SSE")[0]) 1138 | dt.SetCellFloat("AvgSSE", row, agg.Mean(epcix, "AvgSSE")[0]) 1139 | dt.SetCellFloat("PctErr", row, agg.Mean(epcix, "PctErr")[0]) 1140 | dt.SetCellFloat("PctCor", row, agg.Mean(epcix, "PctCor")[0]) 1141 | dt.SetCellFloat("CosDiff", row, agg.Mean(epcix, "CosDiff")[0]) 1142 | 1143 | runix := etable.NewIdxView(dt) 1144 | spl := split.GroupBy(runix, []string{"Params"}) 1145 | split.Desc(spl, "FirstZero") 1146 | split.Desc(spl, "PctCor") 1147 | ss.RunStats = spl.AggsToTable(etable.AddAggName) 1148 | 1149 | // note: essential to use Go version of update when called from another goroutine 1150 | ss.RunPlot.GoUpdate() 1151 | if ss.RunFile != nil { 1152 | if row == 0 { 1153 | dt.WriteCSVHeaders(ss.RunFile, etable.Tab) 1154 | } 1155 | dt.WriteCSVRow(ss.RunFile, row, etable.Tab) 1156 | } 1157 | } 1158 | 1159 | func (ss *Sim) ConfigRunLog(dt *etable.Table) { 1160 | dt.SetMetaData("name", "RunLog") 1161 | dt.SetMetaData("desc", "Record of performance at end of training") 1162 | dt.SetMetaData("read-only", "true") 1163 | dt.SetMetaData("precision", strconv.Itoa(LogPrec)) 1164 | 1165 | dt.SetFromSchema(etable.Schema{ 1166 | {"Run", etensor.INT64, nil, nil}, 1167 | {"Params", etensor.STRING, nil, nil}, 1168 | {"FirstZero", etensor.FLOAT64, nil, nil}, 1169 | {"SSE", etensor.FLOAT64, nil, nil}, 1170 | {"AvgSSE", etensor.FLOAT64, nil, nil}, 1171 | {"PctErr", etensor.FLOAT64, nil, nil}, 1172 | {"PctCor", etensor.FLOAT64, nil, nil}, 1173 | {"CosDiff", etensor.FLOAT64, nil, nil}, 1174 | }, 0) 1175 | } 1176 | 1177 | func (ss *Sim) ConfigRunPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D { 1178 | plt.Params.Title = "Leabra Random Associator 25 Run Plot" 1179 | plt.Params.XAxisCol = "Run" 1180 | plt.SetTable(dt) 1181 | // order of params: on, fixMin, min, fixMax, max 1182 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1183 | plt.SetColParams("FirstZero", eplot.On, eplot.FixMin, 0, eplot.FloatMax, 0) // default plot 1184 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1185 | plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0) 1186 | plt.SetColParams("PctErr", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 1187 | plt.SetColParams("PctCor", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 1188 | plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1) 1189 | return plt 1190 | } 1191 | 1192 | //////////////////////////////////////////////////////////////////////////////////////////// 1193 | // Gui 1194 | 1195 | // ConfigGui configures the GoGi gui interface for this simulation, 1196 | func (ss *Sim) ConfigGui() *gi.Window { 1197 | width := 1600 1198 | height := 1200 1199 | 1200 | gi.SetAppName("ra25") 1201 | gi.SetAppAbout(`This demonstrates a basic Leabra model. See emergent on GitHub.

`) 1202 | 1203 | win := gi.NewMainWindow("ra25", "Leabra Random Associator", width, height) 1204 | ss.Win = win 1205 | 1206 | vp := win.WinViewport2D() 1207 | updt := vp.UpdateStart() 1208 | 1209 | mfr := win.SetMainFrame() 1210 | 1211 | tbar := gi.AddNewToolBar(mfr, "tbar") 1212 | tbar.SetStretchMaxWidth() 1213 | ss.ToolBar = tbar 1214 | 1215 | split := gi.AddNewSplitView(mfr, "split") 1216 | split.Dim = mat32.X 1217 | split.SetStretchMax() 1218 | 1219 | sv := giv.AddNewStructView(split, "sv") 1220 | sv.SetStruct(ss) 1221 | 1222 | tv := gi.AddNewTabView(split, "tv") 1223 | 1224 | nv := tv.AddNewTab(netview.KiT_NetView, "NetView").(*netview.NetView) 1225 | nv.Var = "Act" 1226 | // nv.Params.ColorMap = "Jet" // default is ColdHot 1227 | // which fares pretty well in terms of discussion here: 1228 | // https://matplotlib.org/tutorials/colors/colormaps.html 1229 | nv.SetNet(ss.Net) 1230 | ss.NetView = nv 1231 | 1232 | nv.Scene().Camera.Pose.Pos.Set(0, 1, 2.75) // more "head on" than default which is more "top down" 1233 | nv.Scene().Camera.LookAt(mat32.Vec3{0, 0, 0}, mat32.Vec3{0, 1, 0}) 1234 | 1235 | plt := tv.AddNewTab(eplot.KiT_Plot2D, "TrnEpcPlot").(*eplot.Plot2D) 1236 | ss.TrnEpcPlot = ss.ConfigTrnEpcPlot(plt, ss.TrnEpcLog) 1237 | 1238 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "TstTrlPlot").(*eplot.Plot2D) 1239 | ss.TstTrlPlot = ss.ConfigTstTrlPlot(plt, ss.TstTrlLog) 1240 | 1241 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "TstCycPlot").(*eplot.Plot2D) 1242 | ss.TstCycPlot = ss.ConfigTstCycPlot(plt, ss.TstCycLog) 1243 | 1244 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "TstEpcPlot").(*eplot.Plot2D) 1245 | ss.TstEpcPlot = ss.ConfigTstEpcPlot(plt, ss.TstEpcLog) 1246 | 1247 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "RunPlot").(*eplot.Plot2D) 1248 | ss.RunPlot = ss.ConfigRunPlot(plt, ss.RunLog) 1249 | 1250 | split.SetSplits(.3, .7) 1251 | 1252 | tbar.AddAction(gi.ActOpts{Label: "Init", Icon: "update", Tooltip: "Initialize everything including network weights, and start over. Also applies current params.", UpdateFunc: func(act *gi.Action) { 1253 | act.SetActiveStateUpdt(!ss.IsRunning) 1254 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1255 | ss.Init() 1256 | vp.SetNeedsFullRender() 1257 | }) 1258 | 1259 | tbar.AddAction(gi.ActOpts{Label: "Train", Icon: "run", Tooltip: "Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.", 1260 | UpdateFunc: func(act *gi.Action) { 1261 | act.SetActiveStateUpdt(!ss.IsRunning) 1262 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1263 | if !ss.IsRunning { 1264 | ss.IsRunning = true 1265 | tbar.UpdateActions() 1266 | // ss.Train() 1267 | go ss.Train() 1268 | } 1269 | }) 1270 | 1271 | tbar.AddAction(gi.ActOpts{Label: "Stop", Icon: "stop", Tooltip: "Interrupts running. Hitting Train again will pick back up where it left off.", UpdateFunc: func(act *gi.Action) { 1272 | act.SetActiveStateUpdt(ss.IsRunning) 1273 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1274 | ss.Stop() 1275 | }) 1276 | 1277 | tbar.AddAction(gi.ActOpts{Label: "Step Trial", Icon: "step-fwd", Tooltip: "Advances one training trial at a time.", UpdateFunc: func(act *gi.Action) { 1278 | act.SetActiveStateUpdt(!ss.IsRunning) 1279 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1280 | if !ss.IsRunning { 1281 | ss.IsRunning = true 1282 | ss.TrainTrial() 1283 | ss.IsRunning = false 1284 | vp.SetNeedsFullRender() 1285 | } 1286 | }) 1287 | 1288 | tbar.AddAction(gi.ActOpts{Label: "Step Epoch", Icon: "fast-fwd", Tooltip: "Advances one epoch (complete set of training patterns) at a time.", UpdateFunc: func(act *gi.Action) { 1289 | act.SetActiveStateUpdt(!ss.IsRunning) 1290 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1291 | if !ss.IsRunning { 1292 | ss.IsRunning = true 1293 | tbar.UpdateActions() 1294 | go ss.TrainEpoch() 1295 | } 1296 | }) 1297 | 1298 | tbar.AddAction(gi.ActOpts{Label: "Step Run", Icon: "fast-fwd", Tooltip: "Advances one full training Run at a time.", UpdateFunc: func(act *gi.Action) { 1299 | act.SetActiveStateUpdt(!ss.IsRunning) 1300 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1301 | if !ss.IsRunning { 1302 | ss.IsRunning = true 1303 | tbar.UpdateActions() 1304 | go ss.TrainRun() 1305 | } 1306 | }) 1307 | 1308 | tbar.AddSeparator("test") 1309 | 1310 | tbar.AddAction(gi.ActOpts{Label: "Test Trial", Icon: "step-fwd", Tooltip: "Runs the next testing trial.", UpdateFunc: func(act *gi.Action) { 1311 | act.SetActiveStateUpdt(!ss.IsRunning) 1312 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1313 | if !ss.IsRunning { 1314 | ss.IsRunning = true 1315 | ss.TestTrial(false) // don't return on change -- wrap 1316 | ss.IsRunning = false 1317 | vp.SetNeedsFullRender() 1318 | } 1319 | }) 1320 | 1321 | tbar.AddAction(gi.ActOpts{Label: "Test Item", Icon: "step-fwd", Tooltip: "Prompts for a specific input pattern name to run, and runs it in testing mode.", UpdateFunc: func(act *gi.Action) { 1322 | act.SetActiveStateUpdt(!ss.IsRunning) 1323 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1324 | gi.StringPromptDialog(vp, "", "Test Item", 1325 | gi.DlgOpts{Title: "Test Item", Prompt: "Enter the Name of a given input pattern to test (case insensitive, contains given string."}, 1326 | win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1327 | dlg := send.(*gi.Dialog) 1328 | if sig == int64(gi.DialogAccepted) { 1329 | val := gi.StringPromptDialogValue(dlg) 1330 | idxs := ss.TestEnv.Table.RowsByString("Name", val, etable.Contains, etable.IgnoreCase) 1331 | if len(idxs) == 0 { 1332 | gi.PromptDialog(nil, gi.DlgOpts{Title: "Name Not Found", Prompt: "No patterns found containing: " + val}, gi.AddOk, gi.NoCancel, nil, nil) 1333 | } else { 1334 | if !ss.IsRunning { 1335 | ss.IsRunning = true 1336 | fmt.Printf("testing index: %d\n", idxs[0]) 1337 | ss.TestItem(idxs[0]) 1338 | ss.IsRunning = false 1339 | vp.SetNeedsFullRender() 1340 | } 1341 | } 1342 | } 1343 | }) 1344 | }) 1345 | 1346 | tbar.AddAction(gi.ActOpts{Label: "Test All", Icon: "fast-fwd", Tooltip: "Tests all of the testing trials.", UpdateFunc: func(act *gi.Action) { 1347 | act.SetActiveStateUpdt(!ss.IsRunning) 1348 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1349 | if !ss.IsRunning { 1350 | ss.IsRunning = true 1351 | tbar.UpdateActions() 1352 | go ss.RunTestAll() 1353 | } 1354 | }) 1355 | 1356 | tbar.AddSeparator("log") 1357 | 1358 | tbar.AddAction(gi.ActOpts{Label: "Reset RunLog", Icon: "reset", Tooltip: "Reset the accumulated log of all Runs, which are tagged with the ParamSet used"}, win.This(), 1359 | func(recv, send ki.Ki, sig int64, data interface{}) { 1360 | ss.RunLog.SetNumRows(0) 1361 | ss.RunPlot.Update() 1362 | }) 1363 | 1364 | tbar.AddSeparator("misc") 1365 | 1366 | tbar.AddAction(gi.ActOpts{Label: "New Seed", Icon: "new", Tooltip: "Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time."}, win.This(), 1367 | func(recv, send ki.Ki, sig int64, data interface{}) { 1368 | ss.NewRndSeed() 1369 | }) 1370 | 1371 | tbar.AddAction(gi.ActOpts{Label: "README", Icon: "file-markdown", Tooltip: "Opens your browser on the README file that contains instructions for how to run this model."}, win.This(), 1372 | func(recv, send ki.Ki, sig int64, data interface{}) { 1373 | gi.OpenURL("https://github.com/emer/leabra/blob/master/examples/ra25/README.md") 1374 | }) 1375 | 1376 | vp.UpdateEndNoSig(updt) 1377 | 1378 | // main menu 1379 | appnm := gi.AppName() 1380 | mmen := win.MainMenu 1381 | mmen.ConfigMenus([]string{appnm, "File", "Edit", "Window"}) 1382 | 1383 | amen := win.MainMenu.ChildByName(appnm, 0).(*gi.Action) 1384 | amen.Menu.AddAppMenu(win) 1385 | 1386 | emen := win.MainMenu.ChildByName("Edit", 1).(*gi.Action) 1387 | emen.Menu.AddCopyCutPaste(win) 1388 | 1389 | // note: Command in shortcuts is automatically translated into Control for 1390 | // Linux, Windows or Meta for MacOS 1391 | // fmen := win.MainMenu.ChildByName("File", 0).(*gi.Action) 1392 | // fmen.Menu.AddAction(gi.ActOpts{Label: "Open", Shortcut: "Command+O"}, 1393 | // win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1394 | // FileViewOpenSVG(vp) 1395 | // }) 1396 | // fmen.Menu.AddSeparator("csep") 1397 | // fmen.Menu.AddAction(gi.ActOpts{Label: "Close Window", Shortcut: "Command+W"}, 1398 | // win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1399 | // win.Close() 1400 | // }) 1401 | 1402 | inQuitPrompt := false 1403 | gi.SetQuitReqFunc(func() { 1404 | if inQuitPrompt { 1405 | return 1406 | } 1407 | inQuitPrompt = true 1408 | gi.PromptDialog(vp, gi.DlgOpts{Title: "Really Quit?", 1409 | Prompt: "Are you sure you want to quit and lose any unsaved params, weights, logs, etc?"}, gi.AddOk, gi.AddCancel, 1410 | win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1411 | if sig == int64(gi.DialogAccepted) { 1412 | gi.Quit() 1413 | } else { 1414 | inQuitPrompt = false 1415 | } 1416 | }) 1417 | }) 1418 | 1419 | // gi.SetQuitCleanFunc(func() { 1420 | // fmt.Printf("Doing final Quit cleanup here..\n") 1421 | // }) 1422 | 1423 | inClosePrompt := false 1424 | win.SetCloseReqFunc(func(w *gi.Window) { 1425 | if inClosePrompt { 1426 | return 1427 | } 1428 | inClosePrompt = true 1429 | gi.PromptDialog(vp, gi.DlgOpts{Title: "Really Close Window?", 1430 | Prompt: "Are you sure you want to close the window? This will Quit the App as well, losing all unsaved params, weights, logs, etc"}, gi.AddOk, gi.AddCancel, 1431 | win.This(), func(recv, send ki.Ki, sig int64, data interface{}) { 1432 | if sig == int64(gi.DialogAccepted) { 1433 | gi.Quit() 1434 | } else { 1435 | inClosePrompt = false 1436 | } 1437 | }) 1438 | }) 1439 | 1440 | win.SetCloseCleanFunc(func(w *gi.Window) { 1441 | go gi.Quit() // once main window is closed, quit 1442 | }) 1443 | 1444 | win.MainMenuUpdated() 1445 | return win 1446 | } 1447 | 1448 | // These props register Save methods so they can be used 1449 | var SimProps = ki.Props{ 1450 | "CallMethods": ki.PropSlice{ 1451 | {"SaveWeights", ki.Props{ 1452 | "desc": "save network weights to file", 1453 | "icon": "file-save", 1454 | "Args": ki.PropSlice{ 1455 | {"File Name", ki.Props{ 1456 | "ext": ".wts,.wts.gz", 1457 | }}, 1458 | }, 1459 | }}, 1460 | }, 1461 | } 1462 | 1463 | func (ss *Sim) CmdArgs() { 1464 | ss.NoGui = true 1465 | var nogui bool 1466 | var saveEpcLog bool 1467 | var saveRunLog bool 1468 | var note string 1469 | flag.StringVar(&ss.ParamSet, "params", "", "ParamSet name to use -- must be valid name as listed in compiled-in params or loaded params") 1470 | flag.StringVar(&ss.Tag, "tag", "", "extra tag to add to file names saved from this run") 1471 | flag.StringVar(¬e, "note", "", "user note -- describe the run params etc") 1472 | flag.IntVar(&ss.MaxRuns, "runs", 10, "number of runs to do (note that MaxEpcs is in paramset)") 1473 | flag.BoolVar(&ss.LogSetParams, "setparams", false, "if true, print a record of each parameter that is set") 1474 | flag.BoolVar(&ss.SaveWts, "wts", false, "if true, save final weights after each run") 1475 | flag.BoolVar(&saveEpcLog, "epclog", true, "if true, save train epoch log to file") 1476 | flag.BoolVar(&saveRunLog, "runlog", true, "if true, save run epoch log to file") 1477 | flag.BoolVar(&nogui, "nogui", true, "if not passing any other args and want to run nogui, use nogui") 1478 | flag.Parse() 1479 | ss.Init() 1480 | 1481 | if note != "" { 1482 | fmt.Printf("note: %s\n", note) 1483 | } 1484 | if ss.ParamSet != "" { 1485 | fmt.Printf("Using ParamSet: %s\n", ss.ParamSet) 1486 | } 1487 | 1488 | if saveEpcLog { 1489 | var err error 1490 | fnm := ss.LogFileName("epc") 1491 | ss.TrnEpcFile, err = os.Create(fnm) 1492 | if err != nil { 1493 | log.Println(err) 1494 | ss.TrnEpcFile = nil 1495 | } else { 1496 | fmt.Printf("Saving epoch log to: %s\n", fnm) 1497 | defer ss.TrnEpcFile.Close() 1498 | } 1499 | } 1500 | if saveRunLog { 1501 | var err error 1502 | fnm := ss.LogFileName("run") 1503 | ss.RunFile, err = os.Create(fnm) 1504 | if err != nil { 1505 | log.Println(err) 1506 | ss.RunFile = nil 1507 | } else { 1508 | fmt.Printf("Saving run log to: %s\n", fnm) 1509 | defer ss.RunFile.Close() 1510 | } 1511 | } 1512 | if ss.SaveWts { 1513 | fmt.Printf("Saving final weights per run\n") 1514 | } 1515 | fmt.Printf("Running %d Runs\n", ss.MaxRuns) 1516 | ss.Train() 1517 | } 1518 | 1519 | -------------------------------------------------------------------------------- /testdata/ra25.py: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/pyleabra 2 | 3 | # Copyright (c) 2019, The Emergent Authors. All rights reserved. 4 | # Use of this source code is governed by a BSD-style 5 | # license that can be found in the LICENSE file. 6 | 7 | # use: 8 | # pyleabra -i ra25.py 9 | # to run in gui interactive mode from the command line (or pyleabra, import ra25) 10 | # see main function at the end for startup args 11 | 12 | # to run this python version of the demo: 13 | # * install gopy, currently in fork at https://github.com/goki/gopy 14 | # e.g., 'go get github.com/goki/gopy -u ./...' and then cd to that package 15 | # and do 'go install' 16 | # * go to the python directory in this emergent repository, read README.md there, and 17 | # type 'make' -- if that works, then type make install (may need sudo) 18 | # * cd back here, and run 'pyemergent' which was installed into /usr/local/bin 19 | # * then type 'import ra25' and this should run 20 | # * you'll need various standard packages such as pandas, numpy, matplotlib, etc 21 | 22 | # labra25ra runs a simple random-associator 5x5 = 25 four-layer leabra network 23 | 24 | from leabra import go, leabra, emer, relpos, eplot, env, agg, patgen, prjn, etable, efile, split, etensor, params, netview, rand, erand, gi, giv, epygiv 25 | 26 | # il.reload(ra25) -- doesn't seem to work for reasons unknown 27 | import importlib as il 28 | import io 29 | import sys 30 | import getopt 31 | # import numpy as np 32 | # import matplotlib 33 | # matplotlib.use('SVG') 34 | # import matplotlib.pyplot as plt 35 | # plt.rcParams['svg.fonttype'] = 'none' # essential for not rendering fonts as paths 36 | 37 | # note: pandas, xarray or pytorch TensorDataSet can be used for input / output 38 | # patterns and recording of "log" data for plotting. However, the etable.Table 39 | # has better GUI and API support, and handles tensor columns directly unlike 40 | # pandas. Support for easy migration between these is forthcoming. 41 | # import pandas as pd 42 | 43 | # this will become Sim later.. 44 | TheSim = 1 45 | 46 | # use this for e.g., etable.Column construction args where nil would be passed 47 | nilInts = go.Slice_int() 48 | 49 | # use this for e.g., etable.Column construction args where nil would be passed 50 | nilStrs = go.Slice_string() 51 | 52 | # LogPrec is precision for saving float values in logs 53 | LogPrec = 4 54 | 55 | # note: we cannot use methods for callbacks from Go -- must be separate functions 56 | # so below are all the callbacks from the GUI toolbar actions 57 | 58 | 59 | def InitCB(recv, send, sig, data): 60 | TheSim.Init() 61 | TheSim.ClassView.Update() 62 | TheSim.vp.SetNeedsFullRender() 63 | 64 | 65 | def TrainCB(recv, send, sig, data): 66 | if not TheSim.IsRunning: 67 | TheSim.IsRunning = True 68 | TheSim.ToolBar.UpdateActions() 69 | TheSim.Train() 70 | 71 | 72 | def StopCB(recv, send, sig, data): 73 | TheSim.Stop() 74 | 75 | 76 | def StepTrialCB(recv, send, sig, data): 77 | if not TheSim.IsRunning: 78 | TheSim.IsRunning = True 79 | TheSim.TrainTrial() 80 | TheSim.IsRunning = False 81 | TheSim.ClassView.Update() 82 | TheSim.vp.SetNeedsFullRender() 83 | 84 | 85 | def StepEpochCB(recv, send, sig, data): 86 | if not TheSim.IsRunning: 87 | TheSim.IsRunning = True 88 | TheSim.ToolBar.UpdateActions() 89 | TheSim.TrainEpoch() 90 | 91 | 92 | def StepRunCB(recv, send, sig, data): 93 | if not TheSim.IsRunning: 94 | TheSim.IsRunning = True 95 | TheSim.ToolBar.UpdateActions() 96 | TheSim.TrainRun() 97 | 98 | 99 | def TestTrialCB(recv, send, sig, data): 100 | if not TheSim.IsRunning: 101 | TheSim.IsRunning = True 102 | TheSim.TestTrial() 103 | TheSim.IsRunning = False 104 | TheSim.ClassView.Update() 105 | TheSim.vp.SetNeedsFullRender() 106 | 107 | 108 | def TestItemCB2(recv, send, sig, data): 109 | win = gi.Window(handle=recv) 110 | vp = win.WinViewport2D() 111 | dlg = gi.Dialog(handle=send) 112 | if sig != gi.DialogAccepted: 113 | return 114 | val = gi.StringPromptDialogValue(dlg) 115 | idxs = TheSim.TestEnv.Table.RowsByString( 116 | "Name", val, True, True) # contains, ignoreCase 117 | if len(idxs) == 0: 118 | gi.PromptDialog(vp, gi.DlgOpts(Title="Name Not Found", 119 | Prompt="No patterns found containing: " + val), True, False, go.nil, go.nil) 120 | else: 121 | if not TheSim.IsRunning: 122 | TheSim.IsRunning = True 123 | print("testing index: %s" % idxs[0]) 124 | TheSim.TestItem(idxs[0]) 125 | TheSim.IsRunning = False 126 | vp.SetNeedsFullRender() 127 | 128 | 129 | def TestItemCB(recv, send, sig, data): 130 | win = gi.Window(handle=recv) 131 | gi.StringPromptDialog(win.WinViewport2D(), "", "Test Item", 132 | gi.DlgOpts(Title="Test Item", Prompt="Enter the Name of a given input pattern to test (case insensitive, contains given string."), win, TestItemCB2) 133 | 134 | 135 | def TestAllCB(recv, send, sig, data): 136 | if not TheSim.IsRunning: 137 | TheSim.IsRunning = True 138 | TheSim.ToolBar.UpdateActions() 139 | TheSim.RunTestAll() 140 | 141 | 142 | def ResetRunLogCB(recv, send, sig, data): 143 | TheSim.RunLog.SetNumRows(0) 144 | TheSim.RunPlot.Update() 145 | 146 | 147 | def NewRndSeedCB(recv, send, sig, data): 148 | TheSim.NewRndSeed() 149 | 150 | 151 | def ReadmeCB(recv, send, sig, data): 152 | gi.OpenURL( 153 | "https://github.com/emer/leabra/blob/master/examples/ra25/README.md") 154 | 155 | 156 | def FilterSSE(et, row): 157 | # include error trials 158 | return etable.Table(handle=et).CellFloat("SSE", row) > 0 159 | 160 | 161 | def UpdtFuncNotRunning(act): 162 | act.SetActiveStateUpdt(not TheSim.IsRunning) 163 | 164 | 165 | def UpdtFuncRunning(act): 166 | act.SetActiveStateUpdt(TheSim.IsRunning) 167 | 168 | 169 | ##################################################### 170 | # Sim 171 | 172 | class Sim(object): 173 | """ 174 | Sim encapsulates the entire simulation model, and we define all the 175 | functionality as methods on this struct. This structure keeps all relevant 176 | state information organized and available without having to pass everything around 177 | as arguments to methods, and provides the core GUI interface (note the view tags 178 | for the fields which provide hints to how things should be displayed). 179 | """ 180 | 181 | def __init__(self): 182 | self.Net = leabra.Network() 183 | self.Pats = etable.Table() 184 | self.TrnEpcLog = etable.Table() 185 | self.TstEpcLog = etable.Table() 186 | self.TstTrlLog = etable.Table() 187 | self.TstErrLog = etable.Table() 188 | self.TstErrStats = etable.Table() 189 | self.TstCycLog = etable.Table() 190 | self.RunLog = etable.Table() 191 | self.RunStats = etable.Table() 192 | self.Params = params.Sets() 193 | self.ParamSet = "" 194 | self.Tag = "" 195 | self.MaxRuns = 10 196 | self.MaxEpcs = 50 197 | self.TrainEnv = env.FixedTable() 198 | self.TestEnv = env.FixedTable() 199 | self.Time = leabra.Time() 200 | self.ViewOn = True 201 | self.TrainUpdt = leabra.AlphaCycle 202 | self.TestUpdt = leabra.Cycle 203 | self.TestInterval = 5 204 | 205 | # statistics 206 | self.TrlSSE = 0.0 207 | self.TrlAvgSSE = 0.0 208 | self.TrlCosDiff = 0.0 209 | self.EpcSSE = 0.0 210 | self.EpcAvgSSE = 0.0 211 | self.EpcPctErr = 0.0 212 | self.EpcPctCor = 0.0 213 | self.EpcCosDiff = 0.0 214 | self.FirstZero = -1 215 | 216 | # internal state - view:"-" 217 | self.SumSSE = 0.0 218 | self.SumAvgSSE = 0.0 219 | self.SumCosDiff = 0.0 220 | self.CntErr = 0.0 221 | self.Win = 0 222 | self.vp = 0 223 | self.ToolBar = 0 224 | self.NetView = 0 225 | self.TrnEpcPlot = 0 226 | self.TstEpcPlot = 0 227 | self.TstTrlPlot = 0 228 | self.TstCycPlot = 0 229 | self.RunPlot = 0 230 | self.TrnEpcFile = 0 231 | self.RunFile = 0 232 | self.InputValsTsr = 0 233 | self.OutputValsTsr = 0 234 | self.SaveWts = False 235 | self.NoGui = False 236 | self.LogSetParams = False # True 237 | self.IsRunning = False 238 | self.StopNow = False 239 | self.RndSeed = 0 240 | 241 | # ClassView tags for controlling display of fields 242 | self.Tags = { 243 | 'TrlSSE': 'inactive:"+"', 244 | 'TrlAvgSSE': 'inactive:"+"', 245 | 'TrlCosDiff': 'inactive:"+"', 246 | 'EpcSSE': 'inactive:"+"', 247 | 'EpcAvgSSE': 'inactive:"+"', 248 | 'EpcPctErr': 'inactive:"+"', 249 | 'EpcPctCor': 'inactive:"+"', 250 | 'EpcCosDiff': 'inactive:"+"', 251 | 'FirstZero': 'inactive:"+"', 252 | 'SumSSE': 'view:"-"', 253 | 'SumAvgSSE': 'view:"-"', 254 | 'SumCosDiff': 'view:"-"', 255 | 'CntErr': 'view:"-"', 256 | 'Win': 'view:"-"', 257 | 'vp': 'view:"-"', 258 | 'ToolBar': 'view:"-"', 259 | 'NetView': 'view:"-"', 260 | 'TrnEpcPlot': 'view:"-"', 261 | 'TstEpcPlot': 'view:"-"', 262 | 'TstTrlPlot': 'view:"-"', 263 | 'TstCycPlot': 'view:"-"', 264 | 'RunPlot': 'view:"-"', 265 | 'TrnEpcFile': 'view:"-"', 266 | 'RunFile': 'view:"-"', 267 | 'InputValsTsr': 'view:"-"', 268 | 'OutputValsTsr': 'view:"-"', 269 | 'SaveWts': 'view:"-"', 270 | 'NoGui': 'view:"-"', 271 | 'LogSetParams': 'view:"-"', 272 | 'IsRunning': 'view:"-"', 273 | 'StopNow': 'view:"-"', 274 | 'RndSeed': 'view:"-"', 275 | 'ClassView': 'view:"-"', 276 | 'Tags': 'view:"-"', 277 | } 278 | 279 | def InitParams(self): 280 | """ 281 | Sets the default set of parameters -- Base is always applied, and others can be optionally 282 | selected to apply on top of that 283 | """ 284 | self.Params.OpenJSON("ra25_std.params") 285 | 286 | def Config(ss): 287 | """ 288 | Config configures all the elements using the standard functions 289 | """ 290 | 291 | ss.InitParams() 292 | # self.OpenPats() 293 | ss.ConfigPats() 294 | ss.ConfigEnv() 295 | ss.ConfigNet(ss.Net) 296 | ss.ConfigTrnEpcLog(ss.TrnEpcLog) 297 | ss.ConfigTstEpcLog(ss.TstEpcLog) 298 | ss.ConfigTstTrlLog(ss.TstTrlLog) 299 | ss.ConfigTstCycLog(ss.TstCycLog) 300 | ss.ConfigRunLog(ss.RunLog) 301 | 302 | def ConfigEnv(ss): 303 | if ss.MaxRuns == 0: # allow user override 304 | ss.MaxRuns = 10 305 | if ss.MaxEpcs == 0: # allow user override 306 | ss.MaxEpcs = 50 307 | ss.NZeroStop = 5 308 | 309 | ss.TrainEnv.Nm = "TrainEnv" 310 | ss.TrainEnv.Dsc = "training params and state" 311 | ss.TrainEnv.Table = etable.NewIdxView(ss.Pats) 312 | ss.TrainEnv.Validate() 313 | # note: we are not setting epoch max -- do that manually 314 | ss.TrainEnv.Run.Max = ss.MaxRuns 315 | 316 | ss.TestEnv.Nm = "TestEnv" 317 | ss.TestEnv.Dsc = "testing params and state" 318 | ss.TestEnv.Table = etable.NewIdxView(ss.Pats) 319 | ss.TestEnv.Sequential = True 320 | ss.TestEnv.Validate() 321 | 322 | # note: to create a train / test split of pats, do this: 323 | # all = etable.NewIdxView(self.Pats) 324 | # splits = split.Permuted(all, []float64{.8, .2}, []string{"Train", "Test"}) 325 | # self.TrainEnv.Table = splits.Splits[0] 326 | # self.TestEnv.Table = splits.Splits[1] 327 | 328 | ss.TrainEnv.Init(0) 329 | ss.TestEnv.Init(0) 330 | 331 | def ConfigNet(ss, net): 332 | net.InitName(net, "RA25") 333 | inLay = net.AddLayer2D("Input", 5, 5, emer.Input) 334 | hid1Lay = net.AddLayer2D("Hidden1", 7, 7, emer.Hidden) 335 | hid2Lay = net.AddLayer2D("Hidden2", 7, 7, emer.Hidden) 336 | outLay = net.AddLayer2D("Output", 5, 5, emer.Target) 337 | 338 | # use this to position layers relative to each other 339 | # default is Above, YAlign = Front, XAlign = Center 340 | hid2Lay.SetRelPos(relpos.Rel(Rel=relpos.RightOf, 341 | Other="Hidden1", YAlign=relpos.Front, Space=2)) 342 | 343 | # note: see emergent/prjn module for all the options on how to connect 344 | # NewFull returns a new prjn.Full connectivity pattern 345 | net.ConnectLayers(inLay, hid1Lay, prjn.NewFull(), emer.Forward) 346 | net.ConnectLayers(hid1Lay, hid2Lay, prjn.NewFull(), emer.Forward) 347 | net.ConnectLayers(hid2Lay, outLay, prjn.NewFull(), emer.Forward) 348 | 349 | net.ConnectLayers(outLay, hid2Lay, prjn.NewFull(), emer.Back) 350 | net.ConnectLayers(hid2Lay, hid1Lay, prjn.NewFull(), emer.Back) 351 | 352 | # note: can set these to do parallel threaded computation across multiple cpus 353 | # not worth it for this small of a model, but definitely helps for larger ones 354 | # if Thread { 355 | # hid2Lay.SetThread(1) 356 | # outLay.SetThread(1) 357 | # } 358 | 359 | # note: if you wanted to change a layer type from e.g., Target to Compare, do this: 360 | # outLay.SetType(emer.Compare) 361 | # that would mean that the output layer doesn't reflect target values in plus phase 362 | # and thus removes error-driven learning -- but stats are still computed. 363 | 364 | net.Defaults() 365 | ss.SetParams("Network", ss.LogSetParams) # only set Network params 366 | net.Build() 367 | net.InitWts() 368 | 369 | ###################################### 370 | # Init, utils 371 | 372 | def Init(ss): 373 | """Init restarts the run, and initializes everything, including network weights and resets the epoch log table""" 374 | rand.Seed(ss.RndSeed) 375 | ss.ConfigEnv() # just in case another set of pats was selected.. 376 | ss.StopNow = False 377 | ss.SetParams("", ss.LogSetParams) # all sheets 378 | ss.NewRun() 379 | ss.UpdateView(True) 380 | 381 | def NewRndSeed(ss): 382 | """NewRndSeed gets a new random seed based on current time -- otherwise uses the same random seed for every run""" 383 | # self.RndSeed = time.Now().UnixNano() 384 | 385 | def Counters(ss, train): 386 | """ 387 | Counters returns a string of the current counter state 388 | use tabs to achieve a reasonable formatting overall 389 | and add a few tabs at the end to allow for expansion.. 390 | """ 391 | if train: 392 | return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TrainEnv.Trial.Cur, ss.Time.Cycle, ss.TrainEnv.TrialName.Cur) 393 | else: 394 | return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\t\tCycle:\t%dName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TestEnv.Trial.Cur, ss.Time.Cycle, ss.TestEnv.TrialName.Cur) 395 | 396 | def UpdateView(ss, train): 397 | if ss.NetView != 0 and ss.NetView.IsVisible(): 398 | ss.NetView.Record(ss.Counters(train)) 399 | # note: essential to use Go version of update when called from another goroutine 400 | # note: using counters is significantly slower.. 401 | ss.NetView.GoUpdate() 402 | 403 | ###################################### 404 | # Running the network 405 | 406 | def AlphaCyc(ss, train): 407 | """ 408 | AlphaCyc runs one alpha-cycle (100 msec, 4 quarters) of processing. 409 | External inputs must have already been applied prior to calling, 410 | using ApplyExt method on relevant layers (see TrainTrial, TestTrial). 411 | If train is true, then learning DWt or WtFmDWt calls are made. 412 | Handles netview updating within scope of AlphaCycle 413 | """ 414 | if ss.Win != 0: 415 | ss.Win.PollEvents() # this is essential for GUI responsiveness while running 416 | viewUpdt = ss.TrainUpdt 417 | if not train: 418 | viewUpdt = ss.TestUpdt 419 | 420 | # update prior weight changes at start, so any DWt values remain visible at end 421 | # you might want to do this less frequently to achieve a mini-batch update 422 | # in which case, move it out to the TrainTrial method where the relevant 423 | # counters are being dealt with. 424 | if train: 425 | ss.Net.WtFmDWt() 426 | 427 | ss.Net.AlphaCycInit() 428 | ss.Time.AlphaCycStart() 429 | for qtr in range(4): 430 | for cyc in range(ss.Time.CycPerQtr): 431 | ss.Net.Cycle(ss.Time) 432 | if not train: 433 | ss.LogTstCyc(ss.TstCycLog, ss.Time.Cycle) 434 | ss.Time.CycleInc() 435 | if ss.ViewOn: 436 | if viewUpdt == leabra.Cycle: 437 | ss.UpdateView(train) 438 | if viewUpdt == leabra.FastSpike: 439 | if (cyc+1) % 10 == 0: 440 | ss.UpdateView(train) 441 | ss.Net.QuarterFinal(ss.Time) 442 | ss.Time.QuarterInc() 443 | if ss.ViewOn: 444 | if viewUpdt == leabra.Quarter: 445 | ss.UpdateView(train) 446 | if viewUpdt == leabra.Phase: 447 | if qtr >= 2: 448 | ss.UpdateView(train) 449 | if train: 450 | ss.Net.DWt() 451 | if ss.ViewOn and viewUpdt == leabra.AlphaCycle: 452 | ss.UpdateView(train) 453 | if ss.TstCycPlot != 0 and not train: 454 | ss.TstCycPlot.GoUpdate() 455 | 456 | def ApplyInputs(ss, en): 457 | """ 458 | ApplyInputs applies input patterns from given environment. 459 | It is good practice to have this be a separate method with appropriate 460 | args so that it can be used for various different contexts 461 | (training, testing, etc). 462 | """ 463 | ss.Net.InitExt() # clear any existing inputs -- not strictly necessary if always 464 | # going to the same layers, but good practice and cheap anyway 465 | inLay = leabra.Layer(ss.Net.LayerByName("Input")) 466 | outLay = leabra.Layer(ss.Net.LayerByName("Output")) 467 | 468 | inPats = en.State(inLay.Nm) 469 | if inPats != go.nil: 470 | inLay.ApplyExt(inPats) 471 | 472 | outPats = en.State(outLay.Nm) 473 | if inPats != go.nil: 474 | outLay.ApplyExt(outPats) 475 | 476 | # NOTE: this is how you can use a pandas.DataFrame() to apply inputs 477 | # we are using etable.Table instead because it provides a full GUI 478 | # for viewing your patterns, and has a more convenient API, that integrates 479 | # with the env environment interface. 480 | # 481 | # inLay = leabra.Layer(self.Net.LayerByName("Input")) 482 | # outLay = leabra.Layer(self.Net.LayerByName("Output")) 483 | # pidx = self.Trial 484 | # if not self.Sequential: 485 | # pidx = self.Porder[self.Trial] 486 | # # note: these indexes must be updated based on columns in patterns.. 487 | # inp = self.Pats.iloc[pidx,1:26].values 488 | # outp = self.Pats.iloc[pidx,26:26+25].values 489 | # self.ApplyExt(inLay, inp) 490 | # self.ApplyExt(outLay, outp) 491 | # 492 | # def ApplyExt(self, lay, nparray): 493 | # flt = np.ndarray.flatten(nparray, 'C') 494 | # slc = go.Slice_float32(flt) 495 | # lay.ApplyExt1D(slc) 496 | 497 | def TrainTrial(ss): 498 | """ TrainTrial runs one trial of training using TrainEnv""" 499 | ss.TrainEnv.Step() # the Env encapsulates and manages all counter state 500 | 501 | # Key to query counters FIRST because current state is in NEXT epoch 502 | # if epoch counter has changed 503 | epc = env.CounterCur(ss.TrainEnv, env.Epoch) 504 | chg = env.CounterChg(ss.TrainEnv, env.Epoch) 505 | if chg: 506 | ss.LogTrnEpc(ss.TrnEpcLog) 507 | if ss.ViewOn and ss.TrainUpdt > leabra.AlphaCycle: 508 | ss.UpdateView(True) 509 | if epc % ss.TestInterval == 0: # note: epc is *next* so won't trigger first time 510 | ss.TestAll() 511 | if epc >= ss.MaxEpcs: # done with training.. 512 | ss.RunEnd() 513 | if ss.TrainEnv.Run.Incr(): # we are done! 514 | ss.StopNow = True 515 | return 516 | else: 517 | ss.NewRun() 518 | return 519 | 520 | ss.ApplyInputs(ss.TrainEnv) 521 | ss.AlphaCyc(True) # train 522 | ss.TrialStats(True) # accumulate 523 | 524 | def RunEnd(ss): 525 | """ RunEnd is called at the end of a run -- save weights, record final log, etc here """ 526 | ss.LogRun(ss.RunLog) 527 | if ss.SaveWts: 528 | fnm = ss.WeightsFileName() 529 | fmt.Printf("Saving Weights to: %v", fnm) 530 | ss.Net.SaveWtsJSON(gi.FileName(fnm)) 531 | 532 | def NewRun(ss): 533 | """ NewRun intializes a new run of the model, using the TrainEnv.Run counter for the new run value """ 534 | run = ss.TrainEnv.Run.Cur 535 | ss.TrainEnv.Init(run) 536 | ss.TestEnv.Init(run) 537 | ss.Time.Reset() 538 | ss.Net.InitWts() 539 | ss.InitStats() 540 | ss.TrnEpcLog.SetNumRows(0) 541 | ss.TstEpcLog.SetNumRows(0) 542 | 543 | def InitStats(ss): 544 | """ InitStats initializes all the statistics, especially important for the 545 | cumulative epoch stats -- called at start of new run """ 546 | # accumulators 547 | ss.SumSSE = 0.0 548 | ss.SumAvgSSE = 0.0 549 | ss.SumCosDiff = 0.0 550 | ss.CntErr = 0.0 551 | ss.FirstZero = -1 552 | # clear rest just to make Sim look initialized 553 | ss.TrlSSE = 0.0 554 | ss.TrlAvgSSE = 0.0 555 | ss.EpcSSE = 0.0 556 | ss.EpcAvgSSE = 0.0 557 | ss.EpcPctErr = 0.0 558 | ss.EpcCosDiff = 0.0 559 | 560 | def TrialStats(ss, accum): 561 | """ 562 | TrialStats computes the trial-level statistics and adds them to the epoch accumulators if 563 | accum is true. Note that we're accumulating stats here on the Sim side so the 564 | core algorithm side remains as simple as possible, and doesn't need to worry about 565 | different time-scales over which stats could be accumulated etc. 566 | You can also aggregate directly from log data, as is done for testing stats 567 | """ 568 | outLay = leabra.Layer(ss.Net.LayerByName("Output")) 569 | ss.TrlCosDiff = outLay.CosDiff.Cos 570 | # 0.5 = per-unit tolerance -- right side of .5 571 | ss.TrlSSE = outLay.SSE(0.5) 572 | ss.TrlAvgSSE = ss.TrlSSE / len(outLay.Neurons) 573 | if accum: 574 | ss.SumSSE += ss.TrlSSE 575 | ss.SumAvgSSE += ss.TrlAvgSSE 576 | ss.SumCosDiff += ss.TrlCosDiff 577 | if ss.TrlSSE != 0: 578 | ss.CntErr += 1.0 579 | 580 | def TrainEpoch(ss): 581 | """ TrainEpoch runs training trials for remainder of this epoch """ 582 | ss.StopNow = False 583 | curEpc = ss.TrainEnv.Epoch.Cur 584 | while True: 585 | ss.TrainTrial() 586 | if ss.StopNow or ss.TrainEnv.Epoch.Cur != curEpc: 587 | break 588 | ss.Stopped() 589 | 590 | def TrainRun(ss): 591 | """ TrainRun runs training trials for remainder of run """ 592 | ss.StopNow = False 593 | curRun = ss.TrainEnv.Run.Cur 594 | while True: 595 | ss.TrainTrial() 596 | if ss.StopNow or ss.TrainEnv.Run.Cur != curRun: 597 | break 598 | ss.Stopped() 599 | 600 | def Train(ss): 601 | """ Train runs the full training from this point onward """ 602 | ss.StopNow = False 603 | while True: 604 | ss.TrainTrial() 605 | if ss.StopNow: 606 | break 607 | ss.Stopped() 608 | 609 | def Stop(ss): 610 | """ Stop tells the sim to stop running """ 611 | ss.StopNow = True 612 | 613 | def Stopped(ss): 614 | """ Stopped is called when a run method stops running -- updates the IsRunning flag and toolbar """ 615 | ss.IsRunning = False 616 | if ss.Win != 0: 617 | ss.vp.BlockUpdates() 618 | if ss.ToolBar != go.nil: 619 | ss.ToolBar.UpdateActions() 620 | ss.vp.UnblockUpdates() 621 | ss.ClassView.Update() 622 | ss.vp.SetNeedsFullRender() 623 | 624 | ###################################### 625 | # Testing 626 | 627 | def TestTrial(ss): 628 | """ TestTrial runs one trial of testing -- always sequentially presented inputs """ 629 | ss.TestEnv.Step() 630 | 631 | # Query counters FIRST 632 | chg = env.CounterChg(ss.TestEnv, env.Epoch) 633 | if chg: 634 | if ss.ViewOn and ss.TestUpdt > leabra.AlphaCycle: 635 | ss.UpdateView(False) 636 | ss.LogTstEpc(ss.TstEpcLog) 637 | return 638 | 639 | ss.ApplyInputs(ss.TestEnv) 640 | ss.AlphaCyc(False) # !train 641 | ss.TrialStats(False) # !accumulate 642 | ss.LogTstTrl(ss.TstTrlLog) 643 | 644 | def TestItem(ss, idx): 645 | """ TestItem tests given item which is at given index in test item list """ 646 | cur = ss.TestEnv.Trial.Cur 647 | ss.TestEnv.Trial.Cur = idx 648 | ss.TestEnv.SetTrialName() 649 | ss.ApplyInputs(ss.TestEnv) 650 | ss.AlphaCyc(False) # !train 651 | ss.TrialStats(False) # !accumulate 652 | ss.TestEnv.Trial.Cur = cur 653 | 654 | def TestAll(ss): 655 | """ TestAll runs through the full set of testing items """ 656 | ss.TestEnv.Init(ss.TrainEnv.Run.Cur) 657 | while True: 658 | ss.TestTrial() 659 | chg = env.CounterChg(ss.TestEnv, env.Epoch) 660 | if chg or ss.StopNow: 661 | break 662 | 663 | def RunTestAll(ss): 664 | """ RunTestAll runs through the full set of testing items, has stop running = false at end -- for gui """ 665 | ss.StopNow = False 666 | ss.TestAll() 667 | ss.Stopped() 668 | 669 | ########################################## 670 | # Params methods 671 | 672 | def ParamsName(ss): 673 | """ ParamsName returns name of current set of parameters """ 674 | if ss.ParamSet == "": 675 | return "Base" 676 | return ss.ParamSet 677 | 678 | def SetParams(ss, sheet, setMsg): 679 | """ 680 | SetParams sets the params for "Base" and then current ParamSet. 681 | If sheet is empty, then it applies all avail sheets (e.g., Network, Sim) 682 | otherwise just the named sheet 683 | if setMsg = true then we output a message for each param that was set. 684 | """ 685 | 686 | if sheet == "": 687 | # this is important for catching typos and ensuring that all sheets can be used 688 | ss.Params.ValidateSheets(go.Slice_string(["Network", "Sim"])) 689 | ss.SetParamsSet("Base", sheet, setMsg) 690 | if ss.ParamSet != "" and ss.ParamSet != "Base": 691 | ss.SetParamsSet(ss.ParamSet, sheet, setMsg) 692 | 693 | def SetParamsSet(ss, setNm, sheet, setMsg): 694 | """ 695 | SetParamsSet sets the params for given params.Set name. 696 | If sheet is empty, then it applies all avail sheets (e.g., Network, Sim) 697 | otherwise just the named sheet 698 | if setMsg = true then we output a message for each param that was set. 699 | """ 700 | pset = ss.Params.SetByNameTry(setNm) 701 | if pset == go.nil: 702 | return 703 | if sheet == "" or sheet == "Network": 704 | if "Network" in pset.Sheets: 705 | netp = pset.SheetByNameTry("Network") 706 | ss.Net.ApplyParams(netp, setMsg) 707 | if sheet == "" or sheet == "Sim": 708 | if "Sim" in pset.Sheets: 709 | simp = pset.SheetByNameTry("Sim") 710 | epygiv.ApplyParams(ss, simp, setMsg) 711 | # note: if you have more complex environments with parameters, definitely add 712 | # sheets for them, e.g., "TrainEnv", "TestEnv" etc 713 | 714 | def ConfigPats(ss): 715 | # note: this is all go-based for using etable.Table instead of pandas 716 | dt = ss.Pats 717 | sc = etable.Schema() 718 | sc.append(etable.Column("Name", etensor.STRING, nilInts, nilStrs)) 719 | sc.append(etable.Column("Input", etensor.FLOAT32, 720 | go.Slice_int([5, 5]), go.Slice_string(["Y", "X"]))) 721 | sc.append(etable.Column("Output", etensor.FLOAT32, 722 | go.Slice_int([5, 5]), go.Slice_string(["Y", "X"]))) 723 | dt.SetFromSchema(sc, 25) 724 | 725 | patgen.PermutedBinaryRows(dt.Cols[1], 6, 1, 0) 726 | patgen.PermutedBinaryRows(dt.Cols[2], 6, 1, 0) 727 | dt.SaveCSV("random_5x5_25_gen.dat", etable.Tab, True) 728 | 729 | def OpenPats(ss): 730 | dt = ss.Pats 731 | ss.Pats = dt 732 | dt.SetMetaData("name", "TrainPats") 733 | dt.SetMetaData("desc", "Training patterns") 734 | dt.OpenCSV("random_5x5_25.dat", etable.Tab) 735 | # Note: here's how to read into a pandas DataFrame 736 | # dt = pd.read_csv("random_5x5_25.dat", sep='\t') 737 | # dt = dt.drop(columns="_H:") 738 | 739 | ########################################## 740 | # Logging 741 | 742 | def RunName(ss): 743 | """ 744 | RunName returns a name for this run that combines Tag and Params -- add this to 745 | any file names that are saved. 746 | """ 747 | if ss.Tag != "": 748 | return ss.Tag + "_" + ss.ParamsName() 749 | else: 750 | return ss.ParamsName() 751 | 752 | def RunEpochName(ss, run, epc): 753 | """ 754 | RunEpochName returns a string with the run and epoch numbers with leading zeros, suitable 755 | for using in weights file names. Uses 3, 5 digits for each. 756 | """ 757 | return "%03d_%05d" % run, epc 758 | 759 | def WeightsFileName(ss): 760 | """ WeightsFileName returns default current weights file name """ 761 | return ss.Net.Nm + "_" + ss.RunName() + "_" + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur) + ".wts" 762 | 763 | def LogFileName(ss, lognm): 764 | """ LogFileName returns default log file name """ 765 | return ss.Net.Nm + "_" + ss.RunName() + "_" + lognm + ".csv" 766 | 767 | ############################# 768 | # TrnEpcLog 769 | 770 | def LogTrnEpc(ss, dt): 771 | """ 772 | LogTrnEpc adds data from current epoch to a TrnEpcLog table 773 | computes epoch averages prior to logging. 774 | """ 775 | row = dt.Rows 776 | ss.TrnEpcLog.SetNumRows(row + 1) 777 | 778 | hid1Lay = leabra.Layer(ss.Net.LayerByName("Hidden1")) 779 | hid2Lay = leabra.Layer(ss.Net.LayerByName("Hidden2")) 780 | outLay = leabra.Layer(ss.Net.LayerByName("Output")) 781 | 782 | # this is triggered by increment so use previous value 783 | epc = ss.TrainEnv.Epoch.Prv 784 | nt = ss.TrainEnv.Table.Len() # number of trials in view 785 | 786 | ss.EpcSSE = ss.SumSSE / nt 787 | ss.SumSSE = 0.0 788 | ss.EpcAvgSSE = ss.SumAvgSSE / nt 789 | ss.SumAvgSSE = 0.0 790 | ss.EpcPctErr = ss.CntErr / nt 791 | ss.CntErr = 0.0 792 | ss.EpcPctCor = 1.0 - ss.EpcPctErr 793 | ss.EpcCosDiff = ss.SumCosDiff / nt 794 | ss.SumCosDiff = 0.0 795 | if ss.FirstZero < 0 and ss.EpcPctErr == 0: 796 | ss.FirstZero = epc 797 | 798 | dt.SetCellFloat("Run", row, ss.TrainEnv.Run.Cur) 799 | dt.SetCellFloat("Epoch", row, epc) 800 | dt.SetCellFloat("SSE", row, ss.EpcSSE) 801 | dt.SetCellFloat("AvgSSE", row, ss.EpcAvgSSE) 802 | dt.SetCellFloat("PctErr", row, ss.EpcPctErr) 803 | dt.SetCellFloat("PctCor", row, ss.EpcPctCor) 804 | dt.SetCellFloat("CosDiff", row, ss.EpcCosDiff) 805 | dt.SetCellFloat("Hid1 ActAvg", row, hid1Lay.Pool(0).ActAvg.ActPAvgEff) 806 | dt.SetCellFloat("Hid2 ActAvg", row, hid2Lay.Pool(0).ActAvg.ActPAvgEff) 807 | dt.SetCellFloat("Out ActAvg", row, outLay.Pool(0).ActAvg.ActPAvgEff) 808 | 809 | # note: essential to use Go version of update when called from another goroutine 810 | if ss.TrnEpcPlot != 0: 811 | ss.TrnEpcPlot.GoUpdate() 812 | 813 | if ss.TrnEpcFile != 0: 814 | if ss.TrainEnv.Run.Cur == 0 and epc == 0: 815 | dt.WriteCSVHeaders(ss.TrnEpcFile, etable.Tab) 816 | dt.WriteCSVRow(ss.TrnEpcFile, row, etable.Tab) 817 | 818 | # note: this is how you log to a pandas.DataFrame 819 | # nwdat = [epc, self.EpcSSE, self.EpcAvgSSE, self.EpcPctErr, self.EpcPctCor, self.EpcCosDiff, 0, 0, 0] 820 | # nrow = len(self.EpcLog.index) 821 | # self.EpcLog.loc[nrow] = nwdat # note: this is reportedly rather slow 822 | 823 | def ConfigTrnEpcLog(ss, dt): 824 | dt.SetMetaData("name", "TrnEpcLog") 825 | dt.SetMetaData("desc", "Record of performance over epochs of training") 826 | dt.SetMetaData("read-only", "true") 827 | dt.SetMetaData("precision", str(LogPrec)) 828 | 829 | sc = etable.Schema() 830 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs)) 831 | sc.append(etable.Column("Epoch", etensor.INT64, nilInts, nilStrs)) 832 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs)) 833 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs)) 834 | sc.append(etable.Column("PctErr", etensor.FLOAT64, nilInts, nilStrs)) 835 | sc.append(etable.Column("PctCor", etensor.FLOAT64, nilInts, nilStrs)) 836 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs)) 837 | sc.append(etable.Column("Hid1 ActAvg", 838 | etensor.FLOAT64, nilInts, nilStrs)) 839 | sc.append(etable.Column("Hid2 ActAvg", 840 | etensor.FLOAT64, nilInts, nilStrs)) 841 | sc.append(etable.Column("Out ActAvg", etensor.FLOAT64, nilInts, nilStrs)) 842 | dt.SetFromSchema(sc, 0) 843 | 844 | # note: pandas.DataFrame version 845 | # self.EpcLog = pd.DataFrame(columns=["Epoch", "SSE", "Avg SSE", "Pct Err", "Pct Cor", "CosDiff", "Hid1 ActAvg", "Hid2 ActAvg", "Out ActAvg"]) 846 | # self.PlotVals = ["SSE", "Pct Err"] 847 | # self.Plot = True 848 | 849 | def ConfigTrnEpcPlot(ss, plt, dt): 850 | plt.Params.Title = "Leabra Random Associator 25 Epoch Plot" 851 | plt.Params.XAxisCol = "Epoch" 852 | plt.SetTable(dt) 853 | # order of params: on, fixMin, min, fixMax, max 854 | plt.SetColParams("Run", False, True, 0, False, 0) 855 | plt.SetColParams("Epoch", False, True, 0, False, 0) 856 | plt.SetColParams("SSE", False, True, 0, False, 0) 857 | plt.SetColParams("AvgSSE", False, True, 0, False, 0) 858 | plt.SetColParams("PctErr", True, True, 0, True, 1) # default plot 859 | plt.SetColParams("PctCor", True, True, 0, True, 1) # default plot 860 | plt.SetColParams("CosDiff", False, True, 0, True, 1) 861 | plt.SetColParams("Hid1 ActAvg", False, True, 0, True, .5) 862 | plt.SetColParams("Hid2 ActAvg", False, True, 0, True, .5) 863 | plt.SetColParams("Out ActAvg", False, True, 0, True, .5) 864 | return plt 865 | 866 | ############################# 867 | # TstTrlLog 868 | 869 | def LogTstTrl(ss, dt): 870 | """ 871 | LogTstTrl adds data from current epoch to the TstTrlLog table 872 | log always contains number of testing items 873 | """ 874 | dt = ss.TstTrlLog 875 | 876 | inLay = leabra.Layer(ss.Net.LayerByName("Input")) 877 | hid1Lay = leabra.Layer(ss.Net.LayerByName("Hidden1")) 878 | hid2Lay = leabra.Layer(ss.Net.LayerByName("Hidden2")) 879 | outLay = leabra.Layer(ss.Net.LayerByName("Output")) 880 | 881 | # this is triggered by increment so use previous value 882 | epc = ss.TrainEnv.Epoch.Prv 883 | trl = ss.TestEnv.Trial.Cur 884 | 885 | dt.SetCellFloat("Epoch", trl, epc) 886 | dt.SetCellFloat("Trial", trl, trl) 887 | dt.SetCellString("TrialName", trl, ss.TestEnv.TrialName.Cur) 888 | dt.SetCellFloat("SSE", trl, ss.TrlSSE) 889 | dt.SetCellFloat("AvgSSE", trl, ss.TrlAvgSSE) 890 | dt.SetCellFloat("CosDiff", trl, ss.TrlCosDiff) 891 | dt.SetCellFloat("Hid1 ActM.Avg", trl, hid1Lay.Pool(0).ActM.Avg) 892 | dt.SetCellFloat("Hid2 ActM.Avg", trl, hid2Lay.Pool(0).ActM.Avg) 893 | dt.SetCellFloat("Out ActM.Avg", trl, outLay.Pool(0).ActM.Avg) 894 | 895 | if ss.InputValsTsr == 0: # re-use same tensors so not always reallocating mem 896 | ss.InputValsTsr = etensor.Float32() 897 | ss.OutputValsTsr = etensor.Float32() 898 | inLay.UnitValsTensor(ss.InputValsTsr, "Act") 899 | dt.SetCellTensor("InAct", trl, ss.InputValsTsr) 900 | outLay.UnitValsTensor(ss.OutputValsTsr, "ActM") 901 | dt.SetCellTensor("OutActM", trl, ss.OutputValsTsr) 902 | outLay.UnitValsTensor(ss.OutputValsTsr, "ActP") 903 | dt.SetCellTensor("OutActP", trl, ss.OutputValsTsr) 904 | 905 | # note: essential to use Go version of update when called from another goroutine 906 | if ss.TstTrlPlot != 0: 907 | ss.TstTrlPlot.GoUpdate() 908 | 909 | def ConfigTstTrlLog(ss, dt): 910 | inLay = leabra.Layer(ss.Net.LayerByName("Input")) 911 | outLay = leabra.Layer(ss.Net.LayerByName("Output")) 912 | 913 | dt.SetMetaData("name", "TstTrlLog") 914 | dt.SetMetaData("desc", "Record of testing per input pattern") 915 | dt.SetMetaData("read-only", "true") 916 | dt.SetMetaData("precision", str(LogPrec)) 917 | nt = ss.TestEnv.Table.Len() # number in view 918 | 919 | sc = etable.Schema() 920 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs)) 921 | sc.append(etable.Column("Epoch", etensor.INT64, nilInts, nilStrs)) 922 | sc.append(etable.Column("Trial", etensor.INT64, nilInts, nilStrs)) 923 | sc.append(etable.Column("TrialName", etensor.STRING, nilInts, nilStrs)) 924 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs)) 925 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs)) 926 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs)) 927 | sc.append(etable.Column("Hid1 ActM.Avg", 928 | etensor.FLOAT64, nilInts, nilStrs)) 929 | sc.append(etable.Column("Hid2 ActM.Avg", 930 | etensor.FLOAT64, nilInts, nilStrs)) 931 | sc.append(etable.Column("Out ActM.Avg", 932 | etensor.FLOAT64, nilInts, nilStrs)) 933 | sc.append(etable.Column( 934 | "InAct", etensor.FLOAT64, inLay.Shp.Shp, nilStrs)) 935 | sc.append(etable.Column( 936 | "OutActM", etensor.FLOAT64, outLay.Shp.Shp, nilStrs)) 937 | sc.append(etable.Column( 938 | "OutActP", etensor.FLOAT64, outLay.Shp.Shp, nilStrs)) 939 | dt.SetFromSchema(sc, nt) 940 | 941 | def ConfigTstTrlPlot(ss, plt, dt): 942 | plt.Params.Title = "Leabra Random Associator 25 Test Trial Plot" 943 | plt.Params.XAxisCol = "Trial" 944 | plt.SetTable(dt) 945 | # order of params: on, fixMin, min, fixMax, max 946 | plt.SetColParams("Run", False, True, 0, False, 0) 947 | plt.SetColParams("Epoch", False, True, 0, False, 0) 948 | plt.SetColParams("Trial", False, True, 0, False, 0) 949 | plt.SetColParams("TrialName", False, True, 0, False, 0) 950 | plt.SetColParams("SSE", False, True, 0, False, 0) 951 | plt.SetColParams("AvgSSE", False, True, 0, False, 0) 952 | plt.SetColParams("CosDiff", True, True, 0, True, 1) 953 | plt.SetColParams("Hid1 ActM.Avg", True, True, 0, True, .5) 954 | plt.SetColParams("Hid2 ActM.Avg", True, True, 0, True, .5) 955 | plt.SetColParams("Out ActM.Avg", True, True, 0, True, .5) 956 | 957 | plt.SetColParams("InAct", False, True, 0, True, 1) 958 | plt.SetColParams("OutActM", False, True, 0, True, 1) 959 | plt.SetColParams("OutActP", False, True, 0, True, 1) 960 | return plt 961 | 962 | ############################# 963 | # TstEpcLog 964 | 965 | def LogTstEpc(ss, dt): 966 | """ 967 | LogTstEpc adds data from current epoch to the TstEpcLog table 968 | log always contains number of testing items 969 | """ 970 | row = dt.Rows 971 | dt.SetNumRows(row + 1) 972 | 973 | trl = ss.TstTrlLog 974 | tix = etable.NewIdxView(trl) 975 | epc = ss.TrainEnv.Epoch.Prv 976 | 977 | # note: this shows how to use agg methods to compute summary data from another 978 | # data table, instead of incrementing on the Sim 979 | dt.SetCellFloat("Run", row, ss.TrainEnv.Run.Cur) 980 | dt.SetCellFloat("Epoch", row, epc) 981 | dt.SetCellFloat("SSE", row, agg.Sum(tix, "SSE")[0]) 982 | dt.SetCellFloat("AvgSSE", row, agg.Mean(tix, "AvgSSE")[0]) 983 | dt.SetCellFloat("PctErr", row, agg.PropIf( 984 | tix, "SSE", lambda idx, val: val > 0)[0]) 985 | dt.SetCellFloat("PctCor", row, agg.PropIf( 986 | tix, "SSE", lambda idx, val: val == 0)[0]) 987 | dt.SetCellFloat("CosDiff", row, agg.Mean(tix, "CosDiff")[0]) 988 | 989 | trlix = etable.NewIdxView(trl) 990 | trlix.Filter(FilterSSE) 991 | 992 | ss.TstErrLog = trlix.NewTable() 993 | 994 | allsp = split.All(trlix) 995 | split.Agg(allsp, "SSE", agg.AggSum) 996 | split.Agg(allsp, "AvgSSE", agg.AggMean) 997 | split.Agg(allsp, "InAct", agg.AggMean) 998 | split.Agg(allsp, "OutActM", agg.AggMean) 999 | split.Agg(allsp, "OutActP", agg.AggMean) 1000 | 1001 | ss.TstErrStats = allsp.AggsToTable(False) 1002 | 1003 | # note: essential to use Go version of update when called from another goroutine 1004 | if ss.TstEpcPlot != 0: 1005 | ss.TstEpcPlot.GoUpdate() 1006 | 1007 | def ConfigTstEpcLog(ss, dt): 1008 | dt.SetMetaData("name", "TstEpcLog") 1009 | dt.SetMetaData("desc", "Summary stats for testing trials") 1010 | dt.SetMetaData("read-only", "true") 1011 | dt.SetMetaData("precision", str(LogPrec)) 1012 | 1013 | sc = etable.Schema() 1014 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs)) 1015 | sc.append(etable.Column("Epoch", etensor.INT64, nilInts, nilStrs)) 1016 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs)) 1017 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs)) 1018 | sc.append(etable.Column("PctErr", etensor.FLOAT64, nilInts, nilStrs)) 1019 | sc.append(etable.Column("PctCor", etensor.FLOAT64, nilInts, nilStrs)) 1020 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs)) 1021 | dt.SetFromSchema(sc, 0) 1022 | 1023 | def ConfigTstEpcPlot(ss, plt, dt): 1024 | plt.Params.Title = "Leabra Random Associator 25 Testing Epoch Plot" 1025 | plt.Params.XAxisCol = "Epoch" 1026 | plt.SetTable(dt) 1027 | # order of params: on, fixMin, min, fixMax, max 1028 | plt.SetColParams("Run", False, True, 0, False, 0) 1029 | plt.SetColParams("Epoch", False, True, 0, False, 0) 1030 | plt.SetColParams("SSE", False, True, 0, False, 0) 1031 | plt.SetColParams("AvgSSE", False, True, 0, False, 0) 1032 | plt.SetColParams("PctErr", True, True, 0, True, 1) # default plot 1033 | plt.SetColParams("PctCor", True, True, 0, True, 1) # default plot 1034 | plt.SetColParams("CosDiff", False, True, 0, True, 1) 1035 | return plt 1036 | 1037 | ############################# 1038 | # TstCycLog 1039 | 1040 | def LogTstCyc(ss, dt, cyc): 1041 | """ 1042 | LogTstCyc adds data from current trial to the TstCycLog table. 1043 | log just has 100 cycles, is overwritten 1044 | """ 1045 | if dt.Rows <= cyc: 1046 | dt.SetNumRows(cyc + 1) 1047 | 1048 | hid1Lay = leabra.Layer(ss.Net.LayerByName("Hidden1")) 1049 | hid2Lay = leabra.Layer(ss.Net.LayerByName("Hidden2")) 1050 | outLay = leabra.Layer(ss.Net.LayerByName("Output")) 1051 | 1052 | dt.SetCellFloat("Cycle", cyc, cyc) 1053 | dt.SetCellFloat("Hid1 Ge.Avg", cyc, hid1Lay.Pool(0).Inhib.Ge.Avg) 1054 | dt.SetCellFloat("Hid2 Ge.Avg", cyc, hid2Lay.Pool(0).Inhib.Ge.Avg) 1055 | dt.SetCellFloat("Out Ge.Avg", cyc, outLay.Pool(0).Inhib.Ge.Avg) 1056 | dt.SetCellFloat("Hid1 Act.Avg", cyc, hid1Lay.Pool(0).Inhib.Act.Avg) 1057 | dt.SetCellFloat("Hid2 Act.Avg", cyc, hid2Lay.Pool(0).Inhib.Act.Avg) 1058 | dt.SetCellFloat("Out Act.Avg", cyc, outLay.Pool(0).Inhib.Act.Avg) 1059 | 1060 | if ss.TstCycPlot != 0 and cyc % 10 == 0: # too slow to do every cyc 1061 | # note: essential to use Go version of update when called from another goroutine 1062 | ss.TstCycPlot.GoUpdate() 1063 | 1064 | def ConfigTstCycLog(ss, dt): 1065 | dt.SetMetaData("name", "TstCycLog") 1066 | dt.SetMetaData( 1067 | "desc", "Record of activity etc over one trial by cycle") 1068 | dt.SetMetaData("read-only", "true") 1069 | dt.SetMetaData("precision", str(LogPrec)) 1070 | np = 100 # max cycles 1071 | 1072 | sc = etable.Schema() 1073 | sc.append(etable.Column("Cycle", etensor.INT64, nilInts, nilStrs)) 1074 | sc.append(etable.Column("Hid1 Ge.Avg", 1075 | etensor.FLOAT64, nilInts, nilStrs)) 1076 | sc.append(etable.Column("Hid2 Ge.Avg", 1077 | etensor.FLOAT64, nilInts, nilStrs)) 1078 | sc.append(etable.Column("Out Ge.Avg", etensor.FLOAT64, nilInts, nilStrs)) 1079 | sc.append(etable.Column("Hid1 Act.Avg", 1080 | etensor.FLOAT64, nilInts, nilStrs)) 1081 | sc.append(etable.Column("Hid2 Act.Avg", 1082 | etensor.FLOAT64, nilInts, nilStrs)) 1083 | sc.append(etable.Column("Out Act.Avg", 1084 | etensor.FLOAT64, nilInts, nilStrs)) 1085 | dt.SetFromSchema(sc, np) 1086 | 1087 | def ConfigTstCycPlot(ss, plt, dt): 1088 | plt.Params.Title = "Leabra Random Associator 25 Test Cycle Plot" 1089 | plt.Params.XAxisCol = "Cycle" 1090 | plt.SetTable(dt) 1091 | # order of params: on, fixMin, min, fixMax, max 1092 | plt.SetColParams("Cycle", False, True, 0, False, 0) 1093 | plt.SetColParams("Hid1 Ge.Avg", True, True, 0, True, .5) 1094 | plt.SetColParams("Hid2 Ge.Avg", True, True, 0, True, .5) 1095 | plt.SetColParams("Out Ge.Avg", True, True, 0, True, .5) 1096 | plt.SetColParams("Hid1 Act.Avg", True, True, 0, True, .5) 1097 | plt.SetColParams("Hid2 Act.Avg", True, True, 0, True, .5) 1098 | plt.SetColParams("Out Act.Avg", True, True, 0, True, .5) 1099 | return plt 1100 | 1101 | ############################# 1102 | # RunLog 1103 | 1104 | def LogRun(ss, dt): 1105 | run = ss.TrainEnv.Run.Cur # this is NOT triggered by increment yet -- use Cur 1106 | row = dt.Rows 1107 | ss.RunLog.SetNumRows(row + 1) 1108 | 1109 | epclog = ss.TrnEpcLog 1110 | # compute mean over last N epochs for run level 1111 | nlast = 10 1112 | epcix = etable.NewIdxView(epclog) 1113 | epcix.Idxs = go.Slice_int(epcix.Idxs[epcix.Len()-nlast-1:]) 1114 | # print(epcix.Idxs[epcix.Len()-nlast-1:]) 1115 | 1116 | params = ss.RunName() # includes tag 1117 | 1118 | dt.SetCellFloat("Run", row, run) 1119 | dt.SetCellString("Params", row, params) 1120 | dt.SetCellFloat("FirstZero", row, ss.FirstZero) 1121 | dt.SetCellFloat("SSE", row, agg.Mean(epcix, "SSE")[0]) 1122 | dt.SetCellFloat("AvgSSE", row, agg.Mean(epcix, "AvgSSE")[0]) 1123 | dt.SetCellFloat("PctErr", row, agg.Mean(epcix, "PctErr")[0]) 1124 | dt.SetCellFloat("PctCor", row, agg.Mean(epcix, "PctCor")[0]) 1125 | dt.SetCellFloat("CosDiff", row, agg.Mean(epcix, "CosDiff")[0]) 1126 | 1127 | runix = etable.NewIdxView(dt) 1128 | spl = split.GroupBy(runix, go.Slice_string(["Params"])) 1129 | split.Desc(spl, "FirstZero") 1130 | split.Desc(spl, "PctCor") 1131 | ss.RunStats = spl.AggsToTable(False) 1132 | 1133 | # note: essential to use Go version of update when called from another goroutine 1134 | if ss.RunPlot != 0: 1135 | ss.RunPlot.GoUpdate() 1136 | 1137 | if ss.RunFile != 0: 1138 | if row == 0: 1139 | dt.WriteCSVHeaders(ss.RunFile, etable.Tab) 1140 | dt.WriteCSVRow(ss.RunFile, row, etable.Tab) 1141 | 1142 | def ConfigRunLog(ss, dt): 1143 | dt.SetMetaData("name", "RunLog") 1144 | dt.SetMetaData("desc", "Record of performance at end of training") 1145 | dt.SetMetaData("read-only", "true") 1146 | dt.SetMetaData("precision", str(LogPrec)) 1147 | 1148 | sc = etable.Schema() 1149 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs)) 1150 | sc.append(etable.Column("Params", etensor.STRING, nilInts, nilStrs)) 1151 | sc.append(etable.Column("FirstZero", etensor.FLOAT64, nilInts, nilStrs)) 1152 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs)) 1153 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs)) 1154 | sc.append(etable.Column("PctErr", etensor.FLOAT64, nilInts, nilStrs)) 1155 | sc.append(etable.Column("PctCor", etensor.FLOAT64, nilInts, nilStrs)) 1156 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs)) 1157 | dt.SetFromSchema(sc, 0) 1158 | 1159 | def ConfigRunPlot(ss, plt, dt): 1160 | plt.Params.Title = "Leabra Random Associator 25 Run Plot" 1161 | plt.Params.XAxisCol = "Run" 1162 | plt.SetTable(dt) 1163 | # order of params: on, fixMin, min, fixMax, max 1164 | plt.SetColParams("Run", False, True, 0, False, 0) 1165 | plt.SetColParams("FirstZero", True, True, 0, False, 0) # default plot 1166 | plt.SetColParams("SSE", False, True, 0, False, 0) 1167 | plt.SetColParams("AvgSSE", False, True, 0, False, 0) 1168 | plt.SetColParams("PctErr", False, True, 0, True, 1) 1169 | plt.SetColParams("PctCor", False, True, 0, True, 1) 1170 | plt.SetColParams("CosDiff", False, True, 0, True, 1) 1171 | return plt 1172 | 1173 | ############################## 1174 | # ConfigGui 1175 | 1176 | def ConfigGui(ss): 1177 | """ConfigGui configures the GoGi gui interface for this simulation""" 1178 | width = 1600 1179 | height = 1200 1180 | 1181 | gi.SetAppName("ra25") 1182 | gi.SetAppAbout( 1183 | 'This demonstrates a basic Leabra model. See emergent on GitHub.

') 1184 | 1185 | win = gi.NewMainWindow( 1186 | "ra25", "Leabra Random Associator", width, height) 1187 | ss.Win = win 1188 | 1189 | vp = win.WinViewport2D() 1190 | ss.vp = vp 1191 | updt = vp.UpdateStart() 1192 | 1193 | mfr = win.SetMainFrame() 1194 | 1195 | tbar = gi.AddNewToolBar(mfr, "tbar") 1196 | tbar.SetStretchMaxWidth() 1197 | ss.ToolBar = tbar 1198 | 1199 | split = gi.AddNewSplitView(mfr, "split") 1200 | split.Dim = gi.X 1201 | split.SetStretchMaxWidth() 1202 | split.SetStretchMaxHeight() 1203 | 1204 | ss.ClassView = epygiv.ClassView("ra25sv", ss.Tags) 1205 | ss.ClassView.AddFrame(split) 1206 | ss.ClassView.SetClass(ss) 1207 | 1208 | tv = gi.AddNewTabView(split, "tv") 1209 | 1210 | nv = netview.NetView() 1211 | tv.AddTab(nv, "NetView") 1212 | nv.Var = "Act" 1213 | nv.SetNet(ss.Net) 1214 | ss.NetView = nv 1215 | 1216 | plt = eplot.Plot2D() 1217 | tv.AddTab(plt, "TrnEpcPlot") 1218 | ss.TrnEpcPlot = ss.ConfigTrnEpcPlot(plt, ss.TrnEpcLog) 1219 | 1220 | plt = eplot.Plot2D() 1221 | tv.AddTab(plt, "TstTrlPlot") 1222 | ss.TstTrlPlot = ss.ConfigTstTrlPlot(plt, ss.TstTrlLog) 1223 | 1224 | plt = eplot.Plot2D() 1225 | tv.AddTab(plt, "TstCycPlot") 1226 | ss.TstCycPlot = ss.ConfigTstCycPlot(plt, ss.TstCycLog) 1227 | 1228 | plt = eplot.Plot2D() 1229 | tv.AddTab(plt, "TstEpcPlot") 1230 | ss.TstEpcPlot = ss.ConfigTstEpcPlot(plt, ss.TstEpcLog) 1231 | 1232 | plt = eplot.Plot2D() 1233 | tv.AddTab(plt, "RunPlot") 1234 | ss.RunPlot = ss.ConfigRunPlot(plt, ss.RunLog) 1235 | 1236 | split.SetSplitsList(go.Slice_float32([.3, .7])) 1237 | 1238 | recv = win.This() 1239 | 1240 | tbar.AddAction(gi.ActOpts(Label="Init", Icon="update", 1241 | Tooltip="Initialize everything including network weights, and start over. Also applies current params.", UpdateFunc=UpdtFuncNotRunning), recv, InitCB) 1242 | 1243 | tbar.AddAction(gi.ActOpts(Label="Train", Icon="run", Tooltip="Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.", UpdateFunc=UpdtFuncNotRunning), recv, TrainCB) 1244 | 1245 | tbar.AddAction(gi.ActOpts(Label="Stop", Icon="stop", 1246 | Tooltip="Interrupts running. Hitting Train again will pick back up where it left off.", UpdateFunc=UpdtFuncRunning), recv, StopCB) 1247 | 1248 | tbar.AddAction(gi.ActOpts(Label="Step Trial", Icon="step-fwd", 1249 | Tooltip="Advances one training trial at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepTrialCB) 1250 | 1251 | tbar.AddAction(gi.ActOpts(Label="Step Epoch", Icon="fast-fwd", 1252 | Tooltip="Advances one epoch (complete set of training patterns) at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepEpochCB) 1253 | 1254 | tbar.AddAction(gi.ActOpts(Label="Step Run", Icon="fast-fwd", 1255 | Tooltip="Advances one full training Run at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepRunCB) 1256 | 1257 | tbar.AddSeparator("test") 1258 | 1259 | tbar.AddAction(gi.ActOpts(Label="Test Trial", Icon="step-fwd", 1260 | Tooltip="Runs the next testing trial.", UpdateFunc=UpdtFuncNotRunning), recv, TestTrialCB) 1261 | 1262 | tbar.AddAction(gi.ActOpts(Label="Test Item", Icon="step-fwd", 1263 | Tooltip="Prompts for a specific input pattern name to run, and runs it in testing mode.", UpdateFunc=UpdtFuncNotRunning), recv, TestItemCB) 1264 | 1265 | tbar.AddAction(gi.ActOpts(Label="Test All", Icon="fast-fwd", 1266 | Tooltip="Tests all of the testing trials.", UpdateFunc=UpdtFuncNotRunning), recv, TestAllCB) 1267 | 1268 | tbar.AddSeparator("log") 1269 | 1270 | tbar.AddAction(gi.ActOpts(Label="Reset RunLog", Icon="reset", 1271 | Tooltip="Resets the accumulated log of all Runs, which are tagged with the ParamSet used"), recv, ResetRunLogCB) 1272 | 1273 | tbar.AddSeparator("misc") 1274 | 1275 | tbar.AddAction(gi.ActOpts(Label="New Seed", Icon="new", 1276 | Tooltip="Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time."), recv, NewRndSeedCB) 1277 | 1278 | tbar.AddAction(gi.ActOpts(Label="README", Icon="file-markdown", 1279 | Tooltip="Opens your browser on the README file that contains instructions for how to run this model."), recv, ReadmeCB) 1280 | 1281 | # main menu 1282 | appnm = gi.AppName() 1283 | mmen = win.MainMenu 1284 | mmen.ConfigMenus(go.Slice_string([appnm, "File", "Edit", "Window"])) 1285 | 1286 | amen = gi.Action(win.MainMenu.ChildByName(appnm, 0)) 1287 | amen.Menu.AddAppMenu(win) 1288 | 1289 | emen = gi.Action(win.MainMenu.ChildByName("Edit", 1)) 1290 | emen.Menu.AddCopyCutPaste(win) 1291 | 1292 | # note: Command in shortcuts is automatically translated into Control for 1293 | # Linux, Windows or Meta for MacOS 1294 | # fmen = win.MainMenu.ChildByName("File", 0).(*gi.Action) 1295 | # fmen.Menu = make(gi.Menu, 0, 10) 1296 | # fmen.Menu.AddAction(gi.ActOpts{Label: "Open", Shortcut: "Command+O"}, 1297 | # recv, func(recv, send ki.Ki, sig int64, data interface{}) { 1298 | # FileViewOpenSVG(vp) 1299 | # }) 1300 | # fmen.Menu.AddSeparator("csep") 1301 | # fmen.Menu.AddAction(gi.ActOpts{Label: "Close Window", Shortcut: "Command+W"}, 1302 | # recv, func(recv, send ki.Ki, sig int64, data interface{}) { 1303 | # win.CloseReq() 1304 | # }) 1305 | 1306 | # win.SetCloseCleanFunc(func(w *gi.Window) { 1307 | # gi.Quit() # once main window is closed, quit 1308 | # }) 1309 | # 1310 | win.MainMenuUpdated() 1311 | vp.UpdateEndNoSig(updt) 1312 | win.GoStartEventLoop() 1313 | 1314 | 1315 | # TheSim is the overall state for this simulation 1316 | TheSim = Sim() 1317 | 1318 | 1319 | def usage(): 1320 | print(sys.argv[0] + " --params= --tag= --setparams --wts --epclog=0 --runlog=0 --nogui") 1321 | print("\t pyleabra -i %s to run in interactive, gui mode" % sys.argv[0]) 1322 | print("\t --params= additional params to apply on top of Base (name must be in loaded Params") 1323 | print("\t --tag= tag is appended to file names to uniquely identify this run") 1324 | print("\t --runs= number of runs to do") 1325 | print("\t --setparams show the parameter values that are set") 1326 | print("\t --wts save final trained weights after every run") 1327 | print("\t --epclog=0/False turn off save training epoch log data to file named by param set, tag") 1328 | print("\t --runlog=0/False turn off save run log data to file named by param set, tag") 1329 | print("\t --nogui if no other args needed, this prevents running under the gui") 1330 | 1331 | 1332 | def main(argv): 1333 | TheSim.Config() 1334 | 1335 | # print("n args: %d" % len(argv)) 1336 | TheSim.NoGui = len(argv) > 1 1337 | saveEpcLog = True 1338 | saveRunLog = True 1339 | 1340 | try: 1341 | opts, args = getopt.getopt(argv, "h:", [ 1342 | "params=", "tag=", "runs=", "setparams", "wts", "epclog=", "runlog=", "nogui"]) 1343 | except getopt.GetoptError: 1344 | usage() 1345 | sys.exit(2) 1346 | for opt, arg in opts: 1347 | # print("opt: %s arg: %s" % (opt, arg)) 1348 | if opt == '-h': 1349 | usage() 1350 | sys.exit() 1351 | elif opt == "--tag": 1352 | TheSim.Tag = arg 1353 | elif opt == "--runs": 1354 | TheSim.MaxRuns = int(arg) 1355 | print("Running %d runs" % TheSim.MaxRuns) 1356 | elif opt == "--setparams": 1357 | TheSim.LogSetParams = True 1358 | elif opt == "--wts": 1359 | TheSim.SaveWts = True 1360 | print("Saving final weights per run") 1361 | elif opt == "--epclog": 1362 | if arg.lower() == "false" or arg == "0": 1363 | saveEpcLog = False 1364 | elif opt == "--runlog": 1365 | if arg.lower() == "false" or arg == "0": 1366 | saveRunLog = False 1367 | elif opt == "--nogui": 1368 | TheSim.NoGui = True 1369 | 1370 | TheSim.Init() 1371 | 1372 | if TheSim.NoGui: 1373 | if saveEpcLog: 1374 | fnm = TheSim.LogFileName("epc") 1375 | print("Saving epoch log to: %s" % fnm) 1376 | TheSim.TrnEpcFile = efile.Create(fnm) 1377 | 1378 | if saveRunLog: 1379 | fnm = TheSim.LogFileName("run") 1380 | print("Saving run log to: %s" % fnm) 1381 | TheSim.RunFile = efile.Create(fnm) 1382 | 1383 | TheSim.Train() 1384 | else: 1385 | TheSim.ConfigGui() 1386 | print("Note: run pyleabra -i ra25.py to run in interactive mode, or just pyleabra, then 'import ra25'") 1387 | print("for non-gui background running, here are the args:") 1388 | usage() 1389 | 1390 | 1391 | main(sys.argv[1:]) 1392 | -------------------------------------------------------------------------------- /version.go: -------------------------------------------------------------------------------- 1 | // WARNING: auto-generated by Makefile release target -- run 'make release' to update 2 | 3 | package main 4 | 5 | const ( 6 | Version = "v0.4.1" 7 | GitCommit = "cd0e95d" // the commit JUST BEFORE the release 8 | VersionDate = "2021-09-25 19:54" // UTC 9 | ) 10 | --------------------------------------------------------------------------------