├── .gitignore ├── LICENSE ├── README.md ├── img ├── lstm_vs_nlstm.png └── nlstm_architecture.png ├── rnn_cell.py └── rnn_cell_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Hann Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nlstm 2 | ## Tensorflow Implementation of Nested LSTM Cell 3 | 4 | Here is a tensorflow implementation of Nested LSTM cell. 5 | 6 | | ![nlstm architecture](img/nlstm_architecture.png) | 7 | |:--:| 8 | | *Nested LSTM Architecture. Courtesy of Moniz et al.* | 9 | 10 | NLSTM cell is basically a LSTM-like cell that uses the cell memory to control the state of the inner LSTM, and as such, the architecture can be generalized to multiple layers. For a comparison between LSTM and NLSTM, 11 | 12 | | ![lstm vs nlstm](img/lstm_vs_nlstm.png) | 13 | |:--:| 14 | | *LSTM and stacked LSTM, versus nested LSTM. Courtesy of Moniz et al.* | 15 | 16 | The implementation here is compatible with the tensorflow rnn API. 17 | 18 | ```python 19 | from rnn_cell import NLSTMCell 20 | cell = NLSTMCell(num_units=3, depth=2) 21 | init_state = cell.zero_state(batch_size, dtype=tf.float32) 22 | output, new_state = cell(inputs, state=init_state) 23 | ... 24 | ``` 25 | 26 | Ref: 27 | - Moniz et al, "Nested LSTMs." https://arxiv.org/abs/1801.10308 28 | 29 | -------------------------------------------------------------------------------- /img/lstm_vs_nlstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hannw/nlstm/aa6e96c8746a4cbc1599301253c38407c0db61a9/img/lstm_vs_nlstm.png -------------------------------------------------------------------------------- /img/nlstm_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hannw/nlstm/aa6e96c8746a4cbc1599301253c38407c0db61a9/img/nlstm_architecture.png -------------------------------------------------------------------------------- /rnn_cell.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from tensorflow.python.framework import constant_op 6 | from tensorflow.python.framework import dtypes 7 | from tensorflow.python.ops import rnn_cell_impl 8 | from tensorflow.python.ops import array_ops 9 | from tensorflow.python.ops import init_ops 10 | from tensorflow.python.ops import math_ops 11 | from tensorflow.python.ops import nn_ops 12 | from tensorflow.python.platform import tf_logging as logging 13 | 14 | 15 | import tensorflow as tf 16 | from tensorflow.python.layers import base as base_layer 17 | 18 | _BIAS_VARIABLE_NAME = "bias" 19 | _WEIGHTS_VARIABLE_NAME = "kernel" 20 | 21 | 22 | class NLSTMCell(rnn_cell_impl.RNNCell): 23 | """Nested LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell` 24 | 25 | The implementation is based on: 26 | https://arxiv.org/abs/1801.10308 27 | JRA. Moniz, D. Krueger. 28 | "Nested LSTMs" 29 | ACML, PMLR 77:530-544, 2017 30 | """ 31 | 32 | def __init__(self, num_units, depth, forget_bias=1.0, 33 | state_is_tuple=True, use_peepholes=False, 34 | activation=None, gate_activation=None, 35 | cell_activation=None, 36 | initializer=None, 37 | input_gate_initializer=None, 38 | use_bias=True, reuse=None, name=None): 39 | """Initialize the basic NLSTM cell. 40 | 41 | Args: 42 | num_units: `int`, The number of hidden units of each cell state 43 | and hidden state. 44 | depth: `int`, The number of layers in the nest. 45 | forget_bias: `float`, The bias added to forget gates. 46 | state_is_tuple: If `True`, accepted and returned states are tuples of 47 | the `h_state` and `c_state`s. If `False`, they are concatenated 48 | along the column axis. The latter behavior will soon be deprecated. 49 | use_peepholes: `bool`(optional). 50 | activation: Activation function of the update values, 51 | including new inputs and new cell states. Default: `tanh`. 52 | gate_activation: Activation function of the gates, 53 | including the input, ouput, and forget gate. Default: `sigmoid`. 54 | cell_activation: Activation function of the first cell gate. Default: `identity`. 55 | Note that in the paper only the first cell_activation is identity. 56 | initializer: Initializer of kernel. Default: `orthogonal_initializer`. 57 | input_gate_initializer: Initializer of input gates. 58 | Default: `glorot_normal_initializer`. 59 | use_bias: `bool`. Default: `True`. 60 | reuse: `bool`(optional) Python boolean describing whether to reuse variables 61 | in an existing scope. If not `True`, and the existing scope already has 62 | the given variables, an error is raised. 63 | name: `str`, the name of the layer. Layers with the same name will 64 | share weights, but to avoid mistakes we require reuse=True in such 65 | cases. 66 | """ 67 | super(NLSTMCell, self).__init__(_reuse=reuse, name=name) 68 | if not state_is_tuple: 69 | logging.warn("%s: Using a concatenated state is slower and will soon be " 70 | "deprecated. Use state_is_tuple=True.", self) 71 | 72 | # Inputs must be 2-dimensional. 73 | self.input_spec = base_layer.InputSpec(ndim=2) 74 | self._num_units = num_units 75 | self._forget_bias = forget_bias 76 | self._state_is_tuple = state_is_tuple 77 | self._use_peepholes = use_peepholes 78 | self._depth = depth 79 | self._activation = activation or math_ops.tanh 80 | self._gate_activation = gate_activation or math_ops.sigmoid 81 | self._cell_activation = cell_activation or array_ops.identity 82 | self._initializer = initializer or init_ops.orthogonal_initializer() 83 | self._input_gate_initializer = (input_gate_initializer 84 | or init_ops.glorot_normal_initializer()) 85 | self._use_bias = use_bias 86 | self._kernels = None 87 | self._biases = None 88 | self.built = False 89 | 90 | @property 91 | def state_size(self): 92 | if self._state_is_tuple: 93 | return tuple([self._num_units] * (self.depth + 1)) 94 | else: 95 | return self._num_units * (self.depth + 1) 96 | 97 | @property 98 | def output_size(self): 99 | return self._num_units 100 | 101 | @property 102 | def depth(self): 103 | return self._depth 104 | 105 | def build(self, inputs_shape): 106 | if inputs_shape[1].value is None: 107 | raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" 108 | % inputs_shape) 109 | 110 | input_depth = inputs_shape[1].value 111 | h_depth = self._num_units 112 | self._kernels = [] 113 | if self._use_bias: 114 | self._biases = [] 115 | 116 | if self._use_peepholes: 117 | self._peep_kernels = [] 118 | for i in range(self.depth): 119 | if i == 0: 120 | input_kernel = self.add_variable( 121 | "input_gate_kernel", 122 | shape=[input_depth, 4 * self._num_units], 123 | initializer=self._input_gate_initializer) 124 | hidden_kernel = self.add_variable( 125 | "hidden_gate_kernel", 126 | shape=[h_depth, 4 * self._num_units], 127 | initializer=self._initializer) 128 | kernel = tf.concat([input_kernel, hidden_kernel], 129 | axis=0, name="kernel_0") 130 | self._kernels.append(kernel) 131 | else: 132 | self._kernels.append( 133 | self.add_variable( 134 | "kernel_{}".format(i), 135 | shape=[2 * h_depth, 4 * self._num_units], 136 | initializer=self._initializer)) 137 | if self._use_bias: 138 | self._biases.append( 139 | self.add_variable( 140 | "bias_{}".format(i), 141 | shape=[4 * self._num_units], 142 | initializer=init_ops.zeros_initializer(dtype=self.dtype))) 143 | if self._use_peepholes: 144 | self._peep_kernels.append( 145 | self.add_variable( 146 | "peep_kernel_{}".format(i), 147 | shape=[h_depth, 3 * self._num_units], 148 | initializer=self._initializer)) 149 | 150 | self.built = True 151 | 152 | def _recurrence(self, inputs, hidden_state, cell_states, depth): 153 | """use recurrence to traverse the nested structure 154 | 155 | Args: 156 | inputs: A 2D `Tensor` of [batch_size x input_size] shape. 157 | hidden_state: A 2D `Tensor` of [batch_size x num_units] shape. 158 | cell_states: A `list` of 2D `Tensor` of [batch_size x num_units] shape. 159 | depth: `int` 160 | the current depth in the nested structure, begins at 0. 161 | 162 | Returns: 163 | new_h: A 2D `Tensor` of [batch_size x num_units] shape. 164 | the latest hidden state for current step. 165 | new_cs: A `list` of 2D `Tensor` of [batch_size x num_units] shape. 166 | The accumulated cell states for current step. 167 | """ 168 | sigmoid = math_ops.sigmoid 169 | one = constant_op.constant(1, dtype=dtypes.int32) 170 | # Parameters of gates are concatenated into one multiply for efficiency. 171 | c = cell_states[depth] 172 | h = hidden_state 173 | 174 | gate_inputs = math_ops.matmul( 175 | array_ops.concat([inputs, h], 1), self._kernels[depth]) 176 | if self._use_bias: 177 | gate_inputs = nn_ops.bias_add(gate_inputs, self._biases[depth]) 178 | if self._use_peepholes: 179 | peep_gate_inputs = math_ops.matmul(c, self._peep_kernels[depth]) 180 | i_peep, f_peep, o_peep = array_ops.split( 181 | value=peep_gate_inputs, num_or_size_splits=3, axis=one) 182 | 183 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate 184 | i, j, f, o = array_ops.split( 185 | value=gate_inputs, num_or_size_splits=4, axis=one) 186 | if self._use_peepholes: 187 | i += i_peep 188 | f += f_peep 189 | o += o_peep 190 | 191 | if self._use_peepholes: 192 | peep_gate_inputs = math_ops.matmul(c, self._peep_kernels[depth]) 193 | i_peep, f_peep, o_peep = array_ops.split( 194 | value=peep_gate_inputs, num_or_size_splits=3, axis=one) 195 | i += i_peep 196 | f += f_peep 197 | o += o_peep 198 | 199 | # Note that using `add` and `multiply` instead of `+` and `*` gives a 200 | # performance improvement. So using those at the cost of readability. 201 | add = math_ops.add 202 | multiply = math_ops.multiply 203 | 204 | if self._use_bias: 205 | forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) 206 | f = add(f, forget_bias_tensor) 207 | 208 | inner_hidden = multiply(c, self._gate_activation(f)) 209 | 210 | if depth == 0: 211 | inner_input = multiply(self._gate_activation(i), self._cell_activation(j)) 212 | else: 213 | inner_input = multiply(self._gate_activation(i), self._activation(j)) 214 | 215 | if depth == (self.depth - 1): 216 | new_c = add(inner_hidden, inner_input) 217 | new_cs = [new_c] 218 | else: 219 | new_c, new_cs = self._recurrence( 220 | inputs=inner_input, 221 | hidden_state=inner_hidden, 222 | cell_states=cell_states, 223 | depth=depth + 1) 224 | new_h = multiply(self._activation(new_c), self._gate_activation(o)) 225 | new_cs = [new_h] + new_cs 226 | return new_h, new_cs 227 | 228 | def call(self, inputs, state): 229 | """forward propagation of the cell 230 | 231 | Args: 232 | inputs: a 2D `Tensor` of [batch_size x input_size] shape 233 | state: a `tuple` of 2D `Tensor` of [batch_size x num_units] shape 234 | or a `Tensor` of [batch_size x (num_units * (self.depth + 1))] shape 235 | 236 | Returns: 237 | outputs: a 2D `Tensor` of [batch_size x num_units] shape 238 | next_state: a `tuple` of 2D `Tensor` of [batch_size x num_units] shape 239 | or a `Tensor` of [batch_size x (num_units * (self.depth + 1))] shape 240 | """ 241 | if not self._state_is_tuple: 242 | states = array_ops.split(state, self.depth + 1, axis=1) 243 | else: 244 | states = state 245 | hidden_state = states[0] 246 | cell_states = states[1:] 247 | outputs, next_state = self._recurrence(inputs, hidden_state, cell_states, 0) 248 | if self._state_is_tuple: 249 | next_state = tuple(next_state) 250 | else: 251 | next_state = array_ops.concat(next_state, axis=1) 252 | return outputs, next_state 253 | -------------------------------------------------------------------------------- /rnn_cell_test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import itertools 6 | 7 | import rnn_cell as contrib_rnn_cell 8 | import tensorflow as tf 9 | from tensorflow import test 10 | from tensorflow.python.framework import dtypes 11 | from tensorflow.python.framework import ops 12 | from tensorflow.python.ops import random_ops 13 | from tensorflow.python.ops import variables 14 | 15 | 16 | class TestNLSTM(test.TestCase): 17 | 18 | def _check_tuple_cell(self, *args, **kwargs): 19 | batch_size = 2 20 | num_units = 3 21 | depth = 4 22 | g = ops.Graph() 23 | with self.test_session(graph=g) as sess: 24 | with g.as_default(): 25 | cell = contrib_rnn_cell.NLSTMCell(num_units, depth, *args, **kwargs) 26 | init_state = cell.zero_state(batch_size, dtype=dtypes.float32) 27 | output, new_state = cell( 28 | inputs=random_ops.random_normal([batch_size, 5]), 29 | state=init_state) 30 | variables.global_variables_initializer().run() 31 | vals = sess.run([output, new_state]) 32 | self.assertAllEqual(vals[0], vals[1][0]) 33 | self.assertAllEqual(vals[0].shape, [2, 3]) 34 | for val in vals[1]: 35 | self.assertAllEqual(val.shape, [2, 3]) 36 | self.assertEqual(len(vals[1]), 5) 37 | self.assertAllEqual(cell.state_size, [num_units] * (depth + 1)) 38 | self.assertEqual(cell.depth, depth) 39 | self.assertEqual(cell.output_size, num_units) 40 | 41 | def _check_non_tuple_cell(self, *args, **kwargs): 42 | batch_size = 2 43 | num_units = 3 44 | depth = 2 45 | g = ops.Graph() 46 | with self.test_session(graph=g) as sess: 47 | with g.as_default(): 48 | cell = contrib_rnn_cell.NLSTMCell(num_units, depth, 49 | *args, **kwargs) 50 | init_state = cell.zero_state(batch_size, dtype=dtypes.float32) 51 | output, new_state = cell( 52 | inputs=random_ops.random_normal([batch_size, 5]), 53 | state=init_state) 54 | variables.global_variables_initializer().run() 55 | vals = sess.run([output, new_state]) 56 | self.assertAllEqual(vals[0], vals[1][:, :3]) 57 | self.assertAllEqual(vals[0].shape, [2, 3]) 58 | self.assertAllEqual(vals[1].shape, [2, 9]) 59 | self.assertEqual(cell.state_size, num_units * (depth + 1)) 60 | self.assertEqual(cell.depth, depth) 61 | self.assertEqual(cell.output_size, num_units) 62 | 63 | def testNLSTMBranches(self): 64 | state_is_tuples = [True, False] 65 | use_peepholes = [True, False] 66 | use_biases = [True, False] 67 | options = itertools.product(state_is_tuples, use_peepholes, use_biases) 68 | for option in options: 69 | state_is_tuple, use_peephole, use_bias = option 70 | if state_is_tuple: 71 | self._check_tuple_cell( 72 | state_is_tuple=state_is_tuple, 73 | use_peepholes=use_peephole, use_bias=use_bias) 74 | else: 75 | self._check_non_tuple_cell( 76 | state_is_tuple=state_is_tuple, 77 | use_peepholes=use_peephole, use_bias=use_bias) 78 | 79 | --------------------------------------------------------------------------------