├── image ├── example.png ├── count_46900_test_0.png ├── count_46900_test_1.png ├── count_46900_test_2.png ├── count_46900_test_3.png ├── count_46900_test_4.png ├── count_46900_test_5.png ├── count_46900_test_6.png ├── count_46900_test_7.png ├── count_46900_test_8.png └── count_46900_test_9.png ├── save └── weights_final.tar ├── .idea └── vcs.xml ├── config.py ├── generate.py ├── README.md ├── utility.py ├── train.py └── draw_model.py /image/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/example.png -------------------------------------------------------------------------------- /save/weights_final.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/save/weights_final.tar -------------------------------------------------------------------------------- /image/count_46900_test_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_0.png -------------------------------------------------------------------------------- /image/count_46900_test_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_1.png -------------------------------------------------------------------------------- /image/count_46900_test_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_2.png -------------------------------------------------------------------------------- /image/count_46900_test_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_3.png -------------------------------------------------------------------------------- /image/count_46900_test_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_4.png -------------------------------------------------------------------------------- /image/count_46900_test_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_5.png -------------------------------------------------------------------------------- /image/count_46900_test_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_6.png -------------------------------------------------------------------------------- /image/count_46900_test_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_7.png -------------------------------------------------------------------------------- /image/count_46900_test_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_8.png -------------------------------------------------------------------------------- /image/count_46900_test_9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czm0/draw_pytorch/HEAD/image/count_46900_test_9.png -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | T = 10 2 | batch_size = 64 3 | A = 28 4 | B = 28 5 | z_size = 10 6 | N = 5 7 | dec_size = 256 8 | enc_size = 256 9 | epoch_num = 20 10 | learning_rate = 1e-3 11 | beta1 = 0.5 12 | USE_CUDA = True 13 | clip = 5.0 -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | from draw_model import DrawModel 2 | from config import * 3 | from utility import save_image 4 | import torch.nn.utils 5 | 6 | torch.set_default_tensor_type('torch.FloatTensor') 7 | 8 | model = DrawModel(T,A,B,z_size,N,dec_size,enc_size) 9 | 10 | if USE_CUDA: 11 | model.cuda() 12 | 13 | state_dict = torch.load('save/weights_final.tar') 14 | model.load_state_dict(state_dict) 15 | def generate(): 16 | x = model.generate(batch_size) 17 | save_image(x) 18 | 19 | if __name__ == '__main__': 20 | generate() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # draw-pytorch 2 | 3 | Pytorch implementation of [DRAW: A Recurrent Neural Network For Image Generation](http://arxiv.org/pdf/1502.04623.pdf) on the MNIST generation task. 4 | 5 | | With Attention | 6 | | ------------- | 7 | | | 8 | 9 | 10 | ## Usage 11 | 12 | `python train.py` downloads the MNIST dataset to ./data/mnist and train the DRAW model with attention for both reading and writing. After training, the weights files are written to ./save/weights_final.tar and the generated images are written to ./image/.png 13 | 14 | `python generate.py` loads wieghts from save/weights_final.tar and generates images 15 | 16 | The weights_final.tar file is trained for 50 epoch with minibatch size 64 on GTX 1080 GPU. 17 | 18 | ## Reference 19 | https://github.com/ericjang/draw -------------------------------------------------------------------------------- /utility.py: -------------------------------------------------------------------------------- 1 | import torch.autograd as autograd 2 | import torch 3 | from config import * 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def Variable(data, *args, **kwargs): 9 | if USE_CUDA: 10 | data = data.cuda() 11 | return autograd.Variable(data,*args, **kwargs) 12 | 13 | def unit_prefix(x, n=1): 14 | for i in range(n): x = x.unsqueeze(0) 15 | return x 16 | 17 | def align(x, y, start_dim=0): 18 | xd, yd = x.dim(), y.dim() 19 | if xd > yd: y = unit_prefix(y, xd - yd) 20 | elif yd > xd: x = unit_prefix(x, yd - xd) 21 | 22 | xs, ys = list(x.size()), list(y.size()) 23 | nd = len(ys) 24 | for i in range(start_dim, nd): 25 | td = nd-i-1 26 | if ys[td]==1: ys[td] = xs[td] 27 | elif xs[td]==1: xs[td] = ys[td] 28 | return x.expand(*xs), y.expand(*ys) 29 | 30 | def matmul(X,Y): 31 | results = [] 32 | for i in range(X.size(0)): 33 | result = torch.mm(X[i],Y[i]) 34 | results.append(result.unsqueeze(0)) 35 | return torch.cat(results) 36 | 37 | 38 | 39 | def xrecons_grid(X,B,A): 40 | """ 41 | plots canvas for single time step 42 | X is x_recons, (batch_size x img_size) 43 | assumes features = BxA images 44 | batch is assumed to be a square number 45 | """ 46 | padsize=1 47 | padval=.5 48 | ph=B+2*padsize 49 | pw=A+2*padsize 50 | batch_size=X.shape[0] 51 | N=int(np.sqrt(batch_size)) 52 | X=X.reshape((N,N,B,A)) 53 | img=np.ones((N*ph,N*pw))*padval 54 | for i in range(N): 55 | for j in range(N): 56 | startr=i*ph+padsize 57 | endr=startr+B 58 | startc=j*pw+padsize 59 | endc=startc+A 60 | img[startr:endr,startc:endc]=X[i,j,:,:] 61 | return img 62 | 63 | def save_image(x,count=0): 64 | for t in range(T): 65 | img = xrecons_grid(x[t],B,A) 66 | plt.matshow(img, cmap=plt.cm.gray) 67 | imgname = 'image/count_%d_%s_%d.png' % (count,'test', t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif 68 | plt.savefig(imgname) 69 | print(imgname) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from torchvision import datasets,transforms 3 | import torch.utils 4 | from draw_model import DrawModel 5 | from config import * 6 | from utility import Variable,save_image,xrecons_grid 7 | import torch.nn.utils 8 | import matplotlib.pyplot as plt 9 | 10 | torch.set_default_tensor_type('torch.FloatTensor') 11 | 12 | train_loader = torch.utils.data.DataLoader( 13 | datasets.MNIST('data/', train=True, download=True, 14 | transform=transforms.Compose([ 15 | transforms.ToTensor()])), 16 | batch_size=batch_size, shuffle=False) 17 | 18 | model = DrawModel(T,A,B,z_size,N,dec_size,enc_size) 19 | optimizer = optim.Adam(model.parameters(),lr=learning_rate,betas=(beta1,0.999)) 20 | 21 | if USE_CUDA: 22 | model.cuda() 23 | 24 | def train(): 25 | avg_loss = 0 26 | count = 0 27 | for epoch in range(epoch_num): 28 | for data,_ in train_loader: 29 | bs = data.size()[0] 30 | data = Variable(data).view(bs, -1) 31 | optimizer.zero_grad() 32 | loss = model.loss(data) 33 | avg_loss += loss.cpu().data.numpy() 34 | loss.backward() 35 | torch.nn.utils.clip_grad_norm(model.parameters(), clip) 36 | optimizer.step() 37 | count += 1 38 | if count % 100 == 0: 39 | print 'Epoch-{}; Count-{}; loss: {};'.format(epoch, count, avg_loss / 100) 40 | if count % 3000 == 0: 41 | torch.save(model.state_dict(),'save/weights_%d.tar'%(count)) 42 | generate_image(count) 43 | avg_loss = 0 44 | torch.save(model.state_dict(), 'save/weights_final.tar') 45 | generate_image(count) 46 | 47 | 48 | def generate_image(count): 49 | x = model.generate(batch_size) 50 | save_image(x,count) 51 | 52 | def save_example_image(): 53 | train_iter = iter(train_loader) 54 | data, _ = train_iter.next() 55 | img = data.cpu().numpy().reshape(batch_size, 28, 28) 56 | imgs = xrecons_grid(img, B, A) 57 | plt.matshow(imgs, cmap=plt.cm.gray) 58 | plt.savefig('image/example.png') 59 | 60 | if __name__ == '__main__': 61 | save_example_image() 62 | train() -------------------------------------------------------------------------------- /draw_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utility import * 4 | import torch.functional as F 5 | 6 | class DrawModel(nn.Module): 7 | def __init__(self,T,A,B,z_size,N,dec_size,enc_size): 8 | super(DrawModel,self).__init__() 9 | self.T = T 10 | # self.batch_size = batch_size 11 | self.A = A 12 | self.B = B 13 | self.z_size = z_size 14 | self.N = N 15 | self.dec_size = dec_size 16 | self.enc_size = enc_size 17 | self.cs = [0] * T 18 | self.logsigmas,self.sigmas,self.mus = [0] * T,[0] * T,[0] * T 19 | 20 | self.encoder = nn.LSTMCell(2 * N * N + dec_size, enc_size) 21 | self.encoder_gru = nn.GRUCell(2 * N * N + dec_size, enc_size) 22 | self.mu_linear = nn.Linear(dec_size, z_size) 23 | self.sigma_linear = nn.Linear(dec_size, z_size) 24 | 25 | self.decoder = nn.LSTMCell(z_size,dec_size) 26 | self.decoder_gru = nn.GRUCell(z_size,dec_size) 27 | self.dec_linear = nn.Linear(dec_size,5) 28 | self.dec_w_linear = nn.Linear(dec_size,N*N) 29 | 30 | self.sigmoid = nn.Sigmoid() 31 | 32 | 33 | 34 | def normalSample(self): 35 | return Variable(torch.randn(self.batch_size,self.z_size)) 36 | 37 | # correct 38 | def compute_mu(self,g,rng,delta): 39 | rng_t,delta_t = align(rng,delta) 40 | tmp = (rng_t - self.N / 2 - 0.5) * delta_t 41 | tmp_t,g_t = align(tmp,g) 42 | mu = tmp_t + g_t 43 | return mu 44 | 45 | # correct 46 | def filterbank(self,gx,gy,sigma2,delta): 47 | rng = Variable(torch.arange(0,self.N).view(1,-1)) 48 | mu_x = self.compute_mu(gx,rng,delta) 49 | mu_y = self.compute_mu(gy,rng,delta) 50 | 51 | a = Variable(torch.arange(0,self.A).view(1,1,-1)) 52 | b = Variable(torch.arange(0,self.B).view(1,1,-1)) 53 | 54 | mu_x = mu_x.view(-1,self.N,1) 55 | mu_y = mu_y.view(-1,self.N,1) 56 | sigma2 = sigma2.view(-1,1,1) 57 | 58 | Fx = self.filterbank_matrices(a,mu_x,sigma2) 59 | Fy = self.filterbank_matrices(b,mu_y,sigma2) 60 | 61 | return Fx,Fy 62 | 63 | def forward(self,x): 64 | self.batch_size = x.size()[0] 65 | h_dec_prev = Variable(torch.zeros(self.batch_size,self.dec_size)) 66 | h_enc_prev = Variable(torch.zeros(self.batch_size, self.enc_size)) 67 | 68 | enc_state = Variable(torch.zeros(self.batch_size,self.enc_size)) 69 | dec_state = Variable(torch.zeros(self.batch_size, self.dec_size)) 70 | for t in xrange(self.T): 71 | c_prev = Variable(torch.zeros(self.batch_size,self.A * self.B)) if t == 0 else self.cs[t-1] 72 | x_hat = x - self.sigmoid(c_prev) # 3 73 | r_t = self.read(x,x_hat,h_dec_prev) 74 | h_enc_prev,enc_state = self.encoder(torch.cat((r_t,h_dec_prev),1),(h_enc_prev,enc_state)) 75 | # h_enc = self.encoder_gru(torch.cat((r_t,h_dec_prev),1),h_enc_prev) 76 | z,self.mus[t],self.logsigmas[t],self.sigmas[t] = self.sampleQ(h_enc_prev) 77 | h_dec,dec_state = self.decoder(z, (h_dec_prev, dec_state)) 78 | # h_dec = self.decoder_gru(z, h_dec_prev) 79 | self.cs[t] = c_prev + self.write(h_dec) 80 | h_dec_prev = h_dec 81 | 82 | def loss(self,x): 83 | self.forward(x) 84 | criterion = nn.BCELoss() 85 | x_recons = self.sigmoid(self.cs[-1]) 86 | Lx = criterion(x_recons,x) * self.A * self.B 87 | Lz = 0 88 | kl_terms = [0] * T 89 | for t in xrange(self.T): 90 | mu_2 = self.mus[t] * self.mus[t] 91 | sigma_2 = self.sigmas[t] * self.sigmas[t] 92 | logsigma = self.logsigmas[t] 93 | # Lz += (0.5 * (mu_2 + sigma_2 - 2 * logsigma)) # 11 94 | kl_terms[t] = 0.5 * torch.sum(mu_2+sigma_2-2 * logsigma,1) - self.T * 0.5 95 | Lz += kl_terms[t] 96 | # Lz -= self.T / 2 97 | Lz = torch.mean(Lz) #################################################### 98 | loss = Lz + Lx # 12 99 | return loss 100 | 101 | 102 | # correct 103 | def filterbank_matrices(self,a,mu_x,sigma2,epsilon=1e-9): 104 | t_a,t_mu_x = align(a,mu_x) 105 | temp = t_a - t_mu_x 106 | temp,t_sigma = align(temp,sigma2) 107 | temp = temp / (t_sigma * 2) 108 | F = torch.exp(-torch.pow(temp,2)) 109 | F = F / (F.sum(2,True).expand_as(F) + epsilon) 110 | return F 111 | 112 | #correct 113 | def attn_window(self,h_dec): 114 | params = self.dec_linear(h_dec) 115 | gx_,gy_,log_sigma_2,log_delta,log_gamma = params.split(1,1) #21 116 | 117 | # gx_ = Variable(torch.ones(4,1)) 118 | # gy_ = Variable(torch.ones(4, 1) * 2) 119 | # log_sigma_2 = Variable(torch.ones(4, 1) * 3) 120 | # log_delta = Variable(torch.ones(4, 1) * 4) 121 | # log_gamma = Variable(torch.ones(4, 1) * 5) 122 | 123 | gx = (self.A + 1) / 2 * (gx_ + 1) # 22 124 | gy = (self.B + 1) / 2 * (gy_ + 1) # 23 125 | delta = (max(self.A,self.B) - 1) / (self.N - 1) * torch.exp(log_delta) # 24 126 | sigma2 = torch.exp(log_sigma_2) 127 | gamma = torch.exp(log_gamma) 128 | 129 | return self.filterbank(gx,gy,sigma2,delta),gamma 130 | # correct 131 | def read(self,x,x_hat,h_dec_prev): 132 | (Fx,Fy),gamma = self.attn_window(h_dec_prev) 133 | def filter_img(img,Fx,Fy,gamma,A,B,N): 134 | Fxt = Fx.transpose(2,1) 135 | img = img.view(-1,B,A) 136 | # img = img.transpose(2,1) 137 | # glimpse = matmul(Fy,matmul(img,Fxt)) 138 | glimpse = Fy.bmm(img.bmm(Fxt)) 139 | glimpse = glimpse.view(-1,N*N) 140 | return glimpse * gamma.view(-1,1).expand_as(glimpse) 141 | x = filter_img(x,Fx,Fy,gamma,self.A,self.B,self.N) 142 | x_hat = filter_img(x_hat,Fx,Fy,gamma,self.A,self.B,self.N) 143 | return torch.cat((x,x_hat),1) 144 | 145 | # correct 146 | def write(self,h_dec=0): 147 | w = self.dec_w_linear(h_dec) 148 | w = w.view(self.batch_size,self.N,self.N) 149 | # w = Variable(torch.ones(4,5,5) * 3) 150 | # self.batch_size = 4 151 | (Fx,Fy),gamma = self.attn_window(h_dec) 152 | Fyt = Fy.transpose(2,1) 153 | # wr = matmul(Fyt,matmul(w,Fx)) 154 | wr = Fyt.bmm(w.bmm(Fx)) 155 | wr = wr.view(self.batch_size,self.A*self.B) 156 | return wr / gamma.view(-1,1).expand_as(wr) 157 | 158 | def sampleQ(self,h_enc): 159 | e = self.normalSample() 160 | # mu_sigma = self.mu_sigma_linear(h_enc) 161 | # mu = mu_sigma[:, :self.z_size] 162 | # log_sigma = mu_sigma[:, self.z_size:] 163 | mu = self.mu_linear(h_enc) # 1 164 | log_sigma = self.sigma_linear(h_enc) # 2 165 | sigma = torch.exp(log_sigma) 166 | 167 | return mu + sigma * e , mu , log_sigma, sigma 168 | 169 | def generate(self,batch_size=64): 170 | self.batch_size = batch_size 171 | h_dec_prev = Variable(torch.zeros(self.batch_size,self.dec_size),volatile = True) 172 | dec_state = Variable(torch.zeros(self.batch_size, self.dec_size),volatile = True) 173 | 174 | for t in xrange(self.T): 175 | c_prev = Variable(torch.zeros(self.batch_size, self.A * self.B)) if t == 0 else self.cs[t - 1] 176 | z = self.normalSample() 177 | h_dec, dec_state = self.decoder(z, (h_dec_prev, dec_state)) 178 | self.cs[t] = c_prev + self.write(h_dec) 179 | h_dec_prev = h_dec 180 | imgs = [] 181 | for img in self.cs: 182 | imgs.append(self.sigmoid(img).cpu().data.numpy()) 183 | return imgs 184 | 185 | 186 | 187 | 188 | # model = DrawModel(10,5,5,10,5,128,128) 189 | # x = Variable(torch.ones(4,25)) 190 | # x_hat = Variable(torch.ones(4,25)*2) 191 | # r = model.write() 192 | # print r 193 | # g = Variable(torch.ones(4,1)) 194 | # delta = Variable(torch.ones(4,1) * 3) 195 | # sigma = Variable(torch.ones(4,1)) 196 | # rng = Variable(torch.arange(0,5).view(1,-1)) 197 | # mu_x = model.compute_mu(g,rng,delta) 198 | # a = Variable(torch.arange(0,5).view(1,1,-1)) 199 | # mu_x = mu_x.view(-1,5,1) 200 | # sigma = sigma.view(-1,1,1) 201 | # F = model.filterbank_matrices(a,mu_x,sigma) 202 | # print F 203 | # def test_normalSample(): 204 | # print model.normalSample() 205 | # 206 | # def test_write(): 207 | # h_dec = Variable(torch.zeros(8,128)) 208 | # model.write(h_dec) 209 | # 210 | # def test_read(): 211 | # x = Variable(torch.zeros(8,28*28)) 212 | # x_hat = Variable((torch.zeros(8,28*28))) 213 | # h_dec = Variable(torch.zeros(8, 128)) 214 | # model.read(x,x_hat,h_dec) 215 | # 216 | # def test_loss(): 217 | # x = Variable(torch.zeros(8,28*28)) 218 | # loss = model.loss(x) 219 | # print loss 220 | 221 | --------------------------------------------------------------------------------