├── LICENSE ├── README.md ├── examples ├── datasets │ └── mg17.csv ├── mackey-glass.py └── mnist.py ├── setup.py └── torchesn ├── __init__.py ├── nn ├── __init__.py ├── echo_state_network.py └── reservoir.py └── utils ├── __init__.py └── utilities.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Stefano Nardo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-ESN 2 | 3 | PyTorch-ESN is a PyTorch module, written in Python, implementing Echo State Networks with leaky-integrated units. ESN's implementation with more than one layer is based on [DeepESN](https://arxiv.org/abs/1712.04323). The readout is trainable by ridge regression or by PyTorch's optimizers. 4 | 5 | Its development started under my master thesis titled ["An Empirical Comparison of Recurrent Neural Networks on Sequence Modeling"](https://www.dropbox.com/s/gyt48dcktht7qur/document.pdf?dl=0), which was supervised by Prof. Alessio Micheli and Dr. Claudio Gallicchio at the University of Pisa. 6 | 7 | ## Prerequisites 8 | 9 | * PyTorch 10 | 11 | ## Basic Usage 12 | 13 | ### Offline training (ridge regression) 14 | 15 | #### SVD 16 | Mini-batch mode is not allowed with this method. 17 | 18 | ```python 19 | from torchesn.nn import ESN 20 | from torchesn.utils import prepare_target 21 | 22 | # prepare target matrix for offline training 23 | flat_target = prepare_target(target, seq_lengths, washout) 24 | 25 | model = ESN(input_size, hidden_size, output_size) 26 | 27 | # train 28 | model(input, washout, hidden, flat_target) 29 | 30 | # inference 31 | output, hidden = model(input, washout, hidden) 32 | ``` 33 | 34 | #### Cholesky or inverse 35 | ```python 36 | from torchesn.nn import ESN 37 | from torchesn.utils import prepare_target 38 | 39 | # prepare target matrix for offline training 40 | flat_target = prepare_target(target, seq_lengths, washout) 41 | 42 | model = ESN(input_size, hidden_size, output_size, readout_training='cholesky') 43 | 44 | # accumulate matrices for ridge regression 45 | for batch in batch_iter: 46 | model(batch, washout[batch], hidden, flat_target) 47 | 48 | # train 49 | model.fit() 50 | 51 | # inference 52 | output, hidden = model(input, washout, hidden) 53 | ``` 54 | 55 | #### Classification tasks 56 | For classification, just use one of the previous methods and pass 'mean' or 57 | 'last' to ```output_steps``` argument. 58 | 59 | ```python 60 | model = ESN(input_size, hidden_size, output_size, output_steps='mean') 61 | ``` 62 | 63 | For more information see docstrings or section 4.7 of "A Practical Guide to Applying 64 | Echo State Networks" by Mantas Lukoševičius. 65 | 66 | ### Online training (PyTorch optimizer) 67 | 68 | Same as PyTorch. 69 | -------------------------------------------------------------------------------- /examples/mackey-glass.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | import numpy as np 3 | from torchesn.nn import ESN 4 | from torchesn import utils 5 | import time 6 | 7 | device = torch.device('cuda') 8 | dtype = torch.double 9 | torch.set_default_dtype(dtype) 10 | 11 | if dtype == torch.double: 12 | data = np.loadtxt('datasets/mg17.csv', delimiter=',', dtype=np.float64) 13 | elif dtype == torch.float: 14 | data = np.loadtxt('datasets/mg17.csv', delimiter=',', dtype=np.float32) 15 | X_data = np.expand_dims(data[:, [0]], axis=1) 16 | Y_data = np.expand_dims(data[:, [1]], axis=1) 17 | X_data = torch.from_numpy(X_data).to(device) 18 | Y_data = torch.from_numpy(Y_data).to(device) 19 | 20 | trX = X_data[:5000] 21 | trY = Y_data[:5000] 22 | tsX = X_data[5000:] 23 | tsY = Y_data[5000:] 24 | 25 | washout = [500] 26 | input_size = output_size = 1 27 | hidden_size = 500 28 | loss_fcn = torch.nn.MSELoss() 29 | 30 | if __name__ == "__main__": 31 | start = time.time() 32 | 33 | # Training 34 | trY_flat = utils.prepare_target(trY.clone(), [trX.size(0)], washout) 35 | 36 | model = ESN(input_size, hidden_size, output_size) 37 | model.to(device) 38 | 39 | model(trX, washout, None, trY_flat) 40 | model.fit() 41 | output, hidden = model(trX, washout) 42 | print("Training error:", loss_fcn(output, trY[washout[0]:]).item()) 43 | 44 | # Test 45 | output, hidden = model(tsX, [0], hidden) 46 | print("Test error:", loss_fcn(output, tsY).item()) 47 | print("Ended in", time.time() - start, "seconds.") 48 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from torchvision import datasets, transforms 3 | from torchesn.nn import ESN 4 | import time 5 | 6 | 7 | def Accuracy_Correct(y_pred, y_true): 8 | labels = torch.argmax(y_pred, 1).type(y_pred.type()) 9 | correct = len((labels == y_true).nonzero()) 10 | return correct 11 | 12 | 13 | def one_hot(y, output_dim): 14 | onehot = torch.zeros(y.size(0), output_dim, device=y.device) 15 | 16 | for i in range(output_dim): 17 | onehot[y == i, i] = 1 18 | 19 | return onehot 20 | 21 | 22 | def reshape_batch(batch): 23 | batch = batch.view(batch.size(0), batch.size(1), -1) 24 | return batch.transpose(0, 1).transpose(0, 2) 25 | 26 | 27 | device = torch.device('cuda') 28 | dtype = torch.float 29 | torch.set_default_dtype(dtype) 30 | loss_fcn = Accuracy_Correct 31 | 32 | batch_size = 256 # Tune it according to your VRAM's size. 33 | input_size = 1 34 | hidden_size = 500 35 | output_size = 10 36 | washout_rate = 0.2 37 | 38 | if __name__ == "__main__": 39 | train_iter = torch.utils.data.DataLoader( 40 | datasets.MNIST('./datasets', train=True, download=True, 41 | transform=transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.1307,), (0.3081,))])), 44 | batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True) 45 | 46 | test_iter = torch.utils.data.DataLoader( 47 | datasets.MNIST('./datasets', train=False, 48 | transform=transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.1307,), (0.3081,))])), 51 | batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True) 52 | 53 | start = time.time() 54 | 55 | # Training 56 | model = ESN(input_size, hidden_size, output_size, 57 | output_steps='mean', readout_training='cholesky') 58 | model.to(device) 59 | 60 | # Fit the model 61 | for batch in train_iter: 62 | x, y = batch 63 | x = x.to(device) 64 | y = y.to(device) 65 | 66 | x = reshape_batch(x) 67 | target = one_hot(y, output_size) 68 | washout_list = [int(washout_rate * x.size(0))] * x.size(1) 69 | 70 | model(x, washout_list, None, target) 71 | model.fit() 72 | 73 | # Evaluate on training set 74 | tot_correct = 0 75 | tot_obs = 0 76 | 77 | for batch in train_iter: 78 | x, y = batch 79 | x = x.to(device) 80 | y = y.to(device) 81 | 82 | x = reshape_batch(x) 83 | washout_list = [int(washout_rate * x.size(0))] * x.size(1) 84 | 85 | output, hidden = model(x, washout_list) 86 | tot_obs += x.size(1) 87 | tot_correct += loss_fcn(output[-1], y.type(torch.get_default_dtype())) 88 | 89 | print("Training accuracy:", tot_correct / tot_obs) 90 | 91 | # Test 92 | for batch in test_iter: 93 | x, y = batch 94 | x = x.to(device) 95 | y = y.to(device) 96 | 97 | x = reshape_batch(x) 98 | washout_list = [int(washout_rate * x.size(0))] * x.size(1) 99 | 100 | output, hidden = model(x, washout_list) 101 | tot_obs += x.size(1) 102 | tot_correct += loss_fcn(output[-1], y.type(torch.get_default_dtype())) 103 | 104 | print("Test accuracy:", tot_correct / tot_obs) 105 | 106 | print("Ended in", time.time() - start, "seconds.") 107 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='pytorch-esn', 4 | version='1.2.4', 5 | packages=find_packages(), 6 | install_requires=[ 7 | 'torch', 8 | 'torchvision', 9 | 'numpy' 10 | ], 11 | description="Echo State Network module for PyTorch.", 12 | author='Stefano Nardo', 13 | author_email='stefano_nardo@msn.com', 14 | license='MIT', 15 | url="https://github.com/stefanonardo/pytorch-esn" 16 | ) 17 | -------------------------------------------------------------------------------- /torchesn/__init__.py: -------------------------------------------------------------------------------- 1 | import torchesn.nn 2 | import torchesn.utils 3 | -------------------------------------------------------------------------------- /torchesn/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .echo_state_network import ESN 2 | from .reservoir import Reservoir, VariableRecurrent, AutogradReservoir, \ 3 | Recurrent, StackedRNN, ResIdCell, ResReLUCell, ResTanhCell 4 | 5 | __all__ = ['ESN', 'Reservoir', 'Recurrent', 'VariableRecurrent', 6 | 'AutogradReservoir', 'StackedRNN', 'ResIdCell', 'ResReLUCell', 7 | 'ResTanhCell'] 8 | -------------------------------------------------------------------------------- /torchesn/nn/echo_state_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence 4 | from .reservoir import Reservoir 5 | from ..utils import washout_tensor 6 | 7 | 8 | class ESN(nn.Module): 9 | """ Applies an Echo State Network to an input sequence. Multi-layer Echo 10 | State Network is based on paper 11 | Deep Echo State Network (DeepESN): A Brief Survey - Gallicchio, Micheli 2017 12 | 13 | Args: 14 | input_size: The number of expected features in the input x. 15 | hidden_size: The number of features in the hidden state h. 16 | output_size: The number of expected features in the output y. 17 | num_layers: Number of recurrent layers. Default: 1 18 | nonlinearity: The non-linearity to use ['tanh'|'relu'|'id']. 19 | Default: 'tanh' 20 | batch_first: If ``True``, then the input and output tensors are provided 21 | as (batch, seq, feature). Default: ``False`` 22 | leaking_rate: Leaking rate of reservoir's neurons. Default: 1 23 | spectral_radius: Desired spectral radius of recurrent weight matrix. 24 | Default: 0.9 25 | w_ih_scale: Scale factor for first layer's input weights (w_ih_l0). It 26 | can be a number or a tensor of size '1 + input_size' and first element 27 | is the bias' scale factor. Default: 1 28 | lambda_reg: Ridge regression's shrinkage parameter. Default: 1 29 | density: Recurrent weight matrix's density. Default: 1 30 | w_io: If 'True', then the network uses trainable input-to-output 31 | connections. Default: ``False`` 32 | readout_training: Readout's traning algorithm ['gd'|'svd'|'cholesky'|'inv']. 33 | If 'gd', gradients are accumulated during backward 34 | pass. If 'svd', 'cholesky' or 'inv', the network will learn readout's 35 | parameters during the forward pass using ridge regression. The 36 | coefficients are computed using SVD, Cholesky decomposition or 37 | standard ridge regression formula. 'gd', 'cholesky' and 'inv' 38 | permit the usage of mini-batches to train the readout. 39 | If 'inv' and matrix is singular, pseudoinverse is used. 40 | output_steps: defines how the reservoir's output will be used by ridge 41 | regression method ['all', 'mean', 'last']. 42 | If 'all', the entire reservoir output matrix will be used. 43 | If 'mean', the mean of reservoir output matrix along the timesteps 44 | dimension will be used. 45 | If 'last', only the last timestep of the reservoir output matrix 46 | will be used. 47 | 'mean' and 'last' are useful for classification tasks. 48 | 49 | Inputs: input, washout, h_0, target 50 | input (seq_len, batch, input_size): tensor containing the features of 51 | the input sequence. The input can also be a packed variable length 52 | sequence. See `torch.nn.utils.rnn.pack_padded_sequence` 53 | washout (batch): number of initial timesteps during which output of the 54 | reservoir is not forwarded to the readout. One value per batch's 55 | sample. 56 | h_0 (num_layers, batch, hidden_size): tensor containing 57 | the initial reservoir's hidden state for each element in the batch. 58 | Defaults to zero if not provided. 59 | 60 | target (seq_len*batch - washout*batch, output_size): tensor containing 61 | the features of the batch's target sequences rolled out along one 62 | axis, minus the washouts and the padded values. It is only needed 63 | for readout's training in offline mode. Use `prepare_target` to 64 | compute it. 65 | 66 | Outputs: output, h_n 67 | - output (seq_len, batch, hidden_size): tensor containing the output 68 | features (h_k) from the readout, for each k. 69 | - **h_n** (num_layers * num_directions, batch, hidden_size): tensor 70 | containing the reservoir's hidden state for k=seq_len. 71 | """ 72 | 73 | def __init__(self, input_size, hidden_size, output_size, num_layers=1, 74 | nonlinearity='tanh', batch_first=False, leaking_rate=1, 75 | spectral_radius=0.9, w_ih_scale=1, lambda_reg=0, density=1, 76 | w_io=False, readout_training='svd', output_steps='all'): 77 | super(ESN, self).__init__() 78 | 79 | self.input_size = input_size 80 | self.hidden_size = hidden_size 81 | self.output_size = output_size 82 | self.num_layers = num_layers 83 | if nonlinearity == 'tanh': 84 | mode = 'RES_TANH' 85 | elif nonlinearity == 'relu': 86 | mode = 'RES_RELU' 87 | elif nonlinearity == 'id': 88 | mode = 'RES_ID' 89 | else: 90 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 91 | self.batch_first = batch_first 92 | self.leaking_rate = leaking_rate 93 | self.spectral_radius = spectral_radius 94 | if type(w_ih_scale) != torch.Tensor: 95 | self.w_ih_scale = torch.ones(input_size + 1) 96 | self.w_ih_scale *= w_ih_scale 97 | else: 98 | self.w_ih_scale = w_ih_scale 99 | 100 | self.lambda_reg = lambda_reg 101 | self.density = density 102 | self.w_io = w_io 103 | if readout_training in {'gd', 'svd', 'cholesky', 'inv'}: 104 | self.readout_training = readout_training 105 | else: 106 | raise ValueError("Unknown readout training algorithm '{}'".format( 107 | readout_training)) 108 | 109 | self.reservoir = Reservoir(mode, input_size, hidden_size, num_layers, 110 | leaking_rate, spectral_radius, 111 | self.w_ih_scale, density, 112 | batch_first=batch_first) 113 | 114 | if w_io: 115 | self.readout = nn.Linear(input_size + hidden_size * num_layers, 116 | output_size) 117 | else: 118 | self.readout = nn.Linear(hidden_size * num_layers, output_size) 119 | if readout_training == 'offline': 120 | self.readout.weight.requires_grad = False 121 | 122 | if output_steps in {'all', 'mean', 'last'}: 123 | self.output_steps = output_steps 124 | else: 125 | raise ValueError("Unknown task '{}'".format( 126 | output_steps)) 127 | 128 | self.XTX = None 129 | self.XTy = None 130 | self.X = None 131 | 132 | def forward(self, input, washout, h_0=None, target=None): 133 | with torch.no_grad(): 134 | is_packed = isinstance(input, PackedSequence) 135 | 136 | output, hidden = self.reservoir(input, h_0) 137 | if is_packed: 138 | output, seq_lengths = pad_packed_sequence(output, 139 | batch_first=self.batch_first) 140 | else: 141 | if self.batch_first: 142 | seq_lengths = output.size(0) * [output.size(1)] 143 | else: 144 | seq_lengths = output.size(1) * [output.size(0)] 145 | 146 | if self.batch_first: 147 | output = output.transpose(0, 1) 148 | 149 | output, seq_lengths = washout_tensor(output, washout, seq_lengths) 150 | 151 | if self.w_io: 152 | if is_packed: 153 | input, input_lengths = pad_packed_sequence(input, 154 | batch_first=self.batch_first) 155 | else: 156 | input_lengths = [input.size(0)] * input.size(1) 157 | 158 | if self.batch_first: 159 | input = input.transpose(0, 1) 160 | 161 | input, _ = washout_tensor(input, washout, input_lengths) 162 | output = torch.cat([input, output], -1) 163 | 164 | if self.readout_training == 'gd' or target is None: 165 | with torch.enable_grad(): 166 | output = self.readout(output) 167 | 168 | if is_packed: 169 | for i in range(output.size(1)): 170 | if seq_lengths[i] < output.size(0): 171 | output[seq_lengths[i]:, i] = 0 172 | 173 | if self.batch_first: 174 | output = output.transpose(0, 1) 175 | 176 | # Uncomment if you want packed output. 177 | # if is_packed: 178 | # output = pack_padded_sequence(output, seq_lengths, 179 | # batch_first=self.batch_first) 180 | 181 | return output, hidden 182 | 183 | else: 184 | batch_size = output.size(1) 185 | 186 | X = torch.ones(target.size(0), 1 + output.size(2), device=target.device) 187 | row = 0 188 | for s in range(batch_size): 189 | if self.output_steps == 'all': 190 | X[row:row + seq_lengths[s], 1:] = output[:seq_lengths[s], 191 | s] 192 | row += seq_lengths[s] 193 | elif self.output_steps == 'mean': 194 | X[row, 1:] = torch.mean(output[:seq_lengths[s], s], 0) 195 | row += 1 196 | elif self.output_steps == 'last': 197 | X[row, 1:] = output[seq_lengths[s] - 1, s] 198 | row += 1 199 | 200 | if self.readout_training == 'cholesky': 201 | if self.XTX is None: 202 | self.XTX = torch.mm(X.t(), X) 203 | self.XTy = torch.mm(X.t(), target) 204 | else: 205 | self.XTX += torch.mm(X.t(), X) 206 | self.XTy += torch.mm(X.t(), target) 207 | 208 | elif self.readout_training == 'svd': 209 | # Scikit-Learn SVD solver for ridge regression. 210 | U, s, V = torch.svd(X) 211 | idx = s > 1e-15 # same default value as scipy.linalg.pinv 212 | s_nnz = s[idx][:, None] 213 | UTy = torch.mm(U.t(), target) 214 | d = torch.zeros(s.size(0), 1, device=X.device) 215 | d[idx] = s_nnz / (s_nnz ** 2 + self.lambda_reg) 216 | d_UT_y = d * UTy 217 | W = torch.mm(V, d_UT_y).t() 218 | 219 | self.readout.bias = nn.Parameter(W[:, 0]) 220 | self.readout.weight = nn.Parameter(W[:, 1:]) 221 | elif self.readout_training == 'inv': 222 | self.X = X 223 | if self.XTX is None: 224 | self.XTX = torch.mm(X.t(), X) 225 | self.XTy = torch.mm(X.t(), target) 226 | else: 227 | self.XTX += torch.mm(X.t(), X) 228 | self.XTy += torch.mm(X.t(), target) 229 | 230 | return None, None 231 | 232 | def fit(self): 233 | if self.readout_training in {'gd', 'svd'}: 234 | return 235 | 236 | if self.readout_training == 'cholesky': 237 | W = torch.linalg.solve(self.XTy, 238 | self.XTX + self.lambda_reg * torch.eye( 239 | self.XTX.size(0), device=self.XTX.device))[0].t() 240 | self.XTX = None 241 | self.XTy = None 242 | 243 | self.readout.bias = nn.Parameter(W[:, 0]) 244 | self.readout.weight = nn.Parameter(W[:, 1:]) 245 | elif self.readout_training == 'inv': 246 | I = (self.lambda_reg * torch.eye(self.XTX.size(0))).to( 247 | self.XTX.device) 248 | A = self.XTX + I 249 | X_rank = torch.linalg.matrix_rank(A).item() 250 | 251 | if X_rank == self.X.size(0): 252 | W = torch.mm(torch.inverse(A), self.XTy).t() 253 | else: 254 | W = torch.mm(torch.pinverse(A), self.XTy).t() 255 | 256 | self.readout.bias = nn.Parameter(W[:, 0]) 257 | self.readout.weight = nn.Parameter(W[:, 1:]) 258 | 259 | self.XTX = None 260 | self.XTy = None 261 | 262 | def reset_parameters(self): 263 | self.reservoir.reset_parameters() 264 | self.readout.reset_parameters() 265 | -------------------------------------------------------------------------------- /torchesn/nn/reservoir.py: -------------------------------------------------------------------------------- 1 | """ 2 | This examples is not intended to be optimized. Its purpose is to show how to handle 3 | big datasets with multiple sequences. The accuracy should be around 10%. 4 | """ 5 | 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | from torch.nn.utils.rnn import PackedSequence 11 | import torch.sparse 12 | 13 | 14 | def apply_permutation(tensor, permutation, dim=1): 15 | # type: (Tensor, Tensor, int) -> Tensor 16 | return tensor.index_select(dim, permutation) 17 | 18 | 19 | class Reservoir(nn.Module): 20 | 21 | def __init__(self, mode, input_size, hidden_size, num_layers, leaking_rate, 22 | spectral_radius, w_ih_scale, 23 | density, bias=True, batch_first=False): 24 | super(Reservoir, self).__init__() 25 | self.mode = mode 26 | self.input_size = input_size 27 | self.hidden_size = hidden_size 28 | self.num_layers = num_layers 29 | self.leaking_rate = leaking_rate 30 | self.spectral_radius = spectral_radius 31 | self.w_ih_scale = w_ih_scale 32 | self.density = density 33 | self.bias = bias 34 | self.batch_first = batch_first 35 | 36 | self._all_weights = [] 37 | for layer in range(num_layers): 38 | layer_input_size = input_size if layer == 0 else hidden_size 39 | 40 | w_ih = nn.Parameter(torch.Tensor(hidden_size, layer_input_size)) 41 | w_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size)) 42 | b_ih = nn.Parameter(torch.Tensor(hidden_size)) 43 | layer_params = (w_ih, w_hh, b_ih) 44 | 45 | param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] 46 | if bias: 47 | param_names += ['bias_ih_l{}{}'] 48 | param_names = [x.format(layer, '') for x in param_names] 49 | 50 | for name, param in zip(param_names, layer_params): 51 | setattr(self, name, param) 52 | self._all_weights.append(param_names) 53 | 54 | self.reset_parameters() 55 | 56 | def _apply(self, fn): 57 | ret = super(Reservoir, self)._apply(fn) 58 | return ret 59 | 60 | def reset_parameters(self): 61 | weight_dict = self.state_dict() 62 | for key, value in weight_dict.items(): 63 | if key == 'weight_ih_l0': 64 | nn.init.uniform_(value, -1, 1) 65 | value *= self.w_ih_scale[1:] 66 | elif re.fullmatch('weight_ih_l[^0]*', key): 67 | nn.init.uniform_(value, -1, 1) 68 | elif re.fullmatch('bias_ih_l[0-9]*', key): 69 | nn.init.uniform_(value, -1, 1) 70 | value *= self.w_ih_scale[0] 71 | elif re.fullmatch('weight_hh_l[0-9]*', key): 72 | w_hh = torch.Tensor(self.hidden_size * self.hidden_size) 73 | w_hh.uniform_(-1, 1) 74 | if self.density < 1: 75 | zero_weights = torch.randperm( 76 | int(self.hidden_size * self.hidden_size)) 77 | zero_weights = zero_weights[ 78 | :int( 79 | self.hidden_size * self.hidden_size * ( 80 | 1 - self.density))] 81 | w_hh[zero_weights] = 0 82 | w_hh = w_hh.view(self.hidden_size, self.hidden_size) 83 | abs_eigs = torch.abs(torch.linalg.eigvals(w_hh)) 84 | weight_dict[key] = w_hh * (self.spectral_radius / torch.max(abs_eigs)) 85 | 86 | self.load_state_dict(weight_dict) 87 | 88 | def check_input(self, input, batch_sizes): 89 | # type: (Tensor, Optional[Tensor]) -> None 90 | expected_input_dim = 2 if batch_sizes is not None else 3 91 | if input.dim() != expected_input_dim: 92 | raise RuntimeError( 93 | 'input must have {} dimensions, got {}'.format( 94 | expected_input_dim, input.dim())) 95 | if self.input_size != input.size(-1): 96 | raise RuntimeError( 97 | 'input.size(-1) must be equal to input_size. Expected {}, got {}'.format( 98 | self.input_size, input.size(-1))) 99 | 100 | def get_expected_hidden_size(self, input, batch_sizes): 101 | # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int] 102 | if batch_sizes is not None: 103 | mini_batch = batch_sizes[0] 104 | mini_batch = int(mini_batch) 105 | else: 106 | mini_batch = input.size(0) if self.batch_first else input.size(1) 107 | expected_hidden_size = (self.num_layers, mini_batch, self.hidden_size) 108 | return expected_hidden_size 109 | 110 | def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): 111 | # type: (Tensor, Tuple[int, int, int], str) -> None 112 | if hx.size() != expected_hidden_size: 113 | raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) 114 | 115 | def check_forward_args(self, input, hidden, batch_sizes): 116 | # type: (Tensor, Tensor, Optional[Tensor]) -> None 117 | self.check_input(input, batch_sizes) 118 | expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) 119 | 120 | self.check_hidden_size(hidden, expected_hidden_size) 121 | 122 | def permute_hidden(self, hx, permutation): 123 | # type: (Tensor, Optional[Tensor]) -> Tensor 124 | if permutation is None: 125 | return hx 126 | return apply_permutation(hx, permutation) 127 | 128 | def forward(self, input, hx=None): 129 | is_packed = isinstance(input, PackedSequence) 130 | if is_packed: 131 | input, batch_sizes, sorted_indices, unsorted_indices = input 132 | max_batch_size = int(batch_sizes[0]) 133 | else: 134 | batch_sizes = None 135 | max_batch_size = input.size(0) if self.batch_first else input.size(1) 136 | sorted_indices = None 137 | unsorted_indices = None 138 | 139 | if hx is None: 140 | hx = input.new_zeros(self.num_layers, max_batch_size, 141 | self.hidden_size, requires_grad=False) 142 | else: 143 | # Each batch of the hidden state should match the input sequence that 144 | # the user believes he/she is passing in. 145 | hx = self.permute_hidden(hx, sorted_indices) 146 | 147 | flat_weight = None 148 | 149 | self.check_forward_args(input, hx, batch_sizes) 150 | func = AutogradReservoir( 151 | self.mode, 152 | self.input_size, 153 | self.hidden_size, 154 | num_layers=self.num_layers, 155 | batch_first=self.batch_first, 156 | train=self.training, 157 | variable_length=is_packed, 158 | flat_weight=flat_weight, 159 | leaking_rate=self.leaking_rate 160 | ) 161 | output, hidden = func(input, self.all_weights, hx, batch_sizes) 162 | if is_packed: 163 | output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) 164 | return output, self.permute_hidden(hidden, unsorted_indices) 165 | 166 | def extra_repr(self): 167 | s = '({input_size}, {hidden_size}' 168 | if self.num_layers != 1: 169 | s += ', num_layers={num_layers}' 170 | if self.bias is not True: 171 | s += ', bias={bias}' 172 | if self.batch_first is not False: 173 | s += ', batch_first={batch_first}' 174 | s += ')' 175 | return s.format(**self.__dict__) 176 | 177 | def __setstate__(self, d): 178 | super(Reservoir, self).__setstate__(d) 179 | self.__dict__.setdefault('_data_ptrs', []) 180 | if 'all_weights' in d: 181 | self._all_weights = d['all_weights'] 182 | if isinstance(self._all_weights[0][0], str): 183 | return 184 | num_layers = self.num_layers 185 | self._all_weights = [] 186 | for layer in range(num_layers): 187 | weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}'] 188 | weights = [x.format(layer) for x in weights] 189 | if self.bias: 190 | self._all_weights += [weights] 191 | else: 192 | self._all_weights += [weights[:2]] 193 | 194 | @property 195 | def all_weights(self): 196 | return [[getattr(self, weight) for weight in weights] for weights in 197 | self._all_weights] 198 | 199 | 200 | def AutogradReservoir(mode, input_size, hidden_size, num_layers=1, 201 | batch_first=False, train=True, 202 | batch_sizes=None, variable_length=False, flat_weight=None, 203 | leaking_rate=1): 204 | if mode == 'RES_TANH': 205 | cell = ResTanhCell 206 | elif mode == 'RES_RELU': 207 | cell = ResReLUCell 208 | elif mode == 'RES_ID': 209 | cell = ResIdCell 210 | 211 | if variable_length: 212 | layer = (VariableRecurrent(cell, leaking_rate),) 213 | else: 214 | layer = (Recurrent(cell, leaking_rate),) 215 | 216 | func = StackedRNN(layer, 217 | num_layers, 218 | False, 219 | train=train) 220 | 221 | def forward(input, weight, hidden, batch_sizes): 222 | if batch_first and batch_sizes is None: 223 | input = input.transpose(0, 1) 224 | 225 | nexth, output = func(input, hidden, weight, batch_sizes) 226 | 227 | if batch_first and not variable_length: 228 | output = output.transpose(0, 1) 229 | 230 | return output, nexth 231 | 232 | return forward 233 | 234 | 235 | def Recurrent(inner, leaking_rate): 236 | def forward(input, hidden, weight, batch_sizes): 237 | output = [] 238 | steps = range(input.size(0)) 239 | for i in steps: 240 | hidden = inner(input[i], hidden, leaking_rate, *weight) 241 | # hack to handle LSTM 242 | output.append(hidden[0] if isinstance(hidden, tuple) else hidden) 243 | 244 | output = torch.cat(output, 0).view(input.size(0), *output[0].size()) 245 | 246 | return hidden, output 247 | 248 | return forward 249 | 250 | 251 | def VariableRecurrent(inner, leaking_rate): 252 | def forward(input, hidden, weight, batch_sizes): 253 | output = [] 254 | input_offset = 0 255 | last_batch_size = batch_sizes[0] 256 | hiddens = [] 257 | flat_hidden = not isinstance(hidden, tuple) 258 | if flat_hidden: 259 | hidden = (hidden,) 260 | for batch_size in batch_sizes: 261 | step_input = input[input_offset:input_offset + batch_size] 262 | input_offset += batch_size 263 | 264 | dec = last_batch_size - batch_size 265 | if dec > 0: 266 | hiddens.append(tuple(h[-dec:] for h in hidden)) 267 | hidden = tuple(h[:-dec] for h in hidden) 268 | last_batch_size = batch_size 269 | 270 | if flat_hidden: 271 | hidden = (inner(step_input, hidden[0], leaking_rate, *weight),) 272 | else: 273 | hidden = inner(step_input, hidden, leaking_rate, *weight) 274 | 275 | output.append(hidden[0]) 276 | hiddens.append(hidden) 277 | hiddens.reverse() 278 | 279 | hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) 280 | assert hidden[0].size(0) == batch_sizes[0] 281 | if flat_hidden: 282 | hidden = hidden[0] 283 | output = torch.cat(output, 0) 284 | 285 | return hidden, output 286 | 287 | return forward 288 | 289 | 290 | def StackedRNN(inners, num_layers, lstm=False, train=True): 291 | num_directions = len(inners) 292 | total_layers = num_layers * num_directions 293 | 294 | def forward(input, hidden, weight, batch_sizes): 295 | assert (len(weight) == total_layers) 296 | next_hidden = [] 297 | all_layers_output = [] 298 | 299 | for i in range(num_layers): 300 | all_output = [] 301 | for j, inner in enumerate(inners): 302 | l = i * num_directions + j 303 | 304 | hy, output = inner(input, hidden[l], weight[l], batch_sizes) 305 | next_hidden.append(hy) 306 | all_output.append(output) 307 | 308 | input = torch.cat(all_output, input.dim() - 1) 309 | all_layers_output.append(input) 310 | 311 | all_layers_output = torch.cat(all_layers_output, -1) 312 | next_hidden = torch.cat(next_hidden, 0).view( 313 | total_layers, *next_hidden[0].size()) 314 | 315 | return next_hidden, all_layers_output 316 | 317 | return forward 318 | 319 | 320 | def ResTanhCell(input, hidden, leaking_rate, w_ih, w_hh, b_ih=None): 321 | hy_ = torch.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh)) 322 | hy = (1 - leaking_rate) * hidden + leaking_rate * hy_ 323 | return hy 324 | 325 | 326 | def ResReLUCell(input, hidden, leaking_rate, w_ih, w_hh, b_ih=None): 327 | hy_ = F.relu(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh)) 328 | hy = (1 - leaking_rate) * hidden + leaking_rate * hy_ 329 | return hy 330 | 331 | 332 | def ResIdCell(input, hidden, leaking_rate, w_ih, w_hh, b_ih=None): 333 | hy_ = F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh) 334 | hy = (1 - leaking_rate) * hidden + leaking_rate * hy_ 335 | return hy 336 | -------------------------------------------------------------------------------- /torchesn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utilities import prepare_target, washout_tensor 2 | 3 | __all__ = ['prepare_target', 'washout_tensor'] -------------------------------------------------------------------------------- /torchesn/utils/utilities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def prepare_target(target, seq_lengths, washout, batch_first=False): 5 | """ Preprocess target for offline training. 6 | 7 | Args: 8 | target (seq_len, batch, output_size): tensor containing 9 | the features of the target sequence. 10 | seq_lengths: list of lengths of each sequence in the batch. 11 | washout: number of initial timesteps during which output of the 12 | reservoir is not forwarded to the readout. One value per sample. 13 | batch_first: If ``True``, then the input and output tensors are provided 14 | as (batch, seq, feature). Default: ``False`` 15 | 16 | Returns: 17 | tensor containing the features of the batch's sequences rolled out along 18 | one axis, minus the washouts and the padded values. 19 | """ 20 | 21 | if batch_first: 22 | target = target.transpose(0, 1) 23 | n_sequences = target.size(1) 24 | target_dim = target.size(2) 25 | train_len = sum(torch.tensor(seq_lengths) - torch.tensor(washout)).item() 26 | 27 | new_target = torch.zeros(train_len, target_dim, device=target.device) 28 | 29 | idx = 0 30 | for s in range(n_sequences): 31 | batch_len = seq_lengths[s] - washout[s] 32 | new_target[idx:idx + batch_len, :] = target[washout[s]:seq_lengths[s], s, :] 33 | idx += batch_len 34 | 35 | return new_target 36 | 37 | 38 | def washout_tensor(tensor, washout, seq_lengths, bidirectional=False, batch_first=False): 39 | tensor = tensor.transpose(0, 1) if batch_first else tensor.clone() 40 | if type(seq_lengths) == list: 41 | seq_lengths = seq_lengths.copy() 42 | if type(seq_lengths) == torch.Tensor: 43 | seq_lengths = seq_lengths.clone() 44 | 45 | for b in range(tensor.size(1)): 46 | if washout[b] > 0: 47 | tmp = tensor[washout[b]:seq_lengths[b], b].clone() 48 | tensor[:seq_lengths[b] - washout[b], b] = tmp 49 | tensor[seq_lengths[b] - washout[b]:, b] = 0 50 | seq_lengths[b] -= washout[b] 51 | 52 | if bidirectional: 53 | tensor[seq_lengths[b] - washout[b]:, b] = 0 54 | seq_lengths[b] -= washout[b] 55 | 56 | if type(seq_lengths) == list: 57 | max_len = max(seq_lengths) 58 | else: 59 | max_len = max(seq_lengths).item() 60 | 61 | return tensor[:max_len], seq_lengths 62 | --------------------------------------------------------------------------------