├── LICENSE ├── README.md ├── convlstm.py ├── convlstm_decoder.py └── figures ├── BCI_decoder.png ├── BCI_system.png ├── ConvLSTM_cell.png ├── ConvLSTM_definition.png └── input_output_decoder.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Nguyen Thi Kim Uyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ConvLSTM Pytorch Implementation 2 |
3 | - [Goal](#Goal)
4 | - [Example of using ConvLSTM](#Example-of-using-ConvLSTM)
5 | - [Explaination](#Explaination)
6 | - [1. ConvLSTM definition](#1-ConvLSTM-definition)
7 | - [2. Bidirectional ConvLSTM decoder](#2-bidirectional-convlstm-decoder)
8 | - [3. Input, output for decoder](#3-input-output-for-decoder)
9 | - [Environment](#Environment)
10 | - [References](#References)
11 | 12 | ## Goal 13 | The ConvLSTM model is mainly used as skeleton to design a BCI (Brain Computer Interface) decoder for our project (Decode the kinematic signal from neural signal). 14 | This repo is implementation of ConvLSTM in Pytorch. The implemenation is inherited from the paper: Convolutional LSTM Network-A Machine LearningApproach for Precipitation Nowcasting 15 | 16 | BCI decoder is a part in BCI system, which is clearly shown in the above figure. 17 | 18 | ## Example of using ConvLSTM 19 | convlstm_decoder.py contains an example of defining a ConvLSTM decoder. 20 | 21 | Here is an example of defining 1 layer bidirectional ConvLSTM: 22 | ``` 23 | convlstm_layer = [] 24 | img_size_list=[(10, 10)] 25 | num_layers = 1 # number of layer 26 | input_channel = 96 # the number of electrodes in Utah array 27 | hidden_channels = [256] # the output channels for each layer 28 | kernel_size = [(7, 7)] # the kernel size of cnn for each layer 29 | stride = [(1, 1)] # the stride size of cnn for each layer 30 | padding = [(0, 0)] # padding size of cnn for each layer 31 | for i in range(num_layers): 32 | layer = convlstm.ConvLSTM(img_size=img_size_list[i], 33 | input_dim=input_channel, 34 | hidden_dim=hidden_channels[i], 35 | kernel_size=kernel_size[i], 36 | stride=stride[i], 37 | padding=padding[i], 38 | cnn_dropout=0.2, 39 | rnn_dropout=0., 40 | batch_first=True, 41 | bias=True, 42 | peephole=False, 43 | layer_norm=False, 44 | return_sequence=True, 45 | bidirectional=True) 46 | convlstm_layer.append(layer) 47 | input_channel = hidden_channels[i] 48 | ``` 49 | 50 | ## Explaination 51 | The imlementation firstly was inherited from [the repo](https://github.com/ndrplz/ConvLSTM_pytorch). 52 | 53 | However, I changed the source to have more exactly to the original paper [1]. 54 | ### 1. ConvLSTM definition 55 | Which are following in the paper definition: 56 |

57 | 58 |

59 | 60 | The ConvLSTM Cell is defined as following figure: 61 | ![](/figures/ConvLSTM_cell.png) 62 | 63 | ### 2. Bidirectional ConvLSTM decoder 64 | Our BCI decoder is a 5 timesteps bidirectional ConvLSTM, which contains two ConvLSTM layer: a forward layer to learn direction from left to right input, a backward layer to learn direction from right to left input. Detail in following figure: 65 | ![](/figures/BCI_decoder.png) 66 | 67 | ### 3. Input, output for decoder 68 | The input of our decoder is spike count or LMP, and output is velocity. 69 | ![](/figures/input_output_decoder.png) 70 | 71 | ## Environment 72 | This repository is tested on Python 3.7.0, Pytorch 1.6.0 73 | 74 | ## References 75 | [1] Xingjian, S. H. I., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. C. (2015). Convolutional LSTM network: A machine learning approach for precipitation nowcasting. In Advances in neural information processing systems (pp. 802-810). 76 | -------------------------------------------------------------------------------- /convlstm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Date: 2020/08/30 3 | @author: KimUyen 4 | # The code was revised from repo: https://github.com/ndrplz/ConvLSTM_pytorch 5 | ''' 6 | import torch.nn as nn 7 | import torch 8 | import math 9 | 10 | class HadamardProduct(nn.Module): 11 | def __init__(self, shape): 12 | super(HadamardProduct, self).__init__() 13 | self.weights = nn.Parameter(torch.rand(shape)).cuda() 14 | 15 | def forward(self, x): 16 | return x*self.weights 17 | 18 | class ConvLSTMCell(nn.Module): 19 | 20 | def __init__(self, img_size, input_dim, hidden_dim, kernel_size, 21 | cnn_dropout, rnn_dropout, bias=True, peephole=False, 22 | layer_norm=False): 23 | """ 24 | Initialize ConvLSTM cell. 25 | Parameters 26 | ---------- 27 | input_dim: int 28 | Number of channels of input tensor. 29 | hidden_dim: int 30 | Number of channels of hidden state. 31 | kernel_size: (int, int) 32 | Size of the convolutional kernel for both cnn and rnn. 33 | cnn_dropout, rnn_dropout: float 34 | cnn_dropout: dropout rate for convolutional input. 35 | rnn_dropout: dropout rate for convolutional state. 36 | bias: bool 37 | Whether or not to add the bias. 38 | peephole: bool 39 | add connection between cell state to gates 40 | layer_norm: bool 41 | layer normalization 42 | """ 43 | 44 | super(ConvLSTMCell, self).__init__() 45 | self.input_shape = img_size 46 | self.input_dim = input_dim 47 | self.hidden_dim = hidden_dim 48 | self.kernel_size = kernel_size 49 | self.padding = (int(self.kernel_size[0]/2), int(self.kernel_size[1]/2)) 50 | self.stride = (1, 1) 51 | self.bias = bias 52 | self.peephole = peephole 53 | self.layer_norm = layer_norm 54 | 55 | self.out_height = int((self.input_shape[0] - self.kernel_size[0] + 2*self.padding[0])/self.stride[0] + 1) 56 | self.out_width = int((self.input_shape[1] - self.kernel_size[1] + 2*self.padding[1])/self.stride[1] + 1) 57 | 58 | self.input_conv = nn.Conv2d(in_channels=self.input_dim, out_channels=4*self.hidden_dim, 59 | kernel_size=self.kernel_size, 60 | stride = self.stride, 61 | padding=self.padding, 62 | bias=self.bias) 63 | self.rnn_conv = nn.Conv2d(self.hidden_dim, out_channels=4*self.hidden_dim, 64 | kernel_size = self.kernel_size, 65 | padding=(math.floor(self.kernel_size[0]/2), 66 | math.floor(self.kernel_size[1]/2)), 67 | bias=self.bias) 68 | 69 | if self.peephole is True: 70 | self.weight_ci = HadamardProduct((1, self.hidden_dim, self.out_height, self.out_width)) 71 | self.weight_cf = HadamardProduct((1, self.hidden_dim, self.out_height, self.out_width)) 72 | self.weight_co = HadamardProduct((1, self.hidden_dim, self.out_height, self.out_width)) 73 | self.layer_norm_ci = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width]) 74 | self.layer_norm_cf = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width]) 75 | self.layer_norm_co = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width]) 76 | 77 | 78 | self.cnn_dropout = nn.Dropout(cnn_dropout) 79 | self.rnn_dropout = nn.Dropout(rnn_dropout) 80 | 81 | self.layer_norm_x = nn.LayerNorm([4*self.hidden_dim, self.out_height, self.out_width]) 82 | self.layer_norm_h = nn.LayerNorm([4*self.hidden_dim, self.out_height, self.out_width]) 83 | self.layer_norm_cnext = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width]) 84 | 85 | def forward(self, input_tensor, cur_state): 86 | h_cur, c_cur = cur_state 87 | 88 | x = self.cnn_dropout(input_tensor) 89 | x_conv = self.input_conv(x) 90 | if self.layer_norm is True: 91 | x_conv = self.layer_norm_x(x_conv) 92 | # separate i, f, c o 93 | x_i, x_f, x_c, x_o = torch.split(x_conv, self.hidden_dim, dim=1) 94 | 95 | h = self.rnn_dropout(h_cur) 96 | h_conv = self.rnn_conv(h) 97 | if self.layer_norm is True: 98 | h_conv = self.layer_norm_h(h_conv) 99 | # separate i, f, c o 100 | h_i, h_f, h_c, h_o = torch.split(h_conv, self.hidden_dim, dim=1) 101 | 102 | 103 | if self.peephole is True: 104 | f = torch.sigmoid((x_f + h_f) + self.layer_norm_cf(self.weight_cf(c_cur)) if self.layer_norm is True else self.weight_cf(c_cur)) 105 | i = torch.sigmoid((x_i + h_i) + self.layer_norm_ci(self.weight_ci(c_cur)) if self.layer_norm is True else self.weight_ci(c_cur)) 106 | else: 107 | f = torch.sigmoid((x_f + h_f)) 108 | i = torch.sigmoid((x_i + h_i)) 109 | 110 | 111 | g = torch.tanh((x_c + h_c)) 112 | c_next = f * c_cur + i * g 113 | if self.peephole is True: 114 | o = torch.sigmoid(x_o + h_o + self.layer_norm_co(self.weight_co(c_cur)) if self.layer_norm is True else self.weight_co(c_cur)) 115 | else: 116 | o = torch.sigmoid((x_o + h_o)) 117 | 118 | if self.layer_norm is True: 119 | c_next = self.layer_norm_cnext(c_next) 120 | h_next = o * torch.tanh(c_next) 121 | 122 | return h_next, c_next 123 | 124 | def init_hidden(self, batch_size): 125 | height, width = self.out_height, self.out_width 126 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.input_conv.weight.device), 127 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.input_conv.weight.device)) 128 | 129 | 130 | class ConvLSTM(nn.Module): 131 | 132 | """ 133 | Parameters: 134 | input_dim: Number of channels in input 135 | hidden_dim: Number of hidden channels 136 | kernel_size: Size of kernel in convolutions 137 | cnn_dropout, rnn_dropout: float 138 | cnn_dropout: dropout rate for convolutional input. 139 | rnn_dropout: dropout rate for convolutional state. 140 | batch_first: Whether or not dimension 0 is the batch or not 141 | bias: Bias or no bias in Convolution 142 | return_sequence: return output sequence or final output only 143 | bidirectional: bool 144 | bidirectional ConvLSTM 145 | Input: 146 | A tensor of size B, T, C, H, W or T, B, C, H, W 147 | Output: 148 | A tuple of two sequences output and state 149 | Example: 150 | >> x = torch.rand((32, 10, 64, 128, 128)) 151 | >> convlstm = ConvLSTM(input_dim=64, hidden_dim=16, kernel_size=(3, 3), 152 | cnn_dropout = 0.2, 153 | rnn_dropout=0.2, batch_first=True, bias=False) 154 | >> output, last_state = convlstm(x) 155 | """ 156 | 157 | def __init__(self, img_size, input_dim, hidden_dim, kernel_size, 158 | cnn_dropout=0.5, rnn_dropout=0.5, 159 | batch_first=False, bias=True, peephole=False, 160 | layer_norm=False, 161 | return_sequence=True, 162 | bidirectional=False): 163 | super(ConvLSTM, self).__init__() 164 | 165 | print(kernel_size) 166 | self.batch_first = batch_first 167 | self.return_sequence = return_sequence 168 | self.bidirectional = bidirectional 169 | 170 | cell_fw = ConvLSTMCell(img_size = img_size, 171 | input_dim=input_dim, 172 | hidden_dim=hidden_dim, 173 | kernel_size=kernel_size, 174 | cnn_dropout=cnn_dropout, 175 | rnn_dropout=rnn_dropout, 176 | bias=bias, 177 | peephole=peephole, 178 | layer_norm=layer_norm) 179 | self.cell_fw = cell_fw 180 | 181 | if self.bidirectional is True: 182 | cell_bw = ConvLSTMCell(img_size = img_size, 183 | input_dim=input_dim, 184 | hidden_dim=hidden_dim, 185 | kernel_size=kernel_size, 186 | cnn_dropout=cnn_dropout, 187 | rnn_dropout=rnn_dropout, 188 | bias=bias, 189 | peephole=peephole, 190 | layer_norm=layer_norm) 191 | self.cell_bw = cell_bw 192 | 193 | def forward(self, input_tensor, hidden_state=None): 194 | """ 195 | Parameters 196 | ---------- 197 | input_tensor: todo 198 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 199 | hidden_state: todo 200 | None. todo implement stateful 201 | Returns 202 | ------- 203 | layer_output, last_state 204 | """ 205 | if not self.batch_first: 206 | # (t, b, c, h, w) -> (b, t, c, h, w) 207 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 208 | 209 | b, seq_len, _, h, w = input_tensor.size() 210 | 211 | # Implement stateful ConvLSTM 212 | if hidden_state is not None: 213 | raise NotImplementedError() 214 | else: 215 | # Since the init is done in forward. Can send image size here 216 | hidden_state, hidden_state_inv = self._init_hidden(batch_size=b) 217 | # if self.bidirectional is True: 218 | # hidden_state_inv = self._init_hidden(batch_size=b) 219 | 220 | ## LSTM forward direction 221 | input_fw = input_tensor 222 | h, c = hidden_state 223 | output_inner = [] 224 | for t in range(seq_len): 225 | h, c = self.cell_fw(input_tensor=input_fw[:, t, :, :, :], 226 | cur_state=[h, c]) 227 | 228 | output_inner.append(h) 229 | output_inner = torch.stack((output_inner), dim=1) 230 | layer_output = output_inner 231 | last_state = [h, c] 232 | #################### 233 | 234 | 235 | ## LSTM inverse direction 236 | if self.bidirectional is True: 237 | input_inv = input_tensor 238 | h_inv, c_inv = hidden_state_inv 239 | output_inv = [] 240 | for t in range(seq_len-1, -1, -1): 241 | h_inv, c_inv = self.cell_bw(input_tensor=input_inv[:, t, :, :, :], 242 | cur_state=[h_inv, c_inv]) 243 | 244 | output_inv.append(h_inv) 245 | output_inv.reverse() 246 | output_inv = torch.stack((output_inv), dim=1) 247 | layer_output = torch.cat((output_inner, output_inv), dim=2) 248 | last_state_inv = [h_inv, c_inv] 249 | ################################### 250 | 251 | return layer_output if self.return_sequence is True else layer_output[:, -1:], last_state, last_state_inv if self.bidirectional is True else None 252 | 253 | def _init_hidden(self, batch_size): 254 | init_states_fw = self.cell_fw.init_hidden(batch_size) 255 | init_states_bw = None 256 | if self.bidirectional is True: 257 | init_states_bw = self.cell_bw.init_hidden(batch_size) 258 | return init_states_fw, init_states_bw -------------------------------------------------------------------------------- /convlstm_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import convlstm 3 | import config 4 | 5 | 6 | class Flatten(torch.nn.Module): 7 | def forward(self, input): 8 | b, seq_len, _, h, w = input.size() 9 | return input.view(b, seq_len, -1) 10 | 11 | class ConvLSTMNetwork(torch.nn.Module): 12 | def __init__(self, img_size_list, input_channel, hidden_channels, kernel_size, num_layers, bidirectional = False): 13 | super(ConvLSTMNetwork, self).__init__() 14 | 15 | self.hidden_channels = hidden_channels 16 | self.num_layers = num_layers 17 | self.bidirectional = bidirectional 18 | 19 | convlstm_layer = [] 20 | for i in range(num_layers): 21 | layer = convlstm.ConvLSTM(img_size_list[i], 22 | input_channel, 23 | hidden_channels[i], 24 | kernel_size[i], 25 | 0.2, 0., 26 | batch_first=True, 27 | bias=True, 28 | peephole=True, 29 | layer_norm=True, 30 | return_sequence=config.SEQUENCE_OUTPUT, 31 | bidirectional=self.bidirectional) 32 | convlstm_layer.append(layer) 33 | input_channel = hidden_channels[i] * (2 if self.bidirectional else 1) 34 | 35 | self.convlstm_layer = torch.nn.ModuleList(convlstm_layer) 36 | self.flatten = Flatten() 37 | self.linear2 = torch.nn.Linear(hidden_channels[-1]*(2 if self.bidirectional else 1)*16, 2) 38 | 39 | def forward(self, x): 40 | input_tensor = x 41 | for i in range(self.num_layers): 42 | input_tensor, _, _ = self.convlstm_layer[i](input_tensor) 43 | 44 | out_flatten = self.flatten(input_tensor) 45 | output = self.linear2(out_flatten) 46 | return output 47 | 48 | -------------------------------------------------------------------------------- /figures/BCI_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/BCI_decoder.png -------------------------------------------------------------------------------- /figures/BCI_system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/BCI_system.png -------------------------------------------------------------------------------- /figures/ConvLSTM_cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/ConvLSTM_cell.png -------------------------------------------------------------------------------- /figures/ConvLSTM_definition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/ConvLSTM_definition.png -------------------------------------------------------------------------------- /figures/input_output_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/input_output_decoder.png --------------------------------------------------------------------------------