├── README.md ├── core ├── checkpoint.py ├── data_loader.py ├── face_model.py ├── model.py ├── resnet.py ├── solver.py ├── utils.py └── wing.py ├── images ├── 1 └── qualitative_comparisons.jpg ├── main.py └── param.yaml /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## FaceSwapper - Official PyTorch Implementation 3 | 4 | 5 | 6 | 7 | > **FaceSwapper: Learning Disentangled Representation for One-shot Progressive Face Swapping**
8 | > [Qi Li](https://liqi-casia.github.io/), [Weining Wang](https://scholar.google.com/citations?hl=en&user=NDPvobAAAAAJ), [Chengzhong Xu](https://www.fst.um.edu.mo/people/czxu/), [Zhenan Sun](https://scholar.google.com.au/citations?user=PuZGODYAAAAJ&hl=en), [Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=zh-CN)
9 | > In TPAMI 2024.
10 | 11 | 12 | > Paper: [https://ieeexplore.ieee.org/abstract/document/10536627](https://ieeexplore.ieee.org/abstract/document/10536627)
13 | 14 | 15 | >

Although face swapping has attracted much attention in recent years, it remains a challenging problem. The existing methods leverage a large number of data samples to explore the intrinsic properties of face swapping without taking into account the semantic information of face images. Moreover, the representation of the identity information tends to be fixed, leading to suboptimal face swapping. In this paper, we present a simple yet efficient method named FaceSwapper, for one-shot face swapping based on Generative Adversarial Networks. Our method consists of a disentangled representation module and a semantic-guided fusion module. The disentangled representation module is composed of an attribute encoder and an identity encoder, which aims to achieve the disentanglement of the identity and the attribute information. The identity encoder is more flexible and the attribute encoder contains more details of the attributes than its competitors. Benefiting from the disentangled representation, FaceSwapper can swap face images progressively. In addition, semantic information is introduced into the semantic-guided fusion module to control the swapped area and model the pose and expression more accurately. The experimental results show that our method achieves state-of-the-art results on benchmark datasets with fewer training samples.

16 | 17 | 18 | 19 | >

20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | ## Environment 28 | Clone this repository: 29 | 30 | ```bash 31 | git clone https://github.com/liqi-casia/FaceSwapper.git 32 | cd FaceSwapper/ 33 | ``` 34 | 35 | Install the dependencies: 36 | ```bash 37 | conda create -n faceswapper python=3.6.7 38 | conda activate faceswapper 39 | conda install -y pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.0 -c pytorch 40 | conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge 41 | pip install opencv-python==4.1.2.30 ffmpeg-python==0.2.0 scikit-image==0.16.2 42 | pip install pillow==7.0.0 scipy==1.2.1 tqdm==4.43.0 munch==2.5.0 43 | conda install -y -c anaconda pyyaml 44 | pip install tensorboard tensorboardX 45 | 46 | ``` 47 | 48 | 49 | ## Datasets and pre-trained checkpoints 50 | We provide a link to download datasets used in FaceSwapper and the corresponding pre-trained checkpoints. The datasets and checkpoints should be moved to the `data` and `expr/checkpoints` directories, respectively. 51 | 52 | * Datasets. Click here to download the **CelebA Dataset** through [Baidu Netdisk](https://pan.baidu.com/s/12KqrRI_K9frgky2YPNSJ0Q?pwd=wvy5) or [Google Drive](https://drive.google.com/file/d/1W-FF-DNJ752L7Zqdm-otMT6frOcTLd7q/view?usp=sharing) and the **FaceForensics++ Dataset** through [Baidu Netdisk](https://pan.baidu.com/s/1LcKJw2sGkEAHWjmTcxou-A?pwd=omkc) or [Google Drive](https://drive.google.com/file/d/1Sa6v0m8s4xHXPPzeFt8K1q1n3T088Gvz/view?usp=sharing). 53 | 54 | * Checkpoints. Click here to download the **face recognition model** through [Baidu Netdisk](https://pan.baidu.com/s/11qcEiBjAsQPXwIqKjOE-rQ?pwd=2g78) or [Google Drive](https://drive.google.com/file/d/1-lxc-jZGIFNdwFUXQ9tDS9OSuhadj6AC/view?usp=sharing), the **face alignment model** through [Baidu Netdisk](https://pan.baidu.com/s/1htwmXDi2Gev8l09oJpr_Mg?pwd=ejmj) or [Google Drive](https://drive.google.com/file/d/1lBt4x4P5qaClB2ZN_POBV-ue41hdlaoJ/view?usp=sharing), and the **face swapping model** through [Baidu Netdisk](https://pan.baidu.com/s/1aIRX0twylUJ42z4sYhUaVA?pwd=bkru) or [Google Drive](https://drive.google.com/file/d/1Tb3V09wbaGe6SaiN3BZkOcCy7VJ0KYC8/view?usp=sharing). 55 | 56 | After storing all the files, the directory structure of `./data` and `./pretrained_checkpoints` is expected as follows. 57 | 58 | 59 | 60 | 61 | ``` 62 | ./data 63 | ├── CelebA Dataset 64 | │   ├── CalebA images 65 | │   ├── CelebA landmark images 66 | │   └── CelebA mask images 67 | └── FF++ Dataset 68 |    ├── ff++ images 69 |    ├── ff++ landmark images 70 |    ├── ff++ mask images 71 |    └── ff++ parsing images 72 | 73 | ./pretrained_checkpoints 74 | ├── model_ir_se50.pth 75 | ├── wing.ckpt 76 | └── faceswapper.ckpt 77 | ``` 78 | 79 | 80 | 81 | 82 | ## Generating swapped images 83 | After downloading the pre-trained checkpoints, you can synthesize swapped images. The following commands will save generated images to the `expr/results` directory. 84 | 85 | 86 | FaceForensics++ Dataset. To generate swapped images, you need to specify the testing parameters in param.yaml (especially change mode from 'train' to 'test', and pay attention to paramerters in *\#directory for testing* ). Then run the following command: 87 | ```bash 88 | python main.py 89 | ``` 90 | There are three subfolders in `expr/results/ff++/`, which are named `swapped_result_single`, `swapped_result_afterps` and `swapped_result_all`. Each image is named as *source_FS_target.png*, where source image provides the identity information and target image provides attribute information. 91 | 92 | --`swapped_result_single`: the swapped images. 93 | 94 | --`swapped_result_afterps`: the swapped images after post process. 95 | 96 | --`swapped_result_all`: caoncatenation of the souce images, the target images, the swapped images and the 97 | swapped images after post process. 98 | 99 | Other Datasets. 100 | First, crop and align face images from other datasets automatically so that the proportion of face occupied in the whole is similar to that of CelebA dataset and FaceForensics++ dataset. Then, define the face swapping list siimilar to `face_swap_list.txt` (`source_image_name target_image_name`). The other testing procedure is similar to FaceForensics++ Dataset. 101 | 102 | 103 | Post Process. If occlusion exists in the source image (*e.g.*, hair, hat), we simply preserve the forehead and hair of the target image in the swapped image. 104 | Othewise, we simplely preserve the hair of the target image. You just need to Set `post_process: True` if you want the post process. 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | ## Training networks 115 | 116 | 117 | To train FaceSwapper from scratch, just set the training parameters in param.yaml (especially change mode from 'test' to 'train', and pay attention to paramerters in *\#directory for training* ), and run the following commands. Generated images and network checkpoints will be stored in the `expr/samples` and `expr/checkpoints` directories, respectively. Training usually takes about several days on a single Tesla V100 GPU depending on the total trainig iterations. 118 | 119 | 120 | 121 | 122 | ```bash 123 | python main.py 124 | ``` 125 | 126 | 127 | 128 | ## License 129 | The source code, pre-trained models, and dataset are available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license. You can use, copy, tranform and build upon the material for non-commercial purposes as long as you give appropriate credit by citing our paper, and indicate if changes were made. 130 | For technical, business and other inquiries, please contact qli@nlpr.ia.ac.cn.
131 | 132 | 133 | 134 | ## Citation 135 | If you find this work useful for your research, please cite our paper: 136 | 137 | ``` 138 | @article{li2024learning, 139 | author={Li, Qi and Wang, Weining and Xu, Chengzhong and Sun, Zhenan and Yang, Ming-Hsuan}, 140 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 141 | title={Learning Disentangled Representation for One-Shot Progressive Face Swapping}, 142 | year={2024}, 143 | volume={46}, 144 | number={12}, 145 | pages={8348-8364} 146 | } 147 | ``` 148 | 149 | ## Acknowledgements 150 | The code is written based on the following projects. We would like to thank for their contributions. 151 | 152 | - [Stargan V2](https://github.com/clovaai/stargan-v2) 153 | 154 | - [face alignment](https://github.com/1adrianb/face-alignment) 155 | - [AdaptiveWingLoss](https://github.com/protossw512/AdaptiveWingLoss) 156 | 157 | 158 | -------------------------------------------------------------------------------- /core/checkpoint.py: -------------------------------------------------------------------------------- 1 | """ 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 3 | 4.0 International License. To view a copy of this license, visit 4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 6 | """ 7 | 8 | import os 9 | import torch 10 | 11 | class CheckpointIO(object): 12 | def __init__(self, fname_template, **kwargs): 13 | os.makedirs(os.path.dirname(fname_template), exist_ok=True) 14 | self.fname_template = fname_template 15 | self.module_dict = kwargs 16 | def register(self, **kwargs): 17 | self.module_dict.update(kwargs) 18 | def save(self, step): 19 | fname = self.fname_template.format(step) 20 | print('Saving checkpoint into %s...' % fname) 21 | outdict = {} 22 | for name, module in self.module_dict.items(): 23 | outdict[name] = module.state_dict() 24 | torch.save(outdict, fname) 25 | def load(self, step): 26 | fname = self.fname_template.format(step) 27 | assert os.path.exists(fname), fname + ' does not exist!' 28 | print('Loading checkpoint from %s...' % fname) 29 | if torch.cuda.is_available(): 30 | module_dict = torch.load(fname) 31 | else: 32 | module_dict = torch.load(fname, map_location=torch.device('cpu')) 33 | for name, module in self.module_dict.items(): 34 | module.load_state_dict(module_dict[name]) 35 | def load_test(self,ckptname): 36 | fname = self.fname_template + ckptname 37 | assert os.path.exists(fname), fname + ' does not exist!' 38 | print('Loading checkpoint from %s...' % fname) 39 | if torch.cuda.is_available(): 40 | module_dict = torch.load(fname) 41 | else: 42 | module_dict = torch.load(fname, map_location=torch.device('cpu')) 43 | for name, module in self.module_dict.items(): 44 | module.load_state_dict(module_dict[name]) -------------------------------------------------------------------------------- /core/data_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 3 | 4.0 International License. To view a copy of this license, visit 4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 6 | """ 7 | 8 | from pathlib import Path 9 | from itertools import chain 10 | from munch import Munch 11 | from PIL import Image 12 | import random 13 | import glob 14 | import copy 15 | import torch 16 | from torch.utils import data 17 | from torchvision import transforms 18 | 19 | def listdir(dname): 20 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) 21 | for ext in ['png', 'jpg', 'jpeg', 'JPG']])) 22 | return fnames 23 | class DefaultDataset(data.Dataset): 24 | def __init__(self, root, transform=None): 25 | self.samples = listdir(root) 26 | self.samples.sort() 27 | self.transform = transform 28 | self.targets = None 29 | def __getitem__(self, index): 30 | fname = self.samples[index] 31 | img = Image.open(fname).convert('RGB') 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | return img 35 | def __len__(self): 36 | return len(self.samples) 37 | 38 | class TrainFaceDataSet(data.Dataset): 39 | def __init__(self, data_path_list, transform=None, transform_seg=None): 40 | self.datasets = [] 41 | self.num_per_folder =[] 42 | self.lm_image_path = data_path_list[0][:data_path_list[0].rfind('/')+1] \ 43 | + data_path_list[0][data_path_list[0].rfind('/')+1:] + '_lm_images/' 44 | self.mask_image_path = data_path_list[0][:data_path_list[0].rfind('/')+1] \ 45 | + data_path_list[0][data_path_list[0].rfind('/')+1:] + '_mask_images/' 46 | for data_path in data_path_list: 47 | image_list = glob.glob(f'{data_path}/*.*g') 48 | self.datasets.append(image_list) 49 | self.num_per_folder.append(len(image_list)) 50 | self.transform = transform 51 | self.transform_seg = transform_seg 52 | 53 | def __getitem__(self, item): 54 | idx = 0 55 | while item >= self.num_per_folder[idx]: 56 | item -= self.num_per_folder[idx] 57 | idx += 1 58 | image_path = self.datasets[idx][item] 59 | souce_lm_image_path = self.lm_image_path + image_path.split('/')[-1] 60 | souce_mask_image_path = self.mask_image_path + image_path.split('/')[-1] 61 | source_image = Image.open(image_path).convert('RGB') 62 | source_lm_image = Image.open(souce_lm_image_path).convert('RGB') 63 | source_mask_image = Image.open(souce_mask_image_path).convert('L') 64 | if self.transform is not None: 65 | source_image = self.transform(source_image) 66 | source_lm_image = self.transform(source_lm_image) 67 | source_mask_image = self.transform_seg(source_mask_image) 68 | #choose ref from the same folder image 69 | temp = copy.deepcopy(self.datasets[idx]) 70 | temp.pop(item) 71 | reference_image_path = temp[random.randint(0, len(temp)-1)] 72 | reference_lm_image_path = self.lm_image_path + reference_image_path.split('/')[-1] 73 | reference_mask_image_path = self.mask_image_path + reference_image_path.split('/')[-1] 74 | reference_image = Image.open(reference_image_path).convert('RGB') 75 | reference_lm_image = Image.open(reference_lm_image_path).convert('RGB') 76 | reference_mask_image = Image.open(reference_mask_image_path).convert('L') 77 | if self.transform is not None: 78 | reference_image = self.transform(reference_image) 79 | reference_lm_image = self.transform(reference_lm_image) 80 | reference_mask_image = self.transform_seg(reference_mask_image) 81 | outputs=dict(src=source_image, ref=reference_image, src_lm=source_lm_image, ref_lm=reference_lm_image, 82 | src_mask=1-source_mask_image, ref_mask=1-reference_mask_image) 83 | return outputs 84 | def __len__(self): 85 | return sum(self.num_per_folder) 86 | 87 | class TestFaceDataSet(data.Dataset): 88 | def __init__(self, data_path_list, test_img_list, transform=None, transform_seg=None): 89 | self.source_dataset = [] 90 | self.reference_dataset = [] 91 | self.data_path_list = data_path_list 92 | self.lm_image_path = data_path_list[:data_path_list.rfind('/')+1] \ 93 | + data_path_list[data_path_list.rfind('/')+1:] + '_lm_images/' 94 | self.mask_image_path = data_path_list[:data_path_list.rfind('/')+1] \ 95 | + data_path_list[data_path_list.rfind('/')+1:] + '_mask_images/' 96 | self.biseg_parsing_path = data_path_list[:data_path_list.rfind('/')+1] \ 97 | + data_path_list[data_path_list.rfind('/')+1:] + '_parsing_images/' 98 | f=open(test_img_list,'r') 99 | for line in f.readlines(): 100 | line.split(' ') 101 | self.source_dataset.append(line.split(' ')[0]) 102 | self.reference_dataset.append(line.split(' ')[1]) 103 | f.close() 104 | self.transform = transform 105 | self.transform_seg = transform_seg 106 | def __getitem__(self, item): 107 | source_image_path = self.data_path_list + '/' + self.source_dataset[item] 108 | try: 109 | source_image = Image.open(source_image_path).convert('RGB') 110 | except: 111 | print('fail to read %s.jpg'%source_image_path) 112 | souce_lm_image_path = self.lm_image_path + self.source_dataset[item] 113 | souce_mask_image_path = self.mask_image_path + self.source_dataset[item] 114 | source_parsing_image_path = self.biseg_parsing_path + self.source_dataset[item] 115 | source_lm_image = Image.open(souce_lm_image_path).convert('RGB') 116 | source_mask_image = Image.open(souce_mask_image_path).convert('L') 117 | source_parsing_image = Image.open(source_parsing_image_path).convert('L') 118 | if self.transform is not None: 119 | source_image = self.transform(source_image) 120 | source_lm_image = self.transform(source_lm_image) 121 | source_mask_image = self.transform_seg(source_mask_image) 122 | source_parsing = self.transform_seg(source_parsing_image) 123 | reference_image_path = self.data_path_list + '/' + self.reference_dataset[item][0:-1] 124 | try: 125 | reference_image = Image.open(reference_image_path).convert('RGB') 126 | except: 127 | print('fail to read %s.jpg' %reference_image_path) 128 | reference_lm_image_path = self.lm_image_path + self.reference_dataset[item][0:-1] 129 | reference_mask_image_path = self.mask_image_path + self.reference_dataset[item][0:-1] 130 | reference_parsing_image_path = self.biseg_parsing_path + self.reference_dataset[item][0:-1] 131 | reference_lm_image = Image.open(reference_lm_image_path).convert('RGB') 132 | reference_mask_image = Image.open(reference_mask_image_path).convert('L') 133 | reference_parsing = Image.open(reference_parsing_image_path).convert('L') 134 | if self.transform is not None: 135 | reference_image = self.transform(reference_image) 136 | reference_lm_image = self.transform(reference_lm_image) 137 | reference_mask_image = self.transform_seg(reference_mask_image) 138 | reference_parsing = self.transform_seg(reference_parsing) 139 | outputs=dict(src=source_image, ref=reference_image, src_lm=source_lm_image, ref_lm=reference_lm_image, src_mask=1-source_mask_image, 140 | ref_mask=1-reference_mask_image, src_parsing=source_parsing, ref_parsing=reference_parsing, 141 | src_name=self.source_dataset[item], ref_name=self.reference_dataset[item]) 142 | return outputs 143 | def __len__(self): 144 | return len(self.source_dataset) 145 | 146 | def get_train_loader(root, img_size=256, 147 | batch_size=8, num_workers=4): 148 | print('Preparing dataLoader to fetch images during the training phase...') 149 | transform = transforms.Compose([ 150 | transforms.Resize([img_size, img_size]), 151 | transforms.ToTensor(), 152 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 153 | std=[0.5, 0.5, 0.5]), 154 | ]) 155 | transform_seg = transforms.Compose([ 156 | transforms.Resize([img_size, img_size]), 157 | transforms.ToTensor(), 158 | ]) 159 | train_dataset = TrainFaceDataSet(root, transform, transform_seg) 160 | train_loader = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, 161 | num_workers=num_workers, drop_last=True) 162 | return train_loader 163 | 164 | def get_test_loader(root, test_img_list, img_size=256, 165 | batch_size=8, num_workers=4): 166 | print('Preparing dataLoader to fetch images during the testing phase...') 167 | transform = transforms.Compose([ 168 | transforms.Resize([img_size, img_size]), 169 | transforms.ToTensor(), 170 | transforms.Normalize(mean=[0.5, 0.5, 0.5], 171 | std=[0.5, 0.5, 0.5]), 172 | ]) 173 | transform_seg = transforms.Compose([ 174 | transforms.Resize([img_size, img_size]), 175 | transforms.ToTensor(), 176 | ]) 177 | test_dataset = TestFaceDataSet(root, test_img_list, transform, transform_seg) 178 | test_loader = data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, 179 | num_workers=num_workers, drop_last=True) 180 | return test_loader 181 | 182 | class InputFetcher: 183 | def __init__(self, loader, mode=''): 184 | self.loader = loader 185 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 186 | self.mode = mode 187 | def _fetch_inputs(self): 188 | try: 189 | inputs_data = next(self.iter) 190 | except (AttributeError, StopIteration): 191 | self.iter = iter(self.loader) 192 | inputs_data= next(self.iter) 193 | return inputs_data 194 | def __next__(self): 195 | t_inputs = self._fetch_inputs() 196 | inputs = Munch(src=t_inputs['src'], tar=t_inputs['ref'], src_lm=t_inputs['src_lm'], 197 | tar_lm=t_inputs['ref_lm'], src_mask=t_inputs['src_mask'], tar_mask=t_inputs['ref_mask']) 198 | if self.mode=='train': 199 | inputs = Munch({k: t.to(self.device) for k, t in inputs.items()}) 200 | elif self.mode=='test': 201 | inputs = Munch({k: t.to(self.device) for k, t in inputs.items()}, src_parsing=t_inputs['src_parsing'].to(self.device), 202 | tar_parsing=t_inputs['ref_parsing'].to(self.device), src_name=t_inputs['src_name'],tar_name=t_inputs['ref_name']) 203 | return inputs 204 | 205 | -------------------------------------------------------------------------------- /core/face_model.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter 2 | import torch 3 | from collections import namedtuple 4 | import math 5 | 6 | ################################## Original Arcface Model ############################################################# 7 | 8 | class Flatten(Module): 9 | def forward(self, input): 10 | return input.view(input.size(0), -1) 11 | 12 | def l2_norm(input,axis=1): 13 | norm = torch.norm(input,2,axis,True) 14 | output = torch.div(input, norm) 15 | return output 16 | 17 | class SEModule(Module): 18 | def __init__(self, channels, reduction): 19 | super(SEModule, self).__init__() 20 | self.avg_pool = AdaptiveAvgPool2d(1) 21 | self.fc1 = Conv2d( 22 | channels, channels // reduction, kernel_size=1, padding=0 ,bias=False) 23 | self.relu = ReLU(inplace=True) 24 | self.fc2 = Conv2d( 25 | channels // reduction, channels, kernel_size=1, padding=0 ,bias=False) 26 | self.sigmoid = Sigmoid() 27 | 28 | def forward(self, x): 29 | module_input = x 30 | x = self.avg_pool(x) 31 | x = self.fc1(x) 32 | x = self.relu(x) 33 | x = self.fc2(x) 34 | x = self.sigmoid(x) 35 | return module_input * x 36 | 37 | class bottleneck_IR(Module): 38 | def __init__(self, in_channel, depth, stride): 39 | super(bottleneck_IR, self).__init__() 40 | if in_channel == depth: 41 | self.shortcut_layer = MaxPool2d(1, stride) 42 | else: 43 | self.shortcut_layer = Sequential( 44 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False), BatchNorm2d(depth)) 45 | self.res_layer = Sequential( 46 | BatchNorm2d(in_channel), 47 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1 ,bias=False), PReLU(depth), 48 | Conv2d(depth, depth, (3, 3), stride, 1 ,bias=False), BatchNorm2d(depth)) 49 | 50 | def forward(self, x): 51 | shortcut = self.shortcut_layer(x) 52 | res = self.res_layer(x) 53 | return res + shortcut 54 | 55 | class bottleneck_IR_SE(Module): 56 | def __init__(self, in_channel, depth, stride): 57 | super(bottleneck_IR_SE, self).__init__() 58 | if in_channel == depth: 59 | self.shortcut_layer = MaxPool2d(1, stride) 60 | else: 61 | self.shortcut_layer = Sequential( 62 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False), 63 | BatchNorm2d(depth)) 64 | self.res_layer = Sequential( 65 | BatchNorm2d(in_channel), 66 | Conv2d(in_channel, depth, (3,3), (1,1),1 ,bias=False), 67 | PReLU(depth), 68 | Conv2d(depth, depth, (3,3), stride, 1 ,bias=False), 69 | BatchNorm2d(depth), 70 | SEModule(depth,16) 71 | ) 72 | def forward(self,x): 73 | shortcut = self.shortcut_layer(x) 74 | res = self.res_layer(x) 75 | return res + shortcut 76 | 77 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 78 | '''A named tuple describing a ResNet block.''' 79 | 80 | def get_block(in_channel, depth, num_units, stride = 2): 81 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units-1)] 82 | 83 | def get_blocks(num_layers): 84 | if num_layers == 50: 85 | blocks = [ 86 | get_block(in_channel=64, depth=64, num_units = 3), 87 | get_block(in_channel=64, depth=128, num_units=4), 88 | get_block(in_channel=128, depth=256, num_units=14), 89 | get_block(in_channel=256, depth=512, num_units=3) 90 | ] 91 | elif num_layers == 100: 92 | blocks = [ 93 | get_block(in_channel=64, depth=64, num_units=3), 94 | get_block(in_channel=64, depth=128, num_units=13), 95 | get_block(in_channel=128, depth=256, num_units=30), 96 | get_block(in_channel=256, depth=512, num_units=3) 97 | ] 98 | elif num_layers == 152: 99 | blocks = [ 100 | get_block(in_channel=64, depth=64, num_units=3), 101 | get_block(in_channel=64, depth=128, num_units=8), 102 | get_block(in_channel=128, depth=256, num_units=36), 103 | get_block(in_channel=256, depth=512, num_units=3) 104 | ] 105 | return blocks 106 | 107 | class Backbone(Module): 108 | def __init__(self, num_layers, drop_ratio, mode='ir'): 109 | super(Backbone, self).__init__() 110 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152' 111 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se' 112 | blocks = get_blocks(num_layers) 113 | if mode == 'ir': 114 | unit_module = bottleneck_IR 115 | elif mode == 'ir_se': 116 | unit_module = bottleneck_IR_SE 117 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False), 118 | BatchNorm2d(64), 119 | PReLU(64)) 120 | self.output_layer = Sequential(BatchNorm2d(512), 121 | Dropout(drop_ratio), 122 | Flatten(), 123 | Linear(512 * 7 * 7, 512), 124 | BatchNorm1d(512)) 125 | # ) 126 | modules = [] 127 | for block in blocks: 128 | for bottleneck in block: 129 | modules.append( 130 | unit_module(bottleneck.in_channel, 131 | bottleneck.depth, 132 | bottleneck.stride)) 133 | self.body = Sequential(*modules) 134 | 135 | def forward(self,x): 136 | feats = [] 137 | x = self.input_layer(x) 138 | for m in self.body.children(): 139 | x = m(x) 140 | feats.append(x) 141 | x = self.output_layer(x) 142 | return l2_norm(x), feats 143 | 144 | ################################## MobileFaceNet ############################################################# 145 | 146 | class Conv_block(Module): 147 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 148 | super(Conv_block, self).__init__() 149 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 150 | self.bn = BatchNorm2d(out_c) 151 | self.prelu = PReLU(out_c) 152 | def forward(self, x): 153 | x = self.conv(x) 154 | x = self.bn(x) 155 | x = self.prelu(x) 156 | return x 157 | 158 | class Linear_block(Module): 159 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 160 | super(Linear_block, self).__init__() 161 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False) 162 | self.bn = BatchNorm2d(out_c) 163 | def forward(self, x): 164 | x = self.conv(x) 165 | x = self.bn(x) 166 | return x 167 | 168 | class Depth_Wise(Module): 169 | def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 170 | super(Depth_Wise, self).__init__() 171 | self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 172 | self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride) 173 | self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 174 | self.residual = residual 175 | def forward(self, x): 176 | if self.residual: 177 | short_cut = x 178 | x = self.conv(x) 179 | x = self.conv_dw(x) 180 | x = self.project(x) 181 | if self.residual: 182 | output = short_cut + x 183 | else: 184 | output = x 185 | return output 186 | 187 | class Residual(Module): 188 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 189 | super(Residual, self).__init__() 190 | modules = [] 191 | for _ in range(num_block): 192 | modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups)) 193 | self.model = Sequential(*modules) 194 | def forward(self, x): 195 | return self.model(x) 196 | 197 | class MobileFaceNet(Module): 198 | def __init__(self, embedding_size): 199 | super(MobileFaceNet, self).__init__() 200 | self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 201 | self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 202 | self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) 203 | self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 204 | self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) 205 | self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 206 | self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) 207 | self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 208 | self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 209 | self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0)) 210 | self.conv_6_flatten = Flatten() 211 | self.linear = Linear(512, embedding_size, bias=False) 212 | self.bn = BatchNorm1d(embedding_size) 213 | 214 | def forward(self, x): 215 | out = self.conv1(x) 216 | out = self.conv2_dw(out) 217 | out = self.conv_23(out) 218 | out = self.conv_3(out) 219 | out = self.conv_34(out) 220 | out = self.conv_4(out) 221 | out = self.conv_45(out) 222 | out = self.conv_5(out) 223 | out = self.conv_6_sep(out) 224 | out = self.conv_6_dw(out) 225 | out = self.conv_6_flatten(out) 226 | out = self.linear(out) 227 | out = self.bn(out) 228 | return l2_norm(out) 229 | 230 | ################################## Arcface head ############################################################# 231 | 232 | class Arcface(Module): 233 | # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599 234 | def __init__(self, embedding_size=512, classnum=51332, s=64., m=0.5): 235 | super(Arcface, self).__init__() 236 | self.classnum = classnum 237 | self.kernel = Parameter(torch.Tensor(embedding_size,classnum)) 238 | self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) 239 | self.m = m 240 | self.s = s 241 | self.cos_m = math.cos(m) 242 | self.sin_m = math.sin(m) 243 | self.mm = self.sin_m * m # issue 1 244 | self.threshold = math.cos(math.pi - m) 245 | def forward(self, embbedings, label): 246 | # weights norm 247 | nB = len(embbedings) 248 | kernel_norm = l2_norm(self.kernel,axis=0) 249 | cos_theta = torch.mm(embbedings,kernel_norm) 250 | cos_theta = cos_theta.clamp(-1,1) 251 | cos_theta_2 = torch.pow(cos_theta, 2) 252 | sin_theta_2 = 1 - cos_theta_2 253 | sin_theta = torch.sqrt(sin_theta_2) 254 | cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m) 255 | cond_v = cos_theta - self.threshold 256 | cond_mask = cond_v <= 0 257 | keep_val = (cos_theta - self.mm) 258 | cos_theta_m[cond_mask] = keep_val[cond_mask] 259 | output = cos_theta * 1.0 260 | idx_ = torch.arange(0, nB, dtype=torch.long) 261 | output[idx_, label] = cos_theta_m[idx_, label] 262 | output *= self.s 263 | return output 264 | 265 | ################################## Cosface head ############################################################# 266 | 267 | class Am_softmax(Module): 268 | # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599 269 | def __init__(self,embedding_size=512,classnum=51332): 270 | super(Am_softmax, self).__init__() 271 | self.classnum = classnum 272 | self.kernel = Parameter(torch.Tensor(embedding_size,classnum)) 273 | # initial kernel 274 | self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) 275 | self.m = 0.35 276 | self.s = 30. 277 | def forward(self,embbedings,label): 278 | kernel_norm = l2_norm(self.kernel,axis=0) 279 | cos_theta = torch.mm(embbedings,kernel_norm) 280 | cos_theta = cos_theta.clamp(-1,1) 281 | phi = cos_theta - self.m 282 | label = label.view(-1,1) 283 | index = cos_theta.data * 0.0 284 | index.scatter_(1,label.data.view(-1,1),1) 285 | index = index.byte() 286 | output = cos_theta * 1.0 287 | output[index] = phi[index] 288 | output *= self.s 289 | return output 290 | 291 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 3 | 4.0 International License. To view a copy of this license, visit 4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 6 | """ 7 | 8 | import copy 9 | import math 10 | from munch import Munch 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from core.wing import FAN 16 | 17 | 18 | 19 | class ResBlk(nn.Module): 20 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), 21 | normalize=False, downsample=False): 22 | super().__init__() 23 | self.actv = actv 24 | self.normalize = normalize 25 | self.downsample = downsample 26 | self.learned_sc = dim_in != dim_out 27 | self._build_weights(dim_in, dim_out) 28 | 29 | def _build_weights(self, dim_in, dim_out): 30 | self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) 31 | self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) 32 | if self.normalize: 33 | self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) 34 | self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) 35 | if self.learned_sc: 36 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 37 | 38 | def _shortcut(self, x): 39 | if self.learned_sc: 40 | x = self.conv1x1(x) 41 | if self.downsample: 42 | x = F.avg_pool2d(x, 2) 43 | return x 44 | 45 | def _residual(self, x): 46 | if self.normalize: 47 | x = self.norm1(x) 48 | x = self.actv(x) 49 | x = self.conv1(x) 50 | if self.downsample: 51 | x = F.avg_pool2d(x, 2) 52 | if self.normalize: 53 | x = self.norm2(x) 54 | x = self.actv(x) 55 | x = self.conv2(x) 56 | return x 57 | 58 | def forward(self, x): 59 | x = self._shortcut(x) + self._residual(x) 60 | return x / math.sqrt(2) # unit variance 61 | 62 | 63 | class AdaIN(nn.Module): 64 | def __init__(self, id_dim, num_features): 65 | super().__init__() 66 | input_nc = 3 67 | self.norm = nn.InstanceNorm2d(num_features, affine=False) 68 | self.fc = nn.Linear(id_dim, num_features*2) 69 | self.conv_weight = nn.Conv2d(input_nc, num_features, kernel_size=3, padding=1) 70 | self.conv_bias = nn.Conv2d(input_nc, num_features, kernel_size=3, padding=1) 71 | 72 | def forward(self, x, s, mask, landmark): 73 | h = self.fc(s) 74 | h = h.view(h.size(0), h.size(1), 1, 1) 75 | gamma, beta = torch.chunk(h, chunks=2, dim=1) 76 | face_part = (1-mask[x.size(2)]) * x # face area; 77 | norm_face_part = (1 + gamma) * self.norm(face_part) + beta 78 | landmark = F.interpolate(landmark, x.size(2), mode='bilinear',align_corners=True) 79 | weight_norm = self.conv_weight(landmark) 80 | bias_norm = self.conv_bias(landmark) 81 | norm_face_part = norm_face_part * (1+weight_norm) + bias_norm 82 | new_face = mask[x.size(2)] * x + (1-mask[x.size(2)])*norm_face_part 83 | return new_face 84 | 85 | 86 | class AdainResBlk(nn.Module): 87 | def __init__(self, dim_in, dim_out, id_dim=512, 88 | actv=nn.LeakyReLU(0.2), upsample=False): 89 | super().__init__() 90 | self.actv = actv 91 | self.upsample = upsample 92 | self.learned_sc = dim_in != dim_out 93 | self._build_weights(dim_in, dim_out, id_dim) 94 | 95 | def _build_weights(self, dim_in, dim_out, id_dim): 96 | self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) 97 | self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) 98 | self.norm1 = AdaIN(id_dim, dim_in) 99 | self.norm2 = AdaIN(id_dim, dim_out) 100 | if self.learned_sc: 101 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 102 | 103 | def _shortcut(self, x): 104 | if self.upsample: 105 | x = F.interpolate(x, scale_factor=2, mode='nearest') 106 | if self.learned_sc: 107 | x = self.conv1x1(x) 108 | return x 109 | 110 | def _residual(self, x, s, mask, landmark): 111 | x = self.norm1(x, s, mask, landmark) 112 | x = self.actv(x) 113 | if self.upsample: 114 | x = F.interpolate(x, scale_factor=2, mode='nearest') 115 | x = self.conv1(x) 116 | x = self.norm2(x, s, mask,landmark) 117 | x = self.actv(x) 118 | x = self.conv2(x) 119 | return x 120 | 121 | def forward(self, x, s, mask,landmark): 122 | out = self._residual(x, s, mask,landmark) 123 | return out 124 | 125 | class Generator(nn.Module): 126 | def __init__(self, img_size=256, id_dim=512, max_conv_dim=512): 127 | super().__init__() 128 | dim_in = 2**14 // img_size 129 | self.img_size = img_size 130 | self.id_encoder = IdentityEncoder(self.img_size, id_dim, max_conv_dim) 131 | self.attr_encoder = AttrEncoder(self.img_size,max_conv_dim) 132 | self.org_decoder = Decoder(self.img_size,id_dim,max_conv_dim) 133 | def forward(self, x_a, x_b, x_a_lm, x_b_lm, x_a_mask=None, x_b_mask=None): 134 | x_a_attr, x_a_idvec, x_a_cache = self.encode(x_a, x_a_mask) 135 | x_b_attr, x_b_idvec, x_b_cache = self.encode(x_b, x_b_mask) 136 | x_ba, ms_features_ba, ms_outputs_ba = self.decode(x_b_attr, x_a_idvec, x_b_lm, x_b_cache, x_b_mask) # a's identity 137 | x_ab, ms_features_ab, ms_outputs_ab = self.decode(x_a_attr, x_b_idvec, x_a_lm, x_a_cache, x_a_mask) # b's identity 138 | return x_ba, x_ab, ms_features_ba, ms_features_ab, ms_outputs_ba, ms_outputs_ab 139 | def encode(self, image,mask=None): 140 | # encode an image to its attribute code and identity code 141 | id_vec, id_all_features = self.id_encoder(image) 142 | attr, attr_all_features, cache = self.attr_encoder(image, mask) 143 | return attr, id_vec, cache 144 | def decode(self, attr, id_vec, lm_image, cache, mask=None): 145 | image, ms_features, ms_outputs = self.org_decoder(attr, id_vec, lm_image, cache, mask) 146 | return image, ms_features, ms_outputs 147 | def encode_features(self, image, mask=None): 148 | # encode an image to its multiscale attribute feature and identity feature 149 | id_vec, id_all_features = self.id_encoder(image) 150 | attr, attr_all_features, cache = self.attr_encoder(image, mask) 151 | return attr_all_features, id_all_features 152 | 153 | 154 | # attribute encoder 155 | class AttrEncoder(nn.Module): 156 | def __init__(self, img_size=256, max_conv_dim=512): 157 | super().__init__() 158 | dim_in = 2**14 // img_size 159 | blocks = [] 160 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] 161 | repeat_num = int(np.log2(img_size)) - 4 162 | repeat_num += 1 163 | for _ in range(repeat_num): 164 | dim_out = min(dim_in*2, max_conv_dim) 165 | blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)] 166 | dim_in = dim_out 167 | # bottleneck blocks 168 | for _ in range(2): 169 | blocks += [ResBlk(dim_out, dim_out, normalize=True)] 170 | self.model = nn.Sequential(*blocks) 171 | def forward(self, x, masks=None): 172 | attr_all_features = [] 173 | cache = {} 174 | for block in self.model: 175 | if (masks is not None) and (x.size(2) in [32, 64, 128]): 176 | cache[x.size(2)] = x 177 | x = block(x) 178 | attr_all_features.append(x) 179 | return x, attr_all_features, cache 180 | 181 | # identity encoder 182 | class IdentityEncoder(nn.Module): 183 | def __init__(self, img_size=256, id_dim=512, max_conv_dim=512): 184 | super().__init__() 185 | dim_in = 2**14 // img_size 186 | blocks = [] 187 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] 188 | repeat_num = int(np.log2(img_size)) - 2 189 | for _ in range(repeat_num): 190 | dim_out = min(dim_in*2, max_conv_dim) 191 | blocks += [ResBlk(dim_in, dim_out, downsample=True)] 192 | dim_in = dim_out 193 | blocks += [nn.LeakyReLU(0.2)] 194 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] 195 | blocks += [nn.LeakyReLU(0.2)] 196 | self.model = nn.Sequential(*blocks) 197 | self.final_layer = nn.ModuleList() 198 | self.final_layer += [nn.Linear(dim_out, id_dim)] 199 | def forward(self, x): 200 | id_all_features = [] 201 | for block in self.model: 202 | x = block(x) 203 | id_all_features.append(x) 204 | x = x.view(x.size(0), -1) # batch_size, dim_out 205 | for block in self.final_layer: 206 | x = block(x) 207 | x = x.view(x.size(0), -1) 208 | id_all_features.append(x) 209 | return x, id_all_features 210 | 211 | # decoder 212 | class Decoder(nn.Module): 213 | def __init__(self, img_size=256, id_dim=512, max_conv_dim=512): 214 | super().__init__() 215 | dim_in = 2**14 // img_size 216 | dim_in_org = dim_in 217 | self.mask_size = [32,64,128] 218 | self.x_size =[8,16,32,64,128,256] 219 | self.to_rgb = nn.Sequential( 220 | nn.InstanceNorm2d(dim_in, affine=True), 221 | nn.LeakyReLU(0.2), 222 | nn.Conv2d(dim_in, 3, 1, 1, 0)) 223 | blocks = [] 224 | repeat_num = int(np.log2(img_size)) - 4 225 | repeat_num += 1 226 | for _ in range(repeat_num): 227 | dim_out = min(dim_in*2, max_conv_dim) 228 | blocks.insert(0, AdainResBlk(dim_out, dim_in, id_dim, 229 | upsample=True)) 230 | dim_in = dim_out 231 | # bottleneck blocks 232 | for _ in range(2): 233 | blocks.insert(0, AdainResBlk(dim_out, dim_out, id_dim)) 234 | self.model = nn.Sequential(*blocks) 235 | def to_rgb_output(dim_before_RGB): 236 | output = nn.Sequential( 237 | nn.InstanceNorm2d(dim_before_RGB, affine=True), 238 | nn.LeakyReLU(0.2), 239 | nn.Conv2d(dim_before_RGB, 3, 1, 1, 0)) 240 | return output 241 | def skip_connection(dim_in_skip): 242 | dim_out_skip = dim_in_skip 243 | output = nn.Sequential( 244 | nn.Conv2d(dim_in_skip, dim_out_skip, 1, 1, 0)) 245 | return output 246 | self.rgb_converters = nn.ModuleList() 247 | self.skip_connects = nn.ModuleList() 248 | for i in self.mask_size: 249 | self.rgb_converters.append(to_rgb_output(int(dim_in_org*img_size/i))) 250 | self.skip_connects.append(skip_connection(int(dim_in_org*img_size/i))) 251 | 252 | def forward(self, x, id_vec, lm_image, cache=None, mask=None): 253 | ms_outputs=[] 254 | ms_features=[] 255 | dict_masks={} 256 | if (mask is not None): 257 | for i in self.x_size: 258 | mask = F.interpolate(mask, i, mode='bilinear', align_corners=True) 259 | dict_masks[i] = mask 260 | index=0 261 | for block in self.model: 262 | mask = dict_masks[x.size(2)] 263 | x = block(x, id_vec, dict_masks, lm_image) # 1-masks, face area 264 | if (mask is not None) and (x.size(2) in self.mask_size): 265 | mask = dict_masks[x.size(2)] 266 | x = x + self.skip_connects[index](mask * cache[x.size(2)]) # this mask should be attr 267 | ms_outputs.append(self.rgb_converters[index](x)) 268 | index = index+1 269 | ms_features.append(x) 270 | x = self.to_rgb(x) 271 | return x, ms_features, ms_outputs 272 | 273 | 274 | class Discriminator(nn.Module): 275 | def __init__(self, img_size=256, max_conv_dim=512): 276 | super().__init__() 277 | dim_in = 2**14 // img_size 278 | num_domains = 1 279 | blocks = [] 280 | blocks += [nn.Conv2d(6, dim_in, 3, 1, 1)] 281 | repeat_num = int(np.log2(img_size)) - 2 282 | for _ in range(repeat_num): 283 | dim_out = min(dim_in*2, max_conv_dim) 284 | blocks += [ResBlk(dim_in, dim_out, downsample=True)] 285 | dim_in = dim_out 286 | blocks += [nn.LeakyReLU(0.2)] 287 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] 288 | blocks += [nn.LeakyReLU(0.2)] 289 | blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)] 290 | self.main = nn.Sequential(*blocks) 291 | 292 | def forward(self, x,lm_images): 293 | #out = self.main(x) 294 | x = torch.cat((x,lm_images), dim=-3) 295 | for block in self.main: 296 | x = block(x) 297 | out = x 298 | out = out.view(out.size(0), -1) # (batch, num_domains) 299 | return out 300 | 301 | 302 | def build_model(config): 303 | generator = Generator(config['img_size'], config['id_dim'], max_conv_dim=512) 304 | discriminator = Discriminator(config['img_size'],max_conv_dim=512) 305 | generator_ema = copy.deepcopy(generator) 306 | nets = Munch(generator=generator,discriminator=discriminator) 307 | nets_ema = Munch(generator=generator_ema) 308 | nets.fan = FAN(fname_pretrained=config['wing_path']).eval() 309 | nets_ema.fan = nets.fan 310 | return nets, nets_ema -------------------------------------------------------------------------------- /core/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | def __init__(self, in_chan, out_chan, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = conv3x3(in_chan, out_chan, stride) 15 | self.bn1 = nn.BatchNorm2d(out_chan) 16 | self.conv2 = conv3x3(out_chan, out_chan) 17 | self.bn2 = nn.BatchNorm2d(out_chan) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = None 20 | if in_chan != out_chan or stride != 1: 21 | self.downsample = nn.Sequential( 22 | nn.Conv2d(in_chan, out_chan, 23 | kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_chan), 25 | ) 26 | 27 | def forward(self, x): 28 | residual = self.conv1(x) 29 | residual = F.relu(self.bn1(residual)) 30 | residual = self.conv2(residual) 31 | residual = self.bn2(residual) 32 | 33 | shortcut = x 34 | if self.downsample is not None: 35 | shortcut = self.downsample(x) 36 | 37 | out = shortcut + residual 38 | out = self.relu(out) 39 | return out 40 | 41 | 42 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 43 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 44 | for i in range(bnum-1): 45 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 46 | return nn.Sequential(*layers) 47 | 48 | 49 | class Resnet18(nn.Module): 50 | def __init__(self): 51 | super(Resnet18, self).__init__() 52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 53 | bias=False) 54 | self.bn1 = nn.BatchNorm2d(64) 55 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 56 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 57 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 58 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 59 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 60 | 61 | def forward(self, x): 62 | x = self.conv1(x) 63 | x = F.relu(self.bn1(x)) 64 | x = self.maxpool(x) 65 | 66 | x = self.layer1(x) 67 | feat8 = self.layer2(x) # 1/8 68 | feat16 = self.layer3(feat8) # 1/16 69 | feat32 = self.layer4(feat16) # 1/32 70 | return feat8, feat16, feat32 71 | 72 | def get_params(self): 73 | wd_params, nowd_params = [], [] 74 | for name, module in self.named_modules(): 75 | if isinstance(module, (nn.Linear, nn.Conv2d)): 76 | wd_params.append(module.weight) 77 | if not module.bias is None: 78 | nowd_params.append(module.bias) 79 | elif isinstance(module, nn.BatchNorm2d): 80 | nowd_params += list(module.parameters()) 81 | return wd_params, nowd_params 82 | -------------------------------------------------------------------------------- /core/solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 3 | 4.0 International License. To view a copy of this license, visit 4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 6 | """ 7 | 8 | import os 9 | from os.path import join as ospj 10 | import time 11 | import datetime 12 | from munch import Munch 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import tensorboardX 17 | from core.model import build_model 18 | from core.checkpoint import CheckpointIO 19 | from core.data_loader import InputFetcher 20 | import core.utils as utils 21 | from core.face_model import Backbone 22 | 23 | 24 | 25 | class Solver(nn.Module): 26 | def __init__(self, config): 27 | super().__init__() 28 | self.config = config 29 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 30 | self.nets, self.nets_ema = build_model(config) 31 | self.arcface = Backbone(50, 0.6, 'ir_se') # .to(device) 32 | self.arcface.eval() 33 | self.arcface.load_state_dict(torch.load(config['face_model_path'])) # , strict=False 34 | # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device) 35 | for name, module in self.nets.items(): 36 | utils.print_network(module, name) 37 | setattr(self, name, module) 38 | for name, module in self.nets_ema.items(): 39 | setattr(self, name + '_ema', module) 40 | if config['mode'] == 'train': 41 | print(config) 42 | beta1 = config['beta1'] 43 | beta2 = config['beta2'] 44 | dis_params = list(self.nets['discriminator'].parameters()) 45 | id_params = list(self.nets['generator'].id_encoder.parameters()) 46 | dict_id_params = list(map(id, self.nets['generator'].id_encoder.parameters())) # map is a function 47 | gen_params_wo_id = filter(lambda x: id(x) not in dict_id_params, self.nets['generator'].parameters()) 48 | gen_id_params = [] 49 | for p in id_params: 50 | if p.requires_grad: 51 | gen_id_params.append(p) 52 | gen_params = [] 53 | for p in gen_params_wo_id: 54 | if p.requires_grad: 55 | gen_params.append(p) 56 | 57 | for net in self.nets.keys(): 58 | if net == 'generator': 59 | self.gen_opt = torch.optim.Adam([{'params': gen_params}, 60 | {'params': gen_id_params,'lr':config['id_lr']}], 61 | lr=config['lr'], 62 | betas=[beta1, beta2], 63 | weight_decay=config['weight_decay']) 64 | elif net == 'discriminator': 65 | self.dis_opt = torch.optim.Adam( 66 | [p for p in dis_params if p.requires_grad], 67 | lr=config['lr'], 68 | betas=[beta1, beta2], 69 | weight_decay=config['weight_decay']) 70 | self.optims = Munch() 71 | self.optims['generator'] = self.gen_opt 72 | self.optims['discriminator'] = self.dis_opt 73 | self.ckptios = [ 74 | CheckpointIO(ospj(config['checkpoint_dir'], '{:06d}_nets.ckpt'), **self.nets), 75 | CheckpointIO(ospj(config['checkpoint_dir'], '{:06d}_nets_ema.ckpt'), **self.nets_ema), 76 | CheckpointIO(ospj(config['checkpoint_dir'], '{:06d}_optims.ckpt'), **self.optims)] 77 | else: 78 | self.ckptios = [CheckpointIO(config['test_checkpoint_dir'], **self.nets_ema)] 79 | self.to(self.device) 80 | for name, network in self.named_children(): 81 | # Do not initialize the pretrained network parameters 82 | if ('ema' not in name) and ('fan' not in name) and ('arcface' not in name): 83 | print('Initializing %s...' % name) 84 | network.apply(utils.he_init) 85 | # Setup logger and output folders 86 | model_name = config['dataset'] 87 | timestamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.datetime.now()) 88 | self.train_writer = tensorboardX.SummaryWriter(os.path.join(config['log_dir'] + model_name, timestamp)) 89 | def _save_checkpoint(self, step): 90 | for ckptio in self.ckptios: 91 | ckptio.save(step) 92 | def _load_checkpoint(self, step): 93 | for ckptio in self.ckptios: 94 | ckptio.load(step) 95 | def _load_test_checkpoint(self, ckptname): 96 | for ckptio in self.ckptios: 97 | ckptio.load_test(ckptname) 98 | def _reset_grad(self): 99 | for optim in self.optims.values(): 100 | optim.zero_grad() 101 | 102 | def train(self, loaders): 103 | config = self.config 104 | nets = self.nets 105 | nets_ema = self.nets_ema 106 | gen_opt = self.gen_opt 107 | dis_opt = self.dis_opt 108 | arcface = self.arcface 109 | fetcher = InputFetcher(loaders.src, 'train') 110 | inputs_val = next(fetcher) 111 | # resume training if necessary 112 | if config['resume_iter'] > 0: 113 | self._load_checkpoint(config['resume_iter']) 114 | print('Start training...') 115 | start_time = time.time() 116 | for i in range(config['resume_iter'], config['total_iters']): 117 | # fetch images 118 | inputs = next(fetcher) 119 | src, tar, src_lm, tar_lm, src_mask, tar_mask= inputs.src, inputs.tar, inputs.src_lm, \ 120 | inputs.tar_lm, inputs.src_mask, inputs.tar_mask 121 | # train the discriminator 122 | d_loss, d_losses_all = compute_d_loss( 123 | nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask) 124 | self._reset_grad() 125 | d_loss.backward() 126 | dis_opt.step() 127 | # train the generator 128 | g_loss, g_losses_all = compute_g_loss( 129 | nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, arcface) 130 | self._reset_grad() 131 | g_loss.backward() 132 | gen_opt.step() 133 | # compute moving average of network parameters 134 | moving_average(nets.generator, nets_ema.generator, beta=0.999) 135 | if (i+1) % config['print_every'] == 0: 136 | elapsed = time.time() - start_time 137 | elapsed = str(datetime.timedelta(seconds=elapsed))[:-7] 138 | log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, config['total_iters']) 139 | all_losses = dict() 140 | for loss, prefix in zip([d_losses_all, g_losses_all], 141 | ['D/all_', 'G/all_']): 142 | for key, value in loss.items(): 143 | all_losses[prefix + key] = value 144 | self.train_writer.add_scalar(prefix+key, value, i+1) 145 | log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()]) 146 | print(log) 147 | # generate images for observation 148 | if (i+1) % config['sample_every'] == 0: 149 | os.makedirs(config['sample_dir'], exist_ok=True) 150 | utils.display_image(nets_ema, config, inputs=inputs_val, step=i+1) 151 | # save model checkpoints 152 | if (i+1) % config['save_every'] == 0: 153 | self._save_checkpoint(step=i+1) 154 | self.train_writer.close() 155 | 156 | @torch.no_grad() 157 | def test(self, loaders): 158 | config = self.config 159 | nets_ema = self.nets_ema 160 | os.makedirs(config['result_dir'], exist_ok=True) 161 | self._load_test_checkpoint(config['test_checkpoint_name']) 162 | f = open(config['test_img_list'], 'r') 163 | img_num = len(f.readlines()) 164 | f.close() 165 | total_iters = int(img_num/config['batch_size']) + 1 166 | save_dir=config['result_dir'] 167 | test_fetcher = InputFetcher(loaders.src, 'test') 168 | for i in range(0, total_iters): 169 | inputs = next(test_fetcher) 170 | utils.disentangle_and_swapping_test(nets_ema, config, inputs, save_dir) 171 | 172 | def compute_d_loss(nets, config, x_a, x_b, x_a_lm, x_b_lm, x_a_mask, x_b_mask): 173 | x_a.requires_grad_() 174 | x_b.requires_grad_() 175 | out_a = nets.discriminator(x_a,x_a_lm) 176 | out_b = nets.discriminator(x_b,x_b_lm) 177 | loss_real_a = adv_loss(out_a, 1) 178 | loss_real_b = adv_loss(out_b, 1) 179 | loss_reg_a = r1_reg(out_a, x_a) 180 | loss_reg_b = r1_reg(out_b, x_b) 181 | loss_real = loss_real_a + loss_real_b 182 | loss_reg = loss_reg_a + loss_reg_b 183 | x_ba, x_ab, ms_features_ba, ms_features_ab, ms_outputs_ba,\ 184 | ms_outputs_ab = nets.generator(x_a, x_b, x_a_lm, x_b_lm, x_a_mask, x_b_mask) 185 | out_ba = nets.discriminator(x_ba, x_b_lm) # x_ba, a's identity 186 | out_ab = nets.discriminator(x_ab, x_a_lm) # x_ab, b's identity 187 | loss_fake_ba = adv_loss(out_ba, 0) 188 | loss_fake_ab = adv_loss(out_ab, 0) 189 | loss_fake = loss_fake_ba + loss_fake_ab 190 | loss = loss_real + loss_fake + config['lambda_reg'] * loss_reg 191 | return loss, Munch(real=loss_real.item(), 192 | fake=loss_fake.item(), 193 | reg=loss_reg.item(), 194 | total_loss=loss.item()) 195 | 196 | def compute_g_loss(nets, config, x_a, x_b, x_a_lm, x_b_lm, x_a_mask, x_b_mask, arcface): 197 | att_a, i_a_prime, cache_a = nets.generator.encode(x_a, x_a_mask) 198 | att_b, i_b_prime, cache_b = nets.generator.encode(x_b, x_b_mask) 199 | x_a_recon, ms_features_a, ms_outputs_a = nets.generator.decode(att_a, i_a_prime, x_a_lm, cache_a, x_a_mask) 200 | x_b_recon, ms_features_b, ms_outputs_b = nets.generator.decode(att_b, i_b_prime, x_b_lm, cache_b, x_b_mask) 201 | x_ba, ms_features_ba, ms_outputs_ba = nets.generator.decode(att_b, i_a_prime, x_b_lm, cache_b, x_b_mask) # x_ba, a's identity 202 | x_ab, ms_features_ab, ms_outputs_ab = nets.generator.decode(att_a, i_b_prime, x_a_lm, cache_a, x_a_mask) # x_ab, b's identity 203 | with torch.no_grad(): 204 | a_embed, a_feats = arcface(F.interpolate(x_a, [112, 112], mode='bilinear', align_corners=True)) 205 | b_embed, b_feats = arcface( F.interpolate(x_b, [112, 112], mode='bilinear', align_corners=True)) 206 | ba_embed, ba_feats = arcface(F.interpolate(x_ba, [112, 112], mode='bilinear', align_corners=True)) 207 | ab_embed, ab_feats = arcface(F.interpolate(x_ab, [112, 112], mode='bilinear', align_corners=True)) 208 | loss_id_a = (1 - F.cosine_similarity(a_embed, ba_embed, dim=1)) 209 | loss_id_b=(1 - F.cosine_similarity(b_embed, ab_embed, dim=1)) 210 | loss_id_a = loss_id_a.mean() 211 | loss_id_b = loss_id_b.mean() 212 | loss_id = loss_id_a + loss_id_b 213 | out_a = nets.discriminator(x_ba, x_b_lm) #x_ba, a's identity 214 | out_b = nets.discriminator(x_ab, x_a_lm) #x_ab, b's identity 215 | loss_adv_a = adv_loss(out_a, 1) 216 | loss_adv_b = adv_loss(out_b, 1) 217 | loss_adv = loss_adv_a + loss_adv_b 218 | loss_recon_x_a = torch.mean(torch.abs(x_a_recon - x_a)) 219 | loss_recon_x_b = torch.mean(torch.abs(x_b_recon - x_b)) 220 | loss_recon = loss_recon_x_a + loss_recon_x_b 221 | loss_att_a_face = style_loss_face(ms_features_a, ms_features_ab, x_a_mask) 222 | loss_att_b_face = style_loss_face(ms_features_b, ms_features_ba, x_b_mask) 223 | loss_att_face = loss_att_a_face + loss_att_b_face 224 | loss_att_a_bg = style_loss_background(ms_features_a, ms_features_ab,x_a_mask) 225 | loss_att_b_bg = style_loss_background(ms_features_b, ms_features_ba,x_b_mask) 226 | loss_att_bg = loss_att_a_bg + loss_att_b_bg 227 | loss = loss_adv + config['lambda_id'] * loss_id \ 228 | + config['lambda_recon'] * loss_recon + config['lambda_att_face'] * loss_att_face \ 229 | + config['lambda_att_bg'] * loss_att_bg 230 | return loss, Munch(adv=loss_adv.item(), 231 | id=loss_id.item(), 232 | recon=loss_recon.item(), 233 | att_face=loss_att_face.item(), 234 | att_bg=loss_att_bg.item(), 235 | total_loss=loss.item()) 236 | 237 | def moving_average(model, model_test, beta=0.999): 238 | for param, param_test in zip(model.parameters(), model_test.parameters()): 239 | param_test.data = torch.lerp(param.data, param_test.data, beta) 240 | 241 | def adv_loss(logits, target): 242 | assert target in [1, 0] 243 | targets = torch.full_like(logits, fill_value=target) 244 | loss = F.binary_cross_entropy_with_logits(logits, targets) 245 | return loss 246 | 247 | def r1_reg(d_out, x_in): 248 | # zero-centered gradient penalty for real images 249 | batch_size = x_in.size(0) 250 | grad_dout = torch.autograd.grad( 251 | outputs=d_out.sum(), inputs=x_in, 252 | create_graph=True, retain_graph=True, only_inputs=True 253 | )[0] 254 | grad_dout2 = grad_dout.pow(2) 255 | assert(grad_dout2.size() == x_in.size()) 256 | reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0) 257 | return reg 258 | 259 | def compute_gram(x): 260 | b, ch, h, w = x.size() 261 | f = x.view(b, ch, w * h) 262 | f_T = f.transpose(1, 2) 263 | G = f.bmm(f_T) / (h * w * ch) 264 | return G 265 | 266 | def style_loss_face(x, y,masks): 267 | style_loss = 0.0 268 | for i in range(0, len(x)): 269 | masks = F.interpolate(masks, x[i].size(2), mode='bilinear') 270 | style_loss += torch.nn.L1Loss()(compute_gram((1-masks)*x[i]), compute_gram((1-masks)*y[i])) 271 | style_loss = style_loss/len(x) 272 | return style_loss 273 | 274 | def style_loss_background(x, y, masks): 275 | style_loss = 0.0 276 | for i in range(0, len(x)): 277 | masks = F.interpolate(masks,x[i].size(2),mode='bilinear') 278 | style_loss += torch.nn.L1Loss()(masks*x[i], masks*y[i]) # pay attention to this form 279 | style_loss = style_loss/len(x) 280 | return style_loss 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 3 | 4.0 International License. To view a copy of this license, visit 4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 6 | """ 7 | 8 | import os 9 | from os.path import join as ospj 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torchvision.utils as vutils 14 | import yaml 15 | 16 | def print_network(network, name): 17 | num_params = 0 18 | for p in network.parameters(): 19 | num_params += p.numel() 20 | print("Number of parameters of %s: %i" % (name, num_params)) 21 | def he_init(module): 22 | if isinstance(module, nn.Conv2d): 23 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') 24 | if module.bias is not None: 25 | nn.init.constant_(module.bias, 0) 26 | if isinstance(module, nn.Linear): 27 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu') 28 | if module.bias is not None: 29 | nn.init.constant_(module.bias, 0) 30 | def denormalize(x): 31 | out = (x + 1) / 2 32 | return out.clamp_(0, 1) 33 | def save_image(x, ncol, filename): 34 | x = denormalize(x) 35 | vutils.save_image(x.cpu(), filename, nrow=ncol, padding=0) 36 | def get_config(config): 37 | with open(config, 'r') as stream: 38 | return yaml.load(stream,Loader=yaml.FullLoader) 39 | 40 | 41 | @torch.no_grad() 42 | def disentangle_and_reconstruct(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename): 43 | N, C, H, W = src.size() 44 | att_a, i_a_prime, cache_a = nets.generator.encode(src, src_mask) # cache follows attr 45 | att_b, i_b_prime, cache_b = nets.generator.encode(tar, tar_mask) 46 | a_recon,_,_ = nets.generator.decode(att_a, i_a_prime, src_lm, cache_a, src_mask) 47 | b_recon,_,_ = nets.generator.decode(att_b, i_b_prime, tar_lm, cache_b, tar_mask) 48 | disp_concat = [src, tar, a_recon, b_recon] 49 | disp_concat = torch.cat(disp_concat, dim=0) 50 | save_image(disp_concat, N, filename) 51 | 52 | @torch.no_grad() 53 | def disentangle_and_swapping(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename): 54 | N, C, H, W = src.size() 55 | whitespace = torch.ones(1, C, H, W).to(src.device) 56 | src_with_whitespace = torch.cat([whitespace, src], dim=0) 57 | disp_concat = [src_with_whitespace] 58 | for i, tar_img in enumerate(tar): 59 | tar_imgs = tar_img.repeat(N,1,1,1) 60 | disp_img = tar_img.repeat(1,1,1,1) 61 | tar_i_lm = tar_lm[i,:,:,:] 62 | tar_i_lm = tar_i_lm.repeat(N,1,1,1) 63 | tar_i_mask = tar_mask[i,:,:,:] 64 | tar_i_mask = tar_i_mask.repeat(N,1,1,1) 65 | srcid_taratt, tarid_srcatt,_,_,_,_ = nets.generator(src, tar_imgs, src_lm, 66 | tar_i_lm, src_mask, tar_i_mask) 67 | fake_srcid_taratt = torch.cat([disp_img, srcid_taratt], dim=0) 68 | fake_tarid_srcatt = torch.cat([disp_img, tarid_srcatt], dim=0) 69 | disp_concat += [fake_srcid_taratt] 70 | disp_concat += [fake_tarid_srcatt] 71 | disp_concat = torch.cat(disp_concat, dim=0) 72 | save_image(disp_concat, N+1, filename) 73 | 74 | @torch.no_grad() 75 | def display_image(nets, config, inputs, step): 76 | src, tar, src_lm, tar_lm, src_mask, tar_mask = inputs.src, inputs.tar, \ 77 | inputs.src_lm, inputs.tar_lm, inputs.src_mask, inputs.tar_mask 78 | # face reconstruction 79 | filename = ospj(config['sample_dir'], '%06d_reconstruction.jpg' % (step)) 80 | disentangle_and_reconstruct(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename) 81 | # face swapping 82 | filename = ospj(config['sample_dir'], '%06d_faceswapping.jpg' % (step)) 83 | disentangle_and_swapping(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename) 84 | 85 | @torch.no_grad() 86 | def disentangle_and_swapping_test(nets, config, inputs, save_dir): 87 | post_process = config['post_process'] 88 | src, tar, src_lm, tar_lm, src_mask, tar_mask, tar_parsing, src_name, tar_name = inputs.src, inputs.tar, \ 89 | inputs.src_lm, inputs.tar_lm, inputs.src_mask, inputs.tar_mask, inputs.tar_parsing, inputs.src_name, inputs.tar_name, 90 | src_mask = F.interpolate(src_mask, src.size(2), mode='bilinear', align_corners=True) 91 | tar_mask = F.interpolate(tar_mask, src.size(2), mode='bilinear', align_corners=True) 92 | srcid_taratt, tarid_srcatt,_,_,_,_ = nets.generator(src, tar, src_lm, tar_lm, src_mask, tar_mask) #modified by liqi 93 | result_first = save_dir + 'swapped_result_single/' 94 | result_second = save_dir + 'swapped_result_afterps/' 95 | result_third = save_dir + 'swapped_result_all/' 96 | if not os.path.exists(result_first): 97 | os.makedirs(result_first) 98 | if not os.path.exists(result_second): 99 | os.makedirs(result_second) 100 | if not os.path.exists(result_third): 101 | os.makedirs(result_third) 102 | if post_process: 103 | src_convex_hull = nets.fan.get_convex_hull(src) 104 | tar_convex_hull = nets.fan.get_convex_hull(tar) 105 | temp_src_forehead = src_convex_hull - src_mask 106 | temp_tar_forehead = tar_convex_hull - tar_mask 107 | # to ensure the values of src_forehead and tar_forehead are in [0,1] 108 | one_tensor = torch.ones(temp_src_forehead.size()).to(device=temp_src_forehead.device) 109 | zero_tensor = torch.zeros(temp_src_forehead.size()).to(device=temp_src_forehead.device) 110 | temp_var = torch.where(temp_src_forehead >= 1.0, one_tensor, temp_src_forehead) 111 | src_forehead = torch.where(temp_var <= 0.0, zero_tensor, temp_var) 112 | temp_var = torch.where(temp_tar_forehead >= 1.0, one_tensor, temp_tar_forehead) 113 | tar_forehead = torch.where(temp_var <= 0.0, zero_tensor, temp_var) 114 | tar_hair = get_hair(tar_parsing) 115 | post_result = postprocess(tar, srcid_taratt, tar_hair, src_forehead, tar_forehead) 116 | for i in range(len(srcid_taratt)): 117 | filename = result_first + src_name[i][0:-4]+'_FS_'+ tar_name[i][0:-5]+'.png' 118 | filename_post = result_second + src_name[i][0:-4] + '_FS_' + tar_name[i][0:-5] + '.png' 119 | filename_all = result_third + src_name[i][0:-4] + '_FS_' + tar_name[i][0:-5] + '.png' 120 | save_image(srcid_taratt[i,:,:,:], 1, filename) 121 | if post_process: 122 | save_image(post_result[i, :, :, :], 1, filename_post) 123 | x_concat = torch.cat([src[i].unsqueeze(0), tar[i].unsqueeze(0), 124 | srcid_taratt[i, :, :, :].unsqueeze(0),post_result[i, :, :, :].unsqueeze(0)], dim=0) 125 | save_image(x_concat, 4, filename_all) 126 | else: 127 | x_concat = torch.cat([src[i].unsqueeze(0), tar[i].unsqueeze(0), 128 | srcid_taratt[i, :, :, :].unsqueeze(0)], dim=0) 129 | save_image(x_concat, 3, filename_all) 130 | 131 | 132 | def get_hair(segmentation): 133 | out = segmentation.mul_(255).int() 134 | mask_ind_hair = [17] 135 | with torch.no_grad(): 136 | out_parse = out 137 | hair = torch.ones((out_parse.shape[0], 1, out_parse.shape[2], out_parse.shape[3])).cuda() 138 | for pi in mask_ind_hair: 139 | index = torch.where(out_parse == pi) 140 | hair[index[0], :, index[2],index[3]] = 0 141 | return hair 142 | 143 | 144 | def postprocess(tar, srcid_taratt, tar_hair, src_forehead, tar_forehead): 145 | #inner area of tar_hair is 0, inner area of tar_forehead is 1 146 | smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda() 147 | one_tensor = torch.ones(tar_forehead.size()).to(device=tar_forehead.device) 148 | temp_tar_hair_and_forehead = (1-tar_hair) + tar_forehead 149 | tar_hair_and_forehead = torch.where(temp_tar_hair_and_forehead >= 1.0, one_tensor, temp_tar_hair_and_forehead) 150 | tar_preserve = tar_hair 151 | # find whether occlusion exists in source image; if exists, then preserve the hair and forehead of the target image 152 | for i in range(src_forehead.size(0)): 153 | src_forehead_i = src_forehead[i,:,:,:] 154 | src_forehead_i = src_forehead_i .squeeze_() 155 | tar_forehead_i = tar_forehead[i,:,:,:] 156 | tar_forehead_i = tar_forehead_i.squeeze_() 157 | H1,W1 = torch.nonzero(src_forehead_i).size() 158 | H2,W2 = torch.nonzero(tar_forehead_i).size() 159 | if (H1 * W1) / (H2 * W2 + 0.0001) < 0.4 and (H2 * W2) >= 1000: # 160 | tar_preserve[i,:,:,:] = 1 - tar_hair_and_forehead[i,:,:,:] 161 | soft_mask, _ = smooth_mask(tar_preserve) 162 | result = srcid_taratt * soft_mask + tar * (1-soft_mask) 163 | return result 164 | 165 | 166 | class SoftErosion(nn.Module): 167 | def __init__(self, kernel_size=15, threshold=0.6, iterations=1): 168 | super(SoftErosion, self).__init__() 169 | r = kernel_size // 2 170 | self.padding = r 171 | self.iterations = iterations 172 | self.threshold = threshold 173 | # Create kernel 174 | y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size)) 175 | dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) 176 | kernel = dist.max() - dist 177 | kernel /= kernel.sum() 178 | kernel = kernel.view(1, 1, *kernel.shape) 179 | self.register_buffer('weight', kernel) 180 | 181 | def forward(self, x): 182 | x = x.float() 183 | for i in range(self.iterations - 1): 184 | x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)) 185 | x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding) 186 | mask = x >= self.threshold 187 | x[mask] = 1.0 188 | x[~mask] /= x[~mask].max() 189 | return x, mask -------------------------------------------------------------------------------- /core/wing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 3 | 4.0 International License. To view a copy of this license, visit 4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 6 | 7 | """ 8 | 9 | 10 | from functools import partial 11 | import numpy as np 12 | import cv2 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from matplotlib import pyplot as plt 17 | 18 | def normalize(x, eps=1e-6): 19 | """Apply min-max normalization.""" 20 | x = x.contiguous() 21 | N, C, H, W = x.size() 22 | x_ = x.view(N*C, -1) 23 | max_val = torch.max(x_, dim=1, keepdim=True)[0] 24 | min_val = torch.min(x_, dim=1, keepdim=True)[0] 25 | x_ = (x_ - min_val) / (max_val - min_val + eps) 26 | out = x_.view(N, C, H, W) 27 | return out 28 | 29 | def truncate(x, thres=0.1): 30 | """Remove small values in heatmaps.""" 31 | return torch.where(x < thres, torch.zeros_like(x), x) 32 | 33 | def resize(x, p=2): 34 | """Resize heatmaps.""" 35 | return x**p 36 | 37 | def shift(x, N): 38 | """Shift N pixels up or down.""" 39 | up = N >= 0 40 | N = abs(N) 41 | _, _, H, W = x.size() 42 | head = torch.arange(N) 43 | tail = torch.arange(H-N) 44 | 45 | if up: 46 | head = torch.arange(H-N)+N 47 | tail = torch.arange(N) 48 | else: 49 | head = torch.arange(N) + (H-N) 50 | tail = torch.arange(H-N) 51 | 52 | # permutation indices 53 | perm = torch.cat([head, tail]).to(x.device) 54 | out = x[:, :, perm, :] 55 | return out 56 | 57 | def get_preds_fromhm(hm): 58 | max, idx = torch.max( 59 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) 60 | idx += 1 61 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() 62 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) 63 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) 64 | 65 | for i in range(preds.size(0)): 66 | for j in range(preds.size(1)): 67 | hm_ = hm[i, j, :] 68 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 69 | if pX > 0 and pX < 63 and pY > 0 and pY < 63: 70 | diff = torch.FloatTensor( 71 | [hm_[pY, pX + 1] - hm_[pY, pX - 1], 72 | hm_[pY + 1, pX] - hm_[pY - 1, pX]]) 73 | preds[i, j].add_(diff.sign_().mul_(.25)) 74 | 75 | preds.add_(-0.5) 76 | return preds 77 | 78 | def curve_fill(points, heatmapSize=256, sigma=3, erode=False): 79 | sigma = max(1,(sigma // 2)*2 + 1) 80 | points = points.astype(np.int32) 81 | canvas = np.zeros([heatmapSize, heatmapSize]) 82 | cv2.fillPoly(canvas,np.array([points]),255) 83 | canvas = cv2.GaussianBlur(canvas, (sigma, sigma), sigma) 84 | return canvas.astype(np.float64)/255.0 85 | 86 | class HourGlass(nn.Module): 87 | def __init__(self, num_modules, depth, num_features, first_one=False): 88 | super(HourGlass, self).__init__() 89 | self.num_modules = num_modules 90 | self.depth = depth 91 | self.features = num_features 92 | self.coordconv = CoordConvTh(64, 64, True, True, 256, first_one, 93 | out_channels=256, 94 | kernel_size=1, stride=1, padding=0) 95 | self._generate_network(self.depth) 96 | 97 | def _generate_network(self, level): 98 | self.add_module('b1_' + str(level), ConvBlock(256, 256)) 99 | self.add_module('b2_' + str(level), ConvBlock(256, 256)) 100 | if level > 1: 101 | self._generate_network(level - 1) 102 | else: 103 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256)) 104 | self.add_module('b3_' + str(level), ConvBlock(256, 256)) 105 | 106 | def _forward(self, level, inp): 107 | up1 = inp 108 | up1 = self._modules['b1_' + str(level)](up1) 109 | low1 = F.avg_pool2d(inp, 2, stride=2) 110 | low1 = self._modules['b2_' + str(level)](low1) 111 | 112 | if level > 1: 113 | low2 = self._forward(level - 1, low1) 114 | else: 115 | low2 = low1 116 | low2 = self._modules['b2_plus_' + str(level)](low2) 117 | low3 = low2 118 | low3 = self._modules['b3_' + str(level)](low3) 119 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest') 120 | 121 | return up1 + up2 122 | 123 | def forward(self, x, heatmap): 124 | x, last_channel = self.coordconv(x, heatmap) 125 | return self._forward(self.depth, x), last_channel 126 | 127 | 128 | class AddCoordsTh(nn.Module): 129 | def __init__(self, height=64, width=64, with_r=False, with_boundary=False): 130 | super(AddCoordsTh, self).__init__() 131 | self.with_r = with_r 132 | self.with_boundary = with_boundary 133 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 134 | 135 | with torch.no_grad(): 136 | x_coords = torch.arange(height).unsqueeze(1).expand(height, width).float() 137 | y_coords = torch.arange(width).unsqueeze(0).expand(height, width).float() 138 | x_coords = (x_coords / (height - 1)) * 2 - 1 139 | y_coords = (y_coords / (width - 1)) * 2 - 1 140 | coords = torch.stack([x_coords, y_coords], dim=0) 141 | 142 | if self.with_r: 143 | rr = torch.sqrt(torch.pow(x_coords, 2) + torch.pow(y_coords, 2)) 144 | rr = (rr / torch.max(rr)).unsqueeze(0) 145 | coords = torch.cat([coords, rr], dim=0) 146 | 147 | self.coords = coords.unsqueeze(0).to(device) 148 | self.x_coords = x_coords.to(device) 149 | self.y_coords = y_coords.to(device) 150 | 151 | def forward(self, x, heatmap=None): 152 | """ 153 | x: (batch, c, x_dim, y_dim) 154 | """ 155 | coords = self.coords.repeat(x.size(0), 1, 1, 1) 156 | 157 | if self.with_boundary and heatmap is not None: 158 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0) 159 | zero_tensor = torch.zeros_like(self.x_coords) 160 | xx_boundary_channel = torch.where(boundary_channel > 0.05, self.x_coords, zero_tensor).to(zero_tensor.device) 161 | yy_boundary_channel = torch.where(boundary_channel > 0.05, self.y_coords, zero_tensor).to(zero_tensor.device) 162 | coords = torch.cat([coords, xx_boundary_channel, yy_boundary_channel], dim=1) 163 | 164 | x_and_coords = torch.cat([x, coords], dim=1) 165 | return x_and_coords 166 | 167 | 168 | class CoordConvTh(nn.Module): 169 | """CoordConv layer as in the paper.""" 170 | def __init__(self, height, width, with_r, with_boundary, 171 | in_channels, first_one=False, *args, **kwargs): 172 | super(CoordConvTh, self).__init__() 173 | self.addcoords = AddCoordsTh(height, width, with_r, with_boundary) 174 | in_channels += 2 175 | if with_r: 176 | in_channels += 1 177 | if with_boundary and not first_one: 178 | in_channels += 2 179 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs) 180 | 181 | def forward(self, input_tensor, heatmap=None): 182 | ret = self.addcoords(input_tensor, heatmap) 183 | last_channel = ret[:, -2:, :, :] 184 | ret = self.conv(ret) 185 | return ret, last_channel 186 | 187 | 188 | class ConvBlock(nn.Module): 189 | def __init__(self, in_planes, out_planes): 190 | super(ConvBlock, self).__init__() 191 | self.bn1 = nn.BatchNorm2d(in_planes) 192 | conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False, dilation=1) 193 | self.conv1 = conv3x3(in_planes, int(out_planes / 2)) 194 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) 195 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) 196 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) 197 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) 198 | 199 | self.downsample = None 200 | if in_planes != out_planes: 201 | self.downsample = nn.Sequential(nn.BatchNorm2d(in_planes), 202 | nn.ReLU(True), 203 | nn.Conv2d(in_planes, out_planes, 1, 1, bias=False)) 204 | 205 | def forward(self, x): 206 | residual = x 207 | 208 | out1 = self.bn1(x) 209 | out1 = F.relu(out1, True) 210 | out1 = self.conv1(out1) 211 | 212 | out2 = self.bn2(out1) 213 | out2 = F.relu(out2, True) 214 | out2 = self.conv2(out2) 215 | 216 | out3 = self.bn3(out2) 217 | out3 = F.relu(out3, True) 218 | out3 = self.conv3(out3) 219 | 220 | out3 = torch.cat((out1, out2, out3), 1) 221 | if self.downsample is not None: 222 | residual = self.downsample(residual) 223 | out3 += residual 224 | return out3 225 | 226 | 227 | class FAN(nn.Module): 228 | def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None): 229 | super(FAN, self).__init__() 230 | self.num_modules = num_modules 231 | self.end_relu = end_relu 232 | 233 | # Base part 234 | self.conv1 = CoordConvTh(256, 256, True, False, 235 | in_channels=3, out_channels=64, 236 | kernel_size=7, stride=2, padding=3) 237 | self.bn1 = nn.BatchNorm2d(64) 238 | self.conv2 = ConvBlock(64, 128) 239 | self.conv3 = ConvBlock(128, 128) 240 | self.conv4 = ConvBlock(128, 256) 241 | 242 | # Stacking part 243 | self.add_module('m0', HourGlass(1, 4, 256, first_one=True)) 244 | self.add_module('top_m_0', ConvBlock(256, 256)) 245 | self.add_module('conv_last0', nn.Conv2d(256, 256, 1, 1, 0)) 246 | self.add_module('bn_end0', nn.BatchNorm2d(256)) 247 | self.add_module('l0', nn.Conv2d(256, num_landmarks+1, 1, 1, 0)) 248 | 249 | if fname_pretrained is not None: 250 | self.load_pretrained_weights(fname_pretrained) 251 | 252 | def load_pretrained_weights(self, fname): 253 | if torch.cuda.is_available(): 254 | checkpoint = torch.load(fname) 255 | else: 256 | checkpoint = torch.load(fname, map_location=torch.device('cpu')) 257 | model_weights = self.state_dict() 258 | model_weights.update({k: v for k, v in checkpoint['state_dict'].items() 259 | if k in model_weights}) 260 | self.load_state_dict(model_weights) 261 | 262 | def forward(self, x): 263 | x, _ = self.conv1(x) 264 | x = F.relu(self.bn1(x), True) 265 | x = F.avg_pool2d(self.conv2(x), 2, stride=2) 266 | x = self.conv3(x) 267 | x = self.conv4(x) 268 | 269 | outputs = [] 270 | boundary_channels = [] 271 | tmp_out = None 272 | ll, boundary_channel = self._modules['m0'](x, tmp_out) 273 | ll = self._modules['top_m_0'](ll) 274 | ll = F.relu(self._modules['bn_end0'] 275 | (self._modules['conv_last0'](ll)), True) 276 | 277 | # Predict heatmaps 278 | tmp_out = self._modules['l0'](ll) 279 | if self.end_relu: 280 | tmp_out = F.relu(tmp_out) # HACK: Added relu 281 | outputs.append(tmp_out) 282 | boundary_channels.append(boundary_channel) 283 | return outputs, boundary_channels 284 | 285 | @torch.no_grad() 286 | def get_heatmap(self, x): 287 | ''' outputs 0-1 normalized heatmap ''' 288 | x = F.interpolate(x, size=256, mode='bilinear',align_corners=True) 289 | x_01 = x*0.5 + 0.5 290 | outputs, _ = self(x_01) 291 | heatmaps = outputs[-1][:, :-1, :, :] 292 | scale_factor = x.size(2) // heatmaps.size(2) 293 | heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor, 294 | mode='bilinear', align_corners=True) 295 | return heatmaps 296 | 297 | 298 | @torch.no_grad() 299 | def get_points2heatmap(self, x): 300 | ''' outputs landmarks of x.shape ''' 301 | heatmaps = self.get_heatmap(x) 302 | landmarks = [] 303 | for i in range(x.size(0)): 304 | pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0)) 305 | landmarks.append(pred_landmarks) 306 | scale_factor = x.size(2) // heatmaps.size(2) 307 | landmarks = torch.cat(landmarks) * scale_factor 308 | heatmap_all=torch.zeros((len(x),7,x.size(2),x.size(2))) 309 | for i in range(0,len(x)): 310 | curves, boundary = self.points2curves(landmarks[i]) 311 | heatmap = self.curves2segments(curves) 312 | heatmap = torch.from_numpy(heatmap).float() 313 | heatmap_all[i] = heatmap 314 | heatmap_all = heatmap_all.to(device=x.device) 315 | return heatmap_all, curves, boundary 316 | 317 | @torch.no_grad() 318 | def get_convex_hull(self, x): 319 | ''' outputs landmarks of x.shape ''' 320 | heatmaps = self.get_heatmap(x) 321 | skins = [] 322 | for i in range(x.size(0)): 323 | pred_landmark = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0)) 324 | scale_factor = x.size(2) // heatmaps.size(2) 325 | pred_landmark = torch.squeeze(pred_landmark) * scale_factor 326 | curves_ref, _ = self.points2curves(pred_landmark) 327 | roi_ref = self.curves2segments(curves_ref) 328 | skin = roi_ref[0] 329 | skin = (skin > 0).astype(int) 330 | skin = 1- skin 331 | skin = torch.from_numpy(skin).type(dtype=torch.float32) 332 | skin = skin.unsqueeze(0) 333 | skin = skin.to(x.device) 334 | skins.append(skin) 335 | skins = torch.cat(skins) 336 | skins = skins.unsqueeze(1) 337 | return skins 338 | 339 | @torch.no_grad() 340 | def points2curves(self, points, heatmapSize=256, sigma=1,heatmap_num=15): 341 | curves = [0] * heatmap_num 342 | curves[0] = np.zeros((33, 2)) # contour 343 | curves[1] = np.zeros((5, 2)) # left top eyebrow 344 | curves[2] = np.zeros((5, 2)) # right top eyebrow 345 | curves[3] = np.zeros((4, 2)) # nose bridge 346 | curves[4] = np.zeros((5, 2)) # nose tip 347 | curves[5] = np.zeros((5, 2)) # left bottom eye 348 | curves[6] = np.zeros((5, 2)) # left bottom eye 349 | curves[7] = np.zeros((5, 2)) # right top eye 350 | curves[8] = np.zeros((5, 2)) # right bottom eye 351 | curves[9] = np.zeros((7, 2)) # up up lip 352 | curves[10] = np.zeros((5, 2)) # up bottom lip 353 | curves[11] = np.zeros((5, 2)) # bottom up lip 354 | curves[12] = np.zeros((7, 2)) # bottom bottom lip 355 | curves[13] = np.zeros((5, 2)) # left bottom eyebrow 356 | curves[14] = np.zeros((5, 2)) # left bottom eyebrow 357 | # assignment proccess 358 | # countour 359 | for i in range(33): 360 | curves[0][i] = points[i,:] 361 | for i in range(5): 362 | # left top eyebrow 363 | curves[1][i][0] = points[i + 33, 0] - 10 364 | curves[1][i][1] = points[i + 33, 1]-40 365 | curves[2][i][0] = points[i + 42, 0] + 10 366 | curves[2][i][1] = points[i + 42, 1]-40 367 | # nose bridge 368 | for i in range(4): 369 | curves[3][i] = points[i + 51,:] 370 | # nose tip 371 | for i in range(5): 372 | curves[4][i] = points[i + 55,:] 373 | # left top eye 374 | for i in range(5): 375 | curves[5][i] = points[i + 60,:] 376 | # left bottom eye 377 | curves[6][0] = points[64,:] 378 | curves[6][1] = points[65,:] 379 | curves[6][2] = points[66,:] 380 | curves[6][3] = points[67,:] 381 | curves[6][4] = points[60,:] 382 | # right top eye 383 | for i in range(5): 384 | curves[7][i] = points[i + 68,:] 385 | # right bottom eye 386 | curves[8][0] = points[72,:] 387 | curves[8][1] = points[73,:] 388 | curves[8][2] = points[74,:] 389 | curves[8][3] = points[75,:] 390 | curves[8][4] = points[68,:] 391 | # up up lip 392 | for i in range(7): 393 | curves[9][i] = points[i + 76,:] 394 | # up bottom lip 395 | for i in range(5): 396 | curves[10][i] = points[i + 88,:] 397 | # bottom up lip 398 | curves[11][0] = points[92,:] 399 | curves[11][1] = points[93,:] 400 | curves[11][2] = points[94,:] 401 | curves[11][3] = points[95,:] 402 | curves[11][4] = points[88,:] 403 | # bottom bottom lip 404 | curves[12][0] = points[82,:] 405 | curves[12][1] = points[83,:] 406 | curves[12][2] = points[84,:] 407 | curves[12][3] = points[85,:] 408 | curves[12][4] = points[86,:] 409 | curves[12][5] = points[87,:] 410 | curves[12][6] = points[76,:] 411 | # left bottom eyebrow 412 | curves[13][0] = points[38,:] 413 | curves[13][1] = points[39,:] 414 | curves[13][2] = points[40,:] 415 | curves[13][3] = points[41,:] 416 | curves[13][4] = points[33,:] 417 | # right bottom eyebrow 418 | curves[14][0] = points[46,:] 419 | curves[14][1] = points[47,:] 420 | curves[14][2] = points[48,:] 421 | curves[14][3] = points[49,:] 422 | curves[14][4] = points[50,:] 423 | return curves, None 424 | 425 | @torch.no_grad() 426 | def curves2segments(self, curves, heatmapSize=256, sigma=3): 427 | face = curve_fill(np.vstack([curves[0], curves[2][::-1], curves[1][::-1]]), heatmapSize, sigma) 428 | browL = curve_fill(np.vstack([curves[1], curves[13][::-1]]), heatmapSize, sigma) 429 | browR = curve_fill(np.vstack([curves[2], curves[14][::-1]]), heatmapSize, sigma) 430 | eyeL = curve_fill(np.vstack([curves[5], curves[6]]), heatmapSize, sigma) 431 | eyeR = curve_fill(np.vstack([curves[7], curves[8]]), heatmapSize, sigma) 432 | eye = np.max([eyeL, eyeR], axis=0) 433 | brow = np.max([browL, browR], axis=0) 434 | nose = curve_fill(np.vstack([curves[3][0:1], curves[4]]), heatmapSize, sigma) 435 | lipU = curve_fill(np.vstack([curves[9], curves[10][::-1]]), heatmapSize, sigma) 436 | lipD = curve_fill(np.vstack([curves[11], curves[12][::-1]]), heatmapSize, sigma) 437 | tooth = curve_fill(np.vstack([curves[10], curves[11][::-1]]), heatmapSize, sigma) 438 | return np.stack([face, brow, eye, nose, lipU, lipD, tooth]) 439 | 440 | @torch.no_grad() 441 | def get_landmark_curve(self, x): 442 | heatmaps = self.get_heatmap(x) 443 | landmarks = [] 444 | for i in range(x.size(0)): 445 | pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0)) 446 | landmarks.append(pred_landmarks) 447 | scale_factor = x.size(2) // heatmaps.size(2) 448 | landmarks = torch.cat(landmarks) * scale_factor 449 | batch_landmark_figure =[] 450 | for i in range(0,len(x)): 451 | dpi = 100 452 | input = x[i] 453 | preds = landmarks[i] 454 | fig = plt.figure(figsize=(input.shape[2] / dpi, input.shape[1] / dpi), dpi=dpi) 455 | ax = fig.add_subplot(1, 1, 1) 456 | ax.imshow(np.ones((input.shape[1],input.shape[2],input.shape[0]))) 457 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0) 458 | # eye 459 | ax.plot(preds[60:68,0],preds[60:68,1],marker='',markersize=5,linestyle='-',color='red',lw=2) 460 | ax.plot(preds[68:76,0],preds[68:76,1],marker='',markersize=5,linestyle='-',color='red',lw=2) 461 | #outer and inner lip 462 | ax.plot(preds[76:88,0],preds[76:88,1],marker='',markersize=5,linestyle='-',color='green',lw=2) 463 | ax.plot(preds[88:96,0],preds[88:96,1],marker='',markersize=5,linestyle='-',color='blue',lw=2) 464 | ax.axis('off') 465 | fig.canvas.draw() 466 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 467 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 468 | batch_landmark_figure.append(data) 469 | plt.close(fig) 470 | return batch_landmark_figure 471 | 472 | 473 | -------------------------------------------------------------------------------- /images/1: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/qualitative_comparisons.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liqi-casia/FaceSwapper/fd999e0396c8b589c5a735bacd7e208918edc0c0/images/qualitative_comparisons.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This work is licensed under the Creative Commons Attribution-NonCommercial 3 | 4.0 International License. To view a copy of this license, visit 4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 6 | """ 7 | 8 | import os 9 | import argparse 10 | import torch 11 | from torch.backends import cudnn 12 | from munch import Munch 13 | from core.data_loader import get_train_loader, get_test_loader 14 | from core.solver import Solver 15 | from core.utils import get_config 16 | 17 | 18 | def main(args): 19 | config = get_config(args.config) 20 | os.environ['CUDA_VISIBLE_DEVICES'] = config['cuda_device'] 21 | cudnn.benchmark = True 22 | solver = Solver(config) 23 | if config['mode'] == 'train': 24 | loaders = Munch(src=get_train_loader(root=config['train_img_dir'], 25 | img_size=config['img_size'], 26 | batch_size=config['batch_size'], 27 | num_workers=config['num_workers'])) 28 | solver.train(loaders) 29 | elif config['mode'] == 'test': 30 | loaders = Munch(src=get_test_loader(root=config['test_img_dir'], 31 | test_img_list=config['test_img_list'], 32 | img_size=config['img_size'], 33 | batch_size=config['batch_size'], 34 | num_workers=config['num_workers'])) 35 | solver.test(loaders) 36 | else: 37 | raise NotImplementedError 38 | 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--config', type=str, default='param.yaml', help='Path to the config file.') 43 | args = parser.parse_args() 44 | main(args) 45 | -------------------------------------------------------------------------------- /param.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | dataset: celeba #dataset we use 5 | img_size: 256 #image resolution, default=256 6 | id_dim: 512 #identity code dimension, default=512 7 | cuda_device: '1' #cuda device we use 8 | 9 | # weight for objective functions 10 | lambda_reg: 1 #type=float, weight for R1 regularization 11 | lambda_recon: 1 #type=float, weight for reconstruction loss 12 | lambda_att_face: 2 #type=float, weight for attribute preservation loss (face) 13 | lambda_att_bg: 1 #type=float, weight for attribute preservation loss (background) 14 | lambda_per: 1 #type=float, weight for perceptual loss 15 | lambda_id: 10 #type=float, weight for identity reconstruction loss 16 | 17 | # arguments 18 | total_iters: 60000 #type=int, number of total iterations 19 | batch_size: 8 #type=int, batch size for training 20 | lr: 0.0001 #type=float, learning rate for attribute encoder and discriminator 21 | id_lr: 0.0001 #type=float, learning rate for identity encoder 22 | beta1: 0.0 #type=float, decay rate for 1st moment of Adam 23 | beta2: 0.99 #type=float, decay rate for 2nd moment of Adam 24 | weight_decay: 0.0001 #type=float, weight decay for optimizer 25 | 26 | 27 | # misc 28 | mode: 'train' #type=str, choices=['train', 'test'] 29 | num_workers: 8 #type=int, number of workers used in DataLoader 30 | 31 | # directory for training 32 | train_img_dir: ['data/CelebA_Dataset/CelebA'] #type=str, directory containing training images 33 | sample_dir: 'expr/samples/CelebA/' #type=str, directory for saving generated images 34 | checkpoint_dir: 'expr/checkpoints/CelebA/' #type=str, directory for saving network checkpoints 35 | log_dir: 'expr/logs/CelebA/' #type=str, directory for saving logs 36 | resume_iter: 0 #type=int, number of iterations to resume training 37 | 38 | # directory for testing 39 | test_img_dir: 'data/FF++_Dataset/ff++' #type=str, directory containing testing images 40 | test_img_list: 'data/FF++_Dataset/face_swap_list.txt' #type=str, containing swapping images 41 | test_checkpoint_dir: 'pretrained_checkpoints/' #type=str, directory for pretrained face swapping model 42 | test_checkpoint_name: 'faceswapper.ckpt' #type=str, pretrained face swapping model name 43 | result_dir: 'expr/results/ff++/' #type=str, directory for saving generated images 44 | post_process: True #type=bool, whether we need the post process procedure or not 45 | 46 | # pretrained face alignment model 47 | wing_path: 'pretrained_checkpoints/wing.ckpt' 48 | face_model_path: 'pretrained_checkpoints/model_ir_se50.pth' 49 | 50 | # step size 51 | print_every: 50 #type=int, number of iterations to print log info 52 | sample_every: 2000 #type=int, number of iterations to display the results, 2000 53 | save_every: 5000 #type=int, number of iterations to save the checkpoint, 5000 54 | 55 | 56 | --------------------------------------------------------------------------------