├── README.md ├── LICENSE ├── conv2d_rnncells.py └── conv2d_rnnmodels.py /README.md: -------------------------------------------------------------------------------- 1 | # 2D Convolutional Recurrent Neural Networks with PyTorch 2 | ## Two dimensional Convolutional Recurrent Neural Networks implemented in PyTorch 3 | 4 | The architecture of ```Conv2dLSTMCell``` was inspired by "Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting" 5 | (https://arxiv.org/pdf/1506.04214.pdf). 6 | 7 | See the image below for the key equations of ```Conv2dLSTMCell```: 8 | 9 | ![Capture](https://user-images.githubusercontent.com/71031687/112730543-73de0900-8f32-11eb-8396-a79091979335.JPG) 10 | 11 | 12 | The implementations of ```Conv2dRNNCell``` and ```Conv2dGRUCell``` are based on the implementation of Convolutional LSTM. 13 | 14 | 15 | This repo contains implementations of: 16 | 17 | * Conv2dRNNCell 18 | * Conv2dLSTMCell 19 | * Conv2dGRUCell 20 | 21 | and 22 | 23 | * Conv2dRNN / Biderectional Conv2dRNN 24 | * Conv2dLSTM / Biderectional Conv2dLSTM 25 | * Conv2dGRU / Biderectional Conv2dGRU. 26 | 27 | ## Dependencies 28 | 29 | * ```pytorch``` 30 | * ```numpy``` 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 George Yiasemis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /conv2d_rnncells.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | class Conv2dLSTMCell(nn.Module): 7 | def __init__(self, input_size, hidden_size, kernel_size, bias=True): 8 | super(Conv2dLSTMCell, self).__init__() 9 | 10 | self.input_size = input_size 11 | self.hidden_size = hidden_size 12 | 13 | if type(kernel_size) == tuple and len(kernel_size) == 2: 14 | self.kernel_size = kernel_size 15 | self.padding = (kernel_size[0] // 2, kernel_size[1] // 2) 16 | elif type(kernel_size) == int: 17 | self.kernel_size = (kernel_size, kernel_size) 18 | self.padding = (kernel_size // 2, kernel_size // 2) 19 | else: 20 | raise ValueError("Invalid kernel size.") 21 | 22 | self.bias = bias 23 | self.x2h = nn.Conv2d(in_channels=input_size, 24 | out_channels=hidden_size * 4, 25 | kernel_size=self.kernel_size, 26 | padding=self.padding, 27 | bias=bias) 28 | 29 | self.h2h = nn.Conv2d(in_channels=hidden_size, 30 | out_channels=hidden_size * 4, 31 | kernel_size=self.kernel_size, 32 | padding=self.padding, 33 | bias=bias) 34 | self.Wc = None 35 | self.reset_parameters() 36 | 37 | def reset_parameters(self): 38 | std = 1.0 / np.sqrt(self.hidden_size) 39 | for w in self.parameters(): 40 | w.data.uniform_(-std, std) 41 | 42 | def forward(self, input, hx=None): 43 | 44 | # Inputs: 45 | # input: of shape (batch_size, input_size, height_size, width_size) 46 | # hx: of shape (batch_size, hidden_size, height_size, width_size) 47 | # Outputs: 48 | # hy: of shape (batch_size, hidden_size, height_size, width_size) 49 | # cy: of shape (batch_size, hidden_size, height_size, width_size) 50 | 51 | if self.Wc == None: 52 | self.Wc = nn.Parameter(torch.zeros((1, self.hidden_size * 3, input.size(2), input.size(3)))) 53 | 54 | if hx is None: 55 | hx = Variable(input.new_zeros(input.size(0), self.hidden_size, input.size(2), input.size(3))) 56 | hx = (hx, hx) 57 | hx, cx = hx 58 | 59 | gates = self.x2h(input) + self.h2h(hx) 60 | 61 | # Get gates (i_t, f_t, g_t, o_t) 62 | input_gate, forget_gate, cell_gate, output_gate = gates.chunk(4, 1) 63 | 64 | Wci, Wcf, Wco = self.Wc.chunk(3, 1) 65 | 66 | i_t = torch.sigmoid(input_gate + Wci * cx) 67 | f_t = torch.sigmoid(forget_gate + Wcf * cx) 68 | g_t = torch.tanh(cell_gate) 69 | 70 | cy = f_t * cx + i_t * torch.tanh(g_t) 71 | o_t = torch.sigmoid(output_gate + Wco * cy) 72 | 73 | hy = o_t * torch.tanh(cy) 74 | 75 | 76 | return (hy, cy) 77 | 78 | class Conv2dRNNCell(nn.Module): 79 | def __init__(self, input_size, hidden_size, kernel_size, bias=True, nonlinearity="tanh"): 80 | super(Conv2dRNNCell, self).__init__() 81 | 82 | self.input_size = input_size 83 | self.hidden_size = hidden_size 84 | 85 | if type(kernel_size) == tuple and len(kernel_size) == 2: 86 | self.kernel_size = kernel_size 87 | self.padding = (kernel_size[0] // 2, kernel_size[1] // 2) 88 | elif type(kernel_size) == int: 89 | self.kernel_size = (kernel_size, kernel_size) 90 | self.padding = (kernel_size // 2, kernel_size // 2) 91 | else: 92 | raise ValueError("Invalid kernel size.") 93 | 94 | self.bias = bias 95 | self.nonlinearity = nonlinearity 96 | if self.nonlinearity not in ["tanh", "relu"]: 97 | raise ValueError("Invalid nonlinearity selected for RNN.") 98 | 99 | self.x2h = nn.Conv2d(in_channels=input_size, 100 | out_channels=hidden_size, 101 | kernel_size=self.kernel_size, 102 | padding=self.padding, 103 | bias=bias) 104 | 105 | self.h2h = nn.Conv2d(in_channels=hidden_size, 106 | out_channels=hidden_size, 107 | kernel_size=self.kernel_size, 108 | padding=self.padding, 109 | bias=bias) 110 | self.reset_parameters() 111 | 112 | 113 | def reset_parameters(self): 114 | std = 1.0 / np.sqrt(self.hidden_size) 115 | for w in self.parameters(): 116 | w.data.uniform_(-std, std) 117 | 118 | 119 | def forward(self, input, hx=None): 120 | 121 | # Inputs: 122 | # input: of shape (batch_size, input_size, height_size, width_size) 123 | # hx: of shape (batch_size, hidden_size, height_size, width_size) 124 | # Outputs: 125 | # hy: of shape (batch_size, hidden_size, height_size, width_size) 126 | 127 | if hx is None: 128 | hx = Variable(input.new_zeros(input.size(0), self.hidden_size, input.size(2), input.size(3))) 129 | 130 | hy = (self.x2h(input) + self.h2h(hx)) 131 | 132 | if self.nonlinearity == "tanh": 133 | hy = torch.tanh(hy) 134 | else: 135 | hy = torch.relu(hy) 136 | 137 | return hy 138 | 139 | class Conv2dGRUCell(nn.Module): 140 | def __init__(self, input_size, hidden_size, kernel_size, bias=True): 141 | super(Conv2dGRUCell, self).__init__() 142 | 143 | self.input_size = input_size 144 | self.hidden_size = hidden_size 145 | 146 | if type(kernel_size) == tuple and len(kernel_size) == 2: 147 | self.kernel_size = kernel_size 148 | self.padding = (kernel_size[0] // 2, kernel_size[1] // 2) 149 | elif type(kernel_size) == int: 150 | self.kernel_size = (kernel_size, kernel_size) 151 | self.padding = (kernel_size // 2, kernel_size // 2) 152 | else: 153 | raise ValueError("Invalid kernel size.") 154 | 155 | self.bias = bias 156 | self.x2h = nn.Conv2d(in_channels=input_size, 157 | out_channels=hidden_size * 3, 158 | kernel_size=self.kernel_size, 159 | padding=self.padding, 160 | bias=bias) 161 | 162 | self.h2h = nn.Conv2d(in_channels=hidden_size, 163 | out_channels=hidden_size * 3, 164 | kernel_size=self.kernel_size, 165 | padding=self.padding, 166 | bias=bias) 167 | self.reset_parameters() 168 | 169 | 170 | def reset_parameters(self): 171 | std = 1.0 / np.sqrt(self.hidden_size) 172 | for w in self.parameters(): 173 | w.data.uniform_(-std, std) 174 | 175 | def forward(self, input, hx=None): 176 | 177 | # Inputs: 178 | # input: of shape (batch_size, input_size, height_size, width_size) 179 | # hx: of shape (batch_size, hidden_size, height_size, width_size) 180 | # Outputs: 181 | # hy: of shape (batch_size, hidden_size, height_size, width_size) 182 | 183 | if hx is None: 184 | hx = Variable(input.new_zeros(input.size(0), self.hidden_size, input.size(2), input.size(3))) 185 | 186 | x_t = self.x2h(input) 187 | h_t = self.h2h(hx) 188 | 189 | 190 | x_reset, x_upd, x_new = x_t.chunk(3, 1) 191 | h_reset, h_upd, h_new = h_t.chunk(3, 1) 192 | 193 | reset_gate = torch.sigmoid(x_reset + h_reset) 194 | update_gate = torch.sigmoid(x_upd + h_upd) 195 | new_gate = torch.tanh(x_new + (reset_gate * h_new)) 196 | 197 | hy = update_gate * hx + (1 - update_gate) * new_gate 198 | 199 | return hy 200 | -------------------------------------------------------------------------------- /conv2d_rnnmodels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | from conv2d_rnncells import Conv2dLSTMCell, Conv2dGRUCell, Conv2dRNNCell 7 | 8 | class Conv2dRNN(nn.Module): 9 | def __init__(self, input_size, hidden_size, kernel_size, num_layers, bias, output_size, activation='tanh'): 10 | super(Conv2dRNN, self).__init__() 11 | 12 | self.input_size = input_size 13 | self.hidden_size = hidden_size 14 | 15 | if type(kernel_size) == tuple and len(kernel_size) == 2: 16 | self.kernel_size = kernel_size 17 | self.padding = (kernel_size[0] // 2, kernel_size[1] // 2) 18 | elif type(kernel_size) == int: 19 | self.kernel_size = (kernel_size, kernel_size) 20 | self.padding = (kernel_size // 2, kernel_size // 2) 21 | else: 22 | raise ValueError("Invalid kernel size.") 23 | 24 | self.num_layers = num_layers 25 | self.bias = bias 26 | self.output_size = output_size 27 | 28 | self.rnn_cell_list = nn.ModuleList() 29 | 30 | if activation == 'tanh': 31 | self.rnn_cell_list.append(Conv2dRNNCell(self.input_size, 32 | self.hidden_size, 33 | self.kernel_size, 34 | self.bias, 35 | "tanh")) 36 | for l in range(1, self.num_layers): 37 | self.rnn_cell_list.append(Conv2dRNNCell(self.hidden_size, 38 | self.hidden_size, 39 | self.kernel_size, 40 | self.bias, 41 | "tanh")) 42 | 43 | elif activation == 'relu': 44 | self.rnn_cell_list.append(Conv2dRNNCell(self.input_size, 45 | self.hidden_size, 46 | self.kernel_size, 47 | self.bias, 48 | "relu")) 49 | for l in range(1, self.num_layers): 50 | self.rnn_cell_list.append(Conv2dRNNCell(self.hidden_size, 51 | self.hidden_size, 52 | self.kernel_size, 53 | self.bias, 54 | "relu")) 55 | else: 56 | raise ValueError("Invalid activation.") 57 | 58 | self.conv = nn.Conv2d(in_channels=self.hidden_size, 59 | out_channels=self.output_size, 60 | kernel_size=self.kernel_size, 61 | padding=self.padding, 62 | bias=self.bias) 63 | 64 | def forward(self, input, hx=None): 65 | 66 | # Input of shape (batch_size, seqence length, input_size) 67 | # 68 | # Output of shape (batch_size, output_size) 69 | 70 | if hx is None: 71 | if torch.cuda.is_available(): 72 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1)).cuda()) 73 | else: 74 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1))) 75 | 76 | else: 77 | h0 = hx 78 | 79 | outs = [] 80 | 81 | hidden = list() 82 | for layer in range(self.num_layers): 83 | hidden.append(h0[layer]) 84 | 85 | for t in range(input.size(1)): 86 | 87 | for layer in range(self.num_layers): 88 | 89 | if layer == 0: 90 | hidden_l = self.rnn_cell_list[layer](input[:, t], hidden[layer]) 91 | else: 92 | hidden_l = self.rnn_cell_list[layer](hidden[layer - 1],hidden[layer]) 93 | hidden[layer] = hidden_l 94 | 95 | hidden[layer] = hidden_l 96 | 97 | outs.append(hidden_l) 98 | 99 | # Take only last time step. Modify for seq to seq 100 | out = outs[-1].squeeze() 101 | 102 | out = self.conv(out) 103 | 104 | return out 105 | 106 | class Conv2dLSTM(nn.Module): 107 | def __init__(self, input_size, hidden_size, kernel_size, num_layers, bias, output_size): 108 | super(Conv2dLSTM, self).__init__() 109 | 110 | self.input_size = input_size 111 | self.hidden_size = hidden_size 112 | 113 | if type(kernel_size) == tuple and len(kernel_size) == 2: 114 | self.kernel_size = kernel_size 115 | self.padding = (kernel_size[0] // 2, kernel_size[1] // 2) 116 | elif type(kernel_size) == int: 117 | self.kernel_size = (kernel_size, kernel_size) 118 | self.padding = (kernel_size // 2, kernel_size // 2) 119 | else: 120 | raise ValueError("Invalid kernel size.") 121 | 122 | self.num_layers = num_layers 123 | self.bias = bias 124 | self.output_size = output_size 125 | 126 | self.rnn_cell_list = nn.ModuleList() 127 | 128 | self.rnn_cell_list.append(Conv2dLSTMCell(self.input_size, 129 | self.hidden_size, 130 | self.kernel_size, 131 | self.bias)) 132 | for l in range(1, self.num_layers): 133 | self.rnn_cell_list.append(Conv2dLSTMCell(self.hidden_size, 134 | self.hidden_size, 135 | self.kernel_size, 136 | self.bias)) 137 | 138 | self.conv = nn.Conv2d(in_channels=self.hidden_size, 139 | out_channels=self.output_size, 140 | kernel_size=self.kernel_size, 141 | padding=self.padding, 142 | bias=self.bias) 143 | 144 | def forward(self, input, hx=None): 145 | 146 | # Input of shape (batch_size, seqence length , input_size) 147 | # 148 | # Output of shape (batch_size, output_size) 149 | 150 | if hx is None: 151 | if torch.cuda.is_available(): 152 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1)).cuda()) 153 | else: 154 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1))) 155 | else: 156 | h0 = hx 157 | 158 | outs = [] 159 | 160 | hidden = list() 161 | for layer in range(self.num_layers): 162 | hidden.append((h0[layer], h0[layer])) 163 | 164 | for t in range(input.size(1)): 165 | 166 | for layer in range(self.num_layers): 167 | 168 | if layer == 0: 169 | hidden_l = self.rnn_cell_list[layer]( 170 | input[:, t, :], 171 | (hidden[layer][0],hidden[layer][1]) 172 | ) 173 | else: 174 | hidden_l = self.rnn_cell_list[layer]( 175 | hidden[layer - 1][0], 176 | (hidden[layer][0], hidden[layer][1]) 177 | ) 178 | 179 | hidden[layer] = hidden_l 180 | 181 | outs.append(hidden_l[0]) 182 | 183 | out = outs[-1].squeeze() 184 | 185 | out = self.conv(out) 186 | 187 | return out 188 | 189 | class Conv2dGRU(nn.Module): 190 | def __init__(self, input_size, hidden_size, kernel_size, num_layers, bias, output_size): 191 | super(Conv2dGRU, self).__init__() 192 | 193 | self.input_size = input_size 194 | self.hidden_size = hidden_size 195 | 196 | if type(kernel_size) == tuple and len(kernel_size) == 2: 197 | self.kernel_size = kernel_size 198 | self.padding = (kernel_size[0] // 2, kernel_size[1] // 2) 199 | elif type(kernel_size) == int: 200 | self.kernel_size = (kernel_size, kernel_size) 201 | self.padding = (kernel_size // 2, kernel_size // 2) 202 | else: 203 | raise ValueError("Invalid kernel size.") 204 | 205 | self.num_layers = num_layers 206 | self.bias = bias 207 | self.output_size = output_size 208 | 209 | self.rnn_cell_list = nn.ModuleList() 210 | 211 | self.rnn_cell_list.append(Conv2dGRUCell(self.input_size, 212 | self.hidden_size, 213 | self.kernel_size, 214 | self.bias)) 215 | for l in range(1, self.num_layers): 216 | self.rnn_cell_list.append(Conv2dGRUCell(self.hidden_size, 217 | self.hidden_size, 218 | self.kernel_size, 219 | self.bias)) 220 | 221 | self.conv = nn.Conv2d(in_channels=self.hidden_size, 222 | out_channels=self.output_size, 223 | kernel_size=self.kernel_size, 224 | padding=self.padding, 225 | bias=self.bias) 226 | 227 | 228 | def forward(self, input, hx=None): 229 | 230 | # Input of shape (batch_size, seqence length, input_size) 231 | # 232 | # Output of shape (batch_size, output_size) 233 | 234 | if hx is None: 235 | if torch.cuda.is_available(): 236 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1)).cuda()) 237 | else: 238 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1))) 239 | 240 | else: 241 | h0 = hx 242 | 243 | outs = [] 244 | 245 | hidden = list() 246 | for layer in range(self.num_layers): 247 | hidden.append(h0[layer]) 248 | 249 | for t in range(input.size(1)): 250 | 251 | for layer in range(self.num_layers): 252 | 253 | if layer == 0: 254 | hidden_l = self.rnn_cell_list[layer](input[:, t, :], hidden[layer]) 255 | else: 256 | hidden_l = self.rnn_cell_list[layer](hidden[layer - 1],hidden[layer]) 257 | hidden[layer] = hidden_l 258 | 259 | hidden[layer] = hidden_l 260 | 261 | outs.append(hidden_l) 262 | 263 | # Take only last time step. Modify for seq to seq 264 | out = outs[-1].squeeze() 265 | 266 | out = self.conv(out) 267 | 268 | return out 269 | 270 | class Conv2dBidirRecurrentModel(nn.Module): 271 | def __init__(self, mode, input_size, hidden_size, kernel_size, num_layers, bias, output_size): 272 | super(Conv2dBidirRecurrentModel, self).__init__() 273 | 274 | self.mode = mode 275 | self.input_size = input_size 276 | self.hidden_size = hidden_size 277 | 278 | if type(kernel_size) == tuple and len(kernel_size) == 2: 279 | self.kernel_size = kernel_size 280 | self.padding = (kernel_size[0] // 2, kernel_size[1] // 2) 281 | elif type(kernel_size) == int: 282 | self.kernel_size = (kernel_size, kernel_size) 283 | self.padding = (kernel_size // 2, kernel_size // 2) 284 | else: 285 | raise ValueError("Invalid kernel size.") 286 | 287 | self.num_layers = num_layers 288 | self.bias = bias 289 | self.output_size = output_size 290 | 291 | self.rnn_cell_list = nn.ModuleList() 292 | 293 | if mode == 'LSTM': 294 | 295 | self.rnn_cell_list.append(Conv2dLSTMCell(self.input_size, 296 | self.hidden_size, 297 | self.kernel_size, 298 | self.bias)) 299 | for l in range(1, self.num_layers): 300 | self.rnn_cell_list.append(Conv2dLSTMCell(self.hidden_size, 301 | self.hidden_size, 302 | self.kernel_size, 303 | self.bias)) 304 | 305 | elif mode == 'GRU': 306 | self.rnn_cell_list.append(Conv2dGRUCell(self.input_size, 307 | self.hidden_size, 308 | self.kernel_size, 309 | self.bias)) 310 | for l in range(1, self.num_layers): 311 | self.rnn_cell_list.append(Conv2dGRUCell(self.hidden_size, 312 | self.hidden_size, 313 | self.kernel_size, 314 | self.bias)) 315 | 316 | elif mode == 'RNN_TANH': 317 | self.rnn_cell_list.append(Conv2dRNNCell(self.input_size, 318 | self.hidden_size, 319 | self.kernel_size, 320 | self.bias, 321 | "tanh")) 322 | for l in range(1, self.num_layers): 323 | self.rnn_cell_list.append(Conv2dRNNCell(self.hidden_size, 324 | self.hidden_size, 325 | self.kernel_size, 326 | self.bias, 327 | "tanh")) 328 | 329 | elif mode == 'RNN_RELU': 330 | self.rnn_cell_list.append(Conv2dRNNCell(self.input_size, 331 | self.hidden_size, 332 | self.kernel_size, 333 | self.bias, 334 | "relu")) 335 | for l in range(1, self.num_layers): 336 | self.rnn_cell_list.append(Conv2dRNNCell(self.hidden_size, 337 | self.hidden_size, 338 | self.kernel_size, 339 | self.bias, 340 | "relu")) 341 | else: 342 | raise ValueError("Invalid RNN mode selected.") 343 | 344 | self.conv = nn.Conv2d(in_channels=self.hidden_size * 2, 345 | out_channels=self.output_size, 346 | kernel_size=self.kernel_size, 347 | padding=self.padding, 348 | bias=self.bias) 349 | 350 | def forward(self, input, hx=None): 351 | 352 | # Input of shape (batch_size, sequence length, input_size) 353 | # 354 | # Output of shape (batch_size, output_size) 355 | 356 | if torch.cuda.is_available(): 357 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1)).cuda()) 358 | else: 359 | h0 = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1))) 360 | 361 | if torch.cuda.is_available(): 362 | hT = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1)).cuda()) 363 | else: 364 | hT = Variable(torch.zeros(self.num_layers, input.size(0), self.hidden_size, input.size(-2), input.size(-1))) 365 | 366 | outs = [] 367 | outs_rev = [] 368 | 369 | hidden_forward = list() 370 | for layer in range(self.num_layers): 371 | if self.mode == 'LSTM': 372 | hidden_forward.append((h0[layer], h0[layer])) 373 | else: 374 | hidden_forward.append(h0[layer]) 375 | 376 | hidden_backward = list() 377 | for layer in range(self.num_layers): 378 | if self.mode == 'LSTM': 379 | hidden_backward.append((hT[layer], hT[layer])) 380 | else: 381 | hidden_backward.append(hT[layer]) 382 | 383 | for t in range(input.shape[1]): 384 | for layer in range(self.num_layers): 385 | 386 | if self.mode == 'LSTM': 387 | # If LSTM 388 | if layer == 0: 389 | # Forward net 390 | h_forward_l = self.rnn_cell_list[layer]( 391 | input[:, t, :], 392 | (hidden_forward[layer][0], hidden_forward[layer][1]) 393 | ) 394 | # Backward net 395 | h_back_l = self.rnn_cell_list[layer]( 396 | input[:, -(t + 1), :], 397 | (hidden_backward[layer][0], hidden_backward[layer][1]) 398 | ) 399 | else: 400 | # Forward net 401 | h_forward_l = self.rnn_cell_list[layer]( 402 | hidden_forward[layer - 1][0], 403 | (hidden_forward[layer][0], hidden_forward[layer][1]) 404 | ) 405 | # Backward net 406 | h_back_l = self.rnn_cell_list[layer]( 407 | hidden_backward[layer - 1][0], 408 | (hidden_backward[layer][0], hidden_backward[layer][1]) 409 | ) 410 | 411 | else: 412 | # If RNN{_TANH/_RELU} / GRU 413 | if layer == 0: 414 | # Forward net 415 | h_forward_l = self.rnn_cell_list[layer](input[:, t, :], hidden_forward[layer]) 416 | # Backward net 417 | h_back_l = self.rnn_cell_list[layer](input[:, -(t + 1), :], hidden_backward[layer]) 418 | else: 419 | # Forward net 420 | h_forward_l = self.rnn_cell_list[layer](hidden_forward[layer - 1], hidden_forward[layer]) 421 | # Backward net 422 | h_back_l = self.rnn_cell_list[layer](hidden_backward[layer - 1], hidden_backward[layer]) 423 | 424 | 425 | hidden_forward[layer] = h_forward_l 426 | hidden_backward[layer] = h_back_l 427 | 428 | if self.mode == 'LSTM': 429 | 430 | outs.append(h_forward_l[0]) 431 | outs_rev.append(h_back_l[0]) 432 | 433 | else: 434 | outs.append(h_forward_l) 435 | outs_rev.append(h_back_l) 436 | 437 | # Take only last time step. Modify for seq to seq 438 | out = outs[-1].squeeze() 439 | out_rev = outs_rev[0].squeeze() 440 | out = torch.cat((out, out_rev), 1) 441 | 442 | out = self.conv(out) 443 | return out 444 | --------------------------------------------------------------------------------