├── ops ├── __init__.py ├── loss_added.py ├── histogram_matching.py └── spectral_norm.py ├── tools ├── __init__.py ├── plot.py └── inception_score.py ├── data_loaders ├── __init__.py └── makeup.py ├── README.md ├── test.sh ├── LICENSE ├── visualize.py ├── dataloder.py ├── config.py ├── train.py ├── test.py ├── net.py ├── solver_cycle.py └── solver_makeup.py /ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .makeup import MAKEUP 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BeautyGAN 2 | 3 | Official implementation of ACM MM 2018 paper: "BeautyGAN: Instance-level Facial Makeup Transfer with Deep Generative Adversarial Network" 4 | 5 | Dataset can be found in project page: http://colalab.org/projects/BeautyGAN 6 | 7 | 8 | ## still in construction 9 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | # video! 用在VBT上效果很差,不知道为什么 5 | 6 | --img_size 361 --cls_list wild_before,RE_ORG --batch_size 16 --test_model 66_2520 7 | 8 | --img_size 361 --cls_list A_before,RE_ORG --batch_size 1 --test_model 66_2520 9 | 10 | 11 | --img_size 256 --cls_list wild_256,RE_REF --batch_size 1 --test_model 66_2520 12 | 13 | 14 | # 测试一下在有妆图片下的效果 15 | 16 | --img_size 256 --cls_list RE_REF,RE_ORI --batch_size 1 --test_model 66_2520 17 | 18 | --img_size 256 --cls_list RE_ORG,wild_256 --batch_size 1 --test_model 66_2520 19 | 20 | 21 | # new 22 | 23 | --task_name default --cls_list wild_256,RE_REF --batch_size 1 --test_model 26_2520 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from easydict import EasyDict as edict 5 | from torch.backends import cudnn 6 | from config import config, default, dataset_config 7 | 8 | from solvers import * 9 | from data_loaders import * 10 | 11 | default.network = 'MULTICYCLEGAN' 12 | #default.network = 'STARGAN' 13 | default.dataset_choice = ['MAKEUP'] 14 | #default.dataset_choice = ['CELEBA'] 15 | default.model_base = 'RES' 16 | default.loss_chosen = 'normal' 17 | default.gpu_ids = [0,1,2] 18 | 19 | config_default = config 20 | 21 | 22 | def train_net(): 23 | # enable cudnn 24 | cudnn.benchmark = True 25 | 26 | # get the DataLoader 27 | data_loaders = eval("get_loader_" + config.network)(default.dataset_choice, dataset_config, config, mode="test") 28 | 29 | #get the solver 30 | solver = eval("Solver_" + config.network +"_VIS")(default.dataset_choice, data_loaders, config, dataset_config) 31 | solver.visualize() 32 | 33 | if __name__ == '__main__': 34 | print("Call with args:") 35 | print(default) 36 | config = config_default[default.network] 37 | config.network = default.network 38 | config.model_base = default.model_base 39 | config.gpu_ids = default.gpu_ids 40 | 41 | # Create the directories if not exist 42 | if not os.path.exists(config.log_path): 43 | os.makedirs(config.log_path) 44 | if not os.path.exists(config.vis_path): 45 | os.makedirs(config.vis_path) 46 | if not os.path.exists(config.snapshot_path): 47 | os.makedirs(config.snapshot_path) 48 | if not os.path.exists(config.data_path): 49 | print("No datapath!!") 50 | 51 | train_net() 52 | -------------------------------------------------------------------------------- /tools/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | 7 | import collections 8 | import time 9 | #import cPickle as pickle 10 | 11 | _since_beginning = collections.defaultdict(lambda: {}) 12 | _since_last_flush = collections.defaultdict(lambda: {}) 13 | 14 | _iter = [0] 15 | def tick(): 16 | _iter[0] += 1 17 | 18 | def plot(name, value): 19 | _since_last_flush[name][_iter[0]] = value 20 | #print(_since_last_flush) 21 | 22 | def flush(task_name): 23 | prints = [] 24 | 25 | for name, vals in _since_last_flush.items(): 26 | #prints.append("{}\t{}".format(name, np.mean(vals.values()))) 27 | _since_beginning[name].update(vals) 28 | """ 29 | print(name) 30 | print("#######################") 31 | print(_since_beginning[name]) 32 | print("#######################") 33 | print(_since_beginning[name].keys()) 34 | print("#######################") 35 | print(list(_since_beginning[name].keys())) 36 | print("#######################") 37 | """ 38 | x_vals = np.sort(list(_since_beginning[name].keys())) 39 | y_vals = [_since_beginning[name][x] for x in x_vals] 40 | 41 | plt.clf() 42 | plt.plot(x_vals, y_vals) 43 | plt.xlabel('iteration') 44 | plt.ylabel(name) 45 | plt.savefig(name.replace(' ', '_')+ "_" + task_name +"_" + '.png') 46 | """ 47 | print "iter {}\t{}".format(_iter[0], "\t".join(prints)) 48 | _since_last_flush.clear() 49 | 50 | with open('log.pkl', 'wb') as f: 51 | pickle.dump(dict(_since_beginning), f, pickle.HIGHEST_PROTOCOL) 52 | """ -------------------------------------------------------------------------------- /ops/loss_added.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | class GANLoss(nn.Module): 6 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 7 | tensor=torch.FloatTensor): 8 | super(GANLoss, self).__init__() 9 | self.real_label = target_real_label 10 | self.fake_label = target_fake_label 11 | self.real_label_var = None 12 | self.fake_label_var = None 13 | self.Tensor = tensor 14 | if use_lsgan: 15 | self.loss = nn.MSELoss() 16 | else: 17 | self.loss = nn.BCELoss() 18 | 19 | def get_target_tensor(self, input, target_is_real): 20 | target_tensor = None 21 | if target_is_real: 22 | create_label = ((self.real_label_var is None) or 23 | (self.real_label_var.numel() != input.numel())) 24 | if create_label: 25 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 26 | self.real_label_var = Variable(real_tensor, requires_grad=False) 27 | target_tensor = self.real_label_var 28 | else: 29 | create_label = ((self.fake_label_var is None) or 30 | (self.fake_label_var.numel() != input.numel())) 31 | if create_label: 32 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 33 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 34 | target_tensor = self.fake_label_var 35 | return target_tensor 36 | 37 | def __call__(self, input, target_is_real): 38 | target_tensor = self.get_target_tensor(input, target_is_real) 39 | return self.loss(input, target_tensor) -------------------------------------------------------------------------------- /ops/histogram_matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import copy 4 | 5 | def cal_hist(image): 6 | """ 7 | cal cumulative hist for channel list 8 | """ 9 | hists = [] 10 | for i in range(0, 3): 11 | channel = image[i] 12 | # channel = image[i, :, :] 13 | channel = torch.from_numpy(channel) 14 | # hist, _ = np.histogram(channel, bins=256, range=(0,255)) 15 | hist = torch.histc(channel, bins=256, min=0, max=256) 16 | hist = hist.numpy() 17 | # refHist=hist.view(256,1) 18 | sum = hist.sum() 19 | pdf = [v / sum for v in hist] 20 | for i in range(1, 256): 21 | pdf[i] = pdf[i - 1] + pdf[i] 22 | hists.append(pdf) 23 | return hists 24 | 25 | 26 | def cal_trans(ref, adj): 27 | """ 28 | calculate transfer function 29 | algorithm refering to wiki item: Histogram matching 30 | """ 31 | table = list(range(0, 256)) 32 | for i in list(range(1, 256)): 33 | for j in list(range(1, 256)): 34 | if ref[i] >= adj[j - 1] and ref[i] <= adj[j]: 35 | table[i] = j 36 | break 37 | table[255] = 255 38 | return table 39 | 40 | 41 | def histogram_matching(dstImg, refImg, index): 42 | """ 43 | perform histogram matching 44 | dstImg is transformed to have the same the histogram with refImg's 45 | index[0], index[1]: the index of pixels that need to be transformed in dstImg 46 | index[2], index[3]: the index of pixels that to compute histogram in refImg 47 | """ 48 | index = [x.cpu().numpy() for x in index] 49 | dstImg = dstImg.detach().cpu().numpy() 50 | refImg = refImg.detach().cpu().numpy() 51 | dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)] 52 | ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)] 53 | hist_ref = cal_hist(ref_align) 54 | hist_dst = cal_hist(dst_align) 55 | tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)] 56 | 57 | mid = copy.deepcopy(dst_align) 58 | for i in range(0, 3): 59 | for k in range(0, len(index[0])): 60 | dst_align[i][k] = tables[i][int(mid[i][k])] 61 | 62 | for i in range(0, 3): 63 | dstImg[i, index[0], index[1]] = dst_align[i] 64 | 65 | dstImg = torch.FloatTensor(dstImg).cuda() 66 | return dstImg 67 | -------------------------------------------------------------------------------- /dataloder.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import DataLoader 3 | from torchvision.datasets import ImageFolder 4 | from data_loaders.makeup import MAKEUP 5 | import torch 6 | import numpy as np 7 | import PIL 8 | 9 | def ToTensor(pic): 10 | # handle PIL Image 11 | if pic.mode == 'I': 12 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 13 | elif pic.mode == 'I;16': 14 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 15 | else: 16 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 17 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 18 | if pic.mode == 'YCbCr': 19 | nchannel = 3 20 | elif pic.mode == 'I;16': 21 | nchannel = 1 22 | else: 23 | nchannel = len(pic.mode) 24 | img = img.view(pic.size[1], pic.size[0], nchannel) 25 | # put it from HWC to CHW format 26 | # yikes, this transpose takes 80% of the loading time/CPU 27 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 28 | if isinstance(img, torch.ByteTensor): 29 | return img.float() 30 | else: 31 | return img 32 | 33 | def get_loader(data_config, config, mode="train"): 34 | # return the DataLoader 35 | dataset_name = data_config.name 36 | transform = transforms.Compose([ 37 | transforms.Resize(config.img_size), 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) 40 | transform_mask = transforms.Compose([ 41 | transforms.Resize(config.img_size, interpolation=PIL.Image.NEAREST), 42 | ToTensor]) 43 | print(config.data_path) 44 | #""" 45 | if mode=="train": 46 | dataset_train = eval(dataset_name)(data_config.dataset_path, transform=transform, mode= "train",\ 47 | transform_mask=transform_mask, cls_list = config.cls_list) 48 | dataset_test = eval(dataset_name)(data_config.dataset_path, transform=transform, mode= "test",\ 49 | transform_mask=transform_mask, cls_list = config.cls_list) 50 | 51 | #""" 52 | data_loader_train = DataLoader(dataset=dataset_train, 53 | batch_size=config.batch_size, 54 | shuffle=True) 55 | 56 | if mode=="test": 57 | data_loader_train = None 58 | dataset_test = eval(dataset_name)(data_config.dataset_path, transform=transform, mode= "test",\ 59 | transform_mask =transform_mask, cls_list = config.cls_list) 60 | 61 | 62 | 63 | 64 | data_loader_test = DataLoader(dataset=dataset_test, 65 | batch_size=1, 66 | shuffle=False) 67 | 68 | return [data_loader_train, data_loader_test] -------------------------------------------------------------------------------- /ops/spectral_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | 4 | def l2normalize(v, eps=1e-12): 5 | return v / (v.norm() + eps) 6 | 7 | class SpectralNorm(object): 8 | def __init__(self): 9 | self.name = "weight" 10 | #print(self.name) 11 | self.power_iterations = 1 12 | 13 | def compute_weight(self, module): 14 | u = getattr(module, self.name + "_u") 15 | v = getattr(module, self.name + "_v") 16 | w = getattr(module, self.name + "_bar") 17 | 18 | height = w.data.shape[0] 19 | for _ in range(self.power_iterations): 20 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 21 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 22 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 23 | sigma = u.dot(w.view(height, -1).mv(v)) 24 | return w / sigma.expand_as(w) 25 | 26 | @staticmethod 27 | def apply(module): 28 | name = "weight" 29 | fn = SpectralNorm() 30 | 31 | try: 32 | u = getattr(module, name + "_u") 33 | v = getattr(module, name + "_v") 34 | w = getattr(module, name + "_bar") 35 | except AttributeError: 36 | w = getattr(module, name) 37 | height = w.data.shape[0] 38 | width = w.view(height, -1).data.shape[1] 39 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 40 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 41 | w_bar = Parameter(w.data) 42 | 43 | #del module._parameters[name] 44 | 45 | module.register_parameter(name + "_u", u) 46 | module.register_parameter(name + "_v", v) 47 | module.register_parameter(name + "_bar", w_bar) 48 | 49 | # remove w from parameter list 50 | del module._parameters[name] 51 | 52 | setattr(module, name, fn.compute_weight(module)) 53 | 54 | # recompute weight before every forward() 55 | module.register_forward_pre_hook(fn) 56 | 57 | return fn 58 | 59 | def remove(self, module): 60 | weight = self.compute_weight(module) 61 | delattr(module, self.name) 62 | del module._parameters[self.name + '_u'] 63 | del module._parameters[self.name + '_v'] 64 | del module._parameters[self.name + '_bar'] 65 | module.register_parameter(self.name, Parameter(weight.data)) 66 | 67 | def __call__(self, module, inputs): 68 | setattr(module, self.name, self.compute_weight(module)) 69 | 70 | def spectral_norm(module): 71 | SpectralNorm.apply(module) 72 | return module 73 | 74 | def remove_spectral_norm(module): 75 | name = 'weight' 76 | for k, hook in module._forward_pre_hooks.items(): 77 | if isinstance(hook, SpectralNorm) and hook.name == name: 78 | hook.remove(module) 79 | del module._forward_pre_hooks[k] 80 | return module 81 | 82 | raise ValueError("spectral_norm of '{}' not found in {}" 83 | .format(name, module)) -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | default = edict() 4 | 5 | default.snapshot_path = './snapshot/' 6 | default.vis_path = './visulization/' 7 | default.log_path = './log/' 8 | default.data_path = './data/' 9 | 10 | config = edict() 11 | # setting for cycleGAN 12 | # Hyper-parameters 13 | 14 | config.multi_gpu = False 15 | config.gpu_ids = [0,1,2] 16 | 17 | # Setting path 18 | config.snapshot_path = default.snapshot_path 19 | config.pretrained_path = default.snapshot_path 20 | config.vis_path = default.vis_path 21 | config.log_path = default.log_path 22 | config.data_path = default.data_path 23 | 24 | # Setting training parameters 25 | config.task_name = "" 26 | config.G_LR = 2e-5 27 | config.D_LR = 2e-5 28 | config.beta1 = 0.5 29 | config.beta2 = 0.999 30 | config.c_dim = 2 31 | config.num_epochs = 200 32 | config.num_epochs_decay = 100 33 | config.ndis = 1 34 | config.snapshot_step = 260 35 | config.log_step = 10 36 | config.vis_step = config.snapshot_step 37 | config.batch_size = 1 38 | config.lambda_A = 10.0 39 | config.lambda_B =10.0 40 | config.lambda_idt = 0.5 41 | config.img_size = 256 42 | config.g_conv_dim = 64 43 | config.d_conv_dim = 64 44 | config.g_repeat_num = 6 45 | config.d_repeat_num = 3 46 | 47 | config.checkpoint = "" 48 | 49 | config.test_model = "51_2000" 50 | 51 | 52 | # Setting datasets 53 | dataset_config = edict() 54 | 55 | dataset_config.name = 'MAKEUP' 56 | dataset_config.dataset_path = default.data_path 57 | dataset_config.img_size = 256 58 | 59 | def generate_config(_network, _dataset): 60 | for k, v in dataset_config[_dataset].items(): 61 | if k in config: 62 | config[k] = v 63 | elif k in default: 64 | default[k] = v 65 | 66 | def merge_cfg_arg(config, args): 67 | config.gpu_ids = [int(i) for i in args.gpus.split(',')] 68 | config.batch_size = args.batch_size 69 | config.vis_step = args.vis_step 70 | config.snapshot_step = args.vis_step 71 | config.ndis = args.ndis 72 | config.lambda_cls = args.lambda_cls 73 | config.lambda_A = args.lambda_rec 74 | config.lambda_B = args.lambda_rec 75 | config.G_LR = args.LR 76 | config.D_LR = args.LR 77 | config.num_epochs_decay = args.decay 78 | config.num_epochs = args.epochs 79 | config.whichG = args.whichG 80 | config.task_name = args.task_name 81 | config.norm = args.norm 82 | config.lambda_his = args.lambda_his 83 | config.lambda_vgg = args.lambda_vgg 84 | config.cls_list = [i for i in args.cls_list.split(',')] 85 | config.content_layer = [i for i in args.content_layer.split(',')] 86 | config.direct = args.direct 87 | config.lips = args.lips 88 | config.skin = args.skin 89 | config.eye = args.eye 90 | config.g_repeat = args.g_repeat 91 | config.lambda_his_lip = args.lambda_his 92 | config.lambda_his_skin_1 = args.lambda_his * args.lambda_skin_1 93 | config.lambda_his_skin_2 = args.lambda_his * args.lambda_skin_2 94 | config.lambda_his_eye = args.lambda_his * args.lambda_eye 95 | print(config) 96 | if "checkpoint" in config.items(): 97 | config.checkpoint = args.checkpoint 98 | if "test_model" in config.items(): 99 | config.test_model = args.test_model 100 | return config 101 | 102 | -------------------------------------------------------------------------------- /tools/inception_score.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/openai/improved-gan/blob/master/inception_score/model.py 2 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os.path 8 | import sys 9 | import tarfile 10 | 11 | import numpy as np 12 | from six.moves import urllib 13 | import tensorflow as tf 14 | import glob 15 | import scipy.misc 16 | import math 17 | import sys 18 | import os 19 | 20 | MODEL_DIR = '/tmp/imagenet' 21 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 22 | softmax = None 23 | 24 | #os.environ["CUDA_VISIBLE_DEVICES"] = '0' 25 | config = tf.ConfigProto() 26 | #config.gpu_options.per_process_gpu_memory_fraction = 0.4 27 | config.gpu_options.allow_growth = True 28 | 29 | # Call this function with list of images. Each of elements should be a 30 | # numpy array with values ranging from 0 to 255. 31 | def get_inception_score(images, splits=10): 32 | assert(type(images) == list) 33 | assert(type(images[0]) == np.ndarray) 34 | assert(len(images[0].shape) == 3) 35 | assert(np.max(images[0]) > 10) 36 | assert(np.min(images[0]) >= 0.0) 37 | inps = [] 38 | for img in images: 39 | img = img.astype(np.float32) 40 | inps.append(np.expand_dims(img, 0)) 41 | bs = 100 42 | with tf.Session(config = config) as sess: 43 | preds = [] 44 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 45 | for i in range(n_batches): 46 | # sys.stdout.write(".") 47 | # sys.stdout.flush() 48 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 49 | inp = np.concatenate(inp, 0) 50 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 51 | preds.append(pred) 52 | preds = np.concatenate(preds, 0) 53 | scores = [] 54 | for i in range(splits): 55 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 56 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 57 | kl = np.mean(np.sum(kl, 1)) 58 | scores.append(np.exp(kl)) 59 | return np.mean(scores), np.std(scores) 60 | 61 | # This function is called automatically. 62 | def _init_inception(): 63 | global softmax 64 | if not os.path.exists(MODEL_DIR): 65 | os.makedirs(MODEL_DIR) 66 | filename = DATA_URL.split('/')[-1] 67 | filepath = os.path.join(MODEL_DIR, filename) 68 | if not os.path.exists(filepath): 69 | def _progress(count, block_size, total_size): 70 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 71 | filename, float(count * block_size) / float(total_size) * 100.0)) 72 | sys.stdout.flush() 73 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 74 | print() 75 | statinfo = os.stat(filepath) 76 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 77 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 78 | with tf.gfile.FastGFile(os.path.join( 79 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 80 | graph_def = tf.GraphDef() 81 | graph_def.ParseFromString(f.read()) 82 | _ = tf.import_graph_def(graph_def, name='') 83 | # Works with an arbitrary minibatch size. 84 | with tf.Session(config=config) as sess: 85 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 86 | ops = pool3.graph.get_operations() 87 | for op_idx, op in enumerate(ops): 88 | for o in op.outputs: 89 | shape = o.get_shape() 90 | shape = [s.value for s in shape] 91 | new_shape = [] 92 | for j, s in enumerate(shape): 93 | if s == 1 and j == 0: 94 | new_shape.append(None) 95 | else: 96 | new_shape.append(s) 97 | o._shape = tf.TensorShape(new_shape) 98 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 99 | logits = tf.matmul(tf.squeeze(pool3), w) 100 | softmax = tf.nn.softmax(logits) 101 | 102 | if softmax is None: 103 | _init_inception() 104 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | 5 | from torch.backends import cudnn 6 | from config import config, dataset_config, merge_cfg_arg 7 | 8 | from dataloder import get_loader 9 | from solver_cycle import Solver_cycleGAN 10 | from solver_makeup import Solver_makeupGAN 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='Train GAN') 14 | # general 15 | parser.add_argument('--data_path', default='makeup/makeup_final/', type=str, help='training and test data path') 16 | parser.add_argument('--dataset', default='MAKEUP', type=str, help='dataset name, MAKEUP means two domain, MMAKEUP means multi-domain') 17 | parser.add_argument('--gpus', default='0', type=str, help='GPU device to train with') 18 | parser.add_argument('--batch_size', default='1', type=int, help='batch_size') 19 | parser.add_argument('--vis_step', default='1260', type=int, help='steps between visualization') 20 | parser.add_argument('--task_name', default='', type=str, help='task name') 21 | parser.add_argument('--checkpoint', default='', type=str, help='checkpoint to load') 22 | parser.add_argument('--ndis', default='1', type=int, help='train discriminator steps') 23 | parser.add_argument('--LR', default="2e-4", type=float, help='Learning rate') 24 | parser.add_argument('--decay', default='0', type=int, help='epochs number for training') 25 | parser.add_argument('--model', default='makeupGAN', type=str, help='which model to use: cycleGAN/ makeupGAN') 26 | parser.add_argument('--epochs', default='300', type=int, help='nums of epochs') 27 | parser.add_argument('--whichG', default='branch', type=str, help='which Generator to choose, normal/branch, branch means two input branches') 28 | parser.add_argument('--norm', default='SN', type=str, help='normalization of discriminator, SN means spectrum normalization, none means no normalization') 29 | parser.add_argument('--d_repeat', default='3', type=int, help='the repeat Res-block in discriminator') 30 | parser.add_argument('--g_repeat', default='6', type=int, help='the repeat Res-block in Generator') 31 | parser.add_argument('--lambda_cls', default='1', type=float, help='the lambda_cls weight') 32 | parser.add_argument('--lambda_rec', default='10', type=int, help='lambda_A and lambda_B') 33 | parser.add_argument('--lambda_his', default='1', type=float, help='histogram loss on lips') 34 | parser.add_argument('--lambda_skin_1', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin') 35 | parser.add_argument('--lambda_skin_2', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin') 36 | parser.add_argument('--lambda_eye', default='1', type=float, help='histogram loss on eyes equals to lambda_his*lambda_eye') 37 | parser.add_argument('--content_layer', default='r41', type=str, help='vgg layer we use to output features') 38 | parser.add_argument('--lambda_vgg', default='5e-3', type=float, help='the param of vgg loss') 39 | parser.add_argument('--cls_list', default='SYMIX,MAKEMIX', type=str, help='the classes of makeup to train') 40 | parser.add_argument('--direct', action="store_true", default=True, help='direct means to add local cosmetic loss at the first, unified training') 41 | parser.add_argument('--lips', action="store_true", default=True, help='whether to finetune lips color') 42 | parser.add_argument('--skin', action="store_true", default=True, help='whether to finetune foundation color') 43 | parser.add_argument('--eye', action="store_true", default=True, help='whether to finetune eye shadow color') 44 | args = parser.parse_args() 45 | return args 46 | 47 | def train_net(): 48 | # enable cudnn 49 | cudnn.benchmark = True 50 | 51 | data_loaders = get_loader(dataset_config, config, mode="train") # return train&test 52 | #get the solver 53 | if args.model == 'cycleGAN': 54 | solver = Solver_cycleGAN(data_loaders, config, dataset_config) 55 | elif args.model =='makeupGAN': 56 | solver = Solver_makeupGAN(data_loaders, config, dataset_config) 57 | else: 58 | print("model that not support") 59 | exit() 60 | solver.train() 61 | 62 | 63 | 64 | if __name__ == '__main__': 65 | args = parse_args() 66 | print("Call with args:") 67 | print(args) 68 | config = merge_cfg_arg(config, args) 69 | 70 | dataset_config.name = args.dataset 71 | 72 | print("The config is:") 73 | print(config) 74 | 75 | # Create the directories if not exist 76 | if not os.path.exists(config.data_path): 77 | print("No datapath!!") 78 | exit() 79 | 80 | if args.data_path != '': 81 | dataset_config.dataset_path = os.path.join(config.data_path, args.data_path) 82 | 83 | train_net() 84 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import argparse 4 | 5 | from torch.backends import cudnn 6 | from config import config, dataset_config, merge_cfg_arg 7 | 8 | from dataloder import get_loader 9 | from solver_cycle import Solver_cycleGAN 10 | from solver_makeup import Solver_makeupGAN 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser(description='Train GAN') 14 | # general 15 | parser.add_argument('--data_path', default='makeup/makeup_final/', type=str, help='training and test data path') 16 | parser.add_argument('--dataset', default='MAKEUP', type=str, help='dataset name, MAKEUP means two domain, MMAKEUP means multi-domain') 17 | parser.add_argument('--gpus', default='0', type=str, help='GPU device to train with') 18 | parser.add_argument('--batch_size', default='1', type=int, help='batch_size') 19 | parser.add_argument('--vis_step', default='1260', type=int, help='steps between visualization') 20 | parser.add_argument('--task_name', default='', type=str, help='task name') 21 | parser.add_argument('--ndis', default='1', type=int, help='train discriminator steps') 22 | parser.add_argument('--LR', default="2e-4", type=float, help='Learning rate') 23 | parser.add_argument('--decay', default='0', type=int, help='epochs number for training') 24 | parser.add_argument('--model', default='makeupGAN', type=str, help='which model to use: cycleGAN/ makeupGAN') 25 | parser.add_argument('--epochs', default='300', type=int, help='nums of epochs') 26 | parser.add_argument('--whichG', default='branch', type=str, help='which Generator to choose, normal/branch, branch means two input branches') 27 | parser.add_argument('--norm', default='SN', type=str, help='normalization of discriminator, SN means spectrum normalization, none means no normalization') 28 | parser.add_argument('--d_repeat', default='3', type=int, help='the repeat Res-block in discriminator') 29 | parser.add_argument('--g_repeat', default='6', type=int, help='the repeat Res-block in Generator') 30 | parser.add_argument('--lambda_cls', default='1', type=float, help='the lambda_cls weight') 31 | parser.add_argument('--lambda_rec', default='10', type=int, help='lambda_A and lambda_B') 32 | parser.add_argument('--lambda_his', default='1', type=float, help='histogram loss on lips') 33 | parser.add_argument('--lambda_skin_1', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin') 34 | parser.add_argument('--lambda_skin_2', default='0.1', type=float, help='histogram loss on skin equals to lambda_his* lambda_skin') 35 | parser.add_argument('--lambda_eye', default='1', type=float, help='histogram loss on eyes equals to lambda_his*lambda_eye') 36 | parser.add_argument('--content_layer', default='r41', type=str, help='vgg layer we use') 37 | parser.add_argument('--lambda_vgg', default='5e-3', type=float, help='the param of vgg loss') 38 | parser.add_argument('--cls_list', default='A_OM,B_OM', type=str, help='the classes we choose') 39 | parser.add_argument('--direct', action="store_true", default=False, help='direct means to add local cosmetic loss at the first, unified training') 40 | parser.add_argument('--finetune', action="store_true", default=False, help='finetune the network or not') 41 | parser.add_argument('--lips', action="store_true", default=False, help='whether to finetune lips color') 42 | parser.add_argument('--skin', action="store_true", default=False, help='whether to finetune foundation color') 43 | parser.add_argument('--eye', action="store_true", default=False, help='whether to finetune eye shadow color') 44 | parser.add_argument('--test_model', default='20_2520', type=str, help='which one to test') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def tes_net(): 50 | # enable cudnn 51 | cudnn.benchmark = True 52 | 53 | # get the DataLoader 54 | data_loaders = get_loader(dataset_config, config, mode="test") 55 | 56 | #get the solver 57 | if args.model == 'cycleGAN': 58 | solver = Solver_cycleGAN(data_loaders, config, dataset_config) 59 | elif args.model =='makeupGAN': 60 | solver = Solver_makeupGAN(data_loaders, config, dataset_config) 61 | else: 62 | print("model that not support") 63 | exit() 64 | solver.test() 65 | 66 | if __name__ == '__main__': 67 | args = parse_args() 68 | print("Call with args:") 69 | print(args) 70 | config = merge_cfg_arg(config, args) 71 | 72 | config.test_model = args.test_model 73 | 74 | print("The config is:") 75 | print(config) 76 | 77 | # Create the directories if not exist 78 | if not os.path.exists(config.data_path): 79 | print("No datapath!!") 80 | 81 | dataset_config.dataset_path = os.path.join(config.data_path, args.data_path) 82 | tes_net() 83 | -------------------------------------------------------------------------------- /data_loaders/makeup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import linecache 5 | 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | 9 | class MAKEUP(Dataset): 10 | def __init__(self, image_path, transform, mode, transform_mask, cls_list): 11 | self.image_path = image_path 12 | self.transform = transform 13 | self.mode = mode 14 | self.transform_mask = transform_mask 15 | 16 | self.cls_list = cls_list 17 | self.cls_A = cls_list[0] 18 | self.cls_B = cls_list[1] 19 | 20 | for cls in self.cls_list: 21 | setattr(self, "train_" + cls + "_list_path", os.path.join(self.image_path, "train_" + cls + ".txt")) 22 | setattr(self, "train_" + cls + "_lines", open(getattr(self, "train_" + cls + "_list_path"), 'r').readlines()) 23 | setattr(self, "num_of_train_" + cls + "_data", len(getattr(self, "train_" + cls + "_lines"))) 24 | for cls in self.cls_list: 25 | if self.mode == "test_all": 26 | setattr(self, "test_" + cls + "_list_path", os.path.join(self.image_path, "test_" + cls + "_all.txt")) 27 | setattr(self, "test_" + cls + "_lines", open(getattr(self, "test_" + cls + "_list_path"), 'r').readlines()) 28 | setattr(self, "num_of_test_" + cls + "_data", len(getattr(self, "test_" + cls + "_lines"))) 29 | else: 30 | setattr(self, "test_" + cls + "_list_path", os.path.join(self.image_path, "test_" + cls + ".txt")) 31 | setattr(self, "test_" + cls + "_lines", open(getattr(self, "test_" + cls + "_list_path"), 'r').readlines()) 32 | setattr(self, "num_of_test_" + cls + "_data", len(getattr(self, "test_" + cls + "_lines"))) 33 | 34 | print ('Start preprocessing dataset..!') 35 | self.preprocess() 36 | print ('Finished preprocessing dataset..!') 37 | 38 | def preprocess(self): 39 | for cls in self.cls_list: 40 | setattr(self, "train_" + cls + "_filenames", []) 41 | setattr(self, "train_" + cls + "_mask_filenames", []) 42 | 43 | lines = getattr(self, "train_" + cls + "_lines") 44 | random.shuffle(lines) 45 | 46 | for i, line in enumerate(lines): 47 | splits = line.split() 48 | getattr(self, "train_" + cls + "_filenames").append(splits[0]) 49 | getattr(self, "train_" + cls + "_mask_filenames").append(splits[1]) 50 | 51 | for cls in self.cls_list: 52 | setattr(self, "test_" + cls + "_filenames", []) 53 | setattr(self, "test_" + cls + "_mask_filenames", []) 54 | lines = getattr(self, "test_" + cls + "_lines") 55 | for i, line in enumerate(lines): 56 | splits = line.split() 57 | getattr(self, "test_" + cls + "_filenames").append(splits[0]) 58 | getattr(self, "test_" + cls + "_mask_filenames").append(splits[1]) 59 | 60 | if self.mode == "test_baseline": 61 | setattr(self, "test_" + self.cls_A + "_filenames", os.listdir(os.path.join(self.image_path, "baseline", "org_aligned"))) 62 | setattr(self, "num_of_test_" + self.cls_A + "_data", len(os.listdir(os.path.join(self.image_path, "baseline", "org_aligned")))) 63 | setattr(self, "test_" + self.cls_B + "_filenames", os.listdir(os.path.join(self.image_path, "baseline", "ref_aligned"))) 64 | setattr(self, "num_of_test_" + self.cls_B + "_data", len(os.listdir(os.path.join(self.image_path, "baseline", "ref_aligned")))) 65 | 66 | def __getitem__(self, index): 67 | if self.mode == 'train' or self.mode == 'train_finetune': 68 | index_A = random.randint(0, getattr(self, "num_of_train_" + self.cls_A + "_data") - 1) 69 | index_B = random.randint(0, getattr(self, "num_of_train_" + self.cls_B + "_data") - 1) 70 | image_A = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_A + "_filenames")[index_A])).convert("RGB") 71 | image_B = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_B + "_filenames")[index_B])).convert("RGB") 72 | mask_A = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_A + "_mask_filenames")[index_A])) 73 | mask_B = Image.open(os.path.join(self.image_path, getattr(self, "train_" + self.cls_B + "_mask_filenames")[index_B])) 74 | return self.transform(image_A), self.transform(image_B), self.transform_mask(mask_A), self.transform_mask(mask_B) 75 | if self.mode in ['test', 'test_all']: 76 | #""" 77 | image_A = Image.open(os.path.join(self.image_path, getattr(self, "test_" + self.cls_A + "_filenames")[index // getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB") 78 | image_B = Image.open(os.path.join(self.image_path, getattr(self, "test_" + self.cls_B + "_filenames")[index % getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB") 79 | return self.transform(image_A), self.transform(image_B) 80 | if self.mode == "test_baseline": 81 | image_A = Image.open(os.path.join(self.image_path, "baseline", "org_aligned", getattr(self, "test_" + self.cls_A + "_filenames")[index // getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB") 82 | image_B = Image.open(os.path.join(self.image_path, "baseline", "ref_aligned", getattr(self, "test_" + self.cls_B + "_filenames")[index % getattr(self, 'num_of_test_' + self.cls_list[1] + '_data')])).convert("RGB") 83 | return self.transform(image_A), self.transform(image_B) 84 | 85 | def __len__(self): 86 | if self.mode == 'train' or self.mode == 'train_finetune': 87 | num_A = getattr(self, 'num_of_train_' + self.cls_list[0] + '_data') 88 | num_B = getattr(self, 'num_of_train_' + self.cls_list[1] + '_data') 89 | return max(num_A, num_B) 90 | elif self.mode in ['test', "test_baseline", 'test_all']: 91 | num_A = getattr(self, 'num_of_test_' + self.cls_list[0] + '_data') 92 | num_B = getattr(self, 'num_of_test_' + self.cls_list[1] + '_data') 93 | return num_A * num_B 94 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ops.spectral_norm import spectral_norm as SpectralNorm 6 | 7 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 8 | # When LSGAN is used, it is basically same as MSELoss, 9 | # but it abstracts away the need to create the target label tensor 10 | # that has the same size as the input 11 | 12 | class ResidualBlock(nn.Module): 13 | """Residual Block.""" 14 | def __init__(self, dim_in, dim_out): 15 | super(ResidualBlock, self).__init__() 16 | self.main = nn.Sequential( 17 | nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 18 | nn.InstanceNorm2d(dim_out, affine=True), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False), 21 | nn.InstanceNorm2d(dim_out, affine=True)) 22 | 23 | def forward(self, x): 24 | return x + self.main(x) 25 | 26 | 27 | class Generator(nn.Module): 28 | """Generator. Encoder-Decoder Architecture.""" 29 | def __init__(self, conv_dim=64, repeat_num=6): 30 | super(Generator, self).__init__() 31 | 32 | layers = [] 33 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 34 | layers.append(nn.InstanceNorm2d(conv_dim, affine=True)) 35 | layers.append(nn.ReLU(inplace=True)) 36 | 37 | # Down-Sampling 38 | curr_dim = conv_dim 39 | for i in range(2): 40 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) 41 | layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True)) 42 | layers.append(nn.ReLU(inplace=True)) 43 | curr_dim = curr_dim * 2 44 | 45 | # Bottleneck 46 | for i in range(repeat_num): 47 | layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) 48 | 49 | # Up-Sampling 50 | for i in range(2): 51 | layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False)) 52 | layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True)) 53 | layers.append(nn.ReLU(inplace=True)) 54 | curr_dim = curr_dim // 2 55 | 56 | layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) 57 | layers.append(nn.Tanh()) 58 | self.main = nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | out = self.main(x) 62 | return out 63 | 64 | class Generator_makeup(nn.Module): 65 | """Generator. Encoder-Decoder Architecture.""" 66 | # input 2 images and output 2 images as well 67 | def __init__(self, conv_dim=64, repeat_num=6, input_nc=6): 68 | super(Generator_makeup, self).__init__() 69 | 70 | layers = [] 71 | layers.append(nn.Conv2d(input_nc, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 72 | layers.append(nn.InstanceNorm2d(conv_dim, affine=True)) 73 | layers.append(nn.ReLU(inplace=True)) 74 | 75 | # Down-Sampling 76 | curr_dim = conv_dim 77 | for i in range(2): 78 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) 79 | layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True)) 80 | layers.append(nn.ReLU(inplace=True)) 81 | curr_dim = curr_dim * 2 82 | 83 | # Bottleneck 84 | for i in range(repeat_num): 85 | layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) 86 | 87 | # Up-Sampling 88 | for i in range(2): 89 | layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False)) 90 | layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True)) 91 | layers.append(nn.ReLU(inplace=True)) 92 | curr_dim = curr_dim // 2 93 | 94 | self.main = nn.Sequential(*layers) 95 | 96 | layers_1 = [] 97 | layers_1.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) 98 | layers_1.append(nn.Tanh()) 99 | self.branch_1 = nn.Sequential(*layers_1) 100 | layers_2 = [] 101 | layers_2.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) 102 | layers_2.append(nn.Tanh()) 103 | self.branch_2 = nn.Sequential(*layers_2) 104 | 105 | def forward(self, x, y): 106 | input_x = torch.cat((x, y), dim=1) 107 | out = self.main(input_x) 108 | out_A = self.branch_1(out) 109 | out_B = self.branch_2(out) 110 | return out_A, out_B 111 | 112 | 113 | class Generator_branch(nn.Module): 114 | """Generator. Encoder-Decoder Architecture.""" 115 | # input 2 images and output 2 images as well 116 | def __init__(self, conv_dim=64, repeat_num=6, input_nc=3): 117 | super(Generator_branch, self).__init__() 118 | 119 | # Branch input 120 | layers_branch = [] 121 | layers_branch.append(nn.Conv2d(input_nc, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 122 | layers_branch.append(nn.InstanceNorm2d(conv_dim, affine=True)) 123 | layers_branch.append(nn.ReLU(inplace=True)) 124 | layers_branch.append(nn.Conv2d(conv_dim, conv_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) 125 | layers_branch.append(nn.InstanceNorm2d(conv_dim*2, affine=True)) 126 | layers_branch.append(nn.ReLU(inplace=True)) 127 | self.Branch_0 = nn.Sequential(*layers_branch) 128 | 129 | # Branch input 130 | layers_branch = [] 131 | layers_branch.append(nn.Conv2d(input_nc, conv_dim, kernel_size=7, stride=1, padding=3, bias=False)) 132 | layers_branch.append(nn.InstanceNorm2d(conv_dim, affine=True)) 133 | layers_branch.append(nn.ReLU(inplace=True)) 134 | layers_branch.append(nn.Conv2d(conv_dim, conv_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) 135 | layers_branch.append(nn.InstanceNorm2d(conv_dim*2, affine=True)) 136 | layers_branch.append(nn.ReLU(inplace=True)) 137 | self.Branch_1 = nn.Sequential(*layers_branch) 138 | 139 | # Down-Sampling, branch merge 140 | layers = [] 141 | curr_dim = conv_dim*2 142 | layers.append(nn.Conv2d(curr_dim*2, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False)) 143 | layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True)) 144 | layers.append(nn.ReLU(inplace=True)) 145 | curr_dim = curr_dim * 2 146 | 147 | # Bottleneck 148 | for i in range(repeat_num): 149 | layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim)) 150 | 151 | # Up-Sampling 152 | for i in range(2): 153 | layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False)) 154 | layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True)) 155 | layers.append(nn.ReLU(inplace=True)) 156 | curr_dim = curr_dim // 2 157 | 158 | self.main = nn.Sequential(*layers) 159 | 160 | layers_1 = [] 161 | layers_1.append(nn.Conv2d(curr_dim, curr_dim, kernel_size=3, stride=1, padding=1, bias=False)) 162 | layers_1.append(nn.InstanceNorm2d(curr_dim, affine=True)) 163 | layers_1.append(nn.ReLU(inplace=True)) 164 | layers_1.append(nn.Conv2d(curr_dim, curr_dim, kernel_size=3, stride=1, padding=1, bias=False)) 165 | layers_1.append(nn.InstanceNorm2d(curr_dim, affine=True)) 166 | layers_1.append(nn.ReLU(inplace=True)) 167 | layers_1.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) 168 | layers_1.append(nn.Tanh()) 169 | self.branch_1 = nn.Sequential(*layers_1) 170 | layers_2 = [] 171 | layers_2.append(nn.Conv2d(curr_dim, curr_dim, kernel_size=3, stride=1, padding=1, bias=False)) 172 | layers_2.append(nn.InstanceNorm2d(curr_dim, affine=True)) 173 | layers_2.append(nn.ReLU(inplace=True)) 174 | layers_2.append(nn.Conv2d(curr_dim, curr_dim, kernel_size=3, stride=1, padding=1, bias=False)) 175 | layers_2.append(nn.InstanceNorm2d(curr_dim, affine=True)) 176 | layers_2.append(nn.ReLU(inplace=True)) 177 | layers_2.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False)) 178 | layers_2.append(nn.Tanh()) 179 | self.branch_2 = nn.Sequential(*layers_2) 180 | 181 | def forward(self, x, y): 182 | input_x = self.Branch_0(x) 183 | input_y = self.Branch_1(y) 184 | input_fuse = torch.cat((input_x, input_y), dim=1) 185 | out = self.main(input_fuse) 186 | out_A = self.branch_1(out) 187 | out_B = self.branch_2(out) 188 | return out_A, out_B 189 | 190 | class Discriminator(nn.Module): 191 | """Discriminator. PatchGAN.""" 192 | def __init__(self, image_size=128, conv_dim=64, repeat_num=3, norm='SN'): 193 | super(Discriminator, self).__init__() 194 | 195 | layers = [] 196 | if norm=='SN': 197 | layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))) 198 | else: 199 | layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)) 200 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 201 | 202 | curr_dim = conv_dim 203 | for i in range(1, repeat_num): 204 | if norm=='SN': 205 | layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))) 206 | else: 207 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1)) 208 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 209 | curr_dim = curr_dim * 2 210 | 211 | #k_size = int(image_size / np.power(2, repeat_num)) 212 | if norm=='SN': 213 | layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1))) 214 | else: 215 | layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=1, padding=1)) 216 | layers.append(nn.LeakyReLU(0.01, inplace=True)) 217 | curr_dim = curr_dim *2 218 | 219 | self.main = nn.Sequential(*layers) 220 | if norm=='SN': 221 | self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)) 222 | else: 223 | self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False) 224 | 225 | # conv1 remain the last square size, 256*256-->30*30 226 | #self.conv2 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=k_size, bias=False)) 227 | #conv2 output a single number 228 | 229 | def forward(self, x): 230 | h = self.main(x) 231 | #out_real = self.conv1(h) 232 | out_makeup = self.conv1(h) 233 | #return out_real.squeeze(), out_makeup.squeeze() 234 | return out_makeup.squeeze() 235 | 236 | class VGG(nn.Module): 237 | def __init__(self, pool='max'): 238 | super(VGG, self).__init__() 239 | # vgg modules 240 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 241 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 242 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 243 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 244 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 245 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 246 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 247 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 248 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 249 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 250 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 251 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 252 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 253 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 254 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 255 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 256 | if pool == 'max': 257 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 258 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 259 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 260 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 261 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 262 | elif pool == 'avg': 263 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 264 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 265 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 266 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 267 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 268 | 269 | def forward(self, x, out_keys): 270 | out = {} 271 | out['r11'] = F.relu(self.conv1_1(x)) 272 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 273 | out['p1'] = self.pool1(out['r12']) 274 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 275 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 276 | out['p2'] = self.pool2(out['r22']) 277 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 278 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 279 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 280 | out['r34'] = F.relu(self.conv3_4(out['r33'])) 281 | out['p3'] = self.pool3(out['r34']) 282 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 283 | 284 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 285 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 286 | out['r44'] = F.relu(self.conv4_4(out['r43'])) 287 | out['p4'] = self.pool4(out['r44']) 288 | out['r51'] = F.relu(self.conv5_1(out['p4'])) 289 | out['r52'] = F.relu(self.conv5_2(out['r51'])) 290 | out['r53'] = F.relu(self.conv5_3(out['r52'])) 291 | out['r54'] = F.relu(self.conv5_4(out['r53'])) 292 | out['p5'] = self.pool5(out['r54']) 293 | 294 | return [out[key] for key in out_keys] -------------------------------------------------------------------------------- /solver_cycle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.init as init 3 | from torch.autograd import Variable 4 | from torchvision.utils import save_image 5 | 6 | import itertools 7 | import os 8 | import time 9 | import datetime 10 | 11 | import tools.plot as plot_fig 12 | import net 13 | from ops.loss_added import GANLoss 14 | 15 | class Solver_cycleGAN(object): 16 | """ 17 | solver to reproduce the cycleGAN 18 | """ 19 | def __init__(self, data_loaders, config, dataset_config): 20 | # dataloader 21 | self.checkpoint = config.checkpoint 22 | # Hyper-parameteres 23 | self.g_lr = config.G_LR 24 | self.d_lr = config.D_LR 25 | self.ndis = config.ndis 26 | self.num_epochs = config.num_epochs # set 200 27 | self.num_epochs_decay = config.num_epochs_decay 28 | 29 | # Training settings 30 | self.snapshot_step = config.snapshot_step 31 | self.log_step = config.log_step 32 | self.vis_step = config.vis_step 33 | 34 | #training setting 35 | self.task_name = config.task_name 36 | 37 | # Data loader 38 | self.data_loader_train = data_loaders[0] 39 | self.data_loader_test = data_loaders[1] 40 | 41 | # Model hyper-parameters 42 | self.img_size = config.img_size 43 | self.g_conv_dim = config.g_conv_dim 44 | self.d_conv_dim = config.d_conv_dim 45 | self.g_repeat_num = config.g_repeat_num 46 | self.d_repeat_num = config.d_repeat_num 47 | 48 | # Hyper-parameteres 49 | self.lambda_idt = config.lambda_idt 50 | self.lambda_A = config.lambda_A 51 | self.lambda_B = config.lambda_B 52 | 53 | self.beta1 = config.beta1 54 | self.beta2 = config.beta2 55 | 56 | # Test settings 57 | self.test_model = config.test_model 58 | 59 | # Path 60 | self.log_path = config.log_path + '_' + config.task_name 61 | self.vis_path = config.vis_path + '_' + config.task_name 62 | self.snapshot_path = config.snapshot_path + '_' + config.task_name 63 | self.result_path = config.vis_path + '_' + config.task_name 64 | 65 | if not os.path.exists(self.log_path): 66 | os.makedirs(self.log_path) 67 | if not os.path.exists(self.vis_path): 68 | os.makedirs(self.vis_path) 69 | if not os.path.exists(self.snapshot_path): 70 | os.makedirs(self.snapshot_path) 71 | 72 | 73 | self.build_model() 74 | # Start with trained model 75 | if self.checkpoint: 76 | self.load_checkpoint() 77 | 78 | #for recording 79 | self.start_time = time.time() 80 | self.e = 0 81 | self.i = 0 82 | self.loss = {} 83 | 84 | if not os.path.exists(self.log_path): 85 | os.makedirs(self.log_path) 86 | if not os.path.exists(self.vis_path): 87 | os.makedirs(self.vis_path) 88 | if not os.path.exists(self.snapshot_path): 89 | os.makedirs(self.snapshot_path) 90 | 91 | def print_network(self, model, name): 92 | num_params = 0 93 | for p in model.parameters(): 94 | num_params += p.numel() 95 | print(name) 96 | print(model) 97 | print("The number of parameters: {}".format(num_params)) 98 | 99 | def update_lr(self, g_lr, d_lr): 100 | for param_group in self.g_optimizer.param_groups: 101 | param_group['lr'] = g_lr 102 | for param_group in self.d_A_optimizer.param_groups: 103 | param_group['lr'] = d_lr 104 | for param_group in self.d_B_optimizer.param_groups: 105 | param_group['lr'] = d_lr 106 | 107 | def log_terminal(self): 108 | elapsed = time.time() - self.start_time 109 | elapsed = str(datetime.timedelta(seconds=elapsed)) 110 | 111 | log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( 112 | elapsed, self.e+1, self.num_epochs, self.i+1, self.iters_per_epoch) 113 | 114 | for tag, value in self.loss.items(): 115 | log += ", {}: {:.4f}".format(tag, value) 116 | print(log) 117 | 118 | def save_models(self): 119 | torch.save(self.G_A.state_dict(), 120 | os.path.join(self.snapshot_path, '{}_{}_G_A.pth'.format(self.e + 1, self.i + 1))) 121 | torch.save(self.G_B.state_dict(), 122 | os.path.join(self.snapshot_path, '{}_{}_G_B.pth'.format(self.e + 1, self.i + 1))) 123 | torch.save(self.D_A.state_dict(), 124 | os.path.join(self.snapshot_path, '{}_{}_D_A.pth'.format(self.e + 1, self.i + 1))) 125 | torch.save(self.D_B.state_dict(), 126 | os.path.join(self.snapshot_path, '{}_{}_D_B.pth'.format(self.e + 1, self.i + 1))) 127 | 128 | def weights_init_xavier(self, m): 129 | classname = m.__class__.__name__ 130 | if classname.find('Conv') != -1: 131 | init.xavier_normal(m.weight.data, gain=1.0) 132 | elif classname.find('Linear') != -1: 133 | init.xavier_normal(m.weight.data, gain=1.0) 134 | 135 | def to_var(self, x, requires_grad=True): 136 | if torch.cuda.is_available(): 137 | x = x.cuda() 138 | if not requires_grad: 139 | return Variable(x, requires_grad=requires_grad) 140 | else: 141 | return Variable(x) 142 | 143 | def denorm(self, x): 144 | out = (x + 1) / 2 145 | return out.clamp_(0, 1) 146 | 147 | def load_checkpoint(self): 148 | self.G_A.load_state_dict(torch.load(os.path.join( 149 | self.snapshot_path, '{}_G_A.pth'.format(self.checkpoint)))) 150 | self.G_B.load_state_dict(torch.load(os.path.join( 151 | self.snapshot_path, '{}_G_B.pth'.format(self.checkpoint)))) 152 | self.D_A.load_state_dict(torch.load(os.path.join( 153 | self.snapshot_path, '{}_D_A.pth'.format(self.checkpoint)))) 154 | self.D_B.load_state_dict(torch.load(os.path.join( 155 | self.snapshot_path, '{}_D_B.pth'.format(self.checkpoint)))) 156 | print('loaded trained models (step: {})..!'.format(self.checkpoint)) 157 | 158 | def build_model(self): 159 | # Define generators and discriminators 160 | self.G_A = net.Generator(self.g_conv_dim, self.g_repeat_num) 161 | self.G_B = net.Generator(self.g_conv_dim, self.g_repeat_num) 162 | self.D_A = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num) 163 | self.D_B = net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num) 164 | self.criterionL1 = torch.nn.L1Loss() 165 | self.criterionGAN = GANLoss(use_lsgan=True, tensor =torch.cuda.FloatTensor) 166 | 167 | # Optimizers 168 | self.g_optimizer = torch.optim.Adam(itertools.chain(self.G_A.parameters(), self.G_B.parameters()), 169 | self.g_lr, [self.beta1, self.beta2]) 170 | self.d_A_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_A.parameters()), self.d_lr, [self.beta1, self.beta2]) 171 | self.d_B_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D_B.parameters()), self.d_lr, [self.beta1, self.beta2]) 172 | 173 | self.G_A.apply(self.weights_init_xavier) 174 | self.D_A.apply(self.weights_init_xavier) 175 | self.G_B.apply(self.weights_init_xavier) 176 | self.D_B.apply(self.weights_init_xavier) 177 | 178 | # Print networks 179 | # self.print_network(self.E, 'E') 180 | self.print_network(self.G_A, 'G_A') 181 | self.print_network(self.D_A, 'D_A') 182 | self.print_network(self.G_B, 'G_B') 183 | self.print_network(self.D_B, 'D_B') 184 | 185 | if torch.cuda.is_available(): 186 | self.G_A.cuda() 187 | self.G_B.cuda() 188 | self.D_A.cuda() 189 | self.D_B.cuda() 190 | 191 | def train(self): 192 | """Train StarGAN within a single dataset.""" 193 | # The number of iterations per epoch 194 | self.iters_per_epoch = len(self.data_loader_train) 195 | # Start with trained model if exists 196 | g_lr = self.g_lr 197 | d_lr = self.d_lr 198 | if self.checkpoint: 199 | start = int(self.checkpoint.split('_')[0]) 200 | else: 201 | start = 0 202 | # Start training 203 | self.start_time = time.time() 204 | for self.e in range(start, self.num_epochs): 205 | for self.i, (img_A, img_B, _, _) in enumerate(self.data_loader_train): 206 | # Convert tensor to variable 207 | org_A = self.to_var(img_A, requires_grad=False) 208 | ref_B = self.to_var(img_B, requires_grad=False) 209 | 210 | # ================== Train D ================== # 211 | # training D_A 212 | # Real 213 | out = self.D_A(ref_B) 214 | d_loss_real = self.criterionGAN(out, True) 215 | # Fake 216 | fake = self.G_A(org_A) 217 | fake = Variable(fake.data) 218 | fake = fake.detach() 219 | out = self.D_A(fake) 220 | #d_loss_fake = self.get_D_loss(out, "fake") 221 | d_loss_fake = self.criterionGAN(out, False) 222 | 223 | # Backward + Optimize 224 | d_loss = (d_loss_real + d_loss_fake) * 0.5 225 | self.d_A_optimizer.zero_grad() 226 | d_loss.backward(retain_graph=True) 227 | self.d_A_optimizer.step() 228 | 229 | # Logging 230 | self.loss = {} 231 | self.loss['D-A-loss_real'] = d_loss_real.item() 232 | 233 | # training D_B 234 | # Real 235 | out = self.D_B(org_A) 236 | d_loss_real = self.criterionGAN(out, True) 237 | # Fake 238 | fake = self.G_B(ref_B) 239 | fake = Variable(fake.data) 240 | fake = fake.detach() 241 | out = self.D_B(fake) 242 | #d_loss_fake = self.get_D_loss(out, "fake") 243 | d_loss_fake = self.criterionGAN(out, False) 244 | 245 | # Backward + Optimize 246 | d_loss = (d_loss_real + d_loss_fake) * 0.5 247 | self.d_B_optimizer.zero_grad() 248 | d_loss.backward(retain_graph=True) 249 | self.d_B_optimizer.step() 250 | 251 | # Logging 252 | self.loss['D-B-loss_real'] = d_loss_real.item() 253 | 254 | # ================== Train G ================== # 255 | if (self.i + 1) % self.ndis == 0: 256 | # adversarial loss, i.e. L_trans,v in the paper 257 | 258 | # identity loss 259 | if self.lambda_idt > 0: 260 | # G_A should be identity if ref_B is fed 261 | idt_A = self.G_A(ref_B) 262 | loss_idt_A = self.criterionL1(idt_A, ref_B) * self.lambda_B * self.lambda_idt 263 | # G_B should be identity if org_A is fed 264 | idt_B = self.G_B(org_A) 265 | loss_idt_B = self.criterionL1(idt_B, org_A) * self.lambda_A * self.lambda_idt 266 | g_loss_idt = loss_idt_A + loss_idt_B 267 | else: 268 | g_loss_idt = 0 269 | 270 | # GAN loss D_A(G_A(A)) 271 | fake_B = self.G_A(org_A) 272 | pred_fake = self.D_A(fake_B) 273 | g_A_loss_adv = self.criterionGAN(pred_fake, True) 274 | #g_loss_adv = self.get_G_loss(out) 275 | 276 | # GAN loss D_B(G_B(B)) 277 | fake_A = self.G_B(ref_B) 278 | pred_fake = self.D_B(fake_A) 279 | g_B_loss_adv = self.criterionGAN(pred_fake, True) 280 | 281 | # Forward cycle loss 282 | rec_A = self.G_B(fake_B) 283 | g_loss_rec_A = self.criterionL1(rec_A, org_A) * self.lambda_A 284 | 285 | # Backward cycle loss 286 | rec_B = self.G_A(fake_A) 287 | g_loss_rec_B = self.criterionL1(rec_B, ref_B) * self.lambda_B 288 | 289 | # Combined loss 290 | g_loss = g_A_loss_adv + g_B_loss_adv + g_loss_rec_A + g_loss_rec_B + g_loss_idt 291 | 292 | self.g_optimizer.zero_grad() 293 | g_loss.backward(retain_graph=True) 294 | self.g_optimizer.step() 295 | 296 | # Logging 297 | self.loss['G-A-loss_adv'] = g_A_loss_adv.item() 298 | self.loss['G-B-loss_adv'] = g_A_loss_adv.item() 299 | self.loss['G-loss_org'] = g_loss_rec_A.item() 300 | self.loss['G-loss_ref'] = g_loss_rec_B.item() 301 | self.loss['G-loss_idt'] = g_loss_idt.item() 302 | 303 | # Print out log info 304 | if (self.i + 1) % self.log_step == 0: 305 | self.log_terminal() 306 | 307 | #plot the figures 308 | for key_now in self.loss.keys(): 309 | plot_fig.plot(key_now, self.loss[key_now]) 310 | 311 | #save the images 312 | if (self.i + 1) % self.vis_step == 0: 313 | print("Saving middle output...") 314 | self.vis_train([org_A, ref_B, fake_A, fake_B, rec_A, rec_B]) 315 | self.vis_test() 316 | 317 | # Save model checkpoints 318 | if (self.i + 1) % self.snapshot_step == 0: 319 | self.save_models() 320 | 321 | if (self.i % 100 == 99): 322 | plot_fig.flush(self.task_name) 323 | 324 | plot_fig.tick() 325 | 326 | # Decay learning rate 327 | if (self.e+1) > (self.num_epochs - self.num_epochs_decay): 328 | g_lr -= (self.g_lr / float(self.num_epochs_decay)) 329 | d_lr -= (self.d_lr / float(self.num_epochs_decay)) 330 | self.update_lr(g_lr, d_lr) 331 | print('Decay learning rate to g_lr: {}, d_lr:{}.'.format(g_lr, d_lr)) 332 | 333 | def vis_train(self, img_train_list): 334 | # saving training results 335 | mode = "train_vis" 336 | img_train_list = torch.cat(img_train_list, dim=3) 337 | result_path_train = os.path.join(self.result_path, mode) 338 | if not os.path.exists(result_path_train): 339 | os.mkdir(result_path_train) 340 | save_path = os.path.join(result_path_train, '{}_{}_fake.jpg'.format(self.e, self.i)) 341 | save_image(self.denorm(img_train_list.data), save_path, normalize=True) 342 | 343 | def vis_test(self): 344 | # saving test results 345 | mode = "test_vis" 346 | for i, (img_A, img_B) in enumerate(self.data_loader_test): 347 | real_org = self.to_var(img_A) 348 | real_ref = self.to_var(img_B) 349 | 350 | image_list = [] 351 | image_list.append(real_org) 352 | image_list.append(real_ref) 353 | 354 | # Get makeup result 355 | fake_A = self.G_A(real_org) 356 | fake_B = self.G_B(real_ref) 357 | rec_A = self.G_B(fake_A) 358 | rec_B = self.G_A(fake_B) 359 | 360 | image_list.append(fake_A) 361 | image_list.append(fake_B) 362 | image_list.append(rec_A) 363 | image_list.append(rec_B) 364 | 365 | image_list = torch.cat(image_list, dim=3) 366 | vis_train_path = os.path.join(self.result_path, mode) 367 | result_path_now = os.path.join(vis_train_path, "epoch" + str(self.e)) 368 | if not os.path.exists(result_path_now): 369 | os.makedirs(result_path_now) 370 | save_path = os.path.join(result_path_now, '{}_{}_{}_fake.jpg'.format(self.e, self.i, i + 1)) 371 | save_image(self.denorm(image_list.data), save_path, normalize=True) 372 | #print('Translated test images and saved into "{}"..!'.format(save_path)) 373 | 374 | def test(self): 375 | # Load trained parameters 376 | G_A_path = os.path.join(self.snapshot_path, '{}_G_A.pth'.format(self.test_model)) 377 | G_B_path = os.path.join(self.snapshot_path, '{}_G_B.pth'.format(self.test_model)) 378 | self.G_A.load_state_dict(torch.load(G_A_path)) 379 | self.G_A.eval() 380 | self.G_B.load_state_dict(torch.load(G_B_path)) 381 | self.G_B.eval() 382 | for i, (img_A, img_B) in enumerate(self.data_loader_test): 383 | real_org = self.to_var(img_A) 384 | real_ref = self.to_var(img_B) 385 | 386 | image_list = [] 387 | image_list.append(real_org) 388 | image_list.append(real_ref) 389 | 390 | # Get makeup result 391 | fake_A = self.G_A(real_org) 392 | fake_B = self.G_B(real_ref) 393 | rec_A = self.G_B(fake_A) 394 | rec_B = self.G_A(fake_B) 395 | 396 | image_list.append(fake_A) 397 | image_list.append(fake_B) 398 | image_list.append(rec_A) 399 | image_list.append(rec_B) 400 | 401 | image_list = torch.cat(image_list, dim=3) 402 | save_path = os.path.join(self.result_path, '{}_{}_{}_fake.png'.format(self.e, self.i, i + 1)) 403 | save_image(self.denorm(image_list.data), save_path, nrow=1, padding=0, normalize=True) 404 | print('Translated test images and saved into "{}"..!'.format(save_path)) 405 | -------------------------------------------------------------------------------- /solver_makeup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.init as init 3 | from torch.autograd import Variable 4 | from torchvision.utils import save_image 5 | 6 | import os 7 | import time 8 | import datetime 9 | 10 | import tools.plot as plot_fig 11 | import net 12 | from ops.histogram_matching import * 13 | from ops.loss_added import GANLoss 14 | 15 | class Solver_makeupGAN(object): 16 | def __init__(self, data_loaders, config, dataset_config): 17 | # dataloader 18 | self.checkpoint = config.checkpoint 19 | # Hyper-parameteres 20 | self.g_lr = config.G_LR 21 | self.d_lr = config.D_LR 22 | self.ndis = config.ndis 23 | self.num_epochs = config.num_epochs # set 200 24 | self.num_epochs_decay = config.num_epochs_decay 25 | self.batch_size = config.batch_size 26 | self.whichG = config.whichG 27 | self.norm = config.norm 28 | 29 | # Training settings 30 | self.snapshot_step = config.snapshot_step 31 | self.log_step = config.log_step 32 | self.vis_step = config.vis_step 33 | 34 | #training setting 35 | self.task_name = config.task_name 36 | 37 | # Data loader 38 | self.data_loader_train = data_loaders[0] 39 | self.data_loader_test = data_loaders[1] 40 | 41 | # Model hyper-parameters 42 | self.img_size = config.img_size 43 | self.g_conv_dim = config.g_conv_dim 44 | self.d_conv_dim = config.d_conv_dim 45 | self.g_repeat_num = config.g_repeat_num 46 | self.d_repeat_num = config.d_repeat_num 47 | self.lips = config.lips 48 | self.skin = config.skin 49 | self.eye = config.eye 50 | 51 | # Hyper-parameteres 52 | self.lambda_idt = config.lambda_idt 53 | self.lambda_A = config.lambda_A 54 | self.lambda_B = config.lambda_B 55 | self.lambda_his_lip = config.lambda_his_lip 56 | self.lambda_his_skin_1 = config.lambda_his_skin_1 57 | self.lambda_his_skin_2 = config.lambda_his_skin_2 58 | self.lambda_his_eye = config.lambda_his_eye 59 | self.lambda_vgg = config.lambda_vgg 60 | 61 | self.beta1 = config.beta1 62 | self.beta2 = config.beta2 63 | 64 | self.cls = config.cls_list 65 | self.content_layer = config.content_layer 66 | self.direct = config.direct 67 | # Test settings 68 | self.test_model = config.test_model 69 | 70 | # Path 71 | self.log_path = config.log_path + '_' + config.task_name 72 | self.vis_path = config.vis_path + '_' + config.task_name 73 | self.snapshot_path = config.snapshot_path + '_' + config.task_name 74 | self.result_path = config.vis_path + '_' + config.task_name 75 | 76 | if not os.path.exists(self.log_path): 77 | os.makedirs(self.log_path) 78 | if not os.path.exists(self.vis_path): 79 | os.makedirs(self.vis_path) 80 | if not os.path.exists(self.snapshot_path): 81 | os.makedirs(self.snapshot_path) 82 | 83 | self.build_model() 84 | # Start with trained model 85 | if self.checkpoint: 86 | self.load_checkpoint() 87 | 88 | #for recording 89 | self.start_time = time.time() 90 | self.e = 0 91 | self.i = 0 92 | self.loss = {} 93 | 94 | if not os.path.exists(self.log_path): 95 | os.makedirs(self.log_path) 96 | if not os.path.exists(self.vis_path): 97 | os.makedirs(self.vis_path) 98 | if not os.path.exists(self.snapshot_path): 99 | os.makedirs(self.snapshot_path) 100 | 101 | def print_network(self, model, name): 102 | num_params = 0 103 | for p in model.parameters(): 104 | num_params += p.numel() 105 | print(name) 106 | print(model) 107 | print("The number of parameters: {}".format(num_params)) 108 | 109 | def update_lr(self, g_lr, d_lr): 110 | for param_group in self.g_optimizer.param_groups: 111 | param_group['lr'] = g_lr 112 | for i in self.cls: 113 | for param_group in getattr(self, "d_" + i + "_optimizer").param_groups: 114 | param_group['lr'] = d_lr 115 | 116 | def log_terminal(self): 117 | elapsed = time.time() - self.start_time 118 | elapsed = str(datetime.timedelta(seconds=elapsed)) 119 | 120 | log = "Elapsed [{}], Epoch [{}/{}], Iter [{}/{}]".format( 121 | elapsed, self.e+1, self.num_epochs, self.i+1, self.iters_per_epoch) 122 | 123 | for tag, value in self.loss.items(): 124 | log += ", {}: {:.4f}".format(tag, value) 125 | print(log) 126 | 127 | def save_models(self): 128 | torch.save(self.G.state_dict(), 129 | os.path.join(self.snapshot_path, '{}_{}_G.pth'.format(self.e + 1, self.i + 1))) 130 | for i in self.cls: 131 | torch.save(getattr(self, "D_" + i).state_dict(), 132 | os.path.join(self.snapshot_path, '{}_{}_D_'.format(self.e + 1, self.i + 1) + i + '.pth')) 133 | 134 | def weights_init_xavier(self, m): 135 | classname = m.__class__.__name__ 136 | if classname.find('Conv') != -1: 137 | init.xavier_normal(m.weight.data, gain=1.0) 138 | elif classname.find('Linear') != -1: 139 | init.xavier_normal(m.weight.data, gain=1.0) 140 | 141 | def to_var(self, x, requires_grad=True): 142 | if torch.cuda.is_available(): 143 | x = x.cuda() 144 | if not requires_grad: 145 | return Variable(x, requires_grad=requires_grad) 146 | else: 147 | return Variable(x) 148 | 149 | def de_norm(self, x): 150 | out = (x + 1) / 2 151 | return out.clamp(0, 1) 152 | 153 | def load_checkpoint(self): 154 | self.G.load_state_dict(torch.load(os.path.join( 155 | self.snapshot_path, '{}_G.pth'.format(self.checkpoint)))) 156 | for i in self.cls: 157 | getattr(self, "D_" + i).load_state_dict(torch.load(os.path.join( 158 | self.snapshot_path, '{}_D_'.format(self.checkpoint) + i + '.pth'))) 159 | print('loaded trained models (step: {})..!'.format(self.checkpoint)) 160 | 161 | def build_model(self): 162 | # Define generators and discriminators 163 | if self.whichG=='normal': 164 | self.G = net.Generator_makeup(self.g_conv_dim, self.g_repeat_num) 165 | if self.whichG=='branch': 166 | self.G = net.Generator_branch(self.g_conv_dim, self.g_repeat_num) 167 | for i in self.cls: 168 | setattr(self, "D_" + i, net.Discriminator(self.img_size, self.d_conv_dim, self.d_repeat_num, self.norm)) 169 | 170 | self.criterionL1 = torch.nn.L1Loss() 171 | self.criterionL2 = torch.nn.MSELoss() 172 | self.criterionGAN = GANLoss(use_lsgan=True, tensor =torch.cuda.FloatTensor) 173 | self.vgg = net.VGG() 174 | self.vgg.load_state_dict(torch.load('addings/vgg_conv.pth')) 175 | # Optimizers 176 | self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2]) 177 | for i in self.cls: 178 | setattr(self, "d_" + i + "_optimizer", \ 179 | torch.optim.Adam(filter(lambda p: p.requires_grad, getattr(self, "D_" + i).parameters()), \ 180 | self.d_lr, [self.beta1, self.beta2])) 181 | 182 | # Weights initialization 183 | self.G.apply(self.weights_init_xavier) 184 | for i in self.cls: 185 | getattr(self, "D_" + i).apply(self.weights_init_xavier) 186 | 187 | # Print networks 188 | self.print_network(self.G, 'G') 189 | for i in self.cls: 190 | self.print_network(getattr(self, "D_" + i), "D_" + i) 191 | 192 | if torch.cuda.is_available(): 193 | self.G.cuda() 194 | self.vgg.cuda() 195 | for i in self.cls: 196 | getattr(self, "D_" + i).cuda() 197 | 198 | def rebound_box(self, mask_A, mask_B, mask_A_face): 199 | index_tmp = mask_A.nonzero() 200 | x_A_index = index_tmp[:, 2] 201 | y_A_index = index_tmp[:, 3] 202 | index_tmp = mask_B.nonzero() 203 | x_B_index = index_tmp[:, 2] 204 | y_B_index = index_tmp[:, 3] 205 | mask_A_temp = mask_A.copy_(mask_A) 206 | mask_B_temp = mask_B.copy_(mask_B) 207 | mask_A_temp[: ,: ,min(x_A_index)-10:max(x_A_index)+11, min(y_A_index)-10:max(y_A_index)+11] =\ 208 | mask_A_face[: ,: ,min(x_A_index)-10:max(x_A_index)+11, min(y_A_index)-10:max(y_A_index)+11] 209 | mask_B_temp[: ,: ,min(x_B_index)-10:max(x_B_index)+11, min(y_B_index)-10:max(y_B_index)+11] =\ 210 | mask_A_face[: ,: ,min(x_B_index)-10:max(x_B_index)+11, min(y_B_index)-10:max(y_B_index)+11] 211 | mask_A_temp = self.to_var(mask_A_temp, requires_grad=False) 212 | mask_B_temp = self.to_var(mask_B_temp, requires_grad=False) 213 | return mask_A_temp, mask_B_temp 214 | 215 | def mask_preprocess(self, mask_A, mask_B): 216 | index_tmp = mask_A.nonzero() 217 | x_A_index = index_tmp[:, 2] 218 | y_A_index = index_tmp[:, 3] 219 | index_tmp = mask_B.nonzero() 220 | x_B_index = index_tmp[:, 2] 221 | y_B_index = index_tmp[:, 3] 222 | mask_A = self.to_var(mask_A, requires_grad=False) 223 | mask_B = self.to_var(mask_B, requires_grad=False) 224 | index = [x_A_index, y_A_index, x_B_index, y_B_index] 225 | index_2 = [x_B_index, y_B_index, x_A_index, y_A_index] 226 | return mask_A, mask_B, index, index_2 227 | 228 | def criterionHis(self, input_data, target_data, mask_src, mask_tar, index): 229 | input_data = (self.de_norm(input_data) * 255).squeeze() 230 | target_data = (self.de_norm(target_data) * 255).squeeze() 231 | mask_src = mask_src.expand(1, 3, mask_src.size(2), mask_src.size(2)).squeeze() 232 | mask_tar = mask_tar.expand(1, 3, mask_tar.size(2), mask_tar.size(2)).squeeze() 233 | input_masked = input_data * mask_src 234 | target_masked = target_data * mask_tar 235 | # dstImg = (input_masked.data).cpu().clone() 236 | # refImg = (target_masked.data).cpu().clone() 237 | input_match = histogram_matching(input_masked, target_masked, index) 238 | input_match = self.to_var(input_match, requires_grad=False) 239 | loss = self.criterionL1(input_masked, input_match) 240 | return loss 241 | 242 | def train(self): 243 | """Train StarGAN within a single dataset.""" 244 | # The number of iterations per epoch 245 | self.iters_per_epoch = len(self.data_loader_train) 246 | # Start with trained model if exists 247 | cls_A = self.cls[0] 248 | cls_B = self.cls[1] 249 | g_lr = self.g_lr 250 | d_lr = self.d_lr 251 | if self.checkpoint: 252 | start = int(self.checkpoint.split('_')[0]) 253 | self.vis_test() 254 | else: 255 | start = 0 256 | # Start training 257 | self.start_time = time.time() 258 | for self.e in range(start, self.num_epochs): 259 | for self.i, (img_A, img_B, mask_A, mask_B) in enumerate(self.data_loader_train): 260 | # Convert tensor to variable 261 | # mask attribute: 0:background 1:face 2:left-eyebrown 3:right-eyebrown 4:left-eye 5: right-eye 6: nose 262 | # 7: upper-lip 8: teeth 9: under-lip 10:hair 11: left-ear 12: right-ear 13: neck 263 | if self.checkpoint or self.direct: 264 | if self.lips==True: 265 | mask_A_lip = (mask_A==7).float() + (mask_A==9).float() 266 | mask_B_lip = (mask_B==7).float() + (mask_B==9).float() 267 | mask_A_lip, mask_B_lip, index_A_lip, index_B_lip = self.mask_preprocess(mask_A_lip, mask_B_lip) 268 | if self.skin==True: 269 | mask_A_skin = (mask_A==1).float() + (mask_A==6).float() + (mask_A==13).float() 270 | mask_B_skin = (mask_B==1).float() + (mask_B==6).float() + (mask_B==13).float() 271 | mask_A_skin, mask_B_skin, index_A_skin, index_B_skin = self.mask_preprocess(mask_A_skin, mask_B_skin) 272 | if self.eye==True: 273 | mask_A_eye_left = (mask_A==4).float() 274 | mask_A_eye_right = (mask_A==5).float() 275 | mask_B_eye_left = (mask_B==4).float() 276 | mask_B_eye_right = (mask_B==5).float() 277 | mask_A_face = (mask_A==1).float() + (mask_A==6).float() 278 | mask_B_face = (mask_B==1).float() + (mask_B==6).float() 279 | # avoid the situation that images with eye closed 280 | if not ((mask_A_eye_left>0).any() and (mask_B_eye_left>0).any() and \ 281 | (mask_A_eye_right > 0).any() and (mask_B_eye_right > 0).any()): 282 | continue 283 | mask_A_eye_left, mask_A_eye_right = self.rebound_box(mask_A_eye_left, mask_A_eye_right, mask_A_face) 284 | mask_B_eye_left, mask_B_eye_right = self.rebound_box(mask_B_eye_left, mask_B_eye_right, mask_B_face) 285 | mask_A_eye_left, mask_B_eye_left, index_A_eye_left, index_B_eye_left = \ 286 | self.mask_preprocess(mask_A_eye_left, mask_B_eye_left) 287 | mask_A_eye_right, mask_B_eye_right, index_A_eye_right, index_B_eye_right = \ 288 | self.mask_preprocess(mask_A_eye_right, mask_B_eye_right) 289 | 290 | org_A = self.to_var(img_A, requires_grad=False) 291 | ref_B = self.to_var(img_B, requires_grad=False) 292 | # ================== Train D ================== # 293 | # training D_A, D_A aims to distinguish class B 294 | # Real 295 | out = getattr(self, "D_" + cls_A)(ref_B) 296 | d_loss_real = self.criterionGAN(out, True) 297 | # Fake 298 | fake_A, fake_B = self.G(org_A, ref_B) 299 | fake_A = Variable(fake_A.data).detach() 300 | fake_B = Variable(fake_B.data).detach() 301 | out = getattr(self, "D_" + cls_A)(fake_A) 302 | #d_loss_fake = self.get_D_loss(out, "fake") 303 | d_loss_fake = self.criterionGAN(out, False) 304 | 305 | # Backward + Optimize 306 | d_loss = (d_loss_real + d_loss_fake) * 0.5 307 | getattr(self, "d_" + cls_A + "_optimizer").zero_grad() 308 | d_loss.backward(retain_graph=True) 309 | getattr(self, "d_" + cls_A + "_optimizer").step() 310 | 311 | # Logging 312 | self.loss = {} 313 | self.loss['D-A-loss_real'] = d_loss_real.item() 314 | 315 | # training D_B, D_B aims to distinguish class A 316 | # Real 317 | out = getattr(self, "D_" + cls_B)(org_A) 318 | d_loss_real = self.criterionGAN(out, True) 319 | # Fake 320 | out = getattr(self, "D_" + cls_B)(fake_B) 321 | #d_loss_fake = self.get_D_loss(out, "fake") 322 | d_loss_fake = self.criterionGAN(out, False) 323 | 324 | # Backward + Optimize 325 | d_loss = (d_loss_real + d_loss_fake) * 0.5 326 | getattr(self, "d_" + cls_B + "_optimizer").zero_grad() 327 | d_loss.backward(retain_graph=True) 328 | getattr(self, "d_" + cls_B + "_optimizer").step() 329 | 330 | # Logging 331 | self.loss['D-B-loss_real'] = d_loss_real.item() 332 | 333 | # ================== Train G ================== # 334 | if (self.i + 1) % self.ndis == 0: 335 | # adversarial loss, i.e. L_trans,v in the paper 336 | 337 | # identity loss 338 | if self.lambda_idt > 0: 339 | # G should be identity if ref_B or org_A is fed 340 | idt_A1, idt_A2 = self.G(org_A, org_A) 341 | idt_B1, idt_B2 = self.G(ref_B, ref_B) 342 | loss_idt_A1 = self.criterionL1(idt_A1, org_A) * self.lambda_A * self.lambda_idt 343 | loss_idt_A2 = self.criterionL1(idt_A2, org_A) * self.lambda_A * self.lambda_idt 344 | loss_idt_B1 = self.criterionL1(idt_B1, ref_B) * self.lambda_B * self.lambda_idt 345 | loss_idt_B2 = self.criterionL1(idt_B2, ref_B) * self.lambda_B * self.lambda_idt 346 | # loss_idt 347 | loss_idt = (loss_idt_A1 + loss_idt_A2 + loss_idt_B1 + loss_idt_B2) * 0.5 348 | else: 349 | loss_idt = 0 350 | 351 | # GAN loss D_A(G_A(A)) 352 | # fake_A in class B, 353 | fake_A, fake_B = self.G(org_A, ref_B) 354 | pred_fake = getattr(self, "D_" + cls_A)(fake_A) 355 | g_A_loss_adv = self.criterionGAN(pred_fake, True) 356 | #g_loss_adv = self.get_G_loss(out) 357 | # GAN loss D_B(G_B(B)) 358 | pred_fake = getattr(self, "D_" + cls_B)(fake_B) 359 | g_B_loss_adv = self.criterionGAN(pred_fake, True) 360 | rec_B, rec_A = self.G(fake_B, fake_A) 361 | 362 | # color_histogram loss 363 | g_A_loss_his = 0 364 | g_B_loss_his = 0 365 | if self.checkpoint or self.direct: 366 | if self.lips==True: 367 | g_A_lip_loss_his = self.criterionHis(fake_A, ref_B, mask_A_lip, mask_B_lip, index_A_lip) * self.lambda_his_lip 368 | g_B_lip_loss_his = self.criterionHis(fake_B, org_A, mask_B_lip, mask_A_lip, index_B_lip) * self.lambda_his_lip 369 | g_A_loss_his += g_A_lip_loss_his 370 | g_B_loss_his += g_B_lip_loss_his 371 | if self.skin==True: 372 | g_A_skin_loss_his = self.criterionHis(fake_A, ref_B, mask_A_skin, mask_B_skin, index_A_skin) * self.lambda_his_skin_1 373 | g_B_skin_loss_his = self.criterionHis(fake_B, org_A, mask_B_skin, mask_A_skin, index_B_skin) * self.lambda_his_skin_2 374 | g_A_loss_his += g_A_skin_loss_his 375 | g_B_loss_his += g_B_skin_loss_his 376 | if self.eye==True: 377 | g_A_eye_left_loss_his = self.criterionHis(fake_A, ref_B, mask_A_eye_left, mask_B_eye_left, index_A_eye_left) * self.lambda_his_eye 378 | g_B_eye_left_loss_his = self.criterionHis(fake_B, org_A, mask_B_eye_left, mask_A_eye_left, index_B_eye_left) * self.lambda_his_eye 379 | g_A_eye_right_loss_his = self.criterionHis(fake_A, ref_B, mask_A_eye_right, mask_B_eye_right, index_A_eye_right) * self.lambda_his_eye 380 | g_B_eye_right_loss_his = self.criterionHis(fake_B, org_A, mask_B_eye_right, mask_A_eye_right, index_B_eye_right) * self.lambda_his_eye 381 | g_A_loss_his += g_A_eye_left_loss_his + g_A_eye_right_loss_his 382 | g_B_loss_his += g_B_eye_left_loss_his + g_B_eye_right_loss_his 383 | 384 | # cycle loss 385 | g_loss_rec_A = self.criterionL1(rec_A, org_A) * self.lambda_A 386 | g_loss_rec_B = self.criterionL1(rec_B, ref_B) * self.lambda_B 387 | 388 | # vgg loss 389 | vgg_org = self.vgg(org_A, self.content_layer)[0] 390 | vgg_org = Variable(vgg_org.data).detach() 391 | vgg_fake_A = self.vgg(fake_A, self.content_layer)[0] 392 | g_loss_A_vgg = self.criterionL2(vgg_fake_A, vgg_org) * self.lambda_A * self.lambda_vgg 393 | 394 | vgg_ref = self.vgg(ref_B, self.content_layer)[0] 395 | vgg_ref = Variable(vgg_ref.data).detach() 396 | vgg_fake_B = self.vgg(fake_B, self.content_layer)[0] 397 | g_loss_B_vgg = self.criterionL2(vgg_fake_B, vgg_ref) * self.lambda_B * self.lambda_vgg 398 | 399 | loss_rec = (g_loss_rec_A + g_loss_rec_B + g_loss_A_vgg + g_loss_B_vgg) * 0.5 400 | 401 | # Combined loss 402 | g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt 403 | if self.checkpoint or self.direct: 404 | g_loss = g_A_loss_adv + g_B_loss_adv + loss_rec + loss_idt + g_A_loss_his + g_B_loss_his 405 | 406 | self.g_optimizer.zero_grad() 407 | g_loss.backward(retain_graph=True) 408 | self.g_optimizer.step() 409 | 410 | # Logging 411 | self.loss['G-A-loss-adv'] = g_A_loss_adv.item() 412 | self.loss['G-B-loss-adv'] = g_A_loss_adv.item() 413 | self.loss['G-loss-org'] = g_loss_rec_A.item() 414 | self.loss['G-loss-ref'] = g_loss_rec_B.item() 415 | self.loss['G-loss-idt'] = loss_idt.item() 416 | self.loss['G-loss-img-rec'] = (g_loss_rec_A + g_loss_rec_B).item() 417 | self.loss['G-loss-vgg-rec'] = (g_loss_A_vgg + g_loss_B_vgg).item() 418 | if self.direct: 419 | self.loss['G-A-loss-his'] = g_A_loss_his.item() 420 | self.loss['G-B-loss-his'] = g_B_loss_his.item() 421 | 422 | # Print out log info 423 | if (self.i + 1) % self.log_step == 0: 424 | self.log_terminal() 425 | 426 | #plot the figures 427 | for key_now in self.loss.keys(): 428 | plot_fig.plot(key_now, self.loss[key_now]) 429 | 430 | #save the images 431 | if (self.i + 1) % self.vis_step == 0: 432 | print("Saving middle output...") 433 | self.vis_train([org_A, ref_B, fake_A, fake_B, rec_A, rec_B]) 434 | 435 | 436 | # Save model checkpoints 437 | if (self.i + 1) % self.snapshot_step == 0: 438 | self.save_models() 439 | 440 | if (self.i % 100 == 99): 441 | plot_fig.flush(self.task_name) 442 | 443 | plot_fig.tick() 444 | 445 | # Decay learning rate 446 | if (self.e+1) > (self.num_epochs - self.num_epochs_decay): 447 | g_lr -= (self.g_lr / float(self.num_epochs_decay)) 448 | d_lr -= (self.d_lr / float(self.num_epochs_decay)) 449 | self.update_lr(g_lr, d_lr) 450 | print('Decay learning rate to g_lr: {}, d_lr:{}.'.format(g_lr, d_lr)) 451 | 452 | if self.e % 2 == 0: 453 | print("Saving output...") 454 | self.vis_test() 455 | 456 | def vis_train(self, img_train_list): 457 | # saving training results 458 | mode = "train_vis" 459 | img_train_list = torch.cat(img_train_list, dim=3) 460 | result_path_train = os.path.join(self.result_path, mode) 461 | if not os.path.exists(result_path_train): 462 | os.mkdir(result_path_train) 463 | save_path = os.path.join(result_path_train, '{}_{}_fake.jpg'.format(self.e, self.i)) 464 | save_image(self.de_norm(img_train_list.data), save_path, normalize=True) 465 | 466 | def vis_test(self): 467 | # saving test results 468 | mode = "test_vis" 469 | for i, (img_A, img_B) in enumerate(self.data_loader_test): 470 | real_org = self.to_var(img_A) 471 | real_ref = self.to_var(img_B) 472 | 473 | image_list = [] 474 | image_list.append(real_org) 475 | image_list.append(real_ref) 476 | 477 | # Get makeup result 478 | fake_A, fake_B = self.G(real_org, real_ref) 479 | rec_B, rec_A = self.G(fake_B, fake_A) 480 | 481 | image_list.append(fake_A) 482 | image_list.append(fake_B) 483 | image_list.append(rec_A) 484 | image_list.append(rec_B) 485 | 486 | image_list = torch.cat(image_list, dim=3) 487 | vis_train_path = os.path.join(self.result_path, mode) 488 | result_path_now = os.path.join(vis_train_path, "epoch" + str(self.e)) 489 | if not os.path.exists(result_path_now): 490 | os.makedirs(result_path_now) 491 | save_path = os.path.join(result_path_now, '{}_{}_{}_fake.png'.format(self.e, self.i, i + 1)) 492 | save_image(self.de_norm(image_list.data), save_path, normalize=True) 493 | #print('Translated test images and saved into "{}"..!'.format(save_path)) 494 | 495 | def test(self): 496 | # Load trained parameters 497 | G_path = os.path.join(self.snapshot_path, '{}_G.pth'.format(self.test_model)) 498 | self.G.load_state_dict(torch.load(G_path)) 499 | self.G.eval() 500 | #time_total = time.time() 501 | time_total = 0 502 | for i, (img_A, img_B) in enumerate(self.data_loader_test): 503 | #start = time.time() 504 | start = time.time() 505 | real_org = self.to_var(img_A) 506 | real_ref = self.to_var(img_B) 507 | 508 | image_list = [] 509 | image_list_0 = [] 510 | image_list.append(real_org) 511 | image_list.append(real_ref) 512 | 513 | # Get makeup result 514 | fake_A, fake_B = self.G(real_org, real_ref) 515 | rec_B, rec_A = self.G(fake_B, fake_A) 516 | time_total += time.time() - start 517 | image_list.append(fake_A) 518 | image_list_0.append(fake_A) 519 | image_list.append(fake_B) 520 | image_list.append(rec_A) 521 | image_list.append(rec_B) 522 | 523 | image_list = torch.cat(image_list, dim=3) 524 | image_list_0 = torch.cat(image_list_0, dim=3) 525 | 526 | result_path_now = os.path.join(self.result_path, "multi") 527 | if not os.path.exists(result_path_now): 528 | os.makedirs(result_path_now) 529 | save_path = os.path.join(result_path_now, '{}_{}_{}_fake.png'.format(self.e, self.i, i + 1)) 530 | save_image(self.de_norm(image_list.data), save_path, nrow=1, padding=0, normalize=True) 531 | result_path_now = os.path.join(self.result_path, "single") 532 | if not os.path.exists(result_path_now): 533 | os.makedirs(result_path_now) 534 | save_path_0 = os.path.join(result_path_now, '{}_{}_{}_fake_single.png'.format(self.e, self.i, i + 1)) 535 | save_image(self.de_norm(image_list_0.data), save_path_0, nrow=1, padding=0, normalize=True) 536 | print('Translated test images and saved into "{}"..!'.format(save_path)) 537 | print("average time : {}".format(time_total/len(self.data_loader_test))) 538 | --------------------------------------------------------------------------------