├── .gitignore ├── README.md ├── colordata.py ├── get_data.sh ├── images └── pytorch_lfw.png ├── logger.py ├── main.py ├── mdn.py ├── requirements.txt └── vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.txt 3 | *.jpg 4 | *.png 5 | *.JPEG 6 | *meta 7 | *ckpt 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_divcolor 2 | 3 | PyTorch implementation of Diverse Colorization -- Deshpande et al. "[Learning Diverse Image Colorization](https://arxiv.org/abs/1612.01958)" 4 | 5 | This code is tested for python-3.5.2 and torch-0.3.0. Install packages in requirements.txt 6 | 7 | The Tensorflow implementation used in the paper is [divcolor](https://github.com/aditya12agd5/divcolor) 8 | 9 | 10 | Fetch data by 11 | 12 | ``` 13 | bash get_data.sh 14 | ``` 15 | 16 | Execute main.py to first train vae+mdn and then, generate results for LFW 17 | 18 | ``` 19 | python main.py lfw 20 | ``` 21 | 22 | If you use this code, please cite 23 | 24 | ``` 25 | @inproceedings{DeshpandeLDColor17, 26 | author = {Aditya Deshpande, Jiajun Lu, Mao-Chuang Yeh, Min Jin Chong and David Forsyth}, 27 | title = {Learning Diverse Image Colorization}, 28 | booktitle={Computer Vision and Pattern Recognition}, 29 | url={https://arxiv.org/abs/1612.01958}, 30 | year={2017} 31 | } 32 | ``` 33 | 34 | Some examples of diverse colorizations on LFW 35 | 36 |

37 | 38 |

39 | 40 | -------------------------------------------------------------------------------- /colordata.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import math 4 | import numpy as np 5 | import os 6 | 7 | from torch.utils.data import Dataset 8 | import torchvision.datasets as datasets 9 | import torchvision.transforms as transforms 10 | 11 | class colordata(Dataset): 12 | 13 | def __init__(self, out_directory, listdir=None, featslistdir=None, shape=(64, 64), \ 14 | subdir=False, ext='JPEG', outshape=(256, 256), split='train'): 15 | 16 | self.img_fns = [] 17 | self.feats_fns = [] 18 | 19 | with open('%s/list.%s.vae.txt' % (listdir, split), 'r') as ftr: 20 | for img_fn in ftr: 21 | self.img_fns.append(img_fn.strip('\n')) 22 | 23 | with open('%s/list.%s.txt' % (featslistdir, split), 'r') as ftr: 24 | for feats_fn in ftr: 25 | self.feats_fns.append(feats_fn.strip('\n')) 26 | 27 | self.img_num = min(len(self.img_fns), len(self.feats_fns)) 28 | self.shape = shape 29 | self.outshape = outshape 30 | self.out_directory = out_directory 31 | 32 | self.lossweights = None 33 | countbins = 1./np.load('data/zhang_weights/prior_probs.npy') 34 | binedges = np.load('data/zhang_weights/ab_quantize.npy').reshape(2, 313) 35 | lossweights = {} 36 | for i in range(313): 37 | if binedges[0, i] not in lossweights: 38 | lossweights[binedges[0, i]] = {} 39 | lossweights[binedges[0,i]][binedges[1,i]] = countbins[i] 40 | self.binedges = binedges 41 | self.lossweights = lossweights 42 | 43 | def __len__(self): 44 | return self.img_num 45 | 46 | def __getitem__(self, idx): 47 | color_ab = np.zeros((2, self.shape[0], self.shape[1]), dtype='f') 48 | weights = np.ones((2, self.shape[0], self.shape[1]), dtype='f') 49 | recon_const = np.zeros((1, self.shape[0], self.shape[1]), dtype='f') 50 | recon_const_outres = np.zeros((1, self.outshape[0], self.outshape[1]), dtype='f') 51 | greyfeats = np.zeros((512, 28, 28), dtype='f') 52 | 53 | img_large = cv2.imread(self.img_fns[idx]) 54 | if(self.shape is not None): 55 | img = cv2.resize(img_large, (self.shape[0], self.shape[1])) 56 | img_outres = cv2.resize(img_large, (self.outshape[0], self.outshape[1])) 57 | 58 | img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 59 | img_lab_outres = cv2.cvtColor(img_outres, cv2.COLOR_BGR2LAB) 60 | 61 | img_lab = ((img_lab*2.)/255.)-1. 62 | img_lab_outres = ((img_lab_outres*2.)/255.)-1. 63 | 64 | recon_const[0, :, :] = img_lab[..., 0] 65 | recon_const_outres[0, :, :] = img_lab_outres[..., 0] 66 | 67 | color_ab[0, :, :] = img_lab[..., 1].reshape(1, self.shape[0], self.shape[1]) 68 | color_ab[1, :, :] = img_lab[..., 2].reshape(1, self.shape[0], self.shape[1]) 69 | 70 | if(self.lossweights is not None): 71 | weights = self.__getweights__(color_ab) 72 | 73 | featobj = np.load(self.feats_fns[idx]) 74 | greyfeats[:,:,:] = featobj['arr_0'] 75 | 76 | return color_ab, recon_const, weights, recon_const_outres, greyfeats 77 | 78 | def __getweights__(self, img): 79 | img_vec = img.reshape(-1) 80 | img_vec = img_vec*128. 81 | img_lossweights = np.zeros(img.shape, dtype='f') 82 | img_vec_a = img_vec[:np.prod(self.shape)] 83 | binedges_a = self.binedges[0,...].reshape(-1) 84 | binid_a = [binedges_a.flat[np.abs(binedges_a-v).argmin()] for v in img_vec_a] 85 | img_vec_b = img_vec[np.prod(self.shape):] 86 | binedges_b = self.binedges[1,...].reshape(-1) 87 | binid_b = [binedges_b.flat[np.abs(binedges_b-v).argmin()] for v in img_vec_b] 88 | binweights = np.array([self.lossweights[v1][v2] for v1,v2 in zip(binid_a, binid_b)]) 89 | img_lossweights[0, :, :] = binweights.reshape(self.shape[0], self.shape[1]) 90 | img_lossweights[1, :, :] = binweights.reshape(self.shape[0], self.shape[1]) 91 | return img_lossweights 92 | 93 | def saveoutput_gt(self, net_op, gt, prefix, batch_size, num_cols=8, net_recon_const=None): 94 | 95 | net_out_img = self.__tiledoutput__(net_op, batch_size, num_cols=num_cols, \ 96 | net_recon_const=net_recon_const) 97 | 98 | gt_out_img = self.__tiledoutput__(gt, batch_size, num_cols=num_cols, \ 99 | net_recon_const=net_recon_const) 100 | 101 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols)) 102 | border_img = 255*np.ones((num_rows*self.outshape[0], 128, 3), dtype='uint8') 103 | out_fn_pred = '%s/%s.png' % (self.out_directory, prefix) 104 | cv2.imwrite(out_fn_pred, np.concatenate((net_out_img, border_img, gt_out_img), axis=1)) 105 | 106 | def __tiledoutput__(self, net_op, batch_size, num_cols=8, net_recon_const=None): 107 | 108 | num_rows = np.int_(np.ceil((batch_size*1.)/num_cols)) 109 | out_img = np.zeros((num_rows*self.outshape[0], num_cols*self.outshape[1], 3), dtype='uint8') 110 | img_lab = np.zeros((self.outshape[0], self.outshape[1], 3), dtype='uint8') 111 | c = 0 112 | r = 0 113 | 114 | for i in range(batch_size): 115 | if(i % num_cols == 0 and i > 0): 116 | r = r + 1 117 | c = 0 118 | img_lab[..., 0] = self.__decodeimg__(net_recon_const[i, 0, :, :].reshape(\ 119 | self.outshape[0], self.outshape[1])) 120 | img_lab[..., 1] = self.__decodeimg__(net_op[i, 0, :, :].reshape(\ 121 | self.shape[0], self.shape[1])) 122 | img_lab[..., 2] = self.__decodeimg__(net_op[i, 1, :, :].reshape(\ 123 | self.shape[0], self.shape[1])) 124 | img_rgb = cv2.cvtColor(img_lab, cv2.COLOR_LAB2BGR) 125 | out_img[r*self.outshape[0]:(r+1)*self.outshape[0], \ 126 | c*self.outshape[1]:(c+1)*self.outshape[1], ...] = img_rgb 127 | c = c+1 128 | 129 | return out_img 130 | 131 | def __decodeimg__(self, img_enc): 132 | img_dec = (((img_enc+1.)*1.)/2.)*255. 133 | img_dec[img_dec < 0.] = 0. 134 | img_dec[img_dec > 255.] = 255. 135 | return cv2.resize(np.uint8(img_dec), (self.outshape[0], self.outshape[1])) 136 | 137 | 138 | -------------------------------------------------------------------------------- /get_data.sh: -------------------------------------------------------------------------------- 1 | wget http://vision.cs.illinois.edu/projects/divcolor/data.zip 2 | unzip data.zip 3 | rm data.zip 4 | wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz 5 | tar -xvzf lfw-deepfunneled.tgz 6 | mv lfw-deepfunneled data/lfw_images 7 | rm lfw-deepfunneled.tgz 8 | rm data/output/lfw/*/* 9 | rm data/output/lfw/*npy 10 | -------------------------------------------------------------------------------- /images/pytorch_lfw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aditya12agd5/pytorch_divcolor/64854473e2071ba6628eae9450758c362449a864/images/pytorch_lfw.png -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import visdom 3 | import numpy as np 4 | 5 | class Logger(): 6 | 7 | def __init__(self, server, port, outdir): 8 | self.vis = visdom.Visdom(port=port, server=server) 9 | 10 | titles = ['VAE -- KL Div', 'VAE -- Weighted L2', 'VAE -- L2'] 11 | self.vis_plot_vae = [] 12 | for title in titles: 13 | self.vis_plot_vae.append(self.vis.line( 14 | X=np.array([0.], dtype='f'), 15 | Y=np.array([0.], dtype='f'), 16 | opts=dict( 17 | xlabel='Iteration',\ 18 | ylabel='Loss',\ 19 | title=title))) 20 | 21 | self.vis_plot_test_vae = self.vis.line( 22 | X=np.array([0.], dtype='f'), 23 | Y=np.array([0.], dtype='f'), 24 | opts=dict( 25 | xlabel='Iteration',\ 26 | ylabel='Test Loss',\ 27 | title='VAE Test Loss')) 28 | 29 | 30 | self.vis_plot_mdn = [] 31 | titles = ['MDN Loss', 'MDN -- L2'] 32 | for title in titles: 33 | self.vis_plot_mdn.append(self.vis.line( 34 | X=np.array([0.], dtype='f'), 35 | Y=np.array([0.], dtype='f'), 36 | opts=dict( 37 | xlabel='Iteration',\ 38 | ylabel='Loss',\ 39 | title=title))) 40 | 41 | self.fp_vae = open('%s/log_vae.txt' % outdir, 'w') 42 | self.fp_vae.write('Iteration; KLDiv; WeightedL2; L2;\n') 43 | self.fp_vae.flush() 44 | 45 | self.fp_test_vae = open('%s/log_test_vae.txt' % outdir, 'w') 46 | self.fp_test_vae.write('Iteration; Loss;\n') 47 | self.fp_test_vae.flush() 48 | 49 | self.fp_mdn = open('%s/log_mdn.txt' % outdir, 'w') 50 | self.fp_mdn.write('Iteration; Loss; L2 Loss;\n') 51 | self.fp_mdn.flush() 52 | 53 | def update_plot(self, x, losses, plot_type='vae'): 54 | 55 | if(plot_type == 'vae'): 56 | self.fp_vae.write('%f;' % x) 57 | for loss_i, loss in enumerate(losses): 58 | win = self.vis_plot_vae[loss_i] 59 | self.vis.updateTrace( 60 | X=np.array([x], dtype='f'), 61 | Y=np.array([loss], dtype='f'), 62 | win=win) 63 | self.fp_vae.write(' %f;' % loss) 64 | self.fp_vae.write('\n') 65 | self.fp_vae.flush() 66 | 67 | 68 | elif(plot_type == 'mdn'): 69 | for loss_i, loss in enumerate(losses): 70 | win = self.vis_plot_mdn[loss_i] 71 | self.vis.updateTrace( 72 | X=np.array([x], dtype='f'), 73 | Y=np.array([losses[loss_i]], dtype='f'), 74 | win=win) 75 | self.fp_mdn.write('%f; %f; %f;\n' % (x, losses[0], losses[1])) 76 | self.fp_mdn.flush() 77 | 78 | def update_test_plot(self, x, y): 79 | self.vis.updateTrace( 80 | X=np.array([x], dtype='f'), 81 | Y=np.array([y], dtype='f'), 82 | win=self.vis_plot_test_vae) 83 | self.fp_test_vae.write('%f; %f;\n' % (x, y)) 84 | self.fp_test_vae.flush() 85 | 86 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import socket 6 | import sys 7 | import numpy as np 8 | 9 | from colordata import colordata 10 | from vae import VAE 11 | from mdn import MDN 12 | from logger import Logger 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import torch.optim as optim 18 | from torch.autograd import Variable 19 | from torch.utils.data import DataLoader 20 | 21 | from tqdm import tqdm 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch Diverse Colorization') 24 | 25 | parser.add_argument('dataset_key', help='Dataset') 26 | 27 | parser.add_argument('-g', '--gpu', type=int, default=0,\ 28 | help='gpu device id') 29 | 30 | parser.add_argument('-e', '--epochs', type=int, default=15,\ 31 | help='Number of epochs') 32 | 33 | parser.add_argument('-b', '--batchsize', type=int, default=32,\ 34 | help='Batch size') 35 | 36 | parser.add_argument('-z', '--hiddensize', type=int, default=64,\ 37 | help='Latent vector dimension') 38 | 39 | parser.add_argument('-n', '--nthreads', type=int, default=4,\ 40 | help='Data loader threads') 41 | 42 | parser.add_argument('-em', '--epochs_mdn', type=int, default=7,\ 43 | help='Number of epochs for MDN') 44 | 45 | parser.add_argument('-m', '--nmix', type=int, default=8,\ 46 | help='Number of diverse colorization (or output gmm components)') 47 | 48 | parser.add_argument('-lg', '--logstep', type=int, default=100,\ 49 | help='Interval to log data') 50 | 51 | parser.add_argument('-v', '--visdom', action='store_true',\ 52 | help='Visdom visualization') 53 | 54 | parser.add_argument('-s', '--server', type=str, default='http://vision-gpu-4.cs.illinois.edu',\ 55 | help='Visdom server') 56 | 57 | parser.add_argument('-p', '--port_num', type=int, default=8097) 58 | 59 | 60 | args = parser.parse_args() 61 | 62 | if(args.visdom): 63 | import visdom 64 | 65 | def get_dirpaths(args): 66 | if(args.dataset_key == 'lfw'): 67 | out_dir = 'data/output/lfw/' 68 | listdir = 'data/imglist/lfw/' 69 | featslistdir = 'data/featslist/lfw/' 70 | else: 71 | raise NameError('[ERROR] Incorrect key: %s' % (args.dataset_key)) 72 | return out_dir, listdir, featslistdir 73 | 74 | def vae_loss(mu, logvar, pred, gt, lossweights, batchsize): 75 | kl_element = torch.add(torch.add(torch.add(mu.pow(2), logvar.exp()), -1), logvar.mul(-1)) 76 | kl_loss = torch.sum(kl_element).mul(.5) 77 | gt = gt.view(-1, 64*64*2) 78 | pred = pred.view(-1, 64*64*2) 79 | recon_element = torch.sqrt(torch.sum(torch.mul(torch.add(gt, pred.mul(-1)).pow(2), lossweights), 1)) 80 | recon_loss = torch.sum(recon_element).mul(1./(batchsize)) 81 | 82 | recon_element_l2 = torch.sqrt(torch.sum(torch.add(gt, pred.mul(-1)).pow(2), 1)) 83 | recon_loss_l2 = torch.sum(recon_element_l2).mul(1./(batchsize)) 84 | 85 | return kl_loss, recon_loss, recon_loss_l2 86 | 87 | def get_gmm_coeffs(gmm_params): 88 | gmm_mu = gmm_params[..., :args.hiddensize*args.nmix] 89 | gmm_mu.contiguous() 90 | gmm_pi_activ = gmm_params[..., args.hiddensize*args.nmix:] 91 | gmm_pi_activ.contiguous() 92 | gmm_pi = F.softmax(gmm_pi_activ) 93 | return gmm_mu, gmm_pi 94 | 95 | def mdn_loss(gmm_params, mu, stddev, batchsize): 96 | gmm_mu, gmm_pi = get_gmm_coeffs(gmm_params) 97 | eps = Variable(torch.randn(stddev.size()).normal_()).cuda() 98 | z = torch.add(mu, torch.mul(eps, stddev)) 99 | z_flat = z.repeat(1, args.nmix) 100 | z_flat = z_flat.view(batchsize*args.nmix, args.hiddensize) 101 | gmm_mu_flat = gmm_mu.view(batchsize*args.nmix, args.hiddensize) 102 | dist_all = torch.sqrt(torch.sum(torch.add(z_flat, gmm_mu_flat.mul(-1)).pow(2).mul(50), 1)) 103 | dist_all = dist_all.view(batchsize, args.nmix) 104 | dist_min, selectids = torch.min(dist_all, 1) 105 | gmm_pi_min = torch.gather(gmm_pi, 1, selectids.view(-1, 1)) 106 | gmm_loss = torch.mean(torch.add(-1*torch.log(gmm_pi_min+1e-30), dist_min)) 107 | gmm_loss_l2 = torch.mean(dist_min) 108 | return gmm_loss, gmm_loss_l2 109 | 110 | 111 | def test_vae(model): 112 | 113 | model.train(False) 114 | 115 | out_dir, listdir, featslistdir = get_dirpaths(args) 116 | batchsize = args.batchsize 117 | hiddensize = args.hiddensize 118 | nmix = args.nmix 119 | 120 | data = colordata(\ 121 | os.path.join(out_dir, 'images'), \ 122 | listdir=listdir,\ 123 | featslistdir=featslistdir, 124 | split='test') 125 | 126 | nbatches = np.int_(np.floor(data.img_num/batchsize)) 127 | 128 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\ 129 | batch_size=batchsize, shuffle=False, drop_last=True) 130 | 131 | test_loss = 0. 132 | for batch_idx, (batch, batch_recon_const, batch_weights, batch_recon_const_outres, _) in \ 133 | tqdm(enumerate(data_loader), total=nbatches): 134 | 135 | input_color = Variable(batch).cuda() 136 | lossweights = Variable(batch_weights).cuda() 137 | lossweights = lossweights.view(batchsize, -1) 138 | input_greylevel = Variable(batch_recon_const).cuda() 139 | z = Variable(torch.randn(batchsize, hiddensize)) 140 | 141 | mu, logvar, color_out = model(input_color, input_greylevel, z) 142 | _, _, recon_loss_l2 = \ 143 | vae_loss(mu, logvar, color_out, input_color, lossweights, batchsize) 144 | test_loss = test_loss + recon_loss_l2.data[0] 145 | 146 | test_loss = (test_loss*1.)/nbatches 147 | 148 | model.train(True) 149 | 150 | return test_loss 151 | 152 | def train_vae(logger=None): 153 | 154 | out_dir, listdir, featslistdir = get_dirpaths(args) 155 | batchsize = args.batchsize 156 | hiddensize = args.hiddensize 157 | nmix = args.nmix 158 | nepochs = args.epochs 159 | 160 | data = colordata(\ 161 | os.path.join(out_dir, 'images'), \ 162 | listdir=listdir,\ 163 | featslistdir=featslistdir, 164 | split='train') 165 | 166 | nbatches = np.int_(np.floor(data.img_num/batchsize)) 167 | 168 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\ 169 | batch_size=batchsize, shuffle=True, drop_last=True) 170 | 171 | model = VAE() 172 | model.cuda() 173 | model.train(True) 174 | 175 | optimizer = optim.Adam(model.parameters(), lr=5e-5) 176 | 177 | itr_idx = 0 178 | for epochs in range(nepochs): 179 | train_loss = 0. 180 | 181 | for batch_idx, (batch, batch_recon_const, batch_weights, batch_recon_const_outres, _) in \ 182 | tqdm(enumerate(data_loader), total=nbatches): 183 | 184 | input_color = Variable(batch).cuda() 185 | lossweights = Variable(batch_weights).cuda() 186 | lossweights = lossweights.view(batchsize, -1) 187 | input_greylevel = Variable(batch_recon_const).cuda() 188 | z = Variable(torch.randn(batchsize, hiddensize)) 189 | 190 | optimizer.zero_grad() 191 | mu, logvar, color_out = model(input_color, input_greylevel, z) 192 | kl_loss, recon_loss, recon_loss_l2 = \ 193 | vae_loss(mu, logvar, color_out, input_color, lossweights, batchsize) 194 | loss = kl_loss.mul(1e-2)+recon_loss 195 | recon_loss_l2.detach() 196 | loss.backward() 197 | optimizer.step() 198 | 199 | train_loss = train_loss + recon_loss_l2.data[0] 200 | 201 | if(logger): 202 | logger.update_plot(itr_idx, \ 203 | [kl_loss.data[0], recon_loss.data[0], recon_loss_l2.data[0]], \ 204 | plot_type='vae') 205 | itr_idx += 1 206 | 207 | if(batch_idx % args.logstep == 0): 208 | data.saveoutput_gt(color_out.cpu().data.numpy(), \ 209 | batch.numpy(), \ 210 | 'train_%05d_%05d' % (epochs, batch_idx), \ 211 | batchsize, \ 212 | net_recon_const=batch_recon_const_outres.numpy()) 213 | 214 | train_loss = (train_loss*1.)/(nbatches) 215 | print('[DEBUG] VAE Train Loss, epoch %d has loss %f' % (epochs, train_loss)) 216 | 217 | test_loss = test_vae(model) 218 | if(logger): 219 | logger.update_test_plot(epochs, test_loss) 220 | print('[DEBUG] VAE Test Loss, epoch %d has loss %f' % (epochs, test_loss)) 221 | 222 | torch.save(model.state_dict(), '%s/models/model_vae.pth' % (out_dir)) 223 | 224 | def train_mdn(logger=None): 225 | out_dir, listdir, featslistdir = get_dirpaths(args) 226 | batchsize = args.batchsize 227 | hiddensize = args.hiddensize 228 | nmix = args.nmix 229 | nepochs = args.epochs_mdn 230 | 231 | data = colordata(\ 232 | os.path.join(out_dir, 'images'), \ 233 | listdir=listdir,\ 234 | featslistdir=featslistdir, 235 | split='train') 236 | 237 | nbatches = np.int_(np.floor(data.img_num/batchsize)) 238 | 239 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\ 240 | batch_size=batchsize, shuffle=True, drop_last=True) 241 | 242 | model_vae = VAE() 243 | model_vae.cuda() 244 | model_vae.load_state_dict(torch.load('%s/models/model_vae.pth' % (out_dir))) 245 | model_vae.train(False) 246 | 247 | model_mdn = MDN() 248 | model_mdn.cuda() 249 | model_mdn.train(True) 250 | 251 | optimizer = optim.Adam(model_mdn.parameters(), lr=1e-3) 252 | 253 | itr_idx = 0 254 | for epochs_mdn in range(nepochs): 255 | train_loss = 0. 256 | 257 | for batch_idx, (batch, batch_recon_const, batch_weights, _, batch_feats) in \ 258 | tqdm(enumerate(data_loader), total=nbatches): 259 | 260 | input_color = Variable(batch).cuda() 261 | input_greylevel = Variable(batch_recon_const).cuda() 262 | input_feats = Variable(batch_feats).cuda() 263 | z = Variable(torch.randn(batchsize, hiddensize)) 264 | 265 | optimizer.zero_grad() 266 | 267 | mu, logvar, _ = model_vae(input_color, input_greylevel, z) 268 | mdn_gmm_params = model_mdn(input_feats) 269 | 270 | loss, loss_l2 = mdn_loss(mdn_gmm_params, mu, torch.sqrt(torch.exp(logvar)), batchsize) 271 | loss.backward() 272 | 273 | optimizer.step() 274 | 275 | train_loss = train_loss + loss.data[0] 276 | 277 | if(logger): 278 | logger.update_plot(itr_idx, [loss.data[0], loss_l2.data[0]], plot_type='mdn') 279 | itr_idx += 1 280 | 281 | train_loss = (train_loss*1.)/(nbatches) 282 | print('[DEBUG] Training MDN, epoch %d has loss %f' % (epochs_mdn, train_loss)) 283 | torch.save(model_mdn.state_dict(), '%s/models/model_mdn.pth' % (out_dir)) 284 | 285 | def divcolor(): 286 | out_dir, listdir, featslistdir = get_dirpaths(args) 287 | batchsize = args.batchsize 288 | hiddensize = args.hiddensize 289 | nmix = args.nmix 290 | 291 | data = colordata(\ 292 | os.path.join(out_dir, 'images'), \ 293 | listdir=listdir,\ 294 | featslistdir=featslistdir, 295 | split='test') 296 | 297 | nbatches = np.int_(np.floor(data.img_num/batchsize)) 298 | 299 | data_loader = DataLoader(dataset=data, num_workers=args.nthreads,\ 300 | batch_size=batchsize, shuffle=True, drop_last=True) 301 | 302 | model_vae = VAE() 303 | model_vae.cuda() 304 | model_vae.load_state_dict(torch.load('%s/models/model_vae.pth' % (out_dir))) 305 | model_vae.train(False) 306 | 307 | model_mdn = MDN() 308 | model_mdn.cuda() 309 | model_mdn.load_state_dict(torch.load('%s/models/model_mdn.pth' % (out_dir))) 310 | model_mdn.train(False) 311 | 312 | for batch_idx, (batch, batch_recon_const, batch_weights, \ 313 | batch_recon_const_outres, batch_feats) in \ 314 | tqdm(enumerate(data_loader), total=nbatches): 315 | 316 | input_feats = Variable(batch_feats).cuda() 317 | 318 | mdn_gmm_params = model_mdn(input_feats) 319 | gmm_mu, gmm_pi = get_gmm_coeffs(mdn_gmm_params) 320 | gmm_pi = gmm_pi.view(-1, 1) 321 | gmm_mu = gmm_mu.view(-1, hiddensize) 322 | 323 | for j in range(batchsize): 324 | batch_j = np.tile(batch[j, ...].numpy(), (batchsize, 1, 1, 1)) 325 | batch_recon_const_j = np.tile(batch_recon_const[j, ...].numpy(), (batchsize, 1, 1, 1)) 326 | batch_recon_const_outres_j = np.tile(batch_recon_const_outres[j, ...].numpy(), \ 327 | (batchsize, 1, 1, 1)) 328 | 329 | input_color = Variable(torch.from_numpy(batch_j)).cuda() 330 | input_greylevel = Variable(torch.from_numpy(batch_recon_const_j)).cuda() 331 | 332 | curr_mu = gmm_mu[j*nmix:(j+1)*nmix, :] 333 | orderid = np.argsort(\ 334 | gmm_pi[j*nmix:(j+1)*nmix, 0].cpu().data.numpy().reshape(-1)) 335 | 336 | z = curr_mu.repeat(np.int((batchsize*1.)/nmix), 1) 337 | 338 | _, _, color_out = model_vae(input_color, input_greylevel, z, is_train=False) 339 | 340 | data.saveoutput_gt(color_out.cpu().data.numpy()[orderid, ...], \ 341 | batch_j[orderid, ...], \ 342 | 'divcolor_%05d_%05d' % (batch_idx, j), \ 343 | nmix, \ 344 | net_recon_const=batch_recon_const_outres_j[orderid, ...]) 345 | 346 | if __name__ == '__main__': 347 | 348 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 349 | 350 | logger = None 351 | if(args.visdom): 352 | outdir, _, _ = get_dirpaths(args) 353 | logger = Logger(args.server, args.port_num, outdir) 354 | 355 | train_vae(logger=logger) 356 | train_mdn(logger=logger) 357 | divcolor() 358 | -------------------------------------------------------------------------------- /mdn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | 7 | 8 | class MDN(nn.Module): 9 | 10 | #define layers 11 | def __init__(self): 12 | super(MDN, self).__init__() 13 | 14 | self.feats_nch = 512 15 | self.hidden_size = 64 16 | self.nmix = 8 17 | self.nout = (self.hidden_size+1)*self.nmix 18 | 19 | #MDN Layers 20 | self.mdn_conv1 = nn.Conv2d(self.feats_nch, 384, 5, stride=1, padding=2) 21 | self.mdn_bn1 = nn.BatchNorm2d(384) 22 | self.mdn_conv2 = nn.Conv2d(384, 320, 5, stride=1, padding=2) 23 | self.mdn_bn2 = nn.BatchNorm2d(320) 24 | self.mdn_conv3 = nn.Conv2d(320, 288, 5, stride=1, padding=2) 25 | self.mdn_bn3 = nn.BatchNorm2d(288) 26 | self.mdn_conv4 = nn.Conv2d(288, 256, 5, stride=2, padding=2) 27 | self.mdn_bn4 = nn.BatchNorm2d(256) 28 | self.mdn_conv5 = nn.Conv2d(256, 128, 5, stride=1, padding=2) 29 | self.mdn_bn5 = nn.BatchNorm2d(128) 30 | self.mdn_conv6 = nn.Conv2d(128, 96, 5, stride=2, padding=2) 31 | self.mdn_bn6 = nn.BatchNorm2d(96) 32 | self.mdn_conv7 = nn.Conv2d(96, 64, 5, stride=2, padding=2) 33 | self.mdn_bn7 = nn.BatchNorm2d(64) 34 | self.mdn_dropout1 = nn.Dropout(p=.7) 35 | self.mdn_fc1 = nn.Linear(4*4*64, self.nout) 36 | 37 | #define forward pass 38 | def forward(self, feats): 39 | x = F.relu(self.mdn_conv1(feats)) 40 | x = self.mdn_bn1(x) 41 | x = F.relu(self.mdn_conv2(x)) 42 | x = self.mdn_bn2(x) 43 | x = F.relu(self.mdn_conv3(x)) 44 | x = self.mdn_bn3(x) 45 | x = F.relu(self.mdn_conv4(x)) 46 | x = self.mdn_bn4(x) 47 | x = F.relu(self.mdn_conv5(x)) 48 | x = self.mdn_bn5(x) 49 | x = F.relu(self.mdn_conv6(x)) 50 | x = self.mdn_bn6(x) 51 | x = F.relu(self.mdn_conv7(x)) 52 | x = self.mdn_bn7(x) 53 | x = x.view(-1, 4*4*64) 54 | x = self.mdn_dropout1(x) 55 | x = self.mdn_fc1(x) 56 | return x 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.2.0 2 | torch==0.3.0.post4 3 | tqdm==4.19.5 4 | opencv_python==3.3.0.10 5 | numpy==1.13.3 6 | visdom==0.1.6.5 7 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | 7 | 8 | class VAE(nn.Module): 9 | 10 | #define layers 11 | def __init__(self): 12 | super(VAE, self).__init__() 13 | self.hidden_size = 64 14 | 15 | #Encoder layers 16 | self.enc_conv1 = nn.Conv2d(2, 128, 5, stride=2, padding=2) 17 | self.enc_bn1 = nn.BatchNorm2d(128) 18 | self.enc_conv2 = nn.Conv2d(128, 256, 5, stride=2, padding=2) 19 | self.enc_bn2 = nn.BatchNorm2d(256) 20 | self.enc_conv3 = nn.Conv2d(256, 512, 5, stride=2, padding=2) 21 | self.enc_bn3 = nn.BatchNorm2d(512) 22 | self.enc_conv4 = nn.Conv2d(512, 1024, 3, stride=2, padding=1) 23 | self.enc_bn4 = nn.BatchNorm2d(1024) 24 | self.enc_fc1 = nn.Linear(4*4*1024, self.hidden_size*2) 25 | self.enc_dropout1 = nn.Dropout(p=.7) 26 | 27 | #Cond encoder layers 28 | self.cond_enc_conv1 = nn.Conv2d(1, 128, 5, stride=2, padding=2) 29 | self.cond_enc_bn1 = nn.BatchNorm2d(128) 30 | self.cond_enc_conv2 = nn.Conv2d(128, 256, 5, stride=2, padding=2) 31 | self.cond_enc_bn2 = nn.BatchNorm2d(256) 32 | self.cond_enc_conv3 = nn.Conv2d(256, 512, 5, stride=2, padding=2) 33 | self.cond_enc_bn3 = nn.BatchNorm2d(512) 34 | self.cond_enc_conv4 = nn.Conv2d(512, 1024, 3, stride=2, padding=1) 35 | self.cond_enc_bn4 = nn.BatchNorm2d(1024) 36 | 37 | #Decoder layers 38 | self.dec_upsamp1 = nn.Upsample(scale_factor=4, mode='bilinear') 39 | self.dec_conv1 = nn.Conv2d(1024+self.hidden_size, 512, 3, stride=1, padding=1) 40 | self.dec_bn1 = nn.BatchNorm2d(512) 41 | self.dec_upsamp2 = nn.Upsample(scale_factor=2, mode='bilinear') 42 | self.dec_conv2 = nn.Conv2d(512*2, 256, 5, stride=1, padding=2) 43 | self.dec_bn2 = nn.BatchNorm2d(256) 44 | self.dec_upsamp3 = nn.Upsample(scale_factor=2, mode='bilinear') 45 | self.dec_conv3 = nn.Conv2d(256*2, 128, 5, stride=1, padding=2) 46 | self.dec_bn3 = nn.BatchNorm2d(128) 47 | self.dec_upsamp4 = nn.Upsample(scale_factor=2, mode='bilinear') 48 | self.dec_conv4 = nn.Conv2d(128*2, 64, 5, stride=1, padding=2) 49 | self.dec_bn4 = nn.BatchNorm2d(64) 50 | self.dec_upsamp5 = nn.Upsample(scale_factor=2, mode='bilinear') 51 | self.dec_conv5 = nn.Conv2d(64, 2, 5, stride=1, padding=2) 52 | 53 | def encoder(self, x): 54 | x = F.relu(self.enc_conv1(x)) 55 | x = self.enc_bn1(x) 56 | x = F.relu(self.enc_conv2(x)) 57 | x = self.enc_bn2(x) 58 | x = F.relu(self.enc_conv3(x)) 59 | x = self.enc_bn3(x) 60 | x = F.relu(self.enc_conv4(x)) 61 | x = self.enc_bn4(x) 62 | x = x.view(-1, 4*4*1024) 63 | x = self.enc_dropout1(x) 64 | x = self.enc_fc1(x) 65 | mu = x[..., :self.hidden_size] 66 | logvar = x[..., self.hidden_size:] 67 | return mu, logvar 68 | 69 | def cond_encoder(self, x): 70 | x = F.relu(self.cond_enc_conv1(x)) 71 | sc_feat32 = self.cond_enc_bn1(x) 72 | x = F.relu(self.cond_enc_conv2(sc_feat32)) 73 | sc_feat16 = self.cond_enc_bn2(x) 74 | x = F.relu(self.cond_enc_conv3(sc_feat16)) 75 | sc_feat8 = self.cond_enc_bn3(x) 76 | x = F.relu(self.cond_enc_conv4(sc_feat8)) 77 | sc_feat4 = self.cond_enc_bn4(x) 78 | return sc_feat32, sc_feat16, sc_feat8, sc_feat4 79 | 80 | def decoder(self, z, sc_feat32, sc_feat16, sc_feat8, sc_feat4): 81 | x = z.view(-1, self.hidden_size, 1, 1) 82 | x = self.dec_upsamp1(x) 83 | x = torch.cat([x, sc_feat4], 1) 84 | x = F.relu(self.dec_conv1(x)) 85 | x = self.dec_bn1(x) 86 | x = self.dec_upsamp2(x) 87 | x = torch.cat([x, sc_feat8], 1) 88 | x = F.relu(self.dec_conv2(x)) 89 | x = self.dec_bn2(x) 90 | x = self.dec_upsamp3(x) 91 | x = torch.cat([x, sc_feat16], 1) 92 | x = F.relu(self.dec_conv3(x)) 93 | x = self.dec_bn3(x) 94 | x = self.dec_upsamp4(x) 95 | x = torch.cat([x, sc_feat32], 1) 96 | x = F.relu(self.dec_conv4(x)) 97 | x = self.dec_bn4(x) 98 | x = self.dec_upsamp5(x) 99 | x = F.tanh(self.dec_conv5(x)) 100 | return x 101 | 102 | #define forward pass 103 | def forward(self, color, greylevel, z_in, is_train=True): 104 | sc_feat32, sc_feat16, sc_feat8, sc_feat4 = self.cond_encoder(greylevel) 105 | mu, logvar = self.encoder(color) 106 | if(is_train == True): 107 | stddev = torch.sqrt(torch.exp(logvar)) 108 | eps = Variable(torch.randn(stddev.size()).normal_()).cuda() 109 | z = torch.add(mu, torch.mul(eps, stddev)) 110 | else: 111 | z = z_in 112 | color_out = self.decoder(z, sc_feat32, sc_feat16, sc_feat8, sc_feat4) 113 | return mu, logvar, color_out 114 | --------------------------------------------------------------------------------