├── images ├── difference.PNG └── nested_lstm_diagram.PNG ├── LICENSE ├── README.md ├── .gitignore ├── imdb_nested_lstm.py └── nested_lstm.py /images/difference.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/Nested-LSTM/HEAD/images/difference.PNG -------------------------------------------------------------------------------- /images/nested_lstm_diagram.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/Nested-LSTM/HEAD/images/nested_lstm_diagram.PNG -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Somshubra Majumdar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nested LSTM 2 | Keras implementation of Nested LSTMs from the paper [Nested LSTMs](https://arxiv.org/abs/1801.10308) 3 | 4 | From the paper: 5 | > Nested LSTMs add depth to LSTMs via nesting as opposed to stacking. The value of a memory cell 6 | in an NLSTM is computed by an LSTM cell, which has its own inner memory cell. Nested LSTMs outperform both stacked and single-layer 7 | LSTMs with similar numbers of parameters in our experiments on various character-level language 8 | modeling tasks, and the inner memories of an LSTM learn longer term dependencies compared with 9 | the higher-level units of a stacked LSTM 10 | 11 | # Usage 12 | Via Cells 13 | ```python 14 | from nested_lstm import NestedLSTMCell 15 | from keras.layers import RNN 16 | 17 | ip = Input(shape=(nb_timesteps, input_dim)) 18 | x = RNN(NestedLSTMCell(units=64, depth=2))(ip) 19 | ... 20 | ``` 21 | 22 | Via Layer 23 | ```python 24 | from nested_lstm import NestedLSTM 25 | 26 | ip = Input(shape=(nb_timesteps, input_dim)) 27 | x = NestedLSTM(units=64, depth=2)(ip) 28 | ... 29 | ``` 30 | 31 | # Difference between Stacked LSTMs and Nested LSTMs (from the paper) 32 | 33 | 34 | # Cell diagram (depth = 2, from the paper) 35 | 36 | 37 | # Acknowledgements 38 | Keras code heavily derived from the Tensorflow implementation - https://github.com/hannw/nlstm 39 | 40 | # Requirements 41 | - Keras 2.1.3+ 42 | - Tensorflow 1.2+ or Theano. CNTK untested. 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # weights 104 | weights/* 105 | logs/* 106 | -------------------------------------------------------------------------------- /imdb_nested_lstm.py: -------------------------------------------------------------------------------- 1 | '''Trains a Minimal RNN on the IMDB sentiment classification task. 2 | The dataset is actually too small for Minimal RNN to be of any advantage 3 | compared to simpler, much faster methods such as TF-IDF + LogReg. 4 | ''' 5 | from __future__ import print_function 6 | 7 | from keras.preprocessing import sequence 8 | from keras.models import Sequential 9 | from keras.layers import Dense, Embedding 10 | from keras.callbacks import ModelCheckpoint 11 | from keras.datasets import imdb 12 | 13 | from nested_lstm import NestedLSTM 14 | 15 | max_features = 20000 16 | maxlen = 500 # cut texts after this number of words (among top max_features most common words) 17 | batch_size = 128 18 | 19 | print('Loading data...') 20 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) 21 | print(len(x_train), 'train sequences') 22 | print(len(x_test), 'test sequences') 23 | 24 | print('Pad sequences (samples x time)') 25 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen) 26 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen) 27 | print('x_train shape:', x_train.shape) 28 | print('x_test shape:', x_test.shape) 29 | 30 | # configuration matches 4.47 Million parameters with `units=600` and `64 embedding dim` 31 | print('Build model...') 32 | model = Sequential() 33 | model.add(Embedding(max_features, 128)) 34 | model.add(NestedLSTM(32, depth=2, dropout=0.0, recurrent_dropout=0.0)) 35 | model.add(Dense(1, activation='sigmoid')) 36 | 37 | # try using different optimizers and different optimizer configs 38 | model.compile(loss='binary_crossentropy', 39 | optimizer='adam', 40 | metrics=['accuracy']) 41 | 42 | model.summary() 43 | 44 | print('Train...') 45 | model.fit(x_train, y_train, 46 | batch_size=batch_size, 47 | epochs=15, 48 | validation_data=(x_test, y_test), 49 | callbacks=[ModelCheckpoint('weights/imdb_nlstm.h5', monitor='val_acc', 50 | save_best_only=True, save_weights_only=True)]) 51 | 52 | model.load_weights('weights/imdb_nlstm.h5') 53 | 54 | score, acc = model.evaluate(x_test, y_test, 55 | batch_size=batch_size) 56 | print('Test score:', score) 57 | print('Test accuracy:', acc) 58 | -------------------------------------------------------------------------------- /nested_lstm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from keras import backend as K 5 | from keras import activations 6 | from keras import initializers 7 | from keras import regularizers 8 | from keras import constraints 9 | from keras.engine import Layer 10 | from keras.engine import InputSpec 11 | from keras.legacy import interfaces 12 | from keras.layers import RNN 13 | from keras.layers.recurrent import _generate_dropout_mask 14 | from keras.layers import LSTMCell, LSTM 15 | 16 | 17 | class NestedLSTMCell(Layer): 18 | """Nested NestedLSTM Cell class. 19 | 20 | Derived from the paper [Nested LSTMs](https://arxiv.org/abs/1801.10308) 21 | Ref: [Tensorflow implementation](https://github.com/hannw/nlstm) 22 | 23 | # Arguments 24 | units: Positive integer, dimensionality of the output space. 25 | depth: Depth of nesting of the memory component. 26 | activation: Activation function to use 27 | (see [activations](../activations.md)). 28 | If you pass None, no activation is applied 29 | (ie. "linear" activation: `a(x) = x`). 30 | recurrent_activation: Activation function to use 31 | for the recurrent step 32 | (see [activations](../activations.md)). 33 | cell_activation: Activation function of the first cell gate. 34 | Note that in the paper only the first cell_activation is identity. 35 | (see [activations](../activations.md)). 36 | use_bias: Boolean, whether the layer uses a bias vector. 37 | kernel_initializer: Initializer for the `kernel` weights matrix, 38 | used for the linear transformation of the inputs 39 | (see [initializers](../initializers.md)). 40 | recurrent_initializer: Initializer for the `recurrent_kernel` 41 | weights matrix, 42 | used for the linear transformation of the recurrent state 43 | (see [initializers](../initializers.md)). 44 | bias_initializer: Initializer for the bias vector 45 | (see [initializers](../initializers.md)). 46 | unit_forget_bias: Boolean. 47 | If True, add 1 to the bias of the forget gate at initialization. 48 | Setting it to true will also force `bias_initializer="zeros"`. 49 | This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 50 | kernel_regularizer: Regularizer function applied to 51 | the `kernel` weights matrix 52 | (see [regularizer](../regularizers.md)). 53 | recurrent_regularizer: Regularizer function applied to 54 | the `recurrent_kernel` weights matrix 55 | (see [regularizer](../regularizers.md)). 56 | bias_regularizer: Regularizer function applied to the bias vector 57 | (see [regularizer](../regularizers.md)). 58 | kernel_constraint: Constraint function applied to 59 | the `kernel` weights matrix 60 | (see [constraints](../constraints.md)). 61 | recurrent_constraint: Constraint function applied to 62 | the `recurrent_kernel` weights matrix 63 | (see [constraints](../constraints.md)). 64 | bias_constraint: Constraint function applied to the bias vector 65 | (see [constraints](../constraints.md)). 66 | dropout: Float between 0 and 1. 67 | Fraction of the units to drop for 68 | the linear transformation of the inputs. 69 | recurrent_dropout: Float between 0 and 1. 70 | Fraction of the units to drop for 71 | the linear transformation of the recurrent state. 72 | implementation: Implementation mode, must be 2. 73 | Mode 1 will structure its operations as a larger number of 74 | smaller dot products and additions, whereas mode 2 will 75 | batch them into fewer, larger operations. These modes will 76 | have different performance profiles on different hardware and 77 | for different applications. 78 | """ 79 | 80 | def __init__(self, units, depth, 81 | activation='tanh', 82 | recurrent_activation='sigmoid', 83 | cell_activation='linear', 84 | use_bias=True, 85 | kernel_initializer='glorot_uniform', 86 | recurrent_initializer='orthogonal', 87 | bias_initializer='zeros', 88 | unit_forget_bias=False, 89 | kernel_regularizer=None, 90 | recurrent_regularizer=None, 91 | bias_regularizer=None, 92 | kernel_constraint=None, 93 | recurrent_constraint=None, 94 | bias_constraint=None, 95 | dropout=0., 96 | recurrent_dropout=0., 97 | implementation=2, 98 | **kwargs): 99 | super(NestedLSTMCell, self).__init__(**kwargs) 100 | 101 | if depth < 1: 102 | raise ValueError("`depth` must be at least 1. For better performance, consider using depth > 1.") 103 | 104 | if implementation != 1: 105 | warnings.warn( 106 | "Nested LSTMs only supports implementation 2 for the moment. Defaulting to implementation = 2") 107 | implementation = 2 108 | 109 | self.units = units 110 | self.depth = depth 111 | self.activation = activations.get(activation) 112 | self.recurrent_activation = activations.get(recurrent_activation) 113 | self.cell_activation = activations.get(cell_activation) 114 | self.use_bias = use_bias 115 | 116 | self.kernel_initializer = initializers.get(kernel_initializer) 117 | self.recurrent_initializer = initializers.get(recurrent_initializer) 118 | self.bias_initializer = initializers.get(bias_initializer) 119 | self.unit_forget_bias = unit_forget_bias 120 | 121 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 122 | self.recurrent_regularizer = regularizers.get(recurrent_regularizer) 123 | self.bias_regularizer = regularizers.get(bias_regularizer) 124 | 125 | self.kernel_constraint = constraints.get(kernel_constraint) 126 | self.recurrent_constraint = constraints.get(recurrent_constraint) 127 | self.bias_constraint = constraints.get(bias_constraint) 128 | 129 | self.dropout = min(1., max(0., dropout)) 130 | self.recurrent_dropout = min(1., max(0., recurrent_dropout)) 131 | self.implementation = implementation 132 | self.state_size = tuple([self.units] * (self.depth + 1)) 133 | self._dropout_mask = None 134 | self._nested_recurrent_masks = None 135 | 136 | def build(self, input_shape): 137 | input_dim = input_shape[-1] 138 | self.kernels = [] 139 | self.biases = [] 140 | 141 | for i in range(self.depth): 142 | if i == 0: 143 | input_kernel = self.add_weight(shape=(input_dim, self.units * 4), 144 | name='input_kernel_%d' % (i + 1), 145 | initializer=self.kernel_initializer, 146 | regularizer=self.kernel_regularizer, 147 | constraint=self.kernel_constraint) 148 | hidden_kernel = self.add_weight(shape=(self.units, self.units * 4), 149 | name='kernel_%d' % (i + 1), 150 | initializer=self.recurrent_initializer, 151 | regularizer=self.recurrent_regularizer, 152 | constraint=self.recurrent_constraint) 153 | kernel = K.concatenate([input_kernel, hidden_kernel], axis=0) 154 | else: 155 | kernel = self.add_weight(shape=(self.units * 2, self.units * 4), 156 | name='kernel_%d' % (i + 1), 157 | initializer=self.recurrent_initializer, 158 | regularizer=self.recurrent_regularizer, 159 | constraint=self.recurrent_constraint) 160 | self.kernels.append(kernel) 161 | 162 | if self.use_bias: 163 | if self.unit_forget_bias: 164 | def bias_initializer(_, *args, **kwargs): 165 | return K.concatenate([ 166 | self.bias_initializer((self.units,), *args, **kwargs), 167 | initializers.Ones()((self.units,), *args, **kwargs), 168 | self.bias_initializer((self.units * 2,), *args, **kwargs), 169 | ]) 170 | else: 171 | bias_initializer = self.bias_initializer 172 | 173 | for i in range(self.depth): 174 | bias = self.add_weight(shape=(self.units * 4,), 175 | name='bias_%d' % (i + 1), 176 | initializer=bias_initializer, 177 | regularizer=self.bias_regularizer, 178 | constraint=self.bias_constraint) 179 | self.biases.append(bias) 180 | else: 181 | self.biases = None 182 | 183 | self.built = True 184 | 185 | def call(self, inputs, states, training=None): 186 | if 0 < self.dropout < 1 and self._dropout_mask is None: 187 | self._dropout_mask = _generate_dropout_mask( 188 | K.ones_like(inputs), 189 | self.dropout, 190 | training=training, 191 | count=1) 192 | if (0 < self.recurrent_dropout < 1 and 193 | self._nested_recurrent_masks is None): 194 | _nested_recurrent_mask = _generate_dropout_mask( 195 | K.ones_like(states[0]), 196 | self.recurrent_dropout, 197 | training=training, 198 | count=self.depth) 199 | self._nested_recurrent_masks = _nested_recurrent_mask 200 | 201 | # dropout matrices for input units 202 | dp_mask = self._dropout_mask 203 | # dropout matrices for recurrent units 204 | rec_dp_masks = self._nested_recurrent_masks 205 | 206 | h_tm1 = states[0] # previous memory state 207 | c_tm1 = states[1:self.depth + 1] # previous carry states 208 | 209 | if 0. < self.dropout < 1.: 210 | inputs *= dp_mask[0] 211 | 212 | h, c = self.nested_recurrence(inputs, 213 | hidden_state=h_tm1, 214 | cell_states=c_tm1, 215 | recurrent_masks=rec_dp_masks, 216 | current_depth=0) 217 | 218 | if 0 < self.dropout + self.recurrent_dropout: 219 | if training is None: 220 | h._uses_learning_phase = True 221 | return h, c 222 | 223 | def nested_recurrence(self, inputs, hidden_state, cell_states, recurrent_masks, current_depth): 224 | h_state = hidden_state 225 | c_state = cell_states[current_depth] 226 | 227 | if 0.0 < self.recurrent_dropout <= 1. and recurrent_masks is not None: 228 | hidden_state = h_state * recurrent_masks[current_depth] 229 | 230 | ip = K.concatenate([inputs, hidden_state], axis=-1) 231 | gate_inputs = K.dot(ip, self.kernels[current_depth]) 232 | 233 | if self.use_bias: 234 | gate_inputs = K.bias_add(gate_inputs, self.biases[current_depth]) 235 | 236 | i = gate_inputs[:, :self.units] # input gate 237 | f = gate_inputs[:, self.units * 2: self.units * 3] # forget gate 238 | c = gate_inputs[:, self.units: 2 * self.units] # new input 239 | o = gate_inputs[:, self.units * 3: self.units * 4] # output gate 240 | 241 | inner_hidden = c_state * self.recurrent_activation(f) 242 | 243 | if current_depth == 0: 244 | inner_input = self.recurrent_activation(i) + self.cell_activation(c) 245 | else: 246 | inner_input = self.recurrent_activation(i) + self.activation(c) 247 | 248 | if (current_depth == self.depth - 1): 249 | new_c = inner_hidden + inner_input 250 | new_cs = [new_c] 251 | else: 252 | new_c, new_cs = self.nested_recurrence(inner_input, 253 | hidden_state=inner_hidden, 254 | cell_states=cell_states, 255 | recurrent_masks=recurrent_masks, 256 | current_depth=current_depth + 1) 257 | 258 | new_h = self.activation(new_c) * self.recurrent_activation(o) 259 | new_cs = [new_h] + new_cs 260 | 261 | return new_h, new_cs 262 | 263 | def get_config(self): 264 | config = {'units': self.units, 265 | 'depth': self.depth, 266 | 'activation': activations.serialize(self.activation), 267 | 'recurrent_activation': activations.serialize(self.recurrent_activation), 268 | 'cell_activation': activations.serialize(self.cell_activation), 269 | 'use_bias': self.use_bias, 270 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 271 | 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 272 | 'bias_initializer': initializers.serialize(self.bias_initializer), 273 | 'unit_forget_bias': self.unit_forget_bias, 274 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 275 | 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 276 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 277 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 278 | 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 279 | 'bias_constraint': constraints.serialize(self.bias_constraint), 280 | 'dropout': self.dropout, 281 | 'recurrent_dropout': self.recurrent_dropout, 282 | 'implementation': self.implementation} 283 | base_config = super(NestedLSTMCell, self).get_config() 284 | return dict(list(base_config.items()) + list(config.items())) 285 | 286 | 287 | class NestedLSTM(RNN): 288 | """Nested Long-Short-Term-Memory layer - [Nested LSTMs](https://arxiv.org/abs/1801.10308). 289 | 290 | # Arguments 291 | units: Positive integer, dimensionality of the output space. 292 | depth: Depth of nesting of the memory component. 293 | activation: Activation function to use 294 | (see [activations](../activations.md)). 295 | If you pass None, no activation is applied 296 | (ie. "linear" activation: `a(x) = x`). 297 | recurrent_activation: Activation function to use 298 | for the recurrent step 299 | (see [activations](../activations.md)). 300 | cell_activation: Activation function of the first cell gate. 301 | Note that in the paper only the first cell_activation is identity. 302 | (see [activations](../activations.md)). 303 | use_bias: Boolean, whether the layer uses a bias vector. 304 | kernel_initializer: Initializer for the `kernel` weights matrix, 305 | used for the linear transformation of the inputs. 306 | (see [initializers](../initializers.md)). 307 | recurrent_initializer: Initializer for the `recurrent_kernel` 308 | weights matrix, 309 | used for the linear transformation of the recurrent state. 310 | (see [initializers](../initializers.md)). 311 | bias_initializer: Initializer for the bias vector 312 | (see [initializers](../initializers.md)). 313 | unit_forget_bias: Boolean. 314 | If True, add 1 to the bias of the forget gate at initialization. 315 | Setting it to true will also force `bias_initializer="zeros"`. 316 | This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) 317 | kernel_regularizer: Regularizer function applied to 318 | the `kernel` weights matrix 319 | (see [regularizer](../regularizers.md)). 320 | recurrent_regularizer: Regularizer function applied to 321 | the `recurrent_kernel` weights matrix 322 | (see [regularizer](../regularizers.md)). 323 | bias_regularizer: Regularizer function applied to the bias vector 324 | (see [regularizer](../regularizers.md)). 325 | activity_regularizer: Regularizer function applied to 326 | the output of the layer (its "activation"). 327 | (see [regularizer](../regularizers.md)). 328 | kernel_constraint: Constraint function applied to 329 | the `kernel` weights matrix 330 | (see [constraints](../constraints.md)). 331 | recurrent_constraint: Constraint function applied to 332 | the `recurrent_kernel` weights matrix 333 | (see [constraints](../constraints.md)). 334 | bias_constraint: Constraint function applied to the bias vector 335 | (see [constraints](../constraints.md)). 336 | dropout: Float between 0 and 1. 337 | Fraction of the units to drop for 338 | the linear transformation of the inputs. 339 | recurrent_dropout: Float between 0 and 1. 340 | Fraction of the units to drop for 341 | the linear transformation of the recurrent state. 342 | implementation: Implementation mode, either 1 or 2. 343 | Mode 1 will structure its operations as a larger number of 344 | smaller dot products and additions, whereas mode 2 will 345 | batch them into fewer, larger operations. These modes will 346 | have different performance profiles on different hardware and 347 | for different applications. 348 | return_sequences: Boolean. Whether to return the last output. 349 | in the output sequence, or the full sequence. 350 | return_state: Boolean. Whether to return the last state 351 | in addition to the output. 352 | go_backwards: Boolean (default False). 353 | If True, process the input sequence backwards and return the 354 | reversed sequence. 355 | stateful: Boolean (default False). If True, the last state 356 | for each sample at index i in a batch will be used as initial 357 | state for the sample of index i in the following batch. 358 | unroll: Boolean (default False). 359 | If True, the network will be unrolled, 360 | else a symbolic loop will be used. 361 | Unrolling can speed-up a RNN, 362 | although it tends to be more memory-intensive. 363 | Unrolling is only suitable for short sequences. 364 | 365 | # References 366 | - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) (original 1997 paper) 367 | - [Learning to forget: Continual prediction with NestedLSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015) 368 | - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf) 369 | - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287) 370 | - [Nested LSTMs](https://arxiv.org/abs/1801.10308) 371 | """ 372 | 373 | @interfaces.legacy_recurrent_support 374 | def __init__(self, units, depth, 375 | activation='tanh', 376 | recurrent_activation='sigmoid', 377 | cell_activation='linear', 378 | use_bias=True, 379 | kernel_initializer='glorot_uniform', 380 | recurrent_initializer='orthogonal', 381 | bias_initializer='zeros', 382 | unit_forget_bias=False, 383 | kernel_regularizer=None, 384 | recurrent_regularizer=None, 385 | bias_regularizer=None, 386 | activity_regularizer=None, 387 | kernel_constraint=None, 388 | recurrent_constraint=None, 389 | bias_constraint=None, 390 | dropout=0., 391 | recurrent_dropout=0., 392 | implementation=1, 393 | return_sequences=False, 394 | return_state=False, 395 | go_backwards=False, 396 | stateful=False, 397 | unroll=False, 398 | **kwargs): 399 | if implementation == 0: 400 | warnings.warn('`implementation=0` has been deprecated, ' 401 | 'and now defaults to `implementation=2`.' 402 | 'Please update your layer call.') 403 | if K.backend() == 'theano': 404 | warnings.warn( 405 | 'RNN dropout is no longer supported with the Theano backend ' 406 | 'due to technical limitations. ' 407 | 'You can either set `dropout` and `recurrent_dropout` to 0, ' 408 | 'or use the TensorFlow backend.') 409 | dropout = 0. 410 | recurrent_dropout = 0. 411 | 412 | cell = NestedLSTMCell(units, depth, 413 | activation=activation, 414 | recurrent_activation=recurrent_activation, 415 | cell_activation=cell_activation, 416 | use_bias=use_bias, 417 | kernel_initializer=kernel_initializer, 418 | recurrent_initializer=recurrent_initializer, 419 | unit_forget_bias=unit_forget_bias, 420 | bias_initializer=bias_initializer, 421 | kernel_regularizer=kernel_regularizer, 422 | recurrent_regularizer=recurrent_regularizer, 423 | bias_regularizer=bias_regularizer, 424 | kernel_constraint=kernel_constraint, 425 | recurrent_constraint=recurrent_constraint, 426 | bias_constraint=bias_constraint, 427 | dropout=dropout, 428 | recurrent_dropout=recurrent_dropout, 429 | implementation=implementation) 430 | super(NestedLSTM, self).__init__(cell, 431 | return_sequences=return_sequences, 432 | return_state=return_state, 433 | go_backwards=go_backwards, 434 | stateful=stateful, 435 | unroll=unroll, 436 | **kwargs) 437 | self.activity_regularizer = regularizers.get(activity_regularizer) 438 | 439 | def call(self, inputs, mask=None, training=None, initial_state=None, constants=None): 440 | self.cell._dropout_mask = None 441 | self.cell._nested_recurrent_masks = None 442 | return super(NestedLSTM, self).call(inputs, 443 | mask=mask, 444 | training=training, 445 | initial_state=initial_state, 446 | constants=constants) 447 | 448 | @property 449 | def units(self): 450 | return self.cell.units 451 | 452 | @property 453 | def depth(self): 454 | return self.cell.depth 455 | 456 | @property 457 | def activation(self): 458 | return self.cell.activation 459 | 460 | @property 461 | def recurrent_activation(self): 462 | return self.cell.recurrent_activation 463 | 464 | @property 465 | def cell_activation(self): 466 | return self.cell.cell_activation 467 | 468 | @property 469 | def use_bias(self): 470 | return self.cell.use_bias 471 | 472 | @property 473 | def kernel_initializer(self): 474 | return self.cell.kernel_initializer 475 | 476 | @property 477 | def recurrent_initializer(self): 478 | return self.cell.recurrent_initializer 479 | 480 | @property 481 | def bias_initializer(self): 482 | return self.cell.bias_initializer 483 | 484 | @property 485 | def unit_forget_bias(self): 486 | return self.cell.unit_forget_bias 487 | 488 | @property 489 | def kernel_regularizer(self): 490 | return self.cell.kernel_regularizer 491 | 492 | @property 493 | def recurrent_regularizer(self): 494 | return self.cell.recurrent_regularizer 495 | 496 | @property 497 | def bias_regularizer(self): 498 | return self.cell.bias_regularizer 499 | 500 | @property 501 | def kernel_constraint(self): 502 | return self.cell.kernel_constraint 503 | 504 | @property 505 | def recurrent_constraint(self): 506 | return self.cell.recurrent_constraint 507 | 508 | @property 509 | def bias_constraint(self): 510 | return self.cell.bias_constraint 511 | 512 | @property 513 | def dropout(self): 514 | return self.cell.dropout 515 | 516 | @property 517 | def recurrent_dropout(self): 518 | return self.cell.recurrent_dropout 519 | 520 | @property 521 | def implementation(self): 522 | return self.cell.implementation 523 | 524 | def get_config(self): 525 | config = {'units': self.units, 526 | 'depth': self.depth, 527 | 'activation': activations.serialize(self.activation), 528 | 'recurrent_activation': activations.serialize(self.recurrent_activation), 529 | 'cell_activation': activations.serialize(self.cell_activation), 530 | 'use_bias': self.use_bias, 531 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 532 | 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), 533 | 'bias_initializer': initializers.serialize(self.bias_initializer), 534 | 'unit_forget_bias': self.unit_forget_bias, 535 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 536 | 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), 537 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 538 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 539 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 540 | 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), 541 | 'bias_constraint': constraints.serialize(self.bias_constraint), 542 | 'dropout': self.dropout, 543 | 'recurrent_dropout': self.recurrent_dropout, 544 | 'implementation': self.implementation} 545 | base_config = super(NestedLSTM, self).get_config() 546 | del base_config['cell'] 547 | return dict(list(base_config.items()) + list(config.items())) 548 | 549 | @classmethod 550 | def from_config(cls, config): 551 | if 'implementation' in config and config['implementation'] == 0: 552 | config['implementation'] = 2 553 | return cls(**config) 554 | --------------------------------------------------------------------------------