├── .github └── workflows │ └── test.yaml ├── LICENSE ├── README.md ├── cell.py └── requirements.txt /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - main 5 | pull_request: 6 | branches: 7 | - main 8 | 9 | concurrency: 10 | group: ${{ github.event.pull_request.number || github.ref }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | test: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | - uses: actions/setup-python@v4 19 | with: 20 | cache: pip 21 | - run: pip install -r requirements.txt 22 | - run: python -c 'exec(open("README.md").read().split("```")[1][6:])' 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Carl Thomé 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow ConvLSTM Cell 2 | A ConvLSTM cell for TensorFlow's RNN API. 3 | 4 | ```python 5 | import tensorflow.compat.v1 as tf 6 | 7 | tf.disable_v2_behavior() 8 | 9 | batch_size = 32 10 | timesteps = 100 11 | shape = [640, 480] 12 | kernel = [3, 3] 13 | channels = 3 14 | filters = 12 15 | 16 | # Create a placeholder for videos. 17 | inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels]) 18 | 19 | # Add the ConvLSTM step. 20 | from cell import ConvLSTMCell 21 | cell = ConvLSTMCell(shape, filters, kernel) 22 | outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype) 23 | 24 | # There's also a ConvGRUCell that is more memory efficient. 25 | from cell import ConvGRUCell 26 | cell = ConvGRUCell(shape, filters, kernel) 27 | outputs, state = tf.nn.dynamic_rnn(cell, inputs, dtype=inputs.dtype) 28 | 29 | # It's also possible to enter 2D input or 4D input instead of 3D. 30 | shape = [100] 31 | kernel = [3] 32 | inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels]) 33 | cell = ConvLSTMCell(shape, filters, kernel) 34 | outputs, state = tf.nn.bidirectional_dynamic_rnn(cell, cell, inputs, dtype=inputs.dtype) 35 | 36 | shape = [50, 50, 50] 37 | kernel = [1, 3, 5] 38 | inputs = tf.placeholder(tf.float32, [batch_size, timesteps] + shape + [channels]) 39 | cell = ConvGRUCell(shape, filters, kernel) 40 | outputs, state= tf.nn.bidirectional_dynamic_rnn(cell, cell, inputs, dtype=inputs.dtype) 41 | ``` 42 | -------------------------------------------------------------------------------- /cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | 3 | class ConvLSTMCell(tf.nn.rnn_cell.RNNCell): 4 | """A LSTM cell with convolutions instead of multiplications. 5 | 6 | Reference: 7 | Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015. 8 | """ 9 | 10 | def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, data_format='channels_last', reuse=None): 11 | super(ConvLSTMCell, self).__init__(_reuse=reuse) 12 | self._kernel = kernel 13 | self._filters = filters 14 | self._forget_bias = forget_bias 15 | self._activation = activation 16 | if data_format == 'channels_last': 17 | self._size = tf.TensorShape(shape + [self._filters]) 18 | self._feature_axis = self._size.ndims 19 | self._data_format = None 20 | elif data_format == 'channels_first': 21 | self._size = tf.TensorShape([self._filters] + shape) 22 | self._feature_axis = 0 23 | self._data_format = 'NC' 24 | else: 25 | raise ValueError('Unknown data_format') 26 | 27 | @property 28 | def state_size(self): 29 | return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size) 30 | 31 | @property 32 | def output_size(self): 33 | return self._size 34 | 35 | def call(self, x, state): 36 | c, h = state 37 | 38 | x = tf.concat([x, h], axis=self._feature_axis) 39 | n = x.shape[-1].value 40 | m = 4 * self._filters if self._filters > 1 else 4 41 | W = tf.get_variable('kernel', self._kernel + [n, m]) 42 | y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format) 43 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer()) 44 | j, i, f, o = tf.split(y, 4, axis=self._feature_axis) 45 | 46 | f = tf.sigmoid(f + self._forget_bias) 47 | i = tf.sigmoid(i) 48 | c = c * f + i * self._activation(j) 49 | 50 | o = tf.sigmoid(o) 51 | h = o * self._activation(c) 52 | 53 | state = tf.nn.rnn_cell.LSTMStateTuple(c, h) 54 | 55 | return h, state 56 | 57 | 58 | class ConvGRUCell(tf.nn.rnn_cell.RNNCell): 59 | """A GRU cell with convolutions instead of multiplications.""" 60 | 61 | def __init__(self, shape, filters, kernel, activation=tf.tanh, data_format='channels_last', reuse=None): 62 | super(ConvGRUCell, self).__init__(_reuse=reuse) 63 | self._filters = filters 64 | self._kernel = kernel 65 | self._activation = activation 66 | if data_format == 'channels_last': 67 | self._size = tf.TensorShape(shape + [self._filters]) 68 | self._feature_axis = self._size.ndims 69 | self._data_format = None 70 | elif data_format == 'channels_first': 71 | self._size = tf.TensorShape([self._filters] + shape) 72 | self._feature_axis = 0 73 | self._data_format = 'NC' 74 | else: 75 | raise ValueError('Unknown data_format') 76 | 77 | @property 78 | def state_size(self): 79 | return self._size 80 | 81 | @property 82 | def output_size(self): 83 | return self._size 84 | 85 | def call(self, x, h): 86 | channels = x.shape[self._feature_axis].value 87 | 88 | with tf.variable_scope('gates'): 89 | inputs = tf.concat([x, h], axis=self._feature_axis) 90 | n = channels + self._filters 91 | m = 2 * self._filters if self._filters > 1 else 2 92 | W = tf.get_variable('kernel', self._kernel + [n, m]) 93 | y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format) 94 | y += tf.get_variable('bias', [m], initializer=tf.ones_initializer()) 95 | r, u = tf.split(y, 2, axis=self._feature_axis) 96 | r, u = tf.sigmoid(r), tf.sigmoid(u) 97 | 98 | with tf.variable_scope('candidate'): 99 | inputs = tf.concat([x, r * h], axis=self._feature_axis) 100 | n = channels + self._filters 101 | m = self._filters 102 | W = tf.get_variable('kernel', self._kernel + [n, m]) 103 | y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format) 104 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer()) 105 | h = u * h + (1 - u) * self._activation(y) 106 | 107 | return h, h 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-cpu 2 | --------------------------------------------------------------------------------