├── images
├── difference.PNG
└── nested_lstm_diagram.PNG
├── LICENSE
├── README.md
├── .gitignore
├── imdb_nested_lstm.py
└── nested_lstm.py
/images/difference.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/Nested-LSTM/HEAD/images/difference.PNG
--------------------------------------------------------------------------------
/images/nested_lstm_diagram.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/Nested-LSTM/HEAD/images/nested_lstm_diagram.PNG
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Somshubra Majumdar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Nested LSTM
2 | Keras implementation of Nested LSTMs from the paper [Nested LSTMs](https://arxiv.org/abs/1801.10308)
3 |
4 | From the paper:
5 | > Nested LSTMs add depth to LSTMs via nesting as opposed to stacking. The value of a memory cell
6 | in an NLSTM is computed by an LSTM cell, which has its own inner memory cell. Nested LSTMs outperform both stacked and single-layer
7 | LSTMs with similar numbers of parameters in our experiments on various character-level language
8 | modeling tasks, and the inner memories of an LSTM learn longer term dependencies compared with
9 | the higher-level units of a stacked LSTM
10 |
11 | # Usage
12 | Via Cells
13 | ```python
14 | from nested_lstm import NestedLSTMCell
15 | from keras.layers import RNN
16 |
17 | ip = Input(shape=(nb_timesteps, input_dim))
18 | x = RNN(NestedLSTMCell(units=64, depth=2))(ip)
19 | ...
20 | ```
21 |
22 | Via Layer
23 | ```python
24 | from nested_lstm import NestedLSTM
25 |
26 | ip = Input(shape=(nb_timesteps, input_dim))
27 | x = NestedLSTM(units=64, depth=2)(ip)
28 | ...
29 | ```
30 |
31 | # Difference between Stacked LSTMs and Nested LSTMs (from the paper)
32 |
33 |
34 | # Cell diagram (depth = 2, from the paper)
35 |
36 |
37 | # Acknowledgements
38 | Keras code heavily derived from the Tensorflow implementation - https://github.com/hannw/nlstm
39 |
40 | # Requirements
41 | - Keras 2.1.3+
42 | - Tensorflow 1.2+ or Theano. CNTK untested.
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # weights
104 | weights/*
105 | logs/*
106 |
--------------------------------------------------------------------------------
/imdb_nested_lstm.py:
--------------------------------------------------------------------------------
1 | '''Trains a Minimal RNN on the IMDB sentiment classification task.
2 | The dataset is actually too small for Minimal RNN to be of any advantage
3 | compared to simpler, much faster methods such as TF-IDF + LogReg.
4 | '''
5 | from __future__ import print_function
6 |
7 | from keras.preprocessing import sequence
8 | from keras.models import Sequential
9 | from keras.layers import Dense, Embedding
10 | from keras.callbacks import ModelCheckpoint
11 | from keras.datasets import imdb
12 |
13 | from nested_lstm import NestedLSTM
14 |
15 | max_features = 20000
16 | maxlen = 500 # cut texts after this number of words (among top max_features most common words)
17 | batch_size = 128
18 |
19 | print('Loading data...')
20 | (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
21 | print(len(x_train), 'train sequences')
22 | print(len(x_test), 'test sequences')
23 |
24 | print('Pad sequences (samples x time)')
25 | x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
26 | x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
27 | print('x_train shape:', x_train.shape)
28 | print('x_test shape:', x_test.shape)
29 |
30 | # configuration matches 4.47 Million parameters with `units=600` and `64 embedding dim`
31 | print('Build model...')
32 | model = Sequential()
33 | model.add(Embedding(max_features, 128))
34 | model.add(NestedLSTM(32, depth=2, dropout=0.0, recurrent_dropout=0.0))
35 | model.add(Dense(1, activation='sigmoid'))
36 |
37 | # try using different optimizers and different optimizer configs
38 | model.compile(loss='binary_crossentropy',
39 | optimizer='adam',
40 | metrics=['accuracy'])
41 |
42 | model.summary()
43 |
44 | print('Train...')
45 | model.fit(x_train, y_train,
46 | batch_size=batch_size,
47 | epochs=15,
48 | validation_data=(x_test, y_test),
49 | callbacks=[ModelCheckpoint('weights/imdb_nlstm.h5', monitor='val_acc',
50 | save_best_only=True, save_weights_only=True)])
51 |
52 | model.load_weights('weights/imdb_nlstm.h5')
53 |
54 | score, acc = model.evaluate(x_test, y_test,
55 | batch_size=batch_size)
56 | print('Test score:', score)
57 | print('Test accuracy:', acc)
58 |
--------------------------------------------------------------------------------
/nested_lstm.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import warnings
3 |
4 | from keras import backend as K
5 | from keras import activations
6 | from keras import initializers
7 | from keras import regularizers
8 | from keras import constraints
9 | from keras.engine import Layer
10 | from keras.engine import InputSpec
11 | from keras.legacy import interfaces
12 | from keras.layers import RNN
13 | from keras.layers.recurrent import _generate_dropout_mask
14 | from keras.layers import LSTMCell, LSTM
15 |
16 |
17 | class NestedLSTMCell(Layer):
18 | """Nested NestedLSTM Cell class.
19 |
20 | Derived from the paper [Nested LSTMs](https://arxiv.org/abs/1801.10308)
21 | Ref: [Tensorflow implementation](https://github.com/hannw/nlstm)
22 |
23 | # Arguments
24 | units: Positive integer, dimensionality of the output space.
25 | depth: Depth of nesting of the memory component.
26 | activation: Activation function to use
27 | (see [activations](../activations.md)).
28 | If you pass None, no activation is applied
29 | (ie. "linear" activation: `a(x) = x`).
30 | recurrent_activation: Activation function to use
31 | for the recurrent step
32 | (see [activations](../activations.md)).
33 | cell_activation: Activation function of the first cell gate.
34 | Note that in the paper only the first cell_activation is identity.
35 | (see [activations](../activations.md)).
36 | use_bias: Boolean, whether the layer uses a bias vector.
37 | kernel_initializer: Initializer for the `kernel` weights matrix,
38 | used for the linear transformation of the inputs
39 | (see [initializers](../initializers.md)).
40 | recurrent_initializer: Initializer for the `recurrent_kernel`
41 | weights matrix,
42 | used for the linear transformation of the recurrent state
43 | (see [initializers](../initializers.md)).
44 | bias_initializer: Initializer for the bias vector
45 | (see [initializers](../initializers.md)).
46 | unit_forget_bias: Boolean.
47 | If True, add 1 to the bias of the forget gate at initialization.
48 | Setting it to true will also force `bias_initializer="zeros"`.
49 | This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
50 | kernel_regularizer: Regularizer function applied to
51 | the `kernel` weights matrix
52 | (see [regularizer](../regularizers.md)).
53 | recurrent_regularizer: Regularizer function applied to
54 | the `recurrent_kernel` weights matrix
55 | (see [regularizer](../regularizers.md)).
56 | bias_regularizer: Regularizer function applied to the bias vector
57 | (see [regularizer](../regularizers.md)).
58 | kernel_constraint: Constraint function applied to
59 | the `kernel` weights matrix
60 | (see [constraints](../constraints.md)).
61 | recurrent_constraint: Constraint function applied to
62 | the `recurrent_kernel` weights matrix
63 | (see [constraints](../constraints.md)).
64 | bias_constraint: Constraint function applied to the bias vector
65 | (see [constraints](../constraints.md)).
66 | dropout: Float between 0 and 1.
67 | Fraction of the units to drop for
68 | the linear transformation of the inputs.
69 | recurrent_dropout: Float between 0 and 1.
70 | Fraction of the units to drop for
71 | the linear transformation of the recurrent state.
72 | implementation: Implementation mode, must be 2.
73 | Mode 1 will structure its operations as a larger number of
74 | smaller dot products and additions, whereas mode 2 will
75 | batch them into fewer, larger operations. These modes will
76 | have different performance profiles on different hardware and
77 | for different applications.
78 | """
79 |
80 | def __init__(self, units, depth,
81 | activation='tanh',
82 | recurrent_activation='sigmoid',
83 | cell_activation='linear',
84 | use_bias=True,
85 | kernel_initializer='glorot_uniform',
86 | recurrent_initializer='orthogonal',
87 | bias_initializer='zeros',
88 | unit_forget_bias=False,
89 | kernel_regularizer=None,
90 | recurrent_regularizer=None,
91 | bias_regularizer=None,
92 | kernel_constraint=None,
93 | recurrent_constraint=None,
94 | bias_constraint=None,
95 | dropout=0.,
96 | recurrent_dropout=0.,
97 | implementation=2,
98 | **kwargs):
99 | super(NestedLSTMCell, self).__init__(**kwargs)
100 |
101 | if depth < 1:
102 | raise ValueError("`depth` must be at least 1. For better performance, consider using depth > 1.")
103 |
104 | if implementation != 1:
105 | warnings.warn(
106 | "Nested LSTMs only supports implementation 2 for the moment. Defaulting to implementation = 2")
107 | implementation = 2
108 |
109 | self.units = units
110 | self.depth = depth
111 | self.activation = activations.get(activation)
112 | self.recurrent_activation = activations.get(recurrent_activation)
113 | self.cell_activation = activations.get(cell_activation)
114 | self.use_bias = use_bias
115 |
116 | self.kernel_initializer = initializers.get(kernel_initializer)
117 | self.recurrent_initializer = initializers.get(recurrent_initializer)
118 | self.bias_initializer = initializers.get(bias_initializer)
119 | self.unit_forget_bias = unit_forget_bias
120 |
121 | self.kernel_regularizer = regularizers.get(kernel_regularizer)
122 | self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
123 | self.bias_regularizer = regularizers.get(bias_regularizer)
124 |
125 | self.kernel_constraint = constraints.get(kernel_constraint)
126 | self.recurrent_constraint = constraints.get(recurrent_constraint)
127 | self.bias_constraint = constraints.get(bias_constraint)
128 |
129 | self.dropout = min(1., max(0., dropout))
130 | self.recurrent_dropout = min(1., max(0., recurrent_dropout))
131 | self.implementation = implementation
132 | self.state_size = tuple([self.units] * (self.depth + 1))
133 | self._dropout_mask = None
134 | self._nested_recurrent_masks = None
135 |
136 | def build(self, input_shape):
137 | input_dim = input_shape[-1]
138 | self.kernels = []
139 | self.biases = []
140 |
141 | for i in range(self.depth):
142 | if i == 0:
143 | input_kernel = self.add_weight(shape=(input_dim, self.units * 4),
144 | name='input_kernel_%d' % (i + 1),
145 | initializer=self.kernel_initializer,
146 | regularizer=self.kernel_regularizer,
147 | constraint=self.kernel_constraint)
148 | hidden_kernel = self.add_weight(shape=(self.units, self.units * 4),
149 | name='kernel_%d' % (i + 1),
150 | initializer=self.recurrent_initializer,
151 | regularizer=self.recurrent_regularizer,
152 | constraint=self.recurrent_constraint)
153 | kernel = K.concatenate([input_kernel, hidden_kernel], axis=0)
154 | else:
155 | kernel = self.add_weight(shape=(self.units * 2, self.units * 4),
156 | name='kernel_%d' % (i + 1),
157 | initializer=self.recurrent_initializer,
158 | regularizer=self.recurrent_regularizer,
159 | constraint=self.recurrent_constraint)
160 | self.kernels.append(kernel)
161 |
162 | if self.use_bias:
163 | if self.unit_forget_bias:
164 | def bias_initializer(_, *args, **kwargs):
165 | return K.concatenate([
166 | self.bias_initializer((self.units,), *args, **kwargs),
167 | initializers.Ones()((self.units,), *args, **kwargs),
168 | self.bias_initializer((self.units * 2,), *args, **kwargs),
169 | ])
170 | else:
171 | bias_initializer = self.bias_initializer
172 |
173 | for i in range(self.depth):
174 | bias = self.add_weight(shape=(self.units * 4,),
175 | name='bias_%d' % (i + 1),
176 | initializer=bias_initializer,
177 | regularizer=self.bias_regularizer,
178 | constraint=self.bias_constraint)
179 | self.biases.append(bias)
180 | else:
181 | self.biases = None
182 |
183 | self.built = True
184 |
185 | def call(self, inputs, states, training=None):
186 | if 0 < self.dropout < 1 and self._dropout_mask is None:
187 | self._dropout_mask = _generate_dropout_mask(
188 | K.ones_like(inputs),
189 | self.dropout,
190 | training=training,
191 | count=1)
192 | if (0 < self.recurrent_dropout < 1 and
193 | self._nested_recurrent_masks is None):
194 | _nested_recurrent_mask = _generate_dropout_mask(
195 | K.ones_like(states[0]),
196 | self.recurrent_dropout,
197 | training=training,
198 | count=self.depth)
199 | self._nested_recurrent_masks = _nested_recurrent_mask
200 |
201 | # dropout matrices for input units
202 | dp_mask = self._dropout_mask
203 | # dropout matrices for recurrent units
204 | rec_dp_masks = self._nested_recurrent_masks
205 |
206 | h_tm1 = states[0] # previous memory state
207 | c_tm1 = states[1:self.depth + 1] # previous carry states
208 |
209 | if 0. < self.dropout < 1.:
210 | inputs *= dp_mask[0]
211 |
212 | h, c = self.nested_recurrence(inputs,
213 | hidden_state=h_tm1,
214 | cell_states=c_tm1,
215 | recurrent_masks=rec_dp_masks,
216 | current_depth=0)
217 |
218 | if 0 < self.dropout + self.recurrent_dropout:
219 | if training is None:
220 | h._uses_learning_phase = True
221 | return h, c
222 |
223 | def nested_recurrence(self, inputs, hidden_state, cell_states, recurrent_masks, current_depth):
224 | h_state = hidden_state
225 | c_state = cell_states[current_depth]
226 |
227 | if 0.0 < self.recurrent_dropout <= 1. and recurrent_masks is not None:
228 | hidden_state = h_state * recurrent_masks[current_depth]
229 |
230 | ip = K.concatenate([inputs, hidden_state], axis=-1)
231 | gate_inputs = K.dot(ip, self.kernels[current_depth])
232 |
233 | if self.use_bias:
234 | gate_inputs = K.bias_add(gate_inputs, self.biases[current_depth])
235 |
236 | i = gate_inputs[:, :self.units] # input gate
237 | f = gate_inputs[:, self.units * 2: self.units * 3] # forget gate
238 | c = gate_inputs[:, self.units: 2 * self.units] # new input
239 | o = gate_inputs[:, self.units * 3: self.units * 4] # output gate
240 |
241 | inner_hidden = c_state * self.recurrent_activation(f)
242 |
243 | if current_depth == 0:
244 | inner_input = self.recurrent_activation(i) + self.cell_activation(c)
245 | else:
246 | inner_input = self.recurrent_activation(i) + self.activation(c)
247 |
248 | if (current_depth == self.depth - 1):
249 | new_c = inner_hidden + inner_input
250 | new_cs = [new_c]
251 | else:
252 | new_c, new_cs = self.nested_recurrence(inner_input,
253 | hidden_state=inner_hidden,
254 | cell_states=cell_states,
255 | recurrent_masks=recurrent_masks,
256 | current_depth=current_depth + 1)
257 |
258 | new_h = self.activation(new_c) * self.recurrent_activation(o)
259 | new_cs = [new_h] + new_cs
260 |
261 | return new_h, new_cs
262 |
263 | def get_config(self):
264 | config = {'units': self.units,
265 | 'depth': self.depth,
266 | 'activation': activations.serialize(self.activation),
267 | 'recurrent_activation': activations.serialize(self.recurrent_activation),
268 | 'cell_activation': activations.serialize(self.cell_activation),
269 | 'use_bias': self.use_bias,
270 | 'kernel_initializer': initializers.serialize(self.kernel_initializer),
271 | 'recurrent_initializer': initializers.serialize(self.recurrent_initializer),
272 | 'bias_initializer': initializers.serialize(self.bias_initializer),
273 | 'unit_forget_bias': self.unit_forget_bias,
274 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
275 | 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer),
276 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
277 | 'kernel_constraint': constraints.serialize(self.kernel_constraint),
278 | 'recurrent_constraint': constraints.serialize(self.recurrent_constraint),
279 | 'bias_constraint': constraints.serialize(self.bias_constraint),
280 | 'dropout': self.dropout,
281 | 'recurrent_dropout': self.recurrent_dropout,
282 | 'implementation': self.implementation}
283 | base_config = super(NestedLSTMCell, self).get_config()
284 | return dict(list(base_config.items()) + list(config.items()))
285 |
286 |
287 | class NestedLSTM(RNN):
288 | """Nested Long-Short-Term-Memory layer - [Nested LSTMs](https://arxiv.org/abs/1801.10308).
289 |
290 | # Arguments
291 | units: Positive integer, dimensionality of the output space.
292 | depth: Depth of nesting of the memory component.
293 | activation: Activation function to use
294 | (see [activations](../activations.md)).
295 | If you pass None, no activation is applied
296 | (ie. "linear" activation: `a(x) = x`).
297 | recurrent_activation: Activation function to use
298 | for the recurrent step
299 | (see [activations](../activations.md)).
300 | cell_activation: Activation function of the first cell gate.
301 | Note that in the paper only the first cell_activation is identity.
302 | (see [activations](../activations.md)).
303 | use_bias: Boolean, whether the layer uses a bias vector.
304 | kernel_initializer: Initializer for the `kernel` weights matrix,
305 | used for the linear transformation of the inputs.
306 | (see [initializers](../initializers.md)).
307 | recurrent_initializer: Initializer for the `recurrent_kernel`
308 | weights matrix,
309 | used for the linear transformation of the recurrent state.
310 | (see [initializers](../initializers.md)).
311 | bias_initializer: Initializer for the bias vector
312 | (see [initializers](../initializers.md)).
313 | unit_forget_bias: Boolean.
314 | If True, add 1 to the bias of the forget gate at initialization.
315 | Setting it to true will also force `bias_initializer="zeros"`.
316 | This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
317 | kernel_regularizer: Regularizer function applied to
318 | the `kernel` weights matrix
319 | (see [regularizer](../regularizers.md)).
320 | recurrent_regularizer: Regularizer function applied to
321 | the `recurrent_kernel` weights matrix
322 | (see [regularizer](../regularizers.md)).
323 | bias_regularizer: Regularizer function applied to the bias vector
324 | (see [regularizer](../regularizers.md)).
325 | activity_regularizer: Regularizer function applied to
326 | the output of the layer (its "activation").
327 | (see [regularizer](../regularizers.md)).
328 | kernel_constraint: Constraint function applied to
329 | the `kernel` weights matrix
330 | (see [constraints](../constraints.md)).
331 | recurrent_constraint: Constraint function applied to
332 | the `recurrent_kernel` weights matrix
333 | (see [constraints](../constraints.md)).
334 | bias_constraint: Constraint function applied to the bias vector
335 | (see [constraints](../constraints.md)).
336 | dropout: Float between 0 and 1.
337 | Fraction of the units to drop for
338 | the linear transformation of the inputs.
339 | recurrent_dropout: Float between 0 and 1.
340 | Fraction of the units to drop for
341 | the linear transformation of the recurrent state.
342 | implementation: Implementation mode, either 1 or 2.
343 | Mode 1 will structure its operations as a larger number of
344 | smaller dot products and additions, whereas mode 2 will
345 | batch them into fewer, larger operations. These modes will
346 | have different performance profiles on different hardware and
347 | for different applications.
348 | return_sequences: Boolean. Whether to return the last output.
349 | in the output sequence, or the full sequence.
350 | return_state: Boolean. Whether to return the last state
351 | in addition to the output.
352 | go_backwards: Boolean (default False).
353 | If True, process the input sequence backwards and return the
354 | reversed sequence.
355 | stateful: Boolean (default False). If True, the last state
356 | for each sample at index i in a batch will be used as initial
357 | state for the sample of index i in the following batch.
358 | unroll: Boolean (default False).
359 | If True, the network will be unrolled,
360 | else a symbolic loop will be used.
361 | Unrolling can speed-up a RNN,
362 | although it tends to be more memory-intensive.
363 | Unrolling is only suitable for short sequences.
364 |
365 | # References
366 | - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf) (original 1997 paper)
367 | - [Learning to forget: Continual prediction with NestedLSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
368 | - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
369 | - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
370 | - [Nested LSTMs](https://arxiv.org/abs/1801.10308)
371 | """
372 |
373 | @interfaces.legacy_recurrent_support
374 | def __init__(self, units, depth,
375 | activation='tanh',
376 | recurrent_activation='sigmoid',
377 | cell_activation='linear',
378 | use_bias=True,
379 | kernel_initializer='glorot_uniform',
380 | recurrent_initializer='orthogonal',
381 | bias_initializer='zeros',
382 | unit_forget_bias=False,
383 | kernel_regularizer=None,
384 | recurrent_regularizer=None,
385 | bias_regularizer=None,
386 | activity_regularizer=None,
387 | kernel_constraint=None,
388 | recurrent_constraint=None,
389 | bias_constraint=None,
390 | dropout=0.,
391 | recurrent_dropout=0.,
392 | implementation=1,
393 | return_sequences=False,
394 | return_state=False,
395 | go_backwards=False,
396 | stateful=False,
397 | unroll=False,
398 | **kwargs):
399 | if implementation == 0:
400 | warnings.warn('`implementation=0` has been deprecated, '
401 | 'and now defaults to `implementation=2`.'
402 | 'Please update your layer call.')
403 | if K.backend() == 'theano':
404 | warnings.warn(
405 | 'RNN dropout is no longer supported with the Theano backend '
406 | 'due to technical limitations. '
407 | 'You can either set `dropout` and `recurrent_dropout` to 0, '
408 | 'or use the TensorFlow backend.')
409 | dropout = 0.
410 | recurrent_dropout = 0.
411 |
412 | cell = NestedLSTMCell(units, depth,
413 | activation=activation,
414 | recurrent_activation=recurrent_activation,
415 | cell_activation=cell_activation,
416 | use_bias=use_bias,
417 | kernel_initializer=kernel_initializer,
418 | recurrent_initializer=recurrent_initializer,
419 | unit_forget_bias=unit_forget_bias,
420 | bias_initializer=bias_initializer,
421 | kernel_regularizer=kernel_regularizer,
422 | recurrent_regularizer=recurrent_regularizer,
423 | bias_regularizer=bias_regularizer,
424 | kernel_constraint=kernel_constraint,
425 | recurrent_constraint=recurrent_constraint,
426 | bias_constraint=bias_constraint,
427 | dropout=dropout,
428 | recurrent_dropout=recurrent_dropout,
429 | implementation=implementation)
430 | super(NestedLSTM, self).__init__(cell,
431 | return_sequences=return_sequences,
432 | return_state=return_state,
433 | go_backwards=go_backwards,
434 | stateful=stateful,
435 | unroll=unroll,
436 | **kwargs)
437 | self.activity_regularizer = regularizers.get(activity_regularizer)
438 |
439 | def call(self, inputs, mask=None, training=None, initial_state=None, constants=None):
440 | self.cell._dropout_mask = None
441 | self.cell._nested_recurrent_masks = None
442 | return super(NestedLSTM, self).call(inputs,
443 | mask=mask,
444 | training=training,
445 | initial_state=initial_state,
446 | constants=constants)
447 |
448 | @property
449 | def units(self):
450 | return self.cell.units
451 |
452 | @property
453 | def depth(self):
454 | return self.cell.depth
455 |
456 | @property
457 | def activation(self):
458 | return self.cell.activation
459 |
460 | @property
461 | def recurrent_activation(self):
462 | return self.cell.recurrent_activation
463 |
464 | @property
465 | def cell_activation(self):
466 | return self.cell.cell_activation
467 |
468 | @property
469 | def use_bias(self):
470 | return self.cell.use_bias
471 |
472 | @property
473 | def kernel_initializer(self):
474 | return self.cell.kernel_initializer
475 |
476 | @property
477 | def recurrent_initializer(self):
478 | return self.cell.recurrent_initializer
479 |
480 | @property
481 | def bias_initializer(self):
482 | return self.cell.bias_initializer
483 |
484 | @property
485 | def unit_forget_bias(self):
486 | return self.cell.unit_forget_bias
487 |
488 | @property
489 | def kernel_regularizer(self):
490 | return self.cell.kernel_regularizer
491 |
492 | @property
493 | def recurrent_regularizer(self):
494 | return self.cell.recurrent_regularizer
495 |
496 | @property
497 | def bias_regularizer(self):
498 | return self.cell.bias_regularizer
499 |
500 | @property
501 | def kernel_constraint(self):
502 | return self.cell.kernel_constraint
503 |
504 | @property
505 | def recurrent_constraint(self):
506 | return self.cell.recurrent_constraint
507 |
508 | @property
509 | def bias_constraint(self):
510 | return self.cell.bias_constraint
511 |
512 | @property
513 | def dropout(self):
514 | return self.cell.dropout
515 |
516 | @property
517 | def recurrent_dropout(self):
518 | return self.cell.recurrent_dropout
519 |
520 | @property
521 | def implementation(self):
522 | return self.cell.implementation
523 |
524 | def get_config(self):
525 | config = {'units': self.units,
526 | 'depth': self.depth,
527 | 'activation': activations.serialize(self.activation),
528 | 'recurrent_activation': activations.serialize(self.recurrent_activation),
529 | 'cell_activation': activations.serialize(self.cell_activation),
530 | 'use_bias': self.use_bias,
531 | 'kernel_initializer': initializers.serialize(self.kernel_initializer),
532 | 'recurrent_initializer': initializers.serialize(self.recurrent_initializer),
533 | 'bias_initializer': initializers.serialize(self.bias_initializer),
534 | 'unit_forget_bias': self.unit_forget_bias,
535 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
536 | 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer),
537 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
538 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer),
539 | 'kernel_constraint': constraints.serialize(self.kernel_constraint),
540 | 'recurrent_constraint': constraints.serialize(self.recurrent_constraint),
541 | 'bias_constraint': constraints.serialize(self.bias_constraint),
542 | 'dropout': self.dropout,
543 | 'recurrent_dropout': self.recurrent_dropout,
544 | 'implementation': self.implementation}
545 | base_config = super(NestedLSTM, self).get_config()
546 | del base_config['cell']
547 | return dict(list(base_config.items()) + list(config.items()))
548 |
549 | @classmethod
550 | def from_config(cls, config):
551 | if 'implementation' in config and config['implementation'] == 0:
552 | config['implementation'] = 2
553 | return cls(**config)
554 |
--------------------------------------------------------------------------------