├── LICENSE
├── README.md
├── convlstm.py
├── convlstm_decoder.py
└── figures
├── BCI_decoder.png
├── BCI_system.png
├── ConvLSTM_cell.png
├── ConvLSTM_definition.png
└── input_output_decoder.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Nguyen Thi Kim Uyen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ConvLSTM Pytorch Implementation
2 |
3 | - [Goal](#Goal)
4 | - [Example of using ConvLSTM](#Example-of-using-ConvLSTM)
5 | - [Explaination](#Explaination)
6 | - [1. ConvLSTM definition](#1-ConvLSTM-definition)
7 | - [2. Bidirectional ConvLSTM decoder](#2-bidirectional-convlstm-decoder)
8 | - [3. Input, output for decoder](#3-input-output-for-decoder)
9 | - [Environment](#Environment)
10 | - [References](#References)
11 |
12 | ## Goal
13 | The ConvLSTM model is mainly used as skeleton to design a BCI (Brain Computer Interface) decoder for our project (Decode the kinematic signal from neural signal).
14 | This repo is implementation of ConvLSTM in Pytorch. The implemenation is inherited from the paper: Convolutional LSTM Network-A Machine LearningApproach for Precipitation Nowcasting
15 |
16 | BCI decoder is a part in BCI system, which is clearly shown in the above figure.
17 |
18 | ## Example of using ConvLSTM
19 | convlstm_decoder.py contains an example of defining a ConvLSTM decoder.
20 |
21 | Here is an example of defining 1 layer bidirectional ConvLSTM:
22 | ```
23 | convlstm_layer = []
24 | img_size_list=[(10, 10)]
25 | num_layers = 1 # number of layer
26 | input_channel = 96 # the number of electrodes in Utah array
27 | hidden_channels = [256] # the output channels for each layer
28 | kernel_size = [(7, 7)] # the kernel size of cnn for each layer
29 | stride = [(1, 1)] # the stride size of cnn for each layer
30 | padding = [(0, 0)] # padding size of cnn for each layer
31 | for i in range(num_layers):
32 | layer = convlstm.ConvLSTM(img_size=img_size_list[i],
33 | input_dim=input_channel,
34 | hidden_dim=hidden_channels[i],
35 | kernel_size=kernel_size[i],
36 | stride=stride[i],
37 | padding=padding[i],
38 | cnn_dropout=0.2,
39 | rnn_dropout=0.,
40 | batch_first=True,
41 | bias=True,
42 | peephole=False,
43 | layer_norm=False,
44 | return_sequence=True,
45 | bidirectional=True)
46 | convlstm_layer.append(layer)
47 | input_channel = hidden_channels[i]
48 | ```
49 |
50 | ## Explaination
51 | The imlementation firstly was inherited from [the repo](https://github.com/ndrplz/ConvLSTM_pytorch).
52 |
53 | However, I changed the source to have more exactly to the original paper [1].
54 | ### 1. ConvLSTM definition
55 | Which are following in the paper definition:
56 |
57 |
58 |
59 |
60 | The ConvLSTM Cell is defined as following figure:
61 | 
62 |
63 | ### 2. Bidirectional ConvLSTM decoder
64 | Our BCI decoder is a 5 timesteps bidirectional ConvLSTM, which contains two ConvLSTM layer: a forward layer to learn direction from left to right input, a backward layer to learn direction from right to left input. Detail in following figure:
65 | 
66 |
67 | ### 3. Input, output for decoder
68 | The input of our decoder is spike count or LMP, and output is velocity.
69 | 
70 |
71 | ## Environment
72 | This repository is tested on Python 3.7.0, Pytorch 1.6.0
73 |
74 | ## References
75 | [1] Xingjian, S. H. I., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. C. (2015). Convolutional LSTM network: A machine learning approach for precipitation nowcasting. In Advances in neural information processing systems (pp. 802-810).
76 |
--------------------------------------------------------------------------------
/convlstm.py:
--------------------------------------------------------------------------------
1 | '''
2 | Date: 2020/08/30
3 | @author: KimUyen
4 | # The code was revised from repo: https://github.com/ndrplz/ConvLSTM_pytorch
5 | '''
6 | import torch.nn as nn
7 | import torch
8 | import math
9 |
10 | class HadamardProduct(nn.Module):
11 | def __init__(self, shape):
12 | super(HadamardProduct, self).__init__()
13 | self.weights = nn.Parameter(torch.rand(shape)).cuda()
14 |
15 | def forward(self, x):
16 | return x*self.weights
17 |
18 | class ConvLSTMCell(nn.Module):
19 |
20 | def __init__(self, img_size, input_dim, hidden_dim, kernel_size,
21 | cnn_dropout, rnn_dropout, bias=True, peephole=False,
22 | layer_norm=False):
23 | """
24 | Initialize ConvLSTM cell.
25 | Parameters
26 | ----------
27 | input_dim: int
28 | Number of channels of input tensor.
29 | hidden_dim: int
30 | Number of channels of hidden state.
31 | kernel_size: (int, int)
32 | Size of the convolutional kernel for both cnn and rnn.
33 | cnn_dropout, rnn_dropout: float
34 | cnn_dropout: dropout rate for convolutional input.
35 | rnn_dropout: dropout rate for convolutional state.
36 | bias: bool
37 | Whether or not to add the bias.
38 | peephole: bool
39 | add connection between cell state to gates
40 | layer_norm: bool
41 | layer normalization
42 | """
43 |
44 | super(ConvLSTMCell, self).__init__()
45 | self.input_shape = img_size
46 | self.input_dim = input_dim
47 | self.hidden_dim = hidden_dim
48 | self.kernel_size = kernel_size
49 | self.padding = (int(self.kernel_size[0]/2), int(self.kernel_size[1]/2))
50 | self.stride = (1, 1)
51 | self.bias = bias
52 | self.peephole = peephole
53 | self.layer_norm = layer_norm
54 |
55 | self.out_height = int((self.input_shape[0] - self.kernel_size[0] + 2*self.padding[0])/self.stride[0] + 1)
56 | self.out_width = int((self.input_shape[1] - self.kernel_size[1] + 2*self.padding[1])/self.stride[1] + 1)
57 |
58 | self.input_conv = nn.Conv2d(in_channels=self.input_dim, out_channels=4*self.hidden_dim,
59 | kernel_size=self.kernel_size,
60 | stride = self.stride,
61 | padding=self.padding,
62 | bias=self.bias)
63 | self.rnn_conv = nn.Conv2d(self.hidden_dim, out_channels=4*self.hidden_dim,
64 | kernel_size = self.kernel_size,
65 | padding=(math.floor(self.kernel_size[0]/2),
66 | math.floor(self.kernel_size[1]/2)),
67 | bias=self.bias)
68 |
69 | if self.peephole is True:
70 | self.weight_ci = HadamardProduct((1, self.hidden_dim, self.out_height, self.out_width))
71 | self.weight_cf = HadamardProduct((1, self.hidden_dim, self.out_height, self.out_width))
72 | self.weight_co = HadamardProduct((1, self.hidden_dim, self.out_height, self.out_width))
73 | self.layer_norm_ci = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width])
74 | self.layer_norm_cf = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width])
75 | self.layer_norm_co = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width])
76 |
77 |
78 | self.cnn_dropout = nn.Dropout(cnn_dropout)
79 | self.rnn_dropout = nn.Dropout(rnn_dropout)
80 |
81 | self.layer_norm_x = nn.LayerNorm([4*self.hidden_dim, self.out_height, self.out_width])
82 | self.layer_norm_h = nn.LayerNorm([4*self.hidden_dim, self.out_height, self.out_width])
83 | self.layer_norm_cnext = nn.LayerNorm([self.hidden_dim, self.out_height, self.out_width])
84 |
85 | def forward(self, input_tensor, cur_state):
86 | h_cur, c_cur = cur_state
87 |
88 | x = self.cnn_dropout(input_tensor)
89 | x_conv = self.input_conv(x)
90 | if self.layer_norm is True:
91 | x_conv = self.layer_norm_x(x_conv)
92 | # separate i, f, c o
93 | x_i, x_f, x_c, x_o = torch.split(x_conv, self.hidden_dim, dim=1)
94 |
95 | h = self.rnn_dropout(h_cur)
96 | h_conv = self.rnn_conv(h)
97 | if self.layer_norm is True:
98 | h_conv = self.layer_norm_h(h_conv)
99 | # separate i, f, c o
100 | h_i, h_f, h_c, h_o = torch.split(h_conv, self.hidden_dim, dim=1)
101 |
102 |
103 | if self.peephole is True:
104 | f = torch.sigmoid((x_f + h_f) + self.layer_norm_cf(self.weight_cf(c_cur)) if self.layer_norm is True else self.weight_cf(c_cur))
105 | i = torch.sigmoid((x_i + h_i) + self.layer_norm_ci(self.weight_ci(c_cur)) if self.layer_norm is True else self.weight_ci(c_cur))
106 | else:
107 | f = torch.sigmoid((x_f + h_f))
108 | i = torch.sigmoid((x_i + h_i))
109 |
110 |
111 | g = torch.tanh((x_c + h_c))
112 | c_next = f * c_cur + i * g
113 | if self.peephole is True:
114 | o = torch.sigmoid(x_o + h_o + self.layer_norm_co(self.weight_co(c_cur)) if self.layer_norm is True else self.weight_co(c_cur))
115 | else:
116 | o = torch.sigmoid((x_o + h_o))
117 |
118 | if self.layer_norm is True:
119 | c_next = self.layer_norm_cnext(c_next)
120 | h_next = o * torch.tanh(c_next)
121 |
122 | return h_next, c_next
123 |
124 | def init_hidden(self, batch_size):
125 | height, width = self.out_height, self.out_width
126 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.input_conv.weight.device),
127 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.input_conv.weight.device))
128 |
129 |
130 | class ConvLSTM(nn.Module):
131 |
132 | """
133 | Parameters:
134 | input_dim: Number of channels in input
135 | hidden_dim: Number of hidden channels
136 | kernel_size: Size of kernel in convolutions
137 | cnn_dropout, rnn_dropout: float
138 | cnn_dropout: dropout rate for convolutional input.
139 | rnn_dropout: dropout rate for convolutional state.
140 | batch_first: Whether or not dimension 0 is the batch or not
141 | bias: Bias or no bias in Convolution
142 | return_sequence: return output sequence or final output only
143 | bidirectional: bool
144 | bidirectional ConvLSTM
145 | Input:
146 | A tensor of size B, T, C, H, W or T, B, C, H, W
147 | Output:
148 | A tuple of two sequences output and state
149 | Example:
150 | >> x = torch.rand((32, 10, 64, 128, 128))
151 | >> convlstm = ConvLSTM(input_dim=64, hidden_dim=16, kernel_size=(3, 3),
152 | cnn_dropout = 0.2,
153 | rnn_dropout=0.2, batch_first=True, bias=False)
154 | >> output, last_state = convlstm(x)
155 | """
156 |
157 | def __init__(self, img_size, input_dim, hidden_dim, kernel_size,
158 | cnn_dropout=0.5, rnn_dropout=0.5,
159 | batch_first=False, bias=True, peephole=False,
160 | layer_norm=False,
161 | return_sequence=True,
162 | bidirectional=False):
163 | super(ConvLSTM, self).__init__()
164 |
165 | print(kernel_size)
166 | self.batch_first = batch_first
167 | self.return_sequence = return_sequence
168 | self.bidirectional = bidirectional
169 |
170 | cell_fw = ConvLSTMCell(img_size = img_size,
171 | input_dim=input_dim,
172 | hidden_dim=hidden_dim,
173 | kernel_size=kernel_size,
174 | cnn_dropout=cnn_dropout,
175 | rnn_dropout=rnn_dropout,
176 | bias=bias,
177 | peephole=peephole,
178 | layer_norm=layer_norm)
179 | self.cell_fw = cell_fw
180 |
181 | if self.bidirectional is True:
182 | cell_bw = ConvLSTMCell(img_size = img_size,
183 | input_dim=input_dim,
184 | hidden_dim=hidden_dim,
185 | kernel_size=kernel_size,
186 | cnn_dropout=cnn_dropout,
187 | rnn_dropout=rnn_dropout,
188 | bias=bias,
189 | peephole=peephole,
190 | layer_norm=layer_norm)
191 | self.cell_bw = cell_bw
192 |
193 | def forward(self, input_tensor, hidden_state=None):
194 | """
195 | Parameters
196 | ----------
197 | input_tensor: todo
198 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
199 | hidden_state: todo
200 | None. todo implement stateful
201 | Returns
202 | -------
203 | layer_output, last_state
204 | """
205 | if not self.batch_first:
206 | # (t, b, c, h, w) -> (b, t, c, h, w)
207 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
208 |
209 | b, seq_len, _, h, w = input_tensor.size()
210 |
211 | # Implement stateful ConvLSTM
212 | if hidden_state is not None:
213 | raise NotImplementedError()
214 | else:
215 | # Since the init is done in forward. Can send image size here
216 | hidden_state, hidden_state_inv = self._init_hidden(batch_size=b)
217 | # if self.bidirectional is True:
218 | # hidden_state_inv = self._init_hidden(batch_size=b)
219 |
220 | ## LSTM forward direction
221 | input_fw = input_tensor
222 | h, c = hidden_state
223 | output_inner = []
224 | for t in range(seq_len):
225 | h, c = self.cell_fw(input_tensor=input_fw[:, t, :, :, :],
226 | cur_state=[h, c])
227 |
228 | output_inner.append(h)
229 | output_inner = torch.stack((output_inner), dim=1)
230 | layer_output = output_inner
231 | last_state = [h, c]
232 | ####################
233 |
234 |
235 | ## LSTM inverse direction
236 | if self.bidirectional is True:
237 | input_inv = input_tensor
238 | h_inv, c_inv = hidden_state_inv
239 | output_inv = []
240 | for t in range(seq_len-1, -1, -1):
241 | h_inv, c_inv = self.cell_bw(input_tensor=input_inv[:, t, :, :, :],
242 | cur_state=[h_inv, c_inv])
243 |
244 | output_inv.append(h_inv)
245 | output_inv.reverse()
246 | output_inv = torch.stack((output_inv), dim=1)
247 | layer_output = torch.cat((output_inner, output_inv), dim=2)
248 | last_state_inv = [h_inv, c_inv]
249 | ###################################
250 |
251 | return layer_output if self.return_sequence is True else layer_output[:, -1:], last_state, last_state_inv if self.bidirectional is True else None
252 |
253 | def _init_hidden(self, batch_size):
254 | init_states_fw = self.cell_fw.init_hidden(batch_size)
255 | init_states_bw = None
256 | if self.bidirectional is True:
257 | init_states_bw = self.cell_bw.init_hidden(batch_size)
258 | return init_states_fw, init_states_bw
--------------------------------------------------------------------------------
/convlstm_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import convlstm
3 | import config
4 |
5 |
6 | class Flatten(torch.nn.Module):
7 | def forward(self, input):
8 | b, seq_len, _, h, w = input.size()
9 | return input.view(b, seq_len, -1)
10 |
11 | class ConvLSTMNetwork(torch.nn.Module):
12 | def __init__(self, img_size_list, input_channel, hidden_channels, kernel_size, num_layers, bidirectional = False):
13 | super(ConvLSTMNetwork, self).__init__()
14 |
15 | self.hidden_channels = hidden_channels
16 | self.num_layers = num_layers
17 | self.bidirectional = bidirectional
18 |
19 | convlstm_layer = []
20 | for i in range(num_layers):
21 | layer = convlstm.ConvLSTM(img_size_list[i],
22 | input_channel,
23 | hidden_channels[i],
24 | kernel_size[i],
25 | 0.2, 0.,
26 | batch_first=True,
27 | bias=True,
28 | peephole=True,
29 | layer_norm=True,
30 | return_sequence=config.SEQUENCE_OUTPUT,
31 | bidirectional=self.bidirectional)
32 | convlstm_layer.append(layer)
33 | input_channel = hidden_channels[i] * (2 if self.bidirectional else 1)
34 |
35 | self.convlstm_layer = torch.nn.ModuleList(convlstm_layer)
36 | self.flatten = Flatten()
37 | self.linear2 = torch.nn.Linear(hidden_channels[-1]*(2 if self.bidirectional else 1)*16, 2)
38 |
39 | def forward(self, x):
40 | input_tensor = x
41 | for i in range(self.num_layers):
42 | input_tensor, _, _ = self.convlstm_layer[i](input_tensor)
43 |
44 | out_flatten = self.flatten(input_tensor)
45 | output = self.linear2(out_flatten)
46 | return output
47 |
48 |
--------------------------------------------------------------------------------
/figures/BCI_decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/BCI_decoder.png
--------------------------------------------------------------------------------
/figures/BCI_system.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/BCI_system.png
--------------------------------------------------------------------------------
/figures/ConvLSTM_cell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/ConvLSTM_cell.png
--------------------------------------------------------------------------------
/figures/ConvLSTM_definition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/ConvLSTM_definition.png
--------------------------------------------------------------------------------
/figures/input_output_decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KimUyen/ConvLSTM-Pytorch/c7b4bd108335a4d6c7d99c00c263346026186b0b/figures/input_output_decoder.png
--------------------------------------------------------------------------------