├── convlstm ├── package │ ├── __init__.py │ ├── convlstm_cpp.py │ ├── convlstm_cuda.py │ ├── convlstm_ch_pooling.py │ └── convlstm.py ├── test │ ├── __init__.py │ └── tests.py ├── src │ ├── setup.py │ ├── lltm.cpp │ └── convlstm.cpp └── __init__.py ├── lstm_stack ├── __init__.py └── lstm_stack.py ├── convlstm_autoencoder ├── __init__.py └── convlstm_autoencoder.py ├── lstm_cell_stack ├── __init__.py └── lstm_cell_stack.py ├── lstm_autoencoder ├── __init__.py └── lstm_autoencoder.py ├── __init__.py ├── README.md ├── LICENSE └── .gitignore /convlstm/package/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /convlstm/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lstm_stack/__init__.py: -------------------------------------------------------------------------------- 1 | from .lstm_stack import LSTMStack 2 | -------------------------------------------------------------------------------- /convlstm_autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .convlstm_autoencoder import ConvLSTMAutoencoder 2 | -------------------------------------------------------------------------------- /lstm_cell_stack/__init__.py: -------------------------------------------------------------------------------- 1 | from .lstm_cell_stack import LSTMCellStack, ImgLSTMCellStack 2 | -------------------------------------------------------------------------------- /lstm_autoencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .lstm_autoencoder import LSTMAutoencoder, ImgLSTMAutoencoder 2 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .convlstm import ConvLSTM, ConvLSTMCell, ConvLSTMChannelPooling, ConvLSTMCPPCell 2 | from .lstm_stack import LSTMStack 3 | from .lstm_cell_stack import LSTMCellStack, ImgLSTMCellStack 4 | from .lstm_autoencoder import LSTMAutoencoder, ImgLSTMAutoencoder 5 | from .convlstm_autoencoder import ConvLSTMAutoencoder 6 | -------------------------------------------------------------------------------- /convlstm/src/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension 3 | 4 | setup(name='convlstmcpp', 5 | ext_modules=[CppExtension('convlstmcpp', ['convlstm.cpp'])], 6 | cmdclass={'build_ext': BuildExtension}) 7 | 8 | #setup(name='lltmcuda', 9 | # ext_modules=[CUDAExtension('lltmcuda', ['lltm.cu'])], 10 | # cmdclass={'build_ext': BuildExtension}) 11 | -------------------------------------------------------------------------------- /convlstm/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import torch 3 | import convlstmcpp 4 | except ModuleNotFoundError as err: 5 | import os 6 | import subprocess 7 | dir_path = os.path.dirname(os.path.realpath(__file__)) 8 | completed_process = subprocess.run("python setup.py install", cwd=os.path.join(dir_path, 'src')) 9 | 10 | from .package.convlstm import ConvLSTM, ConvLSTMCell, HiddenState, HiddenStateStacked 11 | from .package.convlstm_cpp import ConvLSTMCPPCell 12 | from .package.convlstm_ch_pooling import ConvLSTMChannelPooling 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_model_zoo 2 | A repository for my pytorch models. It mainly contains recurrent models: 3 | * LSTMCell stack: as PyTorch LSTM but it can be called with the same semantics of LSTMCell, i.e. it processes one sequence element at a time instead of an entire sequence. 4 | * LSTM stack: as PyTorch LSTM but allows for hidden layers of different size. 5 | * ConvLSTM: an implementation of [Shi et al. ConvLSTM](http://arxiv.org/abs/1506.04214), instead of fully connected layers for the LSTM gates it uses convolutional layers. 6 | * LSTM autoencoder: a sequence-to-sequence model similar to the model proposed in [Sutskever et al. _Sequence to Sequence Learning with Neural Networks_](http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf). The model consists of one encoder LSTM an one (or more) decoder LSTMs. The encoder reads the whole sequence and compresses it in an intermediate representation which is then used by the decoder(s) to produce a new sequence. 7 | * ConvLSTM autoencoder: same idea as the LSTM autoencoder but using ConvLSTMs instead of LSTMs. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alberto Cenzato 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /convlstm/package/convlstm_cpp.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import convlstmcpp 7 | 8 | class ConvLSTMFunction(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, input, weights, bias, old_h, old_cell): 11 | outputs = convlstmcpp.forward(input, weights, bias, old_h, old_cell) 12 | new_h, new_cell = outputs[:2] 13 | variables = [outputs[1], old_cell] + outputs[2:] + [weights] 14 | ctx.save_for_backward(*variables) 15 | 16 | return new_h, new_cell 17 | 18 | @staticmethod 19 | def backward(ctx, grad_h, grad_cell): 20 | outputs = convlstmcpp.backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) 21 | d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs 22 | return d_input, d_weights, d_bias, d_old_h, d_old_cell 23 | 24 | 25 | class ConvLSTMCPPCell(nn.Module): 26 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): 27 | super(ConvLSTMCPPCell, self).__init__() 28 | self.input_features = input_dim 29 | self.state_size = hidden_dim 30 | self.weights = nn.Parameter(torch.empty(4 * hidden_dim, input_dim + hidden_dim, kernel_size[0], kernel_size[1])) 31 | self.bias = nn.Parameter(torch.empty(4 * hidden_dim)) 32 | 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | stdv = 1.0 / math.sqrt(self.state_size) 37 | for weight in self.parameters(): 38 | weight.data.uniform_(-stdv, +stdv) 39 | 40 | def forward(self, input, state): 41 | return ConvLSTMFunction.apply(input, self.weights, self.bias, *state) -------------------------------------------------------------------------------- /convlstm/package/convlstm_cuda.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import convlstmcuda 7 | 8 | class ConvLSTMFunction(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, input, weights, bias, old_h, old_cell): 11 | outputs = convlstmcuda.forward(input, weights, bias, old_h, old_cell) 12 | new_h, new_cell = outputs[:2] 13 | variables = [outputs[1], old_cell] + outputs[2:] + [weights] 14 | ctx.save_for_backward(*variables) 15 | 16 | return new_h, new_cell 17 | 18 | @staticmethod 19 | def backward(ctx, grad_h, grad_cell): 20 | outputs = convlstmcuda.backward(grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_tensors) 21 | d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs 22 | return d_input, d_weights, d_bias, d_old_h, d_old_cell 23 | 24 | 25 | class ConvLSTMCudaCell(nn.Module): 26 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias): 27 | super(ConvLSTMCPPCell, self).__init__() 28 | self.input_features = input_dim 29 | self.state_size = hidden_dim 30 | self.weights = nn.Parameter(torch.empty(4 * hidden_dim, input_dim + hidden_dim, kernel_size[0], kernel_size[1])) 31 | self.bias = nn.Parameter(torch.empty(4 * hidden_dim)) 32 | 33 | self.reset_parameters() 34 | 35 | def reset_parameters(self): 36 | stdv = 1.0 / math.sqrt(self.state_size) 37 | for weight in self.parameters(): 38 | weight.data.uniform_(-stdv, +stdv) 39 | 40 | def forward(self, input, state): 41 | return ConvLSTMFunction.apply(input, self.weights, self.bias, *state) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # PyCharm 107 | .idea/ 108 | -------------------------------------------------------------------------------- /lstm_stack/lstm_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch import nn 6 | 7 | 8 | class LSTMStack(nn.Module): 9 | 10 | def __init__(self, input_size: int, hidden_size: List[int], batch_first: bool=True): 11 | super(LSTMStack, self).__init__() 12 | self.num_layers = len(hidden_size) 13 | self.batch_first = batch_first 14 | self.hidden_sizes = hidden_size 15 | sizes = [input_size, *hidden_size] 16 | layers = [] 17 | for l in range(self.num_layers): 18 | lstm = nn.LSTM(input_size=sizes[l], hidden_size=sizes[l+1]) 19 | layers.append(lstm) 20 | self.layers = nn.ModuleList(layers) 21 | 22 | def forward(self, input: Tensor, hidden_state: Tuple[List[Tensor], List[Tensor]]=None) \ 23 | -> Tuple[Tensor, Tuple[List[Tensor], List[Tensor]]]: 24 | # (b, t, c, h, w) -> (t, b, c, h, w) 25 | input_tensor = input.transpose(0, 1) if self.batch_first else input 26 | 27 | if hidden_state is None: 28 | hidden_state = self.init_hidden(input.size(1)) 29 | 30 | h_0, c_0 = hidden_state 31 | h_n, c_n = [], [] 32 | for l, lstm in enumerate(self.layers): 33 | input_tensor, state = lstm(input_tensor, (h_0[l], c_0[l])) 34 | h_n.append(state[0]) 35 | c_n.append(state[1]) 36 | 37 | output = input_tensor.transpose(0, 1) if self.batch_first else input_tensor 38 | return output, (h_n, c_n) 39 | 40 | def init_hidden(self, batch_size: int) -> Tuple[List[Tensor], List[Tensor]]: 41 | h_0, c_0 = [], [] 42 | for layer, hidden_size in zip(self.layers, self.hidden_sizes): 43 | dtype = layer.weight_ih_l0.dtype 44 | device = layer.weight_ih_l0.device 45 | shape = (1, batch_size, hidden_size) 46 | h = torch.zeros(shape, dtype=dtype, device=device) 47 | h_0.append(h) 48 | c_0.append(h) 49 | return h_0, c_0 50 | -------------------------------------------------------------------------------- /convlstm/src/lltm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | 7 | std::vector lltm_forward( 8 | at::Tensor input, 9 | at::Tensor weights, 10 | at::Tensor bias, 11 | at::Tensor old_h, 12 | at::Tensor old_cell) { 13 | auto X = at::cat({old_h, input}, /*dim=*/1); 14 | 15 | auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1)); 16 | auto gates = gate_weights.chunk(3, /*dim=*/1); 17 | 18 | auto input_gate = at::sigmoid(gates[0]); 19 | auto output_gate = at::sigmoid(gates[1]); 20 | auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0); 21 | 22 | auto new_cell = old_cell + candidate_cell * input_gate; 23 | auto new_h = at::tanh(new_cell) * output_gate; 24 | 25 | return {new_h, 26 | new_cell, 27 | input_gate, 28 | output_gate, 29 | candidate_cell, 30 | X, 31 | gate_weights}; 32 | } 33 | 34 | 35 | 36 | at::Tensor d_sigmoid(at::Tensor z) { 37 | auto s = at::sigmoid(z); 38 | return (1 - s) * s; 39 | } 40 | 41 | 42 | // tanh'(z) = 1 - tanh^2(z) 43 | at::Tensor d_tanh(at::Tensor z) { 44 | return 1 - z.tanh().pow(2); 45 | } 46 | 47 | // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} 48 | at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) { 49 | auto e = z.exp(); 50 | auto mask = (alpha * (e - 1)) < 0; 51 | return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); 52 | } 53 | 54 | std::vector lltm_backward(at::Tensor grad_h, at::Tensor grad_cell, //these are the gradients coming from timestep t+1 55 | at::Tensor new_cell, at::Tensor input_gate, // all variables that are part of the cell state 56 | at::Tensor output_gate, at::Tensor candidate_cell, 57 | at::Tensor X, at::Tensor gate_weights, 58 | at::Tensor weights) { 59 | auto d_output_gate = at::tanh(new_cell) * grad_h; 60 | auto d_tanh_new_cell = output_gate * grad_h; 61 | auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; 62 | 63 | auto d_old_cell = d_new_cell; 64 | auto d_candidate_cell = input_gate * d_new_cell; 65 | auto d_input_gate = candidate_cell * d_new_cell; 66 | 67 | auto gates = gate_weights.chunk(3, /*dim=*/1); 68 | d_input_gate *= d_sigmoid(gates[0]); 69 | d_output_gate *= d_sigmoid(gates[1]); 70 | d_candidate_cell *= d_elu(gates[2]); 71 | 72 | auto d_gates = at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); 73 | 74 | auto d_weights = d_gates.t().mm(X); 75 | auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/false); 76 | 77 | auto d_X = d_gates.mm(weights); 78 | const auto state_size = grad_h.size(1); 79 | auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); 80 | auto d_input = d_X.slice(/*dim=*/1, state_size); 81 | 82 | return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; 83 | } 84 | 85 | 86 | 87 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 88 | m.def("forward", &lltm_forward, "LLTM forward"); 89 | m.def("backward", &lltm_backward, "LLTM backward"); 90 | } -------------------------------------------------------------------------------- /lstm_cell_stack/lstm_cell_stack.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | 8 | class LSTMCellStack(nn.Module): 9 | """ 10 | This module is a stack of LSTMCell: instead of receiving an 11 | entire sequence in input as torch.nn.LSTM does, it accepts one 12 | element at a time. 13 | """ 14 | 15 | def __init__(self, input_size: int, hidden_size: List[int]): 16 | super(LSTMCellStack, self).__init__() 17 | self.input_size = input_size 18 | self.hidden_size = hidden_size 19 | self.num_layers = len(hidden_size) 20 | 21 | cells = [] 22 | sizes = [input_size, *hidden_size] 23 | for l in range(self.num_layers): 24 | layer = nn.LSTMCell(input_size=sizes[l], hidden_size=sizes[l+1]) 25 | cells.append(layer) 26 | self.cells = nn.ModuleList(cells) 27 | 28 | def forward(self, input: Tensor, hidden_state: Tuple[List[Tensor], List[Tensor]]=None) \ 29 | -> Tuple[Tensor, Tuple[List[Tensor], List[Tensor]]]: 30 | """ 31 | Inputs: input, (h_0, c_0) 32 | - **input** of shape `(batch, input_size)`: tensor containing the features 33 | of the input sequence. 34 | The input can also be a packed variable length sequence. 35 | See :func:`torch.nn.utils.rnn.pack_padded_sequence` or 36 | :func:`torch.nn.utils.rnn.pack_sequence` for details. 37 | - **h_0** list of size num_layers that contains tensors of shape 38 | `(batch, hidden_size)`: tensor containing the initial hidden 39 | state for each element in the batch. 40 | - **c_0** list of size num_layers that contains tensors of shape 41 | `(batch, hidden_size)`: tensor containing the initial cell 42 | state for each element in the batch. 43 | 44 | If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. 45 | 46 | Outputs: output, (h_n, c_n) 47 | - **output** of shape `(batch, hidden_size)`: tensor 48 | containing the output features `(h_t)` from the last layer of the LSTM. 49 | - **h_n** list of size num_layers that contains tensors of shape 50 | `(batch, hidden_size)`: tensor containing the hidden state. 51 | - **c_n** list of size num_layers that contains tensors of shape 52 | `(batch, hidden_size)`: tensor containing the cell state 53 | """ 54 | in_vector = input 55 | if hidden_state is None: 56 | hidden_state = self.init_hidden(input.size(0)) 57 | 58 | h_0, c_0 = hidden_state 59 | h_n_list, c_n_list = [], [] 60 | for l, layer in enumerate(self.cells): 61 | state = (h_0[l], c_0[l]) 62 | h_n, c_n = layer(in_vector, state) 63 | in_vector = h_n 64 | h_n_list.append(h_n) 65 | c_n_list.append(c_n) 66 | output = in_vector 67 | return output, (h_n_list, c_n_list) 68 | 69 | def init_hidden(self, batch_size: int) -> Tuple[List[Tensor], List[Tensor]]: 70 | h_0, c_0 = [], [] 71 | for size, cell in zip(self.hidden_size, self.cells): 72 | dtype = cell.weight_ih.dtype 73 | device = cell.weight_ih.device 74 | shape = (batch_size, size) 75 | h = torch.zeros(shape, dtype=dtype, device=device) 76 | h_0.append(h) 77 | c_0.append(h) 78 | return h_0, c_0 79 | 80 | 81 | class ImgLSTMCellStack(nn.Module): 82 | 83 | def __init__(self, image_size: Tuple[int, int, int], hidden_size: List[int]): 84 | super(ImgLSTMCellStack, self).__init__() 85 | self.image_size = image_size 86 | self.input_size = image_size[0] * image_size[1] * image_size[2] 87 | 88 | self.lstm_cell_stack = LSTMCellStack(self.input_size, hidden_size) 89 | 90 | def forward(self, input: Tensor, hidden_state: Tuple[List[Tensor], List[Tensor]]=None) \ 91 | -> Tuple[Tensor, Tuple[List[Tensor], List[Tensor]]]: 92 | 93 | flattened = input.view((-1, self.input_size)) 94 | output, state = self.lstm_cell_stack(flattened, hidden_state) 95 | 96 | output_img = output.view((-1,) + self.image_size) 97 | return output_img, state 98 | 99 | def init_hidden(self, batch_size: int) -> Tuple[List[Tensor], List[Tensor]]: 100 | return self.lstm_cell_stack.init_hidden(batch_size) 101 | -------------------------------------------------------------------------------- /convlstm/package/convlstm_ch_pooling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utils import ModelType 7 | from .convlstm import ConvLSTM#, ConvLSTMParams 8 | 9 | 10 | #class ConvLSTMChannelPoolingParams(ConvLSTMParams): 11 | # 12 | # def __init__(self, *args, **kwargs): 13 | # super(ConvLSTMChannelPoolingParams, self).__init__(*args, **kwargs) 14 | # self.model = ModelType.CH_POOLING 15 | 16 | 17 | class ConvLSTMChannelPooling(ConvLSTM): 18 | 19 | def __init__(self, out_dim, *args, **kargs): 20 | """ 21 | Extends the ConvLSTM model adding a 1x1 convolutional output layer 22 | that receives as input each channel from each ConvLSTMCell of the model. 23 | See Shi et al., 'Convolutional LSTM Network: A Machine Learning Approach 24 | for Precipitation Nowcasting' for further explanation 25 | 26 | NB: return_all_layers setting is ignored and is always == False 27 | """ 28 | super(ConvLSTMChannelPooling, self).__init__(*args, **kargs) 29 | self.output_layer = nn.Conv2d(np.sum(self.hidden_dim), out_dim, (1,1)) 30 | 31 | 32 | #def save_params(self): 33 | # return ConvLSTMChannelPoolingParams((self.height, self.width), self.input_dim, self.hidden_dim, 34 | # self.kernel_size, self.num_layers, self.batch_first, 35 | # self.bias, self.return_all_layers, self.mode) 36 | 37 | def _forward_item(self, input_tensor, hidden_state): 38 | """ 39 | Parameters 40 | ---------- 41 | input_tensor: todo 42 | 4-D Tensor either of shape (b, c, h, w) or (c, b, h, w) 43 | hidden_state: todo 44 | Tuple of two 4-D Tensor of shape (b, c, h, w) 45 | 46 | Returns 47 | ------- 48 | output, hidden_state 49 | """ 50 | layer_output_list = [] 51 | cur_layer_input = input_tensor 52 | for index, cell in enumerate(self.cell_list): 53 | hidden_state[index] = cell(input=cur_layer_input, 54 | old_state=hidden_state[index]) 55 | cur_layer_input = hidden_state[index][0] 56 | layer_output_list.append(cur_layer_input) 57 | 58 | lstms_output = torch.cat([state[0] for state in hidden_state], dim=1) 59 | 60 | last_state_list = hidden_state 61 | 62 | #if not self.return_all_layers: 63 | # layer_output_list = layer_output_list[-1] 64 | 65 | output = self.output_layer(lstms_output) 66 | 67 | return output, last_state_list 68 | 69 | 70 | def _forward_sequence(self, input_tensor, hidden_state=None): 71 | """ 72 | Parameters 73 | ---------- 74 | input_tensor: todo 75 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 76 | hidden_state: todo 77 | None. todo implement stateful 78 | 79 | Returns 80 | ------- 81 | last_state_list, layer_output 82 | """ 83 | if not self.batch_first: 84 | # (t, b, c, h, w) -> (b, t, c, h, w) 85 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 86 | 87 | # Implement stateful ConvLSTM 88 | if hidden_state is not None: 89 | raise NotImplementedError() 90 | else: 91 | hidden_state = self.init_hidden(batch_size=input_tensor.size(0)) 92 | 93 | layer_output_list = [] 94 | last_state_list = [] 95 | 96 | seq_len = input_tensor.size(1) 97 | cur_layer_input = input_tensor 98 | 99 | # for each layer compute the entire sequence 100 | # and propagate it to the next layer 101 | for layer_idx in range(self.num_layers): 102 | 103 | h, c = hidden_state[layer_idx] 104 | output_inner = [] 105 | # compute the prediction for each item in the sequence 106 | for t in range(seq_len): 107 | 108 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 109 | cur_state=[h, c]) 110 | output_inner.append(h) 111 | 112 | layer_output = torch.stack(output_inner, dim=1) 113 | cur_layer_input = layer_output 114 | 115 | layer_output_list.append(layer_output) 116 | last_state_list.append([h, c]) 117 | 118 | outputs = [] 119 | # concatenate each channel output of each layer at time t in 120 | # one tensor and feed it to the 1x1 convolutional output layer 121 | for t in range(seq_len): 122 | lstms_output_at_t = [tensor[:,t,:,:,:] for tensor in layer_output_list] 123 | concat = torch.cat(lstms_output_at_t, dim=1) 124 | output_t = self.output_layer(concat) 125 | outputs.append(output_t) 126 | 127 | return torch.stack(outputs, dim=1), last_state_list 128 | -------------------------------------------------------------------------------- /convlstm_autoencoder/convlstm_autoencoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from torch import Tensor 6 | 7 | from convlstm import ConvLSTM, HiddenState 8 | 9 | 10 | class ConvLSTMAutoencoder(nn.Module): 11 | """ 12 | This model is an implementation of the 'autoencoder' convolutional LSTM 13 | model proposed in 'Convolutional LSTM Network: A Machine Learning Approach 14 | for Precipitation Nowcasting', Shi et al., 2015, http://arxiv.org/abs/1506.04214 15 | Instead of one decoding network, as proposed in the paper, this model has two 16 | decoding networks as in 'Unsupervised Learning of Video Representations using LSTMs', 17 | Srivastava et al., 2016. 18 | 19 | The encoding network receives a sequence of images and outputs its hidden state that 20 | should represent a compressed representation of the sequence. Its hidden state is then 21 | used as initial hidden state for the two decoding networks that use the information 22 | contained in it to respectively reconstruct the input sequence and to predict future 23 | frames. 24 | """ 25 | 26 | def __init__(self, input_size: Tuple[int, int], input_dim: int, 27 | hidden_dim: List[int], kernel_size: List[Tuple[int, int]], 28 | batch_first: bool=True, bias: bool=True, decoding_steps: int=-1): 29 | super(ConvLSTMAutoencoder, self).__init__() 30 | self.decoding_steps = decoding_steps 31 | self.input_size = input_size 32 | self.input_dim = input_dim 33 | self.hidden_dim = hidden_dim 34 | self.kernel_size = kernel_size 35 | self.batch_first = batch_first 36 | self.num_layers = len(hidden_dim) 37 | 38 | self.encoder = ConvLSTM( 39 | input_size=input_size, 40 | input_dim=input_dim, 41 | hidden_dim=hidden_dim, 42 | kernel_size=kernel_size, 43 | num_layers=self.num_layers, 44 | batch_first=False, 45 | bias=bias, 46 | mode=ConvLSTM.SEQUENCE 47 | ) 48 | 49 | # reverse the order of hidden dimensions and kernels 50 | decoding_hidden_dim = list(reversed(hidden_dim)) 51 | decoding_kernel_size = list(reversed(kernel_size)) 52 | decoding_hidden_dim .append(input_dim) # NOTE: we need a num_of_decoding_layers = num_of_encoding_layers+1 53 | decoding_kernel_size.append((1,1)) # so we add a 1x1 ConvLSTM as last decoding layer 54 | 55 | self.input_reconstruction = ConvLSTM( 56 | input_size=input_size, 57 | input_dim=input_dim, 58 | hidden_dim=decoding_hidden_dim, 59 | kernel_size=decoding_kernel_size, 60 | num_layers=self.num_layers + 1, 61 | batch_first=False, 62 | bias=bias, 63 | mode=ConvLSTM.STEP_BY_STEP 64 | ) 65 | self.future_prediction = ConvLSTM( 66 | input_size=input_size, 67 | input_dim=input_dim, 68 | hidden_dim=decoding_hidden_dim, 69 | kernel_size=decoding_kernel_size, 70 | num_layers=self.num_layers + 1, 71 | batch_first=False, 72 | bias=bias, 73 | mode=ConvLSTM.STEP_BY_STEP 74 | ) 75 | 76 | def forward(self, input_sequence: Tensor) -> Tuple[Tensor]: 77 | sequence = input_sequence.transpose(0,1) if self.batch_first else input_sequence # always work in sequence-first mode 78 | sequence_len = sequence.size(0) 79 | 80 | steps = self.decoding_steps if self.decoding_steps != -1 else sequence_len 81 | 82 | # encode 83 | _, hidden_state = self.encoder(sequence) 84 | 85 | last_frame = sequence[-1, :] 86 | h_n, c_n = hidden_state 87 | representation = (h_n[-1], c_n[-1]) 88 | 89 | # decode for input reconstruction 90 | output_seq_recon = ConvLSTMAutoencoder._decode(self.input_reconstruction, last_frame, 91 | representation, steps) 92 | 93 | # decode for future prediction 94 | output_seq_pred = ConvLSTMAutoencoder._decode(self.future_prediction, last_frame, 95 | representation, steps) 96 | 97 | if self.batch_first: # if input was batch_first restore dimension order 98 | reconstruction = output_seq_recon.transpose(0,1) 99 | prediction = output_seq_pred .transpose(0,1) 100 | else: 101 | reconstruction = output_seq_recon 102 | prediction = output_seq_pred 103 | 104 | return (reconstruction, prediction) 105 | 106 | @staticmethod 107 | def _decode(decoder: ConvLSTM, last_frame: Tensor, representation: HiddenState, steps: int) -> Tensor: 108 | decoded_sequence = [] 109 | 110 | h_n, c_n = representation 111 | h_0, c_0 = decoder.init_hidden(last_frame.size(0)) 112 | h_0[0], c_0[0] = h_n, c_n 113 | 114 | state = (h_0, c_0) 115 | output = last_frame 116 | for t in range(steps): 117 | output, state = decoder(output, state) 118 | decoded_sequence.append(output) 119 | 120 | return torch.stack(decoded_sequence, dim=0) 121 | -------------------------------------------------------------------------------- /lstm_autoencoder/lstm_autoencoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from lstm_cell_stack.lstm_cell_stack import LSTMCellStack 8 | from lstm_stack.lstm_stack import LSTMStack 9 | 10 | 11 | class LSTMAutoencoder(nn.Module): 12 | """ 13 | Implementation of the model described in 'Unsupervised Learning of Video 14 | Representations using LSTMs', N. Srivastava, E. Mansimov, R. Salakhutdinov 15 | https://arxiv.org/pdf/1502.04681.pdf 16 | 17 | It is composed by an LSTM which acts as an encoder for a video sequence 18 | and one or multiple decoders (LSTM but possibly other models too) that, 19 | given the same input representations, execute various tasks. 20 | """ 21 | 22 | def __init__(self, input_size: int, hidden_size: List[int], 23 | batch_first: bool, decoding_steps=-1): 24 | super(LSTMAutoencoder, self).__init__() 25 | self.batch_first = batch_first 26 | self.input_size = input_size 27 | self.decoding_steps = decoding_steps 28 | 29 | self.encoder = LSTMStack( 30 | input_size=self.input_size, 31 | hidden_size=hidden_size, 32 | batch_first=False 33 | ) 34 | 35 | sizes = [self.input_size, *hidden_size] 36 | decoding_sizes = list(reversed(sizes)) 37 | self.input_reconstruction = LSTMCellStack( 38 | input_size=self.input_size, 39 | hidden_size=decoding_sizes 40 | ) 41 | 42 | # self.future_prediction = LSTMCellStack( 43 | # input_size=self.input_size, 44 | # hidden_size=decoding_sizes 45 | # ) 46 | 47 | def forward(self, input_sequence: Tensor) -> Tuple[Tensor, Tensor]: 48 | sequence = input_sequence.transpose(0,1) if self.batch_first else input_sequence # always work in sequence-first mode 49 | sequence_len = sequence.size(0) 50 | 51 | # encode 52 | _, hidden_state = self.encoder(sequence) # discard output, we are interested only in hidden state to initialize the decoders 53 | # LSTM state has shape (num_layers * num_directions, batch, hidden_size) = 54 | # (1, batch, hidden_size) but LSTMCell expects h and c to have shape 55 | # (batch, hidden_size), so we have to remove the first dimension 56 | h_n, c_n = hidden_state 57 | h_n_last, c_n_last = h_n[-1], c_n[-1] # take the last layer's hidden state ... 58 | representation = (h_n_last.squeeze(dim=0), c_n_last.squeeze(dim=0)) # ... and use it as compressed representation of what the model has seen so far 59 | 60 | #last_frame = sequence[-1, :] 61 | steps = self.decoding_steps if self.decoding_steps != -1 else sequence_len 62 | 63 | # decode for input reconstruction 64 | output_seq_recon = LSTMAutoencoder._decode(self.input_reconstruction, sequence, # last_frame, 65 | representation, steps) 66 | 67 | # decode for future prediction 68 | #output_seq_pred = LSTMAutoencoder._decode(self.future_prediction, last_frame, 69 | # representation, steps) 70 | 71 | if self.batch_first: # if input was batch_first restore dimension order 72 | reconstruction = output_seq_recon.transpose(0,1) 73 | # prediction = output_seq_pred .transpose(0,1) 74 | else: 75 | reconstruction = output_seq_recon 76 | # prediction = output_seq_pred 77 | 78 | return reconstruction # (reconstruction, prediction) 79 | 80 | @staticmethod 81 | def _decode(decoder: LSTMCellStack, input_sequence: Tensor, 82 | representation: Tuple[Tensor, Tensor], steps: int) -> Tensor: 83 | output_seq = [] 84 | #output = input 85 | sequence_reversed = input_sequence.flip(0) 86 | 87 | h_0, c_0 = decoder.init_hidden(input_sequence.size(1)) 88 | # use encoder's last layer hidden state to initalize decoders hidden state 89 | h_0[0], c_0[0] = representation[0], representation[1] 90 | 91 | state = (h_0, c_0) 92 | for t in range(steps): 93 | #output, state = decoder(output, state) 94 | output, state = decoder(sequence_reversed[t,:], state) 95 | output_seq.append(output) 96 | 97 | return torch.stack(output_seq, dim=0) # dim 0 because we are working with batch_first=False 98 | 99 | 100 | class ImgLSTMAutoencoder(nn.Module): 101 | 102 | def __init__(self, image_size: Tuple[int, int, int], hidden_size: List[int], 103 | batch_first: bool, decoding_steps=-1): 104 | super(ImgLSTMAutoencoder, self).__init__() 105 | self.image_size = image_size 106 | self.input_size = image_size[0] * image_size[1] * image_size[2] 107 | self.batch_first = batch_first 108 | 109 | self.lstm_autoencoder = LSTMAutoencoder(self.input_size, hidden_size, False, decoding_steps) 110 | 111 | def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: 112 | sequence = input.transpose(0,1) if self.batch_first else input # always work in sequence-first mode 113 | sequence_len = sequence.size(0) 114 | batch_size = sequence.size(1) 115 | 116 | flattened_sequence = input.view((sequence_len, batch_size, -1)) 117 | 118 | # reconstruction, prediction = self.lstm_autoencoder(flattened_sequence) 119 | reconstruction = self.lstm_autoencoder(flattened_sequence) 120 | 121 | sequence_shape = (sequence_len, batch_size,) + self.image_size 122 | reconstruction_img = reconstruction.view(sequence_shape) 123 | # prediction_img = prediction .view(sequence_shape) 124 | 125 | recon_out = reconstruction_img.transpose(0,1) if self.batch_first else reconstruction_img 126 | # pred_out = prediction_img .transpose(0,1) if self.batch_first else prediction_img 127 | 128 | return recon_out # (recon_out, pred_out) 129 | -------------------------------------------------------------------------------- /convlstm/test/tests.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import random 3 | import time 4 | 5 | import torch 6 | from ..package import convlstm, convlstm_cpp, convlstm_cuda 7 | 8 | class TestConvLSTMCPP(unittest.TestCase): 9 | 10 | def setUp(self): 11 | input_size = (30,30) 12 | input_dim = 1 13 | hidden_dim = 1 14 | kernel_size = (5,5) 15 | bias = True 16 | 17 | self.convlstm_cell = convlstm.ConvLSTMCell (input_size, input_dim, hidden_dim, kernel_size, bias) 18 | self.convlstmcpp_cell = convlstm_cpp.ConvLSTMCPPCell(input_size, input_dim, hidden_dim, kernel_size, bias) 19 | 20 | # The two models must have the same initial conditions 21 | self.convlstmcpp_cell.weights.data = self.convlstm_cell.conv.weight.data 22 | self.convlstmcpp_cell.bias.data = self.convlstm_cell.conv.bias.data 23 | 24 | batch_size = 16 25 | self.input = torch.rand((batch_size, input_dim) + input_size) 26 | self.state = (torch.rand((batch_size, hidden_dim) + input_size), torch.rand((batch_size, hidden_dim) + input_size)) 27 | 28 | 29 | def tearDown(self): 30 | self.convlstm_cell = None 31 | self.convlstmcpp_cell = None 32 | self.input = None 33 | self.state = None 34 | 35 | def test_forward(self): 36 | output = self.convlstm_cell(self.input, self.state) 37 | output_cpp = self.convlstmcpp_cell(self.input, self.state) 38 | self.assertTrue(torch.equal(output[0], output_cpp[0]), 'The two output tensors are not equal') 39 | 40 | def test_backpropagation(self): 41 | output = self.convlstm_cell(self.input, self.state) 42 | output_cpp = self.convlstmcpp_cell(self.input, self.state) 43 | 44 | ground_truth = torch.full_like(output[0], 20) 45 | criterion = torch.nn.MSELoss() 46 | loss = criterion(output[0], ground_truth) 47 | loss_cpp = criterion(output_cpp[0], ground_truth) 48 | 49 | loss.backward() 50 | loss_cpp.backward() 51 | 52 | weight_grad = self.convlstm_cell.conv.weight._grad 53 | weight_grad_cpp = self.convlstmcpp_cell.weights._grad 54 | 55 | self.assertTrue(torch.allclose(weight_grad, weight_grad_cpp), 'The two weights gradients are not equal') 56 | 57 | 58 | def test_bptt(self): 59 | time_steps = 100 #random.randint(1,20) 60 | criterion = torch.nn.MSELoss() 61 | 62 | # python model 63 | start = time.perf_counter() 64 | output = self.convlstm_cell(self.input, self.state) 65 | for _ in range(time_steps): 66 | output = self.convlstm_cell(output[0], self.state) 67 | 68 | ground_truth = torch.full_like(output[0], 20) 69 | 70 | loss = criterion(output[0], ground_truth) 71 | loss.backward() 72 | grad = self.convlstm_cell.conv.weight._grad 73 | print('Python time: {}'.format(time.perf_counter() - start)) 74 | 75 | # c++ model 76 | start = time.perf_counter() 77 | output_cpp = self.convlstmcpp_cell(self.input, self.state) 78 | for _ in range(time_steps): 79 | output_cpp = self.convlstmcpp_cell(output_cpp[0], self.state) 80 | 81 | loss_cpp = criterion(output_cpp[0], ground_truth) 82 | loss_cpp.backward() 83 | grad_cpp = self.convlstmcpp_cell.weights._grad 84 | print('C++ time: {}'.format(time.perf_counter() - start)) 85 | 86 | self.assertTrue(torch.allclose(grad, grad_cpp), 'The two gradients are not equal') 87 | 88 | 89 | class TestConvLSTMCuda(unittest.TestCase): 90 | 91 | def setUp(self): 92 | input_size = (30,30) 93 | input_dim = 1 94 | hidden_dim = 1 95 | kernel_size = (5,5) 96 | bias = True 97 | 98 | self.convlstm_cell = convlstm.ConvLSTMCell (input_size, input_dim, hidden_dim, kernel_size, bias) 99 | self.convlstmcuda_cell = convlstm_cuda.ConvLSTMCudaCell(input_size, input_dim, hidden_dim, kernel_size, bias) 100 | 101 | # The two models must have the same initial conditions 102 | self.convlstmcuda_cell.weights.data = self.convlstm_cell.conv.weight.data 103 | self.convlstmcuda_cell.bias.data = self.convlstm_cell.conv.bias.data 104 | 105 | batch_size = 16 106 | self.input = torch.rand((batch_size, input_dim) + input_size) 107 | self.state = (torch.rand((batch_size, hidden_dim) + input_size), torch.rand((batch_size, hidden_dim) + input_size)) 108 | 109 | 110 | def tearDown(self): 111 | self.convlstm_cell = None 112 | self.convlstmcuda_cell = None 113 | self.input = None 114 | self.state = None 115 | 116 | def test_forward(self): 117 | output = self.convlstm_cell(self.input, self.state) 118 | output_cpp = self.convlstmcuda_cell(self.input, self.state) 119 | self.assertTrue(torch.equal(output[0], output_cpp[0]), 'The two output tensors are not equal') 120 | 121 | def test_backpropagation(self): 122 | output = self.convlstm_cell(self.input, self.state) 123 | output_cpp = self.convlstmcuda_cell(self.input, self.state) 124 | 125 | ground_truth = torch.full_like(output[0], 20) 126 | criterion = torch.nn.MSELoss() 127 | loss = criterion(output[0], ground_truth) 128 | loss_cpp = criterion(output_cpp[0], ground_truth) 129 | 130 | loss.backward() 131 | loss_cpp.backward() 132 | 133 | weight_grad = self.convlstm_cell.conv.weight._grad 134 | weight_grad_cpp = self.convlstmcuda_cell.weights._grad 135 | 136 | self.assertTrue(torch.allclose(weight_grad, weight_grad_cpp), 'The two weights gradients are not equal') 137 | 138 | 139 | def test_bptt(self): 140 | time_steps = 100 #random.randint(1,20) 141 | criterion = torch.nn.MSELoss() 142 | 143 | # python model 144 | start = time.perf_counter() 145 | output = self.convlstm_cell(self.input, self.state) 146 | for _ in range(time_steps): 147 | output = self.convlstm_cell(output[0], self.state) 148 | 149 | ground_truth = torch.full_like(output[0], 20) 150 | 151 | loss = criterion(output[0], ground_truth) 152 | loss.backward() 153 | grad = self.convlstm_cell.conv.weight._grad 154 | print('Python time: {}'.format(time.perf_counter() - start)) 155 | 156 | # c++ model 157 | start = time.perf_counter() 158 | output_cpp = self.convlstmcuda_cell(self.input, self.state) 159 | for _ in range(time_steps): 160 | output_cpp = self.convlstmcuda_cell(output_cpp[0], self.state) 161 | 162 | loss_cpp = criterion(output_cpp[0], ground_truth) 163 | loss_cpp.backward() 164 | grad_cpp = self.convlstmcuda_cell.weights._grad 165 | print('C++ time: {}'.format(time.perf_counter() - start)) 166 | 167 | self.assertTrue(torch.allclose(grad, grad_cpp), 'The two gradients are not equal') 168 | 169 | 170 | 171 | 172 | 173 | if __name__ == '__main__': 174 | unittest.main() -------------------------------------------------------------------------------- /convlstm/src/convlstm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | const int64_t stride = 1; 9 | const int64_t padding = 2; 10 | const int64_t dilation = 1; 11 | 12 | /** 13 | * -------------------------------------------------------------------------- 14 | * | WARNING!!! It only works for stride=1, padding=2, dilation=1, groups=1 | 15 | * -------------------------------------------------------------------------- 16 | */ 17 | std::vector convlstm_forward(const at::Tensor &input, const at::Tensor &weights, const at::Tensor &bias, 18 | const at::Tensor &old_h, const at::Tensor &old_cell) { 19 | const auto X = at::cat({input, old_h}, /*dim=*/1); // assuming channels-first ordering 20 | 21 | const auto gate_weights = at::conv2d(X, weights, bias, stride, padding); // FIXME: had to put a fixed size padding because of some integer overflow issues occurring when computing weights.size(2)/int64_t(2) 22 | const auto gate_size = gate_weights.size(1)/4; 23 | const auto input_gate = at::sigmoid(gate_weights.narrow(/*dim=*/1, 0, gate_size)); 24 | const auto forget_gate = at::sigmoid(gate_weights.narrow(/*dim=*/1, gate_size, gate_size)); 25 | const auto output_gate = at::sigmoid(gate_weights.narrow(/*dim=*/1, 2*gate_size, gate_size)); 26 | const auto candidate_cell = at::tanh(gate_weights.narrow(/*dim=*/1, 3*gate_size, gate_size)); 27 | 28 | const auto c_next = forget_gate * old_cell + input_gate * candidate_cell; 29 | const auto h_next = output_gate * at::tanh(c_next); 30 | 31 | return {h_next, c_next, input_gate, forget_gate, output_gate, 32 | candidate_cell, X, gate_weights}; 33 | } 34 | 35 | 36 | at::Tensor d_sigmoid(at::Tensor z) { 37 | auto s = at::sigmoid(z); 38 | return (1 - s) * s; 39 | } 40 | 41 | // tanh'(z) = 1 - tanh^2(z) 42 | at::Tensor d_tanh(at::Tensor z) { 43 | return 1 - z.tanh().pow(2); 44 | } 45 | 46 | 47 | /** 48 | * Computes the gradient of Y = at::conv2d(X, W) function 49 | * $$ Y = X \star W $$ 50 | * 51 | * ---------------------------------------------------------------------------- 52 | * | WARNING!!! It does not work for generic 2D convolution. It works only for | 53 | * | stride=1, padding=2, dilation=1, groups=1 | 54 | * --------------------------------------------------------------------------- 55 | * 56 | * @conv_out_grad: gradient of the loss function wrt the output of conv2d, 57 | * i.e. $$ \frac{\partial \mathcal{L}}{\partial Y} $$ 58 | * @input: input of conv2d, i.e. $X$ 59 | * @weights: tensor of weights used in conv2d, i.e. $W$ 60 | * @return: a tuple containing the weights gradients and the convolution input gradients, 61 | * i.e. $$ < \frac{\partial \mathcal{L}}{\partial W}, \frac{\partial 62 | * \mathcal{L}}{X} > $$ 63 | */ 64 | std::tuple d_conv2d(const at::Tensor &conv_out_grad, const at::Tensor &input, const at::Tensor &weights) { 65 | 66 | const auto batch_size = input.size(0); 67 | const auto out_channels = conv_out_grad.size(1); 68 | const auto in_channels = input.size(1); 69 | 70 | std::vector d_weights_sizes{out_channels, in_channels, weights.size(2), weights.size(3)}; 71 | auto d_weights = at::zeros(d_weights_sizes, weights.type()); 72 | for (auto batch_index = 0; batch_index < batch_size; ++batch_index) { // Compute loss function derivatives wrt weights for each data sample in the batch: $$ \frac{\partial L}{\partial W_{j,i}} = X_i \star \frac{\partial L}{\partial Y_j} $$ 73 | const auto X = input.narrow(0, batch_index, 1); 74 | const auto dY = conv_out_grad.narrow(0, batch_index, 1).permute({1,0,2,3}); // permute from (1, out_channels, ker_h, ker_w) to (out_channels, 1, ker_h, ker_w) 75 | 76 | std::vector d_weights_in(in_channels); 77 | auto d_weights_batch = at::zeros_like(d_weights); 78 | for (auto i = 0; i < in_channels; ++i) { // for each input channels compute the gradient wrt weights obtaining a 79 | d_weights_in[i] = at::conv2d(X.narrow(1, i, 1), dY, {}, // 80 | stride, padding).squeeze(); // (1, out_channels, ker_h, ker_w) shaped tensor and remove the first dimension 81 | } 82 | d_weights += at::stack(d_weights_in, /*dim=*/1); 83 | 84 | //const auto X = input.narrow(0, batch_index, 1).permute({1,0,2,3}); 85 | //const auto dY = conv_out_grad.narrow(0, batch_index, 1).permute({1,0,2,3}); 86 | //d_weights += at::conv2d(dY, X, {}, stride, padding); 87 | } 88 | 89 | /* 90 | * rotate the kernels of 180° (see here 91 | * https://medium.com/@2017csm1006/forward-and-backpropagation-in-convolutional-neural-network-4dfa96d7b37e 92 | * and here https://grzegorzgwardys.wordpress.com/2016/04/22/8/) 93 | * 94 | * conv_transpose2d() performs the convolution rotating the kernel by 180° 95 | */ 96 | auto dX = at::conv_transpose2d(conv_out_grad, weights, {}, stride, padding); 97 | 98 | return {d_weights, dX}; 99 | } 100 | 101 | 102 | std::vector convlstm_backward(const at::Tensor &grad_h, const at::Tensor &grad_cell, //these are the gradients coming from timestep t+1 103 | const at::Tensor &new_cell, const at::Tensor &old_cell, // all variables that are part of the cell state 104 | const at::Tensor &input_gate, const at::Tensor &forget_gate, 105 | const at::Tensor &output_gate, const at::Tensor &candidate_cell, 106 | const at::Tensor &X, const at::Tensor &gate_weights, 107 | const at::Tensor &weights) { 108 | auto d_output_gate = at::tanh(new_cell) * grad_h; 109 | const auto d_tanh_new_cell = output_gate * grad_h; 110 | const auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; 111 | 112 | const auto d_old_cell = forget_gate * d_new_cell; 113 | auto d_input_gate = candidate_cell * d_new_cell; 114 | auto d_forget_gate = old_cell * d_new_cell; 115 | auto d_candidate_cell = input_gate * d_new_cell; 116 | 117 | auto gate_size = gate_weights.size(1)/4; 118 | d_input_gate *= d_sigmoid(gate_weights.narrow(1, 0, gate_size)); 119 | d_forget_gate *= d_sigmoid(gate_weights.narrow(1, gate_size, gate_size)); 120 | d_output_gate *= d_sigmoid(gate_weights.narrow(1, 2*gate_size, gate_size)); 121 | d_candidate_cell *= d_tanh(gate_weights.narrow(1, 3*gate_size, gate_size)); 122 | 123 | const auto d_gates = at::cat({d_input_gate, d_forget_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); 124 | 125 | const auto gradients = d_conv2d(d_gates, X, weights); 126 | const auto d_weights = std::get<0>(gradients); 127 | const auto d_X = std::get<1>(gradients); 128 | 129 | auto d_bias = d_gates.sum(/*dim=*/{2,3}, /*keepdim=*/false).sum(/*dim=*/0); // CHECK: devo sommare o mediare sulla batch dimension? Pare che la somma vada bene 130 | 131 | const auto input_size = d_X.size(1) - grad_h.size(1); 132 | const auto d_input = d_X.slice(/*dim=*/1, 0, input_size); 133 | const auto d_old_h = d_X.slice(/*dim=*/1, input_size); 134 | 135 | return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; 136 | } 137 | 138 | 139 | 140 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 141 | m.def("forward", &convlstm_forward, "ConvLSTM forward"); 142 | m.def("backward", &convlstm_backward, "ConvLSTM backward"); 143 | } -------------------------------------------------------------------------------- /convlstm/package/convlstm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Tuple, List 2 | from enum import Enum 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | 9 | # typedefs 10 | HiddenState = Tuple[Tensor, Tensor] 11 | HiddenStateStacked = Tuple[List[Tensor], List[Tensor]] 12 | 13 | 14 | class ConvLSTMCell(nn.Module): 15 | 16 | def __init__(self, input_size: Tuple[int, int], input_dim: int, hidden_dim: int, 17 | kernel_size: Tuple[int, int], bias: bool, hidden_activation=torch.sigmoid, 18 | output_activation=torch.tanh): 19 | """ 20 | Initialize ConvLSTM cell. 21 | 22 | Args: 23 | @input_size: Height and width of input tensor as (height, width). 24 | @input_dim: Number of channels of input tensor. 25 | @hidden_dim: Number of channels of hidden state. 26 | @kernel_size: Size of the convolutional kernel. 27 | @bias: Whether or not to add the bias. 28 | """ 29 | super(ConvLSTMCell, self).__init__() 30 | 31 | self.height, self.width = input_size 32 | self.input_dim = input_dim 33 | self.hidden_dim = hidden_dim 34 | 35 | self.kernel_size = kernel_size 36 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 37 | self.bias = bias 38 | 39 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 40 | out_channels=4 * self.hidden_dim, 41 | kernel_size=self.kernel_size, 42 | padding=self.padding, 43 | bias=self.bias) 44 | 45 | self.output_activation = output_activation 46 | self.hidden_activation = hidden_activation 47 | 48 | def forward(self, input: Tensor, hx: HiddenState=None) -> HiddenState: 49 | """ 50 | Inputs: input, (h_0, c_0) 51 | - **input** of shape `(batch, input_dim, height, width)`: 52 | tensor containing input features 53 | - **h_0** of shape `(batch, hidden_dim, height, width)`: 54 | tensor containing the initial hidden state for each element in the batch. 55 | - **c_0** of shape `(batch, hidden_dim, height, width)`: 56 | tensor containing the initial cell state for each element in the batch. 57 | 58 | If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. 59 | 60 | Outputs: h_1, c_1 61 | - **h_1** of shape `(batch, hidden_dim, height, width)`: 62 | tensor containing the next hidden state for each element in the batch 63 | - **c_1** of shape `(batch, hidden_dim, height, width)`: 64 | tensor containing the next cell state for each element in the batch 65 | """ 66 | if not hx: 67 | hx = self.init_hidden(input.size(0)) 68 | 69 | old_h, old_cell = hx 70 | 71 | combined = torch.cat([input, old_h], dim=1) # concatenate along channel axis 72 | 73 | gates_activations = self.conv(combined) 74 | cc_i, cc_f, cc_o, cc_g = gates_activations.chunk(4, dim=1) 75 | i = self.hidden_activation(cc_i) # torch.sigmoid(cc_i) 76 | f = self.hidden_activation(cc_f) # torch.sigmoid(cc_f) 77 | o = self.hidden_activation(cc_o) # torch.sigmoid(cc_o) 78 | g = torch.tanh(cc_g) 79 | 80 | c_next = f * old_cell + i * g 81 | h_next = o * self.output_activation(c_next) # torch.tanh(c_next) 82 | 83 | return h_next, c_next 84 | 85 | def init_hidden(self, batch_size: int) -> HiddenState: 86 | dtype = self.conv.weight.dtype 87 | device = self.conv.weight.device 88 | shape = (batch_size, self.hidden_dim, self.height, self.width) 89 | h = torch.zeros(shape, dtype=dtype).to(device) 90 | return (h, h) 91 | 92 | 93 | #class ConvLSTMParams(): 94 | # 95 | # def __init__(self, input_size: Tuple[int, int], input_dim: int, hidden_dim: int, 96 | # kernel_size: Tuple[int, int], num_layers: int, 97 | # batch_first: bool=False, bias: bool=True, mode: str='sequence'): 98 | # self.input_size = input_size 99 | # self.input_dim = input_dim 100 | # self.hidden_dim = hidden_dim 101 | # self.kernel_size = kernel_size 102 | # self.num_layers = num_layers 103 | # self.batch_first = batch_first 104 | # self.bias = bias 105 | # self.mode = mode 106 | # self.model = ModelType.CONVLSTM 107 | 108 | 109 | class ConvLSTM(nn.Module): 110 | """ 111 | 2D convolutional LSTM model. 112 | 113 | Parameters 114 | ---------- 115 | input_size: (int, int) 116 | Height and width of input tensor as (height, width). 117 | input_dim: int 118 | Number of channels of each hidden state 119 | hidden_dim: list of int 120 | Number of channels of hidden state. 121 | kernel_size: list of (int, int) 122 | Size of each convolutional kernel. 123 | num_layers: int 124 | number of convolutional LSTM layers 125 | batch_first: bool (default False) 126 | input tensor order: (batch_size, sequence_len, channels, height, 127 | width) if batch_first == True, (sequence_len, batch_size, channels, 128 | height, width) otherwise 129 | bias: bool (default True) 130 | Whether or not to add the bias. 131 | mode: either 'sequence' or 'item' (default 'sequence') 132 | if 'sequence' forward() accepts an input tensor of shape 133 | (batch_size, sequence_len, channels, height, width) and outputs a 134 | tensor of the same shape; 135 | if 'item' the model processes one sequence element at a time, 136 | therefore forward accepts an input tensor of shape (batch_size, 137 | sequence_len, channels, height, width) and outputs a tensor of the 138 | same shape. 139 | When using 'item' mode you should take care of feeing forward() with 140 | the output of init_hidden() when processing the first element of the 141 | sequence 142 | """ 143 | SEQUENCE = 'sequence' 144 | STEP_BY_STEP = 'step-by-step' 145 | 146 | def __init__(self, input_size: Tuple[int, int], input_dim: int, hidden_dim: List[int], 147 | kernel_size: List[Tuple[int, int]], num_layers: int, batch_first: bool=False, 148 | bias: bool=True, mode: str='sequence'): 149 | super(ConvLSTM, self).__init__() 150 | 151 | self._check_kernel_size_consistency(kernel_size) 152 | 153 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 154 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 155 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 156 | if not len(kernel_size) == len(hidden_dim) == num_layers: 157 | raise ValueError('Inconsistent list length.') 158 | 159 | self.height, self.width = input_size 160 | self.mode = mode 161 | 162 | self.input_dim = input_dim 163 | self.hidden_dim = hidden_dim 164 | self.kernel_size = kernel_size 165 | self.num_layers = num_layers 166 | self.batch_first = batch_first 167 | self.bias = bias 168 | 169 | cell_list = [] 170 | dims = [input_dim, *hidden_dim] 171 | for i in range(num_layers): 172 | layer = ConvLSTMCell(input_size=(self.height, self.width), 173 | input_dim=dims[i], 174 | hidden_dim=dims[i+1], 175 | kernel_size=self.kernel_size[i], 176 | bias=self.bias) 177 | cell_list.append(layer) 178 | self.cell_list = nn.ModuleList(cell_list) 179 | 180 | self.set_mode(mode) 181 | 182 | def set_mode(self, mode: str) -> str: 183 | old_mode = self.mode 184 | self.mode = mode 185 | if mode == ConvLSTM.SEQUENCE: 186 | self.forward = self._forward_sequence 187 | elif mode == ConvLSTM.STEP_BY_STEP: 188 | self.forward = self._forward_item 189 | else: 190 | raise ValueError("Parameter 'mode' can only be either 'sequence' or 'item'.") 191 | 192 | return old_mode 193 | 194 | def _forward_sequence(self, input: Tensor, hidden_state: HiddenStateStacked=None) \ 195 | -> Tuple[Tensor, HiddenStateStacked]: 196 | """ 197 | Inputs: input, (h_0, c_0) 198 | - **input** either of shape `(seq_len, batch, input_dim, height, width)` 199 | or `(batch, seq_len, channels, height, width)`: tensor containing 200 | the features of the input sequence. 201 | - **h_0** list of size num_layers that contains tensors of shape 202 | `(batch, channels, height, width)`: tensor containing the initial 203 | hidden state for each element in the batch and for each layer in the model. 204 | - **c_0** list of size num_layers that contains tensors of shape 205 | `(batch, channels, height, width)`: tensor containing the initial 206 | cell state for each element in the batch and for each layer in the model. 207 | 208 | If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. 209 | 210 | Outputs: output, (h_n, c_n) 211 | - **output** of shape `(batch, seq_len, channels, height, width)`: 212 | tensor containing the output features `(h_t)` from the last layer 213 | of the ConvLSTM, for each t. 214 | - **h_n** list of size num_layers that contains tensors of shape 215 | `(batch, channels, height, width)`: tensor containing the hidden 216 | state for `t = seq_len`. 217 | - **c_n** list of size num_layers that contains tensors of shape 218 | `(batch, channels, height, width)`: tensor containing the cell 219 | state for `t = seq_len`. 220 | """ 221 | # (b, t, c, h, w) -> (t, b, c, h, w) 222 | input_seq = input.transpose(0, 1) if self.batch_first else input 223 | 224 | if hidden_state is None: 225 | hidden_state = self.init_hidden(batch_size=input_seq.size(1)) 226 | 227 | seq_len = input_seq.size(0) 228 | 229 | h_0, c_0 = hidden_state 230 | h_n_list, c_n_list = [], [] 231 | 232 | prev_layer_output = list(torch.unbind(input_seq)) # [tensor.squeeze(1) for tensor in input_seq.split(1, dim=1)] 233 | for l, cell in enumerate(self.cell_list): 234 | state = (h_0[l], c_0[l]) 235 | for t in range(seq_len): 236 | state = cell(prev_layer_output[t], state) 237 | prev_layer_output[t] = state[0] 238 | h_n_list.append(state[0]) 239 | c_n_list.append(state[1]) 240 | 241 | output = torch.stack(prev_layer_output, dim=1) 242 | 243 | if self.batch_first: 244 | return output.transpose(0, 1), (h_n_list, c_n_list) 245 | 246 | return output, (h_n_list, c_n_list) 247 | 248 | def _forward_item(self, input: Tensor, hidden_state: HiddenStateStacked) \ 249 | -> Tuple[Tensor, HiddenStateStacked]: 250 | """ 251 | Inputs: input, (h_0, c_0) 252 | - **input** of shape `(batch, input_dim, height, width)`: 253 | tensor containing the input image. 254 | - **h_0** list of size num_layers that contains tensors of shape 255 | `(batch, channels, height, width)`: tensor containing 256 | the initial hidden state for each element in the batch and for 257 | each layer in the model. 258 | - **c_0** list of size num_layers that contains tensors of shape 259 | `(batch, channels, height, width)`: tensor containing the initial 260 | cell state for each element in the batch and for each layer in the model. 261 | 262 | If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. 263 | 264 | Outputs: output, (h_n, c_n) 265 | - **output** of shape `(batch, channels, height, width)`: 266 | tensor containing the output features `(h_t)` from the last 267 | layer of the LSTM. 268 | - **h_n** list of size num_layers that contains tensors of shape 269 | `(batch, channels, height, width)`: tensor containing the hidden 270 | state for each layer. 271 | - **c_n** list of size num_layers that contains tensors of shape 272 | (batch, channels, height, width): tensor containing the cell state 273 | """ 274 | if hidden_state is None: 275 | hidden_state = self.init_hidden(batch_size=input.size(0)) 276 | 277 | output = input 278 | h_0, c_0 = hidden_state 279 | h_n_list, c_n_list = [], [] 280 | for l, cell in enumerate(self.cell_list): 281 | h_n, c_n = cell(output, (h_0[l], c_0[l])) 282 | output = h_n 283 | h_n_list.append(h_n) 284 | c_n_list.append(c_n) 285 | 286 | return output, (h_n_list, c_n_list) 287 | 288 | def init_hidden(self, batch_size: int) -> HiddenStateStacked: 289 | h_0, c_0 = [], [] 290 | for cell in self.cell_list: 291 | h, c = cell.init_hidden(batch_size) 292 | h_0.append(h) 293 | c_0.append(c) 294 | return (h_0, c_0) # NOTE: using a list to allow hidden states of different sizes 295 | 296 | @staticmethod 297 | def _check_kernel_size_consistency(kernel_size: Tuple[int, int]): 298 | if not (isinstance(kernel_size, tuple) or 299 | (isinstance(kernel_size, list) and 300 | all([isinstance(elem, tuple) for elem in kernel_size]))): 301 | raise ValueError('`kernel_size` must be tuple or list of tuples') 302 | 303 | @staticmethod 304 | def _extend_for_multilayer(param, num_layers: Union[List[int], int]) -> List[int]: 305 | if not isinstance(param, list): 306 | param = [param] * num_layers 307 | return param --------------------------------------------------------------------------------