├── LICENSE ├── README.md ├── conn.go ├── examples ├── rock-paper-scissor-lizard-spock │ └── main.go └── rock-paper-scissor │ └── main.go ├── img └── rpsls3.jpg ├── net.go └── neuron.go /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Philippe Anel 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # godnn 2 | Deep Neural Network package written in Go programming Language 3 | 4 | Installation 5 | ----------- 6 | 7 | go get github.com/xigh/godnn 8 | 9 | Documentation 10 | ----------- 11 | 12 | This package is very simple: 13 | 14 | * Import the package 15 | 16 | ```go 17 | import "github.com/xigh/godnn" 18 | ``` 19 | 20 | * Create a network instance 21 | 22 | ```go 23 | func Create(topology []uint) (*Net, error) 24 | ``` 25 | 26 | where parameter to dnn.Create is the topology of your neural network layer (ie the number of neuron per layer). First layer is the input, last layer is the output. 27 | 28 | * Train your network: 29 | 30 | ```go 31 | func (net *Net) Train(input, target []float64, rate float64) (float64, error) 32 | ``` 33 | 34 | where input is the input vector, target is the expected result to converge to, rate is learning rate. It returns the average error. 35 | 36 | * Ask you network to predict and answer: 37 | 38 | ```go 39 | func (net *Net) Predict(input []float64) ([]float64, error) 40 | ``` 41 | 42 | It returns the output ... 43 | 44 | Disclaimer 45 | ----------- 46 | 47 | I'm not a AI researcher. I mean I've not studied AI at school, but I often use it at . This is the reason why I wrote this small [IBM Watson SDK in Go](https://github.com/Mediawen/watson-go-sdk). 48 | 49 | I watched the Prof Patrick Henry Winston course at [MIT Open Courseware](http://ocw.mit.edu/courses/electrical-engineering-and-computer-science/6-034-artificial-intelligence-fall-2010/index.htm) along with Yann Lecun videos here and there (especially the course at [Collège de France](http://www.college-de-france.fr/site/yann-lecun/course-2016-02-12-14h30.htm)). 50 | 51 | With this DNN package, I want to learn more how DNN works. My goal is to use it inside our tools we develop for STVHub, our subtitling platform... 52 | 53 | Todo 54 | ----------- 55 | 56 | Testing. 57 | 58 | Better doc. 59 | 60 | Make it more configurable (threshold function, ...). 61 | 62 | Make it more scalable. Use OpenCL/CUDA. 63 | 64 | Make some benchmarks. 65 | 66 | Add more examples (train it with [MNIST DATASET](http://yann.lecun.com/exdb/mnist/)) 67 | 68 | Try RNN (Recurrent Neural Network) with LSTM (Long short-term memory) architecture. 69 | 70 | Learn, learn, study and learn... 71 | 72 | Example 73 | ----------- 74 | 75 | As funny example, I trained this DNN to learn Rock-Paper-Scissors-Lezard-Spock. You can find the rules in Big Bang Theory serie Episode 8, Season 2. Here is the result: 76 | 77 | ![My image](img/rpsls3.jpg) 78 | 79 | License 80 | ----------- 81 | 82 | Copyright (c) 2016, Philippe Anel 83 | All rights reserved. 84 | 85 | Redistribution and use in source and binary forms, with or without 86 | modification, are permitted provided that the following conditions are met: 87 | 88 | * Redistributions of source code must retain the above copyright notice, this 89 | list of conditions and the following disclaimer. 90 | 91 | * Redistributions in binary form must reproduce the above copyright notice, 92 | this list of conditions and the following disclaimer in the documentation 93 | and/or other materials provided with the distribution. 94 | 95 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 96 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 97 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 98 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 99 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 100 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 101 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 102 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 103 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 104 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 105 | -------------------------------------------------------------------------------- /conn.go: -------------------------------------------------------------------------------- 1 | package dnn 2 | 3 | import ( 4 | "math/rand" 5 | "time" 6 | ) 7 | 8 | type Conn struct { 9 | Weight, Delta float64 10 | } 11 | 12 | var crnd *rand.Rand 13 | 14 | func init() { 15 | crnd = rand.New(rand.NewSource(time.Now().Unix())) 16 | } 17 | 18 | func (c* Conn) Init() { 19 | c.Weight = crnd.Float64() * 40 - 20 20 | c.Delta = 0 21 | } 22 | -------------------------------------------------------------------------------- /examples/rock-paper-scissor-lizard-spock/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/xigh/godnn" 5 | "fmt" 6 | "log" 7 | ) 8 | 9 | const ( 10 | rate = 0.45 11 | minErr = 0.01 12 | maxIter = 100000 13 | ) 14 | 15 | var ( 16 | topology = []uint{ 10, 100, 50, 2 } 17 | ) 18 | 19 | const ( 20 | Rock = iota 21 | Paper 22 | Scissors 23 | Lizard 24 | Spock 25 | N 26 | ) 27 | 28 | var ( 29 | names = []string{ "Rock", "Paper", "Scissors", "Lizard", "Spock" } 30 | 31 | COLOR_RESET = "\x1b[39;49;0m" 32 | COLOR_RED = "\x1b[31;1m" 33 | COLOR_GREEN = "\x1b[32;1m" 34 | COLOR_YELLOW = "\x1b[33;1m" 35 | COLOR_OTHER = "\x1b[34;1m" 36 | COLOR_DBLUE = "\x1b[35;1m" 37 | COLOR_BLUE = "\x1b[36;1m" 38 | COLOR_WHITE = "\x1b[37;1m" 39 | ) 40 | 41 | func eval(a, b int) []float64 { 42 | // paper covers rock 43 | if a == Paper && b == Rock { 44 | return []float64{ 1, 0 } 45 | } 46 | if b == Paper && a == Rock { 47 | return []float64{ 0, 1 } 48 | } 49 | 50 | // paper disproves spock 51 | if a == Paper && b == Spock { 52 | return []float64{ 1, 0 } 53 | } 54 | if b == Paper && a == Spock { 55 | return []float64{ 0, 1 } 56 | } 57 | 58 | // rock cruches scissor 59 | if a == Rock && b == Scissors { 60 | return []float64{ 1, 0 } 61 | } 62 | if b == Rock && a == Scissors { 63 | return []float64{ 0, 1 } 64 | } 65 | 66 | // rock cruches lizard 67 | if a == Rock && b == Lizard { 68 | return []float64{ 1, 0 } 69 | } 70 | if b == Rock && a == Lizard { 71 | return []float64{ 0, 1 } 72 | } 73 | 74 | // scissors cuts paper 75 | if a == Scissors && b == Paper { 76 | return []float64{ 1, 0 } 77 | } 78 | if b == Scissors && a == Paper { 79 | return []float64{ 0, 1 } 80 | } 81 | 82 | // scissors decapitates lizard 83 | if a == Scissors && b == Lizard { 84 | return []float64{ 1, 0 } 85 | } 86 | if b == Scissors && a == Lizard { 87 | return []float64{ 0, 1 } 88 | } 89 | 90 | // spock smashes scissors 91 | if a == Spock && b == Scissors { 92 | return []float64{ 1, 0 } 93 | } 94 | if b == Spock && a == Scissors { 95 | return []float64{ 0, 1 } 96 | } 97 | 98 | // spock vaporizes rock 99 | if a == Spock && b == Rock { 100 | return []float64{ 1, 0 } 101 | } 102 | if b == Spock && a == Rock { 103 | return []float64{ 0, 1 } 104 | } 105 | 106 | // lizard poisons spock 107 | if a == Lizard && b == Spock { 108 | return []float64{ 1, 0 } 109 | } 110 | if b == Lizard && a == Spock { 111 | return []float64{ 0, 1 } 112 | } 113 | 114 | // lizard eats paper 115 | if a == Lizard && b == Paper { 116 | return []float64{ 1, 0 } 117 | } 118 | if b == Lizard && a == Paper { 119 | return []float64{ 0, 1 } 120 | } 121 | 122 | return []float64{ 0, 0 } 123 | } 124 | 125 | func test(net *dnn.Net) { 126 | avg := 0.0 127 | nb := 0 128 | for a := 0; a < N; a += 1 { 129 | for b := 0; b < N; b += 1 { 130 | input := make([]float64, N * 2) 131 | input[a] = 1 132 | input[b + N] = 1 133 | 134 | results, err := net.Predict(input) 135 | if err != nil { 136 | log.Fatal(err) 137 | } 138 | 139 | output := eval(a, b) 140 | 141 | if len(output) != len(results) { 142 | log.Fatal("ouput size != results size") 143 | } 144 | 145 | dist := 0.0 146 | for n := range output { 147 | delta := output[n] - results[n] 148 | dist += delta * delta 149 | } 150 | dist /= float64(len(output)) 151 | avg += dist 152 | nb += 1 153 | 154 | fmt.Printf(COLOR_BLUE + "%-10s" + COLOR_RESET, names[a]) 155 | fmt.Printf(" vs ") 156 | fmt.Printf(COLOR_BLUE + "%-10s" + COLOR_RESET, names[b]) 157 | fmt.Printf(" src=" + COLOR_YELLOW + "%v" + COLOR_RESET, input) 158 | fmt.Printf(" res=" + COLOR_OTHER + "%12.7f" + COLOR_RESET, results) 159 | fmt.Printf(" exp=" + COLOR_GREEN + "%v" + COLOR_RESET, output) 160 | fmt.Printf(" err=" + COLOR_RED + "%12.7f%%\n" + COLOR_RESET, dist * 100) 161 | } 162 | } 163 | fmt.Printf("average error: %9.5f%%\n\n", 100 * avg / float64(nb)) 164 | } 165 | 166 | func train(net *dnn.Net, min float64, max uint64) uint64 { 167 | i := uint64(0) 168 | p := 0.0 169 | for { 170 | if i % 1000 == 0 { 171 | fmt.Printf(".") 172 | } 173 | i += 1 174 | 175 | if i > 0 && i % 5000 == 0 { 176 | fmt.Printf("%.2f", p) 177 | } 178 | 179 | if i > max { 180 | fmt.Printf("\ntoo many iterations\n") 181 | break 182 | } 183 | 184 | avg := 0.0 185 | nb := 0 186 | for a := 0; a < N; a += 1 { 187 | for b := 0; b < N; b += 1 { 188 | input := make([]float64, N * 2) 189 | input[a] = 1 190 | input[b + N] = 1 191 | 192 | output := eval(a, b) 193 | 194 | dist, err := net.Train(input, output, rate) 195 | if err != nil { 196 | log.Fatal(err) 197 | } 198 | 199 | avg += dist 200 | nb += 1 201 | } 202 | } 203 | 204 | p = 100 * avg / float64(nb) 205 | if (p < min) { 206 | fmt.Printf("\naverage error=%9.5f%%\n", avg) 207 | break 208 | } 209 | } 210 | return i 211 | } 212 | 213 | func main() { 214 | net, err := dnn.Create(topology) 215 | if err != nil { 216 | log.Fatal(err) 217 | } 218 | 219 | // ------- 220 | 221 | fmt.Printf("topology: %v\n", topology) 222 | 223 | // ------- 224 | 225 | fmt.Printf("test before training:\n") 226 | test(net) 227 | 228 | // ------- 229 | 230 | fmt.Printf("learning [min avg error: %f]:\n", minErr) 231 | itn := train(net, minErr, maxIter) 232 | fmt.Printf(" - %d iterations\n\n", itn) 233 | 234 | // ------- 235 | 236 | fmt.Printf("test after training:\n") 237 | test(net); 238 | } 239 | -------------------------------------------------------------------------------- /examples/rock-paper-scissor/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | // "github.com/xigh/godnn" 5 | "../../../godnn" 6 | "fmt" 7 | "log" 8 | "os" 9 | ) 10 | 11 | const ( 12 | rate = 0.15 13 | minErr = 0.001 14 | maxIter = 100000 15 | ) 16 | 17 | var ( 18 | topology = []uint{ 6, 30, 12, 2 } 19 | ) 20 | 21 | const ( 22 | Rock = iota 23 | Paper 24 | Scissors 25 | N 26 | ) 27 | 28 | var ( 29 | names = []string{ "Rock", "Paper", "Scissors" } 30 | 31 | COLOR_RESET = "\x1b[39;49;0m" 32 | COLOR_RED = "\x1b[31;1m" 33 | COLOR_GREEN = "\x1b[32;1m" 34 | COLOR_YELLOW = "\x1b[33;1m" 35 | COLOR_OTHER = "\x1b[34;1m" 36 | COLOR_DBLUE = "\x1b[35;1m" 37 | COLOR_BLUE = "\x1b[36;1m" 38 | COLOR_WHITE = "\x1b[37;1m" 39 | ) 40 | 41 | func eval(a, b int) []float64 { 42 | // paper covers rock 43 | if a == Paper && b == Rock { 44 | return []float64{ 1, 0 } 45 | } 46 | if b == Paper && a == Rock { 47 | return []float64{ 0, 1 } 48 | } 49 | 50 | // rock cruches scissor 51 | if a == Rock && b == Scissors { 52 | return []float64{ 1, 0 } 53 | } 54 | if b == Rock && a == Scissors { 55 | return []float64{ 0, 1 } 56 | } 57 | 58 | // scissors cuts paper 59 | if a == Scissors && b == Paper { 60 | return []float64{ 1, 0 } 61 | } 62 | if b == Scissors && a == Paper { 63 | return []float64{ 0, 1 } 64 | } 65 | 66 | return []float64{ 0, 0 } 67 | } 68 | 69 | func test(net *dnn.Net) { 70 | avg := 0.0 71 | nb := 0 72 | for a := 0; a < N; a += 1 { 73 | for b := 0; b < N; b += 1 { 74 | input := make([]float64, N * 2) 75 | input[a] = 1 76 | input[b + N] = 1 77 | 78 | results, err := net.Predict(input) 79 | if err != nil { 80 | log.Fatal(err) 81 | } 82 | 83 | output := eval(a, b) 84 | 85 | if len(output) != len(results) { 86 | log.Fatal("ouput size != results size") 87 | } 88 | 89 | dist := 0.0 90 | for n := range output { 91 | delta := output[n] - results[n] 92 | dist += delta * delta 93 | } 94 | dist /= float64(len(output)) 95 | avg += dist 96 | nb += 1 97 | 98 | fmt.Printf(COLOR_BLUE + "%-10s" + COLOR_RESET, names[a]) 99 | fmt.Printf(" vs ") 100 | fmt.Printf(COLOR_BLUE + "%-10s" + COLOR_RESET, names[b]) 101 | fmt.Printf(" src=" + COLOR_YELLOW + "%v" + COLOR_RESET, input) 102 | fmt.Printf(" res=" + COLOR_OTHER + "%12.7f" + COLOR_RESET, results) 103 | fmt.Printf(" exp=" + COLOR_GREEN + "%v" + COLOR_RESET, output) 104 | fmt.Printf(" err=" + COLOR_RED + "%12.7f%%\n" + COLOR_RESET, dist * 100) 105 | } 106 | } 107 | fmt.Printf("average error: %9.5f%%\n\n", 100 * avg / float64(nb)) 108 | } 109 | 110 | func train(net *dnn.Net, min float64, max uint64) uint64 { 111 | i := uint64(0) 112 | p := 0.0 113 | for { 114 | if i % 1000 == 0 { 115 | fmt.Printf(".") 116 | } 117 | i += 1 118 | 119 | if i > 0 && i % 5000 == 0 { 120 | fmt.Printf("%.2f", p) 121 | } 122 | 123 | if i > max { 124 | fmt.Printf("\ntoo many iterations\n") 125 | break 126 | } 127 | 128 | avg := 0.0 129 | nb := 0 130 | for a := 0; a < N; a += 1 { 131 | for b := 0; b < N; b += 1 { 132 | input := make([]float64, N * 2) 133 | input[a] = 1 134 | input[b + N] = 1 135 | 136 | output := eval(a, b) 137 | 138 | dist, err := net.Train(input, output, rate) 139 | if err != nil { 140 | log.Fatal(err) 141 | } 142 | 143 | avg += dist 144 | nb += 1 145 | } 146 | } 147 | 148 | p = 100 * avg / float64(nb) 149 | if (p < min) { 150 | fmt.Printf("\naverage error=%9.5f%%\n", avg) 151 | break 152 | } 153 | } 154 | return i 155 | } 156 | 157 | func learnAndSave(name string) error { 158 | net, err := dnn.Create(topology) 159 | if err != nil { 160 | return err 161 | } 162 | 163 | // ------- 164 | 165 | fmt.Printf("topology: %v\n", topology) 166 | 167 | // ------- 168 | 169 | fmt.Printf("test before training:\n") 170 | test(net) 171 | 172 | // ------- 173 | 174 | fmt.Printf("learning [min avg error: %f]:\n", minErr) 175 | itn := train(net, minErr, maxIter) 176 | fmt.Printf(" - %d iterations\n\n", itn) 177 | 178 | // ------- 179 | 180 | fmt.Printf("test after training:\n") 181 | test(net); 182 | 183 | w, err := os.Create(name) 184 | if err != nil { 185 | return err 186 | } 187 | defer w.Close() 188 | 189 | return net.Save(w) 190 | } 191 | 192 | func loadAndTest(name string) error { 193 | r, err := os.Open(name) 194 | if err != nil { 195 | return err 196 | } 197 | defer r.Close() 198 | 199 | net, err := dnn.Load(r) 200 | if err != nil { 201 | return err 202 | } 203 | 204 | // ------- 205 | 206 | fmt.Printf("topology: %v\n", net.Topology()) 207 | 208 | // ------- 209 | 210 | fmt.Printf("test before training:\n") 211 | test(net) 212 | 213 | return nil 214 | } 215 | 216 | func main() { 217 | err := learnAndSave("network.json") 218 | if err != nil { 219 | log.Fatal(err) 220 | } 221 | 222 | err = loadAndTest("network.json") 223 | if err != nil { 224 | log.Fatal(err) 225 | } 226 | } 227 | -------------------------------------------------------------------------------- /img/rpsls3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xigh/godnn/4c6b87a33a3c1173b1fb47dc9bb5712d6a4b33d2/img/rpsls3.jpg -------------------------------------------------------------------------------- /net.go: -------------------------------------------------------------------------------- 1 | package dnn 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "encoding/json" 7 | ) 8 | 9 | type Layer struct { 10 | Neurons []Neuron 11 | } 12 | 13 | type Net struct { 14 | Layers []Layer 15 | } 16 | 17 | func Create(topology []uint) (*Net, error) { 18 | layerCount := len(topology) 19 | if layerCount < 3 { 20 | return nil, fmt.Errorf("You need more than 2 layers") 21 | } 22 | 23 | layers := make([]Layer, layerCount) 24 | 25 | for l := range layers { 26 | neurons := make([]Neuron, topology[l] + 1) 27 | connCount := uint(0) 28 | if l < layerCount - 1 { 29 | connCount = topology[l + 1] 30 | } 31 | for n := range neurons { 32 | neurons[n].Init(connCount, uint(n)) 33 | } 34 | layers[l].Neurons = neurons 35 | } 36 | 37 | return &Net{ 38 | Layers: layers, 39 | }, nil 40 | } 41 | 42 | func Load(r io.Reader) (*Net, error) { 43 | dec := json.NewDecoder(r) 44 | var net Net 45 | err := dec.Decode(&net) 46 | return &net, err 47 | } 48 | 49 | func (net *Net) feed(input []float64) (*Layer, error) { 50 | layerCount := len(net.Layers) 51 | if layerCount < 3 { 52 | return nil, fmt.Errorf("Unexpected layer count") 53 | } 54 | 55 | first := &net.Layers[0] 56 | if len(input) != len(first.Neurons) - 1 { 57 | return nil, fmt.Errorf("input size different than 1st layer size") 58 | } 59 | 60 | for i := range input { 61 | first.Neurons[i].Output = input[i] 62 | } 63 | 64 | for l := 1; l < len(net.Layers); l += 1 { 65 | prev := &net.Layers[l - 1] 66 | curr := &net.Layers[l] 67 | 68 | for n := 0; n < len(curr.Neurons) - 1; n += 1 { 69 | neuron := &curr.Neurons[n] 70 | neuron.Feed(prev) 71 | } 72 | } 73 | 74 | return &net.Layers[layerCount - 1], nil 75 | } 76 | 77 | func (net *Net) Train(input, target []float64, rate float64) (float64, error) { 78 | last, err := net.feed(input) 79 | if err != nil { 80 | return -1, err 81 | } 82 | 83 | if len(target) != len(last.Neurons) - 1 { 84 | return -1, fmt.Errorf("target size different than last layer size") 85 | } 86 | 87 | dist := 0.0 88 | for n := 0; n < len(last.Neurons) - 1; n += 1 { 89 | neuron := &last.Neurons[n] 90 | delta := target[n] - neuron.Output 91 | dist += delta * delta 92 | neuron.updateGradient(target[n]) 93 | } 94 | 95 | // back propagation 96 | for l := len(net.Layers) - 2; l > 0; l -= 1 { 97 | layer := &net.Layers[l] 98 | next := &net.Layers[l + 1] 99 | 100 | for n := range layer.Neurons { 101 | neuron := &layer.Neurons[n] 102 | neuron.deriveGradients(next) 103 | } 104 | } 105 | 106 | // update weights 107 | for l := len(net.Layers) - 2; l > 0; l -= 1 { 108 | layer := &net.Layers[l] 109 | prev := &net.Layers[l - 1] 110 | 111 | for n := 0; n < len(layer.Neurons) - 1; n += 1 { 112 | neuron := &layer.Neurons[n] 113 | neuron.updateWeight(prev, rate) 114 | } 115 | } 116 | 117 | return dist, nil 118 | } 119 | 120 | func (net *Net) Predict(input []float64) ([]float64, error) { 121 | last, err := net.feed(input) 122 | if err != nil { 123 | return nil, err 124 | } 125 | 126 | sz := len(last.Neurons) - 1 127 | result := make([]float64, sz) 128 | for n := 0; n < sz; n += 1 { 129 | neuron := &last.Neurons[n] 130 | result[n] = neuron.Output 131 | } 132 | 133 | return result, nil 134 | } 135 | 136 | func (net *Net) Save(w io.Writer) error { 137 | enc := json.NewEncoder(w) 138 | return enc.Encode(net) 139 | } 140 | 141 | func (net *Net) Topology() []uint { 142 | var t []uint 143 | for l := range net.Layers { 144 | t = append(t, uint(len(net.Layers[l].Neurons) - 1)) 145 | } 146 | return t 147 | } 148 | -------------------------------------------------------------------------------- /neuron.go: -------------------------------------------------------------------------------- 1 | package dnn 2 | 3 | import ( 4 | "math" 5 | ) 6 | 7 | func Sigmoid(x float64) float64 { 8 | return 1.0 / (1.0 + math.Exp(-x)) 9 | } 10 | 11 | func SigmoidDerive(x float64) float64 { 12 | e := math.Exp(-x) 13 | return e / math.Pow(1 + e, 2) 14 | } 15 | 16 | type Neuron struct { 17 | Index uint 18 | Conns []Conn 19 | Output float64 20 | gradient float64 21 | } 22 | 23 | func (neuron *Neuron) Init(connCount, index uint) { 24 | conns := make([]Conn, connCount) 25 | for c := range conns { 26 | conns[c].Init() 27 | } 28 | neuron.Index = index 29 | neuron.Conns = conns 30 | neuron.Output = 0 31 | neuron.gradient = 0 32 | } 33 | 34 | func (neuron *Neuron) Feed(prevLayer *Layer) { 35 | sum := 0.0 36 | idx := neuron.Index 37 | for n := range prevLayer.Neurons { 38 | prevNeuron := &prevLayer.Neurons[n] 39 | sum += prevNeuron.Output * prevNeuron.Conns[idx].Weight 40 | } 41 | neuron.Output = Sigmoid(sum) 42 | } 43 | 44 | func (neuron *Neuron) updateGradient(target float64) { 45 | delta := target - neuron.Output 46 | neuron.gradient = delta * SigmoidDerive(neuron.Output) 47 | } 48 | 49 | func (neuron *Neuron) deriveGradients(nextLayer *Layer) { 50 | sum := 0.0 51 | for n := 0; n < len(nextLayer.Neurons) - 1; n += 1 { 52 | tmp := &nextLayer.Neurons[n] 53 | sum += neuron.Conns[n].Weight * tmp.gradient 54 | } 55 | 56 | neuron.gradient = sum * SigmoidDerive(neuron.Output) 57 | } 58 | 59 | func (neuron *Neuron) updateWeight(prevLayer *Layer, rate float64) { 60 | for n := range prevLayer.Neurons { 61 | tmp := &prevLayer.Neurons[n] 62 | conn := &tmp.Conns[neuron.Index] 63 | 64 | // conn.Delta is too big ... have to read more on this ... 65 | Delta := rate * tmp.Output * neuron.gradient + conn.Delta * .2 66 | conn.Delta = Delta 67 | conn.Weight += Delta 68 | } 69 | } 70 | --------------------------------------------------------------------------------