├── .gitignore ├── asset ├── illustration.png ├── artists_impression.jpg ├── error_increasing_seq_len.png ├── varying_memory_cpu_gpu_results.png ├── sparse_link_matrix_losses_comparison.png └── sparse_link_matrix_seconds_comparison.png ├── README.md └── dnc.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__/ 3 | .ipynb_checkpoints 4 | result/ 5 | _vizdoom.ini -------------------------------------------------------------------------------- /asset/illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/DifferentiableNeuralComputer/HEAD/asset/illustration.png -------------------------------------------------------------------------------- /asset/artists_impression.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/DifferentiableNeuralComputer/HEAD/asset/artists_impression.jpg -------------------------------------------------------------------------------- /asset/error_increasing_seq_len.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/DifferentiableNeuralComputer/HEAD/asset/error_increasing_seq_len.png -------------------------------------------------------------------------------- /asset/varying_memory_cpu_gpu_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/DifferentiableNeuralComputer/HEAD/asset/varying_memory_cpu_gpu_results.png -------------------------------------------------------------------------------- /asset/sparse_link_matrix_losses_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/DifferentiableNeuralComputer/HEAD/asset/sparse_link_matrix_losses_comparison.png -------------------------------------------------------------------------------- /asset/sparse_link_matrix_seconds_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdeelMufti/DifferentiableNeuralComputer/HEAD/asset/sparse_link_matrix_seconds_comparison.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimized Differentiable Neural Computer In Chainer 2 | This is an optimized implementation of the [Differentiable Neural Computer (DNC)](https://deepmind.com/blog/differentiable-neural-computers/), in [Chainer](https://chainer.org). It was built up using an implementation of the DNC in Chainer by [yos1up](https://github.com/yos1up/DNC). This implementation was created as part of my MSc dissertation in Artificial Intelligence under the supervision of [Subramanian Ramamoorthy](http://homepages.inf.ed.ac.uk/sramamoo/) at the University of Edinburgh. My research aims to use the DNC in the [World Models](https://arxiv.org/abs/1803.10122) framework (check out [my implementation](https://github.com/AdeelMufti/WorldModels)). 3 | 4 | The DNC is a form of a memory augmented Neural Network that has shown promise on solving complex tasks that are difficult for traditional Neural Networks. When an RNN or LSTM is used as the Neural Network in the DNC architecture, it has the capacity to solve and generalize well on tasks with long temporal dependencies across a sequences. 5 | 6 | At each timestep in a sequence, the DNC receives an external input, and memory data read from its internal read head mechanism from the previous timestep (with timestep 0 being a vector of 0's), to manipulate its internal state and produce the next output. The memory in the DNC is said to be external because the DNC can be trained on a smaller memory size, and then attached to a larger memory matrix. 7 | 8 | > ![](asset/illustration.png) 9 | 10 | ## About This Implementation 11 | Unfortunately, no optimized implementation of the DNC was publicly available in Chainer, and yos1up's implementation was not feasible to use in my experiments due to the wall-clock time it took to train on large datasets. So this implementation was created. 12 | 13 | This implementation features: 14 | 15 | * GPU acceleration using [CuPy](https://cupy.chainer.org/) 16 | * Sparse link matrix, as described in the [DNC paper](https://www.nature.com/articles/nature20101) (2nd page of *Methods* section), which reduces computation cost from *O(n^2)* to *O(n log-n)* for the Link matrix 17 | * Avoidance of *for* loops in the core computational mechanism of the DNC 18 | 19 | ## Optimization Results 20 | 21 | With this implementation, I observed **~50x** speedup over yos1up's Chainer DNC implementation on the addition toy task, using *K=8* for the sparse link matrix, with both implementations run on a CPU. With all the same hyperparameters (memory size, read heads, hidden units, etc), yos1up's implementation takes *217.97 seconds* for 10 iterations of training on a constant sequence length of 12, while this optimized DNC takes *4.34 seconds*! 22 | 23 | In an attempt to speed it up further, I created an implementation (not available here) that could process sequences in batches. But this approach failed to converge even on simple tasks for any batch size greater than 1. I believe this is because for every row in a sequence being processed, the internal state and external memory of the DNC is manipulated, so a *for* loop would be necessary internally to process each sequence row item in the batch individually. Thus, I left it to the external code to contain this loop, and the DNC to always take a batch size of 1. 24 | 25 | **Another important note:** with small memory sizes such as 256x64, running on CPU with NumPy is *faster* than running it on GPU with CuPy! GPUs are best at massive parallel computation, with larger batch sizes, but we are limited to a batch size of 1. However, the GPU accelerated DNC outperforms (in wall-clock time for training) versus the CPU when the memory size is larger. 26 | 27 | Below follow results to test the optimizations. All results were recorded using an Amazon [g3.4xlarge](https://aws.amazon.com/ec2/instance-types/g3/) instance with and without GPU acceleration. 28 | 29 | #### GPU vs CPU 30 | > **Wall-clock training time for 100 iterations on the addition task, with increasing memory sizes.** 31 | > *As you can see, a GPU completely outperforms a CPU as the memory size grows.* 32 | > ![](asset/varying_memory_cpu_gpu_results.png) 33 | 34 | #### Sparse link matrix performance 35 | > **MSE loss on addition task with sparse link matrix disabled versus sparse link matrix with K=8.** 36 | > *Using a sparse link matrix does not degrade performance*. 37 | > ![](asset/sparse_link_matrix_losses_comparison.png) 38 | 39 | > **Wall-clock training time of the addition task with sparse link matrix disabled versus sparse link matrix with K=8, memory size 256x64** 40 | > *Training is slightly faster with the sparse link matrix, which can compound over time over large datasets for complex tasks, especially with larger memory.* 41 | > ![](asset/sparse_link_matrix_seconds_comparison.png) 42 | 43 | 44 | ## Toy Task Results 45 | 46 | To test this implementation, some (simple) toy tasks were created that are available with this code. While the authors of the DNC tested it on much more complex problems, I found these tasks to be useful for quickly testing the implementation, and getting a feel for the DNC. 47 | 48 | The toy tasks are: 49 | 50 | 1. **Repeat**: Simply echo a row-wise sequence of randomly generated 0's and 1's 51 | 2. **Addition**: Sum a randomly generated row-wise sequence, with a 1 in each row representing a number based on the column (position) it is contained in. Example below. 52 | 3. **Priority sort**: Harder task. Essentially involves repetition as well. A row-wise sequence of randomly generated 0's and 1's are to be sorted according to a priority assigned to each row, with the priority sampled uniformly from [-1,1]. This task was used in the Neural Turing Machines (DNC's predecessor) [paper](https://arxiv.org/abs/1410.5401) (pg. 19). 53 | 54 | #### Example results 55 | > **An example test sequence on the addition task run on a trained DNC. 56 | > The very last row of the input, with the 1 in the last column, is the delimiter marking the end of the sequence:** 57 | ``` 58 | Input Data: 59 | [[0. 0. 1. 0. 0. 0.] 60 | [1. 0. 0. 0. 0. 0.] 61 | [0. 0. 1. 0. 0. 0.] 62 | [1. 0. 0. 0. 0. 0.] 63 | [0. 0. 0. 0. 1. 0.] 64 | [1. 0. 0. 0. 0. 0.] 65 | [0. 0. 0. 0. 1. 0.] 66 | [0. 0. 0. 0. 0. 1.]] 67 | (Represents: 2 + 0 + 2 + 0 + 4 + 0 + 4) 68 | Target Data: 69 | [[ 0.] 70 | [ 0.] 71 | [ 0.] 72 | [ 0.] 73 | [ 0.] 74 | [ 0.] 75 | [ 0.] 76 | [12.]] 77 | DNC Output: 78 | [[-0.14442] 79 | [-0.30143] 80 | [-0.24193] 81 | [-0.16146] 82 | [0.08639] 83 | [0.01805] 84 | [0.03887] 85 | [12.01316]] 86 | ``` 87 | 88 | **Note:** While you may not see the DNC return exact results (perfect 0's and 1's returned in the DNC's output), if you look closely, you'll note that the results follow the right pattern. For example, for the repeat and priority sort tasks, given enough training (especially for priority sort), you'll note that wherever there are supposed to be 1's, the DNC will return values closer to 1's, and wherever there are supposed to be 0's, the DNC will return values closer to 0's. 89 | 90 | #### Generalization 91 | 92 | I observed that simple LSTMs learn these toy tasks faster--they were able to solve the task in fewer iterations of training. However, I observed that they fail to generalize when they see longer sequences than they were trained on, in comparison to the DNC. 93 | 94 | I tested this out by training both a simple LSTM model and DNC (same number of hidden units in both) for 50,000 iterations on the same task. A maximum input sequence length of 12 was used during training. Then, I input each model with sequence lengths greater than 12. The results of the simple LSTMs started diverging rapidly with sequence lenghts greater than 12, while the DNC was able to get excellent accuracy up to sequence lengths of 100! This goes to show that the DNC learns algorithmic tasks that can make use of its memory, and presents a solid step towards generalization. 95 | 96 | > **Error over increasing sequence lengths on the addition task with simple LSTMs versus DNCs trained for 50k iterations.** 97 | > *Both were trained using a random sequence length between 2 and 12, and had never seen longer sequence lengths during training. As you can see, simple LSTMs quickly diverge on sequences greater than 12, while the DNC is very robust to generalizing on longer sequences. The error for LSTMs becomes so large, that matplotlib refuses to plot it beyond a certain threshold alongside DNC's error!* 98 | > ![](asset/error_increasing_seq_len.png) 99 | 100 | 101 | ## Usage 102 | Set up your environment: 103 | 104 | `conda install chainer numpy cupy` 105 | 106 | To try it on toy tasks, run from the command line using: 107 | `python dnc.py [arguments]` 108 | For example: 109 | `python dnc.py --task addition` 110 | 111 | **Arguments:** 112 | 113 | | Argument|Default|Description| 114 | |---|---|---| 115 | | --hps | 256,64,4,256,8 | N,W,R,LSTM hidden units,K. N=number of memory locations, W=width of each memory location, R=number of read heads, K=number of entries to maintain in the sparse link matrix. Check the [DNC paper](https://www.nature.com/articles/nature20101) for more details. | 116 | | --gpu | -1 | Device ordinal for GPU to be used (usually 0 if you have 1 GPU). -1 means use CPU. | 117 | | --lstm_only | False | Use this argument to turn off DNC and use simple LSTMs for training/testing. Good for comparison of results on toy tasks. | 118 | | --task | addition | Which task to test: **addition**, **repeat**, **priority_sort**. | 119 | | --max_seq_len | 12 | Sequences generated per iteration of training are randomly picked to be up to this length. | 120 | | --max_seq_wid | 6 | Number of elements in each sequence row per iteration of training. This remains fixed. | 121 | | --test_seq_len | 0 | Test longer sequence length than trained on to check generalization. 0 = off. Will only work if a task has been trained and saved (automatically saved during training periodically). | 122 | | --iterations | 100000 | This many randomly generated sequences will be used for training for the selected task. | 123 | 124 | Throughout training, a model will be saved every 50,00 iterations in a sub-folder named *result*. 125 | 126 | #### Using The DNC In Place Of LSTMs In Your Own Projects 127 | The DNC in this code should seamlessly plug into your Chainer projects. Simply import it, and treat it like a [Link](https://docs.chainer.org/en/stable/reference/links.html). Use it in place of an LSTM, and see what happens! 128 | 129 | **Note:** You'll need to use a batch size of 1 when feeding training data. 130 | ```python 131 | from lib import DNC 132 | 133 | class MyModel(chainer.Chain): 134 | def __init__(self, input_dim=35, hidden_dim=256, output_dim=32): 135 | super(MyModel, self).__init__( 136 | # X, Y, N, W, R, lstm_hidden_dim, K 137 | lstm_layer = DNC(input_dim, output_dim, 256, 64, 4, hidden_dim, 8) 138 | ) 139 | #... 140 | 141 | model = MyModel(...) 142 | model.to_gpu(...) 143 | ``` 144 | 145 | > ![](asset/artists_impression.jpg) 146 | > An artists's impression of the Differentiable Neural Computer. 147 | 148 | ## License 149 | This code is available under the [MIT](https://opensource.org/licenses/MIT) license. -------------------------------------------------------------------------------- /dnc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import time 5 | 6 | import numpy as np 7 | from chainer import functions as F 8 | from chainer import links as L 9 | from chainer import \ 10 | Chain, Variable, optimizers, serializers 11 | try: 12 | import cupy as cp 13 | except Exception as e: 14 | None 15 | 16 | xp = np 17 | 18 | 19 | def overlap(u, v): # u, v: (1 * -) -> (1 * 1) 20 | if v.shape[0] != 1: 21 | u_tup = () 22 | for i in range(v.shape[0]): 23 | u_tup = u_tup + (u,) 24 | u_repeated = xp.vstack(u_tup) 25 | v_repeated = xp.repeat(v, u.shape[0], 0) 26 | denominator = xp.sqrt(xp.sum(u_repeated * u_repeated, 1) * xp.sum(v_repeated * v_repeated, 1)) 27 | denominator = denominator.reshape(v.shape[0],u.shape[0]).T 28 | else: 29 | denominator = xp.sqrt(xp.sum(u * u, 1) * xp.sum(v * v, 1)).reshape(-1,1) 30 | denominator[denominator == 0.] = 1. 31 | return xp.dot(u, v.T) / denominator 32 | 33 | 34 | def C(M, k, beta): 35 | ret_list = overlap(M, k) * beta 36 | ret_list = ret_list.T 37 | if ret_list.shape[0] != 1: 38 | softmax = xp.exp(ret_list - xp.max(ret_list,1).reshape(-1,1)) 39 | softmax = softmax / softmax.sum(1).reshape(-1,1) 40 | else: 41 | softmax = xp.exp(ret_list - xp.max(ret_list)) 42 | softmax = softmax / softmax.sum() 43 | ret_list = softmax.T 44 | return ret_list 45 | 46 | 47 | def u2a(u): # u, a: (N * 1) 48 | N = len(u) 49 | phi = xp.argsort(xp.reshape(u,N)) # u[phi]: ascending 50 | cumprod = xp.cumprod(u[phi]) 51 | cumprod[-1] = 1. 52 | cumprod = xp.roll(cumprod, 1) 53 | cumprod = cumprod.reshape(-1,1) 54 | a_list = xp.zeros((N,1)).astype(xp.float32) 55 | a_list[phi] = (cumprod * (1.0 - u[phi])) 56 | return a_list 57 | 58 | 59 | class DeepLSTM(Chain): 60 | def __init__(self, lstm_hidden_dim, d_out, gpu_for_nn_only=False): 61 | self.gpu_for_nn_only = gpu_for_nn_only 62 | super(DeepLSTM, self).__init__( 63 | l1=L.LSTM(None, lstm_hidden_dim), 64 | l2=L.Linear(lstm_hidden_dim, d_out) 65 | ) 66 | 67 | def __call__(self, x): 68 | if self._device_id is not None and self.gpu_for_nn_only: 69 | x = F.copy(x, self._device_id) 70 | y = self.l2(self.l1(x)) 71 | if self._device_id is not None and self.gpu_for_nn_only: 72 | y = F.copy(y, -1) 73 | return y 74 | 75 | def reset_state(self): 76 | self.l1.reset_state() 77 | 78 | def get_h(self): 79 | return self.l1.h 80 | 81 | def get_c(self): 82 | return self.l1.c 83 | 84 | 85 | class Linear(Chain): 86 | def __init__(self, d_out, gpu_for_nn_only=False): 87 | self.gpu_for_nn_only = gpu_for_nn_only 88 | super(Linear, self).__init__( 89 | l=L.Linear(None, d_out) 90 | ) 91 | 92 | def __call__(self, x): 93 | if self._device_id is not None and self.gpu_for_nn_only: 94 | x = F.copy(x, self._device_id) 95 | y = self.l(x) 96 | if self._device_id is not None and self.gpu_for_nn_only: 97 | y = F.copy(y, -1) 98 | return y 99 | 100 | 101 | class DNC(Chain): 102 | def __init__(self, X, Y, N, W, R, lstm_hidden_dim, K=8, gpu_for_nn_only=False): 103 | self.X = X # input dimension 104 | self.Y = Y # output dimension 105 | self.N = N # number of memory slot 106 | self.W = W # dimension of one memory slot 107 | self.R = R # number of read heads 108 | self.K = K # Described under **Methods** of DNC paper, in the *Sparse link matrix* section 109 | 110 | self.xi_split_indices = xp.cumsum(xp.array([self.W * self.R, self.R, self.W, 1, self.W, self.W, self.R, 1, 1])).tolist() 111 | 112 | self.controller = DeepLSTM(lstm_hidden_dim, Y + W * R + 3 * W + 5 * R + 3, gpu_for_nn_only) 113 | self.linear = Linear(self.Y, gpu_for_nn_only) 114 | 115 | super(DNC, self).__init__( 116 | l_dl=self.controller, 117 | l_Wr=self.linear 118 | ) 119 | 120 | def __call__(self, x): 121 | self.chi = F.concat((x, self.r)) 122 | (self.nu, self.xi) = F.split_axis(self.l_dl(self.chi), [self.Y], 1) 123 | 124 | (self.kr, self.betar, self.kw, self.betaw, self.e, self.v, self.f, self.ga, self.gw, self.pi) = \ 125 | F.split_axis(self.xi, self.xi_split_indices, 1) 126 | 127 | self.kr = F.reshape(self.kr, (self.R, self.W)) # R * W 128 | self.betar = 1 + F.softplus(self.betar) # 1 * R 129 | # self.kw: 1 * W 130 | self.betaw = 1 + F.softplus(self.betaw) # 1 * 1 131 | self.e = F.sigmoid(self.e) # 1 * W 132 | # self.v : 1 * W 133 | self.f = F.sigmoid(self.f) # 1 * R 134 | self.ga = F.sigmoid(self.ga) # 1 * 1 135 | self.gw = F.sigmoid(self.gw) # 1 * 1 136 | self.pi = F.softmax(F.reshape(self.pi, (self.R, 3))) # R * 3 (softmax for 3) 137 | 138 | # self.wr : N * R 139 | self.psi_mat = 1 - F.broadcast_to(self.f,(self.N,self.R)) * self.wr # N x R 140 | self.psi = F.prod(self.psi_mat, 1).reshape(self.N, 1) # N x 1 141 | 142 | # self.ww, self.u : N * 1 143 | self.u = (self.u + self.ww - (self.u * self.ww)) * self.psi 144 | 145 | self.a = u2a(self.u.data) # N * 1 146 | self.cw = C(self.M.data, self.kw.data, self.betaw.data) # N * 1 147 | self.ww = F.matmul(F.matmul(self.a, self.ga) + F.matmul(self.cw, 1.0 - self.ga), self.gw) # N * 1 148 | self.M = self.M * (xp.ones((self.N, self.W)).astype(xp.float32) - F.matmul(self.ww, self.e)) + F.matmul(self.ww, 149 | self.v) # N * W 150 | if self.K > 0: 151 | self.p = (1.0 - F.matmul(Variable(xp.ones((self.N, 1)).astype(xp.float32)), F.reshape(F.sum(self.ww), (1, 1)))) \ 152 | * self.p + self.ww # N * 1 153 | self.p.data = xp.sort(self.p.data,0) 154 | self.p.data[0:-self.K] = 0. 155 | self.p.data[-self.K:] = self.p.data[-self.K:]/xp.sum(self.p.data[-self.K:]) 156 | self.ww.data = xp.sort(self.ww.data,0) 157 | self.ww.data[0:-self.K] = 0. 158 | self.ww.data[-self.K:] = self.ww[-self.K:].data/xp.sum(self.ww.data[-self.K:]) 159 | self.wwrep = F.matmul(self.ww, Variable(xp.ones((1, self.N)).astype(xp.float32))) # N * N 160 | self.ww_p_product = xp.zeros((self.N,self.N)).astype(xp.float32) 161 | self.ww_p_product[-self.K:,-self.K:] = F.matmul(self.ww[-self.K:,-self.K:], F.transpose(self.p[-self.K:,-self.K:])).data 162 | self.L = (1.0 - self.wwrep - F.transpose(self.wwrep)) * self.L + self.ww_p_product # N * N 163 | self.L = self.L * (xp.ones((self.N, self.N)) - xp.eye(self.N)) # force L[i,i] == 0 164 | self.L.data[self.L.data < 1/self.K] = 0. 165 | else: 166 | self.p = (1.0 - F.matmul(Variable(xp.ones((self.N, 1)).astype(xp.float32)), 167 | F.reshape(F.sum(self.ww), (1, 1)))) \ 168 | * self.p + self.ww # N * 1 169 | self.wwrep = F.matmul(self.ww, Variable(xp.ones((1, self.N)).astype(xp.float32))) # N * N 170 | self.L = (1.0 - self.wwrep - F.transpose(self.wwrep)) * self.L + F.matmul(self.ww, 171 | F.transpose(self.p)) # N * N 172 | self.L = self.L * (xp.ones((self.N, self.N)) - xp.eye(self.N)) # force L[i,i] == 0 173 | self.fo = F.matmul(self.L, self.wr) # N * R 174 | self.ba = F.matmul(F.transpose(self.L), self.wr) # N * R 175 | 176 | self.cr = C(self.M.data, self.kr.data, self.betar.data) 177 | 178 | self.bacrfo = F.concat((F.reshape(F.transpose(self.ba), (self.R, self.N, 1)), 179 | F.reshape(F.transpose(self.cr), (self.R, self.N, 1)), 180 | F.reshape(F.transpose(self.fo), (self.R, self.N, 1)),), 2) # R * N * 3 181 | self.pi = F.reshape(self.pi, (self.R, 3, 1)) # R * 3 * 1 182 | self.wr = F.transpose(F.reshape(F.batch_matmul(self.bacrfo, self.pi), (self.R, self.N))) # N * R 183 | 184 | self.r = F.reshape(F.matmul(F.transpose(self.M), self.wr), (1, self.R * self.W)) # W * R (-> 1 * RW) 185 | 186 | self.y = self.l_Wr(self.r) + self.nu # 1 * Y 187 | return self.y 188 | 189 | def reset_state(self): 190 | self.l_dl.reset_state() 191 | self.u = Variable(xp.zeros((self.N, 1)).astype(xp.float32)) 192 | self.p = Variable(xp.zeros((self.N, 1)).astype(xp.float32)) 193 | self.L = Variable(xp.zeros((self.N, self.N)).astype(xp.float32)) 194 | self.M = Variable(xp.zeros((self.N, self.W)).astype(xp.float32)) 195 | self.r = Variable(xp.zeros((1, self.R * self.W)).astype(xp.float32)) 196 | self.wr = Variable(xp.zeros((self.N, self.R)).astype(xp.float32)) 197 | self.ww = Variable(xp.zeros((self.N, 1)).astype(xp.float32)) 198 | 199 | def to_gpu(self, device=None): 200 | global xp 201 | xp = cp 202 | if device is not None: 203 | xp.cuda.Device(device).use() 204 | self.l_dl.to_gpu(device) 205 | self.l_Wr.to_gpu(device) 206 | 207 | def to_cpu(self): 208 | global xp 209 | xp = np 210 | self.l_dl.to_cpu() 211 | self.l_Wr.to_cpu() 212 | 213 | def get_h(self): 214 | return self.l_dl.get_h() 215 | 216 | def get_c(self): 217 | return self.l_dl.get_c() 218 | 219 | 220 | if __name__ == '__main__': 221 | parser = argparse.ArgumentParser(description='Optimized Chainer DNC') 222 | parser.add_argument('--hps', default="256,64,4,256,8", 223 | help='DNC hyperparams: N,W,R,H,K. K=0 disable sparse Link matrix. H is hidden dim for LSTM') 224 | parser.add_argument('--gpu', '-g', default=-1, type=int, help='GPU ID (negative value indicates CPU)') 225 | parser.add_argument('--gpu_for_nn_only', action='store_true', 226 | help='If using GPU, use GPU/CuPy for LSTM/Linear layers, and CPU/NumPy for DNC core') 227 | parser.add_argument('--lstm_only', action='store_true', help='Use vanilla LSTMs instead of DNC') 228 | parser.add_argument('--task', default="addition", help="Which task to test: sum, repeat, priority_sort") 229 | parser.add_argument('--max_seq_len', default=12, type=int, 230 | help="Max sequence length to train or test on (picked randomly starting from 0") 231 | parser.add_argument('--max_seq_wid', default=6, type=int, help="Width of each sequence. Always the same.") 232 | parser.add_argument('--test_seq_len', default=0, type=int, 233 | help="Test longer sequence length than trained on to check generalization. 0 = off") 234 | parser.add_argument('--iterations', default=100000, type=int, 235 | help="How many iterations of training sequences fed to network") 236 | args = parser.parse_args() 237 | print("args = " + str(vars(args))+"\n") 238 | 239 | N, W, R, H, K = args.hps.split(",") 240 | N, W, R, H, K = int(N), int(W), int(R), int(H), int(K) 241 | 242 | X = args.max_seq_wid 243 | if args.task == "addition": 244 | Y = 1 245 | elif args.task == "repeat": 246 | Y = args.max_seq_wid 247 | elif args.task == "priority_sort": 248 | Y = args.max_seq_wid-2 249 | else: 250 | print("Unknown task: "+args.task) 251 | exit() 252 | 253 | if args.lstm_only: 254 | model = DeepLSTM(H, Y, True) 255 | dnc_or_lstm = "lstm" 256 | else: 257 | model = DNC(X, Y, N, W, R, H, K, args.gpu_for_nn_only) 258 | dnc_or_lstm = "dnc" 259 | 260 | if not os.path.exists("result"): 261 | os.makedirs("result") 262 | 263 | max_iter = 0 264 | auto_resume_file = None 265 | files = os.listdir("result") 266 | for file in files: 267 | pattern = re.compile("^"+dnc_or_lstm+"_"+args.task+"_iter_") 268 | if pattern.match(file): 269 | iter = int(re.search(r'\d+', file).group()) 270 | if (iter > max_iter): 271 | max_iter = iter 272 | auto_resume_file = os.path.join("result", file) 273 | if auto_resume_file is not None: 274 | print("Resuming from saved model: "+auto_resume_file+"\n") 275 | serializers.load_npz(auto_resume_file, model) 276 | 277 | if args.test_seq_len > 0: 278 | if not auto_resume_file: 279 | print("No saved model found to resume from for testing.") 280 | exit() 281 | args.test = True 282 | if max_iter == args.iterations: 283 | max_iter -= 1 284 | else: 285 | args.test = False 286 | 287 | if args.gpu >= 0: 288 | model.to_gpu(args.gpu) 289 | 290 | opt = optimizers.Adam(alpha=0.0001) 291 | opt.setup(model) 292 | start = time.time() 293 | for i in range(max_iter+1, args.iterations+1): 294 | model.reset_state() 295 | loss = 0 296 | outputs = [] 297 | 298 | if args.task == "addition": 299 | def onehot(x, n): 300 | ret = xp.zeros(n).astype(xp.float32) 301 | ret[x] = 1.0 302 | return ret 303 | def generate_data(): 304 | if args.test: 305 | length = args.test_seq_len 306 | else: 307 | length = int(xp.random.randint(2, (args.max_seq_len) + 1)) 308 | content = xp.random.randint(0, args.max_seq_wid - 1, length) 309 | seq_length = length + 1 310 | input_data = xp.zeros((seq_length, args.max_seq_wid)).astype(xp.float32) 311 | target_data = 0.0 312 | sums_text = "" 313 | for i in range(seq_length): 314 | if i < length: 315 | input_data[i] = onehot(content[i], args.max_seq_wid) 316 | target_data += content[i] 317 | sums_text += str(content[i]) + " + " 318 | else: 319 | input_data[i] = onehot(args.max_seq_wid - 1, args.max_seq_wid) 320 | input_data = input_data.reshape((1,) + input_data.shape) 321 | target_data = xp.array(target_data).astype(xp.float32) 322 | target_data = target_data.reshape(1, 1, 1) 323 | return input_data, target_data, sums_text 324 | input_data, target_data, sums_text = generate_data() 325 | input_data = input_data[0] 326 | target_data = target_data[0] 327 | target_data = xp.vstack((xp.zeros(input_data.shape[0]-1).astype(xp.float32).reshape(-1, 1), target_data)) 328 | 329 | elif args.task == "repeat": 330 | def generate_data(): 331 | if args.test: 332 | length = args.test_seq_len//2+1 333 | else: 334 | length = int(xp.random.randint(1, args.max_seq_len//2+1)) 335 | input_data = xp.zeros((2 * length + 1, args.max_seq_wid), dtype=xp.float32) 336 | target_data = xp.zeros((2 * length + 1, args.max_seq_wid), dtype=xp.int) 337 | sequence = xp.random.randint(0, 2, (length, args.max_seq_wid - 1)) 338 | input_data[:length, :args.max_seq_wid - 1] = sequence 339 | input_data[length, -1] = 1 340 | target_data[length + 1:, :args.max_seq_wid - 1] = sequence 341 | return input_data, target_data 342 | input_data, target_data = generate_data() 343 | 344 | elif args.task == "priority_sort": 345 | def generate_data(): 346 | if args.test: 347 | length = args.test_seq_len//2+1 348 | else: 349 | length = int(xp.random.randint(2, args.max_seq_len//2+1)) 350 | input_data = xp.random.randint(0,2,(length, args.max_seq_wid)).astype(xp.float32) 351 | input_data[:,0] *= 0. 352 | input_data[-1] *= 0. 353 | input_data[:,-1] *= 0. 354 | input_data[-1,-1] = 1. 355 | priority_sort_index = xp.random.uniform(-1,1,(length-1,1)) 356 | input_data[0:-1, 0:1] = priority_sort_index 357 | internal_sort_index = xp.argsort(priority_sort_index.reshape(-1)) 358 | target_data = input_data[internal_sort_index] 359 | input_data = xp.vstack((input_data, xp.zeros((length, args.max_seq_wid)).astype(xp.float32))) 360 | target_data = xp.concatenate((xp.zeros((length+1, args.max_seq_wid)).astype(xp.float32),target_data)) 361 | target_data = target_data[:,1:-1] 362 | return input_data, target_data 363 | input_data, target_data = generate_data() 364 | 365 | for j in range(input_data.shape[0]): 366 | output_data = model(F.expand_dims(input_data[j],0)) 367 | outputs.append(output_data[0]) 368 | #loss += F.sigmoid_cross_entropy(output_data[0], target_data[j], reduce="no") 369 | loss += (output_data[0] - target_data[j]) ** 2 370 | loss = F.mean(loss) 371 | 372 | if not args.test: 373 | model.cleargrads() 374 | loss.backward() 375 | opt.update() 376 | loss.unchain_backward() 377 | 378 | if not args.test and i == max_iter+1: 379 | print("\nTime \t\t Iter \t\t Loss") 380 | print("------------------------------------------") 381 | if not args.test and i % 10 == 0: 382 | print("{:.2f}s".format(time.time()-start), "\t\t", i, "\t\t", loss.data) 383 | if args.test or i % 500 == 0: 384 | print("---Sample Training Output---") 385 | print("Input Data:") 386 | print(input_data) 387 | if args.task == "addition": 388 | print(sums_text) 389 | print("Target Data:") 390 | print(target_data) 391 | if args.task == "addition": 392 | for row in outputs: 393 | print("Output: {:.5f}".format(float(row.data))) 394 | else: 395 | for row in outputs: 396 | print("Output: ", end="") 397 | for col in row.data: 398 | print("{:.5f} ".format(float(col)), end="") 399 | print() 400 | print("----------------------------") 401 | if not args.test and i % 5000 == 0: 402 | filename = "result/"+dnc_or_lstm +"_" + args.task +"_iter_" + str(i) +".model" 403 | print("Saving model to: "+filename) 404 | serializers.save_npz(filename, model) 405 | if args.test: 406 | break --------------------------------------------------------------------------------