├── vendor
├── github.com
│ ├── xyproto
│ │ ├── tinysvg
│ │ │ ├── go.sum
│ │ │ ├── go.mod
│ │ │ ├── .gitignore
│ │ │ ├── README.md
│ │ │ ├── tags.go
│ │ │ └── tinysvg.go
│ │ ├── swish
│ │ │ ├── benchmark.sh
│ │ │ ├── .travis.yml
│ │ │ ├── swish_asm.go
│ │ │ ├── go.mod
│ │ │ ├── .gitignore
│ │ │ ├── go.sum
│ │ │ ├── swish_amd64.s
│ │ │ ├── LICENSE
│ │ │ ├── swish.go
│ │ │ └── README.md
│ │ └── af
│ │ │ ├── .travis.yml
│ │ │ ├── go.mod
│ │ │ ├── .gitignore
│ │ │ ├── go.sum
│ │ │ ├── README.md
│ │ │ └── af.go
│ └── dave
│ │ └── jennifer
│ │ └── jen
│ │ ├── generics.go
│ │ ├── add.go
│ │ ├── do.go
│ │ ├── reserved.go
│ │ ├── tag.go
│ │ ├── dict.go
│ │ ├── custom.go
│ │ ├── statement.go
│ │ ├── comments.go
│ │ ├── group.go
│ │ ├── jen.go
│ │ ├── lit.go
│ │ ├── file.go
│ │ ├── tokens.go
│ │ └── hints.go
└── modules.txt
├── cmd
├── simple
│ └── main.go
├── evolve
│ └── main.go
└── statement
│ └── main.go
├── go.mod
├── .gitignore
├── mnist
└── download_extract.sh
├── TODO.md
├── diagram_test.go
├── img
├── diagram.svg
├── evolved.svg
├── best.svg
├── result.svg
├── test.svg
├── before.svg
├── after.svg
└── labels.svg
├── combine.go
├── teststatement.txt
├── af_test.go
├── utils.go
├── norm.go
├── .github
└── workflows
│ └── test.yml
├── LICENSE
├── config.go
├── neuron_test.go
├── statement_test.go
├── go.sum
├── README.md
├── diagram.go
├── af.go
├── network_test.go
├── neuron.go
├── evolve.go
└── network.go
/vendor/github.com/xyproto/tinysvg/go.sum:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/generics.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/benchmark.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | go test -bench=.
3 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/tinysvg/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/xyproto/tinysvg
2 |
3 | go 1.9
4 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/af/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 |
3 | go:
4 | - "1.11"
5 | - "1.12"
6 | - "1.13"
7 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 |
3 | go:
4 | - "1.11"
5 | - "1.12"
6 | - "1.13"
7 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/af/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/xyproto/af
2 |
3 | go 1.11
4 |
5 | require github.com/xyproto/swish v1.3.0
6 |
--------------------------------------------------------------------------------
/cmd/simple/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/xyproto/wann"
7 | )
8 |
9 | func main() {
10 | fmt.Println(wann.NewNetwork())
11 | }
12 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/tinysvg/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !*.*
3 | !*/
4 | *.o
5 | *.swp
6 | *.tmp
7 | *.bak
8 | *.pro.user
9 | *CMakeFiles*
10 | *.dblite
11 | *.so
12 | .vscode
13 | *.svg
14 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/xyproto/wann
2 |
3 | go 1.11
4 |
5 | require (
6 | github.com/dave/jennifer v1.5.0
7 | github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3
8 | github.com/xyproto/tinysvg v1.0.1
9 | )
10 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/swish_asm.go:
--------------------------------------------------------------------------------
1 | //+build amd64
2 |
3 | package swish
4 |
5 | // SwishAssembly is the swish function, written in hand-optimized assembly
6 | // go: noescape
7 | func SwishAssembly(x float64) float64
8 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/xyproto/swish
2 |
3 | go 1.11
4 |
5 | require (
6 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37
7 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a // indirect
8 | )
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !*.*
3 | !*/
4 | *.o
5 | *.swp
6 | *.tmp
7 | *.bak
8 | *.pro.user
9 | *.dblite
10 | *.so
11 | *.swo
12 | *.ao
13 | *.pyc
14 | *.orig
15 | *.pb.go
16 | *~
17 | ._*
18 | *.nfs.*
19 | .vscode
20 | *CMakeFiles*
21 | _obj
22 | _test
23 | _testmain.go
24 | !img/*
25 | network.svg
26 |
--------------------------------------------------------------------------------
/vendor/modules.txt:
--------------------------------------------------------------------------------
1 | # github.com/dave/jennifer v1.5.0
2 | github.com/dave/jennifer/jen
3 | # github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3
4 | github.com/xyproto/af
5 | # github.com/xyproto/swish v1.3.0
6 | github.com/xyproto/swish
7 | # github.com/xyproto/tinysvg v1.0.1
8 | github.com/xyproto/tinysvg
9 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.exe~
4 | *.dll
5 | *.so
6 | *.dylib
7 |
8 | # Test binary, build with `go test -c`
9 | *.test
10 |
11 | # Output of the go coverage tool, specifically when used with LiteIDE
12 | *.out
13 |
14 | cmd/graph/graph
15 | cmd/precision/precision
16 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/af/.gitignore:
--------------------------------------------------------------------------------
1 | # Binaries for programs and plugins
2 | *.exe
3 | *.dll
4 | *.so
5 | *.dylib
6 |
7 | # Test binary, build with `go test -c`
8 | *.test
9 |
10 | # Output of the go coverage tool, specifically when used with LiteIDE
11 | *.out
12 |
13 | # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736
14 | .glide/
15 |
16 | .DS_Store
17 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/af/go.sum:
--------------------------------------------------------------------------------
1 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37/go.mod h1:u9UyCz2eTrSGy6fbupqJ54eY5c4IC8gREQ1053dK12U=
2 | github.com/xyproto/swish v1.3.0 h1:qiVl2UeqkMqDJKjRLVwRprD7Vzz2pbNRcWTb3s0TpjE=
3 | github.com/xyproto/swish v1.3.0/go.mod h1:IVhz2R80pNsPaSxbEcqR84i4MVL0b06HN7KIMhv38WA=
4 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
5 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/go.sum:
--------------------------------------------------------------------------------
1 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37 h1:uxxtrnACqI9zK4ENDMf0WpXfUsHP5V8liuq5QdgDISU=
2 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37/go.mod h1:u9UyCz2eTrSGy6fbupqJ54eY5c4IC8gREQ1053dK12U=
3 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a h1:sPlwkA5W19gtxRApEyGyqWg4ngTrMzOJ43fOsWrgYEE=
4 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
5 |
--------------------------------------------------------------------------------
/mnist/download_extract.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | echo 'downloading training set images'
4 | curl -O 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
5 |
6 | echo 'downloading training set labels'
7 | curl -O 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
8 |
9 | echo 'downloading test set images'
10 | curl -O 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
11 |
12 | echo 'downloading test set labels'
13 | curl -O 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
14 |
15 | echo 'extracting'
16 | gunzip -v *.gz
17 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/add.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | // Add appends the provided items to the statement.
4 | func Add(code ...Code) *Statement {
5 | return newStatement().Add(code...)
6 | }
7 |
8 | // Add appends the provided items to the statement.
9 | func (g *Group) Add(code ...Code) *Statement {
10 | s := Add(code...)
11 | g.items = append(g.items, s)
12 | return s
13 | }
14 |
15 | // Add appends the provided items to the statement.
16 | func (s *Statement) Add(code ...Code) *Statement {
17 | *s = append(*s, code...)
18 | return s
19 | }
20 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | # TODO
2 |
3 | - [ ] Train and test with the Mnist dataset.
4 | - [ ] Fix any remaining issues with drawing SVG diagrams.
5 | - [ ] Fix any remaining issues with generating expressions.
6 | - [ ] Draw an "O" on the output node in the diagram.
7 | - [ ] Fix an issue with mutating the Network structs (the Neurons needs to be mutated too. And there are pointers everywhere to them).
8 | - [ ] Store all neurons within the network, but keep pointers to input nodes (and the output node might work).
9 | Then all those neurons can be modified when the Network mutates, but the pointers can be kept the same.
10 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/tinysvg/README.md:
--------------------------------------------------------------------------------
1 | # tinysvg [](https://goreportcard.com/report/github.com/xyproto/tinysvg) [](https://godoc.org/github.com/xyproto/tinysvg)
2 |
3 | Construct SVG documents and images using Go.
4 |
5 | This package mainly uses `[]byte` slices instead of strings, and does not indent the generated SVG daata, for performance and compactness.
6 |
7 | ## General info
8 |
9 | * Version: 1.0.1
10 | * Author: Alexander F. Rødseth <xyproto@archlinux.org>
11 | * License: MIT
12 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/do.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | // Do calls the provided function with the statement as a parameter. Use for
4 | // embedding logic.
5 | func Do(f func(*Statement)) *Statement {
6 | return newStatement().Do(f)
7 | }
8 |
9 | // Do calls the provided function with the statement as a parameter. Use for
10 | // embedding logic.
11 | func (g *Group) Do(f func(*Statement)) *Statement {
12 | s := Do(f)
13 | g.items = append(g.items, s)
14 | return s
15 | }
16 |
17 | // Do calls the provided function with the statement as a parameter. Use for
18 | // embedding logic.
19 | func (s *Statement) Do(f func(*Statement)) *Statement {
20 | f(s)
21 | return s
22 | }
23 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/swish_amd64.s:
--------------------------------------------------------------------------------
1 | #include "textflag.h"
2 |
3 | DATA expodata<>+0(SB)/8, $1.0
4 | DATA expodata<>+8(SB)/8, $-0.00390625
5 | GLOBL expodata<>+0(SB), RODATA, $16
6 |
7 | TEXT ·SwishAssembly(SB),NOSPLIT|NOPTR,$0-16
8 | // x+0(FP) is the given argument
9 | MOVSD x+0(FP), X1
10 | MOVSD x+0(FP), X2
11 | // x1 *= -0.00390625 which is (1/256)
12 | MULSD expodata<>+8(SB), X1
13 | // x1 += 1.0
14 | MOVSD expodata<>+0(SB), X3
15 | ADDSD X3, X1
16 | // x1 *= x1 ...
17 | MULSD X1, X1
18 | MULSD X1, X1
19 | MULSD X1, X1
20 | MULSD X1, X1
21 | MULSD X1, X1
22 | MULSD X1, X1
23 | MULSD X1, X1
24 | MULSD X1, X1
25 | // x1 += 1.0
26 | ADDSD X3, X1
27 | // x2 /= x1
28 | DIVSD X1, X2
29 | // return x2
30 | MOVSD X2, ret+8(FP)
31 | // done, jump back
32 | RET
33 |
--------------------------------------------------------------------------------
/diagram_test.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "math/rand"
5 | "os"
6 | "testing"
7 | )
8 |
9 | func TestDiagram(t *testing.T) {
10 | rand.Seed(commonSeed)
11 | net := NewNetwork(&Config{
12 | inputs: 5,
13 | InitialConnectionRatio: 0.5,
14 | sharedWeight: 0.5,
15 | })
16 |
17 | // Set a few activation functions
18 | net.AllNodes[net.InputNodes[0]].ActivationFunction = Linear
19 | net.AllNodes[net.InputNodes[1]].ActivationFunction = Swish
20 | net.AllNodes[net.InputNodes[2]].ActivationFunction = Gauss
21 | net.AllNodes[net.InputNodes[3]].ActivationFunction = Sigmoid
22 | net.AllNodes[net.InputNodes[4]].ActivationFunction = ReLU
23 |
24 | // Save the diagram as an image
25 | err := net.WriteSVG("test.svg")
26 | if err != nil {
27 | t.Error(err)
28 | }
29 | os.Remove("test.svg")
30 | }
31 |
--------------------------------------------------------------------------------
/img/diagram.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/combine.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "github.com/dave/jennifer/jen"
5 | )
6 |
7 | // ActivationStatement creates an activation function statment, given a weight and input statements
8 | // returns: activationFunction(input0 * w + input1 * w + ...)
9 | // The function calling this function is responsible for inserting network input values into the network input nodes.
10 | func ActivationStatement(af ActivationFunctionIndex, w float64, inputStatements []*jen.Statement) *jen.Statement {
11 | // activationFunction(input0 * w + input1 * w + ...)
12 | weightedSum := jen.Empty()
13 | for i, inputStatement := range inputStatements {
14 | if i == 0 {
15 | // first
16 | weightedSum.Add(inputStatement).Op("*").Lit(w)
17 | } else {
18 | // the rest, same as above, but with a leading "+"
19 | weightedSum.Op("+").Add(inputStatement).Op("*").Lit(w)
20 | }
21 | }
22 | return af.Statement(weightedSum)
23 | }
24 |
--------------------------------------------------------------------------------
/teststatement.txt:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | )
7 |
8 | func f(inputData []float64) float64 {
9 | return -(math.Log(1.0+math.Exp(inputData[0]))*1.999999999999592 + math.Exp(-(math.Pow(inputData[1], 2.0))/2.0)*1.999999999999592 + math.Pow(inputData[2], 2.0)*1.999999999999592 + math.Sin((inputData[3])*math.Pi)*1.999999999999592 + math.Abs(inputData[4])*1.999999999999592)
10 | }
11 |
12 | func main() {
13 | up := []float64{
14 | 0.0, 1.0, 0.0, // o
15 | 1.0, 1.0, 1.0} // ooo
16 |
17 | down := []float64{
18 | 1.0, 1.0, 1.0, // ooo
19 | 0.0, 1.0, 0.0} // o
20 |
21 | left := []float64{
22 | 1.0, 1.0, 1.0, // ooo
23 | 0.0, 0.0, 1.0} // o
24 |
25 | right := []float64{
26 | 1.0, 1.0, 1.0, // ooo
27 | 1.0, 0.0, 0.0} // o
28 |
29 | fmt.Println("up score", f(up))
30 | fmt.Println("down score", f(down))
31 | fmt.Println("left score", f(left))
32 | fmt.Println("right score", f(right))
33 | }
34 |
35 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/reserved.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | var reserved = []string{
4 | /* keywords */
5 | "break", "default", "func", "interface", "select", "case", "defer", "go", "map", "struct", "chan", "else", "goto", "package", "switch", "const", "fallthrough", "if", "range", "type", "continue", "for", "import", "return", "var",
6 | /* predeclared */
7 | "bool", "byte", "complex64", "complex128", "error", "float32", "float64", "int", "int8", "int16", "int32", "int64", "rune", "string", "uint", "uint8", "uint16", "uint32", "uint64", "uintptr", "true", "false", "iota", "nil", "append", "cap", "close", "complex", "copy", "delete", "imag", "len", "make", "new", "panic", "print", "println", "real", "recover",
8 | /* common variables */
9 | "err",
10 | }
11 |
12 | // IsReservedWord returns if this is a reserved word in go
13 | func IsReservedWord(alias string) bool {
14 | for _, name := range reserved {
15 | if alias == name {
16 | return true
17 | }
18 | }
19 | return false
20 | }
21 |
--------------------------------------------------------------------------------
/af_test.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "fmt"
5 |
6 | "github.com/dave/jennifer/jen"
7 | )
8 |
9 | func ExampleActivationFunctionIndex_Call() {
10 | fmt.Println(Gauss.Call(2.0))
11 | // Output:
12 | // 0.13427659965015956
13 | }
14 |
15 | func ExampleGauss_Statement() {
16 | statement := Gauss.Statement(jen.Id("x"))
17 | fmt.Println(statement.GoString())
18 | result, err := RunStatementX(statement, 0.5)
19 | if err != nil {
20 | panic(err)
21 | }
22 | fmt.Println(result)
23 | // Output:
24 | // math.Exp(-(math.Pow(x, 2.0)) / 2.0)
25 | // 0.8824969025845955
26 | }
27 |
28 | func ExampleActivationFunctionIndex_GoRun() {
29 | // Run the Gauss function directly
30 | fmt.Println(ActivationFunctions[Gauss](0.5))
31 | // Use Jennifer to generate a source file just for running the Gauss function, then use "go run" and fetch the result
32 | if result, err := Gauss.GoRun(0.5); err == nil { // no error
33 | fmt.Println(result)
34 | }
35 | // Output:
36 | // 0.8824699625576026
37 | // 0.8824969025845955
38 | }
39 |
--------------------------------------------------------------------------------
/utils.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "sort"
5 | )
6 |
7 | // Pair is used for sorting dictionaries by value.
8 | // Thanks https://stackoverflow.com/a/18695740/131264
9 | type Pair struct {
10 | Key int
11 | Value float64
12 | }
13 |
14 | // PairList is a slice of Pair
15 | type PairList []Pair
16 |
17 | func (p PairList) Len() int { return len(p) }
18 | func (p PairList) Less(i, j int) bool { return p[i].Value < p[j].Value }
19 | func (p PairList) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
20 |
21 | // SortByValue sorts a map[int]float64 by value
22 | func SortByValue(m map[int]float64) PairList {
23 | pl := make(PairList, len(m))
24 | i := 0
25 | for k, v := range m {
26 | pl[i] = Pair{k, v}
27 | i++
28 | }
29 | sort.Sort(sort.Reverse(pl))
30 | return pl
31 | }
32 |
33 | // In returns true if this NeuronIndex is in the given *[]NeuronIndex slice
34 | func (ni NeuronIndex) In(nodes *[]NeuronIndex) bool {
35 | for _, ni2 := range *nodes {
36 | if ni2 == ni {
37 | return true
38 | }
39 | }
40 | return false
41 | }
42 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Alexander F. Rødseth
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/norm.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | // NormalizationInfo contains if and how the score function should be normalized
4 | type NormalizationInfo struct {
5 | shouldNormalize bool
6 | mul, add float64
7 | }
8 |
9 | // NewNormalizationInfo returns a new struct, containing if and how the score function should be normalized
10 | func NewNormalizationInfo(enable bool) *NormalizationInfo {
11 | return &NormalizationInfo{enable, 0.0, 1.0}
12 | }
13 |
14 | // Enable signifies that normalization is enabled when this struct is used
15 | func (norm *NormalizationInfo) Enable() {
16 | norm.shouldNormalize = true
17 | }
18 |
19 | // Disable signifies that normalization is disabled when this struct is used
20 | func (norm *NormalizationInfo) Disable() {
21 | norm.shouldNormalize = false
22 | }
23 |
24 | // Get retrieves the multiplication and addition numbers that can be used for normalization
25 | func (norm *NormalizationInfo) Get() (float64, float64) {
26 | return norm.mul, norm.add
27 | }
28 |
29 | // Set sets the multiplication and addition numbers that can be used for normalization
30 | func (norm *NormalizationInfo) Set(mul, add float64) {
31 | norm.mul = mul
32 | norm.add = add
33 | }
34 |
--------------------------------------------------------------------------------
/img/evolved.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/af/README.md:
--------------------------------------------------------------------------------
1 | # af [](https://travis-ci.org/xyproto/af) [](https://goreportcard.com/report/github.com/xyproto/af) [](https://godoc.org/github.com/xyproto/af)
2 |
3 | Activation functions for neural networks.
4 |
5 | These activation functions are included:
6 |
7 | * Swish (`x / (1 + exp(-x))`)
8 | * Sigmoid (`1 / (1 + exp(-x))`)
9 | * SoftPlus (`log(1 + exp(x))`)
10 | * Gaussian01 (`exp(-(x * x) / 2.0)`)
11 | * Sin (`math.Sin(math.Pi * x)`)
12 | * Cos (`math.Cos(math.Pi * x)`)
13 | * Linear (`x`)
14 | * Inv (`-x`)
15 | * ReLU (`x >= 0 ? x : 0`)
16 | * Squared (`x * x`)
17 |
18 | These `math` functions are included just for convenience:
19 |
20 | * Abs (`math.Abs`)
21 | * Tanh (`math.Tanh`)
22 |
23 | One functions that takes two arguments is also included:
24 |
25 | * PReLU (`x >= 0 ? x : x * a`)
26 |
27 | ## Requirements
28 |
29 | * Go 1.11 or later.
30 |
31 | ## General information
32 |
33 | * License: MIT
34 | * Version: 0.3.2
35 | * Author: Alexander F. Rødseth <xyproto@archlinux.org>
36 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | on: [push, pull_request]
2 | name: Build
3 | env:
4 | GO111MODULE: on
5 | jobs:
6 | test:
7 | strategy:
8 | matrix:
9 | go-version: [1.13.x, 1.14.x, 1.15.x, 1.16.x, 1.17.x]
10 | os: [ubuntu-latest, macos-latest]
11 | runs-on: ${{ matrix.os }}
12 | steps:
13 | - name: Install Go
14 | uses: actions/setup-go@v2
15 | with:
16 | go-version: ${{ matrix.go-version }}
17 | - name: Checkout code
18 | uses: actions/checkout@v2
19 | - name: Test
20 | run: go test ./...
21 |
22 | test-cache:
23 | runs-on: ubuntu-latest
24 | steps:
25 | - name: Install Go
26 | uses: actions/setup-go@v2
27 | with:
28 | go-version: 1.17.x
29 | - name: Checkout code
30 | uses: actions/checkout@v2
31 | - uses: actions/cache@v2
32 | with:
33 | path: |
34 | ~/go/pkg/mod # Module download cache
35 | ~/.cache/go-build # Build cache (Linux)
36 | ~/Library/Caches/go-build # Build cache (Mac)
37 | '%LocalAppData%\go-build' # Build cache (Windows)
38 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
39 | restore-keys: |
40 | ${{ runner.os }}-go-
41 | - name: Test
42 | run: go test ./...
43 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/af/af.go:
--------------------------------------------------------------------------------
1 | // Package af provides several activation functions that can be used in neural networks
2 | package af
3 |
4 | import (
5 | "github.com/xyproto/swish"
6 | "math"
7 | )
8 |
9 | // The swish package offers optimized Swish, Sigmoid
10 | // SoftPlus and Gaussian01 activation functions
11 | var (
12 | Sigmoid = swish.Sigmoid
13 | Swish = swish.Swish
14 | SoftPlus = swish.SoftPlus
15 | Gaussian01 = swish.Gaussian01
16 | Linear = func(x float64) float64 { return x }
17 | Inv = func(x float64) float64 { return -x }
18 | Sin = func(x float64) float64 { return math.Sin(math.Pi * x) }
19 | Cos = func(x float64) float64 { return math.Cos(math.Pi * x) }
20 | Squared = func(x float64) float64 { return x * x }
21 | Tanh = math.Tanh
22 | Abs = math.Abs
23 | )
24 |
25 | // Step function
26 | func Step(x float64) float64 {
27 | if x >= 0 {
28 | return 1
29 | }
30 | return 0
31 | }
32 |
33 | // ReLU is the "rectified linear unit"
34 | // `x >= 0 ? x : 0`
35 | func ReLU(x float64) float64 {
36 | if x >= 0 {
37 | return x
38 | }
39 | return 0
40 | }
41 |
42 | // PReLU is the parametric rectified linear unit.
43 | // `x >= 0 ? x : a * x`
44 | func PReLU(x, a float64) float64 {
45 | if x >= 0 {
46 | return x
47 | }
48 | return a * x
49 | }
50 |
--------------------------------------------------------------------------------
/img/best.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/img/result.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2022 Alexander F. Rødseth
2 |
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 |
5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6 |
7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 |
9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10 |
11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
12 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/tag.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "sort"
7 | "strconv"
8 | )
9 |
10 | // Tag renders a struct tag
11 | func Tag(items map[string]string) *Statement {
12 | return newStatement().Tag(items)
13 | }
14 |
15 | // Tag renders a struct tag
16 | func (g *Group) Tag(items map[string]string) *Statement {
17 | // notest
18 | // don't think this can ever be used in valid code?
19 | s := Tag(items)
20 | g.items = append(g.items, s)
21 | return s
22 | }
23 |
24 | // Tag renders a struct tag
25 | func (s *Statement) Tag(items map[string]string) *Statement {
26 | c := tag{
27 | items: items,
28 | }
29 | *s = append(*s, c)
30 | return s
31 | }
32 |
33 | type tag struct {
34 | items map[string]string
35 | }
36 |
37 | func (t tag) isNull(f *File) bool {
38 | return len(t.items) == 0
39 | }
40 |
41 | func (t tag) render(f *File, w io.Writer, s *Statement) error {
42 |
43 | if t.isNull(f) {
44 | // notest
45 | // render won't be called if t is null
46 | return nil
47 | }
48 |
49 | var str string
50 |
51 | var sorted []string
52 | for k := range t.items {
53 | sorted = append(sorted, k)
54 | }
55 | sort.Strings(sorted)
56 |
57 | for _, k := range sorted {
58 | v := t.items[k]
59 | if len(str) > 0 {
60 | str += " "
61 | }
62 | str += fmt.Sprintf(`%s:%q`, k, v)
63 | }
64 |
65 | if strconv.CanBackquote(str) {
66 | str = "`" + str + "`"
67 | } else {
68 | str = strconv.Quote(str)
69 | }
70 |
71 | if _, err := w.Write([]byte(str)); err != nil {
72 | return err
73 | }
74 |
75 | return nil
76 | }
77 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/swish.go:
--------------------------------------------------------------------------------
1 | package swish
2 |
3 | import "math"
4 |
5 | // Thanks https://codingforspeed.com/using-faster-exponential-approximation/
6 | func exp256(x float64) float64 {
7 | x = 1.0 + x/256.0
8 | x *= x
9 | x *= x
10 | x *= x
11 | x *= x
12 | x *= x
13 | x *= x
14 | x *= x
15 | x *= x
16 | return x
17 | }
18 |
19 | // Swish is the x / (1 + exp(-x)) activation function, using exp256
20 | func Swish(x float64) float64 {
21 | return x / (1.0 + exp256(-x))
22 | }
23 |
24 | // Sigmoid is the 1 / (1 + exp(-x)) activation function, using exp256
25 | func Sigmoid(x float64) float64 {
26 | // Uses exp256 instead of math.Exp
27 | return 1.0 / (1.0 + exp256(-x))
28 | }
29 |
30 | // SoftPlus is the log(1 + exp(x)) function, using exp256
31 | func SoftPlus(x float64) float64 {
32 | return math.Log(1.0 + exp256(x))
33 | }
34 |
35 | // Gaussian01 is the Gaussian function with mean 0 and sigma 1, using exp256
36 | func Gaussian01(x float64) float64 {
37 | return exp256(-(x * x) / 2.0)
38 | }
39 |
40 | // SwishPrecise is the x / (1 + exp(-x)) activation function, using math.Exp
41 | func SwishPrecise(x float64) float64 {
42 | return x / (1.0 + math.Exp(-x))
43 | }
44 |
45 | // SigmoidPrecise is the 1 / (1 + exp(-x)) activation function, using math.Exp
46 | func SigmoidPrecise(x float64) float64 {
47 | return 1.0 / (1.0 + math.Exp(-x))
48 | }
49 |
50 | // SoftPlusPrecise is the log(1 + exp(x)) function, using math.Exp
51 | func SoftPlusPrecise(x float64) float64 {
52 | return math.Log(1.0 + math.Exp(x))
53 | }
54 |
55 | // Gaussian01 is the Gaussian function with mean 0 and sigma 1, using math.Exp
56 | func Gaussian01Precise(x float64) float64 {
57 | return math.Exp(-(x * x) / 2.0)
58 | }
59 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/dict.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "sort"
7 | )
8 |
9 | // Dict renders as key/value pairs. Use with Values for map or composite
10 | // literals.
11 | type Dict map[Code]Code
12 |
13 | // DictFunc executes a func(Dict) to generate the value. Use with Values for
14 | // map or composite literals.
15 | func DictFunc(f func(Dict)) Dict {
16 | d := Dict{}
17 | f(d)
18 | return d
19 | }
20 |
21 | func (d Dict) render(f *File, w io.Writer, s *Statement) error {
22 | first := true
23 | // must order keys to ensure repeatable source
24 | type kv struct {
25 | k Code
26 | v Code
27 | }
28 | lookup := map[string]kv{}
29 | keys := []string{}
30 | for k, v := range d {
31 | if k.isNull(f) || v.isNull(f) {
32 | continue
33 | }
34 | buf := &bytes.Buffer{}
35 | if err := k.render(f, buf, nil); err != nil {
36 | return err
37 | }
38 | keys = append(keys, buf.String())
39 | lookup[buf.String()] = kv{k: k, v: v}
40 | }
41 | sort.Strings(keys)
42 | for _, key := range keys {
43 | k := lookup[key].k
44 | v := lookup[key].v
45 | if first && len(keys) > 1 {
46 | if _, err := w.Write([]byte("\n")); err != nil {
47 | return err
48 | }
49 | first = false
50 | }
51 | if err := k.render(f, w, nil); err != nil {
52 | return err
53 | }
54 | if _, err := w.Write([]byte(":")); err != nil {
55 | return err
56 | }
57 | if err := v.render(f, w, nil); err != nil {
58 | return err
59 | }
60 | if len(keys) > 1 {
61 | if _, err := w.Write([]byte(",\n")); err != nil {
62 | return err
63 | }
64 | }
65 | }
66 | return nil
67 | }
68 |
69 | func (d Dict) isNull(f *File) bool {
70 | if d == nil || len(d) == 0 {
71 | return true
72 | }
73 | for k, v := range d {
74 | if !k.isNull(f) && !v.isNull(f) {
75 | // if any of the key/value pairs are both not null, the Dict is not
76 | // null
77 | return false
78 | }
79 | }
80 | return true
81 | }
82 |
--------------------------------------------------------------------------------
/img/test.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/config.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "time"
7 | )
8 |
9 | // Config is a struct that is used when initializing new Network structs.
10 | // The idea is that referring to fields by name is more explicit, and that it can
11 | // be re-used in connection with having a configuration file, in the future.
12 | type Config struct {
13 | // Number of input neurons (inputs per slice of floats in inputData in the Evolve function)
14 | inputs int
15 | // When initializing a network, this is the propability that the node will be connected to the output node
16 | InitialConnectionRatio float64
17 | // sharedWeight is the weight that is shared by all nodes, since this is a Weight Agnostic Neural Network
18 | sharedWeight float64
19 | // How many generations to train for, at a maximum?
20 | Generations int
21 | // How large population sizes to use per generation?
22 | PopulationSize int
23 | // For how many generations should the training go on, without any improvement in the best score? Disabled if 0.
24 | MaxIterationsWithoutBestImprovement int
25 | // RandomSeed, for initializing the random number generator. The current time is used for the seed if this is set to 0.
26 | RandomSeed int64
27 | // Verbose output
28 | Verbose bool
29 | // Has the pseudo-random number generator been seeded and the activation function complexity been estimated yet?
30 | initialized bool
31 | }
32 |
33 | // initialize the pseaudo-random number generator, either using the config.RandomSeed or the time
34 | func (config *Config) initRandom() {
35 | randomSeed := config.RandomSeed
36 | if config.RandomSeed == 0 {
37 | randomSeed = time.Now().UTC().UnixNano()
38 | }
39 | if config.Verbose {
40 | fmt.Println("Using random seed:", randomSeed)
41 | }
42 | // Initialize the pseudo-random number generator
43 | rand.Seed(randomSeed)
44 | }
45 |
46 | // Init will initialize the pseudo-random number generator and estimate the complexity of the available activation functions
47 | func (config *Config) Init() {
48 | config.initRandom()
49 | config.estimateComplexity()
50 | config.initialized = true
51 | }
52 |
--------------------------------------------------------------------------------
/cmd/evolve/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "os"
6 |
7 | "github.com/xyproto/wann"
8 | )
9 |
10 | func main() {
11 | // Here are four shapes, representing: up, down, left and right:
12 |
13 | up := []float64{
14 | 0.0, 1.0, 0.0, // o
15 | 1.0, 1.0, 1.0} // ooo
16 |
17 | down := []float64{
18 | 1.0, 1.0, 1.0, // ooo
19 | 0.0, 1.0, 0.0} // o
20 |
21 | left := []float64{
22 | 1.0, 1.0, 1.0, // ooo
23 | 0.0, 0.0, 1.0} // o
24 |
25 | right := []float64{
26 | 1.0, 1.0, 1.0, // ooo
27 | 1.0, 0.0, 0.0} // o
28 |
29 | // Prepare the input data as a 2D slice
30 | inputData := [][]float64{
31 | up,
32 | down,
33 | left,
34 | right,
35 | }
36 |
37 | // Target scores for: up, down, left, right
38 | correctResultsForUp := []float64{1.0, -1.0, -1.0, -1.0}
39 |
40 | // Prepare a neural network configuration struct
41 | config := &wann.Config{
42 | InitialConnectionRatio: 0.2,
43 | Generations: 2000,
44 | PopulationSize: 500,
45 | Verbose: true,
46 | }
47 |
48 | // Evolve a network, using the input data and the sought after results
49 | trainedNetwork, err := config.Evolve(inputData, correctResultsForUp)
50 | if err != nil {
51 | fmt.Fprintf(os.Stderr, "error: %s\n", err)
52 | os.Exit(1)
53 | }
54 |
55 | // Now to test the trained network on 4 different inputs and see if it passes the test
56 | upScore := trainedNetwork.Evaluate(up)
57 | downScore := trainedNetwork.Evaluate(down)
58 | leftScore := trainedNetwork.Evaluate(left)
59 | rightScore := trainedNetwork.Evaluate(right)
60 |
61 | if config.Verbose {
62 | if upScore > downScore && upScore > leftScore && upScore > rightScore {
63 | fmt.Println("Network training complete, the results are good.")
64 | } else {
65 | fmt.Println("Network training complete, but the results did not pass the test.")
66 | }
67 | }
68 |
69 | // Save the trained network as an SVG image
70 | if config.Verbose {
71 | fmt.Print("Writing network.svg...")
72 | }
73 | if err := trainedNetwork.WriteSVG("network.svg"); err != nil {
74 | fmt.Fprintf(os.Stderr, "error: %s\n", err)
75 | os.Exit(1)
76 | }
77 | if config.Verbose {
78 | fmt.Println("ok")
79 | }
80 | }
81 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/custom.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | // Options specifies options for the Custom method
4 | type Options struct {
5 | Open string
6 | Close string
7 | Separator string
8 | Multi bool
9 | }
10 |
11 | // Custom renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator.
12 | func Custom(options Options, statements ...Code) *Statement {
13 | return newStatement().Custom(options, statements...)
14 | }
15 |
16 | // Custom renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator.
17 | func (g *Group) Custom(options Options, statements ...Code) *Statement {
18 | s := Custom(options, statements...)
19 | g.items = append(g.items, s)
20 | return s
21 | }
22 |
23 | // Custom renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator.
24 | func (s *Statement) Custom(options Options, statements ...Code) *Statement {
25 | g := &Group{
26 | close: options.Close,
27 | items: statements,
28 | multi: options.Multi,
29 | name: "custom",
30 | open: options.Open,
31 | separator: options.Separator,
32 | }
33 | *s = append(*s, g)
34 | return s
35 | }
36 |
37 | // CustomFunc renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator.
38 | func CustomFunc(options Options, f func(*Group)) *Statement {
39 | return newStatement().CustomFunc(options, f)
40 | }
41 |
42 | // CustomFunc renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator.
43 | func (g *Group) CustomFunc(options Options, f func(*Group)) *Statement {
44 | s := CustomFunc(options, f)
45 | g.items = append(g.items, s)
46 | return s
47 | }
48 |
49 | // CustomFunc renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator.
50 | func (s *Statement) CustomFunc(options Options, f func(*Group)) *Statement {
51 | g := &Group{
52 | close: options.Close,
53 | multi: options.Multi,
54 | name: "custom",
55 | open: options.Open,
56 | separator: options.Separator,
57 | }
58 | f(g)
59 | *s = append(*s, g)
60 | return s
61 | }
62 |
--------------------------------------------------------------------------------
/cmd/statement/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "fmt"
5 | "os"
6 |
7 | "github.com/xyproto/wann"
8 | )
9 |
10 | func main() {
11 | // Here are four shapes, representing: up, down, left and right:
12 |
13 | up := []float64{
14 | 0.0, 1.0, 0.0, // o
15 | 1.0, 1.0, 1.0} // ooo
16 |
17 | down := []float64{
18 | 1.0, 1.0, 1.0, // ooo
19 | 0.0, 1.0, 0.0} // o
20 |
21 | left := []float64{
22 | 1.0, 1.0, 1.0, // ooo
23 | 0.0, 0.0, 1.0} // o
24 |
25 | right := []float64{
26 | 1.0, 1.0, 1.0, // ooo
27 | 1.0, 0.0, 0.0} // o
28 |
29 | // Prepare the input data as a 2D slice
30 | inputData := [][]float64{
31 | up,
32 | down,
33 | left,
34 | right,
35 | }
36 |
37 | // Which of the elements in the input data are we trying to identify?
38 | correctResultsForUp := []float64{0.0, -1.0, -1.0, -1.0}
39 |
40 | // Prepare a neural network configuration struct
41 | config := &wann.Config{
42 | InitialConnectionRatio: 0.2,
43 | Generations: 2000,
44 | PopulationSize: 500,
45 | Verbose: true,
46 | }
47 |
48 | // Evolve a network, using the input data and the sought after results
49 | trainedNetwork, err := config.Evolve(inputData, correctResultsForUp)
50 | if err != nil {
51 | fmt.Fprintf(os.Stderr, "error: %s\n", err)
52 | os.Exit(1)
53 | }
54 |
55 | // Now to test the trained network on 4 different inputs and see if it passes the test
56 | upScore := trainedNetwork.Evaluate(up)
57 | downScore := trainedNetwork.Evaluate(down)
58 | leftScore := trainedNetwork.Evaluate(left)
59 | rightScore := trainedNetwork.Evaluate(right)
60 |
61 | if config.Verbose {
62 | if upScore > downScore && upScore > leftScore && upScore > rightScore {
63 | fmt.Println("Network training complete, the results are good.")
64 | } else {
65 | fmt.Println("Network training complete, but the results did not pass the test.")
66 | }
67 | }
68 |
69 | // Save the trained network as an SVG image
70 | if config.Verbose {
71 | fmt.Print("Writing network.svg...")
72 | }
73 | if err := trainedNetwork.WriteSVG("network.svg"); err != nil {
74 | fmt.Fprintf(os.Stderr, "error: %s\n", err)
75 | os.Exit(1)
76 | }
77 | if config.Verbose {
78 | fmt.Println("ok")
79 | }
80 |
81 | fmt.Println("Statement for this network (a work in progress, not correct yet):")
82 |
83 | networkStatement, err := trainedNetwork.StatementWithInputDataVariables()
84 | if err != nil {
85 | fmt.Fprintf(os.Stderr, "error: %s\n", err)
86 | os.Exit(1)
87 | }
88 | fmt.Println(wann.Render(networkStatement))
89 | }
90 |
--------------------------------------------------------------------------------
/img/before.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/statement.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "go/format"
7 | "io"
8 | )
9 |
10 | // Statement represents a simple list of code items. When rendered the items
11 | // are separated by spaces.
12 | type Statement []Code
13 |
14 | func newStatement() *Statement {
15 | return &Statement{}
16 | }
17 |
18 | // Clone makes a copy of the Statement, so further tokens can be appended
19 | // without affecting the original.
20 | func (s *Statement) Clone() *Statement {
21 | return &Statement{s}
22 | }
23 |
24 | func (s *Statement) previous(c Code) Code {
25 | index := -1
26 | for i, item := range *s {
27 | if item == c {
28 | index = i
29 | break
30 | }
31 | }
32 | if index > 0 {
33 | return (*s)[index-1]
34 | }
35 | return nil
36 | }
37 |
38 | func (s *Statement) isNull(f *File) bool {
39 | if s == nil {
40 | return true
41 | }
42 | for _, c := range *s {
43 | if !c.isNull(f) {
44 | return false
45 | }
46 | }
47 | return true
48 | }
49 |
50 | func (s *Statement) render(f *File, w io.Writer, _ *Statement) error {
51 | first := true
52 | for _, code := range *s {
53 | if code == nil || code.isNull(f) {
54 | // Null() token produces no output but also
55 | // no separator. Empty() token products no
56 | // output but adds a separator.
57 | continue
58 | }
59 | if !first {
60 | if _, err := w.Write([]byte(" ")); err != nil {
61 | return err
62 | }
63 | }
64 | if err := code.render(f, w, s); err != nil {
65 | return err
66 | }
67 | first = false
68 | }
69 | return nil
70 | }
71 |
72 | // Render renders the Statement to the provided writer.
73 | func (s *Statement) Render(writer io.Writer) error {
74 | return s.RenderWithFile(writer, NewFile(""))
75 | }
76 |
77 | // GoString renders the Statement for testing. Any error will cause a panic.
78 | func (s *Statement) GoString() string {
79 | buf := bytes.Buffer{}
80 | if err := s.Render(&buf); err != nil {
81 | panic(err)
82 | }
83 | return buf.String()
84 | }
85 |
86 | // RenderWithFile renders the Statement to the provided writer, using imports from the provided file.
87 | func (s *Statement) RenderWithFile(writer io.Writer, file *File) error {
88 | buf := &bytes.Buffer{}
89 | if err := s.render(file, buf, nil); err != nil {
90 | return err
91 | }
92 | b, err := format.Source(buf.Bytes())
93 | if err != nil {
94 | return fmt.Errorf("Error %s while formatting source:\n%s", err, buf.String())
95 | }
96 | if _, err := writer.Write(b); err != nil {
97 | return err
98 | }
99 | return nil
100 | }
101 |
102 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/swish/README.md:
--------------------------------------------------------------------------------
1 | # Swish
2 |
3 | [](https://travis-ci.org/xyproto/swish) [](https://goreportcard.com/report/github.com/xyproto/swish) [](https://godoc.org/github.com/xyproto/swish)
4 |
5 | An optimized Swish activation function ([Ramachandran, Zoph and Le, 2017](https://arxiv.org/abs/1710.05941)), for neural networks.
6 |
7 | ## Screenshots
8 |
9 | 
10 |
11 | 
12 |
13 | The graphs above were drawn using the program in `cmd/graph`, which uses [goterm](https://github.com/buger/goterm).
14 |
15 | ## Benchmark Results
16 |
17 | ### Using a `Swish` function that uses `math.Exp`
18 |
19 | First run:
20 |
21 | ```
22 | goos: linux
23 | goarch: amd64
24 | pkg: github.com/xyproto/swish
25 | BenchmarkSwish07-8 200000000 8.93 ns/op
26 | BenchmarkSwish03-8 200000000 8.95 ns/op
27 | PASS
28 | ok github.com/xyproto/swish 5.391s
29 | ```
30 |
31 | ### Using the optimized `Swish` function that uses `exp256`
32 |
33 | ```
34 | goos: linux
35 | goarch: amd64
36 | pkg: github.com/xyproto/swish
37 | BenchmarkSwish07-8 2000000000 0.26 ns/op
38 | BenchmarkSwish03-8 2000000000 0.26 ns/op
39 | PASS
40 | ok github.com/xyproto/swish 1.108s
41 | ```
42 |
43 | The optimized `Swish` function is **34x** faster than the one that uses `math.Exp`, and quite a bit faster than my (apparently bad) attempt at a hand-written assembly version.
44 |
45 | The average error (difference in output value) between the optimized and non-optimized version is `+-0.0013` and the maximum error is `+-0.0024`. This is for `x` in the range `[5,3]`. See the program in `cmd/precision` for how this was calculated.
46 |
47 | ```
48 | 0.00015
49 | 0.00001
50 | goos: linux
51 | goarch: amd64
52 | pkg: github.com/xyproto/swish
53 | BenchmarkSwishAssembly07-8 500000000 3.63 ns/op
54 | BenchmarkSwishAssembly03-8 500000000 3.65 ns/op
55 | BenchmarkSwish07-8 2000000000 0.30 ns/op
56 | BenchmarkSwish03-8 2000000000 0.26 ns/op
57 | BenchmarkSwishPrecise07-8 200000000 9.07 ns/op
58 | BenchmarkSwishPrecise03-8 200000000 9.25 ns/op
59 | PASS
60 | ok github.com/xyproto/swish 11.100s
61 | ```
62 |
63 | I have no idea why the assembly version is so slow, but `0.26 ns/op` isn't bad for a non-hand-optimized version.
64 |
65 | ## General info
66 |
67 | * Version: 1.3.0
68 | * License: MIT
69 | * Author: Alexander F. Rødseth <xyproto@archlinux.org>
70 |
--------------------------------------------------------------------------------
/img/after.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/comments.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "strings"
7 | )
8 |
9 | // Comment adds a comment. If the provided string contains a newline, the
10 | // comment is formatted in multiline style. If the comment string starts
11 | // with "//" or "/*", the automatic formatting is disabled and the string is
12 | // rendered directly.
13 | func Comment(str string) *Statement {
14 | return newStatement().Comment(str)
15 | }
16 |
17 | // Comment adds a comment. If the provided string contains a newline, the
18 | // comment is formatted in multiline style. If the comment string starts
19 | // with "//" or "/*", the automatic formatting is disabled and the string is
20 | // rendered directly.
21 | func (g *Group) Comment(str string) *Statement {
22 | s := Comment(str)
23 | g.items = append(g.items, s)
24 | return s
25 | }
26 |
27 | // Comment adds a comment. If the provided string contains a newline, the
28 | // comment is formatted in multiline style. If the comment string starts
29 | // with "//" or "/*", the automatic formatting is disabled and the string is
30 | // rendered directly.
31 | func (s *Statement) Comment(str string) *Statement {
32 | c := comment{
33 | comment: str,
34 | }
35 | *s = append(*s, c)
36 | return s
37 | }
38 |
39 | // Commentf adds a comment, using a format string and a list of parameters. If
40 | // the provided string contains a newline, the comment is formatted in
41 | // multiline style. If the comment string starts with "//" or "/*", the
42 | // automatic formatting is disabled and the string is rendered directly.
43 | func Commentf(format string, a ...interface{}) *Statement {
44 | return newStatement().Commentf(format, a...)
45 | }
46 |
47 | // Commentf adds a comment, using a format string and a list of parameters. If
48 | // the provided string contains a newline, the comment is formatted in
49 | // multiline style. If the comment string starts with "//" or "/*", the
50 | // automatic formatting is disabled and the string is rendered directly.
51 | func (g *Group) Commentf(format string, a ...interface{}) *Statement {
52 | s := Commentf(format, a...)
53 | g.items = append(g.items, s)
54 | return s
55 | }
56 |
57 | // Commentf adds a comment, using a format string and a list of parameters. If
58 | // the provided string contains a newline, the comment is formatted in
59 | // multiline style. If the comment string starts with "//" or "/*", the
60 | // automatic formatting is disabled and the string is rendered directly.
61 | func (s *Statement) Commentf(format string, a ...interface{}) *Statement {
62 | c := comment{
63 | comment: fmt.Sprintf(format, a...),
64 | }
65 | *s = append(*s, c)
66 | return s
67 | }
68 |
69 | type comment struct {
70 | comment string
71 | }
72 |
73 | func (c comment) isNull(f *File) bool {
74 | return false
75 | }
76 |
77 | func (c comment) render(f *File, w io.Writer, s *Statement) error {
78 | if strings.HasPrefix(c.comment, "//") || strings.HasPrefix(c.comment, "/*") {
79 | // automatic formatting disabled.
80 | if _, err := w.Write([]byte(c.comment)); err != nil {
81 | return err
82 | }
83 | return nil
84 | }
85 | if strings.Contains(c.comment, "\n") {
86 | if _, err := w.Write([]byte("/*\n")); err != nil {
87 | return err
88 | }
89 | } else {
90 | if _, err := w.Write([]byte("// ")); err != nil {
91 | return err
92 | }
93 | }
94 | if _, err := w.Write([]byte(c.comment)); err != nil {
95 | return err
96 | }
97 | if strings.Contains(c.comment, "\n") {
98 | if !strings.HasSuffix(c.comment, "\n") {
99 | if _, err := w.Write([]byte("\n")); err != nil {
100 | return err
101 | }
102 | }
103 | if _, err := w.Write([]byte("*/")); err != nil {
104 | return err
105 | }
106 | }
107 | return nil
108 | }
109 |
--------------------------------------------------------------------------------
/neuron_test.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "math/rand"
7 | "testing"
8 | )
9 |
10 | func TestNeuron(t *testing.T) {
11 | rand.Seed(commonSeed)
12 | net := NewNetwork()
13 | n, _ := net.NewBlankNeuron()
14 | n.ActivationFunction = Swish
15 | result := n.GetActivationFunction()(0.5)
16 | diff := math.Abs(result - 0.311287)
17 | if diff > 0.00001 { // 0.0000001 {
18 | t.Errorf("default swish activation function, expected a number close to 0.311287, got %f:", result)
19 | }
20 |
21 | fmt.Printf("Neurons in network: %d\n", len(net.AllNodes))
22 | }
23 |
24 | func TestString(t *testing.T) {
25 | rand.Seed(commonSeed)
26 | net := NewNetwork()
27 | n, _ := net.NewBlankNeuron()
28 | _ = n.String()
29 | }
30 |
31 | func TestHasInput(t *testing.T) {
32 | rand.Seed(commonSeed)
33 | net := NewNetwork() // 0
34 | a, _ := net.NewBlankNeuron() // 1
35 | b, _ := net.NewBlankNeuron() // 2
36 | fmt.Println("a is 1?", a)
37 | fmt.Println("b is 2?", b)
38 | a.AddInput(0)
39 | if !a.HasInput(0) {
40 | t.Errorf("a should have b as an input")
41 | }
42 | if b.HasInput(0) {
43 | t.Errorf("b should not have a as an input")
44 | }
45 | }
46 |
47 | func TestFindInput(t *testing.T) {
48 | rand.Seed(commonSeed)
49 | net := NewNetwork()
50 |
51 | a, _ := net.NewBlankNeuron() // a, 1
52 | _, bi := net.NewBlankNeuron() // b, 2
53 | c, ci := net.NewBlankNeuron() // c, 3
54 | _, di := net.NewBlankNeuron() // d, 4
55 |
56 | a.AddInput(bi) // b
57 | a.AddInputNeuron(c) // c
58 |
59 | if _, found := a.FindInput(di); found {
60 | t.Errorf("a should not have d as an input")
61 | }
62 | if pos, found := a.FindInput(bi); !found {
63 | t.Errorf("a should have b as an input")
64 | } else if found && pos != 0 {
65 | t.Errorf("a should have b as an input at position 0")
66 | }
67 | if pos, found := a.FindInput(ci); !found {
68 | t.Errorf("a should have c as an input")
69 | } else if found && pos != 1 {
70 | t.Errorf("a should have c as an input at position 1")
71 | }
72 | }
73 |
74 | func TestRemoveInput(t *testing.T) {
75 | rand.Seed(commonSeed)
76 | net := NewNetwork(&Config{
77 | inputs: 5,
78 | InitialConnectionRatio: 0.5,
79 | sharedWeight: 0.5,
80 | })
81 |
82 | a, _ := net.NewBlankNeuron() // 0
83 | a.AddInput(1)
84 | a.AddInput(2)
85 | if a.RemoveInput(1) != nil {
86 | t.Errorf("could not remove input b from a")
87 | }
88 | if a.RemoveInput(2) != nil {
89 | t.Errorf("could not remove input c from a")
90 | }
91 | if a.HasInput(1) {
92 | t.Errorf("a should not have b as an input")
93 | }
94 | if a.HasInput(2) {
95 | t.Errorf("a should not have c as an input")
96 | }
97 | }
98 |
99 | // func (neuron *Neuron) RemoveInput(e *Neuron) error {
100 |
101 | func TestEvaluate(t *testing.T) {
102 | rand.Seed(commonSeed)
103 | net := NewNetwork(&Config{
104 | inputs: 7,
105 | InitialConnectionRatio: 0.5,
106 | sharedWeight: 0.5,
107 | })
108 |
109 | // Set a few activation functions
110 | net.AllNodes[net.InputNodes[0]].ActivationFunction = Linear
111 | net.AllNodes[net.InputNodes[1]].ActivationFunction = Swish
112 | net.AllNodes[net.InputNodes[2]].ActivationFunction = Gauss
113 | net.AllNodes[net.InputNodes[3]].ActivationFunction = Sigmoid
114 | net.AllNodes[net.InputNodes[4]].ActivationFunction = ReLU
115 | net.AllNodes[net.InputNodes[5]].ActivationFunction = Step
116 | net.AllNodes[net.InputNodes[6]].ActivationFunction = Inv
117 |
118 | result := net.Evaluate([]float64{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5})
119 | fmt.Println(result)
120 | }
121 |
122 | func TestIn(t *testing.T) {
123 | rand.Seed(commonSeed)
124 | net := NewNetwork()
125 | n, ni := net.NewNeuron()
126 | if ni != 1 {
127 | t.Fail()
128 | }
129 | outputNeuronIndex := NeuronIndex(0)
130 | if !n.In([]NeuronIndex{outputNeuronIndex, 1}) {
131 | t.Fail()
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/statement_test.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "testing"
7 | )
8 |
9 | // ExampleNetwork_StatementWithInputValues
10 | func TestNetworkStatementWithInputValues(t *testing.T) {
11 | rand.Seed(1)
12 | net := NewNetwork(&Config{
13 | inputs: 6,
14 | InitialConnectionRatio: 0.7,
15 | sharedWeight: 0.5,
16 | })
17 |
18 | net.SetInputValues([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5})
19 | statement, err := net.StatementWithInputValues()
20 | if err != nil {
21 | panic(err)
22 | }
23 | fmt.Println(Render(statement))
24 | }
25 |
26 | // ExampleNetwork_StatementWithInputDataVariables
27 | func TestNetwork_StatementWithInputDataVariables(t *testing.T) {
28 | rand.Seed(1)
29 | net := NewNetwork(&Config{
30 | inputs: 6,
31 | InitialConnectionRatio: 0.7,
32 | sharedWeight: 0.5,
33 | })
34 |
35 | // 1.234 should not appear in the output statement
36 | net.SetInputValues([]float64{1.234, 1.234, 1.234, 1.234, 1.234, 1.234})
37 |
38 | statement, err := net.StatementWithInputDataVariables()
39 | if err != nil {
40 | panic(err)
41 | }
42 | fmt.Println(Render(statement))
43 | }
44 |
45 | // ExampleNeuron_InputStatement
46 | func ExampleNeuron_InputStatement() {
47 | rand.Seed(1)
48 | net := NewNetwork(&Config{
49 | inputs: 6,
50 | InitialConnectionRatio: 0.7,
51 | sharedWeight: 0.5,
52 | })
53 |
54 | // 1.234 should not appear in the output statement
55 | net.SetInputValues([]float64{1.234, 1.234, 1.234, 1.234, 1.234, 1.234})
56 |
57 | inputStatement2, err := net.AllNodes[net.InputNodes[2]].InputStatement()
58 | if err != nil {
59 | panic(err)
60 | }
61 | fmt.Println(Render(inputStatement2))
62 | // Output:
63 | // inputData[2]
64 | }
65 |
66 | func ExampleNetwork_OutputNodeStatementX_first() {
67 | // First create a network with only one output node, that has a step function
68 | net := NewNetwork()
69 | net.AllNodes[net.OutputNode].ActivationFunction = Step
70 |
71 | fmt.Println(net.OutputNodeStatementX("score"))
72 |
73 | // Output:
74 | // score := func(s float64) float64 {
75 | // if s >= 0 {
76 | // return 1
77 | // } else {
78 | // return 0
79 | // }
80 | // }(x)
81 | }
82 |
83 | func ExampleNetwork_OutputNodeStatementX_second() {
84 | // Then create a network with an input node that has a sigmoid function and an output node that has an invert function
85 | net := NewNetwork()
86 | net.NewInputNode(Sigmoid, true)
87 | net.AllNodes[net.OutputNode].ActivationFunction = Inv
88 |
89 | // Output a Go expression for this network, using the given input variable names
90 | fmt.Println(net.OutputNodeStatementX("score"))
91 |
92 | // Output:
93 | // score := -(x)
94 |
95 | }
96 |
97 | func ExampleNetwork_OutputNodeStatementX_third() {
98 | rand.Seed(999)
99 | net := NewNetwork(&Config{
100 | inputs: 1,
101 | InitialConnectionRatio: 0.7,
102 | sharedWeight: 0.5,
103 | })
104 | fmt.Println(net.OutputNodeStatementX("score"))
105 |
106 | // Output:
107 | // score := math.Exp(-(math.Pow(x, 2.0)) / 2.0)
108 | }
109 | func ExampleNetwork_OutputNodeStatementX_fourth() {
110 |
111 | rand.Seed(1111113)
112 | net := NewNetwork(&Config{
113 | inputs: 5,
114 | InitialConnectionRatio: 0.7,
115 | sharedWeight: 0.5,
116 | })
117 | fmt.Println(net.OutputNodeStatementX("score"))
118 |
119 | // Output:
120 | // score := func(r float64) float64 {
121 | // if r >= 0 {
122 | // return r
123 | // } else {
124 | // return 0
125 | // }
126 | // }(x)
127 | }
128 |
129 | func ExampleNetwork_OutputNodeStatementX() {
130 | rand.Seed(1)
131 | net := NewNetwork(&Config{
132 | inputs: 5,
133 | InitialConnectionRatio: 0.7,
134 | sharedWeight: 0.5,
135 | })
136 | fmt.Println(net.OutputNodeStatementX("f"))
137 | //fmt.Println(net.Score())
138 |
139 | // Output:
140 | // f := math.Pow(x, 2.0)
141 | }
142 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37/go.mod h1:u9UyCz2eTrSGy6fbupqJ54eY5c4IC8gREQ1053dK12U=
2 | github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14/go.mod h1:Sth2QfxfATb/nW4EsrSi2KyJmbcniZ8TgTaji17D6ms=
3 | github.com/dave/brenda v1.1.0/go.mod h1:4wCUr6gSlu5/1Tk7akE5X7UorwiQ8Rij0SKH3/BGMOM=
4 | github.com/dave/courtney v0.3.0/go.mod h1:BAv3hA06AYfNUjfjQr+5gc6vxeBVOupLqrColj+QSD8=
5 | github.com/dave/gopackages v0.0.0-20170318123100-46e7023ec56e/go.mod h1:i00+b/gKdIDIxuLDFob7ustLAVqhsZRk2qVZrArELGQ=
6 | github.com/dave/jennifer v1.5.0 h1:HmgPN93bVDpkQyYbqhCHj5QlgvUkvEOzMyEvKLgCRrg=
7 | github.com/dave/jennifer v1.5.0/go.mod h1:4MnyiFIlZS3l5tSDn8VnzE6ffAhYBMB2SZntBsZGUok=
8 | github.com/dave/kerr v0.0.0-20170318121727-bc25dd6abe8e/go.mod h1:qZqlPyPvfsDJt+3wHJ1EvSXDuVjFTK0j2p/ca+gtsb8=
9 | github.com/dave/patsy v0.0.0-20210517141501-957256f50cba/go.mod h1:qfR88CgEGLoiqDaE+xxDCi5QA5v4vUoW0UCX2Nd5Tlc=
10 | github.com/dave/rebecca v0.9.1/go.mod h1:N6XYdMD/OKw3lkF3ywh8Z6wPGuwNFDNtWYEMFWEmXBA=
11 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
12 | github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3 h1:QN9wTN6qa12WnnYyQ9xlUvyU1W9V58u3SWgwNKp5RpA=
13 | github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3/go.mod h1:xYoWKJKu61/NJ4/6PV6mJuNu5QymmrbJmmI5mNOW+fI=
14 | github.com/xyproto/swish v1.3.0 h1:qiVl2UeqkMqDJKjRLVwRprD7Vzz2pbNRcWTb3s0TpjE=
15 | github.com/xyproto/swish v1.3.0/go.mod h1:IVhz2R80pNsPaSxbEcqR84i4MVL0b06HN7KIMhv38WA=
16 | github.com/xyproto/tinysvg v1.0.1 h1:UwuNaudh95O0BzdBJ4zQ0lTft5a4ToHKfo6xvKb3it0=
17 | github.com/xyproto/tinysvg v1.0.1/go.mod h1:DKgmaYuFIvJab9ug4nH4ZG356VtUaKXG2mUU07GIurs=
18 | github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
19 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
20 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
21 | golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro=
22 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
23 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
24 | golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
25 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
26 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
27 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
28 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
29 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
30 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
31 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
32 | golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
33 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
34 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
35 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
36 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
37 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
38 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
39 | golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU=
40 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
41 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
42 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
43 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/group.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "go/format"
7 | "io"
8 | )
9 |
10 | // Group represents a list of Code items, separated by tokens with an optional
11 | // open and close token.
12 | type Group struct {
13 | name string
14 | items []Code
15 | open string
16 | close string
17 | separator string
18 | multi bool
19 | }
20 |
21 | func (g *Group) isNull(f *File) bool {
22 | if g == nil {
23 | return true
24 | }
25 | if g.open != "" || g.close != "" {
26 | return false
27 | }
28 | for _, c := range g.items {
29 | if !c.isNull(f) {
30 | return false
31 | }
32 | }
33 | return true
34 | }
35 |
36 | func (g *Group) render(f *File, w io.Writer, s *Statement) error {
37 | if g.name == "block" && s != nil {
38 | // Special CaseBlock format for then the previous item in the statement
39 | // is a Case group or the default keyword.
40 | prev := s.previous(g)
41 | grp, isGrp := prev.(*Group)
42 | tkn, isTkn := prev.(token)
43 | if isGrp && grp.name == "case" || isTkn && tkn.content == "default" {
44 | g.open = ""
45 | g.close = ""
46 | }
47 | }
48 | if g.open != "" {
49 | if _, err := w.Write([]byte(g.open)); err != nil {
50 | return err
51 | }
52 | }
53 | isNull, err := g.renderItems(f, w)
54 | if err != nil {
55 | return err
56 | }
57 | if !isNull && g.multi && g.close != "" {
58 | // For multi-line blocks with a closing token, we insert a new line after the last item (but
59 | // not if all items were null). This is to ensure that if the statement finishes with a comment,
60 | // the closing token is not commented out.
61 | s := "\n"
62 | if g.separator == "," {
63 | // We also insert add trailing comma if the separator was ",".
64 | s = ",\n"
65 | }
66 | if _, err := w.Write([]byte(s)); err != nil {
67 | return err
68 | }
69 | }
70 | if g.close != "" {
71 | if _, err := w.Write([]byte(g.close)); err != nil {
72 | return err
73 | }
74 | }
75 | return nil
76 | }
77 |
78 | func (g *Group) renderItems(f *File, w io.Writer) (isNull bool, err error) {
79 | first := true
80 | for _, code := range g.items {
81 | if pt, ok := code.(token); ok && pt.typ == packageToken {
82 | // Special case for package tokens in Qual groups - for dot-imports, the package token
83 | // will be null, so will not render and will not be registered in the imports block.
84 | // This ensures all packageTokens that are rendered are registered.
85 | f.register(pt.content.(string))
86 | }
87 | if code == nil || code.isNull(f) {
88 | // Null() token produces no output but also
89 | // no separator. Empty() token products no
90 | // output but adds a separator.
91 | continue
92 | }
93 | if g.name == "values" {
94 | if _, ok := code.(Dict); ok && len(g.items) > 1 {
95 | panic("Error in Values: if Dict is used, must be one item only")
96 | }
97 | }
98 | if !first && g.separator != "" {
99 | // The separator token is added before each non-null item, but not before the first item.
100 | if _, err := w.Write([]byte(g.separator)); err != nil {
101 | return false, err
102 | }
103 | }
104 | if g.multi {
105 | // For multi-line blocks, we insert a new line before each non-null item.
106 | if _, err := w.Write([]byte("\n")); err != nil {
107 | return false, err
108 | }
109 | }
110 | if err := code.render(f, w, nil); err != nil {
111 | return false, err
112 | }
113 | first = false
114 | }
115 | return first, nil
116 | }
117 |
118 | // Render renders the Group to the provided writer.
119 | func (g *Group) Render(writer io.Writer) error {
120 | return g.RenderWithFile(writer, NewFile(""))
121 | }
122 |
123 | // GoString renders the Group for testing. Any error will cause a panic.
124 | func (g *Group) GoString() string {
125 | buf := bytes.Buffer{}
126 | if err := g.Render(&buf); err != nil {
127 | panic(err)
128 | }
129 | return buf.String()
130 | }
131 |
132 | // RenderWithFile renders the Group to the provided writer, using imports from the provided file.
133 | func (g *Group) RenderWithFile(writer io.Writer, file *File) error {
134 | buf := &bytes.Buffer{}
135 | if err := g.render(file, buf, nil); err != nil {
136 | return err
137 | }
138 | b, err := format.Source(buf.Bytes())
139 | if err != nil {
140 | return fmt.Errorf("Error %s while formatting source:\n%s", err, buf.String())
141 | }
142 | if _, err := writer.Write(b); err != nil {
143 | return err
144 | }
145 | return nil
146 | }
147 |
148 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/jen.go:
--------------------------------------------------------------------------------
1 | // Package jen is a code generator for Go
2 | package jen
3 |
4 | import (
5 | "bytes"
6 | "fmt"
7 | "go/format"
8 | "io"
9 | "io/ioutil"
10 | "sort"
11 | "strconv"
12 | )
13 |
14 | // Code represents an item of code that can be rendered.
15 | type Code interface {
16 | render(f *File, w io.Writer, s *Statement) error
17 | isNull(f *File) bool
18 | }
19 |
20 | // Save renders the file and saves to the filename provided.
21 | func (f *File) Save(filename string) error {
22 | // notest
23 | buf := &bytes.Buffer{}
24 | if err := f.Render(buf); err != nil {
25 | return err
26 | }
27 | if err := ioutil.WriteFile(filename, buf.Bytes(), 0644); err != nil {
28 | return err
29 | }
30 | return nil
31 | }
32 |
33 | // Render renders the file to the provided writer.
34 | func (f *File) Render(w io.Writer) error {
35 | body := &bytes.Buffer{}
36 | if err := f.render(f, body, nil); err != nil {
37 | return err
38 | }
39 | source := &bytes.Buffer{}
40 | if len(f.headers) > 0 {
41 | for _, c := range f.headers {
42 | if err := Comment(c).render(f, source, nil); err != nil {
43 | return err
44 | }
45 | if _, err := fmt.Fprint(source, "\n"); err != nil {
46 | return err
47 | }
48 | }
49 | // Append an extra newline so that header comments don't get lumped in
50 | // with package comments.
51 | if _, err := fmt.Fprint(source, "\n"); err != nil {
52 | return err
53 | }
54 | }
55 | for _, c := range f.comments {
56 | if err := Comment(c).render(f, source, nil); err != nil {
57 | return err
58 | }
59 | if _, err := fmt.Fprint(source, "\n"); err != nil {
60 | return err
61 | }
62 | }
63 | if _, err := fmt.Fprintf(source, "package %s", f.name); err != nil {
64 | return err
65 | }
66 | if f.CanonicalPath != "" {
67 | if _, err := fmt.Fprintf(source, " // import %q", f.CanonicalPath); err != nil {
68 | return err
69 | }
70 | }
71 | if _, err := fmt.Fprint(source, "\n\n"); err != nil {
72 | return err
73 | }
74 | if err := f.renderImports(source); err != nil {
75 | return err
76 | }
77 | if _, err := source.Write(body.Bytes()); err != nil {
78 | return err
79 | }
80 | formatted, err := format.Source(source.Bytes())
81 | if err != nil {
82 | return fmt.Errorf("Error %s while formatting source:\n%s", err, source.String())
83 | }
84 | if _, err := w.Write(formatted); err != nil {
85 | return err
86 | }
87 | return nil
88 | }
89 |
90 | func (f *File) renderImports(source io.Writer) error {
91 |
92 | // Render the "C" import if it's been used in a `Qual`, `Anon` or if there's a preamble comment
93 | hasCgo := f.imports["C"].name != "" || len(f.cgoPreamble) > 0
94 |
95 | // Only separate the import from the main imports block if there's a preamble
96 | separateCgo := hasCgo && len(f.cgoPreamble) > 0
97 |
98 | filtered := map[string]importdef{}
99 | for path, def := range f.imports {
100 | // filter out the "C" pseudo-package so it's not rendered in a block with the other
101 | // imports, but only if it is accompanied by a preamble comment
102 | if path == "C" && separateCgo {
103 | continue
104 | }
105 | filtered[path] = def
106 | }
107 |
108 | if len(filtered) == 1 {
109 | for path, def := range filtered {
110 | if def.alias && path != "C" {
111 | // "C" package should be rendered without alias even when used as an anonymous import
112 | // (e.g. should never have an underscore).
113 | if _, err := fmt.Fprintf(source, "import %s %s\n\n", def.name, strconv.Quote(path)); err != nil {
114 | return err
115 | }
116 | } else {
117 | if _, err := fmt.Fprintf(source, "import %s\n\n", strconv.Quote(path)); err != nil {
118 | return err
119 | }
120 | }
121 | }
122 | } else if len(filtered) > 1 {
123 | if _, err := fmt.Fprint(source, "import (\n"); err != nil {
124 | return err
125 | }
126 | // We must sort the imports to ensure repeatable
127 | // source.
128 | paths := []string{}
129 | for path := range filtered {
130 | paths = append(paths, path)
131 | }
132 | sort.Strings(paths)
133 | for _, path := range paths {
134 | def := filtered[path]
135 | if def.alias && path != "C" {
136 | // "C" package should be rendered without alias even when used as an anonymous import
137 | // (e.g. should never have an underscore).
138 | if _, err := fmt.Fprintf(source, "%s %s\n", def.name, strconv.Quote(path)); err != nil {
139 | return err
140 | }
141 |
142 | } else {
143 | if _, err := fmt.Fprintf(source, "%s\n", strconv.Quote(path)); err != nil {
144 | return err
145 | }
146 | }
147 | }
148 | if _, err := fmt.Fprint(source, ")\n\n"); err != nil {
149 | return err
150 | }
151 | }
152 |
153 | if separateCgo {
154 | for _, c := range f.cgoPreamble {
155 | if err := Comment(c).render(f, source, nil); err != nil {
156 | return err
157 | }
158 | if _, err := fmt.Fprint(source, "\n"); err != nil {
159 | return err
160 | }
161 | }
162 | if _, err := fmt.Fprint(source, "import \"C\"\n\n"); err != nil {
163 | return err
164 | }
165 | }
166 |
167 | return nil
168 | }
169 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/lit.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | // Lit renders a literal. Lit supports only built-in types (bool, string, int, complex128, float64,
4 | // float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64).
5 | // Passing any other type will panic.
6 | func Lit(v interface{}) *Statement {
7 | return newStatement().Lit(v)
8 | }
9 |
10 | // Lit renders a literal. Lit supports only built-in types (bool, string, int, complex128, float64,
11 | // float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64).
12 | // Passing any other type will panic.
13 | func (g *Group) Lit(v interface{}) *Statement {
14 | s := Lit(v)
15 | g.items = append(g.items, s)
16 | return s
17 | }
18 |
19 | // Lit renders a literal. Lit supports only built-in types (bool, string, int, complex128, float64,
20 | // float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64).
21 | // Passing any other type will panic.
22 | func (s *Statement) Lit(v interface{}) *Statement {
23 | t := token{
24 | typ: literalToken,
25 | content: v,
26 | }
27 | *s = append(*s, t)
28 | return s
29 | }
30 |
31 | // LitFunc renders a literal. LitFunc generates the value to render by executing the provided
32 | // function. LitFunc supports only built-in types (bool, string, int, complex128, float64, float32,
33 | // int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64).
34 | // Returning any other type will panic.
35 | func LitFunc(f func() interface{}) *Statement {
36 | return newStatement().LitFunc(f)
37 | }
38 |
39 | // LitFunc renders a literal. LitFunc generates the value to render by executing the provided
40 | // function. LitFunc supports only built-in types (bool, string, int, complex128, float64, float32,
41 | // int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64).
42 | // Returning any other type will panic.
43 | func (g *Group) LitFunc(f func() interface{}) *Statement {
44 | s := LitFunc(f)
45 | g.items = append(g.items, s)
46 | return s
47 | }
48 |
49 | // LitFunc renders a literal. LitFunc generates the value to render by executing the provided
50 | // function. LitFunc supports only built-in types (bool, string, int, complex128, float64, float32,
51 | // int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64).
52 | // Returning any other type will panic.
53 | func (s *Statement) LitFunc(f func() interface{}) *Statement {
54 | t := token{
55 | typ: literalToken,
56 | content: f(),
57 | }
58 | *s = append(*s, t)
59 | return s
60 | }
61 |
62 | // LitRune renders a rune literal.
63 | func LitRune(v rune) *Statement {
64 | return newStatement().LitRune(v)
65 | }
66 |
67 | // LitRune renders a rune literal.
68 | func (g *Group) LitRune(v rune) *Statement {
69 | s := LitRune(v)
70 | g.items = append(g.items, s)
71 | return s
72 | }
73 |
74 | // LitRune renders a rune literal.
75 | func (s *Statement) LitRune(v rune) *Statement {
76 | t := token{
77 | typ: literalRuneToken,
78 | content: v,
79 | }
80 | *s = append(*s, t)
81 | return s
82 | }
83 |
84 | // LitRuneFunc renders a rune literal. LitRuneFunc generates the value to
85 | // render by executing the provided function.
86 | func LitRuneFunc(f func() rune) *Statement {
87 | return newStatement().LitRuneFunc(f)
88 | }
89 |
90 | // LitRuneFunc renders a rune literal. LitRuneFunc generates the value to
91 | // render by executing the provided function.
92 | func (g *Group) LitRuneFunc(f func() rune) *Statement {
93 | s := LitRuneFunc(f)
94 | g.items = append(g.items, s)
95 | return s
96 | }
97 |
98 | // LitRuneFunc renders a rune literal. LitRuneFunc generates the value to
99 | // render by executing the provided function.
100 | func (s *Statement) LitRuneFunc(f func() rune) *Statement {
101 | t := token{
102 | typ: literalRuneToken,
103 | content: f(),
104 | }
105 | *s = append(*s, t)
106 | return s
107 | }
108 |
109 | // LitByte renders a byte literal.
110 | func LitByte(v byte) *Statement {
111 | return newStatement().LitByte(v)
112 | }
113 |
114 | // LitByte renders a byte literal.
115 | func (g *Group) LitByte(v byte) *Statement {
116 | s := LitByte(v)
117 | g.items = append(g.items, s)
118 | return s
119 | }
120 |
121 | // LitByte renders a byte literal.
122 | func (s *Statement) LitByte(v byte) *Statement {
123 | t := token{
124 | typ: literalByteToken,
125 | content: v,
126 | }
127 | *s = append(*s, t)
128 | return s
129 | }
130 |
131 | // LitByteFunc renders a byte literal. LitByteFunc generates the value to
132 | // render by executing the provided function.
133 | func LitByteFunc(f func() byte) *Statement {
134 | return newStatement().LitByteFunc(f)
135 | }
136 |
137 | // LitByteFunc renders a byte literal. LitByteFunc generates the value to
138 | // render by executing the provided function.
139 | func (g *Group) LitByteFunc(f func() byte) *Statement {
140 | s := LitByteFunc(f)
141 | g.items = append(g.items, s)
142 | return s
143 | }
144 |
145 | // LitByteFunc renders a byte literal. LitByteFunc generates the value to
146 | // render by executing the provided function.
147 | func (s *Statement) LitByteFunc(f func() byte) *Statement {
148 | t := token{
149 | typ: literalByteToken,
150 | content: f(),
151 | }
152 | *s = append(*s, t)
153 | return s
154 | }
155 |
--------------------------------------------------------------------------------
/img/labels.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # wann [](https://goreportcard.com/report/github.com/xyproto/wann) [](https://godoc.org/github.com/xyproto/wann)
4 |
5 | Weight Agnostic Neural Networks is a new type of neural network, where the weights of all the neurons are shared and the structure of the network is what matters.
6 |
7 | This package implements Weight Agnostic Neural Networks for Go, and is inspired by this paper from June 2019:
8 |
9 | *"Weight Agnostic Neural Networks" by Adam Gaier and David Ha*. ([PDF](https://arxiv.org/pdf/1906.04358.pdf) | [Interactive version](https://weightagnostic.github.io/) | [Google AI blog post](https://ai.googleblog.com/2019/08/exploring-weight-agnostic-neural.html))
10 |
11 | ## Features and limitations
12 |
13 | * All activation functions are benchmarked at the start of the program and the results are taken into account when calculating the complexity of a network.
14 | * All networks can be translated to a Go statement, using the wonderful [jennifer](https://github.com/dave/jennifer) package (work in progress, there are a few kinks that needs to be ironed out).
15 | * Networks can be saved as `SVG` diagrams. This feature needs more testing.
16 | * Neural networks can be trained and used. See the `cmd` folder for examples.
17 | * A random weight is chosen when training, instead of looping over the range of the weight. The paper describes both methods.
18 | * After the network has been trained, the optimal weight is found by looping over all weights (with a step size of `0.0001`).
19 | * Increased complexity counts negatively when evolving networks. This optimizes not only for less complex networks, but also for execution speed.
20 | * The diagram drawing routine plots the activation functions directly onto the nodes, together with a label. This can be saved as an SVG file.
21 |
22 | ## Example program
23 |
24 | This is a simple example, for creating a network that can recognize one of four shapes:
25 |
26 | ```go
27 | package main
28 |
29 | import (
30 | "fmt"
31 | "os"
32 |
33 | "github.com/xyproto/wann"
34 | )
35 |
36 | func main() {
37 | // Here are four shapes, representing: up, down, left and right:
38 |
39 | up := []float64{
40 | 0.0, 1.0, 0.0, // o
41 | 1.0, 1.0, 1.0} // ooo
42 |
43 | down := []float64{
44 | 1.0, 1.0, 1.0, // ooo
45 | 0.0, 1.0, 0.0} // o
46 |
47 | left := []float64{
48 | 1.0, 1.0, 1.0, // ooo
49 | 0.0, 0.0, 1.0} // o
50 |
51 | right := []float64{
52 | 1.0, 1.0, 1.0, // ooo
53 | 1.0, 0.0, 0.0} // o
54 |
55 | // Prepare the input data as a 2D slice
56 | inputData := [][]float64{
57 | up,
58 | down,
59 | left,
60 | right,
61 | }
62 |
63 | // Target scores for: up, down, left, right
64 | correctResultsForUp := []float64{1.0, -1.0, -1.0, -1.0}
65 |
66 | // Prepare a neural network configuration struct
67 | config := &wann.Config{
68 | InitialConnectionRatio: 0.2,
69 | Generations: 2000,
70 | PopulationSize: 500,
71 | Verbose: true,
72 | }
73 |
74 | // Evolve a network, using the input data and the sought after results
75 | trainedNetwork, err := config.Evolve(inputData, correctResultsForUp)
76 | if err != nil {
77 | fmt.Fprintf(os.Stderr, "error: %s\n", err)
78 | os.Exit(1)
79 | }
80 |
81 | // Now to test the trained network on 4 different inputs and see if it passes the test
82 | upScore := trainedNetwork.Evaluate(up)
83 | downScore := trainedNetwork.Evaluate(down)
84 | leftScore := trainedNetwork.Evaluate(left)
85 | rightScore := trainedNetwork.Evaluate(right)
86 |
87 | if config.Verbose {
88 | if upScore > downScore && upScore > leftScore && upScore > rightScore {
89 | fmt.Println("Network training complete, the results are good.")
90 | } else {
91 | fmt.Println("Network training complete, but the results did not pass the test.")
92 | }
93 | }
94 |
95 | // Save the trained network as an SVG image
96 | if config.Verbose {
97 | fmt.Print("Writing network.svg...")
98 | }
99 | if err := trainedNetwork.WriteSVG("network.svg"); err != nil {
100 | fmt.Fprintf(os.Stderr, "error: %s\n", err)
101 | os.Exit(1)
102 | }
103 | if config.Verbose {
104 | fmt.Println("ok")
105 | }
106 | }
107 | ```
108 |
109 | Here is the resulting network generated by the above program:
110 |
111 |
112 |
113 | This makes sense, since taking the third number in the input data (index 2), running it through a swish function and then inverting it should be a usable detector for the `up` pattern.
114 |
115 | * The generated networks may differ for each run.
116 |
117 | ## Quick start
118 |
119 | This requires Go 1.11 or later.
120 |
121 | Clone the repository:
122 |
123 | git clone https://github.com/xyproto/wann
124 |
125 | Enter the `cmd/evolve` directory:
126 |
127 | cd wann/cmd/evolve
128 |
129 | Build and run the example:
130 |
131 | go build && ./evolve
132 |
133 | Take a look at the best network for judging if a set of numbers that are either 0 or 1 are of one category:
134 |
135 | xdg-open network.svg
136 |
137 | (If needed, use your favorite SVG viewer instead of the `xdg-open` command).
138 |
139 | ## Ideas
140 |
141 | * Adding convolution nodes might give interesting results.
142 |
143 | ## Generating Go code from a trained network
144 |
145 | This is an experimental feature and a work in progress!
146 |
147 | The idea is to generate one large expression from all the expressions that each node in the network represents.
148 |
149 | Right now, this only works for networks that has a depth of 1.
150 |
151 | For example, adding these two lines to `cmd/evolve/main.go`:
152 |
153 | ```go
154 | // Output a Go function for this network
155 | fmt.Println(trainedNetwork.GoFunction())
156 | ```
157 |
158 | Produces this output:
159 |
160 | ```go
161 | func f(x float64) float64 { return -x }
162 | ```
163 |
164 | The plan is to output a function that takes the input data instead, and refers to the input data by index. Support for deeper networks also needs to be added.
165 |
166 | There is a complete example for outputting Go code in `cmd/gofunction`.
167 |
168 | ## General info
169 |
170 | * Version: 0.3.2
171 | * License: BSD-3
172 | * Author: Alexander F. Rødseth <xyproto@archlinux.org>
173 |
--------------------------------------------------------------------------------
/diagram.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "io/ioutil"
7 | "strconv"
8 |
9 | "github.com/xyproto/tinysvg"
10 | )
11 |
12 | // OutputSVG will output the current network as an SVG image to the given io.Writer
13 | // TODO: Clean up and refactor
14 | func (net *Network) OutputSVG(w io.Writer) (int, error) {
15 | // Set up margins and the canvas size
16 | var (
17 | marginLeft = 10
18 | marginTop = 10
19 | marginBottom = 10
20 | marginRight = 10
21 | nodeRadius = 10
22 | betweenPadding = 4
23 | d = float64(net.Depth()) * 2.5
24 | width = marginLeft + int(float64(nodeRadius)*2.0*d) + betweenPadding*(int(d)-1) + nodeRadius + marginRight
25 | l = float64(len(net.InputNodes))
26 | height = marginTop + int(float64(nodeRadius)*1.5*l) + betweenPadding*(int(l)-1) + marginBottom
27 | imgPadding = 5
28 | lineWidth = 2
29 | )
30 |
31 | if width < 128 {
32 | width = 128
33 | }
34 | if height < 128 {
35 | height = 128
36 | }
37 |
38 | // Start a new SVG image
39 | document, svg := tinysvg.NewTinySVG(width+imgPadding*2, height+imgPadding*2)
40 | svg.Describe("generated with github.com/xyproto/wann")
41 |
42 | // White background rounded rectangle
43 | bg := svg.AddRoundedRect(imgPadding, imgPadding, 30, 30, width, height)
44 | bg.Fill2(tinysvg.ColorByName("white"))
45 | bg.Stroke2(tinysvg.ColorByName("black"))
46 |
47 | // Position of output node
48 | outputx := width - (marginRight + nodeRadius*2) + imgPadding
49 | outputy := (height-(nodeRadius*2))/2 + imgPadding
50 |
51 | // For each connected neuron, store it with the distance from the output neuron as the key in a map
52 | layerNeurons := make(map[int][]NeuronIndex)
53 | maxDistance := 0
54 | net.ForEachConnectedNodeIndex(func(ni NeuronIndex) {
55 | distanceFromOutput := net.AllNodes[ni].distanceFromOutputNode
56 | layerNeurons[distanceFromOutput] = append(layerNeurons[distanceFromOutput], ni)
57 | if distanceFromOutput > maxDistance {
58 | maxDistance = distanceFromOutput
59 | }
60 | })
61 |
62 | // Draw the input nodes as circles, and connections to the output node as lines
63 | //for i, n := range net.InputNodes {
64 | columnOffset := 50
65 |
66 | getPosition := func(givenNeuron NeuronIndex) (int, int) {
67 | for outputDistance, neurons := range layerNeurons {
68 | for neuronLayerIndex, otherNeuron := range neurons {
69 | if otherNeuron == givenNeuron {
70 | x := marginLeft + imgPadding + columnOffset*(maxDistance-outputDistance)
71 | y := (neuronLayerIndex * (nodeRadius*2 + betweenPadding)) + marginTop + imgPadding
72 | return x, y
73 | }
74 | }
75 | }
76 | panic("implementation error: neuron index not found")
77 | }
78 |
79 | // Draw node lines first
80 | for _, neurons := range layerNeurons {
81 | for _, neuronIndex := range neurons {
82 | if neuronIndex == net.OutputNode {
83 | continue
84 | }
85 | // Find the position of this node circle
86 | x, y := getPosition(neuronIndex)
87 | // Draw the connection from the center of this node to the center of all input nodes, if applicable
88 | for _, inputNeuron := range (net.AllNodes[neuronIndex]).InputNodes {
89 | ix, iy := getPosition(inputNeuron)
90 | svg.Line(ix+nodeRadius, iy+nodeRadius, x+nodeRadius, y+nodeRadius, lineWidth, "orange")
91 | }
92 | // Draw the connection to the output node, if it has this node as input
93 | if net.AllNodes[net.OutputNode].HasInput(neuronIndex) {
94 | svg.Line(x+nodeRadius, y+nodeRadius, outputx+nodeRadius, outputy+nodeRadius, lineWidth, "#0099ff")
95 | }
96 | }
97 | }
98 |
99 | // Then draw the nodes on top, including graph plots
100 | for _, neurons := range layerNeurons {
101 | for _, neuronIndex := range neurons {
102 | if neuronIndex == net.OutputNode {
103 | continue
104 | }
105 |
106 | // Find the position of this node circle
107 | x, y := getPosition(neuronIndex)
108 |
109 | // Draw this node
110 | input := svg.AddCircle(x+nodeRadius, y+nodeRadius, nodeRadius)
111 | switch net.AllNodes[neuronIndex].distanceFromOutputNode {
112 | case 1, 6:
113 | input.Fill("lightblue")
114 | case 2, 7:
115 | input.Fill("lightgreen")
116 | case 3, 8:
117 | input.Fill("lightyellow")
118 | case 4, 9:
119 | input.Fill("orange")
120 | case 5, 10:
121 | input.Fill("red")
122 | default:
123 | input.Fill("gray")
124 | }
125 | input.Stroke2(tinysvg.ColorByName("black"))
126 |
127 | // Plot the activation function inside this node
128 | var points []*tinysvg.Pos
129 | startx := float64(x) + float64(nodeRadius)*0.5
130 | stopx := float64(x+nodeRadius*2) - float64(nodeRadius)*0.5
131 | ypos := float64(y)
132 | for xpos := startx; xpos < stopx; xpos += 0.2 {
133 | // xr is from 0 to 1
134 | xr := float64(xpos-startx) / float64(stopx-startx)
135 | // xv is from -5 to 5
136 | xv := (xr - 0.5) * float64(nodeRadius)
137 | node := net.AllNodes[neuronIndex]
138 | f := ActivationFunctions[node.ActivationFunction]
139 | yv := f(xv)
140 | // plot, 3.0 is the amplitude along y
141 | yp := float64(ypos) + float64(nodeRadius)*1.35 - (yv * 0.6 * float64(nodeRadius))
142 |
143 | if yp < (ypos + float64(nodeRadius)*0.1) {
144 | continue
145 | } else if yp > (ypos + float64(nodeRadius)*1.9) {
146 | continue
147 | }
148 |
149 | // Label
150 | name := node.ActivationFunction.Name()
151 | if net.IsInput(neuronIndex) {
152 | // Add a the input number to the name
153 | for i, ni := range net.InputNodes {
154 | if neuronIndex == ni {
155 | name += " [" + strconv.Itoa(i) + "]"
156 | }
157 | }
158 | } else if neuronIndex == net.OutputNode {
159 | name += " !"
160 | }
161 | box := svg.AddRect(int(startx-float64(nodeRadius)*0.4), int(ypos+float64(nodeRadius)*2.5)-5, len(name)*5, 6)
162 | box.Fill("black")
163 | svg.Text(int(startx-float64(nodeRadius)*0.4), int(ypos+float64(nodeRadius)*2.5), 8, "Courier", name, "white")
164 |
165 | p := tinysvg.NewPosf(xpos, yp)
166 | points = append(points, p)
167 | }
168 | // Draw the polyline (graph)
169 | pl := svg.Polyline(points, tinysvg.ColorByName("black"))
170 | pl.Stroke2(tinysvg.ColorByName("black"))
171 | pl.Fill2(tinysvg.ColorByName("none"))
172 |
173 | }
174 | }
175 |
176 | // Draw the output node
177 | output := svg.AddCircle(outputx+nodeRadius+1, outputy+nodeRadius+1, nodeRadius)
178 | output.Fill("magenta")
179 | output.Stroke2(tinysvg.ColorByName("black"))
180 |
181 | // Label
182 | name := net.AllNodes[net.OutputNode].ActivationFunction.Name() + " [o]"
183 | box := svg.AddRect(outputx-nodeRadius/2, (nodeRadius*2)+outputy+1, len(name)*5, 6)
184 | box.Fill("black")
185 | svg.Text(outputx-nodeRadius/2, (nodeRadius*2)+outputy+6, 8, "Courier", name, "white")
186 |
187 | // Write the data to the given io.Writer
188 | return w.Write(document.Bytes())
189 | }
190 |
191 | // WriteSVG saves a drawing of the current network as an SVG file
192 | func (net *Network) WriteSVG(filename string) error {
193 | var buf bytes.Buffer
194 | if _, err := net.OutputSVG(&buf); err != nil {
195 | return err
196 | }
197 | return ioutil.WriteFile(filename, buf.Bytes(), 0644)
198 | }
199 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/file.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "regexp"
7 | "strings"
8 | "unicode"
9 | "unicode/utf8"
10 | )
11 |
12 | // NewFile Creates a new file, with the specified package name.
13 | func NewFile(packageName string) *File {
14 | return &File{
15 | Group: &Group{
16 | multi: true,
17 | },
18 | name: packageName,
19 | imports: map[string]importdef{},
20 | hints: map[string]importdef{},
21 | }
22 | }
23 |
24 | // NewFilePath creates a new file while specifying the package path - the
25 | // package name is inferred from the path.
26 | func NewFilePath(packagePath string) *File {
27 | return &File{
28 | Group: &Group{
29 | multi: true,
30 | },
31 | name: guessAlias(packagePath),
32 | path: packagePath,
33 | imports: map[string]importdef{},
34 | hints: map[string]importdef{},
35 | }
36 | }
37 |
38 | // NewFilePathName creates a new file with the specified package path and name.
39 | func NewFilePathName(packagePath, packageName string) *File {
40 | return &File{
41 | Group: &Group{
42 | multi: true,
43 | },
44 | name: packageName,
45 | path: packagePath,
46 | imports: map[string]importdef{},
47 | hints: map[string]importdef{},
48 | }
49 | }
50 |
51 | // File represents a single source file. Package imports are managed
52 | // automaticaly by File.
53 | type File struct {
54 | *Group
55 | name string
56 | path string
57 | imports map[string]importdef
58 | hints map[string]importdef
59 | comments []string
60 | headers []string
61 | cgoPreamble []string
62 | // If you're worried about generated package aliases conflicting with local variable names, you
63 | // can set a prefix here. Package foo becomes {prefix}_foo.
64 | PackagePrefix string
65 | // CanonicalPath adds a canonical import path annotation to the package clause.
66 | CanonicalPath string
67 | }
68 |
69 | // importdef is used to differentiate packages where we know the package name from packages where the
70 | // import is aliased. If alias == false, then name is the actual package name, and the import will be
71 | // rendered without an alias. If used == false, the import has not been used in code yet and should be
72 | // excluded from the import block.
73 | type importdef struct {
74 | name string
75 | alias bool
76 | }
77 |
78 | // HeaderComment adds a comment to the top of the file, above any package
79 | // comments. A blank line is rendered below the header comments, ensuring
80 | // header comments are not included in the package doc.
81 | func (f *File) HeaderComment(comment string) {
82 | f.headers = append(f.headers, comment)
83 | }
84 |
85 | // PackageComment adds a comment to the top of the file, above the package
86 | // keyword.
87 | func (f *File) PackageComment(comment string) {
88 | f.comments = append(f.comments, comment)
89 | }
90 |
91 | // CgoPreamble adds a cgo preamble comment that is rendered directly before the "C" pseudo-package
92 | // import.
93 | func (f *File) CgoPreamble(comment string) {
94 | f.cgoPreamble = append(f.cgoPreamble, comment)
95 | }
96 |
97 | // Anon adds an anonymous import.
98 | func (f *File) Anon(paths ...string) {
99 | for _, p := range paths {
100 | f.imports[p] = importdef{name: "_", alias: true}
101 | }
102 | }
103 |
104 | // ImportName provides the package name for a path. If specified, the alias will be omitted from the
105 | // import block. This is optional. If not specified, a sensible package name is used based on the path
106 | // and this is added as an alias in the import block.
107 | func (f *File) ImportName(path, name string) {
108 | f.hints[path] = importdef{name: name, alias: false}
109 | }
110 |
111 | // ImportNames allows multiple names to be imported as a map. Use the [gennames](gennames) command to
112 | // automatically generate a go file containing a map of a selection of package names.
113 | func (f *File) ImportNames(names map[string]string) {
114 | for path, name := range names {
115 | f.hints[path] = importdef{name: name, alias: false}
116 | }
117 | }
118 |
119 | // ImportAlias provides the alias for a package path that should be used in the import block. A
120 | // period can be used to force a dot-import.
121 | func (f *File) ImportAlias(path, alias string) {
122 | f.hints[path] = importdef{name: alias, alias: true}
123 | }
124 |
125 | func (f *File) isLocal(path string) bool {
126 | return f.path == path
127 | }
128 |
129 | func (f *File) isValidAlias(alias string) bool {
130 | // multiple dot-imports are ok
131 | if alias == "." {
132 | return true
133 | }
134 | // the import alias is invalid if it's a reserved word
135 | if IsReservedWord(alias) {
136 | return false
137 | }
138 | // the import alias is invalid if it's already been registered
139 | for _, v := range f.imports {
140 | if alias == v.name {
141 | return false
142 | }
143 | }
144 | return true
145 | }
146 |
147 | func (f *File) isDotImport(path string) bool {
148 | if id, ok := f.hints[path]; ok {
149 | return id.name == "." && id.alias
150 | }
151 | return false
152 | }
153 |
154 | func (f *File) register(path string) string {
155 | if f.isLocal(path) {
156 | // notest
157 | // should never get here becasue in Qual the packageToken will be null,
158 | // so render will never be called.
159 | return ""
160 | }
161 |
162 | // if the path has been registered previously, simply return the name
163 | def := f.imports[path]
164 | if def.name != "" && def.name != "_" {
165 | return def.name
166 | }
167 |
168 | // special case for "C" pseudo-package
169 | if path == "C" {
170 | f.imports["C"] = importdef{name: "C", alias: false}
171 | return "C"
172 | }
173 |
174 | var name string
175 | var alias bool
176 |
177 | if hint := f.hints[path]; hint.name != "" {
178 | // look up the path in the list of provided package names and aliases by ImportName / ImportAlias
179 | name = hint.name
180 | alias = hint.alias
181 | } else if standardLibraryHints[path] != "" {
182 | // look up the path in the list of standard library packages
183 | name = standardLibraryHints[path]
184 | alias = false
185 | } else {
186 | // if a hint is not found for the package, guess the alias from the package path
187 | name = guessAlias(path)
188 | alias = true
189 | }
190 |
191 | // If the name is invalid or has been registered already, make it unique by appending a number
192 | unique := name
193 | i := 0
194 | for !f.isValidAlias(unique) {
195 | i++
196 | unique = fmt.Sprintf("%s%d", name, i)
197 | }
198 |
199 | // If we've changed the name to make it unique, it should definitely be an alias
200 | if unique != name {
201 | alias = true
202 | }
203 |
204 | // Only add a prefix if the name is an alias
205 | if f.PackagePrefix != "" && alias {
206 | unique = f.PackagePrefix + "_" + unique
207 | }
208 |
209 | // Register the eventual name
210 | f.imports[path] = importdef{name: unique, alias: alias}
211 |
212 | return unique
213 | }
214 |
215 | // GoString renders the File for testing. Any error will cause a panic.
216 | func (f *File) GoString() string {
217 | buf := &bytes.Buffer{}
218 | if err := f.Render(buf); err != nil {
219 | panic(err)
220 | }
221 | return buf.String()
222 | }
223 |
224 | func guessAlias(path string) string {
225 | alias := path
226 |
227 | if strings.HasSuffix(alias, "/") {
228 | // training slashes are usually tolerated, so we can get rid of one if
229 | // it exists
230 | alias = alias[:len(alias)-1]
231 | }
232 |
233 | if strings.Contains(alias, "/") {
234 | // if the path contains a "/", use the last part
235 | alias = alias[strings.LastIndex(alias, "/")+1:]
236 | }
237 |
238 | // alias should be lower case
239 | alias = strings.ToLower(alias)
240 |
241 | // alias should now only contain alphanumerics
242 | importsRegex := regexp.MustCompile(`[^a-z0-9]`)
243 | alias = importsRegex.ReplaceAllString(alias, "")
244 |
245 | // can't have a first digit, per Go identifier rules, so just skip them
246 | for firstRune, runeLen := utf8.DecodeRuneInString(alias); unicode.IsDigit(firstRune); firstRune, runeLen = utf8.DecodeRuneInString(alias) {
247 | alias = alias[runeLen:]
248 | }
249 |
250 | // If path part was all digits, we may be left with an empty string. In this case use "pkg" as the alias.
251 | if alias == "" {
252 | alias = "pkg"
253 | }
254 |
255 | return alias
256 | }
257 |
--------------------------------------------------------------------------------
/af.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "fmt"
5 | "time"
6 |
7 | "github.com/dave/jennifer/jen"
8 | "github.com/xyproto/af"
9 | )
10 |
11 | // ActivationFunctionIndex is a number that represents a specific activation function
12 | type ActivationFunctionIndex int
13 |
14 | const (
15 | // Step is a step. First 0 and then abrubtly up to 1.
16 | Step ActivationFunctionIndex = iota
17 | // Linear is the linear activation function. Gradually from 0 to 1.
18 | Linear
19 | // Sin is the sinoid activation function
20 | Sin
21 | // Gauss is the Gaussian function, with a mean of 0 and a sigma of 1
22 | Gauss
23 | // Tanh is math.Tanh
24 | Tanh
25 | // Sigmoid is the optimized sigmoid function from github.com/xyproto/swish
26 | Sigmoid
27 | // Inv is the inverse linear function
28 | Inv
29 | // Abs is math.Abs
30 | Abs
31 | // ReLU or ReLU is the rectified linear unit, first 0 and then the linear function
32 | ReLU
33 | // Cos is the cosoid (?) activation function
34 | Cos
35 | // Squared increases rapidly
36 | Squared
37 | // Swish is a later invention than ReLU, _|
38 | Swish
39 | // SoftPlus is log(1 + exp(x))
40 | SoftPlus
41 | )
42 |
43 | // ActivationFunctions is a collection of activation functions, where the keys are constants that are defined above
44 | // https://github.com/google/brain-tokyo-workshop/blob/master/WANNRelease/WANN/wann_src/ind.py
45 | var ActivationFunctions = map[ActivationFunctionIndex](func(float64) float64){
46 | Step: af.Step, // Unsigned Step Function
47 | Linear: af.Linear, // Linear
48 | Sin: af.Sin, // Sin
49 | Gauss: af.Gaussian01, // Gaussian with mean 0 and sigma 1
50 | Tanh: af.Tanh, // Hyperbolic Tangent (signed?)
51 | Sigmoid: af.Sigmoid, // Sigmoid (unsigned?)
52 | Inv: af.Inv, // Inverse
53 | Abs: af.Abs, // Absolute value
54 | ReLU: af.ReLU, // Rectified linear unit
55 | Cos: af.Cos, // Cosine
56 | Squared: af.Squared, // Squared
57 | Swish: af.Swish, // Swish
58 | SoftPlus: af.SoftPlus, // SoftPlus
59 | }
60 |
61 | // ComplexityEstimate is a map for having an estimate of how complex each function is,
62 | // based on a quick benchmark of each function.
63 | // The complexity estimates will vary, depending on the performance.
64 | var ComplexityEstimate = make(map[ActivationFunctionIndex]float64)
65 |
66 | func (config *Config) estimateComplexity() {
67 | if config.Verbose {
68 | fmt.Print("Estimating activation function complexity...")
69 | }
70 | startEstimate := time.Now()
71 | resolution := 0.0001
72 | durationMap := make(map[ActivationFunctionIndex]time.Duration)
73 | var maxDuration time.Duration
74 | for i, f := range ActivationFunctions {
75 | start := time.Now()
76 | for x := 0.0; x <= 1.0; x += resolution {
77 | _ = f(x)
78 | }
79 | duration := time.Since(start)
80 | durationMap[ActivationFunctionIndex(i)] = duration
81 | if duration > maxDuration {
82 | maxDuration = duration
83 | }
84 | }
85 | for i := range ActivationFunctions {
86 | // 1.0 means the function took maxDuration
87 | ComplexityEstimate[ActivationFunctionIndex(i)] = float64(durationMap[ActivationFunctionIndex(i)]) / float64(maxDuration)
88 | }
89 | estimateDuration := time.Since(startEstimate)
90 | if config.Verbose {
91 | fmt.Printf(" done. (In %v)\n", estimateDuration)
92 | }
93 | }
94 |
95 | // Call runs an activation function with the given float64 value.
96 | // The activation function is chosen by one of the constants above.
97 | func (afi ActivationFunctionIndex) Call(x float64) float64 {
98 | if f, ok := ActivationFunctions[afi]; ok {
99 | return f(x)
100 | }
101 | // Use the linear function by default
102 | return af.Linear(x)
103 | }
104 |
105 | // Name returns a name for each activation function
106 | func (afi ActivationFunctionIndex) Name() string {
107 | switch afi {
108 | case Step:
109 | return "Step"
110 | case Linear:
111 | return "Linear"
112 | case Sin:
113 | return "Sinusoid"
114 | case Gauss:
115 | return "Gaussian"
116 | case Tanh:
117 | return "Tanh"
118 | case Sigmoid:
119 | return "Sigmoid"
120 | case Inv:
121 | return "Inverted"
122 | case Abs:
123 | return "Absolute"
124 | case ReLU:
125 | return "ReLU"
126 | case Cos:
127 | return "Cosinusoid"
128 | case Squared:
129 | return "Squared"
130 | case Swish:
131 | return "Swish"
132 | case SoftPlus:
133 | return "SoftPlus"
134 | default:
135 | return "Untitled"
136 | }
137 | }
138 |
139 | // goExpression returns the Go expression for this activation function, using the given variable name string as the input variable name
140 | func (afi ActivationFunctionIndex) goExpression(varName string) string {
141 | switch afi {
142 | case Step:
143 | // Using s to not confuse it with the varName
144 | return "func(s float64) float64 { if s >= 0 { return 1 } else { return 0 } }(" + varName + ")"
145 | case Linear:
146 | return varName
147 | case Sin:
148 | return "math.Sin(math.Pi * " + varName + ")"
149 | case Gauss:
150 | return "math.Exp(-(" + varName + " * " + varName + ") / 2.0)"
151 | case Tanh:
152 | return "math.Tanh(" + varName + ")"
153 | case Sigmoid:
154 | return "(1.0 / (1.0 + math.Exp(-" + varName + ")))"
155 | case Inv:
156 | return "-" + varName
157 | case Abs:
158 | return "math.Abs(" + varName + ")"
159 | case ReLU:
160 | // Using r to not confuse it with the varName
161 | return "func(r float64) float64 { if r >= 0 { return r } else { return 0 } }(" + varName + ")"
162 | case Cos:
163 | return "math.Cos(math.Pi * " + varName + ")"
164 | case Squared:
165 | return "(" + varName + " * " + varName + ")"
166 | case Swish:
167 | return "(" + varName + "/ (1.0 + math.Exp(-" + varName + ")))"
168 | case SoftPlus:
169 | return "math.Log(1.0 + math.Exp(" + varName + "))"
170 | default:
171 | return varName
172 | }
173 | }
174 |
175 | // String returns the Go expression for this activation function, using "x" as the input variable name
176 | func (afi ActivationFunctionIndex) String() string {
177 | return afi.goExpression("x")
178 | }
179 |
180 | // Statement returns the Statement statement for this activation function, using the given inner statement
181 | func (afi ActivationFunctionIndex) Statement(inner *jen.Statement) *jen.Statement {
182 | switch afi {
183 | case Step:
184 | // func(s float64) float64 { if s >= 0 { return 1 } else { return 0 } }(inner)
185 | // Using s to not confuse it with the varName
186 | return jen.Func().Params(jen.Id("s").Id("float64")).Id("float64").Block(
187 | jen.If(jen.Id("s").Op(">=").Id("0")).Block(
188 | jen.Return(jen.Lit(1)),
189 | ).Else().Block(
190 | jen.Return(jen.Lit(0)),
191 | ),
192 | ).Call(inner)
193 | case Cos:
194 | // math.Cos((inner) * math.Pi)
195 | return jen.Qual("math", "Cos").Call(jen.Parens(inner).Op("*").Id("math").Dot("Pi"))
196 | case Sin:
197 | // math.Sin((inner) * math.Pi)
198 | return jen.Qual("math", "Sin").Call(jen.Parens(inner).Op("*").Id("math").Dot("Pi"))
199 | case Gauss:
200 | // return math.Exp(-(math.Pow(inner, 2.0)) / 2.0)
201 | return jen.Qual("math", "Exp").Call(jen.Op("-").Parens(
202 | // Using math.Pow ensures the inner expression is only calculated once, if it's a large expression
203 | //inner.Op("*").Add(inner),
204 | jen.Qual("math", "Pow").Params(
205 | inner,
206 | jen.Lit(2.0),
207 | ),
208 | ).Op("/").Lit(2.0))
209 | case Tanh:
210 | // math.Tanh(inner)
211 | return jen.Qual("math", "Tanh").Call(inner)
212 | case Sigmoid:
213 | // (1.0 / (1.0 + math.Exp(-(inner))))
214 | return jen.Lit(1.0).Op("/").Parens(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(jen.Op("-").Parens(inner)))
215 | case Inv:
216 | // -(inner)
217 | return jen.Op("-").Parens(inner)
218 | case Abs:
219 | // math.Abs(inner)
220 | return jen.Qual("math", "Abs").Call(inner)
221 | case ReLU:
222 | //return "func(r float64) float64 { if r >= 0 { return r } else { return 0 } }(" + varName + ")"
223 | // Using r to not confuse it with the varName
224 | return jen.Func().Params(jen.Id("r").Id("float64")).Id("float64").Block(
225 | jen.If(jen.Id("r").Op(">=").Id("0")).Block(
226 | jen.Return(jen.Id("r")),
227 | ).Else().Block(
228 | jen.Return(jen.Lit(0)),
229 | ),
230 | ).Call(inner)
231 | case Squared:
232 | // inner^2
233 | //return inner.Op("*").Add(inner)
234 | // Using math.Pow ensures the inner expression is only calculated once, if it's a large expression
235 | return jen.Qual("math", "Pow").Call(inner, jen.Lit(2.0))
236 | case Swish:
237 | // (inner / (1.0 + math.Exp(-inner)))
238 | return jen.Parens(inner.Op("/").Parens(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(jen.Op("-").Parens(inner))))
239 | case SoftPlus:
240 | // math.Log(1.0 + math.Exp(inner))
241 | return jen.Qual("math", "Log").Call(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(inner))
242 | case Linear:
243 | // This is also the default case: (inner)
244 | fallthrough
245 | default:
246 | // (inner)
247 | return jen.Parens(inner)
248 | }
249 | }
250 |
251 | // GoRun will first construct the expression using jennifer and then evaluate the result using "go run" and a source file innn /tmp
252 | func (afi ActivationFunctionIndex) GoRun(x float64) (float64, error) {
253 | return RunStatementX(afi.Statement(jen.Id("x")), x)
254 | }
255 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/tokens.go:
--------------------------------------------------------------------------------
1 | package jen
2 |
3 | import (
4 | "fmt"
5 | "io"
6 | "strconv"
7 | "strings"
8 | )
9 |
10 | type tokenType string
11 |
12 | const (
13 | packageToken tokenType = "package"
14 | identifierToken tokenType = "identifier"
15 | qualifiedToken tokenType = "qualified"
16 | keywordToken tokenType = "keyword"
17 | operatorToken tokenType = "operator"
18 | delimiterToken tokenType = "delimiter"
19 | literalToken tokenType = "literal"
20 | literalRuneToken tokenType = "literal_rune"
21 | literalByteToken tokenType = "literal_byte"
22 | nullToken tokenType = "null"
23 | layoutToken tokenType = "layout"
24 | )
25 |
26 | type token struct {
27 | typ tokenType
28 | content interface{}
29 | }
30 |
31 | func (t token) isNull(f *File) bool {
32 | if t.typ == packageToken {
33 | // package token is null if the path is a dot-import or the local package path
34 | return f.isDotImport(t.content.(string)) || f.isLocal(t.content.(string))
35 | }
36 | return t.typ == nullToken
37 | }
38 |
39 | func (t token) render(f *File, w io.Writer, s *Statement) error {
40 | switch t.typ {
41 | case literalToken:
42 | var out string
43 | switch t.content.(type) {
44 | case bool, string, int, complex128:
45 | // default constant types can be left bare
46 | out = fmt.Sprintf("%#v", t.content)
47 | case float64:
48 | out = fmt.Sprintf("%#v", t.content)
49 | if !strings.Contains(out, ".") && !strings.Contains(out, "e") {
50 | // If the formatted value is not in scientific notation, and does not have a dot, then
51 | // we add ".0". Otherwise it will be interpreted as an int.
52 | // See:
53 | // https://github.com/dave/jennifer/issues/39
54 | // https://github.com/golang/go/issues/26363
55 | out += ".0"
56 | }
57 | case float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr:
58 | // other built-in types need specific type info
59 | out = fmt.Sprintf("%T(%#v)", t.content, t.content)
60 | case complex64:
61 | // fmt package already renders parenthesis for complex64
62 | out = fmt.Sprintf("%T%#v", t.content, t.content)
63 | default:
64 | panic(fmt.Sprintf("unsupported type for literal: %T", t.content))
65 | }
66 | if _, err := w.Write([]byte(out)); err != nil {
67 | return err
68 | }
69 | case literalRuneToken:
70 | if _, err := w.Write([]byte(strconv.QuoteRune(t.content.(rune)))); err != nil {
71 | return err
72 | }
73 | case literalByteToken:
74 | if _, err := w.Write([]byte(fmt.Sprintf("byte(%#v)", t.content))); err != nil {
75 | return err
76 | }
77 | case keywordToken, operatorToken, layoutToken, delimiterToken:
78 | if _, err := w.Write([]byte(fmt.Sprintf("%s", t.content))); err != nil {
79 | return err
80 | }
81 | if t.content.(string) == "default" {
82 | // Special case for Default, which must always be followed by a colon
83 | if _, err := w.Write([]byte(":")); err != nil {
84 | return err
85 | }
86 | }
87 | case packageToken:
88 | path := t.content.(string)
89 | alias := f.register(path)
90 | if _, err := w.Write([]byte(alias)); err != nil {
91 | return err
92 | }
93 | case identifierToken:
94 | if _, err := w.Write([]byte(t.content.(string))); err != nil {
95 | return err
96 | }
97 | case nullToken: // notest
98 | // do nothing (should never render a null token)
99 | }
100 | return nil
101 | }
102 |
103 | // Null adds a null item. Null items render nothing and are not followed by a
104 | // separator in lists.
105 | func Null() *Statement {
106 | return newStatement().Null()
107 | }
108 |
109 | // Null adds a null item. Null items render nothing and are not followed by a
110 | // separator in lists.
111 | func (g *Group) Null() *Statement {
112 | s := Null()
113 | g.items = append(g.items, s)
114 | return s
115 | }
116 |
117 | // Null adds a null item. Null items render nothing and are not followed by a
118 | // separator in lists.
119 | func (s *Statement) Null() *Statement {
120 | t := token{
121 | typ: nullToken,
122 | }
123 | *s = append(*s, t)
124 | return s
125 | }
126 |
127 | // Empty adds an empty item. Empty items render nothing but are followed by a
128 | // separator in lists.
129 | func Empty() *Statement {
130 | return newStatement().Empty()
131 | }
132 |
133 | // Empty adds an empty item. Empty items render nothing but are followed by a
134 | // separator in lists.
135 | func (g *Group) Empty() *Statement {
136 | s := Empty()
137 | g.items = append(g.items, s)
138 | return s
139 | }
140 |
141 | // Empty adds an empty item. Empty items render nothing but are followed by a
142 | // separator in lists.
143 | func (s *Statement) Empty() *Statement {
144 | t := token{
145 | typ: operatorToken,
146 | content: "",
147 | }
148 | *s = append(*s, t)
149 | return s
150 | }
151 |
152 | // Op renders the provided operator / token.
153 | func Op(op string) *Statement {
154 | return newStatement().Op(op)
155 | }
156 |
157 | // Op renders the provided operator / token.
158 | func (g *Group) Op(op string) *Statement {
159 | s := Op(op)
160 | g.items = append(g.items, s)
161 | return s
162 | }
163 |
164 | // Op renders the provided operator / token.
165 | func (s *Statement) Op(op string) *Statement {
166 | t := token{
167 | typ: operatorToken,
168 | content: op,
169 | }
170 | *s = append(*s, t)
171 | return s
172 | }
173 |
174 | // Dot renders a period followed by an identifier. Use for fields and selectors.
175 | func Dot(name string) *Statement {
176 | // notest
177 | // don't think this can be used in valid code?
178 | return newStatement().Dot(name)
179 | }
180 |
181 | // Dot renders a period followed by an identifier. Use for fields and selectors.
182 | func (g *Group) Dot(name string) *Statement {
183 | // notest
184 | // don't think this can be used in valid code?
185 | s := Dot(name)
186 | g.items = append(g.items, s)
187 | return s
188 | }
189 |
190 | // Dot renders a period followed by an identifier. Use for fields and selectors.
191 | func (s *Statement) Dot(name string) *Statement {
192 | d := token{
193 | typ: delimiterToken,
194 | content: ".",
195 | }
196 | t := token{
197 | typ: identifierToken,
198 | content: name,
199 | }
200 | *s = append(*s, d, t)
201 | return s
202 | }
203 |
204 | // Id renders an identifier.
205 | func Id(name string) *Statement {
206 | return newStatement().Id(name)
207 | }
208 |
209 | // Id renders an identifier.
210 | func (g *Group) Id(name string) *Statement {
211 | s := Id(name)
212 | g.items = append(g.items, s)
213 | return s
214 | }
215 |
216 | // Id renders an identifier.
217 | func (s *Statement) Id(name string) *Statement {
218 | t := token{
219 | typ: identifierToken,
220 | content: name,
221 | }
222 | *s = append(*s, t)
223 | return s
224 | }
225 |
226 | // Qual renders a qualified identifier. Imports are automatically added when
227 | // used with a File. If the path matches the local path, the package name is
228 | // omitted. If package names conflict they are automatically renamed. Note that
229 | // it is not possible to reliably determine the package name given an arbitrary
230 | // package path, so a sensible name is guessed from the path and added as an
231 | // alias. The names of all standard library packages are known so these do not
232 | // need to be aliased. If more control is needed of the aliases, see
233 | // [File.ImportName](#importname) or [File.ImportAlias](#importalias).
234 | func Qual(path, name string) *Statement {
235 | return newStatement().Qual(path, name)
236 | }
237 |
238 | // Qual renders a qualified identifier. Imports are automatically added when
239 | // used with a File. If the path matches the local path, the package name is
240 | // omitted. If package names conflict they are automatically renamed. Note that
241 | // it is not possible to reliably determine the package name given an arbitrary
242 | // package path, so a sensible name is guessed from the path and added as an
243 | // alias. The names of all standard library packages are known so these do not
244 | // need to be aliased. If more control is needed of the aliases, see
245 | // [File.ImportName](#importname) or [File.ImportAlias](#importalias).
246 | func (g *Group) Qual(path, name string) *Statement {
247 | s := Qual(path, name)
248 | g.items = append(g.items, s)
249 | return s
250 | }
251 |
252 | // Qual renders a qualified identifier. Imports are automatically added when
253 | // used with a File. If the path matches the local path, the package name is
254 | // omitted. If package names conflict they are automatically renamed. Note that
255 | // it is not possible to reliably determine the package name given an arbitrary
256 | // package path, so a sensible name is guessed from the path and added as an
257 | // alias. The names of all standard library packages are known so these do not
258 | // need to be aliased. If more control is needed of the aliases, see
259 | // [File.ImportName](#importname) or [File.ImportAlias](#importalias).
260 | func (s *Statement) Qual(path, name string) *Statement {
261 | g := &Group{
262 | close: "",
263 | items: []Code{
264 | token{
265 | typ: packageToken,
266 | content: path,
267 | },
268 | token{
269 | typ: identifierToken,
270 | content: name,
271 | },
272 | },
273 | name: "qual",
274 | open: "",
275 | separator: ".",
276 | }
277 | *s = append(*s, g)
278 | return s
279 | }
280 |
281 | // Line inserts a blank line.
282 | func Line() *Statement {
283 | return newStatement().Line()
284 | }
285 |
286 | // Line inserts a blank line.
287 | func (g *Group) Line() *Statement {
288 | s := Line()
289 | g.items = append(g.items, s)
290 | return s
291 | }
292 |
293 | // Line inserts a blank line.
294 | func (s *Statement) Line() *Statement {
295 | t := token{
296 | typ: layoutToken,
297 | content: "\n",
298 | }
299 | *s = append(*s, t)
300 | return s
301 | }
302 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/tinysvg/tags.go:
--------------------------------------------------------------------------------
1 | // Package tinysvg has structs and functions for creating and rendering TinySVG images
2 | package tinysvg
3 |
4 | // Everything here deals with bytes, not strings
5 | // TODO: Add a Write function that takes an io.Writer so that the image can be written as it is generated.
6 |
7 | import (
8 | "bytes"
9 | "fmt"
10 | "io/ioutil"
11 | )
12 |
13 | // Tag represents an XML tag, as part of a larger XML document
14 | type Tag struct {
15 | name []byte
16 | content []byte
17 | lastContent []byte
18 | xmlContent []byte
19 | attrs map[string][]byte
20 | nextSibling *Tag // siblings
21 | firstChild *Tag // first child
22 | }
23 |
24 | // Document is an XML document, with a title and a root tag
25 | type Document struct {
26 | title []byte
27 | root *Tag
28 | }
29 |
30 | // NewDocument creates a new XML/HTML/SVG image, with a root tag.
31 | // If rootTagName contains "<" or ">", it can be used for preceding declarations,
32 | // like or .
33 | // Returns a pointer to a Document.
34 | func NewDocument(title, rootTagName []byte) *Document {
35 | var image Document
36 | image.title = []byte(title)
37 | rootTag := NewTag([]byte(rootTagName))
38 | image.root = rootTag
39 | return &image
40 | }
41 |
42 | // NewTag creates a new tag based on the given name.
43 | // "name" is what will appear right after "<" when rendering as XML/HTML/SVG.
44 | func NewTag(name []byte) *Tag {
45 | var tag Tag
46 | tag.name = name
47 | tag.attrs = make(map[string][]byte)
48 | tag.nextSibling = nil
49 | tag.firstChild = nil
50 | tag.content = make([]byte, 0)
51 | tag.lastContent = make([]byte, 0)
52 | return &tag
53 | }
54 |
55 | // AddNewTag adds a new tag to another tag. This will place it one step lower
56 | // in the hierarchy of tags. You can for example add a body tag to an html tag.
57 | func (tag *Tag) AddNewTag(name []byte) *Tag {
58 | child := NewTag(name)
59 | tag.AddChild(child)
60 | return child
61 | }
62 |
63 | // AddTag adds a tag to another tag
64 | func (tag *Tag) AddTag(child *Tag) {
65 | tag.AddChild(child)
66 | }
67 |
68 | // AddAttrib adds an attribute to a tag, for instance "size" and "20"
69 | func (tag *Tag) AddAttrib(attrName string, attrValue []byte) {
70 | tag.attrs[attrName] = attrValue
71 | }
72 |
73 | // AddSingularAttrib adds attribute without a value
74 | func (tag *Tag) AddSingularAttrib(attrName string) {
75 | tag.attrs[attrName] = nil
76 | }
77 |
78 | // GetAttrString returns a []byte that represents all the attribute keys and
79 | // values of a tag. This can be used when generating XML, SVG or HTML.
80 | func (tag *Tag) GetAttrString() []byte {
81 | ret := make([]byte, 0)
82 | for key, value := range tag.attrs {
83 | if value == nil {
84 | ret = append(ret, key...)
85 | ret = append(ret, ' ')
86 | } else {
87 | ret = append(ret, key...)
88 | ret = append(ret, []byte("=\"")...)
89 | ret = append(ret, value...)
90 | ret = append(ret, []byte("\" ")...)
91 | }
92 | }
93 | if len(ret) > 0 {
94 | ret = ret[:len(ret)-1]
95 | }
96 | return ret
97 | }
98 |
99 | // getFlatXML renders XML.
100 | // This will generate a []byte for a tag, non-recursively.
101 | func (tag *Tag) getFlatXML() []byte {
102 | // For the root tag
103 | if (len(tag.name) > 0) && (tag.name[0] == '<') {
104 | ret := make([]byte, 0, len(tag.name)+len(tag.content)+len(tag.xmlContent)+len(tag.lastContent))
105 | ret = append(ret, tag.name...)
106 | ret = append(ret, tag.content...)
107 | ret = append(ret, tag.xmlContent...)
108 | ret = append(ret, tag.lastContent...)
109 | return ret
110 | }
111 | // For indenting
112 | spacing := make([]byte, 0)
113 | // Generate the XML based on the tag
114 | attrs := tag.GetAttrString()
115 | ret := make([]byte, 0)
116 | ret = append(ret, spacing...)
117 | ret = append(ret, []byte("<")...)
118 | ret = append(ret, tag.name...)
119 | if len(attrs) > 0 {
120 | ret = append(ret, []byte(" ")...)
121 | ret = append(ret, attrs...)
122 | }
123 | if (len(tag.content) == 0) && (len(tag.xmlContent) == 0) && (len(tag.lastContent) == 0) {
124 | ret = append(ret, []byte(" />")...)
125 | } else {
126 | if len(tag.xmlContent) > 0 {
127 | if tag.xmlContent[0] != ' ' {
128 | ret = append(ret, []byte(">")...)
129 | ret = append(ret, spacing...)
130 | ret = append(ret, tag.xmlContent...)
131 | ret = append(ret, spacing...)
132 | ret = append(ret, []byte("")...)
133 | ret = append(ret, tag.name...)
134 | ret = append(ret, []byte(">")...)
135 | } else {
136 | ret = append(ret, []byte(">")...)
137 | ret = append(ret, tag.xmlContent...)
138 | ret = append(ret, spacing...)
139 | ret = append(ret, []byte("")...)
140 | ret = append(ret, tag.name...)
141 | ret = append(ret, []byte(">")...)
142 | }
143 | } else {
144 | ret = append(ret, []byte(">")...)
145 | ret = append(ret, tag.content...)
146 | ret = append(ret, tag.lastContent...)
147 | ret = append(ret, []byte("")...)
148 | ret = append(ret, tag.name...)
149 | ret = append(ret, []byte(">")...)
150 | }
151 | }
152 | return ret
153 | }
154 |
155 | // GetChildren returns all children for a given tag.
156 | // Returns a slice of pointers to tags.
157 | func (tag *Tag) GetChildren() []*Tag {
158 | var children []*Tag
159 | current := tag.firstChild
160 | for current != nil {
161 | children = append(children, current)
162 | current = current.nextSibling
163 | }
164 | return children
165 | }
166 |
167 | // AddChild adds a tag as a child to another tag
168 | func (tag *Tag) AddChild(child *Tag) {
169 | if tag.firstChild == nil {
170 | tag.firstChild = child
171 | return
172 | }
173 | lastChild := tag.LastChild()
174 | child.nextSibling = nil
175 | lastChild.nextSibling = child
176 | }
177 |
178 | // AddContent adds text to a tag.
179 | // This is what will appear between two tag markers, for example:
180 | // content
181 | // If the tag contains child tags, they will be rendered after this content.
182 | func (tag *Tag) AddContent(content []byte) {
183 | tag.content = append(tag.content, content...)
184 | }
185 |
186 | // AppendContent appends content to the end of the existing content of a tag
187 | func (tag *Tag) AppendContent(content []byte) {
188 | tag.lastContent = append(tag.lastContent, content...)
189 | }
190 |
191 | // AddLastContent appends content to the end of the existing content of a tag.
192 | // Deprecated.
193 | func (tag *Tag) AddLastContent(content []byte) {
194 | tag.AppendContent(content)
195 | }
196 |
197 | // CountChildren returns the number of children a tag has
198 | func (tag *Tag) CountChildren() int {
199 | child := tag.firstChild
200 | if child == nil {
201 | return 0
202 | }
203 | count := 1
204 | if child.nextSibling == nil {
205 | return count
206 | }
207 | child = child.nextSibling
208 | for child != nil {
209 | count++
210 | child = child.nextSibling
211 | }
212 | return count
213 | }
214 |
215 | // CountSiblings returns the number of siblings a tag has
216 | func (tag *Tag) CountSiblings() int {
217 | sib := tag.nextSibling
218 | if sib == nil {
219 | return 0
220 | }
221 | count := 1
222 | if sib.nextSibling == nil {
223 | return count
224 | }
225 | sib = sib.nextSibling
226 | for sib != nil {
227 | count++
228 | sib = sib.nextSibling
229 | }
230 | return count
231 | }
232 |
233 | // LastChild returns the last child of a tag
234 | func (tag *Tag) LastChild() *Tag {
235 | child := tag.firstChild
236 | for child.nextSibling != nil {
237 | child = child.nextSibling
238 | }
239 | return child
240 | }
241 |
242 | // GetTag searches all tags for the given name
243 | func (image *Document) GetTag(name []byte) (*Tag, error) {
244 | return image.root.GetTag(name)
245 | }
246 |
247 | // GetRoot returns the root tag of the image
248 | func (image *Document) GetRoot() *Tag {
249 | return image.root
250 | }
251 |
252 | // GetTag finds a tag by name and returns an error if not found.
253 | // Returns the first tag that matches.
254 | func (tag *Tag) GetTag(name []byte) (*Tag, error) {
255 | if bytes.Index(tag.name, name) == 0 {
256 | return tag, nil
257 | }
258 | couldNotFindError := fmt.Errorf("could not find tag: %s", name)
259 | if tag.CountChildren() == 0 {
260 | // No children. Not found so far
261 | return nil, couldNotFindError
262 | }
263 |
264 | child := tag.firstChild
265 | for child != nil {
266 | found, err := child.GetTag(name)
267 | if err == nil {
268 | return found, err
269 | }
270 | child = child.nextSibling
271 | }
272 |
273 | return nil, couldNotFindError
274 | }
275 |
276 | // Bytes (previously getXMLRecursively) renders XML for a tag, recursively.
277 | // The generated XML is returned as a []byte.
278 | func (tag *Tag) Bytes() []byte {
279 | var content, xmlContent []byte
280 |
281 | if tag.CountChildren() == 0 {
282 | return tag.getFlatXML()
283 | }
284 |
285 | child := tag.firstChild
286 | for child != nil {
287 | xmlContent = child.Bytes()
288 | if len(xmlContent) > 0 {
289 | content = append(content, xmlContent...)
290 | }
291 | child = child.nextSibling
292 | }
293 |
294 | tag.xmlContent = append(tag.xmlContent, tag.content...)
295 | tag.xmlContent = append(tag.xmlContent, content...)
296 | tag.xmlContent = append(tag.xmlContent, tag.lastContent...)
297 |
298 | return tag.getFlatXML()
299 | }
300 |
301 | // String returns the XML contents as a string
302 | func (tag *Tag) String() string {
303 | return string(tag.Bytes())
304 | }
305 |
306 | // AddContent adds content to the body tag.
307 | // Returns the body tag and nil if successful.
308 | // Returns and an error if no body tag is found, else nil.
309 | func (image *Document) AddContent(content []byte) (*Tag, error) {
310 | body, err := image.root.GetTag([]byte("body"))
311 | if err == nil {
312 | body.AddContent(content)
313 | }
314 | return body, err
315 | }
316 |
317 | // Bytes renders the image as an XML document
318 | func (image *Document) Bytes() []byte {
319 | return image.root.Bytes()
320 | }
321 |
322 | // String renders the image as an XML document
323 | func (image *Document) String() string {
324 | return image.root.String()
325 | }
326 |
327 | // SaveSVG will save the current image as an SVG file
328 | func (image *Document) SaveSVG(filename string) error {
329 | return ioutil.WriteFile(filename, image.Bytes(), 0644)
330 | }
331 |
--------------------------------------------------------------------------------
/network_test.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "testing"
7 | )
8 |
9 | // Use a specific seed for the random number generator, just when testing
10 | var commonSeed int64 = 1571917826405889420
11 |
12 | func TestNewNetwork(t *testing.T) {
13 | rand.Seed(commonSeed)
14 | net := NewNetwork(&Config{
15 | inputs: 5,
16 | InitialConnectionRatio: 0.5,
17 | sharedWeight: 0.5,
18 | })
19 | fmt.Println(net)
20 | for i, n := range net.AllNodes {
21 | if NeuronIndex(i) != n.neuronIndex {
22 | t.Fail()
23 | }
24 | }
25 | }
26 |
27 | func TestGet(t *testing.T) {
28 | rand.Seed(commonSeed)
29 | net := NewNetwork()
30 | fmt.Println(net)
31 | fmt.Println(net.Get(0))
32 | if net.OutputNode != 0 {
33 | t.Fail()
34 | }
35 | }
36 |
37 | func TestIsInput(t *testing.T) {
38 | rand.Seed(commonSeed)
39 | net := NewNetwork(&Config{
40 | inputs: 5,
41 | InitialConnectionRatio: 0.5,
42 | sharedWeight: 0.5,
43 | })
44 | if !net.IsInput(1) {
45 | t.Fail()
46 | }
47 | if net.IsInput(0) {
48 | t.Fail()
49 | }
50 | }
51 |
52 | func TestForEachConnected(t *testing.T) {
53 | rand.Seed(commonSeed)
54 | net := NewNetwork(&Config{
55 | inputs: 5,
56 | InitialConnectionRatio: 0.5,
57 | sharedWeight: 0.5,
58 | })
59 | net.ForEachConnected(func(n *Neuron) {
60 | fmt.Printf("%d: %s, distance from output node: %d\n", n.neuronIndex, n, n.distanceFromOutputNode)
61 | })
62 | }
63 |
64 | func TestAll(t *testing.T) {
65 | rand.Seed(commonSeed)
66 | net := NewNetwork(&Config{
67 | inputs: 5,
68 | InitialConnectionRatio: 0.7,
69 | sharedWeight: 0.5,
70 | })
71 | for _, node := range net.All() {
72 | fmt.Println(node)
73 | }
74 | }
75 |
76 | func TestEvaluate2(t *testing.T) {
77 | rand.Seed(commonSeed)
78 | net := NewNetwork(&Config{
79 | inputs: 5,
80 | InitialConnectionRatio: 0.7,
81 | sharedWeight: 0.5,
82 | })
83 | _ = net.Evaluate([]float64{0.1, 0.2, 0.3, 0.4, 0.5})
84 | }
85 |
86 | func TestInsertNode(t *testing.T) {
87 | rand.Seed(commonSeed)
88 | net := NewNetwork(&Config{
89 | inputs: 5,
90 | InitialConnectionRatio: 0.5,
91 | sharedWeight: 0.5,
92 | })
93 | _, newNeuronIndex := net.NewNeuron()
94 | if err := net.InsertNode(0, 2, newNeuronIndex); err != nil {
95 | t.Error(err)
96 | }
97 | _ = net.Evaluate([]float64{0.1, 0.2, 0.3, 0.4, 0.5})
98 | }
99 |
100 | func TestAddConnection(t *testing.T) {
101 | rand.Seed(commonSeed)
102 | net := NewNetwork(&Config{
103 | inputs: 5,
104 | InitialConnectionRatio: 0.5,
105 | })
106 | _, newNeuronIndex := net.NewNeuron()
107 | if err := net.InsertNode(net.OutputNode, 2, newNeuronIndex); err != nil {
108 | t.Error(err)
109 | }
110 | // Add a connection from 1 to the new neuron.
111 | // This is the same as making the new neuron have an additional input neuron: index 1
112 | if err := net.AddConnection(1, newNeuronIndex); err != nil {
113 | t.Error(err)
114 | }
115 | // Add a connection from the output node to the output node. Should fail.
116 | if err := net.AddConnection(net.OutputNode, net.OutputNode); err == nil {
117 | t.Fail()
118 | }
119 | // Adding a made-up index should fail as well
120 | if err := net.AddConnection(net.OutputNode, 999); err == nil {
121 | t.Fail()
122 | }
123 | }
124 |
125 | func TestRandomizeActivationFunctionForRandomNeuron(t *testing.T) {
126 | rand.Seed(commonSeed)
127 | net := NewNetwork(&Config{
128 | inputs: 5,
129 | InitialConnectionRatio: 0.5,
130 | })
131 | net.RandomizeActivationFunctionForRandomNeuron()
132 | }
133 |
134 | func TestNetworkString(t *testing.T) {
135 | rand.Seed(commonSeed)
136 | net := NewNetwork(&Config{
137 | inputs: 5,
138 | InitialConnectionRatio: 0.5,
139 | })
140 | _ = net.String()
141 | }
142 |
143 | func TestSetWeight(t *testing.T) {
144 | net := NewNetwork()
145 | net.SetWeight(0.1234)
146 | if net.Weight != 0.1234 {
147 | t.Fail()
148 | }
149 | }
150 |
151 | func TestComplexity(t *testing.T) {
152 | rand.Seed(commonSeed)
153 | net := NewNetwork(&Config{
154 | inputs: 5,
155 | InitialConnectionRatio: 0.0,
156 | })
157 | // The complexity will vary, because the performance varies when
158 | // estimating the complexity of each function.
159 | // But the complexity compared between networks should still hold.
160 | firstComplexity := net.Complexity()
161 | // Adding a connection increases the complexity
162 | net.AddConnection(0, 1)
163 | // The complexity for the network, after a connection has been added
164 | secondComplexity := net.Complexity()
165 | if firstComplexity >= secondComplexity {
166 | t.Fail()
167 | }
168 | }
169 |
170 | func ExampleNetwork_InsertNode() {
171 | rand.Seed(commonSeed)
172 | net := NewNetwork(&Config{
173 | inputs: 3,
174 | InitialConnectionRatio: 1.0,
175 | })
176 | fmt.Println("Before insertion:")
177 | fmt.Println(net)
178 | _, nodeIndex := net.NewNeuron()
179 | err := net.InsertNode(0, 1, nodeIndex)
180 | if err != nil {
181 | fmt.Println("error: " + err.Error())
182 | }
183 | fmt.Println("After insertion:")
184 | fmt.Println(net)
185 | // Output:
186 | // Before insertion:
187 | // Network (4 nodes, 3 input nodes, 1 output node)
188 | // Connected inputs to output node: 3
189 | // Output node ID 0 has these input connections: [1 2 3]
190 | // Input node ID 1 has these input connections: []
191 | // Input node ID 2 has these input connections: []
192 | // Input node ID 3 has these input connections: []
193 | //
194 | // After insertion:
195 | // Network (5 nodes, 3 input nodes, 1 output node)
196 | // Connected inputs to output node: 3
197 | // Output node ID 0 has these input connections: [2 3 4]
198 | // Input node ID 1 has these input connections: []
199 | // Input node ID 2 has these input connections: []
200 | // Input node ID 3 has these input connections: []
201 | // Node ID 4 has these input connections: [1]
202 | }
203 |
204 | func TestLeftRight(t *testing.T) {
205 | rand.Seed(commonSeed)
206 | net := NewNetwork(&Config{
207 | inputs: 3,
208 | InitialConnectionRatio: 1.0,
209 | })
210 | net.AllNodes[1].ActivationFunction = Swish
211 | a, b, _ := net.LeftRight(0, 1)
212 | // output node to the right
213 | if a != 1 || b != 0 {
214 | t.Fail()
215 | }
216 | // output node to the right
217 | a, b, _ = net.LeftRight(1, 0)
218 | if a != 1 || b != 0 {
219 | t.Fail()
220 | }
221 | _, nodeIndex := net.NewNeuron()
222 | err := net.InsertNode(0, 1, nodeIndex)
223 | if err != nil {
224 | t.Error(err)
225 | }
226 | a, b, _ = net.LeftRight(0, nodeIndex)
227 | // output node to the right
228 | if a != nodeIndex || b != 0 {
229 | t.Fail()
230 | }
231 | a, b, _ = net.LeftRight(nodeIndex, 0)
232 | // output node to the right
233 | if a != nodeIndex || b != 0 {
234 | t.Fail()
235 | }
236 | a, b, _ = net.LeftRight(1, nodeIndex)
237 | // Here, the new node should be to the right, since it's between node 1 and the output node
238 | if a != 1 || b != nodeIndex {
239 | t.Fail()
240 | }
241 | //net.WriteSVG("c.svg")
242 | fmt.Println(net)
243 | a, b, _ = net.LeftRight(nodeIndex, 1)
244 | if a != 1 || b != nodeIndex {
245 | t.Fail()
246 | }
247 | }
248 |
249 | func TestDepth(t *testing.T) {
250 | rand.Seed(commonSeed)
251 | net := NewNetwork(&Config{
252 | inputs: 3,
253 | InitialConnectionRatio: 1.0,
254 | })
255 | fmt.Println(net.Depth())
256 | _, nodeIndex := net.NewBlankNeuron()
257 | _ = net.InsertNode(0, 1, nodeIndex)
258 | fmt.Println(net.Depth())
259 | }
260 |
261 | func ExampleCombine() {
262 | ac := []NeuronIndex{0, 1, 2, 3, 4}
263 | bc := []NeuronIndex{5, 6, 7, 8, 9}
264 | fmt.Println(Combine(ac, bc))
265 | // Output:
266 | // [0 1 2 3 4 5 6 7 8 9]
267 | }
268 |
269 | func TestGetRandomNeuron(t *testing.T) {
270 | rand.Seed(commonSeed)
271 | net := NewNetwork(&Config{
272 | inputs: 5,
273 | InitialConnectionRatio: 1.0,
274 | })
275 | stats := make(map[NeuronIndex]uint)
276 | for i := 0; i < 1000; i++ {
277 | ni := net.GetRandomNode()
278 | if _, ok := stats[ni]; !ok {
279 | stats[ni] = 0
280 | } else {
281 | stats[ni]++
282 | }
283 | }
284 | fmt.Println(stats)
285 | // Check that the output node exists in the stats
286 | if _, ok := stats[0]; !ok {
287 | t.Fail()
288 | }
289 | }
290 |
291 | func TestGetRandomInputNode(t *testing.T) {
292 | rand.Seed(commonSeed)
293 | net := NewNetwork(&Config{
294 | inputs: 5,
295 | InitialConnectionRatio: 1.0,
296 | })
297 | stats := make(map[NeuronIndex]uint)
298 | for i := 0; i < 1000; i++ {
299 | ni := net.GetRandomInputNode()
300 | if _, ok := stats[ni]; !ok {
301 | stats[ni] = 0
302 | } else {
303 | stats[ni]++
304 | }
305 | }
306 | fmt.Println(stats)
307 | // Check that the output node does not exist in the stats
308 | if _, ok := stats[0]; ok {
309 | t.Fail()
310 | }
311 | }
312 |
313 | func TestConnected(t *testing.T) {
314 | rand.Seed(commonSeed)
315 | net := NewNetwork(&Config{
316 | inputs: 5,
317 | InitialConnectionRatio: 0.1,
318 | })
319 | connected := net.Connected()
320 | if connected[0] != 0 || connected[1] != 2 {
321 | t.Fail()
322 | }
323 | }
324 |
325 | func TestUnconnected(t *testing.T) {
326 | rand.Seed(commonSeed)
327 | net := NewNetwork(&Config{
328 | inputs: 5,
329 | InitialConnectionRatio: 0.5,
330 | })
331 | unconnected := net.Unconnected()
332 | correct := []NeuronIndex{1, 3, 4}
333 | for i := 0; i < len(unconnected); i++ {
334 | if unconnected[i] != correct[i] {
335 | t.Fail()
336 | }
337 | }
338 | }
339 |
340 | func TestCopy(t *testing.T) {
341 | rand.Seed(commonSeed)
342 | net := NewNetwork(&Config{
343 | inputs: 5,
344 | InitialConnectionRatio: 0.5,
345 | })
346 |
347 | // Take a deep copy with the Copy() function
348 | net2 := net.Copy()
349 | // Modify net2 by inserting an unconnected neuron
350 | n := NewUnconnectedNeuron()
351 | net2.AllNodes[1] = *n
352 | // net and net2 should now be different, since net2 is a proper copy
353 | if net.String() == net2.String() {
354 | t.Fail()
355 | }
356 |
357 | // Take a shallow copy
358 | net3 := net
359 | // Modify net3 by inserting an unconnected neuron
360 | net3.AllNodes[1] = *n
361 | // net and net3 should still be the same, since net3 is just a shallow copy
362 | if net.String() != net3.String() {
363 | t.Fail()
364 | }
365 | }
366 |
367 | func TestForEachConnectedNodeIndex(t *testing.T) {
368 | rand.Seed(commonSeed)
369 | net := NewNetwork(&Config{
370 | inputs: 5,
371 | InitialConnectionRatio: 0.5,
372 | })
373 | lastNi := NeuronIndex(-1)
374 | net.ForEachConnectedNodeIndex(func(ni NeuronIndex) {
375 | fmt.Println(ni)
376 | lastNi = ni
377 | })
378 | if lastNi != 5 {
379 | t.Fail()
380 | }
381 | }
382 |
--------------------------------------------------------------------------------
/vendor/github.com/dave/jennifer/jen/hints.go:
--------------------------------------------------------------------------------
1 | // This file is generated - do not edit.
2 |
3 | package jen
4 |
5 | // standardLibraryHints contains package name hints
6 | var standardLibraryHints = map[string]string{
7 | "archive/tar": "tar",
8 | "archive/zip": "zip",
9 | "bufio": "bufio",
10 | "bytes": "bytes",
11 | "compress/bzip2": "bzip2",
12 | "compress/flate": "flate",
13 | "compress/gzip": "gzip",
14 | "compress/lzw": "lzw",
15 | "compress/zlib": "zlib",
16 | "constraints": "constraints",
17 | "container/heap": "heap",
18 | "container/list": "list",
19 | "container/ring": "ring",
20 | "context": "context",
21 | "crypto": "crypto",
22 | "crypto/aes": "aes",
23 | "crypto/cipher": "cipher",
24 | "crypto/des": "des",
25 | "crypto/dsa": "dsa",
26 | "crypto/ecdsa": "ecdsa",
27 | "crypto/ed25519": "ed25519",
28 | "crypto/ed25519/internal/edwards25519": "edwards25519",
29 | "crypto/ed25519/internal/edwards25519/field": "field",
30 | "crypto/elliptic": "elliptic",
31 | "crypto/elliptic/internal/fiat": "fiat",
32 | "crypto/elliptic/internal/nistec": "nistec",
33 | "crypto/hmac": "hmac",
34 | "crypto/internal/randutil": "randutil",
35 | "crypto/internal/subtle": "subtle",
36 | "crypto/md5": "md5",
37 | "crypto/rand": "rand",
38 | "crypto/rc4": "rc4",
39 | "crypto/rsa": "rsa",
40 | "crypto/sha1": "sha1",
41 | "crypto/sha256": "sha256",
42 | "crypto/sha512": "sha512",
43 | "crypto/subtle": "subtle",
44 | "crypto/tls": "tls",
45 | "crypto/x509": "x509",
46 | "crypto/x509/internal/macos": "macOS",
47 | "crypto/x509/pkix": "pkix",
48 | "database/sql": "sql",
49 | "database/sql/driver": "driver",
50 | "debug/buildinfo": "buildinfo",
51 | "debug/dwarf": "dwarf",
52 | "debug/elf": "elf",
53 | "debug/gosym": "gosym",
54 | "debug/macho": "macho",
55 | "debug/pe": "pe",
56 | "debug/plan9obj": "plan9obj",
57 | "embed": "embed",
58 | "embed/internal/embedtest": "embedtest",
59 | "encoding": "encoding",
60 | "encoding/ascii85": "ascii85",
61 | "encoding/asn1": "asn1",
62 | "encoding/base32": "base32",
63 | "encoding/base64": "base64",
64 | "encoding/binary": "binary",
65 | "encoding/csv": "csv",
66 | "encoding/gob": "gob",
67 | "encoding/hex": "hex",
68 | "encoding/json": "json",
69 | "encoding/pem": "pem",
70 | "encoding/xml": "xml",
71 | "errors": "errors",
72 | "expvar": "expvar",
73 | "flag": "flag",
74 | "fmt": "fmt",
75 | "go/ast": "ast",
76 | "go/build": "build",
77 | "go/build/constraint": "constraint",
78 | "go/constant": "constant",
79 | "go/doc": "doc",
80 | "go/format": "format",
81 | "go/importer": "importer",
82 | "go/internal/gccgoimporter": "gccgoimporter",
83 | "go/internal/gcimporter": "gcimporter",
84 | "go/internal/srcimporter": "srcimporter",
85 | "go/internal/typeparams": "typeparams",
86 | "go/parser": "parser",
87 | "go/printer": "printer",
88 | "go/scanner": "scanner",
89 | "go/token": "token",
90 | "go/types": "types",
91 | "hash": "hash",
92 | "hash/adler32": "adler32",
93 | "hash/crc32": "crc32",
94 | "hash/crc64": "crc64",
95 | "hash/fnv": "fnv",
96 | "hash/maphash": "maphash",
97 | "html": "html",
98 | "html/template": "template",
99 | "image": "image",
100 | "image/color": "color",
101 | "image/color/palette": "palette",
102 | "image/draw": "draw",
103 | "image/gif": "gif",
104 | "image/internal/imageutil": "imageutil",
105 | "image/jpeg": "jpeg",
106 | "image/png": "png",
107 | "index/suffixarray": "suffixarray",
108 | "internal/abi": "abi",
109 | "internal/buildcfg": "buildcfg",
110 | "internal/bytealg": "bytealg",
111 | "internal/cfg": "cfg",
112 | "internal/cpu": "cpu",
113 | "internal/execabs": "execabs",
114 | "internal/fmtsort": "fmtsort",
115 | "internal/fuzz": "fuzz",
116 | "internal/goarch": "goarch",
117 | "internal/godebug": "godebug",
118 | "internal/goexperiment": "goexperiment",
119 | "internal/goos": "goos",
120 | "internal/goroot": "goroot",
121 | "internal/goversion": "goversion",
122 | "internal/intern": "intern",
123 | "internal/itoa": "itoa",
124 | "internal/lazyregexp": "lazyregexp",
125 | "internal/lazytemplate": "lazytemplate",
126 | "internal/nettrace": "nettrace",
127 | "internal/obscuretestdata": "obscuretestdata",
128 | "internal/oserror": "oserror",
129 | "internal/poll": "poll",
130 | "internal/profile": "profile",
131 | "internal/race": "race",
132 | "internal/reflectlite": "reflectlite",
133 | "internal/singleflight": "singleflight",
134 | "internal/syscall/execenv": "execenv",
135 | "internal/syscall/unix": "unix",
136 | "internal/sysinfo": "sysinfo",
137 | "internal/testenv": "testenv",
138 | "internal/testlog": "testlog",
139 | "internal/trace": "trace",
140 | "internal/unsafeheader": "unsafeheader",
141 | "internal/xcoff": "xcoff",
142 | "io": "io",
143 | "io/fs": "fs",
144 | "io/ioutil": "ioutil",
145 | "log": "log",
146 | "log/syslog": "syslog",
147 | "math": "math",
148 | "math/big": "big",
149 | "math/bits": "bits",
150 | "math/cmplx": "cmplx",
151 | "math/rand": "rand",
152 | "mime": "mime",
153 | "mime/multipart": "multipart",
154 | "mime/quotedprintable": "quotedprintable",
155 | "net": "net",
156 | "net/http": "http",
157 | "net/http/cgi": "cgi",
158 | "net/http/cookiejar": "cookiejar",
159 | "net/http/fcgi": "fcgi",
160 | "net/http/httptest": "httptest",
161 | "net/http/httptrace": "httptrace",
162 | "net/http/httputil": "httputil",
163 | "net/http/internal": "internal",
164 | "net/http/internal/ascii": "ascii",
165 | "net/http/internal/testcert": "testcert",
166 | "net/http/pprof": "pprof",
167 | "net/internal/socktest": "socktest",
168 | "net/mail": "mail",
169 | "net/netip": "netip",
170 | "net/rpc": "rpc",
171 | "net/rpc/jsonrpc": "jsonrpc",
172 | "net/smtp": "smtp",
173 | "net/textproto": "textproto",
174 | "net/url": "url",
175 | "os": "os",
176 | "os/exec": "exec",
177 | "os/exec/internal/fdtest": "fdtest",
178 | "os/signal": "signal",
179 | "os/signal/internal/pty": "pty",
180 | "os/user": "user",
181 | "path": "path",
182 | "path/filepath": "filepath",
183 | "plugin": "plugin",
184 | "reflect": "reflect",
185 | "reflect/internal/example1": "example1",
186 | "reflect/internal/example2": "example2",
187 | "regexp": "regexp",
188 | "regexp/syntax": "syntax",
189 | "runtime": "runtime",
190 | "runtime/cgo": "cgo",
191 | "runtime/debug": "debug",
192 | "runtime/internal/atomic": "atomic",
193 | "runtime/internal/math": "math",
194 | "runtime/internal/sys": "sys",
195 | "runtime/metrics": "metrics",
196 | "runtime/pprof": "pprof",
197 | "runtime/race": "race",
198 | "runtime/trace": "trace",
199 | "sort": "sort",
200 | "strconv": "strconv",
201 | "strings": "strings",
202 | "sync": "sync",
203 | "sync/atomic": "atomic",
204 | "syscall": "syscall",
205 | "testing": "testing",
206 | "testing/fstest": "fstest",
207 | "testing/internal/testdeps": "testdeps",
208 | "testing/iotest": "iotest",
209 | "testing/quick": "quick",
210 | "text/scanner": "scanner",
211 | "text/tabwriter": "tabwriter",
212 | "text/template": "template",
213 | "text/template/parse": "parse",
214 | "time": "time",
215 | "time/tzdata": "tzdata",
216 | "unicode": "unicode",
217 | "unicode/utf16": "utf16",
218 | "unicode/utf8": "utf8",
219 | "unsafe": "unsafe",
220 | }
221 |
--------------------------------------------------------------------------------
/neuron.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "math/rand"
7 |
8 | "github.com/dave/jennifer/jen"
9 | )
10 |
11 | // Neuron is a list of input-neurons, and an activation function.
12 | type Neuron struct {
13 | Net *Network
14 | InputNodes []NeuronIndex // pointers to other neurons
15 | ActivationFunction ActivationFunctionIndex
16 | Value *float64
17 | distanceFromOutputNode int // Used when traversing nodes and drawing diagrams
18 | neuronIndex NeuronIndex
19 | }
20 |
21 | // NewBlankNeuron creates a new Neuron, with the Step activation function as the default
22 | func (net *Network) NewBlankNeuron() (*Neuron, NeuronIndex) {
23 | // Pre-allocate room for 16 connections and use Linear as the default activation function
24 | neuron := Neuron{Net: net, InputNodes: make([]NeuronIndex, 0, 16), ActivationFunction: Swish}
25 | neuron.neuronIndex = NeuronIndex(len(net.AllNodes))
26 | net.AllNodes = append(net.AllNodes, neuron)
27 | return &neuron, neuron.neuronIndex
28 | }
29 |
30 | // NewNeuron creates a new *Neuron, with a randomly chosen activation function
31 | func (net *Network) NewNeuron() (*Neuron, NeuronIndex) {
32 | chosenActivationFunctionIndex := ActivationFunctionIndex(rand.Intn(len(ActivationFunctions)))
33 | inputNodes := make([]NeuronIndex, 0, 16)
34 | neuron := Neuron{
35 | Net: net,
36 | InputNodes: inputNodes,
37 | ActivationFunction: chosenActivationFunctionIndex,
38 | }
39 | // The length of net.AllNodes is what will be the last index
40 | neuronIndex := NeuronIndex(len(net.AllNodes))
41 | // Assign the neuron index in the net to the neuron
42 | neuron.neuronIndex = neuronIndex
43 | // Add this neuron to the net
44 | net.AllNodes = append(net.AllNodes, neuron)
45 | return &neuron, neuronIndex
46 | }
47 |
48 | // NewUnconnectedNeuron returns a new unconnected neuron with neuronIndex -1 and net pointer set to nil
49 | func NewUnconnectedNeuron() *Neuron {
50 | // Pre-allocate room for 16 connections and use Linear as the default activation function
51 | neuron := Neuron{Net: nil, InputNodes: make([]NeuronIndex, 0, 16), ActivationFunction: Linear}
52 | neuron.neuronIndex = -1
53 | return &neuron
54 | }
55 |
56 | // Connect this neuron to a network, overwriting any existing connections.
57 | // This will also clear any input nodes to this neuron, since the net is different.
58 | // TODO: Find the input nodes from the neuron.Net, save those and re-assign if there are matches?
59 | func (neuron *Neuron) Connect(net *Network) {
60 | neuron.InputNodes = []NeuronIndex{}
61 | neuron.Net = net
62 | for ni := range net.AllNodes {
63 | // Check if this network already has a pointer to this neuron
64 | if &net.AllNodes[ni] == neuron {
65 | // Yes, assign the index
66 | neuron.neuronIndex = NeuronIndex(ni)
67 | // All good, bail
68 | return
69 | }
70 | }
71 | // The neuron was not found in the network
72 | // Find what will be the last index in net.AllNodes
73 | neuronIndex := len(net.AllNodes)
74 | // Add this neuron to the network
75 | net.AllNodes = append(net.AllNodes, *neuron)
76 | // Assign the index
77 | net.AllNodes[neuronIndex].neuronIndex = NeuronIndex(neuronIndex)
78 | }
79 |
80 | // RandomizeActivationFunction will choose a random activation function for this neuron
81 | func (neuron *Neuron) RandomizeActivationFunction() {
82 | chosenActivationFunctionIndex := ActivationFunctionIndex(rand.Intn(len(ActivationFunctions)))
83 | neuron.ActivationFunction = chosenActivationFunctionIndex
84 | }
85 |
86 | // SetValue can be used for setting a value for this neuron instead of using input neutrons.
87 | // This changes how the Evaluation function behaves.
88 | func (neuron *Neuron) SetValue(x float64) {
89 | neuron.Value = &x
90 | }
91 |
92 | // HasInput checks if the given neuron is an input neuron to this one
93 | func (neuron *Neuron) HasInput(e NeuronIndex) bool {
94 | for _, ni := range neuron.InputNodes {
95 | if ni == e {
96 | return true
97 | }
98 | }
99 | return false
100 | }
101 |
102 | // FindInput checks if the given neuron is an input neuron to this one,
103 | // and also returns the index to InputNeurons, if found.
104 | func (neuron *Neuron) FindInput(e NeuronIndex) (int, bool) {
105 | for i, n := range neuron.InputNodes {
106 | if n == e {
107 | return i, true
108 | }
109 | }
110 | return -1, false
111 | }
112 |
113 | // Is check if the given NeuronIndex points to this neuron
114 | func (neuron *Neuron) Is(e NeuronIndex) bool {
115 | return neuron.neuronIndex == e
116 | }
117 |
118 | // AddInput will add an input neuron
119 | func (neuron *Neuron) AddInput(ni NeuronIndex) error {
120 | if neuron.Is(ni) {
121 | return errors.New("adding a neuron as input to itself")
122 | }
123 | if neuron.HasInput(ni) {
124 | return errors.New("neuron already exists")
125 | }
126 | neuron.InputNodes = append(neuron.InputNodes, ni)
127 |
128 | return nil
129 | }
130 |
131 | // AddInputNeuron both adds a neuron to this network (if needed) and also
132 | // adds its neuron index to the neuron.InputNeurons
133 | func (neuron *Neuron) AddInputNeuron(n *Neuron) error {
134 | // If n.neuronIndex is known to this network, just add the NeuronIndex to neuron.InputNeurons
135 | if neuron.Net.Exists(n.neuronIndex) {
136 | return neuron.AddInput(n.neuronIndex)
137 | }
138 | // If not, add this neuron to the network first
139 | node := *n
140 | node.neuronIndex = NeuronIndex(len(neuron.Net.AllNodes))
141 | neuron.Net.AllNodes = append(neuron.Net.AllNodes, node)
142 | return neuron.AddInput(n.neuronIndex)
143 | }
144 |
145 | // RemoveInput will remove an input neuron
146 | func (neuron *Neuron) RemoveInput(e NeuronIndex) error {
147 | if i, found := neuron.FindInput(e); found {
148 | // Found it, remove the neuron at index i
149 | neuron.InputNodes = append(neuron.InputNodes[:i], neuron.InputNodes[i+1:]...)
150 | return nil
151 | }
152 | return errors.New("neuron does not exist")
153 | }
154 |
155 | // Exists checks if the given NeuronIndex exists in this Network
156 | func (net *Network) Exists(ni NeuronIndex) bool {
157 | for i := range net.AllNodes {
158 | neuronIndex := NeuronIndex(i)
159 | if neuronIndex == ni {
160 | return true
161 | }
162 | }
163 | return false
164 | }
165 |
166 | // InputNeuronsAreGood checks if all input neurons of this neuron exists in neuron.Net
167 | func (neuron *Neuron) InputNeuronsAreGood() bool {
168 | for _, inputNeuronIndex := range neuron.InputNodes {
169 | if !neuron.Net.Exists(inputNeuronIndex) {
170 | return false
171 | }
172 | }
173 | return true
174 | }
175 |
176 | // evaluate will return a weighted sum of the input nodes,
177 | // using the .Value field if it is set and no input nodes are available.
178 | // returns true if the maximum number of evaluation loops is reached
179 | func (neuron *Neuron) evaluate(weight float64, maxEvaluationLoops *int) (float64, bool) {
180 | if *maxEvaluationLoops <= 0 {
181 | return 0.0, true
182 | }
183 | // Assume this is the Output neuron, recursively evaluating the result
184 | // For each input neuron, evaluate them
185 | summed := 0.0
186 | counter := 0
187 |
188 | for _, inputNeuronIndex := range neuron.InputNodes {
189 | // Let each input neuron do its own evauluation, using the given weight
190 | (*maxEvaluationLoops)--
191 | // TODO: Figure out exactly why this one kicks in (and if it matters)
192 | // It only seems to kick in during "go test" and not in evolve/main.go
193 | if int(inputNeuronIndex) >= len(neuron.Net.AllNodes) {
194 | continue
195 | //panic("TOO HIGH INPUT NEURON INDEX")
196 | }
197 | result, stopNow := neuron.Net.AllNodes[inputNeuronIndex].evaluate(weight, maxEvaluationLoops)
198 | summed += result * weight
199 | counter++
200 | if stopNow || (*maxEvaluationLoops < 0) {
201 | break
202 | }
203 | }
204 | // No input neurons. Use the .Value field if it's not nil and this is not the output node
205 | if counter == 0 && neuron.Value != nil && !neuron.IsOutput() {
206 | return *(neuron.Value), false
207 | }
208 | // Return the averaged sum, or 0
209 | if counter == 0 {
210 | // This should never happen
211 | return 0.0, false
212 | }
213 | // This should run, also when this neuron is the output neuron
214 | f := neuron.GetActivationFunction()
215 | //fmt.Println(neuron.ActivationFunction.Name() + " = " + neuron.ActivationFunction.String())
216 | // Run the input through the activation function
217 | // TODO: Does "retval := f(summed)"" perform better?, or the one that averages the sum first?
218 | //retval := f(summed / float64(counter))
219 | retval := f(summed)
220 | return retval, false
221 | }
222 |
223 | // GetActivationFunction returns the activation function for this neuron
224 | func (neuron *Neuron) GetActivationFunction() func(float64) float64 {
225 | return ActivationFunctions[neuron.ActivationFunction]
226 | }
227 |
228 | // In checks if this neuron is in the given collection
229 | func (neuron *Neuron) In(collection []NeuronIndex) bool {
230 | for _, existingNodeIndex := range collection {
231 | if neuron.Is(existingNodeIndex) {
232 | return true
233 | }
234 | }
235 | return false
236 | }
237 |
238 | // IsInput returns true if this is an input node or not
239 | // Returns false if nil
240 | func (neuron *Neuron) IsInput() bool {
241 | if neuron.Net == nil {
242 | return false
243 | }
244 | return neuron.Net.IsInput(neuron.neuronIndex)
245 | }
246 |
247 | // IsOutput returns true if this is an output node or not
248 | // Returns false if nil
249 | func (neuron *Neuron) IsOutput() bool {
250 | if neuron.Net == nil {
251 | return false
252 | }
253 | return neuron.Net.OutputNode == neuron.neuronIndex
254 | }
255 |
256 | // Copy a Neuron to a new Neuron, and assign the pointer to the given network to .Net
257 | func (neuron Neuron) Copy(net *Network) Neuron {
258 | var newNeuron Neuron
259 | newNeuron.Net = net
260 | newNeuron.InputNodes = neuron.InputNodes
261 | newNeuron.ActivationFunction = neuron.ActivationFunction
262 | newNeuron.Value = neuron.Value
263 | newNeuron.distanceFromOutputNode = neuron.distanceFromOutputNode
264 | newNeuron.neuronIndex = neuron.neuronIndex
265 | return newNeuron
266 | }
267 |
268 | // String will return a string containing both the pointer address and the number of input neurons
269 | func (neuron *Neuron) String() string {
270 | nodeType := " Node"
271 | if neuron.IsInput() {
272 | nodeType = " Input node"
273 | } else if neuron.IsOutput() {
274 | nodeType = "Output node"
275 | }
276 | return fmt.Sprintf("%s ID %d has these input connections: %v", nodeType, neuron.neuronIndex, neuron.InputNodes)
277 | }
278 |
279 | // InputStatement returns a statement like "inputData[0]", if this node is a network input node
280 | func (neuron *Neuron) InputStatement() (*jen.Statement, error) {
281 | // If this node is a network input node, return a statement representing this input,
282 | // like "inputData[0]"
283 | if !neuron.IsInput() {
284 | return jen.Empty(), errors.New(" not an input node")
285 | }
286 | for i, ni := range neuron.Net.InputNodes {
287 | if ni == neuron.neuronIndex {
288 | // This index in the neuron.NetInputNodes is i
289 | return jen.Id("inputData").Index(jen.Lit(i)), nil
290 | }
291 | }
292 | // Not found!
293 | return jen.Empty(), errors.New("not an input node for the associated network")
294 | }
295 |
--------------------------------------------------------------------------------
/evolve.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "math/rand"
7 | "strconv"
8 | )
9 |
10 | // ScorePopulation evaluates a population, given a slice of input numbers.
11 | // It returns a map with scores, together with the sum of scores.
12 | func ScorePopulation(population []*Network, weight float64, inputData [][]float64, incorrectOutputMultipliers []float64) (map[int]float64, float64) {
13 |
14 | scoreMap := make(map[int]float64)
15 | scoreSum := 0.0
16 |
17 | for i := 0; i < len(population); i++ {
18 | net := population[i]
19 |
20 | if len(net.AllNodes[net.OutputNode].InputNodes) == 0 {
21 | // The output node has no input nodes, not great
22 | scoreMap[i] = 0.0
23 | continue
24 | }
25 |
26 | net.SetWeight(weight)
27 |
28 | // Evaluate all the input data examples for this network
29 | result := 0.0
30 | for i := 0; i < len(inputData); i++ {
31 | result += net.Evaluate(inputData[i]) * incorrectOutputMultipliers[i]
32 | }
33 |
34 | // The score is how well the network is doing, divided by the network complexity rating
35 | score := result / net.Complexity()
36 |
37 | scoreSum += score
38 | scoreMap[i] = score
39 | }
40 | return scoreMap, scoreSum
41 | }
42 |
43 | // Modify the network using one of the three methods outlined in the paper:
44 | // * Insert node
45 | // * Add connection
46 | // * Change activation function
47 | func (net *Network) Modify(maxIterations int) {
48 |
49 | // Use method 0, 1 or 2
50 | method := rand.Intn(3) // up to and not including 3
51 |
52 | // Perform a modfification, using one of the three methods outlined in the paper
53 | switch method {
54 | case 0:
55 | // Insert a node, replacing a randomly chosen existing connection
56 | counter := 0
57 | for !net.InsertRandomNode() {
58 | counter++
59 | if maxIterations > 0 && counter > maxIterations {
60 | break
61 | }
62 | }
63 | case 1:
64 | nodeA, nodeB := net.GetRandomNode(), net.GetRandomNode()
65 | // Continue finding random neurons until they work out or until maxIterations is reached
66 | // Create a new connection
67 | counter := 0
68 | for net.AddConnection(nodeA, nodeB) != nil {
69 | nodeA, nodeB = net.GetRandomNode(), net.GetRandomNode()
70 | counter++
71 | if maxIterations > 0 && counter > maxIterations {
72 | // Could not add a connection. The possibilities for connections might be saturated.
73 | return
74 | }
75 | }
76 | case 2:
77 | // Change the activation function to a randomly selected one
78 | net.RandomizeActivationFunctionForRandomNeuron()
79 | default:
80 | panic("implementation error: invalid method number: " + strconv.Itoa(method))
81 | }
82 | }
83 |
84 | // Complexity measures the network complexity
85 | // Will return 1.0 at a minimum
86 | func (net *Network) Complexity() float64 {
87 |
88 | // TODO: These two constants really affect the results. Place them in the Config struct instead.
89 |
90 | // How much should the function complexity matter in relation to the number of connected nodes?
91 | const functionComplexityMultiplier = 1.0
92 |
93 | // Weight the number of connected nodes
94 | const numberOfNodesMultiplier = 2.0
95 |
96 | // Number of input nodes connected to the output node multiplier
97 | const outputNodeInputNodesMultiplier = 3.0
98 |
99 | activationFunctionComplexity := 0.0
100 | // Sum the complexity of all activation functions.
101 | // This penalizes both slow activation functions and
102 | // unconnected nodes.
103 | for _, n := range net.AllNodes {
104 | if n.Value == nil {
105 | activationFunctionComplexity += ComplexityEstimate[n.ActivationFunction]
106 | }
107 | }
108 | activationFunctionComplexity *= functionComplexityMultiplier
109 | // The number of connected nodes should also carry some weight
110 | connectedNodes := float64(len(net.Connected())) * numberOfNodesMultiplier
111 | // The number of input nodes to the output node
112 | outputNodeComplexity := float64(len(net.AllNodes[net.OutputNode].InputNodes)) * outputNodeInputNodesMultiplier
113 | // This must always be larger than 0, to avoid divide by zero later
114 | return connectedNodes + activationFunctionComplexity + outputNodeComplexity + 1.0
115 | }
116 |
117 | // Evolve evolves a neural network, given a slice of training data and a slice of correct output values.
118 | // Will overwrite config.Inputs.
119 | func (config *Config) Evolve(inputData [][]float64, incorrectOutputMultipliers []float64) (*Network, error) {
120 |
121 | // TODO: If the config.initialConnectionRatio field is too low (0.0, for instance), then this function will fail.
122 | // Return with an error if none of the networks in a population has any connections left, then get rid of the "no improvement counter".
123 |
124 | // Initialize, if needed
125 | if !config.initialized {
126 | config.Init()
127 | }
128 |
129 | inputLength := len(inputData)
130 | if inputLength == 0 {
131 | return nil, errors.New("no input data")
132 | }
133 |
134 | const maxModificationInterationsWhenMutating = 10
135 |
136 | // incorrectOutputMultipliers := make([]float64, len(correctOutputMultipliers))
137 | // for i := range correctOutputMultipliers {
138 | // // Convert from having 0..1 for meaning from incorrect to correct, to -1..1 to mean the same
139 | // incorrectOutputMultipliers[i] = correctOutputMultipliers[i]*2.0 - 1.0
140 | // // Convert from having 0..1 for meaning from incorrect to correct, to 1..0 to mean the same
141 | // //incorrectOutputMultipliers[i] = -correctOutputMultipliers[i] + 1.0
142 | // }
143 |
144 | if len(incorrectOutputMultipliers) == 1 && inputLength != 1 {
145 | // Assume the first slice of floats in the input data is the correct and that the rest are examples of being wrong
146 | for i := 1; i < inputLength; i++ {
147 | incorrectOutputMultipliers = append(incorrectOutputMultipliers, -1.0)
148 | }
149 | } else if inputLength != len(incorrectOutputMultipliers) {
150 | // Assume that the list of correct output multipliers should match the length of the float64 slices in inputData
151 | return nil, errors.New("the length of the input data and the slice of output multipliers differs")
152 | }
153 |
154 | config.inputs = len(inputData[0])
155 |
156 | population := make([]*Network, config.PopulationSize)
157 |
158 | // Initialize the population
159 | for i := 0; i < config.PopulationSize; i++ {
160 | n := NewNetwork(config)
161 | population[i] = &n
162 | population[i].UpdateNetworkPointers()
163 | }
164 |
165 | var (
166 | bestNetwork *Network
167 |
168 | // Keep track of the best scores
169 | bestScore float64
170 | lastBestScore float64
171 |
172 | noImprovementCounter int // Counts how many times the best score has been stagnant
173 |
174 | // Keep track of the average scores
175 | averageScore float64
176 |
177 | // Keep track of the worst scores
178 | worstScore float64
179 | )
180 |
181 | if config.Verbose {
182 | fmt.Printf("Starting evolution with population size %d, for %d generations.\n", config.PopulationSize, config.Generations)
183 | }
184 |
185 | // For each generation, evaluate and modify the networks
186 | for j := 0; j < config.Generations; j++ {
187 |
188 | bestNetwork = nil
189 |
190 | // Initialize the scores with unlikely values
191 | // TODO: Use the first network in the population for initializing these instead
192 | first := true
193 |
194 | // Random weight from -2.0 to 2.0
195 | w := rand.Float64()
196 |
197 | // The scores for this generation (using a random shared weight within ScorePopulation).
198 | // CorrectOutputMultipliers gives weight to the "correct" or "wrong" results, with the same index as the inputData
199 | // Score each network in the population.
200 | scoreMap, scoreSum := ScorePopulation(population, w, inputData, incorrectOutputMultipliers)
201 |
202 | // Sort by score
203 | scoreList := SortByValue(scoreMap)
204 |
205 | // Handle the best score stats
206 | if first {
207 | lastBestScore = 0.0
208 | bestScore = scoreList[0].Value
209 | worstScore = scoreList[len(scoreList)-1].Value
210 | bestNetwork = population[scoreList[0].Key]
211 | bestNetwork.SetWeight(w)
212 | first = false
213 | } else {
214 | lastBestScore = bestScore
215 | if scoreList[0].Value > bestScore {
216 | bestScore = scoreList[0].Value
217 | }
218 | }
219 | if bestScore > lastBestScore {
220 | bestNetwork = population[scoreList[0].Key]
221 | bestNetwork.SetWeight(w)
222 | noImprovementCounter = 0
223 | } else {
224 | noImprovementCounter++
225 | }
226 |
227 | // Handle the average score stats
228 | averageScore = scoreSum / float64(config.PopulationSize)
229 |
230 | // Handle the worst score stats
231 | if scoreList[len(scoreList)-1].Value < worstScore {
232 | worstScore = scoreList[len(scoreList)-1].Value
233 | }
234 |
235 | if bestNetwork == nil {
236 | panic("implementation error: no best network")
237 | }
238 |
239 | if config.Verbose {
240 | fmt.Printf("[generation %d] worst score = %f, average score = %f, best score = %f\n", j, worstScore, averageScore, bestScore)
241 | //fmt.Printf("[generation %d] worst score = %f, average score = %f, best score = %f, no improvement counter for this generation = %d\n", j, worstScore, averageScore, bestScore, noImprovementCounter)
242 | if noImprovementCounter > 0 {
243 | fmt.Printf("No improvement in the best score for the last %d generations\n", noImprovementCounter)
244 | }
245 | }
246 |
247 | // Only keep the best 7%
248 | bestFractionCountdown := int(float64(len(population)) * 0.07)
249 |
250 | goodNetworks := make([]*Network, 0, bestFractionCountdown)
251 |
252 | // Now loop over all networks, sorted by score (descending order)
253 | // p.Key is the network index
254 | // p.Value is the network score
255 | for _, p := range scoreList {
256 | networkIndex := p.Key
257 | if bestFractionCountdown > 0 {
258 | bestFractionCountdown--
259 | // In the best third of the networks
260 | goodNetworks = append(goodNetworks, population[networkIndex])
261 | continue
262 | }
263 | // // If there has not been any improvement to the best score lately, randomize the bad half
264 | // if noImprovementCounter > 100 {
265 | // n := NewNetwork(config)
266 | // population[networkIndex] = &n
267 | // continue
268 | // }
269 | randomGoodNetwork := goodNetworks[rand.Intn(len(goodNetworks))]
270 | randomGoodNetworkCopy := randomGoodNetwork.Copy()
271 | randomGoodNetworkCopy.Modify(maxModificationInterationsWhenMutating)
272 | // Replace the "bad" network with the modified copy of a "good" one
273 | // It's important that this is a pointer to a Network and not
274 | // a bare Network, so that the node .Net pointers are correct.
275 | population[networkIndex] = randomGoodNetworkCopy
276 | }
277 | // if noImprovementCounter > 100 {
278 | // noImprovementCounter = 0
279 | // }
280 | }
281 | if config.Verbose {
282 | fmt.Printf("[all time best network, random weight ] weight=%f score=%f\n", bestNetwork.Weight, bestScore)
283 | }
284 |
285 | // Now find the best weight for the best network, using a population of 1
286 | // and a step size of 0.0001 for the weight
287 | population = []*Network{bestNetwork}
288 | bestWeight := -2.0
289 | for w := -2.0; w <= 2.0; w += 0.0001 {
290 | scoreMap, _ := ScorePopulation(population, w, inputData, incorrectOutputMultipliers)
291 | // Handle the best score stats
292 | if scoreMap[0] > bestScore {
293 | bestScore = scoreMap[0]
294 | population[0].SetWeight(w)
295 | bestWeight = w
296 | }
297 | }
298 |
299 | // Check if the best network is nil, just in case
300 | if bestNetwork == nil {
301 | return nil, errors.New("the total best network is nil")
302 | }
303 |
304 | // Save the best weight for the network
305 | bestNetwork.SetWeight(bestWeight)
306 |
307 | if config.Verbose {
308 | fmt.Printf("[all time best network, optimal weight ] weight=%f best score=%f\n", bestNetwork.Weight, bestScore)
309 | }
310 |
311 | return bestNetwork, nil
312 | }
313 |
--------------------------------------------------------------------------------
/network.go:
--------------------------------------------------------------------------------
1 | package wann
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "math/rand"
7 | "strconv"
8 | "strings"
9 | )
10 |
11 | // NeuronIndex is an index into the AllNodes slice
12 | type NeuronIndex int
13 |
14 | // Network is a collection of nodes, an output node and a shared weight.
15 | type Network struct {
16 | AllNodes []Neuron // Storing the actual neurons
17 | InputNodes []NeuronIndex // Pointers to the input nodes
18 | OutputNode NeuronIndex // Pointer to the output node
19 | Weight float64 // Shared weight
20 | }
21 |
22 | // NewNetwork creates a new minimal network with n input nodes and ratio of r connections.
23 | // Passing "nil" as an argument is supported.
24 | func NewNetwork(cs ...*Config) Network {
25 | c := &Config{}
26 | // If a single non-nil *Config struct is given, use that
27 | if len(cs) == 1 && cs[0] != nil {
28 | c = cs[0]
29 | }
30 | n := c.inputs
31 | r := c.InitialConnectionRatio
32 | w := c.sharedWeight
33 | // Create a new network that has one node, the output node
34 | outputNodeIndex := NeuronIndex(0)
35 | net := Network{make([]Neuron, 0, n+1), make([]NeuronIndex, n), outputNodeIndex, w}
36 | outputNode, outputNodeIndex := net.NewNeuron()
37 | net.OutputNode = outputNodeIndex
38 |
39 | // Initialize n input nodes that all are inputs to the one output node.
40 | for i := 0; i < n; i++ {
41 | // Add a new input node
42 | _, nodeIndex := net.NewNeuron()
43 |
44 | // Register the input node index in the input node NeuronIndex slice
45 | net.InputNodes[i] = nodeIndex
46 |
47 | // Make connections for all nodes where a random number between 0 and 1 are larger than r
48 | if r >= rand.Float64() {
49 | if err := outputNode.AddInput(nodeIndex); err != nil {
50 | panic(err)
51 | }
52 | }
53 | }
54 |
55 | // Store the modified output node
56 | net.AllNodes[outputNodeIndex] = *outputNode
57 |
58 | return net
59 | }
60 |
61 | // Get returns a pointer to a neuron, based on the given NeuronIndex
62 | func (net *Network) Get(i NeuronIndex) *Neuron {
63 | return &(net.AllNodes[i])
64 | }
65 |
66 | // IsInput checks if the given node is an input node
67 | func (net *Network) IsInput(ni NeuronIndex) bool {
68 | for _, inputNodeIndex := range net.InputNodes {
69 | if ni == inputNodeIndex {
70 | return true
71 | }
72 | }
73 | return false
74 | }
75 |
76 | //
77 | // Operators for searching the space of network topologies
78 | //
79 |
80 | // InsertNode takes two neurons and inserts a third neuron between them
81 | // Assumes that a is the leftmost node and the b is the rightmost node.
82 | func (net *Network) InsertNode(a, b NeuronIndex, newNodeIndex NeuronIndex) error {
83 | // This is done by first checking that a is an input node to b,
84 | // then setting newNode to be an input node to b,
85 | // then setting a to be an input node to a.
86 |
87 | // TODO: When a neuron is inserted, the input index
88 |
89 | // Sort the nodes by where they place in the diagram
90 | a, b, arbitrary := net.LeftRight(a, b)
91 | if arbitrary {
92 | if a == b {
93 | return errors.New("insert node: the a and b nodes are the same")
94 | }
95 | if net.IsInput(a) && net.IsInput(b) {
96 | return errors.New("insert node: both node a and b are input nodes")
97 | }
98 | return errors.New("insert node: arbitrary ordering when inserting a node")
99 | }
100 |
101 | // This should never happen
102 | if a == net.OutputNode {
103 | panic("implementation error: the leftmost node is an output node and this was not cought earlier")
104 | }
105 |
106 | // b already has a as an input (a -> b)
107 | if net.AllNodes[b].HasInput(a) {
108 | // Remove the old connection
109 | if err := net.AllNodes[b].RemoveInput(a); err != nil {
110 | return errors.New("error in InsertNode b.RemoveInput(a): " + err.Error())
111 | }
112 | }
113 |
114 | // b already has newNodeIndex as an input (newIndex -> b)
115 | if net.AllNodes[b].HasInput(newNodeIndex) {
116 | // Remove the old connection
117 | if err := net.AllNodes[b].RemoveInput(a); err != nil {
118 | return errors.New("error in InsertNode b.RemoveInput(a): " + err.Error())
119 | }
120 | }
121 |
122 | // Connect the new node to b
123 | if err := net.AllNodes[b].AddInput(newNodeIndex); err != nil {
124 | // This does not kick in, the problem must be in AddInput!
125 | return errors.New("error in InsertNode b.AddInput(newNode): " + err.Error())
126 | }
127 |
128 | // Connect a to the new node
129 | if err := net.AllNodes[newNodeIndex].AddInput(a); err != nil {
130 | return errors.New("error in InsertNode newNode.AddInput(a): " + err.Error())
131 | }
132 |
133 | // The situation should now be: a -> newNode -> b
134 | return nil
135 | }
136 |
137 | // AddConnection adds a connection from a to b.
138 | // The order is swapped if needed, then a is added as an input to b.
139 | func (net *Network) AddConnection(a, b NeuronIndex) error {
140 | lastIndex := NeuronIndex(len(net.AllNodes) - 1)
141 | if a < 0 || a > lastIndex || b < 0 || b > lastIndex {
142 | return errors.New("index out of range")
143 | }
144 | // Sort the nodes by where they place in the diagram
145 | var arbitrary bool
146 | a, b, arbitrary = net.LeftRight(a, b)
147 | if arbitrary {
148 | if a == b {
149 | return errors.New("can't connect to self")
150 | }
151 | return errors.New("error: arbitrary ordering when adding a connection")
152 | }
153 | // a should not be an output node
154 | if a == net.OutputNode {
155 | return errors.New("error: will not insert a node between the output node and another node")
156 | }
157 | // b should not be an input node
158 | if net.IsInput(b) {
159 | return errors.New("error: b is an input node")
160 | }
161 | // same thing
162 | if net.AllNodes[b].Value != nil {
163 | return errors.New("error: b is an input node")
164 | }
165 | if net.AllNodes[b].HasInput(a) {
166 | return errors.New("error: input already exists")
167 | }
168 | return net.AllNodes[b].AddInput(a)
169 | }
170 |
171 | // RandomizeActivationFunctionForRandomNeuron randomizes the activation function for a randomly selected neuron
172 | func (net *Network) RandomizeActivationFunctionForRandomNeuron() {
173 | chosenNeuronIndex := net.GetRandomNode()
174 | chosenActivationFunctionIndex := ActivationFunctionIndex(rand.Intn(len(ActivationFunctions)))
175 | net.AllNodes[chosenNeuronIndex].ActivationFunction = chosenActivationFunctionIndex
176 | }
177 |
178 | // Evaluate will return a weighted sum of the input nodes,
179 | // using the .Value field if it is set and no input nodes are available.
180 | // A shared weight can be given.
181 | func (net *Network) Evaluate(inputValues []float64) float64 {
182 | inputLength := len(inputValues)
183 | for i, nindex := range net.InputNodes {
184 | if i < inputLength {
185 | net.AllNodes[nindex].SetValue(inputValues[i])
186 | }
187 | }
188 | maxIterationCounter := inputLength
189 | result, _ := net.AllNodes[net.OutputNode].evaluate(net.Weight, &maxIterationCounter)
190 | return result
191 | }
192 |
193 | // SetWeight will set a shared weight for the entire network
194 | func (net *Network) SetWeight(weight float64) {
195 | net.Weight = weight
196 | }
197 |
198 | // LeftRight returns two neurons, such that the first on is the one that is
199 | // most to the left (towards the input neurons) and the second one is most to
200 | // the right (towards the output neuron). Assumes that a and b are not equal.
201 | // The returned bool is true if there is no order (if the nodes are equal, both are output nodes or both are input nodes)
202 | func (net *Network) LeftRight(a, b NeuronIndex) (NeuronIndex, NeuronIndex, bool) {
203 | // First check if they are equal
204 | if a == b {
205 | return a, b, true // Arbitrary order
206 | }
207 | // First check the network output nodes
208 | if a == net.OutputNode && b == net.OutputNode {
209 | return a, b, true // Arbitrary order
210 | }
211 | if a == net.OutputNode && b != net.OutputNode {
212 | return b, a, false // Swap order
213 | }
214 | if a != net.OutputNode && b == net.OutputNode {
215 | return a, b, false // Same order
216 | }
217 | // Then check if the nodes are already connected
218 | if net.AllNodes[a].In(net.AllNodes[b].InputNodes) {
219 | return a, b, false // Same order
220 | }
221 | if net.AllNodes[b].In(net.AllNodes[a].InputNodes) {
222 | return b, a, false // Swap order
223 | }
224 | // Then check the input nodes of the network
225 | aIsNetworkInputNode := net.AllNodes[a].In(net.InputNodes)
226 | bIsNetworkInputNode := net.AllNodes[b].In(net.InputNodes)
227 | if aIsNetworkInputNode && !bIsNetworkInputNode {
228 | return a, b, false // Same order
229 | }
230 | if !aIsNetworkInputNode && bIsNetworkInputNode {
231 | return b, a, false // Swap order
232 | }
233 | if aIsNetworkInputNode && bIsNetworkInputNode {
234 | return a, b, true // Arbitrary order
235 | }
236 | // Then check the distance from the output node, in steps
237 | aDistance := net.AllNodes[a].distanceFromOutputNode
238 | bDistance := net.AllNodes[b].distanceFromOutputNode
239 | if bDistance > aDistance {
240 | return b, a, false // Swap order, b is further away from the output node, which (usually) means further left in the graph
241 | }
242 | // Everything else
243 | return a, b, false
244 | }
245 |
246 | // Depth returns the maximum connection distance from the output node
247 | func (net *Network) Depth() int {
248 | maxDepth := 0
249 | net.ForEachConnected(func(n *Neuron) {
250 | if n.distanceFromOutputNode > maxDepth {
251 | maxDepth = n.distanceFromOutputNode
252 | }
253 | })
254 | return maxDepth
255 | }
256 |
257 | // All returns a slice with pointers to all nodes in this network
258 | func (net *Network) All() []*Neuron {
259 | allNodes := make([]*Neuron, 0)
260 | for i := range net.AllNodes {
261 | allNodes = append(allNodes, &net.AllNodes[i])
262 | }
263 | // Return pointers to all nodes in this network
264 | return allNodes
265 | }
266 |
267 | // GetRandomNode will select a random neuron.
268 | // This can be any node, including the output node.
269 | func (net *Network) GetRandomNode() NeuronIndex {
270 | return NeuronIndex(rand.Intn(len(net.AllNodes)))
271 | }
272 |
273 | // GetRandomInputNode returns a random input node
274 | func (net *Network) GetRandomInputNode() NeuronIndex {
275 | inputPosition := rand.Intn(len(net.InputNodes))
276 | inputNodeIndex := net.InputNodes[inputPosition]
277 | return inputNodeIndex
278 | }
279 |
280 | // Combine will combine two lists of indices
281 | func Combine(a, b []NeuronIndex) []NeuronIndex {
282 | lena := len(a)
283 | lenb := len(b)
284 | // Allocate the exact size needed
285 | res := make([]NeuronIndex, lena+lenb)
286 | // Add the elements from a
287 | for i := 0; i < lena; i++ {
288 | res[i] = a[i]
289 | }
290 | // Add the elements from b
291 | for i := 0; i < lenb; i++ {
292 | res[i+lena] = b[i]
293 | }
294 | return res
295 | }
296 |
297 | // getAllNodes is a helper function for the recursive network traversal.
298 | // Given the output node and the number 0, it will return a slice of all
299 | // connected nodes, where the distance from the output node has been stored in
300 | // node.distanceFromOutputNode.
301 | func (net *Network) getAllConnectedNodes(nodeIndex NeuronIndex, distanceFromFirstNode int, alreadyHaveThese []NeuronIndex) []NeuronIndex {
302 | allNodes := make([]NeuronIndex, 0, len(net.AllNodes))
303 | node := net.AllNodes[nodeIndex]
304 | if nodeIndex != net.OutputNode {
305 | node.distanceFromOutputNode = distanceFromFirstNode
306 | net.AllNodes[nodeIndex] = node
307 | }
308 | if !node.In(alreadyHaveThese) {
309 | allNodes = append(allNodes, nodeIndex)
310 | }
311 | for _, inputNodeIndex := range node.InputNodes {
312 | if node.Is(inputNodeIndex) {
313 | panic("implementation error: node is input node to self")
314 | }
315 | if int(inputNodeIndex) >= len(net.AllNodes) {
316 | continue
317 | }
318 | inputNode := net.AllNodes[inputNodeIndex]
319 | if !inputNode.In(allNodes) && !inputNode.In(alreadyHaveThese) {
320 | allNodes = Combine(allNodes, net.getAllConnectedNodes(inputNodeIndex, distanceFromFirstNode+1, append(allNodes, alreadyHaveThese...)))
321 | }
322 | }
323 | return allNodes
324 | }
325 |
326 | // ForEachConnected will only go through nodes that are connected to the output node (directly or indirectly)
327 | // Unconnected input nodes are not covered.
328 | func (net *Network) ForEachConnected(f func(n *Neuron)) {
329 | // Start at the output node, traverse left towards the input nodes
330 | // The network has a counter for how many nodes has been added/removed, for quick memory allocation here
331 | // the final slice is to avoid circular connections
332 | for _, nodeIndex := range net.getAllConnectedNodes(net.OutputNode, 0, []NeuronIndex{}) {
333 | f(&(net.AllNodes[nodeIndex]))
334 | }
335 | }
336 |
337 | // Connected returns a slice of neuron indexes, that are all connected to the output node (directly or indirectly)
338 | func (net *Network) Connected() []NeuronIndex {
339 | allConnected := make([]NeuronIndex, 0, len(net.AllNodes)) // Use a bit more memory, but don't allocate at every iteration
340 | net.ForEachConnectedNodeIndex(func(ni NeuronIndex) {
341 | allConnected = append(allConnected, ni)
342 | })
343 | return allConnected
344 | }
345 |
346 | // Unconnected returns a slice of all unconnected neurons
347 | func (net *Network) Unconnected() []NeuronIndex {
348 | connected := net.Connected()
349 | // TODO: Benchmark if using len(net.AllNodes) here is faster or not
350 | unconnected := make([]NeuronIndex, 0, len(net.AllNodes))
351 | for i, node := range net.AllNodes {
352 | if !node.In(connected) {
353 | unconnected = append(unconnected, NeuronIndex(i))
354 | }
355 | }
356 | return unconnected
357 | }
358 |
359 | // ForEachConnectedNodeIndex will only go through nodes that are connected to the output node (directly or indirectly)
360 | // Unconnected input nodes are not covered.
361 | func (net *Network) ForEachConnectedNodeIndex(f func(ni NeuronIndex)) {
362 | net.ForEachConnected(func(n *Neuron) {
363 | f(n.neuronIndex)
364 | })
365 | }
366 |
367 | // InsertRandomNode will try the given number of times to insert a node in a random location,
368 | // replacing an existing connection between two nodes.
369 | // `a -> b` will then become `a -> newNode -> b`
370 | // Returns true if one was inserted or false if the randomly chosen location wasn't fruitful
371 | func (net *Network) InsertRandomNode() bool {
372 |
373 | // Find a random node among the nodes that are connected to the output node (directly or indirectly)
374 | connectedNodes := net.Connected()
375 | randomNodeIndexThatIsConnected := connectedNodes[rand.Intn(len(connectedNodes))]
376 |
377 | // If this is one of the network input nodes, return
378 | if net.IsInput(randomNodeIndexThatIsConnected) {
379 | // Nothing to do here, the input nodes get their input from the input numbers
380 | return false
381 | }
382 |
383 | // If this is the output node, and there are no inputs to the output node, return
384 | if randomNodeIndexThatIsConnected == net.OutputNode && len(net.AllNodes[net.OutputNode].InputNodes) == 0 {
385 | // Nothing to do here, no connections to the output node
386 | return false
387 | }
388 |
389 | // If we arrived here, this node must have input nodes. Choose one at random.
390 | rightIndex := randomNodeIndexThatIsConnected
391 | inputNodes := net.AllNodes[rightIndex].InputNodes
392 |
393 | leftIndex := inputNodes[rand.Intn(len(inputNodes))]
394 |
395 | // We now have a left and right node index, that we know are connected, replace this connection with
396 | // one that goes through an entirely new node.
397 |
398 | // Create a new node and connect it with the left node
399 | newNode, newNodeIndex := net.NewNeuron()
400 | err := newNode.AddInput(leftIndex)
401 | if err != nil {
402 | panic(err)
403 | }
404 |
405 | // Remove the connection to the left node and then connect the right node to the new node
406 | err = net.AllNodes[rightIndex].RemoveInput(leftIndex)
407 | if err != nil {
408 | panic(err)
409 | }
410 | err = net.AllNodes[rightIndex].AddInput(newNodeIndex)
411 | if err != nil {
412 | panic(err)
413 | }
414 |
415 | return true
416 | }
417 |
418 | // UpdateNetworkPointers will update all the node.Net pointers to point to this network
419 | func (net *Network) UpdateNetworkPointers() {
420 | for nodeIndex := range net.AllNodes {
421 | net.AllNodes[nodeIndex].Net = net
422 | }
423 | }
424 |
425 | // NewInputNode creates a new input node for this network, optionally connecting it to the output node
426 | func (net *Network) NewInputNode(activationFunction ActivationFunctionIndex, connectToOutput bool) error {
427 | // Create a new node
428 | _, ni := net.NewBlankNeuron()
429 | // Set the activation function
430 | net.AllNodes[ni].ActivationFunction = activationFunction
431 | // Set the parent network
432 | net.AllNodes[ni].Net = net
433 | // Add the new node to the input nodes of the net
434 | net.InputNodes = append(net.InputNodes, ni)
435 | // Connect the input node to the output node
436 | if connectToOutput {
437 | return net.AddConnection(ni, net.OutputNode)
438 | }
439 | return nil
440 | }
441 |
442 | // Copy a Network to a new network
443 | func (net Network) Copy() *Network {
444 | var newNet Network
445 | newNet.AllNodes = make([]Neuron, len(net.AllNodes))
446 | for nodeIndex := range net.AllNodes {
447 | // This copies the node and also sets the .Net pointer correctly to this network
448 | newNet.AllNodes[nodeIndex] = newNet.AllNodes[nodeIndex].Copy(&newNet)
449 | }
450 | newNet.InputNodes = net.InputNodes
451 | newNet.OutputNode = net.OutputNode
452 | newNet.Weight = net.Weight
453 |
454 | // NOTE: It's important that a pointer to a Network is returned,
455 | // instead of an entire Network struct, so that the .Net pointers in the nodes point correctly.
456 | return &newNet
457 | }
458 |
459 | // String creates a simple and not very useful ASCII representation of the input nodes and the output node.
460 | // Nodes that are not input nodes are skipped.
461 | // Input nodes that are not connected directly to the output node are drawn as non-connected,
462 | // even if they are connected via another node.
463 | func (net Network) String() string {
464 | var sb strings.Builder
465 | sb.WriteString(fmt.Sprintf("Network (%d nodes, %d input nodes, %d output node)\n", len(net.AllNodes), len(net.InputNodes), 1))
466 | sb.WriteString("\tConnected inputs to output node: " + strconv.Itoa(len(net.AllNodes[net.OutputNode].InputNodes)) + "\n")
467 | for _, node := range net.AllNodes {
468 | sb.WriteString("\t" + node.String() + "\n")
469 | }
470 | return sb.String()
471 | }
472 |
--------------------------------------------------------------------------------
/vendor/github.com/xyproto/tinysvg/tinysvg.go:
--------------------------------------------------------------------------------
1 | // Package tinysvg supports generating and writing TinySVG 1.2 images
2 | //
3 | // Some function names are suffixed with "2" if they take structs instead of ints/floats,
4 | // "i" if they take ints and "f" if they take floats. There is no support for multiple dispatch in Go.
5 | //
6 | package tinysvg
7 |
8 | import (
9 | "bytes"
10 | "errors"
11 | "fmt"
12 | "strconv"
13 | "strings"
14 | )
15 |
16 | const (
17 | TRANSPARENT = 0.0
18 | OPAQUE = 1.0
19 |
20 | YES = 0
21 | NO = 1
22 | AUTO = 2
23 | )
24 |
25 | type (
26 | Vec2 struct {
27 | X, Y float64
28 | }
29 |
30 | Pos Vec2
31 | Radius Vec2
32 |
33 | Size struct {
34 | W, H float64
35 | }
36 |
37 | Color struct {
38 | R, G, B int // red, green, blue (0..255)
39 | A float64 // alpha, 0.0..1.0
40 | N string // name (optional, will override the above values)
41 | }
42 |
43 | Font struct {
44 | Family string
45 | Size int
46 | }
47 |
48 | YesNoAuto int
49 | )
50 |
51 | var ErrPair = errors.New("position pairs must be exactly two comma separated numbers")
52 |
53 | // Create a new TinySVG document, where the width and height is defined in pixels, using the "px" suffix
54 | func NewTinySVG(w, h int) (*Document, *Tag) {
55 | page := NewDocument([]byte(""), []byte(``))
56 | svg := page.root.AddNewTag([]byte("svg"))
57 | svg.AddAttrib("xmlns", []byte("http://www.w3.org/2000/svg"))
58 | svg.AddAttrib("version", []byte("1.2"))
59 | svg.AddAttrib("baseProfile", []byte("tiny"))
60 | svg.AddAttrib("viewBox", []byte(fmt.Sprintf("%d %d %d %d", 0, 0, w, h)))
61 | svg.AddAttrib("width", []byte(fmt.Sprintf("%dpx", w)))
62 | svg.AddAttrib("height", []byte(fmt.Sprintf("%dpx", h)))
63 | return page, svg
64 | }
65 |
66 | // NewTinySVG2 creates new TinySVG 1.2 image. Pos and Size defines the viewbox
67 | func NewTinySVG2(p *Pos, s *Size) (*Document, *Tag) {
68 | // No page title is needed when building an SVG tag tree
69 | page := NewDocument([]byte(""), []byte(``))
70 |
71 | // No longer needed for TinySVG 1.2. See: https://www.w3.org/TR/SVGTiny12/intro.html#defining
72 | //
73 |
74 | // Add the root tag
75 | svg := page.root.AddNewTag([]byte("svg"))
76 | svg.AddAttrib("xmlns", []byte("http://www.w3.org/2000/svg"))
77 | svg.AddAttrib("version", []byte("1.2"))
78 | svg.AddAttrib("baseProfile", []byte("tiny"))
79 | svg.AddAttrib("viewBox", []byte(fmt.Sprintf("%f %f %f %f", p.X, p.Y, s.W, s.H)))
80 | return page, svg
81 | }
82 |
83 | // f2b converts the given float64 to a string representation.
84 | // .0 or .000000 suffixes are stripped.
85 | // The string is returned as a byte slice.
86 | func f2b(x float64) []byte {
87 | fs := fmt.Sprintf("%f", x)
88 | fs = strings.TrimSuffix(fs, ".0")
89 | fs = strings.TrimSuffix(fs, ".000000")
90 | return []byte(fs)
91 | }
92 |
93 | // Rect a rectangle, given x and y position, width and height.
94 | // No color is being set.
95 | func (svg *Tag) Rect2(p *Pos, s *Size, c *Color) *Tag {
96 | rect := svg.AddNewTag([]byte("rect"))
97 | rect.AddAttrib("x", f2b(p.X))
98 | rect.AddAttrib("y", f2b(p.Y))
99 | rect.AddAttrib("width", f2b(s.W))
100 | rect.AddAttrib("height", f2b(s.H))
101 | rect.Fill2(c)
102 | return rect
103 | }
104 |
105 | // RoundedRect2 a rectangle, given x and y position, width and height.
106 | // No color is being set.
107 | func (svg *Tag) RoundedRect2(p *Pos, r *Radius, s *Size, c *Color) *Tag {
108 | rect := svg.AddNewTag([]byte("rect"))
109 | rect.AddAttrib("x", f2b(p.X))
110 | rect.AddAttrib("y", f2b(p.Y))
111 | rect.AddAttrib("rx", f2b(r.X))
112 | rect.AddAttrib("ry", f2b(r.Y))
113 | rect.AddAttrib("width", f2b(s.W))
114 | rect.AddAttrib("height", f2b(s.H))
115 | rect.Fill2(c)
116 | return rect
117 | }
118 |
119 | // Text adds text. No color is being set
120 | func (svg *Tag) Text2(p *Pos, f *Font, message string, c *Color) *Tag {
121 | text := svg.AddNewTag([]byte("text"))
122 | text.AddAttrib("x", f2b(p.X))
123 | text.AddAttrib("y", f2b(p.Y))
124 | text.AddAttrib("font-family", []byte(f.Family))
125 | text.AddAttrib("font-size", []byte(strconv.Itoa(f.Size)))
126 | text.Fill2(c)
127 | text.AddContent([]byte(message))
128 | return text
129 | }
130 |
131 | // Circle adds a circle, given a position, radius and color
132 | func (svg *Tag) Circle2(p *Pos, radius int, c *Color) *Tag {
133 | circle := svg.AddNewTag([]byte("circle"))
134 | circle.AddAttrib("cx", f2b(p.X))
135 | circle.AddAttrib("cy", f2b(p.Y))
136 | circle.AddAttrib("r", []byte(strconv.Itoa(radius)))
137 | circle.Fill2(c)
138 | return circle
139 | }
140 |
141 | // Circle adds a circle, given a position, radius and color
142 | func (svg *Tag) Circlef(p *Pos, radius float64, c *Color) *Tag {
143 | circle := svg.AddNewTag([]byte("circle"))
144 | circle.AddAttrib("cx", f2b(p.X))
145 | circle.AddAttrib("cy", f2b(p.Y))
146 | circle.AddAttrib("r", f2b(radius))
147 | circle.Fill2(c)
148 | return circle
149 | }
150 |
151 | // Ellipse adds an ellipse with a given position (x,y) and radius (rx, ry).
152 | func (svg *Tag) Ellipse2(p *Pos, r *Radius, c *Color) *Tag {
153 | ellipse := svg.AddNewTag([]byte("ellipse"))
154 | ellipse.AddAttrib("cx", f2b(p.X))
155 | ellipse.AddAttrib("cy", f2b(p.Y))
156 | ellipse.AddAttrib("rx", f2b(r.X))
157 | ellipse.AddAttrib("ry", f2b(r.Y))
158 | ellipse.Fill2(c)
159 | return ellipse
160 | }
161 |
162 | // Line adds a line from (x1, y1) to (x2, y2) with a given stroke width and color
163 | func (svg *Tag) Line2(p1, p2 *Pos, thickness int, c *Color) *Tag {
164 | line := svg.AddNewTag([]byte("line"))
165 | line.AddAttrib("x1", f2b(p1.X))
166 | line.AddAttrib("y1", f2b(p1.Y))
167 | line.AddAttrib("x2", f2b(p2.X))
168 | line.AddAttrib("y2", f2b(p2.Y))
169 | line.Thickness(thickness)
170 | line.Stroke2(c)
171 | return line
172 | }
173 |
174 | // Triangle adds a colored triangle
175 | func (svg *Tag) Triangle2(p1, p2, p3 *Pos, c *Color) *Tag {
176 | triangle := svg.AddNewTag([]byte("path"))
177 | triangle.AddAttrib("d", []byte(fmt.Sprintf("M %f %f L %f %f L %f %f L %f %f", p1.X, p1.Y, p2.X, p2.Y, p3.X, p3.Y, p1.X, p1.Y)))
178 | triangle.Fill2(c)
179 | return triangle
180 | }
181 |
182 | // Poly2 adds a colored path with 4 points
183 | func (svg *Tag) Poly2(p1, p2, p3, p4 *Pos, c *Color) *Tag {
184 | poly4 := svg.AddNewTag([]byte("path"))
185 | poly4.AddAttrib("d", []byte(fmt.Sprintf("M %f %f L %f %f L %f %f L %f %f L %f %f", p1.X, p1.Y, p2.X, p2.Y, p3.X, p3.Y, p4.X, p4.Y, p1.X, p1.Y)))
186 | poly4.Fill2(c)
187 | return poly4
188 | }
189 |
190 | // Fill selects the fill color that will be used when drawing
191 | func (svg *Tag) Fill2(c *Color) {
192 | // If no color name is given and the color is transparent, don't set a fill color
193 | if (c == nil) || (len(c.N) == 0 && c.A == TRANSPARENT) {
194 | return
195 | }
196 | svg.AddAttrib("fill", c.Bytes())
197 | }
198 |
199 | // Stroke selects the stroke color that will be used when drawing
200 | func (svg *Tag) Stroke2(c *Color) {
201 | // If no color name is given and the color is transparent, don't set a stroke color
202 | if (c == nil) || (len(c.N) == 0 && c.A == TRANSPARENT) {
203 | return
204 | }
205 | svg.AddAttrib("stroke", c.Bytes())
206 | }
207 |
208 | // RGBBytes converts r, g and b (integers in the range 0..255)
209 | // to a color string on the form "#nnnnnn", returned as a byte slice.
210 | // May also return colors strings on the form "#nnn".
211 | func RGBBytes(r, g, b int) []byte {
212 | rs := strconv.FormatInt(int64(r), 16)
213 | gs := strconv.FormatInt(int64(g), 16)
214 | bs := strconv.FormatInt(int64(g), 16)
215 | if len(rs) == 1 && len(gs) == 1 && len(bs) == 1 {
216 | // short form
217 | return []byte("#" + rs + gs + bs)
218 | }
219 | // long form
220 | return []byte(fmt.Sprintf("#%02x%02x%02x", r, g, b))
221 | }
222 |
223 | // RGBABytes converts integers r, g and b (the color) and also
224 | // a given alpha (opacity) to a color-string on the form
225 | // "rgba(255, 255, 255, 1.0)".
226 | func RGBABytes(r, g, b int, a float64) []byte {
227 | return []byte(fmt.Sprintf("rgba(%d, %d, %d, %f)", r, g, b, a))
228 | }
229 |
230 | // RGBA creates a new Color with the given red, green and blue values.
231 | // The colors are in the range 0..255
232 | func RGB(r, g, b int) *Color {
233 | return &Color{r, g, b, OPAQUE, ""}
234 | }
235 |
236 | // RGBA creates a new Color with the given red, green, blue and alpha values.
237 | // Alpha is between 0 and 1, the rest are 0..255.
238 | // For the alpha value, 0 is transparent and 1 is opaque.
239 | func RGBA(r, g, b int, a float64) *Color {
240 | return &Color{r, g, b, a, ""}
241 | }
242 |
243 | // ColorByName creates a new Color with a given name, like "blue"
244 | func ColorByName(name string) *Color {
245 | return &Color{N: name}
246 | }
247 |
248 | // NewColor is the same as ColorByName
249 | func NewColor(name string) *Color {
250 | return ColorByName(name)
251 | }
252 |
253 | // String returns the color as an RGB (#1234FF) string
254 | // or as an RGBA (rgba(0, 1, 2 ,3)) string.
255 | func (c *Color) Bytes() []byte {
256 | // Return an empty string if nil
257 | if c == nil {
258 | return make([]byte, 0)
259 | }
260 | // Return the name, if specified
261 | if len(c.N) != 0 {
262 | return []byte(c.N)
263 | }
264 | // Return a regular RGB string if alpha is 1.0
265 | if c.A == OPAQUE {
266 | // Generate a rgb string
267 | return RGBBytes(c.R, c.G, c.B)
268 | }
269 | // Generate a rgba string if alpha is < 1.0
270 | return RGBABytes(c.R, c.G, c.B, c.A)
271 | }
272 |
273 | // --- Convenience functions and functions for backward compatibility ---
274 |
275 | func NewTinySVGi(x, y, w, h int) (*Document, *Tag) {
276 | return NewTinySVG2(&Pos{float64(x), float64(y)}, &Size{float64(w), float64(h)})
277 | }
278 |
279 | func NewTinySVGf(x, y, w, h float64) (*Document, *Tag) {
280 | return NewTinySVG2(&Pos{x, y}, &Size{w, h})
281 | }
282 |
283 | // AddRect adds a rectangle, given x and y position, width and height.
284 | // No color is being set.
285 | func (svg *Tag) AddRect(x, y, w, h int) *Tag {
286 | return svg.Rect2(&Pos{float64(x), float64(y)}, &Size{float64(w), float64(h)}, nil)
287 | }
288 |
289 | // AddRoundedRect adds a rectangle, given x and y position, radius x, radius y, width and height.
290 | // No color is being set.
291 | func (svg *Tag) AddRoundedRect(x, y, rx, ry, w, h int) *Tag {
292 | return svg.RoundedRect2(&Pos{float64(x), float64(y)}, &Radius{float64(rx), float64(ry)}, &Size{float64(w), float64(h)}, nil)
293 | }
294 |
295 | // AddRectf adds a rectangle, given x and y position, width and height.
296 | // No color is being set.
297 | func (svg *Tag) AddRectf(x, y, w, h float64) *Tag {
298 | return svg.Rect2(&Pos{x, y}, &Size{w, h}, nil)
299 | }
300 |
301 | // AddRoundedRectf adds a rectangle, given x and y position, radius x, radius y, width and height.
302 | // No color is being set.
303 | func (svg *Tag) AddRoundedRectf(x, y, rx, ry, w, h float64) *Tag {
304 | return svg.RoundedRect2(&Pos{x, y}, &Radius{rx, ry}, &Size{w, h}, nil)
305 | }
306 |
307 | // AddText adds text. No color is being set
308 | func (svg *Tag) AddText(x, y, fontSize int, fontFamily, text string) *Tag {
309 | return svg.Text2(&Pos{float64(x), float64(y)}, &Font{fontFamily, fontSize}, text, nil)
310 | }
311 |
312 | // Box adds a rectangle, given x and y position, width, height and color
313 | func (svg *Tag) Box(x, y, w, h int, color string) *Tag {
314 | return svg.Rect2(&Pos{float64(x), float64(y)}, &Size{float64(w), float64(h)}, ColorByName(color))
315 | }
316 |
317 | // AddCircle adds a circle Add a circle, given a position (x, y) and a radius.
318 | // No color is being set.
319 | func (svg *Tag) AddCircle(x, y, radius int) *Tag {
320 | return svg.Circle2(&Pos{float64(x), float64(y)}, radius, nil)
321 | }
322 |
323 | // AddCirclef adds a circle Add a circle, given a position (x, y) and a radius.
324 | // No color is being set.
325 | func (svg *Tag) AddCirclef(x, y, radius float64) *Tag {
326 | return svg.Circlef(&Pos{x, y}, radius, nil)
327 | }
328 |
329 | // AddEllipse adds an ellipse with a given position (x,y) and radius (rx, ry).
330 | // No color is being set.
331 | func (svg *Tag) AddEllipse(x, y, rx, ry int) *Tag {
332 | return svg.Ellipse2(&Pos{float64(x), float64(y)}, &Radius{float64(rx), float64(ry)}, nil)
333 | }
334 |
335 | // AddEllipsef adds an ellipse with a given position (x,y) and radius (rx, ry).
336 | // No color is being set.
337 | func (svg *Tag) AddEllipsef(x, y, rx, ry float64) *Tag {
338 | return svg.Ellipse2(&Pos{x, y}, &Radius{rx, ry}, nil)
339 | }
340 |
341 | // Line adds a line from (x1, y1) to (x2, y2) with a given stroke width and color
342 | func (svg *Tag) Line(x1, y1, x2, y2, thickness int, color string) *Tag {
343 | return svg.Line2(&Pos{float64(x1), float64(y1)}, &Pos{float64(x2), float64(y2)}, thickness, ColorByName(color))
344 | }
345 |
346 | // Triangle adds a colored triangle
347 | func (svg *Tag) Triangle(x1, y1, x2, y2, x3, y3 int, color string) *Tag {
348 | return svg.Triangle2(&Pos{float64(x1), float64(y1)}, &Pos{float64(x2), float64(y2)}, &Pos{float64(x3), float64(y3)}, ColorByName(color))
349 | }
350 |
351 | // Poly4 adds a colored path with 4 points
352 | func (svg *Tag) Poly4(x1, y1, x2, y2, x3, y3, x4, y4 int, color string) *Tag {
353 | return svg.Poly2(&Pos{float64(x1), float64(y1)}, &Pos{float64(x2), float64(y2)}, &Pos{float64(x3), float64(y3)}, &Pos{float64(x4), float64(y4)}, ColorByName(color))
354 | }
355 |
356 | // Circle adds a circle, given x and y position, radius and color
357 | func (svg *Tag) Circle(x, y, radius int, color string) *Tag {
358 | return svg.Circle2(&Pos{float64(x), float64(y)}, radius, ColorByName(color))
359 | }
360 |
361 | // Ellipse adds an ellipse, given x and y position, radiuses and color
362 | func (svg *Tag) Ellipse(x, y, xr, yr int, color string) *Tag {
363 | return svg.Ellipse2(&Pos{float64(x), float64(y)}, &Radius{float64(xr), float64(yr)}, ColorByName(color))
364 | }
365 |
366 | // Fill selects the fill color that will be used when drawing
367 | func (svg *Tag) Fill(color string) {
368 | svg.AddAttrib("fill", []byte(color))
369 | }
370 |
371 | // ColorBytes converts r, g and b (integers in the range 0..255)
372 | // to a color string on the form "#nnnnnn".
373 | func ColorBytes(r, g, b int) []byte {
374 | return RGB(r, g, b).Bytes()
375 | }
376 |
377 | // ColorBytesAlpha converts integers r, g and b (the color) and also
378 | // a given alpha (opacity) to a color-string on the form
379 | // "rgba(255, 255, 255, 1.0)".
380 | func ColorBytesAlpha(r, g, b int, a float64) []byte {
381 | return RGBA(r, g, b, a).Bytes()
382 | }
383 |
384 | // Pixel creates a rectangle that is 1 wide with the given color.
385 | // Note that the size of the "pixel" depends on how large the viewBox is.
386 | func (svg *Tag) Pixel(x, y, r, g, b int) *Tag {
387 | return svg.Rect2(&Pos{float64(x), float64(y)}, &Size{1.0, 1.0}, RGB(r, g, b))
388 | }
389 |
390 | // AlphaDot creates a small circle that can be transparent.
391 | // Takes a position (x, y) and a color (r, g, b, a).
392 | func (svg *Tag) AlphaDot(x, y, r, g, b int, a float32) *Tag {
393 | return svg.Circle2(&Pos{float64(x), float64(y)}, 1, RGBA(r, g, b, float64(a)))
394 | }
395 |
396 | // Dot adds a small colored circle.
397 | // Takes a position (x, y) and a color (r, g, b).
398 | func (svg *Tag) Dot(x, y, r, g, b int) *Tag {
399 | return svg.Circle2(&Pos{float64(x), float64(y)}, 1, RGB(r, g, b))
400 | }
401 |
402 | // Text adds text, with a color
403 | func (svg *Tag) Text(x, y, fontSize int, fontFamily, text, color string) *Tag {
404 | return svg.Text2(&Pos{float64(x), float64(y)}, &Font{fontFamily, fontSize}, text, ColorByName(color))
405 | }
406 |
407 | // Create a new Yes/No/Auto struct. If auto is true, it overrides the yes value.
408 | func NewYesNoAuto(yes bool, auto bool) YesNoAuto {
409 | if auto {
410 | return AUTO
411 | }
412 | if yes {
413 | return YES
414 | }
415 | return NO
416 | }
417 |
418 | func (yna YesNoAuto) Bytes() []byte {
419 | switch yna {
420 | case YES:
421 | return []byte("true")
422 | case AUTO:
423 | return []byte("auto")
424 | default:
425 | return []byte("false")
426 | }
427 | }
428 |
429 | // Focusable sets the "focusable" attribute to either true, false or auto
430 | // If "auto" is true, it overrides the value of "yes".
431 | func (svg *Tag) Focusable(yes bool, auto bool) {
432 | svg.AddAttrib("focusable", NewYesNoAuto(yes, auto).Bytes())
433 | }
434 |
435 | // Thickness sets the stroke-width attribute
436 | func (svg *Tag) Thickness(thickness int) {
437 | svg.AddAttrib("stroke-width", []byte(strconv.Itoa(thickness)))
438 | }
439 |
440 | // Polyline adds a set of connected straight lines, an open shape
441 | func (svg *Tag) Polyline(points []*Pos, c *Color) *Tag {
442 | polyline := svg.AddNewTag([]byte("polyline"))
443 | var buf bytes.Buffer
444 | lastIndex := len(points) - 1
445 | for i, p := range points {
446 | buf.Write(f2b(p.X))
447 | buf.WriteByte(',')
448 | buf.Write(f2b(p.Y))
449 | if i != lastIndex {
450 | buf.WriteByte(' ')
451 | }
452 | }
453 | polyline.AddAttrib("points", buf.Bytes())
454 | return polyline
455 | }
456 |
457 | // Polygon adds a set of connected straight lines, a closed shape
458 | func (svg *Tag) Polygon(points []*Pos, c *Color) *Tag {
459 | polygon := svg.AddNewTag([]byte("polygon"))
460 | var buf bytes.Buffer
461 | lastIndex := len(points) - 1
462 | for i, p := range points {
463 | buf.Write(f2b(p.X))
464 | buf.WriteByte(',')
465 | buf.Write(f2b(p.Y))
466 | if i != lastIndex {
467 | buf.WriteByte(' ')
468 | }
469 | }
470 | polygon.AddAttrib("points", buf.Bytes())
471 | polygon.Fill2(c)
472 | return polygon
473 | }
474 |
475 | func NewPos(xString, yString string) (*Pos, error) {
476 | x, err := strconv.ParseFloat(xString, 64)
477 | if err != nil {
478 | return nil, err
479 | }
480 | y, err := strconv.ParseFloat(yString, 64)
481 | if err != nil {
482 | return nil, err
483 | }
484 | return &Pos{x, y}, nil
485 | }
486 |
487 | func NewPosf(x, y float64) *Pos {
488 | return &Pos{x, y}
489 | }
490 |
491 | func PointsFromString(pointString string) ([]*Pos, error) {
492 | points := make([]*Pos, 0)
493 | for _, positionPair := range strings.Split(pointString, " ") {
494 | elements := strings.Split(positionPair, ",")
495 | if len(elements) != 2 {
496 | return nil, ErrPair
497 | }
498 | p, err := NewPos(elements[0], elements[1])
499 | if err != nil {
500 | return nil, err
501 | }
502 | points = append(points, p)
503 | }
504 | return points, nil
505 | }
506 |
507 | // Describe can be used for adding a description to the SVG header
508 | func (svg *Tag) Describe(description string) {
509 | desc := svg.AddNewTag([]byte("desc"))
510 | desc.AddContent([]byte(description))
511 | }
512 |
--------------------------------------------------------------------------------