├── README.md └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of Differentiable Neural Computers (DNC) in Chainer 2 | 3 | Differentiable Neural Computers (DNC) is a neural network architecture proposed by DeepMind in [their third paper on Nature](http://www.nature.com/articles/nature20101.epdf?author_access_token=ImTXBI8aWbYxYQ51Plys8NRgN0jAjWel9jnR3ZoTv0MggmpDmwljGswxVdeocYSurJ3hxupzWuRNeGvvXnoO8o4jTJcnAyhGuZzXJ1GEaD-Z7E6X_a9R-xqJ9TfJWBqz). 4 | I have implemented DNC in [Chainer](http://chainer.org/), a flexible framework of neural networks developped by [Preferred Networks](https://www.preferred-networks.jp/en/). 5 | 6 | # What is DNC ? 7 | DNC is a newly proposed neural network. In their paper, DNC learns well in several complex tasks, including finding shortest path in a graph and solving a block puzzle game. It is expected to have the capacity to solve complex, structured tasks that are inaccessible to previous neural networks. 8 | 9 | DNC consists of a RNN (recurrent neural network) and a "memory matrix", with some heads for reading and writing to it. The RNN can control the heads at will; it can manipulate the heads in a predetermined fashion to read out the content of the memory and write some data to the memory. 10 | 11 | In each timestep, a vanilla RNN receives some external input and yields some output (and refreshes its internal state). In contrast, a RNN in a DNC recieves "data read by the read head at the previous timestep" together with external input, and yields "memory manipulation command" in addition to output data. In accordance with this command, the heads are moved, the memory content at the write head is edited, and the memory contents at the read heads are fetched. Fetched data is input to RNN at the next timestep (together with external input data). 12 | 13 | A RNN in DNC learns so that it achieves appropriate input-output relationship in the situation that "the memory" --- a convenient tool to compute --- is given to use freely. How the RNN utilizes the tool depends on its learning. 14 | 15 | Although "read-write memory" seems to be very special, it can be regarded as a form of internal state of a RNN(*); in other words, DNC is an RNN that has non-trivial internal-state dynamics like LSTM, but the dynamics are very complicated. This memory enables the RNN to perform complicated information processings. Moreover, equipped with "read-write memory", which is fairly convenient for every kind of information processing, I expect that DNC has high versatility --- to perform various types of tasks reasonably well. 16 | 17 | Note that their [NTM (Neural Turing Machine)](https://arxiv.org/pdf/1410.5401v2.pdf) proposed in 2014 has similar structure to DNC. The difference between DNC and NTM is that DNC has more reasonable memory heads' movement. (For datails see the Methods in the DNC paper.) 18 | 19 | (*): They call the memory as "external". They say that is because "The behaviour of the network is independent of the memory size as long as the memory is not filled to capacity". 20 | 21 | 22 | # About my code 23 | 24 | In my code, a very small-scale DNC learns a very easy "repeat after me" task. It seems to learn correctly without errors, but it does not necessarily mean that this program correctly performs DNC. If you have any comments about my code, please feel free to contact @yos1up (twitter). 25 | 26 | The Supplementary Material of their paper is very useful to implement DNC. It contains ALL variables used in the model and ALL equations to construct the computational graph of the model in two pages. Most of the names of the variables shown in my code coincide with that in their paper. 27 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import chainer 4 | from chainer import functions as F 5 | from chainer import links as L 6 | from chainer import \ 7 | cuda, gradient_check, optimizers, serializers, utils, \ 8 | Chain, ChainList, Function, Link, Variable 9 | 10 | 11 | def onehot(x,n): 12 | ret = np.zeros(n).astype(np.float32) 13 | ret[x] = 1.0 14 | return ret 15 | 16 | def overlap(u, v): # u, v: (1 * -) Variable -> (1 * 1) Variable 17 | denominator = F.sqrt(F.batch_l2_norm_squared(u) * F.batch_l2_norm_squared(v)) 18 | if (np.array_equal(denominator.data, np.array([0]))): 19 | return F.matmul(u, F.transpose(v)) 20 | return F.matmul(u, F.transpose(v)) / F.reshape(denominator,(1,1)) 21 | 22 | 23 | def C(M, k, beta): 24 | # (N * W), (1 * W), (1 * 1) -> (N * 1) 25 | # (not (N * W), ({R,1} * W), (1 * {R,1}) -> (N * {R,1})) 26 | W = M.data.shape[1] 27 | ret_list = [0] * M.data.shape[0] 28 | for i in range(M.data.shape[0]): 29 | ret_list[i] = overlap(F.reshape(M[i,:], (1, W)), k) * beta # pick i-th row 30 | return F.transpose(F.softmax(F.transpose(F.concat(ret_list, 0)))) # concat vertically and calc softmax in each column 31 | 32 | 33 | 34 | def u2a(u): # u, a: (N * 1) Variable 35 | N = len(u.data) 36 | phi = np.argsort(u.data.reshape(N)) # u.data[phi]: ascending 37 | a_list = [0] * N 38 | cumprod = Variable(np.array([[1.0]]).astype(np.float32)) 39 | for i in range(N): 40 | a_list[phi[i]] = cumprod * (1.0 - F.reshape(u[phi[i],0], (1,1))) 41 | cumprod *= F.reshape(u[phi[i],0], (1,1)) 42 | return F.concat(a_list, 0) # concat vertically 43 | 44 | 45 | 46 | class DeepLSTM(Chain): # too simple? 47 | def __init__(self, d_in, d_out): 48 | super(DeepLSTM, self).__init__( 49 | l1 = L.LSTM(d_in, d_out), 50 | l2 = L.Linear(d_out, d_out),) 51 | def __call__(self, x): 52 | self.x = x 53 | self.y = self.l2(self.l1(self.x)) 54 | return self.y 55 | def reset_state(self): 56 | self.l1.reset_state() 57 | 58 | 59 | 60 | class DNC(Chain): 61 | def __init__(self, X, Y, N, W, R): 62 | self.X = X # input dimension 63 | self.Y = Y # output dimension 64 | self.N = N # number of memory slot 65 | self.W = W # dimension of one memory slot 66 | self.R = R # number of read heads 67 | self.controller = DeepLSTM(W*R+X, Y+W*R+3*W+5*R+3) 68 | 69 | super(DNC, self).__init__( 70 | l_dl = self.controller, 71 | l_Wr = L.Linear(self.R * self.W, self.Y) # nobias=True ? 72 | )# 73 | self.reset_state() 74 | def __call__(self, x): 75 | # 1 possible for RNN ? if No, I will implement calculations without batch dimension.> 76 | self.chi = F.concat((x, self.r)) 77 | (self.nu, self.xi) = \ 78 | F.split_axis(self.l_dl(self.chi), [self.Y], 1) 79 | (self.kr, self.betar, self.kw, self.betaw, 80 | self.e, self.v, self.f, self.ga, self.gw, self.pi 81 | ) = F.split_axis(self.xi, np.cumsum( 82 | [self.W*self.R, self.R, self.W, 1, self.W, self.W, self.R, 1, 1]), 1) 83 | 84 | self.kr = F.reshape(self.kr, (self.R, self.W)) # R * W 85 | self.betar = 1 + F.softplus(self.betar) # 1 * R 86 | # self.kw: 1 * W 87 | self.betaw = 1 + F.softplus(self.betaw) # 1 * 1 88 | self.e = F.sigmoid(self.e) # 1 * W 89 | # self.v : 1 * W 90 | self.f = F.sigmoid(self.f) # 1 * R 91 | self.ga = F.sigmoid(self.ga) # 1 * 1 92 | self.gw = F.sigmoid(self.gw) # 1 * 1 93 | self.pi = F.softmax(F.reshape(self.pi, (self.R, 3))) # R * 3 (softmax for 3) 94 | 95 | # self.wr : N * R 96 | self.psi_mat = 1 - F.matmul(Variable(np.ones((self.N, 1)).astype(np.float32)), self.f) * self.wr # N * R 97 | self.psi = Variable(np.ones((self.N, 1)).astype(np.float32)) # N * 1 98 | for i in range(self.R): 99 | self.psi = self.psi * F.reshape(self.psi_mat[:,i],(self.N,1)) # N * 1 100 | 101 | # self.ww, self.u : N * 1 102 | self.u = (self.u + self.ww - (self.u * self.ww)) * self.psi 103 | 104 | self.a = u2a(self.u) # N * 1 105 | self.cw = C(self.M, self.kw, self.betaw) # N * 1 106 | self.ww = F.matmul(F.matmul(self.a, self.ga) + F.matmul(self.cw, 1.0 - self.ga), self.gw) # N * 1 107 | self.M = self.M * (np.ones((self.N, self.W)).astype(np.float32) - F.matmul(self.ww, self.e)) + F.matmul(self.ww, self.v) # N * W 108 | 109 | self.p = (1.0 - F.matmul(Variable(np.ones((self.N,1)).astype(np.float32)), F.reshape(F.sum(self.ww),(1,1)))) \ 110 | * self.p + self.ww # N * 1 111 | self.wwrep = F.matmul(self.ww, Variable(np.ones((1, self.N)).astype(np.float32))) # N * N 112 | self.L = (1.0 - self.wwrep - F.transpose(self.wwrep)) * self.L + F.matmul(self.ww, F.transpose(self.p)) # N * N 113 | self.L = self.L * (np.ones((self.N, self.N)) - np.eye(self.N)) # force L[i,i] == 0 114 | 115 | self.fo = F.matmul(self.L, self.wr) # N * R 116 | self.ba = F.matmul(F.transpose(self.L), self.wr) # N * R 117 | 118 | self.cr_list = [0] * self.R 119 | for i in range(self.R): 120 | self.cr_list[i] = C(self.M, F.reshape(self.kr[i,:],(1, self.W)), 121 | F.reshape(self.betar[0,i],(1, 1))) # N * 1 122 | self.cr = F.concat(self.cr_list) # N * R 123 | 124 | self.bacrfo = F.concat((F.reshape(F.transpose(self.ba),(self.R,self.N,1)), 125 | F.reshape(F.transpose(self.cr),(self.R,self.N,1)), 126 | F.reshape(F.transpose(self.fo) ,(self.R,self.N,1)),),2) # R * N * 3 127 | self.pi = F.reshape(self.pi, (self.R,3,1)) # R * 3 * 1 128 | self.wr = F.transpose(F.reshape(F.batch_matmul(self.bacrfo, self.pi), (self.R, self.N))) # N * R 129 | 130 | self.r = F.reshape(F.matmul(F.transpose(self.M), self.wr),(1, self.R * self.W)) # W * R (-> 1 * RW) 131 | 132 | self.y = self.l_Wr(self.r) + self.nu # 1 * Y 133 | return self.y 134 | def reset_state(self): 135 | self.l_dl.reset_state() 136 | self.u = Variable(np.zeros((self.N, 1)).astype(np.float32)) 137 | self.p = Variable(np.zeros((self.N, 1)).astype(np.float32)) 138 | self.L = Variable(np.zeros((self.N, self.N)).astype(np.float32)) 139 | self.M = Variable(np.zeros((self.N, self.W)).astype(np.float32)) 140 | self.r = Variable(np.zeros((1, self.R*self.W)).astype(np.float32)) 141 | self.wr = Variable(np.zeros((self.N, self.R)).astype(np.float32)) 142 | self.ww = Variable(np.zeros((self.N, 1)).astype(np.float32)) 143 | # any variable else ? 144 | 145 | X = 5 146 | Y = 5 147 | N = 10 148 | W = 10 149 | R = 2 150 | mdl = DNC(X, Y, N, W, R) 151 | opt = optimizers.Adam() 152 | opt.setup(mdl) 153 | datanum = 100000 154 | loss = 0.0 155 | acc = 0.0 156 | for datacnt in range(datanum): 157 | lossfrac = np.zeros((1,2)) 158 | # x_seq = np.random.rand(X,seqlen).astype(np.float32) 159 | # t_seq = np.random.rand(Y,seqlen).astype(np.float32) 160 | # t_seq = np.copy(x_seq) 161 | 162 | contentlen = np.random.randint(3,6) 163 | content = np.random.randint(0,X-1,contentlen) 164 | seqlen = contentlen + contentlen 165 | x_seq_list = [float('nan')] * seqlen 166 | t_seq_list = [float('nan')] * seqlen 167 | for i in range(seqlen): 168 | if (i < contentlen): 169 | x_seq_list[i] = onehot(content[i],X) 170 | elif (i == contentlen): 171 | x_seq_list[i] = onehot(X-1,X) 172 | else: 173 | x_seq_list[i] = np.zeros(X).astype(np.float32) 174 | 175 | if (i >= contentlen): 176 | t_seq_list[i] = onehot(content[i-contentlen],X) 177 | 178 | mdl.reset_state() 179 | for cnt in range(seqlen): 180 | x = Variable(x_seq_list[cnt].reshape(1,X)) 181 | if (isinstance(t_seq_list[cnt], np.ndarray)): 182 | t = Variable(t_seq_list[cnt].reshape(1,Y)) 183 | else: 184 | t = [] 185 | 186 | y = mdl(x) 187 | if (isinstance(t,chainer.Variable)): 188 | loss += (y - t)**2 189 | print y.data, t.data, np.argmax(y.data)==np.argmax(t.data) 190 | if (np.argmax(y.data)==np.argmax(t.data)): acc += 1 191 | if (cnt+1==seqlen): 192 | mdl.cleargrads() 193 | loss.grad = np.ones(loss.data.shape, dtype=np.float32) 194 | loss.backward() 195 | opt.update() 196 | loss.unchain_backward() 197 | print '(', datacnt, ')', loss.data.sum()/loss.data.size/contentlen, acc/contentlen 198 | lossfrac += [loss.data.sum()/loss.data.size/seqlen, 1.] 199 | loss = 0.0 200 | acc = 0.0 201 | --------------------------------------------------------------------------------