├── ConvLSTMCell.py ├── LICENSE └── README.md /ConvLSTMCell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.python.ops import array_ops 4 | from tensorflow.python.ops import variable_scope as vs 5 | from tensorflow.python.ops.math_ops import sigmoid 6 | from tensorflow.python.ops.math_ops import tanh 7 | from tensorflow.python.ops import init_ops 8 | 9 | from tensorflow.python.util import nest 10 | import collections 11 | 12 | class ConvLSTMCell(object): 13 | """ Convolutional LSTM network cell (ConvLSTMCell). 14 | The implementation is based on http://arxiv.org/abs/1506.04214. 15 | and `BasicLSTMCell` in TensorFlow. 16 | """ 17 | def __init__(self, hidden_num, filter_size=[3,3], 18 | forget_bias=1.0, activation=tanh, name="ConvLSTMCell"): 19 | self.hidden_num = hidden_num 20 | self.filter_size = filter_size 21 | self.forget_bias = forget_bias 22 | self.activation = activation 23 | self.name = name 24 | 25 | def zero_state(self, batch_size, height, width): 26 | return tf.zeros([batch_size, height, width, self.hidden_num*2]) 27 | 28 | def __call__(self, inputs, state, scope=None): 29 | """Convolutional Long short-term memory cell (ConvLSTM).""" 30 | with vs.variable_scope(scope or self.name): # "ConvLSTMCell" 31 | c, h = array_ops.split(3, 2, state) 32 | 33 | # batch_size * height * width * channel 34 | concat = _conv([inputs, h], 4 * self.hidden_num, self.filter_size) 35 | 36 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate 37 | i, j, f, o = array_ops.split(3, 4, concat) 38 | 39 | new_c = (c * sigmoid(f + self.forget_bias) + sigmoid(i) * 40 | self.activation(j)) 41 | new_h = self.activation(new_c) * sigmoid(o) 42 | new_state = array_ops.concat(3, [new_c, new_h]) 43 | 44 | return new_h, new_state 45 | 46 | def _conv(args, output_size, filter_size, stddev=0.001, bias=True, bias_start=0.0, scope=None): 47 | if args is None or (nest.is_sequence(args) and not args): 48 | raise ValueError("`args` must be specified") 49 | if not nest.is_sequence(args): 50 | args = [args] 51 | 52 | # Calculate the total size of arguments on dimension 3. 53 | # (batch_size x height x width x arg_size) 54 | total_arg_size = 0 55 | shapes = [a.get_shape().as_list() for a in args] 56 | height = shapes[0][1] 57 | width = shapes[0][2] 58 | for shape in shapes: 59 | if len(shape) != 4: 60 | raise ValueError("Conv is expecting 3D arguments: %s" % str(shapes)) 61 | if not shape[3]: 62 | raise ValueError("Conv expects shape[3] of arguments: %s" % str(shapes)) 63 | if shape[1] == height and shape[2] == width: 64 | total_arg_size += shape[3] 65 | else : 66 | raise ValueError("Inconsistent height and width size in arguments: %s" % str(shapes)) 67 | 68 | with vs.variable_scope(scope or "Conv"): 69 | kernel = vs.get_variable("Kernel", 70 | [filter_size[0], filter_size[1], total_arg_size, output_size], 71 | initializer=init_ops.truncated_normal_initializer(stddev=stddev)) 72 | 73 | if len(args) == 1: 74 | res = tf.nn.conv2d(args[0], kernel, [1, 1, 1, 1], padding='SAME') 75 | else: 76 | res = tf.nn.conv2d(array_ops.concat(3, args), kernel, [1, 1, 1, 1], padding='SAME') 77 | 78 | if not bias: return res 79 | bias_term = vs.get_variable( "Bias", [output_size], 80 | initializer=init_ops.constant_initializer(bias_start)) 81 | return res + bias_term 82 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 iwyoo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvLSTMCell-tensorflow 2 | Convolutional LSTM network cell (ConvLSTM). 3 | The implementation is based on (http://arxiv.org/abs/1506.04214) and BasicLSTMCell in TensorFlow. 4 | 5 | ## Example 6 | ```python 7 | p_input = tf.placeholder(tf.float32, [None, height, width, step_size, channel]) 8 | p_label = tf.placeholder(tf.float32, [None, height, width, 3]) 9 | 10 | p_input_list = tf.split(3, step_size, p_input) 11 | p_input_list = [tf.squeeze(p_input_, [3]) for p_input_ in p_input_list] 12 | 13 | cell = ConvLSTMCell(hidden_num) 14 | state = cell.zero_state(batch_size, height, width) 15 | 16 | with tf.variable_scope("ConvLSTM") as scope: # as BasicLSTMCell 17 | for i, p_input_ in enumerate(p_input_list): 18 | if i > 0: 19 | scope.reuse_variables() 20 | # ConvCell takes Tensor with size [batch_size, height, width, channel]. 21 | t_output, state = cell(p_input_, state) 22 | ``` 23 | --------------------------------------------------------------------------------