├── .gitignore ├── LICENSE ├── README.md ├── figures ├── authors │ ├── BrendanJou.png │ ├── JordiTorres.jpg │ ├── ShihFuChang.jpg │ ├── VictorCampos.jpg │ └── XavierGiro.jpg ├── logos │ ├── MEyC.png │ ├── bsc.jpg │ ├── columbia.png │ ├── generalitat.jpg │ ├── google.png │ ├── gpu_excellence_center.png │ ├── severo_ochoa.png │ └── upc.jpg └── skip-rnn-model.png ├── paper.pdf ├── requirements.txt └── src ├── 01_adding_task.py ├── 02_frequency_discrimination_task.py ├── 03_sequential_mnist.py ├── rnn_cells ├── __init__.py ├── basic_rnn_cells.py ├── rnn_ops.py └── skip_rnn_cells.py └── util ├── __init__.py ├── graph_definition.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # .idea/ 104 | .idea/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Image Processing Group - BarcelonaTECH - UPC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Skip RNN: Learning to Skip State Updates in Recurrent Neural Networks 2 | 3 | | ![Víctor Campos][VictorCampos-photo] | ![Brendan Jou][BrendanJou-photo] | ![Jordi Torres][JordiTorres-photo] | ![Xavier Giro-i-Nieto][XavierGiro-photo] | ![Shih-Fu Chang][ShihFuChang-photo] | 4 | |:-:|:-:|:-:|:-:|:-:| 5 | | [Víctor Campos][VictorCampos-web] | [Brendan Jou][BrendanJou-web] | [Jordi Torres][JordiTorres-web] | [Xavier Giró-i-Nieto][XavierGiro-web] | [Shih-Fu Chang][ShihFuChang-web] | 6 | 7 | [VictorCampos-photo]: ./figures/authors/VictorCampos.jpg "Víctor Campos" 8 | [JordiTorres-photo]: ./figures/authors/JordiTorres.jpg "Jordi Torres" 9 | [XavierGiro-photo]: ./figures/authors/XavierGiro.jpg "Xavier Giro-i-Nieto" 10 | [BrendanJou-photo]: ./figures/authors/BrendanJou.png "Brendan Jou" 11 | [ShihFuChang-photo]: ./figures/authors/ShihFuChang.jpg "Shih-Fu Chang" 12 | 13 | [VictorCampos-web]: https://imatge.upc.edu/web/people/victor-campos 14 | [JordiTorres-web]: http://www.jorditorres.org/ 15 | [XavierGiro-web]: https://imatge.upc.edu/web/people/xavier-giro 16 | [BrendanJou-web]: http://www.ee.columbia.edu/~bjou/ 17 | [ShihFuChang-web]: http://www.ee.columbia.edu/~sfchang/ 18 | 19 | 20 | 21 | A joint collaboration between: 22 | 23 | | ![logo-bsc] | ![logo-google] | ![logo-upc] | ![logo-columbia] | 24 | |:-:|:-:|:-:|:-:| 25 | | [Barcelona Supercomputing Center (BSC)](https://www.bsc.es/) | [Google Inc.](https://www.google.com/) | [Universitat Politècnica de Catalunya (UPC)](http://www.upc.edu/?set_language=en) | [Columbia University](https://www.columbia.edu/ ) | 26 | 27 | [logo-upc]: ./figures/logos/upc.jpg "Universitat Politècnica de Catalunya" 28 | [logo-bsc]: ./figures/logos/bsc.jpg "Barcelona Supercomputing Center" 29 | [logo-google]: ./figures/logos/google.png "Google" 30 | [logo-columbia]: ./figures/logos/columbia.png "Columbia University" 31 | 32 | 33 | 34 | ## Abstract 35 | 36 | Recurrent Neural Networks (RNNs) continue to show outstanding performance in sequence modeling tasks. However, training RNNs on long sequences often face challenges like slow inference, vanishing gradients and difficulty in capturing long term dependencies. In backpropagation through time settings, these issues are tightly coupled with the large, sequential computational graph resulting from unfolding the RNN in time. We introduce the Skip RNN model which extends existing RNN models by learning to skip state updates and shortens the effective size of the computational graph. This model can also be encouraged to perform fewer state updates through a budget constraint. We evaluate the proposed model on various tasks and show how it can reduce the number of required RNN updates while preserving, and sometimes even improving, the performance of the baseline RNN models. 37 | 38 |   39 | 40 | [model]: ./figures/skip-rnn-model.png 41 | ![model] 42 | 43 |   44 | 45 | 46 | ## Publication 47 | 48 | Victor Campos, Brendan Jou, Xavier Giro-i-Nieto, Jordi Torres, and Shih-Fu Chang. "Skip RNN: Learning to Skip State Updates in Recurrent Neural Networks", In International Conference on Learning Representations, 2018. 49 | 50 | ``` 51 | @inproceedings{campos2018skip, 52 | title={Skip RNN: Learning to Skip State Updates in Recurrent Neural Networks}, 53 | author={Campos, V{\'\i}ctor and Jou, Brendan and Gir{\'o}-i-Nieto, Xavier and Torres, Jordi and Chang, Shih-Fu}, 54 | booktitle={International Conference on Learning Representations}, 55 | year={2018} 56 | } 57 | ``` 58 | 59 | ## Code 60 | 61 | ### Dependencies 62 | This code was developed with Python 3.6.0 and TensorFlow 1.13.1. An older version of the code for TensorFlow 1.0.0 is available under the tags menu. To download and install TensorFlow, please follow the [official guide](https://www.tensorflow.org/get_started/os_setup). 63 | 64 | ### Using the models 65 | The models are ready to be used with TensorFlow's `tf.nn.dynamic_rnn` and can be found under `src/rnn_cells/skip_rnn_cells.py`. We provide four different RNN cells: 66 | 67 | * SkipLSTMCell: single SkipLSTM layer 68 | * SkipGRUCell: single SkipGRU layer 69 | * MultiSkipLSTMCell: stack of multiple SkipLSTM layers 70 | * MultiSkipGRUCell: stack of multiple SkipGRU layers 71 | 72 | An usage example can be found below: 73 | 74 | ```python 75 | import tensorflow as tf 76 | from rnn_cells.skip_rnn_cells import SkipLSTM 77 | 78 | # Define constants and hyperparameters 79 | NUM_CELLS = 110 80 | BATCH_SIZE = 256 81 | INPUT_SIZE = 10 82 | COST_PER_SAMPLE = 1e-05 83 | 84 | # Placeholder for the input tensor with shape (batch, time, input_dims) 85 | x = tf.placeholder(tf.float32, [None, None, INPUT_SIZE]) 86 | 87 | # Create SkipLSTM and trainable initial state 88 | cell = SkipLSTMCell(NUM_CELLS) 89 | initial_state = cell.trainable_initial_state(BATCH_SIZE) 90 | 91 | # Dynamic RNN unfolding 92 | rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, x, dtype=tf.float32, initial_state=initial_state) 93 | 94 | # Split the output into the actual RNN output and the state update gate 95 | rnn_outputs, updated_states = rnn_outputs.h, rnn_outputs.state_gate 96 | 97 | # Add a penalization for each state update (i.e. used sample) 98 | budget_loss = tf.reduce_mean(tf.reduce_sum(COST_PER_SAMPLE * updated_states, 1), 0) 99 | ``` 100 | 101 | ### PyTorch version 102 | 103 | [This repository](https://github.com/gitabcworld/skiprnn_pytorch) contains a PyTorch implementation of Skip RNN by Albert Berenguel. 104 | 105 | 106 | ## Acknowledgments 107 | 108 | We would like to especially thank the technical support team at the Barcelona Supercomputing Center, as well as [Oscar Mañas](https://es.linkedin.com/in/oscmansan) for updating the original codebase to TensorFlow 1.13.1, adding TensorBoard support and improving the data loading pipeline. 109 | 110 | | | | 111 | |:--|:-:| 112 | | This work has been supported by the [grant SEV2015-0493 of the Severo Ochoa Program](https://www.bsc.es/es/severo-ochoa/presentaci%C3%B3n) awarded by Spanish Government, project TIN2015-65316 by the Spanish Ministry of Science and Innovation contracts 2014-SGR-1051 by Generalitat de Catalunya | ![logo-severo] | 113 | | We gratefully acknowledge the support of [NVIDIA Corporation](http://www.nvidia.com/content/global/global.php) through the BSC/UPC NVIDIA GPU Center of Excellence. | ![logo-gpu_excellence_center] | 114 | | The Image ProcessingGroup at the UPC is a [SGR14 Consolidated Research Group](https://imatge.upc.edu/web/projects/sgr14-image-and-video-processing-group) recognized and sponsored by the Catalan Government (Generalitat de Catalunya) through its [AGAUR](http://agaur.gencat.cat/en/inici/index.html) office. | ![logo-catalonia] | 115 | | This work has been developed in the framework of the project [BigGraph TEC2013-43935-R](https://imatge.upc.edu/web/projects/biggraph-heterogeneous-information-and-graph-signal-processing-big-data-era-application), funded by the Spanish Ministerio de Economía y Competitividad and the European Regional Development Fund (ERDF). | ![logo-spain] | 116 | 117 | 118 | [logo-gpu_excellence_center]: ./figures/logos/gpu_excellence_center.png "Logo of NVidia" 119 | [logo-catalonia]: ./figures/logos/generalitat.jpg "Logo of Catalan government" 120 | [logo-spain]: ./figures/logos/MEyC.png "Logo of Spanish government" 121 | [logo-severo]: ./figures/logos/severo_ochoa.png "Severo Ochoa" 122 | 123 | 124 | ## Contact 125 | 126 | If you have any general doubt about our work or code which may be of interest for other researchers, please use the [public issues section](https://github.com/imatge-upc/skiprnn-2017-tfm/issues) on this github repo. Alternatively, drop us an e-mail at . 127 | -------------------------------------------------------------------------------- /figures/authors/BrendanJou.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/authors/BrendanJou.png -------------------------------------------------------------------------------- /figures/authors/JordiTorres.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/authors/JordiTorres.jpg -------------------------------------------------------------------------------- /figures/authors/ShihFuChang.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/authors/ShihFuChang.jpg -------------------------------------------------------------------------------- /figures/authors/VictorCampos.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/authors/VictorCampos.jpg -------------------------------------------------------------------------------- /figures/authors/XavierGiro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/authors/XavierGiro.jpg -------------------------------------------------------------------------------- /figures/logos/MEyC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/MEyC.png -------------------------------------------------------------------------------- /figures/logos/bsc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/bsc.jpg -------------------------------------------------------------------------------- /figures/logos/columbia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/columbia.png -------------------------------------------------------------------------------- /figures/logos/generalitat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/generalitat.jpg -------------------------------------------------------------------------------- /figures/logos/google.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/google.png -------------------------------------------------------------------------------- /figures/logos/gpu_excellence_center.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/gpu_excellence_center.png -------------------------------------------------------------------------------- /figures/logos/severo_ochoa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/severo_ochoa.png -------------------------------------------------------------------------------- /figures/logos/upc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/logos/upc.jpg -------------------------------------------------------------------------------- /figures/skip-rnn-model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/figures/skip-rnn-model.png -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/paper.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.13.1 2 | tensorflow-datasets -------------------------------------------------------------------------------- /src/01_adding_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train RNN models on the adding task. The network is given a sequence of (value, marker) tuples. The desired output is 3 | the addition of the only two values that were marked with a 1, whereas those marked with a 0 need to be ignored. 4 | Markers appear only in the first 10% and last 50% of the sequences. 5 | 6 | Validation is performed on data generated on the fly. 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | 12 | import os 13 | import datetime 14 | 15 | import random 16 | import numpy as np 17 | 18 | import tensorflow as tf 19 | import tensorflow.contrib.layers as layers 20 | 21 | from util.misc import * 22 | from util.graph_definition import * 23 | 24 | # Task-independent flags 25 | create_generic_flags() 26 | 27 | # Task-specific flags 28 | tf.app.flags.DEFINE_integer('validation_batches', 15, "How many batches to use for validation metrics.") 29 | tf.app.flags.DEFINE_integer('evaluate_every', 300, "How often is the model evaluated.") 30 | tf.app.flags.DEFINE_integer('sequence_length', 50, "Sequence length.") 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | # Constants 35 | MIN_VAL = -0.5 36 | MAX_VAL = 0.5 37 | FIRST_MARKER = 10. 38 | SECOND_MARKER = 50. 39 | INPUT_SIZE = 2 40 | OUTPUT_SIZE = 1 41 | 42 | 43 | def task_setup(): 44 | print('\tSequence length: %d' % FLAGS.sequence_length) 45 | print('\tValues drawn from Uniform[%.1f, %.1f]' % (MIN_VAL, MAX_VAL)) 46 | print('\tFirst marker: first %d%%' % FIRST_MARKER) 47 | print('\tSecond marker: last %d%%' % SECOND_MARKER) 48 | 49 | 50 | def generate_example(seq_length, min_val, max_val): 51 | """ 52 | Creates a list of (a,b) tuples where a is random[min_val,max_val] and b is 1 in only 53 | two tuples, 0 for the rest. The ground truth is the addition of a values for tuples with b=1. 54 | 55 | :param seq_length: length of the sequence to be generated 56 | :param min_val: minimum value for a 57 | :param max_val: maximum value for a 58 | 59 | :return x: list of (a,b) tuples 60 | :return y: ground truth 61 | """ 62 | # Select b values: one in first X% of the sequence, the other in the second Y% 63 | b1 = random.randint(0, int(seq_length * FIRST_MARKER / 100.) - 1) 64 | b2 = random.randint(int(seq_length * SECOND_MARKER / 100.), seq_length - 1) 65 | 66 | b = [0.] * seq_length 67 | b[b1] = 1. 68 | b[b2] = 1. 69 | 70 | # Generate list of tuples 71 | x = [(random.uniform(min_val, max_val), marker) for marker in b] 72 | y = x[b1][0] + x[b2][0] 73 | 74 | return x, y 75 | 76 | 77 | def generate_batch(seq_length, batch_size, min_val, max_val): 78 | """ 79 | Generates batch of examples. 80 | 81 | :param seq_length: length of the sequence to be generated 82 | :param batch_size: number of samples in the batch 83 | :param min_val: minimum value for a 84 | :param max_val: maximum value for a 85 | 86 | :return x: batch of examples 87 | :return y: batch of ground truth values 88 | """ 89 | n_elems = 2 90 | x = np.empty((batch_size, seq_length, n_elems)) 91 | y = np.empty((batch_size, 1)) 92 | 93 | for i in range(batch_size): 94 | sample, ground_truth = generate_example(seq_length, min_val, max_val) 95 | x[i, :, :] = sample 96 | y[i, 0] = ground_truth 97 | return x, y 98 | 99 | 100 | def train(): 101 | samples = tf.placeholder(tf.float32, [None, None, INPUT_SIZE]) # (batch, time, in) 102 | ground_truth = tf.placeholder(tf.float32, [None, OUTPUT_SIZE]) # (batch, out) 103 | 104 | cell, initial_state = create_model(model=FLAGS.model, 105 | num_cells=[FLAGS.rnn_cells] * FLAGS.rnn_layers, 106 | batch_size=FLAGS.batch_size) 107 | 108 | rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, samples, dtype=tf.float32, initial_state=initial_state) 109 | 110 | # Split the outputs of the RNN into the actual outputs and the state update gate 111 | rnn_outputs, updated_states = split_rnn_outputs(FLAGS.model, rnn_outputs) 112 | 113 | out = layers.linear(inputs=rnn_outputs[:, -1, :], num_outputs=OUTPUT_SIZE) 114 | 115 | # Compute L2 loss 116 | mse = tf.nn.l2_loss(ground_truth - out) / FLAGS.batch_size 117 | 118 | # Compute loss for each updated state 119 | budget_loss = compute_budget_loss(FLAGS.model, mse, updated_states, FLAGS.cost_per_sample) 120 | 121 | # Combine all losses 122 | loss = mse + budget_loss 123 | 124 | # Optimizer 125 | opt, grads_and_vars = compute_gradients(loss, FLAGS.learning_rate, FLAGS.grad_clip) 126 | train_fn = opt.apply_gradients(grads_and_vars) 127 | 128 | sess = tf.Session() 129 | 130 | log_dir = os.path.join(FLAGS.logdir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 131 | valid_writer = tf.summary.FileWriter(log_dir + '/val') 132 | 133 | sess.run(tf.global_variables_initializer()) 134 | 135 | try: 136 | num_iters = 0 137 | while True: 138 | # Generate new batch and perform SGD update 139 | x, y = generate_batch(min_val=MIN_VAL, max_val=MAX_VAL, 140 | seq_length=FLAGS.sequence_length, 141 | batch_size=FLAGS.batch_size) 142 | sess.run([train_fn], feed_dict={samples: x, ground_truth: y}) 143 | num_iters += 1 144 | 145 | # Evaluate on validation data generated on the fly 146 | if num_iters % FLAGS.evaluate_every == 0: 147 | valid_error, valid_steps = 0., 0. 148 | for _ in range(FLAGS.validation_batches): 149 | valid_x, valid_y = generate_batch(min_val=MIN_VAL, max_val=MAX_VAL, 150 | seq_length=FLAGS.sequence_length, 151 | batch_size=FLAGS.batch_size) 152 | valid_iter_error, valid_used_inputs = sess.run( 153 | [mse, updated_states], 154 | feed_dict={ 155 | samples: valid_x, 156 | ground_truth: valid_y}) 157 | valid_error += valid_iter_error 158 | if valid_used_inputs is not None: 159 | valid_steps += compute_used_samples(valid_used_inputs) 160 | else: 161 | valid_steps += FLAGS.sequence_length 162 | valid_error /= FLAGS.validation_batches 163 | valid_steps /= FLAGS.validation_batches 164 | 165 | valid_writer.add_summary(scalar_summary('error', valid_error), num_iters) 166 | valid_writer.add_summary(scalar_summary('used_samples', valid_steps / FLAGS.sequence_length), num_iters) 167 | valid_writer.flush() 168 | 169 | print("Iteration %d, " 170 | "validation error: %.7f, " 171 | "validation samples: %.2f%%" % (num_iters, 172 | valid_error, 173 | 100. * valid_steps / FLAGS.sequence_length)) 174 | except KeyboardInterrupt: 175 | pass 176 | 177 | 178 | def main(argv=None): 179 | print_setup(task_setup) 180 | train() 181 | 182 | 183 | if __name__ == '__main__': 184 | tf.app.run() 185 | -------------------------------------------------------------------------------- /src/02_frequency_discrimination_task.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train RNN models on the frequency discrimination task. Sine waves with period in [1, 100] are randomly generated and 3 | the network has to classify those with period in [5, 6]. 4 | 5 | Batches are stratified. Validation is performed on data generated on the fly. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import print_function 10 | 11 | import os 12 | import datetime 13 | 14 | import random 15 | import numpy as np 16 | 17 | import tensorflow as tf 18 | import tensorflow.contrib.layers as layers 19 | 20 | from util.misc import * 21 | from util.graph_definition import * 22 | 23 | # Task-independent flags 24 | create_generic_flags() 25 | 26 | # Task-specific flags 27 | tf.app.flags.DEFINE_float('sampling_period', 1., "Sampling period, in milliseconds") 28 | tf.app.flags.DEFINE_float('signal_duration', 100., "Signal duration, in milliseconds") 29 | tf.app.flags.DEFINE_integer('validation_batches', 15, "How many batches to use for validation metrics.") 30 | tf.app.flags.DEFINE_integer('evaluate_every', 300, "How often is the model evaluated.") 31 | 32 | FLAGS = tf.app.flags.FLAGS 33 | 34 | # Constants 35 | START_PERIOD = 0 36 | END_PERIOD = 100 37 | START_TARGET_PERIOD = 5 38 | END_TARGET_PERIOD = 6 39 | INPUT_SIZE = 1 40 | OUTPUT_SIZE = 2 41 | SEQUENCE_LENGTH = int(FLAGS.signal_duration / FLAGS.sampling_period) 42 | 43 | 44 | def task_setup(): 45 | print('\tSignal duration: %.1fms' % FLAGS.signal_duration) 46 | print('\tSampling period: %.1fms' % FLAGS.sampling_period) 47 | print('\tSequence length: %d' % SEQUENCE_LENGTH) 48 | print('\tTarget periods: (%.0f, %.0f)' % (START_TARGET_PERIOD, END_TARGET_PERIOD)) 49 | print('\tDistractor periods: (%.0f, %.0f) U (%.0f, %.0f)' % (START_PERIOD, START_TARGET_PERIOD, 50 | END_TARGET_PERIOD, END_PERIOD)) 51 | 52 | 53 | def generate_example(t, frequency, phase_shift): 54 | return np.cos(2 * np.pi * frequency * t + phase_shift) 55 | 56 | 57 | def random_disjoint_interval(start, end, avoid_start, avoid_end): 58 | """ 59 | Sample a value in [start, avoid_start] U [avoid_end, end] with uniform probability 60 | """ 61 | val = random.uniform(start, end - (avoid_end - avoid_start)) 62 | if val > avoid_start: 63 | val += (avoid_end - avoid_start) 64 | return val 65 | 66 | 67 | def generate_batch(batch_size, sampling_period, signal_duration, start_period, end_period, 68 | start_target_period, end_target_period): 69 | """ 70 | Generate a stratified batch of examples. There are two classes: 71 | class 0: sine waves with period in [start_target_period, end_target_period] 72 | class 1: sine waves with period in [start_period, start_target_period] U [end_target_period, end_period] 73 | :param batch_size: number of samples per batch 74 | :param sampling_period: sampling period in milliseconds 75 | :param signal_duration: duration of the sine waves in milliseconds 76 | 77 | :return x: batch of examples 78 | :return y: batch of labels 79 | """ 80 | seq_length = int(signal_duration / sampling_period) 81 | 82 | n_elems = 1 83 | x = np.empty((batch_size, seq_length, n_elems)) 84 | y = np.empty(batch_size, dtype=np.int64) 85 | 86 | t = np.linspace(0, signal_duration - sampling_period, seq_length) 87 | 88 | for idx in range(int(batch_size/2)): 89 | period = random.uniform(start_target_period, end_target_period) 90 | phase_shift = random.uniform(0, period) 91 | x[idx, :, 0] = generate_example(t, 1./period, phase_shift) 92 | y[idx] = 0 93 | for idx in range(int(batch_size/2), batch_size): 94 | period = random_disjoint_interval(start_period, end_period, 95 | start_target_period, end_target_period) 96 | phase_shift = random.uniform(0, period) 97 | x[idx, :, 0] = generate_example(t, 1./period, phase_shift) 98 | y[idx] = 1 99 | return x, y 100 | 101 | 102 | def train(): 103 | samples = tf.placeholder(tf.float32, [None, None, INPUT_SIZE]) # (batch, time, in) 104 | ground_truth = tf.placeholder(tf.int64, [None]) # (batch, out) 105 | 106 | cell, initial_state = create_model(model=FLAGS.model, 107 | num_cells=[FLAGS.rnn_cells] * FLAGS.rnn_layers, 108 | batch_size=FLAGS.batch_size) 109 | 110 | rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, samples, dtype=tf.float32, initial_state=initial_state) 111 | 112 | # Split the outputs of the RNN into the actual outputs and the state update gate 113 | rnn_outputs, updated_states = split_rnn_outputs(FLAGS.model, rnn_outputs) 114 | 115 | out = layers.linear(inputs=rnn_outputs[:, -1, :], num_outputs=OUTPUT_SIZE) 116 | 117 | # Compute cross-entropy loss 118 | cross_entropy_per_sample = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=out, labels=ground_truth) 119 | cross_entropy = tf.reduce_mean(cross_entropy_per_sample) 120 | 121 | # Compute accuracy 122 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(out, 1), ground_truth), tf.float32)) 123 | 124 | # Compute loss for each updated state 125 | budget_loss = compute_budget_loss(FLAGS.model, cross_entropy, updated_states, FLAGS.cost_per_sample) 126 | 127 | # Combine all losses 128 | loss = cross_entropy + budget_loss 129 | 130 | # Optimizer 131 | opt, grads_and_vars = compute_gradients(loss, FLAGS.learning_rate, FLAGS.grad_clip) 132 | train_fn = opt.apply_gradients(grads_and_vars) 133 | 134 | sess = tf.Session() 135 | 136 | log_dir = os.path.join(FLAGS.logdir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 137 | valid_writer = tf.summary.FileWriter(log_dir + '/val') 138 | 139 | sess.run(tf.global_variables_initializer()) 140 | 141 | try: 142 | num_iters = 0 143 | while True: 144 | # Generate new batch and perform SGD update 145 | x, y = generate_batch(FLAGS.batch_size, 146 | FLAGS.sampling_period, 147 | FLAGS.signal_duration, 148 | START_PERIOD, END_PERIOD, 149 | START_TARGET_PERIOD, END_TARGET_PERIOD) 150 | sess.run([train_fn], feed_dict={samples: x, ground_truth: y}) 151 | num_iters += 1 152 | 153 | # Evaluate on validation data generated on the fly 154 | if num_iters % FLAGS.evaluate_every == 0: 155 | valid_accuracy, valid_steps = 0., 0. 156 | for _ in range(FLAGS.validation_batches): 157 | valid_x, valid_y = generate_batch(FLAGS.batch_size, 158 | FLAGS.sampling_period, 159 | FLAGS.signal_duration, 160 | START_PERIOD, END_PERIOD, 161 | START_TARGET_PERIOD, END_TARGET_PERIOD) 162 | valid_iter_accuracy, valid_used_inputs = sess.run( 163 | [accuracy, updated_states], 164 | feed_dict={ 165 | samples: valid_x, 166 | ground_truth: valid_y}) 167 | valid_accuracy += valid_iter_accuracy 168 | if valid_used_inputs is not None: 169 | valid_steps += compute_used_samples(valid_used_inputs) 170 | else: 171 | valid_steps += SEQUENCE_LENGTH 172 | valid_accuracy /= FLAGS.validation_batches 173 | valid_steps /= FLAGS.validation_batches 174 | 175 | valid_writer.add_summary(scalar_summary('accuracy', valid_accuracy), num_iters) 176 | valid_writer.add_summary(scalar_summary('used_samples', valid_steps / SEQUENCE_LENGTH), num_iters) 177 | valid_writer.flush() 178 | 179 | print("Iteration %d, " 180 | "validation accuracy: %.2f%%, " 181 | "validation samples: %.2f (%.2f%%)" % (num_iters, 182 | 100. * valid_accuracy, 183 | valid_steps, 184 | 100. * valid_steps / SEQUENCE_LENGTH)) 185 | except KeyboardInterrupt: 186 | pass 187 | 188 | 189 | def main(argv=None): 190 | print_setup(task_setup) 191 | train() 192 | 193 | 194 | if __name__ == '__main__': 195 | tf.app.run() 196 | -------------------------------------------------------------------------------- /src/03_sequential_mnist.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train RNN models on sequential MNIST, where inputs are processed pixel by pixel. 3 | 4 | Results should be reported by evaluating on the test set the model with the best performance on the validation set. 5 | To avoid storing checkpoints and having a separate evaluation script, this script evaluates on both validation and 6 | test set after every epoch. 7 | """ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | 12 | import os 13 | import time 14 | import datetime 15 | 16 | import tensorflow as tf 17 | import tensorflow.contrib.layers as layers 18 | import tensorflow_datasets as tfds 19 | 20 | from util.misc import * 21 | from util.graph_definition import * 22 | 23 | # Task-independent flags 24 | create_generic_flags() 25 | 26 | # Task-specific flags 27 | tf.app.flags.DEFINE_string('data_path', '../data', "Path where the MNIST data will be stored.") 28 | 29 | FLAGS = tf.app.flags.FLAGS 30 | 31 | # Constants 32 | OUTPUT_SIZE = 10 33 | SEQUENCE_LENGTH = 784 34 | VALIDATION_SAMPLES = 5000 35 | NUM_EPOCHS = 600 36 | 37 | # Load data 38 | mnist_builder = tfds.builder('mnist', data_dir=FLAGS.data_path) 39 | mnist_builder.download_and_prepare() 40 | info = mnist_builder.info 41 | 42 | TRAIN_SAMPLES = info.splits[tfds.Split.TRAIN].num_examples - VALIDATION_SAMPLES 43 | TEST_SAMPLES = info.splits[tfds.Split.TEST].num_examples 44 | 45 | ITERATIONS_PER_EPOCH = int(TRAIN_SAMPLES / FLAGS.batch_size) 46 | VAL_ITERS = int(VALIDATION_SAMPLES / FLAGS.batch_size) 47 | TEST_ITERS = int(TEST_SAMPLES / FLAGS.batch_size) 48 | 49 | 50 | def input_fn(split): 51 | train_split, valid_split = tfds.Split.TRAIN.subsplit([TRAIN_SAMPLES, VALIDATION_SAMPLES]) 52 | if split == 'train': 53 | dataset = mnist_builder.as_dataset(as_supervised=True, split=train_split) 54 | elif split == 'val': 55 | dataset = mnist_builder.as_dataset(as_supervised=True, split=valid_split) 56 | elif split == 'test': 57 | dataset = mnist_builder.as_dataset(as_supervised=True, split=tfds.Split.TEST) 58 | else: 59 | raise ValueError() 60 | 61 | def preprocess(x, y): 62 | x = tf.cast(x, tf.float32) / 255.0 63 | return x, y 64 | 65 | dataset = dataset.map(preprocess) 66 | dataset = dataset.repeat() 67 | dataset = dataset.batch(FLAGS.batch_size) 68 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 69 | 70 | iterator = dataset.make_initializable_iterator() 71 | images, labels = iterator.get_next() 72 | iterator_init_op = iterator.initializer 73 | 74 | inputs = {'images': images, 'labels': labels, 'iterator_init_op': iterator_init_op} 75 | return inputs 76 | 77 | 78 | def model_fn(mode, inputs, reuse=False): 79 | samples = tf.reshape(inputs['images'], (-1, SEQUENCE_LENGTH, 1)) 80 | ground_truth = tf.cast(inputs['labels'], tf.int64) 81 | 82 | is_training = (mode == 'train') 83 | 84 | with tf.variable_scope('model', reuse=reuse): 85 | cell, initial_state = create_model(model=FLAGS.model, 86 | num_cells=[FLAGS.rnn_cells] * FLAGS.rnn_layers, 87 | batch_size=FLAGS.batch_size) 88 | 89 | rnn_outputs, rnn_states = tf.nn.dynamic_rnn(cell, samples, dtype=tf.float32, initial_state=initial_state) 90 | 91 | # Split the outputs of the RNN into the actual outputs and the state update gate 92 | rnn_outputs, updated_states = split_rnn_outputs(FLAGS.model, rnn_outputs) 93 | 94 | logits = layers.linear(inputs=rnn_outputs[:, -1, :], num_outputs=OUTPUT_SIZE) 95 | predictions = tf.argmax(logits, 1) 96 | 97 | # Compute cross-entropy loss 98 | cross_entropy_per_sample = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth) 99 | cross_entropy = tf.reduce_mean(cross_entropy_per_sample) 100 | 101 | # Compute accuracy 102 | accuracy = tf.reduce_mean(tf.cast(tf.equal(predictions, ground_truth), tf.float32)) 103 | 104 | # Compute loss for each updated state 105 | budget_loss = compute_budget_loss(FLAGS.model, cross_entropy, updated_states, FLAGS.cost_per_sample) 106 | 107 | # Combine all losses 108 | loss = cross_entropy + budget_loss 109 | loss = tf.reshape(loss, []) 110 | 111 | if is_training: 112 | # Optimizer 113 | opt, grads_and_vars = compute_gradients(loss, FLAGS.learning_rate, FLAGS.grad_clip) 114 | train_fn = opt.apply_gradients(grads_and_vars) 115 | 116 | model_spec = inputs 117 | model_spec['variable_init_op'] = tf.global_variables_initializer() 118 | model_spec['samples'] = samples 119 | model_spec['labels'] = ground_truth 120 | model_spec['loss'] = loss 121 | model_spec['accuracy'] = accuracy 122 | model_spec['updated_states'] = updated_states 123 | 124 | if is_training: 125 | model_spec['train_fn'] = train_fn 126 | 127 | return model_spec 128 | 129 | 130 | def train(): 131 | train_inputs = input_fn(split='train') 132 | valid_inputs = input_fn(split='val') 133 | test_inputs = input_fn(split='test') 134 | 135 | train_model_spec = model_fn('train', train_inputs) 136 | valid_model_spec = model_fn('val', valid_inputs, reuse=True) 137 | test_model_spec = model_fn('test', test_inputs, reuse=True) 138 | 139 | sess = tf.Session() 140 | 141 | log_dir = os.path.join(FLAGS.logdir, datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 142 | valid_writer = tf.summary.FileWriter(log_dir + '/val') 143 | test_writer = tf.summary.FileWriter(log_dir + '/test') 144 | 145 | # Initialize weights 146 | sess.run(train_model_spec['variable_init_op']) 147 | 148 | try: 149 | for epoch in range(NUM_EPOCHS): 150 | train_fn = train_model_spec['train_fn'] 151 | 152 | # Load the training dataset into the pipeline 153 | sess.run(train_model_spec['iterator_init_op']) 154 | 155 | start_time = time.time() 156 | for iteration in range(ITERATIONS_PER_EPOCH): 157 | # Perform SGD update 158 | sess.run([train_fn]) 159 | duration = time.time() - start_time 160 | 161 | # Evaluate on validation data 162 | accuracy = valid_model_spec['accuracy'] 163 | loss = valid_model_spec['loss'] 164 | updated_states = valid_model_spec['updated_states'] 165 | 166 | # Load the validation dataset into the pipeline 167 | sess.run(valid_model_spec['iterator_init_op']) 168 | 169 | valid_accuracy, valid_loss, valid_steps = 0, 0, 0 170 | for _ in range(VAL_ITERS): 171 | valid_iter_accuracy, valid_iter_loss, valid_used_inputs = sess.run([accuracy, loss, updated_states]) 172 | valid_loss += valid_iter_loss 173 | valid_accuracy += valid_iter_accuracy 174 | if valid_used_inputs is not None: 175 | valid_steps += compute_used_samples(valid_used_inputs) 176 | else: 177 | valid_steps += SEQUENCE_LENGTH 178 | valid_accuracy /= VAL_ITERS 179 | valid_loss /= VAL_ITERS 180 | valid_steps /= VAL_ITERS 181 | 182 | valid_writer.add_summary(scalar_summary('accuracy', valid_accuracy), epoch) 183 | valid_writer.add_summary(scalar_summary('loss', valid_loss), epoch) 184 | valid_writer.add_summary(scalar_summary('used_samples', valid_steps / SEQUENCE_LENGTH), epoch) 185 | valid_writer.flush() 186 | 187 | # Evaluate on test data 188 | accuracy = test_model_spec['accuracy'] 189 | loss = test_model_spec['loss'] 190 | updated_states = test_model_spec['updated_states'] 191 | 192 | # Load the test dataset into the pipeline 193 | sess.run(test_model_spec['iterator_init_op']) 194 | 195 | test_accuracy, test_loss, test_steps = 0, 0, 0 196 | for _ in range(TEST_ITERS): 197 | test_iter_accuracy, test_iter_loss, test_used_inputs = sess.run([accuracy, loss, updated_states]) 198 | test_accuracy += test_iter_accuracy 199 | test_loss += test_iter_loss 200 | if test_used_inputs is not None: 201 | test_steps += compute_used_samples(test_used_inputs) 202 | else: 203 | test_steps += SEQUENCE_LENGTH 204 | test_accuracy /= TEST_ITERS 205 | test_loss /= TEST_ITERS 206 | test_steps /= TEST_ITERS 207 | 208 | test_writer.add_summary(scalar_summary('accuracy', test_accuracy), epoch) 209 | test_writer.add_summary(scalar_summary('loss', test_loss), epoch) 210 | test_writer.add_summary(scalar_summary('used_samples', test_steps / SEQUENCE_LENGTH), epoch) 211 | test_writer.flush() 212 | 213 | print("Epoch %d/%d, " 214 | "duration: %.2f seconds, " 215 | "validation accuracy: %.2f%%, " 216 | "validation samples: %.2f (%.2f%%), " 217 | "test accuracy: %.2f%%, " 218 | "test samples: %.2f (%.2f%%)" % (epoch + 1, 219 | NUM_EPOCHS, 220 | duration, 221 | 100. * valid_accuracy, 222 | valid_steps, 223 | 100. * valid_steps / SEQUENCE_LENGTH, 224 | 100. * test_accuracy, 225 | test_steps, 226 | 100. * test_steps / SEQUENCE_LENGTH)) 227 | except KeyboardInterrupt: 228 | pass 229 | 230 | 231 | def main(argv=None): 232 | print_setup() 233 | train() 234 | 235 | 236 | if __name__ == '__main__': 237 | tf.app.run() 238 | -------------------------------------------------------------------------------- /src/rnn_cells/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/src/rnn_cells/__init__.py -------------------------------------------------------------------------------- /src/rnn_cells/basic_rnn_cells.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extended version of the BasicLSTMCell and BasicGRUCell in TensorFlow that allows to easily add custom inits, 3 | normalization, etc. 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | from rnn_cells import rnn_ops 12 | 13 | 14 | class BasicLSTMCell(tf.contrib.rnn.RNNCell): 15 | """ 16 | Basic LSTM recurrent network cell. 17 | The implementation is based on: http://arxiv.org/abs/1409.2329. 18 | We add forget_bias (default: 1) to the biases of the forget gate in order to 19 | reduce the scale of forgetting in the beginning of the training. 20 | """ 21 | 22 | def __init__(self, num_units, forget_bias=1.0, activation=tf.tanh, layer_norm=False): 23 | """ 24 | Initialize the basic LSTM cell 25 | :param num_units: int, the number of units in the LSTM cell 26 | :param forget_bias: float, the bias added to forget gates 27 | :param activation: activation function of the inner states 28 | :param layer_norm: bool, whether to use layer normalization 29 | """ 30 | self._num_units = num_units 31 | self._forget_bias = forget_bias 32 | self._activation = activation 33 | self._layer_norm = layer_norm 34 | 35 | @property 36 | def state_size(self): 37 | return tf.contrib.rnn.LSTMStateTuple(self._num_units, self._num_units) 38 | 39 | @property 40 | def output_size(self): 41 | return self._num_units 42 | 43 | def __call__(self, inputs, state, scope=None): 44 | """Long short-term memory cell (LSTM).""" 45 | with tf.variable_scope(scope or type(self).__name__): 46 | c, h = state 47 | 48 | # Parameters of gates are concatenated into one multiply for efficiency. 49 | concat = rnn_ops.linear([inputs, h], 4 * self._num_units, True) 50 | 51 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate 52 | i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1) 53 | 54 | if self._layer_norm: 55 | i = rnn_ops.layer_norm(i, name="i") 56 | j = rnn_ops.layer_norm(j, name="j") 57 | f = rnn_ops.layer_norm(f, name="f") 58 | o = rnn_ops.layer_norm(o, name="o") 59 | 60 | new_c = (c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * 61 | self._activation(j)) 62 | new_h = self._activation(new_c) * tf.sigmoid(o) 63 | 64 | new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 65 | return new_h, new_state 66 | 67 | def trainable_initial_state(self, batch_size): 68 | """ 69 | Create a trainable initial state for the BasicLSTMCell 70 | :param batch_size: number of samples per batch 71 | :return: LSTMStateTuple 72 | """ 73 | def _create_initial_state(batch_size, state_size, trainable=True, initializer=tf.random_normal_initializer()): 74 | with tf.device('/cpu:0'): 75 | s = tf.get_variable('initial_state', shape=[1, state_size], dtype=tf.float32, trainable=trainable, 76 | initializer=initializer) 77 | state = tf.tile(s, tf.stack([batch_size] + [1])) 78 | return state 79 | 80 | with tf.variable_scope('initial_c'): 81 | initial_c = _create_initial_state(batch_size, self._num_units) 82 | with tf.variable_scope('initial_h'): 83 | initial_h = _create_initial_state(batch_size, self._num_units) 84 | return tf.contrib.rnn.LSTMStateTuple(initial_c, initial_h) 85 | 86 | 87 | class BasicGRUCell(tf.contrib.rnn.RNNCell): 88 | """ 89 | Gated Recurrent Unit cell. 90 | The implementation is based on http://arxiv.org/abs/1406.1078. 91 | """ 92 | 93 | def __init__(self, num_units, activation=tf.tanh, layer_norm=False): 94 | """ 95 | Initialize the basic GRU cell 96 | :param num_units: int, the number of units in the LSTM cell 97 | :param activation: activation function of the inner states 98 | :param layer_norm: bool, whether to use layer normalization 99 | """ 100 | self._num_units = num_units 101 | self._activation = activation 102 | self._layer_norm = layer_norm 103 | 104 | @property 105 | def state_size(self): 106 | return self._num_units 107 | 108 | @property 109 | def output_size(self): 110 | return self._num_units 111 | 112 | def __call__(self, inputs, state, scope=None): 113 | """Gated recurrent unit (GRU) with num_units cells.""" 114 | with tf.variable_scope(scope or type(self).__name__): 115 | with tf.variable_scope("gates"): # Reset gate and update gate. 116 | # We start with bias of 1.0 to not reset and not update. 117 | concat = rnn_ops.linear([inputs, state], 2 * self._num_units, True, bias_start=1.0) 118 | r, u = tf.split(value=concat, num_or_size_splits=2, axis=1) 119 | 120 | if self._layer_norm: 121 | r = rnn_ops.layer_norm(r, name="r") 122 | u = rnn_ops.layer_norm(u, name="u") 123 | 124 | # Apply non-linearity after layer normalization 125 | r = tf.sigmoid(r) 126 | u = tf.sigmoid(u) 127 | 128 | with tf.variable_scope("candidate"): 129 | c = self._activation(rnn_ops.linear([inputs, r * state], self._num_units, True)) 130 | new_h = u * state + (1 - u) * c 131 | return new_h, new_h 132 | 133 | def trainable_initial_state(self, batch_size): 134 | """ 135 | Create a trainable initial state for the BasicGRUCell 136 | :param batch_size: number of samples per batch 137 | :return: tensor with shape [batch_size, self.state_size] 138 | """ 139 | with tf.variable_scope('initial_h'): 140 | initial_h = rnn_ops.create_initial_state(batch_size, self._num_units) 141 | return initial_h 142 | -------------------------------------------------------------------------------- /src/rnn_cells/rnn_ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import tensorflow as tf 5 | 6 | from tensorflow.python.util import nest 7 | 8 | 9 | def get_variable(name, shape, initializer=None, dtype=tf.float32, device=None): 10 | """ 11 | Helper to create a Variable stored on CPU memory. 12 | Args: 13 | name: name of the variable 14 | shape: list of ints 15 | initializer: initializer for Variable 16 | dtype: data type, defaults to tf.float32 17 | device: device to which the variable will be pinned 18 | Returns: 19 | Variable Tensor 20 | """ 21 | if device is None: 22 | device = '/cpu:0' 23 | if initializer is None: 24 | with tf.device(device): 25 | var = tf.get_variable(name, shape, dtype=dtype) 26 | else: 27 | with tf.device(device): 28 | var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) 29 | return var 30 | 31 | 32 | def linear(args, output_size, bias, weights_init=None, bias_start=0.0): 33 | """ 34 | Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 35 | Args: 36 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 37 | output_size: int, second dimension of W[i]. 38 | bias: boolean, whether to add a bias term or not. 39 | weights_init: initializer for the weights. 40 | bias_start: starting value to initialize the gates bias; 0 by default. 41 | Returns: 42 | A 2D Tensor with shape [batch x output_size] equal to 43 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 44 | Raises: 45 | ValueError: if some of the arguments has unspecified or wrong shape. 46 | """ 47 | if args is None or (nest.is_sequence(args) and not args): 48 | raise ValueError("`args` must be specified") 49 | if not nest.is_sequence(args): 50 | args = [args] 51 | 52 | # Calculate the total size of arguments on dimension 1. 53 | total_arg_size = 0 54 | shapes = [a.get_shape() for a in args] 55 | for shape in shapes: 56 | if shape.ndims != 2: 57 | raise ValueError("linear is expecting 2D arguments: %s" % shapes) 58 | if shape[1].value is None: 59 | raise ValueError("linear expects shape[1] to be provided for shape %s, but saw %s" % (shape, shape[1])) 60 | else: 61 | total_arg_size += shape[1].value 62 | 63 | dtype = [a.dtype for a in args][0] 64 | 65 | # Now the computation. 66 | scope = tf.get_variable_scope() 67 | with tf.variable_scope(scope) as outer_scope: 68 | weights = get_variable("Weights", [total_arg_size, output_size], initializer=weights_init) 69 | if len(args) == 1: 70 | res = tf.matmul(args[0], weights) 71 | else: 72 | res = tf.matmul(tf.concat(args, 1), weights) 73 | if not bias: 74 | return res 75 | with tf.variable_scope(outer_scope) as inner_scope: 76 | inner_scope.set_partitioner(None) 77 | biases = get_variable('Biases', [output_size], initializer=tf.constant_initializer(bias_start, dtype=dtype)) 78 | return tf.nn.bias_add(res, biases) 79 | 80 | 81 | def create_initial_state(batch_size, state_size, trainable=True, initializer=tf.random_normal_initializer()): 82 | with tf.device('/cpu:0'): 83 | s = tf.get_variable('initial_state', shape=[1, state_size], dtype=tf.float32, trainable=trainable, 84 | initializer=initializer) 85 | state = tf.tile(s, tf.stack([batch_size] + [1])) 86 | return state 87 | 88 | 89 | def layer_norm(x, axes=1, initial_bias_value=0.0, epsilon=1e-3, name="var"): 90 | """ 91 | Apply layer normalization to x 92 | Args: 93 | x: input variable. 94 | initial_bias_value: initial value for the LN bias. 95 | epsilon: small constant value to avoid division by zero. 96 | scope: scope or name for the LN op. 97 | Returns: 98 | LN(x) with same shape as x 99 | """ 100 | if not isinstance(axes, list): 101 | axes = [axes] 102 | 103 | scope = tf.get_variable_scope() 104 | with tf.variable_scope(scope): 105 | with tf.variable_scope(name): 106 | mean = tf.reduce_mean(x, axes, keep_dims=True) 107 | variance = tf.sqrt(tf.reduce_mean(tf.square(x - mean), axes, keep_dims=True)) 108 | 109 | with tf.device('/cpu:0'): 110 | gain = tf.get_variable('gain', x.get_shape().as_list()[1:], 111 | initializer=tf.constant_initializer(1.0)) 112 | bias = tf.get_variable('bias', x.get_shape().as_list()[1:], 113 | initializer=tf.constant_initializer(initial_bias_value)) 114 | 115 | return gain * (x - mean) / (variance + epsilon) + bias -------------------------------------------------------------------------------- /src/rnn_cells/skip_rnn_cells.py: -------------------------------------------------------------------------------- 1 | """ 2 | Skip RNN cells that decide which timesteps should be attended. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import print_function 7 | 8 | import collections 9 | import tensorflow as tf 10 | 11 | from rnn_cells import rnn_ops 12 | 13 | from tensorflow.python.framework import ops 14 | 15 | 16 | SkipLSTMStateTuple = collections.namedtuple("SkipLSTMStateTuple", ("c", "h", "update_prob", "cum_update_prob")) 17 | SkipLSTMOutputTuple = collections.namedtuple("SkipLSTMOutputTuple", ("h", "state_gate")) 18 | LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple 19 | 20 | SkipGRUStateTuple = collections.namedtuple("SkipGRUStateTuple", ("h", "update_prob", "cum_update_prob")) 21 | SkipGRUOutputTuple = collections.namedtuple("SkipGRUOutputTuple", ("h", "state_gate")) 22 | 23 | 24 | def _binary_round(x): 25 | """ 26 | Rounds a tensor whose values are in [0,1] to a tensor with values in {0, 1}, 27 | using the straight through estimator for the gradient. 28 | 29 | Based on http://r2rt.com/binary-stochastic-neurons-in-tensorflow.html 30 | 31 | :param x: input tensor 32 | :return: y=round(x) with gradients defined by the identity mapping (y=x) 33 | """ 34 | g = tf.get_default_graph() 35 | 36 | with ops.name_scope("BinaryRound") as name: 37 | with g.gradient_override_map({"Round": "Identity"}): 38 | return tf.round(x, name=name) 39 | 40 | 41 | class SkipLSTMCell(tf.nn.rnn_cell.RNNCell): 42 | """ 43 | Single Skip LSTM cell. Augments the basic LSTM cell with a binary output that decides whether to 44 | update or copy the cell state. The binary neuron is optimized using the Straight Through Estimator. 45 | """ 46 | def __init__(self, num_units, forget_bias=1.0, activation=tf.tanh, layer_norm=False, update_bias=1.0): 47 | """ 48 | Initialize the Skip LSTM cell 49 | :param num_units: int, the number of units in the LSTM cell 50 | :param forget_bias: float, the bias added to forget gates 51 | :param activation: activation function of the inner states 52 | :param layer_norm: bool, whether to use layer normalization 53 | :param update_bias: float, initial value for the bias added to the update state gate 54 | """ 55 | self._num_units = num_units 56 | self._forget_bias = forget_bias 57 | self._activation = activation 58 | self._layer_norm = layer_norm 59 | self._update_bias = update_bias 60 | 61 | @property 62 | def state_size(self): 63 | return SkipLSTMStateTuple(self._num_units, self._num_units, 1, 1) 64 | 65 | @property 66 | def output_size(self): 67 | return SkipLSTMOutputTuple(self._num_units, 1) 68 | 69 | def __call__(self, inputs, state, scope=None): 70 | with tf.variable_scope(scope or type(self).__name__): 71 | c_prev, h_prev, update_prob_prev, cum_update_prob_prev = state 72 | 73 | # Parameters of gates are concatenated into one multiply for efficiency. 74 | concat = rnn_ops.linear([inputs, h_prev], 4 * self._num_units, True) 75 | 76 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate 77 | i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1) 78 | 79 | if self._layer_norm: 80 | i = rnn_ops.layer_norm(i, name="i") 81 | j = rnn_ops.layer_norm(j, name="j") 82 | f = rnn_ops.layer_norm(f, name="f") 83 | o = rnn_ops.layer_norm(o, name="o") 84 | 85 | new_c_tilde = (c_prev * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * self._activation(j)) 86 | new_h_tilde = self._activation(new_c_tilde) * tf.sigmoid(o) 87 | 88 | # Compute value for the update prob 89 | with tf.variable_scope('state_update_prob'): 90 | new_update_prob_tilde = rnn_ops.linear(new_c_tilde, 1, True, bias_start=self._update_bias) 91 | new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) 92 | 93 | # Compute value for the update gate 94 | cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev) 95 | update_gate = _binary_round(cum_update_prob) 96 | 97 | # Apply update gate 98 | new_c = update_gate * new_c_tilde + (1. - update_gate) * c_prev 99 | new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev 100 | new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev 101 | new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob 102 | 103 | new_state = SkipLSTMStateTuple(new_c, new_h, new_update_prob, new_cum_update_prob) 104 | new_output = SkipLSTMOutputTuple(new_h, update_gate) 105 | 106 | return new_output, new_state 107 | 108 | def trainable_initial_state(self, batch_size): 109 | """ 110 | Create a trainable initial state for the SkipLSTMCell 111 | :param batch_size: number of samples per batch 112 | :return: SkipLSTMStateTuple 113 | """ 114 | with tf.variable_scope('initial_c'): 115 | initial_c = rnn_ops.create_initial_state(batch_size, self._num_units) 116 | with tf.variable_scope('initial_h'): 117 | initial_h = rnn_ops.create_initial_state(batch_size, self._num_units) 118 | with tf.variable_scope('initial_update_prob'): 119 | initial_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 120 | initializer=tf.ones_initializer()) 121 | with tf.variable_scope('initial_cum_update_prob'): 122 | initial_cum_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 123 | initializer=tf.zeros_initializer()) 124 | return SkipLSTMStateTuple(initial_c, initial_h, initial_update_prob, initial_cum_update_prob) 125 | 126 | 127 | class MultiSkipLSTMCell(tf.nn.rnn_cell.RNNCell): 128 | """ 129 | Stack of Skip LSTM cells. The selection binary output is computed from the state of the cell on top of 130 | the stack. 131 | """ 132 | def __init__(self, num_units, forget_bias=1.0, activation=tf.tanh, layer_norm=False, update_bias=1.0): 133 | """ 134 | Initialize the stack of Skip LSTM cells 135 | :param num_units: list of int, the number of units in each LSTM cell 136 | :param forget_bias: float, the bias added to forget gates 137 | :param activation: activation function of the inner states 138 | :param layer_norm: bool, whether to use layer normalization 139 | :param update_bias: float, initial value for the bias added to the update state gate 140 | """ 141 | if not isinstance(num_units, list): 142 | num_units = [num_units] 143 | self._num_units = num_units 144 | self._num_layers = len(self._num_units) 145 | self._forget_bias = forget_bias 146 | self._activation = activation 147 | self._layer_norm = layer_norm 148 | self._update_bias = update_bias 149 | 150 | @property 151 | def state_size(self): 152 | return [LSTMStateTuple(num_units, num_units) for num_units in self._num_units[:-1]] + \ 153 | [SkipLSTMStateTuple(self._num_units[-1], self._num_units[:-1], 1, 1)] 154 | 155 | @property 156 | def output_size(self): 157 | return SkipLSTMOutputTuple(self._num_units[-1], 1) 158 | 159 | def __call__(self, inputs, state, scope=None): 160 | with tf.variable_scope(scope or type(self).__name__): 161 | update_prob_prev, cum_update_prob_prev = state[-1].update_prob, state[-1].cum_update_prob 162 | cell_input = inputs 163 | state_candidates = [] 164 | 165 | # Compute update candidates for all layers 166 | for idx in range(self._num_layers): 167 | with tf.variable_scope('layer_%d' % (idx + 1)): 168 | c_prev, h_prev = state[idx].c, state[idx].h 169 | 170 | # Parameters of gates are concatenated into one multiply for efficiency. 171 | concat = rnn_ops.linear([cell_input, h_prev], 4 * self._num_units[idx], True) 172 | 173 | # i = input_gate, j = new_input, f = forget_gate, o = output_gate 174 | i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1) 175 | 176 | if self._layer_norm: 177 | i = rnn_ops.layer_norm(i, name="i") 178 | j = rnn_ops.layer_norm(j, name="j") 179 | f = rnn_ops.layer_norm(f, name="f") 180 | o = rnn_ops.layer_norm(o, name="o") 181 | 182 | new_c_tilde = (c_prev * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * self._activation(j)) 183 | new_h_tilde = self._activation(new_c_tilde) * tf.sigmoid(o) 184 | 185 | state_candidates.append(LSTMStateTuple(new_c_tilde, new_h_tilde)) 186 | cell_input = new_h_tilde 187 | 188 | # Compute value for the update prob 189 | with tf.variable_scope('state_update_prob'): 190 | new_update_prob_tilde = rnn_ops.linear(state_candidates[-1].c, 1, True, bias_start=self._update_bias) 191 | new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) 192 | 193 | # Compute value for the update gate 194 | cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev) 195 | update_gate = _binary_round(cum_update_prob) 196 | 197 | # Apply update gate 198 | new_states = [] 199 | for idx in range(self._num_layers - 1): 200 | new_c = update_gate * state_candidates[idx].c + (1. - update_gate) * state[idx].c 201 | new_h = update_gate * state_candidates[idx].h + (1. - update_gate) * state[idx].h 202 | new_states.append(LSTMStateTuple(new_c, new_h)) 203 | new_c = update_gate * state_candidates[-1].c + (1. - update_gate) * state[-1].c 204 | new_h = update_gate * state_candidates[-1].h + (1. - update_gate) * state[-1].h 205 | new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev 206 | new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob 207 | 208 | new_states.append(SkipLSTMStateTuple(new_c, new_h, new_update_prob, new_cum_update_prob)) 209 | new_output = SkipLSTMOutputTuple(new_h, update_gate) 210 | 211 | return new_output, new_states 212 | 213 | def trainable_initial_state(self, batch_size): 214 | """ 215 | Create a trainable initial state for the MultiSkipLSTMCell 216 | :param batch_size: number of samples per batch 217 | :return: list of SkipLSTMStateTuple 218 | """ 219 | initial_states = [] 220 | for idx in range(self._num_layers - 1): 221 | with tf.variable_scope('layer_%d' % (idx + 1)): 222 | with tf.variable_scope('initial_c'): 223 | initial_c = rnn_ops.create_initial_state(batch_size, self._num_units[idx]) 224 | with tf.variable_scope('initial_h'): 225 | initial_h = rnn_ops.create_initial_state(batch_size, self._num_units[idx]) 226 | initial_states.append(LSTMStateTuple(initial_c, initial_h)) 227 | with tf.variable_scope('layer_%d' % self._num_layers): 228 | with tf.variable_scope('initial_c'): 229 | initial_c = rnn_ops.create_initial_state(batch_size, self._num_units[-1]) 230 | with tf.variable_scope('initial_h'): 231 | initial_h = rnn_ops.create_initial_state(batch_size, self._num_units[-1]) 232 | with tf.variable_scope('initial_update_prob'): 233 | initial_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 234 | initializer=tf.ones_initializer()) 235 | with tf.variable_scope('initial_cum_update_prob'): 236 | initial_cum_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 237 | initializer=tf.zeros_initializer()) 238 | initial_states.append(SkipLSTMStateTuple(initial_c, initial_h, 239 | initial_update_prob, initial_cum_update_prob)) 240 | return initial_states 241 | 242 | 243 | class SkipGRUCell(tf.nn.rnn_cell.RNNCell): 244 | """ 245 | Single Skip GRU cell. Augments the basic GRU cell with a binary output that decides whether to 246 | update or copy the cell state. The binary neuron is optimized using the Straight Through Estimator. 247 | """ 248 | def __init__(self, num_units, activation=tf.tanh, layer_norm=False, update_bias=1.0): 249 | """ 250 | Initialize the Skip GRU cell 251 | :param num_units: int, the number of units in the GRU cell 252 | :param activation: activation function of the inner states 253 | :param layer_norm: bool, whether to use layer normalization 254 | :param update_bias: float, initial value for the bias added to the update state gate 255 | """ 256 | self._num_units = num_units 257 | self._activation = activation 258 | self._layer_norm = layer_norm 259 | self._update_bias = update_bias 260 | 261 | @property 262 | def state_size(self): 263 | return SkipGRUStateTuple(self._num_units, 1, 1) 264 | 265 | @property 266 | def output_size(self): 267 | return SkipGRUOutputTuple(self._num_units, 1) 268 | 269 | def __call__(self, inputs, state, scope=None): 270 | with tf.variable_scope(scope or type(self).__name__): 271 | h_prev, update_prob_prev, cum_update_prob_prev = state 272 | 273 | # Parameters of gates are concatenated into one multiply for efficiency. 274 | with tf.variable_scope("gates"): 275 | concat = rnn_ops.linear([inputs, h_prev], 2 * self._num_units, bias=True, bias_start=1.0) 276 | 277 | # r = reset_gate, u = update_gate 278 | r, u = tf.split(value=concat, num_or_size_splits=2, axis=1) 279 | 280 | if self._layer_norm: 281 | r = rnn_ops.layer_norm(r, name="r") 282 | u = rnn_ops.layer_norm(u, name="u") 283 | 284 | # Apply non-linearity after layer normalization 285 | r = tf.sigmoid(r) 286 | u = tf.sigmoid(u) 287 | 288 | with tf.variable_scope("candidate"): 289 | new_c_tilde = self._activation(rnn_ops.linear([inputs, r * h_prev], self._num_units, True)) 290 | new_h_tilde = u * h_prev + (1 - u) * new_c_tilde 291 | 292 | # Compute value for the update prob 293 | with tf.variable_scope('state_update_prob'): 294 | new_update_prob_tilde = rnn_ops.linear(new_h_tilde, 1, True, bias_start=self._update_bias) 295 | new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) 296 | 297 | # Compute value for the update gate 298 | cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev) 299 | update_gate = _binary_round(cum_update_prob) 300 | 301 | # Apply update gate 302 | new_h = update_gate * new_h_tilde + (1. - update_gate) * h_prev 303 | new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev 304 | new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob 305 | 306 | new_state = SkipGRUStateTuple(new_h, new_update_prob, new_cum_update_prob) 307 | new_output = SkipGRUOutputTuple(new_h, update_gate) 308 | 309 | return new_output, new_state 310 | 311 | def trainable_initial_state(self, batch_size): 312 | """ 313 | Create a trainable initial state for the SkipGRUCell 314 | :param batch_size: number of samples per batch 315 | :return: SkipGRUStateTuple 316 | """ 317 | with tf.variable_scope('initial_h'): 318 | initial_h = rnn_ops.create_initial_state(batch_size, self._num_units) 319 | with tf.variable_scope('initial_update_prob'): 320 | initial_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 321 | initializer=tf.ones_initializer()) 322 | with tf.variable_scope('initial_cum_update_prob'): 323 | initial_cum_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 324 | initializer=tf.zeros_initializer()) 325 | return SkipGRUStateTuple(initial_h, initial_update_prob, initial_cum_update_prob) 326 | 327 | 328 | class MultiSkipGRUCell(tf.nn.rnn_cell.RNNCell): 329 | """ 330 | Stack of Skip GRU cells. The selection binary output is computed from the state of the cell on top of 331 | the stack. 332 | """ 333 | def __init__(self, num_units, activation=tf.tanh, layer_norm=False, update_bias=1.0): 334 | """ 335 | Initialize the stack of Skip GRU cells 336 | :param num_units: list of int, the number of units in each GRU cell 337 | :param activation: activation function of the inner states 338 | :param layer_norm: bool, whether to use layer normalization 339 | :param update_bias: float, initial value for the bias added to the update state gate 340 | """ 341 | if not isinstance(num_units, list): 342 | num_units = [num_units] 343 | self._num_units = num_units 344 | self._num_layers = len(self._num_units) 345 | self._activation = activation 346 | self._layer_norm = layer_norm 347 | self._update_bias = update_bias 348 | 349 | @property 350 | def state_size(self): 351 | return [num_units for num_units in self._num_units[:-1]] + [SkipGRUStateTuple(self._num_units[-1], 1, 1)] 352 | 353 | @property 354 | def output_size(self): 355 | return SkipGRUOutputTuple(self._num_units[-1], 1) 356 | 357 | def __call__(self, inputs, state, scope=None): 358 | with tf.variable_scope(scope or type(self).__name__): 359 | update_prob_prev, cum_update_prob_prev = state[-1].update_prob, state[-1].cum_update_prob 360 | cell_input = inputs 361 | state_candidates = [] 362 | 363 | # Compute update candidates for all layers 364 | for idx in range(self._num_layers): 365 | with tf.variable_scope('layer_%d' % (idx + 1)): 366 | if isinstance(state[idx], SkipGRUStateTuple): 367 | h_prev = state[idx].h 368 | else: 369 | h_prev = state[idx] 370 | 371 | # Parameters of gates are concatenated into one multiply for efficiency. 372 | with tf.variable_scope("gates"): 373 | concat = rnn_ops.linear([cell_input, h_prev], 2 * self._num_units[idx], bias=True, bias_start=1.0,) 374 | 375 | # r = reset_gate, u = update_gate 376 | r, u = tf.split(value=concat, num_or_size_splits=2, axis=1) 377 | 378 | if self._layer_norm: 379 | r = rnn_ops.layer_norm(r, name="r") 380 | u = rnn_ops.layer_norm(u, name="u") 381 | 382 | # Apply non-linearity after layer normalization 383 | r = tf.sigmoid(r) 384 | u = tf.sigmoid(u) 385 | 386 | with tf.variable_scope("candidate"): 387 | new_c_tilde = self._activation(rnn_ops.linear([inputs, r * h_prev], self._num_units[idx], True)) 388 | new_h_tilde = u * h_prev + (1 - u) * new_c_tilde 389 | 390 | state_candidates.append(new_h_tilde) 391 | cell_input = new_h_tilde 392 | 393 | # Compute value for the update prob 394 | with tf.variable_scope('state_update_prob'): 395 | new_update_prob_tilde = rnn_ops.linear(state_candidates[-1], 1, True, bias_start=self._update_bias) 396 | new_update_prob_tilde = tf.sigmoid(new_update_prob_tilde) 397 | 398 | # Compute value for the update gate 399 | cum_update_prob = cum_update_prob_prev + tf.minimum(update_prob_prev, 1. - cum_update_prob_prev) 400 | update_gate = _binary_round(cum_update_prob) 401 | 402 | # Apply update gate 403 | new_states = [] 404 | for idx in range(self._num_layers - 1): 405 | new_h = update_gate * state_candidates[idx] + (1. - update_gate) * state[idx] 406 | new_states.append(new_h) 407 | new_h = update_gate * state_candidates[-1] + (1. - update_gate) * state[-1].h 408 | new_update_prob = update_gate * new_update_prob_tilde + (1. - update_gate) * update_prob_prev 409 | new_cum_update_prob = update_gate * 0. + (1. - update_gate) * cum_update_prob 410 | 411 | new_states.append(SkipGRUStateTuple(new_h, new_update_prob, new_cum_update_prob)) 412 | new_output = SkipGRUOutputTuple(new_h, update_gate) 413 | 414 | return new_output, new_states 415 | 416 | def trainable_initial_state(self, batch_size): 417 | """ 418 | Create a trainable initial state for the MultiSkipGRUCell 419 | :param batch_size: number of samples per batch 420 | :return: list of tensors and SkipGRUStateTuple 421 | """ 422 | initial_states = [] 423 | for idx in range(self._num_layers - 1): 424 | with tf.variable_scope('layer_%d' % (idx + 1)): 425 | with tf.variable_scope('initial_h'): 426 | initial_h = rnn_ops.create_initial_state(batch_size, self._num_units[idx]) 427 | initial_states.append(initial_h) 428 | with tf.variable_scope('layer_%d' % self._num_layers): 429 | with tf.variable_scope('initial_h'): 430 | initial_h = rnn_ops.create_initial_state(batch_size, self._num_units[-1]) 431 | with tf.variable_scope('initial_update_prob'): 432 | initial_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 433 | initializer=tf.ones_initializer()) 434 | with tf.variable_scope('initial_cum_update_prob'): 435 | initial_cum_update_prob = rnn_ops.create_initial_state(batch_size, 1, trainable=False, 436 | initializer=tf.zeros_initializer()) 437 | initial_states.append(SkipGRUStateTuple(initial_h, initial_update_prob, initial_cum_update_prob)) 438 | return initial_states 439 | -------------------------------------------------------------------------------- /src/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imatge-upc/skiprnn-2017-telecombcn/63f93a539a3f2c7a713089fdd2c38bb7b0c581ca/src/util/__init__.py -------------------------------------------------------------------------------- /src/util/graph_definition.py: -------------------------------------------------------------------------------- 1 | """ 2 | Graph creation functions. 3 | """ 4 | 5 | 6 | from __future__ import print_function 7 | from __future__ import absolute_import 8 | 9 | import tensorflow as tf 10 | 11 | from rnn_cells.basic_rnn_cells import BasicLSTMCell, BasicGRUCell 12 | from rnn_cells.skip_rnn_cells import SkipLSTMCell, MultiSkipLSTMCell 13 | from rnn_cells.skip_rnn_cells import SkipGRUCell, MultiSkipGRUCell 14 | 15 | 16 | def create_generic_flags(): 17 | """ 18 | Create flags which are shared by all experiments 19 | """ 20 | # Generic flags 21 | tf.app.flags.DEFINE_string('model', 'lstm', "Select RNN cell: {lstm, gru, skip_lstm, skip_gru}") 22 | tf.app.flags.DEFINE_integer("rnn_cells", 110, "Number of RNN cells.") 23 | tf.app.flags.DEFINE_integer("rnn_layers", 1, "Number of RNN layers.") 24 | tf.app.flags.DEFINE_integer('batch_size', 256, "Batch size.") 25 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, "Learning rate.") 26 | tf.app.flags.DEFINE_float('grad_clip', 1., "Clip gradients at this value. Set to <=0 to disable clipping.") 27 | tf.app.flags.DEFINE_string('logdir', '../logs', "Directory where TensorBoard logs will be stored.") 28 | 29 | # Flags for the Skip RNN cells 30 | tf.app.flags.DEFINE_float('cost_per_sample', 0., "Cost per used sample. Set to 0 to disable this option.") 31 | 32 | 33 | def compute_gradients(loss, learning_rate, gradient_clipping=-1): 34 | """ 35 | Create optimizer, compute gradients and (optionally) apply gradient clipping 36 | """ 37 | opt = tf.train.AdamOptimizer(learning_rate) 38 | if gradient_clipping > 0: 39 | vars_to_optimize = tf.trainable_variables() 40 | grads, _ = tf.clip_by_global_norm(tf.gradients(loss, vars_to_optimize), clip_norm=gradient_clipping) 41 | grads_and_vars = list(zip(grads, vars_to_optimize)) 42 | else: 43 | grads_and_vars = opt.compute_gradients(loss) 44 | return opt, grads_and_vars 45 | 46 | 47 | def create_model(model, num_cells, batch_size, learn_initial_state=True): 48 | """ 49 | Returns a tuple of (cell, initial_state) to use with dynamic_rnn. 50 | If num_cells is an integer, a single RNN cell will be created. If it is a list, a stack of len(num_cells) 51 | cells will be created. 52 | """ 53 | if not model in ['lstm', 'gru', 'skip_lstm', 'skip_gru']: 54 | raise ValueError('The specified model is not supported. Please use {lstm, gru, skip_lstm, skip_gru}.') 55 | if isinstance(num_cells, list) and len(num_cells) > 1: 56 | if model == 'skip_lstm': 57 | cells = MultiSkipLSTMCell(num_cells) 58 | elif model == 'skip_gru': 59 | cells = MultiSkipGRUCell(num_cells) 60 | elif model == 'lstm': 61 | cell_list = [BasicLSTMCell(n) for n in num_cells] 62 | cells = tf.contrib.rnn.MultiRNNCell(cell_list) 63 | elif model == 'gru': 64 | cell_list = [BasicGRUCell(n) for n in num_cells] 65 | cells = tf.contrib.rnn.MultiRNNCell(cell_list) 66 | if learn_initial_state: 67 | if model == 'skip_lstm' or model == 'skip_gru': 68 | initial_state = cells.trainable_initial_state(batch_size) 69 | else: 70 | initial_state = [] 71 | for idx, cell in enumerate(cell_list): 72 | with tf.variable_scope('layer_%d' % (idx + 1)): 73 | initial_state.append(cell.trainable_initial_state(batch_size)) 74 | initial_state = tuple(initial_state) 75 | else: 76 | initial_state = None 77 | return cells, initial_state 78 | else: 79 | if isinstance(num_cells, list): 80 | num_cells = num_cells[0] 81 | if model == 'skip_lstm': 82 | cell = SkipLSTMCell(num_cells) 83 | elif model == 'skip_gru': 84 | cell = SkipGRUCell(num_cells) 85 | elif model == 'lstm': 86 | cell = BasicLSTMCell(num_cells) 87 | elif model == 'gru': 88 | cell = BasicGRUCell(num_cells) 89 | if learn_initial_state: 90 | initial_state = cell.trainable_initial_state(batch_size) 91 | else: 92 | initial_state = None 93 | return cell, initial_state 94 | 95 | 96 | def using_skip_rnn(model): 97 | """ 98 | Helper function determining whether a Skip RNN models is being used 99 | """ 100 | return model.lower() == 'skip_lstm' or model.lower() == 'skip_gru' 101 | 102 | 103 | def split_rnn_outputs(model, rnn_outputs): 104 | """ 105 | Split the output of dynamic_rnn into the actual RNN outputs and the state update gate 106 | """ 107 | if using_skip_rnn(model): 108 | return rnn_outputs.h, rnn_outputs.state_gate 109 | else: 110 | return rnn_outputs, tf.no_op() 111 | 112 | 113 | def compute_budget_loss(model, loss, updated_states, cost_per_sample): 114 | """ 115 | Compute penalization term on the number of updated states (i.e. used samples) 116 | """ 117 | if using_skip_rnn(model): 118 | return tf.reduce_mean(tf.reduce_sum(cost_per_sample * updated_states, 1), 0) 119 | else: 120 | return tf.zeros(loss.get_shape()) 121 | -------------------------------------------------------------------------------- /src/util/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic functions that are used in different scripts. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import print_function 7 | 8 | import types 9 | from decimal import Decimal 10 | 11 | import tensorflow as tf 12 | 13 | FLAGS = tf.app.flags.FLAGS 14 | 15 | 16 | def print_setup(task_specific_setup=None): 17 | """ 18 | Print experimental setup 19 | :param task_specific_setup: (optional) function printing task-specific parameters 20 | """ 21 | model_dict = {'lstm': 'LSTM', 'gru': 'GRU', 'skip_lstm': 'SkipLSTM', 'skip_gru': 'SkipGRU'} 22 | print('\n\n\tExperimental setup') 23 | print('\t------------------\n') 24 | print('\tModel: %s' % model_dict[FLAGS.model.lower()]) 25 | print('\tNumber of layers: %d' % FLAGS.rnn_layers) 26 | print('\tNumber of cells: %d' % FLAGS.rnn_cells) 27 | print('\tBatch size: %d' % FLAGS.batch_size) 28 | print('\tLearning rate: %.2E' % Decimal(FLAGS.learning_rate)) 29 | 30 | if FLAGS.grad_clip > 0: 31 | print('\tGradient clipping: %.1f' % FLAGS.grad_clip) 32 | else: 33 | print('\tGradient clipping: No') 34 | 35 | if FLAGS.model.lower().startswith('skip'): 36 | print('\tCost per sample: %.2E' % Decimal(FLAGS.cost_per_sample)) 37 | 38 | if isinstance(task_specific_setup, types.FunctionType): 39 | print('') 40 | task_specific_setup() 41 | 42 | print('\n\n') 43 | 44 | 45 | def compute_used_samples(update_state_gate): 46 | """ 47 | Compute number of used samples (i.e. number of updated states) 48 | :param update_state_gate: values for the update state gate 49 | :return: number of used samples 50 | """ 51 | batch_size = update_state_gate.shape[0] 52 | steps = 0. 53 | for idx in range(batch_size): 54 | for idt in range(update_state_gate.shape[1]): 55 | steps += update_state_gate[idx, idt] 56 | return steps / batch_size 57 | 58 | 59 | def scalar_summary(name, value): 60 | summary = tf.summary.Summary() 61 | summary_value = summary.value.add() 62 | summary_value.simple_value = value 63 | summary_value.tag = name 64 | return summary 65 | --------------------------------------------------------------------------------