├── image
├── example.png
├── count_46900_test_0.png
├── count_46900_test_1.png
├── count_46900_test_2.png
├── count_46900_test_3.png
├── count_46900_test_4.png
├── count_46900_test_5.png
├── count_46900_test_6.png
├── count_46900_test_7.png
├── count_46900_test_8.png
└── count_46900_test_9.png
├── save
└── weights_final.tar
├── .idea
└── vcs.xml
├── config.py
├── generate.py
├── README.md
├── utility.py
├── train.py
└── draw_model.py
/image/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/example.png
--------------------------------------------------------------------------------
/save/weights_final.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/save/weights_final.tar
--------------------------------------------------------------------------------
/image/count_46900_test_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_0.png
--------------------------------------------------------------------------------
/image/count_46900_test_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_1.png
--------------------------------------------------------------------------------
/image/count_46900_test_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_2.png
--------------------------------------------------------------------------------
/image/count_46900_test_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_3.png
--------------------------------------------------------------------------------
/image/count_46900_test_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_4.png
--------------------------------------------------------------------------------
/image/count_46900_test_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_5.png
--------------------------------------------------------------------------------
/image/count_46900_test_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_6.png
--------------------------------------------------------------------------------
/image/count_46900_test_7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_7.png
--------------------------------------------------------------------------------
/image/count_46900_test_8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_8.png
--------------------------------------------------------------------------------
/image/count_46900_test_9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_9.png
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | T = 10
2 | batch_size = 64
3 | A = 28
4 | B = 28
5 | z_size = 10
6 | N = 5
7 | dec_size = 256
8 | enc_size = 256
9 | epoch_num = 20
10 | learning_rate = 1e-3
11 | beta1 = 0.5
12 | USE_CUDA = True
13 | clip = 5.0
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | from draw_model import DrawModel
2 | from config import *
3 | from utility import save_image
4 | import torch.nn.utils
5 |
6 | torch.set_default_tensor_type('torch.FloatTensor')
7 |
8 | model = DrawModel(T,A,B,z_size,N,dec_size,enc_size)
9 |
10 | if USE_CUDA:
11 | model.cuda()
12 |
13 | state_dict = torch.load('save/weights_final.tar')
14 | model.load_state_dict(state_dict)
15 | def generate():
16 | x = model.generate(batch_size)
17 | save_image(x)
18 |
19 | if __name__ == '__main__':
20 | generate()
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # draw-pytorch
2 |
3 | Pytorch implementation of [DRAW: A Recurrent Neural Network For Image Generation](http://arxiv.org/pdf/1502.04623.pdf) on the MNIST generation task.
4 |
5 | | With Attention |
6 | | ------------- |
7 | |
|
8 |
9 |
10 | ## Usage
11 |
12 | `python train.py` downloads the MNIST dataset to ./data/mnist and train the DRAW model with attention for both reading and writing. After training, the weights files are written to ./save/weights_final.tar and the generated images are written to ./image/.png
13 |
14 | `python generate.py` loads wieghts from save/weights_final.tar and generates images
15 |
16 | The weights_final.tar file is trained for 50 epoch with minibatch size 64 on GTX 1080 GPU.
17 |
18 | ## Reference
19 | https://github.com/ericjang/draw
--------------------------------------------------------------------------------
/utility.py:
--------------------------------------------------------------------------------
1 | import torch.autograd as autograd
2 | import torch
3 | from config import *
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | def Variable(data, *args, **kwargs):
9 | if USE_CUDA:
10 | data = data.cuda()
11 | return autograd.Variable(data,*args, **kwargs)
12 |
13 | def unit_prefix(x, n=1):
14 | for i in range(n): x = x.unsqueeze(0)
15 | return x
16 |
17 | def align(x, y, start_dim=0):
18 | xd, yd = x.dim(), y.dim()
19 | if xd > yd: y = unit_prefix(y, xd - yd)
20 | elif yd > xd: x = unit_prefix(x, yd - xd)
21 |
22 | xs, ys = list(x.size()), list(y.size())
23 | nd = len(ys)
24 | for i in range(start_dim, nd):
25 | td = nd-i-1
26 | if ys[td]==1: ys[td] = xs[td]
27 | elif xs[td]==1: xs[td] = ys[td]
28 | return x.expand(*xs), y.expand(*ys)
29 |
30 | def matmul(X,Y):
31 | results = []
32 | for i in range(X.size(0)):
33 | result = torch.mm(X[i],Y[i])
34 | results.append(result.unsqueeze(0))
35 | return torch.cat(results)
36 |
37 |
38 |
39 | def xrecons_grid(X,B,A):
40 | """
41 | plots canvas for single time step
42 | X is x_recons, (batch_size x img_size)
43 | assumes features = BxA images
44 | batch is assumed to be a square number
45 | """
46 | padsize=1
47 | padval=.5
48 | ph=B+2*padsize
49 | pw=A+2*padsize
50 | batch_size=X.shape[0]
51 | N=int(np.sqrt(batch_size))
52 | X=X.reshape((N,N,B,A))
53 | img=np.ones((N*ph,N*pw))*padval
54 | for i in range(N):
55 | for j in range(N):
56 | startr=i*ph+padsize
57 | endr=startr+B
58 | startc=j*pw+padsize
59 | endc=startc+A
60 | img[startr:endr,startc:endc]=X[i,j,:,:]
61 | return img
62 |
63 | def save_image(x,count=0):
64 | for t in range(T):
65 | img = xrecons_grid(x[t],B,A)
66 | plt.matshow(img, cmap=plt.cm.gray)
67 | imgname = 'image/count_%d_%s_%d.png' % (count,'test', t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif
68 | plt.savefig(imgname)
69 | print(imgname)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from torchvision import datasets,transforms
3 | import torch.utils
4 | from draw_model import DrawModel
5 | from config import *
6 | from utility import Variable,save_image,xrecons_grid
7 | import torch.nn.utils
8 | import matplotlib.pyplot as plt
9 |
10 | torch.set_default_tensor_type('torch.FloatTensor')
11 |
12 | train_loader = torch.utils.data.DataLoader(
13 | datasets.MNIST('data/', train=True, download=True,
14 | transform=transforms.Compose([
15 | transforms.ToTensor()])),
16 | batch_size=batch_size, shuffle=False)
17 |
18 | model = DrawModel(T,A,B,z_size,N,dec_size,enc_size)
19 | optimizer = optim.Adam(model.parameters(),lr=learning_rate,betas=(beta1,0.999))
20 |
21 | if USE_CUDA:
22 | model.cuda()
23 |
24 | def train():
25 | avg_loss = 0
26 | count = 0
27 | for epoch in range(epoch_num):
28 | for data,_ in train_loader:
29 | bs = data.size()[0]
30 | data = Variable(data).view(bs, -1)
31 | optimizer.zero_grad()
32 | loss = model.loss(data)
33 | avg_loss += loss.cpu().data.numpy()
34 | loss.backward()
35 | torch.nn.utils.clip_grad_norm(model.parameters(), clip)
36 | optimizer.step()
37 | count += 1
38 | if count % 100 == 0:
39 | print 'Epoch-{}; Count-{}; loss: {};'.format(epoch, count, avg_loss / 100)
40 | if count % 3000 == 0:
41 | torch.save(model.state_dict(),'save/weights_%d.tar'%(count))
42 | generate_image(count)
43 | avg_loss = 0
44 | torch.save(model.state_dict(), 'save/weights_final.tar')
45 | generate_image(count)
46 |
47 |
48 | def generate_image(count):
49 | x = model.generate(batch_size)
50 | save_image(x,count)
51 |
52 | def save_example_image():
53 | train_iter = iter(train_loader)
54 | data, _ = train_iter.next()
55 | img = data.cpu().numpy().reshape(batch_size, 28, 28)
56 | imgs = xrecons_grid(img, B, A)
57 | plt.matshow(imgs, cmap=plt.cm.gray)
58 | plt.savefig('image/example.png')
59 |
60 | if __name__ == '__main__':
61 | save_example_image()
62 | train()
--------------------------------------------------------------------------------
/draw_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utility import *
4 | import torch.functional as F
5 |
6 | class DrawModel(nn.Module):
7 | def __init__(self,T,A,B,z_size,N,dec_size,enc_size):
8 | super(DrawModel,self).__init__()
9 | self.T = T
10 | # self.batch_size = batch_size
11 | self.A = A
12 | self.B = B
13 | self.z_size = z_size
14 | self.N = N
15 | self.dec_size = dec_size
16 | self.enc_size = enc_size
17 | self.cs = [0] * T
18 | self.logsigmas,self.sigmas,self.mus = [0] * T,[0] * T,[0] * T
19 |
20 | self.encoder = nn.LSTMCell(2 * N * N + dec_size, enc_size)
21 | self.encoder_gru = nn.GRUCell(2 * N * N + dec_size, enc_size)
22 | self.mu_linear = nn.Linear(dec_size, z_size)
23 | self.sigma_linear = nn.Linear(dec_size, z_size)
24 |
25 | self.decoder = nn.LSTMCell(z_size,dec_size)
26 | self.decoder_gru = nn.GRUCell(z_size,dec_size)
27 | self.dec_linear = nn.Linear(dec_size,5)
28 | self.dec_w_linear = nn.Linear(dec_size,N*N)
29 |
30 | self.sigmoid = nn.Sigmoid()
31 |
32 |
33 |
34 | def normalSample(self):
35 | return Variable(torch.randn(self.batch_size,self.z_size))
36 |
37 | # correct
38 | def compute_mu(self,g,rng,delta):
39 | rng_t,delta_t = align(rng,delta)
40 | tmp = (rng_t - self.N / 2 - 0.5) * delta_t
41 | tmp_t,g_t = align(tmp,g)
42 | mu = tmp_t + g_t
43 | return mu
44 |
45 | # correct
46 | def filterbank(self,gx,gy,sigma2,delta):
47 | rng = Variable(torch.arange(0,self.N).view(1,-1))
48 | mu_x = self.compute_mu(gx,rng,delta)
49 | mu_y = self.compute_mu(gy,rng,delta)
50 |
51 | a = Variable(torch.arange(0,self.A).view(1,1,-1))
52 | b = Variable(torch.arange(0,self.B).view(1,1,-1))
53 |
54 | mu_x = mu_x.view(-1,self.N,1)
55 | mu_y = mu_y.view(-1,self.N,1)
56 | sigma2 = sigma2.view(-1,1,1)
57 |
58 | Fx = self.filterbank_matrices(a,mu_x,sigma2)
59 | Fy = self.filterbank_matrices(b,mu_y,sigma2)
60 |
61 | return Fx,Fy
62 |
63 | def forward(self,x):
64 | self.batch_size = x.size()[0]
65 | h_dec_prev = Variable(torch.zeros(self.batch_size,self.dec_size))
66 | h_enc_prev = Variable(torch.zeros(self.batch_size, self.enc_size))
67 |
68 | enc_state = Variable(torch.zeros(self.batch_size,self.enc_size))
69 | dec_state = Variable(torch.zeros(self.batch_size, self.dec_size))
70 | for t in xrange(self.T):
71 | c_prev = Variable(torch.zeros(self.batch_size,self.A * self.B)) if t == 0 else self.cs[t-1]
72 | x_hat = x - self.sigmoid(c_prev) # 3
73 | r_t = self.read(x,x_hat,h_dec_prev)
74 | h_enc_prev,enc_state = self.encoder(torch.cat((r_t,h_dec_prev),1),(h_enc_prev,enc_state))
75 | # h_enc = self.encoder_gru(torch.cat((r_t,h_dec_prev),1),h_enc_prev)
76 | z,self.mus[t],self.logsigmas[t],self.sigmas[t] = self.sampleQ(h_enc_prev)
77 | h_dec,dec_state = self.decoder(z, (h_dec_prev, dec_state))
78 | # h_dec = self.decoder_gru(z, h_dec_prev)
79 | self.cs[t] = c_prev + self.write(h_dec)
80 | h_dec_prev = h_dec
81 |
82 | def loss(self,x):
83 | self.forward(x)
84 | criterion = nn.BCELoss()
85 | x_recons = self.sigmoid(self.cs[-1])
86 | Lx = criterion(x_recons,x) * self.A * self.B
87 | Lz = 0
88 | kl_terms = [0] * T
89 | for t in xrange(self.T):
90 | mu_2 = self.mus[t] * self.mus[t]
91 | sigma_2 = self.sigmas[t] * self.sigmas[t]
92 | logsigma = self.logsigmas[t]
93 | # Lz += (0.5 * (mu_2 + sigma_2 - 2 * logsigma)) # 11
94 | kl_terms[t] = 0.5 * torch.sum(mu_2+sigma_2-2 * logsigma,1) - self.T * 0.5
95 | Lz += kl_terms[t]
96 | # Lz -= self.T / 2
97 | Lz = torch.mean(Lz) ####################################################
98 | loss = Lz + Lx # 12
99 | return loss
100 |
101 |
102 | # correct
103 | def filterbank_matrices(self,a,mu_x,sigma2,epsilon=1e-9):
104 | t_a,t_mu_x = align(a,mu_x)
105 | temp = t_a - t_mu_x
106 | temp,t_sigma = align(temp,sigma2)
107 | temp = temp / (t_sigma * 2)
108 | F = torch.exp(-torch.pow(temp,2))
109 | F = F / (F.sum(2,True).expand_as(F) + epsilon)
110 | return F
111 |
112 | #correct
113 | def attn_window(self,h_dec):
114 | params = self.dec_linear(h_dec)
115 | gx_,gy_,log_sigma_2,log_delta,log_gamma = params.split(1,1) #21
116 |
117 | # gx_ = Variable(torch.ones(4,1))
118 | # gy_ = Variable(torch.ones(4, 1) * 2)
119 | # log_sigma_2 = Variable(torch.ones(4, 1) * 3)
120 | # log_delta = Variable(torch.ones(4, 1) * 4)
121 | # log_gamma = Variable(torch.ones(4, 1) * 5)
122 |
123 | gx = (self.A + 1) / 2 * (gx_ + 1) # 22
124 | gy = (self.B + 1) / 2 * (gy_ + 1) # 23
125 | delta = (max(self.A,self.B) - 1) / (self.N - 1) * torch.exp(log_delta) # 24
126 | sigma2 = torch.exp(log_sigma_2)
127 | gamma = torch.exp(log_gamma)
128 |
129 | return self.filterbank(gx,gy,sigma2,delta),gamma
130 | # correct
131 | def read(self,x,x_hat,h_dec_prev):
132 | (Fx,Fy),gamma = self.attn_window(h_dec_prev)
133 | def filter_img(img,Fx,Fy,gamma,A,B,N):
134 | Fxt = Fx.transpose(2,1)
135 | img = img.view(-1,B,A)
136 | # img = img.transpose(2,1)
137 | # glimpse = matmul(Fy,matmul(img,Fxt))
138 | glimpse = Fy.bmm(img.bmm(Fxt))
139 | glimpse = glimpse.view(-1,N*N)
140 | return glimpse * gamma.view(-1,1).expand_as(glimpse)
141 | x = filter_img(x,Fx,Fy,gamma,self.A,self.B,self.N)
142 | x_hat = filter_img(x_hat,Fx,Fy,gamma,self.A,self.B,self.N)
143 | return torch.cat((x,x_hat),1)
144 |
145 | # correct
146 | def write(self,h_dec=0):
147 | w = self.dec_w_linear(h_dec)
148 | w = w.view(self.batch_size,self.N,self.N)
149 | # w = Variable(torch.ones(4,5,5) * 3)
150 | # self.batch_size = 4
151 | (Fx,Fy),gamma = self.attn_window(h_dec)
152 | Fyt = Fy.transpose(2,1)
153 | # wr = matmul(Fyt,matmul(w,Fx))
154 | wr = Fyt.bmm(w.bmm(Fx))
155 | wr = wr.view(self.batch_size,self.A*self.B)
156 | return wr / gamma.view(-1,1).expand_as(wr)
157 |
158 | def sampleQ(self,h_enc):
159 | e = self.normalSample()
160 | # mu_sigma = self.mu_sigma_linear(h_enc)
161 | # mu = mu_sigma[:, :self.z_size]
162 | # log_sigma = mu_sigma[:, self.z_size:]
163 | mu = self.mu_linear(h_enc) # 1
164 | log_sigma = self.sigma_linear(h_enc) # 2
165 | sigma = torch.exp(log_sigma)
166 |
167 | return mu + sigma * e , mu , log_sigma, sigma
168 |
169 | def generate(self,batch_size=64):
170 | self.batch_size = batch_size
171 | h_dec_prev = Variable(torch.zeros(self.batch_size,self.dec_size),volatile = True)
172 | dec_state = Variable(torch.zeros(self.batch_size, self.dec_size),volatile = True)
173 |
174 | for t in xrange(self.T):
175 | c_prev = Variable(torch.zeros(self.batch_size, self.A * self.B)) if t == 0 else self.cs[t - 1]
176 | z = self.normalSample()
177 | h_dec, dec_state = self.decoder(z, (h_dec_prev, dec_state))
178 | self.cs[t] = c_prev + self.write(h_dec)
179 | h_dec_prev = h_dec
180 | imgs = []
181 | for img in self.cs:
182 | imgs.append(self.sigmoid(img).cpu().data.numpy())
183 | return imgs
184 |
185 |
186 |
187 |
188 | # model = DrawModel(10,5,5,10,5,128,128)
189 | # x = Variable(torch.ones(4,25))
190 | # x_hat = Variable(torch.ones(4,25)*2)
191 | # r = model.write()
192 | # print r
193 | # g = Variable(torch.ones(4,1))
194 | # delta = Variable(torch.ones(4,1) * 3)
195 | # sigma = Variable(torch.ones(4,1))
196 | # rng = Variable(torch.arange(0,5).view(1,-1))
197 | # mu_x = model.compute_mu(g,rng,delta)
198 | # a = Variable(torch.arange(0,5).view(1,1,-1))
199 | # mu_x = mu_x.view(-1,5,1)
200 | # sigma = sigma.view(-1,1,1)
201 | # F = model.filterbank_matrices(a,mu_x,sigma)
202 | # print F
203 | # def test_normalSample():
204 | # print model.normalSample()
205 | #
206 | # def test_write():
207 | # h_dec = Variable(torch.zeros(8,128))
208 | # model.write(h_dec)
209 | #
210 | # def test_read():
211 | # x = Variable(torch.zeros(8,28*28))
212 | # x_hat = Variable((torch.zeros(8,28*28)))
213 | # h_dec = Variable(torch.zeros(8, 128))
214 | # model.read(x,x_hat,h_dec)
215 | #
216 | # def test_loss():
217 | # x = Variable(torch.zeros(8,28*28))
218 | # loss = model.loss(x)
219 | # print loss
220 |
221 |
--------------------------------------------------------------------------------