├── go.sum ├── go.mod ├── examples ├── go.mod ├── go.sum ├── examples_test.go ├── load.go ├── simple.go ├── ff.network └── save.go ├── .gitignore ├── util.go ├── .github └── workflows │ └── go.yml ├── persist ├── persist_test.go └── persist.go ├── LICENSE ├── feedforward_test.go ├── README.md └── feedforward.go /go.sum: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/goml/gobrain 2 | 3 | go 1.13 4 | -------------------------------------------------------------------------------- /examples/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/goml/gobrain/example 2 | 3 | go 1.13 4 | 5 | require github.com/goml/gobrain v0.0.0-20200606141943-08de5fe3f708 6 | -------------------------------------------------------------------------------- /examples/go.sum: -------------------------------------------------------------------------------- 1 | github.com/goml/gobrain v0.0.0-20200606141943-08de5fe3f708 h1:WnaVXug9uxxlMjKEpridlU36EINiN7c0UixfbCPom/8= 2 | github.com/goml/gobrain v0.0.0-20200606141943-08de5fe3f708/go.mod h1:3C9khhaMIcUpCkbwcG7fVaqIjUOOnRxRUvZ+btWi+5A= 3 | -------------------------------------------------------------------------------- /examples/examples_test.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | func ExampleSimple() { 4 | Simple() 5 | 6 | // Output: 7 | // [0.09740879532462123] 8 | } 9 | 10 | func ExampleLoad() { 11 | Load("ff.network") 12 | 13 | // Output: 14 | // [0.09740879532462095] 15 | } 16 | 17 | func ExampleSave() { 18 | filename := "_saved.network" 19 | Save(filename) 20 | Load(filename) 21 | 22 | // Output: 23 | // [0.09740879532462123] 24 | // [0.09740879532462123] 25 | } 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | 25 | test.go 26 | examples/01/01 27 | examples/02/02 28 | examples/03/03 29 | 30 | .idea 31 | 32 | examples/_saved.network -------------------------------------------------------------------------------- /util.go: -------------------------------------------------------------------------------- 1 | package gobrain 2 | 3 | import ( 4 | "math" 5 | "math/rand" 6 | ) 7 | 8 | func random(a, b float64) float64 { 9 | return (b-a)*rand.Float64() + a 10 | } 11 | 12 | func matrix(I, J int) [][]float64 { 13 | m := make([][]float64, I) 14 | for i := 0; i < I; i++ { 15 | m[i] = make([]float64, J) 16 | } 17 | return m 18 | } 19 | 20 | func vector(I int, fill float64) []float64 { 21 | v := make([]float64, I) 22 | for i := 0; i < I; i++ { 23 | v[i] = fill 24 | } 25 | return v 26 | } 27 | 28 | func sigmoid(x float64) float64 { 29 | return 1 / (1 + math.Exp(-x)) 30 | } 31 | 32 | func dsigmoid(y float64) float64 { 33 | return y * (1 - y) 34 | } 35 | -------------------------------------------------------------------------------- /examples/load.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | 8 | "github.com/goml/gobrain" 9 | "github.com/goml/gobrain/persist" 10 | ) 11 | 12 | func Load(filename string) { 13 | // set the random seed to 0 14 | rand.Seed(0) 15 | 16 | // instantiate the Feed Forward 17 | ff := &gobrain.FeedForward{} 18 | 19 | err := persist.Load(filename, &ff) 20 | if err != nil { 21 | log.Println("impossible to load network from file: ", err.Error()) 22 | } 23 | 24 | // sends inputs to the neural network 25 | inputs := []float64{1, 1} 26 | 27 | // saves the result 28 | result := ff.Update(inputs) 29 | 30 | // prints the result 31 | fmt.Println(result) 32 | } 33 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.13 20 | 21 | - name: Check out code into the Go module directory 22 | uses: actions/checkout@v2 23 | 24 | - name: Get dependencies 25 | run: | 26 | go get -v -t -d ./... 27 | 28 | - name: Build 29 | run: go build -v ./... 30 | 31 | - name: Test 32 | run: go test -v ./... 33 | 34 | - name: Examples Test 35 | run: cd examples && go test -v ./... 36 | -------------------------------------------------------------------------------- /persist/persist_test.go: -------------------------------------------------------------------------------- 1 | package persist 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "os" 8 | ) 9 | 10 | const value string = `{"foo": "bar"}` 11 | 12 | var storingFile = "./persiststorefile_tobedeleted" 13 | 14 | func ExampleSave() { 15 | defer os.Remove(storingFile) 16 | 17 | Save(storingFile, value) 18 | 19 | fileRead, _ := ioutil.ReadFile(storingFile) 20 | str := string(fileRead) 21 | fmt.Printf(str) 22 | // Output: 23 | // "{\"foo\": \"bar\"}" 24 | } 25 | 26 | func ExampleLoad() { 27 | defer os.Remove(storingFile) 28 | 29 | ioutil.WriteFile(storingFile, []byte(value), 0666) 30 | 31 | var fileLoaded interface{} 32 | err := Load(storingFile, &fileLoaded) 33 | if err != nil { 34 | log.Println(err.Error()) 35 | } 36 | fmt.Println(fileLoaded) 37 | // Output: 38 | // map[foo:bar] 39 | } 40 | -------------------------------------------------------------------------------- /examples/simple.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | 7 | "github.com/goml/gobrain" 8 | ) 9 | 10 | func Simple() { 11 | // set the random seed to 0 12 | rand.Seed(0) 13 | 14 | // create the XOR representation patter to train the network 15 | patterns := [][][]float64{ 16 | {{0, 0}, {0}}, 17 | {{0, 1}, {1}}, 18 | {{1, 0}, {1}}, 19 | {{1, 1}, {0}}, 20 | } 21 | 22 | // instantiate the Feed Forward 23 | ff := &gobrain.FeedForward{} 24 | 25 | // initialize the Neural Network; 26 | // the networks structure will contain: 27 | // 2 inputs, 2 hidden nodes and 1 output. 28 | ff.Init(2, 2, 1) 29 | 30 | // train the network using the XOR patterns 31 | // the training will run for 1000 epochs 32 | // the learning rate is set to 0.6 and the momentum factor to 0.4 33 | // use true in the last parameter to receive reports about the learning error 34 | ff.Train(patterns, 1000, 0.6, 0.4, false) 35 | 36 | // inputs to send to the neural network 37 | inputs := []float64{1, 1} 38 | 39 | // saves the result 40 | result := ff.Update(inputs) 41 | 42 | // prints the result 43 | fmt.Println(result) 44 | } 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014 Jonas Trevisan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /feedforward_test.go: -------------------------------------------------------------------------------- 1 | package gobrain 2 | 3 | import ( 4 | // "testing" 5 | "math/rand" 6 | ) 7 | 8 | func ExampleFeedForward() { 9 | // set the random seed to 0 10 | rand.Seed(0) 11 | 12 | // create the XOR representation patter to train the network 13 | patterns := [][][]float64{ 14 | {{0, 0}, {0}}, 15 | {{0, 1}, {1}}, 16 | {{1, 0}, {1}}, 17 | {{1, 1}, {0}}, 18 | } 19 | 20 | // instantiate the Feed Forward 21 | ff := &FeedForward{} 22 | 23 | // initialize the Neural Network; 24 | // the networks structure will contain: 25 | // 2 inputs, 2 hidden nodes and 1 output. 26 | ff.Init(2, 2, 1) 27 | 28 | // train the network using the XOR patterns 29 | // the training will run for 1000 epochs 30 | // the learning rate is set to 0.6 and the momentum factor to 0.4 31 | // use true in the last parameter to receive reports about the learning error 32 | ff.Train(patterns, 1000, 0.6, 0.4, false) 33 | 34 | // testing the network 35 | ff.Test(patterns) 36 | 37 | // predicting a value 38 | inputs := []float64{1, 1} 39 | ff.Update(inputs) 40 | 41 | // Output: 42 | // [0 0] -> [0.057503945708445206] : [0] 43 | // [0 1] -> [0.9301006350712101] : [1] 44 | // [1 0] -> [0.9278099662272838] : [1] 45 | // [1 1] -> [0.09740879532462123] : [0] 46 | } 47 | -------------------------------------------------------------------------------- /examples/ff.network: -------------------------------------------------------------------------------- 1 | { 2 | "NInputs": 3, 3 | "NHiddens": 3, 4 | "NOutputs": 1, 5 | "Regression": false, 6 | "InputActivations": [ 7 | 1, 8 | 1, 9 | 1 10 | ], 11 | "HiddenActivations": [ 12 | 0.12352598522727118, 13 | 0.000036082505119834765, 14 | 1 15 | ], 16 | "OutputActivations": [ 17 | 0.09821937121962863 18 | ], 19 | "Contexts": null, 20 | "InputWeights": [ 21 | [ 22 | -3.4880623577856333, 23 | -5.743463781786981, 24 | 0.31191253039081035 25 | ], 26 | [ 27 | -3.5542287917518895, 28 | -6.5511516639871665, 29 | -0.42103913368681445 30 | ], 31 | [ 32 | 5.074841058472211, 33 | 2.0642400505525687, 34 | 0.794339426299602 35 | ] 36 | ], 37 | "OutputWeights": [ 38 | [ 39 | 7.017755827192258 40 | ], 41 | [ 42 | -7.532219157078145 43 | ], 44 | [ 45 | -3.0869008820533197 46 | ] 47 | ], 48 | "InputChanges": [ 49 | [ 50 | -0.0066089076107655595, 51 | 0.0000023642965484966626, 52 | 0 53 | ], 54 | [ 55 | -0.0066089076107655595, 56 | 0.0000023642965484966626, 57 | 0 58 | ], 59 | [ 60 | -0.0066089076107655595, 61 | 0.0000023642965484966626, 62 | 0 63 | ] 64 | ], 65 | "OutputChanges": [ 66 | [ 67 | -0.0010746165566903239 68 | ], 69 | [ 70 | -3.139004100011613e-7 71 | ], 72 | [ 73 | -0.008699518200265103 74 | ] 75 | ] 76 | } -------------------------------------------------------------------------------- /examples/save.go: -------------------------------------------------------------------------------- 1 | package examples 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "math/rand" 7 | 8 | "github.com/goml/gobrain" 9 | "github.com/goml/gobrain/persist" 10 | ) 11 | 12 | func Save(filename string) { 13 | // set the random seed to 0 14 | rand.Seed(0) 15 | 16 | // create the XOR representation patter to train the network 17 | patterns := [][][]float64{ 18 | {{0, 0}, {0}}, 19 | {{0, 1}, {1}}, 20 | {{1, 0}, {1}}, 21 | {{1, 1}, {0}}, 22 | } 23 | 24 | // instantiate the Feed Forward 25 | ff := &gobrain.FeedForward{} 26 | 27 | // initialize the Neural Network; 28 | // the networks structure will contain: 29 | // 2 inputs, 2 hidden nodes and 1 output. 30 | ff.Init(2, 2, 1) 31 | 32 | // train the network using the XOR patterns 33 | // the training will run for 1000 epochs 34 | // the learning rate is set to 0.6 and the momentum factor to 0.4 35 | // use true in the last parameter to receive reports about the learning error 36 | ff.Train(patterns, 1000, 0.6, 0.4, false) 37 | 38 | // saves neural network to file 39 | err := persist.Save(filename, ff) 40 | if err != nil { 41 | log.Println("impossible to save network on file: ", err.Error()) 42 | } 43 | 44 | // sends inputs to the neural network 45 | inputs := []float64{1, 1} 46 | 47 | // saves the result 48 | result := ff.Update(inputs) 49 | 50 | // prints the result 51 | fmt.Println(result) 52 | } 53 | -------------------------------------------------------------------------------- /persist/persist.go: -------------------------------------------------------------------------------- 1 | // Thanks to Matt Ryer 2 | // https://medium.com/@matryer/golang-advent-calendar-day-eleven-persisting-go-objects-to-disk-7caf1ee3d11d 3 | 4 | package persist 5 | 6 | import ( 7 | "bytes" 8 | "encoding/json" 9 | "io" 10 | "os" 11 | "sync" 12 | ) 13 | 14 | var lock sync.Mutex 15 | 16 | // Marshal is a function that marshals the object into an 17 | // io.Reader. 18 | // By default, it uses the JSON marshaller. 19 | var Marshal = func(v interface{}) (io.Reader, error) { 20 | b, err := json.MarshalIndent(v, "", "\t") 21 | if err != nil { 22 | return nil, err 23 | } 24 | return bytes.NewReader(b), nil 25 | } 26 | 27 | // Save saves a representation of v to the file at path. 28 | func Save(path string, v interface{}) error { 29 | lock.Lock() 30 | defer lock.Unlock() 31 | f, err := os.Create(path) 32 | if err != nil { 33 | return err 34 | } 35 | defer f.Close() 36 | r, err := Marshal(v) 37 | if err != nil { 38 | return err 39 | } 40 | _, err = io.Copy(f, r) 41 | return err 42 | } 43 | 44 | // Unmarshal is a function that unmarshals the data from the 45 | // reader into the specified value. 46 | // By default, it uses the JSON unmarshaller. 47 | var Unmarshal = func(r io.Reader, v interface{}) error { 48 | return json.NewDecoder(r).Decode(v) 49 | } 50 | 51 | // Load loads the file at path into v. 52 | // Use os.IsNotExist() to see if the returned error is due 53 | // to the file being missing. 54 | func Load(path string, v interface{}) error { 55 | lock.Lock() 56 | defer lock.Unlock() 57 | f, err := os.Open(path) 58 | if err != nil { 59 | return err 60 | } 61 | defer f.Close() 62 | return Unmarshal(f, v) 63 | } 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gobrain 2 | 3 | Neural Networks written in go 4 | 5 | [![GoDoc](https://godoc.org/github.com/goml/gobrain?status.svg)](https://godoc.org/github.com/goml/gobrain) 6 | [![Build Status](https://travis-ci.org/goml/gobrain.svg?branch=master)](https://travis-ci.org/goml/gobrain) 7 | 8 | ## Getting Started 9 | The version `1.0.0` includes just basic Neural Network functions such as Feed Forward and Elman Recurrent Neural Network. 10 | A simple Feed Forward Neural Network can be constructed and trained as follows: 11 | 12 | ```go 13 | package main 14 | 15 | import ( 16 | "github.com/goml/gobrain" 17 | "math/rand" 18 | ) 19 | 20 | func main() { 21 | // set the random seed to 0 22 | rand.Seed(0) 23 | 24 | // create the XOR representation patter to train the network 25 | patterns := [][][]float64{ 26 | {{0, 0}, {0}}, 27 | {{0, 1}, {1}}, 28 | {{1, 0}, {1}}, 29 | {{1, 1}, {0}}, 30 | } 31 | 32 | // instantiate the Feed Forward 33 | ff := &gobrain.FeedForward{} 34 | 35 | // initialize the Neural Network; 36 | // the networks structure will contain: 37 | // 2 inputs, 2 hidden nodes and 1 output. 38 | ff.Init(2, 2, 1) 39 | 40 | // train the network using the XOR patterns 41 | // the training will run for 1000 epochs 42 | // the learning rate is set to 0.6 and the momentum factor to 0.4 43 | // use true in the last parameter to receive reports about the learning error 44 | ff.Train(patterns, 1000, 0.6, 0.4, true) 45 | } 46 | 47 | ``` 48 | 49 | After running this code the network will be trained and ready to be used. 50 | 51 | The network can be tested running using the `Test` method, for instance: 52 | 53 | ```go 54 | ff.Test(patterns) 55 | ``` 56 | 57 | The test operation will print in the console something like: 58 | 59 | ``` 60 | [0 0] -> [0.057503945708445] : [0] 61 | [0 1] -> [0.930100635071210] : [1] 62 | [1 0] -> [0.927809966227284] : [1] 63 | [1 1] -> [0.097408795324620] : [0] 64 | ``` 65 | 66 | Where the first values are the inputs, the values after the arrow `->` are the output values from the network and the values after `:` are the expected outputs. 67 | 68 | The method `Update` can be used to predict the output given an input, for example: 69 | 70 | ```go 71 | inputs := []float64{1, 1} 72 | ff.Update(inputs) 73 | ``` 74 | 75 | the output will be a vector with values ranging from `0` to `1`. 76 | 77 | In the example folder there are runnable examples with persistence of the trained network on file. 78 | 79 | In example/02 the network is saved on file and in example/03 the network is loaded from file. 80 | 81 | To run the example cd in the folder and run 82 | 83 | go run main.go 84 | 85 | ## Recurrent Neural Network 86 | 87 | This library implements Elman's Simple Recurrent Network. 88 | 89 | To take advantage of this, one can use the `SetContexts` function. 90 | 91 | ```go 92 | ff.SetContexts(1, nil) 93 | ``` 94 | 95 | In the example above, a single context will be created initialized with `0.5`. It is also possible 96 | to create custom initialized contexts, for instance: 97 | 98 | ```go 99 | contexts := [][]float64{ 100 | {0.5, 0.8, 0.1} 101 | } 102 | ``` 103 | 104 | Note that custom contexts must have the same size of hidden nodes + 1 (bias node), 105 | in the example above the size of hidden nodes is 2, thus the context has 3 values. 106 | 107 | ## Changelog 108 | * 1.0.0 - Added Feed Forward Neural Network with contexts from Elman RNN 109 | 110 | -------------------------------------------------------------------------------- /feedforward.go: -------------------------------------------------------------------------------- 1 | // Package gobrain provides basic neural networks algorithms. 2 | package gobrain 3 | 4 | import ( 5 | "fmt" 6 | "log" 7 | "math" 8 | ) 9 | 10 | // FeedForwad struct is used to represent a simple neural network 11 | type FeedForward struct { 12 | // Number of input, hidden, output nodes and contexts 13 | NInputs, NHiddens, NOutputs, NContexts int 14 | // Whether it is regression or not 15 | Regression bool 16 | // Activations for nodes 17 | InputActivations, HiddenActivations, OutputActivations []float64 18 | // ElmanRNN contexts 19 | Contexts [][]float64 20 | // Weights 21 | InputWeights, OutputWeights [][]float64 22 | ContextWeights [][][]float64 23 | // Last change in weights for momentum 24 | InputChanges, OutputChanges [][]float64 25 | ContextChanges [][][]float64 26 | } 27 | 28 | /* 29 | Initialize the neural network; 30 | 31 | the 'inputs' value is the number of inputs the network will have, 32 | the 'hiddens' value is the number of hidden nodes and 33 | the 'outputs' value is the number of the outputs of the network. 34 | */ 35 | func (nn *FeedForward) Init(inputs, hiddens, outputs int) { 36 | nn.NInputs = inputs + 1 // +1 for bias 37 | nn.NHiddens = hiddens + 1 // +1 for bias 38 | nn.NOutputs = outputs 39 | 40 | nn.InputActivations = vector(nn.NInputs, 1.0) 41 | nn.HiddenActivations = vector(nn.NHiddens, 1.0) 42 | nn.OutputActivations = vector(nn.NOutputs, 1.0) 43 | 44 | nn.InputWeights = matrix(nn.NInputs, nn.NHiddens) 45 | nn.OutputWeights = matrix(nn.NHiddens, nn.NOutputs) 46 | 47 | for i := 0; i < nn.NInputs; i++ { 48 | for j := 0; j < nn.NHiddens; j++ { 49 | nn.InputWeights[i][j] = random(-1, 1) 50 | } 51 | } 52 | 53 | for i := 0; i < nn.NHiddens; i++ { 54 | for j := 0; j < nn.NOutputs; j++ { 55 | nn.OutputWeights[i][j] = random(-1, 1) 56 | } 57 | } 58 | 59 | nn.InputChanges = matrix(nn.NInputs, nn.NHiddens) 60 | nn.OutputChanges = matrix(nn.NHiddens, nn.NOutputs) 61 | } 62 | 63 | /* 64 | Set the number of contexts to add to the network. 65 | 66 | By default the network do not have any context so it is a simple Feed Forward network, 67 | when contexts are added the network behaves like an Elman's SRN (Simple Recurrent Network). 68 | 69 | The first parameter (nContexts) is used to indicate the number of contexts to be used, 70 | the second parameter (initValues) can be used to create custom initialized contexts. 71 | 72 | If 'initValues' is set, the first parameter 'nContexts' is ignored and 73 | the contexts provided in 'initValues' are used. 74 | 75 | When using 'initValues' note that contexts must have the same size of hidden nodes + 1 (bias node). 76 | */ 77 | func (nn *FeedForward) SetContexts(nContexts int, initValues [][]float64) { 78 | if initValues == nil { 79 | initValues = make([][]float64, nContexts) 80 | 81 | for i := 0; i < nContexts; i++ { 82 | initValues[i] = vector(nn.NHiddens, 0.5) 83 | } 84 | } 85 | 86 | nn.NContexts = len(initValues) 87 | 88 | nn.ContextWeights = make([][][]float64, nn.NContexts) 89 | nn.ContextChanges = make([][][]float64, nn.NContexts) 90 | 91 | for i := 0; i < nn.NContexts; i++ { 92 | nn.ContextWeights[i] = matrix(nn.NHiddens, nn.NHiddens) 93 | nn.ContextChanges[i] = matrix(nn.NHiddens, nn.NHiddens) 94 | 95 | for j := 0; j < nn.NHiddens; j++ { 96 | for k := 0; k < nn.NHiddens; k++ { 97 | nn.ContextWeights[i][j][k] = random(-1, 1) 98 | } 99 | } 100 | } 101 | 102 | nn.Contexts = initValues 103 | } 104 | 105 | /* 106 | Reset the context values. 107 | 108 | Useful to remove noise from previous context when the network is given the start of a new sequence. 109 | This does not affect the context weights. 110 | */ 111 | func (nn *FeedForward) ResetContexts() { 112 | for i := 0; i < nn.NContexts; i++ { 113 | for j := 0; j < nn.NHiddens; j++ { 114 | nn.Contexts[i][j] = 0.5 115 | } 116 | } 117 | } 118 | 119 | /* 120 | The Update method is used to activate the Neural Network. 121 | 122 | Given an array of inputs, it returns an array, of length equivalent of number of outputs, with values ranging from 0 to 1. 123 | */ 124 | func (nn *FeedForward) Update(inputs []float64) []float64 { 125 | if len(inputs) != nn.NInputs-1 { 126 | log.Fatal("Error: wrong number of inputs") 127 | } 128 | 129 | for i := 0; i < nn.NInputs-1; i++ { 130 | nn.InputActivations[i] = inputs[i] 131 | } 132 | 133 | for i := 0; i < nn.NHiddens-1; i++ { 134 | var sum float64 135 | 136 | for j := 0; j < nn.NInputs; j++ { 137 | sum += nn.InputActivations[j] * nn.InputWeights[j][i] 138 | } 139 | 140 | // compute contexts sum 141 | for k := 0; k < nn.NContexts; k++ { 142 | for j := 0; j < nn.NHiddens-1; j++ { 143 | sum += nn.Contexts[k][j] * nn.ContextWeights[k][j][i] 144 | } 145 | } 146 | 147 | nn.HiddenActivations[i] = sigmoid(sum) 148 | } 149 | 150 | // update the contexts 151 | if len(nn.Contexts) > 0 { 152 | for i := len(nn.Contexts) - 1; i > 0; i-- { 153 | nn.Contexts[i] = nn.Contexts[i-1] 154 | } 155 | nn.Contexts[0] = nn.HiddenActivations 156 | } 157 | 158 | for i := 0; i < nn.NOutputs; i++ { 159 | var sum float64 160 | for j := 0; j < nn.NHiddens; j++ { 161 | sum += nn.HiddenActivations[j] * nn.OutputWeights[j][i] 162 | } 163 | 164 | nn.OutputActivations[i] = sigmoid(sum) 165 | } 166 | 167 | return nn.OutputActivations 168 | } 169 | 170 | /* 171 | The BackPropagate method is used, when training the Neural Network, 172 | to back propagate the errors from network activation. 173 | */ 174 | func (nn *FeedForward) BackPropagate(targets []float64, lRate, mFactor float64) float64 { 175 | if len(targets) != nn.NOutputs { 176 | log.Fatal("Error: wrong number of target values") 177 | } 178 | 179 | outputDeltas := vector(nn.NOutputs, 0.0) 180 | for i := 0; i < nn.NOutputs; i++ { 181 | outputDeltas[i] = dsigmoid(nn.OutputActivations[i]) * (targets[i] - nn.OutputActivations[i]) 182 | } 183 | 184 | hiddenDeltas := vector(nn.NHiddens, 0.0) 185 | for i := 0; i < nn.NHiddens; i++ { 186 | var e float64 187 | 188 | for j := 0; j < nn.NOutputs; j++ { 189 | e += outputDeltas[j] * nn.OutputWeights[i][j] 190 | } 191 | 192 | hiddenDeltas[i] = dsigmoid(nn.HiddenActivations[i]) * e 193 | } 194 | 195 | for i := 0; i < nn.NHiddens; i++ { 196 | for j := 0; j < nn.NOutputs; j++ { 197 | change := outputDeltas[j] * nn.HiddenActivations[i] 198 | nn.OutputWeights[i][j] = nn.OutputWeights[i][j] + lRate*change + mFactor*nn.OutputChanges[i][j] 199 | nn.OutputChanges[i][j] = change 200 | } 201 | } 202 | 203 | for i := 0; i < nn.NContexts; i++ { 204 | for j := 0; j < nn.NHiddens; j++ { 205 | for k := 0; k < nn.NHiddens; k++ { 206 | change := hiddenDeltas[k] * nn.Contexts[i][j] 207 | nn.ContextWeights[i][j][k] = nn.ContextWeights[i][j][k] + lRate*change + mFactor*nn.ContextChanges[i][j][k] 208 | nn.ContextChanges[i][j][k] = change 209 | } 210 | } 211 | } 212 | 213 | for i := 0; i < nn.NInputs; i++ { 214 | for j := 0; j < nn.NHiddens; j++ { 215 | change := hiddenDeltas[j] * nn.InputActivations[i] 216 | nn.InputWeights[i][j] = nn.InputWeights[i][j] + lRate*change + mFactor*nn.InputChanges[i][j] 217 | nn.InputChanges[i][j] = change 218 | } 219 | } 220 | 221 | var e float64 222 | 223 | for i := 0; i < len(targets); i++ { 224 | e += 0.5 * math.Pow(targets[i]-nn.OutputActivations[i], 2) 225 | } 226 | 227 | return e 228 | } 229 | 230 | /* 231 | This method is used to train the Network, it will run the training operation for 'iterations' times 232 | and return the computed errors when training. 233 | */ 234 | func (nn *FeedForward) Train(patterns [][][]float64, iterations int, lRate, mFactor float64, debug bool) []float64 { 235 | errors := make([]float64, iterations) 236 | 237 | for i := 0; i < iterations; i++ { 238 | var e float64 239 | for _, p := range patterns { 240 | nn.Update(p[0]) 241 | 242 | tmp := nn.BackPropagate(p[1], lRate, mFactor) 243 | e += tmp 244 | } 245 | 246 | errors[i] = e 247 | 248 | if debug && i%1000 == 0 { 249 | fmt.Println(i, e) 250 | } 251 | } 252 | 253 | return errors 254 | } 255 | 256 | func (nn *FeedForward) Test(patterns [][][]float64) { 257 | for _, p := range patterns { 258 | fmt.Println(p[0], "->", nn.Update(p[0]), " : ", p[1]) 259 | } 260 | } 261 | --------------------------------------------------------------------------------