├── data ├── __init__.py ├── data_loader.py ├── celeba.py └── base_dataset.py ├── requirements.txt ├── imgs ├── celeba_testing.jpg ├── gifs │ ├── celeba_1.gif │ ├── celeba_2.gif │ ├── emotionnet_1.gif │ ├── emotionnet_2.gif │ ├── emotionnet_3.gif │ └── emotionnet_4.gif ├── celeba_training.jpg ├── ganimation_show.jpg ├── emotionnet_testing.jpg ├── emotionnet_training.jpg ├── celeba_stargan_testing.jpg └── emotionnet_stargan_testing.jpg ├── main.py ├── model ├── __init__.py ├── stargan.py ├── ganimation.py ├── base_model.py └── model_utils.py ├── LICENSE ├── .gitignore ├── visualizer.py ├── solvers.py ├── README.md └── options.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import create_dataloader -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.1 2 | torchvision>=0.2.1 3 | visdom>=0.1.8.3 4 | imageio>=2.5.0 -------------------------------------------------------------------------------- /imgs/celeba_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/celeba_testing.jpg -------------------------------------------------------------------------------- /imgs/gifs/celeba_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/celeba_1.gif -------------------------------------------------------------------------------- /imgs/gifs/celeba_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/celeba_2.gif -------------------------------------------------------------------------------- /imgs/celeba_training.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/celeba_training.jpg -------------------------------------------------------------------------------- /imgs/ganimation_show.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/ganimation_show.jpg -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_1.gif -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_2.gif -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_3.gif -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_4.gif -------------------------------------------------------------------------------- /imgs/emotionnet_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/emotionnet_testing.jpg -------------------------------------------------------------------------------- /imgs/emotionnet_training.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/emotionnet_training.jpg -------------------------------------------------------------------------------- /imgs/celeba_stargan_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/celeba_stargan_testing.jpg -------------------------------------------------------------------------------- /imgs/emotionnet_stargan_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/emotionnet_stargan_testing.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Dec 13, 2018 3 | @author: Yuedong Chen 4 | """ 5 | 6 | from options import Options 7 | from solvers import create_solver 8 | 9 | 10 | 11 | 12 | if __name__ == '__main__': 13 | opt = Options().parse() 14 | 15 | solver = create_solver(opt) 16 | solver.run_solver() 17 | 18 | print('[THE END]') -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .ganimation import GANimationModel 3 | from .stargan import StarGANModel 4 | 5 | 6 | 7 | def create_model(opt): 8 | # specify model name here 9 | if opt.model == "ganimation": 10 | instance = GANimationModel() 11 | elif opt.model == "stargan": 12 | instance = StarGANModel() 13 | else: 14 | instance = BaseModel() 15 | instance.initialize(opt) 16 | instance.setup() 17 | return instance 18 | 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yuedong Chen (Donald) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import random 5 | import numpy as np 6 | import pickle 7 | import torchvision.transforms as transforms 8 | 9 | from .celeba import CelebADataset 10 | 11 | 12 | def create_dataloader(opt): 13 | data_loader = DataLoader() 14 | data_loader.initialize(opt) 15 | return data_loader 16 | 17 | 18 | class DataLoader: 19 | def name(self): 20 | return self.dataset.name() + "_Loader" 21 | 22 | def create_datase(self): 23 | # specify which dataset to load here 24 | loaded_dataset = os.path.basename(self.opt.data_root.strip('/')).lower() 25 | if 'celeba' in loaded_dataset or 'emotion' in loaded_dataset: 26 | dataset = CelebADataset() 27 | else: 28 | dataset = BaseDataset() 29 | dataset.initialize(self.opt) 30 | return dataset 31 | 32 | def initialize(self, opt): 33 | self.opt = opt 34 | self.dataset = self.create_datase() 35 | self.dataloader = torch.utils.data.DataLoader( 36 | self.dataset, 37 | batch_size=opt.batch_size, 38 | shuffle=not opt.serial_batches, 39 | num_workers=int(opt.n_threads) 40 | ) 41 | 42 | def __len__(self): 43 | return min(len(self.dataset), self.opt.max_dataset_size) 44 | 45 | def __iter__(self): 46 | for i, data in enumerate(self.dataloader): 47 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 48 | break 49 | yield data 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # project related 107 | results 108 | ckpts 109 | sftp-config.json 110 | 111 | -------------------------------------------------------------------------------- /data/celeba.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import os 3 | import random 4 | import numpy as np 5 | 6 | 7 | class CelebADataset(BaseDataset): 8 | """docstring for CelebADataset""" 9 | def __init__(self): 10 | super(CelebADataset, self).__init__() 11 | 12 | def initialize(self, opt): 13 | super(CelebADataset, self).initialize(opt) 14 | 15 | def get_aus_by_path(self, img_path): 16 | assert os.path.isfile(img_path), "Cannot find image file: %s" % img_path 17 | img_id = str(os.path.splitext(os.path.basename(img_path))[0]) 18 | return self.aus_dict[img_id] / 5.0 # norm to [0, 1] 19 | 20 | def make_dataset(self): 21 | # return all image full path in a list 22 | imgs_path = [] 23 | assert os.path.isfile(self.imgs_name_file), "%s does not exist." % self.imgs_name_file 24 | with open(self.imgs_name_file, 'r') as f: 25 | lines = f.readlines() 26 | imgs_path = [os.path.join(self.imgs_dir, line.strip()) for line in lines] 27 | imgs_path = sorted(imgs_path) 28 | return imgs_path 29 | 30 | def __getitem__(self, index): 31 | img_path = self.imgs_path[index] 32 | 33 | # load source image 34 | src_img = self.get_img_by_path(img_path) 35 | src_img_tensor = self.img2tensor(src_img) 36 | src_aus = self.get_aus_by_path(img_path) 37 | 38 | # load target image 39 | tar_img_path = random.choice(self.imgs_path) 40 | tar_img = self.get_img_by_path(tar_img_path) 41 | tar_img_tensor = self.img2tensor(tar_img) 42 | tar_aus = self.get_aus_by_path(tar_img_path) 43 | if self.is_train and not self.opt.no_aus_noise: 44 | tar_aus = tar_aus + np.random.uniform(-0.1, 0.1, tar_aus.shape) 45 | 46 | # record paths for debug and test usage 47 | data_dict = {'src_img':src_img_tensor, 'src_aus':src_aus, 'tar_img':tar_img_tensor, 'tar_aus':tar_aus, \ 48 | 'src_path':img_path, 'tar_path':tar_img_path} 49 | 50 | return data_dict 51 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import random 5 | import numpy as np 6 | import pickle 7 | import torchvision.transforms as transforms 8 | 9 | 10 | 11 | class BaseDataset(torch.utils.data.Dataset): 12 | """docstring for BaseDataset""" 13 | def __init__(self): 14 | super(BaseDataset, self).__init__() 15 | 16 | def name(self): 17 | return os.path.basename(self.opt.data_root.strip('/')) 18 | 19 | def initialize(self, opt): 20 | self.opt = opt 21 | self.imgs_dir = os.path.join(self.opt.data_root, self.opt.imgs_dir) 22 | self.is_train = self.opt.mode == "train" 23 | 24 | # load images path 25 | filename = self.opt.train_csv if self.is_train else self.opt.test_csv 26 | self.imgs_name_file = os.path.join(self.opt.data_root, filename) 27 | self.imgs_path = self.make_dataset() 28 | 29 | # load AUs dicitionary 30 | aus_pkl = os.path.join(self.opt.data_root, self.opt.aus_pkl) 31 | self.aus_dict = self.load_dict(aus_pkl) 32 | 33 | # load image to tensor transformer 34 | self.img2tensor = self.img_transformer() 35 | 36 | def make_dataset(self): 37 | return None 38 | 39 | def load_dict(self, pkl_path): 40 | saved_dict = {} 41 | with open(pkl_path, 'rb') as f: 42 | saved_dict = pickle.load(f, encoding='latin1') 43 | return saved_dict 44 | 45 | def get_img_by_path(self, img_path): 46 | assert os.path.isfile(img_path), "Cannot find image file: %s" % img_path 47 | img_type = 'L' if self.opt.img_nc == 1 else 'RGB' 48 | return Image.open(img_path).convert(img_type) 49 | 50 | def get_aus_by_path(self, img_path): 51 | return None 52 | 53 | def img_transformer(self): 54 | transform_list = [] 55 | if self.opt.resize_or_crop == 'resize_and_crop': 56 | transform_list.append(transforms.Resize([self.opt.load_size, self.opt.load_size], Image.BICUBIC)) 57 | transform_list.append(transforms.RandomCrop(self.opt.final_size)) 58 | elif self.opt.resize_or_crop == 'crop': 59 | transform_list.append(transforms.RandomCrop(self.opt.final_size)) 60 | elif self.opt.resize_or_crop == 'none': 61 | transform_list.append(transforms.Lambda(lambda image: image)) 62 | else: 63 | raise ValueError("--resize_or_crop %s is not a valid option." % self.opt.resize_or_crop) 64 | 65 | if self.is_train and not self.opt.no_flip: 66 | transform_list.append(transforms.RandomHorizontalFlip()) 67 | 68 | transform_list.append(transforms.ToTensor()) 69 | transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) 70 | 71 | img2tensor = transforms.Compose(transform_list) 72 | 73 | return img2tensor 74 | 75 | def __len__(self): 76 | return len(self.imgs_path) 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import math 5 | from PIL import Image 6 | # import matplotlib.pyplot as plt 7 | 8 | 9 | 10 | class Visualizer(object): 11 | """docstring for Visualizer""" 12 | def __init__(self): 13 | super(Visualizer, self).__init__() 14 | 15 | def initialize(self, opt): 16 | self.opt = opt 17 | # self.vis_saved_dir = os.path.join(self.opt.ckpt_dir, 'vis_pics') 18 | # if not os.path.isdir(self.vis_saved_dir): 19 | # os.makedirs(self.vis_saved_dir) 20 | # plt.switch_backend('agg') 21 | 22 | self.display_id = self.opt.visdom_display_id 23 | if self.display_id > 0: 24 | import visdom 25 | self.ncols = 8 26 | self.vis = visdom.Visdom(server="http://localhost", port=self.opt.visdom_port, env=self.opt.visdom_env) 27 | 28 | def throw_visdom_connection_error(self): 29 | print('\n\nno visdom server.') 30 | exit(1) 31 | 32 | def print_losses_info(self, info_dict): 33 | msg = '[{}][Epoch: {:0>3}/{:0>3}; Images: {:0>4}/{:0>4}; Time: {:.3f}s/Batch({}); LR: {:.7f}] '.format( 34 | self.opt.name, info_dict['epoch'], info_dict['epoch_len'], 35 | info_dict['epoch_steps'], info_dict['epoch_steps_len'], 36 | info_dict['step_time'], self.opt.batch_size, info_dict['cur_lr']) 37 | for k, v in info_dict['losses'].items(): 38 | msg += '| {}: {:.4f} '.format(k, v) 39 | msg += '|' 40 | print(msg) 41 | with open(info_dict['log_path'], 'a+') as f: 42 | f.write(msg + '\n') 43 | 44 | def display_current_losses(self, epoch, counter_ratio, losses_dict): 45 | if not hasattr(self, 'plot_data'): 46 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses_dict.keys())} 47 | self.plot_data['X'].append(epoch + counter_ratio) 48 | self.plot_data['Y'].append([losses_dict[k] for k in self.plot_data['legend']]) 49 | try: 50 | self.vis.line( 51 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 52 | Y=np.array(self.plot_data['Y']), 53 | opts={ 54 | 'title': self.opt.name + ' loss over time', 55 | 'legend':self.plot_data['legend'], 56 | 'xlabel':'epoch', 57 | 'ylabel':'loss'}, 58 | win=self.display_id) 59 | except ConnectionError: 60 | self.throw_visdom_connection_error() 61 | 62 | def display_online_results(self, visuals, epoch): 63 | win_id = self.display_id + 24 64 | images = [] 65 | labels = [] 66 | for label, image in visuals.items(): 67 | if 'mask' in label: # or 'focus' in label: 68 | image = (image - 0.5) / 0.5 # convert map from [0, 1] to [-1, 1] 69 | image_numpy = self.tensor2im(image) 70 | images.append(image_numpy.transpose([2, 0, 1])) 71 | labels.append(label) 72 | try: 73 | title = ' || '.join(labels) 74 | self.vis.images(images, nrow=self.ncols, win=win_id, 75 | padding=5, opts=dict(title=title)) 76 | except ConnectionError: 77 | self.throw_visdom_connection_error() 78 | 79 | # utils 80 | def tensor2im(self, input_image, imtype=np.uint8): 81 | if isinstance(input_image, torch.Tensor): 82 | image_tensor = input_image.data 83 | else: 84 | return input_image 85 | image_numpy = image_tensor[0].cpu().float().numpy() 86 | im = self.numpy2im(image_numpy, imtype).resize((80, 80), Image.ANTIALIAS) 87 | return np.array(im) 88 | 89 | def numpy2im(self, image_numpy, imtype=np.uint8): 90 | if image_numpy.shape[0] == 1: 91 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 92 | # input should be [0, 1] 93 | #image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 94 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) / 2. + 0.5) * 255.0 95 | # print(image_numpy.shape) 96 | image_numpy = image_numpy.astype(imtype) 97 | im = Image.fromarray(image_numpy) 98 | # im = Image.fromarray(image_numpy).resize((64, 64), Image.ANTIALIAS) 99 | return im # np.array(im) 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /model/stargan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import model_utils 4 | 5 | 6 | 7 | class StarGANModel(BaseModel): 8 | """docstring for StarGANModel""" 9 | def __init__(self): 10 | super(StarGANModel, self).__init__() 11 | self.name = "StarGAN" 12 | 13 | def initialize(self, opt): 14 | super(StarGANModel, self).initialize(opt) 15 | 16 | self.net_gen = model_utils.define_splitG(self.opt.img_nc, self.opt.aus_nc, self.opt.ngf, use_dropout=self.opt.use_dropout, 17 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 18 | self.models_name.append('gen') 19 | 20 | if self.is_train: 21 | self.net_dis = model_utils.define_splitD(self.opt.img_nc, self.opt.aus_nc, self.opt.final_size, self.opt.ndf, 22 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 23 | self.models_name.append('dis') 24 | 25 | if self.opt.load_epoch > 0: 26 | self.load_ckpt(self.opt.load_epoch) 27 | 28 | def setup(self): 29 | super(StarGANModel, self).setup() 30 | if self.is_train: 31 | # setup optimizer 32 | self.optim_gen = torch.optim.Adam(self.net_gen.parameters(), 33 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 34 | self.optims.append(self.optim_gen) 35 | self.optim_dis = torch.optim.Adam(self.net_dis.parameters(), 36 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 37 | self.optims.append(self.optim_dis) 38 | 39 | # setup schedulers 40 | self.schedulers = [model_utils.get_scheduler(optim, self.opt) for optim in self.optims] 41 | 42 | def feed_batch(self, batch): 43 | self.src_img = batch['src_img'].to(self.device) 44 | self.tar_aus = batch['tar_aus'].type(torch.FloatTensor).to(self.device) 45 | if self.is_train: 46 | self.src_aus = batch['src_aus'].type(torch.FloatTensor).to(self.device) 47 | self.tar_img = batch['tar_img'].to(self.device) 48 | 49 | def forward(self): 50 | # generate fake image 51 | self.fake_img, _, _ = self.net_gen(self.src_img, self.tar_aus) 52 | 53 | # reconstruct real image 54 | if self.is_train: 55 | self.rec_real_img, _, _ = self.net_gen(self.fake_img, self.src_aus) 56 | 57 | def backward_dis(self): 58 | # real image 59 | pred_real, self.pred_real_aus = self.net_dis(self.src_img) 60 | self.loss_dis_real = self.criterionGAN(pred_real, True) 61 | self.loss_dis_real_aus = self.criterionMSE(self.pred_real_aus, self.src_aus) 62 | 63 | # fake image, detach to stop backward to generator 64 | pred_fake, _ = self.net_dis(self.fake_img.detach()) 65 | self.loss_dis_fake = self.criterionGAN(pred_fake, False) 66 | 67 | # combine dis loss 68 | self.loss_dis = self.opt.lambda_dis * (self.loss_dis_fake + self.loss_dis_real) \ 69 | + self.opt.lambda_aus * self.loss_dis_real_aus 70 | if self.opt.gan_type == 'wgan-gp': 71 | self.loss_dis_gp = self.gradient_penalty(self.src_img, self.fake_img) 72 | self.loss_dis = self.loss_dis + self.opt.lambda_wgan_gp * self.loss_dis_gp 73 | 74 | # backward discriminator loss 75 | self.loss_dis.backward() 76 | 77 | def backward_gen(self): 78 | # original to target domain, should fake the discriminator 79 | pred_fake, self.pred_fake_aus = self.net_dis(self.fake_img) 80 | self.loss_gen_GAN = self.criterionGAN(pred_fake, True) 81 | self.loss_gen_fake_aus = self.criterionMSE(self.pred_fake_aus, self.tar_aus) 82 | 83 | # target to original domain reconstruct, identity loss 84 | self.loss_gen_rec = self.criterionL1(self.rec_real_img, self.src_img) 85 | 86 | # combine and backward G loss 87 | self.loss_gen = self.opt.lambda_dis * self.loss_gen_GAN \ 88 | + self.opt.lambda_aus * self.loss_gen_fake_aus \ 89 | + self.opt.lambda_rec * self.loss_gen_rec 90 | 91 | self.loss_gen.backward() 92 | 93 | def optimize_paras(self, train_gen): 94 | self.forward() 95 | # update discriminator 96 | self.set_requires_grad(self.net_dis, True) 97 | self.optim_dis.zero_grad() 98 | self.backward_dis() 99 | self.optim_dis.step() 100 | 101 | # update G if needed 102 | if train_gen: 103 | self.set_requires_grad(self.net_dis, False) 104 | self.optim_gen.zero_grad() 105 | self.backward_gen() 106 | self.optim_gen.step() 107 | 108 | def save_ckpt(self, epoch): 109 | # save the specific networks 110 | save_models_name = ['gen', 'dis'] 111 | return super(StarGANModel, self).save_ckpt(epoch, save_models_name) 112 | 113 | def load_ckpt(self, epoch): 114 | # load the specific part of networks 115 | load_models_name = ['gen'] 116 | if self.is_train: 117 | load_models_name.extend(['dis']) 118 | return super(StarGANModel, self).load_ckpt(epoch, load_models_name) 119 | 120 | def clean_ckpt(self, epoch): 121 | # load the specific part of networks 122 | load_models_name = ['gen', 'dis'] 123 | return super(StarGANModel, self).clean_ckpt(epoch, load_models_name) 124 | 125 | def get_latest_losses(self): 126 | get_losses_name = ['dis_fake', 'dis_real', 'dis_real_aus', 'gen_rec'] 127 | return super(StarGANModel, self).get_latest_losses(get_losses_name) 128 | 129 | def get_latest_visuals(self): 130 | visuals_name = ['src_img', 'tar_img', 'fake_img'] 131 | if self.is_train: 132 | visuals_name.extend(['rec_real_img']) 133 | return super(StarGANModel, self).get_latest_visuals(visuals_name) 134 | -------------------------------------------------------------------------------- /model/ganimation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import model_utils 4 | 5 | 6 | class GANimationModel(BaseModel): 7 | """docstring for GANimationModel""" 8 | def __init__(self): 9 | super(GANimationModel, self).__init__() 10 | self.name = "GANimation" 11 | 12 | def initialize(self, opt): 13 | super(GANimationModel, self).initialize(opt) 14 | 15 | self.net_gen = model_utils.define_splitG(self.opt.img_nc, self.opt.aus_nc, self.opt.ngf, use_dropout=self.opt.use_dropout, 16 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 17 | self.models_name.append('gen') 18 | 19 | if self.is_train: 20 | self.net_dis = model_utils.define_splitD(self.opt.img_nc, self.opt.aus_nc, self.opt.final_size, self.opt.ndf, 21 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 22 | self.models_name.append('dis') 23 | 24 | if self.opt.load_epoch > 0: 25 | self.load_ckpt(self.opt.load_epoch) 26 | 27 | def setup(self): 28 | super(GANimationModel, self).setup() 29 | if self.is_train: 30 | # setup optimizer 31 | self.optim_gen = torch.optim.Adam(self.net_gen.parameters(), 32 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 33 | self.optims.append(self.optim_gen) 34 | self.optim_dis = torch.optim.Adam(self.net_dis.parameters(), 35 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 36 | self.optims.append(self.optim_dis) 37 | 38 | # setup schedulers 39 | self.schedulers = [model_utils.get_scheduler(optim, self.opt) for optim in self.optims] 40 | 41 | def feed_batch(self, batch): 42 | self.src_img = batch['src_img'].to(self.device) 43 | self.tar_aus = batch['tar_aus'].type(torch.FloatTensor).to(self.device) 44 | if self.is_train: 45 | self.src_aus = batch['src_aus'].type(torch.FloatTensor).to(self.device) 46 | self.tar_img = batch['tar_img'].to(self.device) 47 | 48 | def forward(self): 49 | # generate fake image 50 | self.color_mask ,self.aus_mask, self.embed = self.net_gen(self.src_img, self.tar_aus) 51 | self.fake_img = self.aus_mask * self.src_img + (1 - self.aus_mask) * self.color_mask 52 | 53 | # reconstruct real image 54 | if self.is_train: 55 | self.rec_color_mask, self.rec_aus_mask, self.rec_embed = self.net_gen(self.fake_img, self.src_aus) 56 | self.rec_real_img = self.rec_aus_mask * self.fake_img + (1 - self.rec_aus_mask) * self.rec_color_mask 57 | 58 | def backward_dis(self): 59 | # real image 60 | pred_real, self.pred_real_aus = self.net_dis(self.src_img) 61 | self.loss_dis_real = self.criterionGAN(pred_real, True) 62 | self.loss_dis_real_aus = self.criterionMSE(self.pred_real_aus, self.src_aus) 63 | 64 | # fake image, detach to stop backward to generator 65 | pred_fake, _ = self.net_dis(self.fake_img.detach()) 66 | self.loss_dis_fake = self.criterionGAN(pred_fake, False) 67 | 68 | # combine dis loss 69 | self.loss_dis = self.opt.lambda_dis * (self.loss_dis_fake + self.loss_dis_real) \ 70 | + self.opt.lambda_aus * self.loss_dis_real_aus 71 | if self.opt.gan_type == 'wgan-gp': 72 | self.loss_dis_gp = self.gradient_penalty(self.src_img, self.fake_img) 73 | self.loss_dis = self.loss_dis + self.opt.lambda_wgan_gp * self.loss_dis_gp 74 | 75 | # backward discriminator loss 76 | self.loss_dis.backward() 77 | 78 | def backward_gen(self): 79 | # original to target domain, should fake the discriminator 80 | pred_fake, self.pred_fake_aus = self.net_dis(self.fake_img) 81 | self.loss_gen_GAN = self.criterionGAN(pred_fake, True) 82 | self.loss_gen_fake_aus = self.criterionMSE(self.pred_fake_aus, self.tar_aus) 83 | 84 | # target to original domain reconstruct, identity loss 85 | self.loss_gen_rec = self.criterionL1(self.rec_real_img, self.src_img) 86 | 87 | # constrain on AUs mask 88 | self.loss_gen_mask_real_aus = torch.mean(self.aus_mask) 89 | self.loss_gen_mask_fake_aus = torch.mean(self.rec_aus_mask) 90 | self.loss_gen_smooth_real_aus = self.criterionTV(self.aus_mask) 91 | self.loss_gen_smooth_fake_aus = self.criterionTV(self.rec_aus_mask) 92 | 93 | # combine and backward G loss 94 | self.loss_gen = self.opt.lambda_dis * self.loss_gen_GAN \ 95 | + self.opt.lambda_aus * self.loss_gen_fake_aus \ 96 | + self.opt.lambda_rec * self.loss_gen_rec \ 97 | + self.opt.lambda_mask * (self.loss_gen_mask_real_aus + self.loss_gen_mask_fake_aus) \ 98 | + self.opt.lambda_tv * (self.loss_gen_smooth_real_aus + self.loss_gen_smooth_fake_aus) 99 | 100 | self.loss_gen.backward() 101 | 102 | def optimize_paras(self, train_gen): 103 | self.forward() 104 | # update discriminator 105 | self.set_requires_grad(self.net_dis, True) 106 | self.optim_dis.zero_grad() 107 | self.backward_dis() 108 | self.optim_dis.step() 109 | 110 | # update G if needed 111 | if train_gen: 112 | self.set_requires_grad(self.net_dis, False) 113 | self.optim_gen.zero_grad() 114 | self.backward_gen() 115 | self.optim_gen.step() 116 | 117 | def save_ckpt(self, epoch): 118 | # save the specific networks 119 | save_models_name = ['gen', 'dis'] 120 | return super(GANimationModel, self).save_ckpt(epoch, save_models_name) 121 | 122 | def load_ckpt(self, epoch): 123 | # load the specific part of networks 124 | load_models_name = ['gen'] 125 | if self.is_train: 126 | load_models_name.extend(['dis']) 127 | return super(GANimationModel, self).load_ckpt(epoch, load_models_name) 128 | 129 | def clean_ckpt(self, epoch): 130 | # load the specific part of networks 131 | load_models_name = ['gen', 'dis'] 132 | return super(GANimationModel, self).clean_ckpt(epoch, load_models_name) 133 | 134 | def get_latest_losses(self): 135 | get_losses_name = ['dis_fake', 'dis_real', 'dis_real_aus', 'gen_rec'] 136 | return super(GANimationModel, self).get_latest_losses(get_losses_name) 137 | 138 | def get_latest_visuals(self): 139 | visuals_name = ['src_img', 'tar_img', 'color_mask', 'aus_mask', 'fake_img'] 140 | if self.is_train: 141 | visuals_name.extend(['rec_color_mask', 'rec_aus_mask', 'rec_real_img']) 142 | return super(GANimationModel, self).get_latest_visuals(visuals_name) 143 | -------------------------------------------------------------------------------- /solvers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Dec 13, 2018 3 | @author: Yuedong Chen 4 | """ 5 | 6 | from data import create_dataloader 7 | from model import create_model 8 | from visualizer import Visualizer 9 | import copy 10 | import time 11 | import os 12 | import torch 13 | import numpy as np 14 | from PIL import Image 15 | 16 | 17 | def create_solver(opt): 18 | instance = Solver() 19 | instance.initialize(opt) 20 | return instance 21 | 22 | 23 | 24 | class Solver(object): 25 | """docstring for Solver""" 26 | def __init__(self): 27 | super(Solver, self).__init__() 28 | 29 | def initialize(self, opt): 30 | self.opt = opt 31 | self.visual = Visualizer() 32 | self.visual.initialize(self.opt) 33 | 34 | def run_solver(self): 35 | if self.opt.mode == "train": 36 | self.train_networks() 37 | else: 38 | self.test_networks(self.opt) 39 | 40 | def train_networks(self): 41 | # init train setting 42 | self.init_train_setting() 43 | 44 | # for every epoch 45 | for epoch in range(self.opt.epoch_count, self.epoch_len + 1): 46 | # train network 47 | self.train_epoch(epoch) 48 | # update learning rate 49 | self.cur_lr = self.train_model.update_learning_rate() 50 | # save checkpoint if needed 51 | if epoch % self.opt.save_epoch_freq == 0: 52 | self.train_model.save_ckpt(epoch) 53 | 54 | # save the last epoch 55 | self.train_model.save_ckpt(self.epoch_len) 56 | 57 | def init_train_setting(self): 58 | self.train_dataset = create_dataloader(self.opt) 59 | self.train_model = create_model(self.opt) 60 | 61 | self.train_total_steps = 0 62 | self.epoch_len = self.opt.niter + self.opt.niter_decay 63 | self.cur_lr = self.opt.lr 64 | 65 | def train_epoch(self, epoch): 66 | epoch_start_time = time.time() 67 | epoch_steps = 0 68 | 69 | last_print_step_t = time.time() 70 | for idx, batch in enumerate(self.train_dataset): 71 | 72 | self.train_total_steps += self.opt.batch_size 73 | epoch_steps += self.opt.batch_size 74 | # train network 75 | self.train_model.feed_batch(batch) 76 | self.train_model.optimize_paras(train_gen=(idx % self.opt.train_gen_iter == 0)) 77 | # print losses 78 | if self.train_total_steps % self.opt.print_losses_freq == 0: 79 | cur_losses = self.train_model.get_latest_losses() 80 | avg_step_t = (time.time() - last_print_step_t) / self.opt.print_losses_freq 81 | last_print_step_t = time.time() 82 | # print loss info to command line 83 | info_dict = {'epoch': epoch, 'epoch_len': self.epoch_len, 84 | 'epoch_steps': idx * self.opt.batch_size, 'epoch_steps_len': len(self.train_dataset), 85 | 'step_time': avg_step_t, 'cur_lr': self.cur_lr, 86 | 'log_path': os.path.join(self.opt.ckpt_dir, self.opt.log_file), 87 | 'losses': cur_losses 88 | } 89 | self.visual.print_losses_info(info_dict) 90 | 91 | # plot loss map to visdom 92 | if self.train_total_steps % self.opt.plot_losses_freq == 0 and self.visual.display_id > 0: 93 | cur_losses = self.train_model.get_latest_losses() 94 | epoch_steps = idx * self.opt.batch_size 95 | self.visual.display_current_losses(epoch - 1, epoch_steps / len(self.train_dataset), cur_losses) 96 | 97 | # display image on visdom 98 | if self.train_total_steps % self.opt.sample_img_freq == 0 and self.visual.display_id > 0: 99 | cur_vis = self.train_model.get_latest_visuals() 100 | self.visual.display_online_results(cur_vis, epoch) 101 | # latest_aus = model.get_latest_aus() 102 | # visual.log_aus(epoch, epoch_steps, latest_aus, opt.ckpt_dir) 103 | 104 | def test_networks(self, opt): 105 | self.init_test_setting(opt) 106 | self.test_ops() 107 | 108 | def init_test_setting(self, opt): 109 | self.test_dataset = create_dataloader(opt) 110 | self.test_model = create_model(opt) 111 | 112 | def test_ops(self): 113 | for batch_idx, batch in enumerate(self.test_dataset): 114 | with torch.no_grad(): 115 | # interpolate several times 116 | faces_list = [batch['src_img'].float().numpy()] 117 | paths_list = [batch['src_path'], batch['tar_path']] 118 | for idx in range(self.opt.interpolate_len): 119 | cur_alpha = (idx + 1.) / float(self.opt.interpolate_len) 120 | cur_tar_aus = cur_alpha * batch['tar_aus'] + (1 - cur_alpha) * batch['src_aus'] 121 | # print(batch['src_aus']) 122 | # print(cur_tar_aus) 123 | test_batch = {'src_img': batch['src_img'], 'tar_aus': cur_tar_aus, 'src_aus':batch['src_aus'], 'tar_img':batch['tar_img']} 124 | 125 | self.test_model.feed_batch(test_batch) 126 | self.test_model.forward() 127 | 128 | cur_gen_faces = self.test_model.fake_img.cpu().float().numpy() 129 | faces_list.append(cur_gen_faces) 130 | faces_list.append(batch['tar_img'].float().numpy()) 131 | self.test_save_imgs(faces_list, paths_list) 132 | 133 | def test_save_imgs(self, faces_list, paths_list): 134 | for idx in range(len(paths_list[0])): 135 | src_name = os.path.splitext(os.path.basename(paths_list[0][idx]))[0] 136 | tar_name = os.path.splitext(os.path.basename(paths_list[1][idx]))[0] 137 | 138 | if self.opt.save_test_gif: 139 | import imageio 140 | imgs_numpy_list = [] 141 | for face_idx in range(len(faces_list) - 1): # remove target image 142 | cur_numpy = np.array(self.visual.numpy2im(faces_list[face_idx][idx])) 143 | imgs_numpy_list.extend([cur_numpy for _ in range(3)]) 144 | saved_path = os.path.join(self.opt.results, "%s_%s.gif" % (src_name, tar_name)) 145 | imageio.mimsave(saved_path, imgs_numpy_list) 146 | else: 147 | # concate src, inters, tar faces 148 | concate_img = np.array(self.visual.numpy2im(faces_list[0][idx])) 149 | for face_idx in range(1, len(faces_list)): 150 | concate_img = np.concatenate((concate_img, np.array(self.visual.numpy2im(faces_list[face_idx][idx]))), axis=1) 151 | concate_img = Image.fromarray(concate_img) 152 | # save image 153 | saved_path = os.path.join(self.opt.results, "%s_%s.jpg" % (src_name, tar_name)) 154 | concate_img.save(saved_path) 155 | 156 | print("[Success] Saved images to %s" % saved_path) 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | import random 5 | from . import model_utils 6 | 7 | 8 | class BaseModel: 9 | """docstring for BaseModel""" 10 | def __init__(self): 11 | super(BaseModel, self).__init__() 12 | self.name = "Base" 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = self.opt.gpu_ids 17 | self.device = torch.device('cuda:%d' % self.gpu_ids[0] if self.gpu_ids else 'cpu') 18 | self.is_train = self.opt.mode == "train" 19 | # inherit to define network model 20 | self.models_name = [] 21 | 22 | def setup(self): 23 | print("%s with Model [%s]" % (self.opt.mode.capitalize(), self.name)) 24 | if self.is_train: 25 | self.set_train() 26 | # define loss function 27 | self.criterionGAN = model_utils.GANLoss(gan_type=self.opt.gan_type).to(self.device) 28 | self.criterionL1 = torch.nn.L1Loss().to(self.device) 29 | self.criterionMSE = torch.nn.MSELoss().to(self.device) 30 | self.criterionTV = model_utils.TVLoss().to(self.device) 31 | torch.nn.DataParallel(self.criterionGAN, self.gpu_ids) 32 | torch.nn.DataParallel(self.criterionL1, self.gpu_ids) 33 | torch.nn.DataParallel(self.criterionMSE, self.gpu_ids) 34 | torch.nn.DataParallel(self.criterionTV, self.gpu_ids) 35 | # inherit to set up train/val/test status 36 | self.losses_name = [] 37 | self.optims = [] 38 | self.schedulers = [] 39 | else: 40 | self.set_eval() 41 | 42 | def set_eval(self): 43 | print("Set model to Test state.") 44 | for name in self.models_name: 45 | if isinstance(name, str): 46 | net = getattr(self, 'net_' + name) 47 | if not self.opt.no_test_eval: 48 | net.eval() 49 | print("Set net_%s to EVAL." % name) 50 | else: 51 | net.train() 52 | self.is_train = False 53 | 54 | def set_train(self): 55 | print("Set model to Train state.") 56 | for name in self.models_name: 57 | if isinstance(name, str): 58 | net = getattr(self, 'net_' + name) 59 | net.train() 60 | print("Set net_%s to TRAIN." % name) 61 | self.is_train = True 62 | 63 | def set_requires_grad(self, parameters, requires_grad=False): 64 | if not isinstance(parameters, list): 65 | parameters = [parameters] 66 | for param in parameters: 67 | if param is not None: 68 | param.requires_grad = requires_grad 69 | 70 | def get_latest_visuals(self, visuals_name): 71 | visual_ret = OrderedDict() 72 | for name in visuals_name: 73 | if isinstance(name, str) and hasattr(self, name): 74 | visual_ret[name] = getattr(self, name) 75 | return visual_ret 76 | 77 | def get_latest_losses(self, losses_name): 78 | errors_ret = OrderedDict() 79 | for name in losses_name: 80 | if isinstance(name, str): 81 | cur_loss = float(getattr(self, 'loss_' + name)) 82 | # cur_loss_lambda = 1. if len(losses_name) == 1 else float(getattr(self.opt, 'lambda_' + name)) 83 | # errors_ret[name] = cur_loss * cur_loss_lambda 84 | errors_ret[name] = cur_loss 85 | return errors_ret 86 | 87 | def feed_batch(self, batch): 88 | pass 89 | 90 | def forward(self): 91 | pass 92 | 93 | def optimize_paras(self): 94 | pass 95 | 96 | def update_learning_rate(self): 97 | for scheduler in self.schedulers: 98 | scheduler.step() 99 | lr = self.optims[0].param_groups[0]['lr'] 100 | return lr 101 | 102 | def save_ckpt(self, epoch, models_name): 103 | for name in models_name: 104 | if isinstance(name, str): 105 | save_filename = '%s_net_%s.pth' % (epoch, name) 106 | save_path = os.path.join(self.opt.ckpt_dir, save_filename) 107 | net = getattr(self, 'net_' + name) 108 | # save cpu params, so that it can be used in other GPU settings 109 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 110 | torch.save(net.module.cpu().state_dict(), save_path) 111 | net.to(self.gpu_ids[0]) 112 | net = torch.nn.DataParallel(net, self.gpu_ids) 113 | else: 114 | torch.save(net.cpu().state_dict(), save_path) 115 | 116 | def load_ckpt(self, epoch, models_name): 117 | # print(models_name) 118 | for name in models_name: 119 | if isinstance(name, str): 120 | load_filename = '%s_net_%s.pth' % (epoch, name) 121 | load_path = os.path.join(self.opt.ckpt_dir, load_filename) 122 | assert os.path.isfile(load_path), "File '%s' does not exist." % load_path 123 | 124 | pretrained_state_dict = torch.load(load_path, map_location=str(self.device)) 125 | if hasattr(pretrained_state_dict, '_metadata'): 126 | del pretrained_state_dict._metadata 127 | 128 | net = getattr(self, 'net_' + name) 129 | if isinstance(net, torch.nn.DataParallel): 130 | net = net.module 131 | # load only existing keys 132 | pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if k in net.state_dict()} 133 | # for k, v in pretrained_state_dict.items(): 134 | # print(k) 135 | # assert False 136 | net.load_state_dict(pretrained_dict) 137 | print("[Info] Successfully load trained weights for net_%s." % name) 138 | 139 | def clean_ckpt(self, epoch, models_name): 140 | for name in models_name: 141 | if isinstance(name, str): 142 | load_filename = '%s_net_%s.pth' % (epoch, name) 143 | load_path = os.path.join(self.opt.ckpt_dir, load_filename) 144 | if os.path.isfile(load_path): 145 | os.remove(load_path) 146 | 147 | def gradient_penalty(self, input_img, generate_img): 148 | # interpolate sample 149 | alpha = torch.rand(input_img.size(0), 1, 1, 1).to(self.device) 150 | inter_img = (alpha * input_img.data + (1 - alpha) * generate_img.data).requires_grad_(True) 151 | inter_img_prob, _ = self.net_dis(inter_img) 152 | 153 | # computer gradient penalty: x: inter_img, y: inter_img_prob 154 | # (L2_norm(dy/dx) - 1)**2 155 | dydx = torch.autograd.grad(outputs=inter_img_prob, 156 | inputs=inter_img, 157 | grad_outputs=torch.ones(inter_img_prob.size()).to(self.device), 158 | retain_graph=True, 159 | create_graph=True, 160 | only_inputs=True)[0] 161 | dydx = dydx.view(dydx.size(0), -1) 162 | dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) 163 | return torch.mean((dydx_l2norm - 1) ** 2) 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GANimation -- An Out-of-the-Box Replicate 2 | 3 |

