├── mackey_glass_t17.npy ├── readme.md ├── license.md ├── testing.py ├── pyESN.py └── mackey.ipynb /mackey_glass_t17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cknd/pyESN/HEAD/mackey_glass_t17.npy -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Echo State Networks in Python 2 | 3 | [Echo State Networks](http://www.scholarpedia.org/article/Echo_state_network) are easy-to-train recurrent neural networks, a variant of [Reservoir Computing](https://en.wikipedia.org/wiki/Reservoir_computing). In some sense, these networks show how far you can get with nothing but a good weight initialisation. 4 | 5 | This ESN implementation is relatively simple and self-contained, though it offers tricks like noise injection and teacher forcing (feedback connections), plus a zoo of dubious little hyperparameters. 6 | 7 | However! If your aims are practical and your gradients automatic, consider using a fully trained network. 8 | 9 | # Examples 10 | 11 | - [learning to be a tunable frequency generator](http://nbviewer.ipython.org/github/cknd/pyESN/blob/master/freqgen.ipynb) 12 | - [learning a Mackey-Glass system](http://nbviewer.ipython.org/github/cknd/pyESN/blob/master/mackey.ipynb) -------------------------------------------------------------------------------- /license.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Clemens Korndörfer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from pyESN import ESN 5 | 6 | N_in, N, N_out = 5, 75, 3 7 | 8 | 9 | def random_task(): 10 | X = np.random.randn(100, N_in) 11 | y = np.random.randn(100, N_out) 12 | Xp = np.random.randn(50, N_in) 13 | return X, y, Xp 14 | 15 | 16 | class RandomStateHandling(unittest.TestCase): 17 | 18 | def setUp(self): 19 | self.task = random_task() 20 | 21 | def _compare(self, esnA, esnB, should_be): 22 | """helper function to see if two esns are the same""" 23 | X, y, Xp = self.task 24 | test = self.assertTrue if should_be == "same" else self.assertFalse 25 | test(np.all(np.equal(esnA.W, esnB.W))) 26 | test(np.all(np.equal(esnA.W_in, esnB.W_in))) 27 | test(np.all(np.equal(esnA.W_feedb, esnB.W_feedb))) 28 | test(np.all(np.equal(esnA.fit(X, y), esnB.fit(X, y)))) 29 | test(np.all(np.equal(esnA.W_out, esnB.W_out))) 30 | test(np.all(np.equal(esnA.predict(Xp), esnB.predict(Xp)))) 31 | 32 | def test_integer(self): 33 | """two esns with the same seed should be the same""" 34 | esnA = ESN(N_in, N_out, random_state=1) 35 | esnB = ESN(N_in, N_out, random_state=1) 36 | self._compare(esnA, esnB, should_be="same") 37 | 38 | def test_randomstate_object(self): 39 | """two esns with the same randomstate objects should be the same""" 40 | rstA = np.random.RandomState(1) 41 | esnA = ESN(N_in, N_out, random_state=rstA) 42 | rstB = np.random.RandomState(1) 43 | esnB = ESN(N_in, N_out, random_state=rstB) 44 | self._compare(esnA, esnB, should_be="same") 45 | 46 | def test_none(self): 47 | """two esns with no specified seed should be different""" 48 | esnA = ESN(N_in, N_out, random_state=None) 49 | esnB = ESN(N_in, N_out, random_state=None) 50 | self._compare(esnA, esnB, should_be="different") 51 | 52 | def test_nonsense(self): 53 | """parameter random_state should only accept positive integers""" 54 | with self.assertRaises(ValueError): 55 | ESN(N_in, N_out, random_state=-1) 56 | 57 | with self.assertRaises(Exception) as cm: 58 | ESN(N_in, N_out, random_state=0.5) 59 | self.assertIn("Invalid seed", str(cm.exception)) 60 | 61 | def test_serialisation(self): 62 | import pickle 63 | import io 64 | esn = ESN(N_in, N_out, random_state=1) 65 | with io.BytesIO() as buf: 66 | pickle.dump(esn, buf) 67 | buf.flush() 68 | buf.seek(0) 69 | esn_unpickled = pickle.load(buf) 70 | self._compare(esn, esn_unpickled, should_be='same') 71 | 72 | 73 | class InitArguments(unittest.TestCase): 74 | 75 | def setUp(self): 76 | self.X, self.y, self.Xp = random_task() 77 | 78 | def test_inputscaling(self): 79 | """input scaling factors of different formats should be correctly intereted or rejected""" 80 | esn = ESN(N_in, N_out, input_scaling=2) 81 | self.assertTrue(np.all(2 * self.X == esn._scale_inputs(self.X))) 82 | esn.fit(self.X, self.y) 83 | esn.predict(self.Xp) 84 | 85 | esn = ESN(N_in, N_out, input_scaling=[2] * N_in) 86 | self.assertTrue(np.all(2 * self.X == esn._scale_inputs(self.X))) 87 | esn.fit(self.X, self.y) 88 | esn.predict(self.Xp) 89 | 90 | esn = ESN(N_in, N_out, input_scaling=np.array([2] * N_in)) 91 | self.assertTrue(np.all(2 * self.X == esn._scale_inputs(self.X))) 92 | esn.fit(self.X, self.y) 93 | esn.predict(self.Xp) 94 | 95 | with self.assertRaises(ValueError): 96 | esn = ESN(N_in, N_out, input_scaling=[2] * (N_in + 1)) 97 | 98 | with self.assertRaises(ValueError): 99 | esn = ESN(N_in, N_out, input_scaling=np.array([[2] * N_in])) 100 | 101 | def test_inputshift(self): 102 | """input shift factors of different formats should be correctly interpreted or rejected""" 103 | esn = ESN(N_in, N_out, input_shift=1) 104 | self.assertTrue(np.all(1 + self.X == esn._scale_inputs(self.X))) 105 | esn.fit(self.X, self.y) 106 | esn.predict(self.Xp) 107 | 108 | esn = ESN(N_in, N_out, input_shift=[1] * N_in) 109 | self.assertTrue(np.all(1 + self.X == esn._scale_inputs(self.X))) 110 | esn.fit(self.X, self.y) 111 | esn.predict(self.Xp) 112 | 113 | esn = ESN(N_in, N_out, input_shift=np.array([1] * N_in)) 114 | self.assertTrue(np.all(1 + self.X == esn._scale_inputs(self.X))) 115 | esn.fit(self.X, self.y) 116 | esn.predict(self.Xp) 117 | 118 | with self.assertRaises(ValueError): 119 | esn = ESN(N_in, N_out, input_shift=[1] * (N_in + 1)) 120 | 121 | with self.assertRaises(ValueError): 122 | esn = ESN(N_in, N_out, input_shift=np.array([[1] * N_in])) 123 | 124 | def test_IODimensions(self): 125 | """try different combinations of input & output dimensionalities & teacher forcing""" 126 | tasks = [(1, 1, 100, True), (10, 1, 100, True), (1, 10, 100, True), (10, 10, 100, True), 127 | (1, 1, 100, False), (10, 1, 100, False), (1, 10, 100, False), (10, 10, 100, False)] 128 | for t in tasks: 129 | N_in, N_out, N_samples, tf = t 130 | X = np.random.randn( 131 | N_samples, N_in) if N_in > 1 else np.random.randn(N_samples) 132 | y = np.random.randn( 133 | N_samples, N_out) if N_out > 1 else np.random.randn(N_samples) 134 | Xp = np.random.randn( 135 | N_samples, N_in) if N_in > 1 else np.random.randn(N_samples) 136 | esn = ESN(N_in, N_out, teacher_forcing=tf) 137 | prediction_tr = esn.fit(X, y) 138 | prediction_t = esn.predict(Xp) 139 | self.assertEqual(prediction_tr.shape, (N_samples, N_out)) 140 | self.assertEqual(prediction_t.shape, (N_samples, N_out)) 141 | 142 | 143 | class Performance(unittest.TestCase): 144 | # Slighty bending the concept of a unit test, I want to catch performance changes during refactoring. 145 | # Ideally, this will expand to a collection of known tasks. 146 | 147 | def test_mackey(self): 148 | try: 149 | data = np.load('mackey_glass_t17.npy') 150 | except IOError: 151 | self.skipTest("missing data") 152 | 153 | esn = ESN(n_inputs=1, 154 | n_outputs=1, 155 | n_reservoir=500, 156 | spectral_radius=1.5, 157 | random_state=42) 158 | 159 | trainlen = 2000 160 | future = 2000 161 | esn.fit(np.ones(trainlen), data[:trainlen]) 162 | prediction = esn.predict(np.ones(future)) 163 | error = np.sqrt( 164 | np.mean((prediction.flatten() - data[trainlen:trainlen + future])**2)) 165 | self.assertAlmostEqual(error, 0.1396039098653574) 166 | 167 | def test_freqgen(self): 168 | rng = np.random.RandomState(42) 169 | 170 | def frequency_generator(N, min_period, max_period, n_changepoints): 171 | """returns a random step function + a sine wave signal that 172 | changes its frequency at each such step.""" 173 | # vector of random indices < N, padded with 0 and N at the ends: 174 | changepoints = np.insert(np.sort(rng.randint(0, N, n_changepoints)), [ 175 | 0, n_changepoints], [0, N]) 176 | # list of interval boundaries between which the control sequence 177 | # should be constant: 178 | const_intervals = list( 179 | zip(changepoints, np.roll(changepoints, -1)))[:-1] 180 | # populate a control sequence 181 | frequency_control = np.zeros((N, 1)) 182 | for (t0, t1) in const_intervals: 183 | frequency_control[t0:t1] = rng.rand() 184 | periods = frequency_control * \ 185 | (max_period - min_period) + max_period 186 | 187 | # run time through a sine, while changing the period length 188 | frequency_output = np.zeros((N, 1)) 189 | z = 0 190 | for i in range(N): 191 | z = z + 2 * np.pi / periods[i] 192 | frequency_output[i] = (np.sin(z) + 1) / 2 193 | return np.hstack([np.ones((N, 1)), 1 - frequency_control]), frequency_output 194 | 195 | N = 15000 196 | min_period = 2 197 | max_period = 10 198 | n_changepoints = int(N / 200) 199 | frequency_control, frequency_output = frequency_generator( 200 | N, min_period, max_period, n_changepoints) 201 | 202 | traintest_cutoff = int(np.ceil(0.7 * N)) 203 | train_ctrl, train_output = frequency_control[ 204 | :traintest_cutoff], frequency_output[:traintest_cutoff] 205 | test_ctrl, test_output = frequency_control[ 206 | traintest_cutoff:], frequency_output[traintest_cutoff:] 207 | 208 | esn = ESN(n_inputs=2, 209 | n_outputs=1, 210 | n_reservoir=200, 211 | spectral_radius=0.25, 212 | sparsity=0.95, 213 | noise=0.001, 214 | input_shift=[0, 0], 215 | input_scaling=[0.01, 3], 216 | teacher_scaling=1.12, 217 | teacher_shift=-0.7, 218 | out_activation=np.tanh, 219 | inverse_out_activation=np.arctanh, 220 | random_state=rng, 221 | silent=True) 222 | 223 | pred_train = esn.fit(train_ctrl, train_output) 224 | # print "test error:" 225 | pred_test = esn.predict(test_ctrl) 226 | error = np.sqrt(np.mean((pred_test - test_output)**2)) 227 | self.assertAlmostEqual(error, 0.30519018985725715) 228 | 229 | 230 | if __name__ == '__main__': 231 | unittest.main() 232 | -------------------------------------------------------------------------------- /pyESN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def correct_dimensions(s, targetlength): 5 | """checks the dimensionality of some numeric argument s, broadcasts it 6 | to the specified length if possible. 7 | 8 | Args: 9 | s: None, scalar or 1D array 10 | targetlength: expected length of s 11 | 12 | Returns: 13 | None if s is None, else numpy vector of length targetlength 14 | """ 15 | if s is not None: 16 | s = np.array(s) 17 | if s.ndim == 0: 18 | s = np.array([s] * targetlength) 19 | elif s.ndim == 1: 20 | if not len(s) == targetlength: 21 | raise ValueError("arg must have length " + str(targetlength)) 22 | else: 23 | raise ValueError("Invalid argument") 24 | return s 25 | 26 | 27 | def identity(x): 28 | return x 29 | 30 | 31 | class ESN(): 32 | 33 | def __init__(self, n_inputs, n_outputs, n_reservoir=200, 34 | spectral_radius=0.95, sparsity=0, noise=0.001, input_shift=None, 35 | input_scaling=None, teacher_forcing=True, feedback_scaling=None, 36 | teacher_scaling=None, teacher_shift=None, 37 | out_activation=identity, inverse_out_activation=identity, 38 | random_state=None, silent=True): 39 | """ 40 | Args: 41 | n_inputs: nr of input dimensions 42 | n_outputs: nr of output dimensions 43 | n_reservoir: nr of reservoir neurons 44 | spectral_radius: spectral radius of the recurrent weight matrix 45 | sparsity: proportion of recurrent weights set to zero 46 | noise: noise added to each neuron (regularization) 47 | input_shift: scalar or vector of length n_inputs to add to each 48 | input dimension before feeding it to the network. 49 | input_scaling: scalar or vector of length n_inputs to multiply 50 | with each input dimension before feeding it to the netw. 51 | teacher_forcing: if True, feed the target back into output units 52 | teacher_scaling: factor applied to the target signal 53 | teacher_shift: additive term applied to the target signal 54 | out_activation: output activation function (applied to the readout) 55 | inverse_out_activation: inverse of the output activation function 56 | random_state: positive integer seed, np.rand.RandomState object, 57 | or None to use numpy's builting RandomState. 58 | silent: supress messages 59 | """ 60 | # check for proper dimensionality of all arguments and write them down. 61 | self.n_inputs = n_inputs 62 | self.n_reservoir = n_reservoir 63 | self.n_outputs = n_outputs 64 | self.spectral_radius = spectral_radius 65 | self.sparsity = sparsity 66 | self.noise = noise 67 | self.input_shift = correct_dimensions(input_shift, n_inputs) 68 | self.input_scaling = correct_dimensions(input_scaling, n_inputs) 69 | 70 | self.teacher_scaling = teacher_scaling 71 | self.teacher_shift = teacher_shift 72 | 73 | self.out_activation = out_activation 74 | self.inverse_out_activation = inverse_out_activation 75 | self.random_state = random_state 76 | 77 | # the given random_state might be either an actual RandomState object, 78 | # a seed or None (in which case we use numpy's builtin RandomState) 79 | if isinstance(random_state, np.random.RandomState): 80 | self.random_state_ = random_state 81 | elif random_state: 82 | try: 83 | self.random_state_ = np.random.RandomState(random_state) 84 | except TypeError as e: 85 | raise Exception("Invalid seed: " + str(e)) 86 | else: 87 | self.random_state_ = np.random.mtrand._rand 88 | 89 | self.teacher_forcing = teacher_forcing 90 | self.silent = silent 91 | self.initweights() 92 | 93 | def initweights(self): 94 | # initialize recurrent weights: 95 | # begin with a random matrix centered around zero: 96 | W = self.random_state_.rand(self.n_reservoir, self.n_reservoir) - 0.5 97 | # delete the fraction of connections given by (self.sparsity): 98 | W[self.random_state_.rand(*W.shape) < self.sparsity] = 0 99 | # compute the spectral radius of these weights: 100 | radius = np.max(np.abs(np.linalg.eigvals(W))) 101 | # rescale them to reach the requested spectral radius: 102 | self.W = W * (self.spectral_radius / radius) 103 | 104 | # random input weights: 105 | self.W_in = self.random_state_.rand( 106 | self.n_reservoir, self.n_inputs) * 2 - 1 107 | # random feedback (teacher forcing) weights: 108 | self.W_feedb = self.random_state_.rand( 109 | self.n_reservoir, self.n_outputs) * 2 - 1 110 | 111 | def _update(self, state, input_pattern, output_pattern): 112 | """performs one update step. 113 | 114 | i.e., computes the next network state by applying the recurrent weights 115 | to the last state & and feeding in the current input and output patterns 116 | """ 117 | if self.teacher_forcing: 118 | preactivation = (np.dot(self.W, state) 119 | + np.dot(self.W_in, input_pattern) 120 | + np.dot(self.W_feedb, output_pattern)) 121 | else: 122 | preactivation = (np.dot(self.W, state) 123 | + np.dot(self.W_in, input_pattern)) 124 | return (np.tanh(preactivation) 125 | + self.noise * (self.random_state_.rand(self.n_reservoir) - 0.5)) 126 | 127 | def _scale_inputs(self, inputs): 128 | """for each input dimension j: multiplies by the j'th entry in the 129 | input_scaling argument, then adds the j'th entry of the input_shift 130 | argument.""" 131 | if self.input_scaling is not None: 132 | inputs = np.dot(inputs, np.diag(self.input_scaling)) 133 | if self.input_shift is not None: 134 | inputs = inputs + self.input_shift 135 | return inputs 136 | 137 | def _scale_teacher(self, teacher): 138 | """multiplies the teacher/target signal by the teacher_scaling argument, 139 | then adds the teacher_shift argument to it.""" 140 | if self.teacher_scaling is not None: 141 | teacher = teacher * self.teacher_scaling 142 | if self.teacher_shift is not None: 143 | teacher = teacher + self.teacher_shift 144 | return teacher 145 | 146 | def _unscale_teacher(self, teacher_scaled): 147 | """inverse operation of the _scale_teacher method.""" 148 | if self.teacher_shift is not None: 149 | teacher_scaled = teacher_scaled - self.teacher_shift 150 | if self.teacher_scaling is not None: 151 | teacher_scaled = teacher_scaled / self.teacher_scaling 152 | return teacher_scaled 153 | 154 | def fit(self, inputs, outputs, inspect=False): 155 | """ 156 | Collect the network's reaction to training data, train readout weights. 157 | 158 | Args: 159 | inputs: array of dimensions (N_training_samples x n_inputs) 160 | outputs: array of dimension (N_training_samples x n_outputs) 161 | inspect: show a visualisation of the collected reservoir states 162 | 163 | Returns: 164 | the network's output on the training data, using the trained weights 165 | """ 166 | # transform any vectors of shape (x,) into vectors of shape (x,1): 167 | if inputs.ndim < 2: 168 | inputs = np.reshape(inputs, (len(inputs), -1)) 169 | if outputs.ndim < 2: 170 | outputs = np.reshape(outputs, (len(outputs), -1)) 171 | # transform input and teacher signal: 172 | inputs_scaled = self._scale_inputs(inputs) 173 | teachers_scaled = self._scale_teacher(outputs) 174 | 175 | if not self.silent: 176 | print("harvesting states...") 177 | # step the reservoir through the given input,output pairs: 178 | states = np.zeros((inputs.shape[0], self.n_reservoir)) 179 | for n in range(1, inputs.shape[0]): 180 | states[n, :] = self._update(states[n - 1], inputs_scaled[n, :], 181 | teachers_scaled[n - 1, :]) 182 | 183 | # learn the weights, i.e. find the linear combination of collected 184 | # network states that is closest to the target output 185 | if not self.silent: 186 | print("fitting...") 187 | # we'll disregard the first few states: 188 | transient = min(int(inputs.shape[1] / 10), 100) 189 | # include the raw inputs: 190 | extended_states = np.hstack((states, inputs_scaled)) 191 | # Solve for W_out: 192 | self.W_out = np.dot(np.linalg.pinv(extended_states[transient:, :]), 193 | self.inverse_out_activation(teachers_scaled[transient:, :])).T 194 | 195 | # remember the last state for later: 196 | self.laststate = states[-1, :] 197 | self.lastinput = inputs[-1, :] 198 | self.lastoutput = teachers_scaled[-1, :] 199 | 200 | # optionally visualize the collected states 201 | if inspect: 202 | from matplotlib import pyplot as plt 203 | # (^-- we depend on matplotlib only if this option is used) 204 | plt.figure( 205 | figsize=(states.shape[0] * 0.0025, states.shape[1] * 0.01)) 206 | plt.imshow(extended_states.T, aspect='auto', 207 | interpolation='nearest') 208 | plt.colorbar() 209 | 210 | if not self.silent: 211 | print("training error:") 212 | # apply learned weights to the collected states: 213 | pred_train = self._unscale_teacher(self.out_activation( 214 | np.dot(extended_states, self.W_out.T))) 215 | if not self.silent: 216 | print(np.sqrt(np.mean((pred_train - outputs)**2))) 217 | return pred_train 218 | 219 | def predict(self, inputs, continuation=True): 220 | """ 221 | Apply the learned weights to the network's reactions to new input. 222 | 223 | Args: 224 | inputs: array of dimensions (N_test_samples x n_inputs) 225 | continuation: if True, start the network from the last training state 226 | 227 | Returns: 228 | Array of output activations 229 | """ 230 | if inputs.ndim < 2: 231 | inputs = np.reshape(inputs, (len(inputs), -1)) 232 | n_samples = inputs.shape[0] 233 | 234 | if continuation: 235 | laststate = self.laststate 236 | lastinput = self.lastinput 237 | lastoutput = self.lastoutput 238 | else: 239 | laststate = np.zeros(self.n_reservoir) 240 | lastinput = np.zeros(self.n_inputs) 241 | lastoutput = np.zeros(self.n_outputs) 242 | 243 | inputs = np.vstack([lastinput, self._scale_inputs(inputs)]) 244 | states = np.vstack( 245 | [laststate, np.zeros((n_samples, self.n_reservoir))]) 246 | outputs = np.vstack( 247 | [lastoutput, np.zeros((n_samples, self.n_outputs))]) 248 | 249 | for n in range(n_samples): 250 | states[ 251 | n + 1, :] = self._update(states[n, :], inputs[n + 1, :], outputs[n, :]) 252 | outputs[n + 1, :] = self.out_activation(np.dot(self.W_out, 253 | np.concatenate([states[n + 1, :], inputs[n + 1, :]]))) 254 | 255 | return self._unscale_teacher(self.out_activation(outputs[1:])) 256 | -------------------------------------------------------------------------------- /mackey.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "learning a [Mackey-Glass](http://www.scholarpedia.org/article/Mackey-Glass_equation) system" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": { 14 | "collapsed": false 15 | }, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "test error: \n", 22 | "0.139603909616\n" 23 | ] 24 | }, 25 | { 26 | "data": { 27 | "text/plain": [ 28 | "" 29 | ] 30 | }, 31 | "execution_count": 1, 32 | "metadata": {}, 33 | "output_type": "execute_result" 34 | }, 35 | { 36 | "data": { 37 | "image/svg+xml": [ 38 | "\n", 39 | "\n", 41 | "\n", 42 | "\n", 43 | " \n", 44 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 1603 | " \n", 1604 | " \n", 1605 | " \n", 2368 | " \n", 2369 | " \n", 2370 | " \n", 2373 | " \n", 2374 | " \n", 2375 | " \n", 2378 | " \n", 2379 | " \n", 2380 | " \n", 2383 | " \n", 2384 | " \n", 2385 | " \n", 2388 | " \n", 2389 | " \n", 2390 | " \n", 2393 | " \n", 2394 | " \n", 2395 | " \n", 2396 | " \n", 2397 | " \n", 2398 | " \n", 2401 | " \n", 2402 | " \n", 2403 | " \n", 2404 | " \n", 2405 | " \n", 2406 | " \n", 2407 | " \n", 2408 | " \n", 2411 | " \n", 2412 | " \n", 2413 | " \n", 2414 | " \n", 2415 | " \n", 2416 | " \n", 2417 | " \n", 2418 | " \n", 2419 | " \n", 2438 | " \n", 2439 | " \n", 2440 | " \n", 2441 | " \n", 2442 | " \n", 2443 | " \n", 2444 | " \n", 2445 | " \n", 2446 | " \n", 2447 | " \n", 2448 | " \n", 2449 | " \n", 2450 | " \n", 2451 | " \n", 2452 | " \n", 2453 | " \n", 2454 | " \n", 2455 | " \n", 2456 | " \n", 2457 | " \n", 2458 | " \n", 2482 | " \n", 2483 | " \n", 2484 | " \n", 2485 | " \n", 2486 | " \n", 2487 | " \n", 2488 | " \n", 2489 | " \n", 2490 | " \n", 2491 | " \n", 2492 | " \n", 2493 | " \n", 2494 | " \n", 2495 | " \n", 2496 | " \n", 2497 | " \n", 2498 | " \n", 2499 | " \n", 2500 | " \n", 2501 | " \n", 2502 | " \n", 2503 | " \n", 2504 | " \n", 2517 | " \n", 2518 | " \n", 2519 | " \n", 2520 | " \n", 2521 | " \n", 2522 | " \n", 2523 | " \n", 2524 | " \n", 2525 | " \n", 2526 | " \n", 2527 | " \n", 2528 | " \n", 2529 | " \n", 2530 | " \n", 2531 | " \n", 2532 | " \n", 2533 | " \n", 2534 | " \n", 2535 | " \n", 2536 | " \n", 2537 | " \n", 2538 | " \n", 2539 | " \n", 2540 | " \n", 2541 | " \n", 2542 | " \n", 2543 | " \n", 2544 | " \n", 2545 | " \n", 2546 | " \n", 2547 | " \n", 2548 | " \n", 2549 | " \n", 2550 | " \n", 2551 | " \n", 2552 | " \n", 2553 | " \n", 2554 | " \n", 2555 | " \n", 2556 | " \n", 2557 | " \n", 2558 | " \n", 2559 | " \n", 2560 | " \n", 2561 | " \n", 2584 | " \n", 2585 | " \n", 2586 | " \n", 2587 | " \n", 2588 | " \n", 2589 | " \n", 2590 | " \n", 2591 | " \n", 2592 | " \n", 2593 | " \n", 2594 | " \n", 2595 | " \n", 2596 | " \n", 2597 | " \n", 2598 | " \n", 2599 | " \n", 2600 | " \n", 2601 | " \n", 2602 | " \n", 2603 | " \n", 2604 | " \n", 2605 | " \n", 2606 | " \n", 2607 | " \n", 2608 | " \n", 2609 | " \n", 2610 | " \n", 2611 | " \n", 2612 | " \n", 2613 | " \n", 2614 | " \n", 2615 | " \n", 2616 | " \n", 2617 | " \n", 2618 | " \n", 2619 | " \n", 2620 | " \n", 2621 | " \n", 2622 | " \n", 2623 | " \n", 2624 | " \n", 2625 | " \n", 2626 | " \n", 2627 | " \n", 2628 | " \n", 2659 | " \n", 2660 | " \n", 2661 | " \n", 2662 | " \n", 2663 | " \n", 2664 | " \n", 2665 | " \n", 2666 | " \n", 2667 | " \n", 2668 | " \n", 2669 | " \n", 2670 | " \n", 2671 | " \n", 2672 | " \n", 2673 | " \n", 2674 | " \n", 2675 | " \n", 2676 | " \n", 2677 | " \n", 2678 | " \n", 2679 | " \n", 2680 | " \n", 2681 | " \n", 2682 | " \n", 2683 | " \n", 2684 | " \n", 2685 | " \n", 2686 | " \n", 2687 | " \n", 2688 | " \n", 2689 | " \n", 2690 | " \n", 2691 | " \n", 2692 | " \n", 2693 | " \n", 2694 | " \n", 2695 | " \n", 2696 | " \n", 2697 | " \n", 2698 | " \n", 2699 | " \n", 2700 | " \n", 2701 | " \n", 2702 | " \n", 2703 | " \n", 2720 | " \n", 2721 | " \n", 2722 | " \n", 2723 | " \n", 2724 | " \n", 2725 | " \n", 2726 | " \n", 2727 | " \n", 2728 | " \n", 2729 | " \n", 2730 | " \n", 2731 | " \n", 2732 | " \n", 2733 | " \n", 2734 | " \n", 2737 | " \n", 2738 | " \n", 2739 | " \n", 2740 | " \n", 2741 | " \n", 2742 | " \n", 2743 | " \n", 2744 | " \n", 2747 | " \n", 2748 | " \n", 2749 | " \n", 2750 | " \n", 2751 | " \n", 2752 | " \n", 2753 | " \n", 2754 | " \n", 2755 | " \n", 2761 | " \n", 2789 | " \n", 2795 | " \n", 2796 | " \n", 2797 | " \n", 2798 | " \n", 2799 | " \n", 2800 | " \n", 2801 | " \n", 2802 | " \n", 2803 | " \n", 2804 | " \n", 2805 | " \n", 2806 | " \n", 2807 | " \n", 2808 | " \n", 2809 | " \n", 2810 | " \n", 2811 | " \n", 2812 | " \n", 2813 | " \n", 2814 | " \n", 2815 | " \n", 2816 | " \n", 2817 | " \n", 2818 | " \n", 2819 | " \n", 2820 | " \n", 2821 | " \n", 2822 | " \n", 2823 | " \n", 2824 | " \n", 2825 | " \n", 2826 | " \n", 2827 | " \n", 2828 | " \n", 2829 | " \n", 2830 | " \n", 2831 | " \n", 2832 | " \n", 2833 | " \n", 2834 | " \n", 2835 | " \n", 2836 | " \n", 2837 | " \n", 2838 | " \n", 2839 | " \n", 2840 | " \n", 2841 | " \n", 2842 | " \n", 2843 | " \n", 2844 | " \n", 2845 | " \n", 2846 | " \n", 2847 | " \n", 2848 | " \n", 2849 | " \n", 2850 | " \n", 2851 | " \n", 2852 | " \n", 2853 | " \n", 2854 | " \n", 2855 | " \n", 2856 | " \n", 2857 | " \n", 2858 | " \n", 2859 | " \n", 2860 | " \n", 2861 | " \n", 2862 | " \n", 2863 | " \n", 2864 | " \n", 2865 | " \n", 2866 | " \n", 2867 | " \n", 2868 | " \n", 2869 | " \n", 2870 | " \n", 2871 | " \n", 2872 | " \n", 2873 | " \n", 2874 | " \n", 2875 | " \n", 2876 | " \n", 2877 | " \n", 2878 | " \n", 2879 | " \n", 2880 | " \n", 2881 | " \n", 2882 | " \n", 2883 | " \n", 2884 | " \n", 2885 | " \n", 2886 | " \n", 2887 | " \n", 2888 | " \n", 2889 | " \n", 2890 | " \n", 2891 | " \n", 2892 | " \n", 2893 | " \n", 2894 | " \n", 2895 | " \n", 2896 | " \n", 2897 | " \n", 2898 | " \n", 2899 | " \n", 2905 | " \n", 2906 | " \n", 2907 | " \n", 2910 | " \n", 2911 | " \n", 2912 | " \n", 2913 | " \n", 2914 | " \n", 2915 | " \n", 2945 | " \n", 2976 | " \n", 2977 | " \n", 2993 | " \n", 3022 | " \n", 3054 | " \n", 3074 | " \n", 3090 | " \n", 3113 | " \n", 3114 | " \n", 3115 | " \n", 3116 | " \n", 3117 | " \n", 3118 | " \n", 3119 | " \n", 3120 | " \n", 3121 | " \n", 3122 | " \n", 3123 | " \n", 3124 | " \n", 3125 | " \n", 3126 | " \n", 3127 | " \n", 3128 | " \n", 3129 | " \n", 3130 | " \n", 3131 | " \n", 3134 | " \n", 3135 | " \n", 3136 | " \n", 3137 | " \n", 3138 | " \n", 3139 | " \n", 3151 | " \n", 3171 | " \n", 3182 | " \n", 3212 | " \n", 3226 | " \n", 3244 | " \n", 3262 | " \n", 3263 | " \n", 3264 | " \n", 3265 | " \n", 3266 | " \n", 3267 | " \n", 3268 | " \n", 3269 | " \n", 3270 | " \n", 3271 | " \n", 3272 | " \n", 3273 | " \n", 3274 | " \n", 3275 | " \n", 3276 | " \n", 3277 | " \n", 3278 | " \n", 3279 | " \n", 3280 | " \n", 3281 | " \n", 3282 | " \n", 3283 | " \n", 3284 | " \n", 3285 | " \n", 3286 | " \n", 3287 | " \n", 3288 | " \n", 3289 | " \n", 3290 | "\n" 3291 | ], 3292 | "text/plain": [ 3293 | "" 3294 | ] 3295 | }, 3296 | "metadata": {}, 3297 | "output_type": "display_data" 3298 | } 3299 | ], 3300 | "source": [ 3301 | "import numpy as np\n", 3302 | "from pyESN import ESN\n", 3303 | "from matplotlib import pyplot as plt\n", 3304 | "%matplotlib inline\n", 3305 | "\n", 3306 | "data = np.load('mackey_glass_t17.npy') # http://minds.jacobs-university.de/mantas/code\n", 3307 | "esn = ESN(n_inputs = 1,\n", 3308 | " n_outputs = 1,\n", 3309 | " n_reservoir = 500,\n", 3310 | " spectral_radius = 1.5,\n", 3311 | " random_state=42)\n", 3312 | "\n", 3313 | "trainlen = 2000\n", 3314 | "future = 2000\n", 3315 | "pred_training = esn.fit(np.ones(trainlen),data[:trainlen])\n", 3316 | "\n", 3317 | "prediction = esn.predict(np.ones(future))\n", 3318 | "print(\"test error: \\n\"+str(np.sqrt(np.mean((prediction.flatten() - data[trainlen:trainlen+future])**2))))\n", 3319 | "\n", 3320 | "plt.figure(figsize=(11,1.5))\n", 3321 | "plt.plot(range(0,trainlen+future),data[0:trainlen+future],'k',label=\"target system\")\n", 3322 | "plt.plot(range(trainlen,trainlen+future),prediction,'r', label=\"free running ESN\")\n", 3323 | "lo,hi = plt.ylim()\n", 3324 | "plt.plot([trainlen,trainlen],[lo+np.spacing(1),hi-np.spacing(1)],'k:')\n", 3325 | "plt.legend(loc=(0.61,1.1),fontsize='x-small')" 3326 | ] 3327 | } 3328 | ], 3329 | "metadata": { 3330 | "kernelspec": { 3331 | "display_name": "Python 3", 3332 | "language": "python", 3333 | "name": "python3" 3334 | }, 3335 | "language_info": { 3336 | "codemirror_mode": { 3337 | "name": "ipython", 3338 | "version": 3 3339 | }, 3340 | "file_extension": ".py", 3341 | "mimetype": "text/x-python", 3342 | "name": "python", 3343 | "nbconvert_exporter": "python", 3344 | "pygments_lexer": "ipython3", 3345 | "version": "3.4.3" 3346 | } 3347 | }, 3348 | "nbformat": 4, 3349 | "nbformat_minor": 0 3350 | } 3351 | --------------------------------------------------------------------------------