├── vendor ├── github.com │ ├── xyproto │ │ ├── tinysvg │ │ │ ├── go.sum │ │ │ ├── go.mod │ │ │ ├── .gitignore │ │ │ ├── README.md │ │ │ ├── tags.go │ │ │ └── tinysvg.go │ │ ├── swish │ │ │ ├── benchmark.sh │ │ │ ├── .travis.yml │ │ │ ├── swish_asm.go │ │ │ ├── go.mod │ │ │ ├── .gitignore │ │ │ ├── go.sum │ │ │ ├── swish_amd64.s │ │ │ ├── LICENSE │ │ │ ├── swish.go │ │ │ └── README.md │ │ └── af │ │ │ ├── .travis.yml │ │ │ ├── go.mod │ │ │ ├── .gitignore │ │ │ ├── go.sum │ │ │ ├── README.md │ │ │ └── af.go │ └── dave │ │ └── jennifer │ │ └── jen │ │ ├── generics.go │ │ ├── add.go │ │ ├── do.go │ │ ├── reserved.go │ │ ├── tag.go │ │ ├── dict.go │ │ ├── custom.go │ │ ├── statement.go │ │ ├── comments.go │ │ ├── group.go │ │ ├── jen.go │ │ ├── lit.go │ │ ├── file.go │ │ ├── tokens.go │ │ └── hints.go └── modules.txt ├── cmd ├── simple │ └── main.go ├── evolve │ └── main.go └── statement │ └── main.go ├── go.mod ├── .gitignore ├── mnist └── download_extract.sh ├── TODO.md ├── diagram_test.go ├── img ├── diagram.svg ├── evolved.svg ├── best.svg ├── result.svg ├── test.svg ├── before.svg ├── after.svg └── labels.svg ├── combine.go ├── teststatement.txt ├── af_test.go ├── utils.go ├── norm.go ├── .github └── workflows │ └── test.yml ├── LICENSE ├── config.go ├── neuron_test.go ├── statement_test.go ├── go.sum ├── README.md ├── diagram.go ├── af.go ├── network_test.go ├── neuron.go ├── evolve.go └── network.go /vendor/github.com/xyproto/tinysvg/go.sum: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/generics.go: -------------------------------------------------------------------------------- 1 | package jen 2 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | go test -bench=. 3 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/tinysvg/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xyproto/tinysvg 2 | 3 | go 1.9 4 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/af/.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.11" 5 | - "1.12" 6 | - "1.13" 7 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.11" 5 | - "1.12" 6 | - "1.13" 7 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/af/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xyproto/af 2 | 3 | go 1.11 4 | 5 | require github.com/xyproto/swish v1.3.0 6 | -------------------------------------------------------------------------------- /cmd/simple/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/xyproto/wann" 7 | ) 8 | 9 | func main() { 10 | fmt.Println(wann.NewNetwork()) 11 | } 12 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/tinysvg/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !*.* 3 | !*/ 4 | *.o 5 | *.swp 6 | *.tmp 7 | *.bak 8 | *.pro.user 9 | *CMakeFiles* 10 | *.dblite 11 | *.so 12 | .vscode 13 | *.svg 14 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xyproto/wann 2 | 3 | go 1.11 4 | 5 | require ( 6 | github.com/dave/jennifer v1.5.0 7 | github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3 8 | github.com/xyproto/tinysvg v1.0.1 9 | ) 10 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/swish_asm.go: -------------------------------------------------------------------------------- 1 | //+build amd64 2 | 3 | package swish 4 | 5 | // SwishAssembly is the swish function, written in hand-optimized assembly 6 | // go: noescape 7 | func SwishAssembly(x float64) float64 8 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/xyproto/swish 2 | 3 | go 1.11 4 | 5 | require ( 6 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37 7 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a // indirect 8 | ) 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !*.* 3 | !*/ 4 | *.o 5 | *.swp 6 | *.tmp 7 | *.bak 8 | *.pro.user 9 | *.dblite 10 | *.so 11 | *.swo 12 | *.ao 13 | *.pyc 14 | *.orig 15 | *.pb.go 16 | *~ 17 | ._* 18 | *.nfs.* 19 | .vscode 20 | *CMakeFiles* 21 | _obj 22 | _test 23 | _testmain.go 24 | !img/* 25 | network.svg 26 | -------------------------------------------------------------------------------- /vendor/modules.txt: -------------------------------------------------------------------------------- 1 | # github.com/dave/jennifer v1.5.0 2 | github.com/dave/jennifer/jen 3 | # github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3 4 | github.com/xyproto/af 5 | # github.com/xyproto/swish v1.3.0 6 | github.com/xyproto/swish 7 | # github.com/xyproto/tinysvg v1.0.1 8 | github.com/xyproto/tinysvg 9 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, build with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | cmd/graph/graph 15 | cmd/precision/precision 16 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/af/.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.dll 4 | *.so 5 | *.dylib 6 | 7 | # Test binary, build with `go test -c` 8 | *.test 9 | 10 | # Output of the go coverage tool, specifically when used with LiteIDE 11 | *.out 12 | 13 | # Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 14 | .glide/ 15 | 16 | .DS_Store 17 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/af/go.sum: -------------------------------------------------------------------------------- 1 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37/go.mod h1:u9UyCz2eTrSGy6fbupqJ54eY5c4IC8gREQ1053dK12U= 2 | github.com/xyproto/swish v1.3.0 h1:qiVl2UeqkMqDJKjRLVwRprD7Vzz2pbNRcWTb3s0TpjE= 3 | github.com/xyproto/swish v1.3.0/go.mod h1:IVhz2R80pNsPaSxbEcqR84i4MVL0b06HN7KIMhv38WA= 4 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 5 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/go.sum: -------------------------------------------------------------------------------- 1 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37 h1:uxxtrnACqI9zK4ENDMf0WpXfUsHP5V8liuq5QdgDISU= 2 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37/go.mod h1:u9UyCz2eTrSGy6fbupqJ54eY5c4IC8gREQ1053dK12U= 3 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a h1:sPlwkA5W19gtxRApEyGyqWg4ngTrMzOJ43fOsWrgYEE= 4 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 5 | -------------------------------------------------------------------------------- /mnist/download_extract.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo 'downloading training set images' 4 | curl -O 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz' 5 | 6 | echo 'downloading training set labels' 7 | curl -O 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz' 8 | 9 | echo 'downloading test set images' 10 | curl -O 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz' 11 | 12 | echo 'downloading test set labels' 13 | curl -O 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz' 14 | 15 | echo 'extracting' 16 | gunzip -v *.gz 17 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/add.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | // Add appends the provided items to the statement. 4 | func Add(code ...Code) *Statement { 5 | return newStatement().Add(code...) 6 | } 7 | 8 | // Add appends the provided items to the statement. 9 | func (g *Group) Add(code ...Code) *Statement { 10 | s := Add(code...) 11 | g.items = append(g.items, s) 12 | return s 13 | } 14 | 15 | // Add appends the provided items to the statement. 16 | func (s *Statement) Add(code ...Code) *Statement { 17 | *s = append(*s, code...) 18 | return s 19 | } 20 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # TODO 2 | 3 | - [ ] Train and test with the Mnist dataset. 4 | - [ ] Fix any remaining issues with drawing SVG diagrams. 5 | - [ ] Fix any remaining issues with generating expressions. 6 | - [ ] Draw an "O" on the output node in the diagram. 7 | - [ ] Fix an issue with mutating the Network structs (the Neurons needs to be mutated too. And there are pointers everywhere to them). 8 | - [ ] Store all neurons within the network, but keep pointers to input nodes (and the output node might work). 9 | Then all those neurons can be modified when the Network mutates, but the pointers can be kept the same. 10 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/tinysvg/README.md: -------------------------------------------------------------------------------- 1 | # tinysvg [![Go Report Card](https://goreportcard.com/badge/github.com/xyproto/tinysvg)](https://goreportcard.com/report/github.com/xyproto/tinysvg) [![GoDoc](https://godoc.org/github.com/xyproto/tinysvg?status.svg)](https://godoc.org/github.com/xyproto/tinysvg) 2 | 3 | Construct SVG documents and images using Go. 4 | 5 | This package mainly uses `[]byte` slices instead of strings, and does not indent the generated SVG daata, for performance and compactness. 6 | 7 | ## General info 8 | 9 | * Version: 1.0.1 10 | * Author: Alexander F. Rødseth <xyproto@archlinux.org> 11 | * License: MIT 12 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/do.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | // Do calls the provided function with the statement as a parameter. Use for 4 | // embedding logic. 5 | func Do(f func(*Statement)) *Statement { 6 | return newStatement().Do(f) 7 | } 8 | 9 | // Do calls the provided function with the statement as a parameter. Use for 10 | // embedding logic. 11 | func (g *Group) Do(f func(*Statement)) *Statement { 12 | s := Do(f) 13 | g.items = append(g.items, s) 14 | return s 15 | } 16 | 17 | // Do calls the provided function with the statement as a parameter. Use for 18 | // embedding logic. 19 | func (s *Statement) Do(f func(*Statement)) *Statement { 20 | f(s) 21 | return s 22 | } 23 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/swish_amd64.s: -------------------------------------------------------------------------------- 1 | #include "textflag.h" 2 | 3 | DATA expodata<>+0(SB)/8, $1.0 4 | DATA expodata<>+8(SB)/8, $-0.00390625 5 | GLOBL expodata<>+0(SB), RODATA, $16 6 | 7 | TEXT ·SwishAssembly(SB),NOSPLIT|NOPTR,$0-16 8 | // x+0(FP) is the given argument 9 | MOVSD x+0(FP), X1 10 | MOVSD x+0(FP), X2 11 | // x1 *= -0.00390625 which is (1/256) 12 | MULSD expodata<>+8(SB), X1 13 | // x1 += 1.0 14 | MOVSD expodata<>+0(SB), X3 15 | ADDSD X3, X1 16 | // x1 *= x1 ... 17 | MULSD X1, X1 18 | MULSD X1, X1 19 | MULSD X1, X1 20 | MULSD X1, X1 21 | MULSD X1, X1 22 | MULSD X1, X1 23 | MULSD X1, X1 24 | MULSD X1, X1 25 | // x1 += 1.0 26 | ADDSD X3, X1 27 | // x2 /= x1 28 | DIVSD X1, X2 29 | // return x2 30 | MOVSD X2, ret+8(FP) 31 | // done, jump back 32 | RET 33 | -------------------------------------------------------------------------------- /diagram_test.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "math/rand" 5 | "os" 6 | "testing" 7 | ) 8 | 9 | func TestDiagram(t *testing.T) { 10 | rand.Seed(commonSeed) 11 | net := NewNetwork(&Config{ 12 | inputs: 5, 13 | InitialConnectionRatio: 0.5, 14 | sharedWeight: 0.5, 15 | }) 16 | 17 | // Set a few activation functions 18 | net.AllNodes[net.InputNodes[0]].ActivationFunction = Linear 19 | net.AllNodes[net.InputNodes[1]].ActivationFunction = Swish 20 | net.AllNodes[net.InputNodes[2]].ActivationFunction = Gauss 21 | net.AllNodes[net.InputNodes[3]].ActivationFunction = Sigmoid 22 | net.AllNodes[net.InputNodes[4]].ActivationFunction = ReLU 23 | 24 | // Save the diagram as an image 25 | err := net.WriteSVG("test.svg") 26 | if err != nil { 27 | t.Error(err) 28 | } 29 | os.Remove("test.svg") 30 | } 31 | -------------------------------------------------------------------------------- /img/diagram.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wann -------------------------------------------------------------------------------- /combine.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "github.com/dave/jennifer/jen" 5 | ) 6 | 7 | // ActivationStatement creates an activation function statment, given a weight and input statements 8 | // returns: activationFunction(input0 * w + input1 * w + ...) 9 | // The function calling this function is responsible for inserting network input values into the network input nodes. 10 | func ActivationStatement(af ActivationFunctionIndex, w float64, inputStatements []*jen.Statement) *jen.Statement { 11 | // activationFunction(input0 * w + input1 * w + ...) 12 | weightedSum := jen.Empty() 13 | for i, inputStatement := range inputStatements { 14 | if i == 0 { 15 | // first 16 | weightedSum.Add(inputStatement).Op("*").Lit(w) 17 | } else { 18 | // the rest, same as above, but with a leading "+" 19 | weightedSum.Op("+").Add(inputStatement).Op("*").Lit(w) 20 | } 21 | } 22 | return af.Statement(weightedSum) 23 | } 24 | -------------------------------------------------------------------------------- /teststatement.txt: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | ) 7 | 8 | func f(inputData []float64) float64 { 9 | return -(math.Log(1.0+math.Exp(inputData[0]))*1.999999999999592 + math.Exp(-(math.Pow(inputData[1], 2.0))/2.0)*1.999999999999592 + math.Pow(inputData[2], 2.0)*1.999999999999592 + math.Sin((inputData[3])*math.Pi)*1.999999999999592 + math.Abs(inputData[4])*1.999999999999592) 10 | } 11 | 12 | func main() { 13 | up := []float64{ 14 | 0.0, 1.0, 0.0, // o 15 | 1.0, 1.0, 1.0} // ooo 16 | 17 | down := []float64{ 18 | 1.0, 1.0, 1.0, // ooo 19 | 0.0, 1.0, 0.0} // o 20 | 21 | left := []float64{ 22 | 1.0, 1.0, 1.0, // ooo 23 | 0.0, 0.0, 1.0} // o 24 | 25 | right := []float64{ 26 | 1.0, 1.0, 1.0, // ooo 27 | 1.0, 0.0, 0.0} // o 28 | 29 | fmt.Println("up score", f(up)) 30 | fmt.Println("down score", f(down)) 31 | fmt.Println("left score", f(left)) 32 | fmt.Println("right score", f(right)) 33 | } 34 | 35 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/reserved.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | var reserved = []string{ 4 | /* keywords */ 5 | "break", "default", "func", "interface", "select", "case", "defer", "go", "map", "struct", "chan", "else", "goto", "package", "switch", "const", "fallthrough", "if", "range", "type", "continue", "for", "import", "return", "var", 6 | /* predeclared */ 7 | "bool", "byte", "complex64", "complex128", "error", "float32", "float64", "int", "int8", "int16", "int32", "int64", "rune", "string", "uint", "uint8", "uint16", "uint32", "uint64", "uintptr", "true", "false", "iota", "nil", "append", "cap", "close", "complex", "copy", "delete", "imag", "len", "make", "new", "panic", "print", "println", "real", "recover", 8 | /* common variables */ 9 | "err", 10 | } 11 | 12 | // IsReservedWord returns if this is a reserved word in go 13 | func IsReservedWord(alias string) bool { 14 | for _, name := range reserved { 15 | if alias == name { 16 | return true 17 | } 18 | } 19 | return false 20 | } 21 | -------------------------------------------------------------------------------- /af_test.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/dave/jennifer/jen" 7 | ) 8 | 9 | func ExampleActivationFunctionIndex_Call() { 10 | fmt.Println(Gauss.Call(2.0)) 11 | // Output: 12 | // 0.13427659965015956 13 | } 14 | 15 | func ExampleGauss_Statement() { 16 | statement := Gauss.Statement(jen.Id("x")) 17 | fmt.Println(statement.GoString()) 18 | result, err := RunStatementX(statement, 0.5) 19 | if err != nil { 20 | panic(err) 21 | } 22 | fmt.Println(result) 23 | // Output: 24 | // math.Exp(-(math.Pow(x, 2.0)) / 2.0) 25 | // 0.8824969025845955 26 | } 27 | 28 | func ExampleActivationFunctionIndex_GoRun() { 29 | // Run the Gauss function directly 30 | fmt.Println(ActivationFunctions[Gauss](0.5)) 31 | // Use Jennifer to generate a source file just for running the Gauss function, then use "go run" and fetch the result 32 | if result, err := Gauss.GoRun(0.5); err == nil { // no error 33 | fmt.Println(result) 34 | } 35 | // Output: 36 | // 0.8824699625576026 37 | // 0.8824969025845955 38 | } 39 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "sort" 5 | ) 6 | 7 | // Pair is used for sorting dictionaries by value. 8 | // Thanks https://stackoverflow.com/a/18695740/131264 9 | type Pair struct { 10 | Key int 11 | Value float64 12 | } 13 | 14 | // PairList is a slice of Pair 15 | type PairList []Pair 16 | 17 | func (p PairList) Len() int { return len(p) } 18 | func (p PairList) Less(i, j int) bool { return p[i].Value < p[j].Value } 19 | func (p PairList) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 20 | 21 | // SortByValue sorts a map[int]float64 by value 22 | func SortByValue(m map[int]float64) PairList { 23 | pl := make(PairList, len(m)) 24 | i := 0 25 | for k, v := range m { 26 | pl[i] = Pair{k, v} 27 | i++ 28 | } 29 | sort.Sort(sort.Reverse(pl)) 30 | return pl 31 | } 32 | 33 | // In returns true if this NeuronIndex is in the given *[]NeuronIndex slice 34 | func (ni NeuronIndex) In(nodes *[]NeuronIndex) bool { 35 | for _, ni2 := range *nodes { 36 | if ni2 == ni { 37 | return true 38 | } 39 | } 40 | return false 41 | } 42 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alexander F. Rødseth 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /norm.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | // NormalizationInfo contains if and how the score function should be normalized 4 | type NormalizationInfo struct { 5 | shouldNormalize bool 6 | mul, add float64 7 | } 8 | 9 | // NewNormalizationInfo returns a new struct, containing if and how the score function should be normalized 10 | func NewNormalizationInfo(enable bool) *NormalizationInfo { 11 | return &NormalizationInfo{enable, 0.0, 1.0} 12 | } 13 | 14 | // Enable signifies that normalization is enabled when this struct is used 15 | func (norm *NormalizationInfo) Enable() { 16 | norm.shouldNormalize = true 17 | } 18 | 19 | // Disable signifies that normalization is disabled when this struct is used 20 | func (norm *NormalizationInfo) Disable() { 21 | norm.shouldNormalize = false 22 | } 23 | 24 | // Get retrieves the multiplication and addition numbers that can be used for normalization 25 | func (norm *NormalizationInfo) Get() (float64, float64) { 26 | return norm.mul, norm.add 27 | } 28 | 29 | // Set sets the multiplication and addition numbers that can be used for normalization 30 | func (norm *NormalizationInfo) Set(mul, add float64) { 31 | norm.mul = mul 32 | norm.add = add 33 | } 34 | -------------------------------------------------------------------------------- /img/evolved.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wann -------------------------------------------------------------------------------- /vendor/github.com/xyproto/af/README.md: -------------------------------------------------------------------------------- 1 | # af [![Build Status](https://travis-ci.org/xyproto/af.svg?branch=master)](https://travis-ci.org/xyproto/af) [![Go Report Card](https://goreportcard.com/badge/github.com/xyproto/af)](https://goreportcard.com/report/github.com/xyproto/af) [![GoDoc](https://godoc.org/github.com/xyproto/af?status.svg)](https://godoc.org/github.com/xyproto/af) 2 | 3 | Activation functions for neural networks. 4 | 5 | These activation functions are included: 6 | 7 | * Swish (`x / (1 + exp(-x))`) 8 | * Sigmoid (`1 / (1 + exp(-x))`) 9 | * SoftPlus (`log(1 + exp(x))`) 10 | * Gaussian01 (`exp(-(x * x) / 2.0)`) 11 | * Sin (`math.Sin(math.Pi * x)`) 12 | * Cos (`math.Cos(math.Pi * x)`) 13 | * Linear (`x`) 14 | * Inv (`-x`) 15 | * ReLU (`x >= 0 ? x : 0`) 16 | * Squared (`x * x`) 17 | 18 | These `math` functions are included just for convenience: 19 | 20 | * Abs (`math.Abs`) 21 | * Tanh (`math.Tanh`) 22 | 23 | One functions that takes two arguments is also included: 24 | 25 | * PReLU (`x >= 0 ? x : x * a`) 26 | 27 | ## Requirements 28 | 29 | * Go 1.11 or later. 30 | 31 | ## General information 32 | 33 | * License: MIT 34 | * Version: 0.3.2 35 | * Author: Alexander F. Rødseth <xyproto@archlinux.org> 36 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | name: Build 3 | env: 4 | GO111MODULE: on 5 | jobs: 6 | test: 7 | strategy: 8 | matrix: 9 | go-version: [1.13.x, 1.14.x, 1.15.x, 1.16.x, 1.17.x] 10 | os: [ubuntu-latest, macos-latest] 11 | runs-on: ${{ matrix.os }} 12 | steps: 13 | - name: Install Go 14 | uses: actions/setup-go@v2 15 | with: 16 | go-version: ${{ matrix.go-version }} 17 | - name: Checkout code 18 | uses: actions/checkout@v2 19 | - name: Test 20 | run: go test ./... 21 | 22 | test-cache: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Install Go 26 | uses: actions/setup-go@v2 27 | with: 28 | go-version: 1.17.x 29 | - name: Checkout code 30 | uses: actions/checkout@v2 31 | - uses: actions/cache@v2 32 | with: 33 | path: | 34 | ~/go/pkg/mod # Module download cache 35 | ~/.cache/go-build # Build cache (Linux) 36 | ~/Library/Caches/go-build # Build cache (Mac) 37 | '%LocalAppData%\go-build' # Build cache (Windows) 38 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 39 | restore-keys: | 40 | ${{ runner.os }}-go- 41 | - name: Test 42 | run: go test ./... 43 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/af/af.go: -------------------------------------------------------------------------------- 1 | // Package af provides several activation functions that can be used in neural networks 2 | package af 3 | 4 | import ( 5 | "github.com/xyproto/swish" 6 | "math" 7 | ) 8 | 9 | // The swish package offers optimized Swish, Sigmoid 10 | // SoftPlus and Gaussian01 activation functions 11 | var ( 12 | Sigmoid = swish.Sigmoid 13 | Swish = swish.Swish 14 | SoftPlus = swish.SoftPlus 15 | Gaussian01 = swish.Gaussian01 16 | Linear = func(x float64) float64 { return x } 17 | Inv = func(x float64) float64 { return -x } 18 | Sin = func(x float64) float64 { return math.Sin(math.Pi * x) } 19 | Cos = func(x float64) float64 { return math.Cos(math.Pi * x) } 20 | Squared = func(x float64) float64 { return x * x } 21 | Tanh = math.Tanh 22 | Abs = math.Abs 23 | ) 24 | 25 | // Step function 26 | func Step(x float64) float64 { 27 | if x >= 0 { 28 | return 1 29 | } 30 | return 0 31 | } 32 | 33 | // ReLU is the "rectified linear unit" 34 | // `x >= 0 ? x : 0` 35 | func ReLU(x float64) float64 { 36 | if x >= 0 { 37 | return x 38 | } 39 | return 0 40 | } 41 | 42 | // PReLU is the parametric rectified linear unit. 43 | // `x >= 0 ? x : a * x` 44 | func PReLU(x, a float64) float64 { 45 | if x >= 0 { 46 | return x 47 | } 48 | return a * x 49 | } 50 | -------------------------------------------------------------------------------- /img/best.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wann -------------------------------------------------------------------------------- /img/result.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wann -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2022 Alexander F. Rødseth 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 12 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/tag.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "sort" 7 | "strconv" 8 | ) 9 | 10 | // Tag renders a struct tag 11 | func Tag(items map[string]string) *Statement { 12 | return newStatement().Tag(items) 13 | } 14 | 15 | // Tag renders a struct tag 16 | func (g *Group) Tag(items map[string]string) *Statement { 17 | // notest 18 | // don't think this can ever be used in valid code? 19 | s := Tag(items) 20 | g.items = append(g.items, s) 21 | return s 22 | } 23 | 24 | // Tag renders a struct tag 25 | func (s *Statement) Tag(items map[string]string) *Statement { 26 | c := tag{ 27 | items: items, 28 | } 29 | *s = append(*s, c) 30 | return s 31 | } 32 | 33 | type tag struct { 34 | items map[string]string 35 | } 36 | 37 | func (t tag) isNull(f *File) bool { 38 | return len(t.items) == 0 39 | } 40 | 41 | func (t tag) render(f *File, w io.Writer, s *Statement) error { 42 | 43 | if t.isNull(f) { 44 | // notest 45 | // render won't be called if t is null 46 | return nil 47 | } 48 | 49 | var str string 50 | 51 | var sorted []string 52 | for k := range t.items { 53 | sorted = append(sorted, k) 54 | } 55 | sort.Strings(sorted) 56 | 57 | for _, k := range sorted { 58 | v := t.items[k] 59 | if len(str) > 0 { 60 | str += " " 61 | } 62 | str += fmt.Sprintf(`%s:%q`, k, v) 63 | } 64 | 65 | if strconv.CanBackquote(str) { 66 | str = "`" + str + "`" 67 | } else { 68 | str = strconv.Quote(str) 69 | } 70 | 71 | if _, err := w.Write([]byte(str)); err != nil { 72 | return err 73 | } 74 | 75 | return nil 76 | } 77 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/swish.go: -------------------------------------------------------------------------------- 1 | package swish 2 | 3 | import "math" 4 | 5 | // Thanks https://codingforspeed.com/using-faster-exponential-approximation/ 6 | func exp256(x float64) float64 { 7 | x = 1.0 + x/256.0 8 | x *= x 9 | x *= x 10 | x *= x 11 | x *= x 12 | x *= x 13 | x *= x 14 | x *= x 15 | x *= x 16 | return x 17 | } 18 | 19 | // Swish is the x / (1 + exp(-x)) activation function, using exp256 20 | func Swish(x float64) float64 { 21 | return x / (1.0 + exp256(-x)) 22 | } 23 | 24 | // Sigmoid is the 1 / (1 + exp(-x)) activation function, using exp256 25 | func Sigmoid(x float64) float64 { 26 | // Uses exp256 instead of math.Exp 27 | return 1.0 / (1.0 + exp256(-x)) 28 | } 29 | 30 | // SoftPlus is the log(1 + exp(x)) function, using exp256 31 | func SoftPlus(x float64) float64 { 32 | return math.Log(1.0 + exp256(x)) 33 | } 34 | 35 | // Gaussian01 is the Gaussian function with mean 0 and sigma 1, using exp256 36 | func Gaussian01(x float64) float64 { 37 | return exp256(-(x * x) / 2.0) 38 | } 39 | 40 | // SwishPrecise is the x / (1 + exp(-x)) activation function, using math.Exp 41 | func SwishPrecise(x float64) float64 { 42 | return x / (1.0 + math.Exp(-x)) 43 | } 44 | 45 | // SigmoidPrecise is the 1 / (1 + exp(-x)) activation function, using math.Exp 46 | func SigmoidPrecise(x float64) float64 { 47 | return 1.0 / (1.0 + math.Exp(-x)) 48 | } 49 | 50 | // SoftPlusPrecise is the log(1 + exp(x)) function, using math.Exp 51 | func SoftPlusPrecise(x float64) float64 { 52 | return math.Log(1.0 + math.Exp(x)) 53 | } 54 | 55 | // Gaussian01 is the Gaussian function with mean 0 and sigma 1, using math.Exp 56 | func Gaussian01Precise(x float64) float64 { 57 | return math.Exp(-(x * x) / 2.0) 58 | } 59 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/dict.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "sort" 7 | ) 8 | 9 | // Dict renders as key/value pairs. Use with Values for map or composite 10 | // literals. 11 | type Dict map[Code]Code 12 | 13 | // DictFunc executes a func(Dict) to generate the value. Use with Values for 14 | // map or composite literals. 15 | func DictFunc(f func(Dict)) Dict { 16 | d := Dict{} 17 | f(d) 18 | return d 19 | } 20 | 21 | func (d Dict) render(f *File, w io.Writer, s *Statement) error { 22 | first := true 23 | // must order keys to ensure repeatable source 24 | type kv struct { 25 | k Code 26 | v Code 27 | } 28 | lookup := map[string]kv{} 29 | keys := []string{} 30 | for k, v := range d { 31 | if k.isNull(f) || v.isNull(f) { 32 | continue 33 | } 34 | buf := &bytes.Buffer{} 35 | if err := k.render(f, buf, nil); err != nil { 36 | return err 37 | } 38 | keys = append(keys, buf.String()) 39 | lookup[buf.String()] = kv{k: k, v: v} 40 | } 41 | sort.Strings(keys) 42 | for _, key := range keys { 43 | k := lookup[key].k 44 | v := lookup[key].v 45 | if first && len(keys) > 1 { 46 | if _, err := w.Write([]byte("\n")); err != nil { 47 | return err 48 | } 49 | first = false 50 | } 51 | if err := k.render(f, w, nil); err != nil { 52 | return err 53 | } 54 | if _, err := w.Write([]byte(":")); err != nil { 55 | return err 56 | } 57 | if err := v.render(f, w, nil); err != nil { 58 | return err 59 | } 60 | if len(keys) > 1 { 61 | if _, err := w.Write([]byte(",\n")); err != nil { 62 | return err 63 | } 64 | } 65 | } 66 | return nil 67 | } 68 | 69 | func (d Dict) isNull(f *File) bool { 70 | if d == nil || len(d) == 0 { 71 | return true 72 | } 73 | for k, v := range d { 74 | if !k.isNull(f) && !v.isNull(f) { 75 | // if any of the key/value pairs are both not null, the Dict is not 76 | // null 77 | return false 78 | } 79 | } 80 | return true 81 | } 82 | -------------------------------------------------------------------------------- /img/test.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wann -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "time" 7 | ) 8 | 9 | // Config is a struct that is used when initializing new Network structs. 10 | // The idea is that referring to fields by name is more explicit, and that it can 11 | // be re-used in connection with having a configuration file, in the future. 12 | type Config struct { 13 | // Number of input neurons (inputs per slice of floats in inputData in the Evolve function) 14 | inputs int 15 | // When initializing a network, this is the propability that the node will be connected to the output node 16 | InitialConnectionRatio float64 17 | // sharedWeight is the weight that is shared by all nodes, since this is a Weight Agnostic Neural Network 18 | sharedWeight float64 19 | // How many generations to train for, at a maximum? 20 | Generations int 21 | // How large population sizes to use per generation? 22 | PopulationSize int 23 | // For how many generations should the training go on, without any improvement in the best score? Disabled if 0. 24 | MaxIterationsWithoutBestImprovement int 25 | // RandomSeed, for initializing the random number generator. The current time is used for the seed if this is set to 0. 26 | RandomSeed int64 27 | // Verbose output 28 | Verbose bool 29 | // Has the pseudo-random number generator been seeded and the activation function complexity been estimated yet? 30 | initialized bool 31 | } 32 | 33 | // initialize the pseaudo-random number generator, either using the config.RandomSeed or the time 34 | func (config *Config) initRandom() { 35 | randomSeed := config.RandomSeed 36 | if config.RandomSeed == 0 { 37 | randomSeed = time.Now().UTC().UnixNano() 38 | } 39 | if config.Verbose { 40 | fmt.Println("Using random seed:", randomSeed) 41 | } 42 | // Initialize the pseudo-random number generator 43 | rand.Seed(randomSeed) 44 | } 45 | 46 | // Init will initialize the pseudo-random number generator and estimate the complexity of the available activation functions 47 | func (config *Config) Init() { 48 | config.initRandom() 49 | config.estimateComplexity() 50 | config.initialized = true 51 | } 52 | -------------------------------------------------------------------------------- /cmd/evolve/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/xyproto/wann" 8 | ) 9 | 10 | func main() { 11 | // Here are four shapes, representing: up, down, left and right: 12 | 13 | up := []float64{ 14 | 0.0, 1.0, 0.0, // o 15 | 1.0, 1.0, 1.0} // ooo 16 | 17 | down := []float64{ 18 | 1.0, 1.0, 1.0, // ooo 19 | 0.0, 1.0, 0.0} // o 20 | 21 | left := []float64{ 22 | 1.0, 1.0, 1.0, // ooo 23 | 0.0, 0.0, 1.0} // o 24 | 25 | right := []float64{ 26 | 1.0, 1.0, 1.0, // ooo 27 | 1.0, 0.0, 0.0} // o 28 | 29 | // Prepare the input data as a 2D slice 30 | inputData := [][]float64{ 31 | up, 32 | down, 33 | left, 34 | right, 35 | } 36 | 37 | // Target scores for: up, down, left, right 38 | correctResultsForUp := []float64{1.0, -1.0, -1.0, -1.0} 39 | 40 | // Prepare a neural network configuration struct 41 | config := &wann.Config{ 42 | InitialConnectionRatio: 0.2, 43 | Generations: 2000, 44 | PopulationSize: 500, 45 | Verbose: true, 46 | } 47 | 48 | // Evolve a network, using the input data and the sought after results 49 | trainedNetwork, err := config.Evolve(inputData, correctResultsForUp) 50 | if err != nil { 51 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 52 | os.Exit(1) 53 | } 54 | 55 | // Now to test the trained network on 4 different inputs and see if it passes the test 56 | upScore := trainedNetwork.Evaluate(up) 57 | downScore := trainedNetwork.Evaluate(down) 58 | leftScore := trainedNetwork.Evaluate(left) 59 | rightScore := trainedNetwork.Evaluate(right) 60 | 61 | if config.Verbose { 62 | if upScore > downScore && upScore > leftScore && upScore > rightScore { 63 | fmt.Println("Network training complete, the results are good.") 64 | } else { 65 | fmt.Println("Network training complete, but the results did not pass the test.") 66 | } 67 | } 68 | 69 | // Save the trained network as an SVG image 70 | if config.Verbose { 71 | fmt.Print("Writing network.svg...") 72 | } 73 | if err := trainedNetwork.WriteSVG("network.svg"); err != nil { 74 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 75 | os.Exit(1) 76 | } 77 | if config.Verbose { 78 | fmt.Println("ok") 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/custom.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | // Options specifies options for the Custom method 4 | type Options struct { 5 | Open string 6 | Close string 7 | Separator string 8 | Multi bool 9 | } 10 | 11 | // Custom renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator. 12 | func Custom(options Options, statements ...Code) *Statement { 13 | return newStatement().Custom(options, statements...) 14 | } 15 | 16 | // Custom renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator. 17 | func (g *Group) Custom(options Options, statements ...Code) *Statement { 18 | s := Custom(options, statements...) 19 | g.items = append(g.items, s) 20 | return s 21 | } 22 | 23 | // Custom renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator. 24 | func (s *Statement) Custom(options Options, statements ...Code) *Statement { 25 | g := &Group{ 26 | close: options.Close, 27 | items: statements, 28 | multi: options.Multi, 29 | name: "custom", 30 | open: options.Open, 31 | separator: options.Separator, 32 | } 33 | *s = append(*s, g) 34 | return s 35 | } 36 | 37 | // CustomFunc renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator. 38 | func CustomFunc(options Options, f func(*Group)) *Statement { 39 | return newStatement().CustomFunc(options, f) 40 | } 41 | 42 | // CustomFunc renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator. 43 | func (g *Group) CustomFunc(options Options, f func(*Group)) *Statement { 44 | s := CustomFunc(options, f) 45 | g.items = append(g.items, s) 46 | return s 47 | } 48 | 49 | // CustomFunc renders a customized statement list. Pass in options to specify multi-line, and tokens for open, close, separator. 50 | func (s *Statement) CustomFunc(options Options, f func(*Group)) *Statement { 51 | g := &Group{ 52 | close: options.Close, 53 | multi: options.Multi, 54 | name: "custom", 55 | open: options.Open, 56 | separator: options.Separator, 57 | } 58 | f(g) 59 | *s = append(*s, g) 60 | return s 61 | } 62 | -------------------------------------------------------------------------------- /cmd/statement/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/xyproto/wann" 8 | ) 9 | 10 | func main() { 11 | // Here are four shapes, representing: up, down, left and right: 12 | 13 | up := []float64{ 14 | 0.0, 1.0, 0.0, // o 15 | 1.0, 1.0, 1.0} // ooo 16 | 17 | down := []float64{ 18 | 1.0, 1.0, 1.0, // ooo 19 | 0.0, 1.0, 0.0} // o 20 | 21 | left := []float64{ 22 | 1.0, 1.0, 1.0, // ooo 23 | 0.0, 0.0, 1.0} // o 24 | 25 | right := []float64{ 26 | 1.0, 1.0, 1.0, // ooo 27 | 1.0, 0.0, 0.0} // o 28 | 29 | // Prepare the input data as a 2D slice 30 | inputData := [][]float64{ 31 | up, 32 | down, 33 | left, 34 | right, 35 | } 36 | 37 | // Which of the elements in the input data are we trying to identify? 38 | correctResultsForUp := []float64{0.0, -1.0, -1.0, -1.0} 39 | 40 | // Prepare a neural network configuration struct 41 | config := &wann.Config{ 42 | InitialConnectionRatio: 0.2, 43 | Generations: 2000, 44 | PopulationSize: 500, 45 | Verbose: true, 46 | } 47 | 48 | // Evolve a network, using the input data and the sought after results 49 | trainedNetwork, err := config.Evolve(inputData, correctResultsForUp) 50 | if err != nil { 51 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 52 | os.Exit(1) 53 | } 54 | 55 | // Now to test the trained network on 4 different inputs and see if it passes the test 56 | upScore := trainedNetwork.Evaluate(up) 57 | downScore := trainedNetwork.Evaluate(down) 58 | leftScore := trainedNetwork.Evaluate(left) 59 | rightScore := trainedNetwork.Evaluate(right) 60 | 61 | if config.Verbose { 62 | if upScore > downScore && upScore > leftScore && upScore > rightScore { 63 | fmt.Println("Network training complete, the results are good.") 64 | } else { 65 | fmt.Println("Network training complete, but the results did not pass the test.") 66 | } 67 | } 68 | 69 | // Save the trained network as an SVG image 70 | if config.Verbose { 71 | fmt.Print("Writing network.svg...") 72 | } 73 | if err := trainedNetwork.WriteSVG("network.svg"); err != nil { 74 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 75 | os.Exit(1) 76 | } 77 | if config.Verbose { 78 | fmt.Println("ok") 79 | } 80 | 81 | fmt.Println("Statement for this network (a work in progress, not correct yet):") 82 | 83 | networkStatement, err := trainedNetwork.StatementWithInputDataVariables() 84 | if err != nil { 85 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 86 | os.Exit(1) 87 | } 88 | fmt.Println(wann.Render(networkStatement)) 89 | } 90 | -------------------------------------------------------------------------------- /img/before.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wann -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/statement.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/format" 7 | "io" 8 | ) 9 | 10 | // Statement represents a simple list of code items. When rendered the items 11 | // are separated by spaces. 12 | type Statement []Code 13 | 14 | func newStatement() *Statement { 15 | return &Statement{} 16 | } 17 | 18 | // Clone makes a copy of the Statement, so further tokens can be appended 19 | // without affecting the original. 20 | func (s *Statement) Clone() *Statement { 21 | return &Statement{s} 22 | } 23 | 24 | func (s *Statement) previous(c Code) Code { 25 | index := -1 26 | for i, item := range *s { 27 | if item == c { 28 | index = i 29 | break 30 | } 31 | } 32 | if index > 0 { 33 | return (*s)[index-1] 34 | } 35 | return nil 36 | } 37 | 38 | func (s *Statement) isNull(f *File) bool { 39 | if s == nil { 40 | return true 41 | } 42 | for _, c := range *s { 43 | if !c.isNull(f) { 44 | return false 45 | } 46 | } 47 | return true 48 | } 49 | 50 | func (s *Statement) render(f *File, w io.Writer, _ *Statement) error { 51 | first := true 52 | for _, code := range *s { 53 | if code == nil || code.isNull(f) { 54 | // Null() token produces no output but also 55 | // no separator. Empty() token products no 56 | // output but adds a separator. 57 | continue 58 | } 59 | if !first { 60 | if _, err := w.Write([]byte(" ")); err != nil { 61 | return err 62 | } 63 | } 64 | if err := code.render(f, w, s); err != nil { 65 | return err 66 | } 67 | first = false 68 | } 69 | return nil 70 | } 71 | 72 | // Render renders the Statement to the provided writer. 73 | func (s *Statement) Render(writer io.Writer) error { 74 | return s.RenderWithFile(writer, NewFile("")) 75 | } 76 | 77 | // GoString renders the Statement for testing. Any error will cause a panic. 78 | func (s *Statement) GoString() string { 79 | buf := bytes.Buffer{} 80 | if err := s.Render(&buf); err != nil { 81 | panic(err) 82 | } 83 | return buf.String() 84 | } 85 | 86 | // RenderWithFile renders the Statement to the provided writer, using imports from the provided file. 87 | func (s *Statement) RenderWithFile(writer io.Writer, file *File) error { 88 | buf := &bytes.Buffer{} 89 | if err := s.render(file, buf, nil); err != nil { 90 | return err 91 | } 92 | b, err := format.Source(buf.Bytes()) 93 | if err != nil { 94 | return fmt.Errorf("Error %s while formatting source:\n%s", err, buf.String()) 95 | } 96 | if _, err := writer.Write(b); err != nil { 97 | return err 98 | } 99 | return nil 100 | } 101 | 102 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/swish/README.md: -------------------------------------------------------------------------------- 1 | # Swish 2 | 3 | [![Build Status](https://travis-ci.org/xyproto/swish.svg?branch=master)](https://travis-ci.org/xyproto/swish) [![Go Report Card](https://goreportcard.com/badge/github.com/xyproto/swish)](https://goreportcard.com/report/github.com/xyproto/swish) [![GoDoc](https://godoc.org/github.com/xyproto/swish?status.svg)](https://godoc.org/github.com/xyproto/swish) 4 | 5 | An optimized Swish activation function ([Ramachandran, Zoph and Le, 2017](https://arxiv.org/abs/1710.05941)), for neural networks. 6 | 7 | ## Screenshots 8 | 9 | ![](img/swish.png) 10 | 11 | ![](img/sigmoid.png) 12 | 13 | The graphs above were drawn using the program in `cmd/graph`, which uses [goterm](https://github.com/buger/goterm). 14 | 15 | ## Benchmark Results 16 | 17 | ### Using a `Swish` function that uses `math.Exp` 18 | 19 | First run: 20 | 21 | ``` 22 | goos: linux 23 | goarch: amd64 24 | pkg: github.com/xyproto/swish 25 | BenchmarkSwish07-8 200000000 8.93 ns/op 26 | BenchmarkSwish03-8 200000000 8.95 ns/op 27 | PASS 28 | ok github.com/xyproto/swish 5.391s 29 | ``` 30 | 31 | ### Using the optimized `Swish` function that uses `exp256` 32 | 33 | ``` 34 | goos: linux 35 | goarch: amd64 36 | pkg: github.com/xyproto/swish 37 | BenchmarkSwish07-8 2000000000 0.26 ns/op 38 | BenchmarkSwish03-8 2000000000 0.26 ns/op 39 | PASS 40 | ok github.com/xyproto/swish 1.108s 41 | ``` 42 | 43 | The optimized `Swish` function is **34x** faster than the one that uses `math.Exp`, and quite a bit faster than my (apparently bad) attempt at a hand-written assembly version. 44 | 45 | The average error (difference in output value) between the optimized and non-optimized version is `+-0.0013` and the maximum error is `+-0.0024`. This is for `x` in the range `[5,3]`. See the program in `cmd/precision` for how this was calculated. 46 | 47 | ``` 48 | 0.00015 49 | 0.00001 50 | goos: linux 51 | goarch: amd64 52 | pkg: github.com/xyproto/swish 53 | BenchmarkSwishAssembly07-8 500000000 3.63 ns/op 54 | BenchmarkSwishAssembly03-8 500000000 3.65 ns/op 55 | BenchmarkSwish07-8 2000000000 0.30 ns/op 56 | BenchmarkSwish03-8 2000000000 0.26 ns/op 57 | BenchmarkSwishPrecise07-8 200000000 9.07 ns/op 58 | BenchmarkSwishPrecise03-8 200000000 9.25 ns/op 59 | PASS 60 | ok github.com/xyproto/swish 11.100s 61 | ``` 62 | 63 | I have no idea why the assembly version is so slow, but `0.26 ns/op` isn't bad for a non-hand-optimized version. 64 | 65 | ## General info 66 | 67 | * Version: 1.3.0 68 | * License: MIT 69 | * Author: Alexander F. Rødseth <xyproto@archlinux.org> 70 | -------------------------------------------------------------------------------- /img/after.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wann -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/comments.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strings" 7 | ) 8 | 9 | // Comment adds a comment. If the provided string contains a newline, the 10 | // comment is formatted in multiline style. If the comment string starts 11 | // with "//" or "/*", the automatic formatting is disabled and the string is 12 | // rendered directly. 13 | func Comment(str string) *Statement { 14 | return newStatement().Comment(str) 15 | } 16 | 17 | // Comment adds a comment. If the provided string contains a newline, the 18 | // comment is formatted in multiline style. If the comment string starts 19 | // with "//" or "/*", the automatic formatting is disabled and the string is 20 | // rendered directly. 21 | func (g *Group) Comment(str string) *Statement { 22 | s := Comment(str) 23 | g.items = append(g.items, s) 24 | return s 25 | } 26 | 27 | // Comment adds a comment. If the provided string contains a newline, the 28 | // comment is formatted in multiline style. If the comment string starts 29 | // with "//" or "/*", the automatic formatting is disabled and the string is 30 | // rendered directly. 31 | func (s *Statement) Comment(str string) *Statement { 32 | c := comment{ 33 | comment: str, 34 | } 35 | *s = append(*s, c) 36 | return s 37 | } 38 | 39 | // Commentf adds a comment, using a format string and a list of parameters. If 40 | // the provided string contains a newline, the comment is formatted in 41 | // multiline style. If the comment string starts with "//" or "/*", the 42 | // automatic formatting is disabled and the string is rendered directly. 43 | func Commentf(format string, a ...interface{}) *Statement { 44 | return newStatement().Commentf(format, a...) 45 | } 46 | 47 | // Commentf adds a comment, using a format string and a list of parameters. If 48 | // the provided string contains a newline, the comment is formatted in 49 | // multiline style. If the comment string starts with "//" or "/*", the 50 | // automatic formatting is disabled and the string is rendered directly. 51 | func (g *Group) Commentf(format string, a ...interface{}) *Statement { 52 | s := Commentf(format, a...) 53 | g.items = append(g.items, s) 54 | return s 55 | } 56 | 57 | // Commentf adds a comment, using a format string and a list of parameters. If 58 | // the provided string contains a newline, the comment is formatted in 59 | // multiline style. If the comment string starts with "//" or "/*", the 60 | // automatic formatting is disabled and the string is rendered directly. 61 | func (s *Statement) Commentf(format string, a ...interface{}) *Statement { 62 | c := comment{ 63 | comment: fmt.Sprintf(format, a...), 64 | } 65 | *s = append(*s, c) 66 | return s 67 | } 68 | 69 | type comment struct { 70 | comment string 71 | } 72 | 73 | func (c comment) isNull(f *File) bool { 74 | return false 75 | } 76 | 77 | func (c comment) render(f *File, w io.Writer, s *Statement) error { 78 | if strings.HasPrefix(c.comment, "//") || strings.HasPrefix(c.comment, "/*") { 79 | // automatic formatting disabled. 80 | if _, err := w.Write([]byte(c.comment)); err != nil { 81 | return err 82 | } 83 | return nil 84 | } 85 | if strings.Contains(c.comment, "\n") { 86 | if _, err := w.Write([]byte("/*\n")); err != nil { 87 | return err 88 | } 89 | } else { 90 | if _, err := w.Write([]byte("// ")); err != nil { 91 | return err 92 | } 93 | } 94 | if _, err := w.Write([]byte(c.comment)); err != nil { 95 | return err 96 | } 97 | if strings.Contains(c.comment, "\n") { 98 | if !strings.HasSuffix(c.comment, "\n") { 99 | if _, err := w.Write([]byte("\n")); err != nil { 100 | return err 101 | } 102 | } 103 | if _, err := w.Write([]byte("*/")); err != nil { 104 | return err 105 | } 106 | } 107 | return nil 108 | } 109 | -------------------------------------------------------------------------------- /neuron_test.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "math/rand" 7 | "testing" 8 | ) 9 | 10 | func TestNeuron(t *testing.T) { 11 | rand.Seed(commonSeed) 12 | net := NewNetwork() 13 | n, _ := net.NewBlankNeuron() 14 | n.ActivationFunction = Swish 15 | result := n.GetActivationFunction()(0.5) 16 | diff := math.Abs(result - 0.311287) 17 | if diff > 0.00001 { // 0.0000001 { 18 | t.Errorf("default swish activation function, expected a number close to 0.311287, got %f:", result) 19 | } 20 | 21 | fmt.Printf("Neurons in network: %d\n", len(net.AllNodes)) 22 | } 23 | 24 | func TestString(t *testing.T) { 25 | rand.Seed(commonSeed) 26 | net := NewNetwork() 27 | n, _ := net.NewBlankNeuron() 28 | _ = n.String() 29 | } 30 | 31 | func TestHasInput(t *testing.T) { 32 | rand.Seed(commonSeed) 33 | net := NewNetwork() // 0 34 | a, _ := net.NewBlankNeuron() // 1 35 | b, _ := net.NewBlankNeuron() // 2 36 | fmt.Println("a is 1?", a) 37 | fmt.Println("b is 2?", b) 38 | a.AddInput(0) 39 | if !a.HasInput(0) { 40 | t.Errorf("a should have b as an input") 41 | } 42 | if b.HasInput(0) { 43 | t.Errorf("b should not have a as an input") 44 | } 45 | } 46 | 47 | func TestFindInput(t *testing.T) { 48 | rand.Seed(commonSeed) 49 | net := NewNetwork() 50 | 51 | a, _ := net.NewBlankNeuron() // a, 1 52 | _, bi := net.NewBlankNeuron() // b, 2 53 | c, ci := net.NewBlankNeuron() // c, 3 54 | _, di := net.NewBlankNeuron() // d, 4 55 | 56 | a.AddInput(bi) // b 57 | a.AddInputNeuron(c) // c 58 | 59 | if _, found := a.FindInput(di); found { 60 | t.Errorf("a should not have d as an input") 61 | } 62 | if pos, found := a.FindInput(bi); !found { 63 | t.Errorf("a should have b as an input") 64 | } else if found && pos != 0 { 65 | t.Errorf("a should have b as an input at position 0") 66 | } 67 | if pos, found := a.FindInput(ci); !found { 68 | t.Errorf("a should have c as an input") 69 | } else if found && pos != 1 { 70 | t.Errorf("a should have c as an input at position 1") 71 | } 72 | } 73 | 74 | func TestRemoveInput(t *testing.T) { 75 | rand.Seed(commonSeed) 76 | net := NewNetwork(&Config{ 77 | inputs: 5, 78 | InitialConnectionRatio: 0.5, 79 | sharedWeight: 0.5, 80 | }) 81 | 82 | a, _ := net.NewBlankNeuron() // 0 83 | a.AddInput(1) 84 | a.AddInput(2) 85 | if a.RemoveInput(1) != nil { 86 | t.Errorf("could not remove input b from a") 87 | } 88 | if a.RemoveInput(2) != nil { 89 | t.Errorf("could not remove input c from a") 90 | } 91 | if a.HasInput(1) { 92 | t.Errorf("a should not have b as an input") 93 | } 94 | if a.HasInput(2) { 95 | t.Errorf("a should not have c as an input") 96 | } 97 | } 98 | 99 | // func (neuron *Neuron) RemoveInput(e *Neuron) error { 100 | 101 | func TestEvaluate(t *testing.T) { 102 | rand.Seed(commonSeed) 103 | net := NewNetwork(&Config{ 104 | inputs: 7, 105 | InitialConnectionRatio: 0.5, 106 | sharedWeight: 0.5, 107 | }) 108 | 109 | // Set a few activation functions 110 | net.AllNodes[net.InputNodes[0]].ActivationFunction = Linear 111 | net.AllNodes[net.InputNodes[1]].ActivationFunction = Swish 112 | net.AllNodes[net.InputNodes[2]].ActivationFunction = Gauss 113 | net.AllNodes[net.InputNodes[3]].ActivationFunction = Sigmoid 114 | net.AllNodes[net.InputNodes[4]].ActivationFunction = ReLU 115 | net.AllNodes[net.InputNodes[5]].ActivationFunction = Step 116 | net.AllNodes[net.InputNodes[6]].ActivationFunction = Inv 117 | 118 | result := net.Evaluate([]float64{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}) 119 | fmt.Println(result) 120 | } 121 | 122 | func TestIn(t *testing.T) { 123 | rand.Seed(commonSeed) 124 | net := NewNetwork() 125 | n, ni := net.NewNeuron() 126 | if ni != 1 { 127 | t.Fail() 128 | } 129 | outputNeuronIndex := NeuronIndex(0) 130 | if !n.In([]NeuronIndex{outputNeuronIndex, 1}) { 131 | t.Fail() 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /statement_test.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | ) 8 | 9 | // ExampleNetwork_StatementWithInputValues 10 | func TestNetworkStatementWithInputValues(t *testing.T) { 11 | rand.Seed(1) 12 | net := NewNetwork(&Config{ 13 | inputs: 6, 14 | InitialConnectionRatio: 0.7, 15 | sharedWeight: 0.5, 16 | }) 17 | 18 | net.SetInputValues([]float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5}) 19 | statement, err := net.StatementWithInputValues() 20 | if err != nil { 21 | panic(err) 22 | } 23 | fmt.Println(Render(statement)) 24 | } 25 | 26 | // ExampleNetwork_StatementWithInputDataVariables 27 | func TestNetwork_StatementWithInputDataVariables(t *testing.T) { 28 | rand.Seed(1) 29 | net := NewNetwork(&Config{ 30 | inputs: 6, 31 | InitialConnectionRatio: 0.7, 32 | sharedWeight: 0.5, 33 | }) 34 | 35 | // 1.234 should not appear in the output statement 36 | net.SetInputValues([]float64{1.234, 1.234, 1.234, 1.234, 1.234, 1.234}) 37 | 38 | statement, err := net.StatementWithInputDataVariables() 39 | if err != nil { 40 | panic(err) 41 | } 42 | fmt.Println(Render(statement)) 43 | } 44 | 45 | // ExampleNeuron_InputStatement 46 | func ExampleNeuron_InputStatement() { 47 | rand.Seed(1) 48 | net := NewNetwork(&Config{ 49 | inputs: 6, 50 | InitialConnectionRatio: 0.7, 51 | sharedWeight: 0.5, 52 | }) 53 | 54 | // 1.234 should not appear in the output statement 55 | net.SetInputValues([]float64{1.234, 1.234, 1.234, 1.234, 1.234, 1.234}) 56 | 57 | inputStatement2, err := net.AllNodes[net.InputNodes[2]].InputStatement() 58 | if err != nil { 59 | panic(err) 60 | } 61 | fmt.Println(Render(inputStatement2)) 62 | // Output: 63 | // inputData[2] 64 | } 65 | 66 | func ExampleNetwork_OutputNodeStatementX_first() { 67 | // First create a network with only one output node, that has a step function 68 | net := NewNetwork() 69 | net.AllNodes[net.OutputNode].ActivationFunction = Step 70 | 71 | fmt.Println(net.OutputNodeStatementX("score")) 72 | 73 | // Output: 74 | // score := func(s float64) float64 { 75 | // if s >= 0 { 76 | // return 1 77 | // } else { 78 | // return 0 79 | // } 80 | // }(x) 81 | } 82 | 83 | func ExampleNetwork_OutputNodeStatementX_second() { 84 | // Then create a network with an input node that has a sigmoid function and an output node that has an invert function 85 | net := NewNetwork() 86 | net.NewInputNode(Sigmoid, true) 87 | net.AllNodes[net.OutputNode].ActivationFunction = Inv 88 | 89 | // Output a Go expression for this network, using the given input variable names 90 | fmt.Println(net.OutputNodeStatementX("score")) 91 | 92 | // Output: 93 | // score := -(x) 94 | 95 | } 96 | 97 | func ExampleNetwork_OutputNodeStatementX_third() { 98 | rand.Seed(999) 99 | net := NewNetwork(&Config{ 100 | inputs: 1, 101 | InitialConnectionRatio: 0.7, 102 | sharedWeight: 0.5, 103 | }) 104 | fmt.Println(net.OutputNodeStatementX("score")) 105 | 106 | // Output: 107 | // score := math.Exp(-(math.Pow(x, 2.0)) / 2.0) 108 | } 109 | func ExampleNetwork_OutputNodeStatementX_fourth() { 110 | 111 | rand.Seed(1111113) 112 | net := NewNetwork(&Config{ 113 | inputs: 5, 114 | InitialConnectionRatio: 0.7, 115 | sharedWeight: 0.5, 116 | }) 117 | fmt.Println(net.OutputNodeStatementX("score")) 118 | 119 | // Output: 120 | // score := func(r float64) float64 { 121 | // if r >= 0 { 122 | // return r 123 | // } else { 124 | // return 0 125 | // } 126 | // }(x) 127 | } 128 | 129 | func ExampleNetwork_OutputNodeStatementX() { 130 | rand.Seed(1) 131 | net := NewNetwork(&Config{ 132 | inputs: 5, 133 | InitialConnectionRatio: 0.7, 134 | sharedWeight: 0.5, 135 | }) 136 | fmt.Println(net.OutputNodeStatementX("f")) 137 | //fmt.Println(net.Score()) 138 | 139 | // Output: 140 | // f := math.Pow(x, 2.0) 141 | } 142 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/buger/goterm v0.0.0-20181115115552-c206103e1f37/go.mod h1:u9UyCz2eTrSGy6fbupqJ54eY5c4IC8gREQ1053dK12U= 2 | github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14/go.mod h1:Sth2QfxfATb/nW4EsrSi2KyJmbcniZ8TgTaji17D6ms= 3 | github.com/dave/brenda v1.1.0/go.mod h1:4wCUr6gSlu5/1Tk7akE5X7UorwiQ8Rij0SKH3/BGMOM= 4 | github.com/dave/courtney v0.3.0/go.mod h1:BAv3hA06AYfNUjfjQr+5gc6vxeBVOupLqrColj+QSD8= 5 | github.com/dave/gopackages v0.0.0-20170318123100-46e7023ec56e/go.mod h1:i00+b/gKdIDIxuLDFob7ustLAVqhsZRk2qVZrArELGQ= 6 | github.com/dave/jennifer v1.5.0 h1:HmgPN93bVDpkQyYbqhCHj5QlgvUkvEOzMyEvKLgCRrg= 7 | github.com/dave/jennifer v1.5.0/go.mod h1:4MnyiFIlZS3l5tSDn8VnzE6ffAhYBMB2SZntBsZGUok= 8 | github.com/dave/kerr v0.0.0-20170318121727-bc25dd6abe8e/go.mod h1:qZqlPyPvfsDJt+3wHJ1EvSXDuVjFTK0j2p/ca+gtsb8= 9 | github.com/dave/patsy v0.0.0-20210517141501-957256f50cba/go.mod h1:qfR88CgEGLoiqDaE+xxDCi5QA5v4vUoW0UCX2Nd5Tlc= 10 | github.com/dave/rebecca v0.9.1/go.mod h1:N6XYdMD/OKw3lkF3ywh8Z6wPGuwNFDNtWYEMFWEmXBA= 11 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 12 | github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3 h1:QN9wTN6qa12WnnYyQ9xlUvyU1W9V58u3SWgwNKp5RpA= 13 | github.com/xyproto/af v0.0.0-20191018214415-1a8887381bd3/go.mod h1:xYoWKJKu61/NJ4/6PV6mJuNu5QymmrbJmmI5mNOW+fI= 14 | github.com/xyproto/swish v1.3.0 h1:qiVl2UeqkMqDJKjRLVwRprD7Vzz2pbNRcWTb3s0TpjE= 15 | github.com/xyproto/swish v1.3.0/go.mod h1:IVhz2R80pNsPaSxbEcqR84i4MVL0b06HN7KIMhv38WA= 16 | github.com/xyproto/tinysvg v1.0.1 h1:UwuNaudh95O0BzdBJ4zQ0lTft5a4ToHKfo6xvKb3it0= 17 | github.com/xyproto/tinysvg v1.0.1/go.mod h1:DKgmaYuFIvJab9ug4nH4ZG356VtUaKXG2mUU07GIurs= 18 | github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= 19 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 20 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 21 | golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= 22 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 23 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 24 | golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 25 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 26 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 27 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 28 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 29 | golang.org/x/sys v0.0.0-20190613101156-ab3f67ed278a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 30 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 31 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 32 | golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 33 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 34 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 35 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 36 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 37 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 38 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 39 | golang.org/x/tools v0.1.8/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= 40 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 41 | golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 42 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 43 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/group.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/format" 7 | "io" 8 | ) 9 | 10 | // Group represents a list of Code items, separated by tokens with an optional 11 | // open and close token. 12 | type Group struct { 13 | name string 14 | items []Code 15 | open string 16 | close string 17 | separator string 18 | multi bool 19 | } 20 | 21 | func (g *Group) isNull(f *File) bool { 22 | if g == nil { 23 | return true 24 | } 25 | if g.open != "" || g.close != "" { 26 | return false 27 | } 28 | for _, c := range g.items { 29 | if !c.isNull(f) { 30 | return false 31 | } 32 | } 33 | return true 34 | } 35 | 36 | func (g *Group) render(f *File, w io.Writer, s *Statement) error { 37 | if g.name == "block" && s != nil { 38 | // Special CaseBlock format for then the previous item in the statement 39 | // is a Case group or the default keyword. 40 | prev := s.previous(g) 41 | grp, isGrp := prev.(*Group) 42 | tkn, isTkn := prev.(token) 43 | if isGrp && grp.name == "case" || isTkn && tkn.content == "default" { 44 | g.open = "" 45 | g.close = "" 46 | } 47 | } 48 | if g.open != "" { 49 | if _, err := w.Write([]byte(g.open)); err != nil { 50 | return err 51 | } 52 | } 53 | isNull, err := g.renderItems(f, w) 54 | if err != nil { 55 | return err 56 | } 57 | if !isNull && g.multi && g.close != "" { 58 | // For multi-line blocks with a closing token, we insert a new line after the last item (but 59 | // not if all items were null). This is to ensure that if the statement finishes with a comment, 60 | // the closing token is not commented out. 61 | s := "\n" 62 | if g.separator == "," { 63 | // We also insert add trailing comma if the separator was ",". 64 | s = ",\n" 65 | } 66 | if _, err := w.Write([]byte(s)); err != nil { 67 | return err 68 | } 69 | } 70 | if g.close != "" { 71 | if _, err := w.Write([]byte(g.close)); err != nil { 72 | return err 73 | } 74 | } 75 | return nil 76 | } 77 | 78 | func (g *Group) renderItems(f *File, w io.Writer) (isNull bool, err error) { 79 | first := true 80 | for _, code := range g.items { 81 | if pt, ok := code.(token); ok && pt.typ == packageToken { 82 | // Special case for package tokens in Qual groups - for dot-imports, the package token 83 | // will be null, so will not render and will not be registered in the imports block. 84 | // This ensures all packageTokens that are rendered are registered. 85 | f.register(pt.content.(string)) 86 | } 87 | if code == nil || code.isNull(f) { 88 | // Null() token produces no output but also 89 | // no separator. Empty() token products no 90 | // output but adds a separator. 91 | continue 92 | } 93 | if g.name == "values" { 94 | if _, ok := code.(Dict); ok && len(g.items) > 1 { 95 | panic("Error in Values: if Dict is used, must be one item only") 96 | } 97 | } 98 | if !first && g.separator != "" { 99 | // The separator token is added before each non-null item, but not before the first item. 100 | if _, err := w.Write([]byte(g.separator)); err != nil { 101 | return false, err 102 | } 103 | } 104 | if g.multi { 105 | // For multi-line blocks, we insert a new line before each non-null item. 106 | if _, err := w.Write([]byte("\n")); err != nil { 107 | return false, err 108 | } 109 | } 110 | if err := code.render(f, w, nil); err != nil { 111 | return false, err 112 | } 113 | first = false 114 | } 115 | return first, nil 116 | } 117 | 118 | // Render renders the Group to the provided writer. 119 | func (g *Group) Render(writer io.Writer) error { 120 | return g.RenderWithFile(writer, NewFile("")) 121 | } 122 | 123 | // GoString renders the Group for testing. Any error will cause a panic. 124 | func (g *Group) GoString() string { 125 | buf := bytes.Buffer{} 126 | if err := g.Render(&buf); err != nil { 127 | panic(err) 128 | } 129 | return buf.String() 130 | } 131 | 132 | // RenderWithFile renders the Group to the provided writer, using imports from the provided file. 133 | func (g *Group) RenderWithFile(writer io.Writer, file *File) error { 134 | buf := &bytes.Buffer{} 135 | if err := g.render(file, buf, nil); err != nil { 136 | return err 137 | } 138 | b, err := format.Source(buf.Bytes()) 139 | if err != nil { 140 | return fmt.Errorf("Error %s while formatting source:\n%s", err, buf.String()) 141 | } 142 | if _, err := writer.Write(b); err != nil { 143 | return err 144 | } 145 | return nil 146 | } 147 | 148 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/jen.go: -------------------------------------------------------------------------------- 1 | // Package jen is a code generator for Go 2 | package jen 3 | 4 | import ( 5 | "bytes" 6 | "fmt" 7 | "go/format" 8 | "io" 9 | "io/ioutil" 10 | "sort" 11 | "strconv" 12 | ) 13 | 14 | // Code represents an item of code that can be rendered. 15 | type Code interface { 16 | render(f *File, w io.Writer, s *Statement) error 17 | isNull(f *File) bool 18 | } 19 | 20 | // Save renders the file and saves to the filename provided. 21 | func (f *File) Save(filename string) error { 22 | // notest 23 | buf := &bytes.Buffer{} 24 | if err := f.Render(buf); err != nil { 25 | return err 26 | } 27 | if err := ioutil.WriteFile(filename, buf.Bytes(), 0644); err != nil { 28 | return err 29 | } 30 | return nil 31 | } 32 | 33 | // Render renders the file to the provided writer. 34 | func (f *File) Render(w io.Writer) error { 35 | body := &bytes.Buffer{} 36 | if err := f.render(f, body, nil); err != nil { 37 | return err 38 | } 39 | source := &bytes.Buffer{} 40 | if len(f.headers) > 0 { 41 | for _, c := range f.headers { 42 | if err := Comment(c).render(f, source, nil); err != nil { 43 | return err 44 | } 45 | if _, err := fmt.Fprint(source, "\n"); err != nil { 46 | return err 47 | } 48 | } 49 | // Append an extra newline so that header comments don't get lumped in 50 | // with package comments. 51 | if _, err := fmt.Fprint(source, "\n"); err != nil { 52 | return err 53 | } 54 | } 55 | for _, c := range f.comments { 56 | if err := Comment(c).render(f, source, nil); err != nil { 57 | return err 58 | } 59 | if _, err := fmt.Fprint(source, "\n"); err != nil { 60 | return err 61 | } 62 | } 63 | if _, err := fmt.Fprintf(source, "package %s", f.name); err != nil { 64 | return err 65 | } 66 | if f.CanonicalPath != "" { 67 | if _, err := fmt.Fprintf(source, " // import %q", f.CanonicalPath); err != nil { 68 | return err 69 | } 70 | } 71 | if _, err := fmt.Fprint(source, "\n\n"); err != nil { 72 | return err 73 | } 74 | if err := f.renderImports(source); err != nil { 75 | return err 76 | } 77 | if _, err := source.Write(body.Bytes()); err != nil { 78 | return err 79 | } 80 | formatted, err := format.Source(source.Bytes()) 81 | if err != nil { 82 | return fmt.Errorf("Error %s while formatting source:\n%s", err, source.String()) 83 | } 84 | if _, err := w.Write(formatted); err != nil { 85 | return err 86 | } 87 | return nil 88 | } 89 | 90 | func (f *File) renderImports(source io.Writer) error { 91 | 92 | // Render the "C" import if it's been used in a `Qual`, `Anon` or if there's a preamble comment 93 | hasCgo := f.imports["C"].name != "" || len(f.cgoPreamble) > 0 94 | 95 | // Only separate the import from the main imports block if there's a preamble 96 | separateCgo := hasCgo && len(f.cgoPreamble) > 0 97 | 98 | filtered := map[string]importdef{} 99 | for path, def := range f.imports { 100 | // filter out the "C" pseudo-package so it's not rendered in a block with the other 101 | // imports, but only if it is accompanied by a preamble comment 102 | if path == "C" && separateCgo { 103 | continue 104 | } 105 | filtered[path] = def 106 | } 107 | 108 | if len(filtered) == 1 { 109 | for path, def := range filtered { 110 | if def.alias && path != "C" { 111 | // "C" package should be rendered without alias even when used as an anonymous import 112 | // (e.g. should never have an underscore). 113 | if _, err := fmt.Fprintf(source, "import %s %s\n\n", def.name, strconv.Quote(path)); err != nil { 114 | return err 115 | } 116 | } else { 117 | if _, err := fmt.Fprintf(source, "import %s\n\n", strconv.Quote(path)); err != nil { 118 | return err 119 | } 120 | } 121 | } 122 | } else if len(filtered) > 1 { 123 | if _, err := fmt.Fprint(source, "import (\n"); err != nil { 124 | return err 125 | } 126 | // We must sort the imports to ensure repeatable 127 | // source. 128 | paths := []string{} 129 | for path := range filtered { 130 | paths = append(paths, path) 131 | } 132 | sort.Strings(paths) 133 | for _, path := range paths { 134 | def := filtered[path] 135 | if def.alias && path != "C" { 136 | // "C" package should be rendered without alias even when used as an anonymous import 137 | // (e.g. should never have an underscore). 138 | if _, err := fmt.Fprintf(source, "%s %s\n", def.name, strconv.Quote(path)); err != nil { 139 | return err 140 | } 141 | 142 | } else { 143 | if _, err := fmt.Fprintf(source, "%s\n", strconv.Quote(path)); err != nil { 144 | return err 145 | } 146 | } 147 | } 148 | if _, err := fmt.Fprint(source, ")\n\n"); err != nil { 149 | return err 150 | } 151 | } 152 | 153 | if separateCgo { 154 | for _, c := range f.cgoPreamble { 155 | if err := Comment(c).render(f, source, nil); err != nil { 156 | return err 157 | } 158 | if _, err := fmt.Fprint(source, "\n"); err != nil { 159 | return err 160 | } 161 | } 162 | if _, err := fmt.Fprint(source, "import \"C\"\n\n"); err != nil { 163 | return err 164 | } 165 | } 166 | 167 | return nil 168 | } 169 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/lit.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | // Lit renders a literal. Lit supports only built-in types (bool, string, int, complex128, float64, 4 | // float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64). 5 | // Passing any other type will panic. 6 | func Lit(v interface{}) *Statement { 7 | return newStatement().Lit(v) 8 | } 9 | 10 | // Lit renders a literal. Lit supports only built-in types (bool, string, int, complex128, float64, 11 | // float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64). 12 | // Passing any other type will panic. 13 | func (g *Group) Lit(v interface{}) *Statement { 14 | s := Lit(v) 15 | g.items = append(g.items, s) 16 | return s 17 | } 18 | 19 | // Lit renders a literal. Lit supports only built-in types (bool, string, int, complex128, float64, 20 | // float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64). 21 | // Passing any other type will panic. 22 | func (s *Statement) Lit(v interface{}) *Statement { 23 | t := token{ 24 | typ: literalToken, 25 | content: v, 26 | } 27 | *s = append(*s, t) 28 | return s 29 | } 30 | 31 | // LitFunc renders a literal. LitFunc generates the value to render by executing the provided 32 | // function. LitFunc supports only built-in types (bool, string, int, complex128, float64, float32, 33 | // int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64). 34 | // Returning any other type will panic. 35 | func LitFunc(f func() interface{}) *Statement { 36 | return newStatement().LitFunc(f) 37 | } 38 | 39 | // LitFunc renders a literal. LitFunc generates the value to render by executing the provided 40 | // function. LitFunc supports only built-in types (bool, string, int, complex128, float64, float32, 41 | // int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64). 42 | // Returning any other type will panic. 43 | func (g *Group) LitFunc(f func() interface{}) *Statement { 44 | s := LitFunc(f) 45 | g.items = append(g.items, s) 46 | return s 47 | } 48 | 49 | // LitFunc renders a literal. LitFunc generates the value to render by executing the provided 50 | // function. LitFunc supports only built-in types (bool, string, int, complex128, float64, float32, 51 | // int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr and complex64). 52 | // Returning any other type will panic. 53 | func (s *Statement) LitFunc(f func() interface{}) *Statement { 54 | t := token{ 55 | typ: literalToken, 56 | content: f(), 57 | } 58 | *s = append(*s, t) 59 | return s 60 | } 61 | 62 | // LitRune renders a rune literal. 63 | func LitRune(v rune) *Statement { 64 | return newStatement().LitRune(v) 65 | } 66 | 67 | // LitRune renders a rune literal. 68 | func (g *Group) LitRune(v rune) *Statement { 69 | s := LitRune(v) 70 | g.items = append(g.items, s) 71 | return s 72 | } 73 | 74 | // LitRune renders a rune literal. 75 | func (s *Statement) LitRune(v rune) *Statement { 76 | t := token{ 77 | typ: literalRuneToken, 78 | content: v, 79 | } 80 | *s = append(*s, t) 81 | return s 82 | } 83 | 84 | // LitRuneFunc renders a rune literal. LitRuneFunc generates the value to 85 | // render by executing the provided function. 86 | func LitRuneFunc(f func() rune) *Statement { 87 | return newStatement().LitRuneFunc(f) 88 | } 89 | 90 | // LitRuneFunc renders a rune literal. LitRuneFunc generates the value to 91 | // render by executing the provided function. 92 | func (g *Group) LitRuneFunc(f func() rune) *Statement { 93 | s := LitRuneFunc(f) 94 | g.items = append(g.items, s) 95 | return s 96 | } 97 | 98 | // LitRuneFunc renders a rune literal. LitRuneFunc generates the value to 99 | // render by executing the provided function. 100 | func (s *Statement) LitRuneFunc(f func() rune) *Statement { 101 | t := token{ 102 | typ: literalRuneToken, 103 | content: f(), 104 | } 105 | *s = append(*s, t) 106 | return s 107 | } 108 | 109 | // LitByte renders a byte literal. 110 | func LitByte(v byte) *Statement { 111 | return newStatement().LitByte(v) 112 | } 113 | 114 | // LitByte renders a byte literal. 115 | func (g *Group) LitByte(v byte) *Statement { 116 | s := LitByte(v) 117 | g.items = append(g.items, s) 118 | return s 119 | } 120 | 121 | // LitByte renders a byte literal. 122 | func (s *Statement) LitByte(v byte) *Statement { 123 | t := token{ 124 | typ: literalByteToken, 125 | content: v, 126 | } 127 | *s = append(*s, t) 128 | return s 129 | } 130 | 131 | // LitByteFunc renders a byte literal. LitByteFunc generates the value to 132 | // render by executing the provided function. 133 | func LitByteFunc(f func() byte) *Statement { 134 | return newStatement().LitByteFunc(f) 135 | } 136 | 137 | // LitByteFunc renders a byte literal. LitByteFunc generates the value to 138 | // render by executing the provided function. 139 | func (g *Group) LitByteFunc(f func() byte) *Statement { 140 | s := LitByteFunc(f) 141 | g.items = append(g.items, s) 142 | return s 143 | } 144 | 145 | // LitByteFunc renders a byte literal. LitByteFunc generates the value to 146 | // render by executing the provided function. 147 | func (s *Statement) LitByteFunc(f func() byte) *Statement { 148 | t := token{ 149 | typ: literalByteToken, 150 | content: f(), 151 | } 152 | *s = append(*s, t) 153 | return s 154 | } 155 | -------------------------------------------------------------------------------- /img/labels.svg: -------------------------------------------------------------------------------- 1 | generated with github.com/xyproto/wannSwish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Swish [2]Inverted [o] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Logo 2 | 3 | # wann [![Go Report Card](https://goreportcard.com/badge/github.com/xyproto/wann)](https://goreportcard.com/report/github.com/xyproto/wann) [![GoDoc](https://godoc.org/github.com/xyproto/wann?status.svg)](https://godoc.org/github.com/xyproto/wann) 4 | 5 | Weight Agnostic Neural Networks is a new type of neural network, where the weights of all the neurons are shared and the structure of the network is what matters. 6 | 7 | This package implements Weight Agnostic Neural Networks for Go, and is inspired by this paper from June 2019: 8 | 9 | *"Weight Agnostic Neural Networks" by Adam Gaier and David Ha*. ([PDF](https://arxiv.org/pdf/1906.04358.pdf) | [Interactive version](https://weightagnostic.github.io/) | [Google AI blog post](https://ai.googleblog.com/2019/08/exploring-weight-agnostic-neural.html)) 10 | 11 | ## Features and limitations 12 | 13 | * All activation functions are benchmarked at the start of the program and the results are taken into account when calculating the complexity of a network. 14 | * All networks can be translated to a Go statement, using the wonderful [jennifer](https://github.com/dave/jennifer) package (work in progress, there are a few kinks that needs to be ironed out). 15 | * Networks can be saved as `SVG` diagrams. This feature needs more testing. 16 | * Neural networks can be trained and used. See the `cmd` folder for examples. 17 | * A random weight is chosen when training, instead of looping over the range of the weight. The paper describes both methods. 18 | * After the network has been trained, the optimal weight is found by looping over all weights (with a step size of `0.0001`). 19 | * Increased complexity counts negatively when evolving networks. This optimizes not only for less complex networks, but also for execution speed. 20 | * The diagram drawing routine plots the activation functions directly onto the nodes, together with a label. This can be saved as an SVG file. 21 | 22 | ## Example program 23 | 24 | This is a simple example, for creating a network that can recognize one of four shapes: 25 | 26 | ```go 27 | package main 28 | 29 | import ( 30 | "fmt" 31 | "os" 32 | 33 | "github.com/xyproto/wann" 34 | ) 35 | 36 | func main() { 37 | // Here are four shapes, representing: up, down, left and right: 38 | 39 | up := []float64{ 40 | 0.0, 1.0, 0.0, // o 41 | 1.0, 1.0, 1.0} // ooo 42 | 43 | down := []float64{ 44 | 1.0, 1.0, 1.0, // ooo 45 | 0.0, 1.0, 0.0} // o 46 | 47 | left := []float64{ 48 | 1.0, 1.0, 1.0, // ooo 49 | 0.0, 0.0, 1.0} // o 50 | 51 | right := []float64{ 52 | 1.0, 1.0, 1.0, // ooo 53 | 1.0, 0.0, 0.0} // o 54 | 55 | // Prepare the input data as a 2D slice 56 | inputData := [][]float64{ 57 | up, 58 | down, 59 | left, 60 | right, 61 | } 62 | 63 | // Target scores for: up, down, left, right 64 | correctResultsForUp := []float64{1.0, -1.0, -1.0, -1.0} 65 | 66 | // Prepare a neural network configuration struct 67 | config := &wann.Config{ 68 | InitialConnectionRatio: 0.2, 69 | Generations: 2000, 70 | PopulationSize: 500, 71 | Verbose: true, 72 | } 73 | 74 | // Evolve a network, using the input data and the sought after results 75 | trainedNetwork, err := config.Evolve(inputData, correctResultsForUp) 76 | if err != nil { 77 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 78 | os.Exit(1) 79 | } 80 | 81 | // Now to test the trained network on 4 different inputs and see if it passes the test 82 | upScore := trainedNetwork.Evaluate(up) 83 | downScore := trainedNetwork.Evaluate(down) 84 | leftScore := trainedNetwork.Evaluate(left) 85 | rightScore := trainedNetwork.Evaluate(right) 86 | 87 | if config.Verbose { 88 | if upScore > downScore && upScore > leftScore && upScore > rightScore { 89 | fmt.Println("Network training complete, the results are good.") 90 | } else { 91 | fmt.Println("Network training complete, but the results did not pass the test.") 92 | } 93 | } 94 | 95 | // Save the trained network as an SVG image 96 | if config.Verbose { 97 | fmt.Print("Writing network.svg...") 98 | } 99 | if err := trainedNetwork.WriteSVG("network.svg"); err != nil { 100 | fmt.Fprintf(os.Stderr, "error: %s\n", err) 101 | os.Exit(1) 102 | } 103 | if config.Verbose { 104 | fmt.Println("ok") 105 | } 106 | } 107 | ``` 108 | 109 | Here is the resulting network generated by the above program: 110 | 111 | Network 112 | 113 | This makes sense, since taking the third number in the input data (index 2), running it through a swish function and then inverting it should be a usable detector for the `up` pattern. 114 | 115 | * The generated networks may differ for each run. 116 | 117 | ## Quick start 118 | 119 | This requires Go 1.11 or later. 120 | 121 | Clone the repository: 122 | 123 | git clone https://github.com/xyproto/wann 124 | 125 | Enter the `cmd/evolve` directory: 126 | 127 | cd wann/cmd/evolve 128 | 129 | Build and run the example: 130 | 131 | go build && ./evolve 132 | 133 | Take a look at the best network for judging if a set of numbers that are either 0 or 1 are of one category: 134 | 135 | xdg-open network.svg 136 | 137 | (If needed, use your favorite SVG viewer instead of the `xdg-open` command). 138 | 139 | ## Ideas 140 | 141 | * Adding convolution nodes might give interesting results. 142 | 143 | ## Generating Go code from a trained network 144 | 145 | This is an experimental feature and a work in progress! 146 | 147 | The idea is to generate one large expression from all the expressions that each node in the network represents. 148 | 149 | Right now, this only works for networks that has a depth of 1. 150 | 151 | For example, adding these two lines to `cmd/evolve/main.go`: 152 | 153 | ```go 154 | // Output a Go function for this network 155 | fmt.Println(trainedNetwork.GoFunction()) 156 | ``` 157 | 158 | Produces this output: 159 | 160 | ```go 161 | func f(x float64) float64 { return -x } 162 | ``` 163 | 164 | The plan is to output a function that takes the input data instead, and refers to the input data by index. Support for deeper networks also needs to be added. 165 | 166 | There is a complete example for outputting Go code in `cmd/gofunction`. 167 | 168 | ## General info 169 | 170 | * Version: 0.3.2 171 | * License: BSD-3 172 | * Author: Alexander F. Rødseth <xyproto@archlinux.org> 173 | -------------------------------------------------------------------------------- /diagram.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "io/ioutil" 7 | "strconv" 8 | 9 | "github.com/xyproto/tinysvg" 10 | ) 11 | 12 | // OutputSVG will output the current network as an SVG image to the given io.Writer 13 | // TODO: Clean up and refactor 14 | func (net *Network) OutputSVG(w io.Writer) (int, error) { 15 | // Set up margins and the canvas size 16 | var ( 17 | marginLeft = 10 18 | marginTop = 10 19 | marginBottom = 10 20 | marginRight = 10 21 | nodeRadius = 10 22 | betweenPadding = 4 23 | d = float64(net.Depth()) * 2.5 24 | width = marginLeft + int(float64(nodeRadius)*2.0*d) + betweenPadding*(int(d)-1) + nodeRadius + marginRight 25 | l = float64(len(net.InputNodes)) 26 | height = marginTop + int(float64(nodeRadius)*1.5*l) + betweenPadding*(int(l)-1) + marginBottom 27 | imgPadding = 5 28 | lineWidth = 2 29 | ) 30 | 31 | if width < 128 { 32 | width = 128 33 | } 34 | if height < 128 { 35 | height = 128 36 | } 37 | 38 | // Start a new SVG image 39 | document, svg := tinysvg.NewTinySVG(width+imgPadding*2, height+imgPadding*2) 40 | svg.Describe("generated with github.com/xyproto/wann") 41 | 42 | // White background rounded rectangle 43 | bg := svg.AddRoundedRect(imgPadding, imgPadding, 30, 30, width, height) 44 | bg.Fill2(tinysvg.ColorByName("white")) 45 | bg.Stroke2(tinysvg.ColorByName("black")) 46 | 47 | // Position of output node 48 | outputx := width - (marginRight + nodeRadius*2) + imgPadding 49 | outputy := (height-(nodeRadius*2))/2 + imgPadding 50 | 51 | // For each connected neuron, store it with the distance from the output neuron as the key in a map 52 | layerNeurons := make(map[int][]NeuronIndex) 53 | maxDistance := 0 54 | net.ForEachConnectedNodeIndex(func(ni NeuronIndex) { 55 | distanceFromOutput := net.AllNodes[ni].distanceFromOutputNode 56 | layerNeurons[distanceFromOutput] = append(layerNeurons[distanceFromOutput], ni) 57 | if distanceFromOutput > maxDistance { 58 | maxDistance = distanceFromOutput 59 | } 60 | }) 61 | 62 | // Draw the input nodes as circles, and connections to the output node as lines 63 | //for i, n := range net.InputNodes { 64 | columnOffset := 50 65 | 66 | getPosition := func(givenNeuron NeuronIndex) (int, int) { 67 | for outputDistance, neurons := range layerNeurons { 68 | for neuronLayerIndex, otherNeuron := range neurons { 69 | if otherNeuron == givenNeuron { 70 | x := marginLeft + imgPadding + columnOffset*(maxDistance-outputDistance) 71 | y := (neuronLayerIndex * (nodeRadius*2 + betweenPadding)) + marginTop + imgPadding 72 | return x, y 73 | } 74 | } 75 | } 76 | panic("implementation error: neuron index not found") 77 | } 78 | 79 | // Draw node lines first 80 | for _, neurons := range layerNeurons { 81 | for _, neuronIndex := range neurons { 82 | if neuronIndex == net.OutputNode { 83 | continue 84 | } 85 | // Find the position of this node circle 86 | x, y := getPosition(neuronIndex) 87 | // Draw the connection from the center of this node to the center of all input nodes, if applicable 88 | for _, inputNeuron := range (net.AllNodes[neuronIndex]).InputNodes { 89 | ix, iy := getPosition(inputNeuron) 90 | svg.Line(ix+nodeRadius, iy+nodeRadius, x+nodeRadius, y+nodeRadius, lineWidth, "orange") 91 | } 92 | // Draw the connection to the output node, if it has this node as input 93 | if net.AllNodes[net.OutputNode].HasInput(neuronIndex) { 94 | svg.Line(x+nodeRadius, y+nodeRadius, outputx+nodeRadius, outputy+nodeRadius, lineWidth, "#0099ff") 95 | } 96 | } 97 | } 98 | 99 | // Then draw the nodes on top, including graph plots 100 | for _, neurons := range layerNeurons { 101 | for _, neuronIndex := range neurons { 102 | if neuronIndex == net.OutputNode { 103 | continue 104 | } 105 | 106 | // Find the position of this node circle 107 | x, y := getPosition(neuronIndex) 108 | 109 | // Draw this node 110 | input := svg.AddCircle(x+nodeRadius, y+nodeRadius, nodeRadius) 111 | switch net.AllNodes[neuronIndex].distanceFromOutputNode { 112 | case 1, 6: 113 | input.Fill("lightblue") 114 | case 2, 7: 115 | input.Fill("lightgreen") 116 | case 3, 8: 117 | input.Fill("lightyellow") 118 | case 4, 9: 119 | input.Fill("orange") 120 | case 5, 10: 121 | input.Fill("red") 122 | default: 123 | input.Fill("gray") 124 | } 125 | input.Stroke2(tinysvg.ColorByName("black")) 126 | 127 | // Plot the activation function inside this node 128 | var points []*tinysvg.Pos 129 | startx := float64(x) + float64(nodeRadius)*0.5 130 | stopx := float64(x+nodeRadius*2) - float64(nodeRadius)*0.5 131 | ypos := float64(y) 132 | for xpos := startx; xpos < stopx; xpos += 0.2 { 133 | // xr is from 0 to 1 134 | xr := float64(xpos-startx) / float64(stopx-startx) 135 | // xv is from -5 to 5 136 | xv := (xr - 0.5) * float64(nodeRadius) 137 | node := net.AllNodes[neuronIndex] 138 | f := ActivationFunctions[node.ActivationFunction] 139 | yv := f(xv) 140 | // plot, 3.0 is the amplitude along y 141 | yp := float64(ypos) + float64(nodeRadius)*1.35 - (yv * 0.6 * float64(nodeRadius)) 142 | 143 | if yp < (ypos + float64(nodeRadius)*0.1) { 144 | continue 145 | } else if yp > (ypos + float64(nodeRadius)*1.9) { 146 | continue 147 | } 148 | 149 | // Label 150 | name := node.ActivationFunction.Name() 151 | if net.IsInput(neuronIndex) { 152 | // Add a the input number to the name 153 | for i, ni := range net.InputNodes { 154 | if neuronIndex == ni { 155 | name += " [" + strconv.Itoa(i) + "]" 156 | } 157 | } 158 | } else if neuronIndex == net.OutputNode { 159 | name += " !" 160 | } 161 | box := svg.AddRect(int(startx-float64(nodeRadius)*0.4), int(ypos+float64(nodeRadius)*2.5)-5, len(name)*5, 6) 162 | box.Fill("black") 163 | svg.Text(int(startx-float64(nodeRadius)*0.4), int(ypos+float64(nodeRadius)*2.5), 8, "Courier", name, "white") 164 | 165 | p := tinysvg.NewPosf(xpos, yp) 166 | points = append(points, p) 167 | } 168 | // Draw the polyline (graph) 169 | pl := svg.Polyline(points, tinysvg.ColorByName("black")) 170 | pl.Stroke2(tinysvg.ColorByName("black")) 171 | pl.Fill2(tinysvg.ColorByName("none")) 172 | 173 | } 174 | } 175 | 176 | // Draw the output node 177 | output := svg.AddCircle(outputx+nodeRadius+1, outputy+nodeRadius+1, nodeRadius) 178 | output.Fill("magenta") 179 | output.Stroke2(tinysvg.ColorByName("black")) 180 | 181 | // Label 182 | name := net.AllNodes[net.OutputNode].ActivationFunction.Name() + " [o]" 183 | box := svg.AddRect(outputx-nodeRadius/2, (nodeRadius*2)+outputy+1, len(name)*5, 6) 184 | box.Fill("black") 185 | svg.Text(outputx-nodeRadius/2, (nodeRadius*2)+outputy+6, 8, "Courier", name, "white") 186 | 187 | // Write the data to the given io.Writer 188 | return w.Write(document.Bytes()) 189 | } 190 | 191 | // WriteSVG saves a drawing of the current network as an SVG file 192 | func (net *Network) WriteSVG(filename string) error { 193 | var buf bytes.Buffer 194 | if _, err := net.OutputSVG(&buf); err != nil { 195 | return err 196 | } 197 | return ioutil.WriteFile(filename, buf.Bytes(), 0644) 198 | } 199 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/file.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "regexp" 7 | "strings" 8 | "unicode" 9 | "unicode/utf8" 10 | ) 11 | 12 | // NewFile Creates a new file, with the specified package name. 13 | func NewFile(packageName string) *File { 14 | return &File{ 15 | Group: &Group{ 16 | multi: true, 17 | }, 18 | name: packageName, 19 | imports: map[string]importdef{}, 20 | hints: map[string]importdef{}, 21 | } 22 | } 23 | 24 | // NewFilePath creates a new file while specifying the package path - the 25 | // package name is inferred from the path. 26 | func NewFilePath(packagePath string) *File { 27 | return &File{ 28 | Group: &Group{ 29 | multi: true, 30 | }, 31 | name: guessAlias(packagePath), 32 | path: packagePath, 33 | imports: map[string]importdef{}, 34 | hints: map[string]importdef{}, 35 | } 36 | } 37 | 38 | // NewFilePathName creates a new file with the specified package path and name. 39 | func NewFilePathName(packagePath, packageName string) *File { 40 | return &File{ 41 | Group: &Group{ 42 | multi: true, 43 | }, 44 | name: packageName, 45 | path: packagePath, 46 | imports: map[string]importdef{}, 47 | hints: map[string]importdef{}, 48 | } 49 | } 50 | 51 | // File represents a single source file. Package imports are managed 52 | // automaticaly by File. 53 | type File struct { 54 | *Group 55 | name string 56 | path string 57 | imports map[string]importdef 58 | hints map[string]importdef 59 | comments []string 60 | headers []string 61 | cgoPreamble []string 62 | // If you're worried about generated package aliases conflicting with local variable names, you 63 | // can set a prefix here. Package foo becomes {prefix}_foo. 64 | PackagePrefix string 65 | // CanonicalPath adds a canonical import path annotation to the package clause. 66 | CanonicalPath string 67 | } 68 | 69 | // importdef is used to differentiate packages where we know the package name from packages where the 70 | // import is aliased. If alias == false, then name is the actual package name, and the import will be 71 | // rendered without an alias. If used == false, the import has not been used in code yet and should be 72 | // excluded from the import block. 73 | type importdef struct { 74 | name string 75 | alias bool 76 | } 77 | 78 | // HeaderComment adds a comment to the top of the file, above any package 79 | // comments. A blank line is rendered below the header comments, ensuring 80 | // header comments are not included in the package doc. 81 | func (f *File) HeaderComment(comment string) { 82 | f.headers = append(f.headers, comment) 83 | } 84 | 85 | // PackageComment adds a comment to the top of the file, above the package 86 | // keyword. 87 | func (f *File) PackageComment(comment string) { 88 | f.comments = append(f.comments, comment) 89 | } 90 | 91 | // CgoPreamble adds a cgo preamble comment that is rendered directly before the "C" pseudo-package 92 | // import. 93 | func (f *File) CgoPreamble(comment string) { 94 | f.cgoPreamble = append(f.cgoPreamble, comment) 95 | } 96 | 97 | // Anon adds an anonymous import. 98 | func (f *File) Anon(paths ...string) { 99 | for _, p := range paths { 100 | f.imports[p] = importdef{name: "_", alias: true} 101 | } 102 | } 103 | 104 | // ImportName provides the package name for a path. If specified, the alias will be omitted from the 105 | // import block. This is optional. If not specified, a sensible package name is used based on the path 106 | // and this is added as an alias in the import block. 107 | func (f *File) ImportName(path, name string) { 108 | f.hints[path] = importdef{name: name, alias: false} 109 | } 110 | 111 | // ImportNames allows multiple names to be imported as a map. Use the [gennames](gennames) command to 112 | // automatically generate a go file containing a map of a selection of package names. 113 | func (f *File) ImportNames(names map[string]string) { 114 | for path, name := range names { 115 | f.hints[path] = importdef{name: name, alias: false} 116 | } 117 | } 118 | 119 | // ImportAlias provides the alias for a package path that should be used in the import block. A 120 | // period can be used to force a dot-import. 121 | func (f *File) ImportAlias(path, alias string) { 122 | f.hints[path] = importdef{name: alias, alias: true} 123 | } 124 | 125 | func (f *File) isLocal(path string) bool { 126 | return f.path == path 127 | } 128 | 129 | func (f *File) isValidAlias(alias string) bool { 130 | // multiple dot-imports are ok 131 | if alias == "." { 132 | return true 133 | } 134 | // the import alias is invalid if it's a reserved word 135 | if IsReservedWord(alias) { 136 | return false 137 | } 138 | // the import alias is invalid if it's already been registered 139 | for _, v := range f.imports { 140 | if alias == v.name { 141 | return false 142 | } 143 | } 144 | return true 145 | } 146 | 147 | func (f *File) isDotImport(path string) bool { 148 | if id, ok := f.hints[path]; ok { 149 | return id.name == "." && id.alias 150 | } 151 | return false 152 | } 153 | 154 | func (f *File) register(path string) string { 155 | if f.isLocal(path) { 156 | // notest 157 | // should never get here becasue in Qual the packageToken will be null, 158 | // so render will never be called. 159 | return "" 160 | } 161 | 162 | // if the path has been registered previously, simply return the name 163 | def := f.imports[path] 164 | if def.name != "" && def.name != "_" { 165 | return def.name 166 | } 167 | 168 | // special case for "C" pseudo-package 169 | if path == "C" { 170 | f.imports["C"] = importdef{name: "C", alias: false} 171 | return "C" 172 | } 173 | 174 | var name string 175 | var alias bool 176 | 177 | if hint := f.hints[path]; hint.name != "" { 178 | // look up the path in the list of provided package names and aliases by ImportName / ImportAlias 179 | name = hint.name 180 | alias = hint.alias 181 | } else if standardLibraryHints[path] != "" { 182 | // look up the path in the list of standard library packages 183 | name = standardLibraryHints[path] 184 | alias = false 185 | } else { 186 | // if a hint is not found for the package, guess the alias from the package path 187 | name = guessAlias(path) 188 | alias = true 189 | } 190 | 191 | // If the name is invalid or has been registered already, make it unique by appending a number 192 | unique := name 193 | i := 0 194 | for !f.isValidAlias(unique) { 195 | i++ 196 | unique = fmt.Sprintf("%s%d", name, i) 197 | } 198 | 199 | // If we've changed the name to make it unique, it should definitely be an alias 200 | if unique != name { 201 | alias = true 202 | } 203 | 204 | // Only add a prefix if the name is an alias 205 | if f.PackagePrefix != "" && alias { 206 | unique = f.PackagePrefix + "_" + unique 207 | } 208 | 209 | // Register the eventual name 210 | f.imports[path] = importdef{name: unique, alias: alias} 211 | 212 | return unique 213 | } 214 | 215 | // GoString renders the File for testing. Any error will cause a panic. 216 | func (f *File) GoString() string { 217 | buf := &bytes.Buffer{} 218 | if err := f.Render(buf); err != nil { 219 | panic(err) 220 | } 221 | return buf.String() 222 | } 223 | 224 | func guessAlias(path string) string { 225 | alias := path 226 | 227 | if strings.HasSuffix(alias, "/") { 228 | // training slashes are usually tolerated, so we can get rid of one if 229 | // it exists 230 | alias = alias[:len(alias)-1] 231 | } 232 | 233 | if strings.Contains(alias, "/") { 234 | // if the path contains a "/", use the last part 235 | alias = alias[strings.LastIndex(alias, "/")+1:] 236 | } 237 | 238 | // alias should be lower case 239 | alias = strings.ToLower(alias) 240 | 241 | // alias should now only contain alphanumerics 242 | importsRegex := regexp.MustCompile(`[^a-z0-9]`) 243 | alias = importsRegex.ReplaceAllString(alias, "") 244 | 245 | // can't have a first digit, per Go identifier rules, so just skip them 246 | for firstRune, runeLen := utf8.DecodeRuneInString(alias); unicode.IsDigit(firstRune); firstRune, runeLen = utf8.DecodeRuneInString(alias) { 247 | alias = alias[runeLen:] 248 | } 249 | 250 | // If path part was all digits, we may be left with an empty string. In this case use "pkg" as the alias. 251 | if alias == "" { 252 | alias = "pkg" 253 | } 254 | 255 | return alias 256 | } 257 | -------------------------------------------------------------------------------- /af.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/dave/jennifer/jen" 8 | "github.com/xyproto/af" 9 | ) 10 | 11 | // ActivationFunctionIndex is a number that represents a specific activation function 12 | type ActivationFunctionIndex int 13 | 14 | const ( 15 | // Step is a step. First 0 and then abrubtly up to 1. 16 | Step ActivationFunctionIndex = iota 17 | // Linear is the linear activation function. Gradually from 0 to 1. 18 | Linear 19 | // Sin is the sinoid activation function 20 | Sin 21 | // Gauss is the Gaussian function, with a mean of 0 and a sigma of 1 22 | Gauss 23 | // Tanh is math.Tanh 24 | Tanh 25 | // Sigmoid is the optimized sigmoid function from github.com/xyproto/swish 26 | Sigmoid 27 | // Inv is the inverse linear function 28 | Inv 29 | // Abs is math.Abs 30 | Abs 31 | // ReLU or ReLU is the rectified linear unit, first 0 and then the linear function 32 | ReLU 33 | // Cos is the cosoid (?) activation function 34 | Cos 35 | // Squared increases rapidly 36 | Squared 37 | // Swish is a later invention than ReLU, _| 38 | Swish 39 | // SoftPlus is log(1 + exp(x)) 40 | SoftPlus 41 | ) 42 | 43 | // ActivationFunctions is a collection of activation functions, where the keys are constants that are defined above 44 | // https://github.com/google/brain-tokyo-workshop/blob/master/WANNRelease/WANN/wann_src/ind.py 45 | var ActivationFunctions = map[ActivationFunctionIndex](func(float64) float64){ 46 | Step: af.Step, // Unsigned Step Function 47 | Linear: af.Linear, // Linear 48 | Sin: af.Sin, // Sin 49 | Gauss: af.Gaussian01, // Gaussian with mean 0 and sigma 1 50 | Tanh: af.Tanh, // Hyperbolic Tangent (signed?) 51 | Sigmoid: af.Sigmoid, // Sigmoid (unsigned?) 52 | Inv: af.Inv, // Inverse 53 | Abs: af.Abs, // Absolute value 54 | ReLU: af.ReLU, // Rectified linear unit 55 | Cos: af.Cos, // Cosine 56 | Squared: af.Squared, // Squared 57 | Swish: af.Swish, // Swish 58 | SoftPlus: af.SoftPlus, // SoftPlus 59 | } 60 | 61 | // ComplexityEstimate is a map for having an estimate of how complex each function is, 62 | // based on a quick benchmark of each function. 63 | // The complexity estimates will vary, depending on the performance. 64 | var ComplexityEstimate = make(map[ActivationFunctionIndex]float64) 65 | 66 | func (config *Config) estimateComplexity() { 67 | if config.Verbose { 68 | fmt.Print("Estimating activation function complexity...") 69 | } 70 | startEstimate := time.Now() 71 | resolution := 0.0001 72 | durationMap := make(map[ActivationFunctionIndex]time.Duration) 73 | var maxDuration time.Duration 74 | for i, f := range ActivationFunctions { 75 | start := time.Now() 76 | for x := 0.0; x <= 1.0; x += resolution { 77 | _ = f(x) 78 | } 79 | duration := time.Since(start) 80 | durationMap[ActivationFunctionIndex(i)] = duration 81 | if duration > maxDuration { 82 | maxDuration = duration 83 | } 84 | } 85 | for i := range ActivationFunctions { 86 | // 1.0 means the function took maxDuration 87 | ComplexityEstimate[ActivationFunctionIndex(i)] = float64(durationMap[ActivationFunctionIndex(i)]) / float64(maxDuration) 88 | } 89 | estimateDuration := time.Since(startEstimate) 90 | if config.Verbose { 91 | fmt.Printf(" done. (In %v)\n", estimateDuration) 92 | } 93 | } 94 | 95 | // Call runs an activation function with the given float64 value. 96 | // The activation function is chosen by one of the constants above. 97 | func (afi ActivationFunctionIndex) Call(x float64) float64 { 98 | if f, ok := ActivationFunctions[afi]; ok { 99 | return f(x) 100 | } 101 | // Use the linear function by default 102 | return af.Linear(x) 103 | } 104 | 105 | // Name returns a name for each activation function 106 | func (afi ActivationFunctionIndex) Name() string { 107 | switch afi { 108 | case Step: 109 | return "Step" 110 | case Linear: 111 | return "Linear" 112 | case Sin: 113 | return "Sinusoid" 114 | case Gauss: 115 | return "Gaussian" 116 | case Tanh: 117 | return "Tanh" 118 | case Sigmoid: 119 | return "Sigmoid" 120 | case Inv: 121 | return "Inverted" 122 | case Abs: 123 | return "Absolute" 124 | case ReLU: 125 | return "ReLU" 126 | case Cos: 127 | return "Cosinusoid" 128 | case Squared: 129 | return "Squared" 130 | case Swish: 131 | return "Swish" 132 | case SoftPlus: 133 | return "SoftPlus" 134 | default: 135 | return "Untitled" 136 | } 137 | } 138 | 139 | // goExpression returns the Go expression for this activation function, using the given variable name string as the input variable name 140 | func (afi ActivationFunctionIndex) goExpression(varName string) string { 141 | switch afi { 142 | case Step: 143 | // Using s to not confuse it with the varName 144 | return "func(s float64) float64 { if s >= 0 { return 1 } else { return 0 } }(" + varName + ")" 145 | case Linear: 146 | return varName 147 | case Sin: 148 | return "math.Sin(math.Pi * " + varName + ")" 149 | case Gauss: 150 | return "math.Exp(-(" + varName + " * " + varName + ") / 2.0)" 151 | case Tanh: 152 | return "math.Tanh(" + varName + ")" 153 | case Sigmoid: 154 | return "(1.0 / (1.0 + math.Exp(-" + varName + ")))" 155 | case Inv: 156 | return "-" + varName 157 | case Abs: 158 | return "math.Abs(" + varName + ")" 159 | case ReLU: 160 | // Using r to not confuse it with the varName 161 | return "func(r float64) float64 { if r >= 0 { return r } else { return 0 } }(" + varName + ")" 162 | case Cos: 163 | return "math.Cos(math.Pi * " + varName + ")" 164 | case Squared: 165 | return "(" + varName + " * " + varName + ")" 166 | case Swish: 167 | return "(" + varName + "/ (1.0 + math.Exp(-" + varName + ")))" 168 | case SoftPlus: 169 | return "math.Log(1.0 + math.Exp(" + varName + "))" 170 | default: 171 | return varName 172 | } 173 | } 174 | 175 | // String returns the Go expression for this activation function, using "x" as the input variable name 176 | func (afi ActivationFunctionIndex) String() string { 177 | return afi.goExpression("x") 178 | } 179 | 180 | // Statement returns the Statement statement for this activation function, using the given inner statement 181 | func (afi ActivationFunctionIndex) Statement(inner *jen.Statement) *jen.Statement { 182 | switch afi { 183 | case Step: 184 | // func(s float64) float64 { if s >= 0 { return 1 } else { return 0 } }(inner) 185 | // Using s to not confuse it with the varName 186 | return jen.Func().Params(jen.Id("s").Id("float64")).Id("float64").Block( 187 | jen.If(jen.Id("s").Op(">=").Id("0")).Block( 188 | jen.Return(jen.Lit(1)), 189 | ).Else().Block( 190 | jen.Return(jen.Lit(0)), 191 | ), 192 | ).Call(inner) 193 | case Cos: 194 | // math.Cos((inner) * math.Pi) 195 | return jen.Qual("math", "Cos").Call(jen.Parens(inner).Op("*").Id("math").Dot("Pi")) 196 | case Sin: 197 | // math.Sin((inner) * math.Pi) 198 | return jen.Qual("math", "Sin").Call(jen.Parens(inner).Op("*").Id("math").Dot("Pi")) 199 | case Gauss: 200 | // return math.Exp(-(math.Pow(inner, 2.0)) / 2.0) 201 | return jen.Qual("math", "Exp").Call(jen.Op("-").Parens( 202 | // Using math.Pow ensures the inner expression is only calculated once, if it's a large expression 203 | //inner.Op("*").Add(inner), 204 | jen.Qual("math", "Pow").Params( 205 | inner, 206 | jen.Lit(2.0), 207 | ), 208 | ).Op("/").Lit(2.0)) 209 | case Tanh: 210 | // math.Tanh(inner) 211 | return jen.Qual("math", "Tanh").Call(inner) 212 | case Sigmoid: 213 | // (1.0 / (1.0 + math.Exp(-(inner)))) 214 | return jen.Lit(1.0).Op("/").Parens(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(jen.Op("-").Parens(inner))) 215 | case Inv: 216 | // -(inner) 217 | return jen.Op("-").Parens(inner) 218 | case Abs: 219 | // math.Abs(inner) 220 | return jen.Qual("math", "Abs").Call(inner) 221 | case ReLU: 222 | //return "func(r float64) float64 { if r >= 0 { return r } else { return 0 } }(" + varName + ")" 223 | // Using r to not confuse it with the varName 224 | return jen.Func().Params(jen.Id("r").Id("float64")).Id("float64").Block( 225 | jen.If(jen.Id("r").Op(">=").Id("0")).Block( 226 | jen.Return(jen.Id("r")), 227 | ).Else().Block( 228 | jen.Return(jen.Lit(0)), 229 | ), 230 | ).Call(inner) 231 | case Squared: 232 | // inner^2 233 | //return inner.Op("*").Add(inner) 234 | // Using math.Pow ensures the inner expression is only calculated once, if it's a large expression 235 | return jen.Qual("math", "Pow").Call(inner, jen.Lit(2.0)) 236 | case Swish: 237 | // (inner / (1.0 + math.Exp(-inner))) 238 | return jen.Parens(inner.Op("/").Parens(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(jen.Op("-").Parens(inner)))) 239 | case SoftPlus: 240 | // math.Log(1.0 + math.Exp(inner)) 241 | return jen.Qual("math", "Log").Call(jen.Lit(1.0).Op("+").Qual("math", "Exp").Call(inner)) 242 | case Linear: 243 | // This is also the default case: (inner) 244 | fallthrough 245 | default: 246 | // (inner) 247 | return jen.Parens(inner) 248 | } 249 | } 250 | 251 | // GoRun will first construct the expression using jennifer and then evaluate the result using "go run" and a source file innn /tmp 252 | func (afi ActivationFunctionIndex) GoRun(x float64) (float64, error) { 253 | return RunStatementX(afi.Statement(jen.Id("x")), x) 254 | } 255 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/tokens.go: -------------------------------------------------------------------------------- 1 | package jen 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "strconv" 7 | "strings" 8 | ) 9 | 10 | type tokenType string 11 | 12 | const ( 13 | packageToken tokenType = "package" 14 | identifierToken tokenType = "identifier" 15 | qualifiedToken tokenType = "qualified" 16 | keywordToken tokenType = "keyword" 17 | operatorToken tokenType = "operator" 18 | delimiterToken tokenType = "delimiter" 19 | literalToken tokenType = "literal" 20 | literalRuneToken tokenType = "literal_rune" 21 | literalByteToken tokenType = "literal_byte" 22 | nullToken tokenType = "null" 23 | layoutToken tokenType = "layout" 24 | ) 25 | 26 | type token struct { 27 | typ tokenType 28 | content interface{} 29 | } 30 | 31 | func (t token) isNull(f *File) bool { 32 | if t.typ == packageToken { 33 | // package token is null if the path is a dot-import or the local package path 34 | return f.isDotImport(t.content.(string)) || f.isLocal(t.content.(string)) 35 | } 36 | return t.typ == nullToken 37 | } 38 | 39 | func (t token) render(f *File, w io.Writer, s *Statement) error { 40 | switch t.typ { 41 | case literalToken: 42 | var out string 43 | switch t.content.(type) { 44 | case bool, string, int, complex128: 45 | // default constant types can be left bare 46 | out = fmt.Sprintf("%#v", t.content) 47 | case float64: 48 | out = fmt.Sprintf("%#v", t.content) 49 | if !strings.Contains(out, ".") && !strings.Contains(out, "e") { 50 | // If the formatted value is not in scientific notation, and does not have a dot, then 51 | // we add ".0". Otherwise it will be interpreted as an int. 52 | // See: 53 | // https://github.com/dave/jennifer/issues/39 54 | // https://github.com/golang/go/issues/26363 55 | out += ".0" 56 | } 57 | case float32, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr: 58 | // other built-in types need specific type info 59 | out = fmt.Sprintf("%T(%#v)", t.content, t.content) 60 | case complex64: 61 | // fmt package already renders parenthesis for complex64 62 | out = fmt.Sprintf("%T%#v", t.content, t.content) 63 | default: 64 | panic(fmt.Sprintf("unsupported type for literal: %T", t.content)) 65 | } 66 | if _, err := w.Write([]byte(out)); err != nil { 67 | return err 68 | } 69 | case literalRuneToken: 70 | if _, err := w.Write([]byte(strconv.QuoteRune(t.content.(rune)))); err != nil { 71 | return err 72 | } 73 | case literalByteToken: 74 | if _, err := w.Write([]byte(fmt.Sprintf("byte(%#v)", t.content))); err != nil { 75 | return err 76 | } 77 | case keywordToken, operatorToken, layoutToken, delimiterToken: 78 | if _, err := w.Write([]byte(fmt.Sprintf("%s", t.content))); err != nil { 79 | return err 80 | } 81 | if t.content.(string) == "default" { 82 | // Special case for Default, which must always be followed by a colon 83 | if _, err := w.Write([]byte(":")); err != nil { 84 | return err 85 | } 86 | } 87 | case packageToken: 88 | path := t.content.(string) 89 | alias := f.register(path) 90 | if _, err := w.Write([]byte(alias)); err != nil { 91 | return err 92 | } 93 | case identifierToken: 94 | if _, err := w.Write([]byte(t.content.(string))); err != nil { 95 | return err 96 | } 97 | case nullToken: // notest 98 | // do nothing (should never render a null token) 99 | } 100 | return nil 101 | } 102 | 103 | // Null adds a null item. Null items render nothing and are not followed by a 104 | // separator in lists. 105 | func Null() *Statement { 106 | return newStatement().Null() 107 | } 108 | 109 | // Null adds a null item. Null items render nothing and are not followed by a 110 | // separator in lists. 111 | func (g *Group) Null() *Statement { 112 | s := Null() 113 | g.items = append(g.items, s) 114 | return s 115 | } 116 | 117 | // Null adds a null item. Null items render nothing and are not followed by a 118 | // separator in lists. 119 | func (s *Statement) Null() *Statement { 120 | t := token{ 121 | typ: nullToken, 122 | } 123 | *s = append(*s, t) 124 | return s 125 | } 126 | 127 | // Empty adds an empty item. Empty items render nothing but are followed by a 128 | // separator in lists. 129 | func Empty() *Statement { 130 | return newStatement().Empty() 131 | } 132 | 133 | // Empty adds an empty item. Empty items render nothing but are followed by a 134 | // separator in lists. 135 | func (g *Group) Empty() *Statement { 136 | s := Empty() 137 | g.items = append(g.items, s) 138 | return s 139 | } 140 | 141 | // Empty adds an empty item. Empty items render nothing but are followed by a 142 | // separator in lists. 143 | func (s *Statement) Empty() *Statement { 144 | t := token{ 145 | typ: operatorToken, 146 | content: "", 147 | } 148 | *s = append(*s, t) 149 | return s 150 | } 151 | 152 | // Op renders the provided operator / token. 153 | func Op(op string) *Statement { 154 | return newStatement().Op(op) 155 | } 156 | 157 | // Op renders the provided operator / token. 158 | func (g *Group) Op(op string) *Statement { 159 | s := Op(op) 160 | g.items = append(g.items, s) 161 | return s 162 | } 163 | 164 | // Op renders the provided operator / token. 165 | func (s *Statement) Op(op string) *Statement { 166 | t := token{ 167 | typ: operatorToken, 168 | content: op, 169 | } 170 | *s = append(*s, t) 171 | return s 172 | } 173 | 174 | // Dot renders a period followed by an identifier. Use for fields and selectors. 175 | func Dot(name string) *Statement { 176 | // notest 177 | // don't think this can be used in valid code? 178 | return newStatement().Dot(name) 179 | } 180 | 181 | // Dot renders a period followed by an identifier. Use for fields and selectors. 182 | func (g *Group) Dot(name string) *Statement { 183 | // notest 184 | // don't think this can be used in valid code? 185 | s := Dot(name) 186 | g.items = append(g.items, s) 187 | return s 188 | } 189 | 190 | // Dot renders a period followed by an identifier. Use for fields and selectors. 191 | func (s *Statement) Dot(name string) *Statement { 192 | d := token{ 193 | typ: delimiterToken, 194 | content: ".", 195 | } 196 | t := token{ 197 | typ: identifierToken, 198 | content: name, 199 | } 200 | *s = append(*s, d, t) 201 | return s 202 | } 203 | 204 | // Id renders an identifier. 205 | func Id(name string) *Statement { 206 | return newStatement().Id(name) 207 | } 208 | 209 | // Id renders an identifier. 210 | func (g *Group) Id(name string) *Statement { 211 | s := Id(name) 212 | g.items = append(g.items, s) 213 | return s 214 | } 215 | 216 | // Id renders an identifier. 217 | func (s *Statement) Id(name string) *Statement { 218 | t := token{ 219 | typ: identifierToken, 220 | content: name, 221 | } 222 | *s = append(*s, t) 223 | return s 224 | } 225 | 226 | // Qual renders a qualified identifier. Imports are automatically added when 227 | // used with a File. If the path matches the local path, the package name is 228 | // omitted. If package names conflict they are automatically renamed. Note that 229 | // it is not possible to reliably determine the package name given an arbitrary 230 | // package path, so a sensible name is guessed from the path and added as an 231 | // alias. The names of all standard library packages are known so these do not 232 | // need to be aliased. If more control is needed of the aliases, see 233 | // [File.ImportName](#importname) or [File.ImportAlias](#importalias). 234 | func Qual(path, name string) *Statement { 235 | return newStatement().Qual(path, name) 236 | } 237 | 238 | // Qual renders a qualified identifier. Imports are automatically added when 239 | // used with a File. If the path matches the local path, the package name is 240 | // omitted. If package names conflict they are automatically renamed. Note that 241 | // it is not possible to reliably determine the package name given an arbitrary 242 | // package path, so a sensible name is guessed from the path and added as an 243 | // alias. The names of all standard library packages are known so these do not 244 | // need to be aliased. If more control is needed of the aliases, see 245 | // [File.ImportName](#importname) or [File.ImportAlias](#importalias). 246 | func (g *Group) Qual(path, name string) *Statement { 247 | s := Qual(path, name) 248 | g.items = append(g.items, s) 249 | return s 250 | } 251 | 252 | // Qual renders a qualified identifier. Imports are automatically added when 253 | // used with a File. If the path matches the local path, the package name is 254 | // omitted. If package names conflict they are automatically renamed. Note that 255 | // it is not possible to reliably determine the package name given an arbitrary 256 | // package path, so a sensible name is guessed from the path and added as an 257 | // alias. The names of all standard library packages are known so these do not 258 | // need to be aliased. If more control is needed of the aliases, see 259 | // [File.ImportName](#importname) or [File.ImportAlias](#importalias). 260 | func (s *Statement) Qual(path, name string) *Statement { 261 | g := &Group{ 262 | close: "", 263 | items: []Code{ 264 | token{ 265 | typ: packageToken, 266 | content: path, 267 | }, 268 | token{ 269 | typ: identifierToken, 270 | content: name, 271 | }, 272 | }, 273 | name: "qual", 274 | open: "", 275 | separator: ".", 276 | } 277 | *s = append(*s, g) 278 | return s 279 | } 280 | 281 | // Line inserts a blank line. 282 | func Line() *Statement { 283 | return newStatement().Line() 284 | } 285 | 286 | // Line inserts a blank line. 287 | func (g *Group) Line() *Statement { 288 | s := Line() 289 | g.items = append(g.items, s) 290 | return s 291 | } 292 | 293 | // Line inserts a blank line. 294 | func (s *Statement) Line() *Statement { 295 | t := token{ 296 | typ: layoutToken, 297 | content: "\n", 298 | } 299 | *s = append(*s, t) 300 | return s 301 | } 302 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/tinysvg/tags.go: -------------------------------------------------------------------------------- 1 | // Package tinysvg has structs and functions for creating and rendering TinySVG images 2 | package tinysvg 3 | 4 | // Everything here deals with bytes, not strings 5 | // TODO: Add a Write function that takes an io.Writer so that the image can be written as it is generated. 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "io/ioutil" 11 | ) 12 | 13 | // Tag represents an XML tag, as part of a larger XML document 14 | type Tag struct { 15 | name []byte 16 | content []byte 17 | lastContent []byte 18 | xmlContent []byte 19 | attrs map[string][]byte 20 | nextSibling *Tag // siblings 21 | firstChild *Tag // first child 22 | } 23 | 24 | // Document is an XML document, with a title and a root tag 25 | type Document struct { 26 | title []byte 27 | root *Tag 28 | } 29 | 30 | // NewDocument creates a new XML/HTML/SVG image, with a root tag. 31 | // If rootTagName contains "<" or ">", it can be used for preceding declarations, 32 | // like or . 33 | // Returns a pointer to a Document. 34 | func NewDocument(title, rootTagName []byte) *Document { 35 | var image Document 36 | image.title = []byte(title) 37 | rootTag := NewTag([]byte(rootTagName)) 38 | image.root = rootTag 39 | return &image 40 | } 41 | 42 | // NewTag creates a new tag based on the given name. 43 | // "name" is what will appear right after "<" when rendering as XML/HTML/SVG. 44 | func NewTag(name []byte) *Tag { 45 | var tag Tag 46 | tag.name = name 47 | tag.attrs = make(map[string][]byte) 48 | tag.nextSibling = nil 49 | tag.firstChild = nil 50 | tag.content = make([]byte, 0) 51 | tag.lastContent = make([]byte, 0) 52 | return &tag 53 | } 54 | 55 | // AddNewTag adds a new tag to another tag. This will place it one step lower 56 | // in the hierarchy of tags. You can for example add a body tag to an html tag. 57 | func (tag *Tag) AddNewTag(name []byte) *Tag { 58 | child := NewTag(name) 59 | tag.AddChild(child) 60 | return child 61 | } 62 | 63 | // AddTag adds a tag to another tag 64 | func (tag *Tag) AddTag(child *Tag) { 65 | tag.AddChild(child) 66 | } 67 | 68 | // AddAttrib adds an attribute to a tag, for instance "size" and "20" 69 | func (tag *Tag) AddAttrib(attrName string, attrValue []byte) { 70 | tag.attrs[attrName] = attrValue 71 | } 72 | 73 | // AddSingularAttrib adds attribute without a value 74 | func (tag *Tag) AddSingularAttrib(attrName string) { 75 | tag.attrs[attrName] = nil 76 | } 77 | 78 | // GetAttrString returns a []byte that represents all the attribute keys and 79 | // values of a tag. This can be used when generating XML, SVG or HTML. 80 | func (tag *Tag) GetAttrString() []byte { 81 | ret := make([]byte, 0) 82 | for key, value := range tag.attrs { 83 | if value == nil { 84 | ret = append(ret, key...) 85 | ret = append(ret, ' ') 86 | } else { 87 | ret = append(ret, key...) 88 | ret = append(ret, []byte("=\"")...) 89 | ret = append(ret, value...) 90 | ret = append(ret, []byte("\" ")...) 91 | } 92 | } 93 | if len(ret) > 0 { 94 | ret = ret[:len(ret)-1] 95 | } 96 | return ret 97 | } 98 | 99 | // getFlatXML renders XML. 100 | // This will generate a []byte for a tag, non-recursively. 101 | func (tag *Tag) getFlatXML() []byte { 102 | // For the root tag 103 | if (len(tag.name) > 0) && (tag.name[0] == '<') { 104 | ret := make([]byte, 0, len(tag.name)+len(tag.content)+len(tag.xmlContent)+len(tag.lastContent)) 105 | ret = append(ret, tag.name...) 106 | ret = append(ret, tag.content...) 107 | ret = append(ret, tag.xmlContent...) 108 | ret = append(ret, tag.lastContent...) 109 | return ret 110 | } 111 | // For indenting 112 | spacing := make([]byte, 0) 113 | // Generate the XML based on the tag 114 | attrs := tag.GetAttrString() 115 | ret := make([]byte, 0) 116 | ret = append(ret, spacing...) 117 | ret = append(ret, []byte("<")...) 118 | ret = append(ret, tag.name...) 119 | if len(attrs) > 0 { 120 | ret = append(ret, []byte(" ")...) 121 | ret = append(ret, attrs...) 122 | } 123 | if (len(tag.content) == 0) && (len(tag.xmlContent) == 0) && (len(tag.lastContent) == 0) { 124 | ret = append(ret, []byte(" />")...) 125 | } else { 126 | if len(tag.xmlContent) > 0 { 127 | if tag.xmlContent[0] != ' ' { 128 | ret = append(ret, []byte(">")...) 129 | ret = append(ret, spacing...) 130 | ret = append(ret, tag.xmlContent...) 131 | ret = append(ret, spacing...) 132 | ret = append(ret, []byte("")...) 135 | } else { 136 | ret = append(ret, []byte(">")...) 137 | ret = append(ret, tag.xmlContent...) 138 | ret = append(ret, spacing...) 139 | ret = append(ret, []byte("")...) 142 | } 143 | } else { 144 | ret = append(ret, []byte(">")...) 145 | ret = append(ret, tag.content...) 146 | ret = append(ret, tag.lastContent...) 147 | ret = append(ret, []byte("")...) 150 | } 151 | } 152 | return ret 153 | } 154 | 155 | // GetChildren returns all children for a given tag. 156 | // Returns a slice of pointers to tags. 157 | func (tag *Tag) GetChildren() []*Tag { 158 | var children []*Tag 159 | current := tag.firstChild 160 | for current != nil { 161 | children = append(children, current) 162 | current = current.nextSibling 163 | } 164 | return children 165 | } 166 | 167 | // AddChild adds a tag as a child to another tag 168 | func (tag *Tag) AddChild(child *Tag) { 169 | if tag.firstChild == nil { 170 | tag.firstChild = child 171 | return 172 | } 173 | lastChild := tag.LastChild() 174 | child.nextSibling = nil 175 | lastChild.nextSibling = child 176 | } 177 | 178 | // AddContent adds text to a tag. 179 | // This is what will appear between two tag markers, for example: 180 | // content 181 | // If the tag contains child tags, they will be rendered after this content. 182 | func (tag *Tag) AddContent(content []byte) { 183 | tag.content = append(tag.content, content...) 184 | } 185 | 186 | // AppendContent appends content to the end of the existing content of a tag 187 | func (tag *Tag) AppendContent(content []byte) { 188 | tag.lastContent = append(tag.lastContent, content...) 189 | } 190 | 191 | // AddLastContent appends content to the end of the existing content of a tag. 192 | // Deprecated. 193 | func (tag *Tag) AddLastContent(content []byte) { 194 | tag.AppendContent(content) 195 | } 196 | 197 | // CountChildren returns the number of children a tag has 198 | func (tag *Tag) CountChildren() int { 199 | child := tag.firstChild 200 | if child == nil { 201 | return 0 202 | } 203 | count := 1 204 | if child.nextSibling == nil { 205 | return count 206 | } 207 | child = child.nextSibling 208 | for child != nil { 209 | count++ 210 | child = child.nextSibling 211 | } 212 | return count 213 | } 214 | 215 | // CountSiblings returns the number of siblings a tag has 216 | func (tag *Tag) CountSiblings() int { 217 | sib := tag.nextSibling 218 | if sib == nil { 219 | return 0 220 | } 221 | count := 1 222 | if sib.nextSibling == nil { 223 | return count 224 | } 225 | sib = sib.nextSibling 226 | for sib != nil { 227 | count++ 228 | sib = sib.nextSibling 229 | } 230 | return count 231 | } 232 | 233 | // LastChild returns the last child of a tag 234 | func (tag *Tag) LastChild() *Tag { 235 | child := tag.firstChild 236 | for child.nextSibling != nil { 237 | child = child.nextSibling 238 | } 239 | return child 240 | } 241 | 242 | // GetTag searches all tags for the given name 243 | func (image *Document) GetTag(name []byte) (*Tag, error) { 244 | return image.root.GetTag(name) 245 | } 246 | 247 | // GetRoot returns the root tag of the image 248 | func (image *Document) GetRoot() *Tag { 249 | return image.root 250 | } 251 | 252 | // GetTag finds a tag by name and returns an error if not found. 253 | // Returns the first tag that matches. 254 | func (tag *Tag) GetTag(name []byte) (*Tag, error) { 255 | if bytes.Index(tag.name, name) == 0 { 256 | return tag, nil 257 | } 258 | couldNotFindError := fmt.Errorf("could not find tag: %s", name) 259 | if tag.CountChildren() == 0 { 260 | // No children. Not found so far 261 | return nil, couldNotFindError 262 | } 263 | 264 | child := tag.firstChild 265 | for child != nil { 266 | found, err := child.GetTag(name) 267 | if err == nil { 268 | return found, err 269 | } 270 | child = child.nextSibling 271 | } 272 | 273 | return nil, couldNotFindError 274 | } 275 | 276 | // Bytes (previously getXMLRecursively) renders XML for a tag, recursively. 277 | // The generated XML is returned as a []byte. 278 | func (tag *Tag) Bytes() []byte { 279 | var content, xmlContent []byte 280 | 281 | if tag.CountChildren() == 0 { 282 | return tag.getFlatXML() 283 | } 284 | 285 | child := tag.firstChild 286 | for child != nil { 287 | xmlContent = child.Bytes() 288 | if len(xmlContent) > 0 { 289 | content = append(content, xmlContent...) 290 | } 291 | child = child.nextSibling 292 | } 293 | 294 | tag.xmlContent = append(tag.xmlContent, tag.content...) 295 | tag.xmlContent = append(tag.xmlContent, content...) 296 | tag.xmlContent = append(tag.xmlContent, tag.lastContent...) 297 | 298 | return tag.getFlatXML() 299 | } 300 | 301 | // String returns the XML contents as a string 302 | func (tag *Tag) String() string { 303 | return string(tag.Bytes()) 304 | } 305 | 306 | // AddContent adds content to the body tag. 307 | // Returns the body tag and nil if successful. 308 | // Returns and an error if no body tag is found, else nil. 309 | func (image *Document) AddContent(content []byte) (*Tag, error) { 310 | body, err := image.root.GetTag([]byte("body")) 311 | if err == nil { 312 | body.AddContent(content) 313 | } 314 | return body, err 315 | } 316 | 317 | // Bytes renders the image as an XML document 318 | func (image *Document) Bytes() []byte { 319 | return image.root.Bytes() 320 | } 321 | 322 | // String renders the image as an XML document 323 | func (image *Document) String() string { 324 | return image.root.String() 325 | } 326 | 327 | // SaveSVG will save the current image as an SVG file 328 | func (image *Document) SaveSVG(filename string) error { 329 | return ioutil.WriteFile(filename, image.Bytes(), 0644) 330 | } 331 | -------------------------------------------------------------------------------- /network_test.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | ) 8 | 9 | // Use a specific seed for the random number generator, just when testing 10 | var commonSeed int64 = 1571917826405889420 11 | 12 | func TestNewNetwork(t *testing.T) { 13 | rand.Seed(commonSeed) 14 | net := NewNetwork(&Config{ 15 | inputs: 5, 16 | InitialConnectionRatio: 0.5, 17 | sharedWeight: 0.5, 18 | }) 19 | fmt.Println(net) 20 | for i, n := range net.AllNodes { 21 | if NeuronIndex(i) != n.neuronIndex { 22 | t.Fail() 23 | } 24 | } 25 | } 26 | 27 | func TestGet(t *testing.T) { 28 | rand.Seed(commonSeed) 29 | net := NewNetwork() 30 | fmt.Println(net) 31 | fmt.Println(net.Get(0)) 32 | if net.OutputNode != 0 { 33 | t.Fail() 34 | } 35 | } 36 | 37 | func TestIsInput(t *testing.T) { 38 | rand.Seed(commonSeed) 39 | net := NewNetwork(&Config{ 40 | inputs: 5, 41 | InitialConnectionRatio: 0.5, 42 | sharedWeight: 0.5, 43 | }) 44 | if !net.IsInput(1) { 45 | t.Fail() 46 | } 47 | if net.IsInput(0) { 48 | t.Fail() 49 | } 50 | } 51 | 52 | func TestForEachConnected(t *testing.T) { 53 | rand.Seed(commonSeed) 54 | net := NewNetwork(&Config{ 55 | inputs: 5, 56 | InitialConnectionRatio: 0.5, 57 | sharedWeight: 0.5, 58 | }) 59 | net.ForEachConnected(func(n *Neuron) { 60 | fmt.Printf("%d: %s, distance from output node: %d\n", n.neuronIndex, n, n.distanceFromOutputNode) 61 | }) 62 | } 63 | 64 | func TestAll(t *testing.T) { 65 | rand.Seed(commonSeed) 66 | net := NewNetwork(&Config{ 67 | inputs: 5, 68 | InitialConnectionRatio: 0.7, 69 | sharedWeight: 0.5, 70 | }) 71 | for _, node := range net.All() { 72 | fmt.Println(node) 73 | } 74 | } 75 | 76 | func TestEvaluate2(t *testing.T) { 77 | rand.Seed(commonSeed) 78 | net := NewNetwork(&Config{ 79 | inputs: 5, 80 | InitialConnectionRatio: 0.7, 81 | sharedWeight: 0.5, 82 | }) 83 | _ = net.Evaluate([]float64{0.1, 0.2, 0.3, 0.4, 0.5}) 84 | } 85 | 86 | func TestInsertNode(t *testing.T) { 87 | rand.Seed(commonSeed) 88 | net := NewNetwork(&Config{ 89 | inputs: 5, 90 | InitialConnectionRatio: 0.5, 91 | sharedWeight: 0.5, 92 | }) 93 | _, newNeuronIndex := net.NewNeuron() 94 | if err := net.InsertNode(0, 2, newNeuronIndex); err != nil { 95 | t.Error(err) 96 | } 97 | _ = net.Evaluate([]float64{0.1, 0.2, 0.3, 0.4, 0.5}) 98 | } 99 | 100 | func TestAddConnection(t *testing.T) { 101 | rand.Seed(commonSeed) 102 | net := NewNetwork(&Config{ 103 | inputs: 5, 104 | InitialConnectionRatio: 0.5, 105 | }) 106 | _, newNeuronIndex := net.NewNeuron() 107 | if err := net.InsertNode(net.OutputNode, 2, newNeuronIndex); err != nil { 108 | t.Error(err) 109 | } 110 | // Add a connection from 1 to the new neuron. 111 | // This is the same as making the new neuron have an additional input neuron: index 1 112 | if err := net.AddConnection(1, newNeuronIndex); err != nil { 113 | t.Error(err) 114 | } 115 | // Add a connection from the output node to the output node. Should fail. 116 | if err := net.AddConnection(net.OutputNode, net.OutputNode); err == nil { 117 | t.Fail() 118 | } 119 | // Adding a made-up index should fail as well 120 | if err := net.AddConnection(net.OutputNode, 999); err == nil { 121 | t.Fail() 122 | } 123 | } 124 | 125 | func TestRandomizeActivationFunctionForRandomNeuron(t *testing.T) { 126 | rand.Seed(commonSeed) 127 | net := NewNetwork(&Config{ 128 | inputs: 5, 129 | InitialConnectionRatio: 0.5, 130 | }) 131 | net.RandomizeActivationFunctionForRandomNeuron() 132 | } 133 | 134 | func TestNetworkString(t *testing.T) { 135 | rand.Seed(commonSeed) 136 | net := NewNetwork(&Config{ 137 | inputs: 5, 138 | InitialConnectionRatio: 0.5, 139 | }) 140 | _ = net.String() 141 | } 142 | 143 | func TestSetWeight(t *testing.T) { 144 | net := NewNetwork() 145 | net.SetWeight(0.1234) 146 | if net.Weight != 0.1234 { 147 | t.Fail() 148 | } 149 | } 150 | 151 | func TestComplexity(t *testing.T) { 152 | rand.Seed(commonSeed) 153 | net := NewNetwork(&Config{ 154 | inputs: 5, 155 | InitialConnectionRatio: 0.0, 156 | }) 157 | // The complexity will vary, because the performance varies when 158 | // estimating the complexity of each function. 159 | // But the complexity compared between networks should still hold. 160 | firstComplexity := net.Complexity() 161 | // Adding a connection increases the complexity 162 | net.AddConnection(0, 1) 163 | // The complexity for the network, after a connection has been added 164 | secondComplexity := net.Complexity() 165 | if firstComplexity >= secondComplexity { 166 | t.Fail() 167 | } 168 | } 169 | 170 | func ExampleNetwork_InsertNode() { 171 | rand.Seed(commonSeed) 172 | net := NewNetwork(&Config{ 173 | inputs: 3, 174 | InitialConnectionRatio: 1.0, 175 | }) 176 | fmt.Println("Before insertion:") 177 | fmt.Println(net) 178 | _, nodeIndex := net.NewNeuron() 179 | err := net.InsertNode(0, 1, nodeIndex) 180 | if err != nil { 181 | fmt.Println("error: " + err.Error()) 182 | } 183 | fmt.Println("After insertion:") 184 | fmt.Println(net) 185 | // Output: 186 | // Before insertion: 187 | // Network (4 nodes, 3 input nodes, 1 output node) 188 | // Connected inputs to output node: 3 189 | // Output node ID 0 has these input connections: [1 2 3] 190 | // Input node ID 1 has these input connections: [] 191 | // Input node ID 2 has these input connections: [] 192 | // Input node ID 3 has these input connections: [] 193 | // 194 | // After insertion: 195 | // Network (5 nodes, 3 input nodes, 1 output node) 196 | // Connected inputs to output node: 3 197 | // Output node ID 0 has these input connections: [2 3 4] 198 | // Input node ID 1 has these input connections: [] 199 | // Input node ID 2 has these input connections: [] 200 | // Input node ID 3 has these input connections: [] 201 | // Node ID 4 has these input connections: [1] 202 | } 203 | 204 | func TestLeftRight(t *testing.T) { 205 | rand.Seed(commonSeed) 206 | net := NewNetwork(&Config{ 207 | inputs: 3, 208 | InitialConnectionRatio: 1.0, 209 | }) 210 | net.AllNodes[1].ActivationFunction = Swish 211 | a, b, _ := net.LeftRight(0, 1) 212 | // output node to the right 213 | if a != 1 || b != 0 { 214 | t.Fail() 215 | } 216 | // output node to the right 217 | a, b, _ = net.LeftRight(1, 0) 218 | if a != 1 || b != 0 { 219 | t.Fail() 220 | } 221 | _, nodeIndex := net.NewNeuron() 222 | err := net.InsertNode(0, 1, nodeIndex) 223 | if err != nil { 224 | t.Error(err) 225 | } 226 | a, b, _ = net.LeftRight(0, nodeIndex) 227 | // output node to the right 228 | if a != nodeIndex || b != 0 { 229 | t.Fail() 230 | } 231 | a, b, _ = net.LeftRight(nodeIndex, 0) 232 | // output node to the right 233 | if a != nodeIndex || b != 0 { 234 | t.Fail() 235 | } 236 | a, b, _ = net.LeftRight(1, nodeIndex) 237 | // Here, the new node should be to the right, since it's between node 1 and the output node 238 | if a != 1 || b != nodeIndex { 239 | t.Fail() 240 | } 241 | //net.WriteSVG("c.svg") 242 | fmt.Println(net) 243 | a, b, _ = net.LeftRight(nodeIndex, 1) 244 | if a != 1 || b != nodeIndex { 245 | t.Fail() 246 | } 247 | } 248 | 249 | func TestDepth(t *testing.T) { 250 | rand.Seed(commonSeed) 251 | net := NewNetwork(&Config{ 252 | inputs: 3, 253 | InitialConnectionRatio: 1.0, 254 | }) 255 | fmt.Println(net.Depth()) 256 | _, nodeIndex := net.NewBlankNeuron() 257 | _ = net.InsertNode(0, 1, nodeIndex) 258 | fmt.Println(net.Depth()) 259 | } 260 | 261 | func ExampleCombine() { 262 | ac := []NeuronIndex{0, 1, 2, 3, 4} 263 | bc := []NeuronIndex{5, 6, 7, 8, 9} 264 | fmt.Println(Combine(ac, bc)) 265 | // Output: 266 | // [0 1 2 3 4 5 6 7 8 9] 267 | } 268 | 269 | func TestGetRandomNeuron(t *testing.T) { 270 | rand.Seed(commonSeed) 271 | net := NewNetwork(&Config{ 272 | inputs: 5, 273 | InitialConnectionRatio: 1.0, 274 | }) 275 | stats := make(map[NeuronIndex]uint) 276 | for i := 0; i < 1000; i++ { 277 | ni := net.GetRandomNode() 278 | if _, ok := stats[ni]; !ok { 279 | stats[ni] = 0 280 | } else { 281 | stats[ni]++ 282 | } 283 | } 284 | fmt.Println(stats) 285 | // Check that the output node exists in the stats 286 | if _, ok := stats[0]; !ok { 287 | t.Fail() 288 | } 289 | } 290 | 291 | func TestGetRandomInputNode(t *testing.T) { 292 | rand.Seed(commonSeed) 293 | net := NewNetwork(&Config{ 294 | inputs: 5, 295 | InitialConnectionRatio: 1.0, 296 | }) 297 | stats := make(map[NeuronIndex]uint) 298 | for i := 0; i < 1000; i++ { 299 | ni := net.GetRandomInputNode() 300 | if _, ok := stats[ni]; !ok { 301 | stats[ni] = 0 302 | } else { 303 | stats[ni]++ 304 | } 305 | } 306 | fmt.Println(stats) 307 | // Check that the output node does not exist in the stats 308 | if _, ok := stats[0]; ok { 309 | t.Fail() 310 | } 311 | } 312 | 313 | func TestConnected(t *testing.T) { 314 | rand.Seed(commonSeed) 315 | net := NewNetwork(&Config{ 316 | inputs: 5, 317 | InitialConnectionRatio: 0.1, 318 | }) 319 | connected := net.Connected() 320 | if connected[0] != 0 || connected[1] != 2 { 321 | t.Fail() 322 | } 323 | } 324 | 325 | func TestUnconnected(t *testing.T) { 326 | rand.Seed(commonSeed) 327 | net := NewNetwork(&Config{ 328 | inputs: 5, 329 | InitialConnectionRatio: 0.5, 330 | }) 331 | unconnected := net.Unconnected() 332 | correct := []NeuronIndex{1, 3, 4} 333 | for i := 0; i < len(unconnected); i++ { 334 | if unconnected[i] != correct[i] { 335 | t.Fail() 336 | } 337 | } 338 | } 339 | 340 | func TestCopy(t *testing.T) { 341 | rand.Seed(commonSeed) 342 | net := NewNetwork(&Config{ 343 | inputs: 5, 344 | InitialConnectionRatio: 0.5, 345 | }) 346 | 347 | // Take a deep copy with the Copy() function 348 | net2 := net.Copy() 349 | // Modify net2 by inserting an unconnected neuron 350 | n := NewUnconnectedNeuron() 351 | net2.AllNodes[1] = *n 352 | // net and net2 should now be different, since net2 is a proper copy 353 | if net.String() == net2.String() { 354 | t.Fail() 355 | } 356 | 357 | // Take a shallow copy 358 | net3 := net 359 | // Modify net3 by inserting an unconnected neuron 360 | net3.AllNodes[1] = *n 361 | // net and net3 should still be the same, since net3 is just a shallow copy 362 | if net.String() != net3.String() { 363 | t.Fail() 364 | } 365 | } 366 | 367 | func TestForEachConnectedNodeIndex(t *testing.T) { 368 | rand.Seed(commonSeed) 369 | net := NewNetwork(&Config{ 370 | inputs: 5, 371 | InitialConnectionRatio: 0.5, 372 | }) 373 | lastNi := NeuronIndex(-1) 374 | net.ForEachConnectedNodeIndex(func(ni NeuronIndex) { 375 | fmt.Println(ni) 376 | lastNi = ni 377 | }) 378 | if lastNi != 5 { 379 | t.Fail() 380 | } 381 | } 382 | -------------------------------------------------------------------------------- /vendor/github.com/dave/jennifer/jen/hints.go: -------------------------------------------------------------------------------- 1 | // This file is generated - do not edit. 2 | 3 | package jen 4 | 5 | // standardLibraryHints contains package name hints 6 | var standardLibraryHints = map[string]string{ 7 | "archive/tar": "tar", 8 | "archive/zip": "zip", 9 | "bufio": "bufio", 10 | "bytes": "bytes", 11 | "compress/bzip2": "bzip2", 12 | "compress/flate": "flate", 13 | "compress/gzip": "gzip", 14 | "compress/lzw": "lzw", 15 | "compress/zlib": "zlib", 16 | "constraints": "constraints", 17 | "container/heap": "heap", 18 | "container/list": "list", 19 | "container/ring": "ring", 20 | "context": "context", 21 | "crypto": "crypto", 22 | "crypto/aes": "aes", 23 | "crypto/cipher": "cipher", 24 | "crypto/des": "des", 25 | "crypto/dsa": "dsa", 26 | "crypto/ecdsa": "ecdsa", 27 | "crypto/ed25519": "ed25519", 28 | "crypto/ed25519/internal/edwards25519": "edwards25519", 29 | "crypto/ed25519/internal/edwards25519/field": "field", 30 | "crypto/elliptic": "elliptic", 31 | "crypto/elliptic/internal/fiat": "fiat", 32 | "crypto/elliptic/internal/nistec": "nistec", 33 | "crypto/hmac": "hmac", 34 | "crypto/internal/randutil": "randutil", 35 | "crypto/internal/subtle": "subtle", 36 | "crypto/md5": "md5", 37 | "crypto/rand": "rand", 38 | "crypto/rc4": "rc4", 39 | "crypto/rsa": "rsa", 40 | "crypto/sha1": "sha1", 41 | "crypto/sha256": "sha256", 42 | "crypto/sha512": "sha512", 43 | "crypto/subtle": "subtle", 44 | "crypto/tls": "tls", 45 | "crypto/x509": "x509", 46 | "crypto/x509/internal/macos": "macOS", 47 | "crypto/x509/pkix": "pkix", 48 | "database/sql": "sql", 49 | "database/sql/driver": "driver", 50 | "debug/buildinfo": "buildinfo", 51 | "debug/dwarf": "dwarf", 52 | "debug/elf": "elf", 53 | "debug/gosym": "gosym", 54 | "debug/macho": "macho", 55 | "debug/pe": "pe", 56 | "debug/plan9obj": "plan9obj", 57 | "embed": "embed", 58 | "embed/internal/embedtest": "embedtest", 59 | "encoding": "encoding", 60 | "encoding/ascii85": "ascii85", 61 | "encoding/asn1": "asn1", 62 | "encoding/base32": "base32", 63 | "encoding/base64": "base64", 64 | "encoding/binary": "binary", 65 | "encoding/csv": "csv", 66 | "encoding/gob": "gob", 67 | "encoding/hex": "hex", 68 | "encoding/json": "json", 69 | "encoding/pem": "pem", 70 | "encoding/xml": "xml", 71 | "errors": "errors", 72 | "expvar": "expvar", 73 | "flag": "flag", 74 | "fmt": "fmt", 75 | "go/ast": "ast", 76 | "go/build": "build", 77 | "go/build/constraint": "constraint", 78 | "go/constant": "constant", 79 | "go/doc": "doc", 80 | "go/format": "format", 81 | "go/importer": "importer", 82 | "go/internal/gccgoimporter": "gccgoimporter", 83 | "go/internal/gcimporter": "gcimporter", 84 | "go/internal/srcimporter": "srcimporter", 85 | "go/internal/typeparams": "typeparams", 86 | "go/parser": "parser", 87 | "go/printer": "printer", 88 | "go/scanner": "scanner", 89 | "go/token": "token", 90 | "go/types": "types", 91 | "hash": "hash", 92 | "hash/adler32": "adler32", 93 | "hash/crc32": "crc32", 94 | "hash/crc64": "crc64", 95 | "hash/fnv": "fnv", 96 | "hash/maphash": "maphash", 97 | "html": "html", 98 | "html/template": "template", 99 | "image": "image", 100 | "image/color": "color", 101 | "image/color/palette": "palette", 102 | "image/draw": "draw", 103 | "image/gif": "gif", 104 | "image/internal/imageutil": "imageutil", 105 | "image/jpeg": "jpeg", 106 | "image/png": "png", 107 | "index/suffixarray": "suffixarray", 108 | "internal/abi": "abi", 109 | "internal/buildcfg": "buildcfg", 110 | "internal/bytealg": "bytealg", 111 | "internal/cfg": "cfg", 112 | "internal/cpu": "cpu", 113 | "internal/execabs": "execabs", 114 | "internal/fmtsort": "fmtsort", 115 | "internal/fuzz": "fuzz", 116 | "internal/goarch": "goarch", 117 | "internal/godebug": "godebug", 118 | "internal/goexperiment": "goexperiment", 119 | "internal/goos": "goos", 120 | "internal/goroot": "goroot", 121 | "internal/goversion": "goversion", 122 | "internal/intern": "intern", 123 | "internal/itoa": "itoa", 124 | "internal/lazyregexp": "lazyregexp", 125 | "internal/lazytemplate": "lazytemplate", 126 | "internal/nettrace": "nettrace", 127 | "internal/obscuretestdata": "obscuretestdata", 128 | "internal/oserror": "oserror", 129 | "internal/poll": "poll", 130 | "internal/profile": "profile", 131 | "internal/race": "race", 132 | "internal/reflectlite": "reflectlite", 133 | "internal/singleflight": "singleflight", 134 | "internal/syscall/execenv": "execenv", 135 | "internal/syscall/unix": "unix", 136 | "internal/sysinfo": "sysinfo", 137 | "internal/testenv": "testenv", 138 | "internal/testlog": "testlog", 139 | "internal/trace": "trace", 140 | "internal/unsafeheader": "unsafeheader", 141 | "internal/xcoff": "xcoff", 142 | "io": "io", 143 | "io/fs": "fs", 144 | "io/ioutil": "ioutil", 145 | "log": "log", 146 | "log/syslog": "syslog", 147 | "math": "math", 148 | "math/big": "big", 149 | "math/bits": "bits", 150 | "math/cmplx": "cmplx", 151 | "math/rand": "rand", 152 | "mime": "mime", 153 | "mime/multipart": "multipart", 154 | "mime/quotedprintable": "quotedprintable", 155 | "net": "net", 156 | "net/http": "http", 157 | "net/http/cgi": "cgi", 158 | "net/http/cookiejar": "cookiejar", 159 | "net/http/fcgi": "fcgi", 160 | "net/http/httptest": "httptest", 161 | "net/http/httptrace": "httptrace", 162 | "net/http/httputil": "httputil", 163 | "net/http/internal": "internal", 164 | "net/http/internal/ascii": "ascii", 165 | "net/http/internal/testcert": "testcert", 166 | "net/http/pprof": "pprof", 167 | "net/internal/socktest": "socktest", 168 | "net/mail": "mail", 169 | "net/netip": "netip", 170 | "net/rpc": "rpc", 171 | "net/rpc/jsonrpc": "jsonrpc", 172 | "net/smtp": "smtp", 173 | "net/textproto": "textproto", 174 | "net/url": "url", 175 | "os": "os", 176 | "os/exec": "exec", 177 | "os/exec/internal/fdtest": "fdtest", 178 | "os/signal": "signal", 179 | "os/signal/internal/pty": "pty", 180 | "os/user": "user", 181 | "path": "path", 182 | "path/filepath": "filepath", 183 | "plugin": "plugin", 184 | "reflect": "reflect", 185 | "reflect/internal/example1": "example1", 186 | "reflect/internal/example2": "example2", 187 | "regexp": "regexp", 188 | "regexp/syntax": "syntax", 189 | "runtime": "runtime", 190 | "runtime/cgo": "cgo", 191 | "runtime/debug": "debug", 192 | "runtime/internal/atomic": "atomic", 193 | "runtime/internal/math": "math", 194 | "runtime/internal/sys": "sys", 195 | "runtime/metrics": "metrics", 196 | "runtime/pprof": "pprof", 197 | "runtime/race": "race", 198 | "runtime/trace": "trace", 199 | "sort": "sort", 200 | "strconv": "strconv", 201 | "strings": "strings", 202 | "sync": "sync", 203 | "sync/atomic": "atomic", 204 | "syscall": "syscall", 205 | "testing": "testing", 206 | "testing/fstest": "fstest", 207 | "testing/internal/testdeps": "testdeps", 208 | "testing/iotest": "iotest", 209 | "testing/quick": "quick", 210 | "text/scanner": "scanner", 211 | "text/tabwriter": "tabwriter", 212 | "text/template": "template", 213 | "text/template/parse": "parse", 214 | "time": "time", 215 | "time/tzdata": "tzdata", 216 | "unicode": "unicode", 217 | "unicode/utf16": "utf16", 218 | "unicode/utf8": "utf8", 219 | "unsafe": "unsafe", 220 | } 221 | -------------------------------------------------------------------------------- /neuron.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math/rand" 7 | 8 | "github.com/dave/jennifer/jen" 9 | ) 10 | 11 | // Neuron is a list of input-neurons, and an activation function. 12 | type Neuron struct { 13 | Net *Network 14 | InputNodes []NeuronIndex // pointers to other neurons 15 | ActivationFunction ActivationFunctionIndex 16 | Value *float64 17 | distanceFromOutputNode int // Used when traversing nodes and drawing diagrams 18 | neuronIndex NeuronIndex 19 | } 20 | 21 | // NewBlankNeuron creates a new Neuron, with the Step activation function as the default 22 | func (net *Network) NewBlankNeuron() (*Neuron, NeuronIndex) { 23 | // Pre-allocate room for 16 connections and use Linear as the default activation function 24 | neuron := Neuron{Net: net, InputNodes: make([]NeuronIndex, 0, 16), ActivationFunction: Swish} 25 | neuron.neuronIndex = NeuronIndex(len(net.AllNodes)) 26 | net.AllNodes = append(net.AllNodes, neuron) 27 | return &neuron, neuron.neuronIndex 28 | } 29 | 30 | // NewNeuron creates a new *Neuron, with a randomly chosen activation function 31 | func (net *Network) NewNeuron() (*Neuron, NeuronIndex) { 32 | chosenActivationFunctionIndex := ActivationFunctionIndex(rand.Intn(len(ActivationFunctions))) 33 | inputNodes := make([]NeuronIndex, 0, 16) 34 | neuron := Neuron{ 35 | Net: net, 36 | InputNodes: inputNodes, 37 | ActivationFunction: chosenActivationFunctionIndex, 38 | } 39 | // The length of net.AllNodes is what will be the last index 40 | neuronIndex := NeuronIndex(len(net.AllNodes)) 41 | // Assign the neuron index in the net to the neuron 42 | neuron.neuronIndex = neuronIndex 43 | // Add this neuron to the net 44 | net.AllNodes = append(net.AllNodes, neuron) 45 | return &neuron, neuronIndex 46 | } 47 | 48 | // NewUnconnectedNeuron returns a new unconnected neuron with neuronIndex -1 and net pointer set to nil 49 | func NewUnconnectedNeuron() *Neuron { 50 | // Pre-allocate room for 16 connections and use Linear as the default activation function 51 | neuron := Neuron{Net: nil, InputNodes: make([]NeuronIndex, 0, 16), ActivationFunction: Linear} 52 | neuron.neuronIndex = -1 53 | return &neuron 54 | } 55 | 56 | // Connect this neuron to a network, overwriting any existing connections. 57 | // This will also clear any input nodes to this neuron, since the net is different. 58 | // TODO: Find the input nodes from the neuron.Net, save those and re-assign if there are matches? 59 | func (neuron *Neuron) Connect(net *Network) { 60 | neuron.InputNodes = []NeuronIndex{} 61 | neuron.Net = net 62 | for ni := range net.AllNodes { 63 | // Check if this network already has a pointer to this neuron 64 | if &net.AllNodes[ni] == neuron { 65 | // Yes, assign the index 66 | neuron.neuronIndex = NeuronIndex(ni) 67 | // All good, bail 68 | return 69 | } 70 | } 71 | // The neuron was not found in the network 72 | // Find what will be the last index in net.AllNodes 73 | neuronIndex := len(net.AllNodes) 74 | // Add this neuron to the network 75 | net.AllNodes = append(net.AllNodes, *neuron) 76 | // Assign the index 77 | net.AllNodes[neuronIndex].neuronIndex = NeuronIndex(neuronIndex) 78 | } 79 | 80 | // RandomizeActivationFunction will choose a random activation function for this neuron 81 | func (neuron *Neuron) RandomizeActivationFunction() { 82 | chosenActivationFunctionIndex := ActivationFunctionIndex(rand.Intn(len(ActivationFunctions))) 83 | neuron.ActivationFunction = chosenActivationFunctionIndex 84 | } 85 | 86 | // SetValue can be used for setting a value for this neuron instead of using input neutrons. 87 | // This changes how the Evaluation function behaves. 88 | func (neuron *Neuron) SetValue(x float64) { 89 | neuron.Value = &x 90 | } 91 | 92 | // HasInput checks if the given neuron is an input neuron to this one 93 | func (neuron *Neuron) HasInput(e NeuronIndex) bool { 94 | for _, ni := range neuron.InputNodes { 95 | if ni == e { 96 | return true 97 | } 98 | } 99 | return false 100 | } 101 | 102 | // FindInput checks if the given neuron is an input neuron to this one, 103 | // and also returns the index to InputNeurons, if found. 104 | func (neuron *Neuron) FindInput(e NeuronIndex) (int, bool) { 105 | for i, n := range neuron.InputNodes { 106 | if n == e { 107 | return i, true 108 | } 109 | } 110 | return -1, false 111 | } 112 | 113 | // Is check if the given NeuronIndex points to this neuron 114 | func (neuron *Neuron) Is(e NeuronIndex) bool { 115 | return neuron.neuronIndex == e 116 | } 117 | 118 | // AddInput will add an input neuron 119 | func (neuron *Neuron) AddInput(ni NeuronIndex) error { 120 | if neuron.Is(ni) { 121 | return errors.New("adding a neuron as input to itself") 122 | } 123 | if neuron.HasInput(ni) { 124 | return errors.New("neuron already exists") 125 | } 126 | neuron.InputNodes = append(neuron.InputNodes, ni) 127 | 128 | return nil 129 | } 130 | 131 | // AddInputNeuron both adds a neuron to this network (if needed) and also 132 | // adds its neuron index to the neuron.InputNeurons 133 | func (neuron *Neuron) AddInputNeuron(n *Neuron) error { 134 | // If n.neuronIndex is known to this network, just add the NeuronIndex to neuron.InputNeurons 135 | if neuron.Net.Exists(n.neuronIndex) { 136 | return neuron.AddInput(n.neuronIndex) 137 | } 138 | // If not, add this neuron to the network first 139 | node := *n 140 | node.neuronIndex = NeuronIndex(len(neuron.Net.AllNodes)) 141 | neuron.Net.AllNodes = append(neuron.Net.AllNodes, node) 142 | return neuron.AddInput(n.neuronIndex) 143 | } 144 | 145 | // RemoveInput will remove an input neuron 146 | func (neuron *Neuron) RemoveInput(e NeuronIndex) error { 147 | if i, found := neuron.FindInput(e); found { 148 | // Found it, remove the neuron at index i 149 | neuron.InputNodes = append(neuron.InputNodes[:i], neuron.InputNodes[i+1:]...) 150 | return nil 151 | } 152 | return errors.New("neuron does not exist") 153 | } 154 | 155 | // Exists checks if the given NeuronIndex exists in this Network 156 | func (net *Network) Exists(ni NeuronIndex) bool { 157 | for i := range net.AllNodes { 158 | neuronIndex := NeuronIndex(i) 159 | if neuronIndex == ni { 160 | return true 161 | } 162 | } 163 | return false 164 | } 165 | 166 | // InputNeuronsAreGood checks if all input neurons of this neuron exists in neuron.Net 167 | func (neuron *Neuron) InputNeuronsAreGood() bool { 168 | for _, inputNeuronIndex := range neuron.InputNodes { 169 | if !neuron.Net.Exists(inputNeuronIndex) { 170 | return false 171 | } 172 | } 173 | return true 174 | } 175 | 176 | // evaluate will return a weighted sum of the input nodes, 177 | // using the .Value field if it is set and no input nodes are available. 178 | // returns true if the maximum number of evaluation loops is reached 179 | func (neuron *Neuron) evaluate(weight float64, maxEvaluationLoops *int) (float64, bool) { 180 | if *maxEvaluationLoops <= 0 { 181 | return 0.0, true 182 | } 183 | // Assume this is the Output neuron, recursively evaluating the result 184 | // For each input neuron, evaluate them 185 | summed := 0.0 186 | counter := 0 187 | 188 | for _, inputNeuronIndex := range neuron.InputNodes { 189 | // Let each input neuron do its own evauluation, using the given weight 190 | (*maxEvaluationLoops)-- 191 | // TODO: Figure out exactly why this one kicks in (and if it matters) 192 | // It only seems to kick in during "go test" and not in evolve/main.go 193 | if int(inputNeuronIndex) >= len(neuron.Net.AllNodes) { 194 | continue 195 | //panic("TOO HIGH INPUT NEURON INDEX") 196 | } 197 | result, stopNow := neuron.Net.AllNodes[inputNeuronIndex].evaluate(weight, maxEvaluationLoops) 198 | summed += result * weight 199 | counter++ 200 | if stopNow || (*maxEvaluationLoops < 0) { 201 | break 202 | } 203 | } 204 | // No input neurons. Use the .Value field if it's not nil and this is not the output node 205 | if counter == 0 && neuron.Value != nil && !neuron.IsOutput() { 206 | return *(neuron.Value), false 207 | } 208 | // Return the averaged sum, or 0 209 | if counter == 0 { 210 | // This should never happen 211 | return 0.0, false 212 | } 213 | // This should run, also when this neuron is the output neuron 214 | f := neuron.GetActivationFunction() 215 | //fmt.Println(neuron.ActivationFunction.Name() + " = " + neuron.ActivationFunction.String()) 216 | // Run the input through the activation function 217 | // TODO: Does "retval := f(summed)"" perform better?, or the one that averages the sum first? 218 | //retval := f(summed / float64(counter)) 219 | retval := f(summed) 220 | return retval, false 221 | } 222 | 223 | // GetActivationFunction returns the activation function for this neuron 224 | func (neuron *Neuron) GetActivationFunction() func(float64) float64 { 225 | return ActivationFunctions[neuron.ActivationFunction] 226 | } 227 | 228 | // In checks if this neuron is in the given collection 229 | func (neuron *Neuron) In(collection []NeuronIndex) bool { 230 | for _, existingNodeIndex := range collection { 231 | if neuron.Is(existingNodeIndex) { 232 | return true 233 | } 234 | } 235 | return false 236 | } 237 | 238 | // IsInput returns true if this is an input node or not 239 | // Returns false if nil 240 | func (neuron *Neuron) IsInput() bool { 241 | if neuron.Net == nil { 242 | return false 243 | } 244 | return neuron.Net.IsInput(neuron.neuronIndex) 245 | } 246 | 247 | // IsOutput returns true if this is an output node or not 248 | // Returns false if nil 249 | func (neuron *Neuron) IsOutput() bool { 250 | if neuron.Net == nil { 251 | return false 252 | } 253 | return neuron.Net.OutputNode == neuron.neuronIndex 254 | } 255 | 256 | // Copy a Neuron to a new Neuron, and assign the pointer to the given network to .Net 257 | func (neuron Neuron) Copy(net *Network) Neuron { 258 | var newNeuron Neuron 259 | newNeuron.Net = net 260 | newNeuron.InputNodes = neuron.InputNodes 261 | newNeuron.ActivationFunction = neuron.ActivationFunction 262 | newNeuron.Value = neuron.Value 263 | newNeuron.distanceFromOutputNode = neuron.distanceFromOutputNode 264 | newNeuron.neuronIndex = neuron.neuronIndex 265 | return newNeuron 266 | } 267 | 268 | // String will return a string containing both the pointer address and the number of input neurons 269 | func (neuron *Neuron) String() string { 270 | nodeType := " Node" 271 | if neuron.IsInput() { 272 | nodeType = " Input node" 273 | } else if neuron.IsOutput() { 274 | nodeType = "Output node" 275 | } 276 | return fmt.Sprintf("%s ID %d has these input connections: %v", nodeType, neuron.neuronIndex, neuron.InputNodes) 277 | } 278 | 279 | // InputStatement returns a statement like "inputData[0]", if this node is a network input node 280 | func (neuron *Neuron) InputStatement() (*jen.Statement, error) { 281 | // If this node is a network input node, return a statement representing this input, 282 | // like "inputData[0]" 283 | if !neuron.IsInput() { 284 | return jen.Empty(), errors.New(" not an input node") 285 | } 286 | for i, ni := range neuron.Net.InputNodes { 287 | if ni == neuron.neuronIndex { 288 | // This index in the neuron.NetInputNodes is i 289 | return jen.Id("inputData").Index(jen.Lit(i)), nil 290 | } 291 | } 292 | // Not found! 293 | return jen.Empty(), errors.New("not an input node for the associated network") 294 | } 295 | -------------------------------------------------------------------------------- /evolve.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math/rand" 7 | "strconv" 8 | ) 9 | 10 | // ScorePopulation evaluates a population, given a slice of input numbers. 11 | // It returns a map with scores, together with the sum of scores. 12 | func ScorePopulation(population []*Network, weight float64, inputData [][]float64, incorrectOutputMultipliers []float64) (map[int]float64, float64) { 13 | 14 | scoreMap := make(map[int]float64) 15 | scoreSum := 0.0 16 | 17 | for i := 0; i < len(population); i++ { 18 | net := population[i] 19 | 20 | if len(net.AllNodes[net.OutputNode].InputNodes) == 0 { 21 | // The output node has no input nodes, not great 22 | scoreMap[i] = 0.0 23 | continue 24 | } 25 | 26 | net.SetWeight(weight) 27 | 28 | // Evaluate all the input data examples for this network 29 | result := 0.0 30 | for i := 0; i < len(inputData); i++ { 31 | result += net.Evaluate(inputData[i]) * incorrectOutputMultipliers[i] 32 | } 33 | 34 | // The score is how well the network is doing, divided by the network complexity rating 35 | score := result / net.Complexity() 36 | 37 | scoreSum += score 38 | scoreMap[i] = score 39 | } 40 | return scoreMap, scoreSum 41 | } 42 | 43 | // Modify the network using one of the three methods outlined in the paper: 44 | // * Insert node 45 | // * Add connection 46 | // * Change activation function 47 | func (net *Network) Modify(maxIterations int) { 48 | 49 | // Use method 0, 1 or 2 50 | method := rand.Intn(3) // up to and not including 3 51 | 52 | // Perform a modfification, using one of the three methods outlined in the paper 53 | switch method { 54 | case 0: 55 | // Insert a node, replacing a randomly chosen existing connection 56 | counter := 0 57 | for !net.InsertRandomNode() { 58 | counter++ 59 | if maxIterations > 0 && counter > maxIterations { 60 | break 61 | } 62 | } 63 | case 1: 64 | nodeA, nodeB := net.GetRandomNode(), net.GetRandomNode() 65 | // Continue finding random neurons until they work out or until maxIterations is reached 66 | // Create a new connection 67 | counter := 0 68 | for net.AddConnection(nodeA, nodeB) != nil { 69 | nodeA, nodeB = net.GetRandomNode(), net.GetRandomNode() 70 | counter++ 71 | if maxIterations > 0 && counter > maxIterations { 72 | // Could not add a connection. The possibilities for connections might be saturated. 73 | return 74 | } 75 | } 76 | case 2: 77 | // Change the activation function to a randomly selected one 78 | net.RandomizeActivationFunctionForRandomNeuron() 79 | default: 80 | panic("implementation error: invalid method number: " + strconv.Itoa(method)) 81 | } 82 | } 83 | 84 | // Complexity measures the network complexity 85 | // Will return 1.0 at a minimum 86 | func (net *Network) Complexity() float64 { 87 | 88 | // TODO: These two constants really affect the results. Place them in the Config struct instead. 89 | 90 | // How much should the function complexity matter in relation to the number of connected nodes? 91 | const functionComplexityMultiplier = 1.0 92 | 93 | // Weight the number of connected nodes 94 | const numberOfNodesMultiplier = 2.0 95 | 96 | // Number of input nodes connected to the output node multiplier 97 | const outputNodeInputNodesMultiplier = 3.0 98 | 99 | activationFunctionComplexity := 0.0 100 | // Sum the complexity of all activation functions. 101 | // This penalizes both slow activation functions and 102 | // unconnected nodes. 103 | for _, n := range net.AllNodes { 104 | if n.Value == nil { 105 | activationFunctionComplexity += ComplexityEstimate[n.ActivationFunction] 106 | } 107 | } 108 | activationFunctionComplexity *= functionComplexityMultiplier 109 | // The number of connected nodes should also carry some weight 110 | connectedNodes := float64(len(net.Connected())) * numberOfNodesMultiplier 111 | // The number of input nodes to the output node 112 | outputNodeComplexity := float64(len(net.AllNodes[net.OutputNode].InputNodes)) * outputNodeInputNodesMultiplier 113 | // This must always be larger than 0, to avoid divide by zero later 114 | return connectedNodes + activationFunctionComplexity + outputNodeComplexity + 1.0 115 | } 116 | 117 | // Evolve evolves a neural network, given a slice of training data and a slice of correct output values. 118 | // Will overwrite config.Inputs. 119 | func (config *Config) Evolve(inputData [][]float64, incorrectOutputMultipliers []float64) (*Network, error) { 120 | 121 | // TODO: If the config.initialConnectionRatio field is too low (0.0, for instance), then this function will fail. 122 | // Return with an error if none of the networks in a population has any connections left, then get rid of the "no improvement counter". 123 | 124 | // Initialize, if needed 125 | if !config.initialized { 126 | config.Init() 127 | } 128 | 129 | inputLength := len(inputData) 130 | if inputLength == 0 { 131 | return nil, errors.New("no input data") 132 | } 133 | 134 | const maxModificationInterationsWhenMutating = 10 135 | 136 | // incorrectOutputMultipliers := make([]float64, len(correctOutputMultipliers)) 137 | // for i := range correctOutputMultipliers { 138 | // // Convert from having 0..1 for meaning from incorrect to correct, to -1..1 to mean the same 139 | // incorrectOutputMultipliers[i] = correctOutputMultipliers[i]*2.0 - 1.0 140 | // // Convert from having 0..1 for meaning from incorrect to correct, to 1..0 to mean the same 141 | // //incorrectOutputMultipliers[i] = -correctOutputMultipliers[i] + 1.0 142 | // } 143 | 144 | if len(incorrectOutputMultipliers) == 1 && inputLength != 1 { 145 | // Assume the first slice of floats in the input data is the correct and that the rest are examples of being wrong 146 | for i := 1; i < inputLength; i++ { 147 | incorrectOutputMultipliers = append(incorrectOutputMultipliers, -1.0) 148 | } 149 | } else if inputLength != len(incorrectOutputMultipliers) { 150 | // Assume that the list of correct output multipliers should match the length of the float64 slices in inputData 151 | return nil, errors.New("the length of the input data and the slice of output multipliers differs") 152 | } 153 | 154 | config.inputs = len(inputData[0]) 155 | 156 | population := make([]*Network, config.PopulationSize) 157 | 158 | // Initialize the population 159 | for i := 0; i < config.PopulationSize; i++ { 160 | n := NewNetwork(config) 161 | population[i] = &n 162 | population[i].UpdateNetworkPointers() 163 | } 164 | 165 | var ( 166 | bestNetwork *Network 167 | 168 | // Keep track of the best scores 169 | bestScore float64 170 | lastBestScore float64 171 | 172 | noImprovementCounter int // Counts how many times the best score has been stagnant 173 | 174 | // Keep track of the average scores 175 | averageScore float64 176 | 177 | // Keep track of the worst scores 178 | worstScore float64 179 | ) 180 | 181 | if config.Verbose { 182 | fmt.Printf("Starting evolution with population size %d, for %d generations.\n", config.PopulationSize, config.Generations) 183 | } 184 | 185 | // For each generation, evaluate and modify the networks 186 | for j := 0; j < config.Generations; j++ { 187 | 188 | bestNetwork = nil 189 | 190 | // Initialize the scores with unlikely values 191 | // TODO: Use the first network in the population for initializing these instead 192 | first := true 193 | 194 | // Random weight from -2.0 to 2.0 195 | w := rand.Float64() 196 | 197 | // The scores for this generation (using a random shared weight within ScorePopulation). 198 | // CorrectOutputMultipliers gives weight to the "correct" or "wrong" results, with the same index as the inputData 199 | // Score each network in the population. 200 | scoreMap, scoreSum := ScorePopulation(population, w, inputData, incorrectOutputMultipliers) 201 | 202 | // Sort by score 203 | scoreList := SortByValue(scoreMap) 204 | 205 | // Handle the best score stats 206 | if first { 207 | lastBestScore = 0.0 208 | bestScore = scoreList[0].Value 209 | worstScore = scoreList[len(scoreList)-1].Value 210 | bestNetwork = population[scoreList[0].Key] 211 | bestNetwork.SetWeight(w) 212 | first = false 213 | } else { 214 | lastBestScore = bestScore 215 | if scoreList[0].Value > bestScore { 216 | bestScore = scoreList[0].Value 217 | } 218 | } 219 | if bestScore > lastBestScore { 220 | bestNetwork = population[scoreList[0].Key] 221 | bestNetwork.SetWeight(w) 222 | noImprovementCounter = 0 223 | } else { 224 | noImprovementCounter++ 225 | } 226 | 227 | // Handle the average score stats 228 | averageScore = scoreSum / float64(config.PopulationSize) 229 | 230 | // Handle the worst score stats 231 | if scoreList[len(scoreList)-1].Value < worstScore { 232 | worstScore = scoreList[len(scoreList)-1].Value 233 | } 234 | 235 | if bestNetwork == nil { 236 | panic("implementation error: no best network") 237 | } 238 | 239 | if config.Verbose { 240 | fmt.Printf("[generation %d] worst score = %f, average score = %f, best score = %f\n", j, worstScore, averageScore, bestScore) 241 | //fmt.Printf("[generation %d] worst score = %f, average score = %f, best score = %f, no improvement counter for this generation = %d\n", j, worstScore, averageScore, bestScore, noImprovementCounter) 242 | if noImprovementCounter > 0 { 243 | fmt.Printf("No improvement in the best score for the last %d generations\n", noImprovementCounter) 244 | } 245 | } 246 | 247 | // Only keep the best 7% 248 | bestFractionCountdown := int(float64(len(population)) * 0.07) 249 | 250 | goodNetworks := make([]*Network, 0, bestFractionCountdown) 251 | 252 | // Now loop over all networks, sorted by score (descending order) 253 | // p.Key is the network index 254 | // p.Value is the network score 255 | for _, p := range scoreList { 256 | networkIndex := p.Key 257 | if bestFractionCountdown > 0 { 258 | bestFractionCountdown-- 259 | // In the best third of the networks 260 | goodNetworks = append(goodNetworks, population[networkIndex]) 261 | continue 262 | } 263 | // // If there has not been any improvement to the best score lately, randomize the bad half 264 | // if noImprovementCounter > 100 { 265 | // n := NewNetwork(config) 266 | // population[networkIndex] = &n 267 | // continue 268 | // } 269 | randomGoodNetwork := goodNetworks[rand.Intn(len(goodNetworks))] 270 | randomGoodNetworkCopy := randomGoodNetwork.Copy() 271 | randomGoodNetworkCopy.Modify(maxModificationInterationsWhenMutating) 272 | // Replace the "bad" network with the modified copy of a "good" one 273 | // It's important that this is a pointer to a Network and not 274 | // a bare Network, so that the node .Net pointers are correct. 275 | population[networkIndex] = randomGoodNetworkCopy 276 | } 277 | // if noImprovementCounter > 100 { 278 | // noImprovementCounter = 0 279 | // } 280 | } 281 | if config.Verbose { 282 | fmt.Printf("[all time best network, random weight ] weight=%f score=%f\n", bestNetwork.Weight, bestScore) 283 | } 284 | 285 | // Now find the best weight for the best network, using a population of 1 286 | // and a step size of 0.0001 for the weight 287 | population = []*Network{bestNetwork} 288 | bestWeight := -2.0 289 | for w := -2.0; w <= 2.0; w += 0.0001 { 290 | scoreMap, _ := ScorePopulation(population, w, inputData, incorrectOutputMultipliers) 291 | // Handle the best score stats 292 | if scoreMap[0] > bestScore { 293 | bestScore = scoreMap[0] 294 | population[0].SetWeight(w) 295 | bestWeight = w 296 | } 297 | } 298 | 299 | // Check if the best network is nil, just in case 300 | if bestNetwork == nil { 301 | return nil, errors.New("the total best network is nil") 302 | } 303 | 304 | // Save the best weight for the network 305 | bestNetwork.SetWeight(bestWeight) 306 | 307 | if config.Verbose { 308 | fmt.Printf("[all time best network, optimal weight ] weight=%f best score=%f\n", bestNetwork.Weight, bestScore) 309 | } 310 | 311 | return bestNetwork, nil 312 | } 313 | -------------------------------------------------------------------------------- /network.go: -------------------------------------------------------------------------------- 1 | package wann 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "math/rand" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // NeuronIndex is an index into the AllNodes slice 12 | type NeuronIndex int 13 | 14 | // Network is a collection of nodes, an output node and a shared weight. 15 | type Network struct { 16 | AllNodes []Neuron // Storing the actual neurons 17 | InputNodes []NeuronIndex // Pointers to the input nodes 18 | OutputNode NeuronIndex // Pointer to the output node 19 | Weight float64 // Shared weight 20 | } 21 | 22 | // NewNetwork creates a new minimal network with n input nodes and ratio of r connections. 23 | // Passing "nil" as an argument is supported. 24 | func NewNetwork(cs ...*Config) Network { 25 | c := &Config{} 26 | // If a single non-nil *Config struct is given, use that 27 | if len(cs) == 1 && cs[0] != nil { 28 | c = cs[0] 29 | } 30 | n := c.inputs 31 | r := c.InitialConnectionRatio 32 | w := c.sharedWeight 33 | // Create a new network that has one node, the output node 34 | outputNodeIndex := NeuronIndex(0) 35 | net := Network{make([]Neuron, 0, n+1), make([]NeuronIndex, n), outputNodeIndex, w} 36 | outputNode, outputNodeIndex := net.NewNeuron() 37 | net.OutputNode = outputNodeIndex 38 | 39 | // Initialize n input nodes that all are inputs to the one output node. 40 | for i := 0; i < n; i++ { 41 | // Add a new input node 42 | _, nodeIndex := net.NewNeuron() 43 | 44 | // Register the input node index in the input node NeuronIndex slice 45 | net.InputNodes[i] = nodeIndex 46 | 47 | // Make connections for all nodes where a random number between 0 and 1 are larger than r 48 | if r >= rand.Float64() { 49 | if err := outputNode.AddInput(nodeIndex); err != nil { 50 | panic(err) 51 | } 52 | } 53 | } 54 | 55 | // Store the modified output node 56 | net.AllNodes[outputNodeIndex] = *outputNode 57 | 58 | return net 59 | } 60 | 61 | // Get returns a pointer to a neuron, based on the given NeuronIndex 62 | func (net *Network) Get(i NeuronIndex) *Neuron { 63 | return &(net.AllNodes[i]) 64 | } 65 | 66 | // IsInput checks if the given node is an input node 67 | func (net *Network) IsInput(ni NeuronIndex) bool { 68 | for _, inputNodeIndex := range net.InputNodes { 69 | if ni == inputNodeIndex { 70 | return true 71 | } 72 | } 73 | return false 74 | } 75 | 76 | // 77 | // Operators for searching the space of network topologies 78 | // 79 | 80 | // InsertNode takes two neurons and inserts a third neuron between them 81 | // Assumes that a is the leftmost node and the b is the rightmost node. 82 | func (net *Network) InsertNode(a, b NeuronIndex, newNodeIndex NeuronIndex) error { 83 | // This is done by first checking that a is an input node to b, 84 | // then setting newNode to be an input node to b, 85 | // then setting a to be an input node to a. 86 | 87 | // TODO: When a neuron is inserted, the input index 88 | 89 | // Sort the nodes by where they place in the diagram 90 | a, b, arbitrary := net.LeftRight(a, b) 91 | if arbitrary { 92 | if a == b { 93 | return errors.New("insert node: the a and b nodes are the same") 94 | } 95 | if net.IsInput(a) && net.IsInput(b) { 96 | return errors.New("insert node: both node a and b are input nodes") 97 | } 98 | return errors.New("insert node: arbitrary ordering when inserting a node") 99 | } 100 | 101 | // This should never happen 102 | if a == net.OutputNode { 103 | panic("implementation error: the leftmost node is an output node and this was not cought earlier") 104 | } 105 | 106 | // b already has a as an input (a -> b) 107 | if net.AllNodes[b].HasInput(a) { 108 | // Remove the old connection 109 | if err := net.AllNodes[b].RemoveInput(a); err != nil { 110 | return errors.New("error in InsertNode b.RemoveInput(a): " + err.Error()) 111 | } 112 | } 113 | 114 | // b already has newNodeIndex as an input (newIndex -> b) 115 | if net.AllNodes[b].HasInput(newNodeIndex) { 116 | // Remove the old connection 117 | if err := net.AllNodes[b].RemoveInput(a); err != nil { 118 | return errors.New("error in InsertNode b.RemoveInput(a): " + err.Error()) 119 | } 120 | } 121 | 122 | // Connect the new node to b 123 | if err := net.AllNodes[b].AddInput(newNodeIndex); err != nil { 124 | // This does not kick in, the problem must be in AddInput! 125 | return errors.New("error in InsertNode b.AddInput(newNode): " + err.Error()) 126 | } 127 | 128 | // Connect a to the new node 129 | if err := net.AllNodes[newNodeIndex].AddInput(a); err != nil { 130 | return errors.New("error in InsertNode newNode.AddInput(a): " + err.Error()) 131 | } 132 | 133 | // The situation should now be: a -> newNode -> b 134 | return nil 135 | } 136 | 137 | // AddConnection adds a connection from a to b. 138 | // The order is swapped if needed, then a is added as an input to b. 139 | func (net *Network) AddConnection(a, b NeuronIndex) error { 140 | lastIndex := NeuronIndex(len(net.AllNodes) - 1) 141 | if a < 0 || a > lastIndex || b < 0 || b > lastIndex { 142 | return errors.New("index out of range") 143 | } 144 | // Sort the nodes by where they place in the diagram 145 | var arbitrary bool 146 | a, b, arbitrary = net.LeftRight(a, b) 147 | if arbitrary { 148 | if a == b { 149 | return errors.New("can't connect to self") 150 | } 151 | return errors.New("error: arbitrary ordering when adding a connection") 152 | } 153 | // a should not be an output node 154 | if a == net.OutputNode { 155 | return errors.New("error: will not insert a node between the output node and another node") 156 | } 157 | // b should not be an input node 158 | if net.IsInput(b) { 159 | return errors.New("error: b is an input node") 160 | } 161 | // same thing 162 | if net.AllNodes[b].Value != nil { 163 | return errors.New("error: b is an input node") 164 | } 165 | if net.AllNodes[b].HasInput(a) { 166 | return errors.New("error: input already exists") 167 | } 168 | return net.AllNodes[b].AddInput(a) 169 | } 170 | 171 | // RandomizeActivationFunctionForRandomNeuron randomizes the activation function for a randomly selected neuron 172 | func (net *Network) RandomizeActivationFunctionForRandomNeuron() { 173 | chosenNeuronIndex := net.GetRandomNode() 174 | chosenActivationFunctionIndex := ActivationFunctionIndex(rand.Intn(len(ActivationFunctions))) 175 | net.AllNodes[chosenNeuronIndex].ActivationFunction = chosenActivationFunctionIndex 176 | } 177 | 178 | // Evaluate will return a weighted sum of the input nodes, 179 | // using the .Value field if it is set and no input nodes are available. 180 | // A shared weight can be given. 181 | func (net *Network) Evaluate(inputValues []float64) float64 { 182 | inputLength := len(inputValues) 183 | for i, nindex := range net.InputNodes { 184 | if i < inputLength { 185 | net.AllNodes[nindex].SetValue(inputValues[i]) 186 | } 187 | } 188 | maxIterationCounter := inputLength 189 | result, _ := net.AllNodes[net.OutputNode].evaluate(net.Weight, &maxIterationCounter) 190 | return result 191 | } 192 | 193 | // SetWeight will set a shared weight for the entire network 194 | func (net *Network) SetWeight(weight float64) { 195 | net.Weight = weight 196 | } 197 | 198 | // LeftRight returns two neurons, such that the first on is the one that is 199 | // most to the left (towards the input neurons) and the second one is most to 200 | // the right (towards the output neuron). Assumes that a and b are not equal. 201 | // The returned bool is true if there is no order (if the nodes are equal, both are output nodes or both are input nodes) 202 | func (net *Network) LeftRight(a, b NeuronIndex) (NeuronIndex, NeuronIndex, bool) { 203 | // First check if they are equal 204 | if a == b { 205 | return a, b, true // Arbitrary order 206 | } 207 | // First check the network output nodes 208 | if a == net.OutputNode && b == net.OutputNode { 209 | return a, b, true // Arbitrary order 210 | } 211 | if a == net.OutputNode && b != net.OutputNode { 212 | return b, a, false // Swap order 213 | } 214 | if a != net.OutputNode && b == net.OutputNode { 215 | return a, b, false // Same order 216 | } 217 | // Then check if the nodes are already connected 218 | if net.AllNodes[a].In(net.AllNodes[b].InputNodes) { 219 | return a, b, false // Same order 220 | } 221 | if net.AllNodes[b].In(net.AllNodes[a].InputNodes) { 222 | return b, a, false // Swap order 223 | } 224 | // Then check the input nodes of the network 225 | aIsNetworkInputNode := net.AllNodes[a].In(net.InputNodes) 226 | bIsNetworkInputNode := net.AllNodes[b].In(net.InputNodes) 227 | if aIsNetworkInputNode && !bIsNetworkInputNode { 228 | return a, b, false // Same order 229 | } 230 | if !aIsNetworkInputNode && bIsNetworkInputNode { 231 | return b, a, false // Swap order 232 | } 233 | if aIsNetworkInputNode && bIsNetworkInputNode { 234 | return a, b, true // Arbitrary order 235 | } 236 | // Then check the distance from the output node, in steps 237 | aDistance := net.AllNodes[a].distanceFromOutputNode 238 | bDistance := net.AllNodes[b].distanceFromOutputNode 239 | if bDistance > aDistance { 240 | return b, a, false // Swap order, b is further away from the output node, which (usually) means further left in the graph 241 | } 242 | // Everything else 243 | return a, b, false 244 | } 245 | 246 | // Depth returns the maximum connection distance from the output node 247 | func (net *Network) Depth() int { 248 | maxDepth := 0 249 | net.ForEachConnected(func(n *Neuron) { 250 | if n.distanceFromOutputNode > maxDepth { 251 | maxDepth = n.distanceFromOutputNode 252 | } 253 | }) 254 | return maxDepth 255 | } 256 | 257 | // All returns a slice with pointers to all nodes in this network 258 | func (net *Network) All() []*Neuron { 259 | allNodes := make([]*Neuron, 0) 260 | for i := range net.AllNodes { 261 | allNodes = append(allNodes, &net.AllNodes[i]) 262 | } 263 | // Return pointers to all nodes in this network 264 | return allNodes 265 | } 266 | 267 | // GetRandomNode will select a random neuron. 268 | // This can be any node, including the output node. 269 | func (net *Network) GetRandomNode() NeuronIndex { 270 | return NeuronIndex(rand.Intn(len(net.AllNodes))) 271 | } 272 | 273 | // GetRandomInputNode returns a random input node 274 | func (net *Network) GetRandomInputNode() NeuronIndex { 275 | inputPosition := rand.Intn(len(net.InputNodes)) 276 | inputNodeIndex := net.InputNodes[inputPosition] 277 | return inputNodeIndex 278 | } 279 | 280 | // Combine will combine two lists of indices 281 | func Combine(a, b []NeuronIndex) []NeuronIndex { 282 | lena := len(a) 283 | lenb := len(b) 284 | // Allocate the exact size needed 285 | res := make([]NeuronIndex, lena+lenb) 286 | // Add the elements from a 287 | for i := 0; i < lena; i++ { 288 | res[i] = a[i] 289 | } 290 | // Add the elements from b 291 | for i := 0; i < lenb; i++ { 292 | res[i+lena] = b[i] 293 | } 294 | return res 295 | } 296 | 297 | // getAllNodes is a helper function for the recursive network traversal. 298 | // Given the output node and the number 0, it will return a slice of all 299 | // connected nodes, where the distance from the output node has been stored in 300 | // node.distanceFromOutputNode. 301 | func (net *Network) getAllConnectedNodes(nodeIndex NeuronIndex, distanceFromFirstNode int, alreadyHaveThese []NeuronIndex) []NeuronIndex { 302 | allNodes := make([]NeuronIndex, 0, len(net.AllNodes)) 303 | node := net.AllNodes[nodeIndex] 304 | if nodeIndex != net.OutputNode { 305 | node.distanceFromOutputNode = distanceFromFirstNode 306 | net.AllNodes[nodeIndex] = node 307 | } 308 | if !node.In(alreadyHaveThese) { 309 | allNodes = append(allNodes, nodeIndex) 310 | } 311 | for _, inputNodeIndex := range node.InputNodes { 312 | if node.Is(inputNodeIndex) { 313 | panic("implementation error: node is input node to self") 314 | } 315 | if int(inputNodeIndex) >= len(net.AllNodes) { 316 | continue 317 | } 318 | inputNode := net.AllNodes[inputNodeIndex] 319 | if !inputNode.In(allNodes) && !inputNode.In(alreadyHaveThese) { 320 | allNodes = Combine(allNodes, net.getAllConnectedNodes(inputNodeIndex, distanceFromFirstNode+1, append(allNodes, alreadyHaveThese...))) 321 | } 322 | } 323 | return allNodes 324 | } 325 | 326 | // ForEachConnected will only go through nodes that are connected to the output node (directly or indirectly) 327 | // Unconnected input nodes are not covered. 328 | func (net *Network) ForEachConnected(f func(n *Neuron)) { 329 | // Start at the output node, traverse left towards the input nodes 330 | // The network has a counter for how many nodes has been added/removed, for quick memory allocation here 331 | // the final slice is to avoid circular connections 332 | for _, nodeIndex := range net.getAllConnectedNodes(net.OutputNode, 0, []NeuronIndex{}) { 333 | f(&(net.AllNodes[nodeIndex])) 334 | } 335 | } 336 | 337 | // Connected returns a slice of neuron indexes, that are all connected to the output node (directly or indirectly) 338 | func (net *Network) Connected() []NeuronIndex { 339 | allConnected := make([]NeuronIndex, 0, len(net.AllNodes)) // Use a bit more memory, but don't allocate at every iteration 340 | net.ForEachConnectedNodeIndex(func(ni NeuronIndex) { 341 | allConnected = append(allConnected, ni) 342 | }) 343 | return allConnected 344 | } 345 | 346 | // Unconnected returns a slice of all unconnected neurons 347 | func (net *Network) Unconnected() []NeuronIndex { 348 | connected := net.Connected() 349 | // TODO: Benchmark if using len(net.AllNodes) here is faster or not 350 | unconnected := make([]NeuronIndex, 0, len(net.AllNodes)) 351 | for i, node := range net.AllNodes { 352 | if !node.In(connected) { 353 | unconnected = append(unconnected, NeuronIndex(i)) 354 | } 355 | } 356 | return unconnected 357 | } 358 | 359 | // ForEachConnectedNodeIndex will only go through nodes that are connected to the output node (directly or indirectly) 360 | // Unconnected input nodes are not covered. 361 | func (net *Network) ForEachConnectedNodeIndex(f func(ni NeuronIndex)) { 362 | net.ForEachConnected(func(n *Neuron) { 363 | f(n.neuronIndex) 364 | }) 365 | } 366 | 367 | // InsertRandomNode will try the given number of times to insert a node in a random location, 368 | // replacing an existing connection between two nodes. 369 | // `a -> b` will then become `a -> newNode -> b` 370 | // Returns true if one was inserted or false if the randomly chosen location wasn't fruitful 371 | func (net *Network) InsertRandomNode() bool { 372 | 373 | // Find a random node among the nodes that are connected to the output node (directly or indirectly) 374 | connectedNodes := net.Connected() 375 | randomNodeIndexThatIsConnected := connectedNodes[rand.Intn(len(connectedNodes))] 376 | 377 | // If this is one of the network input nodes, return 378 | if net.IsInput(randomNodeIndexThatIsConnected) { 379 | // Nothing to do here, the input nodes get their input from the input numbers 380 | return false 381 | } 382 | 383 | // If this is the output node, and there are no inputs to the output node, return 384 | if randomNodeIndexThatIsConnected == net.OutputNode && len(net.AllNodes[net.OutputNode].InputNodes) == 0 { 385 | // Nothing to do here, no connections to the output node 386 | return false 387 | } 388 | 389 | // If we arrived here, this node must have input nodes. Choose one at random. 390 | rightIndex := randomNodeIndexThatIsConnected 391 | inputNodes := net.AllNodes[rightIndex].InputNodes 392 | 393 | leftIndex := inputNodes[rand.Intn(len(inputNodes))] 394 | 395 | // We now have a left and right node index, that we know are connected, replace this connection with 396 | // one that goes through an entirely new node. 397 | 398 | // Create a new node and connect it with the left node 399 | newNode, newNodeIndex := net.NewNeuron() 400 | err := newNode.AddInput(leftIndex) 401 | if err != nil { 402 | panic(err) 403 | } 404 | 405 | // Remove the connection to the left node and then connect the right node to the new node 406 | err = net.AllNodes[rightIndex].RemoveInput(leftIndex) 407 | if err != nil { 408 | panic(err) 409 | } 410 | err = net.AllNodes[rightIndex].AddInput(newNodeIndex) 411 | if err != nil { 412 | panic(err) 413 | } 414 | 415 | return true 416 | } 417 | 418 | // UpdateNetworkPointers will update all the node.Net pointers to point to this network 419 | func (net *Network) UpdateNetworkPointers() { 420 | for nodeIndex := range net.AllNodes { 421 | net.AllNodes[nodeIndex].Net = net 422 | } 423 | } 424 | 425 | // NewInputNode creates a new input node for this network, optionally connecting it to the output node 426 | func (net *Network) NewInputNode(activationFunction ActivationFunctionIndex, connectToOutput bool) error { 427 | // Create a new node 428 | _, ni := net.NewBlankNeuron() 429 | // Set the activation function 430 | net.AllNodes[ni].ActivationFunction = activationFunction 431 | // Set the parent network 432 | net.AllNodes[ni].Net = net 433 | // Add the new node to the input nodes of the net 434 | net.InputNodes = append(net.InputNodes, ni) 435 | // Connect the input node to the output node 436 | if connectToOutput { 437 | return net.AddConnection(ni, net.OutputNode) 438 | } 439 | return nil 440 | } 441 | 442 | // Copy a Network to a new network 443 | func (net Network) Copy() *Network { 444 | var newNet Network 445 | newNet.AllNodes = make([]Neuron, len(net.AllNodes)) 446 | for nodeIndex := range net.AllNodes { 447 | // This copies the node and also sets the .Net pointer correctly to this network 448 | newNet.AllNodes[nodeIndex] = newNet.AllNodes[nodeIndex].Copy(&newNet) 449 | } 450 | newNet.InputNodes = net.InputNodes 451 | newNet.OutputNode = net.OutputNode 452 | newNet.Weight = net.Weight 453 | 454 | // NOTE: It's important that a pointer to a Network is returned, 455 | // instead of an entire Network struct, so that the .Net pointers in the nodes point correctly. 456 | return &newNet 457 | } 458 | 459 | // String creates a simple and not very useful ASCII representation of the input nodes and the output node. 460 | // Nodes that are not input nodes are skipped. 461 | // Input nodes that are not connected directly to the output node are drawn as non-connected, 462 | // even if they are connected via another node. 463 | func (net Network) String() string { 464 | var sb strings.Builder 465 | sb.WriteString(fmt.Sprintf("Network (%d nodes, %d input nodes, %d output node)\n", len(net.AllNodes), len(net.InputNodes), 1)) 466 | sb.WriteString("\tConnected inputs to output node: " + strconv.Itoa(len(net.AllNodes[net.OutputNode].InputNodes)) + "\n") 467 | for _, node := range net.AllNodes { 468 | sb.WriteString("\t" + node.String() + "\n") 469 | } 470 | return sb.String() 471 | } 472 | -------------------------------------------------------------------------------- /vendor/github.com/xyproto/tinysvg/tinysvg.go: -------------------------------------------------------------------------------- 1 | // Package tinysvg supports generating and writing TinySVG 1.2 images 2 | // 3 | // Some function names are suffixed with "2" if they take structs instead of ints/floats, 4 | // "i" if they take ints and "f" if they take floats. There is no support for multiple dispatch in Go. 5 | // 6 | package tinysvg 7 | 8 | import ( 9 | "bytes" 10 | "errors" 11 | "fmt" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | const ( 17 | TRANSPARENT = 0.0 18 | OPAQUE = 1.0 19 | 20 | YES = 0 21 | NO = 1 22 | AUTO = 2 23 | ) 24 | 25 | type ( 26 | Vec2 struct { 27 | X, Y float64 28 | } 29 | 30 | Pos Vec2 31 | Radius Vec2 32 | 33 | Size struct { 34 | W, H float64 35 | } 36 | 37 | Color struct { 38 | R, G, B int // red, green, blue (0..255) 39 | A float64 // alpha, 0.0..1.0 40 | N string // name (optional, will override the above values) 41 | } 42 | 43 | Font struct { 44 | Family string 45 | Size int 46 | } 47 | 48 | YesNoAuto int 49 | ) 50 | 51 | var ErrPair = errors.New("position pairs must be exactly two comma separated numbers") 52 | 53 | // Create a new TinySVG document, where the width and height is defined in pixels, using the "px" suffix 54 | func NewTinySVG(w, h int) (*Document, *Tag) { 55 | page := NewDocument([]byte(""), []byte(``)) 56 | svg := page.root.AddNewTag([]byte("svg")) 57 | svg.AddAttrib("xmlns", []byte("http://www.w3.org/2000/svg")) 58 | svg.AddAttrib("version", []byte("1.2")) 59 | svg.AddAttrib("baseProfile", []byte("tiny")) 60 | svg.AddAttrib("viewBox", []byte(fmt.Sprintf("%d %d %d %d", 0, 0, w, h))) 61 | svg.AddAttrib("width", []byte(fmt.Sprintf("%dpx", w))) 62 | svg.AddAttrib("height", []byte(fmt.Sprintf("%dpx", h))) 63 | return page, svg 64 | } 65 | 66 | // NewTinySVG2 creates new TinySVG 1.2 image. Pos and Size defines the viewbox 67 | func NewTinySVG2(p *Pos, s *Size) (*Document, *Tag) { 68 | // No page title is needed when building an SVG tag tree 69 | page := NewDocument([]byte(""), []byte(``)) 70 | 71 | // No longer needed for TinySVG 1.2. See: https://www.w3.org/TR/SVGTiny12/intro.html#defining 72 | // 73 | 74 | // Add the root tag 75 | svg := page.root.AddNewTag([]byte("svg")) 76 | svg.AddAttrib("xmlns", []byte("http://www.w3.org/2000/svg")) 77 | svg.AddAttrib("version", []byte("1.2")) 78 | svg.AddAttrib("baseProfile", []byte("tiny")) 79 | svg.AddAttrib("viewBox", []byte(fmt.Sprintf("%f %f %f %f", p.X, p.Y, s.W, s.H))) 80 | return page, svg 81 | } 82 | 83 | // f2b converts the given float64 to a string representation. 84 | // .0 or .000000 suffixes are stripped. 85 | // The string is returned as a byte slice. 86 | func f2b(x float64) []byte { 87 | fs := fmt.Sprintf("%f", x) 88 | fs = strings.TrimSuffix(fs, ".0") 89 | fs = strings.TrimSuffix(fs, ".000000") 90 | return []byte(fs) 91 | } 92 | 93 | // Rect a rectangle, given x and y position, width and height. 94 | // No color is being set. 95 | func (svg *Tag) Rect2(p *Pos, s *Size, c *Color) *Tag { 96 | rect := svg.AddNewTag([]byte("rect")) 97 | rect.AddAttrib("x", f2b(p.X)) 98 | rect.AddAttrib("y", f2b(p.Y)) 99 | rect.AddAttrib("width", f2b(s.W)) 100 | rect.AddAttrib("height", f2b(s.H)) 101 | rect.Fill2(c) 102 | return rect 103 | } 104 | 105 | // RoundedRect2 a rectangle, given x and y position, width and height. 106 | // No color is being set. 107 | func (svg *Tag) RoundedRect2(p *Pos, r *Radius, s *Size, c *Color) *Tag { 108 | rect := svg.AddNewTag([]byte("rect")) 109 | rect.AddAttrib("x", f2b(p.X)) 110 | rect.AddAttrib("y", f2b(p.Y)) 111 | rect.AddAttrib("rx", f2b(r.X)) 112 | rect.AddAttrib("ry", f2b(r.Y)) 113 | rect.AddAttrib("width", f2b(s.W)) 114 | rect.AddAttrib("height", f2b(s.H)) 115 | rect.Fill2(c) 116 | return rect 117 | } 118 | 119 | // Text adds text. No color is being set 120 | func (svg *Tag) Text2(p *Pos, f *Font, message string, c *Color) *Tag { 121 | text := svg.AddNewTag([]byte("text")) 122 | text.AddAttrib("x", f2b(p.X)) 123 | text.AddAttrib("y", f2b(p.Y)) 124 | text.AddAttrib("font-family", []byte(f.Family)) 125 | text.AddAttrib("font-size", []byte(strconv.Itoa(f.Size))) 126 | text.Fill2(c) 127 | text.AddContent([]byte(message)) 128 | return text 129 | } 130 | 131 | // Circle adds a circle, given a position, radius and color 132 | func (svg *Tag) Circle2(p *Pos, radius int, c *Color) *Tag { 133 | circle := svg.AddNewTag([]byte("circle")) 134 | circle.AddAttrib("cx", f2b(p.X)) 135 | circle.AddAttrib("cy", f2b(p.Y)) 136 | circle.AddAttrib("r", []byte(strconv.Itoa(radius))) 137 | circle.Fill2(c) 138 | return circle 139 | } 140 | 141 | // Circle adds a circle, given a position, radius and color 142 | func (svg *Tag) Circlef(p *Pos, radius float64, c *Color) *Tag { 143 | circle := svg.AddNewTag([]byte("circle")) 144 | circle.AddAttrib("cx", f2b(p.X)) 145 | circle.AddAttrib("cy", f2b(p.Y)) 146 | circle.AddAttrib("r", f2b(radius)) 147 | circle.Fill2(c) 148 | return circle 149 | } 150 | 151 | // Ellipse adds an ellipse with a given position (x,y) and radius (rx, ry). 152 | func (svg *Tag) Ellipse2(p *Pos, r *Radius, c *Color) *Tag { 153 | ellipse := svg.AddNewTag([]byte("ellipse")) 154 | ellipse.AddAttrib("cx", f2b(p.X)) 155 | ellipse.AddAttrib("cy", f2b(p.Y)) 156 | ellipse.AddAttrib("rx", f2b(r.X)) 157 | ellipse.AddAttrib("ry", f2b(r.Y)) 158 | ellipse.Fill2(c) 159 | return ellipse 160 | } 161 | 162 | // Line adds a line from (x1, y1) to (x2, y2) with a given stroke width and color 163 | func (svg *Tag) Line2(p1, p2 *Pos, thickness int, c *Color) *Tag { 164 | line := svg.AddNewTag([]byte("line")) 165 | line.AddAttrib("x1", f2b(p1.X)) 166 | line.AddAttrib("y1", f2b(p1.Y)) 167 | line.AddAttrib("x2", f2b(p2.X)) 168 | line.AddAttrib("y2", f2b(p2.Y)) 169 | line.Thickness(thickness) 170 | line.Stroke2(c) 171 | return line 172 | } 173 | 174 | // Triangle adds a colored triangle 175 | func (svg *Tag) Triangle2(p1, p2, p3 *Pos, c *Color) *Tag { 176 | triangle := svg.AddNewTag([]byte("path")) 177 | triangle.AddAttrib("d", []byte(fmt.Sprintf("M %f %f L %f %f L %f %f L %f %f", p1.X, p1.Y, p2.X, p2.Y, p3.X, p3.Y, p1.X, p1.Y))) 178 | triangle.Fill2(c) 179 | return triangle 180 | } 181 | 182 | // Poly2 adds a colored path with 4 points 183 | func (svg *Tag) Poly2(p1, p2, p3, p4 *Pos, c *Color) *Tag { 184 | poly4 := svg.AddNewTag([]byte("path")) 185 | poly4.AddAttrib("d", []byte(fmt.Sprintf("M %f %f L %f %f L %f %f L %f %f L %f %f", p1.X, p1.Y, p2.X, p2.Y, p3.X, p3.Y, p4.X, p4.Y, p1.X, p1.Y))) 186 | poly4.Fill2(c) 187 | return poly4 188 | } 189 | 190 | // Fill selects the fill color that will be used when drawing 191 | func (svg *Tag) Fill2(c *Color) { 192 | // If no color name is given and the color is transparent, don't set a fill color 193 | if (c == nil) || (len(c.N) == 0 && c.A == TRANSPARENT) { 194 | return 195 | } 196 | svg.AddAttrib("fill", c.Bytes()) 197 | } 198 | 199 | // Stroke selects the stroke color that will be used when drawing 200 | func (svg *Tag) Stroke2(c *Color) { 201 | // If no color name is given and the color is transparent, don't set a stroke color 202 | if (c == nil) || (len(c.N) == 0 && c.A == TRANSPARENT) { 203 | return 204 | } 205 | svg.AddAttrib("stroke", c.Bytes()) 206 | } 207 | 208 | // RGBBytes converts r, g and b (integers in the range 0..255) 209 | // to a color string on the form "#nnnnnn", returned as a byte slice. 210 | // May also return colors strings on the form "#nnn". 211 | func RGBBytes(r, g, b int) []byte { 212 | rs := strconv.FormatInt(int64(r), 16) 213 | gs := strconv.FormatInt(int64(g), 16) 214 | bs := strconv.FormatInt(int64(g), 16) 215 | if len(rs) == 1 && len(gs) == 1 && len(bs) == 1 { 216 | // short form 217 | return []byte("#" + rs + gs + bs) 218 | } 219 | // long form 220 | return []byte(fmt.Sprintf("#%02x%02x%02x", r, g, b)) 221 | } 222 | 223 | // RGBABytes converts integers r, g and b (the color) and also 224 | // a given alpha (opacity) to a color-string on the form 225 | // "rgba(255, 255, 255, 1.0)". 226 | func RGBABytes(r, g, b int, a float64) []byte { 227 | return []byte(fmt.Sprintf("rgba(%d, %d, %d, %f)", r, g, b, a)) 228 | } 229 | 230 | // RGBA creates a new Color with the given red, green and blue values. 231 | // The colors are in the range 0..255 232 | func RGB(r, g, b int) *Color { 233 | return &Color{r, g, b, OPAQUE, ""} 234 | } 235 | 236 | // RGBA creates a new Color with the given red, green, blue and alpha values. 237 | // Alpha is between 0 and 1, the rest are 0..255. 238 | // For the alpha value, 0 is transparent and 1 is opaque. 239 | func RGBA(r, g, b int, a float64) *Color { 240 | return &Color{r, g, b, a, ""} 241 | } 242 | 243 | // ColorByName creates a new Color with a given name, like "blue" 244 | func ColorByName(name string) *Color { 245 | return &Color{N: name} 246 | } 247 | 248 | // NewColor is the same as ColorByName 249 | func NewColor(name string) *Color { 250 | return ColorByName(name) 251 | } 252 | 253 | // String returns the color as an RGB (#1234FF) string 254 | // or as an RGBA (rgba(0, 1, 2 ,3)) string. 255 | func (c *Color) Bytes() []byte { 256 | // Return an empty string if nil 257 | if c == nil { 258 | return make([]byte, 0) 259 | } 260 | // Return the name, if specified 261 | if len(c.N) != 0 { 262 | return []byte(c.N) 263 | } 264 | // Return a regular RGB string if alpha is 1.0 265 | if c.A == OPAQUE { 266 | // Generate a rgb string 267 | return RGBBytes(c.R, c.G, c.B) 268 | } 269 | // Generate a rgba string if alpha is < 1.0 270 | return RGBABytes(c.R, c.G, c.B, c.A) 271 | } 272 | 273 | // --- Convenience functions and functions for backward compatibility --- 274 | 275 | func NewTinySVGi(x, y, w, h int) (*Document, *Tag) { 276 | return NewTinySVG2(&Pos{float64(x), float64(y)}, &Size{float64(w), float64(h)}) 277 | } 278 | 279 | func NewTinySVGf(x, y, w, h float64) (*Document, *Tag) { 280 | return NewTinySVG2(&Pos{x, y}, &Size{w, h}) 281 | } 282 | 283 | // AddRect adds a rectangle, given x and y position, width and height. 284 | // No color is being set. 285 | func (svg *Tag) AddRect(x, y, w, h int) *Tag { 286 | return svg.Rect2(&Pos{float64(x), float64(y)}, &Size{float64(w), float64(h)}, nil) 287 | } 288 | 289 | // AddRoundedRect adds a rectangle, given x and y position, radius x, radius y, width and height. 290 | // No color is being set. 291 | func (svg *Tag) AddRoundedRect(x, y, rx, ry, w, h int) *Tag { 292 | return svg.RoundedRect2(&Pos{float64(x), float64(y)}, &Radius{float64(rx), float64(ry)}, &Size{float64(w), float64(h)}, nil) 293 | } 294 | 295 | // AddRectf adds a rectangle, given x and y position, width and height. 296 | // No color is being set. 297 | func (svg *Tag) AddRectf(x, y, w, h float64) *Tag { 298 | return svg.Rect2(&Pos{x, y}, &Size{w, h}, nil) 299 | } 300 | 301 | // AddRoundedRectf adds a rectangle, given x and y position, radius x, radius y, width and height. 302 | // No color is being set. 303 | func (svg *Tag) AddRoundedRectf(x, y, rx, ry, w, h float64) *Tag { 304 | return svg.RoundedRect2(&Pos{x, y}, &Radius{rx, ry}, &Size{w, h}, nil) 305 | } 306 | 307 | // AddText adds text. No color is being set 308 | func (svg *Tag) AddText(x, y, fontSize int, fontFamily, text string) *Tag { 309 | return svg.Text2(&Pos{float64(x), float64(y)}, &Font{fontFamily, fontSize}, text, nil) 310 | } 311 | 312 | // Box adds a rectangle, given x and y position, width, height and color 313 | func (svg *Tag) Box(x, y, w, h int, color string) *Tag { 314 | return svg.Rect2(&Pos{float64(x), float64(y)}, &Size{float64(w), float64(h)}, ColorByName(color)) 315 | } 316 | 317 | // AddCircle adds a circle Add a circle, given a position (x, y) and a radius. 318 | // No color is being set. 319 | func (svg *Tag) AddCircle(x, y, radius int) *Tag { 320 | return svg.Circle2(&Pos{float64(x), float64(y)}, radius, nil) 321 | } 322 | 323 | // AddCirclef adds a circle Add a circle, given a position (x, y) and a radius. 324 | // No color is being set. 325 | func (svg *Tag) AddCirclef(x, y, radius float64) *Tag { 326 | return svg.Circlef(&Pos{x, y}, radius, nil) 327 | } 328 | 329 | // AddEllipse adds an ellipse with a given position (x,y) and radius (rx, ry). 330 | // No color is being set. 331 | func (svg *Tag) AddEllipse(x, y, rx, ry int) *Tag { 332 | return svg.Ellipse2(&Pos{float64(x), float64(y)}, &Radius{float64(rx), float64(ry)}, nil) 333 | } 334 | 335 | // AddEllipsef adds an ellipse with a given position (x,y) and radius (rx, ry). 336 | // No color is being set. 337 | func (svg *Tag) AddEllipsef(x, y, rx, ry float64) *Tag { 338 | return svg.Ellipse2(&Pos{x, y}, &Radius{rx, ry}, nil) 339 | } 340 | 341 | // Line adds a line from (x1, y1) to (x2, y2) with a given stroke width and color 342 | func (svg *Tag) Line(x1, y1, x2, y2, thickness int, color string) *Tag { 343 | return svg.Line2(&Pos{float64(x1), float64(y1)}, &Pos{float64(x2), float64(y2)}, thickness, ColorByName(color)) 344 | } 345 | 346 | // Triangle adds a colored triangle 347 | func (svg *Tag) Triangle(x1, y1, x2, y2, x3, y3 int, color string) *Tag { 348 | return svg.Triangle2(&Pos{float64(x1), float64(y1)}, &Pos{float64(x2), float64(y2)}, &Pos{float64(x3), float64(y3)}, ColorByName(color)) 349 | } 350 | 351 | // Poly4 adds a colored path with 4 points 352 | func (svg *Tag) Poly4(x1, y1, x2, y2, x3, y3, x4, y4 int, color string) *Tag { 353 | return svg.Poly2(&Pos{float64(x1), float64(y1)}, &Pos{float64(x2), float64(y2)}, &Pos{float64(x3), float64(y3)}, &Pos{float64(x4), float64(y4)}, ColorByName(color)) 354 | } 355 | 356 | // Circle adds a circle, given x and y position, radius and color 357 | func (svg *Tag) Circle(x, y, radius int, color string) *Tag { 358 | return svg.Circle2(&Pos{float64(x), float64(y)}, radius, ColorByName(color)) 359 | } 360 | 361 | // Ellipse adds an ellipse, given x and y position, radiuses and color 362 | func (svg *Tag) Ellipse(x, y, xr, yr int, color string) *Tag { 363 | return svg.Ellipse2(&Pos{float64(x), float64(y)}, &Radius{float64(xr), float64(yr)}, ColorByName(color)) 364 | } 365 | 366 | // Fill selects the fill color that will be used when drawing 367 | func (svg *Tag) Fill(color string) { 368 | svg.AddAttrib("fill", []byte(color)) 369 | } 370 | 371 | // ColorBytes converts r, g and b (integers in the range 0..255) 372 | // to a color string on the form "#nnnnnn". 373 | func ColorBytes(r, g, b int) []byte { 374 | return RGB(r, g, b).Bytes() 375 | } 376 | 377 | // ColorBytesAlpha converts integers r, g and b (the color) and also 378 | // a given alpha (opacity) to a color-string on the form 379 | // "rgba(255, 255, 255, 1.0)". 380 | func ColorBytesAlpha(r, g, b int, a float64) []byte { 381 | return RGBA(r, g, b, a).Bytes() 382 | } 383 | 384 | // Pixel creates a rectangle that is 1 wide with the given color. 385 | // Note that the size of the "pixel" depends on how large the viewBox is. 386 | func (svg *Tag) Pixel(x, y, r, g, b int) *Tag { 387 | return svg.Rect2(&Pos{float64(x), float64(y)}, &Size{1.0, 1.0}, RGB(r, g, b)) 388 | } 389 | 390 | // AlphaDot creates a small circle that can be transparent. 391 | // Takes a position (x, y) and a color (r, g, b, a). 392 | func (svg *Tag) AlphaDot(x, y, r, g, b int, a float32) *Tag { 393 | return svg.Circle2(&Pos{float64(x), float64(y)}, 1, RGBA(r, g, b, float64(a))) 394 | } 395 | 396 | // Dot adds a small colored circle. 397 | // Takes a position (x, y) and a color (r, g, b). 398 | func (svg *Tag) Dot(x, y, r, g, b int) *Tag { 399 | return svg.Circle2(&Pos{float64(x), float64(y)}, 1, RGB(r, g, b)) 400 | } 401 | 402 | // Text adds text, with a color 403 | func (svg *Tag) Text(x, y, fontSize int, fontFamily, text, color string) *Tag { 404 | return svg.Text2(&Pos{float64(x), float64(y)}, &Font{fontFamily, fontSize}, text, ColorByName(color)) 405 | } 406 | 407 | // Create a new Yes/No/Auto struct. If auto is true, it overrides the yes value. 408 | func NewYesNoAuto(yes bool, auto bool) YesNoAuto { 409 | if auto { 410 | return AUTO 411 | } 412 | if yes { 413 | return YES 414 | } 415 | return NO 416 | } 417 | 418 | func (yna YesNoAuto) Bytes() []byte { 419 | switch yna { 420 | case YES: 421 | return []byte("true") 422 | case AUTO: 423 | return []byte("auto") 424 | default: 425 | return []byte("false") 426 | } 427 | } 428 | 429 | // Focusable sets the "focusable" attribute to either true, false or auto 430 | // If "auto" is true, it overrides the value of "yes". 431 | func (svg *Tag) Focusable(yes bool, auto bool) { 432 | svg.AddAttrib("focusable", NewYesNoAuto(yes, auto).Bytes()) 433 | } 434 | 435 | // Thickness sets the stroke-width attribute 436 | func (svg *Tag) Thickness(thickness int) { 437 | svg.AddAttrib("stroke-width", []byte(strconv.Itoa(thickness))) 438 | } 439 | 440 | // Polyline adds a set of connected straight lines, an open shape 441 | func (svg *Tag) Polyline(points []*Pos, c *Color) *Tag { 442 | polyline := svg.AddNewTag([]byte("polyline")) 443 | var buf bytes.Buffer 444 | lastIndex := len(points) - 1 445 | for i, p := range points { 446 | buf.Write(f2b(p.X)) 447 | buf.WriteByte(',') 448 | buf.Write(f2b(p.Y)) 449 | if i != lastIndex { 450 | buf.WriteByte(' ') 451 | } 452 | } 453 | polyline.AddAttrib("points", buf.Bytes()) 454 | return polyline 455 | } 456 | 457 | // Polygon adds a set of connected straight lines, a closed shape 458 | func (svg *Tag) Polygon(points []*Pos, c *Color) *Tag { 459 | polygon := svg.AddNewTag([]byte("polygon")) 460 | var buf bytes.Buffer 461 | lastIndex := len(points) - 1 462 | for i, p := range points { 463 | buf.Write(f2b(p.X)) 464 | buf.WriteByte(',') 465 | buf.Write(f2b(p.Y)) 466 | if i != lastIndex { 467 | buf.WriteByte(' ') 468 | } 469 | } 470 | polygon.AddAttrib("points", buf.Bytes()) 471 | polygon.Fill2(c) 472 | return polygon 473 | } 474 | 475 | func NewPos(xString, yString string) (*Pos, error) { 476 | x, err := strconv.ParseFloat(xString, 64) 477 | if err != nil { 478 | return nil, err 479 | } 480 | y, err := strconv.ParseFloat(yString, 64) 481 | if err != nil { 482 | return nil, err 483 | } 484 | return &Pos{x, y}, nil 485 | } 486 | 487 | func NewPosf(x, y float64) *Pos { 488 | return &Pos{x, y} 489 | } 490 | 491 | func PointsFromString(pointString string) ([]*Pos, error) { 492 | points := make([]*Pos, 0) 493 | for _, positionPair := range strings.Split(pointString, " ") { 494 | elements := strings.Split(positionPair, ",") 495 | if len(elements) != 2 { 496 | return nil, ErrPair 497 | } 498 | p, err := NewPos(elements[0], elements[1]) 499 | if err != nil { 500 | return nil, err 501 | } 502 | points = append(points, p) 503 | } 504 | return points, nil 505 | } 506 | 507 | // Describe can be used for adding a description to the SVG header 508 | func (svg *Tag) Describe(description string) { 509 | desc := svg.AddNewTag([]byte("desc")) 510 | desc.AddContent([]byte(description)) 511 | } 512 | --------------------------------------------------------------------------------