├── .editorconfig ├── .gitignore ├── .pylintrc ├── README.md ├── dataplumbing.py ├── dataset └── input_data.py ├── ran_cell.py ├── ran_cell_v2.py ├── train.py ├── train_accuracy.png ├── train_cost.png └── utils.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig helps developers define and maintain consistent 2 | # coding styles between different editors and IDEs 3 | # editorconfig.org 4 | 5 | root = true 6 | 7 | 8 | [*] 9 | 10 | # Change these settings to your own preference 11 | indent_style = space 12 | indent_size = 2 13 | 14 | # We recommend you to keep these unchanged 15 | end_of_line = lf 16 | charset = utf-8 17 | trim_trailing_whitespace = true 18 | insert_final_newline = true 19 | 20 | [*.md] 21 | trim_trailing_whitespace = false 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store 3 | checkpoint 4 | data 5 | !data/wiki/vocab.pkl 6 | log 7 | bin 8 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | #My pylintrc for use with atom.io's linter-pylint 2 | [MESSAGES CONTROL] 3 | disable=W0311,W1201,W0702,W0611,W0621,E1101,C0111,C0103,R0902 4 | 5 | # checks for : 6 | # * unauthorized constructions 7 | # * strict indentation 8 | # * line length 9 | # * use of <> instead of != 10 | # 11 | [FORMAT] 12 | # Maximum number of characters on a single line. 13 | max-line-length=128 14 | # Maximum number of lines in a module 15 | max-module-lines=1000 16 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 17 | # tab). In repo it is 2 spaces. 18 | indent-string=' ' 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Recurrent Additive Networks (RAN)](https://arxiv.org/abs/1705.07393) 2 | This is a implementation of Recurrent Additive Networks that extends Tensorflow's RNNCell. 3 | 4 | ### Requirements 5 | * tensorflow r1.1 6 | 7 | ### Accuracy 8 | ![Accuracy](/train_accuracy.png) 9 | 10 | ### Cost 11 | ![Cost](/train_cost.png) 12 | -------------------------------------------------------------------------------- /dataplumbing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ########################################################################################## 3 | # Author: Jared L. Ostmeyer 4 | # Date Started: 2017-01-01 5 | # Purpose: Load dataset and create interfaces for piping the data to the model 6 | ########################################################################################## 7 | 8 | ########################################################################################## 9 | # Libraries 10 | ########################################################################################## 11 | 12 | import numpy as np 13 | 14 | ########################################################################################## 15 | # Class definitions 16 | ########################################################################################## 17 | 18 | # Defines interface between the data and model 19 | # 20 | class Dataset: 21 | def __init__(self, xs, ls, ys): 22 | self.xs = xs # Store the features 23 | self.ls = ls # Store the length of each sequence 24 | self.ys = ys # Store the labels 25 | self.num_samples = len(ys) 26 | self.num_features = len(xs[0,0,:]) 27 | self.max_length = len(xs[0,:,0]) 28 | self.num_classes = 1 29 | def batch(self, batch_size): 30 | js = np.random.randint(0, self.num_samples, batch_size) 31 | return self.xs[js,:,:], self.ls[js], self.ys[js] 32 | 33 | ########################################################################################## 34 | # Import dataset 35 | ########################################################################################## 36 | 37 | # Load data 38 | # 39 | import sys 40 | sys.path.append('./dataset') 41 | import input_data 42 | 43 | # Create split of data 44 | # 45 | train = Dataset(input_data.xs_train, input_data.ls_train, input_data.ys_train) 46 | test = Dataset(input_data.xs_test, input_data.ls_test, input_data.ys_test) 47 | -------------------------------------------------------------------------------- /dataset/input_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ########################################################################################## 3 | # Author: Jared L. Ostmeyer 4 | # Date Started: 2017-01-01 5 | # Purpose: Load dataset or generate it if it does not exist yet. 6 | # License: For legal information see LICENSE in the home directory. 7 | ########################################################################################## 8 | 9 | ########################################################################################## 10 | # Libraries 11 | ########################################################################################## 12 | 13 | import os 14 | import numpy as np 15 | 16 | ########################################################################################## 17 | # Settings 18 | ########################################################################################## 19 | 20 | # Reber grammar 21 | # 22 | states = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] 23 | transitions = { 24 | 1: [2, 7], 25 | 2: [3, 4], 26 | 3: [3, 4], 27 | 4: [5, 6], 28 | 5: [12], 29 | 6: [8, 9], 30 | 7: [8, 9], 31 | 8: [8, 9], 32 | 9: [10, 11], 33 | 10: [5, 6], 34 | 11: [12], 35 | } 36 | aliases = { 37 | 1: 'B', 2: 'T', 3: 'S', 4: 'X', 5: 'S', 6: 'X', 38 | 7: 'P', 8: 'T', 9: 'V', 10: 'P', 11: 'V', 12: 'E', 39 | } 40 | encoding = {'B': 0, 'E': 1, 'P': 2, 'S': 3, 'T': 4, 'V': 5, 'X': 6} 41 | 42 | # Data dimensions 43 | # 44 | num_train = 10000 45 | num_test = 10000 46 | max_length = 50 47 | num_features = 7 48 | 49 | ########################################################################################## 50 | # Utilities 51 | ########################################################################################## 52 | 53 | def make_chain(): 54 | chain = [1] 55 | while chain[-1] != states[-1]: 56 | choices = transitions[chain[-1]] 57 | j = np.random.randint(len(choices)) 58 | chain.append(choices[j]) 59 | return chain 60 | 61 | def valid_chain(chain): 62 | if len(chain) == 0: 63 | return False 64 | if chain[0] != states[0]: 65 | return False 66 | for i in range(1, len(chain)): 67 | if chain[i] not in transitions[chain[i-1]]: 68 | return False 69 | return True 70 | 71 | def convert_chain(chain): 72 | sequence = '' 73 | for value in chain: 74 | sequence += aliases[value] 75 | return sequence 76 | 77 | ########################################################################################## 78 | # Generate/Load dataset 79 | ########################################################################################## 80 | 81 | # Make directory 82 | # 83 | _path = '/'.join(__file__.split('/')[:-1]) 84 | os.makedirs(_path+'/bin', exist_ok=True) 85 | 86 | # Training data 87 | # 88 | if not os.path.isfile(_path+'/bin/xs_train.npy') or \ 89 | not os.path.isfile(_path+'/bin/ls_train.npy') or \ 90 | not os.path.isfile(_path+'/bin/ys_train.npy'): 91 | xs_train = np.zeros((num_train, max_length, num_features)) 92 | ls_train = np.zeros(num_train) 93 | ys_train = np.zeros(num_train) 94 | for i in range(num_train): 95 | chain = make_chain() 96 | valid = 1.0 97 | if np.random.rand() >= 0.5: # Randomly insert a single typo with proability 0.5 98 | hybrid = chain 99 | while valid_chain(hybrid): 100 | chain_ = make_chain() 101 | j = np.random.randint(len(chain)) 102 | j_ = np.random.randint(len(chain_)) 103 | hybrid = chain[:j]+chain_[j_:] 104 | chain = hybrid 105 | valid = 0.0 106 | sequence = convert_chain(chain) 107 | for j, symbol in enumerate(sequence): 108 | k = encoding[sequence[j]] 109 | xs_train[i,j,k] = 1.0 110 | ls_train[i] = len(sequence) 111 | ys_train[i] = valid 112 | np.save(_path+'/bin/xs_train.npy', xs_train) 113 | np.save(_path+'/bin/ls_train.npy', ls_train) 114 | np.save(_path+'/bin/ys_train.npy', ys_train) 115 | else: 116 | xs_train = np.load(_path+'/bin/xs_train.npy') 117 | ls_train = np.load(_path+'/bin/ls_train.npy') 118 | ys_train = np.load(_path+'/bin/ys_train.npy') 119 | 120 | # Test data 121 | # 122 | if not os.path.isfile(_path+'/bin/xs_test.npy') or \ 123 | not os.path.isfile(_path+'/bin/ls_test.npy') or \ 124 | not os.path.isfile(_path+'/bin/ys_test.npy'): 125 | xs_test = np.zeros((num_test, max_length, num_features)) 126 | ls_test = np.zeros(num_test) 127 | ys_test = np.zeros(num_test) 128 | for i in range(num_test): 129 | chain = make_chain() 130 | valid = 1.0 131 | if np.random.rand() >= 0.5: # Randomly insert a single typo with proability 0.5 132 | hybrid = chain 133 | while valid_chain(hybrid): 134 | chain_ = make_chain() 135 | j = np.random.randint(len(chain)) 136 | j_ = np.random.randint(len(chain)) 137 | hybrid = chain[:j]+chain_[j_:] 138 | chain = hybrid 139 | valid = 0.0 140 | sequence = convert_chain(chain) 141 | for j, symbol in enumerate(sequence): 142 | k = encoding[sequence[j]] 143 | xs_test[i,j,k] = 1.0 144 | ls_test[i] = len(sequence) 145 | ys_test[i] = valid 146 | np.save(_path+'/bin/xs_test.npy', xs_test) 147 | np.save(_path+'/bin/ls_test.npy', ls_test) 148 | np.save(_path+'/bin/ys_test.npy', ys_test) 149 | else: 150 | xs_test = np.load(_path+'/bin/xs_test.npy') 151 | ls_test = np.load(_path+'/bin/ls_test.npy') 152 | ys_test = np.load(_path+'/bin/ys_test.npy') 153 | -------------------------------------------------------------------------------- /ran_cell.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import tensorflow as tf 3 | from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell 4 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 5 | from tensorflow.python.ops import array_ops 6 | from tensorflow.python.ops.math_ops import tanh 7 | from tensorflow.python.ops import variable_scope as vs 8 | from utils import linear 9 | 10 | _checked_scope = core_rnn_cell_impl._checked_scope 11 | 12 | class RANCell(RNNCell): 13 | """Recurrent Additive Networks (cf. https://arxiv.org/abs/1705.07393).""" 14 | 15 | def __init__(self, num_units, input_size=None, activation=tanh, normalize=False, reuse=None): 16 | if input_size is not None: 17 | logging.warn("%s: The input_size parameter is deprecated.", self) 18 | self._num_units = num_units 19 | self._activation = activation 20 | self._normalize = normalize 21 | self._reuse = reuse 22 | 23 | @property 24 | def state_size(self): 25 | return self._num_units 26 | 27 | @property 28 | def output_size(self): 29 | return self._num_units 30 | 31 | def __call__(self, inputs, state, scope=None): 32 | with _checked_scope(self, scope or "ran_cell", reuse=self._reuse): 33 | with vs.variable_scope("gates"): 34 | value = tf.nn.sigmoid(linear([state, inputs], 2 * self._num_units, True, normalize=self._normalize)) 35 | i, f = array_ops.split(value=value, num_or_size_splits=2, axis=1) 36 | 37 | with vs.variable_scope("candidate"): 38 | c = linear([inputs], self._num_units, True, normalize=self._normalize) 39 | 40 | new_c = i * c + f * state 41 | new_h = self._activation(c) 42 | 43 | return new_h, new_c 44 | -------------------------------------------------------------------------------- /ran_cell_v2.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import tensorflow as tf 3 | from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell 4 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 5 | from tensorflow.python.ops import array_ops 6 | from tensorflow.python.ops.math_ops import tanh 7 | from tensorflow.python.ops import variable_scope as vs 8 | from utils import linear 9 | 10 | _checked_scope = core_rnn_cell_impl._checked_scope 11 | 12 | class RANCellv2(RNNCell): 13 | """Recurrent Additive Networks (cf. https://arxiv.org/abs/1705.07393).""" 14 | 15 | def __init__(self, num_units, input_size=None, activation=tanh, normalize=False, reuse=None): 16 | if input_size is not None: 17 | logging.warn("%s: The input_size parameter is deprecated.", self) 18 | self._num_units = num_units 19 | self._activation = activation 20 | self._normalize = normalize 21 | self._reuse = reuse 22 | 23 | @property 24 | def state_size(self): 25 | return tf.contrib.rnn.LSTMStateTuple(self._num_units, self.output_size) 26 | 27 | @property 28 | def output_size(self): 29 | return self._num_units 30 | 31 | def __call__(self, inputs, state, scope=None): 32 | with _checked_scope(self, scope or "ran_cell", reuse=self._reuse): 33 | with vs.variable_scope("gates"): 34 | c, h = state 35 | gates = tf.nn.sigmoid(linear([inputs, h], 2 * self._num_units, True, normalize=self._normalize)) 36 | i, f = array_ops.split(value=gates, num_or_size_splits=2, axis=1) 37 | 38 | with vs.variable_scope("candidate"): 39 | content = linear([inputs], self._num_units, True, normalize=self._normalize) 40 | 41 | new_c = i * content + f * c 42 | new_h = self._activation(c) 43 | new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 44 | output = new_h 45 | return output, new_state 46 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import dataplumbing as dp 5 | from tensorflow.contrib import rnn 6 | from tensorflow.contrib.rnn import GRUCell, BasicLSTMCell, LayerNormBasicLSTMCell 7 | from tensorflow.contrib.rnn.python.ops import core_rnn 8 | from tensorflow.contrib.layers import xavier_initializer as glorot 9 | from ran_cell import RANCell 10 | from ran_cell_v2 import RANCellv2 11 | 12 | flags = tf.app.flags 13 | flags.DEFINE_string("rnn_type", "RAN", "rnn type [RAN, RANv2, RAN_LNv2, RAN_LN, LSTM, GRU]") 14 | FLAGS = flags.FLAGS 15 | 16 | def main(_): 17 | np.random.seed(1) 18 | tf.set_random_seed(1) 19 | num_features = dp.train.num_features 20 | max_steps = dp.train.max_length 21 | num_cells = 250 22 | num_classes = dp.train.num_classes 23 | initialization_factor = 1.0 24 | num_iterations = 500 25 | batch_size = 100 26 | learning_rate = 0.001 27 | current_step = 0 28 | 29 | initializer = tf.random_uniform_initializer(minval=-np.sqrt(6.0 * 1.0 / (num_cells + num_classes)), 30 | maxval=np.sqrt(6.0 * 1.0 / (num_cells + num_classes))) 31 | 32 | with tf.variable_scope("train", initializer=initializer): 33 | s = tf.Variable(tf.random_normal([num_cells], stddev=np.sqrt(initialization_factor))) # Determines initial state 34 | x = tf.placeholder(tf.float32, [batch_size, max_steps, num_features]) # Features 35 | y = tf.placeholder(tf.float32, [batch_size]) # Labels 36 | l = tf.placeholder(tf.int32, [batch_size]) 37 | global_step = tf.Variable(0, name="global_step", trainable=False) 38 | 39 | if FLAGS.rnn_type == "RAN": 40 | cell = RANCell(num_cells) 41 | elif FLAGS.rnn_type == "RANv2": 42 | cell = RANCellv2(num_cells) 43 | elif FLAGS.rnn_type == "LSTM": 44 | cell = BasicLSTMCell(num_cells) 45 | elif FLAGS.rnn_type == "LSTM_LN": 46 | cell = LayerNormBasicLSTMCell(num_cells) 47 | elif FLAGS.rnn_type == "GRU": 48 | cell = GRUCell(num_cells) 49 | elif FLAGS.rnn_type == "RAN_LN": 50 | cell = RANCell(num_cells, normalize=True) 51 | elif FLAGS.rnn_type == "RAN_LNv2": 52 | cell = RANCellv2(num_cells, normalize=True) 53 | 54 | states = cell.zero_state(batch_size, tf.float32) 55 | outputs, states = tf.nn.dynamic_rnn(cell, x, l, states) 56 | 57 | W_o = tf.Variable(tf.random_uniform([num_cells, num_classes], 58 | minval=-np.sqrt(6.0*initialization_factor / (num_cells + num_classes)), 59 | maxval=np.sqrt(6.0*initialization_factor / (num_cells + num_classes)))) 60 | b_o = tf.Variable(tf.zeros([num_classes])) 61 | 62 | if FLAGS.rnn_type == "LSTM" or FLAGS.rnn_type == "LSTM_LN" \ 63 | or FLAGS.rnn_type == "RANv2" or FLAGS.rnn_type == "RAN_LNv2": 64 | ly = tf.matmul(states.h, W_o) + b_o 65 | else: 66 | ly = tf.matmul(states, W_o) + b_o 67 | ly_flat = tf.reshape(ly, [batch_size]) 68 | py = tf.nn.sigmoid(ly_flat) 69 | 70 | ########################################################################################## 71 | # Optimizer/Analyzer 72 | ########################################################################################## 73 | 74 | # Cost function and optimizer 75 | # 76 | cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=ly_flat, labels=y)) # Cross-entropy cost function 77 | optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step=global_step) 78 | 79 | # Evaluate performance 80 | # 81 | correct = tf.equal(tf.round(py), tf.round(y)) 82 | accuracy = 100.0 * tf.reduce_mean(tf.cast(correct, tf.float32)) 83 | 84 | tf.summary.scalar('cost', cost) 85 | tf.summary.scalar('accuracy', accuracy) 86 | 87 | ########################################################################################## 88 | # Train 89 | ########################################################################################## 90 | 91 | # Operation to initialize session 92 | # 93 | initializer = tf.global_variables_initializer() 94 | summaries = tf.summary.merge_all() 95 | 96 | # Open session 97 | # 98 | with tf.Session() as session: 99 | # Summary writer 100 | # 101 | summary_writer = tf.summary.FileWriter('log/' + FLAGS.rnn_type, session.graph) 102 | 103 | # Initialize variables 104 | # 105 | session.run(initializer) 106 | 107 | # Each training session represents one batch 108 | # 109 | for iteration in range(num_iterations): 110 | # Grab a batch of training data 111 | # 112 | xs, ls, ys = dp.train.batch(batch_size) 113 | feed = {x: xs, l: ls, y: ys} 114 | 115 | # Update parameters 116 | out = session.run((cost, accuracy, optimizer, summaries, global_step), feed_dict=feed) 117 | print('Iteration:', iteration, 'Dataset:', 'train', 'Cost:', out[0]/np.log(2.0), 'Accuracy:', out[1]) 118 | 119 | summary_writer.add_summary(out[3], current_step) 120 | 121 | # Periodically run model on test data 122 | if iteration%100 == 0: 123 | # Grab a batch of test data 124 | # 125 | xs, ls, ys = dp.test.batch(batch_size) 126 | feed = {x: xs, l: ls, y: ys} 127 | 128 | # Run model 129 | # 130 | summary_writer.flush() 131 | out = session.run((cost, accuracy), feed_dict=feed) 132 | test_cost = out[0] / np.log(2.0) 133 | test_accuracy = out[1] 134 | print('Iteration:', iteration, 'Dataset:', 'test', 'Cost:', test_cost, 'Accuracy:', test_accuracy) 135 | 136 | current_step = tf.train.global_step(session, global_step) 137 | 138 | summary_writer.close() 139 | 140 | # Save the trained model 141 | os.makedirs('bin', exist_ok=True) 142 | saver = tf.train.Saver() 143 | saver.save(session, 'bin/train.ckpt') 144 | 145 | 146 | if __name__ == "__main__": 147 | tf.app.run() 148 | -------------------------------------------------------------------------------- /train_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/indiejoseph/tf-ran-cell/389d549314928e5c9cb7658b3371f76d25a44ab2/train_accuracy.png -------------------------------------------------------------------------------- /train_cost.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/indiejoseph/tf-ran-cell/389d549314928e5c9cb7658b3371f76d25a44ab2/train_cost.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.util import nest 3 | from tensorflow.python.ops import variable_scope as vs 4 | from tensorflow.python.ops import array_ops 5 | from tensorflow.python.ops import init_ops 6 | from tensorflow.python.ops import math_ops 7 | from tensorflow.python.ops import nn_ops 8 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 9 | 10 | _BIAS_VARIABLE_NAME = "bias" 11 | _WEIGHTS_VARIABLE_NAME = "kernel" 12 | 13 | def linear(args, 14 | output_size, 15 | bias, 16 | bias_initializer=None, 17 | kernel_initializer=None, 18 | kernel_regularizer=None, 19 | bias_regularizer=None, 20 | normalize=False): 21 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 22 | Args: 23 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 24 | output_size: int, second dimension of W[i]. 25 | bias: boolean, whether to add a bias term or not. 26 | bias_initializer: starting value to initialize the bias 27 | (default is all zeros). 28 | kernel_initializer: starting value to initialize the weight. 29 | kernel_regularizer: kernel regularizer 30 | bias_regularizer: bias regularizer 31 | Returns: 32 | A 2D Tensor with shape [batch x output_size] equal to 33 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 34 | Raises: 35 | ValueError: if some of the arguments has unspecified or wrong shape. 36 | """ 37 | if args is None or (nest.is_sequence(args) and not args): 38 | raise ValueError("`args` must be specified") 39 | if not nest.is_sequence(args): 40 | args = [args] 41 | 42 | # Calculate the total size of arguments on dimension 1. 43 | total_arg_size = 0 44 | shapes = [a.get_shape() for a in args] 45 | for shape in shapes: 46 | if shape.ndims != 2: 47 | raise ValueError("linear is expecting 2D arguments: %s" % shapes) 48 | if shape[1].value is None: 49 | raise ValueError("linear expects shape[1] to be provided for shape %s, " 50 | "but saw %s" % (shape, shape[1])) 51 | else: 52 | total_arg_size += shape[1].value 53 | 54 | dtype = [a.dtype for a in args][0] 55 | 56 | # Now the computation. 57 | scope = vs.get_variable_scope() 58 | with vs.variable_scope(scope) as outer_scope: 59 | weights = vs.get_variable( 60 | _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], 61 | dtype=dtype, 62 | initializer=kernel_initializer, 63 | regularizer=kernel_regularizer) 64 | 65 | if len(args) == 1: 66 | res = math_ops.matmul(args[0], weights) 67 | else: 68 | res = math_ops.matmul(array_ops.concat(args, 1), weights) 69 | 70 | if normalize: 71 | res = tf.contrib.layers.layer_norm(res) 72 | 73 | # remove the layer’s bias if there is one (because it would be redundant) 74 | if not bias or normalize: 75 | return res 76 | 77 | with vs.variable_scope(outer_scope) as inner_scope: 78 | inner_scope.set_partitioner(None) 79 | if bias_initializer is None: 80 | bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 81 | biases = vs.get_variable( 82 | _BIAS_VARIABLE_NAME, [output_size], 83 | dtype=dtype, 84 | initializer=bias_initializer, 85 | regularizer=bias_regularizer) 86 | 87 | return nn_ops.bias_add(res, biases) 88 | --------------------------------------------------------------------------------