├── .gitignore ├── LICENSE ├── README.md └── convolutional_rnn ├── __init__.py ├── functional.py ├── module.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | *.pyc 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Naoyuki Kamo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_convolutional_rnn 2 | 3 | The pytorch implemenation for convolutional rnn is alreaedy exisitng other than my module, for example. 4 | 5 | - https://github.com/ndrplz/ConvLSTM_pytorch 6 | - https://github.com/jacobkimmel/pytorch_convgru 7 | 8 | However, there are no modules supporting neither variable length tensor nor bidirectional rnn. 9 | 10 | I implemented ``AutogradConvRNN`` by referring to ``AutogradRNN`` at https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py, so my convolutional RNN modules have similar structure to ``torch.nn.RNN`` and supports the above features as it has. 11 | 12 | The benefit of using ``AutogradConvRNN`` is not only that it enables my modules to have the same interface as ``torch.nn.RNN``, but makes it very easy to implement many kinds of CRNN, such as ``CLSTM``, ``CGRU``. 13 | 14 | ## Require 15 | - python3 (Not supporting python2 because I prefer type annotation) 16 | - pytorch0.4.0, python1.0.0 17 | 18 | ## Feature 19 | - Implemented at python level, without any additional CUDA kernel, c++ codes. 20 | - Convolutional RNN, Convolutional LSTM, Convolutional Peephole LSTM, Convolutional GRU 21 | - Unidirectional, Bidirectional 22 | - 1d, 2d, 3d 23 | - Supporting PackedSequence (Supporting variable length tensor) 24 | - Supporting nlayers RNN and RNN Cell, both. 25 | - Not supporting different hidden sizes for each layers (But, it is very easy to implement it by stacking 1-layer-CRNNs) 26 | 27 | ## Example 28 | - With `pack_padded_sequence` 29 | ```python 30 | import torch 31 | import convolutional_rnn 32 | from torch.nn.utils.rnn import pack_padded_sequence 33 | 34 | in_channels = 2 35 | net = convolutional_rnn.Conv3dGRU(in_channels=in_channels, # Corresponds to input size 36 | out_channels=5, # Corresponds to hidden size 37 | kernel_size=(3, 4, 6), # Int or List[int] 38 |                                  num_layers=2, 39 | bidirectional=True, 40 | dilation=2, stride=2, dropout=0.5) 41 | length = 3 42 | batchsize = 2 43 | lengths = [3, 1] 44 | shape = (10, 14, 18) 45 | x = pack_padded_sequence(torch.randn(length, batchsize, in_channels, *shape), lengths, batch_first=False) 46 | h = None 47 | y, h = net(x, h) 48 | ``` 49 | 50 | - Without `pack_padded_sequence` 51 | ```python 52 | import torch 53 | import convolutional_rnn 54 | from torch.nn.utils.rnn import pack_padded_sequence 55 | 56 | in_channels = 2 57 | net = convolutional_rnn.Conv2dLSTM(in_channels=in_channels, # Corresponds to input size 58 | out_channels=5, # Corresponds to hidden size 59 | kernel_size=3, # Int or List[int] 60 |                                  num_layers=2, 61 | bidirectional=True, 62 | dilation=2, stride=2, dropout=0.5, 63 | batch_first=True) 64 | length = 3 65 | batchsize = 2 66 | shape = (10, 14) 67 | x = torch.randn(batchsize, length, in_channels, *shape) 68 | h = None 69 | y, h = net(x, h) 70 | ``` 71 | 72 | - With `Cell` 73 | ```python 74 | import torch 75 | import convolutional_rnn 76 | cell = convolutional_rnn.Conv2dLSTMCell(in_channels=3, out_channels=5, kernel_size=3).cuda() 77 | time = 6 78 | input = torch.randn(time, 16, 3, 10, 10).cuda() 79 | output = [] 80 | for i in range(time): 81 | if i == 0: 82 | hx, cx = cell(input[i]) 83 | else: 84 | hx, cx = cell(input[i], (hx, cx)) 85 | output.append(hx) 86 | 87 | ``` 88 | -------------------------------------------------------------------------------- /convolutional_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import Conv1dRNN 2 | from .module import Conv1dLSTM 3 | from .module import Conv1dPeepholeLSTM 4 | from .module import Conv1dGRU 5 | 6 | from .module import Conv2dRNN 7 | from .module import Conv2dLSTM 8 | from .module import Conv2dPeepholeLSTM 9 | from .module import Conv2dGRU 10 | 11 | from .module import Conv3dRNN 12 | from .module import Conv3dLSTM 13 | from .module import Conv3dPeepholeLSTM 14 | from .module import Conv3dGRU 15 | 16 | from .module import Conv1dRNNCell 17 | from .module import Conv1dLSTMCell 18 | from .module import Conv1dPeepholeLSTMCell 19 | from .module import Conv1dGRUCell 20 | 21 | from .module import Conv2dRNNCell 22 | from .module import Conv2dLSTMCell 23 | from .module import Conv2dPeepholeLSTMCell 24 | from .module import Conv2dGRUCell 25 | 26 | from .module import Conv3dRNNCell 27 | from .module import Conv3dLSTMCell 28 | from .module import Conv3dPeepholeLSTMCell 29 | from .module import Conv3dGRUCell 30 | -------------------------------------------------------------------------------- /convolutional_rnn/functional.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | try: 6 | # pytorch<=0.4.1 7 | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend 8 | except ImportError: 9 | fusedBackend = None 10 | 11 | from .utils import _single, _pair, _triple 12 | 13 | 14 | def RNNReLUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 15 | """ Copied from torch.nn._functions.rnn and modified """ 16 | if linear_func is None: 17 | linear_func = F.linear 18 | hy = F.relu(linear_func(input, w_ih, b_ih) + linear_func(hidden, w_hh, b_hh)) 19 | return hy 20 | 21 | 22 | def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 23 | """ Copied from torch.nn._functions.rnn and modified """ 24 | if linear_func is None: 25 | linear_func = F.linear 26 | hy = torch.tanh(linear_func(input, w_ih, b_ih) + linear_func(hidden, w_hh, b_hh)) 27 | return hy 28 | 29 | 30 | def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 31 | """ Copied from torch.nn._functions.rnn and modified """ 32 | if linear_func is None: 33 | linear_func = F.linear 34 | if input.is_cuda and linear_func is F.linear and fusedBackend is not None: 35 | igates = linear_func(input, w_ih) 36 | hgates = linear_func(hidden[0], w_hh) 37 | state = fusedBackend.LSTMFused.apply 38 | return state(igates, hgates, hidden[1]) if b_ih is None else state(igates, hgates, hidden[1], b_ih, b_hh) 39 | 40 | hx, cx = hidden 41 | gates = linear_func(input, w_ih, b_ih) + linear_func(hx, w_hh, b_hh) 42 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 43 | 44 | ingate = torch.sigmoid(ingate) 45 | forgetgate = torch.sigmoid(forgetgate) 46 | cellgate = torch.tanh(cellgate) 47 | outgate = torch.sigmoid(outgate) 48 | 49 | cy = (forgetgate * cx) + (ingate * cellgate) 50 | hy = outgate * torch.tanh(cy) 51 | 52 | return hy, cy 53 | 54 | 55 | def PeepholeLSTMCell(input, hidden, w_ih, w_hh, w_pi, w_pf, w_po, 56 | b_ih=None, b_hh=None, linear_func=None): 57 | if linear_func is None: 58 | linear_func = F.linear 59 | hx, cx = hidden 60 | gates = linear_func(input, w_ih, b_ih) + linear_func(hx, w_hh, b_hh) 61 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 62 | 63 | ingate += linear_func(cx, w_pi) 64 | forgetgate += linear_func(cx, w_pf) 65 | ingate = torch.sigmoid(ingate) 66 | forgetgate = torch.sigmoid(forgetgate) 67 | cellgate = torch.tanh(cellgate) 68 | 69 | cy = (forgetgate * cx) + (ingate * cellgate) 70 | outgate += linear_func(cy, w_po) 71 | outgate = torch.sigmoid(outgate) 72 | 73 | hy = outgate * torch.tanh(cy) 74 | 75 | return hy, cy 76 | 77 | 78 | def GRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, linear_func=None): 79 | """ Copied from torch.nn._functions.rnn and modified """ 80 | if linear_func is None: 81 | linear_func = F.linear 82 | if input.is_cuda and linear_func is F.linear and fusedBackend is not None: 83 | gi = linear_func(input, w_ih) 84 | gh = linear_func(hidden, w_hh) 85 | state = fusedBackend.GRUFused.apply 86 | return state(gi, gh, hidden) if b_ih is None else state(gi, gh, hidden, b_ih, b_hh) 87 | gi = linear_func(input, w_ih, b_ih) 88 | gh = linear_func(hidden, w_hh, b_hh) 89 | i_r, i_i, i_n = gi.chunk(3, 1) 90 | h_r, h_i, h_n = gh.chunk(3, 1) 91 | 92 | resetgate = torch.sigmoid(i_r + h_r) 93 | inputgate = torch.sigmoid(i_i + h_i) 94 | newgate = torch.tanh(i_n + resetgate * h_n) 95 | hy = newgate + inputgate * (hidden - newgate) 96 | 97 | return hy 98 | 99 | 100 | def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): 101 | """ Copied from torch.nn._functions.rnn and modified """ 102 | 103 | num_directions = len(inners) 104 | total_layers = num_layers * num_directions 105 | 106 | def forward(input, hidden, weight, batch_sizes): 107 | assert(len(weight) == total_layers) 108 | next_hidden = [] 109 | ch_dim = input.dim() - weight[0][0].dim() + 1 110 | 111 | if lstm: 112 | hidden = list(zip(*hidden)) 113 | 114 | for i in range(num_layers): 115 | all_output = [] 116 | for j, inner in enumerate(inners): 117 | l = i * num_directions + j 118 | 119 | hy, output = inner(input, hidden[l], weight[l], batch_sizes) 120 | next_hidden.append(hy) 121 | all_output.append(output) 122 | 123 | input = torch.cat(all_output, ch_dim) 124 | 125 | if dropout != 0 and i < num_layers - 1: 126 | input = F.dropout(input, p=dropout, training=train, inplace=False) 127 | 128 | if lstm: 129 | next_h, next_c = zip(*next_hidden) 130 | next_hidden = ( 131 | torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), 132 | torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) 133 | ) 134 | else: 135 | next_hidden = torch.cat(next_hidden, 0).view( 136 | total_layers, *next_hidden[0].size()) 137 | 138 | return next_hidden, input 139 | 140 | return forward 141 | 142 | 143 | def Recurrent(inner, reverse=False): 144 | """ Copied from torch.nn._functions.rnn without any modification """ 145 | def forward(input, hidden, weight, batch_sizes): 146 | output = [] 147 | steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) 148 | for i in steps: 149 | hidden = inner(input[i], hidden, *weight) 150 | # hack to handle LSTM 151 | output.append(hidden[0] if isinstance(hidden, tuple) else hidden) 152 | 153 | if reverse: 154 | output.reverse() 155 | output = torch.cat(output, 0).view(input.size(0), *output[0].size()) 156 | 157 | return hidden, output 158 | 159 | return forward 160 | 161 | 162 | def variable_recurrent_factory(inner, reverse=False): 163 | """ Copied from torch.nn._functions.rnn without any modification """ 164 | if reverse: 165 | return VariableRecurrentReverse(inner) 166 | else: 167 | return VariableRecurrent(inner) 168 | 169 | 170 | def VariableRecurrent(inner): 171 | """ Copied from torch.nn._functions.rnn without any modification """ 172 | def forward(input, hidden, weight, batch_sizes): 173 | output = [] 174 | input_offset = 0 175 | last_batch_size = batch_sizes[0] 176 | hiddens = [] 177 | flat_hidden = not isinstance(hidden, tuple) 178 | if flat_hidden: 179 | hidden = (hidden,) 180 | for batch_size in batch_sizes: 181 | step_input = input[input_offset:input_offset + batch_size] 182 | input_offset += batch_size 183 | 184 | dec = last_batch_size - batch_size 185 | if dec > 0: 186 | hiddens.append(tuple(h[-dec:] for h in hidden)) 187 | hidden = tuple(h[:-dec] for h in hidden) 188 | last_batch_size = batch_size 189 | 190 | if flat_hidden: 191 | hidden = (inner(step_input, hidden[0], *weight),) 192 | else: 193 | hidden = inner(step_input, hidden, *weight) 194 | 195 | output.append(hidden[0]) 196 | hiddens.append(hidden) 197 | hiddens.reverse() 198 | 199 | hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) 200 | assert hidden[0].size(0) == batch_sizes[0] 201 | if flat_hidden: 202 | hidden = hidden[0] 203 | output = torch.cat(output, 0) 204 | 205 | return hidden, output 206 | 207 | return forward 208 | 209 | 210 | def VariableRecurrentReverse(inner): 211 | """ Copied from torch.nn._functions.rnn without any modification """ 212 | def forward(input, hidden, weight, batch_sizes): 213 | output = [] 214 | input_offset = input.size(0) 215 | last_batch_size = batch_sizes[-1] 216 | initial_hidden = hidden 217 | flat_hidden = not isinstance(hidden, tuple) 218 | if flat_hidden: 219 | hidden = (hidden,) 220 | initial_hidden = (initial_hidden,) 221 | hidden = tuple(h[:batch_sizes[-1]] for h in hidden) 222 | for i in reversed(range(len(batch_sizes))): 223 | batch_size = batch_sizes[i] 224 | inc = batch_size - last_batch_size 225 | if inc > 0: 226 | hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) 227 | for h, ih in zip(hidden, initial_hidden)) 228 | last_batch_size = batch_size 229 | step_input = input[input_offset - batch_size:input_offset] 230 | input_offset -= batch_size 231 | 232 | if flat_hidden: 233 | hidden = (inner(step_input, hidden[0], *weight),) 234 | else: 235 | hidden = inner(step_input, hidden, *weight) 236 | output.append(hidden[0]) 237 | 238 | output.reverse() 239 | output = torch.cat(output, 0) 240 | if flat_hidden: 241 | hidden = hidden[0] 242 | return hidden, output 243 | 244 | return forward 245 | 246 | 247 | def ConvNdWithSamePadding(convndim=2, stride=1, dilation=1, groups=1): 248 | def forward(input, w, b=None): 249 | if convndim == 1: 250 | ntuple = _single 251 | elif convndim == 2: 252 | ntuple = _pair 253 | elif convndim == 3: 254 | ntuple = _triple 255 | else: 256 | raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) 257 | 258 | if input.dim() != convndim + 2: 259 | raise RuntimeError('Input dim must be {}, bot got {}'.format(convndim + 2, input.dim())) 260 | if w.dim() != convndim + 2: 261 | raise RuntimeError('w must be {}, bot got {}'.format(convndim + 2, w.dim())) 262 | 263 | insize = input.shape[2:] 264 | kernel_size = w.shape[2:] 265 | _stride = ntuple(stride) 266 | _dilation = ntuple(dilation) 267 | 268 | ps = [(i + 1 - h + s * (h - 1) + d * (k - 1)) // 2 269 | for h, k, s, d in list(zip(insize, kernel_size, _stride, _dilation))[::-1] for i in range(2)] 270 | # Padding to make the output shape to have the same shape as the input 271 | input = F.pad(input, ps, 'constant', 0) 272 | return getattr(F, 'conv{}d'.format(convndim))( 273 | input, w, b, stride=_stride, padding=ntuple(0), dilation=_dilation, groups=groups) 274 | return forward 275 | 276 | 277 | def _conv_cell_helper(mode, convndim=2, stride=1, dilation=1, groups=1): 278 | linear_func = ConvNdWithSamePadding(convndim=convndim, stride=stride, dilation=dilation, groups=groups) 279 | 280 | if mode == 'RNN_RELU': 281 | cell = partial(RNNReLUCell, linear_func=linear_func) 282 | elif mode == 'RNN_TANH': 283 | cell = partial(RNNTanhCell, linear_func=linear_func) 284 | elif mode == 'LSTM': 285 | cell = partial(LSTMCell, linear_func=linear_func) 286 | elif mode == 'GRU': 287 | cell = partial(GRUCell, linear_func=linear_func) 288 | elif mode == 'PeepholeLSTM': 289 | cell = partial(PeepholeLSTMCell, linear_func=linear_func) 290 | else: 291 | raise Exception('Unknown mode: {}'.format(mode)) 292 | return cell 293 | 294 | 295 | def AutogradConvRNN( 296 | mode, num_layers=1, batch_first=False, 297 | dropout=0, train=True, bidirectional=False, variable_length=False, 298 | convndim=2, stride=1, dilation=1, groups=1): 299 | """ Copied from torch.nn._functions.rnn and modified """ 300 | cell = _conv_cell_helper(mode, convndim=convndim, stride=stride, dilation=dilation, groups=groups) 301 | 302 | rec_factory = variable_recurrent_factory if variable_length else Recurrent 303 | 304 | if bidirectional: 305 | layer = (rec_factory(cell), rec_factory(cell, reverse=True)) 306 | else: 307 | layer = (rec_factory(cell),) 308 | 309 | func = StackedRNN(layer, num_layers, (mode in ('LSTM', 'PeepholeLSTM')), dropout=dropout, train=train) 310 | 311 | def forward(input, weight, hidden, batch_sizes): 312 | if batch_first and batch_sizes is None: 313 | input = input.transpose(0, 1) 314 | 315 | nexth, output = func(input, hidden, weight, batch_sizes) 316 | 317 | if batch_first and batch_sizes is None: 318 | output = output.transpose(0, 1) 319 | 320 | return output, nexth 321 | 322 | return forward 323 | -------------------------------------------------------------------------------- /convolutional_rnn/module.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, Sequence 3 | 4 | import torch 5 | from torch.nn import Parameter 6 | from torch.nn.utils.rnn import PackedSequence 7 | 8 | from .functional import AutogradConvRNN, _conv_cell_helper 9 | from .utils import _single, _pair, _triple 10 | 11 | 12 | class ConvNdRNNBase(torch.nn.Module): 13 | def __init__(self, 14 | mode: str, 15 | in_channels: int, 16 | out_channels: int, 17 | kernel_size: Union[int, Sequence[int]], 18 | num_layers: int=1, 19 | bias: bool=True, 20 | batch_first: bool=False, 21 | dropout: float=0., 22 | bidirectional: bool=False, 23 | convndim: int=2, 24 | stride: Union[int, Sequence[int]]=1, 25 | dilation: Union[int, Sequence[int]]=1, 26 | groups: int=1): 27 | super().__init__() 28 | self.mode = mode 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.num_layers = num_layers 32 | self.bias = bias 33 | self.batch_first = batch_first 34 | self.dropout = dropout 35 | self.bidirectional = bidirectional 36 | self.convndim = convndim 37 | 38 | if convndim == 1: 39 | ntuple = _single 40 | elif convndim == 2: 41 | ntuple = _pair 42 | elif convndim == 3: 43 | ntuple = _triple 44 | else: 45 | raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) 46 | 47 | self.kernel_size = ntuple(kernel_size) 48 | self.stride = ntuple(stride) 49 | self.dilation = ntuple(dilation) 50 | 51 | self.groups = groups 52 | 53 | num_directions = 2 if bidirectional else 1 54 | 55 | if mode in ('LSTM', 'PeepholeLSTM'): 56 | gate_size = 4 * out_channels 57 | elif mode == 'GRU': 58 | gate_size = 3 * out_channels 59 | else: 60 | gate_size = out_channels 61 | 62 | self._all_weights = [] 63 | for layer in range(num_layers): 64 | for direction in range(num_directions): 65 | layer_input_size = in_channels if layer == 0 else out_channels * num_directions 66 | w_ih = Parameter(torch.Tensor(gate_size, layer_input_size // groups, *self.kernel_size)) 67 | w_hh = Parameter(torch.Tensor(gate_size, out_channels // groups, *self.kernel_size)) 68 | 69 | b_ih = Parameter(torch.Tensor(gate_size)) 70 | b_hh = Parameter(torch.Tensor(gate_size)) 71 | 72 | if mode == 'PeepholeLSTM': 73 | w_pi = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 74 | w_pf = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 75 | w_po = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 76 | layer_params = (w_ih, w_hh, w_pi, w_pf, w_po, b_ih, b_hh) 77 | param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 78 | 'weight_pi_l{}{}', 'weight_pf_l{}{}', 'weight_po_l{}{}'] 79 | else: 80 | layer_params = (w_ih, w_hh, b_ih, b_hh) 81 | param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] 82 | if bias: 83 | param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] 84 | 85 | suffix = '_reverse' if direction == 1 else '' 86 | param_names = [x.format(layer, suffix) for x in param_names] 87 | 88 | for name, param in zip(param_names, layer_params): 89 | setattr(self, name, param) 90 | self._all_weights.append(param_names) 91 | 92 | self.reset_parameters() 93 | 94 | def reset_parameters(self): 95 | stdv = 1.0 / math.sqrt(self.out_channels) 96 | for weight in self.parameters(): 97 | weight.data.uniform_(-stdv, stdv) 98 | 99 | def check_forward_args(self, input, hidden, batch_sizes): 100 | is_input_packed = batch_sizes is not None 101 | expected_input_dim = (2 if is_input_packed else 3) + self.convndim 102 | if input.dim() != expected_input_dim: 103 | raise RuntimeError( 104 | 'input must have {} dimensions, got {}'.format( 105 | expected_input_dim, input.dim())) 106 | ch_dim = 1 if is_input_packed else 2 107 | if self.in_channels != input.size(ch_dim): 108 | raise RuntimeError( 109 | 'input.size({}) must be equal to in_channels . Expected {}, got {}'.format( 110 | ch_dim, self.in_channels, input.size(ch_dim))) 111 | 112 | if is_input_packed: 113 | mini_batch = int(batch_sizes[0]) 114 | else: 115 | mini_batch = input.size(0) if self.batch_first else input.size(1) 116 | 117 | num_directions = 2 if self.bidirectional else 1 118 | expected_hidden_size = (self.num_layers * num_directions, 119 | mini_batch, self.out_channels) + input.shape[ch_dim + 1:] 120 | 121 | def check_hidden_size(hx, expected_hidden_size, msg='Expected hidden size {}, got {}'): 122 | if tuple(hx.size()) != expected_hidden_size: 123 | raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) 124 | 125 | if self.mode in ('LSTM', 'PeepholeLSTM'): 126 | check_hidden_size(hidden[0], expected_hidden_size, 127 | 'Expected hidden[0] size {}, got {}') 128 | check_hidden_size(hidden[1], expected_hidden_size, 129 | 'Expected hidden[1] size {}, got {}') 130 | else: 131 | check_hidden_size(hidden, expected_hidden_size) 132 | 133 | def forward(self, input, hx=None): 134 | is_packed = isinstance(input, PackedSequence) 135 | if is_packed: 136 | input, batch_sizes = input 137 | max_batch_size = batch_sizes[0] 138 | insize = input.shape[2:] 139 | else: 140 | batch_sizes = None 141 | max_batch_size = input.size(0) if self.batch_first else input.size(1) 142 | insize = input.shape[3:] 143 | 144 | if hx is None: 145 | num_directions = 2 if self.bidirectional else 1 146 | hx = input.new_zeros(self.num_layers * num_directions, max_batch_size, self.out_channels, 147 | *insize, requires_grad=False) 148 | if self.mode in ('LSTM', 'PeepholeLSTM'): 149 | hx = (hx, hx) 150 | 151 | self.check_forward_args(input, hx, batch_sizes) 152 | func = AutogradConvRNN( 153 | self.mode, 154 | num_layers=self.num_layers, 155 | batch_first=self.batch_first, 156 | dropout=self.dropout, 157 | train=self.training, 158 | bidirectional=self.bidirectional, 159 | variable_length=batch_sizes is not None, 160 | convndim=self.convndim, 161 | stride=self.stride, 162 | dilation=self.dilation, 163 | groups=self.groups 164 | ) 165 | output, hidden = func(input, self.all_weights, hx, batch_sizes) 166 | if is_packed: 167 | output = PackedSequence(output, batch_sizes) 168 | return output, hidden 169 | 170 | def extra_repr(self): 171 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 172 | ', stride={stride}') 173 | if self.dilation != (1,) * len(self.dilation): 174 | s += ', dilation={dilation}' 175 | if self.groups != 1: 176 | s += ', groups={groups}' 177 | if self.num_layers != 1: 178 | s += ', num_layers={num_layers}' 179 | if self.bias is not True: 180 | s += ', bias={bias}' 181 | if self.batch_first is not False: 182 | s += ', batch_first={batch_first}' 183 | if self.dropout != 0: 184 | s += ', dropout={dropout}' 185 | if self.bidirectional is not False: 186 | s += ', bidirectional={bidirectional}' 187 | return s.format(**self.__dict__) 188 | 189 | def __setstate__(self, d): 190 | super(ConvNdRNNBase, self).__setstate__(d) 191 | if 'all_weights' in d: 192 | self._all_weights = d['all_weights'] 193 | if isinstance(self._all_weights[0][0], str): 194 | return 195 | num_layers = self.num_layers 196 | num_directions = 2 if self.bidirectional else 1 197 | self._all_weights = [] 198 | for layer in range(num_layers): 199 | for direction in range(num_directions): 200 | suffix = '_reverse' if direction == 1 else '' 201 | if self.mode == 'PeepholeLSTM': 202 | weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 203 | 'weight_pi_l{}{}', 'weight_pf_l{}{}', 'weight_po_l{}{}', 204 | 'bias_ih_l{}{}', 'bias_hh_l{}{}'] 205 | else: 206 | weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 207 | 'bias_ih_l{}{}', 'bias_hh_l{}{}'] 208 | weights = [x.format(layer, suffix) for x in weights] 209 | if self.bias: 210 | self._all_weights += [weights] 211 | else: 212 | self._all_weights += [weights[:len(weights) // 2]] 213 | 214 | @property 215 | def all_weights(self): 216 | return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] 217 | 218 | 219 | class Conv1dRNN(ConvNdRNNBase): 220 | def __init__(self, 221 | in_channels: int, 222 | out_channels: int, 223 | kernel_size: Union[int, Sequence[int]], 224 | nonlinearity: str='tanh', 225 | num_layers: int=1, 226 | bias: bool=True, 227 | batch_first: bool=False, 228 | dropout: float=0., 229 | bidirectional: bool=False, 230 | stride: Union[int, Sequence[int]]=1, 231 | dilation: Union[int, Sequence[int]]=1, 232 | groups: int=1): 233 | if nonlinearity == 'tanh': 234 | mode = 'RNN_TANH' 235 | elif nonlinearity == 'relu': 236 | mode = 'RNN_RELU' 237 | else: 238 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 239 | super().__init__( 240 | mode=mode, 241 | in_channels=in_channels, 242 | out_channels=out_channels, 243 | kernel_size=kernel_size, 244 | num_layers=num_layers, 245 | bias=bias, 246 | batch_first=batch_first, 247 | dropout=dropout, 248 | bidirectional=bidirectional, 249 | convndim=1, 250 | stride=stride, 251 | dilation=dilation, 252 | groups=groups) 253 | 254 | 255 | class Conv1dPeepholeLSTM(ConvNdRNNBase): 256 | def __init__(self, 257 | in_channels: int, 258 | out_channels: int, 259 | kernel_size: Union[int, Sequence[int]], 260 | num_layers: int=1, 261 | bias: bool=True, 262 | batch_first: bool=False, 263 | dropout: float=0., 264 | bidirectional: bool=False, 265 | stride: Union[int, Sequence[int]]=1, 266 | dilation: Union[int, Sequence[int]]=1, 267 | groups: int=1): 268 | super().__init__( 269 | mode='PeepholeLSTM', 270 | in_channels=in_channels, 271 | out_channels=out_channels, 272 | kernel_size=kernel_size, 273 | num_layers=num_layers, 274 | bias=bias, 275 | batch_first=batch_first, 276 | dropout=dropout, 277 | bidirectional=bidirectional, 278 | convndim=1, 279 | stride=stride, 280 | dilation=dilation, 281 | groups=groups) 282 | 283 | 284 | class Conv1dLSTM(ConvNdRNNBase): 285 | def __init__(self, 286 | in_channels: int, 287 | out_channels: int, 288 | kernel_size: Union[int, Sequence[int]], 289 | num_layers: int=1, 290 | bias: bool=True, 291 | batch_first: bool=False, 292 | dropout: float=0., 293 | bidirectional: bool=False, 294 | stride: Union[int, Sequence[int]]=1, 295 | dilation: Union[int, Sequence[int]]=1, 296 | groups: int=1): 297 | super().__init__( 298 | mode='LSTM', 299 | in_channels=in_channels, 300 | out_channels=out_channels, 301 | kernel_size=kernel_size, 302 | num_layers=num_layers, 303 | bias=bias, 304 | batch_first=batch_first, 305 | dropout=dropout, 306 | bidirectional=bidirectional, 307 | convndim=1, 308 | stride=stride, 309 | dilation=dilation, 310 | groups=groups) 311 | 312 | 313 | class Conv1dGRU(ConvNdRNNBase): 314 | def __init__(self, 315 | in_channels: int, 316 | out_channels: int, 317 | kernel_size: Union[int, Sequence[int]], 318 | num_layers: int=1, 319 | bias: bool=True, 320 | batch_first: bool=False, 321 | dropout: float=0., 322 | bidirectional: bool=False, 323 | stride: Union[int, Sequence[int]]=1, 324 | dilation: Union[int, Sequence[int]]=1, 325 | groups: int=1): 326 | super().__init__( 327 | mode='GRU', 328 | in_channels=in_channels, 329 | out_channels=out_channels, 330 | kernel_size=kernel_size, 331 | num_layers=num_layers, 332 | bias=bias, 333 | batch_first=batch_first, 334 | dropout=dropout, 335 | bidirectional=bidirectional, 336 | convndim=1, 337 | stride=stride, 338 | dilation=dilation, 339 | groups=groups) 340 | 341 | 342 | class Conv2dRNN(ConvNdRNNBase): 343 | def __init__(self, 344 | in_channels: int, 345 | out_channels: int, 346 | kernel_size: Union[int, Sequence[int]], 347 | nonlinearity: str='tanh', 348 | num_layers: int=1, 349 | bias: bool=True, 350 | batch_first: bool=False, 351 | dropout: float=0., 352 | bidirectional: bool=False, 353 | stride: Union[int, Sequence[int]]=1, 354 | dilation: Union[int, Sequence[int]]=1, 355 | groups: int=1): 356 | if nonlinearity == 'tanh': 357 | mode = 'RNN_TANH' 358 | elif nonlinearity == 'relu': 359 | mode = 'RNN_RELU' 360 | else: 361 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 362 | super().__init__( 363 | mode=mode, 364 | in_channels=in_channels, 365 | out_channels=out_channels, 366 | kernel_size=kernel_size, 367 | num_layers=num_layers, 368 | bias=bias, 369 | batch_first=batch_first, 370 | dropout=dropout, 371 | bidirectional=bidirectional, 372 | convndim=2, 373 | stride=stride, 374 | dilation=dilation, 375 | groups=groups) 376 | 377 | 378 | class Conv2dLSTM(ConvNdRNNBase): 379 | def __init__(self, 380 | in_channels: int, 381 | out_channels: int, 382 | kernel_size: Union[int, Sequence[int]], 383 | num_layers: int=1, 384 | bias: bool=True, 385 | batch_first: bool=False, 386 | dropout: float=0., 387 | bidirectional: bool=False, 388 | stride: Union[int, Sequence[int]]=1, 389 | dilation: Union[int, Sequence[int]]=1, 390 | groups: int=1): 391 | super().__init__( 392 | mode='LSTM', 393 | in_channels=in_channels, 394 | out_channels=out_channels, 395 | kernel_size=kernel_size, 396 | num_layers=num_layers, 397 | bias=bias, 398 | batch_first=batch_first, 399 | dropout=dropout, 400 | bidirectional=bidirectional, 401 | convndim=2, 402 | stride=stride, 403 | dilation=dilation, 404 | groups=groups) 405 | 406 | 407 | class Conv2dPeepholeLSTM(ConvNdRNNBase): 408 | def __init__(self, 409 | in_channels: int, 410 | out_channels: int, 411 | kernel_size: Union[int, Sequence[int]], 412 | num_layers: int=1, 413 | bias: bool=True, 414 | batch_first: bool=False, 415 | dropout: float=0., 416 | bidirectional: bool=False, 417 | stride: Union[int, Sequence[int]]=1, 418 | dilation: Union[int, Sequence[int]]=1, 419 | groups: int=1): 420 | super().__init__( 421 | mode='PeepholeLSTM', 422 | in_channels=in_channels, 423 | out_channels=out_channels, 424 | kernel_size=kernel_size, 425 | num_layers=num_layers, 426 | bias=bias, 427 | batch_first=batch_first, 428 | dropout=dropout, 429 | bidirectional=bidirectional, 430 | convndim=2, 431 | stride=stride, 432 | dilation=dilation, 433 | groups=groups) 434 | 435 | 436 | class Conv2dGRU(ConvNdRNNBase): 437 | def __init__(self, 438 | in_channels: int, 439 | out_channels: int, 440 | kernel_size: Union[int, Sequence[int]], 441 | num_layers: int=1, 442 | bias: bool=True, 443 | batch_first: bool=False, 444 | dropout: float=0., 445 | bidirectional: bool=False, 446 | stride: Union[int, Sequence[int]]=1, 447 | dilation: Union[int, Sequence[int]]=1, 448 | groups: int=1): 449 | super().__init__( 450 | mode='GRU', 451 | in_channels=in_channels, 452 | out_channels=out_channels, 453 | kernel_size=kernel_size, 454 | num_layers=num_layers, 455 | bias=bias, 456 | batch_first=batch_first, 457 | dropout=dropout, 458 | bidirectional=bidirectional, 459 | convndim=2, 460 | stride=stride, 461 | dilation=dilation, 462 | groups=groups) 463 | 464 | 465 | class Conv3dRNN(ConvNdRNNBase): 466 | def __init__(self, 467 | in_channels: int, 468 | out_channels: int, 469 | kernel_size: Union[int, Sequence[int]], 470 | nonlinearity: str='tanh', 471 | num_layers: int=1, 472 | bias: bool=True, 473 | batch_first: bool=False, 474 | dropout: float=0., 475 | bidirectional: bool=False, 476 | stride: Union[int, Sequence[int]]=1, 477 | dilation: Union[int, Sequence[int]]=1, 478 | groups: int=1): 479 | if nonlinearity == 'tanh': 480 | mode = 'RNN_TANH' 481 | elif nonlinearity == 'relu': 482 | mode = 'RNN_RELU' 483 | else: 484 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 485 | super().__init__( 486 | mode=mode, 487 | in_channels=in_channels, 488 | out_channels=out_channels, 489 | kernel_size=kernel_size, 490 | num_layers=num_layers, 491 | bias=bias, 492 | batch_first=batch_first, 493 | dropout=dropout, 494 | bidirectional=bidirectional, 495 | convndim=3, 496 | stride=stride, 497 | dilation=dilation, 498 | groups=groups) 499 | 500 | 501 | class Conv3dLSTM(ConvNdRNNBase): 502 | def __init__(self, 503 | in_channels: int, 504 | out_channels: int, 505 | kernel_size: Union[int, Sequence[int]], 506 | num_layers: int=1, 507 | bias: bool=True, 508 | batch_first: bool=False, 509 | dropout: float=0., 510 | bidirectional: bool=False, 511 | stride: Union[int, Sequence[int]]=1, 512 | dilation: Union[int, Sequence[int]]=1, 513 | groups: int=1): 514 | super().__init__( 515 | mode='LSTM', 516 | in_channels=in_channels, 517 | out_channels=out_channels, 518 | kernel_size=kernel_size, 519 | num_layers=num_layers, 520 | bias=bias, 521 | batch_first=batch_first, 522 | dropout=dropout, 523 | bidirectional=bidirectional, 524 | convndim=3, 525 | stride=stride, 526 | dilation=dilation, 527 | groups=groups) 528 | 529 | 530 | class Conv3dPeepholeLSTM(ConvNdRNNBase): 531 | def __init__(self, 532 | in_channels: int, 533 | out_channels: int, 534 | kernel_size: Union[int, Sequence[int]], 535 | num_layers: int=1, 536 | bias: bool=True, 537 | batch_first: bool=False, 538 | dropout: float=0., 539 | bidirectional: bool=False, 540 | stride: Union[int, Sequence[int]]=1, 541 | dilation: Union[int, Sequence[int]]=1, 542 | groups: int=1): 543 | super().__init__( 544 | mode='PeepholeLSTM', 545 | in_channels=in_channels, 546 | out_channels=out_channels, 547 | kernel_size=kernel_size, 548 | num_layers=num_layers, 549 | bias=bias, 550 | batch_first=batch_first, 551 | dropout=dropout, 552 | bidirectional=bidirectional, 553 | convndim=3, 554 | stride=stride, 555 | dilation=dilation, 556 | groups=groups) 557 | 558 | 559 | class Conv3dGRU(ConvNdRNNBase): 560 | def __init__(self, 561 | in_channels: int, 562 | out_channels: int, 563 | kernel_size: Union[int, Sequence[int]], 564 | num_layers: int=1, 565 | bias: bool=True, 566 | batch_first: bool=False, 567 | dropout: float=0., 568 | bidirectional: bool=False, 569 | stride: Union[int, Sequence[int]]=1, 570 | dilation: Union[int, Sequence[int]]=1, 571 | groups: int=1): 572 | super().__init__( 573 | mode='GRU', 574 | in_channels=in_channels, 575 | out_channels=out_channels, 576 | kernel_size=kernel_size, 577 | num_layers=num_layers, 578 | bias=bias, 579 | batch_first=batch_first, 580 | dropout=dropout, 581 | bidirectional=bidirectional, 582 | convndim=3, 583 | stride=stride, 584 | dilation=dilation, 585 | groups=groups) 586 | 587 | 588 | class ConvRNNCellBase(torch.nn.Module): 589 | def __init__(self, 590 | mode: str, 591 | in_channels: int, 592 | out_channels: int, 593 | kernel_size: Union[int, Sequence[int]], 594 | bias: bool=True, 595 | convndim: int=2, 596 | stride: Union[int, Sequence[int]]=1, 597 | dilation: Union[int, Sequence[int]]=1, 598 | groups: int=1 599 | ): 600 | super().__init__() 601 | self.mode = mode 602 | self.in_channels = in_channels 603 | self.out_channels = out_channels 604 | self.bias = bias 605 | self.convndim = convndim 606 | 607 | if convndim == 1: 608 | ntuple = _single 609 | elif convndim == 2: 610 | ntuple = _pair 611 | elif convndim == 3: 612 | ntuple = _triple 613 | else: 614 | raise ValueError('convndim must be 1, 2, or 3, but got {}'.format(convndim)) 615 | 616 | self.kernel_size = ntuple(kernel_size) 617 | self.stride = ntuple(stride) 618 | self.dilation = ntuple(dilation) 619 | 620 | self.groups = groups 621 | 622 | if mode in ('LSTM', 'PeepholeLSTM'): 623 | gate_size = 4 * out_channels 624 | elif mode == 'GRU': 625 | gate_size = 3 * out_channels 626 | else: 627 | gate_size = out_channels 628 | 629 | self.weight_ih = Parameter(torch.Tensor(gate_size, in_channels // groups, *self.kernel_size)) 630 | self.weight_hh = Parameter(torch.Tensor(gate_size, out_channels // groups, *self.kernel_size)) 631 | 632 | if bias: 633 | self.bias_ih = Parameter(torch.Tensor(gate_size)) 634 | self.bias_hh = Parameter(torch.Tensor(gate_size)) 635 | else: 636 | self.register_parameter('bias_ih', None) 637 | self.register_parameter('bias_hh', None) 638 | 639 | if mode == 'PeepholeLSTM': 640 | self.weight_pi = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 641 | self.weight_pf = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 642 | self.weight_po = Parameter(torch.Tensor(out_channels, out_channels // groups, *self.kernel_size)) 643 | 644 | self.reset_parameters() 645 | 646 | def extra_repr(self): 647 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 648 | ', stride={stride}') 649 | if self.dilation != (1,) * len(self.dilation): 650 | s += ', dilation={dilation}' 651 | if self.groups != 1: 652 | s += ', groups={groups}' 653 | if self.bias is not True: 654 | s += ', bias={bias}' 655 | if self.bidirectional is not False: 656 | s += ', bidirectional={bidirectional}' 657 | return s.format(**self.__dict__) 658 | 659 | def check_forward_input(self, input): 660 | if input.size(1) != self.in_channels: 661 | raise RuntimeError( 662 | "input has inconsistent channels: got {}, expected {}".format( 663 | input.size(1), self.in_channels)) 664 | 665 | def check_forward_hidden(self, input, hx, hidden_label=''): 666 | if input.size(0) != hx.size(0): 667 | raise RuntimeError( 668 | "Input batch size {} doesn't match hidden{} batch size {}".format( 669 | input.size(0), hidden_label, hx.size(0))) 670 | 671 | if hx.size(1) != self.out_channels: 672 | raise RuntimeError( 673 | "hidden{} has inconsistent hidden_size: got {}, expected {}".format( 674 | hidden_label, hx.size(1), self.out_channels)) 675 | 676 | def reset_parameters(self): 677 | stdv = 1.0 / math.sqrt(self.out_channels) 678 | for weight in self.parameters(): 679 | weight.data.uniform_(-stdv, stdv) 680 | 681 | def forward(self, input, hx=None): 682 | self.check_forward_input(input) 683 | 684 | if hx is None: 685 | batch_size = input.size(0) 686 | insize = input.shape[2:] 687 | hx = input.new_zeros(batch_size, self.out_channels, *insize, requires_grad=False) 688 | if self.mode in ('LSTM', 'PeepholeLSTM'): 689 | hx = (hx, hx) 690 | if self.mode in ('LSTM', 'PeepholeLSTM'): 691 | self.check_forward_hidden(input, hx[0]) 692 | self.check_forward_hidden(input, hx[1]) 693 | else: 694 | self.check_forward_hidden(input, hx) 695 | 696 | cell = _conv_cell_helper( 697 | self.mode, 698 | convndim=self.convndim, 699 | stride=self.stride, 700 | dilation=self.dilation, 701 | groups=self.groups) 702 | if self.mode == 'PeepholeLSTM': 703 | return cell( 704 | input, hx, 705 | self.weight_ih, self.weight_hh, self.weight_pi, self.weight_pf, self.weight_po, 706 | self.bias_ih, self.bias_hh 707 | ) 708 | else: 709 | return cell( 710 | input, hx, 711 | self.weight_ih, self.weight_hh, 712 | self.bias_ih, self.bias_hh, 713 | ) 714 | 715 | 716 | class Conv1dRNNCell(ConvRNNCellBase): 717 | def __init__(self, 718 | in_channels: int, 719 | out_channels: int, 720 | kernel_size: Union[int, Sequence[int]], 721 | nonlinearity: str='tanh', 722 | bias: bool=True, 723 | stride: Union[int, Sequence[int]]=1, 724 | dilation: Union[int, Sequence[int]]=1, 725 | groups: int=1 726 | ): 727 | if nonlinearity == 'tanh': 728 | mode = 'RNN_TANH' 729 | elif nonlinearity == 'relu': 730 | mode = 'RNN_RELU' 731 | else: 732 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 733 | super().__init__( 734 | mode=mode, 735 | in_channels=in_channels, 736 | out_channels=out_channels, 737 | kernel_size=kernel_size, 738 | bias=bias, 739 | convndim=1, 740 | stride=stride, 741 | dilation=dilation, 742 | groups=groups 743 | ) 744 | 745 | 746 | class Conv1dLSTMCell(ConvRNNCellBase): 747 | def __init__(self, 748 | in_channels: int, 749 | out_channels: int, 750 | kernel_size: Union[int, Sequence[int]], 751 | bias: bool=True, 752 | stride: Union[int, Sequence[int]]=1, 753 | dilation: Union[int, Sequence[int]]=1, 754 | groups: int=1 755 | ): 756 | super().__init__( 757 | mode='LSTM', 758 | in_channels=in_channels, 759 | out_channels=out_channels, 760 | kernel_size=kernel_size, 761 | bias=bias, 762 | convndim=1, 763 | stride=stride, 764 | dilation=dilation, 765 | groups=groups 766 | ) 767 | 768 | 769 | class Conv1dPeepholeLSTMCell(ConvRNNCellBase): 770 | def __init__(self, 771 | in_channels: int, 772 | out_channels: int, 773 | kernel_size: Union[int, Sequence[int]], 774 | bias: bool=True, 775 | stride: Union[int, Sequence[int]]=1, 776 | dilation: Union[int, Sequence[int]]=1, 777 | groups: int=1 778 | ): 779 | super().__init__( 780 | mode='PeepholeLSTM', 781 | in_channels=in_channels, 782 | out_channels=out_channels, 783 | kernel_size=kernel_size, 784 | bias=bias, 785 | convndim=1, 786 | stride=stride, 787 | dilation=dilation, 788 | groups=groups 789 | ) 790 | 791 | 792 | class Conv1dGRUCell(ConvRNNCellBase): 793 | def __init__(self, 794 | in_channels: int, 795 | out_channels: int, 796 | kernel_size: Union[int, Sequence[int]], 797 | bias: bool=True, 798 | stride: Union[int, Sequence[int]]=1, 799 | dilation: Union[int, Sequence[int]]=1, 800 | groups: int=1 801 | ): 802 | super().__init__( 803 | mode='GRU', 804 | in_channels=in_channels, 805 | out_channels=out_channels, 806 | kernel_size=kernel_size, 807 | bias=bias, 808 | convndim=1, 809 | stride=stride, 810 | dilation=dilation, 811 | groups=groups 812 | ) 813 | 814 | 815 | class Conv2dRNNCell(ConvRNNCellBase): 816 | def __init__(self, 817 | in_channels: int, 818 | out_channels: int, 819 | kernel_size: Union[int, Sequence[int]], 820 | nonlinearity: str='tanh', 821 | bias: bool=True, 822 | stride: Union[int, Sequence[int]]=1, 823 | dilation: Union[int, Sequence[int]]=1, 824 | groups: int=1 825 | ): 826 | if nonlinearity == 'tanh': 827 | mode = 'RNN_TANH' 828 | elif nonlinearity == 'relu': 829 | mode = 'RNN_RELU' 830 | else: 831 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 832 | super().__init__( 833 | mode=mode, 834 | in_channels=in_channels, 835 | out_channels=out_channels, 836 | kernel_size=kernel_size, 837 | bias=bias, 838 | convndim=2, 839 | stride=stride, 840 | dilation=dilation, 841 | groups=groups 842 | ) 843 | 844 | 845 | class Conv2dLSTMCell(ConvRNNCellBase): 846 | def __init__(self, 847 | in_channels: int, 848 | out_channels: int, 849 | kernel_size: Union[int, Sequence[int]], 850 | bias: bool=True, 851 | stride: Union[int, Sequence[int]]=1, 852 | dilation: Union[int, Sequence[int]]=1, 853 | groups: int=1 854 | ): 855 | super().__init__( 856 | mode='LSTM', 857 | in_channels=in_channels, 858 | out_channels=out_channels, 859 | kernel_size=kernel_size, 860 | bias=bias, 861 | convndim=2, 862 | stride=stride, 863 | dilation=dilation, 864 | groups=groups 865 | ) 866 | 867 | 868 | class Conv2dPeepholeLSTMCell(ConvRNNCellBase): 869 | def __init__(self, 870 | in_channels: int, 871 | out_channels: int, 872 | kernel_size: Union[int, Sequence[int]], 873 | bias: bool=True, 874 | stride: Union[int, Sequence[int]]=1, 875 | dilation: Union[int, Sequence[int]]=1, 876 | groups: int=1 877 | ): 878 | super().__init__( 879 | mode='PeepholeLSTM', 880 | in_channels=in_channels, 881 | out_channels=out_channels, 882 | kernel_size=kernel_size, 883 | bias=bias, 884 | convndim=2, 885 | stride=stride, 886 | dilation=dilation, 887 | groups=groups 888 | ) 889 | 890 | 891 | class Conv2dGRUCell(ConvRNNCellBase): 892 | def __init__(self, 893 | in_channels: int, 894 | out_channels: int, 895 | kernel_size: Union[int, Sequence[int]], 896 | bias: bool=True, 897 | stride: Union[int, Sequence[int]]=1, 898 | dilation: Union[int, Sequence[int]]=1, 899 | groups: int=1 900 | ): 901 | super().__init__( 902 | mode='GRU', 903 | in_channels=in_channels, 904 | out_channels=out_channels, 905 | kernel_size=kernel_size, 906 | bias=bias, 907 | convndim=2, 908 | stride=stride, 909 | dilation=dilation, 910 | groups=groups 911 | ) 912 | 913 | 914 | class Conv3dRNNCell(ConvRNNCellBase): 915 | def __init__(self, 916 | in_channels: int, 917 | out_channels: int, 918 | kernel_size: Union[int, Sequence[int]], 919 | nonlinearity: str='tanh', 920 | bias: bool=True, 921 | stride: Union[int, Sequence[int]]=1, 922 | dilation: Union[int, Sequence[int]]=1, 923 | groups: int=1 924 | ): 925 | if nonlinearity == 'tanh': 926 | mode = 'RNN_TANH' 927 | elif nonlinearity == 'relu': 928 | mode = 'RNN_RELU' 929 | else: 930 | raise ValueError("Unknown nonlinearity '{}'".format(nonlinearity)) 931 | super().__init__( 932 | mode=mode, 933 | in_channels=in_channels, 934 | out_channels=out_channels, 935 | kernel_size=kernel_size, 936 | bias=bias, 937 | convndim=3, 938 | stride=stride, 939 | dilation=dilation, 940 | groups=groups 941 | ) 942 | 943 | 944 | class Conv3dLSTMCell(ConvRNNCellBase): 945 | def __init__(self, 946 | in_channels: int, 947 | out_channels: int, 948 | kernel_size: Union[int, Sequence[int]], 949 | bias: bool=True, 950 | stride: Union[int, Sequence[int]]=1, 951 | dilation: Union[int, Sequence[int]]=1, 952 | groups: int=1 953 | ): 954 | super().__init__( 955 | mode='LSTM', 956 | in_channels=in_channels, 957 | out_channels=out_channels, 958 | kernel_size=kernel_size, 959 | bias=bias, 960 | convndim=3, 961 | stride=stride, 962 | dilation=dilation, 963 | groups=groups 964 | ) 965 | 966 | 967 | class Conv3dPeepholeLSTMCell(ConvRNNCellBase): 968 | def __init__(self, 969 | in_channels: int, 970 | out_channels: int, 971 | kernel_size: Union[int, Sequence[int]], 972 | bias: bool=True, 973 | stride: Union[int, Sequence[int]]=1, 974 | dilation: Union[int, Sequence[int]]=1, 975 | groups: int=1 976 | ): 977 | super().__init__( 978 | mode='PeepholeLSTM', 979 | in_channels=in_channels, 980 | out_channels=out_channels, 981 | kernel_size=kernel_size, 982 | bias=bias, 983 | convndim=3, 984 | stride=stride, 985 | dilation=dilation, 986 | groups=groups 987 | ) 988 | 989 | 990 | class Conv3dGRUCell(ConvRNNCellBase): 991 | def __init__(self, 992 | in_channels: int, 993 | out_channels: int, 994 | kernel_size: Union[int, Sequence[int]], 995 | bias: bool=True, 996 | stride: Union[int, Sequence[int]]=1, 997 | dilation: Union[int, Sequence[int]]=1, 998 | groups: int=1 999 | ): 1000 | super().__init__( 1001 | mode='GRU', 1002 | in_channels=in_channels, 1003 | out_channels=out_channels, 1004 | kernel_size=kernel_size, 1005 | bias=bias, 1006 | convndim=3, 1007 | stride=stride, 1008 | dilation=dilation, 1009 | groups=groups 1010 | ) 1011 | -------------------------------------------------------------------------------- /convolutional_rnn/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from itertools import repeat 3 | 4 | 5 | """ Copied from torch.nn.modules.utils """ 6 | 7 | 8 | def _ntuple(n): 9 | def parse(x): 10 | if isinstance(x, collections.Iterable): 11 | return x 12 | return tuple(repeat(x, n)) 13 | return parse 14 | 15 | 16 | _single = _ntuple(1) 17 | _pair = _ntuple(2) 18 | _triple = _ntuple(3) 19 | _quadruple = _ntuple(4) 20 | --------------------------------------------------------------------------------