├── README.md
├── core
├── checkpoint.py
├── data_loader.py
├── face_model.py
├── model.py
├── resnet.py
├── solver.py
├── utils.py
└── wing.py
├── images
├── 1
└── qualitative_comparisons.jpg
├── main.py
└── param.yaml
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## FaceSwapper - Official PyTorch Implementation
3 |
4 |
5 |
6 |
7 | > **FaceSwapper: Learning Disentangled Representation for One-shot Progressive Face Swapping**
8 | > [Qi Li](https://liqi-casia.github.io/), [Weining Wang](https://scholar.google.com/citations?hl=en&user=NDPvobAAAAAJ), [Chengzhong Xu](https://www.fst.um.edu.mo/people/czxu/), [Zhenan Sun](https://scholar.google.com.au/citations?user=PuZGODYAAAAJ&hl=en), [Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=zh-CN)
9 | > In TPAMI 2024.
10 |
11 |
12 | > Paper: [https://ieeexplore.ieee.org/abstract/document/10536627](https://ieeexplore.ieee.org/abstract/document/10536627)
13 |
14 |
15 | >
Although face swapping has attracted much attention in recent years, it remains a challenging problem. The existing methods leverage a large number of data samples to explore the intrinsic properties of face swapping without taking into account the semantic information of face images. Moreover, the representation of the identity information tends to be fixed, leading to suboptimal face swapping. In this paper, we present a simple yet efficient method named FaceSwapper, for one-shot face swapping based on Generative Adversarial Networks. Our method consists of a disentangled representation module and a semantic-guided fusion module. The disentangled representation module is composed of an attribute encoder and an identity encoder, which aims to achieve the disentanglement of the identity and the attribute information. The identity encoder is more flexible and the attribute encoder contains more details of the attributes than its competitors. Benefiting from the disentangled representation, FaceSwapper can swap face images progressively. In addition, semantic information is introduced into the semantic-guided fusion module to control the swapped area and model the pose and expression more accurately. The experimental results show that our method achieves state-of-the-art results on benchmark datasets with fewer training samples.
16 |
17 |
18 |
19 | >
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | ## Environment
28 | Clone this repository:
29 |
30 | ```bash
31 | git clone https://github.com/liqi-casia/FaceSwapper.git
32 | cd FaceSwapper/
33 | ```
34 |
35 | Install the dependencies:
36 | ```bash
37 | conda create -n faceswapper python=3.6.7
38 | conda activate faceswapper
39 | conda install -y pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.0 -c pytorch
40 | conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
41 | pip install opencv-python==4.1.2.30 ffmpeg-python==0.2.0 scikit-image==0.16.2
42 | pip install pillow==7.0.0 scipy==1.2.1 tqdm==4.43.0 munch==2.5.0
43 | conda install -y -c anaconda pyyaml
44 | pip install tensorboard tensorboardX
45 |
46 | ```
47 |
48 |
49 | ## Datasets and pre-trained checkpoints
50 | We provide a link to download datasets used in FaceSwapper and the corresponding pre-trained checkpoints. The datasets and checkpoints should be moved to the `data` and `expr/checkpoints` directories, respectively.
51 |
52 | * Datasets. Click here to download the **CelebA Dataset** through [Baidu Netdisk](https://pan.baidu.com/s/12KqrRI_K9frgky2YPNSJ0Q?pwd=wvy5) or [Google Drive](https://drive.google.com/file/d/1W-FF-DNJ752L7Zqdm-otMT6frOcTLd7q/view?usp=sharing) and the **FaceForensics++ Dataset** through [Baidu Netdisk](https://pan.baidu.com/s/1LcKJw2sGkEAHWjmTcxou-A?pwd=omkc) or [Google Drive](https://drive.google.com/file/d/1Sa6v0m8s4xHXPPzeFt8K1q1n3T088Gvz/view?usp=sharing).
53 |
54 | * Checkpoints. Click here to download the **face recognition model** through [Baidu Netdisk](https://pan.baidu.com/s/11qcEiBjAsQPXwIqKjOE-rQ?pwd=2g78) or [Google Drive](https://drive.google.com/file/d/1-lxc-jZGIFNdwFUXQ9tDS9OSuhadj6AC/view?usp=sharing), the **face alignment model** through [Baidu Netdisk](https://pan.baidu.com/s/1htwmXDi2Gev8l09oJpr_Mg?pwd=ejmj) or [Google Drive](https://drive.google.com/file/d/1lBt4x4P5qaClB2ZN_POBV-ue41hdlaoJ/view?usp=sharing), and the **face swapping model** through [Baidu Netdisk](https://pan.baidu.com/s/1aIRX0twylUJ42z4sYhUaVA?pwd=bkru) or [Google Drive](https://drive.google.com/file/d/1Tb3V09wbaGe6SaiN3BZkOcCy7VJ0KYC8/view?usp=sharing).
55 |
56 | After storing all the files, the directory structure of `./data` and `./pretrained_checkpoints` is expected as follows.
57 |
58 |
59 |
60 |
61 | ```
62 | ./data
63 | ├── CelebA Dataset
64 | │ ├── CalebA images
65 | │ ├── CelebA landmark images
66 | │ └── CelebA mask images
67 | └── FF++ Dataset
68 | ├── ff++ images
69 | ├── ff++ landmark images
70 | ├── ff++ mask images
71 | └── ff++ parsing images
72 |
73 | ./pretrained_checkpoints
74 | ├── model_ir_se50.pth
75 | ├── wing.ckpt
76 | └── faceswapper.ckpt
77 | ```
78 |
79 |
80 |
81 |
82 | ## Generating swapped images
83 | After downloading the pre-trained checkpoints, you can synthesize swapped images. The following commands will save generated images to the `expr/results` directory.
84 |
85 |
86 | FaceForensics++ Dataset. To generate swapped images, you need to specify the testing parameters in param.yaml (especially change mode from 'train' to 'test', and pay attention to paramerters in *\#directory for testing* ). Then run the following command:
87 | ```bash
88 | python main.py
89 | ```
90 | There are three subfolders in `expr/results/ff++/`, which are named `swapped_result_single`, `swapped_result_afterps` and `swapped_result_all`. Each image is named as *source_FS_target.png*, where source image provides the identity information and target image provides attribute information.
91 |
92 | --`swapped_result_single`: the swapped images.
93 |
94 | --`swapped_result_afterps`: the swapped images after post process.
95 |
96 | --`swapped_result_all`: caoncatenation of the souce images, the target images, the swapped images and the
97 | swapped images after post process.
98 |
99 | Other Datasets.
100 | First, crop and align face images from other datasets automatically so that the proportion of face occupied in the whole is similar to that of CelebA dataset and FaceForensics++ dataset. Then, define the face swapping list siimilar to `face_swap_list.txt` (`source_image_name target_image_name`). The other testing procedure is similar to FaceForensics++ Dataset.
101 |
102 |
103 | Post Process. If occlusion exists in the source image (*e.g.*, hair, hat), we simply preserve the forehead and hair of the target image in the swapped image.
104 | Othewise, we simplely preserve the hair of the target image. You just need to Set `post_process: True` if you want the post process.
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 | ## Training networks
115 |
116 |
117 | To train FaceSwapper from scratch, just set the training parameters in param.yaml (especially change mode from 'test' to 'train', and pay attention to paramerters in *\#directory for training* ), and run the following commands. Generated images and network checkpoints will be stored in the `expr/samples` and `expr/checkpoints` directories, respectively. Training usually takes about several days on a single Tesla V100 GPU depending on the total trainig iterations.
118 |
119 |
120 |
121 |
122 | ```bash
123 | python main.py
124 | ```
125 |
126 |
127 |
128 | ## License
129 | The source code, pre-trained models, and dataset are available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license. You can use, copy, tranform and build upon the material for non-commercial purposes as long as you give appropriate credit by citing our paper, and indicate if changes were made.
130 | For technical, business and other inquiries, please contact qli@nlpr.ia.ac.cn.
131 |
132 |
133 |
134 | ## Citation
135 | If you find this work useful for your research, please cite our paper:
136 |
137 | ```
138 | @article{li2024learning,
139 | author={Li, Qi and Wang, Weining and Xu, Chengzhong and Sun, Zhenan and Yang, Ming-Hsuan},
140 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
141 | title={Learning Disentangled Representation for One-Shot Progressive Face Swapping},
142 | year={2024},
143 | volume={46},
144 | number={12},
145 | pages={8348-8364}
146 | }
147 | ```
148 |
149 | ## Acknowledgements
150 | The code is written based on the following projects. We would like to thank for their contributions.
151 |
152 | - [Stargan V2](https://github.com/clovaai/stargan-v2)
153 |
154 | - [face alignment](https://github.com/1adrianb/face-alignment)
155 | - [AdaptiveWingLoss](https://github.com/protossw512/AdaptiveWingLoss)
156 |
157 |
158 |
--------------------------------------------------------------------------------
/core/checkpoint.py:
--------------------------------------------------------------------------------
1 | """
2 | This work is licensed under the Creative Commons Attribution-NonCommercial
3 | 4.0 International License. To view a copy of this license, visit
4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
6 | """
7 |
8 | import os
9 | import torch
10 |
11 | class CheckpointIO(object):
12 | def __init__(self, fname_template, **kwargs):
13 | os.makedirs(os.path.dirname(fname_template), exist_ok=True)
14 | self.fname_template = fname_template
15 | self.module_dict = kwargs
16 | def register(self, **kwargs):
17 | self.module_dict.update(kwargs)
18 | def save(self, step):
19 | fname = self.fname_template.format(step)
20 | print('Saving checkpoint into %s...' % fname)
21 | outdict = {}
22 | for name, module in self.module_dict.items():
23 | outdict[name] = module.state_dict()
24 | torch.save(outdict, fname)
25 | def load(self, step):
26 | fname = self.fname_template.format(step)
27 | assert os.path.exists(fname), fname + ' does not exist!'
28 | print('Loading checkpoint from %s...' % fname)
29 | if torch.cuda.is_available():
30 | module_dict = torch.load(fname)
31 | else:
32 | module_dict = torch.load(fname, map_location=torch.device('cpu'))
33 | for name, module in self.module_dict.items():
34 | module.load_state_dict(module_dict[name])
35 | def load_test(self,ckptname):
36 | fname = self.fname_template + ckptname
37 | assert os.path.exists(fname), fname + ' does not exist!'
38 | print('Loading checkpoint from %s...' % fname)
39 | if torch.cuda.is_available():
40 | module_dict = torch.load(fname)
41 | else:
42 | module_dict = torch.load(fname, map_location=torch.device('cpu'))
43 | for name, module in self.module_dict.items():
44 | module.load_state_dict(module_dict[name])
--------------------------------------------------------------------------------
/core/data_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | This work is licensed under the Creative Commons Attribution-NonCommercial
3 | 4.0 International License. To view a copy of this license, visit
4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
6 | """
7 |
8 | from pathlib import Path
9 | from itertools import chain
10 | from munch import Munch
11 | from PIL import Image
12 | import random
13 | import glob
14 | import copy
15 | import torch
16 | from torch.utils import data
17 | from torchvision import transforms
18 |
19 | def listdir(dname):
20 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext))
21 | for ext in ['png', 'jpg', 'jpeg', 'JPG']]))
22 | return fnames
23 | class DefaultDataset(data.Dataset):
24 | def __init__(self, root, transform=None):
25 | self.samples = listdir(root)
26 | self.samples.sort()
27 | self.transform = transform
28 | self.targets = None
29 | def __getitem__(self, index):
30 | fname = self.samples[index]
31 | img = Image.open(fname).convert('RGB')
32 | if self.transform is not None:
33 | img = self.transform(img)
34 | return img
35 | def __len__(self):
36 | return len(self.samples)
37 |
38 | class TrainFaceDataSet(data.Dataset):
39 | def __init__(self, data_path_list, transform=None, transform_seg=None):
40 | self.datasets = []
41 | self.num_per_folder =[]
42 | self.lm_image_path = data_path_list[0][:data_path_list[0].rfind('/')+1] \
43 | + data_path_list[0][data_path_list[0].rfind('/')+1:] + '_lm_images/'
44 | self.mask_image_path = data_path_list[0][:data_path_list[0].rfind('/')+1] \
45 | + data_path_list[0][data_path_list[0].rfind('/')+1:] + '_mask_images/'
46 | for data_path in data_path_list:
47 | image_list = glob.glob(f'{data_path}/*.*g')
48 | self.datasets.append(image_list)
49 | self.num_per_folder.append(len(image_list))
50 | self.transform = transform
51 | self.transform_seg = transform_seg
52 |
53 | def __getitem__(self, item):
54 | idx = 0
55 | while item >= self.num_per_folder[idx]:
56 | item -= self.num_per_folder[idx]
57 | idx += 1
58 | image_path = self.datasets[idx][item]
59 | souce_lm_image_path = self.lm_image_path + image_path.split('/')[-1]
60 | souce_mask_image_path = self.mask_image_path + image_path.split('/')[-1]
61 | source_image = Image.open(image_path).convert('RGB')
62 | source_lm_image = Image.open(souce_lm_image_path).convert('RGB')
63 | source_mask_image = Image.open(souce_mask_image_path).convert('L')
64 | if self.transform is not None:
65 | source_image = self.transform(source_image)
66 | source_lm_image = self.transform(source_lm_image)
67 | source_mask_image = self.transform_seg(source_mask_image)
68 | #choose ref from the same folder image
69 | temp = copy.deepcopy(self.datasets[idx])
70 | temp.pop(item)
71 | reference_image_path = temp[random.randint(0, len(temp)-1)]
72 | reference_lm_image_path = self.lm_image_path + reference_image_path.split('/')[-1]
73 | reference_mask_image_path = self.mask_image_path + reference_image_path.split('/')[-1]
74 | reference_image = Image.open(reference_image_path).convert('RGB')
75 | reference_lm_image = Image.open(reference_lm_image_path).convert('RGB')
76 | reference_mask_image = Image.open(reference_mask_image_path).convert('L')
77 | if self.transform is not None:
78 | reference_image = self.transform(reference_image)
79 | reference_lm_image = self.transform(reference_lm_image)
80 | reference_mask_image = self.transform_seg(reference_mask_image)
81 | outputs=dict(src=source_image, ref=reference_image, src_lm=source_lm_image, ref_lm=reference_lm_image,
82 | src_mask=1-source_mask_image, ref_mask=1-reference_mask_image)
83 | return outputs
84 | def __len__(self):
85 | return sum(self.num_per_folder)
86 |
87 | class TestFaceDataSet(data.Dataset):
88 | def __init__(self, data_path_list, test_img_list, transform=None, transform_seg=None):
89 | self.source_dataset = []
90 | self.reference_dataset = []
91 | self.data_path_list = data_path_list
92 | self.lm_image_path = data_path_list[:data_path_list.rfind('/')+1] \
93 | + data_path_list[data_path_list.rfind('/')+1:] + '_lm_images/'
94 | self.mask_image_path = data_path_list[:data_path_list.rfind('/')+1] \
95 | + data_path_list[data_path_list.rfind('/')+1:] + '_mask_images/'
96 | self.biseg_parsing_path = data_path_list[:data_path_list.rfind('/')+1] \
97 | + data_path_list[data_path_list.rfind('/')+1:] + '_parsing_images/'
98 | f=open(test_img_list,'r')
99 | for line in f.readlines():
100 | line.split(' ')
101 | self.source_dataset.append(line.split(' ')[0])
102 | self.reference_dataset.append(line.split(' ')[1])
103 | f.close()
104 | self.transform = transform
105 | self.transform_seg = transform_seg
106 | def __getitem__(self, item):
107 | source_image_path = self.data_path_list + '/' + self.source_dataset[item]
108 | try:
109 | source_image = Image.open(source_image_path).convert('RGB')
110 | except:
111 | print('fail to read %s.jpg'%source_image_path)
112 | souce_lm_image_path = self.lm_image_path + self.source_dataset[item]
113 | souce_mask_image_path = self.mask_image_path + self.source_dataset[item]
114 | source_parsing_image_path = self.biseg_parsing_path + self.source_dataset[item]
115 | source_lm_image = Image.open(souce_lm_image_path).convert('RGB')
116 | source_mask_image = Image.open(souce_mask_image_path).convert('L')
117 | source_parsing_image = Image.open(source_parsing_image_path).convert('L')
118 | if self.transform is not None:
119 | source_image = self.transform(source_image)
120 | source_lm_image = self.transform(source_lm_image)
121 | source_mask_image = self.transform_seg(source_mask_image)
122 | source_parsing = self.transform_seg(source_parsing_image)
123 | reference_image_path = self.data_path_list + '/' + self.reference_dataset[item][0:-1]
124 | try:
125 | reference_image = Image.open(reference_image_path).convert('RGB')
126 | except:
127 | print('fail to read %s.jpg' %reference_image_path)
128 | reference_lm_image_path = self.lm_image_path + self.reference_dataset[item][0:-1]
129 | reference_mask_image_path = self.mask_image_path + self.reference_dataset[item][0:-1]
130 | reference_parsing_image_path = self.biseg_parsing_path + self.reference_dataset[item][0:-1]
131 | reference_lm_image = Image.open(reference_lm_image_path).convert('RGB')
132 | reference_mask_image = Image.open(reference_mask_image_path).convert('L')
133 | reference_parsing = Image.open(reference_parsing_image_path).convert('L')
134 | if self.transform is not None:
135 | reference_image = self.transform(reference_image)
136 | reference_lm_image = self.transform(reference_lm_image)
137 | reference_mask_image = self.transform_seg(reference_mask_image)
138 | reference_parsing = self.transform_seg(reference_parsing)
139 | outputs=dict(src=source_image, ref=reference_image, src_lm=source_lm_image, ref_lm=reference_lm_image, src_mask=1-source_mask_image,
140 | ref_mask=1-reference_mask_image, src_parsing=source_parsing, ref_parsing=reference_parsing,
141 | src_name=self.source_dataset[item], ref_name=self.reference_dataset[item])
142 | return outputs
143 | def __len__(self):
144 | return len(self.source_dataset)
145 |
146 | def get_train_loader(root, img_size=256,
147 | batch_size=8, num_workers=4):
148 | print('Preparing dataLoader to fetch images during the training phase...')
149 | transform = transforms.Compose([
150 | transforms.Resize([img_size, img_size]),
151 | transforms.ToTensor(),
152 | transforms.Normalize(mean=[0.5, 0.5, 0.5],
153 | std=[0.5, 0.5, 0.5]),
154 | ])
155 | transform_seg = transforms.Compose([
156 | transforms.Resize([img_size, img_size]),
157 | transforms.ToTensor(),
158 | ])
159 | train_dataset = TrainFaceDataSet(root, transform, transform_seg)
160 | train_loader = data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True,
161 | num_workers=num_workers, drop_last=True)
162 | return train_loader
163 |
164 | def get_test_loader(root, test_img_list, img_size=256,
165 | batch_size=8, num_workers=4):
166 | print('Preparing dataLoader to fetch images during the testing phase...')
167 | transform = transforms.Compose([
168 | transforms.Resize([img_size, img_size]),
169 | transforms.ToTensor(),
170 | transforms.Normalize(mean=[0.5, 0.5, 0.5],
171 | std=[0.5, 0.5, 0.5]),
172 | ])
173 | transform_seg = transforms.Compose([
174 | transforms.Resize([img_size, img_size]),
175 | transforms.ToTensor(),
176 | ])
177 | test_dataset = TestFaceDataSet(root, test_img_list, transform, transform_seg)
178 | test_loader = data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False,
179 | num_workers=num_workers, drop_last=True)
180 | return test_loader
181 |
182 | class InputFetcher:
183 | def __init__(self, loader, mode=''):
184 | self.loader = loader
185 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
186 | self.mode = mode
187 | def _fetch_inputs(self):
188 | try:
189 | inputs_data = next(self.iter)
190 | except (AttributeError, StopIteration):
191 | self.iter = iter(self.loader)
192 | inputs_data= next(self.iter)
193 | return inputs_data
194 | def __next__(self):
195 | t_inputs = self._fetch_inputs()
196 | inputs = Munch(src=t_inputs['src'], tar=t_inputs['ref'], src_lm=t_inputs['src_lm'],
197 | tar_lm=t_inputs['ref_lm'], src_mask=t_inputs['src_mask'], tar_mask=t_inputs['ref_mask'])
198 | if self.mode=='train':
199 | inputs = Munch({k: t.to(self.device) for k, t in inputs.items()})
200 | elif self.mode=='test':
201 | inputs = Munch({k: t.to(self.device) for k, t in inputs.items()}, src_parsing=t_inputs['src_parsing'].to(self.device),
202 | tar_parsing=t_inputs['ref_parsing'].to(self.device), src_name=t_inputs['src_name'],tar_name=t_inputs['ref_name'])
203 | return inputs
204 |
205 |
--------------------------------------------------------------------------------
/core/face_model.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter
2 | import torch
3 | from collections import namedtuple
4 | import math
5 |
6 | ################################## Original Arcface Model #############################################################
7 |
8 | class Flatten(Module):
9 | def forward(self, input):
10 | return input.view(input.size(0), -1)
11 |
12 | def l2_norm(input,axis=1):
13 | norm = torch.norm(input,2,axis,True)
14 | output = torch.div(input, norm)
15 | return output
16 |
17 | class SEModule(Module):
18 | def __init__(self, channels, reduction):
19 | super(SEModule, self).__init__()
20 | self.avg_pool = AdaptiveAvgPool2d(1)
21 | self.fc1 = Conv2d(
22 | channels, channels // reduction, kernel_size=1, padding=0 ,bias=False)
23 | self.relu = ReLU(inplace=True)
24 | self.fc2 = Conv2d(
25 | channels // reduction, channels, kernel_size=1, padding=0 ,bias=False)
26 | self.sigmoid = Sigmoid()
27 |
28 | def forward(self, x):
29 | module_input = x
30 | x = self.avg_pool(x)
31 | x = self.fc1(x)
32 | x = self.relu(x)
33 | x = self.fc2(x)
34 | x = self.sigmoid(x)
35 | return module_input * x
36 |
37 | class bottleneck_IR(Module):
38 | def __init__(self, in_channel, depth, stride):
39 | super(bottleneck_IR, self).__init__()
40 | if in_channel == depth:
41 | self.shortcut_layer = MaxPool2d(1, stride)
42 | else:
43 | self.shortcut_layer = Sequential(
44 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False), BatchNorm2d(depth))
45 | self.res_layer = Sequential(
46 | BatchNorm2d(in_channel),
47 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1 ,bias=False), PReLU(depth),
48 | Conv2d(depth, depth, (3, 3), stride, 1 ,bias=False), BatchNorm2d(depth))
49 |
50 | def forward(self, x):
51 | shortcut = self.shortcut_layer(x)
52 | res = self.res_layer(x)
53 | return res + shortcut
54 |
55 | class bottleneck_IR_SE(Module):
56 | def __init__(self, in_channel, depth, stride):
57 | super(bottleneck_IR_SE, self).__init__()
58 | if in_channel == depth:
59 | self.shortcut_layer = MaxPool2d(1, stride)
60 | else:
61 | self.shortcut_layer = Sequential(
62 | Conv2d(in_channel, depth, (1, 1), stride ,bias=False),
63 | BatchNorm2d(depth))
64 | self.res_layer = Sequential(
65 | BatchNorm2d(in_channel),
66 | Conv2d(in_channel, depth, (3,3), (1,1),1 ,bias=False),
67 | PReLU(depth),
68 | Conv2d(depth, depth, (3,3), stride, 1 ,bias=False),
69 | BatchNorm2d(depth),
70 | SEModule(depth,16)
71 | )
72 | def forward(self,x):
73 | shortcut = self.shortcut_layer(x)
74 | res = self.res_layer(x)
75 | return res + shortcut
76 |
77 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
78 | '''A named tuple describing a ResNet block.'''
79 |
80 | def get_block(in_channel, depth, num_units, stride = 2):
81 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units-1)]
82 |
83 | def get_blocks(num_layers):
84 | if num_layers == 50:
85 | blocks = [
86 | get_block(in_channel=64, depth=64, num_units = 3),
87 | get_block(in_channel=64, depth=128, num_units=4),
88 | get_block(in_channel=128, depth=256, num_units=14),
89 | get_block(in_channel=256, depth=512, num_units=3)
90 | ]
91 | elif num_layers == 100:
92 | blocks = [
93 | get_block(in_channel=64, depth=64, num_units=3),
94 | get_block(in_channel=64, depth=128, num_units=13),
95 | get_block(in_channel=128, depth=256, num_units=30),
96 | get_block(in_channel=256, depth=512, num_units=3)
97 | ]
98 | elif num_layers == 152:
99 | blocks = [
100 | get_block(in_channel=64, depth=64, num_units=3),
101 | get_block(in_channel=64, depth=128, num_units=8),
102 | get_block(in_channel=128, depth=256, num_units=36),
103 | get_block(in_channel=256, depth=512, num_units=3)
104 | ]
105 | return blocks
106 |
107 | class Backbone(Module):
108 | def __init__(self, num_layers, drop_ratio, mode='ir'):
109 | super(Backbone, self).__init__()
110 | assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
111 | assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
112 | blocks = get_blocks(num_layers)
113 | if mode == 'ir':
114 | unit_module = bottleneck_IR
115 | elif mode == 'ir_se':
116 | unit_module = bottleneck_IR_SE
117 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1 ,bias=False),
118 | BatchNorm2d(64),
119 | PReLU(64))
120 | self.output_layer = Sequential(BatchNorm2d(512),
121 | Dropout(drop_ratio),
122 | Flatten(),
123 | Linear(512 * 7 * 7, 512),
124 | BatchNorm1d(512))
125 | # )
126 | modules = []
127 | for block in blocks:
128 | for bottleneck in block:
129 | modules.append(
130 | unit_module(bottleneck.in_channel,
131 | bottleneck.depth,
132 | bottleneck.stride))
133 | self.body = Sequential(*modules)
134 |
135 | def forward(self,x):
136 | feats = []
137 | x = self.input_layer(x)
138 | for m in self.body.children():
139 | x = m(x)
140 | feats.append(x)
141 | x = self.output_layer(x)
142 | return l2_norm(x), feats
143 |
144 | ################################## MobileFaceNet #############################################################
145 |
146 | class Conv_block(Module):
147 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
148 | super(Conv_block, self).__init__()
149 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
150 | self.bn = BatchNorm2d(out_c)
151 | self.prelu = PReLU(out_c)
152 | def forward(self, x):
153 | x = self.conv(x)
154 | x = self.bn(x)
155 | x = self.prelu(x)
156 | return x
157 |
158 | class Linear_block(Module):
159 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
160 | super(Linear_block, self).__init__()
161 | self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
162 | self.bn = BatchNorm2d(out_c)
163 | def forward(self, x):
164 | x = self.conv(x)
165 | x = self.bn(x)
166 | return x
167 |
168 | class Depth_Wise(Module):
169 | def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
170 | super(Depth_Wise, self).__init__()
171 | self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
172 | self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
173 | self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
174 | self.residual = residual
175 | def forward(self, x):
176 | if self.residual:
177 | short_cut = x
178 | x = self.conv(x)
179 | x = self.conv_dw(x)
180 | x = self.project(x)
181 | if self.residual:
182 | output = short_cut + x
183 | else:
184 | output = x
185 | return output
186 |
187 | class Residual(Module):
188 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
189 | super(Residual, self).__init__()
190 | modules = []
191 | for _ in range(num_block):
192 | modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
193 | self.model = Sequential(*modules)
194 | def forward(self, x):
195 | return self.model(x)
196 |
197 | class MobileFaceNet(Module):
198 | def __init__(self, embedding_size):
199 | super(MobileFaceNet, self).__init__()
200 | self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
201 | self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
202 | self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
203 | self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
204 | self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
205 | self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
206 | self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
207 | self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
208 | self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
209 | self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
210 | self.conv_6_flatten = Flatten()
211 | self.linear = Linear(512, embedding_size, bias=False)
212 | self.bn = BatchNorm1d(embedding_size)
213 |
214 | def forward(self, x):
215 | out = self.conv1(x)
216 | out = self.conv2_dw(out)
217 | out = self.conv_23(out)
218 | out = self.conv_3(out)
219 | out = self.conv_34(out)
220 | out = self.conv_4(out)
221 | out = self.conv_45(out)
222 | out = self.conv_5(out)
223 | out = self.conv_6_sep(out)
224 | out = self.conv_6_dw(out)
225 | out = self.conv_6_flatten(out)
226 | out = self.linear(out)
227 | out = self.bn(out)
228 | return l2_norm(out)
229 |
230 | ################################## Arcface head #############################################################
231 |
232 | class Arcface(Module):
233 | # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
234 | def __init__(self, embedding_size=512, classnum=51332, s=64., m=0.5):
235 | super(Arcface, self).__init__()
236 | self.classnum = classnum
237 | self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
238 | self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
239 | self.m = m
240 | self.s = s
241 | self.cos_m = math.cos(m)
242 | self.sin_m = math.sin(m)
243 | self.mm = self.sin_m * m # issue 1
244 | self.threshold = math.cos(math.pi - m)
245 | def forward(self, embbedings, label):
246 | # weights norm
247 | nB = len(embbedings)
248 | kernel_norm = l2_norm(self.kernel,axis=0)
249 | cos_theta = torch.mm(embbedings,kernel_norm)
250 | cos_theta = cos_theta.clamp(-1,1)
251 | cos_theta_2 = torch.pow(cos_theta, 2)
252 | sin_theta_2 = 1 - cos_theta_2
253 | sin_theta = torch.sqrt(sin_theta_2)
254 | cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
255 | cond_v = cos_theta - self.threshold
256 | cond_mask = cond_v <= 0
257 | keep_val = (cos_theta - self.mm)
258 | cos_theta_m[cond_mask] = keep_val[cond_mask]
259 | output = cos_theta * 1.0
260 | idx_ = torch.arange(0, nB, dtype=torch.long)
261 | output[idx_, label] = cos_theta_m[idx_, label]
262 | output *= self.s
263 | return output
264 |
265 | ################################## Cosface head #############################################################
266 |
267 | class Am_softmax(Module):
268 | # implementation of additive margin softmax loss in https://arxiv.org/abs/1801.05599
269 | def __init__(self,embedding_size=512,classnum=51332):
270 | super(Am_softmax, self).__init__()
271 | self.classnum = classnum
272 | self.kernel = Parameter(torch.Tensor(embedding_size,classnum))
273 | # initial kernel
274 | self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
275 | self.m = 0.35
276 | self.s = 30.
277 | def forward(self,embbedings,label):
278 | kernel_norm = l2_norm(self.kernel,axis=0)
279 | cos_theta = torch.mm(embbedings,kernel_norm)
280 | cos_theta = cos_theta.clamp(-1,1)
281 | phi = cos_theta - self.m
282 | label = label.view(-1,1)
283 | index = cos_theta.data * 0.0
284 | index.scatter_(1,label.data.view(-1,1),1)
285 | index = index.byte()
286 | output = cos_theta * 1.0
287 | output[index] = phi[index]
288 | output *= self.s
289 | return output
290 |
291 |
--------------------------------------------------------------------------------
/core/model.py:
--------------------------------------------------------------------------------
1 | """
2 | This work is licensed under the Creative Commons Attribution-NonCommercial
3 | 4.0 International License. To view a copy of this license, visit
4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
6 | """
7 |
8 | import copy
9 | import math
10 | from munch import Munch
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | from core.wing import FAN
16 |
17 |
18 |
19 | class ResBlk(nn.Module):
20 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
21 | normalize=False, downsample=False):
22 | super().__init__()
23 | self.actv = actv
24 | self.normalize = normalize
25 | self.downsample = downsample
26 | self.learned_sc = dim_in != dim_out
27 | self._build_weights(dim_in, dim_out)
28 |
29 | def _build_weights(self, dim_in, dim_out):
30 | self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1)
31 | self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
32 | if self.normalize:
33 | self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
34 | self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
35 | if self.learned_sc:
36 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
37 |
38 | def _shortcut(self, x):
39 | if self.learned_sc:
40 | x = self.conv1x1(x)
41 | if self.downsample:
42 | x = F.avg_pool2d(x, 2)
43 | return x
44 |
45 | def _residual(self, x):
46 | if self.normalize:
47 | x = self.norm1(x)
48 | x = self.actv(x)
49 | x = self.conv1(x)
50 | if self.downsample:
51 | x = F.avg_pool2d(x, 2)
52 | if self.normalize:
53 | x = self.norm2(x)
54 | x = self.actv(x)
55 | x = self.conv2(x)
56 | return x
57 |
58 | def forward(self, x):
59 | x = self._shortcut(x) + self._residual(x)
60 | return x / math.sqrt(2) # unit variance
61 |
62 |
63 | class AdaIN(nn.Module):
64 | def __init__(self, id_dim, num_features):
65 | super().__init__()
66 | input_nc = 3
67 | self.norm = nn.InstanceNorm2d(num_features, affine=False)
68 | self.fc = nn.Linear(id_dim, num_features*2)
69 | self.conv_weight = nn.Conv2d(input_nc, num_features, kernel_size=3, padding=1)
70 | self.conv_bias = nn.Conv2d(input_nc, num_features, kernel_size=3, padding=1)
71 |
72 | def forward(self, x, s, mask, landmark):
73 | h = self.fc(s)
74 | h = h.view(h.size(0), h.size(1), 1, 1)
75 | gamma, beta = torch.chunk(h, chunks=2, dim=1)
76 | face_part = (1-mask[x.size(2)]) * x # face area;
77 | norm_face_part = (1 + gamma) * self.norm(face_part) + beta
78 | landmark = F.interpolate(landmark, x.size(2), mode='bilinear',align_corners=True)
79 | weight_norm = self.conv_weight(landmark)
80 | bias_norm = self.conv_bias(landmark)
81 | norm_face_part = norm_face_part * (1+weight_norm) + bias_norm
82 | new_face = mask[x.size(2)] * x + (1-mask[x.size(2)])*norm_face_part
83 | return new_face
84 |
85 |
86 | class AdainResBlk(nn.Module):
87 | def __init__(self, dim_in, dim_out, id_dim=512,
88 | actv=nn.LeakyReLU(0.2), upsample=False):
89 | super().__init__()
90 | self.actv = actv
91 | self.upsample = upsample
92 | self.learned_sc = dim_in != dim_out
93 | self._build_weights(dim_in, dim_out, id_dim)
94 |
95 | def _build_weights(self, dim_in, dim_out, id_dim):
96 | self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
97 | self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
98 | self.norm1 = AdaIN(id_dim, dim_in)
99 | self.norm2 = AdaIN(id_dim, dim_out)
100 | if self.learned_sc:
101 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
102 |
103 | def _shortcut(self, x):
104 | if self.upsample:
105 | x = F.interpolate(x, scale_factor=2, mode='nearest')
106 | if self.learned_sc:
107 | x = self.conv1x1(x)
108 | return x
109 |
110 | def _residual(self, x, s, mask, landmark):
111 | x = self.norm1(x, s, mask, landmark)
112 | x = self.actv(x)
113 | if self.upsample:
114 | x = F.interpolate(x, scale_factor=2, mode='nearest')
115 | x = self.conv1(x)
116 | x = self.norm2(x, s, mask,landmark)
117 | x = self.actv(x)
118 | x = self.conv2(x)
119 | return x
120 |
121 | def forward(self, x, s, mask,landmark):
122 | out = self._residual(x, s, mask,landmark)
123 | return out
124 |
125 | class Generator(nn.Module):
126 | def __init__(self, img_size=256, id_dim=512, max_conv_dim=512):
127 | super().__init__()
128 | dim_in = 2**14 // img_size
129 | self.img_size = img_size
130 | self.id_encoder = IdentityEncoder(self.img_size, id_dim, max_conv_dim)
131 | self.attr_encoder = AttrEncoder(self.img_size,max_conv_dim)
132 | self.org_decoder = Decoder(self.img_size,id_dim,max_conv_dim)
133 | def forward(self, x_a, x_b, x_a_lm, x_b_lm, x_a_mask=None, x_b_mask=None):
134 | x_a_attr, x_a_idvec, x_a_cache = self.encode(x_a, x_a_mask)
135 | x_b_attr, x_b_idvec, x_b_cache = self.encode(x_b, x_b_mask)
136 | x_ba, ms_features_ba, ms_outputs_ba = self.decode(x_b_attr, x_a_idvec, x_b_lm, x_b_cache, x_b_mask) # a's identity
137 | x_ab, ms_features_ab, ms_outputs_ab = self.decode(x_a_attr, x_b_idvec, x_a_lm, x_a_cache, x_a_mask) # b's identity
138 | return x_ba, x_ab, ms_features_ba, ms_features_ab, ms_outputs_ba, ms_outputs_ab
139 | def encode(self, image,mask=None):
140 | # encode an image to its attribute code and identity code
141 | id_vec, id_all_features = self.id_encoder(image)
142 | attr, attr_all_features, cache = self.attr_encoder(image, mask)
143 | return attr, id_vec, cache
144 | def decode(self, attr, id_vec, lm_image, cache, mask=None):
145 | image, ms_features, ms_outputs = self.org_decoder(attr, id_vec, lm_image, cache, mask)
146 | return image, ms_features, ms_outputs
147 | def encode_features(self, image, mask=None):
148 | # encode an image to its multiscale attribute feature and identity feature
149 | id_vec, id_all_features = self.id_encoder(image)
150 | attr, attr_all_features, cache = self.attr_encoder(image, mask)
151 | return attr_all_features, id_all_features
152 |
153 |
154 | # attribute encoder
155 | class AttrEncoder(nn.Module):
156 | def __init__(self, img_size=256, max_conv_dim=512):
157 | super().__init__()
158 | dim_in = 2**14 // img_size
159 | blocks = []
160 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
161 | repeat_num = int(np.log2(img_size)) - 4
162 | repeat_num += 1
163 | for _ in range(repeat_num):
164 | dim_out = min(dim_in*2, max_conv_dim)
165 | blocks += [ResBlk(dim_in, dim_out, normalize=True, downsample=True)]
166 | dim_in = dim_out
167 | # bottleneck blocks
168 | for _ in range(2):
169 | blocks += [ResBlk(dim_out, dim_out, normalize=True)]
170 | self.model = nn.Sequential(*blocks)
171 | def forward(self, x, masks=None):
172 | attr_all_features = []
173 | cache = {}
174 | for block in self.model:
175 | if (masks is not None) and (x.size(2) in [32, 64, 128]):
176 | cache[x.size(2)] = x
177 | x = block(x)
178 | attr_all_features.append(x)
179 | return x, attr_all_features, cache
180 |
181 | # identity encoder
182 | class IdentityEncoder(nn.Module):
183 | def __init__(self, img_size=256, id_dim=512, max_conv_dim=512):
184 | super().__init__()
185 | dim_in = 2**14 // img_size
186 | blocks = []
187 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)]
188 | repeat_num = int(np.log2(img_size)) - 2
189 | for _ in range(repeat_num):
190 | dim_out = min(dim_in*2, max_conv_dim)
191 | blocks += [ResBlk(dim_in, dim_out, downsample=True)]
192 | dim_in = dim_out
193 | blocks += [nn.LeakyReLU(0.2)]
194 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
195 | blocks += [nn.LeakyReLU(0.2)]
196 | self.model = nn.Sequential(*blocks)
197 | self.final_layer = nn.ModuleList()
198 | self.final_layer += [nn.Linear(dim_out, id_dim)]
199 | def forward(self, x):
200 | id_all_features = []
201 | for block in self.model:
202 | x = block(x)
203 | id_all_features.append(x)
204 | x = x.view(x.size(0), -1) # batch_size, dim_out
205 | for block in self.final_layer:
206 | x = block(x)
207 | x = x.view(x.size(0), -1)
208 | id_all_features.append(x)
209 | return x, id_all_features
210 |
211 | # decoder
212 | class Decoder(nn.Module):
213 | def __init__(self, img_size=256, id_dim=512, max_conv_dim=512):
214 | super().__init__()
215 | dim_in = 2**14 // img_size
216 | dim_in_org = dim_in
217 | self.mask_size = [32,64,128]
218 | self.x_size =[8,16,32,64,128,256]
219 | self.to_rgb = nn.Sequential(
220 | nn.InstanceNorm2d(dim_in, affine=True),
221 | nn.LeakyReLU(0.2),
222 | nn.Conv2d(dim_in, 3, 1, 1, 0))
223 | blocks = []
224 | repeat_num = int(np.log2(img_size)) - 4
225 | repeat_num += 1
226 | for _ in range(repeat_num):
227 | dim_out = min(dim_in*2, max_conv_dim)
228 | blocks.insert(0, AdainResBlk(dim_out, dim_in, id_dim,
229 | upsample=True))
230 | dim_in = dim_out
231 | # bottleneck blocks
232 | for _ in range(2):
233 | blocks.insert(0, AdainResBlk(dim_out, dim_out, id_dim))
234 | self.model = nn.Sequential(*blocks)
235 | def to_rgb_output(dim_before_RGB):
236 | output = nn.Sequential(
237 | nn.InstanceNorm2d(dim_before_RGB, affine=True),
238 | nn.LeakyReLU(0.2),
239 | nn.Conv2d(dim_before_RGB, 3, 1, 1, 0))
240 | return output
241 | def skip_connection(dim_in_skip):
242 | dim_out_skip = dim_in_skip
243 | output = nn.Sequential(
244 | nn.Conv2d(dim_in_skip, dim_out_skip, 1, 1, 0))
245 | return output
246 | self.rgb_converters = nn.ModuleList()
247 | self.skip_connects = nn.ModuleList()
248 | for i in self.mask_size:
249 | self.rgb_converters.append(to_rgb_output(int(dim_in_org*img_size/i)))
250 | self.skip_connects.append(skip_connection(int(dim_in_org*img_size/i)))
251 |
252 | def forward(self, x, id_vec, lm_image, cache=None, mask=None):
253 | ms_outputs=[]
254 | ms_features=[]
255 | dict_masks={}
256 | if (mask is not None):
257 | for i in self.x_size:
258 | mask = F.interpolate(mask, i, mode='bilinear', align_corners=True)
259 | dict_masks[i] = mask
260 | index=0
261 | for block in self.model:
262 | mask = dict_masks[x.size(2)]
263 | x = block(x, id_vec, dict_masks, lm_image) # 1-masks, face area
264 | if (mask is not None) and (x.size(2) in self.mask_size):
265 | mask = dict_masks[x.size(2)]
266 | x = x + self.skip_connects[index](mask * cache[x.size(2)]) # this mask should be attr
267 | ms_outputs.append(self.rgb_converters[index](x))
268 | index = index+1
269 | ms_features.append(x)
270 | x = self.to_rgb(x)
271 | return x, ms_features, ms_outputs
272 |
273 |
274 | class Discriminator(nn.Module):
275 | def __init__(self, img_size=256, max_conv_dim=512):
276 | super().__init__()
277 | dim_in = 2**14 // img_size
278 | num_domains = 1
279 | blocks = []
280 | blocks += [nn.Conv2d(6, dim_in, 3, 1, 1)]
281 | repeat_num = int(np.log2(img_size)) - 2
282 | for _ in range(repeat_num):
283 | dim_out = min(dim_in*2, max_conv_dim)
284 | blocks += [ResBlk(dim_in, dim_out, downsample=True)]
285 | dim_in = dim_out
286 | blocks += [nn.LeakyReLU(0.2)]
287 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)]
288 | blocks += [nn.LeakyReLU(0.2)]
289 | blocks += [nn.Conv2d(dim_out, num_domains, 1, 1, 0)]
290 | self.main = nn.Sequential(*blocks)
291 |
292 | def forward(self, x,lm_images):
293 | #out = self.main(x)
294 | x = torch.cat((x,lm_images), dim=-3)
295 | for block in self.main:
296 | x = block(x)
297 | out = x
298 | out = out.view(out.size(0), -1) # (batch, num_domains)
299 | return out
300 |
301 |
302 | def build_model(config):
303 | generator = Generator(config['img_size'], config['id_dim'], max_conv_dim=512)
304 | discriminator = Discriminator(config['img_size'],max_conv_dim=512)
305 | generator_ema = copy.deepcopy(generator)
306 | nets = Munch(generator=generator,discriminator=discriminator)
307 | nets_ema = Munch(generator=generator_ema)
308 | nets.fan = FAN(fname_pretrained=config['wing_path']).eval()
309 | nets_ema.fan = nets.fan
310 | return nets, nets_ema
--------------------------------------------------------------------------------
/core/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
8 | padding=1, bias=False)
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | def __init__(self, in_chan, out_chan, stride=1):
13 | super(BasicBlock, self).__init__()
14 | self.conv1 = conv3x3(in_chan, out_chan, stride)
15 | self.bn1 = nn.BatchNorm2d(out_chan)
16 | self.conv2 = conv3x3(out_chan, out_chan)
17 | self.bn2 = nn.BatchNorm2d(out_chan)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.downsample = None
20 | if in_chan != out_chan or stride != 1:
21 | self.downsample = nn.Sequential(
22 | nn.Conv2d(in_chan, out_chan,
23 | kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(out_chan),
25 | )
26 |
27 | def forward(self, x):
28 | residual = self.conv1(x)
29 | residual = F.relu(self.bn1(residual))
30 | residual = self.conv2(residual)
31 | residual = self.bn2(residual)
32 |
33 | shortcut = x
34 | if self.downsample is not None:
35 | shortcut = self.downsample(x)
36 |
37 | out = shortcut + residual
38 | out = self.relu(out)
39 | return out
40 |
41 |
42 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
43 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
44 | for i in range(bnum-1):
45 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
46 | return nn.Sequential(*layers)
47 |
48 |
49 | class Resnet18(nn.Module):
50 | def __init__(self):
51 | super(Resnet18, self).__init__()
52 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
53 | bias=False)
54 | self.bn1 = nn.BatchNorm2d(64)
55 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
56 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
57 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
58 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
59 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
60 |
61 | def forward(self, x):
62 | x = self.conv1(x)
63 | x = F.relu(self.bn1(x))
64 | x = self.maxpool(x)
65 |
66 | x = self.layer1(x)
67 | feat8 = self.layer2(x) # 1/8
68 | feat16 = self.layer3(feat8) # 1/16
69 | feat32 = self.layer4(feat16) # 1/32
70 | return feat8, feat16, feat32
71 |
72 | def get_params(self):
73 | wd_params, nowd_params = [], []
74 | for name, module in self.named_modules():
75 | if isinstance(module, (nn.Linear, nn.Conv2d)):
76 | wd_params.append(module.weight)
77 | if not module.bias is None:
78 | nowd_params.append(module.bias)
79 | elif isinstance(module, nn.BatchNorm2d):
80 | nowd_params += list(module.parameters())
81 | return wd_params, nowd_params
82 |
--------------------------------------------------------------------------------
/core/solver.py:
--------------------------------------------------------------------------------
1 | """
2 | This work is licensed under the Creative Commons Attribution-NonCommercial
3 | 4.0 International License. To view a copy of this license, visit
4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
6 | """
7 |
8 | import os
9 | from os.path import join as ospj
10 | import time
11 | import datetime
12 | from munch import Munch
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import tensorboardX
17 | from core.model import build_model
18 | from core.checkpoint import CheckpointIO
19 | from core.data_loader import InputFetcher
20 | import core.utils as utils
21 | from core.face_model import Backbone
22 |
23 |
24 |
25 | class Solver(nn.Module):
26 | def __init__(self, config):
27 | super().__init__()
28 | self.config = config
29 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30 | self.nets, self.nets_ema = build_model(config)
31 | self.arcface = Backbone(50, 0.6, 'ir_se') # .to(device)
32 | self.arcface.eval()
33 | self.arcface.load_state_dict(torch.load(config['face_model_path'])) # , strict=False
34 | # below setattrs are to make networks be children of Solver, e.g., for self.to(self.device)
35 | for name, module in self.nets.items():
36 | utils.print_network(module, name)
37 | setattr(self, name, module)
38 | for name, module in self.nets_ema.items():
39 | setattr(self, name + '_ema', module)
40 | if config['mode'] == 'train':
41 | print(config)
42 | beta1 = config['beta1']
43 | beta2 = config['beta2']
44 | dis_params = list(self.nets['discriminator'].parameters())
45 | id_params = list(self.nets['generator'].id_encoder.parameters())
46 | dict_id_params = list(map(id, self.nets['generator'].id_encoder.parameters())) # map is a function
47 | gen_params_wo_id = filter(lambda x: id(x) not in dict_id_params, self.nets['generator'].parameters())
48 | gen_id_params = []
49 | for p in id_params:
50 | if p.requires_grad:
51 | gen_id_params.append(p)
52 | gen_params = []
53 | for p in gen_params_wo_id:
54 | if p.requires_grad:
55 | gen_params.append(p)
56 |
57 | for net in self.nets.keys():
58 | if net == 'generator':
59 | self.gen_opt = torch.optim.Adam([{'params': gen_params},
60 | {'params': gen_id_params,'lr':config['id_lr']}],
61 | lr=config['lr'],
62 | betas=[beta1, beta2],
63 | weight_decay=config['weight_decay'])
64 | elif net == 'discriminator':
65 | self.dis_opt = torch.optim.Adam(
66 | [p for p in dis_params if p.requires_grad],
67 | lr=config['lr'],
68 | betas=[beta1, beta2],
69 | weight_decay=config['weight_decay'])
70 | self.optims = Munch()
71 | self.optims['generator'] = self.gen_opt
72 | self.optims['discriminator'] = self.dis_opt
73 | self.ckptios = [
74 | CheckpointIO(ospj(config['checkpoint_dir'], '{:06d}_nets.ckpt'), **self.nets),
75 | CheckpointIO(ospj(config['checkpoint_dir'], '{:06d}_nets_ema.ckpt'), **self.nets_ema),
76 | CheckpointIO(ospj(config['checkpoint_dir'], '{:06d}_optims.ckpt'), **self.optims)]
77 | else:
78 | self.ckptios = [CheckpointIO(config['test_checkpoint_dir'], **self.nets_ema)]
79 | self.to(self.device)
80 | for name, network in self.named_children():
81 | # Do not initialize the pretrained network parameters
82 | if ('ema' not in name) and ('fan' not in name) and ('arcface' not in name):
83 | print('Initializing %s...' % name)
84 | network.apply(utils.he_init)
85 | # Setup logger and output folders
86 | model_name = config['dataset']
87 | timestamp = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.datetime.now())
88 | self.train_writer = tensorboardX.SummaryWriter(os.path.join(config['log_dir'] + model_name, timestamp))
89 | def _save_checkpoint(self, step):
90 | for ckptio in self.ckptios:
91 | ckptio.save(step)
92 | def _load_checkpoint(self, step):
93 | for ckptio in self.ckptios:
94 | ckptio.load(step)
95 | def _load_test_checkpoint(self, ckptname):
96 | for ckptio in self.ckptios:
97 | ckptio.load_test(ckptname)
98 | def _reset_grad(self):
99 | for optim in self.optims.values():
100 | optim.zero_grad()
101 |
102 | def train(self, loaders):
103 | config = self.config
104 | nets = self.nets
105 | nets_ema = self.nets_ema
106 | gen_opt = self.gen_opt
107 | dis_opt = self.dis_opt
108 | arcface = self.arcface
109 | fetcher = InputFetcher(loaders.src, 'train')
110 | inputs_val = next(fetcher)
111 | # resume training if necessary
112 | if config['resume_iter'] > 0:
113 | self._load_checkpoint(config['resume_iter'])
114 | print('Start training...')
115 | start_time = time.time()
116 | for i in range(config['resume_iter'], config['total_iters']):
117 | # fetch images
118 | inputs = next(fetcher)
119 | src, tar, src_lm, tar_lm, src_mask, tar_mask= inputs.src, inputs.tar, inputs.src_lm, \
120 | inputs.tar_lm, inputs.src_mask, inputs.tar_mask
121 | # train the discriminator
122 | d_loss, d_losses_all = compute_d_loss(
123 | nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask)
124 | self._reset_grad()
125 | d_loss.backward()
126 | dis_opt.step()
127 | # train the generator
128 | g_loss, g_losses_all = compute_g_loss(
129 | nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, arcface)
130 | self._reset_grad()
131 | g_loss.backward()
132 | gen_opt.step()
133 | # compute moving average of network parameters
134 | moving_average(nets.generator, nets_ema.generator, beta=0.999)
135 | if (i+1) % config['print_every'] == 0:
136 | elapsed = time.time() - start_time
137 | elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
138 | log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, config['total_iters'])
139 | all_losses = dict()
140 | for loss, prefix in zip([d_losses_all, g_losses_all],
141 | ['D/all_', 'G/all_']):
142 | for key, value in loss.items():
143 | all_losses[prefix + key] = value
144 | self.train_writer.add_scalar(prefix+key, value, i+1)
145 | log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
146 | print(log)
147 | # generate images for observation
148 | if (i+1) % config['sample_every'] == 0:
149 | os.makedirs(config['sample_dir'], exist_ok=True)
150 | utils.display_image(nets_ema, config, inputs=inputs_val, step=i+1)
151 | # save model checkpoints
152 | if (i+1) % config['save_every'] == 0:
153 | self._save_checkpoint(step=i+1)
154 | self.train_writer.close()
155 |
156 | @torch.no_grad()
157 | def test(self, loaders):
158 | config = self.config
159 | nets_ema = self.nets_ema
160 | os.makedirs(config['result_dir'], exist_ok=True)
161 | self._load_test_checkpoint(config['test_checkpoint_name'])
162 | f = open(config['test_img_list'], 'r')
163 | img_num = len(f.readlines())
164 | f.close()
165 | total_iters = int(img_num/config['batch_size']) + 1
166 | save_dir=config['result_dir']
167 | test_fetcher = InputFetcher(loaders.src, 'test')
168 | for i in range(0, total_iters):
169 | inputs = next(test_fetcher)
170 | utils.disentangle_and_swapping_test(nets_ema, config, inputs, save_dir)
171 |
172 | def compute_d_loss(nets, config, x_a, x_b, x_a_lm, x_b_lm, x_a_mask, x_b_mask):
173 | x_a.requires_grad_()
174 | x_b.requires_grad_()
175 | out_a = nets.discriminator(x_a,x_a_lm)
176 | out_b = nets.discriminator(x_b,x_b_lm)
177 | loss_real_a = adv_loss(out_a, 1)
178 | loss_real_b = adv_loss(out_b, 1)
179 | loss_reg_a = r1_reg(out_a, x_a)
180 | loss_reg_b = r1_reg(out_b, x_b)
181 | loss_real = loss_real_a + loss_real_b
182 | loss_reg = loss_reg_a + loss_reg_b
183 | x_ba, x_ab, ms_features_ba, ms_features_ab, ms_outputs_ba,\
184 | ms_outputs_ab = nets.generator(x_a, x_b, x_a_lm, x_b_lm, x_a_mask, x_b_mask)
185 | out_ba = nets.discriminator(x_ba, x_b_lm) # x_ba, a's identity
186 | out_ab = nets.discriminator(x_ab, x_a_lm) # x_ab, b's identity
187 | loss_fake_ba = adv_loss(out_ba, 0)
188 | loss_fake_ab = adv_loss(out_ab, 0)
189 | loss_fake = loss_fake_ba + loss_fake_ab
190 | loss = loss_real + loss_fake + config['lambda_reg'] * loss_reg
191 | return loss, Munch(real=loss_real.item(),
192 | fake=loss_fake.item(),
193 | reg=loss_reg.item(),
194 | total_loss=loss.item())
195 |
196 | def compute_g_loss(nets, config, x_a, x_b, x_a_lm, x_b_lm, x_a_mask, x_b_mask, arcface):
197 | att_a, i_a_prime, cache_a = nets.generator.encode(x_a, x_a_mask)
198 | att_b, i_b_prime, cache_b = nets.generator.encode(x_b, x_b_mask)
199 | x_a_recon, ms_features_a, ms_outputs_a = nets.generator.decode(att_a, i_a_prime, x_a_lm, cache_a, x_a_mask)
200 | x_b_recon, ms_features_b, ms_outputs_b = nets.generator.decode(att_b, i_b_prime, x_b_lm, cache_b, x_b_mask)
201 | x_ba, ms_features_ba, ms_outputs_ba = nets.generator.decode(att_b, i_a_prime, x_b_lm, cache_b, x_b_mask) # x_ba, a's identity
202 | x_ab, ms_features_ab, ms_outputs_ab = nets.generator.decode(att_a, i_b_prime, x_a_lm, cache_a, x_a_mask) # x_ab, b's identity
203 | with torch.no_grad():
204 | a_embed, a_feats = arcface(F.interpolate(x_a, [112, 112], mode='bilinear', align_corners=True))
205 | b_embed, b_feats = arcface( F.interpolate(x_b, [112, 112], mode='bilinear', align_corners=True))
206 | ba_embed, ba_feats = arcface(F.interpolate(x_ba, [112, 112], mode='bilinear', align_corners=True))
207 | ab_embed, ab_feats = arcface(F.interpolate(x_ab, [112, 112], mode='bilinear', align_corners=True))
208 | loss_id_a = (1 - F.cosine_similarity(a_embed, ba_embed, dim=1))
209 | loss_id_b=(1 - F.cosine_similarity(b_embed, ab_embed, dim=1))
210 | loss_id_a = loss_id_a.mean()
211 | loss_id_b = loss_id_b.mean()
212 | loss_id = loss_id_a + loss_id_b
213 | out_a = nets.discriminator(x_ba, x_b_lm) #x_ba, a's identity
214 | out_b = nets.discriminator(x_ab, x_a_lm) #x_ab, b's identity
215 | loss_adv_a = adv_loss(out_a, 1)
216 | loss_adv_b = adv_loss(out_b, 1)
217 | loss_adv = loss_adv_a + loss_adv_b
218 | loss_recon_x_a = torch.mean(torch.abs(x_a_recon - x_a))
219 | loss_recon_x_b = torch.mean(torch.abs(x_b_recon - x_b))
220 | loss_recon = loss_recon_x_a + loss_recon_x_b
221 | loss_att_a_face = style_loss_face(ms_features_a, ms_features_ab, x_a_mask)
222 | loss_att_b_face = style_loss_face(ms_features_b, ms_features_ba, x_b_mask)
223 | loss_att_face = loss_att_a_face + loss_att_b_face
224 | loss_att_a_bg = style_loss_background(ms_features_a, ms_features_ab,x_a_mask)
225 | loss_att_b_bg = style_loss_background(ms_features_b, ms_features_ba,x_b_mask)
226 | loss_att_bg = loss_att_a_bg + loss_att_b_bg
227 | loss = loss_adv + config['lambda_id'] * loss_id \
228 | + config['lambda_recon'] * loss_recon + config['lambda_att_face'] * loss_att_face \
229 | + config['lambda_att_bg'] * loss_att_bg
230 | return loss, Munch(adv=loss_adv.item(),
231 | id=loss_id.item(),
232 | recon=loss_recon.item(),
233 | att_face=loss_att_face.item(),
234 | att_bg=loss_att_bg.item(),
235 | total_loss=loss.item())
236 |
237 | def moving_average(model, model_test, beta=0.999):
238 | for param, param_test in zip(model.parameters(), model_test.parameters()):
239 | param_test.data = torch.lerp(param.data, param_test.data, beta)
240 |
241 | def adv_loss(logits, target):
242 | assert target in [1, 0]
243 | targets = torch.full_like(logits, fill_value=target)
244 | loss = F.binary_cross_entropy_with_logits(logits, targets)
245 | return loss
246 |
247 | def r1_reg(d_out, x_in):
248 | # zero-centered gradient penalty for real images
249 | batch_size = x_in.size(0)
250 | grad_dout = torch.autograd.grad(
251 | outputs=d_out.sum(), inputs=x_in,
252 | create_graph=True, retain_graph=True, only_inputs=True
253 | )[0]
254 | grad_dout2 = grad_dout.pow(2)
255 | assert(grad_dout2.size() == x_in.size())
256 | reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
257 | return reg
258 |
259 | def compute_gram(x):
260 | b, ch, h, w = x.size()
261 | f = x.view(b, ch, w * h)
262 | f_T = f.transpose(1, 2)
263 | G = f.bmm(f_T) / (h * w * ch)
264 | return G
265 |
266 | def style_loss_face(x, y,masks):
267 | style_loss = 0.0
268 | for i in range(0, len(x)):
269 | masks = F.interpolate(masks, x[i].size(2), mode='bilinear')
270 | style_loss += torch.nn.L1Loss()(compute_gram((1-masks)*x[i]), compute_gram((1-masks)*y[i]))
271 | style_loss = style_loss/len(x)
272 | return style_loss
273 |
274 | def style_loss_background(x, y, masks):
275 | style_loss = 0.0
276 | for i in range(0, len(x)):
277 | masks = F.interpolate(masks,x[i].size(2),mode='bilinear')
278 | style_loss += torch.nn.L1Loss()(masks*x[i], masks*y[i]) # pay attention to this form
279 | style_loss = style_loss/len(x)
280 | return style_loss
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This work is licensed under the Creative Commons Attribution-NonCommercial
3 | 4.0 International License. To view a copy of this license, visit
4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
6 | """
7 |
8 | import os
9 | from os.path import join as ospj
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torchvision.utils as vutils
14 | import yaml
15 |
16 | def print_network(network, name):
17 | num_params = 0
18 | for p in network.parameters():
19 | num_params += p.numel()
20 | print("Number of parameters of %s: %i" % (name, num_params))
21 | def he_init(module):
22 | if isinstance(module, nn.Conv2d):
23 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
24 | if module.bias is not None:
25 | nn.init.constant_(module.bias, 0)
26 | if isinstance(module, nn.Linear):
27 | nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
28 | if module.bias is not None:
29 | nn.init.constant_(module.bias, 0)
30 | def denormalize(x):
31 | out = (x + 1) / 2
32 | return out.clamp_(0, 1)
33 | def save_image(x, ncol, filename):
34 | x = denormalize(x)
35 | vutils.save_image(x.cpu(), filename, nrow=ncol, padding=0)
36 | def get_config(config):
37 | with open(config, 'r') as stream:
38 | return yaml.load(stream,Loader=yaml.FullLoader)
39 |
40 |
41 | @torch.no_grad()
42 | def disentangle_and_reconstruct(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename):
43 | N, C, H, W = src.size()
44 | att_a, i_a_prime, cache_a = nets.generator.encode(src, src_mask) # cache follows attr
45 | att_b, i_b_prime, cache_b = nets.generator.encode(tar, tar_mask)
46 | a_recon,_,_ = nets.generator.decode(att_a, i_a_prime, src_lm, cache_a, src_mask)
47 | b_recon,_,_ = nets.generator.decode(att_b, i_b_prime, tar_lm, cache_b, tar_mask)
48 | disp_concat = [src, tar, a_recon, b_recon]
49 | disp_concat = torch.cat(disp_concat, dim=0)
50 | save_image(disp_concat, N, filename)
51 |
52 | @torch.no_grad()
53 | def disentangle_and_swapping(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename):
54 | N, C, H, W = src.size()
55 | whitespace = torch.ones(1, C, H, W).to(src.device)
56 | src_with_whitespace = torch.cat([whitespace, src], dim=0)
57 | disp_concat = [src_with_whitespace]
58 | for i, tar_img in enumerate(tar):
59 | tar_imgs = tar_img.repeat(N,1,1,1)
60 | disp_img = tar_img.repeat(1,1,1,1)
61 | tar_i_lm = tar_lm[i,:,:,:]
62 | tar_i_lm = tar_i_lm.repeat(N,1,1,1)
63 | tar_i_mask = tar_mask[i,:,:,:]
64 | tar_i_mask = tar_i_mask.repeat(N,1,1,1)
65 | srcid_taratt, tarid_srcatt,_,_,_,_ = nets.generator(src, tar_imgs, src_lm,
66 | tar_i_lm, src_mask, tar_i_mask)
67 | fake_srcid_taratt = torch.cat([disp_img, srcid_taratt], dim=0)
68 | fake_tarid_srcatt = torch.cat([disp_img, tarid_srcatt], dim=0)
69 | disp_concat += [fake_srcid_taratt]
70 | disp_concat += [fake_tarid_srcatt]
71 | disp_concat = torch.cat(disp_concat, dim=0)
72 | save_image(disp_concat, N+1, filename)
73 |
74 | @torch.no_grad()
75 | def display_image(nets, config, inputs, step):
76 | src, tar, src_lm, tar_lm, src_mask, tar_mask = inputs.src, inputs.tar, \
77 | inputs.src_lm, inputs.tar_lm, inputs.src_mask, inputs.tar_mask
78 | # face reconstruction
79 | filename = ospj(config['sample_dir'], '%06d_reconstruction.jpg' % (step))
80 | disentangle_and_reconstruct(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename)
81 | # face swapping
82 | filename = ospj(config['sample_dir'], '%06d_faceswapping.jpg' % (step))
83 | disentangle_and_swapping(nets, config, src, tar, src_lm, tar_lm, src_mask, tar_mask, filename)
84 |
85 | @torch.no_grad()
86 | def disentangle_and_swapping_test(nets, config, inputs, save_dir):
87 | post_process = config['post_process']
88 | src, tar, src_lm, tar_lm, src_mask, tar_mask, tar_parsing, src_name, tar_name = inputs.src, inputs.tar, \
89 | inputs.src_lm, inputs.tar_lm, inputs.src_mask, inputs.tar_mask, inputs.tar_parsing, inputs.src_name, inputs.tar_name,
90 | src_mask = F.interpolate(src_mask, src.size(2), mode='bilinear', align_corners=True)
91 | tar_mask = F.interpolate(tar_mask, src.size(2), mode='bilinear', align_corners=True)
92 | srcid_taratt, tarid_srcatt,_,_,_,_ = nets.generator(src, tar, src_lm, tar_lm, src_mask, tar_mask) #modified by liqi
93 | result_first = save_dir + 'swapped_result_single/'
94 | result_second = save_dir + 'swapped_result_afterps/'
95 | result_third = save_dir + 'swapped_result_all/'
96 | if not os.path.exists(result_first):
97 | os.makedirs(result_first)
98 | if not os.path.exists(result_second):
99 | os.makedirs(result_second)
100 | if not os.path.exists(result_third):
101 | os.makedirs(result_third)
102 | if post_process:
103 | src_convex_hull = nets.fan.get_convex_hull(src)
104 | tar_convex_hull = nets.fan.get_convex_hull(tar)
105 | temp_src_forehead = src_convex_hull - src_mask
106 | temp_tar_forehead = tar_convex_hull - tar_mask
107 | # to ensure the values of src_forehead and tar_forehead are in [0,1]
108 | one_tensor = torch.ones(temp_src_forehead.size()).to(device=temp_src_forehead.device)
109 | zero_tensor = torch.zeros(temp_src_forehead.size()).to(device=temp_src_forehead.device)
110 | temp_var = torch.where(temp_src_forehead >= 1.0, one_tensor, temp_src_forehead)
111 | src_forehead = torch.where(temp_var <= 0.0, zero_tensor, temp_var)
112 | temp_var = torch.where(temp_tar_forehead >= 1.0, one_tensor, temp_tar_forehead)
113 | tar_forehead = torch.where(temp_var <= 0.0, zero_tensor, temp_var)
114 | tar_hair = get_hair(tar_parsing)
115 | post_result = postprocess(tar, srcid_taratt, tar_hair, src_forehead, tar_forehead)
116 | for i in range(len(srcid_taratt)):
117 | filename = result_first + src_name[i][0:-4]+'_FS_'+ tar_name[i][0:-5]+'.png'
118 | filename_post = result_second + src_name[i][0:-4] + '_FS_' + tar_name[i][0:-5] + '.png'
119 | filename_all = result_third + src_name[i][0:-4] + '_FS_' + tar_name[i][0:-5] + '.png'
120 | save_image(srcid_taratt[i,:,:,:], 1, filename)
121 | if post_process:
122 | save_image(post_result[i, :, :, :], 1, filename_post)
123 | x_concat = torch.cat([src[i].unsqueeze(0), tar[i].unsqueeze(0),
124 | srcid_taratt[i, :, :, :].unsqueeze(0),post_result[i, :, :, :].unsqueeze(0)], dim=0)
125 | save_image(x_concat, 4, filename_all)
126 | else:
127 | x_concat = torch.cat([src[i].unsqueeze(0), tar[i].unsqueeze(0),
128 | srcid_taratt[i, :, :, :].unsqueeze(0)], dim=0)
129 | save_image(x_concat, 3, filename_all)
130 |
131 |
132 | def get_hair(segmentation):
133 | out = segmentation.mul_(255).int()
134 | mask_ind_hair = [17]
135 | with torch.no_grad():
136 | out_parse = out
137 | hair = torch.ones((out_parse.shape[0], 1, out_parse.shape[2], out_parse.shape[3])).cuda()
138 | for pi in mask_ind_hair:
139 | index = torch.where(out_parse == pi)
140 | hair[index[0], :, index[2],index[3]] = 0
141 | return hair
142 |
143 |
144 | def postprocess(tar, srcid_taratt, tar_hair, src_forehead, tar_forehead):
145 | #inner area of tar_hair is 0, inner area of tar_forehead is 1
146 | smooth_mask = SoftErosion(kernel_size=17, threshold=0.9, iterations=7).cuda()
147 | one_tensor = torch.ones(tar_forehead.size()).to(device=tar_forehead.device)
148 | temp_tar_hair_and_forehead = (1-tar_hair) + tar_forehead
149 | tar_hair_and_forehead = torch.where(temp_tar_hair_and_forehead >= 1.0, one_tensor, temp_tar_hair_and_forehead)
150 | tar_preserve = tar_hair
151 | # find whether occlusion exists in source image; if exists, then preserve the hair and forehead of the target image
152 | for i in range(src_forehead.size(0)):
153 | src_forehead_i = src_forehead[i,:,:,:]
154 | src_forehead_i = src_forehead_i .squeeze_()
155 | tar_forehead_i = tar_forehead[i,:,:,:]
156 | tar_forehead_i = tar_forehead_i.squeeze_()
157 | H1,W1 = torch.nonzero(src_forehead_i).size()
158 | H2,W2 = torch.nonzero(tar_forehead_i).size()
159 | if (H1 * W1) / (H2 * W2 + 0.0001) < 0.4 and (H2 * W2) >= 1000: #
160 | tar_preserve[i,:,:,:] = 1 - tar_hair_and_forehead[i,:,:,:]
161 | soft_mask, _ = smooth_mask(tar_preserve)
162 | result = srcid_taratt * soft_mask + tar * (1-soft_mask)
163 | return result
164 |
165 |
166 | class SoftErosion(nn.Module):
167 | def __init__(self, kernel_size=15, threshold=0.6, iterations=1):
168 | super(SoftErosion, self).__init__()
169 | r = kernel_size // 2
170 | self.padding = r
171 | self.iterations = iterations
172 | self.threshold = threshold
173 | # Create kernel
174 | y_indices, x_indices = torch.meshgrid(torch.arange(0., kernel_size), torch.arange(0., kernel_size))
175 | dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2)
176 | kernel = dist.max() - dist
177 | kernel /= kernel.sum()
178 | kernel = kernel.view(1, 1, *kernel.shape)
179 | self.register_buffer('weight', kernel)
180 |
181 | def forward(self, x):
182 | x = x.float()
183 | for i in range(self.iterations - 1):
184 | x = torch.min(x, F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding))
185 | x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding)
186 | mask = x >= self.threshold
187 | x[mask] = 1.0
188 | x[~mask] /= x[~mask].max()
189 | return x, mask
--------------------------------------------------------------------------------
/core/wing.py:
--------------------------------------------------------------------------------
1 | """
2 | This work is licensed under the Creative Commons Attribution-NonCommercial
3 | 4.0 International License. To view a copy of this license, visit
4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
6 |
7 | """
8 |
9 |
10 | from functools import partial
11 | import numpy as np
12 | import cv2
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | from matplotlib import pyplot as plt
17 |
18 | def normalize(x, eps=1e-6):
19 | """Apply min-max normalization."""
20 | x = x.contiguous()
21 | N, C, H, W = x.size()
22 | x_ = x.view(N*C, -1)
23 | max_val = torch.max(x_, dim=1, keepdim=True)[0]
24 | min_val = torch.min(x_, dim=1, keepdim=True)[0]
25 | x_ = (x_ - min_val) / (max_val - min_val + eps)
26 | out = x_.view(N, C, H, W)
27 | return out
28 |
29 | def truncate(x, thres=0.1):
30 | """Remove small values in heatmaps."""
31 | return torch.where(x < thres, torch.zeros_like(x), x)
32 |
33 | def resize(x, p=2):
34 | """Resize heatmaps."""
35 | return x**p
36 |
37 | def shift(x, N):
38 | """Shift N pixels up or down."""
39 | up = N >= 0
40 | N = abs(N)
41 | _, _, H, W = x.size()
42 | head = torch.arange(N)
43 | tail = torch.arange(H-N)
44 |
45 | if up:
46 | head = torch.arange(H-N)+N
47 | tail = torch.arange(N)
48 | else:
49 | head = torch.arange(N) + (H-N)
50 | tail = torch.arange(H-N)
51 |
52 | # permutation indices
53 | perm = torch.cat([head, tail]).to(x.device)
54 | out = x[:, :, perm, :]
55 | return out
56 |
57 | def get_preds_fromhm(hm):
58 | max, idx = torch.max(
59 | hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
60 | idx += 1
61 | preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
62 | preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
63 | preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
64 |
65 | for i in range(preds.size(0)):
66 | for j in range(preds.size(1)):
67 | hm_ = hm[i, j, :]
68 | pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
69 | if pX > 0 and pX < 63 and pY > 0 and pY < 63:
70 | diff = torch.FloatTensor(
71 | [hm_[pY, pX + 1] - hm_[pY, pX - 1],
72 | hm_[pY + 1, pX] - hm_[pY - 1, pX]])
73 | preds[i, j].add_(diff.sign_().mul_(.25))
74 |
75 | preds.add_(-0.5)
76 | return preds
77 |
78 | def curve_fill(points, heatmapSize=256, sigma=3, erode=False):
79 | sigma = max(1,(sigma // 2)*2 + 1)
80 | points = points.astype(np.int32)
81 | canvas = np.zeros([heatmapSize, heatmapSize])
82 | cv2.fillPoly(canvas,np.array([points]),255)
83 | canvas = cv2.GaussianBlur(canvas, (sigma, sigma), sigma)
84 | return canvas.astype(np.float64)/255.0
85 |
86 | class HourGlass(nn.Module):
87 | def __init__(self, num_modules, depth, num_features, first_one=False):
88 | super(HourGlass, self).__init__()
89 | self.num_modules = num_modules
90 | self.depth = depth
91 | self.features = num_features
92 | self.coordconv = CoordConvTh(64, 64, True, True, 256, first_one,
93 | out_channels=256,
94 | kernel_size=1, stride=1, padding=0)
95 | self._generate_network(self.depth)
96 |
97 | def _generate_network(self, level):
98 | self.add_module('b1_' + str(level), ConvBlock(256, 256))
99 | self.add_module('b2_' + str(level), ConvBlock(256, 256))
100 | if level > 1:
101 | self._generate_network(level - 1)
102 | else:
103 | self.add_module('b2_plus_' + str(level), ConvBlock(256, 256))
104 | self.add_module('b3_' + str(level), ConvBlock(256, 256))
105 |
106 | def _forward(self, level, inp):
107 | up1 = inp
108 | up1 = self._modules['b1_' + str(level)](up1)
109 | low1 = F.avg_pool2d(inp, 2, stride=2)
110 | low1 = self._modules['b2_' + str(level)](low1)
111 |
112 | if level > 1:
113 | low2 = self._forward(level - 1, low1)
114 | else:
115 | low2 = low1
116 | low2 = self._modules['b2_plus_' + str(level)](low2)
117 | low3 = low2
118 | low3 = self._modules['b3_' + str(level)](low3)
119 | up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
120 |
121 | return up1 + up2
122 |
123 | def forward(self, x, heatmap):
124 | x, last_channel = self.coordconv(x, heatmap)
125 | return self._forward(self.depth, x), last_channel
126 |
127 |
128 | class AddCoordsTh(nn.Module):
129 | def __init__(self, height=64, width=64, with_r=False, with_boundary=False):
130 | super(AddCoordsTh, self).__init__()
131 | self.with_r = with_r
132 | self.with_boundary = with_boundary
133 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
134 |
135 | with torch.no_grad():
136 | x_coords = torch.arange(height).unsqueeze(1).expand(height, width).float()
137 | y_coords = torch.arange(width).unsqueeze(0).expand(height, width).float()
138 | x_coords = (x_coords / (height - 1)) * 2 - 1
139 | y_coords = (y_coords / (width - 1)) * 2 - 1
140 | coords = torch.stack([x_coords, y_coords], dim=0)
141 |
142 | if self.with_r:
143 | rr = torch.sqrt(torch.pow(x_coords, 2) + torch.pow(y_coords, 2))
144 | rr = (rr / torch.max(rr)).unsqueeze(0)
145 | coords = torch.cat([coords, rr], dim=0)
146 |
147 | self.coords = coords.unsqueeze(0).to(device)
148 | self.x_coords = x_coords.to(device)
149 | self.y_coords = y_coords.to(device)
150 |
151 | def forward(self, x, heatmap=None):
152 | """
153 | x: (batch, c, x_dim, y_dim)
154 | """
155 | coords = self.coords.repeat(x.size(0), 1, 1, 1)
156 |
157 | if self.with_boundary and heatmap is not None:
158 | boundary_channel = torch.clamp(heatmap[:, -1:, :, :], 0.0, 1.0)
159 | zero_tensor = torch.zeros_like(self.x_coords)
160 | xx_boundary_channel = torch.where(boundary_channel > 0.05, self.x_coords, zero_tensor).to(zero_tensor.device)
161 | yy_boundary_channel = torch.where(boundary_channel > 0.05, self.y_coords, zero_tensor).to(zero_tensor.device)
162 | coords = torch.cat([coords, xx_boundary_channel, yy_boundary_channel], dim=1)
163 |
164 | x_and_coords = torch.cat([x, coords], dim=1)
165 | return x_and_coords
166 |
167 |
168 | class CoordConvTh(nn.Module):
169 | """CoordConv layer as in the paper."""
170 | def __init__(self, height, width, with_r, with_boundary,
171 | in_channels, first_one=False, *args, **kwargs):
172 | super(CoordConvTh, self).__init__()
173 | self.addcoords = AddCoordsTh(height, width, with_r, with_boundary)
174 | in_channels += 2
175 | if with_r:
176 | in_channels += 1
177 | if with_boundary and not first_one:
178 | in_channels += 2
179 | self.conv = nn.Conv2d(in_channels=in_channels, *args, **kwargs)
180 |
181 | def forward(self, input_tensor, heatmap=None):
182 | ret = self.addcoords(input_tensor, heatmap)
183 | last_channel = ret[:, -2:, :, :]
184 | ret = self.conv(ret)
185 | return ret, last_channel
186 |
187 |
188 | class ConvBlock(nn.Module):
189 | def __init__(self, in_planes, out_planes):
190 | super(ConvBlock, self).__init__()
191 | self.bn1 = nn.BatchNorm2d(in_planes)
192 | conv3x3 = partial(nn.Conv2d, kernel_size=3, stride=1, padding=1, bias=False, dilation=1)
193 | self.conv1 = conv3x3(in_planes, int(out_planes / 2))
194 | self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
195 | self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
196 | self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
197 | self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
198 |
199 | self.downsample = None
200 | if in_planes != out_planes:
201 | self.downsample = nn.Sequential(nn.BatchNorm2d(in_planes),
202 | nn.ReLU(True),
203 | nn.Conv2d(in_planes, out_planes, 1, 1, bias=False))
204 |
205 | def forward(self, x):
206 | residual = x
207 |
208 | out1 = self.bn1(x)
209 | out1 = F.relu(out1, True)
210 | out1 = self.conv1(out1)
211 |
212 | out2 = self.bn2(out1)
213 | out2 = F.relu(out2, True)
214 | out2 = self.conv2(out2)
215 |
216 | out3 = self.bn3(out2)
217 | out3 = F.relu(out3, True)
218 | out3 = self.conv3(out3)
219 |
220 | out3 = torch.cat((out1, out2, out3), 1)
221 | if self.downsample is not None:
222 | residual = self.downsample(residual)
223 | out3 += residual
224 | return out3
225 |
226 |
227 | class FAN(nn.Module):
228 | def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None):
229 | super(FAN, self).__init__()
230 | self.num_modules = num_modules
231 | self.end_relu = end_relu
232 |
233 | # Base part
234 | self.conv1 = CoordConvTh(256, 256, True, False,
235 | in_channels=3, out_channels=64,
236 | kernel_size=7, stride=2, padding=3)
237 | self.bn1 = nn.BatchNorm2d(64)
238 | self.conv2 = ConvBlock(64, 128)
239 | self.conv3 = ConvBlock(128, 128)
240 | self.conv4 = ConvBlock(128, 256)
241 |
242 | # Stacking part
243 | self.add_module('m0', HourGlass(1, 4, 256, first_one=True))
244 | self.add_module('top_m_0', ConvBlock(256, 256))
245 | self.add_module('conv_last0', nn.Conv2d(256, 256, 1, 1, 0))
246 | self.add_module('bn_end0', nn.BatchNorm2d(256))
247 | self.add_module('l0', nn.Conv2d(256, num_landmarks+1, 1, 1, 0))
248 |
249 | if fname_pretrained is not None:
250 | self.load_pretrained_weights(fname_pretrained)
251 |
252 | def load_pretrained_weights(self, fname):
253 | if torch.cuda.is_available():
254 | checkpoint = torch.load(fname)
255 | else:
256 | checkpoint = torch.load(fname, map_location=torch.device('cpu'))
257 | model_weights = self.state_dict()
258 | model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
259 | if k in model_weights})
260 | self.load_state_dict(model_weights)
261 |
262 | def forward(self, x):
263 | x, _ = self.conv1(x)
264 | x = F.relu(self.bn1(x), True)
265 | x = F.avg_pool2d(self.conv2(x), 2, stride=2)
266 | x = self.conv3(x)
267 | x = self.conv4(x)
268 |
269 | outputs = []
270 | boundary_channels = []
271 | tmp_out = None
272 | ll, boundary_channel = self._modules['m0'](x, tmp_out)
273 | ll = self._modules['top_m_0'](ll)
274 | ll = F.relu(self._modules['bn_end0']
275 | (self._modules['conv_last0'](ll)), True)
276 |
277 | # Predict heatmaps
278 | tmp_out = self._modules['l0'](ll)
279 | if self.end_relu:
280 | tmp_out = F.relu(tmp_out) # HACK: Added relu
281 | outputs.append(tmp_out)
282 | boundary_channels.append(boundary_channel)
283 | return outputs, boundary_channels
284 |
285 | @torch.no_grad()
286 | def get_heatmap(self, x):
287 | ''' outputs 0-1 normalized heatmap '''
288 | x = F.interpolate(x, size=256, mode='bilinear',align_corners=True)
289 | x_01 = x*0.5 + 0.5
290 | outputs, _ = self(x_01)
291 | heatmaps = outputs[-1][:, :-1, :, :]
292 | scale_factor = x.size(2) // heatmaps.size(2)
293 | heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor,
294 | mode='bilinear', align_corners=True)
295 | return heatmaps
296 |
297 |
298 | @torch.no_grad()
299 | def get_points2heatmap(self, x):
300 | ''' outputs landmarks of x.shape '''
301 | heatmaps = self.get_heatmap(x)
302 | landmarks = []
303 | for i in range(x.size(0)):
304 | pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0))
305 | landmarks.append(pred_landmarks)
306 | scale_factor = x.size(2) // heatmaps.size(2)
307 | landmarks = torch.cat(landmarks) * scale_factor
308 | heatmap_all=torch.zeros((len(x),7,x.size(2),x.size(2)))
309 | for i in range(0,len(x)):
310 | curves, boundary = self.points2curves(landmarks[i])
311 | heatmap = self.curves2segments(curves)
312 | heatmap = torch.from_numpy(heatmap).float()
313 | heatmap_all[i] = heatmap
314 | heatmap_all = heatmap_all.to(device=x.device)
315 | return heatmap_all, curves, boundary
316 |
317 | @torch.no_grad()
318 | def get_convex_hull(self, x):
319 | ''' outputs landmarks of x.shape '''
320 | heatmaps = self.get_heatmap(x)
321 | skins = []
322 | for i in range(x.size(0)):
323 | pred_landmark = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0))
324 | scale_factor = x.size(2) // heatmaps.size(2)
325 | pred_landmark = torch.squeeze(pred_landmark) * scale_factor
326 | curves_ref, _ = self.points2curves(pred_landmark)
327 | roi_ref = self.curves2segments(curves_ref)
328 | skin = roi_ref[0]
329 | skin = (skin > 0).astype(int)
330 | skin = 1- skin
331 | skin = torch.from_numpy(skin).type(dtype=torch.float32)
332 | skin = skin.unsqueeze(0)
333 | skin = skin.to(x.device)
334 | skins.append(skin)
335 | skins = torch.cat(skins)
336 | skins = skins.unsqueeze(1)
337 | return skins
338 |
339 | @torch.no_grad()
340 | def points2curves(self, points, heatmapSize=256, sigma=1,heatmap_num=15):
341 | curves = [0] * heatmap_num
342 | curves[0] = np.zeros((33, 2)) # contour
343 | curves[1] = np.zeros((5, 2)) # left top eyebrow
344 | curves[2] = np.zeros((5, 2)) # right top eyebrow
345 | curves[3] = np.zeros((4, 2)) # nose bridge
346 | curves[4] = np.zeros((5, 2)) # nose tip
347 | curves[5] = np.zeros((5, 2)) # left bottom eye
348 | curves[6] = np.zeros((5, 2)) # left bottom eye
349 | curves[7] = np.zeros((5, 2)) # right top eye
350 | curves[8] = np.zeros((5, 2)) # right bottom eye
351 | curves[9] = np.zeros((7, 2)) # up up lip
352 | curves[10] = np.zeros((5, 2)) # up bottom lip
353 | curves[11] = np.zeros((5, 2)) # bottom up lip
354 | curves[12] = np.zeros((7, 2)) # bottom bottom lip
355 | curves[13] = np.zeros((5, 2)) # left bottom eyebrow
356 | curves[14] = np.zeros((5, 2)) # left bottom eyebrow
357 | # assignment proccess
358 | # countour
359 | for i in range(33):
360 | curves[0][i] = points[i,:]
361 | for i in range(5):
362 | # left top eyebrow
363 | curves[1][i][0] = points[i + 33, 0] - 10
364 | curves[1][i][1] = points[i + 33, 1]-40
365 | curves[2][i][0] = points[i + 42, 0] + 10
366 | curves[2][i][1] = points[i + 42, 1]-40
367 | # nose bridge
368 | for i in range(4):
369 | curves[3][i] = points[i + 51,:]
370 | # nose tip
371 | for i in range(5):
372 | curves[4][i] = points[i + 55,:]
373 | # left top eye
374 | for i in range(5):
375 | curves[5][i] = points[i + 60,:]
376 | # left bottom eye
377 | curves[6][0] = points[64,:]
378 | curves[6][1] = points[65,:]
379 | curves[6][2] = points[66,:]
380 | curves[6][3] = points[67,:]
381 | curves[6][4] = points[60,:]
382 | # right top eye
383 | for i in range(5):
384 | curves[7][i] = points[i + 68,:]
385 | # right bottom eye
386 | curves[8][0] = points[72,:]
387 | curves[8][1] = points[73,:]
388 | curves[8][2] = points[74,:]
389 | curves[8][3] = points[75,:]
390 | curves[8][4] = points[68,:]
391 | # up up lip
392 | for i in range(7):
393 | curves[9][i] = points[i + 76,:]
394 | # up bottom lip
395 | for i in range(5):
396 | curves[10][i] = points[i + 88,:]
397 | # bottom up lip
398 | curves[11][0] = points[92,:]
399 | curves[11][1] = points[93,:]
400 | curves[11][2] = points[94,:]
401 | curves[11][3] = points[95,:]
402 | curves[11][4] = points[88,:]
403 | # bottom bottom lip
404 | curves[12][0] = points[82,:]
405 | curves[12][1] = points[83,:]
406 | curves[12][2] = points[84,:]
407 | curves[12][3] = points[85,:]
408 | curves[12][4] = points[86,:]
409 | curves[12][5] = points[87,:]
410 | curves[12][6] = points[76,:]
411 | # left bottom eyebrow
412 | curves[13][0] = points[38,:]
413 | curves[13][1] = points[39,:]
414 | curves[13][2] = points[40,:]
415 | curves[13][3] = points[41,:]
416 | curves[13][4] = points[33,:]
417 | # right bottom eyebrow
418 | curves[14][0] = points[46,:]
419 | curves[14][1] = points[47,:]
420 | curves[14][2] = points[48,:]
421 | curves[14][3] = points[49,:]
422 | curves[14][4] = points[50,:]
423 | return curves, None
424 |
425 | @torch.no_grad()
426 | def curves2segments(self, curves, heatmapSize=256, sigma=3):
427 | face = curve_fill(np.vstack([curves[0], curves[2][::-1], curves[1][::-1]]), heatmapSize, sigma)
428 | browL = curve_fill(np.vstack([curves[1], curves[13][::-1]]), heatmapSize, sigma)
429 | browR = curve_fill(np.vstack([curves[2], curves[14][::-1]]), heatmapSize, sigma)
430 | eyeL = curve_fill(np.vstack([curves[5], curves[6]]), heatmapSize, sigma)
431 | eyeR = curve_fill(np.vstack([curves[7], curves[8]]), heatmapSize, sigma)
432 | eye = np.max([eyeL, eyeR], axis=0)
433 | brow = np.max([browL, browR], axis=0)
434 | nose = curve_fill(np.vstack([curves[3][0:1], curves[4]]), heatmapSize, sigma)
435 | lipU = curve_fill(np.vstack([curves[9], curves[10][::-1]]), heatmapSize, sigma)
436 | lipD = curve_fill(np.vstack([curves[11], curves[12][::-1]]), heatmapSize, sigma)
437 | tooth = curve_fill(np.vstack([curves[10], curves[11][::-1]]), heatmapSize, sigma)
438 | return np.stack([face, brow, eye, nose, lipU, lipD, tooth])
439 |
440 | @torch.no_grad()
441 | def get_landmark_curve(self, x):
442 | heatmaps = self.get_heatmap(x)
443 | landmarks = []
444 | for i in range(x.size(0)):
445 | pred_landmarks = get_preds_fromhm(heatmaps[i].cpu().unsqueeze(0))
446 | landmarks.append(pred_landmarks)
447 | scale_factor = x.size(2) // heatmaps.size(2)
448 | landmarks = torch.cat(landmarks) * scale_factor
449 | batch_landmark_figure =[]
450 | for i in range(0,len(x)):
451 | dpi = 100
452 | input = x[i]
453 | preds = landmarks[i]
454 | fig = plt.figure(figsize=(input.shape[2] / dpi, input.shape[1] / dpi), dpi=dpi)
455 | ax = fig.add_subplot(1, 1, 1)
456 | ax.imshow(np.ones((input.shape[1],input.shape[2],input.shape[0])))
457 | plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
458 | # eye
459 | ax.plot(preds[60:68,0],preds[60:68,1],marker='',markersize=5,linestyle='-',color='red',lw=2)
460 | ax.plot(preds[68:76,0],preds[68:76,1],marker='',markersize=5,linestyle='-',color='red',lw=2)
461 | #outer and inner lip
462 | ax.plot(preds[76:88,0],preds[76:88,1],marker='',markersize=5,linestyle='-',color='green',lw=2)
463 | ax.plot(preds[88:96,0],preds[88:96,1],marker='',markersize=5,linestyle='-',color='blue',lw=2)
464 | ax.axis('off')
465 | fig.canvas.draw()
466 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
467 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
468 | batch_landmark_figure.append(data)
469 | plt.close(fig)
470 | return batch_landmark_figure
471 |
472 |
473 |
--------------------------------------------------------------------------------
/images/1:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/qualitative_comparisons.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/liqi-casia/FaceSwapper/fd999e0396c8b589c5a735bacd7e208918edc0c0/images/qualitative_comparisons.jpg
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | This work is licensed under the Creative Commons Attribution-NonCommercial
3 | 4.0 International License. To view a copy of this license, visit
4 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
5 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
6 | """
7 |
8 | import os
9 | import argparse
10 | import torch
11 | from torch.backends import cudnn
12 | from munch import Munch
13 | from core.data_loader import get_train_loader, get_test_loader
14 | from core.solver import Solver
15 | from core.utils import get_config
16 |
17 |
18 | def main(args):
19 | config = get_config(args.config)
20 | os.environ['CUDA_VISIBLE_DEVICES'] = config['cuda_device']
21 | cudnn.benchmark = True
22 | solver = Solver(config)
23 | if config['mode'] == 'train':
24 | loaders = Munch(src=get_train_loader(root=config['train_img_dir'],
25 | img_size=config['img_size'],
26 | batch_size=config['batch_size'],
27 | num_workers=config['num_workers']))
28 | solver.train(loaders)
29 | elif config['mode'] == 'test':
30 | loaders = Munch(src=get_test_loader(root=config['test_img_dir'],
31 | test_img_list=config['test_img_list'],
32 | img_size=config['img_size'],
33 | batch_size=config['batch_size'],
34 | num_workers=config['num_workers']))
35 | solver.test(loaders)
36 | else:
37 | raise NotImplementedError
38 |
39 |
40 | if __name__ == '__main__':
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument('--config', type=str, default='param.yaml', help='Path to the config file.')
43 | args = parser.parse_args()
44 | main(args)
45 |
--------------------------------------------------------------------------------
/param.yaml:
--------------------------------------------------------------------------------
1 |
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | dataset: celeba #dataset we use
5 | img_size: 256 #image resolution, default=256
6 | id_dim: 512 #identity code dimension, default=512
7 | cuda_device: '1' #cuda device we use
8 |
9 | # weight for objective functions
10 | lambda_reg: 1 #type=float, weight for R1 regularization
11 | lambda_recon: 1 #type=float, weight for reconstruction loss
12 | lambda_att_face: 2 #type=float, weight for attribute preservation loss (face)
13 | lambda_att_bg: 1 #type=float, weight for attribute preservation loss (background)
14 | lambda_per: 1 #type=float, weight for perceptual loss
15 | lambda_id: 10 #type=float, weight for identity reconstruction loss
16 |
17 | # arguments
18 | total_iters: 60000 #type=int, number of total iterations
19 | batch_size: 8 #type=int, batch size for training
20 | lr: 0.0001 #type=float, learning rate for attribute encoder and discriminator
21 | id_lr: 0.0001 #type=float, learning rate for identity encoder
22 | beta1: 0.0 #type=float, decay rate for 1st moment of Adam
23 | beta2: 0.99 #type=float, decay rate for 2nd moment of Adam
24 | weight_decay: 0.0001 #type=float, weight decay for optimizer
25 |
26 |
27 | # misc
28 | mode: 'train' #type=str, choices=['train', 'test']
29 | num_workers: 8 #type=int, number of workers used in DataLoader
30 |
31 | # directory for training
32 | train_img_dir: ['data/CelebA_Dataset/CelebA'] #type=str, directory containing training images
33 | sample_dir: 'expr/samples/CelebA/' #type=str, directory for saving generated images
34 | checkpoint_dir: 'expr/checkpoints/CelebA/' #type=str, directory for saving network checkpoints
35 | log_dir: 'expr/logs/CelebA/' #type=str, directory for saving logs
36 | resume_iter: 0 #type=int, number of iterations to resume training
37 |
38 | # directory for testing
39 | test_img_dir: 'data/FF++_Dataset/ff++' #type=str, directory containing testing images
40 | test_img_list: 'data/FF++_Dataset/face_swap_list.txt' #type=str, containing swapping images
41 | test_checkpoint_dir: 'pretrained_checkpoints/' #type=str, directory for pretrained face swapping model
42 | test_checkpoint_name: 'faceswapper.ckpt' #type=str, pretrained face swapping model name
43 | result_dir: 'expr/results/ff++/' #type=str, directory for saving generated images
44 | post_process: True #type=bool, whether we need the post process procedure or not
45 |
46 | # pretrained face alignment model
47 | wing_path: 'pretrained_checkpoints/wing.ckpt'
48 | face_model_path: 'pretrained_checkpoints/model_ir_se50.pth'
49 |
50 | # step size
51 | print_every: 50 #type=int, number of iterations to print log info
52 | sample_every: 2000 #type=int, number of iterations to display the results, 2000
53 | save_every: 5000 #type=int, number of iterations to save the checkpoint, 5000
54 |
55 |
56 |
--------------------------------------------------------------------------------