├── audio2landmark ├── __init__.py ├── utils.py ├── loss.py ├── APBDataset.py ├── main.py ├── APBNet.py └── APBGAN.py ├── landmark2face ├── APB │ ├── _init_.py │ ├── APBDataset.py │ └── APBNet.py ├── experiments │ ├── test.sh │ └── train.sh ├── util │ ├── __init__.py │ ├── image_pool.py │ ├── util.py │ ├── html.py │ ├── get_data.py │ └── visualizer.py ├── options │ ├── __init__.py │ ├── test_options.py │ ├── train_options.py │ └── base_options.py ├── data │ ├── image_folder.py │ ├── l2face_dataset.py │ ├── __init__.py │ └── base_dataset.py ├── models │ ├── __init__.py │ ├── l2face_model.py │ ├── networks_l2face.py │ ├── base_model.py │ └── networks.py ├── test.py └── train.py ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /audio2landmark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /landmark2face/APB/_init_.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.1 2 | torchvision==0.4.2 3 | opencv-python==4.1.2.30 4 | -------------------------------------------------------------------------------- /landmark2face/experiments/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ~/anaconda3/bin/python3 test.py 4 | -------------------------------------------------------------------------------- /landmark2face/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /landmark2face/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /landmark2face/experiments/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ~/anaconda3/bin/python3 train.py --name man1_Res9 \ 4 | --gpu_ids 0 \ 5 | --batch_size 12 \ 6 | --img_size 256 \ 7 | --model l2face \ 8 | --dataset_mode l2face \ 9 | --netG resnet_9blocks_l2face 10 | -------------------------------------------------------------------------------- /audio2landmark/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.init as init 2 | 3 | 4 | def adjust_learning_rate(optimizer, lr, epoch, every=100): 5 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 6 | lr = lr * (0.1 ** (epoch // every)) 7 | for param_group in optimizer.param_groups: 8 | param_group['lr'] = lr 9 | 10 | def weight_init(m): 11 | classname = m.__class__.__name__ 12 | if classname.find('Conv') != -1: 13 | init.xavier_normal_(m.weight.data, gain=1) 14 | if m.bias is not None: 15 | init.constant_(m.bias.data, 0.0) 16 | # print('Conv') 17 | elif classname.find('Linear') != -1: 18 | init.xavier_normal_(m.weight.data, gain=1) 19 | if m.bias is not None: 20 | init.constant_(m.bias.data, 0.0) 21 | # print('Linear') 22 | elif classname.find('FusePool_zjn') != -1: 23 | init.constant_(m.weight.data, 1.0) 24 | # print('FusePool_zjn') 25 | -------------------------------------------------------------------------------- /audio2landmark/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class GANLoss(nn.Module): 6 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 7 | super(GANLoss, self).__init__() 8 | self.register_buffer('real_label', torch.tensor(target_real_label)) 9 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 10 | self.gan_mode = gan_mode 11 | if gan_mode == 'mse': 12 | self.loss = nn.MSELoss() 13 | 14 | def get_target_tensor(self, prediction, target_is_real): 15 | if target_is_real: 16 | target_tensor = self.real_label 17 | else: 18 | target_tensor = self.fake_label 19 | return target_tensor.expand_as(prediction) 20 | 21 | def __call__(self, prediction, target_is_real): 22 | if self.gan_mode in ['mse']: 23 | target_tensor = self.get_target_tensor(prediction, target_is_real) 24 | loss = self.loss(prediction, target_tensor) 25 | return loss 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 zhangzjn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /landmark2face/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | # rewrite devalue values 20 | # parser.set_defaults(model='test') 21 | # To avoid cropping, the load_size should be the same as crop_size 22 | # parser.set_defaults(load_size=parser.get_default('crop_size')) 23 | self.isTrain = False 24 | return parser 25 | -------------------------------------------------------------------------------- /landmark2face/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /audio2landmark/APBDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.utils.data import dataset 3 | import torch 4 | import os 5 | import torchvision.transforms as transforms 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class APBDataset(dataset.Dataset): 11 | def __init__(self, root, idt_name='man1', mode='train', img_size=256): 12 | self.root = root 13 | self.idt_name = idt_name 14 | if not isinstance(mode, list): 15 | mode = [mode] 16 | 17 | self.data_all = list() 18 | for m in mode: 19 | training_data_path = os.path.join(self.root, self.idt_name, '{}_{}.t7'.format(img_size, m)) 20 | training_data = torch.load(training_data_path) 21 | img_paths = training_data['img_paths'] 22 | audio_features = training_data['audio_features'] 23 | lands = training_data['lands'] 24 | poses = training_data['poses'] 25 | eyes = training_data['eyes'] 26 | for i in range(len(img_paths)): 27 | img_path = [os.path.join(self.root, self.idt_name, p) for p in img_paths[i]] # [image, landmark] 28 | audio_feature = audio_features[i] 29 | land = lands[i] 30 | pose = poses[i] 31 | eye = eyes[i] 32 | self.data_all.append([img_path, audio_feature, land, pose, eye]) 33 | self.data_all.sort(key=lambda x: int(x[0][0].split('/')[-1].split('.')[0])) 34 | if 'train' in mode and len(mode) == 1: 35 | self.shuffle() 36 | 37 | def shuffle(self): 38 | random.shuffle(self.data_all) 39 | 40 | def __len__(self): 41 | return len(self.data_all) 42 | 43 | def __getitem__(self, index): 44 | img_path_A1, audio_feature_A1, land_A1, pose_A1, eye_A1 = self.data_all[index] 45 | img_path_A2, audio_feature_A2, land_A2, pose_A2, eye_A2 = random.sample(self.data_all, 1)[0] 46 | # audio 47 | audio_feature_A1 = torch.tensor(audio_feature_A1).unsqueeze(dim=0) 48 | # pose 49 | pose_A1 = torch.tensor(pose_A1) 50 | # eye 51 | eye_A1 = torch.tensor(eye_A1) 52 | # landmark 53 | land_A1 = torch.tensor(land_A1) 54 | land_A2 = torch.tensor(land_A2) 55 | 56 | 57 | return [audio_feature_A1, pose_A1, eye_A1], [land_A1, land_A2] 58 | 59 | 60 | if __name__ == '__main__': 61 | root = '/media/datasets/zhangzjn/AnnVI/feature' 62 | idt_name = 'man1' 63 | trainset = APBDataset(root, idt_name) 64 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4) 65 | for batch_idx, _ in enumerate(trainloader): 66 | print(batch_idx) 67 | -------------------------------------------------------------------------------- /landmark2face/APB/APBDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.utils.data import dataset 3 | import torch 4 | import os 5 | import torchvision.transforms as transforms 6 | import torch 7 | import numpy as np 8 | 9 | 10 | class APBDataset(dataset.Dataset): 11 | def __init__(self, root, idt_name='man1', mode='train', img_size=256): 12 | self.root = root 13 | self.idt_name = idt_name 14 | if not isinstance(mode, list): 15 | mode = [mode] 16 | 17 | self.data_all = list() 18 | for m in mode: 19 | training_data_path = os.path.join(self.root, self.idt_name, '{}_{}.t7'.format(img_size, m)) 20 | training_data = torch.load(training_data_path) 21 | img_paths = training_data['img_paths'] 22 | audio_features = training_data['audio_features'] 23 | lands = training_data['lands'] 24 | poses = training_data['poses'] 25 | eyes = training_data['eyes'] 26 | for i in range(len(img_paths)): 27 | img_path = [os.path.join(self.root, self.idt_name, p) for p in img_paths[i]] # [image, landmark] 28 | audio_feature = audio_features[i] 29 | land = lands[i] 30 | pose = poses[i] 31 | eye = eyes[i] 32 | self.data_all.append([img_path, audio_feature, land, pose, eye]) 33 | self.data_all.sort(key=lambda x: int(x[0][0].split('/')[-1].split('.')[0])) 34 | if 'train' in mode and len(mode) == 1: 35 | self.shuffle() 36 | 37 | def shuffle(self): 38 | random.shuffle(self.data_all) 39 | 40 | def __len__(self): 41 | return len(self.data_all) 42 | 43 | def __getitem__(self, index): 44 | img_path_A1, audio_feature_A1, land_A1, pose_A1, eye_A1 = self.data_all[index] 45 | img_path_A2, audio_feature_A2, land_A2, pose_A2, eye_A2 = random.sample(self.data_all, 1)[0] 46 | # audio 47 | audio_feature_A1 = torch.tensor(audio_feature_A1).unsqueeze(dim=0) 48 | # pose 49 | pose_A1 = torch.tensor(pose_A1) 50 | # eye 51 | eye_A1 = torch.tensor(eye_A1) 52 | # landmark 53 | land_A1 = torch.tensor(land_A1) 54 | land_A2 = torch.tensor(land_A2) 55 | 56 | 57 | return [audio_feature_A1, pose_A1, eye_A1], [land_A1, land_A2], [img_path_A1, img_path_A2] 58 | 59 | 60 | if __name__ == '__main__': 61 | root = '/media/datasets/zhangzjn/AnnVI/feature' 62 | idt_name = 'man1' 63 | trainset = APBDataset(root, idt_name) 64 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4) 65 | for batch_idx, _ in enumerate(trainloader): 66 | print(batch_idx) 67 | -------------------------------------------------------------------------------- /landmark2face/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | audio2landmark/checkpoints 2 | landmark2face/APB/man1_best.pth 3 | landmark2face/checkpoints 4 | landmark2face/metrics 5 | landmark2face/result 6 | AnnVI 7 | .idea 8 | 9 | # Created by .ignore support plugin (hsz.mobi) 10 | ### Python template 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | -------------------------------------------------------------------------------- /landmark2face/data/l2face_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | from data.base_dataset import BaseDataset, get_params, get_transform 4 | import torchvision.transforms as transforms 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | import numpy as np 8 | 9 | 10 | class L2FaceDataset(BaseDataset): 11 | def __init__(self, opt): 12 | BaseDataset.__init__(self, opt) 13 | img_size = opt.img_size 14 | root = '../AnnVI/feature/{}'.format(opt.name.split('_')[0]) 15 | image_dir = '{}/{}_image_crop'.format(root, img_size) 16 | label_dir = '{}/{}_landmark_crop_thin'.format(root, img_size) 17 | # label_dir = '{}/512_landmark_crop'.format(root) 18 | self.labels = [] 19 | 20 | imgs = os.listdir(image_dir) 21 | # if 'man' in opt.name: 22 | # imgs.sort(key=lambda x:int(x.split('.')[0])) 23 | # else: 24 | # imgs.sort(key=lambda x: (int(x.split('.')[0].split('-')[0]), int(x.split('.')[0].split('-')[1]))) 25 | for img in imgs: 26 | img_path = os.path.join(image_dir, img) 27 | lab_path = os.path.join(label_dir, img) 28 | if os.path.exists(lab_path): 29 | self.labels.append([img_path, lab_path]) 30 | # transforms.Resize([img_size, img_size], Image.BICUBIC), 31 | self.transforms_image = transforms.Compose([transforms.ToTensor(), 32 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 33 | # transforms.Resize([img_size, img_size], Image.BICUBIC), 34 | self.transforms_label = transforms.Compose([transforms.ToTensor(), 35 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 36 | self.shuffle() 37 | 38 | def shuffle(self): 39 | random.shuffle(self.labels) 40 | 41 | 42 | def __getitem__(self, index): 43 | img_path, lab_path = self.labels[index] 44 | img = Image.open(img_path).convert('RGB') 45 | lab = Image.open(lab_path).convert('RGB') 46 | img = self.transforms_image(img) 47 | lab = self.transforms_label(lab) 48 | 49 | imgA_path, labA_path = random.sample(self.labels, 1)[0] 50 | imgA = Image.open(imgA_path).convert('RGB') 51 | imgA = self.transforms_image(imgA) 52 | 53 | 54 | return {'A': imgA, 'A_label': lab, 'B': img, 'B_label': lab} 55 | 56 | def __len__(self): 57 | """Return the total number of images in the dataset.""" 58 | return len(self.labels) 59 | 60 | 61 | if __name__ == '__main__': 62 | from options.train_options import TrainOptions 63 | opt = TrainOptions().parse() 64 | dataset = L2FaceDataset(opt) 65 | dataset_size = len(dataset) 66 | print(dataset_size) 67 | for i, data in enumerate(dataset): 68 | print(data) 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## APB2Face — Official PyTorch Implementation 2 | 3 | ![Python 3.7](https://img.shields.io/badge/python-3.7-green.svg?style=plastic) ![PyTorch 1.3.1](https://img.shields.io/badge/pytorch-1.3.1-green.svg?style=plastic) ![License MIT](https://img.shields.io/github/license/zhangzjn/APB2Face) 4 | 5 | Official pytorch implementation of the paper "[APB2FACE: AUDIO-GUIDED FACE REENACTMENT WITH AUXILIARY POSE AND BLINK SIGNALS, ICASSP'20](https://arxiv.org/pdf/2004.14569.pdf)". 6 | 7 | For any inquiries, please contact Jiangning Zhang at [186368@zju.edu.cn](mailto:186368@zju.edu.cn) 8 | 9 | ## Using the Code 10 | 11 | ### Requirements 12 | 13 | This code has been developed under `Python3.7`, `PyTorch 1.3.1` and `CUDA 10.1` on `Ubuntu 16.04`. 14 | 15 | 16 | ```shell 17 | # Install python3 packages 18 | pip3 install -r requirements.txt 19 | ``` 20 | 21 | ### Inference 22 | 23 | - Download pretraind [Audio-to-Landmark model](https://drive.google.com/file/d/159jQ27M_dqKmQ3ZacYZu6woXQ1f8Yc_H/view?usp=sharing) for the person **man1** to the path `landmark2face/APB/man1_best.pth`. 24 | - Download pretraind [Landmark-to-Face model](https://drive.google.com/file/d/1UqjxWG2kNVfG3G65SxdEsTrlGg9KqBRU/view?usp=sharing) for the person **man1** to the path `landmark2face/checkpoints/man1_Res9/latest_net_G.pth` 25 | 26 | ```shell 27 | python3 test.py 28 | ``` 29 | 30 | You can view the result in `result/man1.avi` 31 | 32 | ### Training 33 | 34 | 1. Train **Audio-to-Landmark** model. 35 | 36 | ```shell 37 | python3 audio2landmark/main.py 38 | ``` 39 | 40 | 2. Train **Landmark-to-Face** model. 41 | 42 | ```shell 43 | cd landmark2face 44 | sh experiments/train.sh 45 | ``` 46 | you can watch the checkpoint in `checkpoints/man1_Res9` 47 | 48 | 3. Do following operations before the test. 49 | 50 | ```shell 51 | copy audio2landmark/APBNet.py landmark2face/APB/APBNet.py # if you modify APBNet.py 52 | copy audio2landmark/APBDataset.py landmark2face/APB/APBDataset.py # if you modify APBDataset.py 53 | copy audio2landmark/checkpoints/man1-xxx/man1_best.pth landmark2face/APB/man1_best.pth 54 | ``` 55 | 56 | ## Datasets in the paper 57 | 58 | We propose a new **AnnVI** dataset, you can download it from 59 | [Google Drive](https://drive.google.com/file/d/1xEnZwNLU4SmgFFh4WGV4KEOdegfFrOdp/view?usp=sharing) 60 | or 61 | [Baidu Cloud](https://pan.baidu.com/s/1oydpePBQieRoDmaENg3kfQ) (Key:str3). 62 | ### Citation 63 | 64 | If you think this work is useful for your research, please consider citing: 65 | 66 | ``` 67 | @inproceedings{zhang2020apb2face, 68 | title={APB2FACE: Audio-Guided Face Reenactment with Auxiliary Pose and Blink Signals}, 69 | author={Zhang, Jiangning and Liu, Liang and Xue, Zhucun and Liu, Yong}, 70 | booktitle={ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 71 | pages={4402--4406}, 72 | year={2020}, 73 | organization={IEEE} 74 | } 75 | ``` 76 | 77 | ### Acknowledgements 78 | 79 | We thank for the source code from the great work [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). -------------------------------------------------------------------------------- /audio2landmark/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sys 4 | from APBDataset import * 5 | from APBGAN import * 6 | from utils import * 7 | 8 | 9 | parser = argparse.ArgumentParser(description='APBNet') 10 | parser.add_argument('--isTrain', default=True, type=bool, help='running mode') 11 | parser.add_argument('--lr', default=0.0003, type=float, help='learning rate') 12 | parser.add_argument('--every', default=300, type=float, help='learning rate decay') 13 | parser.add_argument('--gpus', default='1', type=str, help='gpus') 14 | parser.add_argument('--landmark_path', default='../AnnVI/feature', type=str, help='landmark path that contains several persons') 15 | parser.add_argument('--checkpoints', default='checkpoints', type=str, help='checkpoint path') 16 | parser.add_argument('--epochs', default=800, type=int, help='epochs') 17 | parser.add_argument('--resume', '-r', default=False, type=bool, help='resume') 18 | parser.add_argument('--resume_epoch', default=None, type=int, help='resume epoch') 19 | parser.add_argument('--resume_name', default='man1-20200428-134038', type=str, help='resume epoch') 20 | parser.add_argument('--idt_name', default='man1', type=str, help='identity name') 21 | 22 | opt = parser.parse_args() 23 | opt.gpus = [int(dev) for dev in opt.gpus.split(',')] 24 | torch.cuda.set_device(opt.gpus[0]) 25 | 26 | # logging 27 | if not os.path.exists(opt.checkpoints): 28 | os.mkdir(opt.checkpoints) 29 | if opt.resume: 30 | opt.logdir = '{}/{}'.format(opt.checkpoints, opt.resume_name) 31 | else: 32 | opt.logdir = '{}/{}-{}'.format(opt.checkpoints, opt.idt_name, time.strftime("%Y%m%d-%H%M%S")) 33 | if not os.path.exists(opt.logdir): 34 | os.mkdir(opt.logdir) 35 | 36 | log_format = '%(asctime)s - %(message)s' 37 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 38 | fh = logging.FileHandler(os.path.join(opt.logdir, 'log.txt')) 39 | fh.setFormatter(logging.Formatter(log_format)) 40 | logger = logging.getLogger() 41 | logger.addHandler(fh) 42 | 43 | for key, val in vars(opt).items(): 44 | if isinstance(val, list): 45 | val = [str(v) for v in val] 46 | val = ','.join(val) 47 | if val is None: 48 | val = 'None' 49 | logger.info('{:>20} : {:<50}'.format(key, val)) 50 | 51 | logger.info('==> Preparing data..') 52 | trainset = APBDataset(opt.landmark_path, opt.idt_name, mode='train') 53 | testset = APBDataset(opt.landmark_path, opt.idt_name, mode='test') 54 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=1) 55 | testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=1) 56 | 57 | 58 | logger.info('==> Building model..') 59 | net = GANModel(opt, logger) 60 | 61 | def train(epoch): 62 | net.train() 63 | net.run_train(trainloader) 64 | 65 | def test(epoch): 66 | net.eval() 67 | net.run_test(testloader) 68 | 69 | # drawloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True, num_workers=1) 70 | # net.test_draw(drawloader) 71 | 72 | for epoch in range(1, opt.epochs): 73 | train(epoch) 74 | test(epoch) 75 | logger.info('-' * 50) 76 | -------------------------------------------------------------------------------- /landmark2face/util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 25 | else: # if it is a numpy array, do nothing 26 | image_numpy = input_image 27 | return image_numpy.astype(imtype) 28 | 29 | 30 | def diagnose_network(net, name='network'): 31 | """Calculate and print the mean of average absolute(gradients) 32 | 33 | Parameters: 34 | net (torch network) -- Torch network 35 | name (str) -- the name of the network 36 | """ 37 | mean = 0.0 38 | count = 0 39 | for param in net.parameters(): 40 | if param.grad is not None: 41 | mean += torch.mean(torch.abs(param.grad.data)) 42 | count += 1 43 | if count > 0: 44 | mean = mean / count 45 | print(name) 46 | print(mean) 47 | 48 | 49 | def save_image(image_numpy, image_path): 50 | """Save a numpy image to the disk 51 | 52 | Parameters: 53 | image_numpy (numpy array) -- input numpy array 54 | image_path (str) -- the path of the image 55 | """ 56 | image_pil = Image.fromarray(image_numpy) 57 | image_pil.save(image_path) 58 | 59 | 60 | def print_numpy(x, val=True, shp=False): 61 | """Print the mean, min, max, median, std, and size of a numpy array 62 | 63 | Parameters: 64 | val (bool) -- if print the values of the numpy array 65 | shp (bool) -- if print the shape of the numpy array 66 | """ 67 | x = x.astype(np.float64) 68 | if shp: 69 | print('shape,', x.shape) 70 | if val: 71 | x = x.flatten() 72 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 73 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 74 | 75 | 76 | def mkdirs(paths): 77 | """create empty directories if they don't exist 78 | 79 | Parameters: 80 | paths (str list) -- a list of directory paths 81 | """ 82 | if isinstance(paths, list) and not isinstance(paths, str): 83 | for path in paths: 84 | mkdir(path) 85 | else: 86 | mkdir(paths) 87 | 88 | 89 | def mkdir(path): 90 | """create a single empty directory if it didn't exist 91 | 92 | Parameters: 93 | path (str) -- a single directory path 94 | """ 95 | if not os.path.exists(path): 96 | os.makedirs(path) 97 | -------------------------------------------------------------------------------- /landmark2face/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /landmark2face/models/l2face_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | 5 | 6 | class L2FaceModel(BaseModel): 7 | @staticmethod 8 | def modify_commandline_options(parser, is_train=True): 9 | # parser.set_defaults(norm='instance', netG='resnet_9blocks_l2face', dataset_mode='l2face') 10 | if is_train: 11 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 12 | return parser 13 | 14 | def __init__(self, opt): 15 | BaseModel.__init__(self, opt) 16 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 17 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 18 | if self.isTrain: 19 | self.model_names = ['G', 'D'] 20 | else: 21 | self.model_names = ['G'] 22 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 23 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 24 | 25 | if self.isTrain: 26 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 27 | opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) 28 | 29 | if self.isTrain: 30 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 31 | self.criterionL1 = torch.nn.L1Loss() 32 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 33 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 34 | self.optimizers.append(self.optimizer_G) 35 | self.optimizers.append(self.optimizer_D) 36 | 37 | def set_input(self, input): 38 | self.real_A = input['A'].to(self.device) 39 | self.A_label = input['A_label'].to(self.device) 40 | self.real_B = input['B'].to(self.device) 41 | self.B_label = input['B_label'].to(self.device) 42 | 43 | def forward(self): 44 | self.fake_B = self.netG(self.B_label) 45 | 46 | def backward_D(self): 47 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 48 | pred_fake = self.netD(fake_AB.detach()) 49 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 50 | real_AB = torch.cat((self.real_A, self.real_B), 1) 51 | pred_real = self.netD(real_AB) 52 | self.loss_D_real = self.criterionGAN(pred_real, True) 53 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 54 | self.loss_D.backward() 55 | 56 | def backward_G(self): 57 | lambda_GAN = 1 58 | lambda_L1 = 100 59 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 60 | pred_fake = self.netD(fake_AB) 61 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) * lambda_GAN 62 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * lambda_L1 63 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 64 | self.loss_G.backward() 65 | 66 | def optimize_parameters(self): 67 | self.forward() 68 | # update D 69 | self.set_requires_grad(self.netD, True) 70 | self.optimizer_D.zero_grad() 71 | self.backward_D() 72 | self.optimizer_D.step() 73 | # update G 74 | self.set_requires_grad(self.netD, False) 75 | self.optimizer_G.zero_grad() 76 | self.backward_G() 77 | self.optimizer_G.step() 78 | -------------------------------------------------------------------------------- /landmark2face/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--niter', type=int, default=40, help='# of iter at starting learning rate') 31 | parser.add_argument('--niter_decay', type=int, default=70, help='# of iter to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 34 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 35 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 36 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 37 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 38 | 39 | self.isTrain = True 40 | return parser 41 | -------------------------------------------------------------------------------- /landmark2face/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /landmark2face/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /landmark2face/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from options.train_options import TrainOptions 4 | from data import create_dataset 5 | from models import create_model 6 | from util.visualizer import save_images 7 | from util import html 8 | import cv2 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | from models.l2face_model import * 12 | import numpy as np 13 | import sys 14 | from util.util import * 15 | from APB.APBDataset import * 16 | from APB.APBNet import * 17 | import torch 18 | 19 | 20 | def tuple_shape(shape): 21 | r_data = [] 22 | for p in shape: 23 | r_data.append([p.x, p.y]) 24 | return r_data 25 | 26 | 27 | def drawCircle(img, shape, radius=1, color=(255, 255, 255), thickness=1): 28 | for p in shape: 29 | img = cv2.circle(img, (int(p[0]), int(p[1])), radius, color, thickness) 30 | return img 31 | 32 | 33 | def vector2points(landmark): 34 | shape = [] 35 | for i in range(len(landmark) // 2): 36 | shape.append([landmark[2 * i], landmark[2 * i + 1]]) 37 | return shape 38 | 39 | 40 | if __name__ == '__main__': 41 | opt = TestOptions().parse() 42 | opt.isTrain = False 43 | opt.name = 'man1_Res9' 44 | opt.model = 'l2face' 45 | opt.netG = 'resnet_9blocks_l2face' 46 | opt.dataset_mode = 'l2face' 47 | model = L2FaceModel(opt) # create a model given opt.model and other options 48 | model.setup(opt) # regular setup: load and print networks; create schedulers 49 | model.eval() 50 | transforms_label = transforms.Compose([transforms.ToTensor(), 51 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 52 | 53 | # audio2landmark 54 | audio_net = APBNet() 55 | checkpoint = torch.load('APB/man1_best.pth') 56 | audio_net.load_state_dict(checkpoint['net_G']) 57 | audio_net.cuda() 58 | audio_net.eval() 59 | # dataset 60 | feature_path = '../AnnVI/feature' 61 | idt_name = 'man1' 62 | testset = APBDataset(feature_path, idt_name=idt_name, mode='test', img_size=256) 63 | dataloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=1) 64 | 65 | out_path = 'result' 66 | if not os.path.exists(out_path): 67 | os.mkdir(out_path) 68 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 69 | out = cv2.VideoWriter(os.path.join(out_path, '{}.avi'.format(idt_name)), fourcc, 25.0, (256 * 2, 256)) 70 | 71 | for idx, data in enumerate(dataloader): 72 | audio_feature_A1, pose_A1, eye_A1 = data[0][0].cuda(), \ 73 | data[0][1].cuda(), \ 74 | data[0][2].cuda() 75 | landmark_A1, landmark_A2 = data[1][0].cuda(),\ 76 | data[1][1].cuda() 77 | 78 | image_path_A1 = data[2][0][0][0] 79 | print('\r{}/{}'.format(idx+1, len(dataloader)), end='') 80 | 81 | landmark = audio_net(audio_feature_A1, pose_A1, eye_A1) 82 | landmark = landmark.cpu().data.numpy().tolist()[0] 83 | lab_template = np.zeros((256, 256, 3)).astype(np.uint8) 84 | lab = drawCircle(lab_template.copy(), vector2points(landmark), radius=1, color=(255, 255, 255), thickness=4) 85 | lab = Image.fromarray(lab).convert('RGB') 86 | lab = transforms_label(lab).unsqueeze(0) 87 | 88 | input_data = {'A': lab, 'A_label': lab, 'B': lab, 'B_label': lab} 89 | model.set_input(input_data) 90 | model.test() 91 | visuals = model.get_current_visuals() 92 | B_img_f = tensor2im(visuals['fake_B']) 93 | B_img = cv2.imread(image_path_A1) 94 | B_img = cv2.cvtColor(B_img, cv2.COLOR_BGR2RGB) 95 | B_img = cv2.resize(B_img, (256, 256)) 96 | 97 | img_out = np.concatenate([B_img_f, B_img], axis=1) 98 | for _ in range(5): # five times slower 99 | out.write(cv2.cvtColor(img_out, cv2.COLOR_BGR2RGB)) 100 | if idx == 100: 101 | break 102 | out.release() 103 | -------------------------------------------------------------------------------- /landmark2face/util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /audio2landmark/APBNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class APBNet(nn.Module): 6 | 7 | def __init__(self, num_landmark=212): 8 | super(APBNet, self).__init__() 9 | self.num_landmark = num_landmark 10 | # audio 11 | self.audio1 = nn.Sequential( 12 | nn.Conv2d(1, 72, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 13 | nn.Conv2d(72, 108, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 14 | nn.Conv2d(108, 162, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 15 | nn.Conv2d(162, 243, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 16 | nn.Conv2d(243, 256, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 17 | ) 18 | self.audio2 = nn.Sequential( 19 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 20 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 21 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 22 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 23 | nn.Conv2d(256, 256, kernel_size=(4, 1), stride=(4, 1)), nn.ReLU() 24 | ) 25 | self.trans_audio = nn.Sequential(nn.Linear(256 * 2, 256)) 26 | # pose 27 | self.trans_pose = nn.Sequential( 28 | nn.Linear(3, 64), nn.ReLU(), 29 | nn.Linear(64, 64), nn.ReLU(), 30 | nn.Linear(64, 64) 31 | ) 32 | # eye 33 | self.trans_eye = nn.Sequential( 34 | nn.Linear(2, 64), nn.ReLU(), 35 | nn.Linear(64, 64), nn.ReLU(), 36 | nn.Linear(64, 64) 37 | ) 38 | # cat 39 | self.trans_cat = nn.Sequential( 40 | nn.Linear(256 + 64 * 2, 240), nn.ReLU(), 41 | nn.Linear(240, self.num_landmark) 42 | ) 43 | 44 | def num_flat_features(self, x): 45 | size = x.size()[1:] # all dimensions except the batch dimension 46 | num_features = 1 47 | for s in size: 48 | num_features *= s 49 | return num_features 50 | 51 | def forward(self, audio, pose, eye): 52 | x_a = self.audio1(audio) 53 | x_a = self.audio2(x_a) 54 | x_a = x_a.view(-1, self.num_flat_features(x_a)) 55 | x_a = self.trans_audio(x_a) 56 | x_p = self.trans_pose(pose) 57 | x_e = self.trans_eye(eye) 58 | x_cat = torch.cat([x_a, x_p, x_e], dim=1) 59 | output = self.trans_cat(x_cat) 60 | return output 61 | 62 | 63 | class Discriminator(nn.Module): 64 | def __init__(self): 65 | super(Discriminator, self).__init__() 66 | layers1 = [nn.Linear(106 * 2, 512), nn.LeakyReLU(0.2, True), 67 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 68 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 69 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 70 | nn.Linear(512, 64)] 71 | 72 | layers2 = [nn.Linear(106 * 2, 512), nn.LeakyReLU(0.2, True), 73 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 74 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 75 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 76 | nn.Linear(512, 64)] 77 | 78 | layers3 = [nn.Linear(128, 128), nn.LeakyReLU(0.2, True), 79 | nn.Linear(128, 32), nn.LeakyReLU(0.2, True), 80 | nn.Linear(32, 1)] 81 | 82 | self.layers1 = nn.Sequential(*layers1) 83 | self.layers2 = nn.Sequential(*layers2) 84 | self.layers3 = nn.Sequential(*layers3) 85 | 86 | def forward(self, input1, input2): 87 | x1 = self.layers1(input1) 88 | x2 = self.layers2(input2) 89 | x_cat = torch.cat([x1, x2], dim=1) 90 | out = self.layers3(x_cat) 91 | return out 92 | 93 | 94 | if __name__ == "__main__": 95 | torch.cuda.set_device(0) 96 | from APBDataset import * 97 | import time 98 | landmark_paths = '/media/datasets/zhangzjn/AnnVI/feature' 99 | testset = APBDataset(landmark_paths, 'man1') 100 | testloader = torch.utils.data.DataLoader(testset, batch_size=2, shuffle=True, num_workers=1) 101 | net = Generator() 102 | for batch_idx, training_data in enumerate(testloader): 103 | audio_feature_A1, pose_A1, eye_A1 = training_data[0][0], training_data[0][1],\ 104 | training_data[0][2] 105 | landmark = training_data[2][0] 106 | t_start = time.time() 107 | for i in range(10000): 108 | net(audio_feature_A1, landmark, pose_A1, eye_A1) 109 | print(time.time() - t_start) 110 | break -------------------------------------------------------------------------------- /landmark2face/APB/APBNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class APBNet(nn.Module): 6 | 7 | def __init__(self, num_landmark=212): 8 | super(APBNet, self).__init__() 9 | self.num_landmark = num_landmark 10 | # audio 11 | self.audio1 = nn.Sequential( 12 | nn.Conv2d(1, 72, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 13 | nn.Conv2d(72, 108, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 14 | nn.Conv2d(108, 162, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 15 | nn.Conv2d(162, 243, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 16 | nn.Conv2d(243, 256, kernel_size=(1, 3), stride=(1, 2), padding=(0, 1)), nn.ReLU(), 17 | ) 18 | self.audio2 = nn.Sequential( 19 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 20 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 21 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 22 | nn.Conv2d(256, 256, kernel_size=(3, 1), stride=(2, 1), padding=(1, 0)), nn.ReLU(), 23 | nn.Conv2d(256, 256, kernel_size=(4, 1), stride=(4, 1)), nn.ReLU() 24 | ) 25 | self.trans_audio = nn.Sequential(nn.Linear(256 * 2, 256)) 26 | # pose 27 | self.trans_pose = nn.Sequential( 28 | nn.Linear(3, 64), nn.ReLU(), 29 | nn.Linear(64, 64), nn.ReLU(), 30 | nn.Linear(64, 64) 31 | ) 32 | # eye 33 | self.trans_eye = nn.Sequential( 34 | nn.Linear(2, 64), nn.ReLU(), 35 | nn.Linear(64, 64), nn.ReLU(), 36 | nn.Linear(64, 64) 37 | ) 38 | # cat 39 | self.trans_cat = nn.Sequential( 40 | nn.Linear(256 + 64 * 2, 240), nn.ReLU(), 41 | nn.Linear(240, self.num_landmark) 42 | ) 43 | 44 | def num_flat_features(self, x): 45 | size = x.size()[1:] # all dimensions except the batch dimension 46 | num_features = 1 47 | for s in size: 48 | num_features *= s 49 | return num_features 50 | 51 | def forward(self, audio, pose, eye): 52 | x_a = self.audio1(audio) 53 | x_a = self.audio2(x_a) 54 | x_a = x_a.view(-1, self.num_flat_features(x_a)) 55 | x_a = self.trans_audio(x_a) 56 | x_p = self.trans_pose(pose) 57 | x_e = self.trans_eye(eye) 58 | x_cat = torch.cat([x_a, x_p, x_e], dim=1) 59 | output = self.trans_cat(x_cat) 60 | return output 61 | 62 | 63 | class Discriminator(nn.Module): 64 | def __init__(self): 65 | super(Discriminator, self).__init__() 66 | layers1 = [nn.Linear(106 * 2, 512), nn.LeakyReLU(0.2, True), 67 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 68 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 69 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 70 | nn.Linear(512, 64)] 71 | 72 | layers2 = [nn.Linear(106 * 2, 512), nn.LeakyReLU(0.2, True), 73 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 74 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 75 | nn.Linear(512, 512), nn.LeakyReLU(0.2, True), 76 | nn.Linear(512, 64)] 77 | 78 | layers3 = [nn.Linear(128, 128), nn.LeakyReLU(0.2, True), 79 | nn.Linear(128, 32), nn.LeakyReLU(0.2, True), 80 | nn.Linear(32, 1)] 81 | 82 | self.layers1 = nn.Sequential(*layers1) 83 | self.layers2 = nn.Sequential(*layers2) 84 | self.layers3 = nn.Sequential(*layers3) 85 | 86 | def forward(self, input1, input2): 87 | x1 = self.layers1(input1) 88 | x2 = self.layers2(input2) 89 | x_cat = torch.cat([x1, x2], dim=1) 90 | out = self.layers3(x_cat) 91 | return out 92 | 93 | 94 | if __name__ == "__main__": 95 | torch.cuda.set_device(0) 96 | from APBDataset import * 97 | import time 98 | landmark_paths = '/media/datasets/zhangzjn/AnnVI/feature' 99 | testset = APBDataset(landmark_paths, 'man1') 100 | testloader = torch.utils.data.DataLoader(testset, batch_size=2, shuffle=True, num_workers=1) 101 | net = Generator() 102 | for batch_idx, training_data in enumerate(testloader): 103 | audio_feature_A1, pose_A1, eye_A1 = training_data[0][0], training_data[0][1],\ 104 | training_data[0][2] 105 | landmark = training_data[2][0] 106 | t_start = time.time() 107 | for i in range(10000): 108 | net(audio_feature_A1, landmark, pose_A1, eye_A1) 109 | print(time.time() - t_start) 110 | break -------------------------------------------------------------------------------- /landmark2face/train.py: -------------------------------------------------------------------------------- 1 | """General-purpose training script for image-to-image translation. 2 | 3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and 4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). 5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). 6 | 7 | It first creates model, dataset, and visualizer given the option. 8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. 9 | The script supports continue/resume training. Use '--continue_train' to resume your previous training. 10 | 11 | Example: 12 | Train a CycleGAN model: 13 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 14 | Train a pix2pix model: 15 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 16 | 17 | See options/base_options.py and options/train_options.py for more training options. 18 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 19 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 20 | """ 21 | import time 22 | from options.train_options import TrainOptions 23 | from data import create_dataset 24 | from models import create_model 25 | from util.visualizer import Visualizer 26 | 27 | if __name__ == '__main__': 28 | opt = TrainOptions().parse() # get training options 29 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 30 | dataset_size = len(dataset) # get the number of images in the dataset. 31 | print('The number of training images = %d' % dataset_size) 32 | 33 | model = create_model(opt) # create a model given opt.model and other options 34 | model.setup(opt) # regular setup: load and print networks; create schedulers 35 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 36 | total_iters = 0 # the total number of training iterations 37 | 38 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): # outer loop for different epochs; we save the model by , + 39 | epoch_start_time = time.time() # timer for entire epoch 40 | iter_data_time = time.time() # timer for data loading per iteration 41 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 42 | 43 | for i, data in enumerate(dataset): # inner loop within one epoch 44 | iter_start_time = time.time() # timer for computation per iteration 45 | if total_iters % opt.print_freq == 0: 46 | t_data = iter_start_time - iter_data_time 47 | visualizer.reset() 48 | total_iters += opt.batch_size 49 | epoch_iter += opt.batch_size 50 | model.set_input(data) # unpack data from dataset and apply preprocessing 51 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 52 | 53 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 54 | save_result = total_iters % opt.update_html_freq == 0 55 | model.compute_visuals() 56 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 57 | 58 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 59 | losses = model.get_current_losses() 60 | t_comp = (time.time() - iter_start_time) / opt.batch_size 61 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 62 | if opt.display_id > 0: 63 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 64 | 65 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 66 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 67 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 68 | model.save_networks(save_suffix) 69 | 70 | iter_data_time = time.time() 71 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 72 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 73 | model.save_networks('latest') 74 | model.save_networks(epoch) 75 | 76 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 77 | model.update_learning_rate() # update learning rates at the end of every epoch. 78 | -------------------------------------------------------------------------------- /landmark2face/models/networks_l2face.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | 8 | class ResnetL2FaceGenerator(nn.Module): 9 | 10 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'): 11 | assert(n_blocks >= 0) 12 | self.n_blocks = n_blocks 13 | super(ResnetL2FaceGenerator, self).__init__() 14 | if type(norm_layer) == functools.partial: 15 | use_bias = norm_layer.func == nn.InstanceNorm2d 16 | else: 17 | use_bias = norm_layer == nn.InstanceNorm2d 18 | 19 | model1 = [nn.ReflectionPad2d(3), 20 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 21 | norm_layer(ngf), 22 | nn.ReLU(True)] 23 | 24 | n_downsampling = 2 25 | for i in range(n_downsampling): # add downsampling layers 26 | mult = 2 ** i 27 | model1 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 28 | norm_layer(ngf * mult * 2), 29 | nn.ReLU(True)] 30 | 31 | mult = 2 ** n_downsampling 32 | model2 = [] 33 | for i in range(self.n_blocks): 34 | model2 += [ResnetBlock(ngf * mult, ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 35 | 36 | 37 | model3 = [] 38 | for i in range(n_downsampling): # add upsampling layers 39 | mult = 2 ** (n_downsampling - i) 40 | model3 += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 41 | kernel_size=3, stride=2, 42 | padding=1, output_padding=1, 43 | bias=use_bias), 44 | norm_layer(int(ngf * mult / 2)), 45 | nn.ReLU(True)] 46 | model3 += [nn.ReflectionPad2d(3)] 47 | model3 += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 48 | model3 += [nn.Tanh()] 49 | 50 | self.model1 = nn.Sequential(*model1) 51 | self.model2 = nn.Sequential(*model2) 52 | self.model3 = nn.Sequential(*model3) 53 | 54 | def forward(self, input): 55 | """Standard forward""" 56 | x = self.model1(input) 57 | x = self.model2(x) 58 | out = self.model3(x) 59 | 60 | return out 61 | 62 | 63 | class ResnetBlock(nn.Module): 64 | """Define a Resnet block""" 65 | 66 | def __init__(self, dim_in, dim_out, padding_type, norm_layer, use_dropout, use_bias): 67 | """Initialize the Resnet block 68 | 69 | A resnet block is a conv block with skip connections 70 | We construct a conv block with build_conv_block function, 71 | and implement skip connections in function. 72 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 73 | """ 74 | super(ResnetBlock, self).__init__() 75 | self.conv_block = self.build_conv_block(dim_in, dim_out, padding_type, norm_layer, use_dropout, use_bias) 76 | self.shortcut = nn.Sequential(*[nn.Conv2d(dim_in, dim_out, kernel_size=3, padding=1, bias=use_bias), norm_layer(dim_out)]) 77 | 78 | def build_conv_block(self, dim_in, dim_out, padding_type, norm_layer, use_dropout, use_bias): 79 | """Construct a convolutional block. 80 | 81 | Parameters: 82 | dim (int) -- the number of channels in the conv layer. 83 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 84 | norm_layer -- normalization layer 85 | use_dropout (bool) -- if use dropout layers. 86 | use_bias (bool) -- if the conv layer uses bias or not 87 | 88 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 89 | """ 90 | conv_block = [] 91 | p = 0 92 | if padding_type == 'reflect': 93 | conv_block += [nn.ReflectionPad2d(1)] 94 | elif padding_type == 'replicate': 95 | conv_block += [nn.ReplicationPad2d(1)] 96 | elif padding_type == 'zero': 97 | p = 1 98 | else: 99 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 100 | 101 | conv_block += [nn.Conv2d(dim_in, dim_out, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim_out), nn.ReLU(True)] 102 | if use_dropout: 103 | conv_block += [nn.Dropout(0.5)] 104 | 105 | p = 0 106 | if padding_type == 'reflect': 107 | conv_block += [nn.ReflectionPad2d(1)] 108 | elif padding_type == 'replicate': 109 | conv_block += [nn.ReplicationPad2d(1)] 110 | elif padding_type == 'zero': 111 | p = 1 112 | else: 113 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 114 | conv_block += [nn.Conv2d(dim_out, dim_out, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim_out)] 115 | 116 | return nn.Sequential(*conv_block) 117 | 118 | def forward(self, x): 119 | """Forward function (with skip connections)""" 120 | out = self.shortcut(x) + self.conv_block(x) # add skip connections 121 | return out 122 | -------------------------------------------------------------------------------- /landmark2face/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if opt.preprocess == 'none': 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor(), 108 | transforms.Normalize((0.5, 0.5, 0.5), 109 | (0.5, 0.5, 0.5))] 110 | return transforms.Compose(transform_list) 111 | 112 | 113 | def __make_power_2(img, base, method=Image.BICUBIC): 114 | ow, oh = img.size 115 | h = int(round(oh / base) * base) 116 | w = int(round(ow / base) * base) 117 | if (h == oh) and (w == ow): 118 | return img 119 | 120 | __print_size_warning(ow, oh, w, h) 121 | return img.resize((w, h), method) 122 | 123 | 124 | def __scale_width(img, target_width, method=Image.BICUBIC): 125 | ow, oh = img.size 126 | if (ow == target_width): 127 | return img 128 | w = target_width 129 | h = int(target_width * oh / ow) 130 | return img.resize((w, h), method) 131 | 132 | 133 | def __crop(img, pos, size): 134 | ow, oh = img.size 135 | x1, y1 = pos 136 | tw = th = size 137 | if (ow > tw or oh > th): 138 | return img.crop((x1, y1, x1 + tw, y1 + th)) 139 | return img 140 | 141 | 142 | def __flip(img, flip): 143 | if flip: 144 | return img.transpose(Image.FLIP_LEFT_RIGHT) 145 | return img 146 | 147 | 148 | def __print_size_warning(ow, oh, w, h): 149 | """Print warning information about image size(only print once)""" 150 | if not hasattr(__print_size_warning, 'has_printed'): 151 | print("The image size needs to be a multiple of 4. " 152 | "The loaded image size was (%d, %d), so it was adjusted to " 153 | "(%d, %d). This adjustment will be done to all images " 154 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 155 | __print_size_warning.has_printed = True 156 | -------------------------------------------------------------------------------- /landmark2face/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | # parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 24 | parser.add_argument('--dataroot', type=str, default='./data') 25 | parser.add_argument('--name', type=str, default='man1', help='name of the experiment. It decides where to store samples and models') 26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 28 | # model parameters 29 | parser.add_argument('--model', type=str, default='l2face', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 30 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 31 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 33 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 34 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 35 | parser.add_argument('--netG', type=str, default='resnet_9blocks_l2face', help='specify generator architecture [resnet_9blocks | resnet_8blocks | resnet_6blocks | unet_256 | unet_128]') 36 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 37 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 38 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 39 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 40 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 41 | # dataset parameters 42 | parser.add_argument('--dataset_mode', type=str, default='l2face', help='chooses how datasets are loaded.') 43 | parser.add_argument('--img_size', type=int, default=256, help='input image size') 44 | parser.add_argument('--lan_size', type=int, default=1, help='input image size') 45 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') 46 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 47 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 48 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 49 | parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') 50 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 51 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 52 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 53 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 54 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 55 | # additional parameters 56 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 57 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 58 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 59 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 60 | self.initialized = True 61 | return parser 62 | 63 | def gather_options(self): 64 | """Initialize our parser with basic options(only once). 65 | Add additional model-specific and dataset-specific options. 66 | These options are defined in the function 67 | in model and dataset classes. 68 | """ 69 | if not self.initialized: # check if it has been initialized 70 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 71 | parser = self.initialize(parser) 72 | 73 | # get the basic options 74 | opt, _ = parser.parse_known_args() 75 | 76 | # modify model-related parser options 77 | model_name = opt.model 78 | model_option_setter = models.get_option_setter(model_name) 79 | parser = model_option_setter(parser, self.isTrain) 80 | opt, _ = parser.parse_known_args() # parse again with new defaults 81 | 82 | # modify dataset-related parser options 83 | dataset_name = opt.dataset_mode 84 | dataset_option_setter = data.get_option_setter(dataset_name) 85 | parser = dataset_option_setter(parser, self.isTrain) 86 | 87 | # save and return the parser 88 | self.parser = parser 89 | return parser.parse_args() 90 | 91 | def print_options(self, opt): 92 | """Print and save options 93 | 94 | It will print both current options and default values(if different). 95 | It will save options into a text file / [checkpoints_dir] / opt.txt 96 | """ 97 | message = '' 98 | message += '----------------- Options ---------------\n' 99 | for k, v in sorted(vars(opt).items()): 100 | comment = '' 101 | default = self.parser.get_default(k) 102 | if v != default: 103 | comment = '\t[default: %s]' % str(default) 104 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 105 | message += '----------------- End -------------------' 106 | print(message) 107 | 108 | # save to the disk 109 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 110 | util.mkdirs(expr_dir) 111 | file_name = os.path.join(expr_dir, 'opt.txt') 112 | with open(file_name, 'wt') as opt_file: 113 | opt_file.write(message) 114 | opt_file.write('\n') 115 | 116 | def parse(self): 117 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 118 | opt = self.gather_options() 119 | opt.isTrain = self.isTrain # train or test 120 | 121 | # process opt.suffix 122 | if opt.suffix: 123 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 124 | opt.name = opt.name + suffix 125 | 126 | self.print_options(opt) 127 | 128 | # set gpu ids 129 | str_ids = opt.gpu_ids.split(',') 130 | opt.gpu_ids = [] 131 | for str_id in str_ids: 132 | id = int(str_id) 133 | if id >= 0: 134 | opt.gpu_ids.append(id) 135 | if len(opt.gpu_ids) > 0: 136 | torch.cuda.set_device(opt.gpu_ids[0]) 137 | 138 | self.opt = opt 139 | return self.opt 140 | -------------------------------------------------------------------------------- /audio2landmark/APBGAN.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import time 5 | import logging 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from APBNet import * 11 | from utils import * 12 | from loss import * 13 | 14 | 15 | 16 | class GANModel(): 17 | 18 | def __init__(self, opt, logger): 19 | 20 | self.gpus = opt.gpus[0] 21 | self.isTrain = opt.isTrain 22 | self.lr = opt.lr 23 | self.every = opt.every 24 | self.epoch = 0 25 | self.best_loss = float("inf") 26 | self.idt_name = opt.idt_name 27 | self.logdir = opt.logdir 28 | self.logger = logger 29 | # loss 30 | self.criterionGAN = GANLoss(gan_mode='mse').cuda() 31 | self.criterionL1 = nn.L1Loss() 32 | # G 33 | self.netG = APBNet() 34 | self.netG.apply(weight_init) 35 | if opt.resume: 36 | checkpoint = torch.load('{}/{}.pth'.format(self.logdir, opt.resume_epoch if opt.resume_epoch else '{}_best'.format(self.idt_name))) 37 | self.netG.load_state_dict(checkpoint['net_G']) 38 | self.epoch = checkpoint['epoch'] 39 | self.netG.cuda() 40 | # D 41 | if self.isTrain: # define discriminators 42 | self.netD = Discriminator() 43 | self.netD.apply(weight_init) 44 | if opt.resume: 45 | self.netD.load_state_dict(checkpoint['net_D']) 46 | self.netD.cuda() 47 | 48 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=self.lr, betas=(0.99, 0.999)) 49 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=self.lr, betas=(0.99, 0.999)) 50 | 51 | def train(self): 52 | self.isTrain = True 53 | 54 | def eval(self): 55 | self.isTrain = False 56 | 57 | def reset(self): 58 | self.loss_log_L1 = 0 59 | self.loss_log_G_A = 0 60 | 61 | self.loss_log_D_A_F = 0 62 | self.loss_log_D_A_T = 0 63 | 64 | def test_draw(self, dataloader): 65 | def drawCircle(img, shape, radius=1, color=(255, 255, 255), thickness=1): 66 | for i in range(len(shape) // 2): 67 | img = cv2.circle(img, (int(shape[2 * i]), int(shape[2 * i + 1])), radius, color, thickness) 68 | return img 69 | 70 | def drawArrow(img, shape1, shape2, ): 71 | for i in range(len(shape1) // 2): 72 | point1 = (int(shape1[2 * i]), int(shape1[2 * i + 1])) 73 | point2 = (int(shape1[2 * i] + shape2[2 * i]), int(shape1[2 * i + 1] + shape2[2 * i + 1])) 74 | img = cv2.circle(img, point2, radius=6, color=(0, 0, 255), thickness=2) 75 | img = cv2.line(img, point1, point2, (255, 255, 255), thickness=2) 76 | return img 77 | 78 | root = self.logdir 79 | s_pathA = '{}/resultA'.format(root) 80 | s_pathB = '{}/resultB'.format(root) 81 | if not os.path.exists(s_pathA): 82 | os.mkdir(s_pathA) 83 | if not os.path.exists(s_pathB): 84 | os.mkdir(s_pathB) 85 | with torch.no_grad(): 86 | for batch_idx, data in enumerate(dataloader): 87 | self.set_input(data) 88 | self.forward() 89 | img_size = 256 90 | img_template = np.zeros((img_size, img_size, 3)) 91 | img_fake_A1 = drawCircle(img_template.copy(), self.fake_A.squeeze(0).data, radius=1, 92 | color=(255, 255, 255), thickness=2) 93 | img_A1 = drawCircle(img_template.copy(), self.land_A1.squeeze(0).data, radius=1, 94 | color=(255, 255, 255), thickness=2) 95 | img_fake_B1 = drawCircle(img_template.copy(), self.fake_B.squeeze(0).data, radius=1, 96 | color=(255, 255, 255), thickness=2) 97 | img_B1 = drawCircle(img_template.copy(), self.land_B1.squeeze(0).data, radius=1, 98 | color=(255, 255, 255), thickness=2) 99 | 100 | img_compareA = np.concatenate([img_template[:, :, 0][:, :, np.newaxis], img_fake_A1[:, :, 0][:, :, np.newaxis], 101 | img_A1[:, :, 0][:, :, np.newaxis]], axis=2) 102 | img_compareB = np.concatenate([img_template[:, :, 0][:, :, np.newaxis], img_fake_B1[:, :, 0][:, :, np.newaxis], 103 | img_A1[:, :, 2][:, :, np.newaxis]], axis=2) 104 | cv2.imwrite('{}/{}.jpg'.format(s_pathA, batch_idx), img_compareA) 105 | cv2.imwrite('{}/{}.jpg'.format(s_pathB, batch_idx), img_compareB) 106 | print('\r{}'.format(batch_idx + 1), end='') 107 | 108 | def run_train(self, dataloader, epoch=None): 109 | self.epoch += 1 110 | if epoch: 111 | self.epoch = epoch 112 | self.reset() 113 | adjust_learning_rate(self.optimizer_G, self.lr, self.epoch, every=self.every) 114 | adjust_learning_rate(self.optimizer_D, self.lr, self.epoch, every=self.every) 115 | for batch_idx, train_data in enumerate(dataloader): 116 | self.batch_idx = batch_idx + 1 117 | self.set_input(train_data) 118 | self.optimize_parameters() 119 | log_string = 'train\t -> ' 120 | log_string += 'epoch {:>3} '.format(self.epoch) 121 | log_string += 'batch {:>4} '.format(batch_idx + 1) 122 | log_string += '|loss_L1 {:.5f}'.format(self.loss_log_L1 / (batch_idx + 1)) 123 | log_string += '|loss_G_A {:.5f}'.format(self.loss_log_G_A / (batch_idx + 1)) 124 | log_string += '|loss_D_A_F {:.5f}'.format(self.loss_log_D_A_F / (batch_idx + 1)) 125 | log_string += '|loss_D_A_T {:.5f}'.format(self.loss_log_D_A_T / (batch_idx + 1)) 126 | print('\r' + log_string, end='') 127 | print('\r', end='') 128 | self.logger.info(log_string) 129 | 130 | def run_test(self, dataloader, epoch=None): 131 | if epoch: 132 | self.epoch = epoch 133 | self.reset() 134 | for batch_idx, test_data in enumerate(dataloader): 135 | self.batch_idx = batch_idx + 1 136 | self.set_input(test_data) 137 | self.evaluate_loss() 138 | log_string = 'test\t -> ' 139 | log_string += 'epoch {:>3} '.format(self.epoch) 140 | log_string += 'batch {:>4} '.format(batch_idx + 1) 141 | log_string += '|loss_L1 {:.5f}'.format(self.loss_log_L1 / (batch_idx + 1)) 142 | log_string += '|loss_G_A {:.5f}'.format(self.loss_log_G_A / (batch_idx + 1)) 143 | log_string += '|loss_D_A_F {:.5f}'.format(self.loss_log_D_A_F / (batch_idx + 1)) 144 | log_string += '|loss_D_A_T {:.5f}'.format(self.loss_log_D_A_T / (batch_idx + 1)) 145 | print('\r'+log_string, end='') 146 | print('\r', end='') 147 | self.logger.info(log_string) 148 | if self.loss_log_L1 / self.batch_idx < self.best_loss and not self.isTrain: 149 | self.best_loss = self.loss_log_L1 / self.batch_idx 150 | self.logger.info('save_best {:.5f}'.format(self.best_loss)) 151 | self.save(mode='best') 152 | if self.epoch % 50 == 0: 153 | self.logger.info('save_epoch {:d}'.format(self.epoch)) 154 | self.save(mode=self.epoch) 155 | 156 | def set_input(self, training_data): 157 | self.audio_feature_A1, self.pose_A1, self.eye_A1 = training_data[0][0].to(self.gpus),\ 158 | training_data[0][1].to(self.gpus),\ 159 | training_data[0][2].to(self.gpus) 160 | self.land_A1, self.land_A2 = training_data[1][0].to(self.gpus), \ 161 | training_data[1][1].to(self.gpus) 162 | 163 | def optimize_parameters(self): 164 | self.forward() 165 | # G 166 | self.set_requires_grad([self.netD], False) 167 | self.optimizer_G.zero_grad() 168 | self.backward_G() 169 | self.optimizer_G.step() 170 | # D 171 | if self.batch_idx % 1 == 0: 172 | self.set_requires_grad([self.netD], True) 173 | self.optimizer_D.zero_grad() 174 | self.backward_D() 175 | self.optimizer_D.step() 176 | 177 | def evaluate_loss(self): 178 | self.forward() 179 | # G 180 | self.loss_L1 = self.criterionL1(self.fake_A, self.land_A1) 181 | self.loss_G_A = self.criterionGAN(self.netD(self.fake_A, self.land_A2), True) 182 | self.loss_log_L1 += self.loss_L1.item() 183 | self.loss_log_G_A += self.loss_G_A.item() 184 | # D 185 | loss_D_A_F = self.criterionGAN(self.netD(self.fake_A.detach(), self.land_A2.detach()), False) 186 | loss_D_A_T = self.criterionGAN(self.netD(self.land_A1.detach(), self.land_A2.detach()), True) 187 | self.loss_log_D_A_F += loss_D_A_F.item() 188 | self.loss_log_D_A_T += loss_D_A_T.item() 189 | 190 | def forward(self): 191 | self.fake_A = self.netG(self.audio_feature_A1, self.pose_A1, self.eye_A1) 192 | 193 | def set_requires_grad(self, nets, requires_grad=False): 194 | if not isinstance(nets, list): 195 | nets = [nets] 196 | for net in nets: 197 | if net is not None: 198 | for param in net.parameters(): 199 | param.requires_grad = requires_grad 200 | 201 | def backward_G(self): 202 | lambda_L1 = 100 203 | lambda_gan = 0.1 204 | 205 | self.loss_L1 = self.criterionL1(self.fake_A, self.land_A1) 206 | self.loss_G_A = self.criterionGAN(self.netD(self.fake_A, self.land_A2), True) 207 | 208 | self.loss_G = self.loss_L1 * lambda_L1 + self.loss_G_A * lambda_gan 209 | self.loss_G.backward() 210 | # log 211 | self.loss_log_L1 += self.loss_L1.item() 212 | self.loss_log_G_A += self.loss_G_A.item() 213 | 214 | def backward_D(self): 215 | lambda_D = 0.1 216 | loss_D_A_F = self.criterionGAN(self.netD(self.fake_A.detach(), self.land_A2.detach()), False) 217 | loss_D_A_T = self.criterionGAN(self.netD(self.land_A1.detach(), self.land_A2.detach()), True) 218 | # Combined loss and calculate gradients 219 | loss_D = (loss_D_A_F + loss_D_A_T) * 0.5 * lambda_D 220 | loss_D.backward() 221 | # log 222 | self.loss_log_D_A_F += loss_D_A_F.item() 223 | self.loss_log_D_A_T += loss_D_A_T.item() 224 | 225 | def save(self, mode=None): 226 | state = { 227 | 'net_G': self.netG.state_dict(), 228 | 'net_D': self.netD.state_dict(), 229 | 'epoch': self.epoch, 230 | } 231 | torch.save(state, '{}/{}.pth'.format(self.logdir, '{}_{}'.format(self.idt_name, mode if mode else self.epoch))) 232 | -------------------------------------------------------------------------------- /landmark2face/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | self.metric = None # used for learning rate policy 'plateau' 45 | 46 | @staticmethod 47 | def modify_commandline_options(parser, is_train): 48 | """Add new model-specific options, and rewrite default values for existing options. 49 | 50 | Parameters: 51 | parser -- original option parser 52 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 53 | 54 | Returns: 55 | the modified parser. 56 | """ 57 | return parser 58 | 59 | @abstractmethod 60 | def set_input(self, input): 61 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 62 | 63 | Parameters: 64 | input (dict): includes the data itself and its metadata information. 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def forward(self): 70 | """Run forward pass; called by both functions and .""" 71 | pass 72 | 73 | @abstractmethod 74 | def optimize_parameters(self): 75 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 76 | pass 77 | 78 | def setup(self, opt): 79 | """Load and print networks; create schedulers 80 | 81 | Parameters: 82 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 83 | """ 84 | if self.isTrain: 85 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 86 | if not self.isTrain or opt.continue_train: 87 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 88 | self.load_networks(load_suffix) 89 | self.print_networks(opt.verbose) 90 | 91 | def eval(self): 92 | """Make models eval mode during test time""" 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | net = getattr(self, 'net' + name) 96 | net.eval() 97 | 98 | def test(self): 99 | """Forward function used in test time. 100 | 101 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 102 | It also calls to produce additional visualization results 103 | """ 104 | with torch.no_grad(): 105 | self.forward() 106 | self.compute_visuals() 107 | 108 | def compute_visuals(self): 109 | """Calculate additional output images for visdom and HTML visualization""" 110 | pass 111 | 112 | def get_image_paths(self): 113 | """ Return image paths that are used to load current data""" 114 | return self.image_paths 115 | 116 | def update_learning_rate(self): 117 | """Update learning rates for all the networks; called at the end of every epoch""" 118 | for scheduler in self.schedulers: 119 | scheduler.step(self.metric) 120 | lr = self.optimizers[0].param_groups[0]['lr'] 121 | print('learning rate = %.7f' % lr) 122 | 123 | def get_current_visuals(self): 124 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 125 | visual_ret = OrderedDict() 126 | for name in self.visual_names: 127 | if isinstance(name, str): 128 | visual_ret[name] = getattr(self, name) 129 | return visual_ret 130 | 131 | def get_current_losses(self): 132 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 133 | errors_ret = OrderedDict() 134 | for name in self.loss_names: 135 | if isinstance(name, str): 136 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 137 | return errors_ret 138 | 139 | def save_networks(self, epoch): 140 | """Save all the networks to the disk. 141 | 142 | Parameters: 143 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 144 | """ 145 | for name in self.model_names: 146 | if isinstance(name, str): 147 | save_filename = '%s_net_%s.pth' % (epoch, name) 148 | save_path = os.path.join(self.save_dir, save_filename) 149 | net = getattr(self, 'net' + name) 150 | 151 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 152 | torch.save(net.module.cpu().state_dict(), save_path) 153 | net.cuda(self.gpu_ids[0]) 154 | else: 155 | torch.save(net.cpu().state_dict(), save_path) 156 | 157 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 158 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 159 | key = keys[i] 160 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 161 | if module.__class__.__name__.startswith('InstanceNorm') and \ 162 | (key == 'running_mean' or key == 'running_var'): 163 | if getattr(module, key) is None: 164 | state_dict.pop('.'.join(keys)) 165 | if module.__class__.__name__.startswith('InstanceNorm') and \ 166 | (key == 'num_batches_tracked'): 167 | state_dict.pop('.'.join(keys)) 168 | else: 169 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 170 | 171 | def load_networks(self, epoch): 172 | """Load all the networks from the disk. 173 | 174 | Parameters: 175 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 176 | """ 177 | for name in self.model_names: 178 | if isinstance(name, str): 179 | load_filename = '%s_net_%s.pth' % (epoch, name) 180 | load_path = os.path.join(self.save_dir, load_filename) 181 | net = getattr(self, 'net' + name) 182 | if isinstance(net, torch.nn.DataParallel): 183 | net = net.module 184 | print('loading the model from %s' % load_path) 185 | # if you are using PyTorch newer than 0.4 (e.g., built from 186 | # GitHub source), you can remove str() on self.device 187 | state_dict = torch.load(load_path, map_location=str(self.device)) 188 | if hasattr(state_dict, '_metadata'): 189 | del state_dict._metadata 190 | # patch InstanceNorm checkpoints prior to 0.4 191 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 192 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 193 | # print(net) 194 | # print(state_dict.keys()) 195 | net.load_state_dict(state_dict) 196 | 197 | def print_networks(self, verbose): 198 | """Print the total number of parameters in the network and (if verbose) network architecture 199 | 200 | Parameters: 201 | verbose (bool) -- if verbose: print the network architecture 202 | """ 203 | print('---------- Networks initialized -------------') 204 | for name in self.model_names: 205 | if isinstance(name, str): 206 | net = getattr(self, 'net' + name) 207 | num_params = 0 208 | for param in net.parameters(): 209 | num_params += param.numel() 210 | if verbose: 211 | print(net) 212 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 213 | print('-----------------------------------------------') 214 | 215 | def set_requires_grad(self, nets, requires_grad=False): 216 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 217 | Parameters: 218 | nets (network list) -- a list of networks 219 | requires_grad (bool) -- whether the networks require gradients or not 220 | """ 221 | if not isinstance(nets, list): 222 | nets = [nets] 223 | for net in nets: 224 | if net is not None: 225 | for param in net.parameters(): 226 | param.requires_grad = requires_grad 227 | -------------------------------------------------------------------------------- /landmark2face/util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | from . import util, html 7 | from subprocess import Popen, PIPE 8 | from scipy.misc import imresize 9 | 10 | if sys.version_info[0] == 2: 11 | VisdomExceptionBase = Exception 12 | else: 13 | VisdomExceptionBase = ConnectionError 14 | 15 | 16 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 17 | """Save images to the disk. 18 | 19 | Parameters: 20 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 21 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 22 | image_path (str) -- the string is used to create image paths 23 | aspect_ratio (float) -- the aspect ratio of saved images 24 | width (int) -- the images will be resized to width x width 25 | 26 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 27 | """ 28 | image_dir = webpage.get_image_dir() 29 | short_path = ntpath.basename(image_path[0]) 30 | name = os.path.splitext(short_path)[0] 31 | 32 | webpage.add_header(name) 33 | ims, txts, links = [], [], [] 34 | 35 | for label, im_data in visuals.items(): 36 | im = util.tensor2im(im_data) 37 | image_name = '%s_%s.png' % (name, label) 38 | save_path = os.path.join(image_dir, image_name) 39 | h, w, _ = im.shape 40 | if aspect_ratio > 1.0: 41 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 42 | if aspect_ratio < 1.0: 43 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 44 | util.save_image(im, save_path) 45 | 46 | ims.append(image_name) 47 | txts.append(label) 48 | links.append(image_name) 49 | webpage.add_images(ims, txts, links, width=width) 50 | 51 | 52 | class Visualizer(): 53 | """This class includes several functions that can display/save images and print/save logging information. 54 | 55 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 56 | """ 57 | 58 | def __init__(self, opt): 59 | """Initialize the Visualizer class 60 | 61 | Parameters: 62 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 63 | Step 1: Cache the training/test options 64 | Step 2: connect to a visdom server 65 | Step 3: create an HTML object for saveing HTML filters 66 | Step 4: create a logging file to store training losses 67 | """ 68 | self.opt = opt # cache the option 69 | self.display_id = opt.display_id 70 | self.use_html = opt.isTrain and not opt.no_html 71 | self.win_size = opt.display_winsize 72 | self.name = opt.name 73 | self.port = opt.display_port 74 | self.saved = False 75 | if self.display_id > 0: # connect to a visdom server given and 76 | import visdom 77 | self.ncols = opt.display_ncols 78 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 79 | if not self.vis.check_connection(): 80 | self.create_visdom_connections() 81 | 82 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 83 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 84 | self.img_dir = os.path.join(self.web_dir, 'images') 85 | print('create web directory %s...' % self.web_dir) 86 | util.mkdirs([self.web_dir, self.img_dir]) 87 | # create a logging file to store training losses 88 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 89 | with open(self.log_name, "a") as log_file: 90 | now = time.strftime("%c") 91 | log_file.write('================ Training Loss (%s) ================\n' % now) 92 | 93 | def reset(self): 94 | """Reset the self.saved status""" 95 | self.saved = False 96 | 97 | def create_visdom_connections(self): 98 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 99 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 100 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 101 | print('Command: %s' % cmd) 102 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 103 | 104 | def display_current_results(self, visuals, epoch, save_result): 105 | """Display current results on visdom; save current results to an HTML file. 106 | 107 | Parameters: 108 | visuals (OrderedDict) - - dictionary of images to display or save 109 | epoch (int) - - the current epoch 110 | save_result (bool) - - if save the current results to an HTML file 111 | """ 112 | if self.display_id > 0: # show images in the browser using visdom 113 | ncols = self.ncols 114 | if ncols > 0: # show all the images in one visdom panel 115 | ncols = min(ncols, len(visuals)) 116 | h, w = next(iter(visuals.values())).shape[:2] 117 | table_css = """""" % (w, h) # create a table css 121 | # create a table of images. 122 | title = self.name 123 | label_html = '' 124 | label_html_row = '' 125 | images = [] 126 | idx = 0 127 | for label, image in visuals.items(): 128 | image_numpy = util.tensor2im(image) 129 | label_html_row += '%s' % label 130 | images.append(image_numpy.transpose([2, 0, 1])) 131 | idx += 1 132 | if idx % ncols == 0: 133 | label_html += '%s' % label_html_row 134 | label_html_row = '' 135 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 136 | while idx % ncols != 0: 137 | images.append(white_image) 138 | label_html_row += '' 139 | idx += 1 140 | if label_html_row != '': 141 | label_html += '%s' % label_html_row 142 | try: 143 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 144 | padding=2, opts=dict(title=title + ' images')) 145 | label_html = '%s
' % label_html 146 | self.vis.text(table_css + label_html, win=self.display_id + 2, 147 | opts=dict(title=title + ' labels')) 148 | except VisdomExceptionBase: 149 | self.create_visdom_connections() 150 | 151 | else: # show each image in a separate visdom panel; 152 | idx = 1 153 | try: 154 | for label, image in visuals.items(): 155 | image_numpy = util.tensor2im(image) 156 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 157 | win=self.display_id + idx) 158 | idx += 1 159 | except VisdomExceptionBase: 160 | self.create_visdom_connections() 161 | 162 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 163 | self.saved = True 164 | # save images to the disk 165 | for label, image in visuals.items(): 166 | image_numpy = util.tensor2im(image) 167 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 168 | util.save_image(image_numpy, img_path) 169 | 170 | # update website 171 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=10) 172 | for n in range(epoch, 0, -1): 173 | webpage.add_header('epoch [%d]' % n) 174 | ims, txts, links = [], [], [] 175 | 176 | for label, image_numpy in visuals.items(): 177 | image_numpy = util.tensor2im(image) 178 | img_path = 'epoch%.3d_%s.png' % (n, label) 179 | ims.append(img_path) 180 | txts.append(label) 181 | links.append(img_path) 182 | webpage.add_images(ims, txts, links, width=self.win_size) 183 | webpage.save() 184 | 185 | def plot_current_losses(self, epoch, counter_ratio, losses): 186 | """display the current losses on visdom display: dictionary of error labels and values 187 | 188 | Parameters: 189 | epoch (int) -- current epoch 190 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 191 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 192 | """ 193 | if not hasattr(self, 'plot_data'): 194 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 195 | self.plot_data['X'].append(epoch + counter_ratio) 196 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 197 | try: 198 | self.vis.line( 199 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 200 | Y=np.array(self.plot_data['Y']), 201 | opts={ 202 | 'title': self.name + ' loss over time', 203 | 'legend': self.plot_data['legend'], 204 | 'xlabel': 'epoch', 205 | 'ylabel': 'loss'}, 206 | win=self.display_id) 207 | except VisdomExceptionBase: 208 | self.create_visdom_connections() 209 | 210 | # losses: same format as |losses| of plot_current_losses 211 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data): 212 | """print current losses on console; also save the losses to the disk 213 | 214 | Parameters: 215 | epoch (int) -- current epoch 216 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 217 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 218 | t_comp (float) -- computational time per data point (normalized by batch_size) 219 | t_data (float) -- data loading time per data point (normalized by batch_size) 220 | """ 221 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 222 | for k, v in losses.items(): 223 | message += '%s: %.3f ' % (k, v) 224 | 225 | print(message) # print the message 226 | with open(self.log_name, "a") as log_file: 227 | log_file.write('%s\n' % message) # save the message 228 | -------------------------------------------------------------------------------- /landmark2face/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | from .networks_l2face import * 8 | ############################################################################### 9 | # Helper Functions 10 | ############################################################################### 11 | def get_norm_layer(norm_type='instance'): 12 | """Return a normalization layer 13 | 14 | Parameters: 15 | norm_type (str) -- the name of the normalization layer: batch | instance | none 16 | 17 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 18 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 19 | """ 20 | if norm_type == 'batch': 21 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 22 | elif norm_type == 'instance': 23 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 24 | elif norm_type == 'none': 25 | norm_layer = None 26 | else: 27 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 28 | return norm_layer 29 | 30 | 31 | def get_scheduler(optimizer, opt): 32 | """Return a learning rate scheduler 33 | 34 | Parameters: 35 | optimizer -- the optimizer of the network 36 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  37 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 38 | 39 | For 'linear', we keep the same learning rate for the first epochs 40 | and linearly decay the rate to zero over the next epochs. 41 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 42 | See https://pytorch.org/docs/stable/optim.html for more details. 43 | """ 44 | if opt.lr_policy == 'linear': 45 | def lambda_rule(epoch): 46 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 47 | return lr_l 48 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 49 | elif opt.lr_policy == 'step': 50 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 51 | elif opt.lr_policy == 'plateau': 52 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 53 | elif opt.lr_policy == 'cosine': 54 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 55 | else: 56 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 57 | return scheduler 58 | 59 | 60 | def init_weights(net, init_type='normal', init_gain=0.02): 61 | """Initialize network weights. 62 | 63 | Parameters: 64 | net (network) -- network to be initialized 65 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 66 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 67 | 68 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 69 | work better for some applications. Feel free to try yourself. 70 | """ 71 | def init_func(m): # define the initialization function 72 | classname = m.__class__.__name__ 73 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 74 | if init_type == 'normal': 75 | init.normal_(m.weight.data, 0.0, init_gain) 76 | elif init_type == 'xavier': 77 | init.xavier_normal_(m.weight.data, gain=init_gain) 78 | elif init_type == 'kaiming': 79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 80 | elif init_type == 'orthogonal': 81 | init.orthogonal_(m.weight.data, gain=init_gain) 82 | else: 83 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 84 | if hasattr(m, 'bias') and m.bias is not None: 85 | init.constant_(m.bias.data, 0.0) 86 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 87 | init.normal_(m.weight.data, 1.0, init_gain) 88 | init.constant_(m.bias.data, 0.0) 89 | 90 | print('initialize network with %s' % init_type) 91 | net.apply(init_func) # apply the initialization function 92 | 93 | 94 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 95 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 96 | Parameters: 97 | net (network) -- the network to be initialized 98 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 99 | gain (float) -- scaling factor for normal, xavier and orthogonal. 100 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 101 | 102 | Return an initialized network. 103 | """ 104 | if len(gpu_ids) > 0: 105 | assert(torch.cuda.is_available()) 106 | net.to(gpu_ids[0]) 107 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 108 | init_weights(net, init_type, init_gain=init_gain) 109 | return net 110 | 111 | 112 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 113 | """Create a generator 114 | 115 | Parameters: 116 | input_nc (int) -- the number of channels in input images 117 | output_nc (int) -- the number of channels in output images 118 | ngf (int) -- the number of filters in the last conv layer 119 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 120 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 121 | use_dropout (bool) -- if use dropout layers. 122 | init_type (str) -- the name of our initialization method. 123 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 124 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 125 | 126 | Returns a generator 127 | 128 | Our current implementation provides two types of generators: 129 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 130 | The original U-Net paper: https://arxiv.org/abs/1505.04597 131 | 132 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 133 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 134 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 135 | 136 | 137 | The generator has been initialized by . It uses RELU for non-linearity. 138 | """ 139 | net = None 140 | norm_layer = get_norm_layer(norm_type=norm) 141 | 142 | if netG == 'resnet_9blocks': 143 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 144 | elif netG == 'resnet_9blocks_l2face': 145 | net = ResnetL2FaceGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 146 | elif netG == 'resnet_6blocks': 147 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 148 | elif netG == 'unet_128': 149 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 150 | elif netG == 'unet_256': 151 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 152 | else: 153 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 154 | return init_net(net, init_type, init_gain, gpu_ids) 155 | 156 | 157 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 158 | """Create a discriminator 159 | 160 | Parameters: 161 | input_nc (int) -- the number of channels in input images 162 | ndf (int) -- the number of filters in the first conv layer 163 | netD (str) -- the architecture's name: basic | n_layers | pixel 164 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 165 | norm (str) -- the type of normalization layers used in the network. 166 | init_type (str) -- the name of the initialization method. 167 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 168 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 169 | 170 | Returns a discriminator 171 | 172 | Our current implementation provides three types of discriminators: 173 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 174 | It can classify whether 70×70 overlapping patches are real or fake. 175 | Such a patch-level discriminator architecture has fewer parameters 176 | than a full-image discriminator and can work on arbitrarily-sized images 177 | in a fully convolutional fashion. 178 | 179 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator 180 | with the parameter (default=3 as used in [basic] (PatchGAN).) 181 | 182 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 183 | It encourages greater color diversity but has no effect on spatial statistics. 184 | 185 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 186 | """ 187 | net = None 188 | norm_layer = get_norm_layer(norm_type=norm) 189 | 190 | if netD == 'basic': # default PatchGAN classifier 191 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 192 | elif netD == 'n_layers': # more options 193 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 194 | elif netD == 'pixel': # classify if each pixel is real or fake 195 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 196 | else: 197 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 198 | return init_net(net, init_type, init_gain, gpu_ids) 199 | 200 | 201 | ############################################################################## 202 | # Classes 203 | ############################################################################## 204 | class GANLoss(nn.Module): 205 | """Define different GAN objectives. 206 | 207 | The GANLoss class abstracts away the need to create the target label tensor 208 | that has the same size as the input. 209 | """ 210 | 211 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 212 | """ Initialize the GANLoss class. 213 | 214 | Parameters: 215 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 216 | target_real_label (bool) - - label for a real image 217 | target_fake_label (bool) - - label of a fake image 218 | 219 | Note: Do not use sigmoid as the last layer of Discriminator. 220 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 221 | """ 222 | super(GANLoss, self).__init__() 223 | self.register_buffer('real_label', torch.tensor(target_real_label)) 224 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 225 | self.gan_mode = gan_mode 226 | if gan_mode == 'lsgan': 227 | self.loss = nn.MSELoss() 228 | elif gan_mode == 'vanilla': 229 | self.loss = nn.BCEWithLogitsLoss() 230 | elif gan_mode in ['wgangp']: 231 | self.loss = None 232 | else: 233 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 234 | 235 | def get_target_tensor(self, prediction, target_is_real): 236 | """Create label tensors with the same size as the input. 237 | 238 | Parameters: 239 | prediction (tensor) - - tpyically the prediction from a discriminator 240 | target_is_real (bool) - - if the ground truth label is for real images or fake images 241 | 242 | Returns: 243 | A label tensor filled with ground truth label, and with the size of the input 244 | """ 245 | 246 | if target_is_real: 247 | target_tensor = self.real_label 248 | else: 249 | target_tensor = self.fake_label 250 | return target_tensor.expand_as(prediction) 251 | 252 | def __call__(self, prediction, target_is_real): 253 | """Calculate loss given Discriminator's output and grount truth labels. 254 | 255 | Parameters: 256 | prediction (tensor) - - tpyically the prediction output from a discriminator 257 | target_is_real (bool) - - if the ground truth label is for real images or fake images 258 | 259 | Returns: 260 | the calculated loss. 261 | """ 262 | if self.gan_mode in ['lsgan', 'vanilla']: 263 | target_tensor = self.get_target_tensor(prediction, target_is_real) 264 | loss = self.loss(prediction, target_tensor) 265 | elif self.gan_mode == 'wgangp': 266 | if target_is_real: 267 | loss = -prediction.mean() 268 | else: 269 | loss = prediction.mean() 270 | return loss 271 | 272 | 273 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 274 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 275 | 276 | Arguments: 277 | netD (network) -- discriminator network 278 | real_data (tensor array) -- real images 279 | fake_data (tensor array) -- generated images from the generator 280 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 281 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 282 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 283 | lambda_gp (float) -- weight for this loss 284 | 285 | Returns the gradient penalty loss 286 | """ 287 | if lambda_gp > 0.0: 288 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 289 | interpolatesv = real_data 290 | elif type == 'fake': 291 | interpolatesv = fake_data 292 | elif type == 'mixed': 293 | alpha = torch.rand(real_data.shape[0], 1) 294 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 295 | alpha = alpha.to(device) 296 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 297 | else: 298 | raise NotImplementedError('{} not implemented'.format(type)) 299 | interpolatesv.requires_grad_(True) 300 | disc_interpolates = netD(interpolatesv) 301 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 302 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 303 | create_graph=True, retain_graph=True, only_inputs=True) 304 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 305 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 306 | return gradient_penalty, gradients 307 | else: 308 | return 0.0, None 309 | 310 | 311 | class ResnetGenerator(nn.Module): 312 | """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. 313 | 314 | We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) 315 | """ 316 | 317 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 318 | """Construct a Resnet-based generator 319 | 320 | Parameters: 321 | input_nc (int) -- the number of channels in input images 322 | output_nc (int) -- the number of channels in output images 323 | ngf (int) -- the number of filters in the last conv layer 324 | norm_layer -- normalization layer 325 | use_dropout (bool) -- if use dropout layers 326 | n_blocks (int) -- the number of ResNet blocks 327 | padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero 328 | """ 329 | assert(n_blocks >= 0) 330 | super(ResnetGenerator, self).__init__() 331 | if type(norm_layer) == functools.partial: 332 | use_bias = norm_layer.func == nn.InstanceNorm2d 333 | else: 334 | use_bias = norm_layer == nn.InstanceNorm2d 335 | 336 | model = [nn.ReflectionPad2d(3), 337 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), 338 | norm_layer(ngf), 339 | nn.ReLU(True)] 340 | 341 | n_downsampling = 2 342 | for i in range(n_downsampling): # add downsampling layers 343 | mult = 2 ** i 344 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 345 | norm_layer(ngf * mult * 2), 346 | nn.ReLU(True)] 347 | 348 | mult = 2 ** n_downsampling 349 | for i in range(n_blocks): # add ResNet blocks 350 | 351 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 352 | 353 | for i in range(n_downsampling): # add upsampling layers 354 | mult = 2 ** (n_downsampling - i) 355 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 356 | kernel_size=3, stride=2, 357 | padding=1, output_padding=1, 358 | bias=use_bias), 359 | norm_layer(int(ngf * mult / 2)), 360 | nn.ReLU(True)] 361 | model += [nn.ReflectionPad2d(3)] 362 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 363 | model += [nn.Tanh()] 364 | 365 | self.model = nn.Sequential(*model) 366 | 367 | def forward(self, input): 368 | """Standard forward""" 369 | return self.model(input) 370 | 371 | 372 | class ResnetBlock(nn.Module): 373 | """Define a Resnet block""" 374 | 375 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 376 | """Initialize the Resnet block 377 | 378 | A resnet block is a conv block with skip connections 379 | We construct a conv block with build_conv_block function, 380 | and implement skip connections in function. 381 | Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf 382 | """ 383 | super(ResnetBlock, self).__init__() 384 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 385 | 386 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 387 | """Construct a convolutional block. 388 | 389 | Parameters: 390 | dim (int) -- the number of channels in the conv layer. 391 | padding_type (str) -- the name of padding layer: reflect | replicate | zero 392 | norm_layer -- normalization layer 393 | use_dropout (bool) -- if use dropout layers. 394 | use_bias (bool) -- if the conv layer uses bias or not 395 | 396 | Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) 397 | """ 398 | conv_block = [] 399 | p = 0 400 | if padding_type == 'reflect': 401 | conv_block += [nn.ReflectionPad2d(1)] 402 | elif padding_type == 'replicate': 403 | conv_block += [nn.ReplicationPad2d(1)] 404 | elif padding_type == 'zero': 405 | p = 1 406 | else: 407 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 408 | 409 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] 410 | if use_dropout: 411 | conv_block += [nn.Dropout(0.5)] 412 | 413 | p = 0 414 | if padding_type == 'reflect': 415 | conv_block += [nn.ReflectionPad2d(1)] 416 | elif padding_type == 'replicate': 417 | conv_block += [nn.ReplicationPad2d(1)] 418 | elif padding_type == 'zero': 419 | p = 1 420 | else: 421 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 422 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] 423 | 424 | return nn.Sequential(*conv_block) 425 | 426 | def forward(self, x): 427 | """Forward function (with skip connections)""" 428 | out = x + self.conv_block(x) # add skip connections 429 | return out 430 | 431 | 432 | class UnetGenerator(nn.Module): 433 | """Create a Unet-based generator""" 434 | 435 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 436 | """Construct a Unet generator 437 | Parameters: 438 | input_nc (int) -- the number of channels in input images 439 | output_nc (int) -- the number of channels in output images 440 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 441 | image of size 128x128 will become of size 1x1 # at the bottleneck 442 | ngf (int) -- the number of filters in the last conv layer 443 | norm_layer -- normalization layer 444 | 445 | We construct the U-Net from the innermost layer to the outermost layer. 446 | It is a recursive process. 447 | """ 448 | super(UnetGenerator, self).__init__() 449 | # construct unet structure 450 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 451 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 452 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 453 | # gradually reduce the number of filters from ngf * 8 to ngf 454 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 455 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 456 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 457 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 458 | 459 | def forward(self, input): 460 | """Standard forward""" 461 | return self.model(input) 462 | 463 | 464 | class UnetSkipConnectionBlock(nn.Module): 465 | """Defines the Unet submodule with skip connection. 466 | X -------------------identity---------------------- 467 | |-- downsampling -- |submodule| -- upsampling --| 468 | """ 469 | 470 | def __init__(self, outer_nc, inner_nc, input_nc=None, 471 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 472 | """Construct a Unet submodule with skip connections. 473 | 474 | Parameters: 475 | outer_nc (int) -- the number of filters in the outer conv layer 476 | inner_nc (int) -- the number of filters in the inner conv layer 477 | input_nc (int) -- the number of channels in input images/features 478 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 479 | outermost (bool) -- if this module is the outermost module 480 | innermost (bool) -- if this module is the innermost module 481 | norm_layer -- normalization layer 482 | user_dropout (bool) -- if use dropout layers. 483 | """ 484 | super(UnetSkipConnectionBlock, self).__init__() 485 | self.outermost = outermost 486 | if type(norm_layer) == functools.partial: 487 | use_bias = norm_layer.func == nn.InstanceNorm2d 488 | else: 489 | use_bias = norm_layer == nn.InstanceNorm2d 490 | if input_nc is None: 491 | input_nc = outer_nc 492 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 493 | stride=2, padding=1, bias=use_bias) 494 | downrelu = nn.LeakyReLU(0.2, True) 495 | downnorm = norm_layer(inner_nc) 496 | uprelu = nn.ReLU(True) 497 | upnorm = norm_layer(outer_nc) 498 | 499 | if outermost: 500 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 501 | kernel_size=4, stride=2, 502 | padding=1) 503 | down = [downconv] 504 | up = [uprelu, upconv, nn.Tanh()] 505 | model = down + [submodule] + up 506 | elif innermost: 507 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 508 | kernel_size=4, stride=2, 509 | padding=1, bias=use_bias) 510 | down = [downrelu, downconv] 511 | up = [uprelu, upconv, upnorm] 512 | model = down + up 513 | else: 514 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 515 | kernel_size=4, stride=2, 516 | padding=1, bias=use_bias) 517 | down = [downrelu, downconv, downnorm] 518 | up = [uprelu, upconv, upnorm] 519 | 520 | if use_dropout: 521 | model = down + [submodule] + up + [nn.Dropout(0.5)] 522 | else: 523 | model = down + [submodule] + up 524 | 525 | self.model = nn.Sequential(*model) 526 | 527 | def forward(self, x): 528 | if self.outermost: 529 | return self.model(x) 530 | else: # add skip connections 531 | return torch.cat([x, self.model(x)], 1) 532 | 533 | 534 | class NLayerDiscriminator(nn.Module): 535 | """Defines a PatchGAN discriminator""" 536 | 537 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 538 | """Construct a PatchGAN discriminator 539 | 540 | Parameters: 541 | input_nc (int) -- the number of channels in input images 542 | ndf (int) -- the number of filters in the last conv layer 543 | n_layers (int) -- the number of conv layers in the discriminator 544 | norm_layer -- normalization layer 545 | """ 546 | super(NLayerDiscriminator, self).__init__() 547 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 548 | use_bias = norm_layer.func != nn.BatchNorm2d 549 | else: 550 | use_bias = norm_layer != nn.BatchNorm2d 551 | 552 | kw = 4 553 | padw = 1 554 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 555 | nf_mult = 1 556 | nf_mult_prev = 1 557 | for n in range(1, n_layers): # gradually increase the number of filters 558 | nf_mult_prev = nf_mult 559 | nf_mult = min(2 ** n, 8) 560 | sequence += [ 561 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 562 | norm_layer(ndf * nf_mult), 563 | nn.LeakyReLU(0.2, True) 564 | ] 565 | 566 | nf_mult_prev = nf_mult 567 | nf_mult = min(2 ** n_layers, 8) 568 | sequence += [ 569 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 570 | norm_layer(ndf * nf_mult), 571 | nn.LeakyReLU(0.2, True) 572 | ] 573 | 574 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 575 | self.model = nn.Sequential(*sequence) 576 | 577 | def forward(self, input): 578 | """Standard forward.""" 579 | return self.model(input) 580 | 581 | 582 | class PixelDiscriminator(nn.Module): 583 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 584 | 585 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 586 | """Construct a 1x1 PatchGAN discriminator 587 | 588 | Parameters: 589 | input_nc (int) -- the number of channels in input images 590 | ndf (int) -- the number of filters in the last conv layer 591 | norm_layer -- normalization layer 592 | """ 593 | super(PixelDiscriminator, self).__init__() 594 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 595 | use_bias = norm_layer.func != nn.InstanceNorm2d 596 | else: 597 | use_bias = norm_layer != nn.InstanceNorm2d 598 | 599 | self.net = [ 600 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 601 | nn.LeakyReLU(0.2, True), 602 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 603 | norm_layer(ndf * 2), 604 | nn.LeakyReLU(0.2, True), 605 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 606 | 607 | self.net = nn.Sequential(*self.net) 608 | 609 | def forward(self, input): 610 | """Standard forward.""" 611 | return self.net(input) 612 | --------------------------------------------------------------------------------