├── .gitignore ├── LICENSE ├── README.md ├── checkpoint └── copy_10 │ ├── NTM-copy_copy.model-3402.data-00000-of-00001 │ ├── NTM-copy_copy.model-3402.index │ └── checkpoint ├── etc ├── NTM.gif ├── result1.png ├── result2.png ├── result3.png ├── result4.png ├── result_15_12_30.png └── result_15_12_31.png ├── ipynb └── NTM Test.ipynb ├── main.py ├── ntm.py ├── ntm_cell.py ├── ops.py ├── ops_test.py ├── tasks ├── __init__.py ├── copy.py └── recall.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # log 2 | log 3 | 4 | # test 5 | dummy 6 | app.py 7 | 8 | # trash 9 | .dropbox 10 | test* 11 | 12 | # Created by https://www.gitignore.io/api/python,vim 13 | 14 | ### IPythonNotebook ### 15 | ## Temporary data 16 | .ipynb_checkpoints/ 17 | 18 | ### Python ### 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | env/ 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *,cover 64 | .hypothesis/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | 80 | ### Vim ### 81 | [._]*.s[a-w][a-z] 82 | [._]s[a-w][a-z] 83 | *.un~ 84 | Session.vim 85 | .netrwhist 86 | *~ 87 | 88 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Taehoon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Neural Turing Machine in Tensorflow 2 | =================================== 3 | 4 | Tensorflow implementation of [Neural Turing Machine](http://arxiv.org/abs/1410.5401). This implementation uses an LSTM controller. NTM models with multiple read/write heads are supported. 5 | 6 | ![alt_tag](etc/NTM.gif) 7 | 8 | The referenced torch code can be found [here](https://github.com/kaishengtai/torch-ntm). 9 | 10 | ** 1. Loss sometimes goes to nan even with the gradient clipping ([#2](https://github.com/carpedm20/NTM-tensorflow/issues/2)).** 11 | ** 2. The code is very poorly design to support NTM inputs with variable lengths. Just use this code as a reference.** 12 | 13 | 14 | Prerequisites 15 | ------------- 16 | 17 | - Python 2.7 or Python 3.3+ 18 | - [Tensorflow 1.1.0](https://www.tensorflow.org/) 19 | - NumPy 20 | 21 | 22 | Usage 23 | ----- 24 | 25 | To train a copy task: 26 | 27 | $ python main.py --task copy --is_train True 28 | 29 | To test a *quick* copy task: 30 | 31 | $ python main.py --task copy --test_max_length 10 32 | 33 | 34 | Results 35 | ------- 36 | 37 | More detailed results can be found [here](ipynb/NTM\ Test.ipynb). 38 | 39 | **Copy task:** 40 | 41 | ![alt_tag](etc/result4.png) 42 | ![alt_tag](etc/result3.png) 43 | 44 | **Recall task:** 45 | 46 | (in progress) 47 | 48 | 49 | Author 50 | ------ 51 | 52 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 53 | -------------------------------------------------------------------------------- /checkpoint/copy_10/NTM-copy_copy.model-3402.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/checkpoint/copy_10/NTM-copy_copy.model-3402.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoint/copy_10/NTM-copy_copy.model-3402.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/checkpoint/copy_10/NTM-copy_copy.model-3402.index -------------------------------------------------------------------------------- /checkpoint/copy_10/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "NTM-copy_copy.model-3402" 2 | all_model_checkpoint_paths: "NTM-copy_copy.model-3402" 3 | -------------------------------------------------------------------------------- /etc/NTM.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/etc/NTM.gif -------------------------------------------------------------------------------- /etc/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/etc/result1.png -------------------------------------------------------------------------------- /etc/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/etc/result2.png -------------------------------------------------------------------------------- /etc/result3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/etc/result3.png -------------------------------------------------------------------------------- /etc/result4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/etc/result4.png -------------------------------------------------------------------------------- /etc/result_15_12_30.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/etc/result_15_12_30.png -------------------------------------------------------------------------------- /etc/result_15_12_31.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/etc/result_15_12_31.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import importlib 4 | import tensorflow as tf 5 | from ntm_cell import NTMCell 6 | from ntm import NTM 7 | 8 | from utils import pp 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_string("task", "copy", "Task to run [copy, recall]") 12 | flags.DEFINE_integer("epoch", 100000, "Epoch to train [100000]") 13 | flags.DEFINE_integer("input_dim", 10, "Dimension of input [10]") 14 | flags.DEFINE_integer("output_dim", 10, "Dimension of output [10]") 15 | flags.DEFINE_integer("min_length", 1, "Minimum length of input sequence [1]") 16 | flags.DEFINE_integer("max_length", 10, "Maximum length of output sequence [10]") 17 | flags.DEFINE_integer("controller_layer_size", 1, "The size of LSTM controller [1]") 18 | flags.DEFINE_integer("controller_dim", 100, "Dimension of LSTM controller [100]") 19 | flags.DEFINE_integer("write_head_size", 1, "The number of write head [1]") 20 | flags.DEFINE_integer("read_head_size", 1, "The number of read head [1]") 21 | flags.DEFINE_integer("test_max_length", 120, "Maximum length of output sequence [120]") 22 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 23 | flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]") 24 | flags.DEFINE_boolean("continue_train", None, "True to continue training from saved checkpoint. False for restarting. None for automatic [None]") 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | def create_ntm(config, sess, **ntm_args): 29 | cell = NTMCell( 30 | input_dim=config.input_dim, 31 | output_dim=config.output_dim, 32 | controller_layer_size=config.controller_layer_size, 33 | controller_dim=config.controller_dim, 34 | write_head_size=config.write_head_size, 35 | read_head_size=config.read_head_size) 36 | scope = ntm_args.pop('scope', 'NTM-%s' % config.task) 37 | ntm = NTM( 38 | cell, sess, config.min_length, config.max_length, 39 | test_max_length=config.test_max_length, scope=scope, **ntm_args) 40 | return cell, ntm 41 | 42 | 43 | def main(_): 44 | pp.pprint(flags.FLAGS.__flags) 45 | 46 | with tf.device('/cpu:0'), tf.Session() as sess: 47 | try: 48 | task = importlib.import_module('tasks.%s' % FLAGS.task) 49 | except ImportError: 50 | print("task '%s' does not have implementation" % FLAGS.task) 51 | raise 52 | 53 | if FLAGS.is_train: 54 | cell, ntm = create_ntm(FLAGS, sess) 55 | task.train(ntm, FLAGS, sess) 56 | else: 57 | cell, ntm = create_ntm(FLAGS, sess, forward_only=True) 58 | 59 | ntm.load(FLAGS.checkpoint_dir, FLAGS.task) 60 | 61 | if FLAGS.task == 'copy': 62 | task.run(ntm, int(FLAGS.test_max_length * 1 / 3), sess) 63 | print 64 | task.run(ntm, int(FLAGS.test_max_length * 2 / 3), sess) 65 | print 66 | task.run(ntm, int(FLAGS.test_max_length * 3 / 3), sess) 67 | else: 68 | task.run(ntm, int(FLAGS.test_max_length), sess) 69 | 70 | 71 | if __name__ == '__main__': 72 | tf.app.run() 73 | -------------------------------------------------------------------------------- /ntm.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from collections import defaultdict 8 | from tensorflow.contrib.legacy_seq2seq import sequence_loss 9 | 10 | import ntm_cell 11 | 12 | import os 13 | from utils import progress 14 | 15 | def softmax_loss_function(labels, inputs): 16 | return tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=inputs) 17 | 18 | class NTM(object): 19 | def __init__(self, cell, sess, 20 | min_length, max_length, 21 | test_max_length=120, 22 | min_grad=-10, max_grad=+10, 23 | lr=1e-4, momentum=0.9, decay=0.95, 24 | scope="NTM", forward_only=False): 25 | """Create a neural turing machine specified by NTMCell "cell". 26 | 27 | Args: 28 | cell: An instantce of NTMCell. 29 | sess: A TensorFlow session. 30 | min_length: Minimum length of input sequence. 31 | max_length: Maximum length of input sequence for training. 32 | test_max_length: Maximum length of input sequence for testing. 33 | min_grad: (optional) Minimum gradient for gradient clipping [-10]. 34 | max_grad: (optional) Maximum gradient for gradient clipping [+10]. 35 | lr: (optional) Learning rate [1e-4]. 36 | momentum: (optional) Momentum of RMSProp [0.9]. 37 | decay: (optional) Decay rate of RMSProp [0.95]. 38 | """ 39 | if not isinstance(cell, ntm_cell.NTMCell): 40 | raise TypeError("cell must be an instance of NTMCell") 41 | 42 | self.cell = cell 43 | self.sess = sess 44 | self.scope = scope 45 | 46 | self.lr = lr 47 | self.momentum = momentum 48 | self.decay = decay 49 | 50 | self.min_grad = min_grad 51 | self.max_grad = max_grad 52 | self.min_length = min_length 53 | self.max_length = max_length 54 | self._max_length = max_length 55 | 56 | if forward_only: 57 | self.max_length = test_max_length 58 | 59 | self.inputs = [] 60 | self.outputs = {} 61 | self.output_logits = {} 62 | self.true_outputs = [] 63 | 64 | self.prev_states = {} 65 | self.input_states = defaultdict(list) 66 | self.output_states = defaultdict(list) 67 | 68 | self.start_symbol = tf.placeholder(tf.float32, [self.cell.input_dim], 69 | name='start_symbol') 70 | self.end_symbol = tf.placeholder(tf.float32, [self.cell.input_dim], 71 | name='end_symbol') 72 | 73 | self.losses = {} 74 | self.optims = {} 75 | self.grads = {} 76 | 77 | self.saver = None 78 | self.params = None 79 | 80 | with tf.variable_scope(self.scope): 81 | self.global_step = tf.Variable(0, trainable=False) 82 | 83 | self.build_model(forward_only) 84 | 85 | def build_model(self, forward_only, is_copy=True): 86 | print(" [*] Building a NTM model") 87 | 88 | with tf.variable_scope(self.scope): 89 | # present start symbol 90 | if is_copy: 91 | _, _, prev_state = self.cell(self.start_symbol, state=None) 92 | self.save_state(prev_state, 0, self.max_length) 93 | 94 | zeros = np.zeros(self.cell.input_dim, dtype=np.float32) 95 | 96 | tf.get_variable_scope().reuse_variables() 97 | for seq_length in xrange(1, self.max_length + 1): 98 | progress(seq_length / float(self.max_length)) 99 | 100 | input_ = tf.placeholder(tf.float32, [self.cell.input_dim], 101 | name='input_%s' % seq_length) 102 | true_output = tf.placeholder(tf.float32, [self.cell.output_dim], 103 | name='true_output_%s' % seq_length) 104 | 105 | self.inputs.append(input_) 106 | self.true_outputs.append(true_output) 107 | 108 | # present inputs 109 | _, _, prev_state = self.cell(input_, prev_state) 110 | self.save_state(prev_state, seq_length, self.max_length) 111 | 112 | # present end symbol 113 | if is_copy: 114 | _, _, state = self.cell(self.end_symbol, prev_state) 115 | self.save_state(state, seq_length) 116 | 117 | self.prev_states[seq_length] = state 118 | 119 | if not forward_only: 120 | # present targets 121 | outputs, output_logits = [], [] 122 | for _ in xrange(seq_length): 123 | output, output_logit, state = self.cell(zeros, state) 124 | self.save_state(state, seq_length, is_output=True) 125 | outputs.append(output) 126 | output_logits.append(output_logit) 127 | 128 | self.outputs[seq_length] = outputs 129 | self.output_logits[seq_length] = output_logits 130 | 131 | if not forward_only: 132 | for seq_length in xrange(self.min_length, self.max_length + 1): 133 | print(" [*] Building a loss model for seq_length %s" % seq_length) 134 | 135 | loss = sequence_loss( 136 | logits=self.output_logits[seq_length], 137 | targets=self.true_outputs[0:seq_length], 138 | weights=[1] * seq_length, 139 | average_across_timesteps=False, 140 | average_across_batch=False, 141 | softmax_loss_function=softmax_loss_function) 142 | 143 | self.losses[seq_length] = loss 144 | 145 | if not self.params: 146 | self.params = tf.trainable_variables() 147 | 148 | # grads, norm = tf.clip_by_global_norm( 149 | # tf.gradients(loss, self.params), 5) 150 | 151 | grads = [] 152 | for grad in tf.gradients(loss, self.params): 153 | if grad is not None: 154 | grads.append(tf.clip_by_value(grad, 155 | self.min_grad, 156 | self.max_grad)) 157 | else: 158 | grads.append(grad) 159 | 160 | self.grads[seq_length] = grads 161 | opt = tf.train.RMSPropOptimizer(self.lr, 162 | decay=self.decay, 163 | momentum=self.momentum) 164 | 165 | reuse = seq_length != 1 166 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): 167 | self.optims[seq_length] = opt.apply_gradients( 168 | zip(grads, self.params), 169 | global_step=self.global_step) 170 | 171 | model_vars = \ 172 | [v for v in tf.global_variables() if v.name.startswith(self.scope)] 173 | self.saver = tf.train.Saver(model_vars) 174 | print(" [*] Build a NTM model finished") 175 | 176 | def get_outputs(self, seq_length): 177 | if not self.outputs.has_key(seq_length): 178 | with tf.variable_scope(self.scope): 179 | tf.get_variable_scope().reuse_variables() 180 | 181 | zeros = np.zeros(self.cell.input_dim, dtype=np.float32) 182 | state = self.prev_states[seq_length] 183 | 184 | outputs, output_logits = [], [] 185 | for _ in xrange(seq_length): 186 | output, output_logit, state = self.cell(zeros, state) 187 | self.save_state(state, seq_length, is_output=True) 188 | outputs.append(output) 189 | output_logits.append(output_logit) 190 | 191 | self.outputs[seq_length] = outputs 192 | self.output_logits[seq_length] = output_logits 193 | return self.outputs[seq_length] 194 | 195 | def get_loss(self, seq_length): 196 | if not self.outputs.has_key(seq_length): 197 | self.get_outputs(seq_length) 198 | 199 | if not self.losses.has_key(seq_length): 200 | loss = sequence_loss( 201 | logits=self.output_logits[seq_length], 202 | targets=self.true_outputs[0:seq_length], 203 | weights=[1] * seq_length, 204 | average_across_timesteps=False, 205 | average_across_batch=False, 206 | softmax_loss_function=softmax_loss_function) 207 | 208 | self.losses[seq_length] = loss 209 | return self.losses[seq_length] 210 | 211 | def get_output_states(self, seq_length): 212 | zeros = np.zeros(self.cell.input_dim, dtype=np.float32) 213 | 214 | if not self.output_states.has_key(seq_length): 215 | with tf.variable_scope(self.scope): 216 | tf.get_variable_scope().reuse_variables() 217 | 218 | outputs, output_logits = [], [] 219 | state = self.prev_states[seq_length] 220 | 221 | for _ in xrange(seq_length): 222 | output, output_logit, state = self.cell(zeros, state) 223 | self.save_state(state, seq_length, is_output=True) 224 | outputs.append(output) 225 | output_logits.append(output_logit) 226 | self.outputs[seq_length] = outputs 227 | self.output_logits[seq_length] = output_logits 228 | return self.output_states[seq_length] 229 | 230 | @property 231 | def loss(self): 232 | return self.losses[self.cell.depth] 233 | 234 | @property 235 | def optim(self): 236 | return self.optims[self.cell.depth] 237 | 238 | def save_state(self, state, from_, to=None, is_output=False): 239 | if is_output: 240 | state_to_add = self.output_states 241 | else: 242 | state_to_add = self.input_states 243 | 244 | if to: 245 | for idx in xrange(from_, to + 1): 246 | state_to_add[idx].append(state) 247 | else: 248 | state_to_add[from_].append(state) 249 | 250 | def save(self, checkpoint_dir, task_name, step): 251 | task_dir = os.path.join(checkpoint_dir, "%s_%s" % (task_name, self.max_length)) 252 | file_name = "%s_%s.model" % (self.scope, task_name) 253 | 254 | if not os.path.exists(task_dir): 255 | os.makedirs(task_dir) 256 | 257 | self.saver.save( 258 | self.sess, 259 | os.path.join(task_dir, file_name), 260 | global_step=step.astype(int)) 261 | 262 | def load(self, checkpoint_dir, task_name, strict=True): 263 | print(" [*] Reading checkpoints...") 264 | 265 | task_dir = "%s_%s" % (task_name, self._max_length) 266 | checkpoint_dir = os.path.join(checkpoint_dir, task_dir) 267 | 268 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 269 | if ckpt and ckpt.model_checkpoint_path: 270 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 271 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 272 | else: 273 | if strict: 274 | raise Exception(" [!] Testing, but %s not found" % checkpoint_dir) 275 | else: 276 | print(' [!] Training, but previous training data %s not found' % checkpoint_dir) 277 | -------------------------------------------------------------------------------- /ntm_cell.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from functools import reduce 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | from tensorflow.python.ops import array_ops 10 | 11 | from utils import * 12 | from ops import * 13 | 14 | class NTMCell(object): 15 | def __init__(self, input_dim, output_dim, 16 | mem_size=128, mem_dim=20, controller_dim=100, 17 | controller_layer_size=1, shift_range=1, 18 | write_head_size=1, read_head_size=1): 19 | """Initialize the parameters for an NTM cell. 20 | Args: 21 | input_dim: int, The number of units in the LSTM cell 22 | output_dim: int, The dimensionality of the inputs into the LSTM cell 23 | mem_size: (optional) int, The size of memory [128] 24 | mem_dim: (optional) int, The dimensionality for memory [20] 25 | controller_dim: (optional) int, The dimensionality for controller [100] 26 | controller_layer_size: (optional) int, The size of controller layer [1] 27 | """ 28 | # initialize configs 29 | self.input_dim = input_dim 30 | self.output_dim = output_dim 31 | self.mem_size = mem_size 32 | self.mem_dim = mem_dim 33 | self.controller_dim = controller_dim 34 | self.controller_layer_size = controller_layer_size 35 | self.shift_range = shift_range 36 | self.write_head_size = write_head_size 37 | self.read_head_size = read_head_size 38 | 39 | self.depth = 0 40 | self.states = [] 41 | 42 | def __call__(self, input_, state=None, scope=None): 43 | """Run one step of NTM. 44 | 45 | Args: 46 | inputs: input Tensor, 2D, 1 x input_size. 47 | state: state Dictionary which contains M, read_w, write_w, read, 48 | output, hidden. 49 | scope: VariableScope for the created subgraph; defaults to class name. 50 | 51 | Returns: 52 | A tuple containing: 53 | - A 2D, batch x output_dim, Tensor representing the output of the LSTM 54 | after reading "input_" when previous state was "state". 55 | Here output_dim is: 56 | num_proj if num_proj was set, 57 | num_units otherwise. 58 | - A 2D, batch x state_size, Tensor representing the new state of LSTM 59 | after reading "input_" when previous state was "state". 60 | """ 61 | if state == None: 62 | _, state = self.initial_state() 63 | 64 | M_prev = state['M'] 65 | read_w_list_prev = state['read_w'] 66 | write_w_list_prev = state['write_w'] 67 | read_list_prev = state['read'] 68 | output_list_prev = state['output'] 69 | hidden_list_prev = state['hidden'] 70 | 71 | # build a controller 72 | output_list, hidden_list = self.build_controller(input_, read_list_prev, 73 | output_list_prev, 74 | hidden_list_prev) 75 | 76 | # last output layer from LSTM controller 77 | last_output = output_list[-1] 78 | 79 | # build a memory 80 | M, read_w_list, write_w_list, read_list = self.build_memory(M_prev, 81 | read_w_list_prev, 82 | write_w_list_prev, 83 | last_output) 84 | 85 | # get a new output 86 | new_output, new_output_logit = self.new_output(last_output) 87 | 88 | state = { 89 | 'M': M, 90 | 'read_w': read_w_list, 91 | 'write_w': write_w_list, 92 | 'read': read_list, 93 | 'output': output_list, 94 | 'hidden': hidden_list, 95 | } 96 | 97 | self.depth += 1 98 | self.states.append(state) 99 | 100 | return new_output, new_output_logit, state 101 | 102 | def new_output(self, output): 103 | """Logistic sigmoid output layers.""" 104 | 105 | with tf.variable_scope('output'): 106 | logit = Linear(output, self.output_dim, name='output') 107 | return tf.sigmoid(logit), logit 108 | 109 | def build_controller(self, input_, 110 | read_list_prev, output_list_prev, hidden_list_prev): 111 | """Build LSTM controller.""" 112 | 113 | with tf.variable_scope("controller"): 114 | output_list = [] 115 | hidden_list = [] 116 | for layer_idx in xrange(self.controller_layer_size): 117 | o_prev = output_list_prev[layer_idx] 118 | h_prev = hidden_list_prev[layer_idx] 119 | 120 | if layer_idx == 0: 121 | def new_gate(gate_name): 122 | return linear([input_, o_prev] + read_list_prev, 123 | output_size = self.controller_dim, 124 | bias = True, 125 | scope = "%s_gate_%s" % (gate_name, layer_idx)) 126 | else: 127 | def new_gate(gate_name): 128 | return linear([output_list[-1], o_prev], 129 | output_size = self.controller_dim, 130 | bias = True, 131 | scope="%s_gate_%s" % (gate_name, layer_idx)) 132 | 133 | # input, forget, and output gates for LSTM 134 | i = tf.sigmoid(new_gate('input')) 135 | f = tf.sigmoid(new_gate('forget')) 136 | o = tf.sigmoid(new_gate('output')) 137 | update = tf.tanh(new_gate('update')) 138 | 139 | # update the sate of the LSTM cell 140 | hid = tf.add_n([f * h_prev, i * update]) 141 | out = o * tf.tanh(hid) 142 | 143 | hidden_list.append(hid) 144 | output_list.append(out) 145 | 146 | return output_list, hidden_list 147 | 148 | def build_memory(self, M_prev, read_w_list_prev, write_w_list_prev, last_output): 149 | """Build a memory to read & write.""" 150 | 151 | with tf.variable_scope("memory"): 152 | # 3.1 Reading 153 | if self.read_head_size == 1: 154 | read_w_prev = read_w_list_prev[0] 155 | 156 | read_w, read = self.build_read_head(M_prev, tf.squeeze(read_w_prev), 157 | last_output, 0) 158 | read_w_list = [read_w] 159 | read_list = [read] 160 | else: 161 | read_w_list = [] 162 | read_list = [] 163 | 164 | for idx in xrange(self.read_head_size): 165 | read_w_prev_idx = read_w_list_prev[idx] 166 | read_w_idx, read_idx = self.build_read_head(M_prev, read_w_prev_idx, 167 | last_output, idx) 168 | 169 | read_w_list.append(read_w_idx) 170 | read_list.append(read_idx) 171 | 172 | # 3.2 Writing 173 | if self.write_head_size == 1: 174 | write_w_prev = write_w_list_prev[0] 175 | 176 | write_w, write, erase = self.build_write_head(M_prev, 177 | tf.squeeze(write_w_prev), 178 | last_output, 0) 179 | 180 | M_erase = tf.ones([self.mem_size, self.mem_dim]) \ 181 | - outer_product(write_w, erase) 182 | M_write = outer_product(write_w, write) 183 | 184 | write_w_list = [write_w] 185 | else: 186 | write_w_list = [] 187 | write_list = [] 188 | erase_list = [] 189 | 190 | M_erases = [] 191 | M_writes = [] 192 | 193 | for idx in xrange(self.write_head_size): 194 | write_w_prev_idx = write_w_list_prev[idx] 195 | 196 | write_w_idx, write_idx, erase_idx = \ 197 | self.build_write_head(M_prev, write_w_prev_idx, 198 | last_output, idx) 199 | 200 | write_w_list.append(tf.transpose(write_w_idx)) 201 | write_list.append(write_idx) 202 | erase_list.append(erase_idx) 203 | 204 | M_erases.append(tf.ones([self.mem_size, self.mem_dim]) \ 205 | - outer_product(write_w_idx, erase_idx)) 206 | M_writes.append(outer_product(write_w_idx, write_idx)) 207 | 208 | M_erase = reduce(lambda x, y: x*y, M_erases) 209 | M_write = tf.add_n(M_writes) 210 | 211 | M = M_prev * M_erase + M_write 212 | 213 | return M, read_w_list, write_w_list, read_list 214 | 215 | def build_read_head(self, M_prev, read_w_prev, last_output, idx): 216 | return self.build_head(M_prev, read_w_prev, last_output, True, idx) 217 | 218 | def build_write_head(self, M_prev, write_w_prev, last_output, idx): 219 | return self.build_head(M_prev, write_w_prev, last_output, False, idx) 220 | 221 | def build_head(self, M_prev, w_prev, last_output, is_read, idx): 222 | scope = "read" if is_read else "write" 223 | 224 | with tf.variable_scope(scope): 225 | # Figure 2. 226 | # Amplify or attenuate the precision 227 | with tf.variable_scope("k"): 228 | k = tf.tanh(Linear(last_output, self.mem_dim, name='k_%s' % idx)) 229 | # Interpolation gate 230 | with tf.variable_scope("g"): 231 | g = tf.sigmoid(Linear(last_output, 1, name='g_%s' % idx)) 232 | # shift weighting 233 | with tf.variable_scope("s_w"): 234 | w = Linear(last_output, 2 * self.shift_range + 1, name='s_w_%s' % idx) 235 | s_w = softmax(w) 236 | with tf.variable_scope("beta"): 237 | beta = tf.nn.softplus(Linear(last_output, 1, name='beta_%s' % idx)) 238 | with tf.variable_scope("gamma"): 239 | gamma = tf.add(tf.nn.softplus(Linear(last_output, 1, name='gamma_%s' % idx)), 240 | tf.constant(1.0)) 241 | 242 | # 3.3.1 243 | # Cosine similarity 244 | similarity = smooth_cosine_similarity(M_prev, k) # [mem_size x 1] 245 | # Focusing by content 246 | content_focused_w = softmax(scalar_mul(similarity, beta)) 247 | 248 | # 3.3.2 249 | # Focusing by location 250 | gated_w = tf.add_n([ 251 | scalar_mul(content_focused_w, g), 252 | scalar_mul(w_prev, (tf.constant(1.0) - g)) 253 | ]) 254 | 255 | # Convolutional shifts 256 | conv_w = circular_convolution(gated_w, s_w) 257 | 258 | # Sharpening 259 | powed_conv_w = tf.pow(conv_w, gamma) 260 | w = powed_conv_w / tf.reduce_sum(powed_conv_w) 261 | 262 | if is_read: 263 | # 3.1 Reading 264 | read = matmul(tf.transpose(M_prev), w) 265 | return w, read 266 | else: 267 | # 3.2 Writing 268 | erase = tf.sigmoid(Linear(last_output, self.mem_dim, name='erase_%s' % idx)) 269 | add = tf.tanh(Linear(last_output, self.mem_dim, name='add_%s' % idx)) 270 | return w, add, erase 271 | 272 | def initial_state(self, dummy_value=0.0): 273 | self.depth = 0 274 | self.states = [] 275 | with tf.variable_scope("init_cell"): 276 | # always zero 277 | dummy = tf.Variable(tf.constant([[dummy_value]], dtype=tf.float32)) 278 | 279 | # memory 280 | M_init_linear = tf.tanh(Linear(dummy, self.mem_size * self.mem_dim, 281 | name='M_init_linear')) 282 | M_init = tf.reshape(M_init_linear, [self.mem_size, self.mem_dim]) 283 | 284 | # read weights 285 | read_w_list_init = [] 286 | read_list_init = [] 287 | for idx in xrange(self.read_head_size): 288 | read_w_idx = Linear(dummy, self.mem_size, is_range=True, 289 | squeeze=True, name='read_w_%d' % idx) 290 | read_w_list_init.append(softmax(read_w_idx)) 291 | 292 | read_init_idx = Linear(dummy, self.mem_dim, 293 | squeeze=True, name='read_init_%d' % idx) 294 | read_list_init.append(tf.tanh(read_init_idx)) 295 | 296 | # write weights 297 | write_w_list_init = [] 298 | for idx in xrange(self.write_head_size): 299 | write_w_idx = Linear(dummy, self.mem_size, is_range=True, 300 | squeeze=True, name='write_w_%s' % idx) 301 | write_w_list_init.append(softmax(write_w_idx)) 302 | 303 | # controller state 304 | output_init_list = [] 305 | hidden_init_list = [] 306 | for idx in xrange(self.controller_layer_size): 307 | output_init_idx = Linear(dummy, self.controller_dim, 308 | squeeze=True, name='output_init_%s' % idx) 309 | output_init_list.append(tf.tanh(output_init_idx)) 310 | hidden_init_idx = Linear(dummy, self.controller_dim, 311 | squeeze=True, name='hidden_init_%s' % idx) 312 | hidden_init_list.append(tf.tanh(hidden_init_idx)) 313 | 314 | output = tf.tanh(Linear(dummy, self.output_dim, name='new_output')) 315 | 316 | state = { 317 | 'M': M_init, 318 | 'read_w': read_w_list_init, 319 | 'write_w': write_w_list_init, 320 | 'read': read_list_init, 321 | 'output': output_init_list, 322 | 'hidden': hidden_init_list 323 | } 324 | 325 | self.depth += 1 326 | self.states.append(state) 327 | 328 | return output, state 329 | 330 | def get_memory(self, depth=None): 331 | depth = depth if depth else self.depth 332 | return self.states[depth - 1]['M'] 333 | 334 | def get_read_weights(self, depth=None): 335 | depth = depth if depth else self.depth 336 | return self.states[depth - 1]['read_w'] 337 | 338 | def get_write_weights(self, depth=None): 339 | depth = depth if depth else self.depth 340 | return self.states[depth - 1]['write_w'] 341 | 342 | def get_read_vector(self, depth=None): 343 | depth = depth if depth else self.depth 344 | return self.states[depth - 1]['read'] 345 | 346 | def print_read_max(self, sess): 347 | read_w_list = sess.run(self.get_read_weights()) 348 | 349 | fmt = "%-4d %.4f" 350 | if self.read_head_size == 1: 351 | print(fmt % (argmax(read_w_list[0]))) 352 | else: 353 | for idx in xrange(self.read_head_size): 354 | print(fmt % np.argmax(read_w_list[idx])) 355 | 356 | def print_write_max(self, sess): 357 | write_w_list = sess.run(self.get_write_weights()) 358 | 359 | fmt = "%-4d %.4f" 360 | if self.write_head_size == 1: 361 | print(fmt % (argmax(write_w_list[0]))) 362 | else: 363 | for idx in xrange(self.write_head_size): 364 | print(fmt % argmax(write_w_list[idx])) 365 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.ops import array_ops 6 | from tensorflow.python.ops import init_ops 7 | from tensorflow.python.framework import ops 8 | from tensorflow.python.ops import variable_scope as vs 9 | 10 | from utils import * 11 | 12 | def linear(args, output_size, bias, bias_start=0.0, scope=None): 13 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 14 | 15 | Args: 16 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 17 | output_size: int, second dimension of W[i]. 18 | bias: boolean, whether to add a bias term or not. 19 | bias_start: starting value to initialize the bias; 0 by default. 20 | scope: VariableScope for the created subgraph; defaults to "Linear". 21 | 22 | Returns: 23 | A 2D Tensor with shape [batch x output_size] equal to 24 | sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 25 | 26 | Raises: 27 | ValueError: if some of the arguments has unspecified or wrong shape. 28 | """ 29 | if not isinstance(args, (list, tuple)): 30 | args = [args] 31 | 32 | # Calculate the total size of arguments on dimension 1. 33 | total_arg_size = 0 34 | shapes = [] 35 | for a in args: 36 | try: 37 | shapes.append(a.get_shape().as_list()) 38 | except Exception as e: 39 | shapes.append(a.shape) 40 | 41 | is_vector = False 42 | for idx, shape in enumerate(shapes): 43 | if len(shape) != 2: 44 | is_vector = True 45 | args[idx] = tf.reshape(args[idx], [1, -1]) 46 | total_arg_size += shape[0] 47 | else: 48 | total_arg_size += shape[1] 49 | 50 | # Now the computation. 51 | with vs.variable_scope(scope or "Linear"): 52 | matrix = vs.get_variable("Matrix", [total_arg_size, output_size]) 53 | if len(args) == 1: 54 | res = tf.matmul(args[0], matrix) 55 | else: 56 | res = tf.matmul(tf.concat(args, 1), matrix) 57 | if not bias: 58 | return res 59 | bias_term = vs.get_variable( 60 | "Bias", [output_size], 61 | initializer=init_ops.constant_initializer(bias_start)) 62 | 63 | if is_vector: 64 | return tf.reshape(res + bias_term, [-1]) 65 | else: 66 | return res + bias_term 67 | 68 | def Linear(input_, output_size, stddev=0.5, 69 | is_range=False, squeeze=False, 70 | name=None, reuse=None): 71 | """Applies a linear transformation to the incoming data. 72 | 73 | Args: 74 | input: a 2-D or 1-D data (`Tensor` or `ndarray`) 75 | output_size: the size of output matrix or vector 76 | """ 77 | with tf.variable_scope("Linear", reuse=reuse): 78 | if type(input_) == np.ndarray: 79 | shape = input_.shape 80 | else: 81 | shape = input_.get_shape().as_list() 82 | 83 | is_vector = False 84 | if len(shape) == 1: 85 | is_vector = True 86 | input_ = tf.reshape(input_, [1, -1]) 87 | input_size = shape[0] 88 | elif len(shape) == 2: 89 | input_size = shape[1] 90 | else: 91 | raise ValueError("Linear expects shape[1] of inputuments: %s" % str(shape)) 92 | 93 | w_name = "%s_w" % name if name else None 94 | b_name = "%s_b" % name if name else None 95 | 96 | w = tf.get_variable(w_name, [input_size, output_size], tf.float32, 97 | tf.random_normal_initializer(stddev=stddev)) 98 | mul = tf.matmul(input_, w) 99 | 100 | if is_range: 101 | def identity_initializer(tensor): 102 | def _initializer(shape, dtype=tf.float32, partition_info=None): 103 | return tf.identity(tensor) 104 | return _initializer 105 | 106 | range_ = tf.range(output_size, 0, -1) 107 | b = tf.get_variable(b_name, [output_size], tf.float32, 108 | identity_initializer(tf.cast(range_, tf.float32))) 109 | else: 110 | b = tf.get_variable(b_name, [output_size], tf.float32, 111 | tf.random_normal_initializer(stddev=stddev)) 112 | 113 | if squeeze: 114 | output = tf.squeeze(tf.nn.bias_add(mul, b)) 115 | else: 116 | output = tf.nn.bias_add(mul, b) 117 | 118 | if is_vector: 119 | return tf.reshape(output, [-1]) 120 | else: 121 | return output 122 | 123 | def smooth_cosine_similarity(m, v): 124 | """Computes smooth cosine similarity. 125 | 126 | Args: 127 | m: a 2-D `Tensor` (matrix) 128 | v: a 1-D `Tensor` (vector) 129 | """ 130 | shape_x = m.get_shape().as_list() 131 | shape_y = v.get_shape().as_list() 132 | if shape_x[1] != shape_y[0]: 133 | raise ValueError("Smooth cosine similarity is expecting same dimemsnion") 134 | 135 | m_norm = tf.sqrt(tf.reduce_sum(tf.pow(m, 2),1)) 136 | v_norm = tf.sqrt(tf.reduce_sum(tf.pow(v, 2))) 137 | m_dot_v = tf.matmul(m, tf.reshape(v, [-1, 1])) 138 | 139 | similarity = tf.div(tf.reshape(m_dot_v, [-1]), m_norm * v_norm + 1e-3) 140 | return similarity 141 | 142 | def circular_convolution(v, k): 143 | """Computes circular convolution. 144 | 145 | Args: 146 | v: a 1-D `Tensor` (vector) 147 | k: a 1-D `Tensor` (kernel) 148 | """ 149 | size = int(v.get_shape()[0]) 150 | kernel_size = int(k.get_shape()[0]) 151 | kernel_shift = int(math.floor(kernel_size/2.0)) 152 | 153 | def loop(idx): 154 | if idx < 0: return size + idx 155 | if idx >= size : return idx - size 156 | else: return idx 157 | 158 | kernels = [] 159 | for i in xrange(size): 160 | indices = [loop(i+j) for j in xrange(kernel_shift, -kernel_shift-1, -1)] 161 | v_ = tf.gather(v, indices) 162 | kernels.append(tf.reduce_sum(v_ * k, 0)) 163 | 164 | # # code with double loop 165 | # for i in xrange(size): 166 | # for j in xrange(kernel_size): 167 | # idx = i + kernel_shift - j + 1 168 | # if idx < 0: idx = idx + size 169 | # if idx >= size: idx = idx - size 170 | # w = tf.gather(v, int(idx)) * tf.gather(kernel, j) 171 | # output = tf.scatter_add(output, [i], tf.reshape(w, [1, -1])) 172 | 173 | return tf.dynamic_stitch([i for i in xrange(size)], kernels) 174 | 175 | def outer_product(*inputs): 176 | """Computes outer product. 177 | 178 | Args: 179 | inputs: a list of 1-D `Tensor` (vector) 180 | """ 181 | inputs = list(inputs) 182 | order = len(inputs) 183 | 184 | for idx, input_ in enumerate(inputs): 185 | if len(input_.get_shape()) == 1: 186 | inputs[idx] = tf.reshape(input_, [-1, 1] if idx % 2 == 0 else [1, -1]) 187 | 188 | if order == 2: 189 | output = tf.multiply(inputs[0], inputs[1]) 190 | elif order == 3: 191 | size = [] 192 | idx = 1 193 | for i in xrange(order): 194 | size.append(inputs[i].get_shape()[0]) 195 | output = tf.zeros(size) 196 | 197 | u, v, w = inputs[0], inputs[1], inputs[2] 198 | uv = tf.multiply(inputs[0], inputs[1]) 199 | for i in xrange(self.size[-1]): 200 | output = tf.scatter_add(output, [0,0,i], uv) 201 | 202 | return output 203 | 204 | def scalar_mul(x, beta, name=None): 205 | return x * beta 206 | 207 | def scalar_div(x, beta, name=None): 208 | return x / beta 209 | -------------------------------------------------------------------------------- /ops_test.py: -------------------------------------------------------------------------------- 1 | """Tests for ops.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | 8 | from tensorflow.python.ops import constant_op 9 | from tensorflow.python.framework import test_util 10 | from tensorflow.python.platform import googletest 11 | 12 | from ops import * 13 | 14 | class SmoothCosineSimilarityTest(test_util.TensorFlowTestCase): 15 | 16 | def testSmoothCosineSimilarity(self): 17 | """Test code for torch: 18 | 19 | th> x=torch.Tensor{{1,2,3},{2,2,2},{3,2,1},{0,2,4}} 20 | th> y=torch.Tensor{2,2,2} 21 | th> c=nn.SmoothCosineSimilarity() 22 | th> c:forward{x,y} 23 | 0.9257 24 | 0.9999 25 | 0.9257 26 | 0.7745 27 | [torch.DoubleTensor of size 4] 28 | """ 29 | m = constant_op.constant( 30 | [[1,2,3], 31 | [2,2,2], 32 | [3,2,1], 33 | [0,2,4]], dtype=np.float32) 34 | v = constant_op.constant([2,2,2], dtype=np.float32) 35 | for use_gpu in [True, False]: 36 | with self.test_session(use_gpu=use_gpu): 37 | loss = smooth_cosine_similarity(m, v).eval() 38 | self.assertAllClose(loss, [0.92574867671153, 39 | 0.99991667361053, 40 | 0.92574867671153, 41 | 0.77454667246876]) 42 | 43 | class CircularConvolutionTest(test_util.TensorFlowTestCase): 44 | 45 | def testCircularConvolution(self): 46 | v = constant_op.constant([1,2,3,4,5,6,7], dtype=tf.float32) 47 | k = constant_op.constant([0,0,1], dtype=tf.float32) 48 | for use_gpu in [True, False]: 49 | with self.test_session(use_gpu=use_gpu): 50 | loss = circular_convolution(v, k).eval() 51 | self.assertAllEqual(loss, [7,1,2,3,4,5,6]) 52 | 53 | if __name__ == "__main__": 54 | googletest.main() 55 | -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/NTM-tensorflow/5376c35e8800157b08af44ddec1cb870d2ef3c83/tasks/__init__.py -------------------------------------------------------------------------------- /tasks/copy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from random import randint 6 | 7 | from utils import pprint 8 | 9 | print_interval = 5 10 | 11 | 12 | def run(ntm, seq_length, sess, print_=True): 13 | start_symbol = np.zeros([ntm.cell.input_dim], dtype=np.float32) 14 | start_symbol[0] = 1 15 | end_symbol = np.zeros([ntm.cell.input_dim], dtype=np.float32) 16 | end_symbol[1] = 1 17 | 18 | seq = generate_copy_sequence(seq_length, ntm.cell.input_dim - 2) 19 | 20 | feed_dict = {input_: vec for vec, input_ in zip(seq, ntm.inputs)} 21 | feed_dict.update( 22 | {true_output: vec for vec, true_output in zip(seq, ntm.true_outputs)} 23 | ) 24 | feed_dict.update({ 25 | ntm.start_symbol: start_symbol, 26 | ntm.end_symbol: end_symbol 27 | }) 28 | 29 | input_states = [state['write_w'][0] for state in ntm.input_states[seq_length]] 30 | output_states = [state['read_w'][0] for state in ntm.get_output_states(seq_length)] 31 | 32 | result = sess.run( 33 | ntm.get_outputs(seq_length) + 34 | input_states + output_states + 35 | [ntm.get_loss(seq_length)], 36 | feed_dict=feed_dict) 37 | 38 | is_sz = len(input_states) 39 | os_sz = len(output_states) 40 | 41 | outputs = result[:seq_length] 42 | read_ws = result[seq_length:seq_length + is_sz] 43 | write_ws = result[seq_length + is_sz:seq_length + is_sz + os_sz] 44 | loss = result[-1] 45 | 46 | if print_: 47 | np.set_printoptions(suppress=True) 48 | print(" true output : ") 49 | pprint(seq) 50 | print(" predicted output :") 51 | pprint(np.round(outputs)) 52 | print(" Loss : %f" % loss) 53 | np.set_printoptions(suppress=False) 54 | else: 55 | return seq, outputs, read_ws, write_ws, loss 56 | 57 | 58 | def train(ntm, config, sess): 59 | if not os.path.isdir(config.checkpoint_dir): 60 | raise Exception(" [!] Directory %s not found" % config.checkpoint_dir) 61 | 62 | # delimiter flag for start and end 63 | start_symbol = np.zeros([config.input_dim], dtype=np.float32) 64 | start_symbol[0] = 1 65 | end_symbol = np.zeros([config.input_dim], dtype=np.float32) 66 | end_symbol[1] = 1 67 | 68 | print(" [*] Initialize all variables") 69 | tf.global_variables_initializer().run() 70 | print(" [*] Initialization finished") 71 | 72 | if config.continue_train is not False: 73 | ntm.load(config.checkpoint_dir, config.task, strict=config.continue_train is True) 74 | 75 | start_time = time.time() 76 | for idx in xrange(config.epoch): 77 | seq_length = randint(config.min_length, config.max_length) 78 | seq = generate_copy_sequence(seq_length, config.input_dim - 2) 79 | 80 | feed_dict = {input_: vec for vec, input_ in zip(seq, ntm.inputs)} 81 | feed_dict.update( 82 | {true_output: vec for vec, true_output in zip(seq, ntm.true_outputs)} 83 | ) 84 | feed_dict.update({ 85 | ntm.start_symbol: start_symbol, 86 | ntm.end_symbol: end_symbol 87 | }) 88 | 89 | _, cost, step = sess.run([ntm.optims[seq_length], 90 | ntm.get_loss(seq_length), 91 | ntm.global_step], feed_dict=feed_dict) 92 | 93 | if idx % 100 == 0: 94 | ntm.save(config.checkpoint_dir, config.task, step) 95 | 96 | if idx % print_interval == 0: 97 | print( 98 | "[%5d] %2d: %.2f (%.1fs)" 99 | % (idx, seq_length, cost, time.time() - start_time)) 100 | 101 | ntm.save(config.checkpoint_dir, config.task, step) 102 | 103 | print("Training %s task finished" % config.task) 104 | 105 | 106 | def generate_copy_sequence(length, bits): 107 | seq = np.zeros([length, bits + 2], dtype=np.float32) 108 | for idx in xrange(length): 109 | seq[idx, 2:bits+2] = np.random.rand(bits).round() 110 | return list(seq) 111 | -------------------------------------------------------------------------------- /tasks/recall.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from random import randint 6 | 7 | from ntm import NTM 8 | from utils import pprint 9 | from ntm_cell import NTMCell 10 | 11 | print_interval = 5 12 | 13 | 14 | def run(ntm, seq_length, sess, print_=True): 15 | start_symbol = np.zeros([ntm.cell.input_dim], dtype=np.float32) 16 | start_symbol[0] = 1 17 | end_symbol = np.zeros([ntm.cell.input_dim], dtype=np.float32) 18 | end_symbol[1] = 1 19 | 20 | seq = generate_recall_sequence(seq_length, ntm.cell.input_dim - 2) 21 | 22 | feed_dict = {input_:vec for vec, input_ in zip(seq, ntm.inputs)} 23 | feed_dict.update( 24 | {true_output:vec for vec, true_output in zip(seq, ntm.true_outputs)} 25 | ) 26 | feed_dict.update({ 27 | ntm.start_symbol: start_symbol, 28 | ntm.end_symbol: end_symbol 29 | }) 30 | 31 | input_states = [state['write_w'] for state in ntm.input_states[seq_length]] 32 | output_states = [state['read_w'] for state in ntm.get_output_states(seq_length)] 33 | 34 | result = sess.run(ntm.get_outputs(seq_length) + \ 35 | input_states + output_states + \ 36 | [ntm.get_loss(seq_length)], 37 | feed_dict=feed_dict) 38 | 39 | is_sz = len(input_states) 40 | os_sz = len(output_states) 41 | 42 | outputs = result[:seq_length] 43 | read_ws = result[seq_length:seq_length + is_sz] 44 | write_ws = result[seq_length + is_sz:seq_length + is_sz + os_sz] 45 | loss = result[-1] 46 | 47 | if print_: 48 | np.set_printoptions(suppress=True) 49 | print(" true output : ") 50 | pprint(seq) 51 | print(" predicted output :") 52 | pprint(np.round(outputs)) 53 | print(" Loss : %f" % loss) 54 | np.set_printoptions(suppress=False) 55 | else: 56 | return seq, outputs, read_ws, write_ws, loss 57 | 58 | 59 | def train(ntm, config, sess): 60 | if not os.path.isdir(config.checkpoint_dir): 61 | raise Exception(" [!] Directory %s not found" % config.checkpoint_dir) 62 | 63 | delim_symbol = np.zeros([config.input_dim], dtype=np.float32) 64 | start_symbol[0] = 1 65 | query_symbol = np.zeros([config.input_dim], dtype=np.float32) 66 | end_symbol[1] = 1 67 | 68 | print(" [*] Initialize all variables") 69 | tf.initialize_all_variables().run() 70 | print(" [*] Initialization finished") 71 | 72 | start_time = time.time() 73 | for idx in xrange(config.epoch): 74 | seq_length = randint(config.min_length, config.max_length) 75 | seq = generate_recall_sequence(seq_length, config.input_dim - 2) 76 | 77 | feed_dict = {input_:vec for vec, input_ in zip(seq, ntm.inputs)} 78 | feed_dict.update( 79 | {true_output:vec for vec, true_output in zip(seq, ntm.true_outputs)} 80 | ) 81 | feed_dict.update({ 82 | ntm.start_symbol: start_symbol, 83 | ntm.end_symbol: end_symbol 84 | }) 85 | 86 | _, cost, step = sess.run([ntm.optims[seq_length], 87 | ntm.get_loss(seq_length), 88 | ntm.global_step], feed_dict=feed_dict) 89 | 90 | if idx % 100 == 0: 91 | ntm.save(config.checkpoint_dir, 'recall', step) 92 | 93 | if idx % print_interval == 0: 94 | print("[%5d] %2d: %.2f (%.1fs)" \ 95 | % (idx, seq_length, cost, time.time() - start_time)) 96 | 97 | print("Training Copy task finished") 98 | 99 | 100 | def generate_recall_sequence(num_items, item_length, input_dim): 101 | items = [] 102 | for idx in xrange(num_items): 103 | item = np.random.rand(item_length, input_dim).round() 104 | item[0:item_length+1, 0:2] = 0 105 | items.append(item) 106 | return items 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pprint 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | eps = 1e-12 7 | pp = pprint.PrettyPrinter() 8 | 9 | try: 10 | xrange 11 | except NameError: 12 | xrange = range 13 | 14 | def progress(progress): 15 | barLength = 20 # Modify this to change the length of the progress bar 16 | status = "" 17 | if isinstance(progress, int): 18 | progress = float(progress) 19 | if not isinstance(progress, float): 20 | progress = 0 21 | status = "error: progress var must be float\r\n" 22 | if progress < 0: 23 | progress = 0 24 | status = "Halt...\r\n" 25 | if progress >= 1: 26 | progress = 1 27 | status = "Finished.\r\n" 28 | block = int(round(barLength*progress)) 29 | text = "\rPercent: [%s] %.2f%% %s" % ("#"*block + " "*(barLength-block), progress*100, status) 30 | sys.stdout.write(text) 31 | sys.stdout.flush() 32 | 33 | def pprint(seq): 34 | seq = np.array(seq) 35 | seq = np.char.mod('%d', np.around(seq)) 36 | seq[seq == '1'] = '#' 37 | seq[seq == '0'] = ' ' 38 | print("\n".join(["".join(x) for x in seq.tolist()])) 39 | 40 | def gather(m_or_v, idx): 41 | if len(m_or_v.get_shape()) > 1: 42 | return tf.gather(m_or_v, idx) 43 | else: 44 | assert idx == 0, "Error: idx should be 0 but %d" % idx 45 | return m_or_v 46 | 47 | def argmax(x): 48 | index = 0 49 | max_num = x[index] 50 | for idx in xrange(1, len(x)-1): 51 | if x[idx] > max_num: 52 | index = idx 53 | max_num = x[idx] 54 | return index, max_num 55 | 56 | def softmax(x): 57 | """Compute softmax. 58 | 59 | Args: 60 | x: a 2-D `Tensor` (matrix) or 1-D `Tensor` (vector) 61 | """ 62 | try: 63 | return tf.nn.softmax(x + eps) 64 | except: 65 | return tf.reshape(tf.nn.softmax(tf.reshape(x + eps, [1, -1])), [-1]) 66 | 67 | def matmul(x, y): 68 | """Compute matrix multiplication. 69 | 70 | Args: 71 | x: a 2-D `Tensor` (matrix) 72 | y: a 2-D `Tensor` (matrix) or 1-D `Tensor` (vector) 73 | """ 74 | try: 75 | return tf.matmul(x, y) 76 | except: 77 | return tf.reshape(tf.matmul(x, tf.reshape(y, [-1, 1])), [-1]) 78 | --------------------------------------------------------------------------------