├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── diff
└── diff.go
├── format.go
├── go.mod
├── gotopy.go
├── gotopy_test.go
├── parse.go
├── pyedits.go
├── pyprint
├── nodes.go
└── printer.go
├── rewrite.go
├── simplify.go
├── testdata
├── basic.golden
├── basic.input
├── ra25.golden
├── ra25.input
└── ra25.py
└── version.go
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 |
8 | # Test binary, built with `go test -c`
9 | *.test
10 |
11 | # Output of the go coverage tool, specifically when used with LiteIDE
12 | *.out
13 |
14 | # Dependency directories (remove the comment below to include it)
15 | # vendor/
16 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, GoKi
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Basic Go makefile
2 |
3 | GOCMD=go
4 | GOBUILD=$(GOCMD) build
5 | GOCLEAN=$(GOCMD) clean
6 | GOTEST=$(GOCMD) test
7 | GOGET=$(GOCMD) get
8 |
9 | DIRS=`go list ./...`
10 |
11 | all: build
12 |
13 | build:
14 | @echo "GO111MODULE = $(value GO111MODULE)"
15 | $(GOBUILD) -v $(DIRS)
16 |
17 | test:
18 | @echo "GO111MODULE = $(value GO111MODULE)"
19 | $(GOTEST) -v $(DIRS)
20 |
21 | clean:
22 | @echo "GO111MODULE = $(value GO111MODULE)"
23 | $(GOCLEAN) ./...
24 |
25 | fmts:
26 | gofmt -s -w .
27 |
28 | vet:
29 | @echo "GO111MODULE = $(value GO111MODULE)"
30 | $(GOCMD) vet $(DIRS) | grep -v unkeyed
31 |
32 | tidy: export GO111MODULE = on
33 | tidy:
34 | @echo "GO111MODULE = $(value GO111MODULE)"
35 | go mod tidy
36 |
37 | mod-update: export GO111MODULE = on
38 | mod-update:
39 | @echo "GO111MODULE = $(value GO111MODULE)"
40 | go get -u ./...
41 | go mod tidy
42 |
43 | # gopath-update is for GOPATH to get most things updated.
44 | # need to call it in a target executable directory
45 | gopath-update: export GO111MODULE = off
46 | gopath-update:
47 | @echo "GO111MODULE = $(value GO111MODULE)"
48 | go get -u ./...
49 |
50 | # NOTE: MUST update version number here prior to running 'make release' and edit this file!
51 | VERS=v0.4.1
52 | PACKAGE=main
53 | GIT_COMMIT=`git rev-parse --short HEAD`
54 | VERS_DATE=`date -u +%Y-%m-%d\ %H:%M`
55 | VERS_FILE=version.go
56 |
57 | release:
58 | /bin/rm -f $(VERS_FILE)
59 | @echo "// WARNING: auto-generated by Makefile release target -- run 'make release' to update" > $(VERS_FILE)
60 | @echo "" >> $(VERS_FILE)
61 | @echo "package $(PACKAGE)" >> $(VERS_FILE)
62 | @echo "" >> $(VERS_FILE)
63 | @echo "const (" >> $(VERS_FILE)
64 | @echo " Version = \"$(VERS)\"" >> $(VERS_FILE)
65 | @echo " GitCommit = \"$(GIT_COMMIT)\" // the commit JUST BEFORE the release" >> $(VERS_FILE)
66 | @echo " VersionDate = \"$(VERS_DATE)\" // UTC" >> $(VERS_FILE)
67 | @echo ")" >> $(VERS_FILE)
68 | @echo "" >> $(VERS_FILE)
69 | goimports -w $(VERS_FILE)
70 | /bin/cat $(VERS_FILE)
71 | git commit -am "$(VERS) release"
72 | git tag -a $(VERS) -m "$(VERS) release"
73 | git push
74 | git push origin --tags
75 |
76 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GoToPy
2 |
3 | GoToPy is a Go to Python converter -- translates Go code into Python code.
4 |
5 | To install, do standard:
6 |
7 | ```Go
8 | $ go install github.com/go-python/gotopy@latest
9 | ```
10 |
11 | It is based on the Go `gofmt` command source code and the go `printer` package, which parses Go files and writes them out according to standard go formatting.
12 |
13 | We have modified the `printer` code in the `pyprint` package to instead print out Python code.
14 |
15 | The `-gopy` flag generates [GoPy](https:://github.com/go-python/gopy) specific Python code, including:
16 |
17 | * `nil` -> `go.nil`
18 | * `[]string{...}` -> `go.Slice_string([...])` etc for int, float64, float32
19 |
20 | The `-gogi` flag generates [GoGi](https:://github.com/goki/gi) specific Python code, including:
21 |
22 | * struct tags generate: `self.SetTags()` call, for the `pygiv.ClassViewObj` class, which then provides an automatic GUI view with tag-based formatting of struct fields.
23 |
24 | # TODO
25 |
26 | * switch -> ifs.. -- grab switch expr and put into each if
27 |
28 | * string .contains -> "el" in str
29 |
30 | * map access with 2 vars = if el in map: mv = map[el]
31 |
32 | * for range with 2 vars = enumerate(slice)
33 |
34 |
--------------------------------------------------------------------------------
/diff/diff.go:
--------------------------------------------------------------------------------
1 | // Copyright 2019 The Go Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // Package diff implements a Diff function that compare two inputs
6 | // using the 'diff' tool.
7 | package diff
8 |
9 | import (
10 | "io/ioutil"
11 | "os"
12 | "os/exec"
13 | "runtime"
14 | )
15 |
16 | // Returns diff of two arrays of bytes in diff tool format.
17 | func Diff(prefix string, b1, b2 []byte) ([]byte, error) {
18 | f1, err := writeTempFile(prefix, b1)
19 | if err != nil {
20 | return nil, err
21 | }
22 | defer os.Remove(f1)
23 |
24 | f2, err := writeTempFile(prefix, b2)
25 | if err != nil {
26 | return nil, err
27 | }
28 | defer os.Remove(f2)
29 |
30 | cmd := "diff"
31 | if runtime.GOOS == "plan9" {
32 | cmd = "/bin/ape/diff"
33 | }
34 |
35 | data, err := exec.Command(cmd, "-u", f1, f2).CombinedOutput()
36 | if len(data) > 0 {
37 | // diff exits with a non-zero status when the files don't match.
38 | // Ignore that failure as long as we get output.
39 | err = nil
40 | }
41 | return data, err
42 | }
43 |
44 | func writeTempFile(prefix string, data []byte) (string, error) {
45 | file, err := ioutil.TempFile("", prefix)
46 | if err != nil {
47 | return "", err
48 | }
49 | _, err = file.Write(data)
50 | if err1 := file.Close(); err == nil {
51 | err = err1
52 | }
53 | if err != nil {
54 | os.Remove(file.Name())
55 | return "", err
56 | }
57 | return file.Name(), nil
58 | }
59 |
--------------------------------------------------------------------------------
/format.go:
--------------------------------------------------------------------------------
1 | // Copyright 2020 The Go-Python Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // This is based on gofmt source code:
6 |
7 | // Copyright 2015 The Go Authors. All rights reserved.
8 | // Use of this source code is governed by a BSD-style
9 | // license that can be found in the LICENSE file.
10 |
11 | package main
12 |
13 | import (
14 | "bytes"
15 | "go/ast"
16 | "go/token"
17 |
18 | "github.com/go-python/gotopy/pyprint"
19 | )
20 |
21 | // format formats the given package file originally obtained from src
22 | // and adjusts the result based on the original source via sourceAdj
23 | // and indentAdj.
24 | func format(
25 | fset *token.FileSet,
26 | file *ast.File,
27 | sourceAdj func(src []byte, indent int) []byte,
28 | indentAdj int,
29 | src []byte,
30 | cfg pyprint.Config,
31 | ) ([]byte, error) {
32 | if sourceAdj == nil {
33 | // Complete source file.
34 | var buf bytes.Buffer
35 | err := cfg.Fprint(&buf, fset, file)
36 | if err != nil {
37 | return nil, err
38 | }
39 | pyfix := pyEdits(buf.Bytes())
40 | return pyfix, nil
41 | // return buf.Bytes(), nil
42 | }
43 |
44 | // Partial source file.
45 | // Determine and prepend leading space.
46 | i, j := 0, 0
47 | for j < len(src) && isSpace(src[j]) {
48 | if src[j] == '\n' {
49 | i = j + 1 // byte offset of last line in leading space
50 | }
51 | j++
52 | }
53 | var res []byte
54 | res = append(res, src[:i]...)
55 |
56 | // Determine and prepend indentation of first code line.
57 | // Spaces are ignored unless there are no tabs,
58 | // in which case spaces count as one tab.
59 | indent := 0
60 | hasSpace := false
61 | for _, b := range src[i:j] {
62 | switch b {
63 | case ' ':
64 | hasSpace = true
65 | case '\t':
66 | indent++
67 | }
68 | }
69 | if indent == 0 && hasSpace {
70 | indent = 1
71 | }
72 | for i := 0; i < indent; i++ {
73 | res = append(res, '\t')
74 | }
75 |
76 | // Format the source.
77 | // Write it without any leading and trailing space.
78 | cfg.Indent = indent + indentAdj
79 | var buf bytes.Buffer
80 | err := cfg.Fprint(&buf, fset, file)
81 | if err != nil {
82 | return nil, err
83 | }
84 |
85 | pyfix := pyEdits(buf.Bytes())
86 |
87 | out := sourceAdj(pyfix, cfg.Indent)
88 |
89 | // If the adjusted output is empty, the source
90 | // was empty but (possibly) for white space.
91 | // The result is the incoming source.
92 | if len(out) == 0 {
93 | return src, nil
94 | }
95 |
96 | // Otherwise, append output to leading space.
97 | res = append(res, out...)
98 |
99 | // Determine and append trailing space.
100 | i = len(src)
101 | for i > 0 && isSpace(src[i-1]) {
102 | i--
103 | }
104 | return append(res, src[i:]...), nil
105 | }
106 |
107 | // isSpace reports whether the byte is a space character.
108 | // isSpace defines a space as being among the following bytes: ' ', '\t', '\n' and '\r'.
109 | func isSpace(b byte) bool {
110 | return b == ' ' || b == '\t' || b == '\n' || b == '\r'
111 | }
112 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/go-python/gotopy
2 |
3 | go 1.15
4 |
--------------------------------------------------------------------------------
/gotopy.go:
--------------------------------------------------------------------------------
1 | // Copyright 2020 The Go-Python Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // This is based on gofmt source code:
6 |
7 | // Copyright 2009 The Go Authors. All rights reserved.
8 | // Use of this source code is governed by a BSD-style
9 | // license that can be found in the LICENSE file.
10 |
11 | package main
12 |
13 | import (
14 | "bytes"
15 | "flag"
16 | "fmt"
17 | "go/ast"
18 | "go/parser"
19 | "go/scanner"
20 | "go/token"
21 | "io"
22 | "io/ioutil"
23 | "os"
24 | "path/filepath"
25 | "runtime"
26 | "runtime/pprof"
27 | "strings"
28 |
29 | "github.com/go-python/gotopy/diff"
30 | "github.com/go-python/gotopy/pyprint"
31 | )
32 |
33 | var (
34 | // main operation modes
35 | list = flag.Bool("l", false, "list files whose formatting differs from gofmt's")
36 | write = flag.Bool("w", false, "write result to (source) file instead of stdout")
37 | rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'a[b:len(a)] -> a[b:]')")
38 | simplifyAST = flag.Bool("s", false, "simplify code")
39 | doDiff = flag.Bool("d", false, "display diffs instead of rewriting files")
40 | allErrors = flag.Bool("e", false, "report all errors (not just the first 10 on different lines)")
41 | gopyMode = flag.Bool("gopy", false, "support GoPy-specific Python code generation")
42 | gogiMode = flag.Bool("gogi", false, "support GoGi-specific Python code generation (implies gopy)")
43 |
44 | // debugging
45 | cpuprofile = flag.String("cpuprofile", "", "write cpu profile to this file")
46 | )
47 |
48 | // Keep these in sync with go/format/format.go.
49 | const (
50 | tabWidth = 8
51 | printerModeDef = pyprint.UseSpaces | pyprint.TabIndent | printerNormalizeNumbers
52 |
53 | // printerNormalizeNumbers means to canonicalize number literal prefixes
54 | // and exponents while printing. See https://golang.org/doc/go1.13#gofmt.
55 | //
56 | // This value is defined in go/printer specifically for go/format and cmd/gofmt.
57 | printerNormalizeNumbers = 1 << 30
58 | )
59 |
60 | var (
61 | fileSet = token.NewFileSet() // per process FileSet
62 | exitCode = 0
63 | rewrite func(*ast.File) *ast.File
64 | parserMode parser.Mode
65 | printerMode = printerModeDef
66 | )
67 |
68 | func report(err error) {
69 | scanner.PrintError(os.Stderr, err)
70 | exitCode = 2
71 | }
72 |
73 | func usage() {
74 | fmt.Fprintf(os.Stderr, "usage: gofmt [flags] [path ...]\n")
75 | flag.PrintDefaults()
76 | }
77 |
78 | func initParserMode() {
79 | parserMode = parser.ParseComments
80 | if *allErrors {
81 | parserMode |= parser.AllErrors
82 | }
83 | if *gopyMode {
84 | printerMode |= pyprint.GoPy
85 | }
86 |
87 | if *gogiMode {
88 | printerMode |= pyprint.GoGi | pyprint.GoPy
89 | }
90 | }
91 |
92 | func isGoFile(f os.FileInfo) bool {
93 | // ignore non-Go files
94 | name := f.Name()
95 | return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
96 | }
97 |
98 | // If in == nil, the source is the contents of the file with the given filename.
99 | func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error {
100 | var perm os.FileMode = 0644
101 | if in == nil {
102 | f, err := os.Open(filename)
103 | if err != nil {
104 | return err
105 | }
106 | defer f.Close()
107 | fi, err := f.Stat()
108 | if err != nil {
109 | return err
110 | }
111 | in = f
112 | perm = fi.Mode().Perm()
113 | }
114 |
115 | src, err := ioutil.ReadAll(in)
116 | if err != nil {
117 | return err
118 | }
119 |
120 | file, sourceAdj, indentAdj, err := parse(fileSet, filename, src, stdin)
121 | if err != nil {
122 | return err
123 | }
124 |
125 | if rewrite != nil {
126 | if sourceAdj == nil {
127 | file = rewrite(file)
128 | } else {
129 | fmt.Fprintf(os.Stderr, "warning: rewrite ignored for incomplete programs\n")
130 | }
131 | }
132 |
133 | ast.SortImports(fileSet, file)
134 |
135 | if *simplifyAST {
136 | simplify(file)
137 | }
138 |
139 | res, err := format(fileSet, file, sourceAdj, indentAdj, src, pyprint.Config{Mode: printerMode, Tabwidth: tabWidth})
140 | if err != nil {
141 | return err
142 | }
143 |
144 | if !bytes.Equal(src, res) {
145 | // formatting has changed
146 | if *list {
147 | fmt.Fprintln(out, filename)
148 | }
149 | if *write {
150 | // make a temporary backup before overwriting original
151 | bakname, err := backupFile(filename+".", src, perm)
152 | if err != nil {
153 | return err
154 | }
155 | err = ioutil.WriteFile(filename, res, perm)
156 | if err != nil {
157 | os.Rename(bakname, filename)
158 | return err
159 | }
160 | err = os.Remove(bakname)
161 | if err != nil {
162 | return err
163 | }
164 | }
165 | if *doDiff {
166 | data, err := diffWithReplaceTempFile(src, res, filename)
167 | if err != nil {
168 | return fmt.Errorf("computing diff: %s", err)
169 | }
170 | fmt.Printf("diff -u %s %s\n", filepath.ToSlash(filename+".orig"), filepath.ToSlash(filename))
171 | out.Write(data)
172 | }
173 | }
174 |
175 | if !*list && !*write && !*doDiff {
176 | _, err = out.Write(res)
177 | }
178 |
179 | return err
180 | }
181 |
182 | func visitFile(path string, f os.FileInfo, err error) error {
183 | if err == nil && isGoFile(f) {
184 | err = processFile(path, nil, os.Stdout, false)
185 | }
186 | // Don't complain if a file was deleted in the meantime (i.e.
187 | // the directory changed concurrently while running gofmt).
188 | if err != nil && !os.IsNotExist(err) {
189 | report(err)
190 | }
191 | return nil
192 | }
193 |
194 | func walkDir(path string) {
195 | filepath.Walk(path, visitFile)
196 | }
197 |
198 | func main() {
199 | // call gofmtMain in a separate function
200 | // so that it can use defer and have them
201 | // run before the exit.
202 | gofmtMain()
203 | os.Exit(exitCode)
204 | }
205 |
206 | func gofmtMain() {
207 | flag.Usage = usage
208 | flag.Parse()
209 |
210 | if *cpuprofile != "" {
211 | f, err := os.Create(*cpuprofile)
212 | if err != nil {
213 | fmt.Fprintf(os.Stderr, "creating cpu profile: %s\n", err)
214 | exitCode = 2
215 | return
216 | }
217 | defer f.Close()
218 | pprof.StartCPUProfile(f)
219 | defer pprof.StopCPUProfile()
220 | }
221 |
222 | initParserMode()
223 | initRewrite()
224 |
225 | if flag.NArg() == 0 {
226 | if *write {
227 | fmt.Fprintln(os.Stderr, "error: cannot use -w with standard input")
228 | exitCode = 2
229 | return
230 | }
231 | if err := processFile("", os.Stdin, os.Stdout, true); err != nil {
232 | report(err)
233 | }
234 | return
235 | }
236 |
237 | for i := 0; i < flag.NArg(); i++ {
238 | path := flag.Arg(i)
239 | switch dir, err := os.Stat(path); {
240 | case err != nil:
241 | report(err)
242 | case dir.IsDir():
243 | walkDir(path)
244 | default:
245 | if err := processFile(path, nil, os.Stdout, false); err != nil {
246 | report(err)
247 | }
248 | }
249 | }
250 | }
251 |
252 | func diffWithReplaceTempFile(b1, b2 []byte, filename string) ([]byte, error) {
253 | data, err := diff.Diff("gofmt", b1, b2)
254 | if len(data) > 0 {
255 | return replaceTempFilename(data, filename)
256 | }
257 | return data, err
258 | }
259 |
260 | // replaceTempFilename replaces temporary filenames in diff with actual one.
261 | //
262 | // --- /tmp/gofmt316145376 2017-02-03 19:13:00.280468375 -0500
263 | // +++ /tmp/gofmt617882815 2017-02-03 19:13:00.280468375 -0500
264 | // ...
265 | // ->
266 | // --- path/to/file.go.orig 2017-02-03 19:13:00.280468375 -0500
267 | // +++ path/to/file.go 2017-02-03 19:13:00.280468375 -0500
268 | // ...
269 | func replaceTempFilename(diff []byte, filename string) ([]byte, error) {
270 | bs := bytes.SplitN(diff, []byte{'\n'}, 3)
271 | if len(bs) < 3 {
272 | return nil, fmt.Errorf("got unexpected diff for %s", filename)
273 | }
274 | // Preserve timestamps.
275 | var t0, t1 []byte
276 | if i := bytes.LastIndexByte(bs[0], '\t'); i != -1 {
277 | t0 = bs[0][i:]
278 | }
279 | if i := bytes.LastIndexByte(bs[1], '\t'); i != -1 {
280 | t1 = bs[1][i:]
281 | }
282 | // Always print filepath with slash separator.
283 | f := filepath.ToSlash(filename)
284 | bs[0] = []byte(fmt.Sprintf("--- %s%s", f+".orig", t0))
285 | bs[1] = []byte(fmt.Sprintf("+++ %s%s", f, t1))
286 | return bytes.Join(bs, []byte{'\n'}), nil
287 | }
288 |
289 | const chmodSupported = runtime.GOOS != "windows"
290 |
291 | // backupFile writes data to a new file named filename with permissions perm,
292 | // with >>>")
52 |
53 | endclass := "EndClass: "
54 | method := "Method: "
55 | endmethod := "EndMethod"
56 |
57 | lastMethSt := -1
58 | var lastMeth string
59 | curComSt := -1
60 | lastComSt := -1
61 | lastComEd := -1
62 |
63 | li := 0
64 | for {
65 | if li >= len(lines) {
66 | break
67 | }
68 | ln := lines[li]
69 | if len(ln) > 0 && ln[0] == '#' {
70 | if curComSt >= 0 {
71 | lastComEd = li
72 | } else {
73 | curComSt = li
74 | lastComSt = li
75 | lastComEd = li
76 | }
77 | } else {
78 | curComSt = -1
79 | }
80 |
81 | switch {
82 | case bytes.Equal(ln, []byte(" :")) || bytes.Equal(ln, []byte(":")):
83 | lines = append(lines[:li], lines[li+1:]...) // delete marker
84 | li--
85 | case bytes.HasPrefix(ln, class):
86 | cl := string(ln[len(class):])
87 | if idx := strings.Index(cl, "("); idx > 0 {
88 | cl = cl[:idx]
89 | } else if idx := strings.Index(cl, ":"); idx > 0 { // should have
90 | cl = cl[:idx]
91 | }
92 | cl = strings.TrimSpace(cl)
93 | classes[cl] = sted{st: li}
94 | // fmt.Printf("cl: %s at %d\n", cl, li)
95 | case bytes.HasPrefix(ln, pymark) && bytes.HasSuffix(ln, pyend):
96 | tag := string(ln[4 : len(ln)-4])
97 | // fmt.Printf("tag: %s at: %d\n", tag, li)
98 | switch {
99 | case strings.HasPrefix(tag, endclass):
100 | cl := tag[len(endclass):]
101 | st := classes[cl]
102 | classes[cl] = sted{st: st.st, ed: li}
103 | lines = append(lines[:li], lines[li+1:]...) // delete marker
104 | // fmt.Printf("cl: %s at %v\n", cl, classes[cl])
105 | li--
106 | case strings.HasPrefix(tag, method):
107 | cl := tag[len(method):]
108 | lines = append(lines[:li], lines[li+1:]...) // delete marker
109 | li--
110 | lastMeth = cl
111 | if lastComEd == li {
112 | lines = append(lines[:lastComSt], lines[lastComEd+1:]...) // delete comments
113 | lastMethSt = lastComSt
114 | li = lastComSt - 1
115 | } else {
116 | lastMethSt = li + 1
117 | }
118 | case tag == endmethod:
119 | se, ok := classes[lastMeth]
120 | if ok {
121 | lines = append(lines[:li], lines[li+1:]...) // delete marker
122 | moveLines(&lines, se.ed, lastMethSt, li+1) // extra blank
123 | classes[lastMeth] = sted{st: se.st, ed: se.ed + ((li + 1) - lastMethSt)}
124 | li -= 2
125 | }
126 | }
127 | }
128 | li++
129 | }
130 | return lines
131 | }
132 |
133 | // pyEditsReplace replaces Go with equivalent Python code
134 | func pyEditsReplace(lines [][]byte) {
135 | fmtPrintf := []byte("fmt.Printf")
136 | fmtSprintf := []byte("fmt.Sprintf(")
137 | prints := []byte("print")
138 | eqappend := []byte("= append(")
139 | elseif := []byte("else if")
140 | elif := []byte("elif")
141 | forblank := []byte("for _, ")
142 | fornoblank := []byte("for ")
143 | itoa := []byte("strconv.Itoa")
144 | float64p := []byte("float64(")
145 | float32p := []byte("float32(")
146 | floatp := []byte("float(")
147 | stringp := []byte("string(")
148 | strp := []byte("str(")
149 | slicestr := []byte("[]str(")
150 | sliceint := []byte("[]int(")
151 | slicefloat64 := []byte("[]float64(")
152 | slicefloat32 := []byte("[]float32(")
153 | goslicestr := []byte("go.Slice_string([")
154 | gosliceint := []byte("go.Slice_int([")
155 | goslicefloat64 := []byte("go.Slice_float64([")
156 | goslicefloat32 := []byte("go.Slice_float32([")
157 | stringsdot := []byte("strings.")
158 | copyp := []byte("copy(")
159 | eqgonil := []byte(" == go.nil")
160 | eqgonil0 := []byte(" == 0")
161 | negonil := []byte(" != go.nil")
162 | negonil0 := []byte(" != 0")
163 |
164 | gopy := (printerMode&pyprint.GoPy != 0)
165 | // gogi := (printerMode&pyprint.GoGi != 0)
166 |
167 | for li, ln := range lines {
168 | ln = bytes.Replace(ln, float64p, floatp, -1)
169 | ln = bytes.Replace(ln, float32p, floatp, -1)
170 | ln = bytes.Replace(ln, stringp, strp, -1)
171 | ln = bytes.Replace(ln, forblank, fornoblank, -1)
172 | ln = bytes.Replace(ln, eqgonil, eqgonil0, -1)
173 | ln = bytes.Replace(ln, negonil, negonil0, -1)
174 |
175 | if bytes.Contains(ln, fmtSprintf) {
176 | if bytes.Contains(ln, []byte("%")) {
177 | ln = bytes.Replace(ln, []byte(`", `), []byte(`" % (`), -1)
178 | }
179 | ln = bytes.Replace(ln, fmtSprintf, []byte{}, -1)
180 | }
181 |
182 | if bytes.Contains(ln, fmtPrintf) {
183 | if bytes.Contains(ln, []byte("%")) {
184 | ln = bytes.Replace(ln, []byte(`", `), []byte(`" % `), -1)
185 | }
186 | ln = bytes.Replace(ln, fmtPrintf, prints, -1)
187 | }
188 |
189 | if bytes.Contains(ln, eqappend) {
190 | idx := bytes.Index(ln, eqappend)
191 | comi := bytes.Index(ln[idx+len(eqappend):], []byte(","))
192 | nln := make([]byte, idx-1)
193 | copy(nln, ln[:idx-1])
194 | nln = append(nln, []byte(".append(")...)
195 | nln = append(nln, ln[idx+len(eqappend)+comi+1:]...)
196 | ln = nln
197 | }
198 |
199 | for {
200 | if bytes.Contains(ln, stringsdot) {
201 | idx := bytes.Index(ln, stringsdot)
202 | pi := idx + len(stringsdot) + bytes.Index(ln[idx+len(stringsdot):], []byte("("))
203 | comi := bytes.Index(ln[pi:], []byte(","))
204 | nln := make([]byte, idx)
205 | copy(nln, ln[:idx])
206 | if comi < 0 {
207 | comi = bytes.Index(ln[pi:], []byte(")"))
208 | nln = append(nln, ln[pi+1:pi+comi]...)
209 | nln = append(nln, '.')
210 | meth := bytes.ToLower(ln[idx+len(stringsdot) : pi+1])
211 | if bytes.Equal(meth, []byte("fields(")) {
212 | meth = []byte("split(")
213 | }
214 | nln = append(nln, meth...)
215 | nln = append(nln, ln[pi+comi:]...)
216 | } else {
217 | nln = append(nln, ln[pi+1:pi+comi]...)
218 | nln = append(nln, '.')
219 | meth := bytes.ToLower(ln[idx+len(stringsdot) : pi+1])
220 | nln = append(nln, meth...)
221 | nln = append(nln, ln[pi+comi+1:]...)
222 | }
223 | ln = nln
224 | } else {
225 | break
226 | }
227 | }
228 |
229 | if bytes.Contains(ln, copyp) {
230 | idx := bytes.Index(ln, copyp)
231 | pi := idx + len(copyp) + bytes.Index(ln[idx+len(stringsdot):], []byte("("))
232 | comi := bytes.Index(ln[pi:], []byte(","))
233 | nln := make([]byte, idx)
234 | copy(nln, ln[:idx])
235 | nln = append(nln, ln[pi+1:pi+comi]...)
236 | nln = append(nln, '.')
237 | nln = append(nln, copyp...)
238 | nln = append(nln, ln[pi+comi+1:]...)
239 | ln = nln
240 | }
241 |
242 | if bytes.Contains(ln, itoa) {
243 | ln = bytes.Replace(ln, itoa, []byte(`str`), -1)
244 | }
245 |
246 | if bytes.Contains(ln, elseif) {
247 | ln = bytes.Replace(ln, elseif, elif, -1)
248 | }
249 |
250 | if gopy && bytes.Contains(ln, slicestr) {
251 | ln = bytes.Replace(ln, slicestr, goslicestr, -1)
252 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1)
253 | }
254 |
255 | if gopy && bytes.Contains(ln, sliceint) {
256 | ln = bytes.Replace(ln, sliceint, gosliceint, -1)
257 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1)
258 | }
259 |
260 | if gopy && bytes.Contains(ln, slicefloat64) {
261 | ln = bytes.Replace(ln, slicefloat64, goslicefloat64, -1)
262 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1)
263 | }
264 |
265 | if gopy && bytes.Contains(ln, slicefloat32) {
266 | ln = bytes.Replace(ln, slicefloat32, goslicefloat32, -1)
267 | ln = bytes.Replace(ln, []byte(")"), []byte("])"), 1)
268 | }
269 |
270 | ln = bytes.Replace(ln, []byte("\t"), []byte(" "), -1)
271 |
272 | lines[li] = ln
273 | }
274 | }
275 |
--------------------------------------------------------------------------------
/pyprint/printer.go:
--------------------------------------------------------------------------------
1 | // Copyright 2020 The Go-Python Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // This is based on gofmt source code:
6 |
7 | // Copyright 2009 The Go Authors. All rights reserved.
8 | // Use of this source code is governed by a BSD-style
9 | // license that can be found in the LICENSE file.
10 |
11 | // Package printer implements printing of AST nodes.
12 | package pyprint
13 |
14 | import (
15 | "fmt"
16 | "go/ast"
17 | "go/token"
18 | "io"
19 | "os"
20 | "strings"
21 | "text/tabwriter"
22 | "unicode"
23 | )
24 |
25 | const (
26 | maxNewlines = 2 // max. number of newlines between source text
27 | debug = false // enable for debugging
28 | infinity = 1 << 30
29 | )
30 |
31 | type whiteSpace byte
32 |
33 | const (
34 | ignore = whiteSpace(0)
35 | blank = whiteSpace(' ')
36 | vtab = whiteSpace('\v')
37 | newline = whiteSpace('\n')
38 | formfeed = whiteSpace('\f')
39 | indent = whiteSpace('>')
40 | unindent = whiteSpace('<')
41 | )
42 |
43 | // A pmode value represents the current printer mode.
44 | type pmode int
45 |
46 | const (
47 | noExtraBlank pmode = 1 << iota // disables extra blank after /*-style comment
48 | noExtraLinebreak // disables extra line break after /*-style comment
49 | )
50 |
51 | type commentInfo struct {
52 | cindex int // current comment index
53 | comment *ast.CommentGroup // = printer.comments[cindex]; or nil
54 | commentOffset int // = printer.posFor(printer.comments[cindex].List[0].Pos()).Offset; or infinity
55 | commentNewline bool // true if the comment group contains newlines
56 | }
57 |
58 | type printer struct {
59 | // Configuration (does not change after initialization)
60 | Config
61 | fset *token.FileSet
62 |
63 | // Current state
64 | output []byte // raw printer result
65 | indent int // current indentation
66 | level int // level == 0: outside composite literal; level > 0: inside composite literal
67 | mode pmode // current printer mode
68 | endAlignment bool // if set, terminate alignment immediately
69 | impliedSemi bool // if set, a linebreak implies a semicolon
70 | lastTok token.Token // last token printed (token.ILLEGAL if it's whitespace)
71 | prevOpen token.Token // previous non-brace "open" token (, [, or token.ILLEGAL
72 | wsbuf []whiteSpace // delayed white space
73 |
74 | // Positions
75 | // The out position differs from the pos position when the result
76 | // formatting differs from the source formatting (in the amount of
77 | // white space). If there's a difference and SourcePos is set in
78 | // ConfigMode, //line directives are used in the output to restore
79 | // original source positions for a reader.
80 | pos token.Position // current position in AST (source) space
81 | out token.Position // current position in output space
82 | last token.Position // value of pos after calling writeString
83 | linePtr *int // if set, record out.Line for the next token in *linePtr
84 |
85 | // The list of all source comments, in order of appearance.
86 | comments []*ast.CommentGroup // may be nil
87 | useNodeComments bool // if not set, ignore lead and line comments of nodes
88 |
89 | // Information about p.comments[p.cindex]; set up by nextComment.
90 | commentInfo
91 |
92 | // Cache of already computed node sizes.
93 | nodeSizes map[ast.Node]int
94 |
95 | // Cache of most recently computed line position.
96 | cachedPos token.Pos
97 | cachedLine int // line corresponding to cachedPos
98 | }
99 |
100 | func (p *printer) init(cfg *Config, fset *token.FileSet, nodeSizes map[ast.Node]int) {
101 | p.Config = *cfg
102 | p.fset = fset
103 | p.pos = token.Position{Line: 1, Column: 1}
104 | p.out = token.Position{Line: 1, Column: 1}
105 | p.wsbuf = make([]whiteSpace, 0, 16) // whitespace sequences are short
106 | p.nodeSizes = nodeSizes
107 | p.cachedPos = -1
108 | }
109 |
110 | func (p *printer) internalError(msg ...interface{}) {
111 | if debug {
112 | fmt.Print(p.pos.String() + ": ")
113 | fmt.Println(msg...)
114 | panic("go/printer")
115 | }
116 | }
117 |
118 | // commentsHaveNewline reports whether a list of comments belonging to
119 | // an *ast.CommentGroup contains newlines. Because the position information
120 | // may only be partially correct, we also have to read the comment text.
121 | func (p *printer) commentsHaveNewline(list []*ast.Comment) bool {
122 | // len(list) > 0
123 | line := p.lineFor(list[0].Pos())
124 | for i, c := range list {
125 | if i > 0 && p.lineFor(list[i].Pos()) != line {
126 | // not all comments on the same line
127 | return true
128 | }
129 | if t := c.Text; len(t) >= 2 && (t[1] == '/' || strings.Contains(t, "\n")) {
130 | return true
131 | }
132 | }
133 | _ = line
134 | return false
135 | }
136 |
137 | func (p *printer) nextComment() {
138 | for p.cindex < len(p.comments) {
139 | c := p.comments[p.cindex]
140 | p.cindex++
141 | if list := c.List; len(list) > 0 {
142 | p.comment = c
143 | p.commentOffset = p.posFor(list[0].Pos()).Offset
144 | p.commentNewline = p.commentsHaveNewline(list)
145 | return
146 | }
147 | // we should not reach here (correct ASTs don't have empty
148 | // ast.CommentGroup nodes), but be conservative and try again
149 | }
150 | // no more comments
151 | p.commentOffset = infinity
152 | }
153 |
154 | // commentBefore reports whether the current comment group occurs
155 | // before the next position in the source code and printing it does
156 | // not introduce implicit semicolons.
157 | //
158 | func (p *printer) commentBefore(next token.Position) bool {
159 | return p.commentOffset < next.Offset && (!p.impliedSemi || !p.commentNewline)
160 | }
161 |
162 | // commentSizeBefore returns the estimated size of the
163 | // comments on the same line before the next position.
164 | //
165 | func (p *printer) commentSizeBefore(next token.Position) int {
166 | // save/restore current p.commentInfo (p.nextComment() modifies it)
167 | defer func(info commentInfo) {
168 | p.commentInfo = info
169 | }(p.commentInfo)
170 |
171 | size := 0
172 | for p.commentBefore(next) {
173 | for _, c := range p.comment.List {
174 | size += len(c.Text)
175 | }
176 | p.nextComment()
177 | }
178 | return size
179 | }
180 |
181 | // recordLine records the output line number for the next non-whitespace
182 | // token in *linePtr. It is used to compute an accurate line number for a
183 | // formatted construct, independent of pending (not yet emitted) whitespace
184 | // or comments.
185 | //
186 | func (p *printer) recordLine(linePtr *int) {
187 | p.linePtr = linePtr
188 | }
189 |
190 | // linesFrom returns the number of output lines between the current
191 | // output line and the line argument, ignoring any pending (not yet
192 | // emitted) whitespace or comments. It is used to compute an accurate
193 | // size (in number of lines) for a formatted construct.
194 | //
195 | func (p *printer) linesFrom(line int) int {
196 | return p.out.Line - line
197 | }
198 |
199 | func (p *printer) posFor(pos token.Pos) token.Position {
200 | // not used frequently enough to cache entire token.Position
201 | return p.fset.PositionFor(pos, false /* absolute position */)
202 | }
203 |
204 | func (p *printer) lineFor(pos token.Pos) int {
205 | if pos != p.cachedPos {
206 | p.cachedPos = pos
207 | p.cachedLine = p.fset.PositionFor(pos, false /* absolute position */).Line
208 | }
209 | return p.cachedLine
210 | }
211 |
212 | // writeLineDirective writes a //line directive if necessary.
213 | func (p *printer) writeLineDirective(pos token.Position) {
214 | if pos.IsValid() && (p.out.Line != pos.Line || p.out.Filename != pos.Filename) {
215 | p.output = append(p.output, tabwriter.Escape) // protect '\n' in //line from tabwriter interpretation
216 | p.output = append(p.output, fmt.Sprintf("//line %s:%d\n", pos.Filename, pos.Line)...)
217 | p.output = append(p.output, tabwriter.Escape)
218 | // p.out must match the //line directive
219 | p.out.Filename = pos.Filename
220 | p.out.Line = pos.Line
221 | }
222 | }
223 |
224 | // writeIndent writes indentation.
225 | func (p *printer) writeIndent() {
226 | // use "hard" htabs - indentation columns
227 | // must not be discarded by the tabwriter
228 | n := p.Config.Indent + p.indent // include base indentation
229 | for i := 0; i < n; i++ {
230 | p.output = append(p.output, '\t')
231 | }
232 |
233 | // update positions
234 | p.pos.Offset += n
235 | p.pos.Column += n
236 | p.out.Column += n
237 | }
238 |
239 | // writeByte writes ch n times to p.output and updates p.pos.
240 | // Only used to write formatting (white space) characters.
241 | func (p *printer) writeByte(ch byte, n int) {
242 | if p.endAlignment {
243 | // Ignore any alignment control character;
244 | // and at the end of the line, break with
245 | // a formfeed to indicate termination of
246 | // existing columns.
247 | switch ch {
248 | case '\t', '\v':
249 | ch = ' '
250 | case '\n', '\f':
251 | ch = '\f'
252 | p.endAlignment = false
253 | }
254 | }
255 |
256 | if p.out.Column == 1 {
257 | // no need to write line directives before white space
258 | p.writeIndent()
259 | }
260 |
261 | for i := 0; i < n; i++ {
262 | p.output = append(p.output, ch)
263 | }
264 |
265 | // update positions
266 | p.pos.Offset += n
267 | if ch == '\n' || ch == '\f' {
268 | p.pos.Line += n
269 | p.out.Line += n
270 | p.pos.Column = 1
271 | p.out.Column = 1
272 | return
273 | }
274 | p.pos.Column += n
275 | p.out.Column += n
276 | }
277 |
278 | // writeString writes the string s to p.output and updates p.pos, p.out,
279 | // and p.last. If isLit is set, s is escaped w/ tabwriter.Escape characters
280 | // to protect s from being interpreted by the tabwriter.
281 | //
282 | // Note: writeString is only used to write Go tokens, literals, and
283 | // comments, all of which must be written literally. Thus, it is correct
284 | // to always set isLit = true. However, setting it explicitly only when
285 | // needed (i.e., when we don't know that s contains no tabs or line breaks)
286 | // avoids processing extra escape characters and reduces run time of the
287 | // printer benchmark by up to 10%.
288 | //
289 | func (p *printer) writeString(pos token.Position, s string, isLit bool) {
290 | if p.out.Column == 1 {
291 | if p.Config.Mode&SourcePos != 0 {
292 | p.writeLineDirective(pos)
293 | }
294 | p.writeIndent()
295 | }
296 |
297 | if pos.IsValid() {
298 | // update p.pos (if pos is invalid, continue with existing p.pos)
299 | // Note: Must do this after handling line beginnings because
300 | // writeIndent updates p.pos if there's indentation, but p.pos
301 | // is the position of s.
302 | p.pos = pos
303 | }
304 |
305 | if isLit {
306 | // Protect s such that is passes through the tabwriter
307 | // unchanged. Note that valid Go programs cannot contain
308 | // tabwriter.Escape bytes since they do not appear in legal
309 | // UTF-8 sequences.
310 | p.output = append(p.output, tabwriter.Escape)
311 | }
312 |
313 | if debug {
314 | p.output = append(p.output, fmt.Sprintf("/*%s*/", pos)...) // do not update p.pos!
315 | }
316 | p.output = append(p.output, s...)
317 |
318 | // update positions
319 | nlines := 0
320 | var li int // index of last newline; valid if nlines > 0
321 | for i := 0; i < len(s); i++ {
322 | // Raw string literals may contain any character except back quote (`).
323 | if ch := s[i]; ch == '\n' || ch == '\f' {
324 | // account for line break
325 | nlines++
326 | li = i
327 | // A line break inside a literal will break whatever column
328 | // formatting is in place; ignore any further alignment through
329 | // the end of the line.
330 | p.endAlignment = true
331 | }
332 | }
333 | p.pos.Offset += len(s)
334 | if nlines > 0 {
335 | p.pos.Line += nlines
336 | p.out.Line += nlines
337 | c := len(s) - li
338 | p.pos.Column = c
339 | p.out.Column = c
340 | } else {
341 | p.pos.Column += len(s)
342 | p.out.Column += len(s)
343 | }
344 |
345 | if isLit {
346 | p.output = append(p.output, tabwriter.Escape)
347 | }
348 |
349 | p.last = p.pos
350 | }
351 |
352 | // writeCommentPrefix writes the whitespace before a comment.
353 | // If there is any pending whitespace, it consumes as much of
354 | // it as is likely to help position the comment nicely.
355 | // pos is the comment position, next the position of the item
356 | // after all pending comments, prev is the previous comment in
357 | // a group of comments (or nil), and tok is the next token.
358 | //
359 | func (p *printer) writeCommentPrefix(pos, next token.Position, prev *ast.Comment, tok token.Token) {
360 | if len(p.output) == 0 {
361 | // the comment is the first item to be printed - don't write any whitespace
362 | return
363 | }
364 |
365 | if pos.IsValid() && pos.Filename != p.last.Filename {
366 | // comment in a different file - separate with newlines
367 | p.writeByte('\f', maxNewlines)
368 | return
369 | }
370 |
371 | if pos.Line == p.last.Line && (prev == nil || prev.Text[1] != '/') {
372 | // comment on the same line as last item:
373 | // separate with at least one separator
374 | hasSep := false
375 | if prev == nil {
376 | // first comment of a comment group
377 | j := 0
378 | for i, ch := range p.wsbuf {
379 | switch ch {
380 | case blank:
381 | // ignore any blanks before a comment
382 | p.wsbuf[i] = ignore
383 | continue
384 | case vtab:
385 | // respect existing tabs - important
386 | // for proper formatting of commented structs
387 | hasSep = true
388 | continue
389 | case indent:
390 | // apply pending indentation
391 | continue
392 | }
393 | j = i
394 | break
395 | }
396 | p.writeWhitespace(j)
397 | }
398 | // make sure there is at least one separator
399 | if !hasSep {
400 | sep := byte('\t')
401 | if pos.Line == next.Line {
402 | // next item is on the same line as the comment
403 | // (which must be a /*-style comment): separate
404 | // with a blank instead of a tab
405 | sep = ' '
406 | }
407 | p.writeByte(sep, 1)
408 | }
409 |
410 | } else {
411 | // comment on a different line:
412 | // separate with at least one line break
413 | droppedLinebreak := false
414 | j := 0
415 | for i, ch := range p.wsbuf {
416 | switch ch {
417 | case blank, vtab:
418 | // ignore any horizontal whitespace before line breaks
419 | p.wsbuf[i] = ignore
420 | continue
421 | case indent:
422 | // apply pending indentation
423 | continue
424 | case unindent:
425 | // if this is not the last unindent, apply it
426 | // as it is (likely) belonging to the last
427 | // construct (e.g., a multi-line expression list)
428 | // and is not part of closing a block
429 | if i+1 < len(p.wsbuf) && p.wsbuf[i+1] == unindent {
430 | continue
431 | }
432 | // if the next token is not a closing }, apply the unindent
433 | // if it appears that the comment is aligned with the
434 | // token; otherwise assume the unindent is part of a
435 | // closing block and stop (this scenario appears with
436 | // comments before a case label where the comments
437 | // apply to the next case instead of the current one)
438 | if tok != token.RBRACE && pos.Column == next.Column {
439 | continue
440 | }
441 | case newline, formfeed:
442 | p.wsbuf[i] = ignore
443 | droppedLinebreak = prev == nil // record only if first comment of a group
444 | }
445 | j = i
446 | break
447 | }
448 | p.writeWhitespace(j)
449 |
450 | // determine number of linebreaks before the comment
451 | n := 0
452 | if pos.IsValid() && p.last.IsValid() {
453 | n = pos.Line - p.last.Line
454 | if n < 0 { // should never happen
455 | n = 0
456 | }
457 | }
458 |
459 | // at the package scope level only (p.indent == 0),
460 | // add an extra newline if we dropped one before:
461 | // this preserves a blank line before documentation
462 | // comments at the package scope level (issue 2570)
463 | if p.indent == 0 && droppedLinebreak {
464 | n++
465 | }
466 |
467 | // make sure there is at least one line break
468 | // if the previous comment was a line comment
469 | if n == 0 && prev != nil && prev.Text[1] == '/' {
470 | n = 1
471 | }
472 |
473 | if n > 0 {
474 | // use formfeeds to break columns before a comment;
475 | // this is analogous to using formfeeds to separate
476 | // individual lines of /*-style comments
477 | p.writeByte('\f', nlimit(n))
478 | }
479 | }
480 | }
481 |
482 | // Returns true if s contains only white space
483 | // (only tabs and blanks can appear in the printer's context).
484 | //
485 | func isBlank(s string) bool {
486 | for i := 0; i < len(s); i++ {
487 | if s[i] > ' ' {
488 | return false
489 | }
490 | }
491 | return true
492 | }
493 |
494 | // commonPrefix returns the common prefix of a and b.
495 | func commonPrefix(a, b string) string {
496 | i := 0
497 | for i < len(a) && i < len(b) && a[i] == b[i] && (a[i] <= ' ' || a[i] == '*') {
498 | i++
499 | }
500 | return a[0:i]
501 | }
502 |
503 | // trimRight returns s with trailing whitespace removed.
504 | func trimRight(s string) string {
505 | return strings.TrimRightFunc(s, unicode.IsSpace)
506 | }
507 |
508 | // stripCommonPrefix removes a common prefix from /*-style comment lines (unless no
509 | // comment line is indented, all but the first line have some form of space prefix).
510 | // The prefix is computed using heuristics such that is likely that the comment
511 | // contents are nicely laid out after re-printing each line using the printer's
512 | // current indentation.
513 | //
514 | func stripCommonPrefix(lines []string) {
515 | if len(lines) <= 1 {
516 | return // at most one line - nothing to do
517 | }
518 | // len(lines) > 1
519 |
520 | // The heuristic in this function tries to handle a few
521 | // common patterns of /*-style comments: Comments where
522 | // the opening /* and closing */ are aligned and the
523 | // rest of the comment text is aligned and indented with
524 | // blanks or tabs, cases with a vertical "line of stars"
525 | // on the left, and cases where the closing */ is on the
526 | // same line as the last comment text.
527 |
528 | // Compute maximum common white prefix of all but the first,
529 | // last, and blank lines, and replace blank lines with empty
530 | // lines (the first line starts with /* and has no prefix).
531 | // In cases where only the first and last lines are not blank,
532 | // such as two-line comments, or comments where all inner lines
533 | // are blank, consider the last line for the prefix computation
534 | // since otherwise the prefix would be empty.
535 | //
536 | // Note that the first and last line are never empty (they
537 | // contain the opening /* and closing */ respectively) and
538 | // thus they can be ignored by the blank line check.
539 | prefix := ""
540 | prefixSet := false
541 | if len(lines) > 2 {
542 | for i, line := range lines[1 : len(lines)-1] {
543 | if isBlank(line) {
544 | lines[1+i] = "" // range starts with lines[1]
545 | } else {
546 | if !prefixSet {
547 | prefix = line
548 | prefixSet = true
549 | }
550 | prefix = commonPrefix(prefix, line)
551 | }
552 |
553 | }
554 | }
555 | // If we don't have a prefix yet, consider the last line.
556 | if !prefixSet {
557 | line := lines[len(lines)-1]
558 | prefix = commonPrefix(line, line)
559 | }
560 |
561 | /*
562 | * Check for vertical "line of stars" and correct prefix accordingly.
563 | */
564 | lineOfStars := false
565 | if i := strings.Index(prefix, "*"); i >= 0 {
566 | // Line of stars present.
567 | if i > 0 && prefix[i-1] == ' ' {
568 | i-- // remove trailing blank from prefix so stars remain aligned
569 | }
570 | prefix = prefix[0:i]
571 | lineOfStars = true
572 | } else {
573 | // No line of stars present.
574 | // Determine the white space on the first line after the /*
575 | // and before the beginning of the comment text, assume two
576 | // blanks instead of the /* unless the first character after
577 | // the /* is a tab. If the first comment line is empty but
578 | // for the opening /*, assume up to 3 blanks or a tab. This
579 | // whitespace may be found as suffix in the common prefix.
580 | first := lines[0]
581 | if isBlank(first[2:]) {
582 | // no comment text on the first line:
583 | // reduce prefix by up to 3 blanks or a tab
584 | // if present - this keeps comment text indented
585 | // relative to the /* and */'s if it was indented
586 | // in the first place
587 | i := len(prefix)
588 | for n := 0; n < 3 && i > 0 && prefix[i-1] == ' '; n++ {
589 | i--
590 | }
591 | if i == len(prefix) && i > 0 && prefix[i-1] == '\t' {
592 | i--
593 | }
594 | prefix = prefix[0:i]
595 | } else {
596 | // comment text on the first line
597 | suffix := make([]byte, len(first))
598 | n := 2 // start after opening /*
599 | for n < len(first) && first[n] <= ' ' {
600 | suffix[n] = first[n]
601 | n++
602 | }
603 | if n > 2 && suffix[2] == '\t' {
604 | // assume the '\t' compensates for the /*
605 | suffix = suffix[2:n]
606 | } else {
607 | // otherwise assume two blanks
608 | suffix[0], suffix[1] = ' ', ' '
609 | suffix = suffix[0:n]
610 | }
611 | // Shorten the computed common prefix by the length of
612 | // suffix, if it is found as suffix of the prefix.
613 | prefix = strings.TrimSuffix(prefix, string(suffix))
614 | }
615 | }
616 |
617 | // Handle last line: If it only contains a closing */, align it
618 | // with the opening /*, otherwise align the text with the other
619 | // lines.
620 | last := lines[len(lines)-1]
621 | closing := "*/"
622 | i := strings.Index(last, closing) // i >= 0 (closing is always present)
623 | if isBlank(last[0:i]) {
624 | // last line only contains closing */
625 | if lineOfStars {
626 | closing = " */" // add blank to align final star
627 | }
628 | lines[len(lines)-1] = prefix + closing
629 | } else {
630 | // last line contains more comment text - assume
631 | // it is aligned like the other lines and include
632 | // in prefix computation
633 | prefix = commonPrefix(prefix, last)
634 | }
635 |
636 | // Remove the common prefix from all but the first and empty lines.
637 | for i, line := range lines {
638 | if i > 0 && line != "" {
639 | lines[i] = line[len(prefix):]
640 | }
641 | }
642 | }
643 |
644 | func (p *printer) writeComment(comment *ast.Comment) {
645 | text := comment.Text
646 | pos := p.posFor(comment.Pos())
647 |
648 | const linePrefix = "//line "
649 | if strings.HasPrefix(text, linePrefix) && (!pos.IsValid() || pos.Column == 1) {
650 | // Possibly a //-style line directive.
651 | // Suspend indentation temporarily to keep line directive valid.
652 | defer func(indent int) { p.indent = indent }(p.indent)
653 | p.indent = 0
654 | }
655 |
656 | // shortcut common case of //-style comments
657 | if text[1] == '/' {
658 | cstr := "#" + trimRight(text[2:])
659 | p.writeString(pos, cstr, true)
660 | return
661 | }
662 |
663 | // for /*-style comments, print line by line and let the
664 | // write function take care of the proper indentation
665 | lines := strings.Split(text, "\n")
666 |
667 | // The comment started in the first column but is going
668 | // to be indented. For an idempotent result, add indentation
669 | // to all lines such that they look like they were indented
670 | // before - this will make sure the common prefix computation
671 | // is the same independent of how many times formatting is
672 | // applied (was issue 1835).
673 | if pos.IsValid() && pos.Column == 1 && p.indent > 0 {
674 | for i, line := range lines[1:] {
675 | lines[1+i] = " " + line
676 | }
677 | }
678 |
679 | stripCommonPrefix(lines)
680 |
681 | // write comment lines, separated by formfeed,
682 | // without a line break after the last line
683 | for i, line := range lines {
684 | if i > 0 {
685 | p.writeByte('\f', 1)
686 | pos = p.pos
687 | }
688 | if len(line) > 0 {
689 | p.writeString(pos, trimRight(line), true)
690 | }
691 | }
692 | }
693 |
694 | // writeCommentSuffix writes a line break after a comment if indicated
695 | // and processes any leftover indentation information. If a line break
696 | // is needed, the kind of break (newline vs formfeed) depends on the
697 | // pending whitespace. The writeCommentSuffix result indicates if a
698 | // newline was written or if a formfeed was dropped from the whitespace
699 | // buffer.
700 | //
701 | func (p *printer) writeCommentSuffix(needsLinebreak bool) (wroteNewline, droppedFF bool) {
702 | for i, ch := range p.wsbuf {
703 | switch ch {
704 | case blank, vtab:
705 | // ignore trailing whitespace
706 | p.wsbuf[i] = ignore
707 | case indent, unindent:
708 | // don't lose indentation information
709 | case newline, formfeed:
710 | // if we need a line break, keep exactly one
711 | // but remember if we dropped any formfeeds
712 | if needsLinebreak {
713 | needsLinebreak = false
714 | wroteNewline = true
715 | } else {
716 | if ch == formfeed {
717 | droppedFF = true
718 | }
719 | p.wsbuf[i] = ignore
720 | }
721 | }
722 | }
723 | p.writeWhitespace(len(p.wsbuf))
724 |
725 | // make sure we have a line break
726 | if needsLinebreak {
727 | p.writeByte('\n', 1)
728 | wroteNewline = true
729 | }
730 |
731 | return
732 | }
733 |
734 | // containsLinebreak reports whether the whitespace buffer contains any line breaks.
735 | func (p *printer) containsLinebreak() bool {
736 | for _, ch := range p.wsbuf {
737 | if ch == newline || ch == formfeed {
738 | return true
739 | }
740 | }
741 | return false
742 | }
743 |
744 | // intersperseComments consumes all comments that appear before the next token
745 | // tok and prints it together with the buffered whitespace (i.e., the whitespace
746 | // that needs to be written before the next token). A heuristic is used to mix
747 | // the comments and whitespace. The intersperseComments result indicates if a
748 | // newline was written or if a formfeed was dropped from the whitespace buffer.
749 | //
750 | func (p *printer) intersperseComments(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
751 | var last *ast.Comment
752 | for p.commentBefore(next) {
753 | for _, c := range p.comment.List {
754 | p.writeCommentPrefix(p.posFor(c.Pos()), next, last, tok)
755 | p.writeComment(c)
756 | last = c
757 | }
758 | p.nextComment()
759 | }
760 |
761 | if last != nil {
762 | // If the last comment is a /*-style comment and the next item
763 | // follows on the same line but is not a comma, and not a "closing"
764 | // token immediately following its corresponding "opening" token,
765 | // add an extra separator unless explicitly disabled. Use a blank
766 | // as separator unless we have pending linebreaks, they are not
767 | // disabled, and we are outside a composite literal, in which case
768 | // we want a linebreak (issue 15137).
769 | // TODO(gri) This has become overly complicated. We should be able
770 | // to track whether we're inside an expression or statement and
771 | // use that information to decide more directly.
772 | needsLinebreak := false
773 | if p.mode&noExtraBlank == 0 &&
774 | last.Text[1] == '*' && p.lineFor(last.Pos()) == next.Line &&
775 | tok != token.COMMA &&
776 | (tok != token.RPAREN || p.prevOpen == token.LPAREN) &&
777 | (tok != token.RBRACK || p.prevOpen == token.LBRACK) {
778 | if p.containsLinebreak() && p.mode&noExtraLinebreak == 0 && p.level == 0 {
779 | needsLinebreak = true
780 | } else {
781 | p.writeByte(' ', 1)
782 | }
783 | }
784 | // Ensure that there is a line break after a //-style comment,
785 | // before EOF, and before a closing '}' unless explicitly disabled.
786 | if last.Text[1] == '/' ||
787 | tok == token.EOF ||
788 | tok == token.RBRACE && p.mode&noExtraLinebreak == 0 {
789 | needsLinebreak = true
790 | }
791 | return p.writeCommentSuffix(needsLinebreak)
792 | }
793 |
794 | // no comment was written - we should never reach here since
795 | // intersperseComments should not be called in that case
796 | p.internalError("intersperseComments called without pending comments")
797 | return
798 | }
799 |
800 | // whiteWhitespace writes the first n whitespace entries.
801 | func (p *printer) writeWhitespace(n int) {
802 | // write entries
803 | for i := 0; i < n; i++ {
804 | switch ch := p.wsbuf[i]; ch {
805 | case ignore:
806 | // ignore!
807 | case indent:
808 | p.indent++
809 | case unindent:
810 | p.indent--
811 | if p.indent < 0 {
812 | p.internalError("negative indentation:", p.indent)
813 | p.indent = 0
814 | }
815 | case newline, formfeed:
816 | // A line break immediately followed by a "correcting"
817 | // unindent is swapped with the unindent - this permits
818 | // proper label positioning. If a comment is between
819 | // the line break and the label, the unindent is not
820 | // part of the comment whitespace prefix and the comment
821 | // will be positioned correctly indented.
822 | if i+1 < n && p.wsbuf[i+1] == unindent {
823 | // Use a formfeed to terminate the current section.
824 | // Otherwise, a long label name on the next line leading
825 | // to a wide column may increase the indentation column
826 | // of lines before the label; effectively leading to wrong
827 | // indentation.
828 | p.wsbuf[i], p.wsbuf[i+1] = unindent, formfeed
829 | i-- // do it again
830 | continue
831 | }
832 | fallthrough
833 | default:
834 | p.writeByte(byte(ch), 1)
835 | }
836 | }
837 |
838 | // shift remaining entries down
839 | l := copy(p.wsbuf, p.wsbuf[n:])
840 | p.wsbuf = p.wsbuf[:l]
841 | }
842 |
843 | // ----------------------------------------------------------------------------
844 | // Printing interface
845 |
846 | // nlines limits n to maxNewlines.
847 | func nlimit(n int) int {
848 | if n > maxNewlines {
849 | n = maxNewlines
850 | }
851 | return n
852 | }
853 |
854 | func mayCombine(prev token.Token, next byte) (b bool) {
855 | switch prev {
856 | case token.INT:
857 | b = next == '.' // 1.
858 | case token.ADD:
859 | b = next == '+' // ++
860 | case token.SUB:
861 | b = next == '-' // --
862 | case token.QUO:
863 | b = next == '*' // /*
864 | case token.LSS:
865 | b = next == '-' || next == '<' // <- or <<
866 | case token.AND:
867 | b = next == '&' || next == '^' // && or &^
868 | }
869 | return
870 | }
871 |
872 | // print prints a list of "items" (roughly corresponding to syntactic
873 | // tokens, but also including whitespace and formatting information).
874 | // It is the only print function that should be called directly from
875 | // any of the AST printing functions in nodes.go.
876 | //
877 | // Whitespace is accumulated until a non-whitespace token appears. Any
878 | // comments that need to appear before that token are printed first,
879 | // taking into account the amount and structure of any pending white-
880 | // space for best comment placement. Then, any leftover whitespace is
881 | // printed, followed by the actual token.
882 | //
883 | func (p *printer) print(args ...interface{}) {
884 | for _, arg := range args {
885 | // information about the current arg
886 | var data string
887 | var isLit bool
888 | var impliedSemi bool // value for p.impliedSemi after this arg
889 |
890 | // record previous opening token, if any
891 | switch p.lastTok {
892 | case token.ILLEGAL:
893 | // ignore (white space)
894 | case token.LPAREN, token.LBRACK:
895 | p.prevOpen = p.lastTok
896 | default:
897 | // other tokens followed any opening token
898 | p.prevOpen = token.ILLEGAL
899 | }
900 |
901 | switch x := arg.(type) {
902 | case pmode:
903 | // toggle printer mode
904 | p.mode ^= x
905 | continue
906 |
907 | case whiteSpace:
908 | if x == ignore {
909 | // don't add ignore's to the buffer; they
910 | // may screw up "correcting" unindents (see
911 | // LabeledStmt)
912 | continue
913 | }
914 | i := len(p.wsbuf)
915 | if i == cap(p.wsbuf) {
916 | // Whitespace sequences are very short so this should
917 | // never happen. Handle gracefully (but possibly with
918 | // bad comment placement) if it does happen.
919 | p.writeWhitespace(i)
920 | i = 0
921 | }
922 | p.wsbuf = p.wsbuf[0 : i+1]
923 | p.wsbuf[i] = x
924 | if x == newline || x == formfeed {
925 | // newlines affect the current state (p.impliedSemi)
926 | // and not the state after printing arg (impliedSemi)
927 | // because comments can be interspersed before the arg
928 | // in this case
929 | p.impliedSemi = false
930 | }
931 | p.lastTok = token.ILLEGAL
932 | continue
933 |
934 | case *ast.Ident:
935 | data = x.Name
936 | impliedSemi = true
937 | p.lastTok = token.IDENT
938 |
939 | case *ast.BasicLit:
940 | data = x.Value
941 | isLit = true
942 | impliedSemi = true
943 | p.lastTok = x.Kind
944 |
945 | case token.Token:
946 | s := x.String()
947 | if mayCombine(p.lastTok, s[0]) {
948 | // the previous and the current token must be
949 | // separated by a blank otherwise they combine
950 | // into a different incorrect token sequence
951 | // (except for token.INT followed by a '.' this
952 | // should never happen because it is taken care
953 | // of via binary expression formatting)
954 | if len(p.wsbuf) != 0 {
955 | p.internalError("whitespace buffer not empty")
956 | }
957 | p.wsbuf = p.wsbuf[0:1]
958 | p.wsbuf[0] = ' '
959 | }
960 | data = s
961 | // some keywords followed by a newline imply a semicolon
962 | switch x {
963 | case token.BREAK, token.CONTINUE, token.FALLTHROUGH, token.RETURN,
964 | token.INC, token.DEC, token.RPAREN, token.RBRACK, token.RBRACE:
965 | impliedSemi = true
966 | }
967 | p.lastTok = x
968 |
969 | case token.Pos:
970 | if x.IsValid() {
971 | p.pos = p.posFor(x) // accurate position of next item
972 | }
973 | continue
974 |
975 | case string:
976 | // incorrect AST - print error message
977 | data = x
978 | isLit = true
979 | impliedSemi = true
980 | p.lastTok = token.STRING
981 |
982 | default:
983 | fmt.Fprintf(os.Stderr, "print: unsupported argument %v (%T)\n", arg, arg)
984 | panic("go/printer type")
985 | }
986 | // data != ""
987 |
988 | next := p.pos // estimated/accurate position of next item
989 | wroteNewline, droppedFF := p.flush(next, p.lastTok)
990 |
991 | // intersperse extra newlines if present in the source and
992 | // if they don't cause extra semicolons (don't do this in
993 | // flush as it will cause extra newlines at the end of a file)
994 | if !p.impliedSemi {
995 | n := nlimit(next.Line - p.pos.Line)
996 | // don't exceed maxNewlines if we already wrote one
997 | if wroteNewline && n == maxNewlines {
998 | n = maxNewlines - 1
999 | }
1000 | if n > 0 {
1001 | ch := byte('\n')
1002 | if droppedFF {
1003 | ch = '\f' // use formfeed since we dropped one before
1004 | }
1005 | p.writeByte(ch, n)
1006 | impliedSemi = false
1007 | }
1008 | }
1009 |
1010 | // the next token starts now - record its line number if requested
1011 | if p.linePtr != nil {
1012 | *p.linePtr = p.out.Line
1013 | p.linePtr = nil
1014 | }
1015 |
1016 | switch data {
1017 | case "true":
1018 | data = "True"
1019 | case "false":
1020 | data = "False"
1021 | case "nil":
1022 | if p.Config.Mode&GoPy != 0 {
1023 | data = "go.nil"
1024 | }
1025 | }
1026 |
1027 | p.writeString(next, data, isLit)
1028 | p.impliedSemi = impliedSemi
1029 | }
1030 | }
1031 |
1032 | // flush prints any pending comments and whitespace occurring textually
1033 | // before the position of the next token tok. The flush result indicates
1034 | // if a newline was written or if a formfeed was dropped from the whitespace
1035 | // buffer.
1036 | //
1037 | func (p *printer) flush(next token.Position, tok token.Token) (wroteNewline, droppedFF bool) {
1038 | if p.commentBefore(next) {
1039 | // if there are comments before the next item, intersperse them
1040 | wroteNewline, droppedFF = p.intersperseComments(next, tok)
1041 | } else {
1042 | // otherwise, write any leftover whitespace
1043 | p.writeWhitespace(len(p.wsbuf))
1044 | }
1045 | return
1046 | }
1047 |
1048 | // getNode returns the ast.CommentGroup associated with n, if any.
1049 | func getDoc(n ast.Node) *ast.CommentGroup {
1050 | switch n := n.(type) {
1051 | case *ast.Field:
1052 | return n.Doc
1053 | case *ast.ImportSpec:
1054 | return n.Doc
1055 | case *ast.ValueSpec:
1056 | return n.Doc
1057 | case *ast.TypeSpec:
1058 | return n.Doc
1059 | case *ast.GenDecl:
1060 | return n.Doc
1061 | // case *ast.FuncDecl: // handled separately
1062 | // return n.Doc
1063 | case *ast.File:
1064 | return n.Doc
1065 | }
1066 | return nil
1067 | }
1068 |
1069 | func getLastComment(n ast.Node) *ast.CommentGroup {
1070 | switch n := n.(type) {
1071 | case *ast.Field:
1072 | return n.Comment
1073 | case *ast.ImportSpec:
1074 | return n.Comment
1075 | case *ast.ValueSpec:
1076 | return n.Comment
1077 | case *ast.TypeSpec:
1078 | return n.Comment
1079 | case *ast.GenDecl:
1080 | if len(n.Specs) > 0 {
1081 | return getLastComment(n.Specs[len(n.Specs)-1])
1082 | }
1083 | case *ast.File:
1084 | if len(n.Comments) > 0 {
1085 | return n.Comments[len(n.Comments)-1]
1086 | }
1087 | }
1088 | return nil
1089 | }
1090 |
1091 | func (p *printer) printNode(node interface{}) error {
1092 | // unpack *CommentedNode, if any
1093 | var comments []*ast.CommentGroup
1094 | if cnode, ok := node.(*CommentedNode); ok {
1095 | node = cnode.Node
1096 | comments = cnode.Comments
1097 | }
1098 |
1099 | if comments != nil {
1100 | // commented node - restrict comment list to relevant range
1101 | n, ok := node.(ast.Node)
1102 | if !ok {
1103 | goto unsupported
1104 | }
1105 | beg := n.Pos()
1106 | end := n.End()
1107 | // if the node has associated documentation,
1108 | // include that commentgroup in the range
1109 | // (the comment list is sorted in the order
1110 | // of the comment appearance in the source code)
1111 | if doc := getDoc(n); doc != nil {
1112 | beg = doc.Pos()
1113 | }
1114 | if com := getLastComment(n); com != nil {
1115 | if e := com.End(); e > end {
1116 | end = e
1117 | }
1118 | }
1119 | // token.Pos values are global offsets, we can
1120 | // compare them directly
1121 | i := 0
1122 | for i < len(comments) && comments[i].End() < beg {
1123 | i++
1124 | }
1125 | j := i
1126 | for j < len(comments) && comments[j].Pos() < end {
1127 | j++
1128 | }
1129 | if i < j {
1130 | p.comments = comments[i:j]
1131 | }
1132 | } else if n, ok := node.(*ast.File); ok {
1133 | // use ast.File comments, if any
1134 | // note: this seems to be typical source for a parsed .go file
1135 | p.comments = n.Comments
1136 | }
1137 |
1138 | // if there are no comments, use node comments
1139 | p.useNodeComments = p.comments == nil
1140 |
1141 | // get comments ready for use
1142 | p.nextComment()
1143 |
1144 | // format node
1145 | switch n := node.(type) {
1146 | case ast.Expr:
1147 | p.expr(n)
1148 | case ast.Stmt:
1149 | // A labeled statement will un-indent to position the label.
1150 | // Set p.indent to 1 so we don't get indent "underflow".
1151 | if _, ok := n.(*ast.LabeledStmt); ok {
1152 | p.indent = 1
1153 | }
1154 | p.stmt(n, false)
1155 | case ast.Decl:
1156 | p.decl(n)
1157 | case ast.Spec:
1158 | p.spec(n, 1, false)
1159 | case []ast.Stmt:
1160 | // A labeled statement will un-indent to position the label.
1161 | // Set p.indent to 1 so we don't get indent "underflow".
1162 | for _, s := range n {
1163 | if _, ok := s.(*ast.LabeledStmt); ok {
1164 | p.indent = 1
1165 | }
1166 | }
1167 | p.stmtList(n, 0, false)
1168 | case []ast.Decl:
1169 | p.declList(n)
1170 | case *ast.File:
1171 | p.file(n)
1172 | default:
1173 | goto unsupported
1174 | }
1175 |
1176 | return nil
1177 |
1178 | unsupported:
1179 | return fmt.Errorf("go/printer: unsupported node type %T", node)
1180 | }
1181 |
1182 | // ----------------------------------------------------------------------------
1183 | // Trimmer
1184 |
1185 | // A trimmer is an io.Writer filter for stripping tabwriter.Escape
1186 | // characters, trailing blanks and tabs, and for converting formfeed
1187 | // and vtab characters into newlines and htabs (in case no tabwriter
1188 | // is used). Text bracketed by tabwriter.Escape characters is passed
1189 | // through unchanged.
1190 | //
1191 | type trimmer struct {
1192 | output io.Writer
1193 | state int
1194 | space []byte
1195 | }
1196 |
1197 | // trimmer is implemented as a state machine.
1198 | // It can be in one of the following states:
1199 | const (
1200 | inSpace = iota // inside space
1201 | inEscape // inside text bracketed by tabwriter.Escapes
1202 | inText // inside text
1203 | )
1204 |
1205 | func (p *trimmer) resetSpace() {
1206 | p.state = inSpace
1207 | p.space = p.space[0:0]
1208 | }
1209 |
1210 | // Design note: It is tempting to eliminate extra blanks occurring in
1211 | // whitespace in this function as it could simplify some
1212 | // of the blanks logic in the node printing functions.
1213 | // However, this would mess up any formatting done by
1214 | // the tabwriter.
1215 |
1216 | var aNewline = []byte("\n")
1217 |
1218 | func (p *trimmer) Write(data []byte) (n int, err error) {
1219 | // invariants:
1220 | // p.state == inSpace:
1221 | // p.space is unwritten
1222 | // p.state == inEscape, inText:
1223 | // data[m:n] is unwritten
1224 | m := 0
1225 | var b byte
1226 | for n, b = range data {
1227 | if b == '\v' {
1228 | b = '\t' // convert to htab
1229 | }
1230 | switch p.state {
1231 | case inSpace:
1232 | switch b {
1233 | case '\t', ' ':
1234 | p.space = append(p.space, b)
1235 | case '\n', '\f':
1236 | p.resetSpace() // discard trailing space
1237 | _, err = p.output.Write(aNewline)
1238 | case tabwriter.Escape:
1239 | _, err = p.output.Write(p.space)
1240 | p.state = inEscape
1241 | m = n + 1 // +1: skip tabwriter.Escape
1242 | default:
1243 | _, err = p.output.Write(p.space)
1244 | p.state = inText
1245 | m = n
1246 | }
1247 | case inEscape:
1248 | if b == tabwriter.Escape {
1249 | _, err = p.output.Write(data[m:n])
1250 | p.resetSpace()
1251 | }
1252 | case inText:
1253 | switch b {
1254 | case '\t', ' ':
1255 | _, err = p.output.Write(data[m:n])
1256 | p.resetSpace()
1257 | p.space = append(p.space, b)
1258 | case '\n', '\f':
1259 | _, err = p.output.Write(data[m:n])
1260 | p.resetSpace()
1261 | if err == nil {
1262 | _, err = p.output.Write(aNewline)
1263 | }
1264 | case tabwriter.Escape:
1265 | _, err = p.output.Write(data[m:n])
1266 | p.state = inEscape
1267 | m = n + 1 // +1: skip tabwriter.Escape
1268 | }
1269 | default:
1270 | panic("unreachable")
1271 | }
1272 | if err != nil {
1273 | return
1274 | }
1275 | }
1276 | n = len(data)
1277 |
1278 | switch p.state {
1279 | case inEscape, inText:
1280 | _, err = p.output.Write(data[m:n])
1281 | p.resetSpace()
1282 | }
1283 |
1284 | return
1285 | }
1286 |
1287 | // ----------------------------------------------------------------------------
1288 | // Public interface
1289 |
1290 | // A Mode value is a set of flags (or 0). They control printing.
1291 | type Mode uint
1292 |
1293 | const (
1294 | RawFormat Mode = 1 << iota // do not use a tabwriter; if set, UseSpaces is ignored
1295 | TabIndent // use tabs for indentation independent of UseSpaces
1296 | UseSpaces // use spaces instead of tabs for alignment
1297 | SourcePos // emit //line directives to preserve original source positions
1298 | GoPy // support GoPy-specific Python code generation
1299 | GoGi // support GoGi-specific Python code generation
1300 | )
1301 |
1302 | // The mode below is not included in printer's public API because
1303 | // editing code text is deemed out of scope. Because this mode is
1304 | // unexported, it's also possible to modify or remove it based on
1305 | // the evolving needs of go/format and cmd/gofmt without breaking
1306 | // users. See discussion in CL 240683.
1307 | const (
1308 | // normalizeNumbers means to canonicalize number
1309 | // literal prefixes and exponents while printing.
1310 | //
1311 | // This value is known in and used by go/format and cmd/gofmt.
1312 | // It is currently more convenient and performant for those
1313 | // packages to apply number normalization during printing,
1314 | // rather than by modifying the AST in advance.
1315 | normalizeNumbers Mode = 1 << 30
1316 | )
1317 |
1318 | // A Config node controls the output of Fprint.
1319 | type Config struct {
1320 | Mode Mode // default: 0
1321 | Tabwidth int // default: 8
1322 | Indent int // default: 0 (all code is indented at least by this much)
1323 | }
1324 |
1325 | // fprint implements Fprint and takes a nodesSizes map for setting up the printer state.
1326 | func (cfg *Config) fprint(output io.Writer, fset *token.FileSet, node interface{}, nodeSizes map[ast.Node]int) (err error) {
1327 | // print node
1328 | var p printer
1329 | p.init(cfg, fset, nodeSizes)
1330 | if err = p.printNode(node); err != nil {
1331 | return
1332 | }
1333 | // print outstanding comments
1334 | p.impliedSemi = false // EOF acts like a newline
1335 | p.flush(token.Position{Offset: infinity, Line: infinity}, token.EOF)
1336 |
1337 | // redirect output through a trimmer to eliminate trailing whitespace
1338 | // (Input to a tabwriter must be untrimmed since trailing tabs provide
1339 | // formatting information. The tabwriter could provide trimming
1340 | // functionality but no tabwriter is used when RawFormat is set.)
1341 | output = &trimmer{output: output}
1342 |
1343 | // redirect output through a tabwriter if necessary
1344 | if cfg.Mode&RawFormat == 0 {
1345 | minwidth := cfg.Tabwidth
1346 |
1347 | padchar := byte('\t')
1348 | if cfg.Mode&UseSpaces != 0 {
1349 | padchar = ' '
1350 | }
1351 |
1352 | twmode := tabwriter.DiscardEmptyColumns
1353 | if cfg.Mode&TabIndent != 0 {
1354 | minwidth = 0
1355 | twmode |= tabwriter.TabIndent
1356 | }
1357 |
1358 | output = tabwriter.NewWriter(output, minwidth, cfg.Tabwidth, 1, padchar, twmode)
1359 | }
1360 |
1361 | // write printer result via tabwriter/trimmer to output
1362 | if _, err = output.Write(p.output); err != nil {
1363 | return
1364 | }
1365 |
1366 | // flush tabwriter, if any
1367 | if tw, _ := output.(*tabwriter.Writer); tw != nil {
1368 | err = tw.Flush()
1369 | }
1370 |
1371 | return
1372 | }
1373 |
1374 | // A CommentedNode bundles an AST node and corresponding comments.
1375 | // It may be provided as argument to any of the Fprint functions.
1376 | //
1377 | type CommentedNode struct {
1378 | Node interface{} // *ast.File, or ast.Expr, ast.Decl, ast.Spec, or ast.Stmt
1379 | Comments []*ast.CommentGroup
1380 | }
1381 |
1382 | // Fprint "pretty-prints" an AST node to output for a given configuration cfg.
1383 | // Position information is interpreted relative to the file set fset.
1384 | // The node type must be *ast.File, *CommentedNode, []ast.Decl, []ast.Stmt,
1385 | // or assignment-compatible to ast.Expr, ast.Decl, ast.Spec, or ast.Stmt.
1386 | //
1387 | func (cfg *Config) Fprint(output io.Writer, fset *token.FileSet, node interface{}) error {
1388 | return cfg.fprint(output, fset, node, make(map[ast.Node]int))
1389 | }
1390 |
1391 | // Fprint "pretty-prints" an AST node to output.
1392 | // It calls Config.Fprint with default settings.
1393 | // Note that gofmt uses tabs for indentation but spaces for alignment;
1394 | // use format.Node (package go/format) for output that matches gofmt.
1395 | //
1396 | func Fprint(output io.Writer, fset *token.FileSet, node interface{}) error {
1397 | return (&Config{Tabwidth: 8}).Fprint(output, fset, node)
1398 | }
1399 |
--------------------------------------------------------------------------------
/rewrite.go:
--------------------------------------------------------------------------------
1 | // Copyright 2020 The Go-Python Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // This is based on gofmt source code:
6 |
7 | // Copyright 2009 The Go Authors. All rights reserved.
8 | // Use of this source code is governed by a BSD-style
9 | // license that can be found in the LICENSE file.
10 |
11 | package main
12 |
13 | import (
14 | "fmt"
15 | "go/ast"
16 | "go/parser"
17 | "go/token"
18 | "os"
19 | "reflect"
20 | "strings"
21 | "unicode"
22 | "unicode/utf8"
23 | )
24 |
25 | func initRewrite() {
26 | if *rewriteRule == "" {
27 | rewrite = nil // disable any previous rewrite
28 | return
29 | }
30 | f := strings.Split(*rewriteRule, "->")
31 | if len(f) != 2 {
32 | fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
33 | os.Exit(2)
34 | }
35 | pattern := parseExpr(f[0], "pattern")
36 | replace := parseExpr(f[1], "replacement")
37 | rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
38 | }
39 |
40 | // parseExpr parses s as an expression.
41 | // It might make sense to expand this to allow statement patterns,
42 | // but there are problems with preserving formatting and also
43 | // with what a wildcard for a statement looks like.
44 | func parseExpr(s, what string) ast.Expr {
45 | x, err := parser.ParseExpr(s)
46 | if err != nil {
47 | fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
48 | os.Exit(2)
49 | }
50 | return x
51 | }
52 |
53 | // Keep this function for debugging.
54 | /*
55 | func dump(msg string, val reflect.Value) {
56 | fmt.Printf("%s:\n", msg)
57 | ast.Print(fileSet, val.Interface())
58 | fmt.Println()
59 | }
60 | */
61 |
62 | // rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
63 | func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
64 | cmap := ast.NewCommentMap(fileSet, p, p.Comments)
65 | m := make(map[string]reflect.Value)
66 | pat := reflect.ValueOf(pattern)
67 | repl := reflect.ValueOf(replace)
68 |
69 | var rewriteVal func(val reflect.Value) reflect.Value
70 | rewriteVal = func(val reflect.Value) reflect.Value {
71 | // don't bother if val is invalid to start with
72 | if !val.IsValid() {
73 | return reflect.Value{}
74 | }
75 | val = apply(rewriteVal, val)
76 | for k := range m {
77 | delete(m, k)
78 | }
79 | if match(m, pat, val) {
80 | val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
81 | }
82 | return val
83 | }
84 |
85 | r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
86 | r.Comments = cmap.Filter(r).Comments() // recreate comments list
87 | return r
88 | }
89 |
90 | // set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
91 | func set(x, y reflect.Value) {
92 | // don't bother if x cannot be set or y is invalid
93 | if !x.CanSet() || !y.IsValid() {
94 | return
95 | }
96 | defer func() {
97 | if x := recover(); x != nil {
98 | if s, ok := x.(string); ok &&
99 | (strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
100 | // x cannot be set to y - ignore this rewrite
101 | return
102 | }
103 | panic(x)
104 | }
105 | }()
106 | x.Set(y)
107 | }
108 |
109 | // Values/types for special cases.
110 | var (
111 | objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
112 | scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
113 |
114 | identType = reflect.TypeOf((*ast.Ident)(nil))
115 | objectPtrType = reflect.TypeOf((*ast.Object)(nil))
116 | positionType = reflect.TypeOf(token.NoPos)
117 | callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
118 | scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
119 | )
120 |
121 | // apply replaces each AST field x in val with f(x), returning val.
122 | // To avoid extra conversions, f operates on the reflect.Value form.
123 | func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
124 | if !val.IsValid() {
125 | return reflect.Value{}
126 | }
127 |
128 | // *ast.Objects introduce cycles and are likely incorrect after
129 | // rewrite; don't follow them but replace with nil instead
130 | if val.Type() == objectPtrType {
131 | return objectPtrNil
132 | }
133 |
134 | // similarly for scopes: they are likely incorrect after a rewrite;
135 | // replace them with nil
136 | if val.Type() == scopePtrType {
137 | return scopePtrNil
138 | }
139 |
140 | switch v := reflect.Indirect(val); v.Kind() {
141 | case reflect.Slice:
142 | for i := 0; i < v.Len(); i++ {
143 | e := v.Index(i)
144 | set(e, f(e))
145 | }
146 | case reflect.Struct:
147 | for i := 0; i < v.NumField(); i++ {
148 | e := v.Field(i)
149 | set(e, f(e))
150 | }
151 | case reflect.Interface:
152 | e := v.Elem()
153 | set(v, f(e))
154 | }
155 | return val
156 | }
157 |
158 | func isWildcard(s string) bool {
159 | rune, size := utf8.DecodeRuneInString(s)
160 | return size == len(s) && unicode.IsLower(rune)
161 | }
162 |
163 | // match reports whether pattern matches val,
164 | // recording wildcard submatches in m.
165 | // If m == nil, match checks whether pattern == val.
166 | func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
167 | // Wildcard matches any expression. If it appears multiple
168 | // times in the pattern, it must match the same expression
169 | // each time.
170 | if m != nil && pattern.IsValid() && pattern.Type() == identType {
171 | name := pattern.Interface().(*ast.Ident).Name
172 | if isWildcard(name) && val.IsValid() {
173 | // wildcards only match valid (non-nil) expressions.
174 | if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
175 | if old, ok := m[name]; ok {
176 | return match(nil, old, val)
177 | }
178 | m[name] = val
179 | return true
180 | }
181 | }
182 | }
183 |
184 | // Otherwise, pattern and val must match recursively.
185 | if !pattern.IsValid() || !val.IsValid() {
186 | return !pattern.IsValid() && !val.IsValid()
187 | }
188 | if pattern.Type() != val.Type() {
189 | return false
190 | }
191 |
192 | // Special cases.
193 | switch pattern.Type() {
194 | case identType:
195 | // For identifiers, only the names need to match
196 | // (and none of the other *ast.Object information).
197 | // This is a common case, handle it all here instead
198 | // of recursing down any further via reflection.
199 | p := pattern.Interface().(*ast.Ident)
200 | v := val.Interface().(*ast.Ident)
201 | return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
202 | case objectPtrType, positionType:
203 | // object pointers and token positions always match
204 | return true
205 | case callExprType:
206 | // For calls, the Ellipsis fields (token.Position) must
207 | // match since that is how f(x) and f(x...) are different.
208 | // Check them here but fall through for the remaining fields.
209 | p := pattern.Interface().(*ast.CallExpr)
210 | v := val.Interface().(*ast.CallExpr)
211 | if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
212 | return false
213 | }
214 | }
215 |
216 | p := reflect.Indirect(pattern)
217 | v := reflect.Indirect(val)
218 | if !p.IsValid() || !v.IsValid() {
219 | return !p.IsValid() && !v.IsValid()
220 | }
221 |
222 | switch p.Kind() {
223 | case reflect.Slice:
224 | if p.Len() != v.Len() {
225 | return false
226 | }
227 | for i := 0; i < p.Len(); i++ {
228 | if !match(m, p.Index(i), v.Index(i)) {
229 | return false
230 | }
231 | }
232 | return true
233 |
234 | case reflect.Struct:
235 | for i := 0; i < p.NumField(); i++ {
236 | if !match(m, p.Field(i), v.Field(i)) {
237 | return false
238 | }
239 | }
240 | return true
241 |
242 | case reflect.Interface:
243 | return match(m, p.Elem(), v.Elem())
244 | }
245 |
246 | // Handle token integers, etc.
247 | return p.Interface() == v.Interface()
248 | }
249 |
250 | // subst returns a copy of pattern with values from m substituted in place
251 | // of wildcards and pos used as the position of tokens from the pattern.
252 | // if m == nil, subst returns a copy of pattern and doesn't change the line
253 | // number information.
254 | func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
255 | if !pattern.IsValid() {
256 | return reflect.Value{}
257 | }
258 |
259 | // Wildcard gets replaced with map value.
260 | if m != nil && pattern.Type() == identType {
261 | name := pattern.Interface().(*ast.Ident).Name
262 | if isWildcard(name) {
263 | if old, ok := m[name]; ok {
264 | return subst(nil, old, reflect.Value{})
265 | }
266 | }
267 | }
268 |
269 | if pos.IsValid() && pattern.Type() == positionType {
270 | // use new position only if old position was valid in the first place
271 | if old := pattern.Interface().(token.Pos); !old.IsValid() {
272 | return pattern
273 | }
274 | return pos
275 | }
276 |
277 | // Otherwise copy.
278 | switch p := pattern; p.Kind() {
279 | case reflect.Slice:
280 | if p.IsNil() {
281 | // Do not turn nil slices into empty slices. go/ast
282 | // guarantees that certain lists will be nil if not
283 | // populated.
284 | return reflect.Zero(p.Type())
285 | }
286 | v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
287 | for i := 0; i < p.Len(); i++ {
288 | v.Index(i).Set(subst(m, p.Index(i), pos))
289 | }
290 | return v
291 |
292 | case reflect.Struct:
293 | v := reflect.New(p.Type()).Elem()
294 | for i := 0; i < p.NumField(); i++ {
295 | v.Field(i).Set(subst(m, p.Field(i), pos))
296 | }
297 | return v
298 |
299 | case reflect.Ptr:
300 | v := reflect.New(p.Type()).Elem()
301 | if elem := p.Elem(); elem.IsValid() {
302 | v.Set(subst(m, elem, pos).Addr())
303 | }
304 | return v
305 |
306 | case reflect.Interface:
307 | v := reflect.New(p.Type()).Elem()
308 | if elem := p.Elem(); elem.IsValid() {
309 | v.Set(subst(m, elem, pos))
310 | }
311 | return v
312 | }
313 |
314 | return pattern
315 | }
316 |
--------------------------------------------------------------------------------
/simplify.go:
--------------------------------------------------------------------------------
1 | // Copyright 2020 The Go-Python Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | // This is based on gofmt source code:
6 |
7 | // Copyright 2010 The Go Authors. All rights reserved.
8 | // Use of this source code is governed by a BSD-style
9 | // license that can be found in the LICENSE file.
10 |
11 | package main
12 |
13 | import (
14 | "go/ast"
15 | "go/token"
16 | "reflect"
17 | )
18 |
19 | type simplifier struct{}
20 |
21 | func (s simplifier) Visit(node ast.Node) ast.Visitor {
22 | switch n := node.(type) {
23 | case *ast.CompositeLit:
24 | // array, slice, and map composite literals may be simplified
25 | outer := n
26 | var keyType, eltType ast.Expr
27 | switch typ := outer.Type.(type) {
28 | case *ast.ArrayType:
29 | eltType = typ.Elt
30 | case *ast.MapType:
31 | keyType = typ.Key
32 | eltType = typ.Value
33 | }
34 |
35 | if eltType != nil {
36 | var ktyp reflect.Value
37 | if keyType != nil {
38 | ktyp = reflect.ValueOf(keyType)
39 | }
40 | typ := reflect.ValueOf(eltType)
41 | for i, x := range outer.Elts {
42 | px := &outer.Elts[i]
43 | // look at value of indexed/named elements
44 | if t, ok := x.(*ast.KeyValueExpr); ok {
45 | if keyType != nil {
46 | s.simplifyLiteral(ktyp, keyType, t.Key, &t.Key)
47 | }
48 | x = t.Value
49 | px = &t.Value
50 | }
51 | s.simplifyLiteral(typ, eltType, x, px)
52 | }
53 | // node was simplified - stop walk (there are no subnodes to simplify)
54 | return nil
55 | }
56 |
57 | case *ast.SliceExpr:
58 | // a slice expression of the form: s[a:len(s)]
59 | // can be simplified to: s[a:]
60 | // if s is "simple enough" (for now we only accept identifiers)
61 | //
62 | // Note: This may not be correct because len may have been redeclared in another
63 | // file belonging to the same package. However, this is extremely unlikely
64 | // and so far (April 2016, after years of supporting this rewrite feature)
65 | // has never come up, so let's keep it working as is (see also #15153).
66 | if n.Max != nil {
67 | // - 3-index slices always require the 2nd and 3rd index
68 | break
69 | }
70 | if s, _ := n.X.(*ast.Ident); s != nil && s.Obj != nil {
71 | // the array/slice object is a single, resolved identifier
72 | if call, _ := n.High.(*ast.CallExpr); call != nil && len(call.Args) == 1 && !call.Ellipsis.IsValid() {
73 | // the high expression is a function call with a single argument
74 | if fun, _ := call.Fun.(*ast.Ident); fun != nil && fun.Name == "len" && fun.Obj == nil {
75 | // the function called is "len" and it is not locally defined; and
76 | // because we don't have dot imports, it must be the predefined len()
77 | if arg, _ := call.Args[0].(*ast.Ident); arg != nil && arg.Obj == s.Obj {
78 | // the len argument is the array/slice object
79 | n.High = nil
80 | }
81 | }
82 | }
83 | }
84 | // Note: We could also simplify slice expressions of the form s[0:b] to s[:b]
85 | // but we leave them as is since sometimes we want to be very explicit
86 | // about the lower bound.
87 | // An example where the 0 helps:
88 | // x, y, z := b[0:2], b[2:4], b[4:6]
89 | // An example where it does not:
90 | // x, y := b[:n], b[n:]
91 |
92 | case *ast.RangeStmt:
93 | // - a range of the form: for x, _ = range v {...}
94 | // can be simplified to: for x = range v {...}
95 | // - a range of the form: for _ = range v {...}
96 | // can be simplified to: for range v {...}
97 | if isBlank(n.Value) {
98 | n.Value = nil
99 | }
100 | if isBlank(n.Key) && n.Value == nil {
101 | n.Key = nil
102 | }
103 | }
104 |
105 | return s
106 | }
107 |
108 | func (s simplifier) simplifyLiteral(typ reflect.Value, astType, x ast.Expr, px *ast.Expr) {
109 | ast.Walk(s, x) // simplify x
110 |
111 | // if the element is a composite literal and its literal type
112 | // matches the outer literal's element type exactly, the inner
113 | // literal type may be omitted
114 | if inner, ok := x.(*ast.CompositeLit); ok {
115 | if match(nil, typ, reflect.ValueOf(inner.Type)) {
116 | inner.Type = nil
117 | }
118 | }
119 | // if the outer literal's element type is a pointer type *T
120 | // and the element is & of a composite literal of type T,
121 | // the inner &T may be omitted.
122 | if ptr, ok := astType.(*ast.StarExpr); ok {
123 | if addr, ok := x.(*ast.UnaryExpr); ok && addr.Op == token.AND {
124 | if inner, ok := addr.X.(*ast.CompositeLit); ok {
125 | if match(nil, reflect.ValueOf(ptr.X), reflect.ValueOf(inner.Type)) {
126 | inner.Type = nil // drop T
127 | *px = inner // drop &
128 | }
129 | }
130 | }
131 | }
132 | }
133 |
134 | func isBlank(x ast.Expr) bool {
135 | ident, ok := x.(*ast.Ident)
136 | return ok && ident.Name == "_"
137 | }
138 |
139 | func simplify(f *ast.File) {
140 | // remove empty declarations such as "const ()", etc
141 | removeEmptyDeclGroups(f)
142 |
143 | var s simplifier
144 | ast.Walk(s, f)
145 | }
146 |
147 | func removeEmptyDeclGroups(f *ast.File) {
148 | i := 0
149 | for _, d := range f.Decls {
150 | if g, ok := d.(*ast.GenDecl); !ok || !isEmpty(f, g) {
151 | f.Decls[i] = d
152 | i++
153 | }
154 | }
155 | f.Decls = f.Decls[:i]
156 | }
157 |
158 | func isEmpty(f *ast.File, g *ast.GenDecl) bool {
159 | if g.Doc != nil || g.Specs != nil {
160 | return false
161 | }
162 |
163 | for _, c := range f.Comments {
164 | // if there is a comment in the declaration, it is not considered empty
165 | if g.Pos() <= c.Pos() && c.End() <= g.End() {
166 | return false
167 | }
168 | }
169 |
170 | return true
171 | }
172 |
--------------------------------------------------------------------------------
/testdata/basic.golden:
--------------------------------------------------------------------------------
1 | package test
2 |
3 | GlobInt int = 2
4 | GlobStr = "a string"
5 | GlobBool = False
6 |
7 | class MyStru:
8 | # A struct definition
9 | """
10 | A struct definition
11 |
12 | # field desc
13 | """
14 |
15 | def __init__(self):
16 | self.A = int()
17 | self.B = float() # `desc:"field tag"`
18 | self.C = str() # `desc:"more tags"`
19 |
20 | def MethOne(st, arg1):
21 | """
22 | MethOne does something
23 | """
24 | rv = st.A
25 | for a in SomeList :
26 | rv += a
27 | st.A = True
28 |
29 | ano = MyStru(A= 22, B= 44.2, C= "happy")
30 |
31 | return rv
32 |
33 | def MethTwo(st, arg1, arg2, arg3):
34 | """
35 | MethTwo does something
36 | it is pretty cool
37 | not really sure about that
38 | """
39 | rv = st.A
40 | for a in range(100):
41 | rv += a
42 | switch rv:
43 | if 100:
44 | rv *= 2
45 | if 500:
46 | rv /= 5
47 | return rv
48 |
49 |
50 | # A global function
51 | def GlobFun(a, b):
52 | """
53 | A global function
54 | """
55 | if a > b and a == 0 or b == 0:
56 | return a + b
57 | elif a == b:
58 | return a * b
59 | else:
60 | return a - b
61 |
62 |
--------------------------------------------------------------------------------
/testdata/basic.input:
--------------------------------------------------------------------------------
1 | package test
2 |
3 | var (
4 | GlobInt int = 2
5 | GlobStr = "a string"
6 | GlobBool = false
7 | )
8 |
9 | // A struct definition
10 | type MyStru struct {
11 | A int // field desc
12 | B float32 `desc:"field tag"`
13 | C string `desc:"more tags"`
14 | }
15 |
16 | // A global function
17 | func GlobFun(a, b float32) float32 {
18 | if a > b && a == 0 || b == 0 {
19 | return a + b
20 | } else if a == b {
21 | return a * b
22 | } else {
23 | return a - b
24 | }
25 | }
26 |
27 | // MethOne does something
28 | func (st *MyStru) MethOne(arg1 float32) int {
29 | rv := st.A
30 | for _, a := range SomeList {
31 | rv += a
32 | }
33 | st.A = true
34 |
35 | ano := MyStru{A: 22, B: 44.2, C: "happy"}
36 |
37 | return rv
38 | }
39 |
40 | // MethTwo does something
41 | // it is pretty cool
42 | // not really sure about that
43 | func (st *MyStru) MethTwo(arg1, arg2 float32, arg3 int) int {
44 | rv := st.A
45 | for a := 0; a < 100; a++ {
46 | rv += a
47 | }
48 | switch rv {
49 | case 100:
50 | rv *= 2
51 | case 500:
52 | rv /= 5
53 | }
54 | return rv
55 | }
56 |
--------------------------------------------------------------------------------
/testdata/ra25.input:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, The Emergent Authors. All rights reserved.
2 | // Use of this source code is governed by a BSD-style
3 | // license that can be found in the LICENSE file.
4 |
5 | //gofmt -gogi
6 |
7 | // ra25 runs a simple random-associator four-layer leabra network
8 | // that uses the standard supervised learning paradigm to learn
9 | // mappings between 25 random input / output patterns
10 | // defined over 5x5 input / output layers (i.e., 25 units)
11 | package main
12 |
13 | import (
14 | "flag"
15 | "fmt"
16 | "log"
17 | "math/rand"
18 | "os"
19 | "strconv"
20 | "strings"
21 | "time"
22 |
23 | "github.com/emer/emergent/emer"
24 | "github.com/emer/emergent/env"
25 | "github.com/emer/emergent/netview"
26 | "github.com/emer/emergent/params"
27 | "github.com/emer/emergent/patgen"
28 | "github.com/emer/emergent/prjn"
29 | "github.com/emer/emergent/relpos"
30 | "github.com/emer/etable/agg"
31 | "github.com/emer/etable/eplot"
32 | "github.com/emer/etable/etable"
33 | "github.com/emer/etable/etensor"
34 | _ "github.com/emer/etable/etview" // include to get gui views
35 | "github.com/emer/etable/split"
36 | "github.com/emer/leabra/leabra"
37 | "github.com/goki/gi/gi"
38 | "github.com/goki/gi/gimain"
39 | "github.com/goki/gi/giv"
40 | "github.com/goki/ki/ki"
41 | "github.com/goki/ki/kit"
42 | "github.com/goki/mat32"
43 | )
44 |
45 | func main() {
46 | TheSim.New()
47 | TheSim.Config()
48 | if len(os.Args) > 1 {
49 | TheSim.CmdArgs() // simple assumption is that any args = no gui -- could add explicit arg if you want
50 | } else {
51 | gimain.Main(func() { // this starts gui -- requires valid OpenGL display connection (e.g., X11)
52 | guirun()
53 | })
54 | }
55 | }
56 |
57 | func guirun() {
58 | TheSim.Init()
59 | win := TheSim.ConfigGui()
60 | win.StartEventLoop()
61 | }
62 |
63 | // LogPrec is precision for saving float values in logs
64 | const LogPrec = 4
65 |
66 | // ParamSets is the default set of parameters -- Base is always applied, and others can be optionally
67 | // selected to apply on top of that
68 | var ParamSets = params.Sets{
69 | {Name: "Base", Desc: "these are the best params", Sheets: params.Sheets{
70 | "Network": ¶ms.Sheet{
71 | {Sel: "Prjn", Desc: "norm and momentum on works better, but wt bal is not better for smaller nets",
72 | Params: params.Params{
73 | "Prjn.Learn.Norm.On": "true",
74 | "Prjn.Learn.Momentum.On": "true",
75 | "Prjn.Learn.WtBal.On": "false",
76 | }},
77 | {Sel: "Layer", Desc: "using default 1.8 inhib for all of network -- can explore",
78 | Params: params.Params{
79 | "Layer.Inhib.Layer.Gi": "1.8",
80 | "Layer.Act.Gbar.L": "0.1", // set explictly, new default, a bit better vs 0.2
81 | }},
82 | {Sel: ".Back", Desc: "top-down back-projections MUST have lower relative weight scale, otherwise network hallucinates",
83 | Params: params.Params{
84 | "Prjn.WtScale.Rel": "0.2",
85 | }},
86 | {Sel: "#Output", Desc: "output definitely needs lower inhib -- true for smaller layers in general",
87 | Params: params.Params{
88 | "Layer.Inhib.Layer.Gi": "1.4",
89 | }},
90 | },
91 | "Sim": ¶ms.Sheet{ // sim params apply to sim object
92 | {Sel: "Sim", Desc: "best params always finish in this time",
93 | Params: params.Params{
94 | "Sim.MaxEpcs": "50",
95 | }},
96 | },
97 | }},
98 | {Name: "DefaultInhib", Desc: "output uses default inhib instead of lower", Sheets: params.Sheets{
99 | "Network": ¶ms.Sheet{
100 | {Sel: "#Output", Desc: "go back to default",
101 | Params: params.Params{
102 | "Layer.Inhib.Layer.Gi": "1.8",
103 | }},
104 | },
105 | "Sim": ¶ms.Sheet{ // sim params apply to sim object
106 | {Sel: "Sim", Desc: "takes longer -- generally doesn't finish..",
107 | Params: params.Params{
108 | "Sim.MaxEpcs": "100",
109 | }},
110 | },
111 | }},
112 | {Name: "NoMomentum", Desc: "no momentum or normalization", Sheets: params.Sheets{
113 | "Network": ¶ms.Sheet{
114 | {Sel: "Prjn", Desc: "no norm or momentum",
115 | Params: params.Params{
116 | "Prjn.Learn.Norm.On": "false",
117 | "Prjn.Learn.Momentum.On": "false",
118 | }},
119 | },
120 | }},
121 | {Name: "WtBalOn", Desc: "try with weight bal on", Sheets: params.Sheets{
122 | "Network": ¶ms.Sheet{
123 | {Sel: "Prjn", Desc: "weight bal on",
124 | Params: params.Params{
125 | "Prjn.Learn.WtBal.On": "true",
126 | }},
127 | },
128 | }},
129 | }
130 |
131 | // Sim encapsulates the entire simulation model, and we define all the
132 | // functionality as methods on this struct. This structure keeps all relevant
133 | // state information organized and available without having to pass everything around
134 | // as arguments to methods, and provides the core GUI interface (note the view tags
135 | // for the fields which provide hints to how things should be displayed).
136 | type Sim struct {
137 | Net *leabra.Network `view:"no-inline" desc:"the network -- click to view / edit parameters for layers, prjns, etc"`
138 | Pats *etable.Table `view:"no-inline" desc:"the training patterns to use"`
139 | TrnEpcLog *etable.Table `view:"no-inline" desc:"training epoch-level log data"`
140 | TstEpcLog *etable.Table `view:"no-inline" desc:"testing epoch-level log data"`
141 | TstTrlLog *etable.Table `view:"no-inline" desc:"testing trial-level log data"`
142 | TstErrLog *etable.Table `view:"no-inline" desc:"log of all test trials where errors were made"`
143 | TstErrStats *etable.Table `view:"no-inline" desc:"stats on test trials where errors were made"`
144 | TstCycLog *etable.Table `view:"no-inline" desc:"testing cycle-level log data"`
145 | RunLog *etable.Table `view:"no-inline" desc:"summary log of each run"`
146 | RunStats *etable.Table `view:"no-inline" desc:"aggregate stats on all runs"`
147 | Params params.Sets `view:"no-inline" desc:"full collection of param sets"`
148 | ParamSet string `desc:"which set of *additional* parameters to use -- always applies Base and optionaly this next if set -- can use multiple names separated by spaces (don't put spaces in ParamSet names!)"`
149 | Tag string `desc:"extra tag string to add to any file names output from sim (e.g., weights files, log files, params for run)"`
150 | MaxRuns int `desc:"maximum number of model runs to perform"`
151 | MaxEpcs int `desc:"maximum number of epochs to run per model run"`
152 | NZeroStop int `desc:"if a positive number, training will stop after this many epochs with zero SSE"`
153 | TrainEnv env.FixedTable `desc:"Training environment -- contains everything about iterating over input / output patterns over training"`
154 | TestEnv env.FixedTable `desc:"Testing environment -- manages iterating over testing"`
155 | Time leabra.Time `desc:"leabra timing parameters and state"`
156 | ViewOn bool `desc:"whether to update the network view while running"`
157 | TrainUpdt leabra.TimeScales `desc:"at what time scale to update the display during training? Anything longer than Epoch updates at Epoch in this model"`
158 | TestUpdt leabra.TimeScales `desc:"at what time scale to update the display during testing? Anything longer than Epoch updates at Epoch in this model"`
159 | TestInterval int `desc:"how often to run through all the test patterns, in terms of training epochs -- can use 0 or -1 for no testing"`
160 | LayStatNms []string `desc:"names of layers to collect more detailed stats on (avg act, etc)"`
161 |
162 | // statistics: note use float64 as that is best for etable.Table
163 | TrlErr float64 `inactive:"+" desc:"1 if trial was error, 0 if correct -- based on SSE = 0 (subject to .5 unit-wise tolerance)"`
164 | TrlSSE float64 `inactive:"+" desc:"current trial's sum squared error"`
165 | TrlAvgSSE float64 `inactive:"+" desc:"current trial's average sum squared error"`
166 | TrlCosDiff float64 `inactive:"+" desc:"current trial's cosine difference"`
167 | EpcSSE float64 `inactive:"+" desc:"last epoch's total sum squared error"`
168 | EpcAvgSSE float64 `inactive:"+" desc:"last epoch's average sum squared error (average over trials, and over units within layer)"`
169 | EpcPctErr float64 `inactive:"+" desc:"last epoch's average TrlErr"`
170 | EpcPctCor float64 `inactive:"+" desc:"1 - last epoch's average TrlErr"`
171 | EpcCosDiff float64 `inactive:"+" desc:"last epoch's average cosine difference for output layer (a normalized error measure, maximum of 1 when the minus phase exactly matches the plus)"`
172 | EpcPerTrlMSec float64 `inactive:"+" desc:"how long did the epoch take per trial in wall-clock milliseconds"`
173 | FirstZero int `inactive:"+" desc:"epoch at when SSE first went to zero"`
174 | NZero int `inactive:"+" desc:"number of epochs in a row with zero SSE"`
175 |
176 | // internal state - view:"-"
177 | SumErr float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"`
178 | SumSSE float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"`
179 | SumAvgSSE float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"`
180 | SumCosDiff float64 `view:"-" inactive:"+" desc:"sum to increment as we go through epoch"`
181 | Win *gi.Window `view:"-" desc:"main GUI window"`
182 | NetView *netview.NetView `view:"-" desc:"the network viewer"`
183 | ToolBar *gi.ToolBar `view:"-" desc:"the master toolbar"`
184 | TrnEpcPlot *eplot.Plot2D `view:"-" desc:"the training epoch plot"`
185 | TstEpcPlot *eplot.Plot2D `view:"-" desc:"the testing epoch plot"`
186 | TstTrlPlot *eplot.Plot2D `view:"-" desc:"the test-trial plot"`
187 | TstCycPlot *eplot.Plot2D `view:"-" desc:"the test-cycle plot"`
188 | RunPlot *eplot.Plot2D `view:"-" desc:"the run plot"`
189 | TrnEpcFile *os.File `view:"-" desc:"log file"`
190 | RunFile *os.File `view:"-" desc:"log file"`
191 | ValsTsrs map[string]*etensor.Float32 `view:"-" desc:"for holding layer values"`
192 | SaveWts bool `view:"-" desc:"for command-line run only, auto-save final weights after each run"`
193 | NoGui bool `view:"-" desc:"if true, runing in no GUI mode"`
194 | LogSetParams bool `view:"-" desc:"if true, print message for all params that are set"`
195 | IsRunning bool `view:"-" desc:"true if sim is running"`
196 | StopNow bool `view:"-" desc:"flag to stop running"`
197 | NeedsNewRun bool `view:"-" desc:"flag to initialize NewRun if last one finished"`
198 | RndSeed int64 `view:"-" desc:"the current random seed"`
199 | LastEpcTime time.Time `view:"-" desc:"timer for last epoch"`
200 | }
201 |
202 | // this registers this Sim Type and gives it properties that e.g.,
203 | // prompt for filename for save methods.
204 | var KiT_Sim = kit.Types.AddType(&Sim{}, SimProps)
205 |
206 | // TheSim is the overall state for this simulation
207 | var TheSim Sim
208 |
209 | // New creates new blank elements and initializes defaults
210 | func (ss *Sim) New() {
211 | ss.Net = &leabra.Network{}
212 | ss.Pats = &etable.Table{}
213 | ss.TrnEpcLog = &etable.Table{}
214 | ss.TstEpcLog = &etable.Table{}
215 | ss.TstTrlLog = &etable.Table{}
216 | ss.TstCycLog = &etable.Table{}
217 | ss.RunLog = &etable.Table{}
218 | ss.RunStats = &etable.Table{}
219 | ss.Params = ParamSets
220 | ss.RndSeed = 1
221 | ss.ViewOn = true
222 | ss.TrainUpdt = leabra.AlphaCycle
223 | ss.TestUpdt = leabra.Cycle
224 | ss.TestInterval = 5
225 | ss.LayStatNms = []string{"Hidden1", "Hidden2", "Output"}
226 | }
227 |
228 | ////////////////////////////////////////////////////////////////////////////////////////////
229 | // Configs
230 |
231 | // Config configures all the elements using the standard functions
232 | func (ss *Sim) Config() {
233 | //ss.ConfigPats()
234 | ss.OpenPats()
235 | ss.ConfigEnv()
236 | ss.ConfigNet(ss.Net)
237 | ss.ConfigTrnEpcLog(ss.TrnEpcLog)
238 | ss.ConfigTstEpcLog(ss.TstEpcLog)
239 | ss.ConfigTstTrlLog(ss.TstTrlLog)
240 | ss.ConfigTstCycLog(ss.TstCycLog)
241 | ss.ConfigRunLog(ss.RunLog)
242 | }
243 |
244 | func (ss *Sim) ConfigEnv() {
245 | if ss.MaxRuns == 0 { // allow user override
246 | ss.MaxRuns = 10
247 | }
248 | if ss.MaxEpcs == 0 { // allow user override
249 | ss.MaxEpcs = 50
250 | ss.NZeroStop = 5
251 | }
252 |
253 | ss.TrainEnv.Nm = "TrainEnv"
254 | ss.TrainEnv.Dsc = "training params and state"
255 | ss.TrainEnv.Table = etable.NewIdxView(ss.Pats)
256 | ss.TrainEnv.Validate()
257 | ss.TrainEnv.Run.Max = ss.MaxRuns // note: we are not setting epoch max -- do that manually
258 |
259 | ss.TestEnv.Nm = "TestEnv"
260 | ss.TestEnv.Dsc = "testing params and state"
261 | ss.TestEnv.Table = etable.NewIdxView(ss.Pats)
262 | ss.TestEnv.Sequential = true
263 | ss.TestEnv.Validate()
264 |
265 | // note: to create a train / test split of pats, do this:
266 | // all := etable.NewIdxView(ss.Pats)
267 | // splits, _ := split.Permuted(all, []float64{.8, .2}, []string{"Train", "Test"})
268 | // ss.TrainEnv.Table = splits.Splits[0]
269 | // ss.TestEnv.Table = splits.Splits[1]
270 |
271 | ss.TrainEnv.Init(0)
272 | ss.TestEnv.Init(0)
273 | }
274 |
275 | func (ss *Sim) ConfigNet(net *leabra.Network) {
276 | net.InitName(net, "RA25")
277 | inp := net.AddLayer2D("Input", 5, 5, emer.Input)
278 | hid1 := net.AddLayer2D("Hidden1", 7, 7, emer.Hidden)
279 | hid2 := net.AddLayer4D("Hidden2", 2, 4, 3, 2, emer.Hidden)
280 | out := net.AddLayer2D("Output", 5, 5, emer.Target)
281 |
282 | // use this to position layers relative to each other
283 | // default is Above, YAlign = Front, XAlign = Center
284 | hid2.SetRelPos(relpos.Rel{Rel: relpos.RightOf, Other: "Hidden1", YAlign: relpos.Front, Space: 2})
285 |
286 | // note: see emergent/prjn module for all the options on how to connect
287 | // NewFull returns a new prjn.Full connectivity pattern
288 | full := prjn.NewFull()
289 |
290 | net.ConnectLayers(inp, hid1, full, emer.Forward)
291 | net.BidirConnectLayers(hid1, hid2, full)
292 | net.BidirConnectLayers(hid2, out, full)
293 |
294 | // note: can set these to do parallel threaded computation across multiple cpus
295 | // not worth it for this small of a model, but definitely helps for larger ones
296 | // if Thread {
297 | // hid2.SetThread(1)
298 | // out.SetThread(1)
299 | // }
300 |
301 | // note: if you wanted to change a layer type from e.g., Target to Compare, do this:
302 | // out.SetType(emer.Compare)
303 | // that would mean that the output layer doesn't reflect target values in plus phase
304 | // and thus removes error-driven learning -- but stats are still computed.
305 |
306 | net.Defaults()
307 | ss.SetParams("Network", ss.LogSetParams) // only set Network params
308 | err := net.Build()
309 | if err != nil {
310 | log.Println(err)
311 | return
312 | }
313 | net.InitWts()
314 | }
315 |
316 | ////////////////////////////////////////////////////////////////////////////////
317 | // Init, utils
318 |
319 | // Init restarts the run, and initializes everything, including network weights
320 | // and resets the epoch log table
321 | func (ss *Sim) Init() {
322 | rand.Seed(ss.RndSeed)
323 | ss.ConfigEnv() // re-config env just in case a different set of patterns was
324 | // selected or patterns have been modified etc
325 | ss.StopNow = false
326 | ss.SetParams("", ss.LogSetParams) // all sheets
327 | ss.NewRun()
328 | ss.UpdateView(true)
329 | }
330 |
331 | // NewRndSeed gets a new random seed based on current time -- otherwise uses
332 | // the same random seed for every run
333 | func (ss *Sim) NewRndSeed() {
334 | ss.RndSeed = time.Now().UnixNano()
335 | }
336 |
337 | // Counters returns a string of the current counter state
338 | // use tabs to achieve a reasonable formatting overall
339 | // and add a few tabs at the end to allow for expansion..
340 | func (ss *Sim) Counters(train bool) string {
341 | if train {
342 | return fmt.Sprintf("Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t", ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TrainEnv.Trial.Cur, ss.Time.Cycle, ss.TrainEnv.TrialName.Cur)
343 | } else {
344 | return fmt.Sprintf("Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t", ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TestEnv.Trial.Cur, ss.Time.Cycle, ss.TestEnv.TrialName.Cur)
345 | }
346 | }
347 |
348 | func (ss *Sim) UpdateView(train bool) {
349 | if ss.NetView != nil && ss.NetView.IsVisible() {
350 | ss.NetView.Record(ss.Counters(train))
351 | // note: essential to use Go version of update when called from another goroutine
352 | ss.NetView.GoUpdate() // note: using counters is significantly slower..
353 | }
354 | }
355 |
356 | ////////////////////////////////////////////////////////////////////////////////
357 | // Running the Network, starting bottom-up..
358 |
359 | // AlphaCyc runs one alpha-cycle (100 msec, 4 quarters) of processing.
360 | // External inputs must have already been applied prior to calling,
361 | // using ApplyExt method on relevant layers (see TrainTrial, TestTrial).
362 | // If train is true, then learning DWt or WtFmDWt calls are made.
363 | // Handles netview updating within scope of AlphaCycle
364 | func (ss *Sim) AlphaCyc(train bool) {
365 | // ss.Win.PollEvents() // this can be used instead of running in a separate goroutine
366 | viewUpdt := ss.TrainUpdt
367 | if !train {
368 | viewUpdt = ss.TestUpdt
369 | }
370 |
371 | // update prior weight changes at start, so any DWt values remain visible at end
372 | // you might want to do this less frequently to achieve a mini-batch update
373 | // in which case, move it out to the TrainTrial method where the relevant
374 | // counters are being dealt with.
375 | if train {
376 | ss.Net.WtFmDWt()
377 | }
378 |
379 | ss.Net.AlphaCycInit()
380 | ss.Time.AlphaCycStart()
381 | for qtr := 0; qtr < 4; qtr++ {
382 | for cyc := 0; cyc < ss.Time.CycPerQtr; cyc++ {
383 | ss.Net.Cycle(&ss.Time)
384 | if !train {
385 | ss.LogTstCyc(ss.TstCycLog, ss.Time.Cycle)
386 | }
387 | ss.Time.CycleInc()
388 | if ss.ViewOn {
389 | switch viewUpdt {
390 | case leabra.Cycle:
391 | if cyc != ss.Time.CycPerQtr-1 { // will be updated by quarter
392 | ss.UpdateView(train)
393 | }
394 | case leabra.FastSpike:
395 | if (cyc+1)%10 == 0 {
396 | ss.UpdateView(train)
397 | }
398 | }
399 | }
400 | }
401 | ss.Net.QuarterFinal(&ss.Time)
402 | ss.Time.QuarterInc()
403 | if ss.ViewOn {
404 | switch {
405 | case viewUpdt <= leabra.Quarter:
406 | ss.UpdateView(train)
407 | case viewUpdt == leabra.Phase:
408 | if qtr >= 2 {
409 | ss.UpdateView(train)
410 | }
411 | }
412 | }
413 | }
414 |
415 | if train {
416 | ss.Net.DWt()
417 | }
418 | if ss.ViewOn && viewUpdt == leabra.AlphaCycle {
419 | ss.UpdateView(train)
420 | }
421 | if !train {
422 | ss.TstCycPlot.GoUpdate() // make sure up-to-date at end
423 | }
424 | }
425 |
426 | // ApplyInputs applies input patterns from given envirbonment.
427 | // It is good practice to have this be a separate method with appropriate
428 | // args so that it can be used for various different contexts
429 | // (training, testing, etc).
430 | func (ss *Sim) ApplyInputs(en env.Env) {
431 | // ss.Net.InitExt() // clear any existing inputs -- not strictly necessary if always
432 | // going to the same layers, but good practice and cheap anyway
433 |
434 | lays := []string{"Input", "Output"}
435 | for _, lnm := range lays {
436 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra()
437 | pats := en.State(ly.Nm)
438 | if pats != nil {
439 | ly.ApplyExt(pats)
440 | }
441 | }
442 | }
443 |
444 | // TrainTrial runs one trial of training using TrainEnv
445 | func (ss *Sim) TrainTrial() {
446 | if ss.NeedsNewRun {
447 | ss.NewRun()
448 | }
449 |
450 | ss.TrainEnv.Step() // the Env encapsulates and manages all counter state
451 |
452 | // Key to query counters FIRST because current state is in NEXT epoch
453 | // if epoch counter has changed
454 | epc, _, chg := ss.TrainEnv.Counter(env.Epoch)
455 | if chg {
456 | ss.LogTrnEpc(ss.TrnEpcLog)
457 | if ss.ViewOn && ss.TrainUpdt > leabra.AlphaCycle {
458 | ss.UpdateView(true)
459 | }
460 | if ss.TestInterval > 0 && epc%ss.TestInterval == 0 { // note: epc is *next* so won't trigger first time
461 | ss.TestAll()
462 | }
463 | if epc >= ss.MaxEpcs || (ss.NZeroStop > 0 && ss.NZero >= ss.NZeroStop) {
464 | // done with training..
465 | ss.RunEnd()
466 | if ss.TrainEnv.Run.Incr() { // we are done!
467 | ss.StopNow = true
468 | return
469 | } else {
470 | ss.NeedsNewRun = true
471 | return
472 | }
473 | }
474 | }
475 |
476 | ss.ApplyInputs(&ss.TrainEnv)
477 | ss.AlphaCyc(true) // train
478 | ss.TrialStats(true) // accumulate
479 | }
480 |
481 | // RunEnd is called at the end of a run -- save weights, record final log, etc here
482 | func (ss *Sim) RunEnd() {
483 | ss.LogRun(ss.RunLog)
484 | if ss.SaveWts {
485 | fnm := ss.WeightsFileName()
486 | fmt.Printf("Saving Weights to: %s\n", fnm)
487 | ss.Net.SaveWtsJSON(gi.FileName(fnm))
488 | }
489 | }
490 |
491 | // NewRun intializes a new run of the model, using the TrainEnv.Run counter
492 | // for the new run value
493 | func (ss *Sim) NewRun() {
494 | run := ss.TrainEnv.Run.Cur
495 | ss.TrainEnv.Init(run)
496 | ss.TestEnv.Init(run)
497 | ss.Time.Reset()
498 | ss.Net.InitWts()
499 | ss.InitStats()
500 | ss.TrnEpcLog.SetNumRows(0)
501 | ss.TstEpcLog.SetNumRows(0)
502 | ss.NeedsNewRun = false
503 | }
504 |
505 | // InitStats initializes all the statistics, especially important for the
506 | // cumulative epoch stats -- called at start of new run
507 | func (ss *Sim) InitStats() {
508 | // accumulators
509 | ss.SumErr = 0
510 | ss.SumSSE = 0
511 | ss.SumAvgSSE = 0
512 | ss.SumCosDiff = 0
513 | ss.FirstZero = -1
514 | ss.NZero = 0
515 | // clear rest just to make Sim look initialized
516 | ss.TrlErr = 0
517 | ss.TrlSSE = 0
518 | ss.TrlAvgSSE = 0
519 | ss.EpcSSE = 0
520 | ss.EpcAvgSSE = 0
521 | ss.EpcPctErr = 0
522 | ss.EpcCosDiff = 0
523 | }
524 |
525 | // TrialStats computes the trial-level statistics and adds them to the epoch accumulators if
526 | // accum is true. Note that we're accumulating stats here on the Sim side so the
527 | // core algorithm side remains as simple as possible, and doesn't need to worry about
528 | // different time-scales over which stats could be accumulated etc.
529 | // You can also aggregate directly from log data, as is done for testing stats
530 | func (ss *Sim) TrialStats(accum bool) {
531 | out := ss.Net.LayerByName("Output").(leabra.LeabraLayer).AsLeabra()
532 | ss.TrlCosDiff = float64(out.CosDiff.Cos)
533 | ss.TrlSSE, ss.TrlAvgSSE = out.MSE(0.5) // 0.5 = per-unit tolerance -- right side of .5
534 | if ss.TrlSSE > 0 {
535 | ss.TrlErr = 1
536 | } else {
537 | ss.TrlErr = 0
538 | }
539 | if accum {
540 | ss.SumErr += ss.TrlErr
541 | ss.SumSSE += ss.TrlSSE
542 | ss.SumAvgSSE += ss.TrlAvgSSE
543 | ss.SumCosDiff += ss.TrlCosDiff
544 | }
545 | }
546 |
547 | // TrainEpoch runs training trials for remainder of this epoch
548 | func (ss *Sim) TrainEpoch() {
549 | ss.StopNow = false
550 | curEpc := ss.TrainEnv.Epoch.Cur
551 | for {
552 | ss.TrainTrial()
553 | if ss.StopNow || ss.TrainEnv.Epoch.Cur != curEpc {
554 | break
555 | }
556 | }
557 | ss.Stopped()
558 | }
559 |
560 | // TrainRun runs training trials for remainder of run
561 | func (ss *Sim) TrainRun() {
562 | ss.StopNow = false
563 | curRun := ss.TrainEnv.Run.Cur
564 | for {
565 | ss.TrainTrial()
566 | if ss.StopNow || ss.TrainEnv.Run.Cur != curRun {
567 | break
568 | }
569 | }
570 | ss.Stopped()
571 | }
572 |
573 | // Train runs the full training from this point onward
574 | func (ss *Sim) Train() {
575 | ss.StopNow = false
576 | for {
577 | ss.TrainTrial()
578 | if ss.StopNow {
579 | break
580 | }
581 | }
582 | ss.Stopped()
583 | }
584 |
585 | // Stop tells the sim to stop running
586 | func (ss *Sim) Stop() {
587 | ss.StopNow = true
588 | }
589 |
590 | // Stopped is called when a run method stops running -- updates the IsRunning flag and toolbar
591 | func (ss *Sim) Stopped() {
592 | ss.IsRunning = false
593 | if ss.Win != nil {
594 | vp := ss.Win.WinViewport2D()
595 | if ss.ToolBar != nil {
596 | ss.ToolBar.UpdateActions()
597 | }
598 | vp.SetNeedsFullRender()
599 | }
600 | }
601 |
602 | // SaveWeights saves the network weights -- when called with giv.CallMethod
603 | // it will auto-prompt for filename
604 | func (ss *Sim) SaveWeights(filename gi.FileName) {
605 | ss.Net.SaveWtsJSON(filename)
606 | }
607 |
608 | ////////////////////////////////////////////////////////////////////////////////////////////
609 | // Testing
610 |
611 | // TestTrial runs one trial of testing -- always sequentially presented inputs
612 | func (ss *Sim) TestTrial(returnOnChg bool) {
613 | ss.TestEnv.Step()
614 |
615 | // Query counters FIRST
616 | _, _, chg := ss.TestEnv.Counter(env.Epoch)
617 | if chg {
618 | if ss.ViewOn && ss.TestUpdt > leabra.AlphaCycle {
619 | ss.UpdateView(false)
620 | }
621 | ss.LogTstEpc(ss.TstEpcLog)
622 | if returnOnChg {
623 | return
624 | }
625 | }
626 |
627 | ss.ApplyInputs(&ss.TestEnv)
628 | ss.AlphaCyc(false) // !train
629 | ss.TrialStats(false) // !accumulate
630 | ss.LogTstTrl(ss.TstTrlLog)
631 | }
632 |
633 | // TestItem tests given item which is at given index in test item list
634 | func (ss *Sim) TestItem(idx int) {
635 | cur := ss.TestEnv.Trial.Cur
636 | ss.TestEnv.Trial.Cur = idx
637 | ss.TestEnv.SetTrialName()
638 | ss.ApplyInputs(&ss.TestEnv)
639 | ss.AlphaCyc(false) // !train
640 | ss.TrialStats(false) // !accumulate
641 | ss.TestEnv.Trial.Cur = cur
642 | }
643 |
644 | // TestAll runs through the full set of testing items
645 | func (ss *Sim) TestAll() {
646 | ss.TestEnv.Init(ss.TrainEnv.Run.Cur)
647 | for {
648 | ss.TestTrial(true) // return on change -- don't wrap
649 | _, _, chg := ss.TestEnv.Counter(env.Epoch)
650 | if chg || ss.StopNow {
651 | break
652 | }
653 | }
654 | }
655 |
656 | // RunTestAll runs through the full set of testing items, has stop running = false at end -- for gui
657 | func (ss *Sim) RunTestAll() {
658 | ss.StopNow = false
659 | ss.TestAll()
660 | ss.Stopped()
661 | }
662 |
663 | /////////////////////////////////////////////////////////////////////////
664 | // Params setting
665 |
666 | // ParamsName returns name of current set of parameters
667 | func (ss *Sim) ParamsName() string {
668 | if ss.ParamSet == "" {
669 | return "Base"
670 | }
671 | return ss.ParamSet
672 | }
673 |
674 | // SetParams sets the params for "Base" and then current ParamSet.
675 | // If sheet is empty, then it applies all avail sheets (e.g., Network, Sim)
676 | // otherwise just the named sheet
677 | // if setMsg = true then we output a message for each param that was set.
678 | func (ss *Sim) SetParams(sheet string, setMsg bool) error {
679 | if sheet == "" {
680 | // this is important for catching typos and ensuring that all sheets can be used
681 | ss.Params.ValidateSheets([]string{"Network", "Sim"})
682 | }
683 | err := ss.SetParamsSet("Base", sheet, setMsg)
684 | if ss.ParamSet != "" && ss.ParamSet != "Base" {
685 | sps := strings.Fields(ss.ParamSet)
686 | for _, ps := range sps {
687 | err = ss.SetParamsSet(ps, sheet, setMsg)
688 | }
689 | }
690 | return err
691 | }
692 |
693 | // SetParamsSet sets the params for given params.Set name.
694 | // If sheet is empty, then it applies all avail sheets (e.g., Network, Sim)
695 | // otherwise just the named sheet
696 | // if setMsg = true then we output a message for each param that was set.
697 | func (ss *Sim) SetParamsSet(setNm string, sheet string, setMsg bool) error {
698 | pset, err := ss.Params.SetByNameTry(setNm)
699 | if err != nil {
700 | return err
701 | }
702 | if sheet == "" || sheet == "Network" {
703 | netp, ok := pset.Sheets["Network"]
704 | if ok {
705 | ss.Net.ApplyParams(netp, setMsg)
706 | }
707 | }
708 |
709 | if sheet == "" || sheet == "Sim" {
710 | simp, ok := pset.Sheets["Sim"]
711 | if ok {
712 | simp.Apply(ss, setMsg)
713 | }
714 | }
715 | // note: if you have more complex environments with parameters, definitely add
716 | // sheets for them, e.g., "TrainEnv", "TestEnv" etc
717 | return err
718 | }
719 |
720 | func (ss *Sim) ConfigPats() {
721 | dt := ss.Pats
722 | dt.SetMetaData("name", "TrainPats")
723 | dt.SetMetaData("desc", "Training patterns")
724 | dt.SetFromSchema(etable.Schema{
725 | {"Name", etensor.STRING, nil, nil},
726 | {"Input", etensor.FLOAT32, []int{5, 5}, []string{"Y", "X"}},
727 | {"Output", etensor.FLOAT32, []int{5, 5}, []string{"Y", "X"}},
728 | }, 25)
729 |
730 | patgen.PermutedBinaryRows(dt.Cols[1], 6, 1, 0)
731 | patgen.PermutedBinaryRows(dt.Cols[2], 6, 1, 0)
732 | dt.SaveCSV("random_5x5_25_gen.csv", etable.Comma, etable.Headers)
733 | }
734 |
735 | func (ss *Sim) OpenPats() {
736 | dt := ss.Pats
737 | dt.SetMetaData("name", "TrainPats")
738 | dt.SetMetaData("desc", "Training patterns")
739 | err := dt.OpenCSV("random_5x5_25.tsv", etable.Tab)
740 | if err != nil {
741 | log.Println(err)
742 | }
743 | }
744 |
745 | ////////////////////////////////////////////////////////////////////////////////////////////
746 | // Logging
747 |
748 | // ValsTsr gets value tensor of given name, creating if not yet made
749 | func (ss *Sim) ValsTsr(name string) *etensor.Float32 {
750 | if ss.ValsTsrs == nil {
751 | ss.ValsTsrs = make(map[string]*etensor.Float32)
752 | }
753 | tsr, ok := ss.ValsTsrs[name]
754 | if !ok {
755 | tsr = &etensor.Float32{}
756 | ss.ValsTsrs[name] = tsr
757 | }
758 | return tsr
759 | }
760 |
761 | // RunName returns a name for this run that combines Tag and Params -- add this to
762 | // any file names that are saved.
763 | func (ss *Sim) RunName() string {
764 | if ss.Tag != "" {
765 | return ss.Tag + "_" + ss.ParamsName()
766 | } else {
767 | return ss.ParamsName()
768 | }
769 | }
770 |
771 | // RunEpochName returns a string with the run and epoch numbers with leading zeros, suitable
772 | // for using in weights file names. Uses 3, 5 digits for each.
773 | func (ss *Sim) RunEpochName(run, epc int) string {
774 | return fmt.Sprintf("%03d_%05d", run, epc)
775 | }
776 |
777 | // WeightsFileName returns default current weights file name
778 | func (ss *Sim) WeightsFileName() string {
779 | return ss.Net.Nm + "_" + ss.RunName() + "_" + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur) + ".wts"
780 | }
781 |
782 | // LogFileName returns default log file name
783 | func (ss *Sim) LogFileName(lognm string) string {
784 | return ss.Net.Nm + "_" + ss.RunName() + "_" + lognm + ".tsv"
785 | }
786 |
787 | //////////////////////////////////////////////
788 | // TrnEpcLog
789 |
790 | // LogTrnEpc adds data from current epoch to the TrnEpcLog table.
791 | // computes epoch averages prior to logging.
792 | func (ss *Sim) LogTrnEpc(dt *etable.Table) {
793 | row := dt.Rows
794 | dt.SetNumRows(row + 1)
795 |
796 | epc := ss.TrainEnv.Epoch.Prv // this is triggered by increment so use previous value
797 | nt := float64(len(ss.TrainEnv.Order)) // number of trials in view
798 |
799 | ss.EpcSSE = ss.SumSSE / nt
800 | ss.SumSSE = 0
801 | ss.EpcAvgSSE = ss.SumAvgSSE / nt
802 | ss.SumAvgSSE = 0
803 | ss.EpcPctErr = float64(ss.SumErr) / nt
804 | ss.SumErr = 0
805 | ss.EpcPctCor = 1 - ss.EpcPctErr
806 | ss.EpcCosDiff = ss.SumCosDiff / nt
807 | ss.SumCosDiff = 0
808 | if ss.FirstZero < 0 && ss.EpcPctErr == 0 {
809 | ss.FirstZero = epc
810 | }
811 | if ss.EpcPctErr == 0 {
812 | ss.NZero++
813 | } else {
814 | ss.NZero = 0
815 | }
816 |
817 | if ss.LastEpcTime.IsZero() {
818 | ss.EpcPerTrlMSec = 0
819 | } else {
820 | iv := time.Now().Sub(ss.LastEpcTime)
821 | ss.EpcPerTrlMSec = float64(iv) / (nt * float64(time.Millisecond))
822 | }
823 | ss.LastEpcTime = time.Now()
824 |
825 | dt.SetCellFloat("Run", row, float64(ss.TrainEnv.Run.Cur))
826 | dt.SetCellFloat("Epoch", row, float64(epc))
827 | dt.SetCellFloat("SSE", row, ss.EpcSSE)
828 | dt.SetCellFloat("AvgSSE", row, ss.EpcAvgSSE)
829 | dt.SetCellFloat("PctErr", row, ss.EpcPctErr)
830 | dt.SetCellFloat("PctCor", row, ss.EpcPctCor)
831 | dt.SetCellFloat("CosDiff", row, ss.EpcCosDiff)
832 | dt.SetCellFloat("PerTrlMSec", row, ss.EpcPerTrlMSec)
833 |
834 | for _, lnm := range ss.LayStatNms {
835 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra()
836 | dt.SetCellFloat(ly.Nm+"_ActAvg", row, float64(ly.Pools[0].ActAvg.ActPAvgEff))
837 | }
838 |
839 | // note: essential to use Go version of update when called from another goroutine
840 | ss.TrnEpcPlot.GoUpdate()
841 | if ss.TrnEpcFile != nil {
842 | if ss.TrainEnv.Run.Cur == 0 && epc == 0 {
843 | dt.WriteCSVHeaders(ss.TrnEpcFile, etable.Tab)
844 | }
845 | dt.WriteCSVRow(ss.TrnEpcFile, row, etable.Tab)
846 | }
847 | }
848 |
849 | func (ss *Sim) ConfigTrnEpcLog(dt *etable.Table) {
850 | dt.SetMetaData("name", "TrnEpcLog")
851 | dt.SetMetaData("desc", "Record of performance over epochs of training")
852 | dt.SetMetaData("read-only", "true")
853 | dt.SetMetaData("precision", strconv.Itoa(LogPrec))
854 |
855 | sch := etable.Schema{
856 | {"Run", etensor.INT64, nil, nil},
857 | {"Epoch", etensor.INT64, nil, nil},
858 | {"SSE", etensor.FLOAT64, nil, nil},
859 | {"AvgSSE", etensor.FLOAT64, nil, nil},
860 | {"PctErr", etensor.FLOAT64, nil, nil},
861 | {"PctCor", etensor.FLOAT64, nil, nil},
862 | {"CosDiff", etensor.FLOAT64, nil, nil},
863 | {"PerTrlMSec", etensor.FLOAT64, nil, nil},
864 | }
865 | for _, lnm := range ss.LayStatNms {
866 | sch = append(sch, etable.Column{lnm + "_ActAvg", etensor.FLOAT64, nil, nil})
867 | }
868 | dt.SetFromSchema(sch, 0)
869 | }
870 |
871 | func (ss *Sim) ConfigTrnEpcPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D {
872 | plt.Params.Title = "Leabra Random Associator 25 Epoch Plot"
873 | plt.Params.XAxisCol = "Epoch"
874 | plt.SetTable(dt)
875 | // order of params: on, fixMin, min, fixMax, max
876 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
877 | plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
878 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
879 | plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
880 | plt.SetColParams("PctErr", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot
881 | plt.SetColParams("PctCor", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot
882 | plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
883 | plt.SetColParams("PerTrlMSec", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
884 |
885 | for _, lnm := range ss.LayStatNms {
886 | plt.SetColParams(lnm+"_ActAvg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, .5)
887 | }
888 | return plt
889 | }
890 |
891 | //////////////////////////////////////////////
892 | // TstTrlLog
893 |
894 | // LogTstTrl adds data from current trial to the TstTrlLog table.
895 | // log always contains number of testing items
896 | func (ss *Sim) LogTstTrl(dt *etable.Table) {
897 | epc := ss.TrainEnv.Epoch.Prv // this is triggered by increment so use previous value
898 | inp := ss.Net.LayerByName("Input").(leabra.LeabraLayer).AsLeabra()
899 | out := ss.Net.LayerByName("Output").(leabra.LeabraLayer).AsLeabra()
900 |
901 | trl := ss.TestEnv.Trial.Cur
902 | row := trl
903 |
904 | if dt.Rows <= row {
905 | dt.SetNumRows(row + 1)
906 | }
907 |
908 | dt.SetCellFloat("Run", row, float64(ss.TrainEnv.Run.Cur))
909 | dt.SetCellFloat("Epoch", row, float64(epc))
910 | dt.SetCellFloat("Trial", row, float64(trl))
911 | dt.SetCellString("TrialName", row, ss.TestEnv.TrialName.Cur)
912 | dt.SetCellFloat("Err", row, ss.TrlErr)
913 | dt.SetCellFloat("SSE", row, ss.TrlSSE)
914 | dt.SetCellFloat("AvgSSE", row, ss.TrlAvgSSE)
915 | dt.SetCellFloat("CosDiff", row, ss.TrlCosDiff)
916 |
917 | for _, lnm := range ss.LayStatNms {
918 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra()
919 | dt.SetCellFloat(ly.Nm+" ActM.Avg", row, float64(ly.Pools[0].ActM.Avg))
920 | }
921 | ivt := ss.ValsTsr("Input")
922 | ovt := ss.ValsTsr("Output")
923 | inp.UnitValsTensor(ivt, "Act")
924 | dt.SetCellTensor("InAct", row, ivt)
925 | out.UnitValsTensor(ovt, "ActM")
926 | dt.SetCellTensor("OutActM", row, ovt)
927 | out.UnitValsTensor(ovt, "ActP")
928 | dt.SetCellTensor("OutActP", row, ovt)
929 |
930 | // note: essential to use Go version of update when called from another goroutine
931 | ss.TstTrlPlot.GoUpdate()
932 | }
933 |
934 | func (ss *Sim) ConfigTstTrlLog(dt *etable.Table) {
935 | inp := ss.Net.LayerByName("Input").(leabra.LeabraLayer).AsLeabra()
936 | out := ss.Net.LayerByName("Output").(leabra.LeabraLayer).AsLeabra()
937 |
938 | dt.SetMetaData("name", "TstTrlLog")
939 | dt.SetMetaData("desc", "Record of testing per input pattern")
940 | dt.SetMetaData("read-only", "true")
941 | dt.SetMetaData("precision", strconv.Itoa(LogPrec))
942 |
943 | nt := ss.TestEnv.Table.Len() // number in view
944 | sch := etable.Schema{
945 | {"Run", etensor.INT64, nil, nil},
946 | {"Epoch", etensor.INT64, nil, nil},
947 | {"Trial", etensor.INT64, nil, nil},
948 | {"TrialName", etensor.STRING, nil, nil},
949 | {"Err", etensor.FLOAT64, nil, nil},
950 | {"SSE", etensor.FLOAT64, nil, nil},
951 | {"AvgSSE", etensor.FLOAT64, nil, nil},
952 | {"CosDiff", etensor.FLOAT64, nil, nil},
953 | }
954 | for _, lnm := range ss.LayStatNms {
955 | sch = append(sch, etable.Column{lnm + " ActM.Avg", etensor.FLOAT64, nil, nil})
956 | }
957 | sch = append(sch, etable.Schema{
958 | {"InAct", etensor.FLOAT64, inp.Shp.Shp, nil},
959 | {"OutActM", etensor.FLOAT64, out.Shp.Shp, nil},
960 | {"OutActP", etensor.FLOAT64, out.Shp.Shp, nil},
961 | }...)
962 | dt.SetFromSchema(sch, nt)
963 | }
964 |
965 | func (ss *Sim) ConfigTstTrlPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D {
966 | plt.Params.Title = "Leabra Random Associator 25 Test Trial Plot"
967 | plt.Params.XAxisCol = "Trial"
968 | plt.SetTable(dt)
969 | // order of params: on, fixMin, min, fixMax, max
970 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
971 | plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
972 | plt.SetColParams("Trial", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
973 | plt.SetColParams("TrialName", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
974 | plt.SetColParams("Err", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
975 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
976 | plt.SetColParams("AvgSSE", eplot.On, eplot.FixMin, 0, eplot.FloatMax, 0)
977 | plt.SetColParams("CosDiff", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1)
978 |
979 | for _, lnm := range ss.LayStatNms {
980 | plt.SetColParams(lnm+" ActM.Avg", eplot.Off, eplot.FixMin, 0, eplot.FixMax, .5)
981 | }
982 |
983 | plt.SetColParams("InAct", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
984 | plt.SetColParams("OutActM", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
985 | plt.SetColParams("OutActP", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
986 | return plt
987 | }
988 |
989 | //////////////////////////////////////////////
990 | // TstEpcLog
991 |
992 | func (ss *Sim) LogTstEpc(dt *etable.Table) {
993 | row := dt.Rows
994 | dt.SetNumRows(row + 1)
995 |
996 | trl := ss.TstTrlLog
997 | tix := etable.NewIdxView(trl)
998 | epc := ss.TrainEnv.Epoch.Prv // ?
999 |
1000 | // note: this shows how to use agg methods to compute summary data from another
1001 | // data table, instead of incrementing on the Sim
1002 | dt.SetCellFloat("Run", row, float64(ss.TrainEnv.Run.Cur))
1003 | dt.SetCellFloat("Epoch", row, float64(epc))
1004 | dt.SetCellFloat("SSE", row, agg.Sum(tix, "SSE")[0])
1005 | dt.SetCellFloat("AvgSSE", row, agg.Mean(tix, "AvgSSE")[0])
1006 | dt.SetCellFloat("PctErr", row, agg.Mean(tix, "Err")[0])
1007 | dt.SetCellFloat("PctCor", row, 1-agg.Mean(tix, "Err")[0])
1008 | dt.SetCellFloat("CosDiff", row, agg.Mean(tix, "CosDiff")[0])
1009 |
1010 | trlix := etable.NewIdxView(trl)
1011 | trlix.Filter(func(et *etable.Table, row int) bool {
1012 | return et.CellFloat("SSE", row) > 0 // include error trials
1013 | })
1014 | ss.TstErrLog = trlix.NewTable()
1015 |
1016 | allsp := split.All(trlix)
1017 | split.Agg(allsp, "SSE", agg.AggSum)
1018 | split.Agg(allsp, "AvgSSE", agg.AggMean)
1019 | split.Agg(allsp, "InAct", agg.AggMean)
1020 | split.Agg(allsp, "OutActM", agg.AggMean)
1021 | split.Agg(allsp, "OutActP", agg.AggMean)
1022 |
1023 | ss.TstErrStats = allsp.AggsToTable(etable.AddAggName)
1024 |
1025 | // note: essential to use Go version of update when called from another goroutine
1026 | ss.TstEpcPlot.GoUpdate()
1027 | }
1028 |
1029 | func (ss *Sim) ConfigTstEpcLog(dt *etable.Table) {
1030 | dt.SetMetaData("name", "TstEpcLog")
1031 | dt.SetMetaData("desc", "Summary stats for testing trials")
1032 | dt.SetMetaData("read-only", "true")
1033 | dt.SetMetaData("precision", strconv.Itoa(LogPrec))
1034 |
1035 | dt.SetFromSchema(etable.Schema{
1036 | {"Run", etensor.INT64, nil, nil},
1037 | {"Epoch", etensor.INT64, nil, nil},
1038 | {"SSE", etensor.FLOAT64, nil, nil},
1039 | {"AvgSSE", etensor.FLOAT64, nil, nil},
1040 | {"PctErr", etensor.FLOAT64, nil, nil},
1041 | {"PctCor", etensor.FLOAT64, nil, nil},
1042 | {"CosDiff", etensor.FLOAT64, nil, nil},
1043 | }, 0)
1044 | }
1045 |
1046 | func (ss *Sim) ConfigTstEpcPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D {
1047 | plt.Params.Title = "Leabra Random Associator 25 Testing Epoch Plot"
1048 | plt.Params.XAxisCol = "Epoch"
1049 | plt.SetTable(dt)
1050 | // order of params: on, fixMin, min, fixMax, max
1051 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1052 | plt.SetColParams("Epoch", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1053 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1054 | plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1055 | plt.SetColParams("PctErr", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot
1056 | plt.SetColParams("PctCor", eplot.On, eplot.FixMin, 0, eplot.FixMax, 1) // default plot
1057 | plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
1058 | return plt
1059 | }
1060 |
1061 | //////////////////////////////////////////////
1062 | // TstCycLog
1063 |
1064 | // LogTstCyc adds data from current trial to the TstCycLog table.
1065 | // log just has 100 cycles, is overwritten
1066 | func (ss *Sim) LogTstCyc(dt *etable.Table, cyc int) {
1067 | if dt.Rows <= cyc {
1068 | dt.SetNumRows(cyc + 1)
1069 | }
1070 |
1071 | dt.SetCellFloat("Cycle", cyc, float64(cyc))
1072 | for _, lnm := range ss.LayStatNms {
1073 | ly := ss.Net.LayerByName(lnm).(leabra.LeabraLayer).AsLeabra()
1074 | dt.SetCellFloat(ly.Nm+" Ge.Avg", cyc, float64(ly.Pools[0].Inhib.Ge.Avg))
1075 | dt.SetCellFloat(ly.Nm+" Act.Avg", cyc, float64(ly.Pools[0].Inhib.Act.Avg))
1076 | }
1077 |
1078 | if cyc%10 == 0 { // too slow to do every cyc
1079 | // note: essential to use Go version of update when called from another goroutine
1080 | ss.TstCycPlot.GoUpdate()
1081 | }
1082 | }
1083 |
1084 | func (ss *Sim) ConfigTstCycLog(dt *etable.Table) {
1085 | dt.SetMetaData("name", "TstCycLog")
1086 | dt.SetMetaData("desc", "Record of activity etc over one trial by cycle")
1087 | dt.SetMetaData("read-only", "true")
1088 | dt.SetMetaData("precision", strconv.Itoa(LogPrec))
1089 |
1090 | np := 100 // max cycles
1091 | sch := etable.Schema{
1092 | {"Cycle", etensor.INT64, nil, nil},
1093 | }
1094 | for _, lnm := range ss.LayStatNms {
1095 | sch = append(sch, etable.Column{lnm + " Ge.Avg", etensor.FLOAT64, nil, nil})
1096 | sch = append(sch, etable.Column{lnm + " Act.Avg", etensor.FLOAT64, nil, nil})
1097 | }
1098 | dt.SetFromSchema(sch, np)
1099 | }
1100 |
1101 | func (ss *Sim) ConfigTstCycPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D {
1102 | plt.Params.Title = "Leabra Random Associator 25 Test Cycle Plot"
1103 | plt.Params.XAxisCol = "Cycle"
1104 | plt.SetTable(dt)
1105 | // order of params: on, fixMin, min, fixMax, max
1106 | plt.SetColParams("Cycle", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1107 | for _, lnm := range ss.LayStatNms {
1108 | plt.SetColParams(lnm+" Ge.Avg", true, true, 0, true, .5)
1109 | plt.SetColParams(lnm+" Act.Avg", true, true, 0, true, .5)
1110 | }
1111 | return plt
1112 | }
1113 |
1114 | //////////////////////////////////////////////
1115 | // RunLog
1116 |
1117 | // LogRun adds data from current run to the RunLog table.
1118 | func (ss *Sim) LogRun(dt *etable.Table) {
1119 | run := ss.TrainEnv.Run.Cur // this is NOT triggered by increment yet -- use Cur
1120 | row := dt.Rows
1121 | dt.SetNumRows(row + 1)
1122 |
1123 | epclog := ss.TrnEpcLog
1124 | epcix := etable.NewIdxView(epclog)
1125 | // compute mean over last N epochs for run level
1126 | nlast := 5
1127 | if nlast > epcix.Len()-1 {
1128 | nlast = epcix.Len() - 1
1129 | }
1130 | epcix.Idxs = epcix.Idxs[epcix.Len()-nlast:]
1131 |
1132 | params := ss.RunName() // includes tag
1133 |
1134 | dt.SetCellFloat("Run", row, float64(run))
1135 | dt.SetCellString("Params", row, params)
1136 | dt.SetCellFloat("FirstZero", row, float64(ss.FirstZero))
1137 | dt.SetCellFloat("SSE", row, agg.Mean(epcix, "SSE")[0])
1138 | dt.SetCellFloat("AvgSSE", row, agg.Mean(epcix, "AvgSSE")[0])
1139 | dt.SetCellFloat("PctErr", row, agg.Mean(epcix, "PctErr")[0])
1140 | dt.SetCellFloat("PctCor", row, agg.Mean(epcix, "PctCor")[0])
1141 | dt.SetCellFloat("CosDiff", row, agg.Mean(epcix, "CosDiff")[0])
1142 |
1143 | runix := etable.NewIdxView(dt)
1144 | spl := split.GroupBy(runix, []string{"Params"})
1145 | split.Desc(spl, "FirstZero")
1146 | split.Desc(spl, "PctCor")
1147 | ss.RunStats = spl.AggsToTable(etable.AddAggName)
1148 |
1149 | // note: essential to use Go version of update when called from another goroutine
1150 | ss.RunPlot.GoUpdate()
1151 | if ss.RunFile != nil {
1152 | if row == 0 {
1153 | dt.WriteCSVHeaders(ss.RunFile, etable.Tab)
1154 | }
1155 | dt.WriteCSVRow(ss.RunFile, row, etable.Tab)
1156 | }
1157 | }
1158 |
1159 | func (ss *Sim) ConfigRunLog(dt *etable.Table) {
1160 | dt.SetMetaData("name", "RunLog")
1161 | dt.SetMetaData("desc", "Record of performance at end of training")
1162 | dt.SetMetaData("read-only", "true")
1163 | dt.SetMetaData("precision", strconv.Itoa(LogPrec))
1164 |
1165 | dt.SetFromSchema(etable.Schema{
1166 | {"Run", etensor.INT64, nil, nil},
1167 | {"Params", etensor.STRING, nil, nil},
1168 | {"FirstZero", etensor.FLOAT64, nil, nil},
1169 | {"SSE", etensor.FLOAT64, nil, nil},
1170 | {"AvgSSE", etensor.FLOAT64, nil, nil},
1171 | {"PctErr", etensor.FLOAT64, nil, nil},
1172 | {"PctCor", etensor.FLOAT64, nil, nil},
1173 | {"CosDiff", etensor.FLOAT64, nil, nil},
1174 | }, 0)
1175 | }
1176 |
1177 | func (ss *Sim) ConfigRunPlot(plt *eplot.Plot2D, dt *etable.Table) *eplot.Plot2D {
1178 | plt.Params.Title = "Leabra Random Associator 25 Run Plot"
1179 | plt.Params.XAxisCol = "Run"
1180 | plt.SetTable(dt)
1181 | // order of params: on, fixMin, min, fixMax, max
1182 | plt.SetColParams("Run", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1183 | plt.SetColParams("FirstZero", eplot.On, eplot.FixMin, 0, eplot.FloatMax, 0) // default plot
1184 | plt.SetColParams("SSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1185 | plt.SetColParams("AvgSSE", eplot.Off, eplot.FixMin, 0, eplot.FloatMax, 0)
1186 | plt.SetColParams("PctErr", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
1187 | plt.SetColParams("PctCor", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
1188 | plt.SetColParams("CosDiff", eplot.Off, eplot.FixMin, 0, eplot.FixMax, 1)
1189 | return plt
1190 | }
1191 |
1192 | ////////////////////////////////////////////////////////////////////////////////////////////
1193 | // Gui
1194 |
1195 | // ConfigGui configures the GoGi gui interface for this simulation,
1196 | func (ss *Sim) ConfigGui() *gi.Window {
1197 | width := 1600
1198 | height := 1200
1199 |
1200 | gi.SetAppName("ra25")
1201 | gi.SetAppAbout(`This demonstrates a basic Leabra model. See emergent on GitHub.
`)
1202 |
1203 | win := gi.NewMainWindow("ra25", "Leabra Random Associator", width, height)
1204 | ss.Win = win
1205 |
1206 | vp := win.WinViewport2D()
1207 | updt := vp.UpdateStart()
1208 |
1209 | mfr := win.SetMainFrame()
1210 |
1211 | tbar := gi.AddNewToolBar(mfr, "tbar")
1212 | tbar.SetStretchMaxWidth()
1213 | ss.ToolBar = tbar
1214 |
1215 | split := gi.AddNewSplitView(mfr, "split")
1216 | split.Dim = mat32.X
1217 | split.SetStretchMax()
1218 |
1219 | sv := giv.AddNewStructView(split, "sv")
1220 | sv.SetStruct(ss)
1221 |
1222 | tv := gi.AddNewTabView(split, "tv")
1223 |
1224 | nv := tv.AddNewTab(netview.KiT_NetView, "NetView").(*netview.NetView)
1225 | nv.Var = "Act"
1226 | // nv.Params.ColorMap = "Jet" // default is ColdHot
1227 | // which fares pretty well in terms of discussion here:
1228 | // https://matplotlib.org/tutorials/colors/colormaps.html
1229 | nv.SetNet(ss.Net)
1230 | ss.NetView = nv
1231 |
1232 | nv.Scene().Camera.Pose.Pos.Set(0, 1, 2.75) // more "head on" than default which is more "top down"
1233 | nv.Scene().Camera.LookAt(mat32.Vec3{0, 0, 0}, mat32.Vec3{0, 1, 0})
1234 |
1235 | plt := tv.AddNewTab(eplot.KiT_Plot2D, "TrnEpcPlot").(*eplot.Plot2D)
1236 | ss.TrnEpcPlot = ss.ConfigTrnEpcPlot(plt, ss.TrnEpcLog)
1237 |
1238 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "TstTrlPlot").(*eplot.Plot2D)
1239 | ss.TstTrlPlot = ss.ConfigTstTrlPlot(plt, ss.TstTrlLog)
1240 |
1241 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "TstCycPlot").(*eplot.Plot2D)
1242 | ss.TstCycPlot = ss.ConfigTstCycPlot(plt, ss.TstCycLog)
1243 |
1244 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "TstEpcPlot").(*eplot.Plot2D)
1245 | ss.TstEpcPlot = ss.ConfigTstEpcPlot(plt, ss.TstEpcLog)
1246 |
1247 | plt = tv.AddNewTab(eplot.KiT_Plot2D, "RunPlot").(*eplot.Plot2D)
1248 | ss.RunPlot = ss.ConfigRunPlot(plt, ss.RunLog)
1249 |
1250 | split.SetSplits(.3, .7)
1251 |
1252 | tbar.AddAction(gi.ActOpts{Label: "Init", Icon: "update", Tooltip: "Initialize everything including network weights, and start over. Also applies current params.", UpdateFunc: func(act *gi.Action) {
1253 | act.SetActiveStateUpdt(!ss.IsRunning)
1254 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1255 | ss.Init()
1256 | vp.SetNeedsFullRender()
1257 | })
1258 |
1259 | tbar.AddAction(gi.ActOpts{Label: "Train", Icon: "run", Tooltip: "Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.",
1260 | UpdateFunc: func(act *gi.Action) {
1261 | act.SetActiveStateUpdt(!ss.IsRunning)
1262 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1263 | if !ss.IsRunning {
1264 | ss.IsRunning = true
1265 | tbar.UpdateActions()
1266 | // ss.Train()
1267 | go ss.Train()
1268 | }
1269 | })
1270 |
1271 | tbar.AddAction(gi.ActOpts{Label: "Stop", Icon: "stop", Tooltip: "Interrupts running. Hitting Train again will pick back up where it left off.", UpdateFunc: func(act *gi.Action) {
1272 | act.SetActiveStateUpdt(ss.IsRunning)
1273 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1274 | ss.Stop()
1275 | })
1276 |
1277 | tbar.AddAction(gi.ActOpts{Label: "Step Trial", Icon: "step-fwd", Tooltip: "Advances one training trial at a time.", UpdateFunc: func(act *gi.Action) {
1278 | act.SetActiveStateUpdt(!ss.IsRunning)
1279 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1280 | if !ss.IsRunning {
1281 | ss.IsRunning = true
1282 | ss.TrainTrial()
1283 | ss.IsRunning = false
1284 | vp.SetNeedsFullRender()
1285 | }
1286 | })
1287 |
1288 | tbar.AddAction(gi.ActOpts{Label: "Step Epoch", Icon: "fast-fwd", Tooltip: "Advances one epoch (complete set of training patterns) at a time.", UpdateFunc: func(act *gi.Action) {
1289 | act.SetActiveStateUpdt(!ss.IsRunning)
1290 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1291 | if !ss.IsRunning {
1292 | ss.IsRunning = true
1293 | tbar.UpdateActions()
1294 | go ss.TrainEpoch()
1295 | }
1296 | })
1297 |
1298 | tbar.AddAction(gi.ActOpts{Label: "Step Run", Icon: "fast-fwd", Tooltip: "Advances one full training Run at a time.", UpdateFunc: func(act *gi.Action) {
1299 | act.SetActiveStateUpdt(!ss.IsRunning)
1300 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1301 | if !ss.IsRunning {
1302 | ss.IsRunning = true
1303 | tbar.UpdateActions()
1304 | go ss.TrainRun()
1305 | }
1306 | })
1307 |
1308 | tbar.AddSeparator("test")
1309 |
1310 | tbar.AddAction(gi.ActOpts{Label: "Test Trial", Icon: "step-fwd", Tooltip: "Runs the next testing trial.", UpdateFunc: func(act *gi.Action) {
1311 | act.SetActiveStateUpdt(!ss.IsRunning)
1312 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1313 | if !ss.IsRunning {
1314 | ss.IsRunning = true
1315 | ss.TestTrial(false) // don't return on change -- wrap
1316 | ss.IsRunning = false
1317 | vp.SetNeedsFullRender()
1318 | }
1319 | })
1320 |
1321 | tbar.AddAction(gi.ActOpts{Label: "Test Item", Icon: "step-fwd", Tooltip: "Prompts for a specific input pattern name to run, and runs it in testing mode.", UpdateFunc: func(act *gi.Action) {
1322 | act.SetActiveStateUpdt(!ss.IsRunning)
1323 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1324 | gi.StringPromptDialog(vp, "", "Test Item",
1325 | gi.DlgOpts{Title: "Test Item", Prompt: "Enter the Name of a given input pattern to test (case insensitive, contains given string."},
1326 | win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1327 | dlg := send.(*gi.Dialog)
1328 | if sig == int64(gi.DialogAccepted) {
1329 | val := gi.StringPromptDialogValue(dlg)
1330 | idxs := ss.TestEnv.Table.RowsByString("Name", val, etable.Contains, etable.IgnoreCase)
1331 | if len(idxs) == 0 {
1332 | gi.PromptDialog(nil, gi.DlgOpts{Title: "Name Not Found", Prompt: "No patterns found containing: " + val}, gi.AddOk, gi.NoCancel, nil, nil)
1333 | } else {
1334 | if !ss.IsRunning {
1335 | ss.IsRunning = true
1336 | fmt.Printf("testing index: %d\n", idxs[0])
1337 | ss.TestItem(idxs[0])
1338 | ss.IsRunning = false
1339 | vp.SetNeedsFullRender()
1340 | }
1341 | }
1342 | }
1343 | })
1344 | })
1345 |
1346 | tbar.AddAction(gi.ActOpts{Label: "Test All", Icon: "fast-fwd", Tooltip: "Tests all of the testing trials.", UpdateFunc: func(act *gi.Action) {
1347 | act.SetActiveStateUpdt(!ss.IsRunning)
1348 | }}, win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1349 | if !ss.IsRunning {
1350 | ss.IsRunning = true
1351 | tbar.UpdateActions()
1352 | go ss.RunTestAll()
1353 | }
1354 | })
1355 |
1356 | tbar.AddSeparator("log")
1357 |
1358 | tbar.AddAction(gi.ActOpts{Label: "Reset RunLog", Icon: "reset", Tooltip: "Reset the accumulated log of all Runs, which are tagged with the ParamSet used"}, win.This(),
1359 | func(recv, send ki.Ki, sig int64, data interface{}) {
1360 | ss.RunLog.SetNumRows(0)
1361 | ss.RunPlot.Update()
1362 | })
1363 |
1364 | tbar.AddSeparator("misc")
1365 |
1366 | tbar.AddAction(gi.ActOpts{Label: "New Seed", Icon: "new", Tooltip: "Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time."}, win.This(),
1367 | func(recv, send ki.Ki, sig int64, data interface{}) {
1368 | ss.NewRndSeed()
1369 | })
1370 |
1371 | tbar.AddAction(gi.ActOpts{Label: "README", Icon: "file-markdown", Tooltip: "Opens your browser on the README file that contains instructions for how to run this model."}, win.This(),
1372 | func(recv, send ki.Ki, sig int64, data interface{}) {
1373 | gi.OpenURL("https://github.com/emer/leabra/blob/master/examples/ra25/README.md")
1374 | })
1375 |
1376 | vp.UpdateEndNoSig(updt)
1377 |
1378 | // main menu
1379 | appnm := gi.AppName()
1380 | mmen := win.MainMenu
1381 | mmen.ConfigMenus([]string{appnm, "File", "Edit", "Window"})
1382 |
1383 | amen := win.MainMenu.ChildByName(appnm, 0).(*gi.Action)
1384 | amen.Menu.AddAppMenu(win)
1385 |
1386 | emen := win.MainMenu.ChildByName("Edit", 1).(*gi.Action)
1387 | emen.Menu.AddCopyCutPaste(win)
1388 |
1389 | // note: Command in shortcuts is automatically translated into Control for
1390 | // Linux, Windows or Meta for MacOS
1391 | // fmen := win.MainMenu.ChildByName("File", 0).(*gi.Action)
1392 | // fmen.Menu.AddAction(gi.ActOpts{Label: "Open", Shortcut: "Command+O"},
1393 | // win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1394 | // FileViewOpenSVG(vp)
1395 | // })
1396 | // fmen.Menu.AddSeparator("csep")
1397 | // fmen.Menu.AddAction(gi.ActOpts{Label: "Close Window", Shortcut: "Command+W"},
1398 | // win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1399 | // win.Close()
1400 | // })
1401 |
1402 | inQuitPrompt := false
1403 | gi.SetQuitReqFunc(func() {
1404 | if inQuitPrompt {
1405 | return
1406 | }
1407 | inQuitPrompt = true
1408 | gi.PromptDialog(vp, gi.DlgOpts{Title: "Really Quit?",
1409 | Prompt: "Are you sure you want to quit and lose any unsaved params, weights, logs, etc?"}, gi.AddOk, gi.AddCancel,
1410 | win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1411 | if sig == int64(gi.DialogAccepted) {
1412 | gi.Quit()
1413 | } else {
1414 | inQuitPrompt = false
1415 | }
1416 | })
1417 | })
1418 |
1419 | // gi.SetQuitCleanFunc(func() {
1420 | // fmt.Printf("Doing final Quit cleanup here..\n")
1421 | // })
1422 |
1423 | inClosePrompt := false
1424 | win.SetCloseReqFunc(func(w *gi.Window) {
1425 | if inClosePrompt {
1426 | return
1427 | }
1428 | inClosePrompt = true
1429 | gi.PromptDialog(vp, gi.DlgOpts{Title: "Really Close Window?",
1430 | Prompt: "Are you sure you want to close the window? This will Quit the App as well, losing all unsaved params, weights, logs, etc"}, gi.AddOk, gi.AddCancel,
1431 | win.This(), func(recv, send ki.Ki, sig int64, data interface{}) {
1432 | if sig == int64(gi.DialogAccepted) {
1433 | gi.Quit()
1434 | } else {
1435 | inClosePrompt = false
1436 | }
1437 | })
1438 | })
1439 |
1440 | win.SetCloseCleanFunc(func(w *gi.Window) {
1441 | go gi.Quit() // once main window is closed, quit
1442 | })
1443 |
1444 | win.MainMenuUpdated()
1445 | return win
1446 | }
1447 |
1448 | // These props register Save methods so they can be used
1449 | var SimProps = ki.Props{
1450 | "CallMethods": ki.PropSlice{
1451 | {"SaveWeights", ki.Props{
1452 | "desc": "save network weights to file",
1453 | "icon": "file-save",
1454 | "Args": ki.PropSlice{
1455 | {"File Name", ki.Props{
1456 | "ext": ".wts,.wts.gz",
1457 | }},
1458 | },
1459 | }},
1460 | },
1461 | }
1462 |
1463 | func (ss *Sim) CmdArgs() {
1464 | ss.NoGui = true
1465 | var nogui bool
1466 | var saveEpcLog bool
1467 | var saveRunLog bool
1468 | var note string
1469 | flag.StringVar(&ss.ParamSet, "params", "", "ParamSet name to use -- must be valid name as listed in compiled-in params or loaded params")
1470 | flag.StringVar(&ss.Tag, "tag", "", "extra tag to add to file names saved from this run")
1471 | flag.StringVar(¬e, "note", "", "user note -- describe the run params etc")
1472 | flag.IntVar(&ss.MaxRuns, "runs", 10, "number of runs to do (note that MaxEpcs is in paramset)")
1473 | flag.BoolVar(&ss.LogSetParams, "setparams", false, "if true, print a record of each parameter that is set")
1474 | flag.BoolVar(&ss.SaveWts, "wts", false, "if true, save final weights after each run")
1475 | flag.BoolVar(&saveEpcLog, "epclog", true, "if true, save train epoch log to file")
1476 | flag.BoolVar(&saveRunLog, "runlog", true, "if true, save run epoch log to file")
1477 | flag.BoolVar(&nogui, "nogui", true, "if not passing any other args and want to run nogui, use nogui")
1478 | flag.Parse()
1479 | ss.Init()
1480 |
1481 | if note != "" {
1482 | fmt.Printf("note: %s\n", note)
1483 | }
1484 | if ss.ParamSet != "" {
1485 | fmt.Printf("Using ParamSet: %s\n", ss.ParamSet)
1486 | }
1487 |
1488 | if saveEpcLog {
1489 | var err error
1490 | fnm := ss.LogFileName("epc")
1491 | ss.TrnEpcFile, err = os.Create(fnm)
1492 | if err != nil {
1493 | log.Println(err)
1494 | ss.TrnEpcFile = nil
1495 | } else {
1496 | fmt.Printf("Saving epoch log to: %s\n", fnm)
1497 | defer ss.TrnEpcFile.Close()
1498 | }
1499 | }
1500 | if saveRunLog {
1501 | var err error
1502 | fnm := ss.LogFileName("run")
1503 | ss.RunFile, err = os.Create(fnm)
1504 | if err != nil {
1505 | log.Println(err)
1506 | ss.RunFile = nil
1507 | } else {
1508 | fmt.Printf("Saving run log to: %s\n", fnm)
1509 | defer ss.RunFile.Close()
1510 | }
1511 | }
1512 | if ss.SaveWts {
1513 | fmt.Printf("Saving final weights per run\n")
1514 | }
1515 | fmt.Printf("Running %d Runs\n", ss.MaxRuns)
1516 | ss.Train()
1517 | }
1518 |
1519 |
--------------------------------------------------------------------------------
/testdata/ra25.py:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/pyleabra
2 |
3 | # Copyright (c) 2019, The Emergent Authors. All rights reserved.
4 | # Use of this source code is governed by a BSD-style
5 | # license that can be found in the LICENSE file.
6 |
7 | # use:
8 | # pyleabra -i ra25.py
9 | # to run in gui interactive mode from the command line (or pyleabra, import ra25)
10 | # see main function at the end for startup args
11 |
12 | # to run this python version of the demo:
13 | # * install gopy, currently in fork at https://github.com/goki/gopy
14 | # e.g., 'go get github.com/goki/gopy -u ./...' and then cd to that package
15 | # and do 'go install'
16 | # * go to the python directory in this emergent repository, read README.md there, and
17 | # type 'make' -- if that works, then type make install (may need sudo)
18 | # * cd back here, and run 'pyemergent' which was installed into /usr/local/bin
19 | # * then type 'import ra25' and this should run
20 | # * you'll need various standard packages such as pandas, numpy, matplotlib, etc
21 |
22 | # labra25ra runs a simple random-associator 5x5 = 25 four-layer leabra network
23 |
24 | from leabra import go, leabra, emer, relpos, eplot, env, agg, patgen, prjn, etable, efile, split, etensor, params, netview, rand, erand, gi, giv, epygiv
25 |
26 | # il.reload(ra25) -- doesn't seem to work for reasons unknown
27 | import importlib as il
28 | import io
29 | import sys
30 | import getopt
31 | # import numpy as np
32 | # import matplotlib
33 | # matplotlib.use('SVG')
34 | # import matplotlib.pyplot as plt
35 | # plt.rcParams['svg.fonttype'] = 'none' # essential for not rendering fonts as paths
36 |
37 | # note: pandas, xarray or pytorch TensorDataSet can be used for input / output
38 | # patterns and recording of "log" data for plotting. However, the etable.Table
39 | # has better GUI and API support, and handles tensor columns directly unlike
40 | # pandas. Support for easy migration between these is forthcoming.
41 | # import pandas as pd
42 |
43 | # this will become Sim later..
44 | TheSim = 1
45 |
46 | # use this for e.g., etable.Column construction args where nil would be passed
47 | nilInts = go.Slice_int()
48 |
49 | # use this for e.g., etable.Column construction args where nil would be passed
50 | nilStrs = go.Slice_string()
51 |
52 | # LogPrec is precision for saving float values in logs
53 | LogPrec = 4
54 |
55 | # note: we cannot use methods for callbacks from Go -- must be separate functions
56 | # so below are all the callbacks from the GUI toolbar actions
57 |
58 |
59 | def InitCB(recv, send, sig, data):
60 | TheSim.Init()
61 | TheSim.ClassView.Update()
62 | TheSim.vp.SetNeedsFullRender()
63 |
64 |
65 | def TrainCB(recv, send, sig, data):
66 | if not TheSim.IsRunning:
67 | TheSim.IsRunning = True
68 | TheSim.ToolBar.UpdateActions()
69 | TheSim.Train()
70 |
71 |
72 | def StopCB(recv, send, sig, data):
73 | TheSim.Stop()
74 |
75 |
76 | def StepTrialCB(recv, send, sig, data):
77 | if not TheSim.IsRunning:
78 | TheSim.IsRunning = True
79 | TheSim.TrainTrial()
80 | TheSim.IsRunning = False
81 | TheSim.ClassView.Update()
82 | TheSim.vp.SetNeedsFullRender()
83 |
84 |
85 | def StepEpochCB(recv, send, sig, data):
86 | if not TheSim.IsRunning:
87 | TheSim.IsRunning = True
88 | TheSim.ToolBar.UpdateActions()
89 | TheSim.TrainEpoch()
90 |
91 |
92 | def StepRunCB(recv, send, sig, data):
93 | if not TheSim.IsRunning:
94 | TheSim.IsRunning = True
95 | TheSim.ToolBar.UpdateActions()
96 | TheSim.TrainRun()
97 |
98 |
99 | def TestTrialCB(recv, send, sig, data):
100 | if not TheSim.IsRunning:
101 | TheSim.IsRunning = True
102 | TheSim.TestTrial()
103 | TheSim.IsRunning = False
104 | TheSim.ClassView.Update()
105 | TheSim.vp.SetNeedsFullRender()
106 |
107 |
108 | def TestItemCB2(recv, send, sig, data):
109 | win = gi.Window(handle=recv)
110 | vp = win.WinViewport2D()
111 | dlg = gi.Dialog(handle=send)
112 | if sig != gi.DialogAccepted:
113 | return
114 | val = gi.StringPromptDialogValue(dlg)
115 | idxs = TheSim.TestEnv.Table.RowsByString(
116 | "Name", val, True, True) # contains, ignoreCase
117 | if len(idxs) == 0:
118 | gi.PromptDialog(vp, gi.DlgOpts(Title="Name Not Found",
119 | Prompt="No patterns found containing: " + val), True, False, go.nil, go.nil)
120 | else:
121 | if not TheSim.IsRunning:
122 | TheSim.IsRunning = True
123 | print("testing index: %s" % idxs[0])
124 | TheSim.TestItem(idxs[0])
125 | TheSim.IsRunning = False
126 | vp.SetNeedsFullRender()
127 |
128 |
129 | def TestItemCB(recv, send, sig, data):
130 | win = gi.Window(handle=recv)
131 | gi.StringPromptDialog(win.WinViewport2D(), "", "Test Item",
132 | gi.DlgOpts(Title="Test Item", Prompt="Enter the Name of a given input pattern to test (case insensitive, contains given string."), win, TestItemCB2)
133 |
134 |
135 | def TestAllCB(recv, send, sig, data):
136 | if not TheSim.IsRunning:
137 | TheSim.IsRunning = True
138 | TheSim.ToolBar.UpdateActions()
139 | TheSim.RunTestAll()
140 |
141 |
142 | def ResetRunLogCB(recv, send, sig, data):
143 | TheSim.RunLog.SetNumRows(0)
144 | TheSim.RunPlot.Update()
145 |
146 |
147 | def NewRndSeedCB(recv, send, sig, data):
148 | TheSim.NewRndSeed()
149 |
150 |
151 | def ReadmeCB(recv, send, sig, data):
152 | gi.OpenURL(
153 | "https://github.com/emer/leabra/blob/master/examples/ra25/README.md")
154 |
155 |
156 | def FilterSSE(et, row):
157 | # include error trials
158 | return etable.Table(handle=et).CellFloat("SSE", row) > 0
159 |
160 |
161 | def UpdtFuncNotRunning(act):
162 | act.SetActiveStateUpdt(not TheSim.IsRunning)
163 |
164 |
165 | def UpdtFuncRunning(act):
166 | act.SetActiveStateUpdt(TheSim.IsRunning)
167 |
168 |
169 | #####################################################
170 | # Sim
171 |
172 | class Sim(object):
173 | """
174 | Sim encapsulates the entire simulation model, and we define all the
175 | functionality as methods on this struct. This structure keeps all relevant
176 | state information organized and available without having to pass everything around
177 | as arguments to methods, and provides the core GUI interface (note the view tags
178 | for the fields which provide hints to how things should be displayed).
179 | """
180 |
181 | def __init__(self):
182 | self.Net = leabra.Network()
183 | self.Pats = etable.Table()
184 | self.TrnEpcLog = etable.Table()
185 | self.TstEpcLog = etable.Table()
186 | self.TstTrlLog = etable.Table()
187 | self.TstErrLog = etable.Table()
188 | self.TstErrStats = etable.Table()
189 | self.TstCycLog = etable.Table()
190 | self.RunLog = etable.Table()
191 | self.RunStats = etable.Table()
192 | self.Params = params.Sets()
193 | self.ParamSet = ""
194 | self.Tag = ""
195 | self.MaxRuns = 10
196 | self.MaxEpcs = 50
197 | self.TrainEnv = env.FixedTable()
198 | self.TestEnv = env.FixedTable()
199 | self.Time = leabra.Time()
200 | self.ViewOn = True
201 | self.TrainUpdt = leabra.AlphaCycle
202 | self.TestUpdt = leabra.Cycle
203 | self.TestInterval = 5
204 |
205 | # statistics
206 | self.TrlSSE = 0.0
207 | self.TrlAvgSSE = 0.0
208 | self.TrlCosDiff = 0.0
209 | self.EpcSSE = 0.0
210 | self.EpcAvgSSE = 0.0
211 | self.EpcPctErr = 0.0
212 | self.EpcPctCor = 0.0
213 | self.EpcCosDiff = 0.0
214 | self.FirstZero = -1
215 |
216 | # internal state - view:"-"
217 | self.SumSSE = 0.0
218 | self.SumAvgSSE = 0.0
219 | self.SumCosDiff = 0.0
220 | self.CntErr = 0.0
221 | self.Win = 0
222 | self.vp = 0
223 | self.ToolBar = 0
224 | self.NetView = 0
225 | self.TrnEpcPlot = 0
226 | self.TstEpcPlot = 0
227 | self.TstTrlPlot = 0
228 | self.TstCycPlot = 0
229 | self.RunPlot = 0
230 | self.TrnEpcFile = 0
231 | self.RunFile = 0
232 | self.InputValsTsr = 0
233 | self.OutputValsTsr = 0
234 | self.SaveWts = False
235 | self.NoGui = False
236 | self.LogSetParams = False # True
237 | self.IsRunning = False
238 | self.StopNow = False
239 | self.RndSeed = 0
240 |
241 | # ClassView tags for controlling display of fields
242 | self.Tags = {
243 | 'TrlSSE': 'inactive:"+"',
244 | 'TrlAvgSSE': 'inactive:"+"',
245 | 'TrlCosDiff': 'inactive:"+"',
246 | 'EpcSSE': 'inactive:"+"',
247 | 'EpcAvgSSE': 'inactive:"+"',
248 | 'EpcPctErr': 'inactive:"+"',
249 | 'EpcPctCor': 'inactive:"+"',
250 | 'EpcCosDiff': 'inactive:"+"',
251 | 'FirstZero': 'inactive:"+"',
252 | 'SumSSE': 'view:"-"',
253 | 'SumAvgSSE': 'view:"-"',
254 | 'SumCosDiff': 'view:"-"',
255 | 'CntErr': 'view:"-"',
256 | 'Win': 'view:"-"',
257 | 'vp': 'view:"-"',
258 | 'ToolBar': 'view:"-"',
259 | 'NetView': 'view:"-"',
260 | 'TrnEpcPlot': 'view:"-"',
261 | 'TstEpcPlot': 'view:"-"',
262 | 'TstTrlPlot': 'view:"-"',
263 | 'TstCycPlot': 'view:"-"',
264 | 'RunPlot': 'view:"-"',
265 | 'TrnEpcFile': 'view:"-"',
266 | 'RunFile': 'view:"-"',
267 | 'InputValsTsr': 'view:"-"',
268 | 'OutputValsTsr': 'view:"-"',
269 | 'SaveWts': 'view:"-"',
270 | 'NoGui': 'view:"-"',
271 | 'LogSetParams': 'view:"-"',
272 | 'IsRunning': 'view:"-"',
273 | 'StopNow': 'view:"-"',
274 | 'RndSeed': 'view:"-"',
275 | 'ClassView': 'view:"-"',
276 | 'Tags': 'view:"-"',
277 | }
278 |
279 | def InitParams(self):
280 | """
281 | Sets the default set of parameters -- Base is always applied, and others can be optionally
282 | selected to apply on top of that
283 | """
284 | self.Params.OpenJSON("ra25_std.params")
285 |
286 | def Config(ss):
287 | """
288 | Config configures all the elements using the standard functions
289 | """
290 |
291 | ss.InitParams()
292 | # self.OpenPats()
293 | ss.ConfigPats()
294 | ss.ConfigEnv()
295 | ss.ConfigNet(ss.Net)
296 | ss.ConfigTrnEpcLog(ss.TrnEpcLog)
297 | ss.ConfigTstEpcLog(ss.TstEpcLog)
298 | ss.ConfigTstTrlLog(ss.TstTrlLog)
299 | ss.ConfigTstCycLog(ss.TstCycLog)
300 | ss.ConfigRunLog(ss.RunLog)
301 |
302 | def ConfigEnv(ss):
303 | if ss.MaxRuns == 0: # allow user override
304 | ss.MaxRuns = 10
305 | if ss.MaxEpcs == 0: # allow user override
306 | ss.MaxEpcs = 50
307 | ss.NZeroStop = 5
308 |
309 | ss.TrainEnv.Nm = "TrainEnv"
310 | ss.TrainEnv.Dsc = "training params and state"
311 | ss.TrainEnv.Table = etable.NewIdxView(ss.Pats)
312 | ss.TrainEnv.Validate()
313 | # note: we are not setting epoch max -- do that manually
314 | ss.TrainEnv.Run.Max = ss.MaxRuns
315 |
316 | ss.TestEnv.Nm = "TestEnv"
317 | ss.TestEnv.Dsc = "testing params and state"
318 | ss.TestEnv.Table = etable.NewIdxView(ss.Pats)
319 | ss.TestEnv.Sequential = True
320 | ss.TestEnv.Validate()
321 |
322 | # note: to create a train / test split of pats, do this:
323 | # all = etable.NewIdxView(self.Pats)
324 | # splits = split.Permuted(all, []float64{.8, .2}, []string{"Train", "Test"})
325 | # self.TrainEnv.Table = splits.Splits[0]
326 | # self.TestEnv.Table = splits.Splits[1]
327 |
328 | ss.TrainEnv.Init(0)
329 | ss.TestEnv.Init(0)
330 |
331 | def ConfigNet(ss, net):
332 | net.InitName(net, "RA25")
333 | inLay = net.AddLayer2D("Input", 5, 5, emer.Input)
334 | hid1Lay = net.AddLayer2D("Hidden1", 7, 7, emer.Hidden)
335 | hid2Lay = net.AddLayer2D("Hidden2", 7, 7, emer.Hidden)
336 | outLay = net.AddLayer2D("Output", 5, 5, emer.Target)
337 |
338 | # use this to position layers relative to each other
339 | # default is Above, YAlign = Front, XAlign = Center
340 | hid2Lay.SetRelPos(relpos.Rel(Rel=relpos.RightOf,
341 | Other="Hidden1", YAlign=relpos.Front, Space=2))
342 |
343 | # note: see emergent/prjn module for all the options on how to connect
344 | # NewFull returns a new prjn.Full connectivity pattern
345 | net.ConnectLayers(inLay, hid1Lay, prjn.NewFull(), emer.Forward)
346 | net.ConnectLayers(hid1Lay, hid2Lay, prjn.NewFull(), emer.Forward)
347 | net.ConnectLayers(hid2Lay, outLay, prjn.NewFull(), emer.Forward)
348 |
349 | net.ConnectLayers(outLay, hid2Lay, prjn.NewFull(), emer.Back)
350 | net.ConnectLayers(hid2Lay, hid1Lay, prjn.NewFull(), emer.Back)
351 |
352 | # note: can set these to do parallel threaded computation across multiple cpus
353 | # not worth it for this small of a model, but definitely helps for larger ones
354 | # if Thread {
355 | # hid2Lay.SetThread(1)
356 | # outLay.SetThread(1)
357 | # }
358 |
359 | # note: if you wanted to change a layer type from e.g., Target to Compare, do this:
360 | # outLay.SetType(emer.Compare)
361 | # that would mean that the output layer doesn't reflect target values in plus phase
362 | # and thus removes error-driven learning -- but stats are still computed.
363 |
364 | net.Defaults()
365 | ss.SetParams("Network", ss.LogSetParams) # only set Network params
366 | net.Build()
367 | net.InitWts()
368 |
369 | ######################################
370 | # Init, utils
371 |
372 | def Init(ss):
373 | """Init restarts the run, and initializes everything, including network weights and resets the epoch log table"""
374 | rand.Seed(ss.RndSeed)
375 | ss.ConfigEnv() # just in case another set of pats was selected..
376 | ss.StopNow = False
377 | ss.SetParams("", ss.LogSetParams) # all sheets
378 | ss.NewRun()
379 | ss.UpdateView(True)
380 |
381 | def NewRndSeed(ss):
382 | """NewRndSeed gets a new random seed based on current time -- otherwise uses the same random seed for every run"""
383 | # self.RndSeed = time.Now().UnixNano()
384 |
385 | def Counters(ss, train):
386 | """
387 | Counters returns a string of the current counter state
388 | use tabs to achieve a reasonable formatting overall
389 | and add a few tabs at the end to allow for expansion..
390 | """
391 | if train:
392 | return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\tCycle:\t%d\tName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TrainEnv.Trial.Cur, ss.Time.Cycle, ss.TrainEnv.TrialName.Cur)
393 | else:
394 | return "Run:\t%d\tEpoch:\t%d\tTrial:\t%d\t\tCycle:\t%dName:\t%s\t\t\t" % (ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur, ss.TestEnv.Trial.Cur, ss.Time.Cycle, ss.TestEnv.TrialName.Cur)
395 |
396 | def UpdateView(ss, train):
397 | if ss.NetView != 0 and ss.NetView.IsVisible():
398 | ss.NetView.Record(ss.Counters(train))
399 | # note: essential to use Go version of update when called from another goroutine
400 | # note: using counters is significantly slower..
401 | ss.NetView.GoUpdate()
402 |
403 | ######################################
404 | # Running the network
405 |
406 | def AlphaCyc(ss, train):
407 | """
408 | AlphaCyc runs one alpha-cycle (100 msec, 4 quarters) of processing.
409 | External inputs must have already been applied prior to calling,
410 | using ApplyExt method on relevant layers (see TrainTrial, TestTrial).
411 | If train is true, then learning DWt or WtFmDWt calls are made.
412 | Handles netview updating within scope of AlphaCycle
413 | """
414 | if ss.Win != 0:
415 | ss.Win.PollEvents() # this is essential for GUI responsiveness while running
416 | viewUpdt = ss.TrainUpdt
417 | if not train:
418 | viewUpdt = ss.TestUpdt
419 |
420 | # update prior weight changes at start, so any DWt values remain visible at end
421 | # you might want to do this less frequently to achieve a mini-batch update
422 | # in which case, move it out to the TrainTrial method where the relevant
423 | # counters are being dealt with.
424 | if train:
425 | ss.Net.WtFmDWt()
426 |
427 | ss.Net.AlphaCycInit()
428 | ss.Time.AlphaCycStart()
429 | for qtr in range(4):
430 | for cyc in range(ss.Time.CycPerQtr):
431 | ss.Net.Cycle(ss.Time)
432 | if not train:
433 | ss.LogTstCyc(ss.TstCycLog, ss.Time.Cycle)
434 | ss.Time.CycleInc()
435 | if ss.ViewOn:
436 | if viewUpdt == leabra.Cycle:
437 | ss.UpdateView(train)
438 | if viewUpdt == leabra.FastSpike:
439 | if (cyc+1) % 10 == 0:
440 | ss.UpdateView(train)
441 | ss.Net.QuarterFinal(ss.Time)
442 | ss.Time.QuarterInc()
443 | if ss.ViewOn:
444 | if viewUpdt == leabra.Quarter:
445 | ss.UpdateView(train)
446 | if viewUpdt == leabra.Phase:
447 | if qtr >= 2:
448 | ss.UpdateView(train)
449 | if train:
450 | ss.Net.DWt()
451 | if ss.ViewOn and viewUpdt == leabra.AlphaCycle:
452 | ss.UpdateView(train)
453 | if ss.TstCycPlot != 0 and not train:
454 | ss.TstCycPlot.GoUpdate()
455 |
456 | def ApplyInputs(ss, en):
457 | """
458 | ApplyInputs applies input patterns from given environment.
459 | It is good practice to have this be a separate method with appropriate
460 | args so that it can be used for various different contexts
461 | (training, testing, etc).
462 | """
463 | ss.Net.InitExt() # clear any existing inputs -- not strictly necessary if always
464 | # going to the same layers, but good practice and cheap anyway
465 | inLay = leabra.Layer(ss.Net.LayerByName("Input"))
466 | outLay = leabra.Layer(ss.Net.LayerByName("Output"))
467 |
468 | inPats = en.State(inLay.Nm)
469 | if inPats != go.nil:
470 | inLay.ApplyExt(inPats)
471 |
472 | outPats = en.State(outLay.Nm)
473 | if inPats != go.nil:
474 | outLay.ApplyExt(outPats)
475 |
476 | # NOTE: this is how you can use a pandas.DataFrame() to apply inputs
477 | # we are using etable.Table instead because it provides a full GUI
478 | # for viewing your patterns, and has a more convenient API, that integrates
479 | # with the env environment interface.
480 | #
481 | # inLay = leabra.Layer(self.Net.LayerByName("Input"))
482 | # outLay = leabra.Layer(self.Net.LayerByName("Output"))
483 | # pidx = self.Trial
484 | # if not self.Sequential:
485 | # pidx = self.Porder[self.Trial]
486 | # # note: these indexes must be updated based on columns in patterns..
487 | # inp = self.Pats.iloc[pidx,1:26].values
488 | # outp = self.Pats.iloc[pidx,26:26+25].values
489 | # self.ApplyExt(inLay, inp)
490 | # self.ApplyExt(outLay, outp)
491 | #
492 | # def ApplyExt(self, lay, nparray):
493 | # flt = np.ndarray.flatten(nparray, 'C')
494 | # slc = go.Slice_float32(flt)
495 | # lay.ApplyExt1D(slc)
496 |
497 | def TrainTrial(ss):
498 | """ TrainTrial runs one trial of training using TrainEnv"""
499 | ss.TrainEnv.Step() # the Env encapsulates and manages all counter state
500 |
501 | # Key to query counters FIRST because current state is in NEXT epoch
502 | # if epoch counter has changed
503 | epc = env.CounterCur(ss.TrainEnv, env.Epoch)
504 | chg = env.CounterChg(ss.TrainEnv, env.Epoch)
505 | if chg:
506 | ss.LogTrnEpc(ss.TrnEpcLog)
507 | if ss.ViewOn and ss.TrainUpdt > leabra.AlphaCycle:
508 | ss.UpdateView(True)
509 | if epc % ss.TestInterval == 0: # note: epc is *next* so won't trigger first time
510 | ss.TestAll()
511 | if epc >= ss.MaxEpcs: # done with training..
512 | ss.RunEnd()
513 | if ss.TrainEnv.Run.Incr(): # we are done!
514 | ss.StopNow = True
515 | return
516 | else:
517 | ss.NewRun()
518 | return
519 |
520 | ss.ApplyInputs(ss.TrainEnv)
521 | ss.AlphaCyc(True) # train
522 | ss.TrialStats(True) # accumulate
523 |
524 | def RunEnd(ss):
525 | """ RunEnd is called at the end of a run -- save weights, record final log, etc here """
526 | ss.LogRun(ss.RunLog)
527 | if ss.SaveWts:
528 | fnm = ss.WeightsFileName()
529 | fmt.Printf("Saving Weights to: %v", fnm)
530 | ss.Net.SaveWtsJSON(gi.FileName(fnm))
531 |
532 | def NewRun(ss):
533 | """ NewRun intializes a new run of the model, using the TrainEnv.Run counter for the new run value """
534 | run = ss.TrainEnv.Run.Cur
535 | ss.TrainEnv.Init(run)
536 | ss.TestEnv.Init(run)
537 | ss.Time.Reset()
538 | ss.Net.InitWts()
539 | ss.InitStats()
540 | ss.TrnEpcLog.SetNumRows(0)
541 | ss.TstEpcLog.SetNumRows(0)
542 |
543 | def InitStats(ss):
544 | """ InitStats initializes all the statistics, especially important for the
545 | cumulative epoch stats -- called at start of new run """
546 | # accumulators
547 | ss.SumSSE = 0.0
548 | ss.SumAvgSSE = 0.0
549 | ss.SumCosDiff = 0.0
550 | ss.CntErr = 0.0
551 | ss.FirstZero = -1
552 | # clear rest just to make Sim look initialized
553 | ss.TrlSSE = 0.0
554 | ss.TrlAvgSSE = 0.0
555 | ss.EpcSSE = 0.0
556 | ss.EpcAvgSSE = 0.0
557 | ss.EpcPctErr = 0.0
558 | ss.EpcCosDiff = 0.0
559 |
560 | def TrialStats(ss, accum):
561 | """
562 | TrialStats computes the trial-level statistics and adds them to the epoch accumulators if
563 | accum is true. Note that we're accumulating stats here on the Sim side so the
564 | core algorithm side remains as simple as possible, and doesn't need to worry about
565 | different time-scales over which stats could be accumulated etc.
566 | You can also aggregate directly from log data, as is done for testing stats
567 | """
568 | outLay = leabra.Layer(ss.Net.LayerByName("Output"))
569 | ss.TrlCosDiff = outLay.CosDiff.Cos
570 | # 0.5 = per-unit tolerance -- right side of .5
571 | ss.TrlSSE = outLay.SSE(0.5)
572 | ss.TrlAvgSSE = ss.TrlSSE / len(outLay.Neurons)
573 | if accum:
574 | ss.SumSSE += ss.TrlSSE
575 | ss.SumAvgSSE += ss.TrlAvgSSE
576 | ss.SumCosDiff += ss.TrlCosDiff
577 | if ss.TrlSSE != 0:
578 | ss.CntErr += 1.0
579 |
580 | def TrainEpoch(ss):
581 | """ TrainEpoch runs training trials for remainder of this epoch """
582 | ss.StopNow = False
583 | curEpc = ss.TrainEnv.Epoch.Cur
584 | while True:
585 | ss.TrainTrial()
586 | if ss.StopNow or ss.TrainEnv.Epoch.Cur != curEpc:
587 | break
588 | ss.Stopped()
589 |
590 | def TrainRun(ss):
591 | """ TrainRun runs training trials for remainder of run """
592 | ss.StopNow = False
593 | curRun = ss.TrainEnv.Run.Cur
594 | while True:
595 | ss.TrainTrial()
596 | if ss.StopNow or ss.TrainEnv.Run.Cur != curRun:
597 | break
598 | ss.Stopped()
599 |
600 | def Train(ss):
601 | """ Train runs the full training from this point onward """
602 | ss.StopNow = False
603 | while True:
604 | ss.TrainTrial()
605 | if ss.StopNow:
606 | break
607 | ss.Stopped()
608 |
609 | def Stop(ss):
610 | """ Stop tells the sim to stop running """
611 | ss.StopNow = True
612 |
613 | def Stopped(ss):
614 | """ Stopped is called when a run method stops running -- updates the IsRunning flag and toolbar """
615 | ss.IsRunning = False
616 | if ss.Win != 0:
617 | ss.vp.BlockUpdates()
618 | if ss.ToolBar != go.nil:
619 | ss.ToolBar.UpdateActions()
620 | ss.vp.UnblockUpdates()
621 | ss.ClassView.Update()
622 | ss.vp.SetNeedsFullRender()
623 |
624 | ######################################
625 | # Testing
626 |
627 | def TestTrial(ss):
628 | """ TestTrial runs one trial of testing -- always sequentially presented inputs """
629 | ss.TestEnv.Step()
630 |
631 | # Query counters FIRST
632 | chg = env.CounterChg(ss.TestEnv, env.Epoch)
633 | if chg:
634 | if ss.ViewOn and ss.TestUpdt > leabra.AlphaCycle:
635 | ss.UpdateView(False)
636 | ss.LogTstEpc(ss.TstEpcLog)
637 | return
638 |
639 | ss.ApplyInputs(ss.TestEnv)
640 | ss.AlphaCyc(False) # !train
641 | ss.TrialStats(False) # !accumulate
642 | ss.LogTstTrl(ss.TstTrlLog)
643 |
644 | def TestItem(ss, idx):
645 | """ TestItem tests given item which is at given index in test item list """
646 | cur = ss.TestEnv.Trial.Cur
647 | ss.TestEnv.Trial.Cur = idx
648 | ss.TestEnv.SetTrialName()
649 | ss.ApplyInputs(ss.TestEnv)
650 | ss.AlphaCyc(False) # !train
651 | ss.TrialStats(False) # !accumulate
652 | ss.TestEnv.Trial.Cur = cur
653 |
654 | def TestAll(ss):
655 | """ TestAll runs through the full set of testing items """
656 | ss.TestEnv.Init(ss.TrainEnv.Run.Cur)
657 | while True:
658 | ss.TestTrial()
659 | chg = env.CounterChg(ss.TestEnv, env.Epoch)
660 | if chg or ss.StopNow:
661 | break
662 |
663 | def RunTestAll(ss):
664 | """ RunTestAll runs through the full set of testing items, has stop running = false at end -- for gui """
665 | ss.StopNow = False
666 | ss.TestAll()
667 | ss.Stopped()
668 |
669 | ##########################################
670 | # Params methods
671 |
672 | def ParamsName(ss):
673 | """ ParamsName returns name of current set of parameters """
674 | if ss.ParamSet == "":
675 | return "Base"
676 | return ss.ParamSet
677 |
678 | def SetParams(ss, sheet, setMsg):
679 | """
680 | SetParams sets the params for "Base" and then current ParamSet.
681 | If sheet is empty, then it applies all avail sheets (e.g., Network, Sim)
682 | otherwise just the named sheet
683 | if setMsg = true then we output a message for each param that was set.
684 | """
685 |
686 | if sheet == "":
687 | # this is important for catching typos and ensuring that all sheets can be used
688 | ss.Params.ValidateSheets(go.Slice_string(["Network", "Sim"]))
689 | ss.SetParamsSet("Base", sheet, setMsg)
690 | if ss.ParamSet != "" and ss.ParamSet != "Base":
691 | ss.SetParamsSet(ss.ParamSet, sheet, setMsg)
692 |
693 | def SetParamsSet(ss, setNm, sheet, setMsg):
694 | """
695 | SetParamsSet sets the params for given params.Set name.
696 | If sheet is empty, then it applies all avail sheets (e.g., Network, Sim)
697 | otherwise just the named sheet
698 | if setMsg = true then we output a message for each param that was set.
699 | """
700 | pset = ss.Params.SetByNameTry(setNm)
701 | if pset == go.nil:
702 | return
703 | if sheet == "" or sheet == "Network":
704 | if "Network" in pset.Sheets:
705 | netp = pset.SheetByNameTry("Network")
706 | ss.Net.ApplyParams(netp, setMsg)
707 | if sheet == "" or sheet == "Sim":
708 | if "Sim" in pset.Sheets:
709 | simp = pset.SheetByNameTry("Sim")
710 | epygiv.ApplyParams(ss, simp, setMsg)
711 | # note: if you have more complex environments with parameters, definitely add
712 | # sheets for them, e.g., "TrainEnv", "TestEnv" etc
713 |
714 | def ConfigPats(ss):
715 | # note: this is all go-based for using etable.Table instead of pandas
716 | dt = ss.Pats
717 | sc = etable.Schema()
718 | sc.append(etable.Column("Name", etensor.STRING, nilInts, nilStrs))
719 | sc.append(etable.Column("Input", etensor.FLOAT32,
720 | go.Slice_int([5, 5]), go.Slice_string(["Y", "X"])))
721 | sc.append(etable.Column("Output", etensor.FLOAT32,
722 | go.Slice_int([5, 5]), go.Slice_string(["Y", "X"])))
723 | dt.SetFromSchema(sc, 25)
724 |
725 | patgen.PermutedBinaryRows(dt.Cols[1], 6, 1, 0)
726 | patgen.PermutedBinaryRows(dt.Cols[2], 6, 1, 0)
727 | dt.SaveCSV("random_5x5_25_gen.dat", etable.Tab, True)
728 |
729 | def OpenPats(ss):
730 | dt = ss.Pats
731 | ss.Pats = dt
732 | dt.SetMetaData("name", "TrainPats")
733 | dt.SetMetaData("desc", "Training patterns")
734 | dt.OpenCSV("random_5x5_25.dat", etable.Tab)
735 | # Note: here's how to read into a pandas DataFrame
736 | # dt = pd.read_csv("random_5x5_25.dat", sep='\t')
737 | # dt = dt.drop(columns="_H:")
738 |
739 | ##########################################
740 | # Logging
741 |
742 | def RunName(ss):
743 | """
744 | RunName returns a name for this run that combines Tag and Params -- add this to
745 | any file names that are saved.
746 | """
747 | if ss.Tag != "":
748 | return ss.Tag + "_" + ss.ParamsName()
749 | else:
750 | return ss.ParamsName()
751 |
752 | def RunEpochName(ss, run, epc):
753 | """
754 | RunEpochName returns a string with the run and epoch numbers with leading zeros, suitable
755 | for using in weights file names. Uses 3, 5 digits for each.
756 | """
757 | return "%03d_%05d" % run, epc
758 |
759 | def WeightsFileName(ss):
760 | """ WeightsFileName returns default current weights file name """
761 | return ss.Net.Nm + "_" + ss.RunName() + "_" + ss.RunEpochName(ss.TrainEnv.Run.Cur, ss.TrainEnv.Epoch.Cur) + ".wts"
762 |
763 | def LogFileName(ss, lognm):
764 | """ LogFileName returns default log file name """
765 | return ss.Net.Nm + "_" + ss.RunName() + "_" + lognm + ".csv"
766 |
767 | #############################
768 | # TrnEpcLog
769 |
770 | def LogTrnEpc(ss, dt):
771 | """
772 | LogTrnEpc adds data from current epoch to a TrnEpcLog table
773 | computes epoch averages prior to logging.
774 | """
775 | row = dt.Rows
776 | ss.TrnEpcLog.SetNumRows(row + 1)
777 |
778 | hid1Lay = leabra.Layer(ss.Net.LayerByName("Hidden1"))
779 | hid2Lay = leabra.Layer(ss.Net.LayerByName("Hidden2"))
780 | outLay = leabra.Layer(ss.Net.LayerByName("Output"))
781 |
782 | # this is triggered by increment so use previous value
783 | epc = ss.TrainEnv.Epoch.Prv
784 | nt = ss.TrainEnv.Table.Len() # number of trials in view
785 |
786 | ss.EpcSSE = ss.SumSSE / nt
787 | ss.SumSSE = 0.0
788 | ss.EpcAvgSSE = ss.SumAvgSSE / nt
789 | ss.SumAvgSSE = 0.0
790 | ss.EpcPctErr = ss.CntErr / nt
791 | ss.CntErr = 0.0
792 | ss.EpcPctCor = 1.0 - ss.EpcPctErr
793 | ss.EpcCosDiff = ss.SumCosDiff / nt
794 | ss.SumCosDiff = 0.0
795 | if ss.FirstZero < 0 and ss.EpcPctErr == 0:
796 | ss.FirstZero = epc
797 |
798 | dt.SetCellFloat("Run", row, ss.TrainEnv.Run.Cur)
799 | dt.SetCellFloat("Epoch", row, epc)
800 | dt.SetCellFloat("SSE", row, ss.EpcSSE)
801 | dt.SetCellFloat("AvgSSE", row, ss.EpcAvgSSE)
802 | dt.SetCellFloat("PctErr", row, ss.EpcPctErr)
803 | dt.SetCellFloat("PctCor", row, ss.EpcPctCor)
804 | dt.SetCellFloat("CosDiff", row, ss.EpcCosDiff)
805 | dt.SetCellFloat("Hid1 ActAvg", row, hid1Lay.Pool(0).ActAvg.ActPAvgEff)
806 | dt.SetCellFloat("Hid2 ActAvg", row, hid2Lay.Pool(0).ActAvg.ActPAvgEff)
807 | dt.SetCellFloat("Out ActAvg", row, outLay.Pool(0).ActAvg.ActPAvgEff)
808 |
809 | # note: essential to use Go version of update when called from another goroutine
810 | if ss.TrnEpcPlot != 0:
811 | ss.TrnEpcPlot.GoUpdate()
812 |
813 | if ss.TrnEpcFile != 0:
814 | if ss.TrainEnv.Run.Cur == 0 and epc == 0:
815 | dt.WriteCSVHeaders(ss.TrnEpcFile, etable.Tab)
816 | dt.WriteCSVRow(ss.TrnEpcFile, row, etable.Tab)
817 |
818 | # note: this is how you log to a pandas.DataFrame
819 | # nwdat = [epc, self.EpcSSE, self.EpcAvgSSE, self.EpcPctErr, self.EpcPctCor, self.EpcCosDiff, 0, 0, 0]
820 | # nrow = len(self.EpcLog.index)
821 | # self.EpcLog.loc[nrow] = nwdat # note: this is reportedly rather slow
822 |
823 | def ConfigTrnEpcLog(ss, dt):
824 | dt.SetMetaData("name", "TrnEpcLog")
825 | dt.SetMetaData("desc", "Record of performance over epochs of training")
826 | dt.SetMetaData("read-only", "true")
827 | dt.SetMetaData("precision", str(LogPrec))
828 |
829 | sc = etable.Schema()
830 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs))
831 | sc.append(etable.Column("Epoch", etensor.INT64, nilInts, nilStrs))
832 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs))
833 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs))
834 | sc.append(etable.Column("PctErr", etensor.FLOAT64, nilInts, nilStrs))
835 | sc.append(etable.Column("PctCor", etensor.FLOAT64, nilInts, nilStrs))
836 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs))
837 | sc.append(etable.Column("Hid1 ActAvg",
838 | etensor.FLOAT64, nilInts, nilStrs))
839 | sc.append(etable.Column("Hid2 ActAvg",
840 | etensor.FLOAT64, nilInts, nilStrs))
841 | sc.append(etable.Column("Out ActAvg", etensor.FLOAT64, nilInts, nilStrs))
842 | dt.SetFromSchema(sc, 0)
843 |
844 | # note: pandas.DataFrame version
845 | # self.EpcLog = pd.DataFrame(columns=["Epoch", "SSE", "Avg SSE", "Pct Err", "Pct Cor", "CosDiff", "Hid1 ActAvg", "Hid2 ActAvg", "Out ActAvg"])
846 | # self.PlotVals = ["SSE", "Pct Err"]
847 | # self.Plot = True
848 |
849 | def ConfigTrnEpcPlot(ss, plt, dt):
850 | plt.Params.Title = "Leabra Random Associator 25 Epoch Plot"
851 | plt.Params.XAxisCol = "Epoch"
852 | plt.SetTable(dt)
853 | # order of params: on, fixMin, min, fixMax, max
854 | plt.SetColParams("Run", False, True, 0, False, 0)
855 | plt.SetColParams("Epoch", False, True, 0, False, 0)
856 | plt.SetColParams("SSE", False, True, 0, False, 0)
857 | plt.SetColParams("AvgSSE", False, True, 0, False, 0)
858 | plt.SetColParams("PctErr", True, True, 0, True, 1) # default plot
859 | plt.SetColParams("PctCor", True, True, 0, True, 1) # default plot
860 | plt.SetColParams("CosDiff", False, True, 0, True, 1)
861 | plt.SetColParams("Hid1 ActAvg", False, True, 0, True, .5)
862 | plt.SetColParams("Hid2 ActAvg", False, True, 0, True, .5)
863 | plt.SetColParams("Out ActAvg", False, True, 0, True, .5)
864 | return plt
865 |
866 | #############################
867 | # TstTrlLog
868 |
869 | def LogTstTrl(ss, dt):
870 | """
871 | LogTstTrl adds data from current epoch to the TstTrlLog table
872 | log always contains number of testing items
873 | """
874 | dt = ss.TstTrlLog
875 |
876 | inLay = leabra.Layer(ss.Net.LayerByName("Input"))
877 | hid1Lay = leabra.Layer(ss.Net.LayerByName("Hidden1"))
878 | hid2Lay = leabra.Layer(ss.Net.LayerByName("Hidden2"))
879 | outLay = leabra.Layer(ss.Net.LayerByName("Output"))
880 |
881 | # this is triggered by increment so use previous value
882 | epc = ss.TrainEnv.Epoch.Prv
883 | trl = ss.TestEnv.Trial.Cur
884 |
885 | dt.SetCellFloat("Epoch", trl, epc)
886 | dt.SetCellFloat("Trial", trl, trl)
887 | dt.SetCellString("TrialName", trl, ss.TestEnv.TrialName.Cur)
888 | dt.SetCellFloat("SSE", trl, ss.TrlSSE)
889 | dt.SetCellFloat("AvgSSE", trl, ss.TrlAvgSSE)
890 | dt.SetCellFloat("CosDiff", trl, ss.TrlCosDiff)
891 | dt.SetCellFloat("Hid1 ActM.Avg", trl, hid1Lay.Pool(0).ActM.Avg)
892 | dt.SetCellFloat("Hid2 ActM.Avg", trl, hid2Lay.Pool(0).ActM.Avg)
893 | dt.SetCellFloat("Out ActM.Avg", trl, outLay.Pool(0).ActM.Avg)
894 |
895 | if ss.InputValsTsr == 0: # re-use same tensors so not always reallocating mem
896 | ss.InputValsTsr = etensor.Float32()
897 | ss.OutputValsTsr = etensor.Float32()
898 | inLay.UnitValsTensor(ss.InputValsTsr, "Act")
899 | dt.SetCellTensor("InAct", trl, ss.InputValsTsr)
900 | outLay.UnitValsTensor(ss.OutputValsTsr, "ActM")
901 | dt.SetCellTensor("OutActM", trl, ss.OutputValsTsr)
902 | outLay.UnitValsTensor(ss.OutputValsTsr, "ActP")
903 | dt.SetCellTensor("OutActP", trl, ss.OutputValsTsr)
904 |
905 | # note: essential to use Go version of update when called from another goroutine
906 | if ss.TstTrlPlot != 0:
907 | ss.TstTrlPlot.GoUpdate()
908 |
909 | def ConfigTstTrlLog(ss, dt):
910 | inLay = leabra.Layer(ss.Net.LayerByName("Input"))
911 | outLay = leabra.Layer(ss.Net.LayerByName("Output"))
912 |
913 | dt.SetMetaData("name", "TstTrlLog")
914 | dt.SetMetaData("desc", "Record of testing per input pattern")
915 | dt.SetMetaData("read-only", "true")
916 | dt.SetMetaData("precision", str(LogPrec))
917 | nt = ss.TestEnv.Table.Len() # number in view
918 |
919 | sc = etable.Schema()
920 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs))
921 | sc.append(etable.Column("Epoch", etensor.INT64, nilInts, nilStrs))
922 | sc.append(etable.Column("Trial", etensor.INT64, nilInts, nilStrs))
923 | sc.append(etable.Column("TrialName", etensor.STRING, nilInts, nilStrs))
924 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs))
925 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs))
926 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs))
927 | sc.append(etable.Column("Hid1 ActM.Avg",
928 | etensor.FLOAT64, nilInts, nilStrs))
929 | sc.append(etable.Column("Hid2 ActM.Avg",
930 | etensor.FLOAT64, nilInts, nilStrs))
931 | sc.append(etable.Column("Out ActM.Avg",
932 | etensor.FLOAT64, nilInts, nilStrs))
933 | sc.append(etable.Column(
934 | "InAct", etensor.FLOAT64, inLay.Shp.Shp, nilStrs))
935 | sc.append(etable.Column(
936 | "OutActM", etensor.FLOAT64, outLay.Shp.Shp, nilStrs))
937 | sc.append(etable.Column(
938 | "OutActP", etensor.FLOAT64, outLay.Shp.Shp, nilStrs))
939 | dt.SetFromSchema(sc, nt)
940 |
941 | def ConfigTstTrlPlot(ss, plt, dt):
942 | plt.Params.Title = "Leabra Random Associator 25 Test Trial Plot"
943 | plt.Params.XAxisCol = "Trial"
944 | plt.SetTable(dt)
945 | # order of params: on, fixMin, min, fixMax, max
946 | plt.SetColParams("Run", False, True, 0, False, 0)
947 | plt.SetColParams("Epoch", False, True, 0, False, 0)
948 | plt.SetColParams("Trial", False, True, 0, False, 0)
949 | plt.SetColParams("TrialName", False, True, 0, False, 0)
950 | plt.SetColParams("SSE", False, True, 0, False, 0)
951 | plt.SetColParams("AvgSSE", False, True, 0, False, 0)
952 | plt.SetColParams("CosDiff", True, True, 0, True, 1)
953 | plt.SetColParams("Hid1 ActM.Avg", True, True, 0, True, .5)
954 | plt.SetColParams("Hid2 ActM.Avg", True, True, 0, True, .5)
955 | plt.SetColParams("Out ActM.Avg", True, True, 0, True, .5)
956 |
957 | plt.SetColParams("InAct", False, True, 0, True, 1)
958 | plt.SetColParams("OutActM", False, True, 0, True, 1)
959 | plt.SetColParams("OutActP", False, True, 0, True, 1)
960 | return plt
961 |
962 | #############################
963 | # TstEpcLog
964 |
965 | def LogTstEpc(ss, dt):
966 | """
967 | LogTstEpc adds data from current epoch to the TstEpcLog table
968 | log always contains number of testing items
969 | """
970 | row = dt.Rows
971 | dt.SetNumRows(row + 1)
972 |
973 | trl = ss.TstTrlLog
974 | tix = etable.NewIdxView(trl)
975 | epc = ss.TrainEnv.Epoch.Prv
976 |
977 | # note: this shows how to use agg methods to compute summary data from another
978 | # data table, instead of incrementing on the Sim
979 | dt.SetCellFloat("Run", row, ss.TrainEnv.Run.Cur)
980 | dt.SetCellFloat("Epoch", row, epc)
981 | dt.SetCellFloat("SSE", row, agg.Sum(tix, "SSE")[0])
982 | dt.SetCellFloat("AvgSSE", row, agg.Mean(tix, "AvgSSE")[0])
983 | dt.SetCellFloat("PctErr", row, agg.PropIf(
984 | tix, "SSE", lambda idx, val: val > 0)[0])
985 | dt.SetCellFloat("PctCor", row, agg.PropIf(
986 | tix, "SSE", lambda idx, val: val == 0)[0])
987 | dt.SetCellFloat("CosDiff", row, agg.Mean(tix, "CosDiff")[0])
988 |
989 | trlix = etable.NewIdxView(trl)
990 | trlix.Filter(FilterSSE)
991 |
992 | ss.TstErrLog = trlix.NewTable()
993 |
994 | allsp = split.All(trlix)
995 | split.Agg(allsp, "SSE", agg.AggSum)
996 | split.Agg(allsp, "AvgSSE", agg.AggMean)
997 | split.Agg(allsp, "InAct", agg.AggMean)
998 | split.Agg(allsp, "OutActM", agg.AggMean)
999 | split.Agg(allsp, "OutActP", agg.AggMean)
1000 |
1001 | ss.TstErrStats = allsp.AggsToTable(False)
1002 |
1003 | # note: essential to use Go version of update when called from another goroutine
1004 | if ss.TstEpcPlot != 0:
1005 | ss.TstEpcPlot.GoUpdate()
1006 |
1007 | def ConfigTstEpcLog(ss, dt):
1008 | dt.SetMetaData("name", "TstEpcLog")
1009 | dt.SetMetaData("desc", "Summary stats for testing trials")
1010 | dt.SetMetaData("read-only", "true")
1011 | dt.SetMetaData("precision", str(LogPrec))
1012 |
1013 | sc = etable.Schema()
1014 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs))
1015 | sc.append(etable.Column("Epoch", etensor.INT64, nilInts, nilStrs))
1016 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs))
1017 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs))
1018 | sc.append(etable.Column("PctErr", etensor.FLOAT64, nilInts, nilStrs))
1019 | sc.append(etable.Column("PctCor", etensor.FLOAT64, nilInts, nilStrs))
1020 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs))
1021 | dt.SetFromSchema(sc, 0)
1022 |
1023 | def ConfigTstEpcPlot(ss, plt, dt):
1024 | plt.Params.Title = "Leabra Random Associator 25 Testing Epoch Plot"
1025 | plt.Params.XAxisCol = "Epoch"
1026 | plt.SetTable(dt)
1027 | # order of params: on, fixMin, min, fixMax, max
1028 | plt.SetColParams("Run", False, True, 0, False, 0)
1029 | plt.SetColParams("Epoch", False, True, 0, False, 0)
1030 | plt.SetColParams("SSE", False, True, 0, False, 0)
1031 | plt.SetColParams("AvgSSE", False, True, 0, False, 0)
1032 | plt.SetColParams("PctErr", True, True, 0, True, 1) # default plot
1033 | plt.SetColParams("PctCor", True, True, 0, True, 1) # default plot
1034 | plt.SetColParams("CosDiff", False, True, 0, True, 1)
1035 | return plt
1036 |
1037 | #############################
1038 | # TstCycLog
1039 |
1040 | def LogTstCyc(ss, dt, cyc):
1041 | """
1042 | LogTstCyc adds data from current trial to the TstCycLog table.
1043 | log just has 100 cycles, is overwritten
1044 | """
1045 | if dt.Rows <= cyc:
1046 | dt.SetNumRows(cyc + 1)
1047 |
1048 | hid1Lay = leabra.Layer(ss.Net.LayerByName("Hidden1"))
1049 | hid2Lay = leabra.Layer(ss.Net.LayerByName("Hidden2"))
1050 | outLay = leabra.Layer(ss.Net.LayerByName("Output"))
1051 |
1052 | dt.SetCellFloat("Cycle", cyc, cyc)
1053 | dt.SetCellFloat("Hid1 Ge.Avg", cyc, hid1Lay.Pool(0).Inhib.Ge.Avg)
1054 | dt.SetCellFloat("Hid2 Ge.Avg", cyc, hid2Lay.Pool(0).Inhib.Ge.Avg)
1055 | dt.SetCellFloat("Out Ge.Avg", cyc, outLay.Pool(0).Inhib.Ge.Avg)
1056 | dt.SetCellFloat("Hid1 Act.Avg", cyc, hid1Lay.Pool(0).Inhib.Act.Avg)
1057 | dt.SetCellFloat("Hid2 Act.Avg", cyc, hid2Lay.Pool(0).Inhib.Act.Avg)
1058 | dt.SetCellFloat("Out Act.Avg", cyc, outLay.Pool(0).Inhib.Act.Avg)
1059 |
1060 | if ss.TstCycPlot != 0 and cyc % 10 == 0: # too slow to do every cyc
1061 | # note: essential to use Go version of update when called from another goroutine
1062 | ss.TstCycPlot.GoUpdate()
1063 |
1064 | def ConfigTstCycLog(ss, dt):
1065 | dt.SetMetaData("name", "TstCycLog")
1066 | dt.SetMetaData(
1067 | "desc", "Record of activity etc over one trial by cycle")
1068 | dt.SetMetaData("read-only", "true")
1069 | dt.SetMetaData("precision", str(LogPrec))
1070 | np = 100 # max cycles
1071 |
1072 | sc = etable.Schema()
1073 | sc.append(etable.Column("Cycle", etensor.INT64, nilInts, nilStrs))
1074 | sc.append(etable.Column("Hid1 Ge.Avg",
1075 | etensor.FLOAT64, nilInts, nilStrs))
1076 | sc.append(etable.Column("Hid2 Ge.Avg",
1077 | etensor.FLOAT64, nilInts, nilStrs))
1078 | sc.append(etable.Column("Out Ge.Avg", etensor.FLOAT64, nilInts, nilStrs))
1079 | sc.append(etable.Column("Hid1 Act.Avg",
1080 | etensor.FLOAT64, nilInts, nilStrs))
1081 | sc.append(etable.Column("Hid2 Act.Avg",
1082 | etensor.FLOAT64, nilInts, nilStrs))
1083 | sc.append(etable.Column("Out Act.Avg",
1084 | etensor.FLOAT64, nilInts, nilStrs))
1085 | dt.SetFromSchema(sc, np)
1086 |
1087 | def ConfigTstCycPlot(ss, plt, dt):
1088 | plt.Params.Title = "Leabra Random Associator 25 Test Cycle Plot"
1089 | plt.Params.XAxisCol = "Cycle"
1090 | plt.SetTable(dt)
1091 | # order of params: on, fixMin, min, fixMax, max
1092 | plt.SetColParams("Cycle", False, True, 0, False, 0)
1093 | plt.SetColParams("Hid1 Ge.Avg", True, True, 0, True, .5)
1094 | plt.SetColParams("Hid2 Ge.Avg", True, True, 0, True, .5)
1095 | plt.SetColParams("Out Ge.Avg", True, True, 0, True, .5)
1096 | plt.SetColParams("Hid1 Act.Avg", True, True, 0, True, .5)
1097 | plt.SetColParams("Hid2 Act.Avg", True, True, 0, True, .5)
1098 | plt.SetColParams("Out Act.Avg", True, True, 0, True, .5)
1099 | return plt
1100 |
1101 | #############################
1102 | # RunLog
1103 |
1104 | def LogRun(ss, dt):
1105 | run = ss.TrainEnv.Run.Cur # this is NOT triggered by increment yet -- use Cur
1106 | row = dt.Rows
1107 | ss.RunLog.SetNumRows(row + 1)
1108 |
1109 | epclog = ss.TrnEpcLog
1110 | # compute mean over last N epochs for run level
1111 | nlast = 10
1112 | epcix = etable.NewIdxView(epclog)
1113 | epcix.Idxs = go.Slice_int(epcix.Idxs[epcix.Len()-nlast-1:])
1114 | # print(epcix.Idxs[epcix.Len()-nlast-1:])
1115 |
1116 | params = ss.RunName() # includes tag
1117 |
1118 | dt.SetCellFloat("Run", row, run)
1119 | dt.SetCellString("Params", row, params)
1120 | dt.SetCellFloat("FirstZero", row, ss.FirstZero)
1121 | dt.SetCellFloat("SSE", row, agg.Mean(epcix, "SSE")[0])
1122 | dt.SetCellFloat("AvgSSE", row, agg.Mean(epcix, "AvgSSE")[0])
1123 | dt.SetCellFloat("PctErr", row, agg.Mean(epcix, "PctErr")[0])
1124 | dt.SetCellFloat("PctCor", row, agg.Mean(epcix, "PctCor")[0])
1125 | dt.SetCellFloat("CosDiff", row, agg.Mean(epcix, "CosDiff")[0])
1126 |
1127 | runix = etable.NewIdxView(dt)
1128 | spl = split.GroupBy(runix, go.Slice_string(["Params"]))
1129 | split.Desc(spl, "FirstZero")
1130 | split.Desc(spl, "PctCor")
1131 | ss.RunStats = spl.AggsToTable(False)
1132 |
1133 | # note: essential to use Go version of update when called from another goroutine
1134 | if ss.RunPlot != 0:
1135 | ss.RunPlot.GoUpdate()
1136 |
1137 | if ss.RunFile != 0:
1138 | if row == 0:
1139 | dt.WriteCSVHeaders(ss.RunFile, etable.Tab)
1140 | dt.WriteCSVRow(ss.RunFile, row, etable.Tab)
1141 |
1142 | def ConfigRunLog(ss, dt):
1143 | dt.SetMetaData("name", "RunLog")
1144 | dt.SetMetaData("desc", "Record of performance at end of training")
1145 | dt.SetMetaData("read-only", "true")
1146 | dt.SetMetaData("precision", str(LogPrec))
1147 |
1148 | sc = etable.Schema()
1149 | sc.append(etable.Column("Run", etensor.INT64, nilInts, nilStrs))
1150 | sc.append(etable.Column("Params", etensor.STRING, nilInts, nilStrs))
1151 | sc.append(etable.Column("FirstZero", etensor.FLOAT64, nilInts, nilStrs))
1152 | sc.append(etable.Column("SSE", etensor.FLOAT64, nilInts, nilStrs))
1153 | sc.append(etable.Column("AvgSSE", etensor.FLOAT64, nilInts, nilStrs))
1154 | sc.append(etable.Column("PctErr", etensor.FLOAT64, nilInts, nilStrs))
1155 | sc.append(etable.Column("PctCor", etensor.FLOAT64, nilInts, nilStrs))
1156 | sc.append(etable.Column("CosDiff", etensor.FLOAT64, nilInts, nilStrs))
1157 | dt.SetFromSchema(sc, 0)
1158 |
1159 | def ConfigRunPlot(ss, plt, dt):
1160 | plt.Params.Title = "Leabra Random Associator 25 Run Plot"
1161 | plt.Params.XAxisCol = "Run"
1162 | plt.SetTable(dt)
1163 | # order of params: on, fixMin, min, fixMax, max
1164 | plt.SetColParams("Run", False, True, 0, False, 0)
1165 | plt.SetColParams("FirstZero", True, True, 0, False, 0) # default plot
1166 | plt.SetColParams("SSE", False, True, 0, False, 0)
1167 | plt.SetColParams("AvgSSE", False, True, 0, False, 0)
1168 | plt.SetColParams("PctErr", False, True, 0, True, 1)
1169 | plt.SetColParams("PctCor", False, True, 0, True, 1)
1170 | plt.SetColParams("CosDiff", False, True, 0, True, 1)
1171 | return plt
1172 |
1173 | ##############################
1174 | # ConfigGui
1175 |
1176 | def ConfigGui(ss):
1177 | """ConfigGui configures the GoGi gui interface for this simulation"""
1178 | width = 1600
1179 | height = 1200
1180 |
1181 | gi.SetAppName("ra25")
1182 | gi.SetAppAbout(
1183 | 'This demonstrates a basic Leabra model. See emergent on GitHub.')
1184 |
1185 | win = gi.NewMainWindow(
1186 | "ra25", "Leabra Random Associator", width, height)
1187 | ss.Win = win
1188 |
1189 | vp = win.WinViewport2D()
1190 | ss.vp = vp
1191 | updt = vp.UpdateStart()
1192 |
1193 | mfr = win.SetMainFrame()
1194 |
1195 | tbar = gi.AddNewToolBar(mfr, "tbar")
1196 | tbar.SetStretchMaxWidth()
1197 | ss.ToolBar = tbar
1198 |
1199 | split = gi.AddNewSplitView(mfr, "split")
1200 | split.Dim = gi.X
1201 | split.SetStretchMaxWidth()
1202 | split.SetStretchMaxHeight()
1203 |
1204 | ss.ClassView = epygiv.ClassView("ra25sv", ss.Tags)
1205 | ss.ClassView.AddFrame(split)
1206 | ss.ClassView.SetClass(ss)
1207 |
1208 | tv = gi.AddNewTabView(split, "tv")
1209 |
1210 | nv = netview.NetView()
1211 | tv.AddTab(nv, "NetView")
1212 | nv.Var = "Act"
1213 | nv.SetNet(ss.Net)
1214 | ss.NetView = nv
1215 |
1216 | plt = eplot.Plot2D()
1217 | tv.AddTab(plt, "TrnEpcPlot")
1218 | ss.TrnEpcPlot = ss.ConfigTrnEpcPlot(plt, ss.TrnEpcLog)
1219 |
1220 | plt = eplot.Plot2D()
1221 | tv.AddTab(plt, "TstTrlPlot")
1222 | ss.TstTrlPlot = ss.ConfigTstTrlPlot(plt, ss.TstTrlLog)
1223 |
1224 | plt = eplot.Plot2D()
1225 | tv.AddTab(plt, "TstCycPlot")
1226 | ss.TstCycPlot = ss.ConfigTstCycPlot(plt, ss.TstCycLog)
1227 |
1228 | plt = eplot.Plot2D()
1229 | tv.AddTab(plt, "TstEpcPlot")
1230 | ss.TstEpcPlot = ss.ConfigTstEpcPlot(plt, ss.TstEpcLog)
1231 |
1232 | plt = eplot.Plot2D()
1233 | tv.AddTab(plt, "RunPlot")
1234 | ss.RunPlot = ss.ConfigRunPlot(plt, ss.RunLog)
1235 |
1236 | split.SetSplitsList(go.Slice_float32([.3, .7]))
1237 |
1238 | recv = win.This()
1239 |
1240 | tbar.AddAction(gi.ActOpts(Label="Init", Icon="update",
1241 | Tooltip="Initialize everything including network weights, and start over. Also applies current params.", UpdateFunc=UpdtFuncNotRunning), recv, InitCB)
1242 |
1243 | tbar.AddAction(gi.ActOpts(Label="Train", Icon="run", Tooltip="Starts the network training, picking up from wherever it may have left off. If not stopped, training will complete the specified number of Runs through the full number of Epochs of training, with testing automatically occuring at the specified interval.", UpdateFunc=UpdtFuncNotRunning), recv, TrainCB)
1244 |
1245 | tbar.AddAction(gi.ActOpts(Label="Stop", Icon="stop",
1246 | Tooltip="Interrupts running. Hitting Train again will pick back up where it left off.", UpdateFunc=UpdtFuncRunning), recv, StopCB)
1247 |
1248 | tbar.AddAction(gi.ActOpts(Label="Step Trial", Icon="step-fwd",
1249 | Tooltip="Advances one training trial at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepTrialCB)
1250 |
1251 | tbar.AddAction(gi.ActOpts(Label="Step Epoch", Icon="fast-fwd",
1252 | Tooltip="Advances one epoch (complete set of training patterns) at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepEpochCB)
1253 |
1254 | tbar.AddAction(gi.ActOpts(Label="Step Run", Icon="fast-fwd",
1255 | Tooltip="Advances one full training Run at a time.", UpdateFunc=UpdtFuncNotRunning), recv, StepRunCB)
1256 |
1257 | tbar.AddSeparator("test")
1258 |
1259 | tbar.AddAction(gi.ActOpts(Label="Test Trial", Icon="step-fwd",
1260 | Tooltip="Runs the next testing trial.", UpdateFunc=UpdtFuncNotRunning), recv, TestTrialCB)
1261 |
1262 | tbar.AddAction(gi.ActOpts(Label="Test Item", Icon="step-fwd",
1263 | Tooltip="Prompts for a specific input pattern name to run, and runs it in testing mode.", UpdateFunc=UpdtFuncNotRunning), recv, TestItemCB)
1264 |
1265 | tbar.AddAction(gi.ActOpts(Label="Test All", Icon="fast-fwd",
1266 | Tooltip="Tests all of the testing trials.", UpdateFunc=UpdtFuncNotRunning), recv, TestAllCB)
1267 |
1268 | tbar.AddSeparator("log")
1269 |
1270 | tbar.AddAction(gi.ActOpts(Label="Reset RunLog", Icon="reset",
1271 | Tooltip="Resets the accumulated log of all Runs, which are tagged with the ParamSet used"), recv, ResetRunLogCB)
1272 |
1273 | tbar.AddSeparator("misc")
1274 |
1275 | tbar.AddAction(gi.ActOpts(Label="New Seed", Icon="new",
1276 | Tooltip="Generate a new initial random seed to get different results. By default, Init re-establishes the same initial seed every time."), recv, NewRndSeedCB)
1277 |
1278 | tbar.AddAction(gi.ActOpts(Label="README", Icon="file-markdown",
1279 | Tooltip="Opens your browser on the README file that contains instructions for how to run this model."), recv, ReadmeCB)
1280 |
1281 | # main menu
1282 | appnm = gi.AppName()
1283 | mmen = win.MainMenu
1284 | mmen.ConfigMenus(go.Slice_string([appnm, "File", "Edit", "Window"]))
1285 |
1286 | amen = gi.Action(win.MainMenu.ChildByName(appnm, 0))
1287 | amen.Menu.AddAppMenu(win)
1288 |
1289 | emen = gi.Action(win.MainMenu.ChildByName("Edit", 1))
1290 | emen.Menu.AddCopyCutPaste(win)
1291 |
1292 | # note: Command in shortcuts is automatically translated into Control for
1293 | # Linux, Windows or Meta for MacOS
1294 | # fmen = win.MainMenu.ChildByName("File", 0).(*gi.Action)
1295 | # fmen.Menu = make(gi.Menu, 0, 10)
1296 | # fmen.Menu.AddAction(gi.ActOpts{Label: "Open", Shortcut: "Command+O"},
1297 | # recv, func(recv, send ki.Ki, sig int64, data interface{}) {
1298 | # FileViewOpenSVG(vp)
1299 | # })
1300 | # fmen.Menu.AddSeparator("csep")
1301 | # fmen.Menu.AddAction(gi.ActOpts{Label: "Close Window", Shortcut: "Command+W"},
1302 | # recv, func(recv, send ki.Ki, sig int64, data interface{}) {
1303 | # win.CloseReq()
1304 | # })
1305 |
1306 | # win.SetCloseCleanFunc(func(w *gi.Window) {
1307 | # gi.Quit() # once main window is closed, quit
1308 | # })
1309 | #
1310 | win.MainMenuUpdated()
1311 | vp.UpdateEndNoSig(updt)
1312 | win.GoStartEventLoop()
1313 |
1314 |
1315 | # TheSim is the overall state for this simulation
1316 | TheSim = Sim()
1317 |
1318 |
1319 | def usage():
1320 | print(sys.argv[0] + " --params= --tag= --setparams --wts --epclog=0 --runlog=0 --nogui")
1321 | print("\t pyleabra -i %s to run in interactive, gui mode" % sys.argv[0])
1322 | print("\t --params= additional params to apply on top of Base (name must be in loaded Params")
1323 | print("\t --tag= tag is appended to file names to uniquely identify this run")
1324 | print("\t --runs= number of runs to do")
1325 | print("\t --setparams show the parameter values that are set")
1326 | print("\t --wts save final trained weights after every run")
1327 | print("\t --epclog=0/False turn off save training epoch log data to file named by param set, tag")
1328 | print("\t --runlog=0/False turn off save run log data to file named by param set, tag")
1329 | print("\t --nogui if no other args needed, this prevents running under the gui")
1330 |
1331 |
1332 | def main(argv):
1333 | TheSim.Config()
1334 |
1335 | # print("n args: %d" % len(argv))
1336 | TheSim.NoGui = len(argv) > 1
1337 | saveEpcLog = True
1338 | saveRunLog = True
1339 |
1340 | try:
1341 | opts, args = getopt.getopt(argv, "h:", [
1342 | "params=", "tag=", "runs=", "setparams", "wts", "epclog=", "runlog=", "nogui"])
1343 | except getopt.GetoptError:
1344 | usage()
1345 | sys.exit(2)
1346 | for opt, arg in opts:
1347 | # print("opt: %s arg: %s" % (opt, arg))
1348 | if opt == '-h':
1349 | usage()
1350 | sys.exit()
1351 | elif opt == "--tag":
1352 | TheSim.Tag = arg
1353 | elif opt == "--runs":
1354 | TheSim.MaxRuns = int(arg)
1355 | print("Running %d runs" % TheSim.MaxRuns)
1356 | elif opt == "--setparams":
1357 | TheSim.LogSetParams = True
1358 | elif opt == "--wts":
1359 | TheSim.SaveWts = True
1360 | print("Saving final weights per run")
1361 | elif opt == "--epclog":
1362 | if arg.lower() == "false" or arg == "0":
1363 | saveEpcLog = False
1364 | elif opt == "--runlog":
1365 | if arg.lower() == "false" or arg == "0":
1366 | saveRunLog = False
1367 | elif opt == "--nogui":
1368 | TheSim.NoGui = True
1369 |
1370 | TheSim.Init()
1371 |
1372 | if TheSim.NoGui:
1373 | if saveEpcLog:
1374 | fnm = TheSim.LogFileName("epc")
1375 | print("Saving epoch log to: %s" % fnm)
1376 | TheSim.TrnEpcFile = efile.Create(fnm)
1377 |
1378 | if saveRunLog:
1379 | fnm = TheSim.LogFileName("run")
1380 | print("Saving run log to: %s" % fnm)
1381 | TheSim.RunFile = efile.Create(fnm)
1382 |
1383 | TheSim.Train()
1384 | else:
1385 | TheSim.ConfigGui()
1386 | print("Note: run pyleabra -i ra25.py to run in interactive mode, or just pyleabra, then 'import ra25'")
1387 | print("for non-gui background running, here are the args:")
1388 | usage()
1389 |
1390 |
1391 | main(sys.argv[1:])
1392 |
--------------------------------------------------------------------------------
/version.go:
--------------------------------------------------------------------------------
1 | // WARNING: auto-generated by Makefile release target -- run 'make release' to update
2 |
3 | package main
4 |
5 | const (
6 | Version = "v0.4.1"
7 | GitCommit = "cd0e95d" // the commit JUST BEFORE the release
8 | VersionDate = "2021-09-25 19:54" // UTC
9 | )
10 |
--------------------------------------------------------------------------------