├── BN_LSTMCell.py └── README.md /BN_LSTMCell.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2016-2017 by Akira TAMAMORI 2 | # 3 | # This program is free software; you can redistribute it and/or modify it under 4 | # the terms of the GNU General Public License as published by the Free Software 5 | # Foundation, either version 3 of the License, or (at your option) any later 6 | # version. 7 | # 8 | # This program is distributed in the hope that it will be useful, but WITHOUT 9 | # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 10 | # FOR A PARTICULAR PURPOSE. See the GNU General Public License for more 11 | # details. 12 | # 13 | # You should have received a copy of the GNU General Public License along with 14 | # this program. If not, see . 15 | 16 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 17 | # 18 | # Unless required by applicable law or agreed to in writing, software 19 | # distributed under the License is distributed on an "AS IS" BASIS, 20 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 21 | # See the License for the specific language governing permissions and 22 | # limitations under the License. 23 | 24 | # Commentary: 25 | # TODO: implemation of another initializer for LSTM 26 | 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | from tensorflow.contrib.rnn import RNNCell, LSTMStateTuple 31 | 32 | 33 | # Thanks to 'initializers_enhanced.py' of Project RNN Enhancement: 34 | # https://github.com/nicolas-ivanov/Seq2Seq_Upgrade_TensorFlow/blob/master/rnn_enhancement/initializers_enhanced.py 35 | def orthogonal_initializer(scale=1.0): 36 | def _initializer(shape, dtype=tf.float32, partition_info=None): 37 | if partition_info is not None: 38 | ValueError( 39 | "Do not know what to do with partition_info in BN_LSTMCell") 40 | flat_shape = (shape[0], np.prod(shape[1:])) 41 | a = np.random.normal(0.0, 1.0, flat_shape) 42 | u, _, v = np.linalg.svd(a, full_matrices=False) 43 | q = u if u.shape == flat_shape else v 44 | q = q.reshape(shape) 45 | return tf.constant(scale * q[:shape[0], :shape[1]], dtype=dtype) 46 | return _initializer 47 | 48 | 49 | # Thanks to https://github.com/OlavHN/bnlstm 50 | def batch_norm(inputs, name_scope, is_training, epsilon=1e-3, decay=0.99): 51 | with tf.variable_scope(name_scope): 52 | size = inputs.get_shape().as_list()[1] 53 | 54 | scale = tf.get_variable( 55 | 'scale', [size], initializer=tf.constant_initializer(0.1)) 56 | offset = tf.get_variable('offset', [size]) 57 | 58 | population_mean = tf.get_variable( 59 | 'population_mean', [size], 60 | initializer=tf.zeros_initializer(), trainable=False) 61 | population_var = tf.get_variable( 62 | 'population_var', [size], 63 | initializer=tf.ones_initializer(), trainable=False) 64 | batch_mean, batch_var = tf.nn.moments(inputs, [0]) 65 | 66 | # The following part is based on the implementation of : 67 | # https://github.com/cooijmanstim/recurrent-batch-normalization 68 | train_mean_op = tf.assign( 69 | population_mean, 70 | population_mean * decay + batch_mean * (1 - decay)) 71 | train_var_op = tf.assign( 72 | population_var, population_var * decay + batch_var * (1 - decay)) 73 | 74 | if is_training is True: 75 | with tf.control_dependencies([train_mean_op, train_var_op]): 76 | return tf.nn.batch_normalization( 77 | inputs, batch_mean, batch_var, offset, scale, epsilon) 78 | else: 79 | return tf.nn.batch_normalization( 80 | inputs, population_mean, population_var, offset, scale, 81 | epsilon) 82 | 83 | 84 | class BN_LSTMCell(RNNCell): 85 | """LSTM cell with Recurrent Batch Normalization. 86 | 87 | This implementation is based on: 88 | http://arxiv.org/abs/1603.09025 89 | 90 | This implementation is also based on: 91 | https://github.com/OlavHN/bnlstm 92 | https://github.com/nicolas-ivanov/Seq2Seq_Upgrade_TensorFlow 93 | 94 | """ 95 | 96 | def __init__(self, num_units, is_training, 97 | use_peepholes=False, cell_clip=None, 98 | initializer=orthogonal_initializer(), 99 | num_proj=None, proj_clip=None, 100 | forget_bias=1.0, 101 | state_is_tuple=True, 102 | activation=tf.tanh): 103 | """Initialize the parameters for an LSTM cell. 104 | Args: 105 | num_units: int, The number of units in the LSTM cell. 106 | is_training: bool, set True when training. 107 | use_peepholes: bool, set True to enable diagonal/peephole 108 | connections. 109 | cell_clip: (optional) A float value, if provided the cell state 110 | is clipped by this value prior to the cell output activation. 111 | initializer: (optional) The initializer to use for the weight 112 | matrices. 113 | num_proj: (optional) int, The output dimensionality for 114 | the projection matrices. If None, no projection is performed. 115 | forget_bias: Biases of the forget gate are initialized by default 116 | to 1 in order to reduce the scale of forgetting at the beginning of 117 | the training. 118 | state_is_tuple: If True, accepted and returned states are 2-tuples of 119 | the `c_state` and `m_state`. If False, they are concatenated 120 | along the column axis. 121 | activation: Activation function of the inner states. 122 | """ 123 | if not state_is_tuple: 124 | tf.logging.log_first_n( 125 | tf.logging.WARN, 126 | "%s: Using a concatenated state is slower and " 127 | " will soon be deprecated. Use state_is_tuple=True.", 1, self) 128 | 129 | self.num_units = num_units 130 | self.is_training = is_training 131 | self.use_peepholes = use_peepholes 132 | self.cell_clip = cell_clip 133 | self.num_proj = num_proj 134 | self.proj_clip = proj_clip 135 | self.initializer = initializer 136 | self.forget_bias = forget_bias 137 | self._state_is_tuple = state_is_tuple 138 | self.activation = activation 139 | 140 | if num_proj: 141 | self._state_size = ( 142 | LSTMStateTuple(num_units, num_proj) 143 | if state_is_tuple else num_units + num_proj) 144 | self._output_size = num_proj 145 | else: 146 | self._state_size = ( 147 | LSTMStateTuple(num_units, num_units) 148 | if state_is_tuple else 2 * num_units) 149 | self._output_size = num_units 150 | 151 | @property 152 | def state_size(self): 153 | return self._state_size 154 | 155 | @property 156 | def output_size(self): 157 | return self._output_size 158 | 159 | def __call__(self, inputs, state, scope=None): 160 | 161 | num_proj = self.num_units if self.num_proj is None else self.num_proj 162 | 163 | if self._state_is_tuple: 164 | (c_prev, h_prev) = state 165 | else: 166 | c_prev = tf.slice(state, [0, 0], [-1, self.num_units]) 167 | h_prev = tf.slice(state, [0, self.num_units], [-1, num_proj]) 168 | 169 | dtype = inputs.dtype 170 | input_size = inputs.get_shape().with_rank(2)[1] 171 | 172 | with tf.variable_scope(scope or type(self).__name__): 173 | if input_size.value is None: 174 | raise ValueError( 175 | "Could not infer input size from inputs.get_shape()[-1]") 176 | 177 | W_xh = tf.get_variable( 178 | 'W_xh', 179 | [input_size, 4 * self.num_units], 180 | initializer=self.initializer) 181 | W_hh = tf.get_variable( 182 | 'W_hh', 183 | [num_proj, 4 * self.num_units], 184 | initializer=self.initializer) 185 | bias = tf.get_variable('B', [4 * self.num_units]) 186 | 187 | xh = tf.matmul(inputs, W_xh) 188 | hh = tf.matmul(h_prev, W_hh) 189 | 190 | bn_xh = batch_norm(xh, 'xh', self.is_training) 191 | bn_hh = batch_norm(hh, 'hh', self.is_training) 192 | 193 | # i:input gate, j:new input, f:forget gate, o:output gate 194 | lstm_matrix = tf.nn.bias_add(tf.add(bn_xh, bn_hh), bias) 195 | i, j, f, o = tf.split( 196 | value=lstm_matrix, num_or_size_splits=4, axis=1) 197 | 198 | # Diagonal connections 199 | if self.use_peepholes: 200 | w_f_diag = tf.get_variable( 201 | "W_F_diag", shape=[self.num_units], dtype=dtype) 202 | w_i_diag = tf.get_variable( 203 | "W_I_diag", shape=[self.num_units], dtype=dtype) 204 | w_o_diag = tf.get_variable( 205 | "W_O_diag", shape=[self.num_units], dtype=dtype) 206 | 207 | if self.use_peepholes: 208 | c = c_prev * tf.sigmoid(f + self.forget_bias + 209 | w_f_diag * c_prev) + \ 210 | tf.sigmoid(i + w_i_diag * c_prev) * self.activation(j) 211 | else: 212 | c = c_prev * tf.sigmoid(f + self.forget_bias) + \ 213 | tf.sigmoid(i) * self.activation(j) 214 | 215 | if self.cell_clip is not None: 216 | c = tf.clip_by_value(c, -self.cell_clip, self.cell_clip) 217 | 218 | bn_c = batch_norm(c, 'cell', self.is_training) 219 | 220 | if self.use_peepholes: 221 | h = tf.sigmoid(o + w_o_diag * c) * self.activation(bn_c) 222 | else: 223 | h = tf.sigmoid(o) * self.activation(bn_c) 224 | 225 | if self.num_proj is not None: 226 | w_proj = tf.get_variable( 227 | "W_P", [self.num_units, num_proj], dtype=dtype) 228 | 229 | h = tf.matmul(h, w_proj) 230 | if self.proj_clip is not None: 231 | h = tf.clip_by_value(h, -self.proj_clip, self.proj_clip) 232 | 233 | new_state = (LSTMStateTuple(c, h) 234 | if self.state_is_tuple else tf.concat(1, [c, h])) 235 | 236 | return h, new_state 237 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An implementation of LSTM with Recurrent Batch Normalization. 2 | 3 | This implementation is based on the following paper: 4 | 5 | Tim Cooijmans, Nicolas Ballas, César Laurent, Çağlar Gülçehre, Aaron Courville, "Recurrent Batch Normalization," https://arxiv.org/abs/1603.09025 6 | 7 | The implementation of [DCGAN in TensorFlow](https://github.com/carpedm20/DCGAN-tensorflow) was referred to implement the batch normalization functionality and incorporate it into LSTM cell of TensorFlow. 8 | --------------------------------------------------------------------------------