├── .gitignore ├── DCGAN.py ├── README.md ├── backbone ├── mobilenet.py └── resnet.py ├── checkpoint ├── imitator_gitopen.pth └── myimitator-512.json ├── config.py ├── dat ├── 103.jpg ├── 110.jpg ├── 16.jpg └── avg_face.jpg ├── dataset.py ├── demo.py ├── evaluate_adam.py ├── evaluate_sgd_L1Loss.py ├── evaluate_sgd_cross_entropy.py ├── face_align.py ├── face_parser.py ├── faceparse.py ├── faceswap.py ├── imitator.py ├── lightcnn.py ├── model_process.py ├── myimitator.py ├── papers ├── 1909.01064v1(Face-to-Parameter Translation ).pdf ├── 2003.05653(Towards High-Fidelity 3D Face Reconstruction).pdf └── 2008.07132v1(Fast and Robust Face-to-Parameter Translation).pdf ├── random_gen_image.py ├── resnet.py ├── tools ├── demo.py ├── face_align.py ├── fid.py └── lsgan.py ├── train_myimitator.py ├── train_translator.py ├── translator.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | #From https://github.com/github/gitignore/blob/master/Python.gitignore 2 | 3 | # Byte-compiled / optimized / DLL files 4 | */__pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | develop-eggs/ 14 | dist/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | .project 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | 126 | .py~ 127 | .gitignore~ 128 | 129 | .vscode/ 130 | .idea/ 131 | *.ipynb 132 | 133 | 134 | logs/ 135 | output/ 136 | face_data/ 137 | checkpoint/shape_predictor_68_face_landmarks.dat 138 | checkpoint/79999_iter.pth 139 | checkpoint/resnet18-5c106cde.pth 140 | checkpoint/LightCNN_29Layers_V2_checkpoint.pth.tar 141 | checkpoint/rtnet50-fcn-14.torch 142 | checkpoint/epoch_130_0.666491.pt 143 | checkpoint/epoch_340_0.434396.pt 144 | checkpoint/faceseg_65_0.030560_0.025751_0.827335-7.pth 145 | checkpoint/faceseg_179_0.050777_0.065476_0.842724_withface7.pth -------------------------------------------------------------------------------- /DCGAN.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | from PIL import Image 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | from torch.utils.data import DataLoader, Dataset 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | import torchvision.utils as vutils 16 | 17 | ''' 18 | DCGAN pytorch官方实现 19 | https://github.com/pytorch/examples/blob/master/dcgan/main.py 20 | ''' 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataroot', default='F:/BaiduNetdiskDownload/faces/', required=False, help='path to dataset') 24 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=0) 25 | parser.add_argument('--batchSize', type=int, default=16, help='input batch size') 26 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') 27 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') 28 | parser.add_argument('--ngf', type=int, default=64) 29 | parser.add_argument('--ndf', type=int, default=64) 30 | parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train for') 31 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 33 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 34 | parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works') 35 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 36 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") 37 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") 38 | parser.add_argument('--outf', default='./output/', help='folder to output images and model checkpoints') 39 | parser.add_argument('--manualSeed', type=int, help='manual seed') 40 | 41 | opt = parser.parse_args() 42 | print(opt) 43 | 44 | try: 45 | os.makedirs(opt.outf) 46 | except OSError: 47 | pass 48 | 49 | if opt.manualSeed is None: 50 | opt.manualSeed = random.randint(1, 10000) 51 | print("Random Seed: ", opt.manualSeed) 52 | random.seed(opt.manualSeed) 53 | torch.manual_seed(opt.manualSeed) 54 | 55 | cudnn.benchmark = True 56 | 57 | if torch.cuda.is_available() and not opt.cuda: 58 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 59 | 60 | if opt.dataroot is None: 61 | raise ValueError("`dataroot` parameter is required ") 62 | 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | print(torch.cuda.is_available()) 65 | 66 | # 图像增强 67 | transform = transforms.Compose([ 68 | transforms.Resize(opt.imageSize), 69 | transforms.CenterCrop(opt.imageSize), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 72 | ]) 73 | 74 | # 自定义Dataset 75 | class DCGAN_Dataset(Dataset): 76 | def __init__(self, data_root, transform): 77 | self.data_root = data_root 78 | self.transform = transform 79 | self.fileList = [] 80 | for file in os.listdir(self.data_root): 81 | self.fileList.append(os.path.join(self.data_root, file)) 82 | 83 | def __getitem__(self, index): 84 | file = self.fileList[index] 85 | img = Image.open(os.path.join(self.data_root, file)).convert("RGB") 86 | 87 | return self.transform(img) 88 | 89 | def __len__(self): 90 | return len(self.fileList) 91 | 92 | dataset = DCGAN_Dataset(opt.dataroot, transform) 93 | 94 | dataloader = torch.utils.data.DataLoader( 95 | dataset=dataset, 96 | batch_size=opt.batchSize, 97 | shuffle=True, 98 | num_workers=0 99 | ) 100 | 101 | device = torch.device("cuda:0" if opt.cuda else "cpu") 102 | ngpu = int(opt.ngpu) 103 | nz = int(opt.nz) 104 | ngf = int(opt.ngf) 105 | ndf = int(opt.ndf) 106 | 107 | # custom weights initialization called on netG and netD 108 | def weights_init(m): 109 | classname = m.__class__.__name__ 110 | if classname.find('Conv') != -1: 111 | torch.nn.init.normal_(m.weight, 0.0, 0.02) 112 | elif classname.find('BatchNorm') != -1: 113 | torch.nn.init.normal_(m.weight, 1.0, 0.02) 114 | torch.nn.init.zeros_(m.bias) 115 | 116 | 117 | class Generator(nn.Module): 118 | def __init__(self, ngpu): 119 | super(Generator, self).__init__() 120 | self.ngpu = ngpu 121 | self.main = nn.Sequential( 122 | # input is Z, going into a convolution 123 | nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), 124 | nn.BatchNorm2d(ngf * 8), 125 | nn.ReLU(True), 126 | # state size. (ngf*8) x 4 x 4 127 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), 128 | nn.BatchNorm2d(ngf * 4), 129 | nn.ReLU(True), 130 | # state size. (ngf*4) x 8 x 8 131 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), 132 | nn.BatchNorm2d(ngf * 2), 133 | nn.ReLU(True), 134 | # state size. (ngf*2) x 16 x 16 135 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), 136 | nn.BatchNorm2d(ngf), 137 | nn.ReLU(True), 138 | # state size. (ngf) x 32 x 32 139 | nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False), 140 | nn.Tanh() 141 | # state size. (nc) x 64 x 64 142 | ) 143 | 144 | def forward(self, input): 145 | if input.is_cuda and self.ngpu > 1: 146 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 147 | else: 148 | output = self.main(input) 149 | return output 150 | 151 | 152 | netG = Generator(ngpu).to(device) 153 | netG.apply(weights_init) 154 | if opt.netG != '': 155 | netG.load_state_dict(torch.load(opt.netG)) 156 | print(netG) 157 | 158 | 159 | class Discriminator(nn.Module): 160 | def __init__(self, ngpu): 161 | super(Discriminator, self).__init__() 162 | self.ngpu = ngpu 163 | self.main = nn.Sequential( 164 | # input is (nc) x 64 x 64 165 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False), 166 | nn.LeakyReLU(0.2, inplace=True), 167 | # state size. (ndf) x 32 x 32 168 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), 169 | nn.BatchNorm2d(ndf * 2), 170 | nn.LeakyReLU(0.2, inplace=True), 171 | # state size. (ndf*2) x 16 x 16 172 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), 173 | nn.BatchNorm2d(ndf * 4), 174 | nn.LeakyReLU(0.2, inplace=True), 175 | # state size. (ndf*4) x 8 x 8 176 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), 177 | nn.BatchNorm2d(ndf * 8), 178 | nn.LeakyReLU(0.2, inplace=True), 179 | # state size. (ndf*8) x 4 x 4 180 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), 181 | nn.Sigmoid() 182 | ) 183 | 184 | def forward(self, input): 185 | if input.is_cuda and self.ngpu > 1: 186 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) 187 | else: 188 | output = self.main(input) 189 | 190 | return output.view(-1, 1).squeeze(1) 191 | 192 | 193 | netD = Discriminator(ngpu).to(device) 194 | netD.apply(weights_init) 195 | if opt.netD != '': 196 | netD.load_state_dict(torch.load(opt.netD)) 197 | print(netD) 198 | 199 | criterion = nn.BCELoss() 200 | 201 | fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) 202 | real_label = 1 203 | fake_label = 0 204 | 205 | # setup optimizer 206 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 207 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 208 | 209 | if opt.dry_run: 210 | opt.epochs = 1 211 | 212 | for epoch in range(opt.epochs): 213 | for i, data in enumerate(dataloader, 0): 214 | ############################ 215 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 216 | ########################### 217 | # train with real 218 | netD.zero_grad() 219 | batch_size = data.shape[0] 220 | label = torch.full((batch_size,), real_label, dtype=torch.float32, device=device) # 这里label为1,真 221 | 222 | output = netD(data) 223 | errD_real = criterion(output, label) # 将真判断为真。这里是判别器的损失,output越接近1,损失越小 224 | errD_real.backward() 225 | D_x = output.mean().item() 226 | 227 | # train with fake 228 | noise = torch.randn(batch_size, nz, 1, 1, device=device) 229 | fake = netG(noise) 230 | label.fill_(fake_label) # 这里label为0,假 231 | output = netD(fake.detach()) # output越接近0,损失越小 232 | errD_fake = criterion(output, label) # 假的判断为假。这里是判别器的损失,output越为假,损失越小 233 | errD_fake.backward() 234 | D_G_z1 = output.mean().item() # 越小越好 235 | errD = errD_real + errD_fake 236 | optimizerD.step() 237 | 238 | ############################ 239 | # (2) Update G network: maximize log(D(G(z))) 240 | ########################### 241 | netG.zero_grad() 242 | label.fill_(real_label) # 这里标签为1,真 243 | output = netD(fake) # 对于判别器,要尽可能将生成器生成的判断为假; 244 | errG = criterion(output, label) # 这里表征生成器的生成损失,errG越小,说明生成器生成的越逼真。计算output(判别器的输出,越接近0说明判别器越好)与label(生成器给出的标签,越接近1说明生成器越好)的损失 245 | errG.backward() 246 | D_G_z2 = output.mean().item() # 越小越好 247 | optimizerG.step() 248 | 249 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' 250 | % (epoch, opt.epochs, i, len(dataloader), 251 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) 252 | if i % 1 == 0: 253 | vutils.save_image(data, 254 | '%s/real_samples_%d.png' % (opt.outf, epoch), 255 | normalize=True) 256 | fake = netG(fixed_noise) 257 | vutils.save_image(fake.detach(), 258 | '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch), 259 | normalize=True) 260 | 261 | if opt.dry_run: 262 | break 263 | # do checkpointing 264 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) 265 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D捏脸 2 | 3 | 网易《Face-to-Parameter Translation for Game Character Auto-Creation》论文复现 4 | 5 | ## 1、网络结构 6 | 7 | G:Imitator网络 8 | 9 | ​ 输入:连续参数208d(归一化到0~1),离散参数102d; 10 | 11 | ​ 输出:游戏角色正脸图像(512x512x3); 12 | 13 | ​ SGD,batch_size=16,momentum=0.9,learning_rate=0.01 decay 10% per 50epochs,max_epochs=500; 14 | 15 | F1:lightcnn,做人脸识别,确保生成前后为同一个人,算余弦损失; 16 | 17 | ​ 输入:128x128x1灰度图; 18 | 19 | ​ 输出:256d向量; 20 | 21 | F2:faceparse,做面部语义分割(任意网络即可,deeplab、BiSeNet、RTNet等),算带权重的L1损失;(不算L2的原因:L1的稀疏性,能突出五官的特征;而L2是平滑性) 22 | 23 | ​ 输入:256x256x3 RGB图像; 24 | 25 | ​ 输出:语义分割结果; 26 | 27 | ​ 语义概率图增强人脸特征,采用不同权重突出五官的特征; 28 | 29 | ​ 也可使用交叉熵损失,衡量将每个像素点分类为不同类型的概率分布距离; 30 | 31 | ## 2、数据预处理 32 | 33 | ​ Face alignment:dlib 34 | 35 | ## 3、缺点 36 | 37 | (1)对人脸鲁棒性低:人脸姿态、遮挡敏感; 38 | 39 | (2)离散参数求解未能有效解决; 40 | 41 | (3)需要多次迭代优化:基于梯度下降的算法运算效率低; 42 | 43 | (4)只适合真人向的游戏角色; 44 | 45 | ## 4、优化项 46 | 47 | Imitator基于DCGAN的G网络,网络结构简单,特征表达能力有限,导致生成的图像模糊,五官畸形严重。 48 | 49 | 解决办法: 50 | 51 | 1、采用DCGAN的D网络,学习从图像到参数的映射,来代替基于梯度下降的优化过程; 52 | 53 | 2、参照第二篇论文,《Fast and Robust Face-to-Parameter Translation for Game Character Auto-Creation》,新增T网络,学习face-recognition到参数的映射关系。 54 | -------------------------------------------------------------------------------- /backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | ''' 6 | backbone: mobilenetV2 7 | ''' 8 | class MobileNetV2(nn.Module): 9 | ''' 10 | :param num_classes 类别个数 11 | :param output_stride 12 | :param width_mult 通过该参数控制每一层的通道数量 13 | :param inverted_residual_setting 14 | :param round_neatest 将每层的通道数四舍五入为该数字的倍数,设为1则关闭四舍五入 15 | ''' 16 | def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): 17 | super(MobileNetV2, self).__init__() 18 | block = InvertedResidual 19 | input_channel = 32 20 | last_channel = 1280 21 | self.output_stride = output_stride 22 | current_stride = 1 23 | if inverted_residual_setting is None: 24 | inverted_residual_setting = [ 25 | # t, c, n, s 26 | [1, 16, 1, 1], 27 | [6, 24, 2, 2], 28 | [6, 32, 3, 2], 29 | [6, 64, 4, 2], 30 | [6, 96, 3, 1], 31 | [6, 160, 3, 2], 32 | [6, 320, 1, 1], 33 | ] 34 | 35 | # 构建第一层 36 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 37 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 38 | features = [ConvBNReLU(3, input_channel, stride=2)] 39 | current_stride *= 2 40 | dilation = 1 41 | previous_dilation = 1 42 | 43 | # 构建中间的倒置残差块 44 | for t, c, n, s in inverted_residual_setting: 45 | output_channel = _make_divisible(c * width_mult, round_nearest) 46 | previous_dilation = dilation 47 | if current_stride == output_stride: 48 | stride = 1 49 | dilation *= s 50 | else: 51 | stride = s 52 | current_stride *= s 53 | output_channel = int(c * width_mult) 54 | 55 | for i in range(n): 56 | if i == 0: 57 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t)) 58 | else: 59 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t)) 60 | input_channel = output_channel 61 | 62 | # 构建最后一层 63 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 64 | self.features = nn.Sequential(*features) 65 | 66 | # 构建最后的分类层 67 | self.classifier = nn.Sequential( 68 | nn.Dropout(0.2), 69 | nn.Linear(self.last_channel, num_classes) 70 | ) 71 | 72 | # weight initialization 73 | for m in self.modules(): 74 | if isinstance(m, nn.Conv2d): 75 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 76 | if m.bias is not None: 77 | nn.init.zeros_(m.bias) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | nn.init.ones_(m.weight) 80 | nn.init.zeros_(m.bias) 81 | elif isinstance(m, nn.Linear): 82 | nn.init.normal_(m.weight, 0, 0.01) 83 | nn.init.zeros_(m.bias) 84 | 85 | def forward(self, x): 86 | x = self.features(x) 87 | x = x.mean([2, 3]) 88 | x = self.classifier(x) 89 | return x 90 | 91 | """ 92 | This function is taken from the original tf repo. 93 | It ensures that all layers have a channel number that is divisible by 8 94 | It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 95 | :param v: 96 | :param divisor: 97 | :param min_value: 98 | :return: 99 | """ 100 | def _make_divisible(v, divisor, min_value=None): 101 | if min_value is None: 102 | min_value = divisor 103 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 104 | if new_v < 0.9 * v: # 确保四舍五入的下降幅度不超过10% 105 | new_v += divisor 106 | return new_v 107 | 108 | ''' 109 | CBR结构:Conv+BN+Relu 110 | ''' 111 | class ConvBNReLU(nn.Sequential): 112 | def __init__(self, input_channel, output_channel, kernel_size=3, stride=1, dilation=1, groups=1): 113 | super(ConvBNReLU, self).__init__( 114 | nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=False), 115 | nn.BatchNorm2d(output_channel), 116 | nn.ReLU6(inplace=True) 117 | ) 118 | 119 | ''' 120 | 倒置残差模块 121 | ''' 122 | class InvertedResidual(nn.Module): 123 | def __init__(self, input_channel, output_channel, stride, dilation, expand_ratio): 124 | super(InvertedResidual, self).__init__() 125 | self.stride = stride 126 | 127 | hidden_dim = int(round(input_channel * expand_ratio)) 128 | self.use_residual_connect = self.stride == 1 and input_channel == output_channel # 判断是否采用残差连接 129 | layers = [] 130 | 131 | if expand_ratio != 1: 132 | layers.append(ConvBNReLU(input_channel, hidden_dim, kernel_size=1)) 133 | 134 | layers.extend([ 135 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), # 采用深度可分离卷积,各通道各卷积各的 136 | nn.Conv2d(hidden_dim, output_channel, 1, 1, 0, bias=False), 137 | nn.BatchNorm2d(output_channel) 138 | ]) 139 | self.conv = nn.Sequential(*layers) 140 | self.input_padding = fix_padding(3, dilation) 141 | 142 | def forward(self, x): 143 | x_pad = F.pad(x, self.input_padding) 144 | if self.use_residual_connect: 145 | return x + self.conv(x_pad) 146 | else: 147 | return self.conv(x_pad) 148 | 149 | 150 | def fix_padding(kernel_size, dilation): 151 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 152 | pad_total = kernel_size_effective - 1 153 | pad_beg = pad_total // 2 154 | pad_end = pad_total - pad_beg 155 | return (pad_beg, pad_end, pad_beg, pad_end) 156 | -------------------------------------------------------------------------------- /backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | ''' 6 | backbone: resnet 7 | ''' 8 | class ResNet50(nn.Module): 9 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 10 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 11 | norm_layer=None): 12 | super(ResNet50, self).__init__() 13 | if norm_layer is None: 14 | norm_layer = nn.BatchNorm2d 15 | self._norm_layer = norm_layer 16 | 17 | self.inplanes = 64 18 | self.dilation = 1 19 | if replace_stride_with_dilation is None: 20 | # each element in the tuple indicates if we should replace 21 | # the 2x2 stride with a dilated convolution instead 22 | replace_stride_with_dilation = [False, False, False] 23 | if len(replace_stride_with_dilation) != 3: 24 | raise ValueError("replace_stride_with_dilation should be None " 25 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 26 | self.groups = groups 27 | self.base_width = width_per_group 28 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 29 | bias=False) 30 | self.bn1 = norm_layer(self.inplanes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 33 | self.layer1 = self._make_layer(block, 64, layers[0]) 34 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 35 | dilate=replace_stride_with_dilation[0]) 36 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 37 | dilate=replace_stride_with_dilation[1]) 38 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 39 | dilate=replace_stride_with_dilation[2]) 40 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 41 | self.fc = nn.Linear(512 * block.expansion, num_classes) 42 | 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 46 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 47 | nn.init.constant_(m.weight, 1) 48 | nn.init.constant_(m.bias, 0) 49 | 50 | # Zero-initialize the last BN in each residual branch, 51 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 52 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 53 | if zero_init_residual: 54 | for m in self.modules(): 55 | nn.init.constant_(m.bn3.weight, 0) 56 | 57 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 58 | norm_layer = self._norm_layer 59 | downsample = None 60 | previous_dilation = self.dilation 61 | if dilate: 62 | self.dilation *= stride 63 | stride = 1 64 | if stride != 1 or self.inplanes != planes * block.expansion: 65 | downsample = nn.Sequential( 66 | conv1x1(self.inplanes, planes * block.expansion, stride), 67 | norm_layer(planes * block.expansion), 68 | ) 69 | 70 | layers = [] 71 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 72 | self.base_width, previous_dilation, norm_layer)) 73 | self.inplanes = planes * block.expansion 74 | for _ in range(1, blocks): 75 | layers.append(block(self.inplanes, planes, groups=self.groups, 76 | base_width=self.base_width, dilation=self.dilation, 77 | norm_layer=norm_layer)) 78 | 79 | return nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | x = self.conv1(x) 83 | x = self.bn1(x) 84 | x = self.relu(x) 85 | x = self.maxpool(x) 86 | 87 | x = self.layer1(x) 88 | x = self.layer2(x) 89 | x = self.layer3(x) 90 | x = self.layer4(x) 91 | 92 | x = self.avgpool(x) 93 | x = torch.flatten(x, 1) 94 | x = self.fc(x) 95 | 96 | return x 97 | 98 | class Bottleneck(nn.Module): 99 | expansion = 4 100 | 101 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 102 | base_width=64, dilation=1, norm_layer=None): 103 | super(Bottleneck, self).__init__() 104 | if norm_layer is None: 105 | norm_layer = nn.BatchNorm2d 106 | width = int(planes * (base_width / 64.)) * groups 107 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 108 | self.conv1 = conv1x1(inplanes, width) 109 | self.bn1 = norm_layer(width) 110 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 111 | self.bn2 = norm_layer(width) 112 | self.conv3 = conv1x1(width, planes * self.expansion) 113 | self.bn3 = norm_layer(planes * self.expansion) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.downsample = downsample 116 | self.stride = stride 117 | 118 | def forward(self, x): 119 | identity = x 120 | 121 | out = self.conv1(x) 122 | out = self.bn1(out) 123 | out = self.relu(out) 124 | 125 | out = self.conv2(out) 126 | out = self.bn2(out) 127 | out = self.relu(out) 128 | 129 | out = self.conv3(out) 130 | out = self.bn3(out) 131 | 132 | if self.downsample is not None: 133 | identity = self.downsample(x) 134 | 135 | out += identity 136 | out = self.relu(out) 137 | 138 | return out 139 | 140 | ''' 141 | conv 3*3 142 | ''' 143 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 144 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 145 | padding=dilation, groups=groups, bias=False, dilation=dilation) 146 | 147 | ''' 148 | conv 1*1 149 | ''' 150 | def conv1x1(in_planes, out_planes, stride=1): 151 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 152 | -------------------------------------------------------------------------------- /checkpoint/imitator_gitopen.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/checkpoint/imitator_gitopen.pth -------------------------------------------------------------------------------- /checkpoint/myimitator-512.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_layer_position": 8, 3 | "channel_width": 128, 4 | "class_embed_dim": 128, 5 | "eps": 0.0001, 6 | "layers": [ 7 | [ 8 | false, 9 | 16, 10 | 16 11 | ], 12 | [ 13 | true, 14 | 16, 15 | 16 16 | ], 17 | [ 18 | false, 19 | 16, 20 | 16 21 | ], 22 | [ 23 | true, 24 | 16, 25 | 8 26 | ], 27 | [ 28 | false, 29 | 8, 30 | 8 31 | ], 32 | [ 33 | true, 34 | 8, 35 | 8 36 | ], 37 | [ 38 | false, 39 | 8, 40 | 8 41 | ], 42 | [ 43 | true, 44 | 8, 45 | 4 46 | ], 47 | [ 48 | false, 49 | 4, 50 | 4 51 | ], 52 | [ 53 | true, 54 | 4, 55 | 2 56 | ], 57 | [ 58 | false, 59 | 2, 60 | 2 61 | ], 62 | [ 63 | true, 64 | 2, 65 | 1 66 | ], 67 | [ 68 | false, 69 | 1, 70 | 1 71 | ], 72 | [ 73 | true, 74 | 1, 75 | 1 76 | ] 77 | ], 78 | "n_stats": 51, 79 | "num_classes": 1000, 80 | "output_dim": 512, 81 | "z_dim": 128 82 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import dlib 2 | import torch 3 | 4 | # 通用配置项 5 | continuous_params_size = 159 # 连续参数个数,github开源的是95个参数 6 | # image_root = "F:/dataset/face_simple/face/" 7 | # train_params_root = "F:/dataset/face_simple/train_param.json" 8 | # test_params_root = "F:/dataset/face_simple/test_param.json" 9 | 10 | # train_set = "./face_data/trainset_female" 11 | # test_set = "./face_data/testset_female" 12 | 13 | 14 | image_root = "F:/dataset/face_20211203_20000_nojiemao/" 15 | params_root = "F:/dataset/face_20211203_20000_nojiemao/param.json" 16 | 17 | use_gpu = False # 是否使用gpu 18 | num_gpu = 1 # gpu的个数 19 | device = torch.device('cuda:0') if torch.cuda.is_available() and use_gpu else torch.device('cpu') 20 | path_tensor_log = "./logs/" 21 | 22 | 23 | # imitator配置项 24 | total_epochs = 500 25 | batch_size = 16 26 | save_freq = 10 27 | prev_freq = 10 28 | learning_rate = 1 29 | # imitator_model = "./checkpoint/imitator.pth" # 不做finetune,就直接写空字符串 30 | imitator_model = "./checkpoint/epoch_340_0.434396.pt" 31 | # imitator_model = "" 32 | 33 | prev_path = "./output/preview" 34 | # prev_path = "E:/nielian/" 35 | model_path = "./output/imitator" 36 | 37 | # 评估时 38 | total_eval_steps = 50 39 | eval_alpha = 0.1 # Ls = alpha * L1 + L2 40 | eval_learning_rate = 1 41 | eval_prev_freq = 1 42 | 43 | 44 | # 人脸语义分割 45 | faceparse_backbone = 'mobilenetv2' 46 | faceparse_checkpoint = "./checkpoint/faceseg_179_0.050777_0.065476_0.842724_withface7.pth" 47 | num_classes = 7 48 | output_stride = 16 49 | pretrained = True 50 | progress = True 51 | model_urls = "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" # mobilenetv2 52 | # model_urls = "https://download.pytorch.org/models/resnet50-19c8e357.pth" # resnet 53 | 54 | # light-cnn 55 | lightcnn_checkpoint = "./checkpoint/LightCNN_29Layers_V2_checkpoint.pth.tar" 56 | 57 | # 人脸关键点检测及摆正 58 | detector = dlib.get_frontal_face_detector() 59 | predictor = dlib.shape_predictor('./checkpoint/shape_predictor_68_face_landmarks.dat') 60 | 61 | # 自定义imitator 62 | config_jsonfile = "./checkpoint/myimitator-512.json" 63 | 64 | 65 | -------------------------------------------------------------------------------- /dat/103.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/103.jpg -------------------------------------------------------------------------------- /dat/110.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/110.jpg -------------------------------------------------------------------------------- /dat/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/16.jpg -------------------------------------------------------------------------------- /dat/avg_face.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/avg_face.jpg -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | import config 10 | 11 | ''' 12 | Imitator Dataset 13 | ''' 14 | class Imitator_Dataset(Dataset): 15 | def __init__(self, params_root, image_root, mode="train"): 16 | self.image_root = image_root 17 | self.mode = mode 18 | with open(params_root, encoding='utf-8') as f: 19 | self.params = json.load(f) 20 | 21 | def __getitem__(self, index): 22 | if self.mode == "val": 23 | img = Image.open(os.path.join(self.image_root, '%d.png' % (index + 54000))).convert("RGB") 24 | param = torch.tensor(self.params['%d.png' % (index + 54000)]) 25 | else: 26 | img = Image.open(os.path.join(self.image_root, '%d.png' % index)).convert("RGB") 27 | param = torch.tensor(self.params['%d.png' % index]) 28 | img = T.ToTensor()(img) 29 | return param, img 30 | 31 | def __len__(self): 32 | if self.mode == "train": 33 | return 54000 34 | else: 35 | return 6000 36 | 37 | ############################################################ 38 | ''' 39 | Translator Dataset 40 | ''' 41 | def split_dataset(datapath): 42 | trainlist, vallist = [], [] 43 | for file in os.listdir(datapath): 44 | if random.randint(0, 10) <= 8: 45 | trainlist.append(os.path.join(datapath, file)) 46 | else: 47 | vallist.append(os.path.join(datapath, file)) 48 | return trainlist, vallist 49 | 50 | class Translator_Dataset(Dataset): 51 | def __init__(self, img_list): 52 | self.img_list = img_list 53 | 54 | def __getitem__(self, index): 55 | img = Image.open(self.img_list[index]).convert("RGB").resize((512, 512), Image.BILINEAR) 56 | img = T.ToTensor()(img) 57 | return img 58 | 59 | def __len__(self): 60 | return len(self.img_list) 61 | 62 | 63 | if __name__ == '__main__': 64 | train_imitator_Dataset = Imitator_Dataset(config.params_root, config.image_root) 65 | train_imitator_dataloader = DataLoader(train_imitator_Dataset, batch_size=16, shuffle=True) 66 | for i, content in enumerate(train_imitator_dataloader): 67 | x, y = content[:] 68 | print(i, x.shape, y.shape) 69 | 70 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import utils 4 | import config 5 | from imitator import Imitator 6 | 7 | imitator = Imitator() 8 | 9 | # model_dict = {"model.0.weight": "0.0.weight", 10 | # "model.0.bias": "0.0.bias", 11 | # "model.1.weight": "0.1.weight", 12 | # "model.1.bias": "0.1.bias", 13 | # "model.1.running_mean": "0.1.running_mean", 14 | # "model.1.running_var": "0.1.running_var", 15 | # "model.1.num_batches_tracked": "0.1.num_batches_tracked", 16 | # "model.3.weight": "1.0.weight", 17 | # "model.3.bias": "1.0.bias", 18 | # "model.4.weight": "1.1.weight", 19 | # "model.4.bias": "1.1.bias", 20 | # "model.4.running_mean": "1.1.running_mean", 21 | # "model.4.running_var": "1.1.running_var", 22 | # "model.4.num_batches_tracked": "1.1.num_batches_tracked", 23 | # "model.6.weight": "2.0.weight", 24 | # "model.6.bias": "2.0.bias", 25 | # "model.7.weight": "2.1.weight", 26 | # "model.7.bias": "2.1.bias", 27 | # "model.7.running_mean": "2.1.running_mean", 28 | # "model.7.running_var": "2.1.running_var", 29 | # "model.7.num_batches_tracked": "2.1.num_batches_tracked", 30 | # "model.9.weight": "3.0.weight", 31 | # "model.9.bias": "3.0.bias", 32 | # "model.10.weight": "3.1.weight", 33 | # "model.10.bias": "3.1.bias", 34 | # "model.10.running_mean": "3.1.running_mean", 35 | # "model.10.running_var": "3.1.running_var", 36 | # "model.10.num_batches_tracked": "3.1.num_batches_tracked", 37 | # "model.12.weight": "4.0.weight", 38 | # "model.12.bias": "4.0.bias", 39 | # "model.13.weight": "4.1.weight", 40 | # "model.13.bias": "4.1.bias", 41 | # "model.13.running_mean": "4.1.running_mean", 42 | # "model.13.running_var": "4.1.running_var", 43 | # "model.13.num_batches_tracked": "4.1.num_batches_tracked", 44 | # "model.15.weight": "5.0.weight", 45 | # "model.15.bias": "5.0.bias", 46 | # "model.16.weight": "5.1.weight", 47 | # "model.16.bias": "5.1.bias", 48 | # "model.16.running_mean": "5.1.running_mean", 49 | # "model.16.running_var": "5.1.running_var", 50 | # "model.16.num_batches_tracked": "5.1.num_batches_tracked", 51 | # "model.18.weight": "6.0.weight", 52 | # "model.18.bias": "6.0.bias", 53 | # "model.19.weight": "6.1.weight", 54 | # "model.19.bias": "6.1.bias", 55 | # "model.19.running_mean": "6.1.running_mean", 56 | # "model.19.running_var": "6.1.running_var", 57 | # "model.19.num_batches_tracked": "6.1.num_batches_tracked", 58 | # "model.21.weight": "7.weight", 59 | # "model.21.bias": "7.bias" 60 | # } 61 | # 62 | # checkpoint = torch.load("./checkpoint/model_imitator_100000_cuda.pth", map_location=torch.device('cpu')) 63 | # # imitator.load_state_dict(checkpoint['net']) 64 | # for k1, k2 in zip(imitator.state_dict(), checkpoint['net']): 65 | # # print("%s\t%s\t%s\t%s" % (k1, imitator.state_dict()[k1].shape, k2, checkpoint['net'][k2].shape)) 66 | # content = '"%s": "%s",' % (k1, k2) 67 | # print(content) 68 | # 69 | # state_dict = torch.load(config.imitator_model, map_location=torch.device('cpu'))['net'] 70 | # # 71 | # # op_model = {} 72 | # # for k in state_dict['net'].keys(): 73 | # # op_model["model." + str(k)[2:]] = imitator_model['net'][k] 74 | # for new_key, old_key in model_dict.items(): 75 | # state_dict[new_key] = state_dict.pop(old_key) 76 | # 77 | # imitator.load_state_dict(state_dict) 78 | # torch.save(state_dict, "./imitator.pth") 79 | 80 | 81 | model = torch.load(config.imitator_model) 82 | imitator.load_state_dict(model) 83 | -------------------------------------------------------------------------------- /evaluate_adam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from tqdm import tqdm 4 | from PIL import Image 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | import utils 10 | import config 11 | from imitator import Imitator 12 | from lightcnn import LightCNN_29Layers_v2 13 | from faceparse import BiSeNet 14 | 15 | ''' 16 | 评估:t_params封装起来,做Adam优化 17 | ''' 18 | 19 | ''' 20 | evaluate with torch tensor 21 | :param input: torch tensor [B, H, W, C] rang: [0-1], not [0-255] 22 | :param image: 做mask图底色用 23 | :param bsnet: BiSeNet model 24 | :param w: tuple len=6 [eyebrow,eye,nose,teeth,up lip,lower lip] 25 | :return 带权重的tensor;脸部mask;带mask的图 26 | ''' 27 | def faceparsing_tensor(input, image, bsnet, w): 28 | out = bsnet(input) # [1, 19, 512, 512] 29 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0) 30 | 31 | mask_img = utils.vis_parsing_maps(image, parsing, 1) 32 | out = out.squeeze() 33 | 34 | return w[0] * out[3] + w[1] * out[4] + w[2] * out[10] + out[11] + out[12] + out[13], out[1], mask_img 35 | 36 | if __name__ == '__main__': 37 | eval_image = "./dat/gen.jpg" # 要评估的图片 38 | 39 | # 加载lightcnn 40 | model = LightCNN_29Layers_v2(num_classes=80013) 41 | model.eval() 42 | if config.use_gpu: 43 | checkpoint = torch.load(config.lightcnn_checkpoint) 44 | model = torch.nn.DataParallel(model).cuda() 45 | model.load_state_dict(checkpoint['state_dict']) 46 | else: 47 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu") 48 | new_state_dict = model.state_dict() 49 | for k, v in checkpoint['state_dict'].items(): 50 | _name = k[7:] # remove `module.` 51 | new_state_dict[_name] = v 52 | model.load_state_dict(new_state_dict) 53 | 54 | # 冻结lightcnn 55 | for param in model.parameters(): 56 | param.requires_grad = False 57 | 58 | losses = [] 59 | 60 | # 加载BiSeNet 61 | bsnet = BiSeNet(n_classes=19) 62 | if config.use_gpu: 63 | bsnet.cuda() 64 | bsnet.load_state_dict(torch.load(config.faceparse_checkpoint)) 65 | else: 66 | bsnet.load_state_dict(torch.load(config.faceparse_checkpoint, map_location="cpu")) 67 | bsnet.eval() 68 | 69 | # 冻结BiSeNet 70 | for param in bsnet.parameters(): 71 | param.requires_grad = False 72 | 73 | # 加载imitator 74 | imitator = Imitator() 75 | 76 | l2_c = (torch.ones((512, 512)), torch.ones((512, 512))) 77 | if config.use_gpu: 78 | imitator.cuda() 79 | imitator.eval() 80 | 81 | # 冻结imitator 82 | for param in imitator.parameters(): 83 | param.requires_grad = False 84 | 85 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu')) 86 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数 87 | 88 | # 图片读取 89 | img = cv2.imread(eval_image) 90 | img = cv2.resize(img, (512, 512)) 91 | img = img.astype(np.float32) 92 | 93 | img1 = Image.open(eval_image) 94 | image = img1.resize((512, 512), Image.BILINEAR) 95 | 96 | # inference 97 | # t_params = 0.5 * torch.ones((1, config.continuous_params_size), dtype=torch.float32) 98 | t_params = torch.rand((1, config.continuous_params_size), dtype=torch.float32) # 论文中用均匀分布,torch.randn()为正态分布 99 | optimizer = torch.optim.Adam([t_params], lr=config.eval_learning_rate) 100 | if config.use_gpu: 101 | t_params = t_params.cuda() 102 | t_params.requires_grad = True 103 | losses.clear() # 清空损失 104 | 105 | # 用于计算L1损失的:L1_y [B, C, W, H] 106 | img_resized = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_LINEAR) 107 | img_resized = np.swapaxes(img_resized, 0, 2).astype(np.float32) 108 | img_resized = np.mean(img_resized, axis=0)[np.newaxis, np.newaxis, :, :] 109 | 110 | # cv2.imwrite("l1_tmp.jpg", cv2.resize(img_resized[0].transpose(1, 2, 0), (512, 512))) 111 | img_resized = torch.from_numpy(img_resized) 112 | if config.use_gpu: 113 | img_resized = img_resized.cuda() 114 | L1_y = img_resized 115 | 116 | # 用于计算L2损失的参照:L2_y [B, C, H, W] 117 | img = img[np.newaxis, :, :, ] 118 | img = np.swapaxes(img, 1, 2) 119 | img = np.swapaxes(img, 1, 3) 120 | # print(img.shape) 121 | # cv2.imwrite("l2_tmp.jpg", cv2.resize(img[0].transpose(1, 2, 0), (512, 512))) 122 | 123 | img = torch.from_numpy(img) 124 | if config.use_gpu: 125 | img = img.cuda() 126 | L2_y = img / 255. 127 | 128 | 129 | 130 | # 做total_eval_steps次训练,取最后一次 131 | m_progress = tqdm(range(1, config.total_eval_steps + 1)) 132 | for i in m_progress: 133 | y_ = imitator(t_params) # [1, 3, 512, 512], [batch_size, c, w, h] 134 | # tmp = y_.detach().cpu().numpy()[0] 135 | # tmp = tmp.transpose(2, 1, 0) 136 | # tmp = tmp * 255.0 137 | # # cv2.imshow("y_", tmp) 138 | # # cv2.waitKey() 139 | # print(type(tmp), tmp.shape) 140 | # cv2.imwrite("./output/gen.jpg", tmp) 141 | # print("已保存") 142 | # break 143 | # loss, info = self.evaluate_ls(y_) 144 | y_copy = y_.clone() # 复制出一份算L2损失 145 | 146 | # 衡量人脸相似度 147 | # L1损失:表示余弦距离的损失(判断是否为同一个人),大致身份的一致性。y_ [B, C, W, H] 148 | y_ = F.max_pool2d(y_, kernel_size=(4, 4), stride=4) # 512->128, [1, 3, 128, 128] 149 | y_ = torch.mean(y_, dim=1).view(1, 1, 128, 128) # gray 150 | 151 | # 计算L1损失时,BCWH改为BCHW 152 | # L1 = utils.discriminative_loss(torch.from_numpy(L1_y.detach().cpu().numpy().transpose(0, 1, 3, 2)), 153 | # torch.from_numpy(y_.detach().cpu().numpy().transpose(0, 1, 3, 2)), 154 | # model) 155 | 156 | L1 = utils.discriminative_loss(L1_y, y_, model) # BCWH 157 | # if i % config.eval_prev_freq == 0: 158 | # cv2.imwrite("L1_y_%d.jpg" % i, cv2.resize(L1_y.detach().cpu().numpy()[0].transpose(1, 2, 0), (512, 512))) 159 | # cv2.imwrite("y_%d.jpg" % i, cv2.resize(y_.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255., (512, 512))) 160 | 161 | # L2损失:面部语义损失(关键部位加权计算损失) 162 | w_r = [1.1, 1., 1., 0.7, 1., 1.] 163 | w_g = [1.1, 1., 1., 0.7, 1., 1.] 164 | part1, _, mask_img1 = faceparsing_tensor(L2_y, image, bsnet, w_r) 165 | y_copy = y_copy.transpose(2, 3) # [1, 3, 512, 512] 166 | part2, _, mask_img2 = faceparsing_tensor(y_copy, image, bsnet, w_g) 167 | # if i % config.eval_prev_freq == 0: 168 | # cv2.imwrite("L2_y_%d.jpg" % i, L2_y.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.) 169 | # cv2.imwrite("y_copy_%d.jpg" % i, y_copy.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.) 170 | 171 | L2_c = (part1 * 10, part2 * 10) 172 | L2 = F.l1_loss(part1, part2) 173 | 174 | L2_mask = (mask_img1, mask_img2) # 用于在第二行显示mask的图 175 | 176 | # 计算综合损失:Ls = alpha * L1 + L2 177 | Ls = config.eval_alpha * (1 - L1) + L2 178 | # Ls = L2 179 | info = "L1:{0:.6f} L2:{1:.6f} Ls:{2:.6f}".format(1-L1, L2, Ls) 180 | print(info) 181 | losses.append((1-L1.item(), L2.item()/3, Ls.item())) 182 | 183 | optimizer.zero_grad() 184 | Ls.backward() 185 | optimizer.step() 186 | 187 | if i == 1: 188 | # utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), 0, config.prev_path, L2_c) 189 | utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), 0, config.prev_path, L2_mask) 190 | 191 | # eval_learning_rate = config.eval_learning_rate 192 | # if i % 10 == 0 and i != 0: # 评估时学习率,每5轮衰减20% 193 | # eval_learning_rate = (1 - 0.20) * eval_learning_rate 194 | 195 | # t_params.data = t_params.data - eval_learning_rate * t_params.grad.data 196 | # t_params.data = t_params.data.clamp(0., 1.) 197 | # print(i, t_params.grad, t_params.data) 198 | 199 | # one-hot编码:argmax处理(这里没搞清楚定义方法但没返回值的作用,直接写方法的内容在这里处理) 200 | # def argmax_params(params, start, count) 201 | # utils.argmax_params(t_params.data, 96, 3)sss #这行有没有用?待定 202 | 203 | # start = 96 204 | # count = 3 205 | # dims = t_params.size()[0] 206 | # for dim in range(dims): 207 | # tmp = t_params[dim, start] 208 | # mx = start 209 | # for idx in range(start + 1, start + count): 210 | # if t_params[dim, idx] > tmp: 211 | # mx = idx 212 | # tmp = t_params[dim, idx] 213 | # for idx in range(start, start + count): 214 | # t_params[dim, idx] = 1. if idx == mx else 0 215 | 216 | # one-hot编码:argmax处理结束 217 | 218 | t_params.grad.zero_() 219 | m_progress.set_description(info) 220 | if i % config.eval_prev_freq == 0: 221 | x = i / float(config.total_eval_steps) 222 | lr = config.eval_learning_rate * (1 - x) + 1e-2 223 | # utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), i, config.prev_path, L2_c) 224 | utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), i, config.prev_path, L2_mask) 225 | utils.eval_plot(losses) 226 | utils.eval_plot(losses) 227 | 228 | # 输出 229 | # utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), config.total_eval_steps+1, config.prev_path, L2_c) 230 | utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), config.total_eval_steps+1, config.prev_path, L2_mask) 231 | -------------------------------------------------------------------------------- /evaluate_sgd_L1Loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn.functional as F 7 | from torchvision import transforms as T 8 | import torchvision.utils as vutils 9 | from PIL import Image 10 | 11 | import utils 12 | import config 13 | from imitator import Imitator 14 | from lightcnn import LightCNN_29Layers_v2 15 | from faceparse import BiSeNet 16 | from face_parser import load_model 17 | from face_align import template_path, face_Align 18 | from faceswap import template_path, faceswap 19 | 20 | ''' 21 | evaluate with torch tensor 22 | :param input: torch tensor [B, H, W, C] rang: [0-1], not [0-255] 23 | :param image: 做mask图底色用 24 | :param bsnet: BiSeNet model 25 | :param w: tuple len=6 [eyebrow,eye,nose,teeth,up lip,lower lip] 26 | :return 带权重的tensor;脸部mask;带mask的图 27 | ''' 28 | def faceparsing_tensor(input, image, face_parse, is_label=False): 29 | out = face_parse(input) # [1, 19, 512, 512] 30 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0) 31 | mask_img = utils.vis_parsing_maps(image, parsing, 1) 32 | # # 7类:带脸 33 | # { 34 | # 0: 'background', 35 | # 1: 'face', 36 | # 2: 'brow', 37 | # 3: 'eye', 38 | # 4: 'nose', 39 | # 5: 'up_lip', 40 | # 6: 'down_lip' 41 | # } 42 | if is_label: # 如果是标签,则做成[1, 512, 512]的 43 | out = out.max(1)[1] 44 | else: 45 | out = out.squeeze() # [7, 512, 512] 46 | 47 | # 10类:background, face, right eyebrow, left eyebrow, right eye, left eye, nose, upper lip, tooth, down lip 48 | # 7类:background, face, brow, eye, nose, uplip, downlip 49 | # weights = torch.tensor([1., 1., 0., 0., 0.9, 0.8, 0.8], dtype=torch.float32).unsqueeze(1).unsqueeze(2) 50 | 51 | # return torch.mul(weights, out), mask_img 52 | return out, mask_img 53 | 54 | def main(): 55 | eval_imagepath = "./dat/16.jpg" # 要评估的图片 56 | 57 | # 1.加载lightcnn 58 | lightcnn = LightCNN_29Layers_v2(num_classes=80013) 59 | lightcnn.eval() 60 | if config.use_gpu: 61 | checkpoint = torch.load(config.lightcnn_checkpoint) 62 | model = torch.nn.DataParallel(lightcnn).cuda() 63 | model.load_state_dict(checkpoint['state_dict']) 64 | else: 65 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu") 66 | new_state_dict = lightcnn.state_dict() 67 | for k, v in checkpoint['state_dict'].items(): 68 | _name = k[7:] # remove `module.` 69 | new_state_dict[_name] = v 70 | lightcnn.load_state_dict(new_state_dict) 71 | 72 | # 冻结lightcnn 73 | for param in lightcnn.parameters(): 74 | param.requires_grad = False 75 | 76 | losses = [] 77 | 78 | # 2.加载语义分割网络 79 | mean = [0.485, 0.456, 0.406] 80 | std = [0.229, 0.224, 0.225] 81 | transform = T.Compose([ 82 | T.Normalize(mean=[0.485, 0.456, 0.406], 83 | std=[0.229, 0.224, 0.225]), 84 | ]) 85 | 86 | deeplab = load_model('mobilenetv2', num_classes=config.num_classes, output_stride=config.output_stride) 87 | checkpoint = torch.load(config.faceparse_checkpoint, map_location=torch.device('cpu')) 88 | if config.faceparse_backbone == 'resnet50': 89 | deeplab.load_state_dict(checkpoint) 90 | else: 91 | deeplab.load_state_dict(checkpoint["model_state"]) 92 | deeplab.eval() 93 | 94 | for param in deeplab.parameters(): 95 | param.requires_grad = False 96 | 97 | # 3.加载imitator 98 | imitator = Imitator(is_bias=True) 99 | 100 | l2_c = (torch.ones((512, 512)), torch.ones((512, 512))) 101 | if config.use_gpu: 102 | imitator.cuda() 103 | imitator.eval() 104 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu')) 105 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数 106 | 107 | # 冻结imitator 108 | for param in imitator.parameters(): 109 | param.requires_grad = False 110 | 111 | # 读取图片 112 | # 先对齐 113 | warped_mask = face_Align(template_path, eval_imagepath) 114 | warped_mask = cv2.cvtColor(warped_mask, cv2.COLOR_BGR2RGB) 115 | # img_upload = Image.open(eval_imagepath) 116 | img_upload = Image.fromarray(np.uint8(warped_mask)) # 对齐后的 117 | 118 | image_mask = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # 做mask图用 119 | 120 | image_F1 = img_upload.convert('L').resize((128, 128), Image.BILINEAR) # F1损失:身份验证损失 121 | image_F1 = T.ToTensor()(image_F1).unsqueeze(0) # [1, 1, 128, 128] 122 | 123 | image_F2 = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # F2损失:内容损失 124 | image_F2 = T.ToTensor()(image_F2).unsqueeze(0) # [1, 3, 512, 512] 125 | 126 | image_F2 = transform(image_F2) # 训练mobilenet时做了normalize,推理时也得做 127 | 128 | t_params = torch.full([1, config.continuous_params_size], 0.5, dtype=torch.float32) # 从平均人脸初始化 129 | optimizer = torch.optim.SGD([t_params], lr=config.eval_learning_rate, momentum=0.9) # SGD带动量的 130 | if config.use_gpu: 131 | t_params = t_params.cuda() 132 | 133 | t_params.requires_grad = True 134 | losses.clear() # 清空损失 135 | 136 | # 做total_eval_steps次训练,取最后一次 137 | m_progress = tqdm(range(1, config.total_eval_steps + 1)) 138 | for i in m_progress: 139 | gen_img = imitator(t_params) # [1, 3, 512, 512] 140 | gen_img_copy = gen_img.clone() # 复制出一份来展示用 141 | 142 | # 1.身份验证损失 143 | trans = T.Compose([ 144 | T.Resize((128, 128)), 145 | T.Grayscale() 146 | ]) 147 | 148 | F1_Loss = utils.discriminative_loss(image_F1, trans(gen_img), lightcnn) # 身份验证损失 149 | 150 | # 2.内容损失 151 | gen_img = transform(gen_img) 152 | upload_F2_feature, mask_img_upload = faceparsing_tensor(image_F2, image_mask, deeplab, is_label=False) # 参照 153 | gen_F2_feature, mask_img_gen = faceparsing_tensor(gen_img, image_mask, deeplab, is_label=False) # 生成 154 | 155 | F2_Loss = F.l1_loss(upload_F2_feature, gen_F2_feature) 156 | # F2_Loss = F.cross_entropy(gen_F2_feature.view(-1, 7, 512*512), upload_F2_feature.view(-1, 512*512).long()) # 这里改用交叉熵损失 157 | 158 | # 计算综合损失:Ls = alpha * L1 + L2 159 | # Ls = config.eval_alpha * (1 - F1_Loss) + F2_Loss 160 | Ls = F2_Loss 161 | info = "1-F1:{0:.6f} F2:{1:.6f} Ls:{2:.6f}".format(1 - F1_Loss, F2_Loss, Ls) 162 | print(info) 163 | losses.append((1 - F1_Loss.item(), F2_Loss.item() / 3, Ls.item())) 164 | 165 | optimizer.zero_grad() 166 | Ls.backward() 167 | optimizer.step() 168 | 169 | if i % 5 == 0 and i != 0: # 评估时学习率,每5轮衰减20% 170 | for p in optimizer.param_groups: 171 | p["lr"] *= 0.8 172 | t_params.data = t_params.data.clamp(0.05, 0.95) 173 | t_params.grad.zero_() 174 | m_progress.set_description(info) 175 | 176 | if i % config.eval_prev_freq == 0: 177 | # image_upload, gen_img, mask_img_upload, mask_img_gen 178 | # tmp = T.PILToTensor()(T.ToPILImage()(gen_img[0])) 179 | # tmp.show() 180 | 181 | show_img_list = [T.PILToTensor()(img_upload), T.PILToTensor()(T.ToPILImage()(gen_img_copy[0])), 182 | T.ToTensor()(mask_img_upload) * 255., T.ToTensor()(mask_img_gen) * 255.] 183 | 184 | label_show = vutils.make_grid(show_img_list, nrow=2, padding=2, normalize=True).cpu() 185 | vutils.save_image(label_show, os.path.join(config.prev_path, "eval_%d.png" % (i))) 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | 191 | 192 | -------------------------------------------------------------------------------- /evaluate_sgd_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn.functional as F 7 | from torchvision import transforms as T 8 | import torchvision.utils as vutils 9 | from PIL import Image 10 | 11 | import utils 12 | import config 13 | from imitator import Imitator 14 | from lightcnn import LightCNN_29Layers_v2 15 | from face_parser import load_model 16 | from face_align import template_path, face_Align 17 | from faceswap import template_path, faceswap 18 | 19 | ''' 20 | evaluate with torch tensor 21 | :param input: torch tensor [B, H, W, C] rang: [0-1], not [0-255] 22 | :param image: 做mask图底色用 23 | :param bsnet: BiSeNet model 24 | :param w: tuple len=6 [eyebrow,eye,nose,teeth,up lip,lower lip] 25 | :return 带权重的tensor;脸部mask;带mask的图 26 | ''' 27 | def faceparsing_tensor(input, image, face_parse, is_label=False): 28 | out = face_parse(input) # [1, 19, 512, 512] 29 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0) 30 | mask_img = utils.vis_parsing_maps(image, parsing, 1) 31 | # # 7类 32 | # { 33 | # 0: 'background', 34 | # 1: 'face', 35 | # 2: 'brow', 36 | # 3: 'eye', 37 | # 4: 'nose', 38 | # 5: 'up_lip', 39 | # 6: 'down_lip' 40 | # } 41 | 42 | if is_label: # 如果是标签,则做成[1, 512, 512]的 43 | out = out.max(1)[1] 44 | else: 45 | out = out.squeeze() # [7, 512, 512] 46 | # tmp = out[4] 47 | return out, mask_img 48 | 49 | if __name__ == '__main__': 50 | eval_imagepath = "./dat/16.jpg" # 要评估的图片 51 | 52 | # 1.加载lightcnn 53 | lightcnn = LightCNN_29Layers_v2(num_classes=80013) 54 | lightcnn.eval() 55 | if config.use_gpu: 56 | checkpoint = torch.load(config.lightcnn_checkpoint) 57 | model = torch.nn.DataParallel(lightcnn).cuda() 58 | model.load_state_dict(checkpoint['state_dict']) 59 | else: 60 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu") 61 | new_state_dict = lightcnn.state_dict() 62 | for k, v in checkpoint['state_dict'].items(): 63 | _name = k[7:] # remove `module.` 64 | new_state_dict[_name] = v 65 | lightcnn.load_state_dict(new_state_dict) 66 | 67 | # 冻结lightcnn 68 | for param in lightcnn.parameters(): 69 | param.requires_grad = False 70 | 71 | losses = [] 72 | 73 | # 2.加载语义分割网络 74 | mean = [0.485, 0.456, 0.406] 75 | std = [0.229, 0.224, 0.225] 76 | transform = T.Compose([ 77 | T.Normalize(mean=[0.485, 0.456, 0.406], 78 | std=[0.229, 0.224, 0.225]), 79 | ]) 80 | 81 | deeplab = load_model(backbone=config.faceparse_backbone, num_classes=config.num_classes, output_stride=config.output_stride) 82 | checkpoint = torch.load(config.faceparse_checkpoint, map_location=torch.device('cpu')) 83 | deeplab.load_state_dict(checkpoint['model_state']) 84 | deeplab.eval() 85 | 86 | for param in deeplab.parameters(): 87 | param.requires_grad = False 88 | 89 | # 3.加载imitator 90 | imitator = Imitator(is_bias=True) 91 | 92 | l2_c = (torch.ones((512, 512)), torch.ones((512, 512))) 93 | if config.use_gpu: 94 | imitator.cuda() 95 | imitator.eval() 96 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu')) 97 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数 98 | 99 | # 冻结imitator 100 | for param in imitator.parameters(): 101 | param.requires_grad = False 102 | 103 | # 读取图片 104 | # 先对齐 105 | warped_mask = face_Align(template_path, eval_imagepath) 106 | warped_mask = cv2.cvtColor(warped_mask, cv2.COLOR_BGR2RGB) 107 | # img_upload = Image.open(eval_imagepath) 108 | img_upload = Image.fromarray(np.uint8(warped_mask)) # 对齐后的 109 | 110 | image_mask = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # 做mask图用 111 | 112 | image_F1 = img_upload.convert('L').resize((128, 128), Image.BILINEAR) # F1损失:身份验证损失 113 | image_F1 = T.ToTensor()(image_F1).unsqueeze(0) # [1, 1, 128, 128] 114 | 115 | image_F2 = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # F2损失:内容损失 116 | image_F2 = T.ToTensor()(image_F2).unsqueeze(0) # [1, 3, 512, 512] 117 | image_F2 = transform(image_F2) # 训练mobilenet时做了normalize,推理时也得做 118 | 119 | t_params = torch.full([1, config.continuous_params_size], 0.5, dtype=torch.float32) # 从平均人脸初始化 120 | optimizer = torch.optim.SGD([t_params], lr=config.eval_learning_rate, momentum=0.9) # SGD带动量的 121 | if config.use_gpu: 122 | t_params = t_params.cuda() 123 | 124 | t_params.requires_grad = True 125 | losses.clear() # 清空损失 126 | 127 | # 做total_eval_steps次训练,取最后一次 128 | m_progress = tqdm(range(1, config.total_eval_steps + 1)) 129 | for i in m_progress: 130 | gen_img = imitator(t_params) # [1, 3, 512, 512] 131 | gen_img_copy = gen_img.clone() # 复制出一份来展示用 132 | 133 | # 1.身份验证损失 134 | trans = T.Compose([ 135 | T.Resize((128, 128)), 136 | T.Grayscale() 137 | ]) 138 | 139 | F1_Loss = utils.discriminative_loss(image_F1, trans(gen_img), lightcnn) # 身份验证损失 140 | 141 | # 2.内容损失 142 | gen_img = transform(gen_img) 143 | upload_F2_feature, mask_img_upload = faceparsing_tensor(image_F2, image_mask, deeplab, is_label=True) # 参照 144 | gen_F2_feature, mask_img_gen = faceparsing_tensor(gen_img, image_mask, deeplab, is_label=False) # 生成 145 | 146 | # tmp = upload_F2_feature.view(-1, 512*512).softmax(dim=0).max(0)[0].long() 147 | 148 | # F2_Loss = F.l1_loss(upload_F2_feature, gen_F2_feature) 149 | F2_Loss = F.cross_entropy(gen_F2_feature.view(-1, config.num_classes, 512*512), upload_F2_feature.view(-1, 512*512).long()) # 这里改用交叉熵损失 150 | # F2_Loss = F.cross_entropy(gen_F2_feature.view(-1, config.num_classes, 512*512), upload_F2_feature.view(-1, 512*512).softmax(dim=0).max(0)[0].view(-1, 512*512).long()) # 采用带条件概率的交叉熵损失 151 | 152 | # 计算综合损失:Ls = alpha * L1 + L2 153 | Ls = config.eval_alpha * (1 - F1_Loss) + F2_Loss 154 | # Ls = F1_Loss 155 | # Ls = F2_Loss 156 | info = "1-F1:{0:.6f} F2:{1:.6f} Ls:{2:.6f}".format(1 - F1_Loss, F2_Loss, Ls) 157 | print(info) 158 | losses.append((1 - F1_Loss.item(), F2_Loss.item(), Ls.item())) 159 | 160 | optimizer.zero_grad() 161 | Ls.backward() 162 | optimizer.step() 163 | 164 | if i % 5 == 0 and i != 0: # 评估时学习率,每5轮衰减20% 165 | for p in optimizer.param_groups: 166 | p["lr"] *= 0.8 167 | t_params.data = t_params.data.clamp(0.05, 0.95) 168 | t_params.grad.zero_() 169 | m_progress.set_description(info) 170 | 171 | if i % config.eval_prev_freq == 0 or i == 1: 172 | # image_upload, gen_img, mask_img_upload, mask_img_gen 173 | # tmp = T.PILToTensor()(T.ToPILImage()(gen_img[0])) 174 | # tmp.show() 175 | 176 | show_img_list = [T.PILToTensor()(img_upload), T.PILToTensor()(T.ToPILImage()(gen_img_copy[0])), T.ToTensor()(mask_img_upload) * 255., T.ToTensor()(mask_img_gen) * 255.] 177 | 178 | label_show = vutils.make_grid(show_img_list, nrow=2, padding=2, normalize=True).cpu() 179 | vutils.save_image(label_show, os.path.join(config.prev_path, "eval_%d.png" % i)) 180 | -------------------------------------------------------------------------------- /face_align.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import dlib 3 | import numpy 4 | import sys 5 | import matplotlib.pyplot as plt 6 | 7 | PREDICTOR_PATH = r"./checkpoint/shape_predictor_68_face_landmarks.dat" # 68个关键点landmarks的模型文件 8 | SCALE_FACTOR = 1 # 图像的放缩比 9 | FEATHER_AMOUNT = 15 # 羽化边界范围,越大,羽化能力越大,一定要奇数,不能偶数 10 | 11 | # 68个点 12 | FACE_POINTS = list(range(17, 68)) # 脸 13 | MOUTH_POINTS = list(range(48, 61)) # 嘴巴 14 | RIGHT_BROW_POINTS = list(range(17, 22)) # 右眉毛 15 | LEFT_BROW_POINTS = list(range(22, 27)) # 左眉毛 16 | RIGHT_EYE_POINTS = list(range(36, 42)) # 右眼睛 17 | LEFT_EYE_POINTS = list(range(42, 48)) # 左眼睛 18 | NOSE_POINTS = list(range(27, 35)) # 鼻子 19 | JAW_POINTS = list(range(0, 17)) # 下巴 20 | 21 | # 选取用于叠加在第一张脸上的第二张脸的面部特征 22 | # 特征点包括左右眼、眉毛、鼻子和嘴巴 23 | # 是否数量变多之后,会有什么干扰吗? 24 | ALIGN_POINTS = (FACE_POINTS + LEFT_BROW_POINTS + RIGHT_EYE_POINTS + LEFT_EYE_POINTS + 25 | RIGHT_BROW_POINTS + NOSE_POINTS + MOUTH_POINTS) 26 | 27 | # Points from the second image to overlay on the first. The convex hull of each 28 | # element will be overlaid. 29 | OVERLAY_POINTS = [ 30 | LEFT_EYE_POINTS + RIGHT_EYE_POINTS + LEFT_BROW_POINTS + RIGHT_BROW_POINTS, 31 | NOSE_POINTS + MOUTH_POINTS, 32 | ] 33 | # 眼睛 ,眉毛 2 * 22 34 | # 鼻子,嘴巴 分开来 35 | 36 | # 定义用于颜色校正的模糊量,作为瞳孔距离的系数 37 | COLOUR_CORRECT_BLUR_FRAC = 0.6 38 | 39 | # 实例化脸部检测器 40 | detector = dlib.get_frontal_face_detector() 41 | # 加载训练模型 42 | # 并实例化特征提取器 43 | predictor = dlib.shape_predictor(PREDICTOR_PATH) 44 | 45 | 46 | # 定义了两个类处理意外 47 | class TooManyFaces(Exception): 48 | pass 49 | 50 | 51 | class NoFaces(Exception): 52 | pass 53 | 54 | 55 | def get_landmarks(im): 56 | ''' 57 | 通过predictor 拿到68 landmarks 58 | ''' 59 | rects = detector(im, 1) 60 | 61 | if len(rects) > 1: 62 | raise TooManyFaces 63 | if len(rects) == 0: 64 | raise NoFaces 65 | 66 | return numpy.matrix([[p.x, p.y] for p in predictor(im, rects[0]).parts()]) # 68*2的矩阵 67 | 68 | 69 | def annotate_landmarks(im, landmarks): 70 | ''' 71 | 人脸关键点,画图函数 72 | ''' 73 | im = im.copy() 74 | for idx, point in enumerate(landmarks): 75 | pos = (point[0, 0], point[0, 1]) 76 | cv2.putText(im, str(idx), pos, 77 | fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 78 | fontScale=0.4, 79 | color=(0, 0, 255)) 80 | cv2.circle(im, pos, 3, color=(0, 255, 255)) 81 | return im 82 | 83 | 84 | def draw_convex_hull(im, points, color): 85 | ''' 86 | # 绘制凸多边形 计算凸包 87 | ''' 88 | points = cv2.convexHull(points) 89 | cv2.fillConvexPoly(im, points, color=color) 90 | 91 | 92 | def get_face_mask(im, landmarks): 93 | '''获取面部特征部分(眉毛、眼睛、鼻子以及嘴巴)的图像掩码。 94 | 图像掩码作用于原图之后,原图中对应掩码部分为白色的部分才能显示出来,黑色的部分则不予显示,因此通过图像掩码我们就能实现对图像“裁剪”。 95 | 效果参考:https://dn-anything-about-doc.qbox.me/document-uid242676labid2260timestamp1477921310170.png/wm 96 | get_face_mask()的定义是为一张图像和一个标记矩阵生成一个遮罩,它画出了两个白色的凸多边形:一个是眼睛周围的区域, 97 | 一个是鼻子和嘴部周围的区域。之后它由11个(FEATHER_AMOUNT)像素向遮罩的边缘外部羽化扩展,可以帮助隐藏任何不连续的区域。 98 | ''' 99 | im = numpy.zeros(im.shape[:2], dtype=numpy.float64) 100 | 101 | for group in OVERLAY_POINTS: 102 | draw_convex_hull(im, 103 | landmarks[group], 104 | color=1) 105 | 106 | im = numpy.array([im, im, im]).transpose((1, 2, 0)) 107 | 108 | im = (cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) > 0) * 1.0 109 | im = cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) 110 | 111 | return im 112 | 113 | # 返回一个仿射变换 114 | def transformation_from_points(points1, points2): 115 | """ 116 | Return an affine transformation [s * R | T] such that: 117 | sum ||s*R*p1,i + T - p2,i||^2 118 | is minimized. 119 | """ 120 | # Solve the procrustes problem by subtracting centroids, scaling by the 121 | # standard deviation, and then using the SVD to calculate the rotation. See 122 | # the following for more details: 123 | # https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem 124 | 125 | points1 = points1.astype(numpy.float64) # 人脸的指定关键点 126 | points2 = points2.astype(numpy.float64) 127 | 128 | # 数据标准化:先减去均值,再除以std,做成均值为0方差为1的序列 129 | # 每张脸各自做各自的标准化 130 | c1 = numpy.mean(points1, axis=0) # 分别算x和y的均值 131 | c2 = numpy.mean(points2, axis=0) 132 | points1 -= c1 # 浮动于均值的部分,[43, 2] 133 | points2 -= c2 134 | 135 | s1 = numpy.std(points1) 136 | s2 = numpy.std(points2) 137 | points1 /= s1 # 138 | points2 /= s2 139 | 140 | U, S, Vt = numpy.linalg.svd(points1.T * points2) # 141 | 142 | # The R we seek is in fact the transpose of the one given by U * Vt. This 143 | # is because the above formulation assumes the matrix goes on the right 144 | # (with row vectors) where as our solution requires the matrix to be on the 145 | # left (with column vectors). 146 | R = (U * Vt).T # [2, 2] 147 | 148 | return numpy.vstack([numpy.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), 149 | numpy.matrix([0., 0., 1.])]) 150 | 151 | 152 | def read_im_and_landmarks(fname): 153 | im = cv2.imread(fname, cv2.IMREAD_COLOR) 154 | im = cv2.resize(im, (im.shape[1] * SCALE_FACTOR, 155 | im.shape[0] * SCALE_FACTOR)) 156 | s = get_landmarks(im) # [68, 2] 157 | 158 | return im, s 159 | 160 | 161 | def warp_im(im, M, dshape): 162 | ''' 163 | 由 get_face_mask 获得的图像掩码还不能直接使用,因为一般来讲用户提供的两张图像的分辨率大小很可能不一样,而且即便分辨率一样, 164 | 图像中的人脸由于拍摄角度和距离等原因也会呈现出不同的大小以及角度,所以如果不能只是简单地把第二个人的面部特征抠下来直接放在第一个人脸上, 165 | 我们还需要根据两者计算所得的面部特征区域进行匹配变换,使得二者的面部特征尽可能重合。 166 | 167 | 仿射函数,warpAffine,能对图像进行几何变换 168 | 三个主要参数,第一个输入图像,第二个变换矩阵 np.float32 类型,第三个变换之后图像的宽高 169 | 170 | 对齐主要函数 171 | ''' 172 | output_im = numpy.zeros(dshape, dtype=im.dtype) # [512, 512, 3] 173 | cv2.warpAffine(im, 174 | M[:2], 175 | (dshape[1], dshape[0]), 176 | dst=output_im, 177 | borderMode=cv2.BORDER_TRANSPARENT, 178 | flags=cv2.WARP_INVERSE_MAP) 179 | return output_im 180 | 181 | 182 | def correct_colours(im1, im2, landmarks1): 183 | ''' 184 | 修改皮肤颜色,使两张图片在拼接时候显得更加自然。 185 | ''' 186 | blur_amount = COLOUR_CORRECT_BLUR_FRAC * numpy.linalg.norm( 187 | numpy.mean(landmarks1[LEFT_EYE_POINTS], axis=0) - 188 | numpy.mean(landmarks1[RIGHT_EYE_POINTS], axis=0)) 189 | blur_amount = int(blur_amount) 190 | if blur_amount % 2 == 0: 191 | blur_amount += 1 192 | im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0) 193 | im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0) 194 | 195 | # Avoid divide-by-zero errors. 196 | im2_blur += (128 * (im2_blur <= 1.0)).astype(im2_blur.dtype) 197 | 198 | return (im2.astype(numpy.float64) * im1_blur.astype(numpy.float64) / 199 | im2_blur.astype(numpy.float64)) 200 | 201 | 202 | # 换脸函数 203 | def Switch_face(Base_path, cover_path): 204 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图 205 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图 206 | 207 | if len(landmarks1) == 0 & len(landmarks2) == 0: 208 | raise RuntimeError("Faces detected is no face!") 209 | if len(landmarks1) > 1 & len(landmarks2) > 1: 210 | raise RuntimeError("Faces detected is more than 1!") 211 | 212 | # landmarks1[ALIGN_POINTS]为人脸的的指定关键点 213 | M = transformation_from_points(landmarks1[ALIGN_POINTS], 214 | landmarks2[ALIGN_POINTS]) 215 | mask = get_face_mask(im2, landmarks2) 216 | warped_mask = warp_im(mask, M, im1.shape) 217 | combined_mask = numpy.max([get_face_mask(im1, landmarks1), warped_mask], 218 | axis=0) 219 | warped_im2 = warp_im(im2, M, im1.shape) 220 | warped_corrected_im2 = correct_colours(im1, warped_im2, landmarks1) 221 | 222 | output_im = im1 * (1.0 - combined_mask) + warped_corrected_im2 * combined_mask 223 | return output_im 224 | 225 | 226 | # 人脸对齐函数 227 | def face_Align(Base_path, cover_path): 228 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图 229 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图 230 | 231 | # 得到仿射变换矩阵 232 | M = transformation_from_points(landmarks1[ALIGN_POINTS], 233 | landmarks2[ALIGN_POINTS]) 234 | warped_im2 = warp_im(im2, M, im1.shape) 235 | return warped_im2 236 | 237 | FEATHER_AMOUNT = 19 238 | 239 | template_path = './dat/avg_face.jpg' # 模板 240 | if __name__ == '__main__': 241 | cover_path = './dat/16.jpg' 242 | warped_mask = face_Align(template_path, cover_path) 243 | cv2.imwrite("./dat/result_16.jpg", warped_mask) 244 | 245 | # plt.subplot(111) 246 | # plt.imshow(warped_mask) # 数据展示 247 | # plt.show() -------------------------------------------------------------------------------- /face_parser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models.utils import load_state_dict_from_url 5 | from collections import OrderedDict 6 | 7 | import config 8 | from backbone.resnet import ResNet50, Bottleneck 9 | from backbone.mobilenet import MobileNetV2 10 | 11 | ''' 12 | 指定backbone,加载模型 13 | ''' 14 | def load_model(backbone, num_classes, output_stride): 15 | if backbone == 'resnet50': 16 | model = segment_resnet(num_classes=num_classes, output_stride=output_stride) 17 | elif backbone == 'mobilenetv2': 18 | model = segment_mobilenetv2(num_classes=num_classes, output_stride=output_stride) 19 | else: 20 | raise NotImplementedError 21 | return model 22 | 23 | ''' 24 | 语义分割网络,使用resnet作为backbone 25 | :param num_classes 分割的类别个数 26 | :param output_stride 27 | ''' 28 | def segment_resnet(num_classes, output_stride): 29 | if output_stride == 8: 30 | replace_stride_with_dilation = [False, True, True] # 是否用空洞卷积代替stride 31 | aspp_dilate = [12, 24, 36] # ASPP结构空洞卷积的dilate大小 32 | else: 33 | replace_stride_with_dilation = [False, False, True] 34 | aspp_dilate = [6, 12, 18] 35 | 36 | # 加载backbone(这里resnet50最后两层仅为了加载预训练模型,需设置num_classes为1000) 37 | backbone = ResNet50(Bottleneck, [3, 4, 6, 3], num_classes=1000, replace_stride_with_dilation=replace_stride_with_dilation) 38 | if config.pretrained: 39 | state_dict = load_state_dict_from_url(config.model_urls, progress=config.progress) 40 | backbone.load_state_dict(state_dict) 41 | del state_dict 42 | 43 | inplanes = 2048 44 | low_level_planes = 256 45 | 46 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} # layer4作为resnet的最后输出,作为encoder端ASPP结构的输入;layer1作为低等级特征,直接输入到decoder端 47 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) 48 | 49 | # 提取网络的第几层输出结果并给一个别名 50 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 51 | 52 | # 组装成完整的分割模型 53 | # model = Simple_Segmentation_Model(backbone, classifier) 54 | model = FaceSegmentation(backbone, classifier) 55 | return model 56 | 57 | ''' 58 | 语义分割网络,使用mobilenetv2作为backbone 59 | :param num_classes 分割的类别个数 60 | :param output_stride 61 | ''' 62 | def segment_mobilenetv2(num_classes, output_stride): 63 | if output_stride == 8: 64 | aspp_dilate = [12, 24, 36] 65 | else: 66 | aspp_dilate = [6, 12, 18] 67 | 68 | # 这里num_classes要写1000,为加载上预训练模型 69 | backbone = MobileNetV2(num_classes=1000, output_stride=output_stride, width_mult=1.0, inverted_residual_setting=None, round_nearest=8) 70 | if config.pretrained: 71 | state_dict = load_state_dict_from_url(config.model_urls, progress=config.progress) 72 | backbone.load_state_dict(state_dict) 73 | del state_dict 74 | 75 | # 将backbone的特征分为高阶特征和低阶特征 76 | backbone.low_level_features = backbone.features[0:4] # 倒数第四层之前的全为低等级特征 77 | backbone.high_level_features = backbone.features[4:-1] # 第四层开始,到倒数第二层的为高等级特征 78 | backbone.features = None 79 | backbone.classifier = None 80 | 81 | inplanes = 320 82 | low_level_planes = 24 83 | 84 | return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'} 85 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) # 这里要写传入参数的num_classes,为加载voc预训练模型考虑 86 | 87 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 88 | 89 | model = FaceSegmentation(backbone, classifier) 90 | return model 91 | 92 | ''' 93 | deeplabv3+ 94 | ''' 95 | class DeepLabHeadV3Plus(nn.Module): 96 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): 97 | super(DeepLabHeadV3Plus, self).__init__() 98 | self.project = nn.Sequential( 99 | nn.Conv2d(low_level_channels, 48, 1, bias=False), 100 | nn.BatchNorm2d(48), 101 | nn.ReLU(inplace=True), 102 | ) 103 | 104 | self.aspp = ASPP(in_channels, aspp_dilate) 105 | 106 | self.classifier = nn.Sequential( 107 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 108 | nn.BatchNorm2d(256), 109 | nn.ReLU(inplace=True), 110 | nn.Conv2d(256, num_classes, 1) 111 | ) 112 | self._init_weight() 113 | 114 | def forward(self, feature): 115 | # print(feature.shape) 116 | low_level_feature = self.project( 117 | feature['low_level']) # return_layers = {'layer4': 'out', 'layer1': 'low_level'} 118 | # print(low_level_feature.shape) 119 | output_feature = self.aspp(feature['out']) 120 | # print(output_feature.shape) 121 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', 122 | align_corners=False) 123 | # print(output_feature.shape) 124 | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1)) 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | nn.init.kaiming_normal_(m.weight) 130 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 131 | nn.init.constant_(m.weight, 1) 132 | nn.init.constant_(m.bias, 0) 133 | 134 | ''' 135 | ASPP: atrous conv spp,空洞卷积+SPP 136 | ''' 137 | class ASPP(nn.Module): 138 | def __init__(self, in_channels, atrous_rates): 139 | super(ASPP, self).__init__() 140 | out_channels = 256 141 | modules = [] 142 | modules.append(nn.Sequential( 143 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 144 | nn.BatchNorm2d(out_channels), 145 | nn.ReLU(inplace=True))) 146 | 147 | rate1, rate2, rate3 = tuple(atrous_rates) 148 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 149 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 150 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 151 | modules.append(ASPPPooling(in_channels, out_channels)) 152 | 153 | self.convs = nn.ModuleList(modules) 154 | 155 | self.project = nn.Sequential( 156 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 157 | nn.BatchNorm2d(out_channels), 158 | nn.ReLU(inplace=True), 159 | nn.Dropout(0.1),) 160 | 161 | def forward(self, x): 162 | res = [] 163 | for conv in self.convs: 164 | #print(conv(x).shape) 165 | res.append(conv(x)) 166 | res = torch.cat(res, dim=1) 167 | return self.project(res) 168 | 169 | ''' 170 | ASPP模块中conv结构 171 | ''' 172 | class ASPPConv(nn.Sequential): 173 | def __init__(self, in_channels, out_channels, dilation): 174 | modules = [ 175 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 176 | nn.BatchNorm2d(out_channels), 177 | nn.ReLU(inplace=True) 178 | ] 179 | super(ASPPConv, self).__init__(*modules) 180 | 181 | ''' 182 | ASPP模块中pooling结构 183 | ''' 184 | class ASPPPooling(nn.Sequential): 185 | def __init__(self, in_channels, out_channels): 186 | super(ASPPPooling, self).__init__( 187 | nn.AdaptiveAvgPool2d(1), 188 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 189 | nn.BatchNorm2d(out_channels), 190 | nn.ReLU(inplace=True)) 191 | 192 | def forward(self, x): 193 | size = x.shape[-2:] 194 | x = super(ASPPPooling, self).forward(x) 195 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 196 | 197 | class IntermediateLayerGetter(nn.ModuleDict): 198 | """ 199 | Module wrapper that returns intermediate layers from a model 200 | 201 | It has a strong assumption that the modules have been registered 202 | into the model in the same order as they are used. 203 | This means that one should **not** reuse the same nn.Module 204 | twice in the forward if you want this to work. 205 | 206 | Additionally, it is only able to query submodules that are directly 207 | assigned to the model. So if `model` is passed, `model.feature1` can 208 | be returned, but not `model.feature1.layer2`. 209 | 210 | Arguments: 211 | model (nn.Module): model on which we will extract the features 212 | return_layers (Dict[name, new_name]): a dict containing the names 213 | of the modules for which the activations will be returned as 214 | the key of the dict, and the value of the dict is the name 215 | of the returned activation (which the user can specify). 216 | 217 | Examples:: 218 | 219 | >>> m = torchvision.models.resnet18(pretrained=True) 220 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 221 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 222 | >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 223 | >>> out = new_m(torch.rand(1, 3, 224, 224)) 224 | >>> print([(k, v.shape) for k, v in out.items()]) 225 | >>> [('feat1', torch.Size([1, 64, 56, 56])), 226 | >>> ('feat2', torch.Size([1, 256, 14, 14]))] 227 | """ 228 | def __init__(self, model, return_layers): 229 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 230 | raise ValueError("return_layers are not present in model") 231 | 232 | orig_return_layers = return_layers 233 | return_layers = {k: v for k, v in return_layers.items()} 234 | layers = OrderedDict() 235 | for name, module in model.named_children(): 236 | layers[name] = module 237 | if name in return_layers: 238 | del return_layers[name] 239 | if not return_layers: 240 | break 241 | 242 | super(IntermediateLayerGetter, self).__init__(layers) 243 | self.return_layers = orig_return_layers 244 | 245 | def forward(self, x): 246 | out = OrderedDict() 247 | for name, module in self.named_children(): 248 | x = module(x) 249 | if name in self.return_layers: 250 | out_name = self.return_layers[name] 251 | out[out_name] = x 252 | return out 253 | 254 | ''' 255 | 完整的分割模型 256 | ''' 257 | class Simple_Segmentation_Model(nn.Module): 258 | def __init__(self, backbone, classifier): 259 | super(Simple_Segmentation_Model, self).__init__() 260 | self.backbone = backbone 261 | self.classifier = classifier 262 | 263 | def forward(self, x): 264 | input_shape = x.shape[-2:] 265 | features = self.backbone(x) 266 | x = self.classifier(features) 267 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) 268 | return x 269 | 270 | class FaceSegmentation(Simple_Segmentation_Model): 271 | pass 272 | 273 | if __name__ == '__main__': 274 | model = load_model('resnet50', num_classes=config.num_classes, output_stride=config.output_stride) 275 | model = model.to(config.device) 276 | # print(model) 277 | input = torch.randn([16, 3, 513, 513]) 278 | output = model(input) 279 | print(output.shape) # torch.Size([16, 10, 513, 513]) 280 | 281 | -------------------------------------------------------------------------------- /faceparse.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | 7 | 8 | from resnet import Resnet18 9 | 10 | 11 | class ConvBNReLU(nn.Module): 12 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 13 | super(ConvBNReLU, self).__init__() 14 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) 15 | self.bn = nn.BatchNorm2d(out_chan) 16 | self.init_weight() 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = F.relu(self.bn(x)) 21 | return x 22 | 23 | def init_weight(self): 24 | for ly in self.children(): 25 | if isinstance(ly, nn.Conv2d): 26 | nn.init.kaiming_normal_(ly.weight, a=1) 27 | if ly.bias is not None: 28 | nn.init.constant_(ly.bias, 0) 29 | 30 | 31 | class BiSeNetOutput(nn.Module): 32 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 33 | super(BiSeNetOutput, self).__init__() 34 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 35 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 36 | self.init_weight() 37 | 38 | def forward(self, x): 39 | x = self.conv(x) 40 | x = self.conv_out(x) 41 | return x 42 | 43 | def init_weight(self): 44 | for ly in self.children(): 45 | if isinstance(ly, nn.Conv2d): 46 | nn.init.kaiming_normal_(ly.weight, a=1) 47 | if ly.bias is not None: 48 | nn.init.constant_(ly.bias, 0) 49 | 50 | def get_params(self): 51 | wd_params, nowd_params = [], [] 52 | for name, module in self.named_modules(): 53 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 54 | wd_params.append(module.weight) 55 | if module.bias is not None: 56 | nowd_params.append(module.bias) 57 | elif isinstance(module, nn.BatchNorm2d): 58 | nowd_params += list(module.parameters()) 59 | return wd_params, nowd_params 60 | 61 | 62 | class AttentionRefinementModule(nn.Module): 63 | def __init__(self, in_chan, out_chan, *args, **kwargs): 64 | super(AttentionRefinementModule, self).__init__() 65 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 66 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) 67 | self.bn_atten = nn.BatchNorm2d(out_chan) 68 | self.sigmoid_atten = nn.Sigmoid() 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | feat = self.conv(x) 73 | atten = F.avg_pool2d(feat, feat.size()[2:]) 74 | atten = self.conv_atten(atten) 75 | atten = self.bn_atten(atten) 76 | atten = self.sigmoid_atten(atten) 77 | out = torch.mul(feat, atten) 78 | return out 79 | 80 | def init_weight(self): 81 | for ly in self.children(): 82 | if isinstance(ly, nn.Conv2d): 83 | nn.init.kaiming_normal_(ly.weight, a=1) 84 | if ly.bias is not None: 85 | nn.init.constant_(ly.bias, 0) 86 | 87 | 88 | class ContextPath(nn.Module): 89 | def __init__(self, *args, **kwargs): 90 | super(ContextPath, self).__init__() 91 | self.resnet = Resnet18() 92 | self.arm16 = AttentionRefinementModule(256, 128) 93 | self.arm32 = AttentionRefinementModule(512, 128) 94 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 95 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 96 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 97 | 98 | self.init_weight() 99 | 100 | def forward(self, x): 101 | H0, W0 = x.size()[2:] 102 | feat8, feat16, feat32 = self.resnet(x) 103 | H8, W8 = feat8.size()[2:] 104 | H16, W16 = feat16.size()[2:] 105 | H32, W32 = feat32.size()[2:] 106 | 107 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 108 | avg = self.conv_avg(avg) 109 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 110 | 111 | feat32_arm = self.arm32(feat32) 112 | feat32_sum = feat32_arm + avg_up 113 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 114 | feat32_up = self.conv_head32(feat32_up) 115 | 116 | feat16_arm = self.arm16(feat16) 117 | feat16_sum = feat16_arm + feat32_up 118 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 119 | feat16_up = self.conv_head16(feat16_up) 120 | 121 | return feat8, feat16_up, feat32_up # x8, x8, x16 122 | 123 | def init_weight(self): 124 | for ly in self.children(): 125 | if isinstance(ly, nn.Conv2d): 126 | nn.init.kaiming_normal_(ly.weight, a=1) 127 | if ly.bias is not None: 128 | nn.init.constant_(ly.bias, 0) 129 | 130 | def get_params(self): 131 | wd_params, nowd_params = [], [] 132 | for name, module in self.named_modules(): 133 | if isinstance(module, (nn.Linear, nn.Conv2d)): 134 | wd_params.append(module.weight) 135 | if module.bias is not None: 136 | nowd_params.append(module.bias) 137 | elif isinstance(module, nn.BatchNorm2d): 138 | nowd_params += list(module.parameters()) 139 | return wd_params, nowd_params 140 | 141 | 142 | # This is not used, since I replace this with the resnet feature with the same size 143 | class SpatialPath(nn.Module): 144 | def __init__(self, *args, **kwargs): 145 | super(SpatialPath, self).__init__() 146 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 147 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 148 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 149 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 150 | self.init_weight() 151 | 152 | def forward(self, x): 153 | feat = self.conv1(x) 154 | feat = self.conv2(feat) 155 | feat = self.conv3(feat) 156 | feat = self.conv_out(feat) 157 | return feat 158 | 159 | def init_weight(self): 160 | for ly in self.children(): 161 | if isinstance(ly, nn.Conv2d): 162 | nn.init.kaiming_normal_(ly.weight, a=1) 163 | if ly.bias is not None: 164 | nn.init.constant_(ly.bias, 0) 165 | 166 | def get_params(self): 167 | wd_params, nowd_params = [], [] 168 | for name, module in self.named_modules(): 169 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 170 | wd_params.append(module.weight) 171 | if module.bias is not None: 172 | nowd_params.append(module.bias) 173 | elif isinstance(module, nn.BatchNorm2d): 174 | nowd_params += list(module.parameters()) 175 | return wd_params, nowd_params 176 | 177 | 178 | class FeatureFusionModule(nn.Module): 179 | def __init__(self, in_chan, out_chan, *args, **kwargs): 180 | super(FeatureFusionModule, self).__init__() 181 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 182 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) 183 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) 184 | self.relu = nn.ReLU(inplace=True) 185 | self.sigmoid = nn.Sigmoid() 186 | self.init_weight() 187 | 188 | def forward(self, fsp, fcp): 189 | fcat = torch.cat([fsp, fcp], dim=1) 190 | feat = self.convblk(fcat) 191 | atten = F.avg_pool2d(feat, feat.size()[2:]) 192 | atten = self.conv1(atten) 193 | atten = self.relu(atten) 194 | atten = self.conv2(atten) 195 | atten = self.sigmoid(atten) 196 | feat_atten = torch.mul(feat, atten) 197 | feat_out = feat_atten + feat 198 | return feat_out 199 | 200 | def init_weight(self): 201 | for ly in self.children(): 202 | if isinstance(ly, nn.Conv2d): 203 | nn.init.kaiming_normal_(ly.weight, a=1) 204 | if ly.bias is not None: 205 | nn.init.constant_(ly.bias, 0) 206 | 207 | def get_params(self): 208 | wd_params, nowd_params = [], [] 209 | for name, module in self.named_modules(): 210 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 211 | wd_params.append(module.weight) 212 | if module.bias is not None: 213 | nowd_params.append(module.bias) 214 | elif isinstance(module, nn.BatchNorm2d): 215 | nowd_params += list(module.parameters()) 216 | return wd_params, nowd_params 217 | 218 | 219 | class BiSeNet(nn.Module): 220 | def __init__(self, n_classes, *args, **kwargs): 221 | super(BiSeNet, self).__init__() 222 | self.cp = ContextPath() 223 | # here self.sp is deleted 224 | self.ffm = FeatureFusionModule(256, 256) 225 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 226 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 227 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 228 | self.init_weight() 229 | 230 | def forward(self, x): 231 | H, W = x.size()[2:] 232 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 233 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 234 | feat_fuse = self.ffm(feat_sp, feat_cp8) 235 | 236 | feat_out = self.conv_out(feat_fuse) 237 | # feat_out16 = self.conv_out16(feat_cp8) 238 | # feat_out32 = self.conv_out32(feat_cp16) 239 | 240 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 241 | # feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 242 | # feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 243 | return feat_out # , feat_out16, feat_out32 244 | 245 | def init_weight(self): 246 | for ly in self.children(): 247 | if isinstance(ly, nn.Conv2d): 248 | nn.init.kaiming_normal_(ly.weight, a=1) 249 | if ly.bias is not None: 250 | nn.init.constant_(ly.bias, 0) 251 | 252 | def get_params(self): 253 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 254 | for name, child in self.named_children(): 255 | child_wd_params, child_nowd_params = child.get_params() 256 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 257 | lr_mul_wd_params += child_wd_params 258 | lr_mul_nowd_params += child_nowd_params 259 | else: 260 | wd_params += child_wd_params 261 | nowd_params += child_nowd_params 262 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 263 | 264 | if __name__ == "__main__": 265 | net = BiSeNet(19) 266 | # net.cuda() 267 | net.eval() 268 | in_ten = torch.randn(16, 3, 640, 480) 269 | # in_ten = torch.randn(3, 640, 480) 270 | # out, out16, out32 = net(in_ten) 271 | out = net(in_ten) 272 | print(out.shape) 273 | net.get_params() -------------------------------------------------------------------------------- /faceswap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # Copyright (c) 2015 Matthew Earl 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included 13 | # in all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 16 | # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN 18 | # NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 19 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 20 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE 21 | # USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | 23 | """ 24 | This is the code behind the Switching Eds blog post: 25 | http://matthewearl.github.io/2015/07/28/switching-eds-with-python/ 26 | See the above for an explanation of the code below. 27 | To run the script you'll need to install dlib (http://dlib.net) including its 28 | Python bindings, and OpenCV. You'll also need to obtain the trained model from 29 | sourceforge: 30 | http://sourceforge.net/projects/dclib/files/dlib/v18.10/shape_predictor_68_face_landmarks.dat.bz2 31 | Unzip with `bunzip2` and change `PREDICTOR_PATH` to refer to this file. The 32 | script is run like so: 33 | ./faceswap.py 34 | If successful, a file `output.jpg` will be produced with the facial features 35 | from `` replaced with the facial features from ``. 36 | """ 37 | 38 | import cv2 39 | import dlib 40 | import numpy 41 | 42 | import sys 43 | 44 | PREDICTOR_PATH = "./checkpoint/shape_predictor_68_face_landmarks.dat" 45 | SCALE_FACTOR = 1 46 | FEATHER_AMOUNT = 11 47 | 48 | FACE_POINTS = list(range(17, 68)) 49 | MOUTH_POINTS = list(range(48, 61)) 50 | RIGHT_BROW_POINTS = list(range(17, 22)) 51 | LEFT_BROW_POINTS = list(range(22, 27)) 52 | RIGHT_EYE_POINTS = list(range(36, 42)) 53 | LEFT_EYE_POINTS = list(range(42, 48)) 54 | NOSE_POINTS = list(range(27, 35)) 55 | JAW_POINTS = list(range(0, 17)) 56 | 57 | # Points used to line up the images. 58 | ALIGN_POINTS = (LEFT_BROW_POINTS + RIGHT_EYE_POINTS + LEFT_EYE_POINTS + 59 | RIGHT_BROW_POINTS + NOSE_POINTS + MOUTH_POINTS) 60 | 61 | # Points from the second image to overlay on the first. The convex hull of each 62 | # element will be overlaid. 63 | OVERLAY_POINTS = [ 64 | LEFT_EYE_POINTS + RIGHT_EYE_POINTS + LEFT_BROW_POINTS + RIGHT_BROW_POINTS, 65 | NOSE_POINTS + MOUTH_POINTS, 66 | ] 67 | 68 | # Amount of blur to use during colour correction, as a fraction of the 69 | # pupillary distance. 70 | COLOUR_CORRECT_BLUR_FRAC = 0.6 71 | 72 | detector = dlib.get_frontal_face_detector() 73 | predictor = dlib.shape_predictor(PREDICTOR_PATH) 74 | 75 | 76 | class TooManyFaces(Exception): 77 | pass 78 | 79 | 80 | class NoFaces(Exception): 81 | pass 82 | 83 | 84 | def get_landmarks(im): 85 | rects = detector(im, 1) 86 | 87 | if len(rects) > 1: 88 | raise TooManyFaces 89 | if len(rects) == 0: 90 | raise NoFaces 91 | 92 | return numpy.matrix([[p.x, p.y] for p in predictor(im, rects[0]).parts()]) 93 | 94 | 95 | def annotate_landmarks(im, landmarks): 96 | im = im.copy() 97 | for idx, point in enumerate(landmarks): 98 | pos = (point[0, 0], point[0, 1]) 99 | cv2.putText(im, str(idx), pos, 100 | fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 101 | fontScale=0.4, 102 | color=(0, 0, 255)) 103 | cv2.circle(im, pos, 3, color=(0, 255, 255)) 104 | return im 105 | 106 | 107 | def draw_convex_hull(im, points, color): 108 | points = cv2.convexHull(points) 109 | cv2.fillConvexPoly(im, points, color=color) 110 | 111 | 112 | def get_face_mask(im, landmarks): 113 | im = numpy.zeros(im.shape[:2], dtype=numpy.float64) 114 | 115 | for group in OVERLAY_POINTS: 116 | draw_convex_hull(im, 117 | landmarks[group], 118 | color=1) 119 | 120 | im = numpy.array([im, im, im]).transpose((1, 2, 0)) 121 | 122 | im = (cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) > 0) * 1.0 123 | im = cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) 124 | 125 | return im 126 | 127 | 128 | def transformation_from_points(points1, points2): 129 | """ 130 | Return an affine transformation [s * R | T] such that: 131 | sum ||s*R*p1,i + T - p2,i||^2 132 | is minimized. 133 | """ 134 | # Solve the procrustes problem by subtracting centroids, scaling by the 135 | # standard deviation, and then using the SVD to calculate the rotation. See 136 | # the following for more details: 137 | # https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem 138 | 139 | points1 = points1.astype(numpy.float64) 140 | points2 = points2.astype(numpy.float64) 141 | 142 | c1 = numpy.mean(points1, axis=0) 143 | c2 = numpy.mean(points2, axis=0) 144 | points1 -= c1 145 | points2 -= c2 146 | 147 | s1 = numpy.std(points1) 148 | s2 = numpy.std(points2) 149 | points1 /= s1 150 | points2 /= s2 151 | 152 | U, S, Vt = numpy.linalg.svd(points1.T * points2) 153 | 154 | # The R we seek is in fact the transpose of the one given by U * Vt. This 155 | # is because the above formulation assumes the matrix goes on the right 156 | # (with row vectors) where as our solution requires the matrix to be on the 157 | # left (with column vectors). 158 | R = (U * Vt).T 159 | 160 | return numpy.vstack([numpy.hstack(((s2 / s1) * R, 161 | c2.T - (s2 / s1) * R * c1.T)), 162 | numpy.matrix([0., 0., 1.])]) 163 | 164 | 165 | def read_im_and_landmarks(fname): 166 | im = cv2.imread(fname, cv2.IMREAD_COLOR) 167 | im = cv2.resize(im, (im.shape[1] * SCALE_FACTOR, 168 | im.shape[0] * SCALE_FACTOR)) 169 | s = get_landmarks(im) 170 | 171 | return im, s 172 | 173 | 174 | def warp_im(im, M, dshape): 175 | output_im = numpy.zeros(dshape, dtype=im.dtype) 176 | cv2.warpAffine(im, 177 | M[:2], 178 | (dshape[1], dshape[0]), 179 | dst=output_im, 180 | borderMode=cv2.BORDER_TRANSPARENT, 181 | flags=cv2.WARP_INVERSE_MAP) 182 | return output_im 183 | 184 | 185 | def correct_colours(im1, im2, landmarks1): 186 | blur_amount = COLOUR_CORRECT_BLUR_FRAC * numpy.linalg.norm( 187 | numpy.mean(landmarks1[LEFT_EYE_POINTS], axis=0) - 188 | numpy.mean(landmarks1[RIGHT_EYE_POINTS], axis=0)) 189 | blur_amount = int(blur_amount) 190 | if blur_amount % 2 == 0: 191 | blur_amount += 1 192 | im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0) 193 | im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0) 194 | 195 | # Avoid divide-by-zero errors. 196 | im2_blur += (128 * (im2_blur <= 1.0)).astype(im2_blur.dtype) 197 | 198 | return (im2.astype(numpy.float64) * im1_blur.astype(numpy.float64) / 199 | im2_blur.astype(numpy.float64)) 200 | 201 | 202 | template_path = './dat/avg_face.jpg' # 用平均脸做底图,替换五官 203 | 204 | def faceswap(upload_path): 205 | # upload_path = "./dat/lss.jpg" 206 | 207 | im1, landmarks1 = read_im_and_landmarks(template_path) 208 | im2, landmarks2 = read_im_and_landmarks(upload_path) 209 | 210 | M = transformation_from_points(landmarks1[ALIGN_POINTS], 211 | landmarks2[ALIGN_POINTS]) 212 | 213 | mask = get_face_mask(im2, landmarks2) 214 | warped_mask = warp_im(mask, M, im1.shape) 215 | combined_mask = numpy.max([get_face_mask(im1, landmarks1), warped_mask], 216 | axis=0) 217 | 218 | warped_im2 = warp_im(im2, M, im1.shape) 219 | warped_corrected_im2 = correct_colours(im1, warped_im2, landmarks1) 220 | 221 | output_im = im1 * (1.0 - combined_mask) + warped_corrected_im2 * combined_mask 222 | return output_im 223 | # cv2.imwrite('output.jpg', output_im) 224 | 225 | if __name__ == '__main__': 226 | im = faceswap("./dat/lss.jpg") 227 | cv2.imshow("im", im * 255.) 228 | cv2.waitKey() 229 | cv2.imwrite('output.jpg', im) -------------------------------------------------------------------------------- /imitator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | 7 | import utils 8 | import config 9 | 10 | 11 | ''' 12 | 模拟器:使用神经网络代替游戏引擎 13 | ''' 14 | 15 | class Imitator(nn.Module): 16 | def __init__(self, is_bias=False): 17 | super(Imitator, self).__init__() 18 | 19 | self.model = nn.Sequential( 20 | # 1. (batch, 512, 4, 4) 21 | nn.ConvTranspose2d(config.continuous_params_size, 512, kernel_size=4, bias=is_bias), 22 | nn.BatchNorm2d(512), 23 | nn.ReLU(), 24 | # 2. (batch, 512, 8, 8) 25 | nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=is_bias), 26 | nn.BatchNorm2d(512), 27 | nn.ReLU(), 28 | # 3. (batch, 512, 16, 16) 29 | nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=is_bias), 30 | nn.BatchNorm2d(512), 31 | nn.ReLU(), 32 | # 4. (batch, 256, 32, 32) 33 | nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=is_bias), 34 | nn.BatchNorm2d(256), 35 | nn.ReLU(), 36 | # 5. (batch, 128, 64, 64) 37 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=is_bias), 38 | nn.BatchNorm2d(128), 39 | nn.ReLU(), 40 | # 6. (batch, 64, 128, 128) 41 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=is_bias), 42 | nn.BatchNorm2d(64), 43 | nn.ReLU(), 44 | # 7. (batch, 64, 256, 256) 45 | nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=is_bias), 46 | nn.BatchNorm2d(64), 47 | nn.ReLU(), 48 | # 8. (batch, 3, 512, 512) 49 | nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=is_bias), 50 | # nn.Tanh(), # DCGAN和论文中用的tanh 51 | nn.Sigmoid() # 开源版本用的sigmoid 52 | ) 53 | 54 | self._initialize_weights() 55 | 56 | def _initialize_weights(self) -> None: 57 | for m in self.modules(): 58 | classname = m.__class__.__name__ 59 | if classname.find('Conv') != -1: 60 | nn.init.normal_(m.weight, 0.0, 0.02) 61 | elif classname.find('BatchNorm') != -1: 62 | nn.init.normal_(m.weight, 1.0, 0.02) 63 | nn.init.constant_(m.bias, 0) 64 | 65 | ''' 66 | :param params [batch, params_cnt] 67 | :return [batch, 3, 512, 512] 68 | ''' 69 | def forward(self, params): 70 | # batch = params.size(0) # 1 71 | # length = params.size(1) # 95 72 | # _params = params.reshape((batch, length, 1, 1)) # [1, 95, 1, 1] 73 | 74 | # 把连续参数的size从[batch, continuous_params_size]扩展成[batch, continuous_params_size, 1, 1] 75 | _params = params.unsqueeze(2).unsqueeze(3) 76 | return self.model(_params) 77 | 78 | -------------------------------------------------------------------------------- /lightcnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | 11 | import config 12 | 13 | 14 | class mfm(nn.Module): 15 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1): 16 | super(mfm, self).__init__() 17 | self.out_channels = out_channels 18 | if type == 1: 19 | self.filter = nn.Conv2d(in_channels, 2 * out_channels, kernel_size=kernel_size, stride=stride, 20 | padding=padding) 21 | else: 22 | self.filter = nn.Linear(in_channels, 2 * out_channels) 23 | 24 | def forward(self, x): 25 | x = self.filter(x) 26 | out = torch.split(x, self.out_channels, 1) 27 | return torch.max(out[0], out[1]) 28 | 29 | 30 | class group(nn.Module): 31 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 32 | super(group, self).__init__() 33 | self.conv_a = mfm(in_channels, in_channels, 1, 1, 0) 34 | self.conv = mfm(in_channels, out_channels, kernel_size, stride, padding) 35 | 36 | def forward(self, x): 37 | x = self.conv_a(x) 38 | x = self.conv(x) 39 | return x 40 | 41 | 42 | class resblock(nn.Module): 43 | def __init__(self, in_channels, out_channels): 44 | super(resblock, self).__init__() 45 | self.conv1 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 46 | self.conv2 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 47 | 48 | def forward(self, x): 49 | res = x 50 | out = self.conv1(x) 51 | out = self.conv2(out) 52 | out = out + res 53 | return out 54 | 55 | 56 | class network_9layers(nn.Module): 57 | def __init__(self, num_classes=79077): 58 | super(network_9layers, self).__init__() 59 | self.features = nn.Sequential( 60 | mfm(1, 48, 5, 1, 2), 61 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True), 62 | group(48, 96, 3, 1, 1), 63 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True), 64 | group(96, 192, 3, 1, 1), 65 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True), 66 | group(192, 128, 3, 1, 1), 67 | group(128, 128, 3, 1, 1), 68 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True), 69 | ) 70 | self.fc1 = mfm(8 * 8 * 128, 256, type=0) 71 | self.fc2 = nn.Linear(256, num_classes) 72 | 73 | def forward(self, x): 74 | x = self.features(x) 75 | x = x.view(x.size(0), -1) 76 | x = self.fc1(x) 77 | x = F.dropout(x, training=self.training) 78 | out = self.fc2(x) 79 | return out, x 80 | 81 | 82 | class network_29layers(nn.Module): 83 | def __init__(self, block, layers, num_classes=79077): 84 | super(network_29layers, self).__init__() 85 | self.conv1 = mfm(1, 48, 5, 1, 2) 86 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) 87 | self.block1 = self._make_layer(block, layers[0], 48, 48) 88 | self.group1 = group(48, 96, 3, 1, 1) 89 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) 90 | self.block2 = self._make_layer(block, layers[1], 96, 96) 91 | self.group2 = group(96, 192, 3, 1, 1) 92 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) 93 | self.block3 = self._make_layer(block, layers[2], 192, 192) 94 | self.group3 = group(192, 128, 3, 1, 1) 95 | self.block4 = self._make_layer(block, layers[3], 128, 128) 96 | self.group4 = group(128, 128, 3, 1, 1) 97 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) 98 | self.fc = mfm(8 * 8 * 128, 256, type=0) 99 | self.fc2 = nn.Linear(256, num_classes) 100 | 101 | def _make_layer(self, block, num_blocks, in_channels, out_channels): 102 | layers = [] 103 | for i in range(0, num_blocks): 104 | layers.append(block(in_channels, out_channels)) 105 | return nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | x = self.conv1(x) 109 | x = self.pool1(x) 110 | 111 | x = self.block1(x) 112 | x = self.group1(x) 113 | x = self.pool2(x) 114 | 115 | x = self.block2(x) 116 | x = self.group2(x) 117 | x = self.pool3(x) 118 | 119 | x = self.block3(x) 120 | x = self.group3(x) 121 | x = self.block4(x) 122 | x = self.group4(x) 123 | x = self.pool4(x) 124 | 125 | x = x.view(x.size(0), -1) 126 | fc = self.fc(x) 127 | fc = F.dropout(fc, training=self.training) 128 | out = self.fc2(fc) 129 | return out, fc 130 | 131 | 132 | class network_29layers_v2(nn.Module): 133 | def __init__(self, block, layers, num_classes=79077): 134 | super(network_29layers_v2, self).__init__() 135 | self.conv1 = mfm(1, 48, 5, 1, 2) 136 | self.block1 = self._make_layer(block, layers[0], 48, 48) 137 | self.group1 = group(48, 96, 3, 1, 1) 138 | self.block2 = self._make_layer(block, layers[1], 96, 96) 139 | self.group2 = group(96, 192, 3, 1, 1) 140 | self.block3 = self._make_layer(block, layers[2], 192, 192) 141 | self.group3 = group(192, 128, 3, 1, 1) 142 | self.block4 = self._make_layer(block, layers[3], 128, 128) 143 | self.group4 = group(128, 128, 3, 1, 1) 144 | self.fc = nn.Linear(8 * 8 * 128, 256) 145 | self.fc2 = nn.Linear(256, num_classes, bias=False) 146 | 147 | def _make_layer(self, block, num_blocks, in_channels, out_channels): 148 | layers = [] 149 | for i in range(0, num_blocks): 150 | layers.append(block(in_channels, out_channels)) 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x): 154 | x = self.conv1(x) 155 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) 156 | 157 | x = self.block1(x) 158 | x = self.group1(x) 159 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) 160 | 161 | x = self.block2(x) 162 | x = self.group2(x) 163 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) 164 | 165 | x = self.block3(x) 166 | x = self.group3(x) 167 | x = self.block4(x) 168 | x = self.group4(x) 169 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) 170 | 171 | x = x.view(x.size(0), -1) 172 | fc = self.fc(x) 173 | x = F.dropout(fc, training=self.training) 174 | out = self.fc2(x) 175 | return out, fc 176 | 177 | 178 | def LightCNN_9Layers(**kwargs): 179 | model = network_9layers(**kwargs) 180 | return model 181 | 182 | 183 | def LightCNN_29Layers(**kwargs): 184 | model = network_29layers(resblock, [1, 2, 3, 4], **kwargs) 185 | return model 186 | 187 | 188 | def LightCNN_29Layers_v2(**kwargs): 189 | model = network_29layers_v2(resblock, [1, 2, 3, 4], **kwargs) 190 | return model 191 | 192 | ''' 193 | 保存lightcnn提取的特征 194 | ''' 195 | def save_feature(save_path, img_name, features): 196 | img_path = os.path.join(save_path, img_name) 197 | img_dir = os.path.dirname(img_path) + '/' 198 | if not os.path.exists(img_dir): 199 | os.makedirs(img_dir) 200 | fname = os.path.splitext(img_path)[0] 201 | fname = fname + '.feat' 202 | fid = open(fname, 'wb') 203 | fid.write(features) 204 | fid.close() 205 | 206 | if __name__ == '__main__': 207 | model = LightCNN_29Layers_v2(num_classes=80013) 208 | model.eval() 209 | if torch.cuda.is_available(): 210 | model = torch.nn.DataParallel(model).cuda() 211 | else: 212 | model = torch.nn.DataParallel(model) 213 | checkpoint = torch.load("./checkpoint/LightCNN_29Layers_V2_checkpoint.pth.tar", map_location="cpu") 214 | model.load_state_dict(checkpoint['state_dict']) 215 | 216 | transform = transforms.Compose([transforms.ToTensor()]) 217 | img_path = "./dat/0020.png" 218 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 219 | img = cv2.resize(img, (128, 128)) 220 | img = np.reshape(img, (128, 128, 1)) 221 | img = transform(img) 222 | 223 | input = torch.zeros(1, 1, 128, 128) 224 | input[0, :, :, :] = img 225 | 226 | start = time.time() 227 | if config.use_gpu: 228 | input = input.cuda() 229 | 230 | with torch.no_grad(): 231 | _, features = model(input) 232 | print(features.shape) 233 | 234 | fname = "./output/features.feat" 235 | fid = open(fname, 'wb') 236 | fid.write(features.data.cpu().numpy()[0]) 237 | fid.close() 238 | -------------------------------------------------------------------------------- /model_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import config 4 | 5 | ''' 6 | 模型key的处理 7 | 多GPU训练的模型以module.开头,需改成单GPU模型的格式 8 | ''' 9 | 10 | if __name__ == '__main__': 11 | # checkpoint = torch.load(config.imitator_model, map_location=config.device) 12 | # for new_key, old_key in checkpoint.items(): 13 | # checkpoint[new_key] = checkpoint.pop(old_key) 14 | # torch.save(checkpoint, './checkpoint/epoch_950.pt') 15 | 16 | new_model = {k.replace('module.', ''): v for k, v in torch.load(config.imitator_model, map_location=config.device).items()} 17 | torch.save(new_model, './checkpoint/epoch_950.pt') -------------------------------------------------------------------------------- /myimitator.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, unicode_literals) 2 | 3 | import os 4 | import cv2 5 | import copy 6 | import math 7 | import json 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import utils 14 | import config 15 | 16 | ''' 17 | 自定义Imitator 18 | 1.conv,linear,embedding后加上sn 19 | 2.指定层加上self-attention 20 | 3.自定义bn 21 | ''' 22 | 23 | # 采用sn做 normalization 24 | def snconv2d(eps=1e-12, **kwargs): 25 | return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps) 26 | 27 | def snlinear(eps=1e-12, **kwargs): 28 | return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps) 29 | 30 | def sn_embedding(eps=1e-12, **kwargs): 31 | return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps) 32 | 33 | # self-attention层 34 | class SelfAttn(nn.Module): 35 | def __init__(self, in_channels, eps=1e-12): 36 | super(SelfAttn, self).__init__() 37 | self.in_channels = in_channels 38 | self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, 39 | kernel_size=1, bias=False, eps=eps) 40 | self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, 41 | kernel_size=1, bias=False, eps=eps) 42 | self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, 43 | kernel_size=1, bias=False, eps=eps) 44 | self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels, 45 | kernel_size=1, bias=False, eps=eps) 46 | self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) 47 | self.softmax = nn.Softmax(dim=-1) 48 | self.gamma = nn.Parameter(torch.zeros(1)) 49 | 50 | def forward(self, x): 51 | _, ch, h, w = x.size() 52 | # Theta path 53 | theta = self.snconv1x1_theta(x) 54 | theta = theta.view(-1, ch//8, h*w) 55 | # Phi path 56 | phi = self.snconv1x1_phi(x) 57 | phi = self.maxpool(phi) 58 | phi = phi.view(-1, ch//8, h*w//4) 59 | # Attn map 60 | attn = torch.bmm(theta.permute(0, 2, 1), phi) 61 | attn = self.softmax(attn) 62 | # g path 63 | g = self.snconv1x1_g(x) 64 | g = self.maxpool(g) 65 | g = g.view(-1, ch//2, h*w//4) 66 | # Attn_g - o_conv 67 | attn_g = torch.bmm(g, attn.permute(0, 2, 1)) 68 | attn_g = attn_g.view(-1, ch//2, h, w) 69 | attn_g = self.snconv1x1_o_conv(attn_g) 70 | # Out 71 | out = x + self.gamma*attn_g 72 | return out 73 | 74 | # 自定义bn 75 | class BigGANBatchNorm(nn.Module): 76 | """ This is a batch norm module that can handle conditional input and can be provided with pre-computed 77 | activation means and variances for various truncation parameters. 78 | 79 | We cannot just rely on torch.batch_norm since it cannot handle 80 | batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances. 81 | If you want to train this model you should add running means and variance computation logic. 82 | """ 83 | def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True): 84 | super(BigGANBatchNorm, self).__init__() 85 | self.num_features = num_features 86 | self.eps = eps 87 | self.conditional = conditional 88 | 89 | # We use pre-computed statistics for n_stats values of truncation between 0 and 1 90 | self.register_buffer('running_means', torch.zeros(n_stats, num_features)) 91 | self.register_buffer('running_vars', torch.ones(n_stats, num_features)) 92 | self.step_size = 1.0 / (n_stats - 1) 93 | 94 | if conditional: 95 | assert condition_vector_dim is not None 96 | self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) 97 | self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) 98 | else: 99 | self.weight = torch.nn.Parameter(torch.Tensor(num_features)) 100 | self.bias = torch.nn.Parameter(torch.Tensor(num_features)) 101 | 102 | def forward(self, x, truncation, condition_vector=None): 103 | # Retreive pre-computed statistics associated to this truncation 104 | coef, start_idx = math.modf(truncation / self.step_size) 105 | start_idx = int(start_idx) 106 | if coef != 0.0: # Interpolate 107 | running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef) 108 | running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef) 109 | else: 110 | running_mean = self.running_means[start_idx] 111 | running_var = self.running_vars[start_idx] 112 | 113 | if self.conditional: 114 | running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 115 | running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 116 | 117 | weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1) 118 | bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1) 119 | 120 | out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias 121 | else: 122 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, 123 | training=False, momentum=0.0, eps=self.eps) 124 | return out 125 | 126 | class GenBlock(nn.Module): 127 | def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False, 128 | n_stats=51, eps=1e-12): 129 | super(GenBlock, self).__init__() 130 | self.up_sample = up_sample 131 | self.drop_channels = (in_size != out_size) 132 | middle_size = in_size // reduction_factor 133 | 134 | self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 135 | self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps) 136 | 137 | self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 138 | self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) 139 | 140 | self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 141 | self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) 142 | 143 | self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 144 | self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps) 145 | 146 | self.relu = nn.ReLU() 147 | 148 | def forward(self, x, cond_vector, truncation): 149 | x0 = x 150 | 151 | x = self.bn_0(x, truncation, cond_vector) 152 | x = self.relu(x) 153 | x = self.conv_0(x) 154 | 155 | x = self.bn_1(x, truncation, cond_vector) 156 | x = self.relu(x) 157 | if self.up_sample: 158 | x = F.interpolate(x, scale_factor=2, mode='nearest') 159 | x = self.conv_1(x) 160 | 161 | x = self.bn_2(x, truncation, cond_vector) 162 | x = self.relu(x) 163 | x = self.conv_2(x) 164 | 165 | x = self.bn_3(x, truncation, cond_vector) 166 | x = self.relu(x) 167 | x = self.conv_3(x) 168 | 169 | if self.drop_channels: 170 | new_channels = x0.shape[1] // 2 171 | x0 = x0[:, :new_channels, ...] 172 | if self.up_sample: 173 | x0 = F.interpolate(x0, scale_factor=2, mode='nearest') 174 | 175 | out = x + x0 176 | return out 177 | 178 | class MyImitator(nn.Module): 179 | def __init__(self): 180 | super(MyImitator, self).__init__() 181 | 182 | # 1.加载配置文件 183 | with open(config.config_jsonfile, "r", encoding='utf-8') as reader: 184 | text = reader.read() 185 | self.conf = BigGANConfig() 186 | for key, value in json.loads(text).items(): 187 | self.conf.__dict__[key] = value 188 | 189 | # 定义网络结构 190 | # self.embeddings = nn.Linear(config.num_classes, config.continuous_params_size, bias=False) 191 | 192 | ch = self.conf.channel_width 193 | condition_vector_dim = config.continuous_params_size 194 | 195 | self.gen_z = snlinear(in_features=condition_vector_dim, out_features=4*4*16*ch, eps=self.conf.eps) 196 | layers = [] 197 | for i, layer in enumerate(self.conf.layers): 198 | if i == self.conf.attention_layer_position: # 在指定层加上self-attention 199 | layers.append(SelfAttn(ch * layer[1], eps=self.conf.eps)) 200 | layers.append(GenBlock(ch * layer[1], 201 | ch * layer[2], 202 | condition_vector_dim, 203 | up_sample=layer[0], 204 | n_stats=self.conf.n_stats, 205 | eps=self.conf.eps)) 206 | self.layers = nn.ModuleList(layers) 207 | 208 | self.bn = BigGANBatchNorm(ch, n_stats=self.conf.n_stats, eps=self.conf.eps, conditional=False) 209 | self.relu = nn.ReLU() 210 | self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=self.conf.eps) 211 | self.tanh = nn.Tanh() 212 | 213 | def forward(self, cond_vector, truncation=0.4): 214 | # cond_vector = cond_vector.unsqueeze(2).unsqueeze(3) 215 | z = self.gen_z(cond_vector) # cond_cector [batch_size, config.continuous_params_size], z [1, 4*4*16*self.conf.channel_width] 216 | 217 | # We use this conversion step to be able to use TF weights: 218 | # TF convention on shape is [batch, height, width, channels] 219 | # PT convention on shape is [batch, channels, height, width] 220 | z = z.view(-1, 4, 4, 16 * self.conf.channel_width) # [batch_size, 4, 4, 2048] 221 | z = z.permute(0, 3, 1, 2).contiguous() # [batch_size, 2048, 4, 4] 222 | 223 | for i, layer in enumerate(self.layers): 224 | if isinstance(layer, GenBlock): 225 | z = layer(z, cond_vector, truncation) 226 | else: 227 | z = layer(z) 228 | 229 | z = self.bn(z, truncation) # [1, 128, 512, 512] 230 | z = self.relu(z) # [1, 128, 512, 512] 231 | z = self.conv_to_rgb(z) # [1, 128, 512, 512] 232 | z = z[:, :3, ...] # [1, 3, 512, 512] 233 | z = self.tanh(z) # [1, 3, 512, 512] 234 | return z 235 | 236 | ''' 237 | 自定义Imitator的config 238 | ''' 239 | class BigGANConfig(object): 240 | """ Configuration class to store the configuration of a `BigGAN`. 241 | Defaults are for the 128x128 model. 242 | layers tuple are (up-sample in the layer ?, input channels, output channels) 243 | """ 244 | def __init__(self, 245 | output_dim=512, 246 | z_dim=512, 247 | class_embed_dim=512, 248 | channel_width=512, 249 | num_classes=1000, 250 | # (是否上采样,input_channels,output_channels) 251 | layers=[(False, 16, 16), 252 | (True, 16, 16), 253 | (False, 16, 16), 254 | (True, 16, 8), 255 | (False, 8, 8), 256 | (True, 8, 4), 257 | (False, 4, 4), 258 | (True, 4, 2), 259 | (False, 2, 2), 260 | (True, 2, 1)], 261 | attention_layer_position=8, 262 | eps=1e-4, 263 | n_stats=51): 264 | """Constructs BigGANConfig. """ 265 | self.output_dim = output_dim 266 | self.z_dim = z_dim 267 | self.class_embed_dim = class_embed_dim 268 | self.channel_width = channel_width 269 | self.num_classes = num_classes 270 | self.layers = layers 271 | self.attention_layer_position = attention_layer_position 272 | self.eps = eps 273 | self.n_stats = n_stats 274 | 275 | @classmethod 276 | def from_dict(cls, json_object): 277 | """Constructs a `BigGANConfig` from a Python dictionary of parameters.""" 278 | config = BigGANConfig() 279 | for key, value in json_object.items(): 280 | config.__dict__[key] = value 281 | return config 282 | 283 | @classmethod 284 | def from_json_file(cls, json_file): 285 | """Constructs a `BigGANConfig` from a json file of parameters.""" 286 | with open(json_file, "r", encoding='utf-8') as reader: 287 | text = reader.read() 288 | return cls.from_dict(json.loads(text)) 289 | 290 | def __repr__(self): 291 | return str(self.to_json_string()) 292 | 293 | def to_dict(self): 294 | """Serializes this instance to a Python dictionary.""" 295 | output = copy.deepcopy(self.__dict__) 296 | return output 297 | 298 | def to_json_string(self): 299 | """Serializes this instance to a JSON string.""" 300 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 301 | 302 | 303 | if __name__ == '__main__': 304 | # # d1 = nn.Linear(20, 40) 305 | # d1 = nn.Conv2d(10, 20, kernel_size=3, stride=1, padding=0, groups=10) 306 | # d2 = nn.utils.spectral_norm(d1) 307 | # print(d2) 308 | # print(d2.weight_u.size()) 309 | 310 | t_params = torch.full([1, config.continuous_params_size], 0.5, dtype=torch.float32) 311 | imitator = MyImitator() 312 | print(imitator) 313 | output = imitator(t_params) 314 | print(output.shape) # [1, 3, 512, 512] 315 | -------------------------------------------------------------------------------- /papers/1909.01064v1(Face-to-Parameter Translation ).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/papers/1909.01064v1(Face-to-Parameter Translation ).pdf -------------------------------------------------------------------------------- /papers/2003.05653(Towards High-Fidelity 3D Face Reconstruction).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/papers/2003.05653(Towards High-Fidelity 3D Face Reconstruction).pdf -------------------------------------------------------------------------------- /papers/2008.07132v1(Fast and Robust Face-to-Parameter Translation).pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/papers/2008.07132v1(Fast and Robust Face-to-Parameter Translation).pdf -------------------------------------------------------------------------------- /random_gen_image.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import random 4 | 5 | import config 6 | from imitator import Imitator 7 | 8 | 9 | if __name__ == '__main__': 10 | imitator = Imitator() 11 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu')) 12 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数 13 | 14 | for i in range(10): 15 | if random.randint(1, 10) % 2 == 0: 16 | t_params = torch.rand((1, config.continuous_params_size), dtype=torch.float32) 17 | else: 18 | # t_params = torch.randn((1, config.continuous_params_size), dtype=torch.float32) 19 | t_params = torch.normal(0.5, 1, (1, config.continuous_params_size)) 20 | print("2.1.", t_params) 21 | t_params.data = t_params.data.clamp(0., 1.) 22 | print("2.2.", t_params) 23 | # t_params = torch.rand((1, config.continuous_params_size), dtype=torch.float32) 24 | 25 | y_ = imitator(t_params) # [1, 3, 512, 512], [batch_size, c, w, h] 26 | tmp = y_.detach().cpu().numpy()[0] 27 | tmp = tmp.transpose(2, 1, 0) 28 | tmp = tmp * 255.0 29 | # cv2.imshow("y_", tmp) 30 | # cv2.waitKey() 31 | print(type(tmp), tmp.shape) 32 | cv2.imwrite("./dat/gen_%d.jpg" % i, tmp) 33 | print("已保存:", i) 34 | 35 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | import config 10 | 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | # resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 14 | resnet18_url = "./checkpoint/resnet18-5c106cde.pth" 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | def __init__(self, in_chan, out_chan, stride=1): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(in_chan, out_chan, stride) 25 | self.bn1 = nn.BatchNorm2d(out_chan) 26 | self.conv2 = conv3x3(out_chan, out_chan) 27 | self.bn2 = nn.BatchNorm2d(out_chan) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | if in_chan != out_chan or stride != 1: 31 | self.downsample = nn.Sequential(nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 32 | nn.BatchNorm2d(out_chan), ) 33 | 34 | def forward(self, x): 35 | residual = self.conv1(x) 36 | residual = F.relu(self.bn1(residual)) 37 | residual = self.conv2(residual) 38 | residual = self.bn2(residual) 39 | 40 | shortcut = x 41 | if self.downsample is not None: 42 | shortcut = self.downsample(x) 43 | 44 | out = shortcut + residual 45 | out = self.relu(out) 46 | return out 47 | 48 | 49 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 50 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 51 | for i in range(bnum - 1): 52 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 53 | return nn.Sequential(*layers) 54 | 55 | 56 | class Resnet18(nn.Module): 57 | def __init__(self): 58 | super(Resnet18, self).__init__() 59 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 60 | self.bn1 = nn.BatchNorm2d(64) 61 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 62 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 63 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 64 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 65 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 66 | self.init_weight() 67 | 68 | def forward(self, x): 69 | x = self.conv1(x) 70 | x = F.relu(self.bn1(x)) 71 | x = self.maxpool(x) 72 | 73 | x = self.layer1(x) 74 | feat8 = self.layer2(x) # 1/8 75 | feat16 = self.layer3(feat8) # 1/16 76 | feat32 = self.layer4(feat16) # 1/32 77 | return feat8, feat16, feat32 78 | 79 | def init_weight(self): 80 | if resnet18_url.startswith("http"): 81 | state_dict = modelzoo.load_url(resnet18_url) 82 | else: 83 | if config.use_gpu: 84 | state_dict = torch.load(resnet18_url) 85 | else: 86 | state_dict = torch.load(resnet18_url, map_location=torch.device('cpu')) 87 | self_state_dict = self.state_dict() 88 | for k, v in state_dict.items(): 89 | if 'fc' in k: continue 90 | self_state_dict.update({k: v}) 91 | self.load_state_dict(self_state_dict) 92 | 93 | def get_params(self): 94 | wd_params, nowd_params = [], [] 95 | for name, module in self.named_modules(): 96 | if isinstance(module, (nn.Linear, nn.Conv2d)): 97 | wd_params.append(module.weight) 98 | if not module.bias is None: 99 | nowd_params.append(module.bias) 100 | elif isinstance(module, nn.BatchNorm2d): 101 | nowd_params += list(module.parameters()) 102 | return wd_params, nowd_params 103 | 104 | 105 | if __name__ == "__main__": 106 | net = Resnet18() 107 | x = torch.randn(16, 3, 224, 224) 108 | out = net(x) 109 | print(out[0].size(), out[1].size(), out[2].size()) 110 | net.get_params() 111 | -------------------------------------------------------------------------------- /tools/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | import torch.utils.data 6 | import numpy as np 7 | from scipy.stats import entropy 8 | import torchvision.datasets as dset 9 | import torchvision.transforms as transforms 10 | import os 11 | import dataloader_own 12 | from scipy import linalg 13 | 14 | from inception import InceptionV3 15 | 16 | # we should use same mean and std for inception v3 model in training and testing process 17 | # reference web page: https://pytorch.org/hub/pytorch_vision_inception_v3/ 18 | mean_inception = [0.485, 0.456, 0.406] 19 | std_inception = [0.229, 0.224, 0.225] 20 | 21 | def compute_FID(img1, img2, batch_size=1): 22 | device = torch.device("cuda:0") # you can change the index of cuda 23 | 24 | N1 = len(img1) 25 | N2 = len(img2) 26 | n_act = 2048 # the number of final layer's dimension 27 | 28 | # Set up dataloaders 29 | dataloader1 = torch.utils.data.DataLoader(img1, batch_size=batch_size) 30 | dataloader2 = torch.utils.data.DataLoader(img2, batch_size=batch_size) 31 | 32 | # Load inception model 33 | # inception_model = inception_v3(pretrained=True, transform_input=False).to(device) 34 | 35 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[n_act] 36 | inception_model = InceptionV3([block_idx]).to(device) 37 | inception_model.eval() 38 | 39 | # get the activations 40 | def get_activations(x): 41 | x = inception_model(x)[0] 42 | return x.cpu().data.numpy().reshape(batch_size, -1) 43 | 44 | act1 = np.zeros((N1, n_act)) 45 | act2 = np.zeros((N2, n_act)) 46 | 47 | data = [dataloader1, dataloader2] 48 | act = [act1, act2] 49 | for n, loader in enumerate(data): 50 | for i, batch in enumerate(loader, 0): 51 | batch = batch[0].to(device) 52 | batch_size_i = batch.size()[0] 53 | activation = get_activations(batch) 54 | 55 | act[n][i * batch_size:i * batch_size + batch_size_i] = activation 56 | 57 | # compute the activation's statistics: mean and std 58 | def compute_act_mean_std(act): 59 | mu = np.mean(act, axis=0) 60 | sigma = np.cov(act, rowvar=False) 61 | return mu, sigma 62 | mu_act1, sigma_act1 = compute_act_mean_std(act1) 63 | mu_act2, sigma_act2 = compute_act_mean_std(act2) 64 | 65 | # compute FID 66 | def _compute_FID(mu1, mu2, sigma1, sigma2,eps=1e-6): 67 | mu1 = np.atleast_1d(mu1) 68 | mu2 = np.atleast_1d(mu2) 69 | sigma1 = np.atleast_2d(sigma1) 70 | sigma2 = np.atleast_2d(sigma2) 71 | 72 | diff = mu1 - mu2 73 | 74 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 75 | if not np.isfinite(covmean).all(): 76 | msg = ('fid calculation produces singular product; ' 77 | 'adding %s to diagonal of cov estimates') % eps 78 | print(msg) 79 | offset = np.eye(sigma1.shape[0]) * eps 80 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 81 | 82 | # Numerical error might give slight imaginary component 83 | if np.iscomplexobj(covmean): 84 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 85 | m = np.max(np.abs(covmean.imag)) 86 | raise ValueError('Imaginary component {}'.format(m)) 87 | covmean = covmean.real 88 | 89 | tr_covmean = np.trace(covmean) 90 | 91 | FID = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 92 | 93 | return FID 94 | 95 | FID = _compute_FID(mu_act1, mu_act2, sigma_act1, sigma_act2) 96 | return FID 97 | 98 | 99 | # main function to compuate FID 100 | data_root = os.path.join('/PATH/TO/YOUR/IMAGE1') 101 | my_dataset_fakeB = dataloader_own.image_loader(data_root, batch_size=64, img_size=299, resize=True, rotation=False, normalize=[mean_inception, std_inception]) 102 | data_root = os.path.join('/PATH/TO/YOUR/IMAGE2') 103 | my_dataset_realB = dataloader_own.image_loader(data_root, batch_size=64, img_size=299, resize=True, rotation=False, normalize=[mean_inception, std_inception]) 104 | 105 | FID = compute_FID(my_dataset_fakeB, my_dataset_realB) 106 | print(FID) -------------------------------------------------------------------------------- /tools/face_align.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import dlib 3 | import numpy 4 | import sys 5 | import matplotlib.pyplot as plt 6 | 7 | PREDICTOR_PATH = r"../checkpoint/shape_predictor_68_face_landmarks.dat" # 68个关键点landmarks的模型文件 8 | SCALE_FACTOR = 1 # 图像的放缩比 9 | FEATHER_AMOUNT = 15 # 羽化边界范围,越大,羽化能力越大,一定要奇数,不能偶数 10 | 11 | # 68个点 12 | FACE_POINTS = list(range(17, 68)) # 脸 13 | MOUTH_POINTS = list(range(48, 61)) # 嘴巴 14 | RIGHT_BROW_POINTS = list(range(17, 22)) # 右眉毛 15 | LEFT_BROW_POINTS = list(range(22, 27)) # 左眉毛 16 | RIGHT_EYE_POINTS = list(range(36, 42)) # 右眼睛 17 | LEFT_EYE_POINTS = list(range(42, 48)) # 左眼睛 18 | NOSE_POINTS = list(range(27, 35)) # 鼻子 19 | JAW_POINTS = list(range(0, 17)) # 下巴 20 | 21 | # 选取用于叠加在第一张脸上的第二张脸的面部特征 22 | # 特征点包括左右眼、眉毛、鼻子和嘴巴 23 | # 是否数量变多之后,会有什么干扰吗? 24 | ALIGN_POINTS = (FACE_POINTS + LEFT_BROW_POINTS + RIGHT_EYE_POINTS + LEFT_EYE_POINTS + 25 | RIGHT_BROW_POINTS + NOSE_POINTS + MOUTH_POINTS + JAW_POINTS) 26 | 27 | # Points from the second image to overlay on the first. The convex hull of each 28 | # element will be overlaid. 29 | OVERLAY_POINTS = [ 30 | LEFT_EYE_POINTS + RIGHT_EYE_POINTS + LEFT_BROW_POINTS + RIGHT_BROW_POINTS, 31 | NOSE_POINTS + MOUTH_POINTS, 32 | ] 33 | # 眼睛 ,眉毛 2 * 22 34 | # 鼻子,嘴巴 分开来 35 | 36 | # 定义用于颜色校正的模糊量,作为瞳孔距离的系数 37 | COLOUR_CORRECT_BLUR_FRAC = 0.6 38 | 39 | # 实例化脸部检测器 40 | detector = dlib.get_frontal_face_detector() 41 | # 加载训练模型 42 | # 并实例化特征提取器 43 | predictor = dlib.shape_predictor(PREDICTOR_PATH) 44 | 45 | 46 | # 定义了两个类处理意外 47 | class TooManyFaces(Exception): 48 | pass 49 | 50 | 51 | class NoFaces(Exception): 52 | pass 53 | 54 | 55 | def get_landmarks(im): 56 | ''' 57 | 通过predictor 拿到68 landmarks 58 | ''' 59 | rects = detector(im, 1) 60 | 61 | if len(rects) > 1: 62 | raise TooManyFaces 63 | if len(rects) == 0: 64 | raise NoFaces 65 | 66 | return numpy.matrix([[p.x, p.y] for p in predictor(im, rects[0]).parts()]) # 68*2的矩阵 67 | 68 | 69 | def annotate_landmarks(im, landmarks): 70 | ''' 71 | 人脸关键点,画图函数 72 | ''' 73 | im = im.copy() 74 | for idx, point in enumerate(landmarks): 75 | pos = (point[0, 0], point[0, 1]) 76 | cv2.putText(im, str(idx), pos, 77 | fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 78 | fontScale=0.4, 79 | color=(0, 0, 255)) 80 | cv2.circle(im, pos, 3, color=(0, 255, 255)) 81 | return im 82 | 83 | 84 | def draw_convex_hull(im, points, color): 85 | ''' 86 | # 绘制凸多边形 计算凸包 87 | ''' 88 | points = cv2.convexHull(points) 89 | cv2.fillConvexPoly(im, points, color=color) 90 | 91 | 92 | def get_face_mask(im, landmarks): 93 | '''获取面部特征部分(眉毛、眼睛、鼻子以及嘴巴)的图像掩码。 94 | 图像掩码作用于原图之后,原图中对应掩码部分为白色的部分才能显示出来,黑色的部分则不予显示,因此通过图像掩码我们就能实现对图像“裁剪”。 95 | 效果参考:https://dn-anything-about-doc.qbox.me/document-uid242676labid2260timestamp1477921310170.png/wm 96 | get_face_mask()的定义是为一张图像和一个标记矩阵生成一个遮罩,它画出了两个白色的凸多边形:一个是眼睛周围的区域, 97 | 一个是鼻子和嘴部周围的区域。之后它由11个(FEATHER_AMOUNT)像素向遮罩的边缘外部羽化扩展,可以帮助隐藏任何不连续的区域。 98 | ''' 99 | im = numpy.zeros(im.shape[:2], dtype=numpy.float64) 100 | 101 | for group in OVERLAY_POINTS: 102 | draw_convex_hull(im, 103 | landmarks[group], 104 | color=1) 105 | 106 | im = numpy.array([im, im, im]).transpose((1, 2, 0)) 107 | 108 | im = (cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) > 0) * 1.0 109 | im = cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) 110 | 111 | return im 112 | 113 | # 返回一个仿射变换 114 | def transformation_from_points(points1, points2): 115 | """ 116 | Return an affine transformation [s * R | T] such that: 117 | sum ||s*R*p1,i + T - p2,i||^2 118 | is minimized. 119 | """ 120 | # Solve the procrustes problem by subtracting centroids, scaling by the 121 | # standard deviation, and then using the SVD to calculate the rotation. See 122 | # the following for more details: 123 | # https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem 124 | 125 | points1 = points1.astype(numpy.float64) # 人脸的指定关键点 126 | points2 = points2.astype(numpy.float64) 127 | 128 | # 数据标准化:先减去均值,再除以std,做成均值为0方差为1的序列 129 | # 每张脸各自做各自的标准化 130 | c1 = numpy.mean(points1, axis=0) # 分别算x和y的均值 131 | c2 = numpy.mean(points2, axis=0) 132 | points1 -= c1 # 浮动于均值的部分,[43, 2] 133 | points2 -= c2 134 | 135 | s1 = numpy.std(points1) 136 | s2 = numpy.std(points2) 137 | points1 /= s1 # 138 | points2 /= s2 139 | 140 | U, S, Vt = numpy.linalg.svd(points1.T * points2) # 141 | 142 | # The R we seek is in fact the transpose of the one given by U * Vt. This 143 | # is because the above formulation assumes the matrix goes on the right 144 | # (with row vectors) where as our solution requires the matrix to be on the 145 | # left (with column vectors). 146 | R = (U * Vt).T # [2, 2] 147 | 148 | return numpy.vstack([numpy.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), 149 | numpy.matrix([0., 0., 1.])]) 150 | 151 | 152 | def read_im_and_landmarks(fname): 153 | im = cv2.imread(fname, cv2.IMREAD_COLOR) 154 | im = cv2.resize(im, (im.shape[1] * SCALE_FACTOR, 155 | im.shape[0] * SCALE_FACTOR)) 156 | s = get_landmarks(im) # [68, 2] 157 | 158 | return im, s 159 | 160 | 161 | def warp_im(im, M, dshape): 162 | ''' 163 | 由 get_face_mask 获得的图像掩码还不能直接使用,因为一般来讲用户提供的两张图像的分辨率大小很可能不一样,而且即便分辨率一样, 164 | 图像中的人脸由于拍摄角度和距离等原因也会呈现出不同的大小以及角度,所以如果不能只是简单地把第二个人的面部特征抠下来直接放在第一个人脸上, 165 | 我们还需要根据两者计算所得的面部特征区域进行匹配变换,使得二者的面部特征尽可能重合。 166 | 167 | 仿射函数,warpAffine,能对图像进行几何变换 168 | 三个主要参数,第一个输入图像,第二个变换矩阵 np.float32 类型,第三个变换之后图像的宽高 169 | 170 | 对齐主要函数 171 | ''' 172 | output_im = numpy.zeros(dshape, dtype=im.dtype) # [512, 512, 3] 173 | cv2.warpAffine(im, 174 | M[:2], 175 | (dshape[1], dshape[0]), 176 | dst=output_im, 177 | borderMode=cv2.BORDER_TRANSPARENT, 178 | flags=cv2.WARP_INVERSE_MAP) 179 | return output_im 180 | 181 | 182 | def correct_colours(im1, im2, landmarks1): 183 | ''' 184 | 修改皮肤颜色,使两张图片在拼接时候显得更加自然。 185 | ''' 186 | blur_amount = COLOUR_CORRECT_BLUR_FRAC * numpy.linalg.norm( 187 | numpy.mean(landmarks1[LEFT_EYE_POINTS], axis=0) - 188 | numpy.mean(landmarks1[RIGHT_EYE_POINTS], axis=0)) 189 | blur_amount = int(blur_amount) 190 | if blur_amount % 2 == 0: 191 | blur_amount += 1 192 | im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0) 193 | im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0) 194 | 195 | # Avoid divide-by-zero errors. 196 | im2_blur += (128 * (im2_blur <= 1.0)).astype(im2_blur.dtype) 197 | 198 | return (im2.astype(numpy.float64) * im1_blur.astype(numpy.float64) / 199 | im2_blur.astype(numpy.float64)) 200 | 201 | 202 | # 换脸函数 203 | def Switch_face(Base_path, cover_path): 204 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图 205 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图 206 | 207 | if len(landmarks1) == 0 & len(landmarks2) == 0: 208 | raise RuntimeError("Faces detected is no face!") 209 | if len(landmarks1) > 1 & len(landmarks2) > 1: 210 | raise RuntimeError("Faces detected is more than 1!") 211 | 212 | # landmarks1[ALIGN_POINTS]为人脸的的指定关键点 213 | M = transformation_from_points(landmarks1[ALIGN_POINTS], 214 | landmarks2[ALIGN_POINTS]) 215 | mask = get_face_mask(im2, landmarks2) 216 | warped_mask = warp_im(mask, M, im1.shape) 217 | combined_mask = numpy.max([get_face_mask(im1, landmarks1), warped_mask], 218 | axis=0) 219 | warped_im2 = warp_im(im2, M, im1.shape) 220 | warped_corrected_im2 = correct_colours(im1, warped_im2, landmarks1) 221 | 222 | output_im = im1 * (1.0 - combined_mask) + warped_corrected_im2 * combined_mask 223 | return output_im 224 | 225 | 226 | # 人脸对齐函数 227 | def face_Align(Base_path, cover_path): 228 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图 229 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图 230 | 231 | # 得到仿射变换矩阵 232 | M = transformation_from_points(landmarks1[ALIGN_POINTS], 233 | landmarks2[ALIGN_POINTS]) 234 | warped_im2 = warp_im(im2, M, im1.shape) 235 | return warped_im2 236 | 237 | FEATHER_AMOUNT = 19 238 | 239 | template_img = '../dat/avg_face.jpg' # 模板 240 | process_img = '../dat/0020.png' 241 | warped_mask = face_Align(template_img, process_img) 242 | cv2.imwrite("../dat/result_0020.jpg", warped_mask) 243 | 244 | # plt.subplot(111) 245 | # plt.imshow(warped_mask) # 数据展示 246 | # plt.show() -------------------------------------------------------------------------------- /tools/fid.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision.models.inception import Inception3 4 | from torchvision.models.utils import load_state_dict_from_url 5 | 6 | mean_inception = [0.485, 0.456, 0.406] 7 | std_inception = [0.229, 0.224, 0.225] 8 | 9 | class FID_Dataset(Dataset): 10 | def __init__(self, imgpath, transform): 11 | self.imgpath = imgpath 12 | self.fileList = [os.path.join(imgpath, file) for file in os.listdir(imgpath)] 13 | self.transform = transform 14 | 15 | 16 | ''' 17 | 计算FID指标 18 | ''' 19 | def compute_FID(): 20 | pass 21 | 22 | 23 | if __name__ == '__main__': 24 | model = Inception3() 25 | print(model) 26 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth", progress=True) 27 | model.load_state_dict(state_dict) 28 | -------------------------------------------------------------------------------- /tools/lsgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | 10 | import torchvision.transforms as transforms 11 | from torchvision.utils import save_image 12 | from torchvision import datasets 13 | 14 | ''' 15 | LSGAN:最小二乘生成对抗损失 16 | GAN存在的通用问题:生成图片质量不高;训练过程不稳定 17 | 传统做法: 18 | 1.采用交叉熵损失(只看True/False)。使得生成器不再优化被判别器识别为True的fake image,即使生成的图片离判别器的决策边界很远,即离真实数据很远。 19 | 这意味着生成器生成的图片质量并不高。为什么生成器不再优化生成图片?因为已经完成了目标——即骗过判别器,所以这时交叉熵损失已经很小了。 20 | 而最小二乘法要求,骗过判别器的前提下还得让生成器把离决策边界比较远的图片拉向决策边界。 21 | 2.sigmoid函数,输入过大或过小时,都会造成梯度消失;而最小二乘只有x=1时梯度为0。 22 | LSGAN做法:与传统做法类似,只需将损失函数换成torch.nn.BCELoss()即可,如下: 23 | adversarial_loss = torch.nn.BCELoss(),若报错RuntimeError: all elements of input should be between 0 and 1,可采用如下: 24 | adversarial_loss = torch.nn.BCEWithLogitsLoss() 25 | ''' 26 | 27 | # 配置项 28 | epochs=200 29 | batch_size=64 30 | learning_rate=0.0002 31 | latent_dim=100 # 从100维向量开始生成图片 32 | image_size=32 33 | sample_interval=100 # 每隔1000个batch保存一下 34 | save_sample_path = "../output/images/" # 样例图片保存路径 35 | if os.path.exists("../output/images/") is False: 36 | os.makedirs("../output/images/") 37 | 38 | cuda = True if torch.cuda.is_available() else False 39 | 40 | def weights_init_normal(m): 41 | classname = m.__class__.__name__ 42 | if classname.find("Conv") != -1: 43 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 44 | elif classname.find("BatchNorm") != -1: 45 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 46 | torch.nn.init.constant_(m.bias.data, 0.0) 47 | 48 | # 生成器 49 | class Generator(nn.Module): 50 | def __init__(self): 51 | super(Generator, self).__init__() 52 | 53 | self.init_size = image_size // 4 54 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2)) 55 | 56 | self.conv_blocks = nn.Sequential( 57 | nn.Upsample(scale_factor=2), 58 | nn.Conv2d(128, 128, 3, stride=1, padding=1), 59 | nn.BatchNorm2d(128, 0.8), 60 | nn.LeakyReLU(0.2, inplace=True), 61 | nn.Upsample(scale_factor=2), 62 | nn.Conv2d(128, 64, 3, stride=1, padding=1), 63 | nn.BatchNorm2d(64, 0.8), 64 | nn.LeakyReLU(0.2, inplace=True), 65 | nn.Conv2d(64, 1, 3, stride=1, padding=1), # 最后生成1维的图像 66 | nn.Tanh(), 67 | ) 68 | 69 | def forward(self, z): 70 | out = self.l1(z) 71 | out = out.view(out.shape[0], 128, self.init_size, self.init_size) 72 | img = self.conv_blocks(out) 73 | return img 74 | 75 | 76 | class Discriminator(nn.Module): 77 | def __init__(self): 78 | super(Discriminator, self).__init__() 79 | 80 | def discriminator_block(in_filters, out_filters, bn=True): 81 | block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)] 82 | if bn: 83 | block.append(nn.BatchNorm2d(out_filters, 0.8)) 84 | return block 85 | 86 | self.model = nn.Sequential( 87 | *discriminator_block(1, 16, bn=False), 88 | *discriminator_block(16, 32), 89 | *discriminator_block(32, 64), 90 | *discriminator_block(64, 128), 91 | ) 92 | 93 | # The height and width of downsampled image 94 | ds_size = image_size // 2 ** 4 95 | self.adv_layer = nn.Linear(128 * ds_size ** 2, 1) 96 | 97 | def forward(self, img): 98 | out = self.model(img) 99 | out = out.view(out.shape[0], -1) 100 | validity = self.adv_layer(out) 101 | 102 | return validity 103 | 104 | 105 | # 这里使用MSE损失,符合LSGAN最小二乘损失的思想 106 | # MSE:均方误差损失函数;BCE:二分类下的交叉熵损失(传统gan里用的BCELoss) 107 | adversarial_loss = torch.nn.MSELoss() 108 | # adversarial_loss = torch.nn.BCEWithLogitsLoss() # 若报错RuntimeError: all elements of input should be between 0 and 1,可采用torch.nn.BCEWithLogitsLoss() 109 | 110 | # Initialize generator and discriminator 111 | generator = Generator() 112 | discriminator = Discriminator() 113 | 114 | if cuda: 115 | generator.cuda() 116 | discriminator.cuda() 117 | adversarial_loss.cuda() 118 | 119 | # Initialize weights 120 | generator.apply(weights_init_normal) 121 | discriminator.apply(weights_init_normal) 122 | 123 | # Configure data loader 124 | os.makedirs("../output/mnist", exist_ok=True) 125 | dataloader = torch.utils.data.DataLoader( 126 | datasets.MNIST( 127 | "../output/mnist", 128 | train=True, 129 | download=True, 130 | transform=transforms.Compose( 131 | [transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] 132 | ), 133 | ), 134 | batch_size=batch_size, 135 | shuffle=True, 136 | ) 137 | 138 | # Optimizers 139 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999)) 140 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)) 141 | 142 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 143 | 144 | # 训练过程 145 | g_loss_list = [] 146 | d_loss_list = [] 147 | 148 | for epoch in range(epochs): 149 | for i, (imgs, _) in enumerate(dataloader): # images [batch_size, 1, image_size, image_size] 150 | 151 | # Adversarial ground truths 152 | valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) # [batch_size, 1] 153 | fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) # [batch_size, 0] 154 | 155 | real_imgs = Variable(imgs.type(Tensor)) # 真实图片 156 | 157 | # 训练生成器 158 | optimizer_G.zero_grad() 159 | z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))) # [batch_size, latent_dim] 160 | gen_imgs = generator(z) # [batch_size, 1, image_size, image_size] 161 | 162 | # 生成器优化方向是D(x)越来越靠近1 163 | g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 判别器判断的结果与真做MSE 164 | 165 | g_loss.backward() 166 | optimizer_G.step() 167 | 168 | # 训练判别器 169 | optimizer_D.zero_grad() 170 | 171 | # LSGAN,要求不能简简单单算交叉熵损失,而是计算两个分布之间的均方误差损失 172 | real_loss = adversarial_loss(discriminator(real_imgs), valid) # 要求把真的识别为真 173 | fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 把假的识别为假 174 | d_loss = 0.5 * (real_loss + fake_loss) 175 | 176 | d_loss.backward() 177 | optimizer_D.step() 178 | 179 | print( 180 | "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" 181 | % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item()) 182 | ) 183 | 184 | if i % sample_interval == 0: 185 | save_image(gen_imgs.data[:25], "%s/%d_%d.png" % (save_sample_path, epoch, i), nrow=5, normalize=True) 186 | g_loss_list.append(g_loss.item()) 187 | d_loss_list.append(d_loss.item()) 188 | 189 | if epoch >= 0: 190 | plt.figure() 191 | plt.subplot(121) 192 | plt.plot(np.arange(0, len(g_loss_list)), g_loss_list) 193 | plt.subplot(122) 194 | plt.plot(np.arange(0, len(d_loss_list)), d_loss_list) 195 | plt.savefig("loss.jpg") 196 | plt.close("all") -------------------------------------------------------------------------------- /train_myimitator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.parallel 7 | import torch.nn.functional as F 8 | from torch.optim import lr_scheduler 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | from torchvision import transforms as T 13 | from torch.utils.data import DataLoader, Dataset 14 | import json 15 | import torchvision.utils as vutils 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | from PIL import Image 19 | import time 20 | 21 | import copy 22 | import math 23 | 24 | ''' 25 | 在服务器上训练只需上传这一个文件即可 26 | ''' 27 | 28 | # Set random seed for reproducibility 29 | manualSeed = 999 30 | # manualSeed = random.randint(1, 10000) # use if you want new results 31 | print("Random Seed: ", manualSeed) 32 | random.seed(manualSeed) 33 | torch.manual_seed(manualSeed) 34 | 35 | # Batch size during training 36 | batch_size = 16 37 | image_size = 512 38 | num_epochs = 1000 39 | lr = 0.01 40 | ngpu = 2 41 | 42 | image_root = "F:/dataset/face_20211203_20000_nojiemao/" 43 | 44 | class Imitator_Dataset(Dataset): 45 | def __init__(self, params_root, image_root, mode="train"): 46 | self.image_root = image_root 47 | self.mode = mode 48 | with open(params_root, encoding='utf-8') as f: 49 | self.params = json.load(f) 50 | 51 | def __getitem__(self, index): 52 | if self.mode == "val": 53 | img = Image.open(os.path.join(self.image_root, '%d.png' % (index + 18000))).convert("RGB") 54 | param = torch.tensor(self.params['%d.png' % (index + 18000)]) 55 | else: 56 | img = Image.open(os.path.join(self.image_root, '%d.png' % index)).convert("RGB") 57 | param = torch.tensor(self.params['%d.png' % index]) 58 | img = T.ToTensor()(img) 59 | return param, img 60 | 61 | def __len__(self): 62 | if self.mode == "train": 63 | return 18000 64 | else: 65 | return 2000 66 | 67 | 68 | train_dataset = Imitator_Dataset(image_root + "param.json", image_root + "face_train/", mode="train") 69 | val_dataset = Imitator_Dataset(image_root + "param.json", image_root + "face_val/", mode="val") 70 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 71 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 72 | 73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 74 | 75 | 76 | # real_batch = next(iter(val_dataloader)) 77 | # plt.figure(figsize=(4, 4)) 78 | # plt.axis("off") 79 | # plt.title("Training Images") 80 | # plt.imshow(np.transpose(vutils.make_grid(real_batch[1].to(device)[:16], nrow=4, padding=2, normalize=True).cpu(), (1, 2, 0))) 81 | # plt.show() 82 | # vutils.save_image(vutils.make_grid(real_batch[1].to(device)[:16], nrow=4, padding=2, normalize=True).cpu(), "./a.jpg") 83 | 84 | 85 | 86 | ''' 87 | 自定义Imitator 88 | 1.conv,linear,embedding后加上sn 89 | 2.指定层加上self-attention 90 | 3.自定义bn 91 | ''' 92 | 93 | # 采用sn做 normalization 94 | def snconv2d(eps=1e-12, **kwargs): 95 | return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps) 96 | 97 | def snlinear(eps=1e-12, **kwargs): 98 | return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps) 99 | 100 | def sn_embedding(eps=1e-12, **kwargs): 101 | return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps) 102 | 103 | # self-attention层 104 | class SelfAttn(nn.Module): 105 | def __init__(self, in_channels, eps=1e-12): 106 | super(SelfAttn, self).__init__() 107 | self.in_channels = in_channels 108 | self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, 109 | kernel_size=1, bias=False, eps=eps) 110 | self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, 111 | kernel_size=1, bias=False, eps=eps) 112 | self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, 113 | kernel_size=1, bias=False, eps=eps) 114 | self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels, 115 | kernel_size=1, bias=False, eps=eps) 116 | self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) 117 | self.softmax = nn.Softmax(dim=-1) 118 | self.gamma = nn.Parameter(torch.zeros(1)) 119 | 120 | def forward(self, x): 121 | _, ch, h, w = x.size() 122 | # Theta path 123 | theta = self.snconv1x1_theta(x) 124 | theta = theta.view(-1, ch//8, h*w) 125 | # Phi path 126 | phi = self.snconv1x1_phi(x) 127 | phi = self.maxpool(phi) 128 | phi = phi.view(-1, ch//8, h*w//4) 129 | # Attn map 130 | attn = torch.bmm(theta.permute(0, 2, 1), phi) 131 | attn = self.softmax(attn) 132 | # g path 133 | g = self.snconv1x1_g(x) 134 | g = self.maxpool(g) 135 | g = g.view(-1, ch//2, h*w//4) 136 | # Attn_g - o_conv 137 | attn_g = torch.bmm(g, attn.permute(0, 2, 1)) 138 | attn_g = attn_g.view(-1, ch//2, h, w) 139 | attn_g = self.snconv1x1_o_conv(attn_g) 140 | # Out 141 | out = x + self.gamma*attn_g 142 | return out 143 | 144 | # 自定义bn 145 | class BigGANBatchNorm(nn.Module): 146 | """ This is a batch norm module that can handle conditional input and can be provided with pre-computed 147 | activation means and variances for various truncation parameters. 148 | 149 | We cannot just rely on torch.batch_norm since it cannot handle 150 | batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances. 151 | If you want to train this model you should add running means and variance computation logic. 152 | """ 153 | def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True): 154 | super(BigGANBatchNorm, self).__init__() 155 | self.num_features = num_features 156 | self.eps = eps 157 | self.conditional = conditional 158 | 159 | # We use pre-computed statistics for n_stats values of truncation between 0 and 1 160 | self.register_buffer('running_means', torch.zeros(n_stats, num_features)) 161 | self.register_buffer('running_vars', torch.ones(n_stats, num_features)) 162 | self.step_size = 1.0 / (n_stats - 1) 163 | 164 | if conditional: 165 | assert condition_vector_dim is not None 166 | self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) 167 | self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps) 168 | else: 169 | self.weight = torch.nn.Parameter(torch.Tensor(num_features)) 170 | self.bias = torch.nn.Parameter(torch.Tensor(num_features)) 171 | 172 | def forward(self, x, truncation, condition_vector=None): 173 | # Retreive pre-computed statistics associated to this truncation 174 | coef, start_idx = math.modf(truncation / self.step_size) 175 | start_idx = int(start_idx) 176 | if coef != 0.0: # Interpolate 177 | running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef) 178 | running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef) 179 | else: 180 | running_mean = self.running_means[start_idx] 181 | running_var = self.running_vars[start_idx] 182 | 183 | if self.conditional: 184 | running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 185 | running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 186 | 187 | weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1) 188 | bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1) 189 | 190 | out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias 191 | else: 192 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias, 193 | training=False, momentum=0.0, eps=self.eps) 194 | return out 195 | 196 | class GenBlock(nn.Module): 197 | def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False, 198 | n_stats=51, eps=1e-12): 199 | super(GenBlock, self).__init__() 200 | self.up_sample = up_sample 201 | self.drop_channels = (in_size != out_size) 202 | middle_size = in_size // reduction_factor 203 | 204 | self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 205 | self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps) 206 | 207 | self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 208 | self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) 209 | 210 | self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 211 | self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps) 212 | 213 | self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True) 214 | self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps) 215 | 216 | self.relu = nn.ReLU() 217 | 218 | def forward(self, x, cond_vector, truncation): 219 | x0 = x 220 | 221 | x = self.bn_0(x, truncation, cond_vector) 222 | x = self.relu(x) 223 | x = self.conv_0(x) 224 | 225 | x = self.bn_1(x, truncation, cond_vector) 226 | x = self.relu(x) 227 | if self.up_sample: 228 | x = F.interpolate(x, scale_factor=2, mode='nearest') 229 | x = self.conv_1(x) 230 | 231 | x = self.bn_2(x, truncation, cond_vector) 232 | x = self.relu(x) 233 | x = self.conv_2(x) 234 | 235 | x = self.bn_3(x, truncation, cond_vector) 236 | x = self.relu(x) 237 | x = self.conv_3(x) 238 | 239 | if self.drop_channels: 240 | new_channels = x0.shape[1] // 2 241 | x0 = x0[:, :new_channels, ...] 242 | if self.up_sample: 243 | x0 = F.interpolate(x0, scale_factor=2, mode='nearest') 244 | 245 | out = x + x0 246 | return out 247 | 248 | class MyImitator(nn.Module): 249 | def __init__(self): 250 | super(MyImitator, self).__init__() 251 | 252 | # 1.加载配置文件 253 | with open("./checkpoint/myimitator-512.json", "r", encoding='utf-8') as reader: 254 | text = reader.read() 255 | self.conf = BigGANConfig() 256 | for key, value in json.loads(text).items(): 257 | self.conf.__dict__[key] = value 258 | 259 | # 定义网络结构 260 | # self.embeddings = nn.Linear(config.num_classes, config.continuous_params_size, bias=False) 261 | 262 | ch = self.conf.channel_width 263 | condition_vector_dim = 223 264 | 265 | self.gen_z = snlinear(in_features=condition_vector_dim, out_features=4*4*16*ch, eps=self.conf.eps) 266 | layers = [] 267 | for i, layer in enumerate(self.conf.layers): 268 | if i == self.conf.attention_layer_position: # 在指定层加上self-attention 269 | layers.append(SelfAttn(ch * layer[1], eps=self.conf.eps)) 270 | layers.append(GenBlock(ch * layer[1], 271 | ch * layer[2], 272 | condition_vector_dim, 273 | up_sample=layer[0], 274 | n_stats=self.conf.n_stats, 275 | eps=self.conf.eps)) 276 | self.layers = nn.ModuleList(layers) 277 | 278 | self.bn = BigGANBatchNorm(ch, n_stats=self.conf.n_stats, eps=self.conf.eps, conditional=False) 279 | self.relu = nn.ReLU() 280 | self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=self.conf.eps) 281 | self.tanh = nn.Tanh() 282 | 283 | def forward(self, cond_vector, truncation=0.4): 284 | # cond_vector = cond_vector.unsqueeze(2).unsqueeze(3) 285 | z = self.gen_z(cond_vector) # cond_cector [batch_size, config.continuous_params_size], z [1, 4*4*16*self.conf.channel_width] 286 | 287 | # We use this conversion step to be able to use TF weights: 288 | # TF convention on shape is [batch, height, width, channels] 289 | # PT convention on shape is [batch, channels, height, width] 290 | z = z.view(-1, 4, 4, 16 * self.conf.channel_width) # [batch_size, 4, 4, 2048] 291 | z = z.permute(0, 3, 1, 2).contiguous() # [batch_size, 2048, 4, 4] 292 | 293 | for i, layer in enumerate(self.layers): 294 | if isinstance(layer, GenBlock): 295 | z = layer(z, cond_vector, truncation) 296 | else: 297 | z = layer(z) 298 | 299 | z = self.bn(z, truncation) # [1, 128, 512, 512] 300 | z = self.relu(z) # [1, 128, 512, 512] 301 | z = self.conv_to_rgb(z) # [1, 128, 512, 512] 302 | z = z[:, :3, ...] # [1, 3, 512, 512] 303 | z = self.tanh(z) # [1, 3, 512, 512] 304 | return z 305 | 306 | ''' 307 | 自定义Imitator的config 308 | ''' 309 | class BigGANConfig(object): 310 | """ Configuration class to store the configuration of a `BigGAN`. 311 | Defaults are for the 128x128 model. 312 | layers tuple are (up-sample in the layer ?, input channels, output channels) 313 | """ 314 | def __init__(self, 315 | output_dim=512, 316 | z_dim=512, 317 | class_embed_dim=512, 318 | channel_width=512, 319 | num_classes=1000, 320 | # (是否上采样,input_channels,output_channels) 321 | layers=[(False, 16, 16), 322 | (True, 16, 16), 323 | (False, 16, 16), 324 | (True, 16, 8), 325 | (False, 8, 8), 326 | (True, 8, 4), 327 | (False, 4, 4), 328 | (True, 4, 2), 329 | (False, 2, 2), 330 | (True, 2, 1)], 331 | attention_layer_position=8, 332 | eps=1e-4, 333 | n_stats=51): 334 | """Constructs BigGANConfig. """ 335 | self.output_dim = output_dim 336 | self.z_dim = z_dim 337 | self.class_embed_dim = class_embed_dim 338 | self.channel_width = channel_width 339 | self.num_classes = num_classes 340 | self.layers = layers 341 | self.attention_layer_position = attention_layer_position 342 | self.eps = eps 343 | self.n_stats = n_stats 344 | 345 | @classmethod 346 | def from_dict(cls, json_object): 347 | """Constructs a `BigGANConfig` from a Python dictionary of parameters.""" 348 | config = BigGANConfig() 349 | for key, value in json_object.items(): 350 | config.__dict__[key] = value 351 | return config 352 | 353 | @classmethod 354 | def from_json_file(cls, json_file): 355 | """Constructs a `BigGANConfig` from a json file of parameters.""" 356 | with open(json_file, "r", encoding='utf-8') as reader: 357 | text = reader.read() 358 | return cls.from_dict(json.loads(text)) 359 | 360 | def __repr__(self): 361 | return str(self.to_json_string()) 362 | 363 | def to_dict(self): 364 | """Serializes this instance to a Python dictionary.""" 365 | output = copy.deepcopy(self.__dict__) 366 | return output 367 | 368 | def to_json_string(self): 369 | """Serializes this instance to a JSON string.""" 370 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 371 | 372 | 373 | imitator = MyImitator() 374 | if device.type == 'cuda': 375 | imitator = nn.DataParallel(imitator) 376 | imitator.to(device) 377 | 378 | # Initialize BCELoss function 379 | criterion = nn.L1Loss() 380 | 381 | # optimizer = optim.SGD(imitator.parameters(), lr=lr, momentum=0.9) 382 | optimizer = optim.Adam(params=imitator.parameters(), lr=5e-5, 383 | betas=(0.0, 0.999), weight_decay=0, 384 | eps=1e-8) 385 | 386 | # 每50个epoch衰减10% 387 | # scheduler = lr_scheduler.StepLR(optimizer, step_size=len(train_dataloader) * 50, gamma=0.9) 388 | 389 | total_step = len(train_dataloader) 390 | imitator.train() 391 | train_loss_list = [] 392 | val_loss_list = [] 393 | for epoch in range(num_epochs): 394 | start = time.time() 395 | for i, (params, img) in enumerate(train_dataloader): 396 | optimizer.zero_grad() 397 | params = params.to(device) 398 | img = img.to(device) 399 | outputs = imitator(params) 400 | loss = criterion(outputs, img) 401 | loss.backward() 402 | optimizer.step() 403 | 404 | if (i % 10) == 0: 405 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, spend time: {:.4f}' 406 | .format(epoch + 1, num_epochs, i + 1, total_step, loss.item(), time.time() - start)) 407 | start = time.time() 408 | 409 | train_loss_list.append(loss.item()) 410 | imitator.eval() 411 | with torch.no_grad(): 412 | val_loss = 0 413 | for i, (params, img) in enumerate(val_dataloader): 414 | params = params.to(device) 415 | img = img.to(device) 416 | outputs = imitator(params) 417 | loss = criterion(outputs, img) 418 | val_loss += loss.item() 419 | if i == 1: 420 | vutils.save_image( 421 | vutils.make_grid(outputs.to(device)[:16], nrow=4, padding=2, normalize=True).cpu(), 422 | image_root + "gen_image/%d.jpg" % epoch) 423 | val_loss_list.append(val_loss / len(val_dataloader)) 424 | 425 | print('Epoch [{}/{}], val_loss: {:.6f}' 426 | .format(epoch + 1, num_epochs, val_loss)) 427 | if (epoch % 10) == 0 or (epoch+1) == num_epochs: 428 | torch.save(imitator.state_dict(), 429 | image_root + 'model/epoch_{}_val_loss_{:.6f}_file.pt'.format( 430 | epoch, val_loss)) 431 | if epoch >= 1: 432 | plt.figure() 433 | plt.subplot(121) 434 | plt.plot(np.arange(0, len(train_loss_list)), train_loss_list) 435 | plt.subplot(122) 436 | plt.plot(np.arange(0, len(val_loss_list)), val_loss_list) 437 | plt.savefig(image_root + "metrics.jpg") 438 | plt.close("all") 439 | 440 | imitator.train() 441 | -------------------------------------------------------------------------------- /train_translator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torchvision.utils as vutils 9 | import torchvision.transforms as T 10 | from torch.utils.data import DataLoader 11 | import matplotlib.pyplot as plt 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | import config 16 | from imitator import Imitator 17 | from dataset import Translator_Dataset, split_dataset 18 | from lightcnn import LightCNN_29Layers_v2 19 | from face_parser import load_model 20 | from translator import Translator 21 | 22 | root_path = "./data/" 23 | r1 = 0.01 24 | r2 = 1 25 | r3 = 1 26 | 27 | def criterion_lightcnn(x1, x2): 28 | distance = torch.cosine_similarity(x1, x2)[0] 29 | return 1 - distance 30 | 31 | if __name__ == '__main__': 32 | # 1.初始化并加载imitator 33 | imitator = Imitator() 34 | if len(config.imitator_model) > 0: 35 | if config.use_gpu: 36 | imitator_model = torch.load(config.imitator_model) 37 | else: 38 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu')) 39 | print("load Imitator pretrained model success!") 40 | else: 41 | print("No pretrained model...") 42 | 43 | imitator = imitator.to(config.device) 44 | for param in imitator.parameters(): 45 | param.requires_grad = False 46 | 47 | # 2.加载lightcnn 48 | lightcnn = LightCNN_29Layers_v2(num_classes=80013) 49 | lightcnn = lightcnn.to(config.device) 50 | lightcnn.eval() 51 | if config.use_gpu: 52 | checkpoint = torch.load(config.lightcnn_checkpoint) 53 | model = torch.nn.DataParallel(lightcnn).cuda() 54 | model.load_state_dict(checkpoint['state_dict']) 55 | else: 56 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu") 57 | new_state_dict = lightcnn.state_dict() 58 | for k, v in checkpoint['state_dict'].items(): 59 | _name = k[7:] # remove `module.` 60 | new_state_dict[_name] = v 61 | lightcnn.load_state_dict(new_state_dict) 62 | 63 | # 冻结lightcnn 64 | for param in lightcnn.parameters(): 65 | param.requires_grad = False 66 | 67 | # 3.T网络 68 | translator = Translator(isBias=False) 69 | if config.device.type == 'cuda': 70 | translator = nn.DataParallel(translator) 71 | translator.to(config.device) 72 | 73 | # 4.加载face_parser 74 | transform = T.Compose([ 75 | T.Normalize(mean=[0.485, 0.456, 0.406], 76 | std=[0.229, 0.224, 0.225]), 77 | ]) # 语义分割之前要先做Normalize 78 | deeplab = load_model('mobilenetv2', num_classes=config.num_classes, output_stride=config.output_stride) 79 | checkpoint = torch.load(config.faceparse_checkpoint, map_location=config.device) 80 | if config.faceparse_backbone == 'resnet50': 81 | deeplab.load_state_dict(checkpoint) 82 | else: 83 | deeplab.load_state_dict(checkpoint["model_state"]) 84 | deeplab.to(config.device) 85 | deeplab.eval() 86 | 87 | for param in deeplab.parameters(): 88 | param.requires_grad = False 89 | 90 | trainlist, vallist = split_dataset(root_path) 91 | 92 | # 损失 93 | criterion_param = nn.L1Loss() 94 | criterion_parser = nn.L1Loss() 95 | 96 | train_dataset = Translator_Dataset(trainlist) 97 | val_dataset = Translator_Dataset(vallist) 98 | train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) 99 | val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=True) 100 | 101 | optimizer = torch.optim.Adam(translator.parameters(), lr=1e-4) 102 | 103 | total_step = len(train_dataloader) 104 | translator.train() 105 | train_loss_list = [] 106 | val_loss_list = [] 107 | for epoch in range(config.total_epochs): 108 | start = time.time() 109 | for i, imgs in enumerate(train_dataloader): 110 | optimizer.zero_grad() 111 | # features = features.to(config.device) 112 | # params = params.to(config.device) 113 | imgs = imgs.to(config.device) 114 | 115 | # 先做语义分割 116 | parse = deeplab(transform(imgs)) 117 | 118 | imgs = T.Grayscale()(F.interpolate(imgs, (128, 128), mode='bilinear')) 119 | _, features = lightcnn(imgs) 120 | outputs = translator(features) # 223 121 | 122 | gen_img = imitator(outputs) # 生成图像 123 | gen_parse = deeplab(transform(gen_img)) 124 | gen_img = T.Grayscale()(F.interpolate(gen_img, (128, 128), mode='bilinear')) 125 | 126 | _, gen_features = lightcnn(gen_img) 127 | gen_outputs = translator(gen_features) 128 | 129 | loss1 = criterion_lightcnn(gen_features, features) 130 | loss2 = criterion_parser(gen_parse, parse) 131 | loss3 = criterion_param(outputs, gen_outputs) 132 | 133 | loss = r1 * loss1 + r2 * loss2 + r3 * loss3 134 | 135 | loss.backward() 136 | optimizer.step() 137 | 138 | if (i % 10) == 0: 139 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, spend time: {:.4f}' 140 | .format(epoch + 1, config.total_epochs, i + 1, total_step, loss.item(), time.time() - start)) 141 | start = time.time() 142 | 143 | train_loss_list.append(loss.item()) 144 | 145 | translator.eval() 146 | with torch.no_grad(): 147 | val_loss = 0 148 | for i, imgs in enumerate(val_dataloader): 149 | # features = features.to(config.device) 150 | # params = params.to(config.device) 151 | 152 | imgs = imgs.to(config.device) 153 | 154 | imgs = T.Grayscale()(F.interpolate(imgs, (128, 128), mode='bilinear')) 155 | _, features = lightcnn(imgs) 156 | outputs = translator(features) # 223 157 | 158 | gen_img = imitator(outputs) # 生成图像 159 | gen_img = T.Grayscale()(F.interpolate(gen_img, (128, 128), mode='bilinear')) 160 | 161 | _, gen_features = lightcnn(gen_img) 162 | gen_outputs = translator(gen_features) 163 | 164 | loss1 = criterion_lightcnn(gen_features, features) 165 | loss3 = criterion_param(outputs, gen_outputs) 166 | 167 | loss = r1 * loss1 + r3 * loss3 168 | 169 | val_loss += loss.item() 170 | val_loss_list.append(val_loss / len(val_dataloader)) 171 | 172 | print('Epoch [{}/{}], val_loss: {:.6f}' 173 | .format(epoch + 1, config.total_epochs, val_loss)) 174 | if (epoch % 1) == 0 or (epoch + 1) == config.total_epochs: 175 | torch.save(translator.state_dict(), './checkpoint/translator_{}_{:.6f}.pt'.format(epoch, val_loss)) 176 | if epoch >= 1: 177 | plt.figure() 178 | plt.subplot(121) 179 | plt.plot(np.arange(0, len(train_loss_list)), train_loss_list) 180 | plt.subplot(122) 181 | plt.plot(np.arange(0, len(val_loss_list)), val_loss_list) 182 | plt.savefig(root_path + "metrics.jpg") 183 | plt.close("all") 184 | 185 | translator.train() 186 | -------------------------------------------------------------------------------- /translator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import config 5 | 6 | ''' 7 | Translator网络,从face-recognition到params 8 | 网络结构参考第二篇论文:《Fast and Robust Face-to-Parameter Translation for Game Character Auto-Creation》 9 | keypoint: 10 | 1.use the Adam optimizer to train T with the learning rate of 1e-4 and max-iteration of 20 epochs. 11 | 2.the learning rate decay is set to 10% per 50 epochs.(参照imitator训练策略,但一般用不上,epoch=40时就可以停止训练了) 12 | ''' 13 | class Translator(nn.Module): 14 | def __init__(self, isBias=False): 15 | super(Translator, self).__init__() 16 | self.fc1 = nn.Linear(in_features=256, out_features=512, bias=isBias) 17 | self.resatt1 = ResAttention(isBias=isBias) 18 | self.resatt2 = ResAttention(isBias=isBias) 19 | self.resatt3 = ResAttention(isBias=isBias) 20 | self.fc2 = nn.Linear(in_features=512, out_features=config.continuous_params_size, bias=isBias) 21 | self.bn = nn.BatchNorm1d(config.continuous_params_size) 22 | 23 | def forward(self, x): 24 | y = self.fc1(x) 25 | y_ = y + self.resatt1(y) 26 | y_ = y_ + self.resatt2(y_) 27 | y_ = y_ + self.resatt3(y_) 28 | return self.bn(self.fc2(y_)) 29 | 30 | 31 | ''' 32 | 带attention的res模块 33 | ''' 34 | class ResAttention(nn.Module): 35 | def __init__(self, isBias=False): 36 | super(ResAttention, self).__init__() 37 | self.fc1 = nn.Linear(512, 1024, bias=isBias) 38 | self.bn1 = nn.BatchNorm1d(1024) 39 | self.relu1 = nn.ReLU() 40 | 41 | self.fc2 = nn.Linear(1024, 512, bias=isBias) 42 | self.bn2 = nn.BatchNorm1d(512) 43 | 44 | # 这里开始分支 45 | self.fc3 = nn.Linear(512, 16, bias=isBias) 46 | self.relu3 = nn.ReLU() 47 | 48 | self.fc4 = nn.Linear(16, 512, bias=isBias) 49 | self.sigmoid4 = nn.Sigmoid() 50 | 51 | # 这里开始做点乘 52 | 53 | self.relu5 = nn.ReLU() 54 | 55 | def forward(self, x): 56 | y = self.fc1(x) 57 | y = self.bn1(y) 58 | y = self.relu1(y) 59 | 60 | y = self.fc2(y) 61 | y = self.bn2(y) 62 | 63 | # 这里开始分支 64 | 65 | y_ = self.fc3(y) 66 | y_ = self.relu3(y_) 67 | 68 | y_ = self.fc4(y_) 69 | y_ = self.sigmoid4(y_) 70 | 71 | # 做点乘 72 | y_ = torch.mul(y, y_) 73 | return self.relu5(y_) 74 | 75 | if __name__ == '__main__': 76 | trans = Translator() 77 | x = torch.randn([2, 256], dtype=torch.float32) # face-recognition的结果 78 | print(x.shape) 79 | y = trans(x) 80 | print(y.shape) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import struct 4 | import numpy as np 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms as transforms 10 | 11 | import config 12 | from faceparse import BiSeNet 13 | 14 | # 反卷积 15 | def deconv_layer(in_chanel, out_chanel, kernel_size, stride=1, pad=0): 16 | return nn.Sequential( 17 | nn.ConvTranspose2d(in_chanel, out_chanel, kernel_size=kernel_size, stride=stride, padding=pad), 18 | nn.BatchNorm2d(out_chanel), 19 | nn.ReLU()) 20 | 21 | # 自定义异常 22 | class NeuralException(Exception): 23 | def __init__(self, message): 24 | print("neural error: " + message) 25 | self.message = "neural exception: " + message 26 | 27 | ''' 28 | 提取原始图像的边缘 29 | :param img: input image 30 | :return: edge image 31 | ''' 32 | def img_edge(img): 33 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 34 | x_grad = cv2.Sobel(gray, cv2.CV_16SC1, 1, 0) 35 | y_grad = cv2.Sobel(gray, cv2.CV_16SC1, 0, 1) 36 | return cv2.Canny(x_grad, y_grad, 40, 130) 37 | 38 | ''' 39 | 将tensor转numpy array 给cv2使用 40 | :param tensor: [batch, c, w, h] 41 | :return: [batch, h, w, c] 42 | ''' 43 | def tensor_2_image(tensor): 44 | 45 | batch = tensor.size(0) 46 | images = [] 47 | for i in range(batch): 48 | img = tensor[i].cpu().detach().numpy() 49 | img = np.swapaxes(img, 0, 2) # [h, w, c] 50 | img = np.swapaxes(img, 0, 1) # [w, h, c] 51 | images.append(img * 255) 52 | return images 53 | 54 | ''' 55 | [W, H, 1] -> [W, H, 3] or [W, H]->[W, H, 3] 56 | :param image: input image 57 | :return: transfer image 58 | ''' 59 | def fill_gray(image): 60 | shape = image.shape 61 | if len(shape) == 2: 62 | image = image[:, :, np.newaxis] 63 | shape = image.shape 64 | if shape[2] == 1: 65 | return np.pad(image, ((0, 0), (0, 0), (1, 1)), 'edge') 66 | elif shape[2] == 3: 67 | return np.mean(image, axis=2) 68 | return image 69 | 70 | ''' 71 | imitator 快照 72 | :param path: save path 73 | :param tensor1: input photo 74 | :param tensor2: generated image 75 | :param parse: parse checkpoint's path 76 | ''' 77 | def capture(path, tensor1, tensor2, parse, cuda): 78 | img1 = tensor_2_image(tensor1)[0].swapaxes(0, 1).astype(np.uint8) 79 | img2 = tensor_2_image(tensor2)[0].swapaxes(0, 1).astype(np.uint8) 80 | img1 = cv2.resize(img1, (512, 512), interpolation=cv2.INTER_LINEAR) 81 | img3 = faceparsing_ndarray(img1, parse, cuda) 82 | img4 = img_edge(img3) 83 | img4 = 255 - fill_gray(img4) 84 | image = merge_4image(img1, img2, img3, img4, transpose=False) 85 | cv2.imwrite(path, image) 86 | 87 | def merge_4image(image1, image2, image3, image4, size=512, show=False, transpose=True): 88 | """ 89 | 拼接图片 90 | :param image1: input image1, numpy array 91 | :param image2: input image2, numpy array 92 | :param image3: input image3, numpy array 93 | :param image4: input image4, numpy array 94 | :param size: 输出分辨率 95 | :param show: 窗口显示 96 | :param transpose: 转置长和宽 cv2顺序[H, W, C] 97 | :return: merged image 98 | """ 99 | size_ = (int(size / 2), int(size / 2)) 100 | img_1 = cv2.resize(image1, size_) 101 | # cv2.imshow("img1", img_1) 102 | # cv2.waitKey() 103 | 104 | img_2 = cv2.resize(image2, size_) 105 | # cv2.imshow("img2", img_2) 106 | # cv2.waitKey() 107 | 108 | img_3 = cv2.resize(image3, size_) 109 | # cv2.imshow("img3", img_3) 110 | # cv2.waitKey() 111 | 112 | img_4 = cv2.resize(image4, size_) 113 | # cv2.imshow("img4", img_4) 114 | # cv2.waitKey() 115 | 116 | image1_ = np.append(img_1, img_2, axis=1) # axis=1,行 117 | image2_ = np.append(img_3, img_4, axis=1) 118 | image = np.append(image1_, image2_, axis=0) 119 | if transpose: 120 | image = image.swapaxes(0, 1) 121 | if show: 122 | cv2.imshow("contact", image) 123 | cv2.waitKey() 124 | cv2.destroyAllWindows() 125 | return image 126 | 127 | ''' 128 | evaluate with numpy array 129 | :param input: numpy array, 注意一定要是np.uint8, 而不是np.float32 [H, W, C] 130 | :param cp: args.parsing_checkpoint, str 131 | :param cuda: use gpu to speedup 132 | ''' 133 | def faceparsing_ndarray(input, checkpoint, cuda=False): 134 | # 构建BiSeNet并加载模型 135 | bsnet = BiSeNet(n_classes=19) 136 | if cuda: 137 | bsnet.cuda() 138 | bsnet.load_state_dict(torch.load(checkpoint)) 139 | else: 140 | bsnet.load_state_dict(torch.load(checkpoint, map_location="cpu")) 141 | bsnet.eval() 142 | 143 | to_tensor = transforms.Compose( 144 | [ 145 | transforms.ToTensor(), # [H, W, C]->[C, H, W] 146 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 147 | ]) 148 | 149 | # input_ = _to_tensor_(input) 150 | input_ = to_tensor(input) 151 | input_ = torch.unsqueeze(input_, 0) 152 | 153 | if cuda: 154 | input_ = input_.cuda() 155 | out = bsnet(input_) 156 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0) 157 | return vis_parsing_maps(input, parsing, stride=1) 158 | 159 | ''' 160 | 结果可视化 161 | ''' 162 | def vis_parsing_maps(im, parsing, stride): 163 | """ 164 | # 显示所有部位 165 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170], [0, 255, 0], [85, 255, 0], 166 | [170, 255, 0], [0, 255, 85], [0, 255, 170], [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255], 167 | [0, 170, 255], [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255], 168 | [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]] 169 | """ 170 | # 只显示脸 鼻子 眼睛 眉毛 嘴巴 171 | part_colors = [[255, 255, 255], [255, 85, 0], [25, 170, 0], [255, 170, 0], [254, 0, 170], [254, 0, 170], 172 | [255, 255, 255], 173 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [0, 0, 254], [85, 0, 255], [170, 0, 255], 174 | [0, 85, 255], 175 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], 176 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255]] 177 | """ 178 | part_colors = [[255, 255, 255], [脸], [左眉], [右眉], [左眼], [右眼], 179 | [255, 255, 255], 180 | [左耳], [右耳], [255, 255, 255], [鼻子], [牙齿], [上唇], 181 | [下唇], 182 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], 183 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255]] 184 | """ 185 | 186 | im = np.array(im) 187 | vis_parsing = parsing.copy().astype(np.uint8) 188 | vis_parsing_anno_color = np.zeros((vis_parsing.shape[0], vis_parsing.shape[1], 3)) + 255 189 | num_of_class = np.max(vis_parsing) 190 | for pi in range(1, num_of_class + 1): 191 | index = np.where(vis_parsing == pi) 192 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi] 193 | 194 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) 195 | return vis_parsing_anno_color 196 | 197 | ''' 198 | 论文里的判别损失, 判断真实照片和由模拟器生成的图像是否属于同一个身份 199 | Discriminative Loss 使用余弦距离 200 | https://www.cnblogs.com/dsgcBlogs/p/8619566.html 201 | :param lightcnn_inst: lightcnn model instance 202 | :param img1: generated by engine, type: list of Tensor 203 | :param img2: generated by imitator, type: list of Tensor 204 | :return tensor scalar,余弦相似度 205 | ''' 206 | def discriminative_loss(img1, img2, lightcnn_inst): 207 | x1 = batch_feature256(img1, lightcnn_inst) # [1, 256] 208 | x2 = batch_feature256(img2, lightcnn_inst) # [1, 256] 209 | distance = torch.cosine_similarity(x1, x2) 210 | return torch.mean(distance) 211 | 212 | def batch_feature256(img, lightcnn_inst): 213 | """ 214 | 使用light cnn提取256维特征参数 215 | :param lightcnn_inst: lightcnn model instance 216 | :param img: tensor 输入图片 shape:(batch, 1, 512, 512) 217 | :return: 256维特征参数 tensor [batch, 256] 218 | """ 219 | _, features = lightcnn_inst(img) 220 | # log.debug("features shape:{0} {1} {2}".format(features.size(), features.requires_grad, img.requires_grad)) 221 | return features 222 | 223 | ''' 224 | capture for result 225 | :param x: generated image with grad, torch tensor [b,params] 226 | :param refer: reference picture: [3, 512, 512] 227 | :param step: train step 228 | ''' 229 | def eval_output(imitator, x, refer, step, prev_path, L2_c): 230 | eval_write(x) 231 | y_ = imitator(x) 232 | y_ = y_.cpu().detach().numpy() 233 | y_ = np.squeeze(y_, axis=0) 234 | y_ = np.swapaxes(y_, 0, 2) * 255 235 | y_ = y_.astype(np.uint8) # [512, 512, 3] 236 | im1 = L2_c[0] # [512, 512] 237 | im2 = L2_c[1] # [512, 512] 238 | # np_im1 = im1.cpu().detach().numpy() 239 | # np_im2 = im2.cpu().detach().numpy() 240 | # f_im1 = fill_gray(np_im1) # [512, 512, 3],灰度图 241 | # f_im2 = fill_gray(np_im2) # [512, 512, 3] 242 | f_im1 = im1 # [512, 512, 3],这里直接显示原图 243 | f_im2 = im2 # [512, 512, 3] 244 | 245 | # refer 改为channel last的 246 | refer = np.transpose(refer, [1, 2, 0]) # [512, 512, 3] 247 | # print("f_im1:", type(f_im1), f_im1.shape) 248 | image_ = merge_4image(refer, y_, f_im1, f_im2, transpose=False) 249 | path = os.path.join(prev_path, "eval_{0}.jpg".format(step)) 250 | cv2.imwrite(path, image_) 251 | 252 | ''' 253 | 生成二进制文件 能够在unity里还原出来 254 | :param params: 捏脸参数 tensor [batch, params_cnt] 255 | ''' 256 | def eval_write(params): 257 | np_param = params.cpu().detach().numpy() 258 | np_param = np_param[0] 259 | list_param = np_param.tolist() 260 | dataset = config.train_set 261 | shape = curr_roleshape(dataset) 262 | path = os.path.join(config.model_path, "eval.bytes") 263 | f = open(path, 'wb') 264 | write_layer(f, shape, list_param) 265 | f.close() 266 | 267 | ''' 268 | 判断当前运行的是roleshape (c# RoleShape) 269 | :param dataset: args path_to_dataset 270 | :return: RoleShape 271 | ''' 272 | def curr_roleshape(dataset): 273 | if dataset.find("female") >= 0: 274 | return 4 275 | else: 276 | return 3 277 | 278 | def write_layer(f, shape, args): 279 | f.write(struct.pack('i', shape)) 280 | for it in args: 281 | byte = struct.pack('f', it) 282 | f.write(byte) 283 | 284 | ''' 285 | One-hot编码 argmax 处理 286 | :param params: 处理params 287 | :param start: One-hot 偏移起始地址 288 | :param count: One-hot 编码长度 289 | ''' 290 | def argmax_params(params, start, count): 291 | dims = params.size()[0] 292 | for dim in range(dims): 293 | tmp = params[dim, start] 294 | mx = start 295 | for idx in range(start + 1, start + count): 296 | if params[dim, idx] > tmp: 297 | mx = idx 298 | tmp = params[dim, idx] 299 | for idx in range(start, start + count): 300 | params[dim, idx] = 1. if idx == mx else 0 301 | 302 | # plot loss 303 | def eval_plot(losses): 304 | count = len(losses) 305 | if count > 0: 306 | plt.style.use('seaborn-whitegrid') 307 | x = range(count) 308 | y1 = [] 309 | y2 = [] 310 | for it in losses: 311 | y1.append(it[0]) 312 | y2.append(it[1]) 313 | plt.plot(x, y1, color='r', label='1-L1') 314 | plt.plot(x, y2, color='g', label='L2') 315 | plt.ylabel("loss") 316 | plt.xlabel('step') 317 | plt.legend() 318 | path = os.path.join(config.prev_path, "loss.png") 319 | plt.savefig(path) 320 | plt.close('all') 321 | 322 | ''' 323 | dlib检测68个关键点 324 | :param img BGR三通道图 325 | ''' 326 | def detect_face_keypoint(img): 327 | # 取灰度 328 | img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 329 | rects = config.detector(img_gray, 0) 330 | for i in range(len(rects)): 331 | landmarks = np.matrix([[p.x, p.y] for p in config.predictor(img, rects[i]).parts()]) 332 | for idx, point in enumerate(landmarks): 333 | pos = (point[0, 0], point[0, 1]) 334 | print(idx, pos) 335 | 336 | cv2.circle(img, pos, 5, color=(0, 255, 0)) 337 | # 利用cv2.putText输出1-68 338 | font = cv2.FONT_HERSHEY_SIMPLEX 339 | cv2.putText(img, str(idx + 1), pos, font, 0.8, (0, 0, 255), 1, cv2.LINE_AA) 340 | cv2.imshow("img", img) 341 | cv2.waitKey(0) 342 | 343 | ''' 344 | 解决argmax()不可微的问题,自定义ArgMax 345 | 原理:argmax = softmax + c,c为常数 346 | ''' 347 | class ArgMax(torch.autograd.Function): 348 | @staticmethod 349 | def forward(input): 350 | idx = torch.argmax(input, 0) 351 | output = torch.zeros_like(input) 352 | output.scatter(1, idx, 1) 353 | return output 354 | 355 | @staticmethod 356 | def backward(grad_output): 357 | return grad_output 358 | --------------------------------------------------------------------------------