├── .gitignore ├── model ├── __init__.py ├── cudnn_rnn.py ├── fo_pool.py ├── pool.py ├── qrnn_lib.so └── rcrn.py └── readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanzytay/NIPS2018_RCRN/7a8d07a609756f38d6a5ea19ce39fffeda679965/model/__init__.py -------------------------------------------------------------------------------- /model/cudnn_rnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | class cudnn_rnn: 8 | """ Universal cudnn_rnn class 9 | Supports both LSTM and GRU 10 | 11 | Variational dropout is optional 12 | """ 13 | 14 | def __init__(self, num_layers, num_units, batch_size, input_size, keep_prob=1.0, 15 | is_train=None, scope=None, init=None, rnn_type='', 16 | direction='bidirectional'): 17 | if(init is None): 18 | rnn_init = tf.random_normal_initializer(stddev=0.1) 19 | else: 20 | rnn_init = init 21 | self.num_layers = num_layers 22 | self.grus = [] 23 | self.inits = [] 24 | self.dropout_mask = [] 25 | self.num_units = num_units 26 | self.is_train = is_train 27 | self.keep_prob = keep_prob 28 | self.input_size = input_size 29 | self.rnn_type = rnn_type 30 | self.direction=direction 31 | self.num_params = [] 32 | for layer in range(num_layers): 33 | input_size_ = input_size if layer == 0 else 2 * num_units 34 | if('LSTM' in rnn_type): 35 | gru_fw = tf.contrib.cudnn_rnn.CudnnLSTM( 36 | 1, num_units, kernel_initializer=rnn_init) 37 | if(self.direction=='bidirectional'): 38 | gru_bw = tf.contrib.cudnn_rnn.CudnnLSTM( 39 | 1, num_units, kernel_initializer=rnn_init) 40 | else: 41 | gru_bw = None 42 | else: 43 | gru_fw = tf.contrib.cudnn_rnn.CudnnGRU( 44 | 1, num_units, kernel_initializer=rnn_init) 45 | if(self.direction=='bidirectional'): 46 | gru_bw = tf.contrib.cudnn_rnn.CudnnGRU( 47 | 1, num_units, kernel_initializer=rnn_init) 48 | else: 49 | gru_bw = None 50 | 51 | self.grus.append((gru_fw, gru_bw, )) 52 | 53 | def __call__(self, inputs, seq_len, batch_size=None, 54 | is_train=None, concat_layers=True, 55 | var_drop=1, train_init=0): 56 | # batch_size = inputs.get_shape().as_list()[0] 57 | batch_size = tf.shape(inputs)[0] 58 | outputs = [tf.transpose(inputs, [1, 0, 2])] 59 | 60 | for layer in range(self.num_layers): 61 | if(train_init): 62 | init_fw = tf.tile(tf.Variable( 63 | tf.zeros([1, 1, self.num_units])), [1, batch_size, 1]) 64 | if(self.direction=='bidirectional'): 65 | init_bw = tf.tile(tf.Variable( 66 | tf.zeros([1, 1, self.num_units])), [1, batch_size, 1]) 67 | else: 68 | init_bw = None 69 | else: 70 | init_fw = tf.tile(tf.zeros([1, 1, self.num_units]), 71 | [1, batch_size, 1]) 72 | if(self.direction=='bidirectional'): 73 | init_bw = tf.tile(tf.zeros([1, 1, self.num_units]), 74 | [1, batch_size, 1]) 75 | else: 76 | init_bw = None 77 | if(var_drop==1): 78 | mask_fw = dropout(tf.ones([1, batch_size, self.input_size], 79 | dtype=tf.float32), 80 | keep_prob=self.keep_prob, is_train=self.is_train) 81 | output_fw = outputs[-1] * mask_fw 82 | if(self.direction=='bidirectional'): 83 | mask_bw = dropout(tf.ones([1, batch_size, self.input_size], 84 | dtype=tf.float32), 85 | keep_prob=self.keep_prob, is_train=self.is_train) 86 | output_bw = outputs[-1] * mask_bw 87 | else: 88 | output_fw = outputs[-1] 89 | output_fw = dropout(output_fw, 90 | keep_prob=self.keep_prob, 91 | is_train=self.is_train) 92 | if(self.direction=='bidirectional'): 93 | output_bw = outputs[-1] 94 | output_bw = dropout(output_bw, 95 | keep_prob=self.keep_prob, 96 | is_train=self.is_train) 97 | gru_fw, gru_bw = self.grus[layer] 98 | if('LSTM' in self.rnn_type): 99 | init_state1 = (init_fw, init_fw) 100 | init_state2 = (init_bw, init_bw) 101 | else: 102 | init_state1 = (init_fw,) 103 | init_state2 = (init_bw,) 104 | 105 | with tf.variable_scope("fw_{}".format(layer)): 106 | out_fw, _ = gru_fw( 107 | output_fw, initial_state=init_state1) 108 | self.num_params += gru_fw.canonical_weight_shapes 109 | 110 | out = out_fw 111 | 112 | if(self.direction=='bidirectional'): 113 | with tf.variable_scope("bw_{}".format(layer)): 114 | inputs_bw = tf.reverse_sequence( 115 | output_bw, seq_lengths=seq_len, seq_dim=0, batch_dim=1) 116 | out_bw, _ = gru_bw(inputs_bw, initial_state=init_state2) 117 | out_bw = tf.reverse_sequence( 118 | out_bw, seq_lengths=seq_len, seq_dim=0, batch_dim=1) 119 | out = tf.concat([out, out_bw], 2) 120 | self.num_params += gru_bw.canonical_weight_shapes 121 | outputs.append(out) 122 | if concat_layers: 123 | res = tf.concat(outputs[1:], axis=2) 124 | else: 125 | res = outputs[-1] 126 | res = tf.transpose(res, [1, 0, 2]) 127 | 128 | counter = 0 129 | for t in self.num_params: 130 | counter += t[0] * t[1] 131 | print('Cudnn Parameters={}'.format(counter)) 132 | 133 | return res 134 | -------------------------------------------------------------------------------- /model/fo_pool.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | custom_op_path = os.path.dirname(os.path.abspath(__file__)) + '/qrnn_lib.so' 5 | # print("Loading Custom Op from {}".format(custom_op_path)) 6 | qrnn_lib = tf.load_op_library(custom_op_path) 7 | 8 | time_major_fo_pool_unsliced = qrnn_lib.time_major_fo_pool 9 | time_major_bwd_fo_pool = qrnn_lib.time_major_bwd_fo_pool 10 | 11 | batch_major_fo_pool_unsliced = qrnn_lib.batch_major_fo_pool 12 | batch_major_bwd_fo_pool = qrnn_lib.batch_major_bwd_fo_pool 13 | 14 | @tf.RegisterGradient("TimeMajorFoPool") 15 | def _fo_pool_grad(op, grad): 16 | return time_major_bwd_fo_pool(h=op.outputs[0], x=op.inputs[0], 17 | forget=op.inputs[1], gh=grad) 18 | 19 | @tf.RegisterGradient("BatchMajorFoPool") 20 | def _fo_pool_grad(op, grad): 21 | return batch_major_bwd_fo_pool(h=op.outputs[0], x=op.inputs[0], 22 | forget=op.inputs[1], gh=grad) 23 | 24 | 25 | def fo_pool(x, forget, initial_state=None, time_major=False): 26 | """Applies a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence. 27 | Args: 28 | x: Tensor, input values in [Batch, Time, Channels] format, 29 | float32 or double 30 | or [Time, Batch, Channels] if time_major 31 | forget: Tensor, input values in [Batch, Time, Channels] format, 32 | float32 or double. Usually in the range 0-1. 33 | or [Time, Batch, Channels] if time_major 34 | initial_state: Tensor, initial hidden state values in [Batch, Channels] format, 35 | float32 or double. 36 | Returns: 37 | Tensor: fo_pooled output, [Batch, Time, Channels] format 38 | or [Time, Batch, Channels] if time_major 39 | """ 40 | if initial_state is None: 41 | initial_state = tf.zeros((tf.shape(x)[1] if time_major else tf.shape(x)[0], 42 | tf.shape(x)[2]), dtype=tf.dtype) 43 | if time_major: 44 | return time_major_fo_pool_unsliced(x, forget, initial_state)[1:] 45 | else: 46 | return batch_major_fo_pool_unsliced(x, forget, initial_state)[:, 1:] 47 | -------------------------------------------------------------------------------- /model/pool.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .fo_pool import * 3 | 4 | class RCRNpooling(tf.nn.rnn_cell.RNNCell): 5 | """ Handles the pooling using two BiRNNs as gates 6 | """ 7 | 8 | def __init__(self, out_fmaps, pool_type, 9 | initializer=None, in_dim=None): 10 | self.__pool_type = pool_type 11 | self.__out_fmaps = out_fmaps 12 | if(initializer is None): 13 | initialzier = tf.orthogonal_initializer() 14 | 15 | @property 16 | def state_size(self): 17 | return self.__out_fmaps 18 | 19 | @property 20 | def output_size(self): 21 | return self.__out_fmaps 22 | 23 | def __call__(self, inputs, state, scope=None): 24 | """ 25 | inputs: 2-D tensor of shape [batch_size, feats + [gates]] 26 | """ 27 | pool_type = self.__pool_type 28 | with tf.variable_scope(scope or "QRNN-{}-pooling".format(pool_type)): 29 | if pool_type == 'f': 30 | # extract Z activations and F gate activations 31 | Z, F = tf.split(inputs, 2, 1) 32 | # return the dynamic average pooling 33 | output = tf.multiply(F, state) + tf.multiply(tf.subtract(1., F), Z) 34 | return output, output 35 | elif pool_type == 'fo': 36 | # extract Z, F gate and O gate 37 | Z, F, O = tf.split(inputs, 3, 1) 38 | new_state = tf.multiply(F, state) + tf.multiply(tf.subtract(1., F), Z) 39 | output = tf.multiply(O, new_state) 40 | return output, new_state 41 | elif pool_type == 'ifo': 42 | # extract Z, I gate, F gate, and O gate 43 | Z, I, F, O = tf.split(inputs, 4, 1) 44 | new_state = tf.multiply(F, state) + tf.multiply(I, Z) 45 | output = tf.multiply(O, new_state) 46 | return output, new_state 47 | -------------------------------------------------------------------------------- /model/qrnn_lib.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanzytay/NIPS2018_RCRN/7a8d07a609756f38d6a5ea19ce39fffeda679965/model/qrnn_lib.so -------------------------------------------------------------------------------- /model/rcrn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from .pool import * 7 | from .fo_pool import * 8 | from .cudnn_rnn import * 9 | 10 | def RCRN(embed, lengths, 11 | initializer=None, name='', reuse=None, 12 | dropout=None, is_train=None, var_drop=0, dim=None, 13 | direction='bidirectional', rnn_type='LSTM', 14 | cell_dropout=None, 15 | fuse_kernel=0, train_init=0): 16 | """ Implementation of RCRN encoders 17 | 18 | Args: 19 | embed: tensor sequence of bsz x seq_len x dim 20 | lengths: tensor of [bsz] with actual lengths 21 | initializer: tensorflow initializer 22 | name: give it a name! 23 | reuse: whether to reuse vars 24 | dropout: tensor scalar. Pass it via feed_dict 25 | is_train: tensor bool, whether training or not 26 | var_drop: whether to use variational dropout 27 | dim: int size of the output dim (if not, uses input dim) 28 | direction: 'bidirectional' or 'unidirectional' 29 | rnn_type: the rnn type of internal cell 30 | cell_dropout: bool, whether to use dropout during recurrence 31 | fuse_kernel: int, 1 to use fast cuda ops and 0 not to 32 | train_init: whether starting state is zero or trainable parameters 33 | """ 34 | 35 | if(dim is None): 36 | dim = embed.get_shape().as_list()[2] 37 | 38 | dim2 = dim 39 | if(direction=='bidirectional'): 40 | dim2 = dim2 * 2 41 | 42 | batch_size = tf.shape(embed)[0] 43 | if(train_init): 44 | initial_state = tf.tile(tf.Variable( 45 | tf.zeros([1, dim2])), [batch_size, 1]) 46 | else: 47 | initial_state = tf.tile( 48 | tf.zeros([1, dim2]), [batch_size, 1]) 49 | 50 | d = dim 51 | bsz = batch_size 52 | 53 | with tf.variable_scope("main_rnn", reuse=reuse): 54 | main_rnn = cudnn_rnn(num_layers=1, num_units=d, 55 | batch_size=bsz, 56 | input_size=embed.get_shape().as_list()[-1], 57 | keep_prob=dropout, 58 | is_train=is_train, 59 | direction=direction, 60 | rnn_type=rnn_type, 61 | init=initializer 62 | ) 63 | proj_embed = main_rnn(embed, 64 | seq_len=lengths, 65 | var_drop=var_drop, 66 | train_init=train_init 67 | ) 68 | with tf.variable_scope("fg_rnn", reuse=reuse): 69 | forget_rnn = cudnn_rnn(num_layers=1, num_units=d, 70 | batch_size=bsz, 71 | input_size=embed.get_shape().as_list()[-1], 72 | keep_prob=dropout, 73 | direction=direction, 74 | is_train=is_train, 75 | rnn_type=rnn_type, 76 | init=initializer) 77 | forget_gate = forget_rnn(embed, seq_len=lengths, 78 | var_drop=var_drop, 79 | train_init=train_init 80 | ) 81 | with tf.variable_scope("og_rnn", reuse=reuse): 82 | output_rnn = cudnn_rnn(num_layers=1, num_units=d, 83 | batch_size=bsz, 84 | input_size=embed.get_shape().as_list()[-1], 85 | keep_prob=dropout, 86 | direction=direction, 87 | is_train=is_train, 88 | rnn_type=rnn_type, 89 | init=initializer) 90 | output_gate = output_rnn(embed, seq_len=lengths, 91 | var_drop=var_drop, 92 | train_init=train_init 93 | ) 94 | 95 | # forget_gate = gate 96 | pooling = RCRNpooling(dim2, 'fo') 97 | if(cell_dropout is not None and cell_dropout<1.0): 98 | print("Adding dropout") 99 | pooling = tf.contrib.rnn.DropoutWrapper(pooling, 100 | output_keep_prob=cell_dropout) 101 | initial_state = pooling.zero_state(tf.shape(embed)[0], 102 | tf.float32) 103 | output_gate = tf.nn.sigmoid(output_gate) 104 | forget_gate = tf.nn.sigmoid(forget_gate) 105 | 106 | if(fuse_kernel==1): 107 | print("Using Cuda-level Fused Kernel Optimization") 108 | with tf.name_scope("FoPool"): 109 | c = fo_pool(proj_embed, forget_gate, 110 | initial_state=initial_state, 111 | time_major=0) 112 | embed = c * output_gate 113 | else: 114 | stack_input = tf.concat([proj_embed, 115 | forget_gate, output_gate], 2) 116 | embed, _ = tf.nn.dynamic_rnn(pooling, stack_input, 117 | initial_state=initial_state, 118 | sequence_length=tf.cast( 119 | lengths,tf.int32)) 120 | 121 | 122 | return embed, tf.reduce_sum(embed, 1) 123 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # NIPS 2018 2 | 3 | We propose a new Recurrently controlled recurrent networks (RCRN) that shows some improvement over stacked BiLSTMs and BiLSTM across a number of NLP tasks. 4 | 5 | This repository contains the Tensorflow model file for RCRN, according with the custom cuda optimized kernel. I will upload running scripts / example usage when I have time (already have tons of backlog ): ) . 6 | 7 | # Dependencies 8 | 9 | Python 2.7 10 | Tensorflow 1.7 11 | 12 | # Acknowledgements 13 | 14 | Our CUDA op was adapted from: https://github.com/JonathanRaiman/tensorflow_qrnn 15 | 16 | Cudnn RNN was adapted from: 17 | https://github.com/HKUST-KnowComp/R-Net 18 | 19 | # Reference 20 | 21 | If you find our repository useful, please cite our paper! 22 | 23 | ``` 24 | @inproceedings{nips2018, 25 | author = {Yi Tay and 26 | Luu Anh Tuan and 27 | Siu Cheung Hui}, 28 | title = {Recurrently Controlled Recurrent Networks}, 29 | booktitle = {Proceedings of NIPS 2018}, 30 | year = {2018} 31 | } 32 | ``` 33 | --------------------------------------------------------------------------------