├── .gitignore ├── LICENSE ├── README.md ├── assets ├── Controller-DFD.png ├── DNC-DFD.png ├── DNC-dynamic-mem.png ├── DNC-scalable.png ├── UML.png ├── allocation_weighting.png ├── babi-training.png ├── model-series-2-curve.png ├── model-series-4-curve.png └── model-single-curve.png ├── dnc ├── __init__.py ├── controller.py ├── dnc.py ├── memory.py └── utility.py ├── docs ├── basic-usage.md ├── data-flow.md └── implementation-notes.md ├── run_tests.sh ├── tasks ├── babi │ ├── README.md │ ├── checkpoints │ │ └── step-500005 │ │ │ ├── checkpoint │ │ │ ├── model.ckpt │ │ │ └── model.ckpt.meta │ ├── dnc │ ├── logs │ │ ├── events.out.tfevents.1483079568.ip-172-31-15-24 │ │ ├── events.out.tfevents.1483218179.ip-172-31-15-24 │ │ ├── events.out.tfevents.1483353893.ip-172-31-15-24 │ │ ├── events.out.tfevents.1483491250.ip-172-31-15-24 │ │ └── events.out.tfevents.1483628556.ip-172-31-15-24 │ ├── preprocess.py │ ├── recurrent_controller.py │ ├── test.py │ └── train.py └── copy │ ├── README.md │ ├── checkpoints │ ├── model-series-2 │ │ ├── checkpoint │ │ ├── model.ckpt │ │ └── model.ckpt.meta │ ├── model-series-4 │ │ ├── checkpoint │ │ ├── model.ckpt │ │ └── model.ckpt.meta │ └── model-single-10 │ │ ├── checkpoint │ │ ├── model.ckpt │ │ └── model.ckpt.meta │ ├── dnc │ ├── feedforward_controller.py │ ├── train-series.py │ ├── train.py │ └── visualization.ipynb └── unit-tests ├── controller.py ├── dnc ├── dnc.py ├── memory.py └── utility.py /.gitignore: -------------------------------------------------------------------------------- 1 | # compiled python files 2 | *.pyc 3 | 4 | # tensorboard logs directories 5 | logs 6 | 7 | # IPython checkpoints directories 8 | .ipynb_checkpoints 9 | 10 | # bAbI task data 11 | tasks/babi/data 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Mostafa-Samir 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DNC TensorFlow 2 | 3 | This is a TensorFlow implementation of DeepMind's Differentiable Neural Computer (DNC) architecture introduced in their recent Nature paper: 4 | > [Graves, Alex, et al. "Hybrid computing using a neural network with dynamic external memory." Nature 538.7626 (2016): 471-476.](http://www.nature.com/articles/nature20101.epdf?author_access_token=ImTXBI8aWbYxYQ51Plys8NRgN0jAjWel9jnR3ZoTv0MggmpDmwljGswxVdeocYSurJ3hxupzWuRNeGvvXnoO8o4jTJcnAyhGuZzXJ1GEaD-Z7E6X_a9R-xqJ9TfJWBqz) 5 | 6 | This implementation doesn't include all the tasks that was described in the paper, but it's focused on exploring and re-producing the general task-independent key characteristics of the architecture. However, the implementation was designed with extensibility in mind, so it's fairly simple to adapt it to further tasks. 7 | 8 | ## Local Environment Specification 9 | 10 | Copy experiments and tests ran on a machine with: 11 | - An Intel Core i5 2410M CPU @ 2.30GHz (2 physical cores, with hyper-threading enabled) 12 | - 4GB SO-DIMM DDR3 RAM @ 1333MHz 13 | - No GPU. 14 | - Ubuntu 14.04 LTS 15 | - TensorFlow r0.11 16 | - Python 2.7 17 | 18 | bAbI experiment and tests ran on an AWS P2 instance on 1 Tesla K80 GPU. 19 | 20 | ## Experiments 21 | 22 | ### Dynamic Memory Mechanisms 23 | 24 | This experiment is designed to demonstrate the various functionalities of the external memory access mechanisms such as in-order retrieval and allocation/deallocation. 25 | 26 | A similar approach to that of the paper was followed by training a 2-layer feedforward model with only 10 memory locations on a copy task in which a series of 4 random binary sequences each of which is of size 6 (24 piece of information) was presented as input. Details about the training can be found [here](tasks/copy/). 27 | 28 | The model was able to learn to copy the input successfully, and it indeed learned to use the mentioned memory mechanisms. The following figure (which resembles **Extended Data Figure 1** in the paper) illustrates that. 29 | 30 | *You can re-generate similar figures in the [visualization notebook](tasks/copy/visualization.ipynb)* 31 | 32 | ![DNC-Memory-Mechanisms](/assets/DNC-dynamic-mem.png) 33 | 34 | - In the **Memory Locations** part of the figure, it's apparent that the model is able to read the memory locations in the same order they were written into. 35 | 36 | - In the **Free Gate** and the **Allocation Gate** portions of the figure, it's shown that the free gates are fully activated after a memory location is read and becomes obsolete, while being less activated in the writing phase. The opposite is true for the allocation gate. The **Memory Locations Usage** also demonstrates how memory locations are used, freed, and re-used again time after time. 37 | 38 | *The figure differs a little from the one in the paper when it comes to the activation degrees of the gates. This could be due to the small size of the model and the relatively small training time. However, this doesn't affect the operation of the model.* 39 | 40 | ### Generalization and Memory Scalability 41 | 42 | This experiment was designed to check: 43 | - if the trained model has learned an implicit copying algorithm that can be generalized to larger input lengths. 44 | - if the learned model is independent of the training memory size and can be scaled-up with memories of larger sizes. 45 | 46 | To approach that, a 2-layer feedforward model with 15 memory locations was trained on a copy problem in which a single sequence of random binary vectors of lengths between 1 and 10 was presented as input. Details of the training process can be found [here](tasks/copy/). 47 | 48 | The model was then tested on pairs of increasing sequence lengths and increasing memory sizes with re-training on any of these pairs of parameters, and the fraction of correctly copied sequences out of a batch of 100 was recorded. The model was indeed able to generalize and use the available memory locations effectively without retraining. This is depicted in the following figure which resembles **Extended Data Figure 2** from the paper. 49 | 50 | *Similar figures can be re-generated in the [visualization notebook](tasks/copy/visualization.ipynb)* 51 | 52 | ![DNC-Scalability](/assets/DNC-scalable.png) 53 | 54 | ### bAbI Task 55 | 56 | This experiment was designed to reproduce the paper's results on the bAbI 20QA task. By training a model with the same parameters as DNC1 described in the paper (Extended Data Table 2) on the **en-10k** dataset, the model resulted in error percentages that *mostly* fell within the 1 standard deviation of the means reported in the paper (Extended Data Table 1). The results, and their comparison to the paper's mean results, are shown in the following table. Details about training and reproduction can be found [here](tasks/babi/). 57 | 58 | | Task Name | Results | Paper's Mean | 59 | | --------- | ------- | ------------ | 60 | | single supporting fact | 0.00% | 9.0±12.6% | 61 | | two supporting facts | 11.88% | 39.2±20.5% | 62 | | three supporting facts | 27.80% | 39.6±16.4% | 63 | | two arg relations | 1.40% | 0.4±0.7% | 64 | | three arg relations | 1.70% | 1.5±1.0% | 65 | | yes no questions | 0.50% | 6.9±7.5% | 66 | | counting | 4.90% | 9.8±7.0% | 67 | | lists sets | 2.10% | 5.5±5.9% | 68 | | simple negation | 0.80% | 7.7±8.3% | 69 | | indefinite knowledge | 1.70% | 9.6±11.4% | 70 | | basic coreference | 0.10% | 3.3±5.7% | 71 | | conjunction | 0.00% | 5.0±6.3% | 72 | | compound coreference | 0.40% | 3.1±3.6% | 73 | | time reasoning | 11.80% | 11.0±7.5% | 74 | | basic deduction | 45.44% | 27.2±20.1% | 75 | | basic induction | 56.43% | 53.6±1.9% | 76 | | positional reasoning | 39.02% | 32.4±8.0% | 77 | | size reasoning | 8.68% | 4.2±1.8% | 78 | | path finding | 98.21% | 64.6±37.4% | 79 | | agents motivations | 2.71% | 0.0±0.1% | 80 | | **Mean Err.** | 15.78% | 16.7±7.6% | 81 | | **Failed (err. > 5%)** | 8 | 11.2±5.4 | 82 | 83 | ## Getting Involved 84 | 85 | If you're interested in using the implementation for new tasks, you should first start by **[reading the structure and basic usage guide](docs/basic-usage.md)** to get comfortable with how the project is structured and how it can be extended to new tasks. 86 | 87 | If you intend to work with the source code of the implementation itself, you should begin with looking at **[the data flow diagrams](docs/data-flow.md)** to get a high-level overview of how the data moves from the input to the output across the modules of the implementation. This would ease you into reading the source code, which is okay-documented. 88 | 89 | You might also find the **[implementation notes](docs/implementation-notes.md)** helpful to clarify how some of the math is implemented. 90 | 91 | ## To-Do 92 | 93 | - **Core:** 94 | - Sparse link matrix. 95 | - Variable sequence lengths across the same batch. 96 | - **Tasks**: 97 | - ~~bAbI task.~~ 98 | - Graph inference tasks. 99 | - Mini-SHRDLU task. 100 | - **Utility**: 101 | - A task builder that abstracts away all details about iterations, learning rates, ... etc into configurable command-line arguments and leaves the user only with the worries of defining the computational graph. 102 | 103 | ## Author 104 | Mostafa Samir 105 | 106 | [mostafa.3210@gmail.com](mailto:mostfa.3210@gmail.com) 107 | 108 | ## License 109 | MIT 110 | -------------------------------------------------------------------------------- /assets/Controller-DFD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/Controller-DFD.png -------------------------------------------------------------------------------- /assets/DNC-DFD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/DNC-DFD.png -------------------------------------------------------------------------------- /assets/DNC-dynamic-mem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/DNC-dynamic-mem.png -------------------------------------------------------------------------------- /assets/DNC-scalable.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/DNC-scalable.png -------------------------------------------------------------------------------- /assets/UML.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/UML.png -------------------------------------------------------------------------------- /assets/allocation_weighting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/allocation_weighting.png -------------------------------------------------------------------------------- /assets/babi-training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/babi-training.png -------------------------------------------------------------------------------- /assets/model-series-2-curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/model-series-2-curve.png -------------------------------------------------------------------------------- /assets/model-series-4-curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/model-series-4-curve.png -------------------------------------------------------------------------------- /assets/model-single-curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/assets/model-single-curve.png -------------------------------------------------------------------------------- /dnc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/dnc/__init__.py -------------------------------------------------------------------------------- /dnc/controller.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | class BaseController: 5 | 6 | def __init__(self, input_size, output_size, memory_read_heads, memory_word_size, batch_size=1): 7 | """ 8 | constructs a controller as described in the DNC paper: 9 | http://www.nature.com/nature/journal/vaop/ncurrent/full/nature20101.html 10 | 11 | Parameters: 12 | ---------- 13 | input_size: int 14 | the size of the data input vector 15 | output_size: int 16 | the size of the data output vector 17 | memory_read_heads: int 18 | the number of read haeds in the associated external memory 19 | memory_word_size: int 20 | the size of the word in the associated external memory 21 | batch_size: int 22 | the size of the input data batch [optional, usually set by the DNC object] 23 | """ 24 | 25 | self.input_size = input_size 26 | self.output_size = output_size 27 | self.read_heads = memory_read_heads 28 | self.word_size = memory_word_size 29 | self.batch_size = batch_size 30 | 31 | # indicates if the internal neural network is recurrent 32 | # by the existence of recurrent_update and get_state methods 33 | has_recurrent_update = callable(getattr(self, 'update_state', None)) 34 | has_get_state = callable(getattr(self, 'get_state', None)) 35 | self.has_recurrent_nn = has_recurrent_update and has_get_state 36 | 37 | # the actual size of the neural network input after flatenning and 38 | # concatenating the input vector with the previously read vctors from memory 39 | self.nn_input_size = self.word_size * self.read_heads + self.input_size 40 | 41 | self.interface_vector_size = self.word_size * self.read_heads + 3 * self.word_size + 5 * self.read_heads + 3 42 | 43 | # define network vars 44 | with tf.name_scope("controller"): 45 | self.network_vars() 46 | 47 | self.nn_output_size = None 48 | with tf.variable_scope("shape_inference"): 49 | self.nn_output_size = self.get_nn_output_size() 50 | 51 | self.initials() 52 | 53 | def initials(self): 54 | """ 55 | sets the initial values of the controller transformation weights matrices 56 | this method can be overwritten to use a different initialization scheme 57 | """ 58 | # defining internal weights of the controller 59 | self.interface_weights = tf.Variable( 60 | tf.random_normal([self.nn_output_size, self.interface_vector_size], stddev=0.1), 61 | name='interface_weights' 62 | ) 63 | self.nn_output_weights = tf.Variable( 64 | tf.random_normal([self.nn_output_size, self.output_size], stddev=0.1), 65 | name='nn_output_weights' 66 | ) 67 | self.mem_output_weights = tf.Variable( 68 | tf.random_normal([self.word_size * self.read_heads, self.output_size], stddev=0.1), 69 | name='mem_output_weights' 70 | ) 71 | 72 | def network_vars(self): 73 | """ 74 | defines the variables needed by the internal neural network 75 | [the variables should be attributes of the class, i.e. self.*] 76 | """ 77 | raise NotImplementedError("network_vars is not implemented") 78 | 79 | 80 | def network_op(self, X): 81 | """ 82 | defines the controller's internal neural network operation 83 | 84 | Parameters: 85 | ---------- 86 | X: Tensor (batch_size, word_size * read_haeds + input_size) 87 | the input data concatenated with the previously read vectors from memory 88 | 89 | Returns: Tensor (batch_size, nn_output_size) 90 | """ 91 | raise NotImplementedError("network_op method is not implemented") 92 | 93 | 94 | def get_nn_output_size(self): 95 | """ 96 | retrives the output size of the defined neural network 97 | 98 | Returns: int 99 | the output's size 100 | 101 | Raises: ValueError 102 | """ 103 | 104 | input_vector = np.zeros([self.batch_size, self.nn_input_size], dtype=np.float32) 105 | output_vector = None 106 | 107 | if self.has_recurrent_nn: 108 | output_vector,_ = self.network_op(input_vector, self.get_state()) 109 | else: 110 | output_vector = self.network_op(input_vector) 111 | 112 | shape = output_vector.get_shape().as_list() 113 | 114 | if len(shape) > 2: 115 | raise ValueError("Expected the neural network to output a 1D vector, but got %dD" % (len(shape) - 1)) 116 | else: 117 | return shape[1] 118 | 119 | 120 | def parse_interface_vector(self, interface_vector): 121 | """ 122 | pasres the flat interface_vector into its various components with their 123 | correct shapes 124 | 125 | Parameters: 126 | ---------- 127 | interface_vector: Tensor (batch_size, interface_vector_size) 128 | the flattened inetrface vector to be parsed 129 | 130 | Returns: dict 131 | a dictionary with the components of the interface_vector parsed 132 | """ 133 | 134 | parsed = {} 135 | 136 | r_keys_end = self.word_size * self.read_heads 137 | r_strengths_end = r_keys_end + self.read_heads 138 | w_key_end = r_strengths_end + self.word_size 139 | erase_end = w_key_end + 1 + self.word_size 140 | write_end = erase_end + self.word_size 141 | free_end = write_end + self.read_heads 142 | 143 | r_keys_shape = (-1, self.word_size, self.read_heads) 144 | r_strengths_shape = (-1, self.read_heads) 145 | w_key_shape = (-1, self.word_size, 1) 146 | write_shape = erase_shape = (-1, self.word_size) 147 | free_shape = (-1, self.read_heads) 148 | modes_shape = (-1, 3, self.read_heads) 149 | 150 | # parsing the vector into its individual components 151 | parsed['read_keys'] = tf.reshape(interface_vector[:, :r_keys_end], r_keys_shape) 152 | parsed['read_strengths'] = tf.reshape(interface_vector[:, r_keys_end:r_strengths_end], r_strengths_shape) 153 | parsed['write_key'] = tf.reshape(interface_vector[:, r_strengths_end:w_key_end], w_key_shape) 154 | parsed['write_strength'] = tf.reshape(interface_vector[:, w_key_end], (-1, 1)) 155 | parsed['erase_vector'] = tf.reshape(interface_vector[:, w_key_end + 1:erase_end], erase_shape) 156 | parsed['write_vector'] = tf.reshape(interface_vector[:, erase_end:write_end], write_shape) 157 | parsed['free_gates'] = tf.reshape(interface_vector[:, write_end:free_end], free_shape) 158 | parsed['allocation_gate'] = tf.expand_dims(interface_vector[:, free_end], 1) 159 | parsed['write_gate'] = tf.expand_dims(interface_vector[:, free_end + 1], 1) 160 | parsed['read_modes'] = tf.reshape(interface_vector[:, free_end + 2:], modes_shape) 161 | 162 | # transforming the components to ensure they're in the right ranges 163 | parsed['read_strengths'] = 1 + tf.nn.softplus(parsed['read_strengths']) 164 | parsed['write_strength'] = 1 + tf.nn.softplus(parsed['write_strength']) 165 | parsed['erase_vector'] = tf.nn.sigmoid(parsed['erase_vector']) 166 | parsed['free_gates'] = tf.nn.sigmoid(parsed['free_gates']) 167 | parsed['allocation_gate'] = tf.nn.sigmoid(parsed['allocation_gate']) 168 | parsed['write_gate'] = tf.nn.sigmoid(parsed['write_gate']) 169 | parsed['read_modes'] = tf.nn.softmax(parsed['read_modes'], 1) 170 | 171 | return parsed 172 | 173 | def process_input(self, X, last_read_vectors, state=None): 174 | """ 175 | processes input data through the controller network and returns the 176 | pre-output and interface_vector 177 | 178 | Parameters: 179 | ---------- 180 | X: Tensor (batch_size, input_size) 181 | the input data batch 182 | last_read_vectors: (batch_size, word_size, read_heads) 183 | the last batch of read vectors from memory 184 | state: Tuple 185 | state vectors if the network is recurrent 186 | 187 | Returns: Tuple 188 | pre-output: Tensor (batch_size, output_size) 189 | parsed_interface_vector: dict 190 | """ 191 | 192 | flat_read_vectors = tf.reshape(last_read_vectors, (-1, self.word_size * self.read_heads)) 193 | complete_input = tf.concat(1, [X, flat_read_vectors]) 194 | nn_output, nn_state = None, None 195 | 196 | if self.has_recurrent_nn: 197 | nn_output, nn_state = self.network_op(complete_input, state) 198 | else: 199 | nn_output = self.network_op(complete_input) 200 | 201 | pre_output = tf.matmul(nn_output, self.nn_output_weights) 202 | interface = tf.matmul(nn_output, self.interface_weights) 203 | parsed_interface = self.parse_interface_vector(interface) 204 | 205 | if self.has_recurrent_nn: 206 | return pre_output, parsed_interface, nn_state 207 | else: 208 | return pre_output, parsed_interface 209 | 210 | 211 | def final_output(self, pre_output, new_read_vectors): 212 | """ 213 | returns the final output by taking rececnt memory changes into account 214 | 215 | Parameters: 216 | ---------- 217 | pre_output: Tensor (batch_size, output_size) 218 | the ouput vector from the input processing step 219 | new_read_vectors: Tensor (batch_size, words_size, read_heads) 220 | the newly read vectors from the updated memory 221 | 222 | Returns: Tensor (batch_size, output_size) 223 | """ 224 | 225 | flat_read_vectors = tf.reshape(new_read_vectors, (-1, self.word_size * self.read_heads)) 226 | 227 | final_output = pre_output + tf.matmul(flat_read_vectors, self.mem_output_weights) 228 | 229 | return final_output 230 | -------------------------------------------------------------------------------- /dnc/dnc.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn_cell import LSTMStateTuple 3 | from memory import Memory 4 | import utility 5 | import os 6 | 7 | class DNC: 8 | 9 | def __init__(self, controller_class, input_size, output_size, max_sequence_length, 10 | memory_words_num = 256, memory_word_size = 64, memory_read_heads = 4, batch_size = 1): 11 | """ 12 | constructs a complete DNC architecture as described in the DNC paper 13 | http://www.nature.com/nature/journal/vaop/ncurrent/full/nature20101.html 14 | 15 | Parameters: 16 | ----------- 17 | controller_class: BaseController 18 | a concrete implementation of the BaseController class 19 | input_size: int 20 | the size of the input vector 21 | output_size: int 22 | the size of the output vector 23 | max_sequence_length: int 24 | the maximum length of an input sequence 25 | memory_words_num: int 26 | the number of words that can be stored in memory 27 | memory_word_size: int 28 | the size of an individual word in memory 29 | memory_read_heads: int 30 | the number of read heads in the memory 31 | batch_size: int 32 | the size of the data batch 33 | """ 34 | 35 | self.input_size = input_size 36 | self.output_size = output_size 37 | self.max_sequence_length = max_sequence_length 38 | self.words_num = memory_words_num 39 | self.word_size = memory_word_size 40 | self.read_heads = memory_read_heads 41 | self.batch_size = batch_size 42 | 43 | self.memory = Memory(self.words_num, self.word_size, self.read_heads, self.batch_size) 44 | self.controller = controller_class(self.input_size, self.output_size, self.read_heads, self.word_size, self.batch_size) 45 | 46 | # input data placeholders 47 | self.input_data = tf.placeholder(tf.float32, [batch_size, None, input_size], name='input') 48 | self.target_output = tf.placeholder(tf.float32, [batch_size, None, output_size], name='targets') 49 | self.sequence_length = tf.placeholder(tf.int32, name='sequence_length') 50 | 51 | self.build_graph() 52 | 53 | 54 | def _step_op(self, step, memory_state, controller_state=None): 55 | """ 56 | performs a step operation on the input step data 57 | 58 | Parameters: 59 | ---------- 60 | step: Tensor (batch_size, input_size) 61 | memory_state: Tuple 62 | a tuple of current memory parameters 63 | controller_state: Tuple 64 | the state of the controller if it's recurrent 65 | 66 | Returns: Tuple 67 | output: Tensor (batch_size, output_size) 68 | memory_view: dict 69 | """ 70 | 71 | last_read_vectors = memory_state[6] 72 | pre_output, interface, nn_state = None, None, None 73 | 74 | if self.controller.has_recurrent_nn: 75 | pre_output, interface, nn_state = self.controller.process_input(step, last_read_vectors, controller_state) 76 | else: 77 | pre_output, interface = self.controller.process_input(step, last_read_vectors) 78 | 79 | usage_vector, write_weighting, memory_matrix, link_matrix, precedence_vector = self.memory.write( 80 | memory_state[0], memory_state[1], memory_state[5], 81 | memory_state[4], memory_state[2], memory_state[3], 82 | interface['write_key'], 83 | interface['write_strength'], 84 | interface['free_gates'], 85 | interface['allocation_gate'], 86 | interface['write_gate'], 87 | interface['write_vector'], 88 | interface['erase_vector'] 89 | ) 90 | 91 | read_weightings, read_vectors = self.memory.read( 92 | memory_matrix, 93 | memory_state[5], 94 | interface['read_keys'], 95 | interface['read_strengths'], 96 | link_matrix, 97 | interface['read_modes'], 98 | ) 99 | 100 | return [ 101 | 102 | # report new memory state to be updated outside the condition branch 103 | memory_matrix, 104 | usage_vector, 105 | precedence_vector, 106 | link_matrix, 107 | write_weighting, 108 | read_weightings, 109 | read_vectors, 110 | 111 | self.controller.final_output(pre_output, read_vectors), 112 | interface['free_gates'], 113 | interface['allocation_gate'], 114 | interface['write_gate'], 115 | 116 | # report new state of RNN if exists 117 | nn_state[0] if nn_state is not None else tf.zeros(1), 118 | nn_state[1] if nn_state is not None else tf.zeros(1) 119 | ] 120 | 121 | 122 | def _loop_body(self, time, memory_state, outputs, free_gates, allocation_gates, write_gates, 123 | read_weightings, write_weightings, usage_vectors, controller_state): 124 | """ 125 | the body of the DNC sequence processing loop 126 | 127 | Parameters: 128 | ---------- 129 | time: Tensor 130 | outputs: TensorArray 131 | memory_state: Tuple 132 | free_gates: TensorArray 133 | allocation_gates: TensorArray 134 | write_gates: TensorArray 135 | read_weightings: TensorArray, 136 | write_weightings: TensorArray, 137 | usage_vectors: TensorArray, 138 | controller_state: Tuple 139 | 140 | Returns: Tuple containing all updated arguments 141 | """ 142 | 143 | step_input = self.unpacked_input_data.read(time) 144 | 145 | output_list = self._step_op(step_input, memory_state, controller_state) 146 | 147 | # update memory parameters 148 | 149 | new_controller_state = tf.zeros(1) 150 | new_memory_state = tuple(output_list[0:7]) 151 | 152 | new_controller_state = LSTMStateTuple(output_list[11], output_list[12]) 153 | 154 | outputs = outputs.write(time, output_list[7]) 155 | 156 | # collecting memory view for the current step 157 | free_gates = free_gates.write(time, output_list[8]) 158 | allocation_gates = allocation_gates.write(time, output_list[9]) 159 | write_gates = write_gates.write(time, output_list[10]) 160 | read_weightings = read_weightings.write(time, output_list[5]) 161 | write_weightings = write_weightings.write(time, output_list[4]) 162 | usage_vectors = usage_vectors.write(time, output_list[1]) 163 | 164 | return ( 165 | time + 1, new_memory_state, outputs, 166 | free_gates,allocation_gates, write_gates, 167 | read_weightings, write_weightings, 168 | usage_vectors, new_controller_state 169 | ) 170 | 171 | 172 | def build_graph(self): 173 | """ 174 | builds the computational graph that performs a step-by-step evaluation 175 | of the input data batches 176 | """ 177 | 178 | self.unpacked_input_data = utility.unpack_into_tensorarray(self.input_data, 1, self.sequence_length) 179 | 180 | outputs = tf.TensorArray(tf.float32, self.sequence_length) 181 | free_gates = tf.TensorArray(tf.float32, self.sequence_length) 182 | allocation_gates = tf.TensorArray(tf.float32, self.sequence_length) 183 | write_gates = tf.TensorArray(tf.float32, self.sequence_length) 184 | read_weightings = tf.TensorArray(tf.float32, self.sequence_length) 185 | write_weightings = tf.TensorArray(tf.float32, self.sequence_length) 186 | usage_vectors = tf.TensorArray(tf.float32, self.sequence_length) 187 | 188 | controller_state = self.controller.get_state() if self.controller.has_recurrent_nn else (tf.zeros(1), tf.zeros(1)) 189 | memory_state = self.memory.init_memory() 190 | if not isinstance(controller_state, LSTMStateTuple): 191 | controller_state = LSTMStateTuple(controller_state[0], controller_state[1]) 192 | final_results = None 193 | 194 | with tf.variable_scope("sequence_loop") as scope: 195 | time = tf.constant(0, dtype=tf.int32) 196 | 197 | final_results = tf.while_loop( 198 | cond=lambda time, *_: time < self.sequence_length, 199 | body=self._loop_body, 200 | loop_vars=( 201 | time, memory_state, outputs, 202 | free_gates, allocation_gates, write_gates, 203 | read_weightings, write_weightings, 204 | usage_vectors, controller_state 205 | ), 206 | parallel_iterations=32, 207 | swap_memory=True 208 | ) 209 | 210 | dependencies = [] 211 | if self.controller.has_recurrent_nn: 212 | dependencies.append(self.controller.update_state(final_results[9])) 213 | 214 | with tf.control_dependencies(dependencies): 215 | self.packed_output = utility.pack_into_tensor(final_results[2], axis=1) 216 | self.packed_memory_view = { 217 | 'free_gates': utility.pack_into_tensor(final_results[3], axis=1), 218 | 'allocation_gates': utility.pack_into_tensor(final_results[4], axis=1), 219 | 'write_gates': utility.pack_into_tensor(final_results[5], axis=1), 220 | 'read_weightings': utility.pack_into_tensor(final_results[6], axis=1), 221 | 'write_weightings': utility.pack_into_tensor(final_results[7], axis=1), 222 | 'usage_vectors': utility.pack_into_tensor(final_results[8], axis=1) 223 | } 224 | 225 | 226 | def get_outputs(self): 227 | """ 228 | returns the graph nodes for the output and memory view 229 | 230 | Returns: Tuple 231 | outputs: Tensor (batch_size, time_steps, output_size) 232 | memory_view: dict 233 | """ 234 | return self.packed_output, self.packed_memory_view 235 | 236 | 237 | def save(self, session, ckpts_dir, name): 238 | """ 239 | saves the current values of the model's parameters to a checkpoint 240 | 241 | Parameters: 242 | ---------- 243 | session: tf.Session 244 | the tensorflow session to save 245 | ckpts_dir: string 246 | the path to the checkpoints directories 247 | name: string 248 | the name of the checkpoint subdirectory 249 | """ 250 | checkpoint_dir = os.path.join(ckpts_dir, name) 251 | 252 | if not os.path.exists(checkpoint_dir): 253 | os.makedirs(checkpoint_dir) 254 | 255 | tf.train.Saver(tf.trainable_variables()).save(session, os.path.join(checkpoint_dir, 'model.ckpt')) 256 | 257 | 258 | def restore(self, session, ckpts_dir, name): 259 | """ 260 | session: tf.Session 261 | the tensorflow session to restore into 262 | ckpts_dir: string 263 | the path to the checkpoints directories 264 | name: string 265 | the name of the checkpoint subdirectory 266 | """ 267 | tf.train.Saver(tf.trainable_variables()).restore(session, os.path.join(ckpts_dir, name, 'model.ckpt')) 268 | -------------------------------------------------------------------------------- /dnc/memory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import utility 4 | 5 | class Memory: 6 | 7 | def __init__(self, words_num=256, word_size=64, read_heads=4, batch_size=1): 8 | """ 9 | constructs a memory matrix with read heads and a write head as described 10 | in the DNC paper 11 | http://www.nature.com/nature/journal/vaop/ncurrent/full/nature20101.html 12 | 13 | Parameters: 14 | ---------- 15 | words_num: int 16 | the maximum number of words that can be stored in the memory at the 17 | same time 18 | word_size: int 19 | the size of the individual word in the memory 20 | read_heads: int 21 | the number of read heads that can read simultaneously from the memory 22 | batch_size: int 23 | the size of input data batch 24 | """ 25 | 26 | self.words_num = words_num 27 | self.word_size = word_size 28 | self.read_heads = read_heads 29 | self.batch_size = batch_size 30 | 31 | # a words_num x words_num identity matrix 32 | self.I = tf.constant(np.identity(words_num, dtype=np.float32)) 33 | 34 | # maps the indecies from the 2D array of free list per batch to 35 | # their corresponding values in the flat 1D array of ordered_allocation_weighting 36 | self.index_mapper = tf.constant( 37 | np.cumsum([0] + [words_num] * (batch_size - 1), dtype=np.int32)[:, np.newaxis] 38 | ) 39 | 40 | def init_memory(self): 41 | """ 42 | returns the initial values for the memory Parameters 43 | 44 | Returns: Tuple 45 | """ 46 | 47 | return ( 48 | tf.fill([self.batch_size, self.words_num, self.word_size], 1e-6), # initial memory matrix 49 | tf.zeros([self.batch_size, self.words_num, ]), # initial usage vector 50 | tf.zeros([self.batch_size, self.words_num, ]), # initial precedence vector 51 | tf.zeros([self.batch_size, self.words_num, self.words_num]), # initial link matrix 52 | tf.fill([self.batch_size, self.words_num, ], 1e-6), # initial write weighting 53 | tf.fill([self.batch_size, self.words_num, self.read_heads], 1e-6), # initial read weightings 54 | tf.fill([self.batch_size, self.word_size, self.read_heads], 1e-6), # initial read vectors 55 | ) 56 | 57 | def get_lookup_weighting(self, memory_matrix, keys, strengths): 58 | """ 59 | retrives a content-based adderssing weighting given the keys 60 | 61 | Parameters: 62 | ---------- 63 | memory_matrix: Tensor (batch_size, words_num, word_size) 64 | the memory matrix to lookup in 65 | keys: Tensor (batch_size, word_size, number_of_keys) 66 | the keys to query the memory with 67 | strengths: Tensor (batch_size, number_of_keys, ) 68 | the list of strengths for each lookup key 69 | 70 | Returns: Tensor (batch_size, words_num, number_of_keys) 71 | The list of lookup weightings for each provided key 72 | """ 73 | 74 | normalized_memory = tf.nn.l2_normalize(memory_matrix, 2) 75 | normalized_keys = tf.nn.l2_normalize(keys, 1) 76 | 77 | similiarity = tf.batch_matmul(normalized_memory, normalized_keys) 78 | strengths = tf.expand_dims(strengths, 1) 79 | 80 | return tf.nn.softmax(similiarity * strengths, 1) 81 | 82 | 83 | def update_usage_vector(self, usage_vector, read_weightings, write_weighting, free_gates): 84 | """ 85 | updates and returns the usgae vector given the values of the free gates 86 | and the usage_vector, read_weightings, write_weighting from previous step 87 | 88 | Parameters: 89 | ---------- 90 | usage_vector: Tensor (batch_size, words_num) 91 | read_weightings: Tensor (batch_size, words_num, read_heads) 92 | write_weighting: Tensor (batch_size, words_num) 93 | free_gates: Tensor (batch_size, read_heads, ) 94 | 95 | Returns: Tensor (batch_size, words_num, ) 96 | the updated usage vector 97 | """ 98 | free_gates = tf.expand_dims(free_gates, 1) 99 | 100 | retention_vector = tf.reduce_prod(1 - read_weightings * free_gates, 2) 101 | updated_usage = (usage_vector + write_weighting - usage_vector * write_weighting) * retention_vector 102 | 103 | return updated_usage 104 | 105 | 106 | def get_allocation_weighting(self, sorted_usage, free_list): 107 | """ 108 | retreives the writing allocation weighting based on the usage free list 109 | 110 | Parameters: 111 | ---------- 112 | sorted_usage: Tensor (batch_size, words_num, ) 113 | the usage vector sorted ascndingly 114 | free_list: Tensor (batch, words_num, ) 115 | the original indecies of the sorted usage vector 116 | 117 | Returns: Tensor (batch_size, words_num, ) 118 | the allocation weighting for each word in memory 119 | """ 120 | 121 | shifted_cumprod = tf.cumprod(sorted_usage, axis = 1, exclusive=True) 122 | unordered_allocation_weighting = (1 - sorted_usage) * shifted_cumprod 123 | 124 | mapped_free_list = free_list + self.index_mapper 125 | flat_unordered_allocation_weighting = tf.reshape(unordered_allocation_weighting, (-1,)) 126 | flat_mapped_free_list = tf.reshape(mapped_free_list, (-1,)) 127 | flat_container = tf.TensorArray(tf.float32, self.batch_size * self.words_num) 128 | 129 | flat_ordered_weightings = flat_container.scatter( 130 | flat_mapped_free_list, 131 | flat_unordered_allocation_weighting 132 | ) 133 | 134 | packed_wightings = flat_ordered_weightings.pack() 135 | return tf.reshape(packed_wightings, (self.batch_size, self.words_num)) 136 | 137 | 138 | def update_write_weighting(self, lookup_weighting, allocation_weighting, write_gate, allocation_gate): 139 | """ 140 | updates and returns the current write_weighting 141 | 142 | Parameters: 143 | ---------- 144 | lookup_weighting: Tensor (batch_size, words_num, 1) 145 | the weight of the lookup operation in writing 146 | allocation_weighting: Tensor (batch_size, words_num) 147 | the weight of the allocation operation in writing 148 | write_gate: (batch_size, 1) 149 | the fraction of writing to be done 150 | allocation_gate: (batch_size, 1) 151 | the fraction of allocation to be done 152 | 153 | Returns: Tensor (batch_size, words_num) 154 | the updated write_weighting 155 | """ 156 | 157 | # remove the dimension of 1 from the lookup_weighting 158 | lookup_weighting = tf.squeeze(lookup_weighting) 159 | 160 | updated_write_weighting = write_gate * (allocation_gate * allocation_weighting + (1 - allocation_gate) * lookup_weighting) 161 | 162 | return updated_write_weighting 163 | 164 | 165 | def update_memory(self, memory_matrix, write_weighting, write_vector, erase_vector): 166 | """ 167 | updates and returns the memory matrix given the weighting, write and erase vectors 168 | and the memory matrix from previous step 169 | 170 | Parameters: 171 | ---------- 172 | memory_matrix: Tensor (batch_size, words_num, word_size) 173 | the memory matrix from previous step 174 | write_weighting: Tensor (batch_size, words_num) 175 | the weight of writing at each memory location 176 | write_vector: Tensor (batch_size, word_size) 177 | a vector specifying what to write 178 | erase_vector: Tensor (batch_size, word_size) 179 | a vector specifying what to erase from memory 180 | 181 | Returns: Tensor (batch_size, words_num, word_size) 182 | the updated memory matrix 183 | """ 184 | 185 | # expand data with a dimension of 1 at multiplication-adjacent location 186 | # to force matmul to behave as an outer product 187 | write_weighting = tf.expand_dims(write_weighting, 2) 188 | write_vector = tf.expand_dims(write_vector, 1) 189 | erase_vector = tf.expand_dims(erase_vector, 1) 190 | 191 | erasing = memory_matrix * (1 - tf.batch_matmul(write_weighting, erase_vector)) 192 | writing = tf.batch_matmul(write_weighting, write_vector) 193 | updated_memory = erasing + writing 194 | 195 | return updated_memory 196 | 197 | 198 | def update_precedence_vector(self, precedence_vector, write_weighting): 199 | """ 200 | updates the precedence vector given the latest write weighting 201 | and the precedence_vector from last step 202 | 203 | Parameters: 204 | ---------- 205 | precedence_vector: Tensor (batch_size. words_num) 206 | the precedence vector from the last time step 207 | write_weighting: Tensor (batch_size,words_num) 208 | the latest write weighting for the memory 209 | 210 | Returns: Tensor (batch_size, words_num) 211 | the updated precedence vector 212 | """ 213 | 214 | reset_factor = 1 - tf.reduce_sum(write_weighting, 1, keep_dims=True) 215 | updated_precedence_vector = reset_factor * precedence_vector + write_weighting 216 | 217 | return updated_precedence_vector 218 | 219 | 220 | def update_link_matrix(self, precedence_vector, link_matrix, write_weighting): 221 | """ 222 | updates and returns the temporal link matrix for the latest write 223 | given the precedence vector and the link matrix from previous step 224 | 225 | Parameters: 226 | ---------- 227 | precedence_vector: Tensor (batch_size, words_num) 228 | the precedence vector from the last time step 229 | link_matrix: Tensor (batch_size, words_num, words_num) 230 | the link matrix form the last step 231 | write_weighting: Tensor (batch_size, words_num) 232 | the latest write_weighting for the memory 233 | 234 | Returns: Tensor (batch_size, words_num, words_num) 235 | the updated temporal link matrix 236 | """ 237 | 238 | write_weighting = tf.expand_dims(write_weighting, 2) 239 | precedence_vector = tf.expand_dims(precedence_vector, 1) 240 | 241 | reset_factor = 1 - utility.pairwise_add(write_weighting, is_batch=True) 242 | updated_link_matrix = reset_factor * link_matrix + tf.batch_matmul(write_weighting, precedence_vector) 243 | updated_link_matrix = (1 - self.I) * updated_link_matrix # eliminates self-links 244 | 245 | return updated_link_matrix 246 | 247 | 248 | def get_directional_weightings(self, read_weightings, link_matrix): 249 | """ 250 | computes and returns the forward and backward reading weightings 251 | given the read_weightings from the previous step 252 | 253 | Parameters: 254 | ---------- 255 | read_weightings: Tensor (batch_size, words_num, read_heads) 256 | the read weightings from the last time step 257 | link_matrix: Tensor (batch_size, words_num, words_num) 258 | the temporal link matrix 259 | 260 | Returns: Tuple 261 | forward weighting: Tensor (batch_size, words_num, read_heads), 262 | backward weighting: Tensor (batch_size, words_num, read_heads) 263 | """ 264 | 265 | forward_weighting = tf.batch_matmul(link_matrix, read_weightings) 266 | backward_weighting = tf.batch_matmul(link_matrix, read_weightings, adj_x=True) 267 | 268 | return forward_weighting, backward_weighting 269 | 270 | 271 | def update_read_weightings(self, lookup_weightings, forward_weighting, backward_weighting, read_mode): 272 | """ 273 | updates and returns the current read_weightings 274 | 275 | Parameters: 276 | ---------- 277 | lookup_weightings: Tensor (batch_size, words_num, read_heads) 278 | the content-based read weighting 279 | forward_weighting: Tensor (batch_size, words_num, read_heads) 280 | the forward direction read weighting 281 | backward_weighting: Tensor (batch_size, words_num, read_heads) 282 | the backward direction read weighting 283 | read_mode: Tesnor (batch_size, 3, read_heads) 284 | the softmax distribution between the three read modes 285 | 286 | Returns: Tensor (batch_size, words_num, read_heads) 287 | """ 288 | 289 | backward_mode = tf.expand_dims(read_mode[:, 0, :], 1) * backward_weighting 290 | lookup_mode = tf.expand_dims(read_mode[:, 1, :], 1) * lookup_weightings 291 | forward_mode = tf.expand_dims(read_mode[:, 2, :], 1) * forward_weighting 292 | updated_read_weightings = backward_mode + lookup_mode + forward_mode 293 | 294 | return updated_read_weightings 295 | 296 | 297 | def update_read_vectors(self, memory_matrix, read_weightings): 298 | """ 299 | reads, updates, and returns the read vectors of the recently updated memory 300 | 301 | Parameters: 302 | ---------- 303 | memory_matrix: Tensor (batch_size, words_num, word_size) 304 | the recently updated memory matrix 305 | read_weightings: Tensor (batch_size, words_num, read_heads) 306 | the amount of info to read from each memory location by each read head 307 | 308 | Returns: Tensor (word_size, read_heads) 309 | """ 310 | 311 | updated_read_vectors = tf.batch_matmul(memory_matrix, read_weightings, adj_x=True) 312 | 313 | return updated_read_vectors 314 | 315 | 316 | def write(self, memory_matrix, usage_vector, read_weightings, write_weighting, 317 | precedence_vector, link_matrix, key, strength, free_gates, 318 | allocation_gate, write_gate, write_vector, erase_vector): 319 | """ 320 | defines the complete pipeline of writing to memory gievn the write variables 321 | and the memory_matrix, usage_vector, link_matrix, and precedence_vector from 322 | previous step 323 | 324 | Parameters: 325 | ---------- 326 | memory_matrix: Tensor (batch_size, words_num, word_size) 327 | the memory matrix from previous step 328 | usage_vector: Tensor (batch_size, words_num) 329 | the usage_vector from the last time step 330 | read_weightings: Tensor (batch_size, words_num, read_heads) 331 | the read_weightings from the last time step 332 | write_weighting: Tensor (batch_size, words_num) 333 | the write_weighting from the last time step 334 | precedence_vector: Tensor (batch_size, words_num) 335 | the precedence vector from the last time step 336 | link_matrix: Tensor (batch_size, words_num, words_num) 337 | the link_matrix from previous step 338 | key: Tensor (batch_size, word_size, 1) 339 | the key to query the memory location with 340 | strength: (batch_size, 1) 341 | the strength of the query key 342 | free_gates: Tensor (batch_size, read_heads) 343 | the degree to which location at read haeds will be freed 344 | allocation_gate: (batch_size, 1) 345 | the fraction of writing that is being allocated in a new locatio 346 | write_gate: (batch_size, 1) 347 | the amount of information to be written to memory 348 | write_vector: Tensor (batch_size, word_size) 349 | specifications of what to write to memory 350 | erase_vector: Tensor(batch_size, word_size) 351 | specifications of what to erase from memory 352 | 353 | Returns : Tuple 354 | the updated usage vector: Tensor (batch_size, words_num) 355 | the updated write_weighting: Tensor(batch_size, words_num) 356 | the updated memory_matrix: Tensor (batch_size, words_num, words_size) 357 | the updated link matrix: Tensor(batch_size, words_num, words_num) 358 | the updated precedence vector: Tensor (batch_size, words_num) 359 | """ 360 | 361 | lookup_weighting = self.get_lookup_weighting(memory_matrix, key, strength) 362 | new_usage_vector = self.update_usage_vector(usage_vector, read_weightings, write_weighting, free_gates) 363 | 364 | sorted_usage, free_list = tf.nn.top_k(-1 * new_usage_vector, self.words_num) 365 | sorted_usage = -1 * sorted_usage 366 | 367 | allocation_weighting = self.get_allocation_weighting(sorted_usage, free_list) 368 | new_write_weighting = self.update_write_weighting(lookup_weighting, allocation_weighting, write_gate, allocation_gate) 369 | new_memory_matrix = self.update_memory(memory_matrix, new_write_weighting, write_vector, erase_vector) 370 | new_link_matrix = self.update_link_matrix(precedence_vector, link_matrix, new_write_weighting) 371 | new_precedence_vector = self.update_precedence_vector(precedence_vector, new_write_weighting) 372 | 373 | return new_usage_vector, new_write_weighting, new_memory_matrix, new_link_matrix, new_precedence_vector 374 | 375 | 376 | def read(self, memory_matrix, read_weightings, keys, strengths, link_matrix, read_modes): 377 | """ 378 | defines the complete pipeline for reading from memory 379 | 380 | Parameters: 381 | ---------- 382 | memory_matrix: Tensor (batch_size, words_num, word_size) 383 | the updated memory matrix from the last writing 384 | read_weightings: Tensor (batch_size, words_num, read_heads) 385 | the read weightings form the last time step 386 | keys: Tensor (batch_size, word_size, read_heads) 387 | the kyes to query the memory locations with 388 | strengths: Tensor (batch_size, read_heads) 389 | the strength of each read key 390 | link_matrix: Tensor (batch_size, words_num, words_num) 391 | the updated link matrix from the last writing 392 | read_modes: Tensor (batch_size, 3, read_heads) 393 | the softmax distribution between the three read modes 394 | 395 | Returns: Tuple 396 | the updated read_weightings: Tensor(batch_size, words_num, read_heads) 397 | the recently read vectors: Tensor (batch_size, word_size, read_heads) 398 | """ 399 | 400 | lookup_weighting = self.get_lookup_weighting(memory_matrix, keys, strengths) 401 | forward_weighting, backward_weighting = self.get_directional_weightings(read_weightings, link_matrix) 402 | new_read_weightings = self.update_read_weightings(lookup_weighting, forward_weighting, backward_weighting, read_modes) 403 | new_read_vectors = self.update_read_vectors(memory_matrix, new_read_weightings) 404 | 405 | return new_read_weightings, new_read_vectors 406 | -------------------------------------------------------------------------------- /dnc/utility.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.ops import gen_state_ops 4 | 5 | def pairwise_add(u, v=None, is_batch=False): 6 | """ 7 | performs a pairwise summation between vectors (possibly the same) 8 | 9 | Parameters: 10 | ---------- 11 | u: Tensor (n, ) | (n, 1) 12 | v: Tensor (n, ) | (n, 1) [optional] 13 | is_batch: bool 14 | a flag for whether the vectors come in a batch 15 | ie.: whether the vectors has a shape of (b,n) or (b,n,1) 16 | 17 | Returns: Tensor (n, n) 18 | Raises: ValueError 19 | """ 20 | u_shape = u.get_shape().as_list() 21 | 22 | if len(u_shape) > 2 and not is_batch: 23 | raise ValueError("Expected at most 2D tensors, but got %dD" % len(u_shape)) 24 | if len(u_shape) > 3 and is_batch: 25 | raise ValueError("Expected at most 2D tensor batches, but got %dD" % len(u_shape)) 26 | 27 | if v is None: 28 | v = u 29 | else: 30 | v_shape = v.get_shape().as_list() 31 | if u_shape != v_shape: 32 | raise VauleError("Shapes %s and %s do not match" % (u_shape, v_shape)) 33 | 34 | n = u_shape[0] if not is_batch else u_shape[1] 35 | 36 | column_u = tf.reshape(u, (-1, 1) if not is_batch else (-1, n, 1)) 37 | U = tf.concat(1 if not is_batch else 2, [column_u] * n) 38 | 39 | if v is u: 40 | return U + tf.transpose(U, None if not is_batch else [0, 2, 1]) 41 | else: 42 | row_v = tf.reshape(v, (1, -1) if not is_batch else (-1, 1, n)) 43 | V = tf.concat(0 if not is_batch else 1, [row_v] * n) 44 | 45 | return U + V 46 | 47 | 48 | def decaying_softmax(shape, axis): 49 | rank = len(shape) 50 | max_val = shape[axis] 51 | 52 | weights_vector = np.arange(1, max_val + 1, dtype=np.float32) 53 | weights_vector = weights_vector[::-1] 54 | weights_vector = np.exp(weights_vector) / np.sum(np.exp(weights_vector)) 55 | 56 | container = np.zeros(shape, dtype=np.float32) 57 | broadcastable_shape = [1] * rank 58 | broadcastable_shape[axis] = max_val 59 | 60 | return container + np.reshape(weights_vector, broadcastable_shape) 61 | 62 | def unpack_into_tensorarray(value, axis, size=None): 63 | """ 64 | unpacks a given tensor along a given axis into a TensorArray 65 | 66 | Parameters: 67 | ---------- 68 | value: Tensor 69 | the tensor to be unpacked 70 | axis: int 71 | the axis to unpack the tensor along 72 | size: int 73 | the size of the array to be used if shape inference resulted in None 74 | 75 | Returns: TensorArray 76 | the unpacked TensorArray 77 | """ 78 | 79 | shape = value.get_shape().as_list() 80 | rank = len(shape) 81 | dtype = value.dtype 82 | array_size = shape[axis] if not shape[axis] is None else size 83 | 84 | if array_size is None: 85 | raise ValueError("Can't create TensorArray with size None") 86 | 87 | array = tf.TensorArray(dtype=dtype, size=array_size) 88 | dim_permutation = [axis] + range(1, axis) + [0] + range(axis + 1, rank) 89 | unpack_axis_major_value = tf.transpose(value, dim_permutation) 90 | full_array = array.unpack(unpack_axis_major_value) 91 | 92 | return full_array 93 | 94 | def pack_into_tensor(array, axis): 95 | """ 96 | packs a given TensorArray into a tensor along a given axis 97 | 98 | Parameters: 99 | ---------- 100 | array: TensorArray 101 | the tensor array to pack 102 | axis: int 103 | the axis to pack the array along 104 | 105 | Returns: Tensor 106 | the packed tensor 107 | """ 108 | 109 | packed_tensor = array.pack() 110 | shape = packed_tensor.get_shape() 111 | rank = len(shape) 112 | 113 | dim_permutation = [axis] + range(1, axis) + [0] + range(axis + 1, rank) 114 | correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation) 115 | 116 | return correct_shape_tensor 117 | -------------------------------------------------------------------------------- /docs/basic-usage.md: -------------------------------------------------------------------------------- 1 | # Project Structure and Usage 2 | 3 | ## Structure Overview 4 | 5 | The implementation is structured into three main modules. 6 | 7 | - **Memeory** `dnc/memeory.py`: this module implements the memory access and attention mechanisms used in the DNC architecture. This is considered an internal module which a basic user would need to work directly with. 8 | - **BaseController** `dnc/controller.py`: this module defines an **abstract class** that represents the controller unit in the DNC architecture. The class abstracts away all the common operations between various task (like interface vector parsing, input and read vector concatenation, ... etc) and only leaves for the user two un-implemented methods that concern with defining the internal neural network. 9 | - **DNC** `dnc/dnc.py`: this module integrates the operations of the controller unit and the memory, and it's considered the public API that the user should interact directly with. This module also abstracts away all the common operations across various tasks (like initiating the memory, looping through the time steps, memory-controller communications, ... etc) so the user is only required to construct an instance of that class using the desired parameters, and use simple API to feed data into the model and get outputs out of it. 10 | 11 | The following pseudo-UML diagram summarizes the project structure: 12 | 13 | ![UML](/assets/UML.png) 14 | 15 | The reasons behind such design choices stems from the inherit flexibility and generality of the DNC architecture and its ability to be adapted to various tasks and problems. This design was chosen to reflect these characteristics and allow the user to quickly set up his/her model by focusing on the specific details of the task without worrying about the rest of the architecture's operation. This is also the justification behind leaving any details about the loss, optimizers and session runs out of the implementation and into the users hand to adapt them to whatever task they desire. 16 | 17 | ## Usage 18 | 19 | ### Defining the Controller 20 | 21 | The first step at setting up your task is to define your controller's internal neural network. This is done by extending the `BaseController` class and implementing the two methods that define your network: 22 | 23 | - `network_vars(self): void`: in this method you should define your network variables and their initializers as an instance attributes of the class (aka `self.*`). This method will be used automatically by the `DNC` instance to create the variables upon construction. This method shouldn't return any thing. 24 | - `network_op(self, X): Tensor`: in this method you define the operation of your network, that is the layers operations the activations, batch normalizations, ... etc. This method takes one input, which is a 2D Tensor of shape `batch_size X (input_size + read_vectors_size)` and should return a 2D Tensor of shape `batch_size X output_size` 25 | 26 | When you define your `network_vars` method, you shouldn't worry about calculating the size of the input plus the read vectors, this value will be automatically available for you via the attribute `self.nn_input_size`. The defined batch size is also available via `self.batch_size`. 27 | 28 | The following is an example of controller with a 1-layer feedforward neural network: 29 | 30 | ```python 31 | import tensorflow as tf 32 | from dnc.controller import BaseController 33 | 34 | class FeedfrowardController(BaseController): 35 | def network_vars(self): 36 | self.W = tf.Variable(tf.truncated_normal([self.nn_input_size, 128]), name='weights') 37 | self.b = tf.Variable(tf.zeros([128]), name='bias') 38 | 39 | 40 | def network_op(self, X): 41 | output = tf.matmul(X, self.W) + self.b 42 | activations = tf.nn.relu(output) 43 | 44 | return activations 45 | ``` 46 | 47 | **Notice** that the network handles works with flat inputs and flat outputs, so if you're planning to do convolutions you should: 48 | 1. Flatten your data before passing it to the model. 49 | 2. Bring it back to 2D in the beginning of your `network_op`. 50 | 3. Flatten the output before returning it from the `network_op`. 51 | 52 | #### Defining Recurrent Controllers 53 | 54 | To define a controller with a recurrent neural network, you'll need to add a few things to your new controller class: 55 | - Defining the state of your network inside your `network_vars` method. 56 | - A method named `get_state` that returns a tuple `(previous_output, previous_hidden_state)` which should be read from the defined state. 57 | - A method named `update_state` that will be used to update the values of the state across runs. This method should return a TensorFlow operation. 58 | - Making the `network_op` method take an extra argument for the state and return alongside the output a state tuple. 59 | 60 | You only need to address these changes and the `DNC` module would automatically recognize it and do the rest of the work. 61 | 62 | The following is an example of a possible recurrent controller: 63 | ```python 64 | import tensorflow as tf 65 | from dnc.controller import BaseController 66 | 67 | class RecurrentController(BaseController): 68 | def network_vars(self): 69 | self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(64) 70 | self.state = tf.Variable(tf.zeros([self.batch_size, 64]), trainable=False) 71 | self.output = tf.Variable(tf.zeros([self.batch_size, 64]), trainable=False) 72 | 73 | def network_op(self, X, state): 74 | X = tf.convert_to_tensor(X) 75 | return self.lstm_cell(X, state) 76 | 77 | def update_state(self, new_state): 78 | return tf.group( 79 | self.output.assign(new_state[0]), 80 | self.state.assign(new_state[1]) 81 | ) 82 | 83 | def get_state(self): 84 | return (self.output, self.state) 85 | ``` 86 | 87 | #### Initial Transformation Weights 88 | 89 | By default, the transformation weights matrices ![](https://latex.codecogs.com/gif.latex?W_y%2C%20W_%5Czeta%2C%20W_r) are initialized from a zero-mean normal distribution with a standard deviation of 0.1. This initialization scheme is defined in the `initials` method of a `BaseController`: 90 | 91 | ```python 92 | def initials(self): 93 | """ 94 | sets the initial values of the controller transformation weights matrices 95 | this method can be overwritten to use a different initialization scheme 96 | """ 97 | # defining internal weights of the controller 98 | self.interface_weights = tf.Variable( 99 | tf.random_normal([self.nn_output_size, self.interface_vector_size], stddev=0.1), 100 | name='interface_weights' 101 | ) 102 | self.nn_output_weights = tf.Variable( 103 | tf.random_normal([self.nn_output_size, self.output_size], stddev=0.1), 104 | name='nn_output_weights' 105 | ) 106 | self.mem_output_weights = tf.Variable( 107 | tf.random_normal([self.word_size * self.read_heads, self.output_size], stddev=0.1), 108 | name='mem_output_weights' 109 | ) 110 | ``` 111 | 112 | A different initialization scheme can be defined by overwriting this method with the desired scheme. See [the FeedforwardController of the copy task](../tasks/copy/feedforward_controller.py) as an example of different initialization scheme. 113 | 114 | ### Using the DNC module 115 | 116 | Once you defined your concrete controller class, you're then ready to plug in that controller and use the DNC on your task. 117 | 118 | To do that, you need to construct an instance of the DNC module and pass it your controller class and the desired parameters of your model. The constructor of the DNC module is defined as follows: 119 | 120 | ```python 121 | DNC.__init__( 122 | controller_class, 123 | input_size, 124 | output_size, 125 | max_sequence_length, 126 | memory_words_num = 256, 127 | memory_word_size = 64, 128 | memory_read_heads = 4, 129 | batch_size = 1 130 | ) 131 | ``` 132 | * **controller_class**: is a reference to the concrete controller class you defined earlier. You just need to pass the class, you do not need to construct an instance yourself; the `DNC` constructor will automatically handle that. 133 | * **input_size**: the size of the flatten input vector. 134 | * **output_size**: the size of the flatten output vector. 135 | * **max_sequence_length**: the maximum length of input sequences that is expected to be fed into the model. 136 | * **memory_words_num**: the number of memory locations. 137 | * **memory_word_size**: the size of an individual memory location. 138 | * **memory_read_heads**: the number of read head in the memory. 139 | * **batch_size**: the size of the batch to be fed to the model. 140 | 141 | As you may have noticed, you do not construct an instance of `Memory` directly, you just pass the desired parameters and the `DNC` module will handle its construction. 142 | 143 | To get define the operations leading to the output out of the model, you use the instance method `get_outputs()`: 144 | ```python 145 | output_op, memory_view = dnc_instance.get_outputs() 146 | ``` 147 | *`memory_view` is a pyton `dict` that carries some of the internal values of the model (like weightings and gates) that is mainly used for visualization.* 148 | 149 | To actually get the outputs, you need to run this `output_op`, while feeding three placeholders that are attributes of the dnc instance. These placeholders are: 150 | * **input_data**: a 3D tensor of shape `batch_size X sequence_length X input_size` which represents the inputs of that run. 151 | * **target_output**: a 3D tensor of shape `batch_size X sequence_length X output_size` which represents the desired outputs. 152 | * **sequence_length**: a integer that define the sequence length across that batch. **Notice** that this means that the whole batch must be of the same sequence length (which is a to-be-addressed limitation), but sequence_length can vary between batches as long as they are less than or equal to the `max_sequence_length` the DNC was instantiated with. 153 | 154 | So a run for an instantiated DNC model looks like: 155 | ```python 156 | input_data = ... 157 | target_output = ... 158 | sequence_length = 10 159 | 160 | output = dnc_instance.get_outputs() 161 | loss = some_loss_fn(output, dnc_instance.target_output) 162 | 163 | loss_val, dnc_output = session.run([loss, output], feed_dict={ 164 | dnc_instance.input_data: input_data, 165 | dnc_instance.target_output: target_output, 166 | dnc_instance.sequence_length: sequence_length 167 | }) 168 | ``` 169 | After you train your model, you can save a check point to disk using the `save` method. This method takes three arguments: the tensorflow session, the path to the checkpoints directory at which the checkpoint will be saved, and the name to be saved with. 170 | ```python 171 | dnc_instance.save(session, './checkpoints_dir', 'checkpint-1') 172 | ``` 173 | To restore a previous check point, you can use the `restore` method with the 1st two parameters like in saving and the 3rd one is now the name of the existing checkpoint to be restored. 174 | ```python 175 | dnc_instance.restore(session, './checkpoints_dir', 'checkpint-1') 176 | ``` 177 | #### An Example 178 | 179 | The following is an excerpt from the copy task trainer to demonstrate how a `DNC` instance can be integrated with an optimizer to construct a complete graph. 180 | ```python 181 | optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=momentum) 182 | 183 | ncomputer = DNC( 184 | FeedforwardController, 185 | input_size, 186 | output_size, 187 | 2 * sequence_max_length + 1, 188 | words_count, 189 | word_size, 190 | read_heads, 191 | batch_size 192 | ) 193 | 194 | # squash the DNC output between 0 and 1 195 | output, _ = ncomputer.get_outputs() 196 | squashed_output = tf.clip_by_value(tf.sigmoid(output), 1e-6, 1. - 1e-6) 197 | 198 | loss = binary_cross_entropy(squashed_output, ncomputer.target_output) 199 | 200 | gradients = optimizer.compute_gradients(loss) 201 | for i, (grad, var) in enumerate(gradients): 202 | if grad is not None: 203 | gradients[i] = (tf.clip_by_value(grad, -10, 10), var) 204 | 205 | apply_gradients = optimizer.apply_gradients(gradients) 206 | ``` 207 | -------------------------------------------------------------------------------- /docs/data-flow.md: -------------------------------------------------------------------------------- 1 | # Data Flow Across the Modules 2 | 3 | These pseudo data flow diagrams show how the data flow from input through the three modules of the implementation. These are high level overview of the internal operation of the modules that should ease the process of reading into the source code. 4 | 5 | **Notation** 6 | * **B**: the batch size. 7 | * **T**: the sequence length. 8 | * **the rest** follows the paper notation. 9 | 10 | *No data flow diagram is shown for the memory module as it would be unnecessarily complicated; given that , unlike the other two modules, the memory model is not doing anything more than what is described in the paper*. 11 | 12 | ## DNC Module 13 | 14 | ![DNC-pDFD](/assets/DNC-DFD.png) 15 | 16 | ## Controller Module 17 | 18 | ![Controller-pDFD](/assets/Controller-DFD.png) 19 | -------------------------------------------------------------------------------- /docs/implementation-notes.md: -------------------------------------------------------------------------------- 1 | # Implementation Notes 2 | 3 | ## Mathematics 4 | 5 | Two considerations were taken into account when the mathematical operations were implemented: 6 | 7 | - At the time of the implementation, the version of TensorFlow used (r0.11) lacks a lot regarding slicing and assigning values to slices. 8 | - A vectorized implementation is generally better than an Implementation with a python for loop (usually for the possible parallelism and the fact that python for loops create a copy of the same subgraph, one for each iteration). 9 | 10 | Most of the operations described in the paper lend can be straightforwardly implemented in TensorFlow, except possibly for two operations: the allocation weighting calculations, and the link matrix updates; as they both are described in a slicing and looping manner, which can make their current implementation look a little convoluted. The following attempts to clarify how these operations were implemented. 11 | 12 | ### Allocation Weighting 13 | 14 | In the paper, the allocation weightings are calculated using the formula: 15 | 16 | ![](https://latex.codecogs.com/gif.latex?a_t%5B%5Cphi_t%5Bj%5D%5D%20%3D%20%281%20-%20u_t%5B%5Cphi_t%5Bj%5D%5D%29%5Cprod_%7Bi%3D1%7D%5E%7Bj-1%7Du_t%5B%5Cphi_t%5Bi%5D%5D) 17 | 18 | This operation can be vectorized by instead computing the following formula: 19 | 20 | ![](https://latex.codecogs.com/gif.latex?%5Chat%7Ba%7D_t%20%3D%20%5Cleft%28%201%20-%20%5Chat%7Bu%7D_t%20%5Cright%29%5CPi_t%5E%5Chat%7Bu%7D) 21 | 22 | Where ![](https://latex.codecogs.com/gif.latex?%5Chat%7Bu%7D_t) is the sorted usage vector and ![](https://latex.codecogs.com/gif.latex?%5CPi%5E%7B%5Chat%7Bu%7D%7D_t) is the cumulative product vector of the sorted usage, computed with `tf.cumprod`. With this equation, we get the allocation weighting ![](https://latex.codecogs.com/gif.latex?%5Chat%7Ba%7D_t) out of the original order of the memory locations. We can reorder it into the original order of the memory locations using `TensorArray`'s scatter operation using the free-list as the scatter indices. 23 | 24 | ```python 25 | shifted_cumprod = tf.cumprod(sorted_usage, axis = 1, exclusive=True) 26 | unordered_allocation_weighting = (1 - sorted_usage) * shifted_cumprod 27 | 28 | mapped_free_list = free_list + self.index_mapper 29 | flat_unordered_allocation_weighting = tf.reshape(unordered_allocation_weighting, (-1,)) 30 | flat_mapped_free_list = tf.reshape(mapped_free_list, (-1,)) 31 | flat_container = tf.TensorArray(tf.float32, self.batch_size * self.words_num) 32 | 33 | flat_ordered_weightings = flat_container.scatter( 34 | flat_mapped_free_list, 35 | flat_unordered_allocation_weighting 36 | ) 37 | 38 | packed_wightings = flat_ordered_weightings.pack() 39 | return tf.reshape(packed_wightings, (self.batch_size, self.words_num)) 40 | ``` 41 | 42 | Because `TensorArray` operations work only on one dimension and our allocation weightings are of shape *batch_size × N*, we map the free-list indices to their values as if they point to consecutive locations in a flat container. Then we flat all the operands and reshape them back to their original 2D shapes at the end. This process is depicted in the following figure. 43 | 44 | ![](../assets/allocation_weighting.png) 45 | 46 | ### Link Matrix 47 | 48 | The paper's original formulation of the link matrix update is and index-by-index operation: 49 | 50 | ![](https://latex.codecogs.com/gif.latex?L_t%5Bi%2Cj%5D%20%3D%20%281%20-%20%5Cmathbf%7Bw%7D%5E%7Bw%7D_%7Bt%7D%5Bi%5D%20-%20%5Cmathbf%7Bw%7D%5E%7Bw%7D_%7Bt%7D%5Bj%5D%29L_%7Bt-1%7D%5Bi%2Cj%5D%20+%20%5Cmathbf%7Bw%7D%5E%7Bw%7D_%7Bt%7D%5Bi%5D%5Cmathbf%7Bp%7D_%7Bt-1%7D%5Bj%5D) 51 | 52 | A vectorized implementation of this operation can be written as: 53 | 54 | ![](https://latex.codecogs.com/gif.latex?L_t%20%3D%20%5B%281%20-%20%28%5Cmathbf%7Bw%7D_t%5E%7Bw%7D%5Coplus%20%5Cmathbf%7Bw%7D_t%5E%7Bw%7D%29%29%5Ccirc%20L_%7Bt-1%7D%20+%20%5Cmathbf%7Bw%7D_t%5E%7Bw%7D%5Cmathbf%7Bp%7D_%7Bt-1%7D%5D%5Ccirc%20%281-I%29) 55 | 56 | Where ![](https://latex.codecogs.com/gif.latex?%5Ccirc) is elementwise multiplication, and ![](https://latex.codecogs.com/gif.latex?%5Coplus) is a *pairwise addition* operator defined as: 57 | 58 | ![](https://latex.codecogs.com/gif.latex?u%20%5Coplus%20v%20%3D%20%5Cbegin%7Bpmatrix%7D%20u_1%20+%20v_1%20%26%20%5Chdots%20%26%20u_1+v_n%20%5C%5C%20%5Cvdots%20%26%20%5Cddots%20%26%20%5Cvdots%5C%5C%20u_n+v_1%20%26%20%5Chdots%20%26%20u_n+v_n%20%5Cend%7Bpmatrix%7D) 59 | 60 | Where ![](https://latex.codecogs.com/gif.latex?%5Cinline%20u%2Cv%20%5Cin%20%5Cmathbb%7BR%7D%5En). This allows TensorFlow to parallelize the operation, but of course with a cost incurred on the space complexity. 61 | 62 | *The elementwise multiplication by ![](https://latex.codecogs.com/gif.latex?%5Cinline%20%5Cmathit%7B1%20-%20I%7D) is to ensure that all diagonal elements are zero, thus ensuring the elimination of self-links.* 63 | 64 | 65 | ## Weight Initializations 66 | 67 | * **Memory's usage and precedence vectors and link matrix** are initialized to zero as specified by the paper. 68 | 69 | * **Memory's matrix, read and write weightings, and read vectors** are initialized to a very small value (10⁻⁶). Attempting to initialize them to 0 resulted in **NaN** after the first few iterations. 70 | 71 | *These initialization schemes were chosen after many experiments on the copy-task, as they've shown the highest degree of stability in training (The highest ratio of convergence, and the smallest ratio of NaN-outs). However, they might re-consideration with other tasks.* 72 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for file in `ls ./unit-tests/` 4 | do 5 | if [ -f './unit-tests/'$file ]; then 6 | printf "Running:: %s\n\n" "$file" 7 | python './unit-tests/'$file 8 | printf "\n" 9 | fi 10 | done 11 | -------------------------------------------------------------------------------- /tasks/babi/README.md: -------------------------------------------------------------------------------- 1 | ### Settings 2 | 3 | - RMSProp optimizer with learning rate 10⁻⁴, and momentum of 0.9. 4 | - Memory with 256 locations each with size 32, and 4 read heads. 5 | - A recurrent neural network consisting of a single LSTM layer with hidden size 256. The RNN state is rest back to zero on the start of each new input sequence. 6 | - The controller weights are initialized from a zero mean normal distribution with a standard deviation of 0.1. 7 | - A batch size of 1. 8 | 9 | The loss function is the cross-entropy between the softmax of the output sequence and the target sequence. All outputs that do not correspond to an answer step are ignored using a weights vector that has 1s at the steps with answers at them and 0s elsewhere. 10 | 11 | ![loss](https://latex.codecogs.com/gif.latex?%5Cmathcal%7BL%7D%28y%2C%20%5Chat%7By%7D%29%20%3D%20-%20%5Csum_%7Bi%20%3D%201%7D%5E%7B%5Cleft%20%7C%20D%20%5Cright%20%7C%7D%20w_iy_i%5Clog%5Cleft%28%5Cmathop%7B%5Cmathrm%7Bsoftmax%7D%28%5Chat%7By%7D_i%29%7D%20%5Cright%20%29%2C%20%5Chspace%7B2em%7D%20w_i%20%3D%20%5Cleft%5C%7B%5Cbegin%7Bmatrix%7D%201%20%26%20%5Ctext%7Bif%20step%20%7D%20i%20%5Ctext%7B%20is%20an%20answer%7D%20%5C%5C%200%20%26%20%5Ctext%7Botherwise%7D%20%5Cend%7Bmatrix%7D%5Cright.) 12 | 13 | Where **|D|** is the number of unique lexicons in the data. 14 | 15 | All gradients are clipped between -10 and 10. 16 | 17 | ### Preprocessing 18 | 19 | 1. All words are brought to lower case, all the numbers are omitted, and all punctuation marks except for {?, ., -} are omitted. A dictionary is then built of the remaining unique lexicons. 20 | 2. All the data are encoded as numeric input and output vectors. The input vectors contains each word of the story as its numeric code in the dictionary, with all the words representing answers are replaced with the code for '-'. The output vector contains the codes for these replaced answer words. 21 | 3. All training data are joined together in one file and the testing data are kept separate. 22 | 23 | This whole process is implemented in the `preprocess.py` file with the option to filter stories by sequence length as well as the option to save the training data separately. The script can also work for any version of the data, not just the **en-10k**, just by specifying the path to the desired data. 24 | 25 | ``` 26 | $python preprocess.py --data_dir [--single_train] [--length_limit] 27 | ``` 28 | 29 | ### Training 30 | 31 | In training time, a story is randomly sampled form the training data, encoded into a sequence of one-hot vectors, and the appropriate weight vector is generated. 32 | 33 | Training span about 500k iterations and took about 7 consecutive days on an AWS P2 instance with a Tesla K80 GPU and 16GB of RAM. The following plot contains the learning curve throughout training. 34 | 35 | ![learning-curve](../../assets/babi-training.png) 36 | 37 | ### Testing 38 | 39 | After the model was trained, it was tested against the test data for each separate task. A answer is chosen to be the most probable word in the softmax distribution of the output. A question is considered to be answered correctly if and only if all of its answers words were predicted correctly. If the model crossed a 5% error rate on a some task, it's considered a failed task. A report is finally given comparing the resulting error rates to the mean values reported in the paper. 40 | 41 | This process is implemented in the `test.py` file. 42 | 43 | ### Re-training 44 | 45 | ``` 46 | $python preprocess --data_dir=path/to/en-10k 47 | $python train.py --iterations=500005 48 | $python test.py 49 | ``` 50 | 51 | #### *Caution!* 52 | *Because the provided model in checkpoints was trained on GPU, attempting to restore it and running the `test.py` on CPU only version of TensorFlow will result in a very high error-rate. In order to get the reported results here, the model needs to be restored and ran on a GPU-supported version of TensorFlow.* 53 | 54 | *I'm not sure why this happens, but it's probably due to the device placement choice done automatically by TensorFlow as there's no explicit device placement in the implementation. If you have a solution this problem, feel free to create a pull request with a fix or an issue describing how it could be fixed.* 55 | -------------------------------------------------------------------------------- /tasks/babi/checkpoints/step-500005/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /tasks/babi/checkpoints/step-500005/model.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/babi/checkpoints/step-500005/model.ckpt -------------------------------------------------------------------------------- /tasks/babi/checkpoints/step-500005/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/babi/checkpoints/step-500005/model.ckpt.meta -------------------------------------------------------------------------------- /tasks/babi/dnc: -------------------------------------------------------------------------------- 1 | ../../dnc -------------------------------------------------------------------------------- /tasks/babi/logs/events.out.tfevents.1483079568.ip-172-31-15-24: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/babi/logs/events.out.tfevents.1483079568.ip-172-31-15-24 -------------------------------------------------------------------------------- /tasks/babi/logs/events.out.tfevents.1483218179.ip-172-31-15-24: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/babi/logs/events.out.tfevents.1483218179.ip-172-31-15-24 -------------------------------------------------------------------------------- /tasks/babi/logs/events.out.tfevents.1483353893.ip-172-31-15-24: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/babi/logs/events.out.tfevents.1483353893.ip-172-31-15-24 -------------------------------------------------------------------------------- /tasks/babi/logs/events.out.tfevents.1483491250.ip-172-31-15-24: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/babi/logs/events.out.tfevents.1483491250.ip-172-31-15-24 -------------------------------------------------------------------------------- /tasks/babi/logs/events.out.tfevents.1483628556.ip-172-31-15-24: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/babi/logs/events.out.tfevents.1483628556.ip-172-31-15-24 -------------------------------------------------------------------------------- /tasks/babi/preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle 3 | import getopt 4 | import numpy as np 5 | from shutil import rmtree 6 | from os import listdir, mkdir 7 | from os.path import join, isfile, isdir, dirname, basename, normpath, abspath, exists 8 | 9 | def llprint(message): 10 | sys.stdout.write(message) 11 | sys.stdout.flush() 12 | 13 | def create_dictionary(files_list): 14 | """ 15 | creates a dictionary of unique lexicons in the dataset and their mapping to numbers 16 | 17 | Parameters: 18 | ---------- 19 | files_list: list 20 | the list of files to scan through 21 | 22 | Returns: dict 23 | the constructed dictionary of lexicons 24 | """ 25 | 26 | lexicons_dict = {} 27 | id_counter = 0 28 | 29 | llprint("Creating Dictionary ... 0/%d" % (len(files_list))) 30 | 31 | for indx, filename in enumerate(files_list): 32 | with open(filename, 'r') as fobj: 33 | for line in fobj: 34 | 35 | # first seperate . and ? away from words into seperate lexicons 36 | line = line.replace('.', ' .') 37 | line = line.replace('?', ' ?') 38 | line = line.replace(',', ' ') 39 | 40 | for word in line.split(): 41 | if not word.lower() in lexicons_dict and word.isalpha(): 42 | lexicons_dict[word.lower()] = id_counter 43 | id_counter += 1 44 | 45 | llprint("\rCreating Dictionary ... %d/%d" % ((indx + 1), len(files_list))) 46 | 47 | print "\rCreating Dictionary ... Done!" 48 | return lexicons_dict 49 | 50 | 51 | def encode_data(files_list, lexicons_dictionary, length_limit=None): 52 | """ 53 | encodes the dataset into its numeric form given a constructed dictionary 54 | 55 | Parameters: 56 | ---------- 57 | files_list: list 58 | the list of files to scan through 59 | lexicons_dictionary: dict 60 | the mappings of unique lexicons 61 | 62 | Returns: tuple (dict, int) 63 | the data in its numeric form, maximum story length 64 | """ 65 | 66 | files = {} 67 | story_inputs = None 68 | story_outputs = None 69 | stories_lengths = [] 70 | answers_flag = False # a flag to specify when to put data into outputs list 71 | limit = length_limit if not length_limit is None else float("inf") 72 | 73 | llprint("Encoding Data ... 0/%d" % (len(files_list))) 74 | 75 | for indx, filename in enumerate(files_list): 76 | 77 | files[filename] = [] 78 | 79 | with open(filename, 'r') as fobj: 80 | for line in fobj: 81 | 82 | # first seperate . and ? away from words into seperate lexicons 83 | line = line.replace('.', ' .') 84 | line = line.replace('?', ' ?') 85 | line = line.replace(',', ' ') 86 | 87 | answers_flag = False # reset as answers end by end of line 88 | 89 | for i, word in enumerate(line.split()): 90 | 91 | if word == '1' and i == 0: 92 | # beginning of a new story 93 | if not story_inputs is None: 94 | stories_lengths.append(len(story_inputs)) 95 | if len(story_inputs) <= limit: 96 | files[filename].append({ 97 | 'inputs':story_inputs, 98 | 'outputs': story_outputs 99 | }) 100 | story_inputs = [] 101 | story_outputs = [] 102 | 103 | if word.isalpha() or word == '?' or word == '.': 104 | if not answers_flag: 105 | story_inputs.append(lexicons_dictionary[word.lower()]) 106 | else: 107 | story_inputs.append(lexicons_dictionary['-']) 108 | story_outputs.append(lexicons_dictionary[word.lower()]) 109 | 110 | # set the answers_flags if a question mark is encountered 111 | if not answers_flag: 112 | answers_flag = (word == '?') 113 | 114 | llprint("\rEncoding Data ... %d/%d" % (indx + 1, len(files_list))) 115 | 116 | print "\rEncoding Data ... Done!" 117 | return files, stories_lengths 118 | 119 | 120 | if __name__ == '__main__': 121 | task_dir = dirname(abspath(__file__)) 122 | options,_ = getopt.getopt(sys.argv[1:], '', ['data_dir=', 'single_train', 'length_limit=']) 123 | data_dir = None 124 | joint_train = True 125 | length_limit = None 126 | files_list = [] 127 | 128 | if not exists(join(task_dir, 'data')): 129 | mkdir(join(task_dir, 'data')) 130 | 131 | for opt in options: 132 | if opt[0] == '--data_dir': 133 | data_dir = opt[1] 134 | if opt[0] == '--single_train': 135 | joint_train = False 136 | if opt[0] == '--length_limit': 137 | length_limit = int(opt[1]) 138 | 139 | if data_dir is None: 140 | raise ValueError("data_dir argument cannot be None") 141 | 142 | for entryname in listdir(data_dir): 143 | entry_path = join(data_dir, entryname) 144 | if isfile(entry_path): 145 | files_list.append(entry_path) 146 | 147 | lexicon_dictionary = create_dictionary(files_list) 148 | lexicon_count = len(lexicon_dictionary) 149 | 150 | # append used punctuation to dictionary 151 | lexicon_dictionary['?'] = lexicon_count 152 | lexicon_dictionary['.'] = lexicon_count + 1 153 | lexicon_dictionary['-'] = lexicon_count + 2 154 | 155 | encoded_files, stories_lengths = encode_data(files_list, lexicon_dictionary, length_limit) 156 | 157 | stories_lengths = np.array(stories_lengths) 158 | length_limit = np.max(stories_lengths) if length_limit is None else length_limit 159 | print "Total Number of stories: %d" % (len(stories_lengths)) 160 | print "Number of stories with lengthes > %d: %d (%% %.2f) [discarded]" % (length_limit, np.sum(stories_lengths > length_limit), np.mean(stories_lengths > length_limit) * 100.0) 161 | print "Number of Remaining Stories: %d" % (len(stories_lengths[stories_lengths <= length_limit])) 162 | 163 | processed_data_dir = join(task_dir, 'data', basename(normpath(data_dir))) 164 | train_data_dir = join(processed_data_dir, 'train') 165 | test_data_dir = join(processed_data_dir, 'test') 166 | if exists(processed_data_dir) and isdir(processed_data_dir): 167 | rmtree(processed_data_dir) 168 | 169 | mkdir(processed_data_dir) 170 | mkdir(train_data_dir) 171 | mkdir(test_data_dir) 172 | 173 | llprint("Saving processed data to disk ... ") 174 | 175 | pickle.dump(lexicon_dictionary, open(join(processed_data_dir, 'lexicon-dict.pkl'), 'wb')) 176 | 177 | joint_train_data = [] 178 | 179 | for filename in encoded_files: 180 | if filename.endswith("test.txt"): 181 | pickle.dump(encoded_files[filename], open(join(test_data_dir, basename(filename) + '.pkl'), 'wb')) 182 | elif filename.endswith("train.txt"): 183 | if not joint_train: 184 | pickle.dump(encoded_files[filename], open(join(train_data_dir, basename(filename) + '.pkl'), 'wb')) 185 | else: 186 | joint_train_data.extend(encoded_files[filename]) 187 | 188 | if joint_train: 189 | pickle.dump(joint_train_data, open(join(train_data_dir, 'train.pkl'), 'wb')) 190 | 191 | llprint("Done!\n") 192 | -------------------------------------------------------------------------------- /tasks/babi/recurrent_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from dnc.controller import BaseController 4 | 5 | """ 6 | A 1-layer LSTM recurrent neural network with 256 hidden units 7 | Note: the state of the LSTM is not saved in a variable becuase we want 8 | the state to reset to zero on every input sequnece 9 | """ 10 | 11 | class RecurrentController(BaseController): 12 | 13 | def network_vars(self): 14 | self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(256) 15 | self.state = self.lstm_cell.zero_state(self.batch_size, tf.float32) 16 | 17 | def network_op(self, X, state): 18 | X = tf.convert_to_tensor(X) 19 | return self.lstm_cell(X, state) 20 | 21 | def get_state(self): 22 | return self.state 23 | 24 | def update_state(self, new_state): 25 | return tf.no_op() 26 | -------------------------------------------------------------------------------- /tasks/babi/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from recurrent_controller import RecurrentController 4 | from dnc.dnc import DNC 5 | import tensorflow as tf 6 | import numpy as np 7 | import pickle 8 | import sys 9 | import os 10 | import re 11 | 12 | def llprint(message): 13 | sys.stdout.write(message) 14 | sys.stdout.flush() 15 | 16 | def load(path): 17 | return pickle.load(open(path, 'rb')) 18 | 19 | def onehot(index, size): 20 | vec = np.zeros(size, dtype=np.float32) 21 | vec[index] = 1.0 22 | return vec 23 | 24 | def prepare_sample(sample, target_code, word_space_size): 25 | input_vec = np.array(sample[0]['inputs'], dtype=np.float32) 26 | output_vec = np.array(sample[0]['inputs'], dtype=np.float32) 27 | seq_len = input_vec.shape[0] 28 | weights_vec = np.zeros(seq_len, dtype=np.float32) 29 | 30 | target_mask = (input_vec == target_code) 31 | output_vec[target_mask] = sample[0]['outputs'] 32 | weights_vec[target_mask] = 1.0 33 | 34 | input_vec = np.array([onehot(code, word_space_size) for code in input_vec]) 35 | output_vec = np.array([onehot(code, word_space_size) for code in output_vec]) 36 | 37 | return ( 38 | np.reshape(input_vec, (1, -1, word_space_size)), 39 | np.reshape(output_vec, (1, -1, word_space_size)), 40 | seq_len, 41 | np.reshape(weights_vec, (1, -1, 1)) 42 | ) 43 | 44 | ckpts_dir = './checkpoints/' 45 | lexicon_dictionary = load('./data/en-10k/lexicon-dict.pkl') 46 | question_code = lexicon_dictionary["?"] 47 | target_code = lexicon_dictionary["-"] 48 | test_files = [] 49 | 50 | for entryname in os.listdir('./data/en-10k/test/'): 51 | entry_path = os.path.join('./data/en-10k/test/', entryname) 52 | if os.path.isfile(entry_path): 53 | test_files.append(entry_path) 54 | 55 | graph = tf.Graph() 56 | with graph.as_default(): 57 | with tf.Session(graph=graph) as session: 58 | 59 | ncomputer = DNC( 60 | RecurrentController, 61 | input_size=len(lexicon_dictionary), 62 | output_size=len(lexicon_dictionary), 63 | max_sequence_length=100, 64 | memory_words_num=256, 65 | memory_word_size=64, 66 | memory_read_heads=4, 67 | ) 68 | 69 | ncomputer.restore(session, ckpts_dir, 'step-500005') 70 | 71 | outputs, _ = ncomputer.get_outputs() 72 | softmaxed = tf.nn.softmax(outputs) 73 | 74 | tasks_results = {} 75 | tasks_names = {} 76 | for test_file in test_files: 77 | test_data = load(test_file) 78 | task_regexp = r'qa([0-9]{1,2})_([a-z\-]*)_test.txt.pkl' 79 | task_filename = os.path.basename(test_file) 80 | task_match_obj = re.match(task_regexp, task_filename) 81 | task_number = task_match_obj.group(1) 82 | task_name = task_match_obj.group(2).replace('-', ' ') 83 | tasks_names[task_number] = task_name 84 | counter = 0 85 | results = [] 86 | 87 | llprint("%s ... %d/%d" % (task_name, counter, len(test_data))) 88 | 89 | for story in test_data: 90 | astory = np.array(story['inputs']) 91 | questions_indecies = np.argwhere(astory == question_code) 92 | questions_indecies = np.reshape(questions_indecies, (-1,)) 93 | target_mask = (astory == target_code) 94 | 95 | desired_answers = np.array(story['outputs']) 96 | input_vec, _, seq_len, _ = prepare_sample([story], target_code, len(lexicon_dictionary)) 97 | softmax_output = session.run(softmaxed, feed_dict={ 98 | ncomputer.input_data: input_vec, 99 | ncomputer.sequence_length: seq_len 100 | }) 101 | 102 | softmax_output = np.squeeze(softmax_output, axis=0) 103 | given_answers = np.argmax(softmax_output[target_mask], axis=1) 104 | 105 | 106 | answers_cursor = 0 107 | for question_indx in questions_indecies: 108 | question_grade = [] 109 | targets_cursor = question_indx + 1 110 | while targets_cursor < len(astory) and astory[targets_cursor] == target_code: 111 | question_grade.append(given_answers[answers_cursor] == desired_answers[answers_cursor]) 112 | answers_cursor += 1 113 | targets_cursor += 1 114 | results.append(np.prod(question_grade)) 115 | 116 | counter += 1 117 | llprint("\r%s ... %d/%d" % (task_name, counter, len(test_data))) 118 | 119 | error_rate = 1. - np.mean(results) 120 | tasks_results[task_number] = error_rate 121 | llprint("\r%s ... %.3f%% Error Rate.\n" % (task_name, error_rate * 100)) 122 | 123 | print "\n" 124 | print "%-27s%-27s%s" % ("Task", "Result", "Paper's Mean") 125 | print "-------------------------------------------------------------------" 126 | paper_means = { 127 | '1': '9.0±12.6%', '2': '39.2±20.5%', '3': '39.6±16.4%', 128 | '4': '0.4±0.7%', '5': '1.5±1.0%', '6': '6.9±7.5%', '7': '9.8±7.0%', 129 | '8': '5.5±5.9%', '9': '7.7±8.3%', '10': '9.6±11.4%', '11':'3.3±5.7%', 130 | '12': '5.0±6.3%', '13': '3.1±3.6%', '14': '11.0±7.5%', '15': '27.2±20.1%', 131 | '16': '53.6±1.9%', '17': '32.4±8.0%', '18': '4.2±1.8%', '19': '64.6±37.4%', 132 | '20': '0.0±0.1%', 'mean': '16.7±7.6%', 'fail': '11.2±5.4' 133 | } 134 | for k in range(20): 135 | task_id = str(k + 1) 136 | task_result = "%.2f%%" % (tasks_results[task_id] * 100) 137 | print "%-27s%-27s%s" % (tasks_names[task_id], task_result, paper_means[task_id]) 138 | print "-------------------------------------------------------------------" 139 | all_tasks_results = [v for _,v in tasks_results.iteritems()] 140 | results_mean = "%.2f%%" % (np.mean(all_tasks_results) * 100) 141 | failed_count = "%d" % (np.sum(np.array(all_tasks_results) > 0.05)) 142 | 143 | print "%-27s%-27s%s" % ("Mean Err.", results_mean, paper_means['mean']) 144 | print "%-27s%-27s%s" % ("Failed (err. > 5%)", failed_count, paper_means['fail']) 145 | -------------------------------------------------------------------------------- /tasks/babi/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import pickle 7 | import getopt 8 | import time 9 | import sys 10 | import os 11 | 12 | from dnc.dnc import DNC 13 | from recurrent_controller import RecurrentController 14 | 15 | def llprint(message): 16 | sys.stdout.write(message) 17 | sys.stdout.flush() 18 | 19 | def load(path): 20 | return pickle.load(open(path, 'rb')) 21 | 22 | def onehot(index, size): 23 | vec = np.zeros(size, dtype=np.float32) 24 | vec[index] = 1.0 25 | return vec 26 | 27 | def prepare_sample(sample, target_code, word_space_size): 28 | input_vec = np.array(sample[0]['inputs'], dtype=np.float32) 29 | output_vec = np.array(sample[0]['inputs'], dtype=np.float32) 30 | seq_len = input_vec.shape[0] 31 | weights_vec = np.zeros(seq_len, dtype=np.float32) 32 | 33 | target_mask = (input_vec == target_code) 34 | output_vec[target_mask] = sample[0]['outputs'] 35 | weights_vec[target_mask] = 1.0 36 | 37 | input_vec = np.array([onehot(code, word_space_size) for code in input_vec]) 38 | output_vec = np.array([onehot(code, word_space_size) for code in output_vec]) 39 | 40 | return ( 41 | np.reshape(input_vec, (1, -1, word_space_size)), 42 | np.reshape(output_vec, (1, -1, word_space_size)), 43 | seq_len, 44 | np.reshape(weights_vec, (1, -1, 1)) 45 | ) 46 | 47 | 48 | 49 | if __name__ == '__main__': 50 | 51 | dirname = os.path.dirname(__file__) 52 | ckpts_dir = os.path.join(dirname , 'checkpoints') 53 | data_dir = os.path.join(dirname, 'data', 'en-10k') 54 | tb_logs_dir = os.path.join(dirname, 'logs') 55 | 56 | llprint("Loading Data ... ") 57 | lexicon_dict = load(os.path.join(data_dir, 'lexicon-dict.pkl')) 58 | data = load(os.path.join(data_dir, 'train', 'train.pkl')) 59 | llprint("Done!\n") 60 | 61 | batch_size = 1 62 | input_size = output_size = len(lexicon_dict) 63 | sequence_max_length = 100 64 | word_space_size = len(lexicon_dict) 65 | words_count = 256 66 | word_size = 64 67 | read_heads = 4 68 | 69 | learning_rate = 1e-4 70 | momentum = 0.9 71 | 72 | from_checkpoint = None 73 | iterations = 100000 74 | start_step = 0 75 | 76 | options,_ = getopt.getopt(sys.argv[1:], '', ['checkpoint=', 'iterations=', 'start=']) 77 | 78 | for opt in options: 79 | if opt[0] == '--checkpoint': 80 | from_checkpoint = opt[1] 81 | elif opt[0] == '--iterations': 82 | iterations = int(opt[1]) 83 | elif opt[0] == '--start': 84 | start_step = int(opt[1]) 85 | 86 | graph = tf.Graph() 87 | with graph.as_default(): 88 | with tf.Session(graph=graph) as session: 89 | 90 | llprint("Building Computational Graph ... ") 91 | 92 | optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=momentum) 93 | summerizer = tf.train.SummaryWriter(tb_logs_dir, session.graph) 94 | 95 | ncomputer = DNC( 96 | RecurrentController, 97 | input_size, 98 | output_size, 99 | sequence_max_length, 100 | words_count, 101 | word_size, 102 | read_heads, 103 | batch_size 104 | ) 105 | 106 | output, _ = ncomputer.get_outputs() 107 | 108 | loss_weights = tf.placeholder(tf.float32, [batch_size, None, 1]) 109 | loss = tf.reduce_mean( 110 | loss_weights * tf.nn.softmax_cross_entropy_with_logits(output, ncomputer.target_output) 111 | ) 112 | 113 | summeries = [] 114 | 115 | gradients = optimizer.compute_gradients(loss) 116 | for i, (grad, var) in enumerate(gradients): 117 | if grad is not None: 118 | gradients[i] = (tf.clip_by_value(grad, -10, 10), var) 119 | for (grad, var) in gradients: 120 | if grad is not None: 121 | summeries.append(tf.histogram_summary(var.name + '/grad', grad)) 122 | 123 | apply_gradients = optimizer.apply_gradients(gradients) 124 | 125 | summeries.append(tf.scalar_summary("Loss", loss)) 126 | 127 | summerize_op = tf.merge_summary(summeries) 128 | no_summerize = tf.no_op() 129 | 130 | llprint("Done!\n") 131 | 132 | llprint("Initializing Variables ... ") 133 | session.run(tf.initialize_all_variables()) 134 | llprint("Done!\n") 135 | 136 | if from_checkpoint is not None: 137 | llprint("Restoring Checkpoint %s ... " % (from_checkpoint)) 138 | ncomputer.restore(session, ckpts_dir, from_checkpoint) 139 | llprint("Done!\n") 140 | 141 | 142 | last_100_losses = [] 143 | 144 | start = 0 if start_step == 0 else start_step + 1 145 | end = start_step + iterations + 1 146 | 147 | start_time_100 = time.time() 148 | end_time_100 = None 149 | avg_100_time = 0. 150 | avg_counter = 0 151 | 152 | for i in xrange(start, end + 1): 153 | try: 154 | llprint("\rIteration %d/%d" % (i, end)) 155 | 156 | sample = np.random.choice(data, 1) 157 | input_data, target_output, seq_len, weights = prepare_sample(sample, lexicon_dict['-'], word_space_size) 158 | 159 | summerize = (i % 100 == 0) 160 | take_checkpoint = (i != 0) and (i % end == 0) 161 | 162 | loss_value, _, summary = session.run([ 163 | loss, 164 | apply_gradients, 165 | summerize_op if summerize else no_summerize 166 | ], feed_dict={ 167 | ncomputer.input_data: input_data, 168 | ncomputer.target_output: target_output, 169 | ncomputer.sequence_length: seq_len, 170 | loss_weights: weights 171 | }) 172 | 173 | last_100_losses.append(loss_value) 174 | summerizer.add_summary(summary, i) 175 | 176 | if summerize: 177 | llprint("\n\tAvg. Cross-Entropy: %.7f\n" % (np.mean(last_100_losses))) 178 | 179 | end_time_100 = time.time() 180 | elapsed_time = (end_time_100 - start_time_100) / 60 181 | avg_counter += 1 182 | avg_100_time += (1. / avg_counter) * (elapsed_time - avg_100_time) 183 | estimated_time = (avg_100_time * ((end - i) / 100.)) / 60. 184 | 185 | print "\tAvg. 100 iterations time: %.2f minutes" % (avg_100_time) 186 | print "\tApprox. time to completion: %.2f hours" % (estimated_time) 187 | 188 | start_time_100 = time.time() 189 | last_100_losses = [] 190 | 191 | if take_checkpoint: 192 | llprint("\nSaving Checkpoint ... "), 193 | ncomputer.save(session, ckpts_dir, 'step-%d' % (i)) 194 | llprint("Done!\n") 195 | 196 | except KeyboardInterrupt: 197 | 198 | llprint("\nSaving Checkpoint ... "), 199 | ncomputer.save(session, ckpts_dir, 'step-%d' % (i)) 200 | llprint("Done!\n") 201 | sys.exit(0) 202 | -------------------------------------------------------------------------------- /tasks/copy/README.md: -------------------------------------------------------------------------------- 1 | ### Common Settings 2 | 3 | Both series and single models were trained on 2-layer feedforward controller (with hidden sizes 128 and 256 respectively) with ReLU activations, and both share the following set of hyperparameters: 4 | 5 | - RMSProp Optimizer with learning rate of 10⁻⁴, momentum of 0.9. 6 | - Memory word size of 10, with a single read head. 7 | - Controller weights are initialized from samples 1 standard-deviation away from a zero mean normal distribution with a variance ![](https://latex.codecogs.com/gif.latex?%5Cinline%20%5Csigma%5E2%20%3D%20%5Ctext%7Bmin%7D%5Chspace%7B0.2em%7D%5Cleft%281%5Ctimes10%5E%7B-4%7D%2C%20%5Cfrac%7B2%7D%7Bn%7D%5Cright%29), where ![](https://latex.codecogs.com/gif.latex?%5Cinline%20n) is the size of the input vector coming into the weight matrix. 8 | - A batch size of 1. 9 | 10 | All output from the DNC is squashed between 0 and 1 using a sigmoid functions and binary cross-entropy loss (or logistic loss) function of the form: 11 | 12 | ![loss](https://latex.codecogs.com/gif.latex?%5Cmathcal%7BL%7D%28y%2C%20%5Chat%7By%7D%29%20%3D%20-%5Cfrac%7B1%7D%7BBTS%7D%5Csum_%7Bi%3D1%7D%5E%7BB%7D%5Csum_%7Bj%3D1%7D%5E%7BT%7D%5Csum_%7Bk%3D1%7D%5ES%5Cleft%28%20y_%7Bijk%7D%5Clog%20%5Chat%7By%7D_%7Bijk%7D%20+%20%281%20-%20y_%7Bijk%7D%29%5Clog%281-%5Chat%7By%7D_%7Bijk%7D%29%20%5Cright%29) 13 | 14 | is used. That is the mean of the logistic loss across the batch, time steps, and output size. 15 | 16 | All gradients are clipped between -10 and 10. 17 | 18 | *Possible __NaNs__ could occur during training!* 19 | 20 | 21 | ### Series Training 22 | 23 | The model was first trained on a length-2 series of random binary vectors of size 6. Then starting off from the length-2 learned model, a length-4 model was trained in a **curriculum learning** fashion. 24 | 25 | The following plots show the learning curves for the length-2 and length-4 models respectively. 26 | 27 | ![series-2](/assets/model-series-2-curve.png) 28 | 29 | ![series-4](/assets/model-series-4-curve.png) 30 | 31 | *Attempting to train a length-4 model directly always resulted in __NaNs__. The paper mentioned using curriculum learning for the graph and mini-SHRDLU tasks, but it did not mention any thing about the copy task, so there's a possibility that this is not the most efficient method.* 32 | 33 | #### Retraining 34 | ``` 35 | $python tasks/copy/train-series.py --length=2 36 | ``` 37 | Then, assuming that the trained model from that execution is saved under the name 'step-100000'. 38 | 39 | ``` 40 | $python tasks/copy/train-series.py --length=4 --checkpoint=step-100000 --iterations=20000 41 | ``` 42 | 43 | ### Single Training 44 | 45 | The model was trained directly on a single input of length between 1 and 10 and the length was chosen randomly at each run, so no curriculum learning was used. The following plot shows the learning curve of the single model. 46 | 47 | ![single-10](/assets/model-single-curve.png) 48 | 49 | #### Retraining 50 | 51 | ``` 52 | $python tasks/copy/train.py --iterations=50000 53 | ``` 54 | -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-series-2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-series-2/model.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/copy/checkpoints/model-series-2/model.ckpt -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-series-2/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/copy/checkpoints/model-series-2/model.ckpt.meta -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-series-4/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-series-4/model.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/copy/checkpoints/model-series-4/model.ckpt -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-series-4/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/copy/checkpoints/model-series-4/model.ckpt.meta -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-single-10/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-single-10/model.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/copy/checkpoints/model-single-10/model.ckpt -------------------------------------------------------------------------------- /tasks/copy/checkpoints/model-single-10/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Mostafa-Samir/DNC-tensorflow/5280f5993d8692f21a86ffde2a032dc660dbb693/tasks/copy/checkpoints/model-single-10/model.ckpt.meta -------------------------------------------------------------------------------- /tasks/copy/dnc: -------------------------------------------------------------------------------- 1 | ../../dnc/ -------------------------------------------------------------------------------- /tasks/copy/feedforward_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from dnc.controller import BaseController 4 | 5 | 6 | """ 7 | A 2-Layers feedforward neural network with 128, 256 nodes respectively 8 | """ 9 | 10 | class FeedforwardController(BaseController): 11 | 12 | def network_vars(self): 13 | initial_std = lambda in_nodes: np.min(1e-2, np.sqrt(2.0 / in_nodes)) 14 | input_ = self.nn_input_size 15 | 16 | self.W1 = tf.Variable(tf.truncated_normal([input_, 128], stddev=initial_std(input_)), name='layer1_W') 17 | self.W2 = tf.Variable(tf.truncated_normal([128, 256], stddev=initial_std(128)), name='layer2_W') 18 | self.b1 = tf.Variable(tf.zeros([128]), name='layer1_b') 19 | self.b2 = tf.Variable(tf.zeros([256]), name='layer2_b') 20 | 21 | 22 | def network_op(self, X): 23 | l1_output = tf.matmul(X, self.W1) + self.b1 24 | l1_activation = tf.nn.relu(l1_output) 25 | 26 | l2_output = tf.matmul(l1_activation, self.W2) + self.b2 27 | l2_activation = tf.nn.relu(l2_output) 28 | 29 | return l2_activation 30 | 31 | def initials(self): 32 | initial_std = lambda in_nodes: np.min(1e-2, np.sqrt(2.0 / in_nodes)) 33 | 34 | # defining internal weights of the controller 35 | self.interface_weights = tf.Variable( 36 | tf.truncated_normal([self.nn_output_size, self.interface_vector_size], stddev=initial_std(self.nn_output_size)), 37 | name='interface_weights' 38 | ) 39 | self.nn_output_weights = tf.Variable( 40 | tf.truncated_normal([self.nn_output_size, self.output_size], stddev=initial_std(self.nn_output_size)), 41 | name='nn_output_weights' 42 | ) 43 | self.mem_output_weights = tf.Variable( 44 | tf.truncated_normal([self.word_size * self.read_heads, self.output_size], stddev=initial_std(self.word_size * self.read_heads)), 45 | name='mem_output_weights' 46 | ) 47 | -------------------------------------------------------------------------------- /tasks/copy/train-series.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import getopt 7 | import sys 8 | import os 9 | 10 | from dnc.dnc import DNC 11 | from feedforward_controller import FeedforwardController 12 | 13 | def llprint(message): 14 | sys.stdout.write(message) 15 | sys.stdout.flush() 16 | 17 | def generate_data(batch_size, length, size): 18 | 19 | input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 20 | target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 21 | 22 | sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1)) 23 | 24 | input_data[:, :length, :size - 1] = sequence 25 | input_data[:, length, -1] = 1 # the end symbol 26 | target_output[:, length + 1:, :size - 1] = sequence 27 | 28 | return input_data, target_output 29 | 30 | 31 | def binary_cross_entropy(predictions, targets): 32 | 33 | return tf.reduce_mean( 34 | -1 * targets * tf.log(predictions) - (1 - targets) * tf.log(1 - predictions) 35 | ) 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | dirname = os.path.dirname(__file__) 41 | ckpts_dir = os.path.join(dirname , 'checkpoints') 42 | tb_logs_dir = os.path.join(dirname, 'logs') 43 | 44 | batch_size = 1 45 | input_size = output_size = 6 46 | series_length = 2 47 | sequence_max_length = 22 48 | words_count = 10 49 | word_size = 10 50 | read_heads = 1 51 | 52 | learning_rate = 1e-4 53 | momentum = 0.9 54 | 55 | from_checkpoint = None 56 | iterations = 100000 57 | start_step = 0 58 | 59 | options,_ = getopt.getopt(sys.argv[1:], '', ['checkpoint=', 'iterations=', 'start=', 'length=']) 60 | 61 | for opt in options: 62 | if opt[0] == '--checkpoint': 63 | from_checkpoint = opt[1] 64 | elif opt[0] == '--iterations': 65 | iterations = int(opt[1]) 66 | elif opt[0] == '--start': 67 | start_step = int(opt[1]) 68 | elif opt[0] == '--length': 69 | series_length = int(opt[1]) 70 | sequence_max_length = 11 * int(opt[1]) 71 | 72 | graph = tf.Graph() 73 | 74 | with graph.as_default(): 75 | with tf.Session(graph=graph) as session: 76 | 77 | llprint("Building Computational Graph ... ") 78 | 79 | optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=momentum) 80 | summerizer = tf.train.SummaryWriter(tb_logs_dir, session.graph) 81 | 82 | ncomputer = DNC( 83 | FeedforwardController, 84 | input_size, 85 | output_size, 86 | sequence_max_length, 87 | words_count, 88 | word_size, 89 | read_heads, 90 | batch_size 91 | ) 92 | 93 | output, _ = ncomputer.get_outputs() 94 | squashed_output = tf.clip_by_value(tf.sigmoid(output), 1e-6, 1. - 1e-6) 95 | 96 | loss = binary_cross_entropy(squashed_output, ncomputer.target_output) 97 | 98 | summeries = [] 99 | 100 | gradients = optimizer.compute_gradients(loss) 101 | for i, (grad, var) in enumerate(gradients): 102 | if grad is not None: 103 | summeries.append(tf.histogram_summary(var.name + '/grad', grad)) 104 | gradients[i] = (tf.clip_by_value(grad, -10, 10), var) 105 | 106 | apply_gradients = optimizer.apply_gradients(gradients) 107 | 108 | summeries.append(tf.scalar_summary("Loss", loss)) 109 | 110 | summerize_op = tf.merge_summary(summeries) 111 | no_summerize = tf.no_op() 112 | 113 | llprint("Done!\n") 114 | 115 | llprint("Initializing Variables ... ") 116 | session.run(tf.initialize_all_variables()) 117 | llprint("Done!\n") 118 | 119 | if from_checkpoint is not None: 120 | llprint("Restoring Checkpoint %s ... " % (from_checkpoint)) 121 | ncomputer.restore(session, ckpts_dir, from_checkpoint) 122 | llprint("Done!\n") 123 | 124 | 125 | last_100_losses = [] 126 | 127 | start = 0 if start_step == 0 else start_step + 1 128 | end = start_step + iterations + 1 129 | 130 | for i in xrange(start, end): 131 | llprint("\rIteration %d/%d" % (i, end - 1)) 132 | 133 | input_series = [] 134 | output_series = [] 135 | 136 | for k in range(series_length): 137 | input_data, target_output = generate_data(batch_size, 5, input_size) 138 | input_series.append(input_data) 139 | output_series.append(target_output) 140 | 141 | one_big_input = np.concatenate(input_series, axis=1) 142 | one_big_output = np.concatenate(output_series, axis=1) 143 | 144 | summerize = (i % 100 == 0) 145 | take_checkpoint = (i != 0) and (i % iterations == 0) 146 | 147 | loss_value, _, summary = session.run([ 148 | loss, 149 | apply_gradients, 150 | summerize_op if summerize else no_summerize 151 | ], feed_dict={ 152 | ncomputer.input_data: one_big_input, 153 | ncomputer.target_output: one_big_output, 154 | ncomputer.sequence_length: sequence_max_length 155 | }) 156 | 157 | last_100_losses.append(loss_value) 158 | summerizer.add_summary(summary, i) 159 | 160 | if summerize: 161 | llprint("\n\tAvg. Logistic Loss: %.4f\n" % (np.mean(last_100_losses))) 162 | last_100_losses = [] 163 | 164 | if take_checkpoint: 165 | llprint("\nSaving Checkpoint ... "), 166 | ncomputer.save(session, ckpts_dir, 'step-%d' % (i)) 167 | llprint("Done!\n") 168 | -------------------------------------------------------------------------------- /tasks/copy/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import getopt 7 | import sys 8 | import os 9 | 10 | from dnc.dnc import DNC 11 | from feedforward_controller import FeedforwardController 12 | 13 | def llprint(message): 14 | sys.stdout.write(message) 15 | sys.stdout.flush() 16 | 17 | def generate_data(batch_size, length, size): 18 | 19 | input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 20 | target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32) 21 | 22 | sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1)) 23 | 24 | input_data[:, :length, :size - 1] = sequence 25 | input_data[:, length, -1] = 1 # the end symbol 26 | target_output[:, length + 1:, :size - 1] = sequence 27 | 28 | return input_data, target_output 29 | 30 | 31 | def binary_cross_entropy(predictions, targets): 32 | 33 | return tf.reduce_mean( 34 | -1 * targets * tf.log(predictions) - (1 - targets) * tf.log(1 - predictions) 35 | ) 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | dirname = os.path.dirname(__file__) 41 | ckpts_dir = os.path.join(dirname , 'checkpoints') 42 | tb_logs_dir = os.path.join(dirname, 'logs') 43 | 44 | batch_size = 1 45 | input_size = output_size = 6 46 | sequence_max_length = 10 47 | words_count = 15 48 | word_size = 10 49 | read_heads = 1 50 | 51 | learning_rate = 1e-4 52 | momentum = 0.9 53 | 54 | from_checkpoint = None 55 | iterations = 100000 56 | 57 | options,_ = getopt.getopt(sys.argv[1:], '', ['checkpoint=', 'iterations=']) 58 | 59 | for opt in options: 60 | if opt[0] == '--checkpoint': 61 | from_checkpoint = opt[1] 62 | elif opt[0] == '--iterations': 63 | iterations = int(opt[1]) 64 | 65 | graph = tf.Graph() 66 | 67 | with graph.as_default(): 68 | with tf.Session(graph=graph) as session: 69 | 70 | llprint("Building Computational Graph ... ") 71 | 72 | optimizer = tf.train.RMSPropOptimizer(learning_rate, momentum=momentum) 73 | 74 | ncomputer = DNC( 75 | FeedforwardController, 76 | input_size, 77 | output_size, 78 | 2 * sequence_max_length + 1, 79 | words_count, 80 | word_size, 81 | read_heads, 82 | batch_size 83 | ) 84 | 85 | # squash the DNC output between 0 and 1 86 | output, _ = ncomputer.get_outputs() 87 | squashed_output = tf.clip_by_value(tf.sigmoid(output), 1e-6, 1. - 1e-6) 88 | 89 | loss = binary_cross_entropy(squashed_output, ncomputer.target_output) 90 | 91 | summeries = [] 92 | 93 | gradients = optimizer.compute_gradients(loss) 94 | for i, (grad, var) in enumerate(gradients): 95 | if grad is not None: 96 | summeries.append(tf.histogram_summary(var.name + '/grad', grad)) 97 | gradients[i] = (tf.clip_by_value(grad, -10, 10), var) 98 | 99 | apply_gradients = optimizer.apply_gradients(gradients) 100 | 101 | summeries.append(tf.scalar_summary("Loss", loss)) 102 | 103 | summerize_op = tf.merge_summary(summeries) 104 | no_summerize = tf.no_op() 105 | 106 | summerizer = tf.train.SummaryWriter(tb_logs_dir, session.graph) 107 | 108 | llprint("Done!\n") 109 | 110 | llprint("Initializing Variables ... ") 111 | session.run(tf.initialize_all_variables()) 112 | llprint("Done!\n") 113 | 114 | if from_checkpoint is not None: 115 | llprint("Restoring Checkpoint %s ... " % (from_checkpoint)) 116 | ncomputer.restore(session, ckpts_dir, from_checkpoint) 117 | llprint("Done!\n") 118 | 119 | 120 | last_100_losses = [] 121 | 122 | for i in xrange(iterations + 1): 123 | llprint("\rIteration %d/%d" % (i, iterations)) 124 | 125 | random_length = np.random.randint(1, sequence_max_length + 1) 126 | input_data, target_output = generate_data(batch_size, random_length, input_size) 127 | 128 | summerize = (i % 100 == 0) 129 | take_checkpoint = (i != 0) and (i % iterations == 0) 130 | 131 | loss_value, _, summary = session.run([ 132 | loss, 133 | apply_gradients, 134 | summerize_op if summerize else no_summerize 135 | ], feed_dict={ 136 | ncomputer.input_data: input_data, 137 | ncomputer.target_output: target_output, 138 | ncomputer.sequence_length: 2 * random_length + 1 139 | }) 140 | 141 | last_100_losses.append(loss_value) 142 | summerizer.add_summary(summary, i) 143 | 144 | if summerize: 145 | llprint("\n\tAvg. Logistic Loss: %.4f\n" % (np.mean(last_100_losses))) 146 | last_100_losses = [] 147 | 148 | if take_checkpoint: 149 | llprint("\nSaving Checkpoint ... "), 150 | ncomputer.save(session, ckpts_dir, 'step-%d' % (i)) 151 | llprint("Done!\n") 152 | -------------------------------------------------------------------------------- /unit-tests/controller.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import unittest 4 | 5 | from dnc.controller import BaseController 6 | 7 | class DummyController(BaseController): 8 | def network_vars(self): 9 | self.W = tf.Variable(tf.truncated_normal([self.nn_input_size, 64])) 10 | self.b = tf.Variable(tf.zeros([64])) 11 | 12 | def network_op(self, X): 13 | return tf.matmul(X, self.W) + self.b 14 | 15 | 16 | class DummyRecurrentController(BaseController): 17 | def network_vars(self): 18 | self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(64) 19 | self.state = tf.Variable(tf.zeros([self.batch_size, 64]), trainable=False) 20 | self.output = tf.Variable(tf.zeros([self.batch_size, 64]), trainable=False) 21 | 22 | def network_op(self, X, state): 23 | X = tf.convert_to_tensor(X) 24 | return self.lstm_cell(X, state) 25 | 26 | def update_state(self, new_state): 27 | return tf.group( 28 | self.output.assign(new_state[0]), 29 | self.state.assign(new_state[1]) 30 | ) 31 | 32 | def get_state(self): 33 | return (self.output, self.state) 34 | 35 | 36 | class DNCControllerTest(unittest.TestCase): 37 | 38 | def test_construction(self): 39 | graph = tf.Graph() 40 | with graph.as_default(): 41 | with tf.Session(graph=graph) as session: 42 | 43 | controller = DummyController(10, 10, 2, 5) 44 | rcontroller = DummyRecurrentController(10, 10, 2, 5, 1) 45 | 46 | self.assertFalse(controller.has_recurrent_nn) 47 | self.assertEqual(controller.nn_input_size, 20) 48 | self.assertEqual(controller.interface_vector_size, 38) 49 | self.assertEqual(controller.interface_weights.get_shape().as_list(), [64, 38]) 50 | self.assertEqual(controller.nn_output_weights.get_shape().as_list(), [64, 10]) 51 | self.assertEqual(controller.mem_output_weights.get_shape().as_list(), [10, 10]) 52 | 53 | self.assertTrue(rcontroller.has_recurrent_nn) 54 | self.assertEqual(rcontroller.nn_input_size, 20) 55 | self.assertEqual(rcontroller.interface_vector_size, 38) 56 | self.assertEqual(rcontroller.interface_weights.get_shape().as_list(), [64, 38]) 57 | self.assertEqual(rcontroller.nn_output_weights.get_shape().as_list(), [64, 10]) 58 | self.assertEqual(rcontroller.mem_output_weights.get_shape().as_list(), [10, 10]) 59 | 60 | 61 | 62 | def test_get_nn_output_size(self): 63 | graph = tf.Graph() 64 | with graph.as_default(): 65 | with tf.Session(graph=graph) as Session: 66 | 67 | controller = DummyController(10, 10, 2, 5) 68 | rcontroller = DummyRecurrentController(10, 10, 2, 5, 1) 69 | 70 | self.assertEqual(controller.get_nn_output_size(), 64) 71 | self.assertEqual(rcontroller.get_nn_output_size(), 64) 72 | 73 | 74 | def test_parse_interface_vector(self): 75 | graph = tf.Graph() 76 | with graph.as_default(): 77 | with tf.Session(graph=graph) as session: 78 | 79 | controller = DummyController(10, 10, 2, 5) 80 | zeta = np.random.uniform(-2, 2, (2, 38)).astype(np.float32) 81 | 82 | read_keys = np.reshape(zeta[:, :10], (-1, 5, 2)) 83 | read_strengths = 1 + np.log(np.exp(np.reshape(zeta[:, 10:12], (-1, 2, ))) + 1) 84 | write_key = np.reshape(zeta[:, 12:17], (-1, 5, 1)) 85 | write_strength = 1 + np.log(np.exp(np.reshape(zeta[:, 17], (-1, 1))) + 1) 86 | erase_vector = 1.0 / (1 + np.exp(-1 * np.reshape(zeta[:, 18:23], (-1, 5)))) 87 | write_vector = np.reshape(zeta[:, 23:28], (-1, 5)) 88 | free_gates = 1.0 / (1 + np.exp(-1 * np.reshape(zeta[:, 28:30], (-1, 2)))) 89 | allocation_gate = 1.0 / (1 + np.exp(-1 * zeta[:, 30, np.newaxis])) 90 | write_gate = 1.0 / (1 + np.exp(-1 * zeta[:, 31, np.newaxis])) 91 | read_modes = np.reshape(zeta[:, 32:], (-1, 3, 2)) 92 | 93 | read_modes = np.transpose(read_modes, [0, 2, 1]) 94 | read_modes = np.reshape(read_modes, (-1, 3)) 95 | read_modes = np.exp(read_modes) / np.sum(np.exp(read_modes), axis=-1, keepdims=True) 96 | read_modes = np.reshape(read_modes, (2, 2, 3)) 97 | read_modes = np.transpose(read_modes, [0, 2, 1]) 98 | 99 | op = controller.parse_interface_vector(zeta) 100 | session.run(tf.initialize_all_variables()) 101 | parsed = session.run(op) 102 | 103 | self.assertTrue(np.allclose(parsed['read_keys'], read_keys)) 104 | self.assertTrue(np.allclose(parsed['read_strengths'], read_strengths)) 105 | self.assertTrue(np.allclose(parsed['write_key'], write_key)) 106 | self.assertTrue(np.allclose(parsed['write_strength'], write_strength)) 107 | self.assertTrue(np.allclose(parsed['erase_vector'], erase_vector)) 108 | self.assertTrue(np.allclose(parsed['write_vector'], write_vector)) 109 | self.assertTrue(np.allclose(parsed['free_gates'], free_gates)) 110 | self.assertTrue(np.allclose(parsed['allocation_gate'], allocation_gate)) 111 | self.assertTrue(np.allclose(parsed['write_gate'], write_gate)) 112 | self.assertTrue(np.allclose(parsed['read_modes'], read_modes)) 113 | 114 | 115 | def test_process_input(self): 116 | graph = tf.Graph() 117 | with graph.as_default(): 118 | with tf.Session(graph=graph) as session: 119 | 120 | controller = DummyController(10, 10, 2, 5) 121 | rcontroller = DummyRecurrentController(10, 10, 2, 5, 2) 122 | 123 | input_batch = np.random.uniform(0, 1, (2, 10)).astype(np.float32) 124 | last_read_vectors = np.random.uniform(-1, 1, (2, 5, 2)).astype(np.float32) 125 | 126 | v_op, zeta_op = controller.process_input(input_batch, last_read_vectors) 127 | rv_op, rzeta_op, rs_op = rcontroller.process_input(input_batch, last_read_vectors, rcontroller.get_state()) 128 | 129 | session.run(tf.initialize_all_variables()) 130 | v, zeta = session.run([v_op, zeta_op]) 131 | rv, rzeta, rs = session.run([rv_op, rzeta_op, rs_op]) 132 | 133 | self.assertEqual(v.shape, (2, 10)) 134 | self.assertEqual(np.concatenate([np.reshape(val, (2, -1)) for _,val in zeta.iteritems()], axis=1).shape, (2, 38)) 135 | 136 | self.assertEqual(rv.shape, (2, 10)) 137 | self.assertEqual(np.concatenate([np.reshape(val, (2, -1)) for _,val in rzeta.iteritems()], axis=1).shape, (2, 38)) 138 | self.assertEqual([_s.shape for _s in rs], [(2, 64), (2, 64)]) 139 | 140 | 141 | def test_final_output(self): 142 | graph = tf.Graph() 143 | with graph.as_default(): 144 | with tf.Session(graph=graph) as session: 145 | 146 | controller = DummyController(10, 10, 2, 5) 147 | output_batch = np.random.uniform(0, 1, (2, 10)).astype(np.float32) 148 | new_read_vectors = np.random.uniform(-1, 1, (2, 5, 2)).astype(np.float32) 149 | 150 | op = controller.final_output(output_batch, new_read_vectors) 151 | session.run(tf.initialize_all_variables()) 152 | y = session.run(op) 153 | 154 | self.assertEqual(y.shape, (2, 10)) 155 | 156 | 157 | if __name__ == '__main__': 158 | unittest.main(verbosity=2) 159 | -------------------------------------------------------------------------------- /unit-tests/dnc: -------------------------------------------------------------------------------- 1 | ../dnc -------------------------------------------------------------------------------- /unit-tests/dnc.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import unittest 4 | import shutil 5 | import os 6 | 7 | from dnc.dnc import DNC 8 | from dnc.memory import Memory 9 | from dnc.controller import BaseController 10 | 11 | class DummyController(BaseController): 12 | def network_vars(self): 13 | self.W = tf.Variable(tf.truncated_normal([self.nn_input_size, 64]), name='layer_W') 14 | self.b = tf.Variable(tf.zeros([64]), name='layer_b') 15 | 16 | def network_op(self, X): 17 | return tf.matmul(X, self.W) + self.b 18 | 19 | class DummyRecurrentController(BaseController): 20 | def network_vars(self): 21 | self.lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(64) 22 | self.state = tf.Variable(tf.zeros([self.batch_size, 64]), trainable=False) 23 | self.output = tf.Variable(tf.zeros([self.batch_size, 64]), trainable=False) 24 | 25 | def network_op(self, X, state): 26 | X = tf.convert_to_tensor(X) 27 | return self.lstm_cell(X, state) 28 | 29 | def update_state(self, new_state): 30 | return tf.group( 31 | self.output.assign(new_state[0]), 32 | self.state.assign(new_state[1]) 33 | ) 34 | 35 | def get_state(self): 36 | return (self.output, self.state) 37 | 38 | class DNCTest(unittest.TestCase): 39 | 40 | @classmethod 41 | def _clear(cls): 42 | try: 43 | current_dir = os.path.dirname(__file__) 44 | ckpts_dir = os.path.join(current_dir, 'checkpoints') 45 | 46 | shutil.rmtree(ckpts_dir) 47 | except: 48 | # swallow error 49 | return 50 | 51 | @classmethod 52 | def setUpClass(cls): 53 | cls._clear() 54 | 55 | 56 | @classmethod 57 | def tearDownClass(cls): 58 | cls._clear() 59 | 60 | 61 | def test_construction(self): 62 | graph = tf.Graph() 63 | with graph.as_default(): 64 | with tf.Session(graph=graph) as session: 65 | 66 | computer = DNC(DummyController, 10, 20, 10, 10, 64, 1) 67 | rcomputer = DNC(DummyRecurrentController, 10, 20, 10, 10, 64, 1) 68 | 69 | self.assertEqual(computer.input_size, 10) 70 | self.assertEqual(computer.output_size, 20) 71 | self.assertEqual(computer.words_num, 10) 72 | self.assertEqual(computer.word_size, 64) 73 | self.assertEqual(computer.read_heads, 1) 74 | self.assertEqual(computer.batch_size, 1) 75 | 76 | self.assertTrue(isinstance(computer.memory, Memory)) 77 | self.assertTrue(isinstance(computer.controller, DummyController)) 78 | self.assertTrue(isinstance(rcomputer.controller, DummyRecurrentController)) 79 | 80 | 81 | def test_call(self): 82 | graph = tf.Graph() 83 | with graph.as_default(): 84 | with tf.Session(graph=graph) as session: 85 | 86 | computer = DNC(DummyController, 10, 20, 10, 10, 64, 2, batch_size=3) 87 | rcomputer = DNC(DummyRecurrentController, 10, 20, 10, 10, 64, 2, batch_size=3) 88 | input_batches = np.random.uniform(0, 1, (3, 5, 10)).astype(np.float32) 89 | 90 | session.run(tf.initialize_all_variables()) 91 | out_view = session.run(computer.get_outputs(), feed_dict={ 92 | computer.input_data: input_batches, 93 | computer.sequence_length: 5 94 | }) 95 | out, view = out_view 96 | 97 | rout_rview, ro, rs = session.run([ 98 | rcomputer.get_outputs(), 99 | rcomputer.controller.get_state()[0], 100 | rcomputer.controller.get_state()[1] 101 | ], feed_dict={ 102 | rcomputer.input_data: input_batches, 103 | rcomputer.sequence_length: 5 104 | }) 105 | rout, rview = rout_rview 106 | 107 | self.assertEqual(out.shape, (3, 5, 20)) 108 | self.assertEqual(view['free_gates'].shape, (3, 5, 2)) 109 | self.assertEqual(view['allocation_gates'].shape, (3, 5, 1)) 110 | self.assertEqual(view['write_gates'].shape, (3, 5, 1)) 111 | self.assertEqual(view['read_weightings'].shape, (3, 5, 10, 2)) 112 | self.assertEqual(view['write_weightings'].shape, (3, 5, 10)) 113 | 114 | 115 | self.assertEqual(rout.shape, (3, 5, 20)) 116 | self.assertEqual(rview['free_gates'].shape, (3, 5, 2)) 117 | self.assertEqual(rview['allocation_gates'].shape, (3, 5, 1)) 118 | self.assertEqual(rview['write_gates'].shape, (3, 5, 1)) 119 | self.assertEqual(rview['read_weightings'].shape, (3, 5, 10, 2)) 120 | self.assertEqual(rview['write_weightings'].shape, (3, 5, 10)) 121 | 122 | 123 | def test_save(self): 124 | graph = tf.Graph() 125 | with graph.as_default(): 126 | with tf.Session(graph=graph) as session: 127 | 128 | computer = DNC(DummyController, 10, 20, 10, 10, 64, 2, batch_size=2) 129 | session.run(tf.initialize_all_variables()) 130 | current_dir = os.path.dirname(__file__) 131 | ckpts_dir = os.path.join(current_dir, 'checkpoints') 132 | 133 | computer.save(session, ckpts_dir, 'test-save') 134 | 135 | self.assert_(True) 136 | 137 | 138 | def test_restore(self): 139 | 140 | current_dir = os.path.dirname(__file__) 141 | ckpts_dir = os.path.join(current_dir, 'checkpoints') 142 | 143 | model1_output, model1_memview = None, None 144 | sample_input = np.random.uniform(0, 1, (2, 5, 10)).astype(np.float32) 145 | sample_seq_len = 5 146 | 147 | graph1 = tf.Graph() 148 | with graph1.as_default(): 149 | with tf.Session(graph=graph1) as session1: 150 | 151 | computer = DNC(DummyController, 10, 20, 10, 10, 64, 2, batch_size=2) 152 | session1.run(tf.initialize_all_variables()) 153 | 154 | saved_weights = session1.run([ 155 | computer.controller.nn_output_weights, 156 | computer.controller.interface_weights, 157 | computer.controller.mem_output_weights, 158 | computer.controller.W, 159 | computer.controller.b 160 | ]) 161 | 162 | computer.save(session1, ckpts_dir, 'test-restore') 163 | 164 | graph2 = tf.Graph() 165 | with graph2.as_default(): 166 | with tf.Session(graph=graph2) as session2: 167 | 168 | computer = DNC(DummyController, 10, 20, 10, 10, 64, 2, batch_size=2) 169 | session2.run(tf.initialize_all_variables()) 170 | computer.restore(session2, ckpts_dir, 'test-restore') 171 | 172 | restored_weights = session2.run([ 173 | computer.controller.nn_output_weights, 174 | computer.controller.interface_weights, 175 | computer.controller.mem_output_weights, 176 | computer.controller.W, 177 | computer.controller.b 178 | ]) 179 | 180 | self.assertTrue(np.product([np.array_equal(restored_weights[i], saved_weights[i]) for i in range(5)])) 181 | 182 | if __name__ == '__main__': 183 | unittest.main(verbosity=2) 184 | -------------------------------------------------------------------------------- /unit-tests/memory.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import unittest 4 | 5 | from dnc.memory import Memory 6 | 7 | def random_softmax(shape, axis): 8 | rand = np.random.uniform(0, 1, shape).astype(np.float32) 9 | return np.exp(rand) / np.sum(np.exp(rand), axis=axis, keepdims=True) 10 | 11 | class DNCMemoryTests(unittest.TestCase): 12 | 13 | def test_construction(self): 14 | graph = tf.Graph() 15 | with graph.as_default(): 16 | with tf.Session(graph=graph) as session: 17 | 18 | mem = Memory(4, 5, 2, 2) 19 | session.run(tf.initialize_all_variables()) 20 | 21 | self.assertEqual(mem.words_num, 4) 22 | self.assertEqual(mem.word_size, 5) 23 | self.assertEqual(mem.read_heads, 2) 24 | self.assertEqual(mem.batch_size, 2) 25 | 26 | 27 | def test_init_memory(self): 28 | graph = tf.Graph() 29 | with graph.as_default(): 30 | with tf.Session(graph=graph) as session: 31 | 32 | mem = Memory(4, 5, 2, 2) 33 | M, u, p, L, ww, rw, r = session.run(mem.init_memory()) 34 | 35 | self.assertEqual(M.shape, (2, 4, 5)) 36 | self.assertEqual(u.shape, (2, 4)) 37 | self.assertEqual(L.shape, (2, 4, 4)) 38 | self.assertEqual(ww.shape, (2, 4)) 39 | self.assertEqual(rw.shape, (2, 4, 2)) 40 | self.assertEqual(r.shape, (2, 5, 2)) 41 | self.assertEqual(p.shape, (2, 4)) 42 | 43 | def test_lookup_weighting(self): 44 | graph = tf.Graph() 45 | with graph.as_default(): 46 | with tf.Session(graph=graph) as session: 47 | 48 | mem = Memory(4, 5, 2, 2) 49 | initial_mem = np.random.uniform(0, 1, (2, 4, 5)).astype(np.float32) 50 | keys = np.random.uniform(0, 1, (2, 5, 2)).astype(np.float32) 51 | strengths = np.random.uniform(0, 1, (2 ,2)).astype(np.float32) 52 | 53 | norm_mem = initial_mem / np.sqrt(np.sum(initial_mem ** 2, axis=2, keepdims=True)) 54 | norm_keys = keys/ np.sqrt(np.sum(keys ** 2, axis = 1, keepdims=True)) 55 | sim = np.matmul(norm_mem, norm_keys) 56 | sim = sim * strengths[:, np.newaxis, :] 57 | predicted_wieghts = np.exp(sim) / np.sum(np.exp(sim), axis=1, keepdims=True) 58 | 59 | memory_matrix = tf.convert_to_tensor(initial_mem) 60 | op = mem.get_lookup_weighting(memory_matrix, keys, strengths) 61 | c = session.run(op) 62 | 63 | self.assertEqual(c.shape, (2, 4, 2)) 64 | self.assertTrue(np.allclose(c, predicted_wieghts)) 65 | 66 | 67 | def test_update_usage_vector(self): 68 | graph = tf.Graph() 69 | with graph.as_default(): 70 | with tf.Session(graph=graph) as session: 71 | 72 | mem = Memory(4, 5, 2, 2) 73 | free_gates = np.random.uniform(0, 1, (2, 2)).astype(np.float32) 74 | init_read_weightings = random_softmax((2, 4, 2), axis=1) 75 | init_write_weightings = random_softmax((2, 4), axis=1) 76 | init_usage = np.random.uniform(0, 1, (2, 4)).astype(np.float32) 77 | 78 | psi = np.product(1 - init_read_weightings * free_gates[:, np.newaxis, :], axis=2) 79 | predicted_usage = (init_usage + init_write_weightings - init_usage * init_write_weightings) * psi 80 | 81 | 82 | read_weightings = tf.convert_to_tensor(init_read_weightings) 83 | write_weighting = tf.convert_to_tensor(init_write_weightings) 84 | usage_vector = tf.convert_to_tensor(init_usage) 85 | 86 | op = mem.update_usage_vector(usage_vector, read_weightings, write_weighting, free_gates) 87 | u = session.run(op) 88 | 89 | self.assertEqual(u.shape, (2, 4)) 90 | self.assertTrue(np.array_equal(u, predicted_usage)) 91 | 92 | 93 | def test_get_allocation_weighting(self): 94 | graph = tf.Graph() 95 | with graph.as_default(): 96 | with tf.Session(graph=graph) as session: 97 | 98 | mem = Memory(4, 5, 2, 2) 99 | mock_usage = np.random.uniform(0.01, 1, (2, 4)).astype(np.float32) 100 | sorted_usage = np.sort(mock_usage, axis=1) 101 | free_list = np.argsort(mock_usage, axis=1) 102 | 103 | predicted_weights = np.zeros((2, 4)).astype(np.float32) 104 | for i in range(2): 105 | for j in range(4): 106 | product_list = [mock_usage[i, free_list[i,k]] for k in range(j)] 107 | predicted_weights[i, free_list[i,j]] = (1 - mock_usage[i, free_list[i, j]]) * np.product(product_list) 108 | 109 | op = mem.get_allocation_weighting(sorted_usage, free_list) 110 | a = session.run(op) 111 | 112 | self.assertEqual(a.shape, (2, 4)) 113 | self.assertTrue(np.allclose(a, predicted_weights)) 114 | 115 | 116 | def test_updated_write_weighting(self): 117 | graph = tf.Graph() 118 | with graph.as_default(): 119 | with tf.Session(graph=graph) as session: 120 | 121 | mem = Memory(4, 5, 2, 2) 122 | write_gate = np.random.uniform(0, 1, (2,1)).astype(np.float32) 123 | allocation_gate = np.random.uniform(0, 1, (2,1)).astype(np.float32) 124 | lookup_weighting = random_softmax((2, 4, 1), axis=1) 125 | allocation_weighting = random_softmax((2, 4), axis=1) 126 | 127 | predicted_weights = write_gate * (allocation_gate * allocation_weighting + (1 - allocation_gate) * np.squeeze(lookup_weighting)) 128 | 129 | op = mem.update_write_weighting(lookup_weighting, allocation_weighting, write_gate, allocation_gate) 130 | w_w = session.run(op) 131 | 132 | self.assertEqual(w_w.shape, (2,4)) 133 | self.assertTrue(np.allclose(w_w, predicted_weights)) 134 | 135 | 136 | def test_update_memory(self): 137 | graph = tf.Graph() 138 | with graph.as_default(): 139 | with tf.Session(graph=graph) as session: 140 | 141 | mem = Memory(4, 5, 2, 2) 142 | write_weighting = random_softmax((2, 4), axis=1) 143 | write_vector = np.random.uniform(0, 1, (2, 5)).astype(np.float32) 144 | erase_vector = np.random.uniform(0, 1, (2, 5)).astype(np.float32) 145 | memory_matrix = np.random.uniform(-1, 1, (2, 4, 5)).astype(np.float32) 146 | 147 | ww = write_weighting[:, :, np.newaxis] 148 | v, e = write_vector[:, np.newaxis, :], erase_vector[:, np.newaxis, :] 149 | predicted = memory_matrix * (1 - np.matmul(ww, e)) + np.matmul(ww, v) 150 | 151 | memory_matrix = tf.convert_to_tensor(memory_matrix) 152 | 153 | op = mem.update_memory(memory_matrix, write_weighting, write_vector, erase_vector) 154 | M = session.run(op) 155 | 156 | self.assertEqual(M.shape, (2, 4, 5)) 157 | self.assertTrue(np.allclose(M, predicted)) 158 | 159 | def test_update_precedence_vector(self): 160 | graph = tf.Graph() 161 | with graph.as_default(): 162 | with tf.Session(graph=graph) as session: 163 | 164 | mem = Memory(4, 5, 2, 2) 165 | write_weighting = random_softmax((2, 4), axis=1) 166 | initial_precedence = random_softmax((2, 4), axis=1) 167 | predicted = (1 - write_weighting.sum(axis=1, keepdims=True)) * initial_precedence + write_weighting 168 | 169 | precedence_vector = tf.convert_to_tensor(initial_precedence) 170 | 171 | op = mem.update_precedence_vector(precedence_vector, write_weighting) 172 | p = session.run(op) 173 | 174 | self.assertEqual(p.shape, (2,4)) 175 | self.assertTrue(np.allclose(p, predicted)) 176 | 177 | 178 | def test_update_link_matrix(self): 179 | graph = tf.Graph() 180 | with graph.as_default(): 181 | with tf.Session(graph=graph) as session: 182 | 183 | mem = Memory(4, 5, 2, 2) 184 | _write_weighting = random_softmax((2, 4), axis=1) 185 | _precedence_vector = random_softmax((2, 4), axis=1) 186 | initial_link = np.random.uniform(0, 1, (2, 4, 4)).astype(np.float32) 187 | np.fill_diagonal(initial_link[0,:], 0) 188 | np.fill_diagonal(initial_link[1,:], 0) 189 | 190 | # calculate the updated link iteratively as in paper 191 | # to check the correcteness of the vectorized implementation 192 | predicted = np.zeros((2,4,4), dtype=np.float32) 193 | for i in range(4): 194 | for j in range(4): 195 | if i != j: 196 | reset_factor = (1 - _write_weighting[:,i] - _write_weighting[:,j]) 197 | predicted[:, i, j] = reset_factor * initial_link[:, i , j] + _write_weighting[:, i] * _precedence_vector[:, j] 198 | 199 | link_matrix = tf.convert_to_tensor(initial_link) 200 | precedence_vector = tf.convert_to_tensor(_precedence_vector) 201 | 202 | write_weighting = tf.constant(_write_weighting) 203 | 204 | op = mem.update_link_matrix(precedence_vector, link_matrix, write_weighting) 205 | L = session.run(op) 206 | 207 | self.assertTrue(np.allclose(L, predicted)) 208 | 209 | 210 | def test_get_directional_weightings(self): 211 | graph = tf.Graph() 212 | with graph.as_default(): 213 | with tf.Session(graph=graph) as session: 214 | 215 | mem = Memory(4, 5, 2, 2) 216 | _link_matrix = np.random.uniform(0, 1, (2, 4, 4)).astype(np.float32) 217 | _read_weightings = random_softmax((2, 4, 2), axis=1) 218 | predicted_forward = np.matmul(_link_matrix, _read_weightings) 219 | predicted_backward = np.matmul(np.transpose(_link_matrix, [0, 2, 1]), _read_weightings) 220 | 221 | read_weightings = tf.convert_to_tensor(_read_weightings) 222 | 223 | fop, bop = mem.get_directional_weightings(read_weightings, _link_matrix) 224 | 225 | forward_weighting, backward_weighting = session.run([fop, bop]) 226 | 227 | self.assertTrue(np.allclose(forward_weighting, predicted_forward)) 228 | self.assertTrue(np.allclose(backward_weighting, predicted_backward)) 229 | 230 | 231 | 232 | def test_update_read_weightings(self): 233 | graph = tf.Graph() 234 | with graph.as_default(): 235 | with tf.Session(graph=graph) as session: 236 | 237 | mem = Memory(4, 5, 2, 2) 238 | lookup_weightings = random_softmax((2, 4, 2), axis=1) 239 | forward_weighting = random_softmax((2, 4, 2), axis=1) 240 | backward_weighting = random_softmax((2, 4, 2), axis=1) 241 | read_mode = random_softmax((2, 3, 2), axis=1) 242 | predicted_weights = np.zeros((2, 4, 2)).astype(np.float32) 243 | 244 | # calculate the predicted weights using iterative method from paper 245 | # to check the correcteness of the vectorized implementation 246 | for i in range(2): 247 | predicted_weights[:, :, i] = read_mode[:, 0,i, np.newaxis] * backward_weighting[:, :, i] + read_mode[:, 1, i, np.newaxis] * lookup_weightings[:, :, i] + read_mode[:, 2, i, np.newaxis] * forward_weighting[:, :, i] 248 | 249 | op = mem.update_read_weightings(lookup_weightings, forward_weighting, backward_weighting, read_mode) 250 | session.run(tf.initialize_all_variables()) 251 | w_r = session.run(op) 252 | #updated_read_weightings = session.run(mem.read_weightings.value()) 253 | 254 | self.assertTrue(np.allclose(w_r, predicted_weights)) 255 | #self.assertTrue(np.allclose(updated_read_weightings, predicted_weights)) 256 | 257 | 258 | def test_update_read_vectors(self): 259 | graph = tf.Graph() 260 | with graph.as_default(): 261 | with tf.Session(graph = graph) as session: 262 | 263 | mem = Memory(4, 5, 2, 4) 264 | memory_matrix = np.random.uniform(-1, 1, (4, 4, 5)).astype(np.float32) 265 | read_weightings = random_softmax((4, 4, 2), axis=1) 266 | predicted = np.matmul(np.transpose(memory_matrix, [0, 2, 1]), read_weightings) 267 | 268 | op = mem.update_read_vectors(memory_matrix, read_weightings) 269 | session.run(tf.initialize_all_variables()) 270 | r = session.run(op) 271 | #updated_read_vectors = session.run(mem.read_vectors.value()) 272 | 273 | self.assertTrue(np.allclose(r, predicted)) 274 | #self.assertTrue(np.allclose(updated_read_vectors, predicted)) 275 | 276 | def test_write(self): 277 | graph = tf.Graph() 278 | with graph.as_default(): 279 | with tf.Session(graph = graph) as session: 280 | 281 | mem = Memory(4, 5, 2, 1) 282 | M, u, p, L, ww, rw, r = session.run(mem.init_memory()) 283 | key = np.random.uniform(0, 1, (1, 5, 1)).astype(np.float32) 284 | strength = np.random.uniform(0, 1, (1, 1)).astype(np.float32) 285 | free_gates = np.random.uniform(0, 1, (1, 2)).astype(np.float32) 286 | write_gate = np.random.uniform(0, 1, (1, 1)).astype(np.float32) 287 | allocation_gate = np.random.uniform(0, 1, (1,1)).astype(np.float32) 288 | write_vector = np.random.uniform(0, 1, (1, 5)).astype(np.float32) 289 | erase_vector = np.zeros((1, 5)).astype(np.float32) 290 | 291 | u_op, ww_op, M_op, L_op, p_op = mem.write( 292 | M, u, rw, ww, p, L, 293 | key, strength, free_gates, allocation_gate, 294 | write_gate , write_vector, erase_vector 295 | ) 296 | session.run(tf.initialize_all_variables()) 297 | u, ww, M, L, p = session.run([u_op, ww_op, M_op, L_op, p_op]) 298 | 299 | self.assertEqual(u.shape, (1, 4)) 300 | self.assertEqual(ww.shape, (1, 4)) 301 | self.assertEqual(M.shape, (1, 4, 5)) 302 | self.assertEqual(L.shape, (1, 4, 4)) 303 | self.assertEqual(p.shape, (1, 4)) 304 | 305 | 306 | 307 | def test_read(self): 308 | graph = tf.Graph() 309 | with graph.as_default(): 310 | with tf.Session(graph = graph) as session: 311 | mem = Memory(4, 5, 2, 1) 312 | M, u, p, L, ww, rw, r = session.run(mem.init_memory()) 313 | keys = np.random.uniform(0, 1, (1, 5, 2)).astype(np.float32) 314 | strengths = np.random.uniform(0, 1, (1, 2)).astype(np.float32) 315 | link_matrix = np.random.uniform(0, 1, (1, 4, 4)).astype(np.float32) 316 | read_modes = random_softmax((1, 3, 2), axis=1).astype(np.float32) 317 | memory_matrix = np.random.uniform(-1, 1, (1, 4, 5)).astype(np.float32) 318 | 319 | wr_op, r_op = mem.read(memory_matrix, rw, keys, strengths, link_matrix, read_modes) 320 | session.run(tf.initialize_all_variables()) 321 | wr, r = session.run([wr_op, r_op]) 322 | 323 | self.assertEqual(wr.shape, (1, 4, 2)) 324 | self.assertEqual(r.shape, (1, 5, 2)) 325 | 326 | 327 | if __name__ == '__main__': 328 | unittest.main(verbosity=2) 329 | -------------------------------------------------------------------------------- /unit-tests/utility.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import unittest 4 | 5 | import dnc.utility as util 6 | 7 | class DNCUtilityTests(unittest.TestCase): 8 | 9 | def test_pairwise_add(self): 10 | graph = tf.Graph() 11 | with graph.as_default(): 12 | with tf.Session(graph=graph) as session: 13 | 14 | _u = np.array([5, 6]) 15 | _v = np.array([1, 2]) 16 | 17 | predicted_U = np.array([[10, 11], [11, 12]]) 18 | predicted_UV = np.array([[6, 7], [7, 8]]) 19 | 20 | u = tf.constant(_u) 21 | v = tf.constant(_v) 22 | 23 | U_op = util.pairwise_add(u) 24 | UV_op = util.pairwise_add(u, v) 25 | 26 | U, UV = session.run([U_op, UV_op]) 27 | 28 | self.assertTrue(np.allclose(U, predicted_U)) 29 | self.assertTrue(np.allclose(UV, predicted_UV)) 30 | 31 | 32 | def test_pairwise_add_with_batch(self): 33 | graph = tf.Graph() 34 | with graph.as_default(): 35 | with tf.Session(graph=graph) as session: 36 | 37 | _u = np.array([[5, 6], [7, 8]]) 38 | _v = np.array([[1, 2], [3, 4]]) 39 | 40 | predicted_U = np.array([[[10, 11], [11, 12]], [[14, 15], [15, 16]]]) 41 | predicted_UV = np.array([[[6, 7], [7, 8]], [[10, 11], [11, 12]]]) 42 | 43 | u = tf.constant(_u) 44 | v = tf.constant(_v) 45 | 46 | U_op = util.pairwise_add(u, is_batch=True) 47 | UV_op = util.pairwise_add(u, v, is_batch=True) 48 | 49 | U, UV = session.run([U_op, UV_op]) 50 | 51 | self.assertTrue(np.allclose(U, predicted_U)) 52 | self.assertTrue(np.allclose(UV, predicted_UV)) 53 | 54 | 55 | def test_unpack_into_tensorarray(self): 56 | graph = tf.Graph() 57 | with graph.as_default(): 58 | with tf.Session(graph=graph) as session: 59 | 60 | T = tf.random_normal([5, 10, 7, 7]) 61 | ta = util.unpack_into_tensorarray(T, axis=1) 62 | 63 | vT, vTA5 = session.run([T, ta.read(5)]) 64 | 65 | self.assertEqual(vTA5.shape, (5, 7, 7)) 66 | self.assertTrue(np.allclose(vT[:, 5, :, :], vTA5)) 67 | 68 | 69 | def test_pack_into_tensor(self): 70 | graph = tf.Graph() 71 | with graph.as_default(): 72 | with tf.Session(graph=graph) as session: 73 | 74 | T = tf.random_normal([5, 10, 7, 7]) 75 | ta = util.unpack_into_tensorarray(T, axis=1) 76 | pT = util.pack_into_tensor(ta, axis=1) 77 | 78 | vT, vPT = session.run([T, pT]) 79 | 80 | self.assertEqual(vPT.shape, (5, 10, 7, 7)) 81 | self.assertTrue(np.allclose(vT, vPT)) 82 | 83 | 84 | if __name__ == "__main__": 85 | unittest.main(verbosity=2) 86 | --------------------------------------------------------------------------------