├── .classpath
├── .gitignore
├── .project
├── .settings
└── org.eclipse.jdt.core.prefs
├── LICENSE
├── README
└── src
├── Test.java
└── com
└── evolvingstuff
├── EmbeddedReberGrammar.java
├── IAgent.java
├── IAgentSupervised.java
├── IdentityNeuron.java
├── LSTM.java
├── Neuron.java
├── NeuronType.java
├── SigmoidNeuron.java
└── TanhNeuron.java
/.classpath:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bin/*
2 |
--------------------------------------------------------------------------------
/.project:
--------------------------------------------------------------------------------
1 |
2 |
3 | LongShortTermMemory
4 |
5 |
6 |
7 |
8 |
9 | org.eclipse.jdt.core.javabuilder
10 |
11 |
12 |
13 |
14 |
15 | org.eclipse.jdt.core.javanature
16 |
17 |
18 |
--------------------------------------------------------------------------------
/.settings/org.eclipse.jdt.core.prefs:
--------------------------------------------------------------------------------
1 | #Tue Feb 14 14:02:32 PST 2012
2 | eclipse.preferences.version=1
3 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled
4 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=1.6
5 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve
6 | org.eclipse.jdt.core.compiler.compliance=1.6
7 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate
8 | org.eclipse.jdt.core.compiler.debug.localVariable=generate
9 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate
10 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error
11 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
12 | org.eclipse.jdt.core.compiler.source=1.6
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | copyright (c) 2013 Thomas Lahore
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/README:
--------------------------------------------------------------------------------
1 | This is an implementation of Long Short Term Memory in Java - see here for description of LSTM: http://www.idsia.ch/~juergen/rnn.html
2 |
3 | Not very thoroughly tested or debugged at this point.
4 |
--------------------------------------------------------------------------------
/src/Test.java:
--------------------------------------------------------------------------------
1 | import java.util.Random;
2 |
3 | import com.evolvingstuff.EmbeddedReberGrammar;
4 | import com.evolvingstuff.LSTM;
5 |
6 | public class Test {
7 | public static void main(String[] args) throws Exception {
8 | System.out.println("Test of Long Short Term Memory on Embedded Reber Grammar task");
9 | Random r = new Random(1234);
10 | EmbeddedReberGrammar.deterministic_evaluation = false;
11 | EmbeddedReberGrammar.reset_at_begining = true;
12 | EmbeddedReberGrammar evaluator = new EmbeddedReberGrammar(r);
13 | int cell_blocks = 15;
14 | LSTM agent = new LSTM(r, evaluator.GetObservationDimension(), evaluator.GetActionDimension(), cell_blocks);
15 | int training_epoches = 1000;
16 | for (int t = 0; t < training_epoches; t++)
17 | {
18 | evaluator.SetValidationMode(false);
19 | double fit = evaluator.EvaluateFitnessSupervised(agent);
20 | evaluator.SetValidationMode(true);
21 | double validation = evaluator.EvaluateFitnessSupervised(agent);
22 | System.out.println("\t["+t+"]:\t" + (1-fit) + "\t" + (1-validation));
23 | }
24 | System.out.println("done.");
25 | }
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/EmbeddedReberGrammar.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | import java.util.*;
4 |
5 | public class EmbeddedReberGrammar
6 | {
7 | int tests = 5000;
8 |
9 | public class State
10 | {
11 | public State(Transition[] transitions)
12 | {
13 | this.transitions = transitions;
14 | }
15 | public Transition[] transitions;
16 | }
17 |
18 | public class Transition
19 | {
20 | public Transition(int next_state_id, int token)
21 | {
22 | this.next_state_id = next_state_id;
23 | this.token = token;
24 | }
25 |
26 | public int next_state_id;
27 | public int token;
28 | }
29 |
30 | private Random r;
31 | private State[] states;
32 | private static String[] num_to_string = {"B","T","P","S","X","V","E"};
33 | private int B = 0;
34 | private int T = 1;
35 | private int P = 2;
36 | private int S = 3;
37 | private int X = 4;
38 | private int V = 5;
39 | private int E = 6;
40 | public static boolean reset_at_begining = true;//false
41 | public static boolean super_discrete_feedback = false;
42 | public static boolean discrete_feedback = true;
43 | private boolean error_squared = false;
44 | public static boolean ignore_short_transitions = false; //Partial vs Complete version
45 | public static boolean deterministic_evaluation = false; //TODO
46 | private boolean validation_mode = false;
47 |
48 | public EmbeddedReberGrammar(Random r)
49 | {
50 |
51 | this.r = r;
52 | states = new State[19];
53 | states[0] = new State(new Transition[] {new Transition(1,B)});
54 | states[1] = new State(new Transition[] {new Transition(2,T), new Transition(11,P)});
55 | states[2] = new State(new Transition[] {new Transition(3,B)});
56 | states[3] = new State(new Transition[] {new Transition(4,T), new Transition(9,P)});
57 | states[4] = new State(new Transition[] {new Transition(4,S), new Transition(5,X)});
58 | states[5] = new State(new Transition[] {new Transition(6,S), new Transition(9,X)});
59 | states[6] = new State(new Transition[] {new Transition(7,E)});
60 | states[7] = new State(new Transition[] {new Transition(8,T)});
61 | states[8] = new State(new Transition[] {new Transition(0,E)});
62 | states[9] = new State(new Transition[] {new Transition(9,T), new Transition(10,V)});
63 | states[10] = new State(new Transition[] {new Transition(5,P), new Transition(6,V)});
64 | states[11] = new State(new Transition[] {new Transition(12,B)});
65 | states[12] = new State(new Transition[] {new Transition(13,T), new Transition(17,P)});
66 | states[13] = new State(new Transition[] {new Transition(13,S), new Transition(14,X)});
67 | states[14] = new State(new Transition[] {new Transition(15,S), new Transition(17,X)});
68 | states[15] = new State(new Transition[] {new Transition(16,E)});
69 | states[16] = new State(new Transition[] {new Transition(8,P)});
70 | states[17] = new State(new Transition[] {new Transition(17,T), new Transition(18,V)});
71 | states[18] = new State(new Transition[] {new Transition(14,P), new Transition(15,V)});
72 | }
73 |
74 | public int GetActionDimension()
75 | {
76 | return 7;
77 | }
78 |
79 | public int GetObservationDimension()
80 | {
81 | return 7;
82 | }
83 |
84 | public void SetValidationMode(boolean validation) {
85 | this.validation_mode = validation;
86 | }
87 |
88 | public double EvaluateFitnessSupervised(IAgentSupervised agent) throws Exception {
89 | if (deterministic_evaluation == true)
90 | r = new Random(1);
91 |
92 | int state_id = 0;
93 | agent.Reset();
94 | double tot_fit = 0;
95 | double tot_long_transitions = 0;
96 | double incorrect_long_transitions = 0;
97 | for (int t = 0; t < tests; t++)
98 | {
99 | int transition = -1;
100 | if (states[state_id].transitions.length == 1)
101 | transition = 0;
102 | else if (states[state_id].transitions.length == 2)
103 | transition = r.nextInt(2);
104 | else
105 | System.out.println("ERROR: more that 2 transitions");
106 | if (transition == -1)
107 | System.out.println("ERROR! no transition selected");
108 |
109 | double[] agent_input = new double[7];
110 | agent_input[states[state_id].transitions[transition].token] = 1.0;
111 |
112 | state_id = states[state_id].transitions[transition].next_state_id;
113 |
114 | double[] target = new double[7];
115 | for (int i = 0; i < states[state_id].transitions.length; i++)
116 | target[states[state_id].transitions[i].token] = 1.0;
117 |
118 | double[] agent_output;
119 | if (!validation_mode)
120 | agent_output = agent.Next(agent_input, target);
121 | else
122 | agent_output = agent.Next(agent_input);
123 |
124 | if (state_id == 7 || state_id == 16)
125 | tot_long_transitions++;
126 |
127 | boolean missed_long = false;
128 |
129 | if (super_discrete_feedback == true)
130 | {
131 | boolean all_correct = true;
132 | for (int i = 0; i < 7; i++)
133 | {
134 | if (Math.abs(target[i] - agent_output[i]) >= 0.5) {
135 | all_correct = false;
136 | break;
137 | }
138 | }
139 | if (all_correct)
140 | tot_fit += 1/(double)tests;
141 | }
142 | else
143 | {
144 | for (int i = 0; i < 7; i++)
145 | {
146 | if (discrete_feedback == true)
147 | {
148 |
149 | if (Math.abs(target[i] - agent_output[i]) < 0.5)
150 | tot_fit += 1/(7*(double)tests);
151 | else
152 | {
153 | if (state_id == 7 || state_id == 16)
154 | missed_long = true;
155 | }
156 | }
157 | else
158 | {
159 | if (error_squared == true)
160 | tot_fit += (1 - (target[i] - agent_output[i])*(target[i] - agent_output[i]))/(7*(double)tests);
161 | else
162 | tot_fit += (1 - Math.abs(target[i] - agent_output[i]))/(7*(double)tests);
163 | }
164 | }
165 | }
166 |
167 | if (missed_long == true)
168 | incorrect_long_transitions += 1;
169 |
170 | }
171 |
172 | System.out.println("\t\t\tLong-transition error (validation:"+validation_mode+") = " + (incorrect_long_transitions/tot_long_transitions));
173 |
174 | return tot_fit;
175 | }
176 |
177 | }
178 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/IAgent.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | public interface IAgent
4 | {
5 | void Reset();
6 | double[] Next(double[] input) throws Exception;
7 | }
8 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/IAgentSupervised.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | public interface IAgentSupervised
4 | {
5 | void Reset();
6 | double[] Next(double[] input, double[] target_output) throws Exception;
7 | double[] Next(double[] input) throws Exception;
8 | }
9 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/IdentityNeuron.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | public class IdentityNeuron extends Neuron
4 | {
5 | @Override
6 | public double Activate(double x)
7 | {
8 | return x;
9 | }
10 |
11 | @Override
12 | public double Derivative(double x) {
13 | // TODO Auto-generated method stub
14 | return 1;
15 | }
16 | }
17 |
18 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/LSTM.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | import java.util.*;
4 |
5 | public class LSTM implements IAgent, IAgentSupervised {
6 | private double init_weight_range = 0.1;
7 | private int full_input_dimension;
8 | private int full_hidden_dimension;
9 | private int output_dimension;
10 | private int cell_blocks;
11 | private Neuron neuronNetInput;
12 | private Neuron neuronInputGate;
13 | private Neuron neuronForgetGate;
14 | private Neuron neuronOutputGate;
15 | private Neuron neuronCECSquash;
16 |
17 | private double [] CEC;
18 | private double [] context;
19 |
20 | private double [] peepInputGate;
21 | private double [] peepForgetGate;
22 | private double [] peepOutputGate;
23 | private double [][] weightsNetInput;
24 | private double [][] weightsInputGate;
25 | private double [][] weightsForgetGate;
26 | private double [][] weightsOutputGate;
27 | private double [][] weightsGlobalOutput;
28 |
29 | private double [][] dSdwWeightsNetInput;
30 | private double [][] dSdwWeightsInputGate;
31 | private double [][] dSdwWeightsForgetGate;
32 |
33 | private double biasInputGate = 2;
34 | private double biasForgetGate = -2;
35 | private double biasOutputGate = 2;
36 |
37 | private double learningRate = 0.07;
38 |
39 | public double[] GetHiddenState() {
40 | return context.clone();
41 | }
42 |
43 | public void SetHiddenState(double[] new_state) {
44 | context = new_state.clone();
45 | }
46 |
47 | public LSTM(Random r, int input_dimension, int output_dimension, int cell_blocks) {
48 | this.output_dimension = output_dimension;
49 | this.cell_blocks = cell_blocks;
50 |
51 | CEC = new double[cell_blocks];
52 | context = new double[cell_blocks];
53 |
54 | full_input_dimension = input_dimension + cell_blocks + 1; //+1 for bias
55 | full_hidden_dimension = cell_blocks + 1; //+1 for bias
56 |
57 | neuronNetInput = Neuron.Factory(NeuronType.Tanh);
58 | neuronInputGate = Neuron.Factory(NeuronType.Sigmoid);
59 | neuronForgetGate = Neuron.Factory(NeuronType.Sigmoid);
60 | neuronOutputGate = Neuron.Factory(NeuronType.Sigmoid);
61 | neuronCECSquash= Neuron.Factory(NeuronType.Identity);
62 |
63 | weightsNetInput = new double[cell_blocks][full_input_dimension];
64 | weightsInputGate = new double[cell_blocks][full_input_dimension];
65 | weightsForgetGate = new double[cell_blocks][full_input_dimension];
66 |
67 | dSdwWeightsNetInput = new double[cell_blocks][full_input_dimension];
68 | dSdwWeightsInputGate = new double[cell_blocks][full_input_dimension];
69 | dSdwWeightsForgetGate = new double[cell_blocks][full_input_dimension];
70 |
71 | weightsOutputGate = new double[cell_blocks][full_input_dimension];
72 |
73 | for (int i = 0; i < full_input_dimension; i++) {
74 | for (int j = 0; j < cell_blocks; j++) {
75 | weightsNetInput[j][i] = (r.nextDouble() * 2 - 1) * init_weight_range;
76 | weightsInputGate[j][i] = (r.nextDouble() * 2 - 1) * init_weight_range;
77 | weightsForgetGate[j][i] = (r.nextDouble() * 2 - 1) * init_weight_range;
78 | weightsOutputGate[j][i] = (r.nextDouble() * 2 - 1) * init_weight_range;
79 | }
80 | }
81 |
82 | for (int j = 0; j < cell_blocks; j++) {
83 | weightsInputGate[j][full_input_dimension-1] += biasInputGate;
84 | weightsForgetGate[j][full_input_dimension-1] += biasForgetGate;
85 | weightsOutputGate[j][full_input_dimension-1] += biasOutputGate;
86 | }
87 |
88 | peepInputGate = new double[cell_blocks];
89 | peepForgetGate = new double[cell_blocks];
90 | peepOutputGate = new double[cell_blocks];
91 |
92 | for (int j = 0; j < cell_blocks; j++) {
93 | peepInputGate[j] = (r.nextDouble() * 2 - 1) * init_weight_range;
94 | peepForgetGate[j] = (r.nextDouble() * 2 - 1) * init_weight_range;
95 | peepOutputGate[j] = (r.nextDouble() * 2 - 1) * init_weight_range;
96 | }
97 |
98 | weightsGlobalOutput = new double[output_dimension][full_hidden_dimension];
99 |
100 | for (int j = 0; j < full_hidden_dimension; j++) {
101 | for (int k = 0; k < output_dimension; k++)
102 | weightsGlobalOutput[k][j] = (r.nextDouble() * 2 - 1) * init_weight_range;
103 | }
104 |
105 | }
106 |
107 | public void Reset() {
108 | //TODO: reset deltas here?
109 | for (int c = 0; c < CEC.length; c++)
110 | CEC[c] = 0.0;
111 | for (int c = 0; c < context.length; c++)
112 | context[c] = 0.0;
113 | //reset accumulated partials
114 | for (int c = 0; c < cell_blocks; c++) {
115 | for (int i = 0; i < full_input_dimension; i++) {
116 | this.dSdwWeightsForgetGate[c][i] = 0;
117 | this.dSdwWeightsInputGate[c][i] = 0;
118 | this.dSdwWeightsNetInput[c][i] = 0;
119 | }
120 | }
121 | }
122 |
123 | public double[] Next(double[] input) throws Exception {
124 | return Next(input, null);
125 | }
126 |
127 |
128 |
129 | public void Display() {
130 | System.out.println("==============================");
131 | System.out.println("LSTM: todo...");
132 | System.out.println("\n==============================");
133 | }
134 |
135 | public double[] GetParameters() {
136 | double[] params = new double[(full_input_dimension) * cell_blocks * 4 + 3 * cell_blocks + full_hidden_dimension * output_dimension];
137 | int loc = 0;
138 | for (int j = 0; j < cell_blocks; j++) {
139 | for (int i = 0; i < full_input_dimension; i++) {
140 | params[loc++] = weightsNetInput[j][i];
141 | params[loc++] = weightsInputGate[j][i];
142 | params[loc++] = weightsForgetGate[j][i];
143 | params[loc++] = weightsOutputGate[j][i];
144 | }
145 | params[loc++] = peepInputGate[j];
146 | params[loc++] = peepForgetGate[j];
147 | params[loc++] = peepOutputGate[j];
148 | }
149 |
150 | for (int j = 0; j < full_hidden_dimension; j++) {
151 | for (int k = 0; k < output_dimension; k++)
152 | params[loc++] = weightsGlobalOutput[k][j];
153 | }
154 | if (loc != params.length)
155 | System.out.println("ERROR in LSTM.GetParameters() " + loc + " vs " + params.length);
156 | return params;
157 | }
158 |
159 | public void SetParameters(double[] params) {
160 | int loc = 0;
161 | for (int j = 0; j < cell_blocks; j++) {
162 | for (int i = 0; i < full_input_dimension; i++) {
163 | weightsNetInput[j][i] = params[loc++];
164 | weightsInputGate[j][i] = params[loc++];
165 | weightsForgetGate[j][i] = params[loc++];
166 | weightsOutputGate[j][i] = params[loc++];
167 | }
168 | peepInputGate[j] = params[loc++];
169 | peepForgetGate[j] = params[loc++];
170 | peepOutputGate[j] = params[loc++];
171 | }
172 |
173 | for (int j = 0; j < full_hidden_dimension; j++) {
174 | for (int k = 0; k < output_dimension; k++)
175 | weightsGlobalOutput[k][j] = params[loc++];
176 | }
177 | if (loc != params.length)
178 | System.out.println("ERROR in LSTM.SetParameters() " + loc + " vs " + params.length);
179 | }
180 |
181 | public int GetHiddenDimension() {
182 | return cell_blocks;
183 | }
184 |
185 | public double[] Next(double[] input, double[] target_output) {
186 |
187 | //setup input vector
188 | double[] full_input = new double[full_input_dimension];
189 | int loc = 0;
190 | for (int i = 0; i < input.length; i++)
191 | full_input[loc++] = input[i];
192 | for (int c = 0; c < context.length; c++)
193 | full_input[loc++] = context[c];
194 | full_input[loc++] = 1.0; //bias
195 |
196 | //cell block arrays
197 | double[] NetInputSum = new double[cell_blocks];
198 | double[] InputGateSum = new double[cell_blocks];
199 | double[] ForgetGateSum = new double[cell_blocks];
200 | double[] OutputGateSum = new double[cell_blocks];
201 |
202 | double[] NetInputAct = new double[cell_blocks];
203 | double[] InputGateAct = new double[cell_blocks];
204 | double[] ForgetGateAct = new double[cell_blocks];
205 | double[] OutputGateAct = new double[cell_blocks];
206 |
207 | double[] CECSquashAct = new double[cell_blocks];
208 |
209 | double[] NetOutputAct = new double[cell_blocks];
210 |
211 | //inputs to cell blocks
212 | for (int i = 0; i < full_input_dimension; i++) {
213 | for (int j = 0; j < cell_blocks; j++) {
214 | NetInputSum[j] += weightsNetInput[j][i] * full_input[i];
215 | InputGateSum[j] += weightsInputGate[j][i] * full_input[i];
216 | ForgetGateSum[j] += weightsForgetGate[j][i] * full_input[i];
217 | OutputGateSum[j] += weightsOutputGate[j][i] * full_input[i];
218 | }
219 | }
220 |
221 | double[] CEC1 = new double[cell_blocks];
222 | double[] CEC2 = new double[cell_blocks];
223 | double[] CEC3 = new double[cell_blocks];
224 |
225 | //internals of cell blocks
226 | for (int j = 0; j < cell_blocks; j++) {
227 | CEC1[j] = CEC[j];
228 |
229 | NetInputAct[j] = neuronNetInput.Activate(NetInputSum[j]);
230 |
231 | ForgetGateSum[j] += peepForgetGate[j] * CEC1[j];
232 | ForgetGateAct[j] = neuronForgetGate.Activate(ForgetGateSum[j]);
233 |
234 | CEC2[j] = CEC1[j] * ForgetGateAct[j];
235 |
236 | InputGateSum[j] += peepInputGate[j] * CEC2[j];
237 | InputGateAct[j] = neuronInputGate.Activate(InputGateSum[j]);
238 |
239 | CEC3[j] = CEC2[j] + NetInputAct[j] * InputGateAct[j];
240 |
241 | OutputGateSum[j] += peepOutputGate[j] * CEC3[j]; //TODO: this versus squashed?
242 | OutputGateAct[j] = neuronOutputGate.Activate(OutputGateSum[j]);
243 |
244 | CECSquashAct[j] = neuronCECSquash.Activate(CEC3[j]);
245 |
246 | NetOutputAct[j] = CECSquashAct[j] * OutputGateAct[j];
247 | }
248 |
249 | //prepare hidden layer plus bias
250 | double [] full_hidden = new double[full_hidden_dimension];
251 | loc = 0;
252 | for (int j = 0; j < cell_blocks; j++)
253 | full_hidden[loc++] = NetOutputAct[j];
254 | full_hidden[loc++] = 1.0; //bias
255 |
256 | //calculate output
257 | double[] output = new double[output_dimension];
258 | for (int k = 0; k < output_dimension; k++) {
259 | for (int j = 0; j < full_hidden_dimension; j++)
260 | output[k] += weightsGlobalOutput[k][j] * full_hidden[j];
261 | //output not squashed
262 | }
263 |
264 | //////////////////////////////////////////////////////////////
265 | //////////////////////////////////////////////////////////////
266 | //BACKPROP
267 | //////////////////////////////////////////////////////////////
268 | //////////////////////////////////////////////////////////////
269 |
270 | //scale partials
271 | for (int c = 0; c < cell_blocks; c++) {
272 | for (int i = 0; i < full_input_dimension; i++) {
273 | this.dSdwWeightsInputGate[c][i] *= ForgetGateAct[c];
274 | this.dSdwWeightsForgetGate[c][i] *= ForgetGateAct[c];
275 | this.dSdwWeightsNetInput[c][i] *= ForgetGateAct[c];
276 |
277 | dSdwWeightsInputGate[c][i] += full_input[i] * neuronInputGate.Derivative(InputGateSum[c]) * NetInputAct[c];
278 | dSdwWeightsForgetGate[c][i] += full_input[i] * neuronForgetGate.Derivative(ForgetGateSum[c]) * CEC1[c];
279 | dSdwWeightsNetInput[c][i] += full_input[i] * neuronNetInput.Derivative(NetInputSum[c]) * InputGateAct[c];
280 | }
281 | }
282 |
283 | if (target_output != null) {
284 | double[] deltaGlobalOutputPre = new double[output_dimension];
285 | for (int k = 0; k < output_dimension; k++) {
286 | deltaGlobalOutputPre[k] = target_output[k] - output[k];
287 | }
288 |
289 | //output to hidden
290 | double[] deltaNetOutput = new double[cell_blocks];
291 | for (int k = 0; k < output_dimension; k++) {
292 | //links
293 | for (int c = 0; c < cell_blocks; c++) {
294 | deltaNetOutput[c] += deltaGlobalOutputPre[k] * weightsGlobalOutput[k][c];
295 | weightsGlobalOutput[k][c] += deltaGlobalOutputPre[k] * NetOutputAct[c] * learningRate;
296 | }
297 | //bias
298 | weightsGlobalOutput[k][cell_blocks] += deltaGlobalOutputPre[k] * 1.0 * learningRate;
299 | }
300 |
301 | for (int c = 0; c < cell_blocks; c++) {
302 |
303 | //update output gates
304 | double deltaOutputGatePost = deltaNetOutput[c] * CECSquashAct[c];
305 | double deltaOutputGatePre = neuronOutputGate.Derivative(OutputGateSum[c]) * deltaOutputGatePost;
306 | for (int i = 0; i < full_input_dimension; i++) {
307 | weightsOutputGate[c][i] += full_input[i] * deltaOutputGatePre * learningRate;
308 | }
309 | peepOutputGate[c] += CEC3[c] * deltaOutputGatePre * learningRate;
310 |
311 | //before outgate
312 | double deltaCEC3 = deltaNetOutput[c] * OutputGateAct[c] * neuronCECSquash.Derivative(CEC3[c]);
313 |
314 | //update input gates
315 | double deltaInputGatePost = deltaCEC3 * NetInputAct[c];
316 | double deltaInputGatePre = neuronInputGate.Derivative(InputGateSum[c]) * deltaInputGatePost;
317 | for (int i = 0; i < full_input_dimension; i++) {
318 | weightsInputGate[c][i] += dSdwWeightsInputGate[c][i] * deltaCEC3 * learningRate;
319 | }
320 | peepInputGate[c] += CEC2[c] * deltaInputGatePre * learningRate;
321 |
322 | //before ingate
323 | double deltaCEC2 = deltaCEC3;
324 |
325 | //update forget gates
326 | double deltaForgetGatePost = deltaCEC2 * CEC1[c];
327 | double deltaForgetGatePre = neuronForgetGate.Derivative(ForgetGateSum[c]) * deltaForgetGatePost;
328 | for (int i = 0; i < full_input_dimension; i++) {
329 | weightsForgetGate[c][i] += dSdwWeightsForgetGate[c][i] * deltaCEC2 * learningRate;
330 | }
331 | peepForgetGate[c] += CEC1[c] * deltaForgetGatePre * learningRate;
332 |
333 | //update cell inputs
334 | for (int i = 0; i < full_input_dimension; i++) {
335 | weightsNetInput[c][i] += dSdwWeightsNetInput[c][i] * deltaCEC3 * learningRate;
336 | }
337 | //no peeps for cell inputs
338 | }
339 | }
340 |
341 | //////////////////////////////////////////////////////////////
342 |
343 | //roll-over context to next time step
344 | for (int j = 0; j < cell_blocks; j++) {
345 | context[j] = NetOutputAct[j];
346 | CEC[j] = CEC3[j];
347 | }
348 |
349 | //give results
350 | return output;
351 | }
352 | }
353 |
354 |
355 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/Neuron.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | public abstract class Neuron
4 | {
5 | public static Neuron Factory(NeuronType neuron_type)
6 | {
7 | if (neuron_type == NeuronType.Sigmoid)
8 | return new SigmoidNeuron();
9 | else if (neuron_type == NeuronType.Identity)
10 | return new IdentityNeuron();
11 | else if (neuron_type == NeuronType.Tanh)
12 | return new TanhNeuron();
13 | else
14 | System.out.println("ERROR: unknown neuron type");
15 | return null;
16 | }
17 |
18 | abstract public double Activate(double x);
19 | abstract public double Derivative(double x);
20 | }
21 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/NeuronType.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | public enum NeuronType
4 | {
5 | Sigmoid,
6 | Identity,
7 | Tanh
8 | }
9 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/SigmoidNeuron.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | public class SigmoidNeuron extends Neuron
4 | {
5 | @Override
6 | public double Activate(double x) {
7 | return 1 / (1 + Math.exp(-x));
8 | }
9 |
10 | @Override
11 | public double Derivative(double x) {
12 | double act = Activate(x);
13 | return act * (1 - act);
14 | }
15 |
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/src/com/evolvingstuff/TanhNeuron.java:
--------------------------------------------------------------------------------
1 | package com.evolvingstuff;
2 |
3 | public class TanhNeuron extends Neuron
4 | {
5 | @Override
6 | public double Activate(double x) {
7 | return Math.tanh(x);
8 | }
9 |
10 | @Override
11 | public double Derivative(double x) {
12 | double coshx = Math.cosh(x);
13 | double denom = (Math.cosh(2*x) + 1);
14 | return 4 * coshx * coshx / (denom * denom);
15 | }
16 |
17 |
18 | }
19 |
--------------------------------------------------------------------------------