4 | Status 5 | Platform 6 | PyTorch 7 | License 8 |

9 | 10 | **A reimplementation of *[GANimation: Anatomically-aware Facial Animation from a Single Image](https://arxiv.org/abs/1807.09251)*, using PyTorch. Pretrained models/weights are available at [GDrive](https://drive.google.com/drive/folders/1DJHeeLwZ3OsbiesvH7UM44cB3NxhebxF) or [BaiduPan](https://pan.baidu.com/s/1eLGC6jhciBS8DDuw_Gkd7A)(Code:3fyb) !** 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | ![ganimation_show](imgs/ganimation_show.jpg) 22 | 23 | ## Pros (compared with the [official](https://github.com/albertpumarola/GANimation) implementation) 24 | 25 | * Codes are cleaner and well structured, inspired by the [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 26 | * Provide a more powerful test function for generating **linear interpolations** between two expressions as shown in the paper. 27 | * Provide a **preprocessed [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset**, including cropped faces, Action Units related to all cropped faces, train and test split. 28 | * Provide **pretrained models** for the above CelebA dataset (trained with ~145k images for 30 epoches). 29 | * Provide Action Units vectors for the [EmotionNet](https://cbcsl.ece.ohio-state.edu/EmotionNetChallenge/index.html) extracted using [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace). 30 | * Provide **pretrained models** for the EmotionNet dataset (trained with ~410k images for 30 epoches). 31 | 32 | All resources related to this project are located at **[GDrive](https://drive.google.com/drive/folders/1DJHeeLwZ3OsbiesvH7UM44cB3NxhebxF)** or **[BaiduPan](https://pan.baidu.com/s/1eLGC6jhciBS8DDuw_Gkd7A)(Code:3fyb)**. 33 | 34 | ## Getting Started 35 | 36 | ### Requirements 37 | 38 | * Python 3 39 | * PyTorch 0.4.1 40 | * visdom (optional, only for training with browser visualizer) 41 | * imageio (optional, only for generating GIF image in testing) 42 | 43 | ### Installation 44 | 45 | * Clone this repo: 46 | 47 | ``` 48 | git clone https://github.com/donydchen/ganimation_replicate.git 49 | cd ganimation_replicate 50 | pip install -r requirements.txt 51 | ``` 52 | 53 | ### Resources 54 | 55 | * All resources related to this project are located at **[GDrive](https://drive.google.com/drive/folders/1DJHeeLwZ3OsbiesvH7UM44cB3NxhebxF)** or **[BaiduPan](https://pan.baidu.com/s/1eLGC6jhciBS8DDuw_Gkd7A)(Code:3fyb)**. 56 | * Download `datasets` and put it in the root path of this project. 57 | * Download `ckpts` and put it in the root path of this project. (optional, only for test or finetune) 58 | * Note: for the EmotionNet, the AU vectors are saved as a dictionary, where the key is the file name (without extension), and dumped into a pickle file. 59 | 60 | ### Train 61 | 62 | * To view training results and loss plots, run `python -m visdom.server` and click the URL [http://localhost:8097](http://localhost:8097) 63 | 64 | ``` 65 | python main.py --data_root [path_to_dataset] 66 | 67 | # e.g. python main.py --data_root datasets/celebA --gpu_ids 0,1 --sample_img_freq 500 68 | # python main.py --data_root datasets/emotionNet --gpu_ids 0,1 --sample_img_freq 500 69 | # set '--visdom_display_id 0' if you don't want to use visdom 70 | # use 'python main.py -h' to check out more options. 71 | ``` 72 | 73 | ### Test 74 | 75 | * Make sure you have trained the model or downloaded the pretrained model. 76 | 77 | ``` 78 | python main.py --mode test --data_root [path_to_dataset] --ckpt_dir [path_to_pretrained_model] --load_epoch [epoch_num] 79 | 80 | # e.g. python main.py --mode test --data_root datasets/celebA --batch_size 8 --max_dataset_size 150 --gpu_ids 0,1 --ckpt_dir ckpts/celebA/ganimation/190327_161852/ --load_epoch 30 81 | # set '--interpolate_len 1' if you don't need linear interpolation. 82 | # use '--save_test_gif' to generate animated images. 83 | ``` 84 | 85 | ### Finetune 86 | 87 | ``` 88 | python main.py --data_root [path_to_dataset] --ckpt_dir [path_to_existing_checkpoint] --load_epoch [epoch_num] 89 | 90 | # e.g. python main.py --data_root datasets/celebA --gpu_ids 0,1 --sample_img_freq 300 --n_threads 18 --ckpt_dir ckpts/celebA/ganimation/190327_161852 --load_epoch 30 --epoch_count 31 --niter 30 --niter_decay 10 91 | ``` 92 | 93 | ### Use Own Datasets 94 | 95 | * **Crop Face:** Use [face_recognition](https://github.com/ageitgey/face_recognition) to extract face bounding box and crop face from images. 96 | * **Obtain AUs Vector:** Use [OpenFace](https://github.com/TadasBaltrusaitis/OpenFace) to extract Action Units vectors from the above cropped face. Specifically, only the AUs intensity is used in this project, namely `AU01_r, AU02_r, AU04_r, AU05_r, AU06_r, AU07_r, AU09_r, AU10_r, AU12_r, AU14_r, AU15_r, AU17_r, AU20_r, AU23_r, AU25_r, AU26_r, AU45_r`. 97 | 98 | ``` 99 | ./FaceLandmarkImg -f [path_to_img] -aus 100 | 101 | # In the result file, values of columns [2:19] are extracted for later usage. 102 | ``` 103 | 104 | * **Download Pretrained Model:** Since in this project, the EmotionNet employed for training contains more than 400k in-the-wild face images, the pretrained model should meet the requirements of lots of scenes. You're recommended to directly try to apply the EmotionNet pretrained model on your own datasets. 105 | 106 | ## Some Results 107 | 108 | ### CelebA 109 | 110 | **Training** 111 | 112 | ![celeba_training](imgs/celeba_training.jpg) 113 | 114 | **Testing** (with *GANimation* model, on epoch 30) 115 | 116 | ![celeba_testing](imgs/celeba_testing.jpg) 117 | 118 | **Testing** (with *StarGAN* model, on epoch 30) 119 | 120 | ![celeba_stargan_testing](imgs/celeba_stargan_testing.jpg) 121 | 122 | ### EmotionNet (Visual quality is much better than that of CelebA) 123 | 124 | **Training** 125 | 126 | ![emotionnet_training](imgs/emotionnet_training.jpg) 127 | 128 | **Testing** (with *GANimation* model, on epoch 30) 129 | 130 | ![emotionnet_testing](imgs/emotionnet_testing.jpg) 131 | 132 | **Testing** (with *StarGAN* model, on epoch 30) 133 | 134 | ![emotionnet_stargan_testing](imgs/emotionnet_stargan_testing.jpg) 135 | 136 | ## Why this Project? 137 | 138 | My [mentor](https://jianfeng1991.github.io/personal/) came up with a fancy idea of playing GANs with AUs when I was an intern at AI Lab, [Lenovo Research](http://research.lenovo.com/webapp/view_English/index.html) around early August, 2018. I enjoyed the idea very much and started working on it. However, just a few days after that, the GANimation paper showed up, which was not a good news for us... So I tried to replicate GANimation, and this is the start of this project. 139 | 140 | And in late August, 2018, I came accross an [issue](https://github.com/albertpumarola/GANimation/issues/22) on the official GANimation implementation, claiming that the test result is wrong. While in my case, I did get some reasonable results, so I replied that issue with the results I had got. Since the author of GANimation hadn't decided to release the pretrained model yet, I recieved Emails inquiring me whether I could offer my codes and pretrained models from time to time. 141 | 142 | I really wanted to provide the codes and pretrained models. However, I was very busy in the past few months, moving from Beijing to Singapore, working for paper deadlines, so on and so forth. So the codes remained in the server of Lenovo Research for half an year. And these days, I finally got some free time. So I dug out the codes, cleaned them, retrained the network, and now, I make them public. I will keep updating this project if I have time, and hope that these codes can serve to faciliate the research of someone who are working on the related tasks. 143 | 144 | Feel free to contact me if you need any help from me related to this project. 145 | 146 | ## Pull Request 147 | 148 | You are always welcome to contribute to this repository by sending a [pull request](https://help.github.com/articles/about-pull-requests/). 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | from datetime import datetime 5 | import time 6 | import torch 7 | import random 8 | import numpy as np 9 | import sys 10 | 11 | 12 | 13 | class Options(object): 14 | """docstring for Options""" 15 | def __init__(self): 16 | super(Options, self).__init__() 17 | 18 | def initialize(self): 19 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | parser.add_argument('--mode', type=str, default='train', help='Mode of code. [train|test]') 21 | parser.add_argument('--model', type=str, default='ganimation', help='[ganimation|stargan], see model.__init__ from more details.') 22 | parser.add_argument('--lucky_seed', type=int, default=0, help='seed for random initialize, 0 to use current time.') 23 | parser.add_argument('--visdom_env', type=str, default="main", help='visdom env.') 24 | parser.add_argument('--visdom_port', type=int, default=8097, help='visdom port.') 25 | parser.add_argument('--visdom_display_id', type=int, default=1, help='set value larger than 0 to display with visdom.') 26 | 27 | parser.add_argument('--results', type=str, default="results", help='save test results to this path.') 28 | parser.add_argument('--interpolate_len', type=int, default=5, help='interpolate length for test.') 29 | parser.add_argument('--no_test_eval', action='store_true', help='do not use eval mode during test time.') 30 | parser.add_argument('--save_test_gif', action='store_true', help='save gif images instead of the concatenation of static images.') 31 | 32 | parser.add_argument('--data_root', required=True, help='paths to data set.') 33 | parser.add_argument('--imgs_dir', type=str, default="imgs", help='path to image') 34 | parser.add_argument('--aus_pkl', type=str, default="aus_openface.pkl", help='AUs pickle dictionary.') 35 | parser.add_argument('--train_csv', type=str, default="train_ids.csv", help='train images paths') 36 | parser.add_argument('--test_csv', type=str, default="test_ids.csv", help='test images paths') 37 | 38 | parser.add_argument('--batch_size', type=int, default=25, help='input batch size.') 39 | parser.add_argument('--serial_batches', action='store_true', help='if specified, input images in order.') 40 | parser.add_argument('--n_threads', type=int, default=6, help='number of workers to load data.') 41 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='maximum number of samples.') 42 | 43 | parser.add_argument('--resize_or_crop', type=str, default='none', help='Preprocessing image, [resize_and_crop|crop|none]') 44 | parser.add_argument('--load_size', type=int, default=148, help='scale image to this size.') 45 | parser.add_argument('--final_size', type=int, default=128, help='crop image to this size.') 46 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip image.') 47 | parser.add_argument('--no_aus_noise', action='store_true', help='if specified, add noise to target AUs.') 48 | 49 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids, eg. 0,1,2; -1 for cpu.') 50 | parser.add_argument('--ckpt_dir', type=str, default='./ckpts', help='directory to save check points.') 51 | parser.add_argument('--load_epoch', type=int, default=0, help='load epoch; 0: do not load') 52 | parser.add_argument('--log_file', type=str, default="logs.txt", help='log loss') 53 | parser.add_argument('--opt_file', type=str, default="opt.txt", help='options file') 54 | 55 | # train options 56 | parser.add_argument('--img_nc', type=int, default=3, help='image number of channel') 57 | parser.add_argument('--aus_nc', type=int, default=17, help='aus number of channel') 58 | parser.add_argument('--ngf', type=int, default=64, help='ngf') 59 | parser.add_argument('--ndf', type=int, default=64, help='ndf') 60 | parser.add_argument('--use_dropout', action='store_true', help='if specified, use dropout.') 61 | 62 | parser.add_argument('--gan_type', type=str, default='wgan-gp', help='GAN loss [wgan-gp|lsgan|gan]') 63 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 64 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 65 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [batch|instance|none]') 66 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 67 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 68 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 69 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 70 | 71 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 72 | parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate') 73 | parser.add_argument('--niter_decay', type=int, default=10, help='# of iter to linearly decay learning rate to zero') 74 | 75 | # loss options 76 | parser.add_argument('--lambda_dis', type=float, default=1.0, help='discriminator weight in loss') 77 | parser.add_argument('--lambda_aus', type=float, default=160.0, help='AUs weight in loss') 78 | parser.add_argument('--lambda_rec', type=float, default=10.0, help='reconstruct loss weight') 79 | parser.add_argument('--lambda_mask', type=float, default=0, help='mse loss weight') 80 | parser.add_argument('--lambda_tv', type=float, default=0, help='total variation loss weight') 81 | parser.add_argument('--lambda_wgan_gp', type=float, default=10., help='wgan gradient penalty weight') 82 | 83 | # frequency options 84 | parser.add_argument('--train_gen_iter', type=int, default=5, help='train G every n interations.') 85 | parser.add_argument('--print_losses_freq', type=int, default=100, help='print log every print_freq step.') 86 | parser.add_argument('--plot_losses_freq', type=int, default=20000, help='plot log every plot_freq step.') 87 | parser.add_argument('--sample_img_freq', type=int, default=2000, help='draw image every sample_img_freq step.') 88 | parser.add_argument('--save_epoch_freq', type=int, default=2, help='save checkpoint every save_epoch_freq epoch.') 89 | 90 | return parser 91 | 92 | def parse(self): 93 | parser = self.initialize() 94 | parser.set_defaults(name=datetime.now().strftime("%y%m%d_%H%M%S")) 95 | opt = parser.parse_args() 96 | 97 | dataset_name = os.path.basename(opt.data_root.strip('/')) 98 | # update checkpoint dir 99 | if opt.mode == 'train' and opt.load_epoch == 0: 100 | opt.ckpt_dir = os.path.join(opt.ckpt_dir, dataset_name, opt.model, opt.name) 101 | if not os.path.exists(opt.ckpt_dir): 102 | os.makedirs(opt.ckpt_dir) 103 | 104 | # if test, disable visdom, update results path 105 | if opt.mode == "test": 106 | opt.visdom_display_id = 0 107 | opt.results = os.path.join(opt.results, "%s_%s_%s" % (dataset_name, opt.model, opt.load_epoch)) 108 | if not os.path.exists(opt.results): 109 | os.makedirs(opt.results) 110 | 111 | # set gpu device 112 | str_ids = opt.gpu_ids.split(',') 113 | opt.gpu_ids = [] 114 | for str_id in str_ids: 115 | cur_id = int(str_id) 116 | if cur_id >= 0: 117 | opt.gpu_ids.append(cur_id) 118 | if len(opt.gpu_ids) > 0: 119 | torch.cuda.set_device(opt.gpu_ids[0]) 120 | 121 | # set seed 122 | if opt.lucky_seed == 0: 123 | opt.lucky_seed = int(time.time()) 124 | random.seed(a=opt.lucky_seed) 125 | np.random.seed(seed=opt.lucky_seed) 126 | torch.manual_seed(opt.lucky_seed) 127 | if len(opt.gpu_ids) > 0: 128 | torch.backends.cudnn.deterministic = True 129 | torch.backends.cudnn.benchmark = False 130 | torch.cuda.manual_seed(opt.lucky_seed) 131 | torch.cuda.manual_seed_all(opt.lucky_seed) 132 | 133 | # write command to file 134 | script_dir = opt.ckpt_dir 135 | with open(os.path.join(os.path.join(script_dir, "run_script.sh")), 'a+') as f: 136 | f.write("[%5s][%s]python %s\n" % (opt.mode, opt.name, ' '.join(sys.argv))) 137 | 138 | # print and write options file 139 | msg = '' 140 | msg += '------------------- [%5s][%s]Options --------------------\n' % (opt.mode, opt.name) 141 | for k, v in sorted(vars(opt).items()): 142 | comment = '' 143 | default_v = parser.get_default(k) 144 | if v != default_v: 145 | comment = '\t[default: %s]' % str(default_v) 146 | msg += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 147 | msg += '--------------------- [%5s][%s]End ----------------------\n' % (opt.mode, opt.name) 148 | print(msg) 149 | with open(os.path.join(os.path.join(script_dir, "opt.txt")), 'a+') as f: 150 | f.write(msg + '\n\n') 151 | 152 | return opt 153 | 154 | 155 | 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /model/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | from collections import OrderedDict 7 | 8 | 9 | ''' 10 | Helper functions for model 11 | Borrow tons of code from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 12 | ''' 13 | 14 | def get_norm_layer(norm_type='instance'): 15 | """Return a normalization layer 16 | Parameters: 17 | norm_type (str) -- the name of the normalization layer: batch | instance | none 18 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 19 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 20 | """ 21 | if norm_type == 'batch': 22 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 23 | elif norm_type == 'instance': 24 | # change default flag, make sure instance norm behave as the same in both train and eval 25 | # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/395 26 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 27 | elif norm_type == 'none': 28 | norm_layer = None 29 | else: 30 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 31 | return norm_layer 32 | 33 | 34 | def get_scheduler(optimizer, opt): 35 | if opt.lr_policy == 'lambda': 36 | def lambda_rule(epoch): 37 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 38 | return lr_l 39 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 40 | elif opt.lr_policy == 'step': 41 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 42 | elif opt.lr_policy == 'plateau': 43 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 44 | else: 45 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 46 | return scheduler 47 | 48 | 49 | def init_weights(net, init_type='normal', gain=0.02): 50 | def init_func(m): 51 | classname = m.__class__.__name__ 52 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 53 | if init_type == 'normal': 54 | init.normal_(m.weight.data, 0.0, gain) 55 | elif init_type == 'xavier': 56 | init.xavier_normal_(m.weight.data, gain=gain) 57 | elif init_type == 'kaiming': 58 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 59 | elif init_type == 'orthogonal': 60 | init.orthogonal_(m.weight.data, gain=gain) 61 | else: 62 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 63 | if hasattr(m, 'bias') and m.bias is not None: 64 | init.constant_(m.bias.data, 0.0) 65 | elif classname.find('BatchNorm2d') != -1: 66 | init.normal_(m.weight.data, 1.0, gain) 67 | init.constant_(m.bias.data, 0.0) 68 | 69 | print('initialize network with %s' % init_type) 70 | net.apply(init_func) 71 | 72 | 73 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 74 | if len(gpu_ids) > 0: 75 | # print("gpu_ids,", gpu_ids) 76 | assert(torch.cuda.is_available()) 77 | net.to(gpu_ids[0]) 78 | net = torch.nn.DataParallel(net, gpu_ids) 79 | init_weights(net, init_type, gain=init_gain) 80 | return net 81 | 82 | 83 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 84 | netG = None 85 | norm_layer = get_norm_layer(norm_type=norm) 86 | 87 | if which_model_netG == 'resnet_9blocks': 88 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 89 | elif which_model_netG == 'resnet_6blocks': 90 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 91 | elif which_model_netG == 'unet_128': 92 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 93 | elif which_model_netG == 'unet_256': 94 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 95 | else: 96 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 97 | return init_net(netG, init_type, init_gain, gpu_ids) 98 | 99 | 100 | def define_D(input_nc, ndf, which_model_netD, 101 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 102 | netD = None 103 | norm_layer = get_norm_layer(norm_type=norm) 104 | 105 | if which_model_netD == 'basic': 106 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 107 | elif which_model_netD == 'n_layers': 108 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 109 | elif which_model_netD == 'pixel': 110 | netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 111 | else: 112 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % 113 | which_model_netD) 114 | return init_net(netD, init_type, init_gain, gpu_ids) 115 | 116 | 117 | ############################################################################## 118 | # Classes 119 | ############################################################################## 120 | 121 | 122 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 123 | # When LSGAN is used, it is basically same as MSELoss, 124 | # but it abstracts away the need to create the target label tensor 125 | # that has the same size as the input 126 | class GANLoss(nn.Module): 127 | def __init__(self, gan_type='wgan-gp', target_real_label=1.0, target_fake_label=0.0): 128 | super(GANLoss, self).__init__() 129 | self.register_buffer('real_label', torch.tensor(target_real_label)) 130 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 131 | self.gan_type = gan_type 132 | if self.gan_type == 'wgan-gp': 133 | self.loss = lambda x, y: -torch.mean(x) if y else torch.mean(x) 134 | elif self.gan_type == 'lsgan': 135 | self.loss = nn.MSELoss() 136 | elif self.gan_type == 'gan': 137 | self.loss = nn.BCELoss() 138 | else: 139 | raise NotImplementedError('GAN loss type [%s] is not found' % gan_type) 140 | 141 | def get_target_tensor(self, input, target_is_real): 142 | if target_is_real: 143 | target_tensor = self.real_label 144 | else: 145 | target_tensor = self.fake_label 146 | return target_tensor.expand_as(input) 147 | 148 | def __call__(self, input, target_is_real): 149 | if self.gan_type == 'wgan-gp': 150 | target_tensor = target_is_real 151 | else: 152 | target_tensor = self.get_target_tensor(input, target_is_real) 153 | return self.loss(input, target_tensor) 154 | 155 | 156 | # Defines the generator that consists of Resnet blocks between a few 157 | # downsampling/upsampling operations. 158 | # Code and idea originally from Justin Johnson's architecture. 159 | # https://github.com/jcjohnson/fast-neural-style/ 160 | class ResnetGenerator(nn.Module): 161 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 162 | assert(n_blocks >= 0) 163 | super(ResnetGenerator, self).__init__() 164 | self.input_nc = input_nc 165 | self.output_nc = output_nc 166 | self.ngf = ngf 167 | if type(norm_layer) == functools.partial: 168 | use_bias = norm_layer.func == nn.InstanceNorm2d 169 | else: 170 | use_bias = norm_layer == nn.InstanceNorm2d 171 | 172 | model = [nn.ReflectionPad2d(3), 173 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 174 | bias=use_bias), 175 | norm_layer(ngf), 176 | nn.ReLU(True)] 177 | 178 | n_downsampling = 2 179 | for i in range(n_downsampling): 180 | mult = 2**i 181 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 182 | stride=2, padding=1, bias=use_bias), 183 | norm_layer(ngf * mult * 2), 184 | nn.ReLU(True)] 185 | 186 | mult = 2**n_downsampling 187 | for i in range(n_blocks): 188 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 189 | 190 | for i in range(n_downsampling): 191 | mult = 2**(n_downsampling - i) 192 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 193 | kernel_size=3, stride=2, 194 | padding=1, output_padding=1, 195 | bias=use_bias), 196 | norm_layer(int(ngf * mult / 2)), 197 | nn.ReLU(True)] 198 | model += [nn.ReflectionPad2d(3)] 199 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 200 | model += [nn.Tanh()] 201 | 202 | self.model = nn.Sequential(*model) 203 | 204 | def forward(self, input): 205 | return self.model(input) 206 | 207 | 208 | # Define a resnet block 209 | class ResnetBlock(nn.Module): 210 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 211 | super(ResnetBlock, self).__init__() 212 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 213 | 214 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 215 | conv_block = [] 216 | p = 0 217 | if padding_type == 'reflect': 218 | conv_block += [nn.ReflectionPad2d(1)] 219 | elif padding_type == 'replicate': 220 | conv_block += [nn.ReplicationPad2d(1)] 221 | elif padding_type == 'zero': 222 | p = 1 223 | else: 224 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 225 | 226 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 227 | norm_layer(dim), 228 | nn.ReLU(True)] 229 | if use_dropout: 230 | conv_block += [nn.Dropout(0.5)] 231 | 232 | p = 0 233 | if padding_type == 'reflect': 234 | conv_block += [nn.ReflectionPad2d(1)] 235 | elif padding_type == 'replicate': 236 | conv_block += [nn.ReplicationPad2d(1)] 237 | elif padding_type == 'zero': 238 | p = 1 239 | else: 240 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 241 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 242 | norm_layer(dim)] 243 | 244 | return nn.Sequential(*conv_block) 245 | 246 | def forward(self, x): 247 | out = x + self.conv_block(x) 248 | return out 249 | 250 | 251 | # Defines the Unet generator. 252 | # |num_downs|: number of downsamplings in UNet. For example, 253 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 254 | # at the bottleneck 255 | class UnetGenerator(nn.Module): 256 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 257 | norm_layer=nn.BatchNorm2d, use_dropout=False): 258 | super(UnetGenerator, self).__init__() 259 | 260 | # construct unet structure 261 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 262 | for i in range(num_downs - 5): 263 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 264 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 265 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 266 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 267 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 268 | 269 | self.model = unet_block 270 | 271 | def forward(self, input): 272 | return self.model(input) 273 | 274 | 275 | # Defines the submodule with skip connection. 276 | # X -------------------identity---------------------- X 277 | # |-- downsampling -- |submodule| -- upsampling --| 278 | class UnetSkipConnectionBlock(nn.Module): 279 | def __init__(self, outer_nc, inner_nc, input_nc=None, 280 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 281 | super(UnetSkipConnectionBlock, self).__init__() 282 | self.outermost = outermost 283 | if type(norm_layer) == functools.partial: 284 | use_bias = norm_layer.func == nn.InstanceNorm2d 285 | else: 286 | use_bias = norm_layer == nn.InstanceNorm2d 287 | if input_nc is None: 288 | input_nc = outer_nc 289 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 290 | stride=2, padding=1, bias=use_bias) 291 | downrelu = nn.LeakyReLU(0.2, True) 292 | downnorm = norm_layer(inner_nc) 293 | uprelu = nn.ReLU(True) 294 | upnorm = norm_layer(outer_nc) 295 | 296 | if outermost: 297 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 298 | kernel_size=4, stride=2, 299 | padding=1) 300 | down = [downconv] 301 | up = [uprelu, upconv, nn.Tanh()] 302 | model = down + [submodule] + up 303 | elif innermost: 304 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 305 | kernel_size=4, stride=2, 306 | padding=1, bias=use_bias) 307 | down = [downrelu, downconv] 308 | up = [uprelu, upconv, upnorm] 309 | model = down + up 310 | else: 311 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 312 | kernel_size=4, stride=2, 313 | padding=1, bias=use_bias) 314 | down = [downrelu, downconv, downnorm] 315 | up = [uprelu, upconv, upnorm] 316 | 317 | if use_dropout: 318 | model = down + [submodule] + up + [nn.Dropout(0.5)] 319 | else: 320 | model = down + [submodule] + up 321 | 322 | self.model = nn.Sequential(*model) 323 | 324 | def forward(self, x): 325 | if self.outermost: 326 | return self.model(x) 327 | else: 328 | return torch.cat([x, self.model(x)], 1) 329 | 330 | 331 | # Defines the PatchGAN discriminator with the specified arguments. 332 | class NLayerDiscriminator(nn.Module): 333 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 334 | super(NLayerDiscriminator, self).__init__() 335 | if type(norm_layer) == functools.partial: 336 | use_bias = norm_layer.func == nn.InstanceNorm2d 337 | else: 338 | use_bias = norm_layer == nn.InstanceNorm2d 339 | 340 | kw = 4 341 | padw = 1 342 | sequence = [ 343 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 344 | nn.LeakyReLU(0.2, True) 345 | ] 346 | 347 | nf_mult = 1 348 | nf_mult_prev = 1 349 | for n in range(1, n_layers): 350 | nf_mult_prev = nf_mult 351 | nf_mult = min(2**n, 8) 352 | sequence += [ 353 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 354 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 355 | norm_layer(ndf * nf_mult), 356 | nn.LeakyReLU(0.2, True) 357 | ] 358 | 359 | nf_mult_prev = nf_mult 360 | nf_mult = min(2**n_layers, 8) 361 | sequence += [ 362 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 363 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 364 | norm_layer(ndf * nf_mult), 365 | nn.LeakyReLU(0.2, True) 366 | ] 367 | 368 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 369 | 370 | if use_sigmoid: 371 | sequence += [nn.Sigmoid()] 372 | 373 | self.model = nn.Sequential(*sequence) 374 | 375 | def forward(self, input): 376 | return self.model(input) 377 | 378 | 379 | class PixelDiscriminator(nn.Module): 380 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 381 | super(PixelDiscriminator, self).__init__() 382 | if type(norm_layer) == functools.partial: 383 | use_bias = norm_layer.func == nn.InstanceNorm2d 384 | else: 385 | use_bias = norm_layer == nn.InstanceNorm2d 386 | 387 | self.net = [ 388 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 389 | nn.LeakyReLU(0.2, True), 390 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 391 | norm_layer(ndf * 2), 392 | nn.LeakyReLU(0.2, True), 393 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 394 | 395 | if use_sigmoid: 396 | self.net.append(nn.Sigmoid()) 397 | 398 | self.net = nn.Sequential(*self.net) 399 | 400 | def forward(self, input): 401 | return self.net(input) 402 | 403 | 404 | ############################################################################## 405 | # Basic network model 406 | ############################################################################## 407 | def define_splitG(img_nc, aus_nc, ngf, use_dropout=False, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[]): 408 | norm_layer = get_norm_layer(norm_type=norm) 409 | net_img_au = SplitGenerator(img_nc, aus_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 410 | return init_net(net_img_au, init_type, init_gain, gpu_ids) 411 | 412 | 413 | def define_splitD(input_nc, aus_nc, image_size, ndf, norm='instance', init_type='normal', init_gain=0.02, gpu_ids=[]): 414 | norm_layer = get_norm_layer(norm_type=norm) 415 | net_dis_aus = SplitDiscriminator(input_nc, aus_nc, image_size, ndf, n_layers=6, norm_layer=norm_layer) 416 | return init_net(net_dis_aus, init_type, init_gain, gpu_ids) 417 | 418 | 419 | class SplitGenerator(nn.Module): 420 | def __init__(self, img_nc, aus_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero'): 421 | assert(n_blocks >= 0) 422 | super(SplitGenerator, self).__init__() 423 | self.input_nc = img_nc + aus_nc 424 | self.ngf = ngf 425 | if type(norm_layer) == functools.partial: 426 | use_bias = norm_layer.func == nn.InstanceNorm2d 427 | else: 428 | use_bias = norm_layer == nn.InstanceNorm2d 429 | 430 | model = [nn.Conv2d(self.input_nc, ngf, kernel_size=7, stride=1, padding=3, 431 | bias=use_bias), 432 | norm_layer(ngf), 433 | nn.ReLU(True)] 434 | 435 | n_downsampling = 2 436 | for i in range(n_downsampling): 437 | mult = 2**i 438 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, \ 439 | kernel_size=4, stride=2, padding=1, \ 440 | bias=use_bias), 441 | norm_layer(ngf * mult * 2), 442 | nn.ReLU(True)] 443 | 444 | mult = 2**n_downsampling 445 | for i in range(n_blocks): 446 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 447 | 448 | for i in range(n_downsampling): 449 | mult = 2**(n_downsampling - i) 450 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 451 | kernel_size=4, stride=2, padding=1, 452 | bias=use_bias), 453 | norm_layer(int(ngf * mult / 2)), 454 | nn.ReLU(True)] 455 | 456 | self.model = nn.Sequential(*model) 457 | # color mask generator top 458 | color_top = [] 459 | color_top += [nn.Conv2d(ngf, img_nc, kernel_size=7, stride=1, padding=3, bias=False), 460 | nn.Tanh()] 461 | self.color_top = nn.Sequential(*color_top) 462 | # AUs mask generator top 463 | au_top = [] 464 | au_top += [nn.Conv2d(ngf, 1, kernel_size=7, stride=1, padding=3, bias=False), 465 | nn.Sigmoid()] 466 | self.au_top = nn.Sequential(*au_top) 467 | 468 | # from torchsummary import summary 469 | # summary(self.model.to("cuda"), (20, 128, 128)) 470 | # summary(self.color_top.to("cuda"), (64, 128, 128)) 471 | # summary(self.au_top.to("cuda"), (64, 128, 128)) 472 | # assert False 473 | 474 | def forward(self, img, au): 475 | # replicate AUs vector to match image shap and concate to construct input 476 | sparse_au = au.unsqueeze(2).unsqueeze(3) 477 | sparse_au = sparse_au.expand(sparse_au.size(0), sparse_au.size(1), img.size(2), img.size(3)) 478 | self.input_img_au = torch.cat([img, sparse_au], dim=1) 479 | 480 | embed_features = self.model(self.input_img_au) 481 | 482 | return self.color_top(embed_features), self.au_top(embed_features), embed_features 483 | 484 | 485 | class SplitDiscriminator(nn.Module): 486 | def __init__(self, input_nc, aus_nc, image_size=128, ndf=64, n_layers=6, norm_layer=nn.BatchNorm2d): 487 | super(SplitDiscriminator, self).__init__() 488 | if type(norm_layer) == functools.partial: 489 | use_bias = norm_layer.func == nn.InstanceNorm2d 490 | else: 491 | use_bias = norm_layer == nn.InstanceNorm2d 492 | 493 | kw = 4 494 | padw = 1 495 | sequence = [ 496 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 497 | nn.LeakyReLU(0.01, True) 498 | ] 499 | 500 | cur_dim = ndf 501 | for n in range(1, n_layers): 502 | sequence += [ 503 | nn.Conv2d(cur_dim, 2 * cur_dim, 504 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 505 | nn.LeakyReLU(0.01, True) 506 | ] 507 | cur_dim = 2 * cur_dim 508 | 509 | self.model = nn.Sequential(*sequence) 510 | # patch discriminator top 511 | self.dis_top = nn.Conv2d(cur_dim, 1, kernel_size=kw-1, stride=1, padding=padw, bias=False) 512 | # AUs classifier top 513 | k_size = int(image_size / (2 ** n_layers)) 514 | self.aus_top = nn.Conv2d(cur_dim, aus_nc, kernel_size=k_size, stride=1, bias=False) 515 | 516 | # from torchsummary import summary 517 | # summary(self.model.to("cuda"), (3, 128, 128)) 518 | 519 | def forward(self, img): 520 | embed_features = self.model(img) 521 | pred_map = self.dis_top(embed_features) 522 | pred_aus = self.aus_top(embed_features) 523 | return pred_map.squeeze(), pred_aus.squeeze() 524 | 525 | 526 | # https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py 527 | class TVLoss(nn.Module): 528 | def __init__(self, TVLoss_weight=1): 529 | super(TVLoss,self).__init__() 530 | self.TVLoss_weight = TVLoss_weight 531 | 532 | def forward(self,x): 533 | batch_size = x.size()[0] 534 | h_x = x.size()[2] 535 | w_x = x.size()[3] 536 | count_h = self._tensor_size(x[:,:,1:,:]) 537 | count_w = self._tensor_size(x[:,:,:,1:]) 538 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 539 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 540 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 541 | 542 | def _tensor_size(self,t): 543 | return t.size()[1]*t.size()[2]*t.size()[3] 544 | 545 | 546 | 547 | 548 | --------------------------------------------------------------------------------