├── assets ├── mnist.png ├── mode.png ├── mnist_2.png ├── model_gru.png ├── model_gru1.png ├── model_gru2.png ├── model_gru3.png └── model_gru4.png ├── README.md ├── mnist.py └── layers.py /assets/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/mnist.png -------------------------------------------------------------------------------- /assets/mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/mode.png -------------------------------------------------------------------------------- /assets/mnist_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/mnist_2.png -------------------------------------------------------------------------------- /assets/model_gru.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/model_gru.png -------------------------------------------------------------------------------- /assets/model_gru1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/model_gru1.png -------------------------------------------------------------------------------- /assets/model_gru2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/model_gru2.png -------------------------------------------------------------------------------- /assets/model_gru3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/model_gru3.png -------------------------------------------------------------------------------- /assets/model_gru4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pbhatia243/tf-layer-norm/HEAD/assets/model_gru4.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Layer Normalization and Hyper Networks 2 | ================================= 3 | Tensorflow implementation of [Layer Normalization](https://arxiv.org/abs/1607.06450) and [Hyper Networks](https://arxiv.org/pdf/1609.09106v1.pdf). 4 | 5 | This implementation contains: 6 | 7 | 1. Layer Normalization for GRU 8 | 9 | 2. Layer Normalization for LSTM 10 | - Currently normalizing c causes lot of nan's in the model, thus commenting it out for now. 11 | 12 | 3. Hyper Networks for LSTM 13 | 14 | 4. Layer Normalization and Hyper Networks (combined) for LSTM 15 | 16 | ![model_demo](./assets/model_gru1.png) 17 | 18 | 19 | 20 | 21 | Prerequisites 22 | ------------- 23 | 24 | - Python 2.7 or Python 3.3+ 25 | - [NLTK](http://www.nltk.org/) 26 | - [TensorFlow](https://www.tensorflow.org/) >= 0.9 27 | 28 | MNIST 29 | ----- 30 | To evaluate the new model, we train it on MNIST. Here is the model and results using Layer Normalized GRU 31 | 32 | ![histogram](./assets/model_gru3.png) 33 | 34 | 35 | ![scalar](./assets/model_gru4.png) 36 | 37 | 38 | Usage 39 | ----- 40 | 41 | To train a mnist model with different cell_types: 42 | 43 | $ python mnist.py --hidden 128 summaries_dir log/ --cell_type LNGRU 44 | 45 | To train a mnist model with HyperNetworks: 46 | 47 | $ python mnist.py --hidden 128 summaries_dir log/ --cell_type HyperLnLSTMCell --layer_norm 0 48 | 49 | To train a mnist model with HyperNetworks and Layer Normalization: 50 | 51 | $ python mnist.py --hidden 128 summaries_dir log/ --cell_type HyperLnLSTMCell --layer_norm 1 52 | 53 | 54 | 55 | cell_type = [LNGRU, LNLSTM, LSTM , GRU, BasicRNN, HyperLnLSTMCell] 56 | 57 | 58 | To view graph: 59 | 60 | $ tensorboard --logdir log/train/ 61 | 62 | Todo 63 | ----- 64 | 1. Add attention based models ( in progress ). 65 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A Recurent Neural Network (LSTM) implementation example using TensorFlow library. 3 | This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/) 4 | Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf 5 | 6 | Example code is adapted from https://github.com/aymericdamien/TensorFlow-Examples/ 7 | Author: Parminder 8 | ''' 9 | 10 | import tensorflow as tf 11 | from tensorflow.python.ops import rnn, rnn_cell 12 | import numpy as np 13 | from layers import * 14 | # Import MINST data 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 17 | 18 | ''' 19 | To classify images using a reccurent neural network, we consider every image 20 | row as a sequence of pixels. Because MNIST image shape is 28*28px, we will then 21 | handle 28 sequences of 28 steps for every sample. 22 | ''' 23 | 24 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") 25 | tf.app.flags.DEFINE_float("iterations", 100000, 26 | "Number of iterations.") 27 | tf.app.flags.DEFINE_integer("batch_size", 128, 28 | "Batch size to use during training.") 29 | tf.app.flags.DEFINE_integer("display_step", 10, 30 | "How many training steps to do per checkpoint.") 31 | tf.app.flags.DEFINE_integer("hidden", 128, 32 | "How many hidden units.") 33 | tf.app.flags.DEFINE_integer("classes", 10, 34 | "NUmber of classes") 35 | tf.app.flags.DEFINE_integer("layers", 1, 36 | "NUmber of layers for the model") 37 | tf.app.flags.DEFINE_string("cell_type", "LNGRU" , "Select from LSTM, GRU , BasicRNN, LNGRU, LNLSTM, HyperLnLSTMCell") 38 | tf.app.flags.DEFINE_integer("layer_norm", 0 , "Select from LSTM, GRU , BasicRNN, LNGRU, LNLSTM, HyperLnLSTMCell") 39 | FLAGS = tf.app.flags.FLAGS 40 | tf.app.flags.DEFINE_string("summaries_dir", "./log/" , "Directory for summary") 41 | FLAGS = tf.app.flags.FLAGS 42 | # Parameters 43 | learning_rate = FLAGS.learning_rate 44 | training_iters = FLAGS.iterations 45 | batch_size = FLAGS.batch_size 46 | display_step = FLAGS.display_step 47 | 48 | # Network Parameters 49 | n_input = 28 # MNIST data input (img shape: 28*28) 50 | n_steps = 28 # timesteps 51 | n_hidden = FLAGS.hidden # hidden layer num of features 52 | n_classes = FLAGS.classes # MNIST total classes (0-9 digits) 53 | 54 | 55 | def train(): 56 | sess = tf.InteractiveSession() 57 | 58 | 59 | with tf.name_scope('input'): 60 | x = tf.placeholder(tf.float32, [None, n_steps, n_input], name='x-input') 61 | y = tf.placeholder(tf.float32, [None, n_classes], name='y-input') 62 | 63 | weights = { 64 | 'out': tf.Variable(tf.random_normal([n_hidden, n_classes])) 65 | } 66 | biases = { 67 | 'out': tf.Variable(tf.random_normal([n_classes])) 68 | } 69 | 70 | 71 | 72 | def RNN(x, weights, biases, type, layer_norm): 73 | 74 | # Prepare data shape to match `rnn` function requirements 75 | # Current data input shape: (batch_size, n_steps, n_input) 76 | # Required shape: 'n_steps' tensors list of shape (batch_size, n_input) 77 | 78 | # Permuting batch_size and n_steps 79 | x = tf.transpose(x, [1, 0, 2]) 80 | # Reshaping to (n_steps*batch_size, n_input) 81 | x = tf.reshape(x, [-1, n_input]) 82 | # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input) 83 | x = tf.split(0, n_steps, x) 84 | 85 | # Define a lstm cell with tensorflow 86 | cell_class_map = { 87 | "LSTM": rnn_cell.BasicLSTMCell(n_hidden), 88 | "GRU": rnn_cell.GRUCell(n_hidden), 89 | "BasicRNN": rnn_cell.BasicRNNCell(n_hidden), 90 | "LNGRU": LNGRUCell(n_hidden), 91 | "LNLSTM": LNBasicLSTMCell(n_hidden), 92 | 'HyperLnLSTMCell':HyperLnLSTMCell(n_hidden, is_layer_norm = layer_norm)} 93 | 94 | lstm_cell = cell_class_map.get(type) 95 | cell = rnn_cell.MultiRNNCell([lstm_cell] * FLAGS.layers) 96 | print "Using %s model" % type 97 | # Get lstm cell output 98 | outputs, states = rnn.rnn(cell, x, dtype=tf.float32) 99 | 100 | # Linear activation, using rnn inner loop last output 101 | return tf.matmul(outputs[-1], weights['out']) + biases['out'] 102 | 103 | 104 | if FLAGS.layer_norm ==1: 105 | layer_norm = True 106 | else: 107 | layer_norm= False 108 | pred = RNN(x, weights, biases, FLAGS.cell_type, layer_norm) 109 | 110 | # Define loss and optimizer 111 | # print pred 112 | cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y)) 113 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) 114 | 115 | # Evaluate model 116 | correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1)) 117 | accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) 118 | tf.scalar_summary('Accuracy', accuracy) 119 | tf.scalar_summary('Cost', cost) 120 | 121 | merged = tf.merge_all_summaries() 122 | train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + "train/", 123 | sess.graph) 124 | test_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + "test/", 125 | sess.graph) 126 | # Initializing the variables 127 | init = tf.initialize_all_variables() 128 | # print tf.trainable_variables() 129 | for v in tf.trainable_variables(): 130 | print v.name 131 | sess.run(init) 132 | test_len = 128 133 | test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input)) 134 | test_label = mnist.test.labels[:test_len] 135 | step = 1 136 | # Keep training until reach max iterations 137 | while step * batch_size < training_iters: 138 | batch_x, batch_y = mnist.train.next_batch(batch_size) 139 | # Reshape data to get 28 seq of 28 elements 140 | batch_x = batch_x.reshape((batch_size, n_steps, n_input)) 141 | # Run optimization op (backprop) 142 | summary, _ = sess.run([merged,optimizer], feed_dict={x: batch_x, y: batch_y}) 143 | # train_writer.add_summary(summary, step) 144 | if step % display_step == 0: 145 | # Calculate batch accuracy 146 | summary, acc, loss = sess.run([merged,accuracy,cost], feed_dict={x: batch_x, y: batch_y}) 147 | train_writer.add_summary(summary, step) 148 | # Calculate batch loss 149 | print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + \ 150 | "{:.6f}".format(loss) + ", Training Accuracy= " + \ 151 | "{:.5f}".format(acc) 152 | 153 | summary, acc, loss = sess.run([merged, accuracy, cost], feed_dict={x: test_data, y: test_label}) 154 | test_writer.add_summary(summary, step) 155 | print "Testing Accuracy:", acc 156 | step += 1 157 | print "Optimization Finished!" 158 | 159 | # Calculate accuracy for 128 mnist test images 160 | 161 | print "Testing Accuracy:", \ 162 | sess.run(accuracy, feed_dict={x: test_data, y: test_label}) 163 | 164 | 165 | def main(_): 166 | train() 167 | 168 | 169 | if __name__ == '__main__': 170 | tf.app.run() 171 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | 4 | from tensorflow.python.ops import variable_scope as vs 5 | 6 | from tensorflow.python.ops.math_ops import sigmoid 7 | from tensorflow.python.ops.math_ops import tanh 8 | from tensorflow.python.ops import array_ops 9 | 10 | from tensorflow.python.framework import dtypes 11 | from tensorflow.python.framework import ops 12 | from tensorflow.python.ops import array_ops 13 | from tensorflow.python.ops import control_flow_ops 14 | from tensorflow.python.ops import embedding_ops 15 | from tensorflow.python.ops import math_ops 16 | from tensorflow.python.ops import nn_ops 17 | from tensorflow.python.ops import rnn 18 | from tensorflow.python.ops import rnn_cell 19 | from tensorflow.python.ops import variable_scope 20 | import tensorflow as tf 21 | from tensorflow.python.ops import clip_ops 22 | 23 | try: 24 | linear = tf.nn.rnn_cell.linear 25 | except: 26 | from tensorflow.python.ops.rnn_cell import _linear as linear 27 | 28 | 29 | def ln(input, s, b, epsilon = 1e-5, max = 1000): 30 | """ Layer normalizes a 2D tensor along its second axis, which corresponds to batch """ 31 | m, v = tf.nn.moments(input, [1], keep_dims=True) 32 | normalised_input = (input - m) / tf.sqrt(v + epsilon) 33 | return normalised_input * s + b 34 | 35 | 36 | 37 | class LNGRUCell(rnn_cell.RNNCell): 38 | """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" 39 | 40 | def __init__(self, num_units, input_size=None, activation=tanh): 41 | if input_size is not None: 42 | print("%s: The input_size parameter is deprecated." % self) 43 | self._num_units = num_units 44 | self._activation = activation 45 | 46 | @property 47 | def state_size(self): 48 | return self._num_units 49 | 50 | @property 51 | def output_size(self): 52 | return self._num_units 53 | 54 | def __call__(self, inputs, state, scope=None): 55 | """Gated recurrent unit (GRU) with nunits cells.""" 56 | dim = self._num_units 57 | with vs.variable_scope(scope or type(self).__name__): # "GRUCell" 58 | with vs.variable_scope("Gates"): # Reset gate and update gate. 59 | # We start with bias of 1.0 to not reset and not update. 60 | with vs.variable_scope( "Layer_Parameters"): 61 | 62 | s1 = vs.get_variable("s1", initializer=tf.ones([2*dim]), dtype=tf.float32) 63 | s2 = vs.get_variable("s2", initializer=tf.ones([2*dim]), dtype=tf.float32) 64 | s3 = vs.get_variable("s3", initializer=tf.ones([dim]), dtype=tf.float32) 65 | s4 = vs.get_variable("s4", initializer=tf.ones([dim]), dtype=tf.float32) 66 | b1 = vs.get_variable("b1", initializer=tf.zeros([2*dim]), dtype=tf.float32) 67 | b2 = vs.get_variable("b2", initializer=tf.zeros([2*dim]), dtype=tf.float32) 68 | b3 = vs.get_variable("b3", initializer=tf.zeros([dim]), dtype=tf.float32) 69 | b4 = vs.get_variable("b4", initializer=tf.zeros([dim]), dtype=tf.float32) 70 | 71 | 72 | # Code below initialized for all cells 73 | # s1 = tf.Variable(tf.ones([2 * dim]), name="s1") 74 | # s2 = tf.Variable(tf.ones([2 * dim]), name="s2") 75 | # s3 = tf.Variable(tf.ones([dim]), name="s3") 76 | # s4 = tf.Variable(tf.ones([dim]), name="s4") 77 | # b1 = tf.Variable(tf.zeros([2 * dim]), name="b1") 78 | # b2 = tf.Variable(tf.zeros([2 * dim]), name="b2") 79 | # b3 = tf.Variable(tf.zeros([dim]), name="b3") 80 | # b4 = tf.Variable(tf.zeros([dim]), name="b4") 81 | 82 | input_below_ = rnn_cell._linear([inputs], 83 | 2 * self._num_units, False, scope="out_1") 84 | input_below_ = ln(input_below_, s1, b1) 85 | state_below_ = rnn_cell._linear([state], 86 | 2 * self._num_units, False, scope="out_2") 87 | state_below_ = ln(state_below_, s2, b2) 88 | out =tf.add(input_below_, state_below_) 89 | r, u = array_ops.split(1, 2, out) 90 | r, u = sigmoid(r), sigmoid(u) 91 | 92 | with vs.variable_scope("Candidate"): 93 | input_below_x = rnn_cell._linear([inputs], 94 | self._num_units, False, scope="out_3") 95 | input_below_x = ln(input_below_x, s3, b3) 96 | state_below_x = rnn_cell._linear([state], 97 | self._num_units, False, scope="out_4") 98 | state_below_x = ln(state_below_x, s4, b4) 99 | c_pre = tf.add(input_below_x,r * state_below_x) 100 | c = self._activation(c_pre) 101 | new_h = u * state + (1 - u) * c 102 | return new_h, new_h 103 | 104 | _LNLSTMStateTuple = collections.namedtuple("LNLSTMStateTuple", ("c", "h")) 105 | 106 | 107 | class LSTMStateTuple(_LNLSTMStateTuple): 108 | """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. 109 | Stores two elements: `(c, h)`, in that order. 110 | Only used when `state_is_tuple=True`. 111 | """ 112 | __slots__ = () 113 | 114 | class LNBasicLSTMCell(rnn_cell.RNNCell): 115 | """Basic LSTM recurrent network cell. 116 | The implementation is based on: http://arxiv.org/abs/1409.2329. 117 | We add forget_bias (default: 1) to the biases of the forget gate in order to 118 | reduce the scale of forgetting in the beginning of the training. 119 | It does not allow cell clipping, a projection layer, and does not 120 | use peep-hole connections: it is the basic baseline. 121 | For advanced models, please use the full LSTMCell that follows. 122 | """ 123 | 124 | def __init__(self, num_units, forget_bias=1.0, input_size=None, 125 | state_is_tuple=False, activation=tanh): 126 | """Initialize the basic LSTM cell. 127 | Args: 128 | num_units: int, The number of units in the LSTM cell. 129 | forget_bias: float, The bias added to forget gates (see above). 130 | input_size: Deprecated and unused. 131 | state_is_tuple: If True, accepted and returned states are 2-tuples of 132 | the `c_state` and `m_state`. By default (False), they are concatenated 133 | along the column axis. This default behavior will soon be deprecated. 134 | activation: Activation function of the inner states. 135 | """ 136 | if not state_is_tuple: 137 | print("%s: Using a concatenated state is slower and will soon be " 138 | "deprecated. Use state_is_tuple=True.", self) 139 | if input_size is not None: 140 | print("%s: The input_size parameter is deprecated.", self) 141 | self._num_units = num_units 142 | self._forget_bias = forget_bias 143 | self._state_is_tuple = state_is_tuple 144 | self._activation = activation 145 | 146 | @property 147 | def state_size(self): 148 | return (LSTMStateTuple(self._num_units, self._num_units) 149 | if self._state_is_tuple else 2 * self._num_units) 150 | 151 | @property 152 | def output_size(self): 153 | return self._num_units 154 | 155 | def __call__(self, inputs, state, scope=None): 156 | """Long short-term memory cell (LSTM).""" 157 | with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" 158 | # Parameters of gates are concatenated into one multiply for efficiency. 159 | if self._state_is_tuple: 160 | c, h = state 161 | else: 162 | c, h = array_ops.split(1, 2, state) 163 | 164 | s1 = vs.get_variable("s1", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32) 165 | s2 = vs.get_variable("s2", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32) 166 | s3 = vs.get_variable("s3", initializer=tf.ones([self._num_units]), dtype=tf.float32) 167 | 168 | b1 = vs.get_variable("b1", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32) 169 | b2 = vs.get_variable("b2", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32) 170 | b3 = vs.get_variable("b3", initializer=tf.zeros([self._num_units]), dtype=tf.float32) 171 | 172 | # s1 = tf.Variable(tf.ones([4 * self._num_units]), name="s1") 173 | # s2 = tf.Variable(tf.ones([4 * self._num_units]), name="s2") 174 | # s3 = tf.Variable(tf.ones([self._num_units]), name="s3") 175 | # 176 | # b1 = tf.Variable(tf.zeros([4 * self._num_units]), name="b1") 177 | # b2 = tf.Variable(tf.zeros([4 * self._num_units]), name="b2") 178 | # b3 = tf.Variable(tf.zeros([self._num_units]), name="b3") 179 | 180 | input_below_ = rnn_cell._linear([inputs], 181 | 4 * self._num_units, False, scope="out_1") 182 | input_below_ = ln(input_below_, s1, b1) 183 | state_below_ = rnn_cell._linear([h], 184 | 4 * self._num_units, False, scope="out_2") 185 | state_below_ = ln(state_below_, s2, b2) 186 | lstm_matrix = tf.add(input_below_, state_below_) 187 | 188 | i, j, f, o = array_ops.split(1, 4, lstm_matrix) 189 | 190 | new_c = (c * sigmoid(f) + sigmoid(i) * 191 | self._activation(j)) 192 | 193 | # Currently normalizing c causes lot of nan's in the model, thus commenting it out for now. 194 | # new_c_ = ln(new_c, s3, b3) 195 | new_c_ = new_c 196 | new_h = self._activation(new_c_) * sigmoid(o) 197 | 198 | if self._state_is_tuple: 199 | new_state = LSTMStateTuple(new_c, new_h) 200 | else: 201 | new_state = array_ops.concat(1, [new_c, new_h]) 202 | return new_h, new_state 203 | 204 | 205 | class HyperLnLSTMCell(rnn_cell.RNNCell): 206 | """Basic LSTM recurrent network cell. 207 | The implementation is based on: http://arxiv.org/abs/1409.2329. 208 | We add forget_bias (default: 1) to the biases of the forget gate in order to 209 | reduce the scale of forgetting in the beginning of the training. 210 | It does not allow cell clipping, a projection layer, and does not 211 | use peep-hole connections: it is the basic baseline. 212 | For advanced models, please use the full LSTMCell that follows. 213 | """ 214 | 215 | def __init__(self, num_units, forget_bias=1.0, input_size=None, 216 | state_is_tuple=False, activation=tanh, hyper_num_units=128, hyper_embedding_size=32, is_layer_norm = True): 217 | """Initialize the basic LSTM cell. 218 | Args: 219 | num_units: int, The number of units in the LSTM cell. 220 | hyper_num_units: int, The number of units in the HyperLSTM cell. 221 | forget_bias: float, The bias added to forget gates (see above). 222 | input_size: Deprecated and unused. 223 | state_is_tuple: If True, accepted and returned states are 2-tuples of 224 | the `c_state` and `m_state`. By default (False), they are concatenated 225 | along the column axis. This default behavior will soon be deprecated. 226 | activation: Activation function of the inner states. 227 | """ 228 | if not state_is_tuple: 229 | print("%s: Using a concatenated state is slower and will soon be " 230 | "deprecated. Use state_is_tuple=True.", self) 231 | if input_size is not None: 232 | print("%s: The input_size parameter is deprecated.", self) 233 | self._num_units = num_units 234 | self._forget_bias = forget_bias 235 | self._state_is_tuple = state_is_tuple 236 | self._activation = activation 237 | self.hyper_num_units = hyper_num_units 238 | self.total_num_units = self._num_units + self.hyper_num_units 239 | self.hyper_cell = rnn_cell.BasicLSTMCell(hyper_num_units) 240 | self.hyper_embedding_size= hyper_embedding_size 241 | self.is_layer_norm = is_layer_norm 242 | 243 | @property 244 | def state_size(self): 245 | return 2 * self.total_num_units 246 | # return (LSTMStateTuple(self._num_units, self._num_units) 247 | # if self._state_is_tuple else 2 * self._num_units) 248 | 249 | @property 250 | def output_size(self): 251 | return self._num_units 252 | 253 | def hyper_norm(self, layer, dimensions, scope="hyper"): 254 | with tf.variable_scope(scope): 255 | zw = rnn_cell._linear(self.hyper_output, 256 | self.hyper_embedding_size, False, scope=scope+ "z") 257 | alpha = rnn_cell._linear(zw, dimensions, False, scope=scope+ "alpha") 258 | result = tf.mul(alpha, layer) 259 | 260 | return result 261 | 262 | def __call__(self, inputs, state, scope=None): 263 | """Long short-term memory cell (LSTM) with hypernetworks and layer normalization.""" 264 | with vs.variable_scope(scope or type(self).__name__): 265 | # Parameters of gates are concatenated into one multiply for efficiency. 266 | total_h, total_c = tf.split(1, 2, state) 267 | h = total_h[:, 0:self._num_units] 268 | c = total_c[:, 0:self._num_units] 269 | 270 | self.hyper_state = tf.concat(1, [total_h[:, self._num_units:], total_c[:, self._num_units:]]) 271 | hyper_input = tf.concat(1, [inputs, h]) 272 | hyper_output, hyper_new_state = self.hyper_cell(hyper_input, self.hyper_state) 273 | self.hyper_output = hyper_output 274 | self.hyper_state = hyper_new_state 275 | 276 | input_below_ = rnn_cell._linear([inputs], 277 | 4 * self._num_units, False, scope="out_1") 278 | input_below_ = self.hyper_norm(input_below_, 4 * self._num_units, scope="hyper_x") 279 | state_below_ = rnn_cell._linear([h], 280 | 4 * self._num_units, False, scope="out_2") 281 | state_below_ = self.hyper_norm(state_below_, 4 * self._num_units, scope="hyper_h") 282 | 283 | if self.is_layer_norm: 284 | s1 = vs.get_variable("s1", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32) 285 | s2 = vs.get_variable("s2", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32) 286 | s3 = vs.get_variable("s3", initializer=tf.ones([self._num_units]), dtype=tf.float32) 287 | 288 | b1 = vs.get_variable("b1", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32) 289 | b2 = vs.get_variable("b2", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32) 290 | b3 = vs.get_variable("b3", initializer=tf.zeros([self._num_units]), dtype=tf.float32) 291 | 292 | 293 | input_below_ = ln(input_below_, s1, b1) 294 | 295 | 296 | state_below_ = ln(state_below_, s2, b2) 297 | 298 | lstm_matrix = tf.add(input_below_, state_below_) 299 | i, j, f, o = array_ops.split(1, 4, lstm_matrix) 300 | new_c = (c * sigmoid(f) + sigmoid(i) * 301 | self._activation(j)) 302 | 303 | # Currently normalizing c causes lot of nan's in the model, thus commenting it out for now. 304 | # new_c_ = ln(new_c, s3, b3) 305 | new_c_ = new_c 306 | new_h = self._activation(new_c_) * sigmoid(o) 307 | 308 | hyper_h, hyper_c = tf.split(1, 2, hyper_new_state) 309 | new_total_h = tf.concat(1, [new_h, hyper_h]) 310 | new_total_c = tf.concat(1, [new_c, hyper_c]) 311 | new_total_state = tf.concat(1, [new_total_h, new_total_c]) 312 | return new_h, new_total_state 313 | 314 | class LNLSTMCell(rnn_cell.RNNCell): 315 | """Long short-term memory unit (LSTM) recurrent network cell. 316 | The default non-peephole implementation is based on: 317 | http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf 318 | S. Hochreiter and J. Schmidhuber. 319 | "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 320 | The peephole implementation is based on: 321 | https://research.google.com/pubs/archive/43905.pdf 322 | Hasim Sak, Andrew Senior, and Francoise Beaufays. 323 | "Long short-term memory recurrent neural network architectures for 324 | large scale acoustic modeling." INTERSPEECH, 2014. 325 | The class uses optional peep-hole connections, optional cell clipping, and 326 | an optional projection layer. 327 | """ 328 | 329 | def __init__(self, num_units, input_size=None, 330 | initializer=None, num_proj=None, 331 | state_is_tuple=False, activation=tanh): 332 | """Initialize the parameters for an LSTM cell. 333 | Args: 334 | num_units: int, The number of units in the LSTM cell 335 | input_size: Deprecated and unused. 336 | use_peepholes: bool, set True to enable diagonal/peephole connections. 337 | cell_clip: (optional) A float value, if provided the cell state is clipped 338 | by this value prior to the cell output activation. 339 | initializer: (optional) The initializer to use for the weight and 340 | projection matrices. 341 | num_proj: (optional) int, The output dimensionality for the projection 342 | matrices. If None, no projection is performed. 343 | proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 344 | provided, then the projected values are clipped elementwise to within 345 | `[-proj_clip, proj_clip]`. 346 | num_unit_shards: How to split the weight matrix. If >1, the weight 347 | matrix is stored across num_unit_shards. 348 | num_proj_shards: How to split the projection matrix. If >1, the 349 | projection matrix is stored across num_proj_shards. 350 | forget_bias: Biases of the forget gate are initialized by default to 1 351 | in order to reduce the scale of forgetting at the beginning of 352 | the training. 353 | state_is_tuple: If True, accepted and returned states are 2-tuples of 354 | the `c_state` and `m_state`. By default (False), they are concatenated 355 | along the column axis. This default behavior will soon be deprecated. 356 | activation: Activation function of the inner states. 357 | """ 358 | if not state_is_tuple: 359 | print( 360 | "%s: Using a concatenated state is slower and will soon be " 361 | "deprecated. Use state_is_tuple=True." % self) 362 | if input_size is not None: 363 | print("%s: The input_size parameter is deprecated." % self) 364 | self._num_units = num_units 365 | self._initializer = initializer 366 | self._num_proj = num_proj 367 | self._state_is_tuple = state_is_tuple 368 | self._activation = activation 369 | 370 | if num_proj: 371 | self._state_size = ( 372 | LSTMStateTuple(num_units, num_proj) 373 | if state_is_tuple else num_units + num_proj) 374 | self._output_size = num_proj 375 | else: 376 | self._state_size = ( 377 | LSTMStateTuple(num_units, num_units) 378 | if state_is_tuple else 2 * num_units) 379 | self._output_size = num_units 380 | 381 | @property 382 | def state_size(self): 383 | return self._state_size 384 | 385 | @property 386 | def output_size(self): 387 | return self._output_size 388 | 389 | def __call__(self, inputs, state, scope=None): 390 | """Run one step of LSTM. 391 | Args: 392 | inputs: input Tensor, 2D, batch x num_units. 393 | state: if `state_is_tuple` is False, this must be a state Tensor, 394 | `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 395 | tuple of state Tensors, both `2-D`, with column sizes `c_state` and 396 | `m_state`. 397 | scope: VariableScope for the created subgraph; defaults to "LSTMCell". 398 | Returns: 399 | A tuple containing: 400 | - A `2-D, [batch x output_dim]`, Tensor representing the output of the 401 | LSTM after reading `inputs` when previous state was `state`. 402 | Here output_dim is: 403 | num_proj if num_proj was set, 404 | num_units otherwise. 405 | - Tensor(s) representing the new state of LSTM after reading `inputs` when 406 | the previous state was `state`. Same type and shape(s) as `state`. 407 | Raises: 408 | ValueError: If input size cannot be inferred from inputs via 409 | static shape inference. 410 | """ 411 | num_proj = self._num_units if self._num_proj is None else self._num_proj 412 | 413 | if self._state_is_tuple: 414 | (c_prev, m_prev) = state 415 | else: 416 | c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 417 | m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 418 | 419 | input_size = inputs.get_shape().with_rank(2)[1] 420 | if input_size.value is None: 421 | raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 422 | with vs.variable_scope(scope or type(self).__name__, 423 | initializer=self._initializer): # "LSTMCell" 424 | 425 | s1 = vs.get_variable("s1", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32) 426 | s2 = vs.get_variable("s2", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32) 427 | s3 = vs.get_variable("s3", initializer=tf.ones([self._num_units]), dtype=tf.float32) 428 | 429 | b1 = vs.get_variable("b1", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32) 430 | b2 = vs.get_variable("b2", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32) 431 | b3 = vs.get_variable("b3", initializer=tf.zeros([self._num_units]), dtype=tf.float32) 432 | 433 | # s1 = tf.Variable(tf.ones([4 * self._num_units]), name="s1") 434 | # s2 = tf.Variable(tf.ones([4 * self._num_units]), name="s2") 435 | # s3 = tf.Variable(tf.ones([self._num_units]), name="s3") 436 | # 437 | # b1 = tf.Variable(tf.zeros([4 * self._num_units]), name="b1") 438 | # b2 = tf.Variable(tf.zeros([4 * self._num_units]), name="b2") 439 | # b3 = tf.Variable(tf.zeros([self._num_units]), name="b3") 440 | 441 | input_below_ = rnn_cell._linear([inputs], 442 | 4 * self._num_units, False, scope="out_1") 443 | input_below_ = ln(input_below_, s1, b1) 444 | state_below_ = rnn_cell._linear([m_prev], 445 | 4 * self._num_units, False, scope="out_2") 446 | state_below_ = ln(state_below_, s2, b2) 447 | lstm_matrix = tf.add(input_below_, state_below_) 448 | 449 | i, j, f, o = array_ops.split(1, 4, lstm_matrix) 450 | 451 | c = (sigmoid(f) * c_prev + sigmoid(i) * 452 | self._activation(j)) 453 | 454 | # Currently normalizing c causes lot of nan's in the model, thus commenting it out for now. 455 | # c_ = ln(c, s3, b3) 456 | c_ = c 457 | m = sigmoid(o) * self._activation(c_) 458 | 459 | new_state = (LSTMStateTuple(c, m) if self._state_is_tuple 460 | else array_ops.concat(1, [c, m])) 461 | return m, new_state --------------------------------------------------------------------------------