├── .gitignore
├── README.md
├── colordata.py
├── get_data.sh
├── images
└── pytorch_lfw.png
├── logger.py
├── main.py
├── mdn.py
├── requirements.txt
└── vae.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.txt
3 | *.jpg
4 | *.png
5 | *.JPEG
6 | *meta
7 | *ckpt
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch_divcolor
2 |
3 | PyTorch implementation of Diverse Colorization -- Deshpande et al. "[Learning Diverse Image Colorization](https://arxiv.org/abs/1612.01958)"
4 |
5 | This code is tested for python-3.5.2 and torch-0.3.0. Install packages in requirements.txt
6 |
7 | The Tensorflow implementation used in the paper is [divcolor](https://github.com/aditya12agd5/divcolor)
8 |
9 |
10 | Fetch data by
11 |
12 | ```
13 | bash get_data.sh
14 | ```
15 |
16 | Execute main.py to first train vae+mdn and then, generate results for LFW
17 |
18 | ```
19 | python main.py lfw
20 | ```
21 |
22 | If you use this code, please cite
23 |
24 | ```
25 | @inproceedings{DeshpandeLDColor17,
26 | author = {Aditya Deshpande, Jiajun Lu, Mao-Chuang Yeh, Min Jin Chong and David Forsyth},
27 | title = {Learning Diverse Image Colorization},
28 | booktitle={Computer Vision and Pattern Recognition},
29 | url={https://arxiv.org/abs/1612.01958},
30 | year={2017}
31 | }
32 | ```
33 |
34 | Some examples of diverse colorizations on LFW
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/colordata.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import math
4 | import numpy as np
5 | import os
6 |
7 | from torch.utils.data import Dataset
8 | import torchvision.datasets as datasets
9 | import torchvision.transforms as transforms
10 |
11 | class colordata(Dataset):
12 |
13 | def __init__(self, out_directory, listdir=None, featslistdir=None, shape=(64, 64), \
14 | subdir=False, ext='JPEG', outshape=(256, 256), split='train'):
15 |
16 | self.img_fns = []
17 | self.feats_fns = []
18 |
19 | with open('%s/list.%s.vae.txt' % (listdir, split), 'r') as ftr:
20 | for img_fn in ftr:
21 | self.img_fns.append(img_fn.strip('\n'))
22 |
23 | with open('%s/list.%s.txt' % (featslistdir, split), 'r') as ftr:
24 | for feats_fn in ftr:
25 | self.feats_fns.append(feats_fn.strip('\n'))
26 |
27 | self.img_num = min(len(self.img_fns), len(self.feats_fns))
28 | self.shape = shape
29 | self.outshape = outshape
30 | self.out_directory = out_directory
31 |
32 | self.lossweights = None
33 | countbins = 1./np.load('data/zhang_weights/prior_probs.npy')
34 | binedges = np.load('data/zhang_weights/ab_quantize.npy').reshape(2, 313)
35 | lossweights = {}
36 | for i in range(313):
37 | if binedges[0, i] not in lossweights:
38 | lossweights[binedges[0, i]] = {}
39 | lossweights[binedges[0,i]][binedges[1,i]] = countbins[i]
40 | self.binedges = binedges
41 | self.lossweights = lossweights
42 |
43 | def __len__(self):
44 | return self.img_num
45 |
46 | def __getitem__(self, idx):
47 | color_ab = np.zeros((2, self.shape[0], self.shape[1]), dtype='f')
48 | weights = np.ones((2, self.shape[0], self.shape[1]), dtype='f')
49 | recon_const = np.zeros((1, self.shape[0], self.shape[1]), dtype='f')
50 | recon_const_outres = np.zeros((1, self.outshape[0], self.outshape[1]), dtype='f')
51 | greyfeats = np.zeros((512, 28, 28), dtype='f')
52 |
53 | img_large = cv2.imread(self.img_fns[idx])
54 | if(self.shape is not None):
55 | img = cv2.resize(img_large, (self.shape[0], self.shape[1]))
56 | img_outres = cv2.resize(img_large, (self.outshape[0], self.outshape[1]))
57 |
58 | img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
59 | img_lab_outres = cv2.cvtColor(img_outres, cv2.COLOR_BGR2LAB)
60 |
61 | img_lab = ((img_lab*2.)/255.)-1.
62 | img_lab_outres = ((img_lab_outres*2.)/255.)-1.
63 |
64 | recon_const[0, :, :] = img_lab[..., 0]
65 | recon_const_outres[0, :, :] = img_lab_outres[..., 0]
66 |
67 | color_ab[0, :, :] = img_lab[..., 1].reshape(1, self.shape[0], self.shape[1])
68 | color_ab[1, :, :] = img_lab[..., 2].reshape(1, self.shape[0], self.shape[1])
69 |
70 | if(self.lossweights is not None):
71 | weights = self.__getweights__(color_ab)
72 |
73 | featobj = np.load(self.feats_fns[idx])
74 | greyfeats[:,:,:] = featobj['arr_0']
75 |
76 | return color_ab, recon_const, weights, recon_const_outres, greyfeats
77 |
78 | def __getweights__(self, img):
79 | img_vec = img.reshape(-1)
80 | img_vec = img_vec*128.
81 | img_lossweights = np.zeros(img.shape, dtype='f')
82 | img_vec_a = img_vec[:np.prod(self.shape)]
83 | binedges_a = self.binedges[0,...].reshape(-1)
84 | binid_a = [binedges_a.flat[np.abs(binedges_a-v).argmin()] for v in img_vec_a]
85 | img_vec_b = img_vec[np.prod(self.shape):]
86 | binedges_b = self.binedges[1,...].reshape(-1)
87 | binid_b = [binedges_b.flat[np.abs(binedges_b-v).argmin()] for v in img_vec_b]
88 | binweights = np.array([self.lossweights[v1][v2] for v1,v2 in zip(binid_a, binid_b)])
89 | img_lossweights[0, :, :] = binweights.reshape(self.shape[0], self.shape[1])
90 | img_lossweights[1, :, :] = binweights.reshape(self.shape[0], self.shape[1])
91 | return img_lossweights
92 |
93 | def saveoutput_gt(self, net_op, gt, prefix, batch_size, num_cols=8, net_recon_const=None):
94 |
95 | net_out_img = self.__tiledoutput__(net_op, batch_size, num_cols=num_cols, \
96 | net_recon_const=net_recon_const)
97 |
98 | gt_out_img = self.__tiledoutput__(gt, batch_size, num_cols=num_cols, \
99 | net_recon_const=net_recon_const)
100 |
101 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols))
102 | border_img = 255*np.ones((num_rows*self.outshape[0], 128, 3), dtype='uint8')
103 | out_fn_pred = '%s/%s.png' % (self.out_directory, prefix)
104 | cv2.imwrite(out_fn_pred, np.concatenate((net_out_img, border_img, gt_out_img), axis=1))
105 |
106 | def __tiledoutput__(self, net_op, batch_size, num_cols=8, net_recon_const=None):
107 |
108 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols))
109 | out_img = np.zeros((num_rows*self.outshape[0], num_cols*self.outshape[1], 3), dtype='uint8')
110 | img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8')
111 | c = 0
112 | r = 0
113 |
114 | for i in range(batch_size):
115 | if(i % num_cols == 0 and i > 0):
116 | r = r + 1
117 | c = 0
118 | img_lab[..., 0] = self.__decodeimg__(net_recon_const[i, 0, :, :].reshape(\
119 | self.outshape[0], self.outshape[1]))
120 | img_lab[..., 1] = self.__decodeimg__(net_op[i, 0, :, :].reshape(\
121 | self.shape[0], self.shape[1]))
122 | img_lab[..., 2] = self.__decodeimg__(net_op[i, 1, :, :].reshape(\
123 | self.shape[0], self.shape[1]))
124 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR)
125 | out_img[r*self.outshape[0]:(r+1)*self.outshape[0], \
126 | c*self.outshape[1]:(c+1)*self.outshape[1], ...] = img_rgb
127 | c = c+1
128 |
129 | return out_img
130 |
131 | def __decodeimg__(self, img_enc):
132 | img_dec = (((img_enc+1.)*1.)/2.)*255.
133 | img_dec[img_dec < 0.] = 0.
134 | img_dec[img_dec > 255.] = 255.
135 | return cv2.resize(np.uint8(img_dec), (self.outshape[0], self.outshape[1]))
136 |
137 |
138 |
--------------------------------------------------------------------------------
/get_data.sh:
--------------------------------------------------------------------------------
1 | wget http://vision.cs.illinois.edu/projects/divcolor/data.zip
2 | unzip data.zip
3 | rm data.zip
4 | wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
5 | tar -xvzf lfw-deepfunneled.tgz
6 | mv lfw-deepfunneled data/lfw_images
7 | rm lfw-deepfunneled.tgz
8 | rm data/output/lfw/*/*
9 | rm data/output/lfw/*npy
10 |
--------------------------------------------------------------------------------
/images/pytorch_lfw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aditya12agd5/pytorch_divcolor/64854473e2071ba6628eae9450758c362449a864/images/pytorch_lfw.png
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import visdom
3 | import numpy as np
4 |
5 | class Logger():
6 |
7 | def __init__(self, server, port, outdir):
8 | self.vis = visdom.Visdom(port=port, server=server)
9 |
10 | titles = ['VAE -- KL Div', 'VAE -- Weighted L2', 'VAE -- L2']
11 | self.vis_plot_vae = []
12 | for title in titles:
13 | self.vis_plot_vae.append(self.vis.line(
14 | X=np.array([0.], dtype='f'),
15 | Y=np.array([0.], dtype='f'),
16 | opts=dict(
17 | xlabel='Iteration',\
18 | ylabel='Loss',\
19 | title=title)))
20 |
21 | self.vis_plot_test_vae = self.vis.line(
22 | X=np.array([0.], dtype='f'),
23 | Y=np.array([0.], dtype='f'),
24 | opts=dict(
25 | xlabel='Iteration',\
26 | ylabel='Test Loss',\
27 | title='VAE Test Loss'))
28 |
29 |
30 | self.vis_plot_mdn = []
31 | titles = ['MDN Loss', 'MDN -- L2']
32 | for title in titles:
33 | self.vis_plot_mdn.append(self.vis.line(
34 | X=np.array([0.], dtype='f'),
35 | Y=np.array([0.], dtype='f'),
36 | opts=dict(
37 | xlabel='Iteration',\
38 | ylabel='Loss',\
39 | title=title)))
40 |
41 | self.fp_vae = open('%s/log_vae.txt' % outdir, 'w')
42 | self.fp_vae.write('Iteration; KLDiv; WeightedL2; L2;\n')
43 | self.fp_vae.flush()
44 |
45 | self.fp_test_vae = open('%s/log_test_vae.txt' % outdir, 'w')
46 | self.fp_test_vae.write('Iteration; Loss;\n')
47 | self.fp_test_vae.flush()
48 |
49 | self.fp_mdn = open('%s/log_mdn.txt' % outdir, 'w')
50 | self.fp_mdn.write('Iteration; Loss; L2 Loss;\n')
51 | self.fp_mdn.flush()
52 |
53 | def update_plot(self, x, losses, plot_type='vae'):
54 |
55 | if(plot_type == 'vae'):
56 | self.fp_vae.write('%f;' % x)
57 | for loss_i, loss in enumerate(losses):
58 | win = self.vis_plot_vae[loss_i]
59 | self.vis.updateTrace(
60 | X=np.array([x], dtype='f'),
61 | Y=np.array([loss], dtype='f'),
62 | win=win)
63 | self.fp_vae.write(' %f;' % loss)
64 | self.fp_vae.write('\n')
65 | self.fp_vae.flush()
66 |
67 |
68 | elif(plot_type == 'mdn'):
69 | for loss_i, loss in enumerate(losses):
70 | win = self.vis_plot_mdn[loss_i]
71 | self.vis.updateTrace(
72 | X=np.array([x], dtype='f'),
73 | Y=np.array([losses[loss_i]], dtype='f'),
74 | win=win)
75 | self.fp_mdn.write('%f; %f; %f;\n' % (x, losses[0], losses[1]))
76 | self.fp_mdn.flush()
77 |
78 | def update_test_plot(self, x, y):
79 | self.vis.updateTrace(
80 | X=np.array([x], dtype='f'),
81 | Y=np.array([y], dtype='f'),
82 | win=self.vis_plot_test_vae)
83 | self.fp_test_vae.write('%f; %f;\n' % (x, y))
84 | self.fp_test_vae.flush()
85 |
86 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import socket
6 | import sys
7 | import numpy as np
8 |
9 | from colordata import colordata
10 | from vae import VAE
11 | from mdn import MDN
12 | from logger import Logger
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import torch.optim as optim
18 | from torch.autograd import Variable
19 | from torch.utils.data import DataLoader
20 |
21 | from tqdm import tqdm
22 |
23 | parser = argparse.ArgumentParser(description='PyTorch Diverse Colorization')
24 |
25 | parser.add_argument('dataset_key', help='Dataset')
26 |
27 | parser.add_argument('-g', '--gpu', type=int, default=0,\
28 | help='gpu device id')
29 |
30 | parser.add_argument('-e', '--epochs', type=int, default=15,\
31 | help='Number of epochs')
32 |
33 | parser.add_argument('-b', '--batchsize', type=int, default=32,\
34 | help='Batch size')
35 |
36 | parser.add_argument('-z', '--hiddensize', type=int, default=64,\
37 | help='Latent vector dimension')
38 |
39 | parser.add_argument('-n', '--nthreads', type=int, default=4,\
40 | help='Data loader threads')
41 |
42 | parser.add_argument('-em', '--epochs_mdn', type=int, default=7,\
43 | help='Number of epochs for MDN')
44 |
45 | parser.add_argument('-m', '--nmix', type=int, default=8,\
46 | help='Number of diverse colorization (or output gmm components)')
47 |
48 | parser.add_argument('-lg', '--logstep', type=int, default=100,\
49 | help='Interval to log data')
50 |
51 | parser.add_argument('-v', '--visdom', action='store_true',\
52 | help='Visdom visualization')
53 |
54 | parser.add_argument('-s', '--server', type=str, default='http://vision-gpu-4.cs.illinois.edu',\
55 | help='Visdom server')
56 |
57 | parser.add_argument('-p', '--port_num', type=int, default=8097)
58 |
59 |
60 | args = parser.parse_args()
61 |
62 | if(args.visdom):
63 | import visdom
64 |
65 | def get_dirpaths(args):
66 | if(args.dataset_key == 'lfw'):
67 | out_dir = 'data/output/lfw/'
68 | listdir = 'data/imglist/lfw/'
69 | featslistdir = 'data/featslist/lfw/'
70 | else:
71 | raise NameError('[ERROR] Incorrect key: %s' % (args.dataset_key))
72 | return out_dir, listdir, featslistdir
73 |
74 | def vae_loss(mu, logvar, pred, gt, lossweights, batchsize):
75 | kl_element = torch.add(torch.add(torch.add(mu.pow(2), logvar.exp()), -1), logvar.mul(-1))
76 | kl_loss = torch.sum(kl_element).mul(.5)
77 | gt = gt.view(-1, 64*64*2)
78 | pred = pred.view(-1, 64*64*2)
79 | recon_element = torch.sqrt(torch.sum(torch.mul(torch.add(gt, pred.mul(-1)).pow(2), lossweights), 1))
80 | recon_loss = torch.sum(recon_element).mul(1./(batchsize))
81 |
82 | recon_element_l2 = torch.sqrt(torch.sum(torch.add(gt, pred.mul(-1)).pow(2), 1))
83 | recon_loss_l2 = torch.sum(recon_element_l2).mul(1./(batchsize))
84 |
85 | return kl_loss, recon_loss, recon_loss_l2
86 |
87 | def get_gmm_coeffs(gmm_params):
88 | gmm_mu = gmm_params[..., :args.hiddensize*args.nmix]
89 | gmm_mu.contiguous()
90 | gmm_pi_activ = gmm_params[..., args.hiddensize*args.nmix:]
91 | gmm_pi_activ.contiguous()
92 | gmm_pi = F.softmax(gmm_pi_activ)
93 | return gmm_mu, gmm_pi
94 |
95 | def mdn_loss(gmm_params, mu, stddev, batchsize):
96 | gmm_mu, gmm_pi = get_gmm_coeffs(gmm_params)
97 | eps = Variable(torch.randn(stddev.size()).normal_()).cuda()
98 | z = torch.add(mu, torch.mul(eps, stddev))
99 | z_flat = z.repeat(1, args.nmix)
100 | z_flat = z_flat.view(batchsize*args.nmix, args.hiddensize)
101 | gmm_mu_flat = gmm_mu.view(batchsize*args.nmix, args.hiddensize)
102 | dist_all = torch.sqrt(torch.sum(torch.add(z_flat, gmm_mu_flat.mul(-1)).pow(2).mul(50), 1))
103 | dist_all = dist_all.view(batchsize, args.nmix)
104 | dist_min, selectids = torch.min(dist_all, 1)
105 | gmm_pi_min = torch.gather(gmm_pi, 1, selectids.view(-1, 1))
106 | gmm_loss = torch.mean(torch.add(-1*torch.log(gmm_pi_min+1e-30), dist_min))
107 | gmm_loss_l2 = torch.mean(dist_min)
108 | return gmm_loss, gmm_loss_l2
109 |
110 |
111 | def test_vae(model):
112 |
113 | model.train(False)
114 |
115 | out_dir, listdir, featslistdir = get_dirpaths(args)
116 | batchsize = args.batchsize
117 | hiddensize = args.hiddensize
118 | nmix = args.nmix
119 |
120 | data = colordata(\
121 | os.path.join(out_dir, 'images'), \
122 | listdir=listdir,\
123 | featslistdir=featslistdir,
124 | split='test')
125 |
126 | nbatches = np.int_(np.floor(data.img_num/batchsize))
127 |
128 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\
129 | batch_size=batchsize, shuffle=False, drop_last=True)
130 |
131 | test_loss = 0.
132 | for batch_idx, (batch, batch_recon_const, batch_weights, batch_recon_const_outres, _) in \
133 | tqdm(enumerate(data_loader), total=nbatches):
134 |
135 | input_color = Variable(batch).cuda()
136 | lossweights = Variable(batch_weights).cuda()
137 | lossweights = lossweights.view(batchsize, -1)
138 | input_greylevel = Variable(batch_recon_const).cuda()
139 | z = Variable(torch.randn(batchsize, hiddensize))
140 |
141 | mu, logvar, color_out = model(input_color, input_greylevel, z)
142 | _, _, recon_loss_l2 = \
143 | vae_loss(mu, logvar, color_out, input_color, lossweights, batchsize)
144 | test_loss = test_loss + recon_loss_l2.data[0]
145 |
146 | test_loss = (test_loss*1.)/nbatches
147 |
148 | model.train(True)
149 |
150 | return test_loss
151 |
152 | def train_vae(logger=None):
153 |
154 | out_dir, listdir, featslistdir = get_dirpaths(args)
155 | batchsize = args.batchsize
156 | hiddensize = args.hiddensize
157 | nmix = args.nmix
158 | nepochs = args.epochs
159 |
160 | data = colordata(\
161 | os.path.join(out_dir, 'images'), \
162 | listdir=listdir,\
163 | featslistdir=featslistdir,
164 | split='train')
165 |
166 | nbatches = np.int_(np.floor(data.img_num/batchsize))
167 |
168 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\
169 | batch_size=batchsize, shuffle=True, drop_last=True)
170 |
171 | model = VAE()
172 | model.cuda()
173 | model.train(True)
174 |
175 | optimizer = optim.Adam(model.parameters(), lr=5e-5)
176 |
177 | itr_idx = 0
178 | for epochs in range(nepochs):
179 | train_loss = 0.
180 |
181 | for batch_idx, (batch, batch_recon_const, batch_weights, batch_recon_const_outres, _) in \
182 | tqdm(enumerate(data_loader), total=nbatches):
183 |
184 | input_color = Variable(batch).cuda()
185 | lossweights = Variable(batch_weights).cuda()
186 | lossweights = lossweights.view(batchsize, -1)
187 | input_greylevel = Variable(batch_recon_const).cuda()
188 | z = Variable(torch.randn(batchsize, hiddensize))
189 |
190 | optimizer.zero_grad()
191 | mu, logvar, color_out = model(input_color, input_greylevel, z)
192 | kl_loss, recon_loss, recon_loss_l2 = \
193 | vae_loss(mu, logvar, color_out, input_color, lossweights, batchsize)
194 | loss = kl_loss.mul(1e-2)+recon_loss
195 | recon_loss_l2.detach()
196 | loss.backward()
197 | optimizer.step()
198 |
199 | train_loss = train_loss + recon_loss_l2.data[0]
200 |
201 | if(logger):
202 | logger.update_plot(itr_idx, \
203 | [kl_loss.data[0], recon_loss.data[0], recon_loss_l2.data[0]], \
204 | plot_type='vae')
205 | itr_idx += 1
206 |
207 | if(batch_idx % args.logstep == 0):
208 | data.saveoutput_gt(color_out.cpu().data.numpy(), \
209 | batch.numpy(), \
210 | 'train_%05d_%05d' % (epochs, batch_idx), \
211 | batchsize, \
212 | net_recon_const=batch_recon_const_outres.numpy())
213 |
214 | train_loss = (train_loss*1.)/(nbatches)
215 | print('[DEBUG] VAE Train Loss, epoch %d has loss %f' % (epochs, train_loss))
216 |
217 | test_loss = test_vae(model)
218 | if(logger):
219 | logger.update_test_plot(epochs, test_loss)
220 | print('[DEBUG] VAE Test Loss, epoch %d has loss %f' % (epochs, test_loss))
221 |
222 | torch.save(model.state_dict(), '%s/models/model_vae.pth' % (out_dir))
223 |
224 | def train_mdn(logger=None):
225 | out_dir, listdir, featslistdir = get_dirpaths(args)
226 | batchsize = args.batchsize
227 | hiddensize = args.hiddensize
228 | nmix = args.nmix
229 | nepochs = args.epochs_mdn
230 |
231 | data = colordata(\
232 | os.path.join(out_dir, 'images'), \
233 | listdir=listdir,\
234 | featslistdir=featslistdir,
235 | split='train')
236 |
237 | nbatches = np.int_(np.floor(data.img_num/batchsize))
238 |
239 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\
240 | batch_size=batchsize, shuffle=True, drop_last=True)
241 |
242 | model_vae = VAE()
243 | model_vae.cuda()
244 | model_vae.load_state_dict(torch.load('%s/models/model_vae.pth' % (out_dir)))
245 | model_vae.train(False)
246 |
247 | model_mdn = MDN()
248 | model_mdn.cuda()
249 | model_mdn.train(True)
250 |
251 | optimizer = optim.Adam(model_mdn.parameters(), lr=1e-3)
252 |
253 | itr_idx = 0
254 | for epochs_mdn in range(nepochs):
255 | train_loss = 0.
256 |
257 | for batch_idx, (batch, batch_recon_const, batch_weights, _, batch_feats) in \
258 | tqdm(enumerate(data_loader), total=nbatches):
259 |
260 | input_color = Variable(batch).cuda()
261 | input_greylevel = Variable(batch_recon_const).cuda()
262 | input_feats = Variable(batch_feats).cuda()
263 | z = Variable(torch.randn(batchsize, hiddensize))
264 |
265 | optimizer.zero_grad()
266 |
267 | mu, logvar, _ = model_vae(input_color, input_greylevel, z)
268 | mdn_gmm_params = model_mdn(input_feats)
269 |
270 | loss, loss_l2 = mdn_loss(mdn_gmm_params, mu, torch.sqrt(torch.exp(logvar)), batchsize)
271 | loss.backward()
272 |
273 | optimizer.step()
274 |
275 | train_loss = train_loss + loss.data[0]
276 |
277 | if(logger):
278 | logger.update_plot(itr_idx, [loss.data[0], loss_l2.data[0]], plot_type='mdn')
279 | itr_idx += 1
280 |
281 | train_loss = (train_loss*1.)/(nbatches)
282 | print('[DEBUG] Training MDN, epoch %d has loss %f' % (epochs_mdn, train_loss))
283 | torch.save(model_mdn.state_dict(), '%s/models/model_mdn.pth' % (out_dir))
284 |
285 | def divcolor():
286 | out_dir, listdir, featslistdir = get_dirpaths(args)
287 | batchsize = args.batchsize
288 | hiddensize = args.hiddensize
289 | nmix = args.nmix
290 |
291 | data = colordata(\
292 | os.path.join(out_dir, 'images'), \
293 | listdir=listdir,\
294 | featslistdir=featslistdir,
295 | split='test')
296 |
297 | nbatches = np.int_(np.floor(data.img_num/batchsize))
298 |
299 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\
300 | batch_size=batchsize, shuffle=True, drop_last=True)
301 |
302 | model_vae = VAE()
303 | model_vae.cuda()
304 | model_vae.load_state_dict(torch.load('%s/models/model_vae.pth' % (out_dir)))
305 | model_vae.train(False)
306 |
307 | model_mdn = MDN()
308 | model_mdn.cuda()
309 | model_mdn.load_state_dict(torch.load('%s/models/model_mdn.pth' % (out_dir)))
310 | model_mdn.train(False)
311 |
312 | for batch_idx, (batch, batch_recon_const, batch_weights, \
313 | batch_recon_const_outres, batch_feats) in \
314 | tqdm(enumerate(data_loader), total=nbatches):
315 |
316 | input_feats = Variable(batch_feats).cuda()
317 |
318 | mdn_gmm_params = model_mdn(input_feats)
319 | gmm_mu, gmm_pi = get_gmm_coeffs(mdn_gmm_params)
320 | gmm_pi = gmm_pi.view(-1, 1)
321 | gmm_mu = gmm_mu.view(-1, hiddensize)
322 |
323 | for j in range(batchsize):
324 | batch_j = np.tile(batch[j, ...].numpy(), (batchsize, 1, 1, 1))
325 | batch_recon_const_j = np.tile(batch_recon_const[j, ...].numpy(), (batchsize, 1, 1, 1))
326 | batch_recon_const_outres_j = np.tile(batch_recon_const_outres[j, ...].numpy(), \
327 | (batchsize, 1, 1, 1))
328 |
329 | input_color = Variable(torch.from_numpy(batch_j)).cuda()
330 | input_greylevel = Variable(torch.from_numpy(batch_recon_const_j)).cuda()
331 |
332 | curr_mu = gmm_mu[j*nmix:(j+1)*nmix, :]
333 | orderid = np.argsort(\
334 | gmm_pi[j*nmix:(j+1)*nmix, 0].cpu().data.numpy().reshape(-1))
335 |
336 | z = curr_mu.repeat(np.int((batchsize*1.)/nmix), 1)
337 |
338 | _, _, color_out = model_vae(input_color, input_greylevel, z, is_train=False)
339 |
340 | data.saveoutput_gt(color_out.cpu().data.numpy()[orderid, ...], \
341 | batch_j[orderid, ...], \
342 | 'divcolor_%05d_%05d' % (batch_idx, j), \
343 | nmix, \
344 | net_recon_const=batch_recon_const_outres_j[orderid, ...])
345 |
346 | if __name__ == '__main__':
347 |
348 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
349 |
350 | logger = None
351 | if(args.visdom):
352 | outdir, _, _ = get_dirpaths(args)
353 | logger = Logger(args.server, args.port_num, outdir)
354 |
355 | train_vae(logger=logger)
356 | train_mdn(logger=logger)
357 | divcolor()
358 |
--------------------------------------------------------------------------------
/mdn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | from torch.autograd import Variable
6 |
7 |
8 | class MDN(nn.Module):
9 |
10 | #define layers
11 | def __init__(self):
12 | super(MDN, self).__init__()
13 |
14 | self.feats_nch = 512
15 | self.hidden_size = 64
16 | self.nmix = 8
17 | self.nout = (self.hidden_size+1)*self.nmix
18 |
19 | #MDN Layers
20 | self.mdn_conv1 = nn.Conv2d(self.feats_nch, 384, 5, stride=1, padding=2)
21 | self.mdn_bn1 = nn.BatchNorm2d(384)
22 | self.mdn_conv2 = nn.Conv2d(384, 320, 5, stride=1, padding=2)
23 | self.mdn_bn2 = nn.BatchNorm2d(320)
24 | self.mdn_conv3 = nn.Conv2d(320, 288, 5, stride=1, padding=2)
25 | self.mdn_bn3 = nn.BatchNorm2d(288)
26 | self.mdn_conv4 = nn.Conv2d(288, 256, 5, stride=2, padding=2)
27 | self.mdn_bn4 = nn.BatchNorm2d(256)
28 | self.mdn_conv5 = nn.Conv2d(256, 128, 5, stride=1, padding=2)
29 | self.mdn_bn5 = nn.BatchNorm2d(128)
30 | self.mdn_conv6 = nn.Conv2d(128, 96, 5, stride=2, padding=2)
31 | self.mdn_bn6 = nn.BatchNorm2d(96)
32 | self.mdn_conv7 = nn.Conv2d(96, 64, 5, stride=2, padding=2)
33 | self.mdn_bn7 = nn.BatchNorm2d(64)
34 | self.mdn_dropout1 = nn.Dropout(p=.7)
35 | self.mdn_fc1 = nn.Linear(4*4*64, self.nout)
36 |
37 | #define forward pass
38 | def forward(self, feats):
39 | x = F.relu(self.mdn_conv1(feats))
40 | x = self.mdn_bn1(x)
41 | x = F.relu(self.mdn_conv2(x))
42 | x = self.mdn_bn2(x)
43 | x = F.relu(self.mdn_conv3(x))
44 | x = self.mdn_bn3(x)
45 | x = F.relu(self.mdn_conv4(x))
46 | x = self.mdn_bn4(x)
47 | x = F.relu(self.mdn_conv5(x))
48 | x = self.mdn_bn5(x)
49 | x = F.relu(self.mdn_conv6(x))
50 | x = self.mdn_bn6(x)
51 | x = F.relu(self.mdn_conv7(x))
52 | x = self.mdn_bn7(x)
53 | x = x.view(-1, 4*4*64)
54 | x = self.mdn_dropout1(x)
55 | x = self.mdn_fc1(x)
56 | return x
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchvision==0.2.0
2 | torch==0.3.0.post4
3 | tqdm==4.19.5
4 | opencv_python==3.3.0.10
5 | numpy==1.13.3
6 | visdom==0.1.6.5
7 |
--------------------------------------------------------------------------------
/vae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | from torch.autograd import Variable
6 |
7 |
8 | class VAE(nn.Module):
9 |
10 | #define layers
11 | def __init__(self):
12 | super(VAE, self).__init__()
13 | self.hidden_size = 64
14 |
15 | #Encoder layers
16 | self.enc_conv1 = nn.Conv2d(2, 128, 5, stride=2, padding=2)
17 | self.enc_bn1 = nn.BatchNorm2d(128)
18 | self.enc_conv2 = nn.Conv2d(128, 256, 5, stride=2, padding=2)
19 | self.enc_bn2 = nn.BatchNorm2d(256)
20 | self.enc_conv3 = nn.Conv2d(256, 512, 5, stride=2, padding=2)
21 | self.enc_bn3 = nn.BatchNorm2d(512)
22 | self.enc_conv4 = nn.Conv2d(512, 1024, 3, stride=2, padding=1)
23 | self.enc_bn4 = nn.BatchNorm2d(1024)
24 | self.enc_fc1 = nn.Linear(4*4*1024, self.hidden_size*2)
25 | self.enc_dropout1 = nn.Dropout(p=.7)
26 |
27 | #Cond encoder layers
28 | self.cond_enc_conv1 = nn.Conv2d(1, 128, 5, stride=2, padding=2)
29 | self.cond_enc_bn1 = nn.BatchNorm2d(128)
30 | self.cond_enc_conv2 = nn.Conv2d(128, 256, 5, stride=2, padding=2)
31 | self.cond_enc_bn2 = nn.BatchNorm2d(256)
32 | self.cond_enc_conv3 = nn.Conv2d(256, 512, 5, stride=2, padding=2)
33 | self.cond_enc_bn3 = nn.BatchNorm2d(512)
34 | self.cond_enc_conv4 = nn.Conv2d(512, 1024, 3, stride=2, padding=1)
35 | self.cond_enc_bn4 = nn.BatchNorm2d(1024)
36 |
37 | #Decoder layers
38 | self.dec_upsamp1 = nn.Upsample(scale_factor=4, mode='bilinear')
39 | self.dec_conv1 = nn.Conv2d(1024+self.hidden_size, 512, 3, stride=1, padding=1)
40 | self.dec_bn1 = nn.BatchNorm2d(512)
41 | self.dec_upsamp2 = nn.Upsample(scale_factor=2, mode='bilinear')
42 | self.dec_conv2 = nn.Conv2d(512*2, 256, 5, stride=1, padding=2)
43 | self.dec_bn2 = nn.BatchNorm2d(256)
44 | self.dec_upsamp3 = nn.Upsample(scale_factor=2, mode='bilinear')
45 | self.dec_conv3 = nn.Conv2d(256*2, 128, 5, stride=1, padding=2)
46 | self.dec_bn3 = nn.BatchNorm2d(128)
47 | self.dec_upsamp4 = nn.Upsample(scale_factor=2, mode='bilinear')
48 | self.dec_conv4 = nn.Conv2d(128*2, 64, 5, stride=1, padding=2)
49 | self.dec_bn4 = nn.BatchNorm2d(64)
50 | self.dec_upsamp5 = nn.Upsample(scale_factor=2, mode='bilinear')
51 | self.dec_conv5 = nn.Conv2d(64, 2, 5, stride=1, padding=2)
52 |
53 | def encoder(self, x):
54 | x = F.relu(self.enc_conv1(x))
55 | x = self.enc_bn1(x)
56 | x = F.relu(self.enc_conv2(x))
57 | x = self.enc_bn2(x)
58 | x = F.relu(self.enc_conv3(x))
59 | x = self.enc_bn3(x)
60 | x = F.relu(self.enc_conv4(x))
61 | x = self.enc_bn4(x)
62 | x = x.view(-1, 4*4*1024)
63 | x = self.enc_dropout1(x)
64 | x = self.enc_fc1(x)
65 | mu = x[..., :self.hidden_size]
66 | logvar = x[..., self.hidden_size:]
67 | return mu, logvar
68 |
69 | def cond_encoder(self, x):
70 | x = F.relu(self.cond_enc_conv1(x))
71 | sc_feat32 = self.cond_enc_bn1(x)
72 | x = F.relu(self.cond_enc_conv2(sc_feat32))
73 | sc_feat16 = self.cond_enc_bn2(x)
74 | x = F.relu(self.cond_enc_conv3(sc_feat16))
75 | sc_feat8 = self.cond_enc_bn3(x)
76 | x = F.relu(self.cond_enc_conv4(sc_feat8))
77 | sc_feat4 = self.cond_enc_bn4(x)
78 | return sc_feat32, sc_feat16, sc_feat8, sc_feat4
79 |
80 | def decoder(self, z, sc_feat32, sc_feat16, sc_feat8, sc_feat4):
81 | x = z.view(-1, self.hidden_size, 1, 1)
82 | x = self.dec_upsamp1(x)
83 | x = torch.cat([x, sc_feat4], 1)
84 | x = F.relu(self.dec_conv1(x))
85 | x = self.dec_bn1(x)
86 | x = self.dec_upsamp2(x)
87 | x = torch.cat([x, sc_feat8], 1)
88 | x = F.relu(self.dec_conv2(x))
89 | x = self.dec_bn2(x)
90 | x = self.dec_upsamp3(x)
91 | x = torch.cat([x, sc_feat16], 1)
92 | x = F.relu(self.dec_conv3(x))
93 | x = self.dec_bn3(x)
94 | x = self.dec_upsamp4(x)
95 | x = torch.cat([x, sc_feat32], 1)
96 | x = F.relu(self.dec_conv4(x))
97 | x = self.dec_bn4(x)
98 | x = self.dec_upsamp5(x)
99 | x = F.tanh(self.dec_conv5(x))
100 | return x
101 |
102 | #define forward pass
103 | def forward(self, color, greylevel, z_in, is_train=True):
104 | sc_feat32, sc_feat16, sc_feat8, sc_feat4 = self.cond_encoder(greylevel)
105 | mu, logvar = self.encoder(color)
106 | if(is_train == True):
107 | stddev = torch.sqrt(torch.exp(logvar))
108 | eps = Variable(torch.randn(stddev.size()).normal_()).cuda()
109 | z = torch.add(mu, torch.mul(eps, stddev))
110 | else:
111 | z = z_in
112 | color_out = self.decoder(z, sc_feat32, sc_feat16, sc_feat8, sc_feat4)
113 | return mu, logvar, color_out
114 |
--------------------------------------------------------------------------------