├── data ├── __init__.py ├── data_loader.py ├── celeba.py └── base_dataset.py ├── requirements.txt ├── imgs ├── celeba_testing.jpg ├── gifs │ ├── celeba_1.gif │ ├── celeba_2.gif │ ├── emotionnet_1.gif │ ├── emotionnet_2.gif │ ├── emotionnet_3.gif │ └── emotionnet_4.gif ├── celeba_training.jpg ├── ganimation_show.jpg ├── emotionnet_testing.jpg ├── emotionnet_training.jpg ├── celeba_stargan_testing.jpg └── emotionnet_stargan_testing.jpg ├── main.py ├── model ├── __init__.py ├── stargan.py ├── ganimation.py ├── base_model.py └── model_utils.py ├── LICENSE ├── .gitignore ├── visualizer.py ├── solvers.py ├── README.md └── options.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import create_dataloader -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.1 2 | torchvision>=0.2.1 3 | visdom>=0.1.8.3 4 | imageio>=2.5.0 -------------------------------------------------------------------------------- /imgs/celeba_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/celeba_testing.jpg -------------------------------------------------------------------------------- /imgs/gifs/celeba_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/celeba_1.gif -------------------------------------------------------------------------------- /imgs/gifs/celeba_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/celeba_2.gif -------------------------------------------------------------------------------- /imgs/celeba_training.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/celeba_training.jpg -------------------------------------------------------------------------------- /imgs/ganimation_show.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/ganimation_show.jpg -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_1.gif -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_2.gif -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_3.gif -------------------------------------------------------------------------------- /imgs/gifs/emotionnet_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/gifs/emotionnet_4.gif -------------------------------------------------------------------------------- /imgs/emotionnet_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/emotionnet_testing.jpg -------------------------------------------------------------------------------- /imgs/emotionnet_training.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/emotionnet_training.jpg -------------------------------------------------------------------------------- /imgs/celeba_stargan_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/celeba_stargan_testing.jpg -------------------------------------------------------------------------------- /imgs/emotionnet_stargan_testing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donydchen/ganimation_replicate/HEAD/imgs/emotionnet_stargan_testing.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Dec 13, 2018 3 | @author: Yuedong Chen 4 | """ 5 | 6 | from options import Options 7 | from solvers import create_solver 8 | 9 | 10 | 11 | 12 | if __name__ == '__main__': 13 | opt = Options().parse() 14 | 15 | solver = create_solver(opt) 16 | solver.run_solver() 17 | 18 | print('[THE END]') -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .ganimation import GANimationModel 3 | from .stargan import StarGANModel 4 | 5 | 6 | 7 | def create_model(opt): 8 | # specify model name here 9 | if opt.model == "ganimation": 10 | instance = GANimationModel() 11 | elif opt.model == "stargan": 12 | instance = StarGANModel() 13 | else: 14 | instance = BaseModel() 15 | instance.initialize(opt) 16 | instance.setup() 17 | return instance 18 | 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yuedong Chen (Donald) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import random 5 | import numpy as np 6 | import pickle 7 | import torchvision.transforms as transforms 8 | 9 | from .celeba import CelebADataset 10 | 11 | 12 | def create_dataloader(opt): 13 | data_loader = DataLoader() 14 | data_loader.initialize(opt) 15 | return data_loader 16 | 17 | 18 | class DataLoader: 19 | def name(self): 20 | return self.dataset.name() + "_Loader" 21 | 22 | def create_datase(self): 23 | # specify which dataset to load here 24 | loaded_dataset = os.path.basename(self.opt.data_root.strip('/')).lower() 25 | if 'celeba' in loaded_dataset or 'emotion' in loaded_dataset: 26 | dataset = CelebADataset() 27 | else: 28 | dataset = BaseDataset() 29 | dataset.initialize(self.opt) 30 | return dataset 31 | 32 | def initialize(self, opt): 33 | self.opt = opt 34 | self.dataset = self.create_datase() 35 | self.dataloader = torch.utils.data.DataLoader( 36 | self.dataset, 37 | batch_size=opt.batch_size, 38 | shuffle=not opt.serial_batches, 39 | num_workers=int(opt.n_threads) 40 | ) 41 | 42 | def __len__(self): 43 | return min(len(self.dataset), self.opt.max_dataset_size) 44 | 45 | def __iter__(self): 46 | for i, data in enumerate(self.dataloader): 47 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 48 | break 49 | yield data 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # project related 107 | results 108 | ckpts 109 | sftp-config.json 110 | 111 | -------------------------------------------------------------------------------- /data/celeba.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset 2 | import os 3 | import random 4 | import numpy as np 5 | 6 | 7 | class CelebADataset(BaseDataset): 8 | """docstring for CelebADataset""" 9 | def __init__(self): 10 | super(CelebADataset, self).__init__() 11 | 12 | def initialize(self, opt): 13 | super(CelebADataset, self).initialize(opt) 14 | 15 | def get_aus_by_path(self, img_path): 16 | assert os.path.isfile(img_path), "Cannot find image file: %s" % img_path 17 | img_id = str(os.path.splitext(os.path.basename(img_path))[0]) 18 | return self.aus_dict[img_id] / 5.0 # norm to [0, 1] 19 | 20 | def make_dataset(self): 21 | # return all image full path in a list 22 | imgs_path = [] 23 | assert os.path.isfile(self.imgs_name_file), "%s does not exist." % self.imgs_name_file 24 | with open(self.imgs_name_file, 'r') as f: 25 | lines = f.readlines() 26 | imgs_path = [os.path.join(self.imgs_dir, line.strip()) for line in lines] 27 | imgs_path = sorted(imgs_path) 28 | return imgs_path 29 | 30 | def __getitem__(self, index): 31 | img_path = self.imgs_path[index] 32 | 33 | # load source image 34 | src_img = self.get_img_by_path(img_path) 35 | src_img_tensor = self.img2tensor(src_img) 36 | src_aus = self.get_aus_by_path(img_path) 37 | 38 | # load target image 39 | tar_img_path = random.choice(self.imgs_path) 40 | tar_img = self.get_img_by_path(tar_img_path) 41 | tar_img_tensor = self.img2tensor(tar_img) 42 | tar_aus = self.get_aus_by_path(tar_img_path) 43 | if self.is_train and not self.opt.no_aus_noise: 44 | tar_aus = tar_aus + np.random.uniform(-0.1, 0.1, tar_aus.shape) 45 | 46 | # record paths for debug and test usage 47 | data_dict = {'src_img':src_img_tensor, 'src_aus':src_aus, 'tar_img':tar_img_tensor, 'tar_aus':tar_aus, \ 48 | 'src_path':img_path, 'tar_path':tar_img_path} 49 | 50 | return data_dict 51 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | import random 5 | import numpy as np 6 | import pickle 7 | import torchvision.transforms as transforms 8 | 9 | 10 | 11 | class BaseDataset(torch.utils.data.Dataset): 12 | """docstring for BaseDataset""" 13 | def __init__(self): 14 | super(BaseDataset, self).__init__() 15 | 16 | def name(self): 17 | return os.path.basename(self.opt.data_root.strip('/')) 18 | 19 | def initialize(self, opt): 20 | self.opt = opt 21 | self.imgs_dir = os.path.join(self.opt.data_root, self.opt.imgs_dir) 22 | self.is_train = self.opt.mode == "train" 23 | 24 | # load images path 25 | filename = self.opt.train_csv if self.is_train else self.opt.test_csv 26 | self.imgs_name_file = os.path.join(self.opt.data_root, filename) 27 | self.imgs_path = self.make_dataset() 28 | 29 | # load AUs dicitionary 30 | aus_pkl = os.path.join(self.opt.data_root, self.opt.aus_pkl) 31 | self.aus_dict = self.load_dict(aus_pkl) 32 | 33 | # load image to tensor transformer 34 | self.img2tensor = self.img_transformer() 35 | 36 | def make_dataset(self): 37 | return None 38 | 39 | def load_dict(self, pkl_path): 40 | saved_dict = {} 41 | with open(pkl_path, 'rb') as f: 42 | saved_dict = pickle.load(f, encoding='latin1') 43 | return saved_dict 44 | 45 | def get_img_by_path(self, img_path): 46 | assert os.path.isfile(img_path), "Cannot find image file: %s" % img_path 47 | img_type = 'L' if self.opt.img_nc == 1 else 'RGB' 48 | return Image.open(img_path).convert(img_type) 49 | 50 | def get_aus_by_path(self, img_path): 51 | return None 52 | 53 | def img_transformer(self): 54 | transform_list = [] 55 | if self.opt.resize_or_crop == 'resize_and_crop': 56 | transform_list.append(transforms.Resize([self.opt.load_size, self.opt.load_size], Image.BICUBIC)) 57 | transform_list.append(transforms.RandomCrop(self.opt.final_size)) 58 | elif self.opt.resize_or_crop == 'crop': 59 | transform_list.append(transforms.RandomCrop(self.opt.final_size)) 60 | elif self.opt.resize_or_crop == 'none': 61 | transform_list.append(transforms.Lambda(lambda image: image)) 62 | else: 63 | raise ValueError("--resize_or_crop %s is not a valid option." % self.opt.resize_or_crop) 64 | 65 | if self.is_train and not self.opt.no_flip: 66 | transform_list.append(transforms.RandomHorizontalFlip()) 67 | 68 | transform_list.append(transforms.ToTensor()) 69 | transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) 70 | 71 | img2tensor = transforms.Compose(transform_list) 72 | 73 | return img2tensor 74 | 75 | def __len__(self): 76 | return len(self.imgs_path) 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /visualizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import math 5 | from PIL import Image 6 | # import matplotlib.pyplot as plt 7 | 8 | 9 | 10 | class Visualizer(object): 11 | """docstring for Visualizer""" 12 | def __init__(self): 13 | super(Visualizer, self).__init__() 14 | 15 | def initialize(self, opt): 16 | self.opt = opt 17 | # self.vis_saved_dir = os.path.join(self.opt.ckpt_dir, 'vis_pics') 18 | # if not os.path.isdir(self.vis_saved_dir): 19 | # os.makedirs(self.vis_saved_dir) 20 | # plt.switch_backend('agg') 21 | 22 | self.display_id = self.opt.visdom_display_id 23 | if self.display_id > 0: 24 | import visdom 25 | self.ncols = 8 26 | self.vis = visdom.Visdom(server="http://localhost", port=self.opt.visdom_port, env=self.opt.visdom_env) 27 | 28 | def throw_visdom_connection_error(self): 29 | print('\n\nno visdom server.') 30 | exit(1) 31 | 32 | def print_losses_info(self, info_dict): 33 | msg = '[{}][Epoch: {:0>3}/{:0>3}; Images: {:0>4}/{:0>4}; Time: {:.3f}s/Batch({}); LR: {:.7f}] '.format( 34 | self.opt.name, info_dict['epoch'], info_dict['epoch_len'], 35 | info_dict['epoch_steps'], info_dict['epoch_steps_len'], 36 | info_dict['step_time'], self.opt.batch_size, info_dict['cur_lr']) 37 | for k, v in info_dict['losses'].items(): 38 | msg += '| {}: {:.4f} '.format(k, v) 39 | msg += '|' 40 | print(msg) 41 | with open(info_dict['log_path'], 'a+') as f: 42 | f.write(msg + '\n') 43 | 44 | def display_current_losses(self, epoch, counter_ratio, losses_dict): 45 | if not hasattr(self, 'plot_data'): 46 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses_dict.keys())} 47 | self.plot_data['X'].append(epoch + counter_ratio) 48 | self.plot_data['Y'].append([losses_dict[k] for k in self.plot_data['legend']]) 49 | try: 50 | self.vis.line( 51 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 52 | Y=np.array(self.plot_data['Y']), 53 | opts={ 54 | 'title': self.opt.name + ' loss over time', 55 | 'legend':self.plot_data['legend'], 56 | 'xlabel':'epoch', 57 | 'ylabel':'loss'}, 58 | win=self.display_id) 59 | except ConnectionError: 60 | self.throw_visdom_connection_error() 61 | 62 | def display_online_results(self, visuals, epoch): 63 | win_id = self.display_id + 24 64 | images = [] 65 | labels = [] 66 | for label, image in visuals.items(): 67 | if 'mask' in label: # or 'focus' in label: 68 | image = (image - 0.5) / 0.5 # convert map from [0, 1] to [-1, 1] 69 | image_numpy = self.tensor2im(image) 70 | images.append(image_numpy.transpose([2, 0, 1])) 71 | labels.append(label) 72 | try: 73 | title = ' || '.join(labels) 74 | self.vis.images(images, nrow=self.ncols, win=win_id, 75 | padding=5, opts=dict(title=title)) 76 | except ConnectionError: 77 | self.throw_visdom_connection_error() 78 | 79 | # utils 80 | def tensor2im(self, input_image, imtype=np.uint8): 81 | if isinstance(input_image, torch.Tensor): 82 | image_tensor = input_image.data 83 | else: 84 | return input_image 85 | image_numpy = image_tensor[0].cpu().float().numpy() 86 | im = self.numpy2im(image_numpy, imtype).resize((80, 80), Image.ANTIALIAS) 87 | return np.array(im) 88 | 89 | def numpy2im(self, image_numpy, imtype=np.uint8): 90 | if image_numpy.shape[0] == 1: 91 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 92 | # input should be [0, 1] 93 | #image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 94 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) / 2. + 0.5) * 255.0 95 | # print(image_numpy.shape) 96 | image_numpy = image_numpy.astype(imtype) 97 | im = Image.fromarray(image_numpy) 98 | # im = Image.fromarray(image_numpy).resize((64, 64), Image.ANTIALIAS) 99 | return im # np.array(im) 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /model/stargan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import model_utils 4 | 5 | 6 | 7 | class StarGANModel(BaseModel): 8 | """docstring for StarGANModel""" 9 | def __init__(self): 10 | super(StarGANModel, self).__init__() 11 | self.name = "StarGAN" 12 | 13 | def initialize(self, opt): 14 | super(StarGANModel, self).initialize(opt) 15 | 16 | self.net_gen = model_utils.define_splitG(self.opt.img_nc, self.opt.aus_nc, self.opt.ngf, use_dropout=self.opt.use_dropout, 17 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 18 | self.models_name.append('gen') 19 | 20 | if self.is_train: 21 | self.net_dis = model_utils.define_splitD(self.opt.img_nc, self.opt.aus_nc, self.opt.final_size, self.opt.ndf, 22 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 23 | self.models_name.append('dis') 24 | 25 | if self.opt.load_epoch > 0: 26 | self.load_ckpt(self.opt.load_epoch) 27 | 28 | def setup(self): 29 | super(StarGANModel, self).setup() 30 | if self.is_train: 31 | # setup optimizer 32 | self.optim_gen = torch.optim.Adam(self.net_gen.parameters(), 33 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 34 | self.optims.append(self.optim_gen) 35 | self.optim_dis = torch.optim.Adam(self.net_dis.parameters(), 36 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 37 | self.optims.append(self.optim_dis) 38 | 39 | # setup schedulers 40 | self.schedulers = [model_utils.get_scheduler(optim, self.opt) for optim in self.optims] 41 | 42 | def feed_batch(self, batch): 43 | self.src_img = batch['src_img'].to(self.device) 44 | self.tar_aus = batch['tar_aus'].type(torch.FloatTensor).to(self.device) 45 | if self.is_train: 46 | self.src_aus = batch['src_aus'].type(torch.FloatTensor).to(self.device) 47 | self.tar_img = batch['tar_img'].to(self.device) 48 | 49 | def forward(self): 50 | # generate fake image 51 | self.fake_img, _, _ = self.net_gen(self.src_img, self.tar_aus) 52 | 53 | # reconstruct real image 54 | if self.is_train: 55 | self.rec_real_img, _, _ = self.net_gen(self.fake_img, self.src_aus) 56 | 57 | def backward_dis(self): 58 | # real image 59 | pred_real, self.pred_real_aus = self.net_dis(self.src_img) 60 | self.loss_dis_real = self.criterionGAN(pred_real, True) 61 | self.loss_dis_real_aus = self.criterionMSE(self.pred_real_aus, self.src_aus) 62 | 63 | # fake image, detach to stop backward to generator 64 | pred_fake, _ = self.net_dis(self.fake_img.detach()) 65 | self.loss_dis_fake = self.criterionGAN(pred_fake, False) 66 | 67 | # combine dis loss 68 | self.loss_dis = self.opt.lambda_dis * (self.loss_dis_fake + self.loss_dis_real) \ 69 | + self.opt.lambda_aus * self.loss_dis_real_aus 70 | if self.opt.gan_type == 'wgan-gp': 71 | self.loss_dis_gp = self.gradient_penalty(self.src_img, self.fake_img) 72 | self.loss_dis = self.loss_dis + self.opt.lambda_wgan_gp * self.loss_dis_gp 73 | 74 | # backward discriminator loss 75 | self.loss_dis.backward() 76 | 77 | def backward_gen(self): 78 | # original to target domain, should fake the discriminator 79 | pred_fake, self.pred_fake_aus = self.net_dis(self.fake_img) 80 | self.loss_gen_GAN = self.criterionGAN(pred_fake, True) 81 | self.loss_gen_fake_aus = self.criterionMSE(self.pred_fake_aus, self.tar_aus) 82 | 83 | # target to original domain reconstruct, identity loss 84 | self.loss_gen_rec = self.criterionL1(self.rec_real_img, self.src_img) 85 | 86 | # combine and backward G loss 87 | self.loss_gen = self.opt.lambda_dis * self.loss_gen_GAN \ 88 | + self.opt.lambda_aus * self.loss_gen_fake_aus \ 89 | + self.opt.lambda_rec * self.loss_gen_rec 90 | 91 | self.loss_gen.backward() 92 | 93 | def optimize_paras(self, train_gen): 94 | self.forward() 95 | # update discriminator 96 | self.set_requires_grad(self.net_dis, True) 97 | self.optim_dis.zero_grad() 98 | self.backward_dis() 99 | self.optim_dis.step() 100 | 101 | # update G if needed 102 | if train_gen: 103 | self.set_requires_grad(self.net_dis, False) 104 | self.optim_gen.zero_grad() 105 | self.backward_gen() 106 | self.optim_gen.step() 107 | 108 | def save_ckpt(self, epoch): 109 | # save the specific networks 110 | save_models_name = ['gen', 'dis'] 111 | return super(StarGANModel, self).save_ckpt(epoch, save_models_name) 112 | 113 | def load_ckpt(self, epoch): 114 | # load the specific part of networks 115 | load_models_name = ['gen'] 116 | if self.is_train: 117 | load_models_name.extend(['dis']) 118 | return super(StarGANModel, self).load_ckpt(epoch, load_models_name) 119 | 120 | def clean_ckpt(self, epoch): 121 | # load the specific part of networks 122 | load_models_name = ['gen', 'dis'] 123 | return super(StarGANModel, self).clean_ckpt(epoch, load_models_name) 124 | 125 | def get_latest_losses(self): 126 | get_losses_name = ['dis_fake', 'dis_real', 'dis_real_aus', 'gen_rec'] 127 | return super(StarGANModel, self).get_latest_losses(get_losses_name) 128 | 129 | def get_latest_visuals(self): 130 | visuals_name = ['src_img', 'tar_img', 'fake_img'] 131 | if self.is_train: 132 | visuals_name.extend(['rec_real_img']) 133 | return super(StarGANModel, self).get_latest_visuals(visuals_name) 134 | -------------------------------------------------------------------------------- /model/ganimation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import model_utils 4 | 5 | 6 | class GANimationModel(BaseModel): 7 | """docstring for GANimationModel""" 8 | def __init__(self): 9 | super(GANimationModel, self).__init__() 10 | self.name = "GANimation" 11 | 12 | def initialize(self, opt): 13 | super(GANimationModel, self).initialize(opt) 14 | 15 | self.net_gen = model_utils.define_splitG(self.opt.img_nc, self.opt.aus_nc, self.opt.ngf, use_dropout=self.opt.use_dropout, 16 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 17 | self.models_name.append('gen') 18 | 19 | if self.is_train: 20 | self.net_dis = model_utils.define_splitD(self.opt.img_nc, self.opt.aus_nc, self.opt.final_size, self.opt.ndf, 21 | norm=self.opt.norm, init_type=self.opt.init_type, init_gain=self.opt.init_gain, gpu_ids=self.gpu_ids) 22 | self.models_name.append('dis') 23 | 24 | if self.opt.load_epoch > 0: 25 | self.load_ckpt(self.opt.load_epoch) 26 | 27 | def setup(self): 28 | super(GANimationModel, self).setup() 29 | if self.is_train: 30 | # setup optimizer 31 | self.optim_gen = torch.optim.Adam(self.net_gen.parameters(), 32 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 33 | self.optims.append(self.optim_gen) 34 | self.optim_dis = torch.optim.Adam(self.net_dis.parameters(), 35 | lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 36 | self.optims.append(self.optim_dis) 37 | 38 | # setup schedulers 39 | self.schedulers = [model_utils.get_scheduler(optim, self.opt) for optim in self.optims] 40 | 41 | def feed_batch(self, batch): 42 | self.src_img = batch['src_img'].to(self.device) 43 | self.tar_aus = batch['tar_aus'].type(torch.FloatTensor).to(self.device) 44 | if self.is_train: 45 | self.src_aus = batch['src_aus'].type(torch.FloatTensor).to(self.device) 46 | self.tar_img = batch['tar_img'].to(self.device) 47 | 48 | def forward(self): 49 | # generate fake image 50 | self.color_mask ,self.aus_mask, self.embed = self.net_gen(self.src_img, self.tar_aus) 51 | self.fake_img = self.aus_mask * self.src_img + (1 - self.aus_mask) * self.color_mask 52 | 53 | # reconstruct real image 54 | if self.is_train: 55 | self.rec_color_mask, self.rec_aus_mask, self.rec_embed = self.net_gen(self.fake_img, self.src_aus) 56 | self.rec_real_img = self.rec_aus_mask * self.fake_img + (1 - self.rec_aus_mask) * self.rec_color_mask 57 | 58 | def backward_dis(self): 59 | # real image 60 | pred_real, self.pred_real_aus = self.net_dis(self.src_img) 61 | self.loss_dis_real = self.criterionGAN(pred_real, True) 62 | self.loss_dis_real_aus = self.criterionMSE(self.pred_real_aus, self.src_aus) 63 | 64 | # fake image, detach to stop backward to generator 65 | pred_fake, _ = self.net_dis(self.fake_img.detach()) 66 | self.loss_dis_fake = self.criterionGAN(pred_fake, False) 67 | 68 | # combine dis loss 69 | self.loss_dis = self.opt.lambda_dis * (self.loss_dis_fake + self.loss_dis_real) \ 70 | + self.opt.lambda_aus * self.loss_dis_real_aus 71 | if self.opt.gan_type == 'wgan-gp': 72 | self.loss_dis_gp = self.gradient_penalty(self.src_img, self.fake_img) 73 | self.loss_dis = self.loss_dis + self.opt.lambda_wgan_gp * self.loss_dis_gp 74 | 75 | # backward discriminator loss 76 | self.loss_dis.backward() 77 | 78 | def backward_gen(self): 79 | # original to target domain, should fake the discriminator 80 | pred_fake, self.pred_fake_aus = self.net_dis(self.fake_img) 81 | self.loss_gen_GAN = self.criterionGAN(pred_fake, True) 82 | self.loss_gen_fake_aus = self.criterionMSE(self.pred_fake_aus, self.tar_aus) 83 | 84 | # target to original domain reconstruct, identity loss 85 | self.loss_gen_rec = self.criterionL1(self.rec_real_img, self.src_img) 86 | 87 | # constrain on AUs mask 88 | self.loss_gen_mask_real_aus = torch.mean(self.aus_mask) 89 | self.loss_gen_mask_fake_aus = torch.mean(self.rec_aus_mask) 90 | self.loss_gen_smooth_real_aus = self.criterionTV(self.aus_mask) 91 | self.loss_gen_smooth_fake_aus = self.criterionTV(self.rec_aus_mask) 92 | 93 | # combine and backward G loss 94 | self.loss_gen = self.opt.lambda_dis * self.loss_gen_GAN \ 95 | + self.opt.lambda_aus * self.loss_gen_fake_aus \ 96 | + self.opt.lambda_rec * self.loss_gen_rec \ 97 | + self.opt.lambda_mask * (self.loss_gen_mask_real_aus + self.loss_gen_mask_fake_aus) \ 98 | + self.opt.lambda_tv * (self.loss_gen_smooth_real_aus + self.loss_gen_smooth_fake_aus) 99 | 100 | self.loss_gen.backward() 101 | 102 | def optimize_paras(self, train_gen): 103 | self.forward() 104 | # update discriminator 105 | self.set_requires_grad(self.net_dis, True) 106 | self.optim_dis.zero_grad() 107 | self.backward_dis() 108 | self.optim_dis.step() 109 | 110 | # update G if needed 111 | if train_gen: 112 | self.set_requires_grad(self.net_dis, False) 113 | self.optim_gen.zero_grad() 114 | self.backward_gen() 115 | self.optim_gen.step() 116 | 117 | def save_ckpt(self, epoch): 118 | # save the specific networks 119 | save_models_name = ['gen', 'dis'] 120 | return super(GANimationModel, self).save_ckpt(epoch, save_models_name) 121 | 122 | def load_ckpt(self, epoch): 123 | # load the specific part of networks 124 | load_models_name = ['gen'] 125 | if self.is_train: 126 | load_models_name.extend(['dis']) 127 | return super(GANimationModel, self).load_ckpt(epoch, load_models_name) 128 | 129 | def clean_ckpt(self, epoch): 130 | # load the specific part of networks 131 | load_models_name = ['gen', 'dis'] 132 | return super(GANimationModel, self).clean_ckpt(epoch, load_models_name) 133 | 134 | def get_latest_losses(self): 135 | get_losses_name = ['dis_fake', 'dis_real', 'dis_real_aus', 'gen_rec'] 136 | return super(GANimationModel, self).get_latest_losses(get_losses_name) 137 | 138 | def get_latest_visuals(self): 139 | visuals_name = ['src_img', 'tar_img', 'color_mask', 'aus_mask', 'fake_img'] 140 | if self.is_train: 141 | visuals_name.extend(['rec_color_mask', 'rec_aus_mask', 'rec_real_img']) 142 | return super(GANimationModel, self).get_latest_visuals(visuals_name) 143 | -------------------------------------------------------------------------------- /solvers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Dec 13, 2018 3 | @author: Yuedong Chen 4 | """ 5 | 6 | from data import create_dataloader 7 | from model import create_model 8 | from visualizer import Visualizer 9 | import copy 10 | import time 11 | import os 12 | import torch 13 | import numpy as np 14 | from PIL import Image 15 | 16 | 17 | def create_solver(opt): 18 | instance = Solver() 19 | instance.initialize(opt) 20 | return instance 21 | 22 | 23 | 24 | class Solver(object): 25 | """docstring for Solver""" 26 | def __init__(self): 27 | super(Solver, self).__init__() 28 | 29 | def initialize(self, opt): 30 | self.opt = opt 31 | self.visual = Visualizer() 32 | self.visual.initialize(self.opt) 33 | 34 | def run_solver(self): 35 | if self.opt.mode == "train": 36 | self.train_networks() 37 | else: 38 | self.test_networks(self.opt) 39 | 40 | def train_networks(self): 41 | # init train setting 42 | self.init_train_setting() 43 | 44 | # for every epoch 45 | for epoch in range(self.opt.epoch_count, self.epoch_len + 1): 46 | # train network 47 | self.train_epoch(epoch) 48 | # update learning rate 49 | self.cur_lr = self.train_model.update_learning_rate() 50 | # save checkpoint if needed 51 | if epoch % self.opt.save_epoch_freq == 0: 52 | self.train_model.save_ckpt(epoch) 53 | 54 | # save the last epoch 55 | self.train_model.save_ckpt(self.epoch_len) 56 | 57 | def init_train_setting(self): 58 | self.train_dataset = create_dataloader(self.opt) 59 | self.train_model = create_model(self.opt) 60 | 61 | self.train_total_steps = 0 62 | self.epoch_len = self.opt.niter + self.opt.niter_decay 63 | self.cur_lr = self.opt.lr 64 | 65 | def train_epoch(self, epoch): 66 | epoch_start_time = time.time() 67 | epoch_steps = 0 68 | 69 | last_print_step_t = time.time() 70 | for idx, batch in enumerate(self.train_dataset): 71 | 72 | self.train_total_steps += self.opt.batch_size 73 | epoch_steps += self.opt.batch_size 74 | # train network 75 | self.train_model.feed_batch(batch) 76 | self.train_model.optimize_paras(train_gen=(idx % self.opt.train_gen_iter == 0)) 77 | # print losses 78 | if self.train_total_steps % self.opt.print_losses_freq == 0: 79 | cur_losses = self.train_model.get_latest_losses() 80 | avg_step_t = (time.time() - last_print_step_t) / self.opt.print_losses_freq 81 | last_print_step_t = time.time() 82 | # print loss info to command line 83 | info_dict = {'epoch': epoch, 'epoch_len': self.epoch_len, 84 | 'epoch_steps': idx * self.opt.batch_size, 'epoch_steps_len': len(self.train_dataset), 85 | 'step_time': avg_step_t, 'cur_lr': self.cur_lr, 86 | 'log_path': os.path.join(self.opt.ckpt_dir, self.opt.log_file), 87 | 'losses': cur_losses 88 | } 89 | self.visual.print_losses_info(info_dict) 90 | 91 | # plot loss map to visdom 92 | if self.train_total_steps % self.opt.plot_losses_freq == 0 and self.visual.display_id > 0: 93 | cur_losses = self.train_model.get_latest_losses() 94 | epoch_steps = idx * self.opt.batch_size 95 | self.visual.display_current_losses(epoch - 1, epoch_steps / len(self.train_dataset), cur_losses) 96 | 97 | # display image on visdom 98 | if self.train_total_steps % self.opt.sample_img_freq == 0 and self.visual.display_id > 0: 99 | cur_vis = self.train_model.get_latest_visuals() 100 | self.visual.display_online_results(cur_vis, epoch) 101 | # latest_aus = model.get_latest_aus() 102 | # visual.log_aus(epoch, epoch_steps, latest_aus, opt.ckpt_dir) 103 | 104 | def test_networks(self, opt): 105 | self.init_test_setting(opt) 106 | self.test_ops() 107 | 108 | def init_test_setting(self, opt): 109 | self.test_dataset = create_dataloader(opt) 110 | self.test_model = create_model(opt) 111 | 112 | def test_ops(self): 113 | for batch_idx, batch in enumerate(self.test_dataset): 114 | with torch.no_grad(): 115 | # interpolate several times 116 | faces_list = [batch['src_img'].float().numpy()] 117 | paths_list = [batch['src_path'], batch['tar_path']] 118 | for idx in range(self.opt.interpolate_len): 119 | cur_alpha = (idx + 1.) / float(self.opt.interpolate_len) 120 | cur_tar_aus = cur_alpha * batch['tar_aus'] + (1 - cur_alpha) * batch['src_aus'] 121 | # print(batch['src_aus']) 122 | # print(cur_tar_aus) 123 | test_batch = {'src_img': batch['src_img'], 'tar_aus': cur_tar_aus, 'src_aus':batch['src_aus'], 'tar_img':batch['tar_img']} 124 | 125 | self.test_model.feed_batch(test_batch) 126 | self.test_model.forward() 127 | 128 | cur_gen_faces = self.test_model.fake_img.cpu().float().numpy() 129 | faces_list.append(cur_gen_faces) 130 | faces_list.append(batch['tar_img'].float().numpy()) 131 | self.test_save_imgs(faces_list, paths_list) 132 | 133 | def test_save_imgs(self, faces_list, paths_list): 134 | for idx in range(len(paths_list[0])): 135 | src_name = os.path.splitext(os.path.basename(paths_list[0][idx]))[0] 136 | tar_name = os.path.splitext(os.path.basename(paths_list[1][idx]))[0] 137 | 138 | if self.opt.save_test_gif: 139 | import imageio 140 | imgs_numpy_list = [] 141 | for face_idx in range(len(faces_list) - 1): # remove target image 142 | cur_numpy = np.array(self.visual.numpy2im(faces_list[face_idx][idx])) 143 | imgs_numpy_list.extend([cur_numpy for _ in range(3)]) 144 | saved_path = os.path.join(self.opt.results, "%s_%s.gif" % (src_name, tar_name)) 145 | imageio.mimsave(saved_path, imgs_numpy_list) 146 | else: 147 | # concate src, inters, tar faces 148 | concate_img = np.array(self.visual.numpy2im(faces_list[0][idx])) 149 | for face_idx in range(1, len(faces_list)): 150 | concate_img = np.concatenate((concate_img, np.array(self.visual.numpy2im(faces_list[face_idx][idx]))), axis=1) 151 | concate_img = Image.fromarray(concate_img) 152 | # save image 153 | saved_path = os.path.join(self.opt.results, "%s_%s.jpg" % (src_name, tar_name)) 154 | concate_img.save(saved_path) 155 | 156 | print("[Success] Saved images to %s" % saved_path) 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | import random 5 | from . import model_utils 6 | 7 | 8 | class BaseModel: 9 | """docstring for BaseModel""" 10 | def __init__(self): 11 | super(BaseModel, self).__init__() 12 | self.name = "Base" 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = self.opt.gpu_ids 17 | self.device = torch.device('cuda:%d' % self.gpu_ids[0] if self.gpu_ids else 'cpu') 18 | self.is_train = self.opt.mode == "train" 19 | # inherit to define network model 20 | self.models_name = [] 21 | 22 | def setup(self): 23 | print("%s with Model [%s]" % (self.opt.mode.capitalize(), self.name)) 24 | if self.is_train: 25 | self.set_train() 26 | # define loss function 27 | self.criterionGAN = model_utils.GANLoss(gan_type=self.opt.gan_type).to(self.device) 28 | self.criterionL1 = torch.nn.L1Loss().to(self.device) 29 | self.criterionMSE = torch.nn.MSELoss().to(self.device) 30 | self.criterionTV = model_utils.TVLoss().to(self.device) 31 | torch.nn.DataParallel(self.criterionGAN, self.gpu_ids) 32 | torch.nn.DataParallel(self.criterionL1, self.gpu_ids) 33 | torch.nn.DataParallel(self.criterionMSE, self.gpu_ids) 34 | torch.nn.DataParallel(self.criterionTV, self.gpu_ids) 35 | # inherit to set up train/val/test status 36 | self.losses_name = [] 37 | self.optims = [] 38 | self.schedulers = [] 39 | else: 40 | self.set_eval() 41 | 42 | def set_eval(self): 43 | print("Set model to Test state.") 44 | for name in self.models_name: 45 | if isinstance(name, str): 46 | net = getattr(self, 'net_' + name) 47 | if not self.opt.no_test_eval: 48 | net.eval() 49 | print("Set net_%s to EVAL." % name) 50 | else: 51 | net.train() 52 | self.is_train = False 53 | 54 | def set_train(self): 55 | print("Set model to Train state.") 56 | for name in self.models_name: 57 | if isinstance(name, str): 58 | net = getattr(self, 'net_' + name) 59 | net.train() 60 | print("Set net_%s to TRAIN." % name) 61 | self.is_train = True 62 | 63 | def set_requires_grad(self, parameters, requires_grad=False): 64 | if not isinstance(parameters, list): 65 | parameters = [parameters] 66 | for param in parameters: 67 | if param is not None: 68 | param.requires_grad = requires_grad 69 | 70 | def get_latest_visuals(self, visuals_name): 71 | visual_ret = OrderedDict() 72 | for name in visuals_name: 73 | if isinstance(name, str) and hasattr(self, name): 74 | visual_ret[name] = getattr(self, name) 75 | return visual_ret 76 | 77 | def get_latest_losses(self, losses_name): 78 | errors_ret = OrderedDict() 79 | for name in losses_name: 80 | if isinstance(name, str): 81 | cur_loss = float(getattr(self, 'loss_' + name)) 82 | # cur_loss_lambda = 1. if len(losses_name) == 1 else float(getattr(self.opt, 'lambda_' + name)) 83 | # errors_ret[name] = cur_loss * cur_loss_lambda 84 | errors_ret[name] = cur_loss 85 | return errors_ret 86 | 87 | def feed_batch(self, batch): 88 | pass 89 | 90 | def forward(self): 91 | pass 92 | 93 | def optimize_paras(self): 94 | pass 95 | 96 | def update_learning_rate(self): 97 | for scheduler in self.schedulers: 98 | scheduler.step() 99 | lr = self.optims[0].param_groups[0]['lr'] 100 | return lr 101 | 102 | def save_ckpt(self, epoch, models_name): 103 | for name in models_name: 104 | if isinstance(name, str): 105 | save_filename = '%s_net_%s.pth' % (epoch, name) 106 | save_path = os.path.join(self.opt.ckpt_dir, save_filename) 107 | net = getattr(self, 'net_' + name) 108 | # save cpu params, so that it can be used in other GPU settings 109 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 110 | torch.save(net.module.cpu().state_dict(), save_path) 111 | net.to(self.gpu_ids[0]) 112 | net = torch.nn.DataParallel(net, self.gpu_ids) 113 | else: 114 | torch.save(net.cpu().state_dict(), save_path) 115 | 116 | def load_ckpt(self, epoch, models_name): 117 | # print(models_name) 118 | for name in models_name: 119 | if isinstance(name, str): 120 | load_filename = '%s_net_%s.pth' % (epoch, name) 121 | load_path = os.path.join(self.opt.ckpt_dir, load_filename) 122 | assert os.path.isfile(load_path), "File '%s' does not exist." % load_path 123 | 124 | pretrained_state_dict = torch.load(load_path, map_location=str(self.device)) 125 | if hasattr(pretrained_state_dict, '_metadata'): 126 | del pretrained_state_dict._metadata 127 | 128 | net = getattr(self, 'net_' + name) 129 | if isinstance(net, torch.nn.DataParallel): 130 | net = net.module 131 | # load only existing keys 132 | pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if k in net.state_dict()} 133 | # for k, v in pretrained_state_dict.items(): 134 | # print(k) 135 | # assert False 136 | net.load_state_dict(pretrained_dict) 137 | print("[Info] Successfully load trained weights for net_%s." % name) 138 | 139 | def clean_ckpt(self, epoch, models_name): 140 | for name in models_name: 141 | if isinstance(name, str): 142 | load_filename = '%s_net_%s.pth' % (epoch, name) 143 | load_path = os.path.join(self.opt.ckpt_dir, load_filename) 144 | if os.path.isfile(load_path): 145 | os.remove(load_path) 146 | 147 | def gradient_penalty(self, input_img, generate_img): 148 | # interpolate sample 149 | alpha = torch.rand(input_img.size(0), 1, 1, 1).to(self.device) 150 | inter_img = (alpha * input_img.data + (1 - alpha) * generate_img.data).requires_grad_(True) 151 | inter_img_prob, _ = self.net_dis(inter_img) 152 | 153 | # computer gradient penalty: x: inter_img, y: inter_img_prob 154 | # (L2_norm(dy/dx) - 1)**2 155 | dydx = torch.autograd.grad(outputs=inter_img_prob, 156 | inputs=inter_img, 157 | grad_outputs=torch.ones(inter_img_prob.size()).to(self.device), 158 | retain_graph=True, 159 | create_graph=True, 160 | only_inputs=True)[0] 161 | dydx = dydx.view(dydx.size(0), -1) 162 | dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1)) 163 | return torch.mean((dydx_l2norm - 1) ** 2) 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GANimation -- An Out-of-the-Box Replicate 2 | 3 |
4 |
5 |
6 |
7 |
8 |
14 |
15 |
16 |
17 |
18 |
19 |