├── .gitignore
├── DCGAN.py
├── README.md
├── backbone
├── mobilenet.py
└── resnet.py
├── checkpoint
├── imitator_gitopen.pth
└── myimitator-512.json
├── config.py
├── dat
├── 103.jpg
├── 110.jpg
├── 16.jpg
└── avg_face.jpg
├── dataset.py
├── demo.py
├── evaluate_adam.py
├── evaluate_sgd_L1Loss.py
├── evaluate_sgd_cross_entropy.py
├── face_align.py
├── face_parser.py
├── faceparse.py
├── faceswap.py
├── imitator.py
├── lightcnn.py
├── model_process.py
├── myimitator.py
├── papers
├── 1909.01064v1(Face-to-Parameter Translation ).pdf
├── 2003.05653(Towards High-Fidelity 3D Face Reconstruction).pdf
└── 2008.07132v1(Fast and Robust Face-to-Parameter Translation).pdf
├── random_gen_image.py
├── resnet.py
├── tools
├── demo.py
├── face_align.py
├── fid.py
└── lsgan.py
├── train_myimitator.py
├── train_translator.py
├── translator.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | #From https://github.com/github/gitignore/blob/master/Python.gitignore
2 |
3 | # Byte-compiled / optimized / DLL files
4 | */__pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | develop-eggs/
14 | dist/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # pipenv
86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
88 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
89 | # install all needed dependencies.
90 | #Pipfile.lock
91 |
92 | # celery beat schedule file
93 | celerybeat-schedule
94 |
95 | # SageMath parsed files
96 | *.sage.py
97 |
98 | # Environments
99 | .env
100 | .venv
101 | env/
102 | venv/
103 | ENV/
104 | env.bak/
105 | venv.bak/
106 |
107 | # Spyder project settings
108 | .spyderproject
109 | .spyproject
110 | .project
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 |
118 | # mypy
119 | .mypy_cache/
120 | .dmypy.json
121 | dmypy.json
122 |
123 | # Pyre type checker
124 | .pyre/
125 |
126 | .py~
127 | .gitignore~
128 |
129 | .vscode/
130 | .idea/
131 | *.ipynb
132 |
133 |
134 | logs/
135 | output/
136 | face_data/
137 | checkpoint/shape_predictor_68_face_landmarks.dat
138 | checkpoint/79999_iter.pth
139 | checkpoint/resnet18-5c106cde.pth
140 | checkpoint/LightCNN_29Layers_V2_checkpoint.pth.tar
141 | checkpoint/rtnet50-fcn-14.torch
142 | checkpoint/epoch_130_0.666491.pt
143 | checkpoint/epoch_340_0.434396.pt
144 | checkpoint/faceseg_65_0.030560_0.025751_0.827335-7.pth
145 | checkpoint/faceseg_179_0.050777_0.065476_0.842724_withface7.pth
--------------------------------------------------------------------------------
/DCGAN.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import random
5 | from PIL import Image
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.parallel
9 | from torch.utils.data import DataLoader, Dataset
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim as optim
12 | import torch.utils.data
13 | import torchvision
14 | import torchvision.transforms as transforms
15 | import torchvision.utils as vutils
16 |
17 | '''
18 | DCGAN pytorch官方实现
19 | https://github.com/pytorch/examples/blob/master/dcgan/main.py
20 | '''
21 |
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--dataroot', default='F:/BaiduNetdiskDownload/faces/', required=False, help='path to dataset')
24 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=0)
25 | parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
26 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
27 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
28 | parser.add_argument('--ngf', type=int, default=64)
29 | parser.add_argument('--ndf', type=int, default=64)
30 | parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train for')
31 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
32 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
33 | parser.add_argument('--cuda', action='store_true', help='enables cuda')
34 | parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works')
35 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
36 | parser.add_argument('--netG', default='', help="path to netG (to continue training)")
37 | parser.add_argument('--netD', default='', help="path to netD (to continue training)")
38 | parser.add_argument('--outf', default='./output/', help='folder to output images and model checkpoints')
39 | parser.add_argument('--manualSeed', type=int, help='manual seed')
40 |
41 | opt = parser.parse_args()
42 | print(opt)
43 |
44 | try:
45 | os.makedirs(opt.outf)
46 | except OSError:
47 | pass
48 |
49 | if opt.manualSeed is None:
50 | opt.manualSeed = random.randint(1, 10000)
51 | print("Random Seed: ", opt.manualSeed)
52 | random.seed(opt.manualSeed)
53 | torch.manual_seed(opt.manualSeed)
54 |
55 | cudnn.benchmark = True
56 |
57 | if torch.cuda.is_available() and not opt.cuda:
58 | print("WARNING: You have a CUDA device, so you should probably run with --cuda")
59 |
60 | if opt.dataroot is None:
61 | raise ValueError("`dataroot` parameter is required ")
62 |
63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64 | print(torch.cuda.is_available())
65 |
66 | # 图像增强
67 | transform = transforms.Compose([
68 | transforms.Resize(opt.imageSize),
69 | transforms.CenterCrop(opt.imageSize),
70 | transforms.ToTensor(),
71 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
72 | ])
73 |
74 | # 自定义Dataset
75 | class DCGAN_Dataset(Dataset):
76 | def __init__(self, data_root, transform):
77 | self.data_root = data_root
78 | self.transform = transform
79 | self.fileList = []
80 | for file in os.listdir(self.data_root):
81 | self.fileList.append(os.path.join(self.data_root, file))
82 |
83 | def __getitem__(self, index):
84 | file = self.fileList[index]
85 | img = Image.open(os.path.join(self.data_root, file)).convert("RGB")
86 |
87 | return self.transform(img)
88 |
89 | def __len__(self):
90 | return len(self.fileList)
91 |
92 | dataset = DCGAN_Dataset(opt.dataroot, transform)
93 |
94 | dataloader = torch.utils.data.DataLoader(
95 | dataset=dataset,
96 | batch_size=opt.batchSize,
97 | shuffle=True,
98 | num_workers=0
99 | )
100 |
101 | device = torch.device("cuda:0" if opt.cuda else "cpu")
102 | ngpu = int(opt.ngpu)
103 | nz = int(opt.nz)
104 | ngf = int(opt.ngf)
105 | ndf = int(opt.ndf)
106 |
107 | # custom weights initialization called on netG and netD
108 | def weights_init(m):
109 | classname = m.__class__.__name__
110 | if classname.find('Conv') != -1:
111 | torch.nn.init.normal_(m.weight, 0.0, 0.02)
112 | elif classname.find('BatchNorm') != -1:
113 | torch.nn.init.normal_(m.weight, 1.0, 0.02)
114 | torch.nn.init.zeros_(m.bias)
115 |
116 |
117 | class Generator(nn.Module):
118 | def __init__(self, ngpu):
119 | super(Generator, self).__init__()
120 | self.ngpu = ngpu
121 | self.main = nn.Sequential(
122 | # input is Z, going into a convolution
123 | nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
124 | nn.BatchNorm2d(ngf * 8),
125 | nn.ReLU(True),
126 | # state size. (ngf*8) x 4 x 4
127 | nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
128 | nn.BatchNorm2d(ngf * 4),
129 | nn.ReLU(True),
130 | # state size. (ngf*4) x 8 x 8
131 | nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
132 | nn.BatchNorm2d(ngf * 2),
133 | nn.ReLU(True),
134 | # state size. (ngf*2) x 16 x 16
135 | nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
136 | nn.BatchNorm2d(ngf),
137 | nn.ReLU(True),
138 | # state size. (ngf) x 32 x 32
139 | nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
140 | nn.Tanh()
141 | # state size. (nc) x 64 x 64
142 | )
143 |
144 | def forward(self, input):
145 | if input.is_cuda and self.ngpu > 1:
146 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
147 | else:
148 | output = self.main(input)
149 | return output
150 |
151 |
152 | netG = Generator(ngpu).to(device)
153 | netG.apply(weights_init)
154 | if opt.netG != '':
155 | netG.load_state_dict(torch.load(opt.netG))
156 | print(netG)
157 |
158 |
159 | class Discriminator(nn.Module):
160 | def __init__(self, ngpu):
161 | super(Discriminator, self).__init__()
162 | self.ngpu = ngpu
163 | self.main = nn.Sequential(
164 | # input is (nc) x 64 x 64
165 | nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
166 | nn.LeakyReLU(0.2, inplace=True),
167 | # state size. (ndf) x 32 x 32
168 | nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
169 | nn.BatchNorm2d(ndf * 2),
170 | nn.LeakyReLU(0.2, inplace=True),
171 | # state size. (ndf*2) x 16 x 16
172 | nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
173 | nn.BatchNorm2d(ndf * 4),
174 | nn.LeakyReLU(0.2, inplace=True),
175 | # state size. (ndf*4) x 8 x 8
176 | nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
177 | nn.BatchNorm2d(ndf * 8),
178 | nn.LeakyReLU(0.2, inplace=True),
179 | # state size. (ndf*8) x 4 x 4
180 | nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
181 | nn.Sigmoid()
182 | )
183 |
184 | def forward(self, input):
185 | if input.is_cuda and self.ngpu > 1:
186 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
187 | else:
188 | output = self.main(input)
189 |
190 | return output.view(-1, 1).squeeze(1)
191 |
192 |
193 | netD = Discriminator(ngpu).to(device)
194 | netD.apply(weights_init)
195 | if opt.netD != '':
196 | netD.load_state_dict(torch.load(opt.netD))
197 | print(netD)
198 |
199 | criterion = nn.BCELoss()
200 |
201 | fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
202 | real_label = 1
203 | fake_label = 0
204 |
205 | # setup optimizer
206 | optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
207 | optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
208 |
209 | if opt.dry_run:
210 | opt.epochs = 1
211 |
212 | for epoch in range(opt.epochs):
213 | for i, data in enumerate(dataloader, 0):
214 | ############################
215 | # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
216 | ###########################
217 | # train with real
218 | netD.zero_grad()
219 | batch_size = data.shape[0]
220 | label = torch.full((batch_size,), real_label, dtype=torch.float32, device=device) # 这里label为1,真
221 |
222 | output = netD(data)
223 | errD_real = criterion(output, label) # 将真判断为真。这里是判别器的损失,output越接近1,损失越小
224 | errD_real.backward()
225 | D_x = output.mean().item()
226 |
227 | # train with fake
228 | noise = torch.randn(batch_size, nz, 1, 1, device=device)
229 | fake = netG(noise)
230 | label.fill_(fake_label) # 这里label为0,假
231 | output = netD(fake.detach()) # output越接近0,损失越小
232 | errD_fake = criterion(output, label) # 假的判断为假。这里是判别器的损失,output越为假,损失越小
233 | errD_fake.backward()
234 | D_G_z1 = output.mean().item() # 越小越好
235 | errD = errD_real + errD_fake
236 | optimizerD.step()
237 |
238 | ############################
239 | # (2) Update G network: maximize log(D(G(z)))
240 | ###########################
241 | netG.zero_grad()
242 | label.fill_(real_label) # 这里标签为1,真
243 | output = netD(fake) # 对于判别器,要尽可能将生成器生成的判断为假;
244 | errG = criterion(output, label) # 这里表征生成器的生成损失,errG越小,说明生成器生成的越逼真。计算output(判别器的输出,越接近0说明判别器越好)与label(生成器给出的标签,越接近1说明生成器越好)的损失
245 | errG.backward()
246 | D_G_z2 = output.mean().item() # 越小越好
247 | optimizerG.step()
248 |
249 | print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
250 | % (epoch, opt.epochs, i, len(dataloader),
251 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
252 | if i % 1 == 0:
253 | vutils.save_image(data,
254 | '%s/real_samples_%d.png' % (opt.outf, epoch),
255 | normalize=True)
256 | fake = netG(fixed_noise)
257 | vutils.save_image(fake.detach(),
258 | '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
259 | normalize=True)
260 |
261 | if opt.dry_run:
262 | break
263 | # do checkpointing
264 | torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
265 | torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 3D捏脸
2 |
3 | 网易《Face-to-Parameter Translation for Game Character Auto-Creation》论文复现
4 |
5 | ## 1、网络结构
6 |
7 | G:Imitator网络
8 |
9 | 输入:连续参数208d(归一化到0~1),离散参数102d;
10 |
11 | 输出:游戏角色正脸图像(512x512x3);
12 |
13 | SGD,batch_size=16,momentum=0.9,learning_rate=0.01 decay 10% per 50epochs,max_epochs=500;
14 |
15 | F1:lightcnn,做人脸识别,确保生成前后为同一个人,算余弦损失;
16 |
17 | 输入:128x128x1灰度图;
18 |
19 | 输出:256d向量;
20 |
21 | F2:faceparse,做面部语义分割(任意网络即可,deeplab、BiSeNet、RTNet等),算带权重的L1损失;(不算L2的原因:L1的稀疏性,能突出五官的特征;而L2是平滑性)
22 |
23 | 输入:256x256x3 RGB图像;
24 |
25 | 输出:语义分割结果;
26 |
27 | 语义概率图增强人脸特征,采用不同权重突出五官的特征;
28 |
29 | 也可使用交叉熵损失,衡量将每个像素点分类为不同类型的概率分布距离;
30 |
31 | ## 2、数据预处理
32 |
33 | Face alignment:dlib
34 |
35 | ## 3、缺点
36 |
37 | (1)对人脸鲁棒性低:人脸姿态、遮挡敏感;
38 |
39 | (2)离散参数求解未能有效解决;
40 |
41 | (3)需要多次迭代优化:基于梯度下降的算法运算效率低;
42 |
43 | (4)只适合真人向的游戏角色;
44 |
45 | ## 4、优化项
46 |
47 | Imitator基于DCGAN的G网络,网络结构简单,特征表达能力有限,导致生成的图像模糊,五官畸形严重。
48 |
49 | 解决办法:
50 |
51 | 1、采用DCGAN的D网络,学习从图像到参数的映射,来代替基于梯度下降的优化过程;
52 |
53 | 2、参照第二篇论文,《Fast and Robust Face-to-Parameter Translation for Game Character Auto-Creation》,新增T网络,学习face-recognition到参数的映射关系。
54 |
--------------------------------------------------------------------------------
/backbone/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | '''
6 | backbone: mobilenetV2
7 | '''
8 | class MobileNetV2(nn.Module):
9 | '''
10 | :param num_classes 类别个数
11 | :param output_stride
12 | :param width_mult 通过该参数控制每一层的通道数量
13 | :param inverted_residual_setting
14 | :param round_neatest 将每层的通道数四舍五入为该数字的倍数,设为1则关闭四舍五入
15 | '''
16 | def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
17 | super(MobileNetV2, self).__init__()
18 | block = InvertedResidual
19 | input_channel = 32
20 | last_channel = 1280
21 | self.output_stride = output_stride
22 | current_stride = 1
23 | if inverted_residual_setting is None:
24 | inverted_residual_setting = [
25 | # t, c, n, s
26 | [1, 16, 1, 1],
27 | [6, 24, 2, 2],
28 | [6, 32, 3, 2],
29 | [6, 64, 4, 2],
30 | [6, 96, 3, 1],
31 | [6, 160, 3, 2],
32 | [6, 320, 1, 1],
33 | ]
34 |
35 | # 构建第一层
36 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
37 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
38 | features = [ConvBNReLU(3, input_channel, stride=2)]
39 | current_stride *= 2
40 | dilation = 1
41 | previous_dilation = 1
42 |
43 | # 构建中间的倒置残差块
44 | for t, c, n, s in inverted_residual_setting:
45 | output_channel = _make_divisible(c * width_mult, round_nearest)
46 | previous_dilation = dilation
47 | if current_stride == output_stride:
48 | stride = 1
49 | dilation *= s
50 | else:
51 | stride = s
52 | current_stride *= s
53 | output_channel = int(c * width_mult)
54 |
55 | for i in range(n):
56 | if i == 0:
57 | features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
58 | else:
59 | features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
60 | input_channel = output_channel
61 |
62 | # 构建最后一层
63 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
64 | self.features = nn.Sequential(*features)
65 |
66 | # 构建最后的分类层
67 | self.classifier = nn.Sequential(
68 | nn.Dropout(0.2),
69 | nn.Linear(self.last_channel, num_classes)
70 | )
71 |
72 | # weight initialization
73 | for m in self.modules():
74 | if isinstance(m, nn.Conv2d):
75 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
76 | if m.bias is not None:
77 | nn.init.zeros_(m.bias)
78 | elif isinstance(m, nn.BatchNorm2d):
79 | nn.init.ones_(m.weight)
80 | nn.init.zeros_(m.bias)
81 | elif isinstance(m, nn.Linear):
82 | nn.init.normal_(m.weight, 0, 0.01)
83 | nn.init.zeros_(m.bias)
84 |
85 | def forward(self, x):
86 | x = self.features(x)
87 | x = x.mean([2, 3])
88 | x = self.classifier(x)
89 | return x
90 |
91 | """
92 | This function is taken from the original tf repo.
93 | It ensures that all layers have a channel number that is divisible by 8
94 | It can be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
95 | :param v:
96 | :param divisor:
97 | :param min_value:
98 | :return:
99 | """
100 | def _make_divisible(v, divisor, min_value=None):
101 | if min_value is None:
102 | min_value = divisor
103 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
104 | if new_v < 0.9 * v: # 确保四舍五入的下降幅度不超过10%
105 | new_v += divisor
106 | return new_v
107 |
108 | '''
109 | CBR结构:Conv+BN+Relu
110 | '''
111 | class ConvBNReLU(nn.Sequential):
112 | def __init__(self, input_channel, output_channel, kernel_size=3, stride=1, dilation=1, groups=1):
113 | super(ConvBNReLU, self).__init__(
114 | nn.Conv2d(input_channel, output_channel, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups, bias=False),
115 | nn.BatchNorm2d(output_channel),
116 | nn.ReLU6(inplace=True)
117 | )
118 |
119 | '''
120 | 倒置残差模块
121 | '''
122 | class InvertedResidual(nn.Module):
123 | def __init__(self, input_channel, output_channel, stride, dilation, expand_ratio):
124 | super(InvertedResidual, self).__init__()
125 | self.stride = stride
126 |
127 | hidden_dim = int(round(input_channel * expand_ratio))
128 | self.use_residual_connect = self.stride == 1 and input_channel == output_channel # 判断是否采用残差连接
129 | layers = []
130 |
131 | if expand_ratio != 1:
132 | layers.append(ConvBNReLU(input_channel, hidden_dim, kernel_size=1))
133 |
134 | layers.extend([
135 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), # 采用深度可分离卷积,各通道各卷积各的
136 | nn.Conv2d(hidden_dim, output_channel, 1, 1, 0, bias=False),
137 | nn.BatchNorm2d(output_channel)
138 | ])
139 | self.conv = nn.Sequential(*layers)
140 | self.input_padding = fix_padding(3, dilation)
141 |
142 | def forward(self, x):
143 | x_pad = F.pad(x, self.input_padding)
144 | if self.use_residual_connect:
145 | return x + self.conv(x_pad)
146 | else:
147 | return self.conv(x_pad)
148 |
149 |
150 | def fix_padding(kernel_size, dilation):
151 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
152 | pad_total = kernel_size_effective - 1
153 | pad_beg = pad_total // 2
154 | pad_end = pad_total - pad_beg
155 | return (pad_beg, pad_end, pad_beg, pad_end)
156 |
--------------------------------------------------------------------------------
/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | '''
6 | backbone: resnet
7 | '''
8 | class ResNet50(nn.Module):
9 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
10 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
11 | norm_layer=None):
12 | super(ResNet50, self).__init__()
13 | if norm_layer is None:
14 | norm_layer = nn.BatchNorm2d
15 | self._norm_layer = norm_layer
16 |
17 | self.inplanes = 64
18 | self.dilation = 1
19 | if replace_stride_with_dilation is None:
20 | # each element in the tuple indicates if we should replace
21 | # the 2x2 stride with a dilated convolution instead
22 | replace_stride_with_dilation = [False, False, False]
23 | if len(replace_stride_with_dilation) != 3:
24 | raise ValueError("replace_stride_with_dilation should be None "
25 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
26 | self.groups = groups
27 | self.base_width = width_per_group
28 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
29 | bias=False)
30 | self.bn1 = norm_layer(self.inplanes)
31 | self.relu = nn.ReLU(inplace=True)
32 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
33 | self.layer1 = self._make_layer(block, 64, layers[0])
34 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
35 | dilate=replace_stride_with_dilation[0])
36 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
37 | dilate=replace_stride_with_dilation[1])
38 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
39 | dilate=replace_stride_with_dilation[2])
40 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
41 | self.fc = nn.Linear(512 * block.expansion, num_classes)
42 |
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
46 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
47 | nn.init.constant_(m.weight, 1)
48 | nn.init.constant_(m.bias, 0)
49 |
50 | # Zero-initialize the last BN in each residual branch,
51 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
52 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
53 | if zero_init_residual:
54 | for m in self.modules():
55 | nn.init.constant_(m.bn3.weight, 0)
56 |
57 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
58 | norm_layer = self._norm_layer
59 | downsample = None
60 | previous_dilation = self.dilation
61 | if dilate:
62 | self.dilation *= stride
63 | stride = 1
64 | if stride != 1 or self.inplanes != planes * block.expansion:
65 | downsample = nn.Sequential(
66 | conv1x1(self.inplanes, planes * block.expansion, stride),
67 | norm_layer(planes * block.expansion),
68 | )
69 |
70 | layers = []
71 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
72 | self.base_width, previous_dilation, norm_layer))
73 | self.inplanes = planes * block.expansion
74 | for _ in range(1, blocks):
75 | layers.append(block(self.inplanes, planes, groups=self.groups,
76 | base_width=self.base_width, dilation=self.dilation,
77 | norm_layer=norm_layer))
78 |
79 | return nn.Sequential(*layers)
80 |
81 | def forward(self, x):
82 | x = self.conv1(x)
83 | x = self.bn1(x)
84 | x = self.relu(x)
85 | x = self.maxpool(x)
86 |
87 | x = self.layer1(x)
88 | x = self.layer2(x)
89 | x = self.layer3(x)
90 | x = self.layer4(x)
91 |
92 | x = self.avgpool(x)
93 | x = torch.flatten(x, 1)
94 | x = self.fc(x)
95 |
96 | return x
97 |
98 | class Bottleneck(nn.Module):
99 | expansion = 4
100 |
101 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
102 | base_width=64, dilation=1, norm_layer=None):
103 | super(Bottleneck, self).__init__()
104 | if norm_layer is None:
105 | norm_layer = nn.BatchNorm2d
106 | width = int(planes * (base_width / 64.)) * groups
107 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
108 | self.conv1 = conv1x1(inplanes, width)
109 | self.bn1 = norm_layer(width)
110 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
111 | self.bn2 = norm_layer(width)
112 | self.conv3 = conv1x1(width, planes * self.expansion)
113 | self.bn3 = norm_layer(planes * self.expansion)
114 | self.relu = nn.ReLU(inplace=True)
115 | self.downsample = downsample
116 | self.stride = stride
117 |
118 | def forward(self, x):
119 | identity = x
120 |
121 | out = self.conv1(x)
122 | out = self.bn1(out)
123 | out = self.relu(out)
124 |
125 | out = self.conv2(out)
126 | out = self.bn2(out)
127 | out = self.relu(out)
128 |
129 | out = self.conv3(out)
130 | out = self.bn3(out)
131 |
132 | if self.downsample is not None:
133 | identity = self.downsample(x)
134 |
135 | out += identity
136 | out = self.relu(out)
137 |
138 | return out
139 |
140 | '''
141 | conv 3*3
142 | '''
143 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
144 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
145 | padding=dilation, groups=groups, bias=False, dilation=dilation)
146 |
147 | '''
148 | conv 1*1
149 | '''
150 | def conv1x1(in_planes, out_planes, stride=1):
151 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
152 |
--------------------------------------------------------------------------------
/checkpoint/imitator_gitopen.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/checkpoint/imitator_gitopen.pth
--------------------------------------------------------------------------------
/checkpoint/myimitator-512.json:
--------------------------------------------------------------------------------
1 | {
2 | "attention_layer_position": 8,
3 | "channel_width": 128,
4 | "class_embed_dim": 128,
5 | "eps": 0.0001,
6 | "layers": [
7 | [
8 | false,
9 | 16,
10 | 16
11 | ],
12 | [
13 | true,
14 | 16,
15 | 16
16 | ],
17 | [
18 | false,
19 | 16,
20 | 16
21 | ],
22 | [
23 | true,
24 | 16,
25 | 8
26 | ],
27 | [
28 | false,
29 | 8,
30 | 8
31 | ],
32 | [
33 | true,
34 | 8,
35 | 8
36 | ],
37 | [
38 | false,
39 | 8,
40 | 8
41 | ],
42 | [
43 | true,
44 | 8,
45 | 4
46 | ],
47 | [
48 | false,
49 | 4,
50 | 4
51 | ],
52 | [
53 | true,
54 | 4,
55 | 2
56 | ],
57 | [
58 | false,
59 | 2,
60 | 2
61 | ],
62 | [
63 | true,
64 | 2,
65 | 1
66 | ],
67 | [
68 | false,
69 | 1,
70 | 1
71 | ],
72 | [
73 | true,
74 | 1,
75 | 1
76 | ]
77 | ],
78 | "n_stats": 51,
79 | "num_classes": 1000,
80 | "output_dim": 512,
81 | "z_dim": 128
82 | }
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import dlib
2 | import torch
3 |
4 | # 通用配置项
5 | continuous_params_size = 159 # 连续参数个数,github开源的是95个参数
6 | # image_root = "F:/dataset/face_simple/face/"
7 | # train_params_root = "F:/dataset/face_simple/train_param.json"
8 | # test_params_root = "F:/dataset/face_simple/test_param.json"
9 |
10 | # train_set = "./face_data/trainset_female"
11 | # test_set = "./face_data/testset_female"
12 |
13 |
14 | image_root = "F:/dataset/face_20211203_20000_nojiemao/"
15 | params_root = "F:/dataset/face_20211203_20000_nojiemao/param.json"
16 |
17 | use_gpu = False # 是否使用gpu
18 | num_gpu = 1 # gpu的个数
19 | device = torch.device('cuda:0') if torch.cuda.is_available() and use_gpu else torch.device('cpu')
20 | path_tensor_log = "./logs/"
21 |
22 |
23 | # imitator配置项
24 | total_epochs = 500
25 | batch_size = 16
26 | save_freq = 10
27 | prev_freq = 10
28 | learning_rate = 1
29 | # imitator_model = "./checkpoint/imitator.pth" # 不做finetune,就直接写空字符串
30 | imitator_model = "./checkpoint/epoch_340_0.434396.pt"
31 | # imitator_model = ""
32 |
33 | prev_path = "./output/preview"
34 | # prev_path = "E:/nielian/"
35 | model_path = "./output/imitator"
36 |
37 | # 评估时
38 | total_eval_steps = 50
39 | eval_alpha = 0.1 # Ls = alpha * L1 + L2
40 | eval_learning_rate = 1
41 | eval_prev_freq = 1
42 |
43 |
44 | # 人脸语义分割
45 | faceparse_backbone = 'mobilenetv2'
46 | faceparse_checkpoint = "./checkpoint/faceseg_179_0.050777_0.065476_0.842724_withface7.pth"
47 | num_classes = 7
48 | output_stride = 16
49 | pretrained = True
50 | progress = True
51 | model_urls = "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" # mobilenetv2
52 | # model_urls = "https://download.pytorch.org/models/resnet50-19c8e357.pth" # resnet
53 |
54 | # light-cnn
55 | lightcnn_checkpoint = "./checkpoint/LightCNN_29Layers_V2_checkpoint.pth.tar"
56 |
57 | # 人脸关键点检测及摆正
58 | detector = dlib.get_frontal_face_detector()
59 | predictor = dlib.shape_predictor('./checkpoint/shape_predictor_68_face_landmarks.dat')
60 |
61 | # 自定义imitator
62 | config_jsonfile = "./checkpoint/myimitator-512.json"
63 |
64 |
65 |
--------------------------------------------------------------------------------
/dat/103.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/103.jpg
--------------------------------------------------------------------------------
/dat/110.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/110.jpg
--------------------------------------------------------------------------------
/dat/16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/16.jpg
--------------------------------------------------------------------------------
/dat/avg_face.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/dat/avg_face.jpg
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import torch
5 | from PIL import Image
6 | from torchvision import transforms as T
7 | from torch.utils.data import DataLoader, Dataset
8 |
9 | import config
10 |
11 | '''
12 | Imitator Dataset
13 | '''
14 | class Imitator_Dataset(Dataset):
15 | def __init__(self, params_root, image_root, mode="train"):
16 | self.image_root = image_root
17 | self.mode = mode
18 | with open(params_root, encoding='utf-8') as f:
19 | self.params = json.load(f)
20 |
21 | def __getitem__(self, index):
22 | if self.mode == "val":
23 | img = Image.open(os.path.join(self.image_root, '%d.png' % (index + 54000))).convert("RGB")
24 | param = torch.tensor(self.params['%d.png' % (index + 54000)])
25 | else:
26 | img = Image.open(os.path.join(self.image_root, '%d.png' % index)).convert("RGB")
27 | param = torch.tensor(self.params['%d.png' % index])
28 | img = T.ToTensor()(img)
29 | return param, img
30 |
31 | def __len__(self):
32 | if self.mode == "train":
33 | return 54000
34 | else:
35 | return 6000
36 |
37 | ############################################################
38 | '''
39 | Translator Dataset
40 | '''
41 | def split_dataset(datapath):
42 | trainlist, vallist = [], []
43 | for file in os.listdir(datapath):
44 | if random.randint(0, 10) <= 8:
45 | trainlist.append(os.path.join(datapath, file))
46 | else:
47 | vallist.append(os.path.join(datapath, file))
48 | return trainlist, vallist
49 |
50 | class Translator_Dataset(Dataset):
51 | def __init__(self, img_list):
52 | self.img_list = img_list
53 |
54 | def __getitem__(self, index):
55 | img = Image.open(self.img_list[index]).convert("RGB").resize((512, 512), Image.BILINEAR)
56 | img = T.ToTensor()(img)
57 | return img
58 |
59 | def __len__(self):
60 | return len(self.img_list)
61 |
62 |
63 | if __name__ == '__main__':
64 | train_imitator_Dataset = Imitator_Dataset(config.params_root, config.image_root)
65 | train_imitator_dataloader = DataLoader(train_imitator_Dataset, batch_size=16, shuffle=True)
66 | for i, content in enumerate(train_imitator_dataloader):
67 | x, y = content[:]
68 | print(i, x.shape, y.shape)
69 |
70 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import utils
4 | import config
5 | from imitator import Imitator
6 |
7 | imitator = Imitator()
8 |
9 | # model_dict = {"model.0.weight": "0.0.weight",
10 | # "model.0.bias": "0.0.bias",
11 | # "model.1.weight": "0.1.weight",
12 | # "model.1.bias": "0.1.bias",
13 | # "model.1.running_mean": "0.1.running_mean",
14 | # "model.1.running_var": "0.1.running_var",
15 | # "model.1.num_batches_tracked": "0.1.num_batches_tracked",
16 | # "model.3.weight": "1.0.weight",
17 | # "model.3.bias": "1.0.bias",
18 | # "model.4.weight": "1.1.weight",
19 | # "model.4.bias": "1.1.bias",
20 | # "model.4.running_mean": "1.1.running_mean",
21 | # "model.4.running_var": "1.1.running_var",
22 | # "model.4.num_batches_tracked": "1.1.num_batches_tracked",
23 | # "model.6.weight": "2.0.weight",
24 | # "model.6.bias": "2.0.bias",
25 | # "model.7.weight": "2.1.weight",
26 | # "model.7.bias": "2.1.bias",
27 | # "model.7.running_mean": "2.1.running_mean",
28 | # "model.7.running_var": "2.1.running_var",
29 | # "model.7.num_batches_tracked": "2.1.num_batches_tracked",
30 | # "model.9.weight": "3.0.weight",
31 | # "model.9.bias": "3.0.bias",
32 | # "model.10.weight": "3.1.weight",
33 | # "model.10.bias": "3.1.bias",
34 | # "model.10.running_mean": "3.1.running_mean",
35 | # "model.10.running_var": "3.1.running_var",
36 | # "model.10.num_batches_tracked": "3.1.num_batches_tracked",
37 | # "model.12.weight": "4.0.weight",
38 | # "model.12.bias": "4.0.bias",
39 | # "model.13.weight": "4.1.weight",
40 | # "model.13.bias": "4.1.bias",
41 | # "model.13.running_mean": "4.1.running_mean",
42 | # "model.13.running_var": "4.1.running_var",
43 | # "model.13.num_batches_tracked": "4.1.num_batches_tracked",
44 | # "model.15.weight": "5.0.weight",
45 | # "model.15.bias": "5.0.bias",
46 | # "model.16.weight": "5.1.weight",
47 | # "model.16.bias": "5.1.bias",
48 | # "model.16.running_mean": "5.1.running_mean",
49 | # "model.16.running_var": "5.1.running_var",
50 | # "model.16.num_batches_tracked": "5.1.num_batches_tracked",
51 | # "model.18.weight": "6.0.weight",
52 | # "model.18.bias": "6.0.bias",
53 | # "model.19.weight": "6.1.weight",
54 | # "model.19.bias": "6.1.bias",
55 | # "model.19.running_mean": "6.1.running_mean",
56 | # "model.19.running_var": "6.1.running_var",
57 | # "model.19.num_batches_tracked": "6.1.num_batches_tracked",
58 | # "model.21.weight": "7.weight",
59 | # "model.21.bias": "7.bias"
60 | # }
61 | #
62 | # checkpoint = torch.load("./checkpoint/model_imitator_100000_cuda.pth", map_location=torch.device('cpu'))
63 | # # imitator.load_state_dict(checkpoint['net'])
64 | # for k1, k2 in zip(imitator.state_dict(), checkpoint['net']):
65 | # # print("%s\t%s\t%s\t%s" % (k1, imitator.state_dict()[k1].shape, k2, checkpoint['net'][k2].shape))
66 | # content = '"%s": "%s",' % (k1, k2)
67 | # print(content)
68 | #
69 | # state_dict = torch.load(config.imitator_model, map_location=torch.device('cpu'))['net']
70 | # #
71 | # # op_model = {}
72 | # # for k in state_dict['net'].keys():
73 | # # op_model["model." + str(k)[2:]] = imitator_model['net'][k]
74 | # for new_key, old_key in model_dict.items():
75 | # state_dict[new_key] = state_dict.pop(old_key)
76 | #
77 | # imitator.load_state_dict(state_dict)
78 | # torch.save(state_dict, "./imitator.pth")
79 |
80 |
81 | model = torch.load(config.imitator_model)
82 | imitator.load_state_dict(model)
83 |
--------------------------------------------------------------------------------
/evaluate_adam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from tqdm import tqdm
4 | from PIL import Image
5 | import torch
6 | import torch.nn.functional as F
7 | import numpy as np
8 |
9 | import utils
10 | import config
11 | from imitator import Imitator
12 | from lightcnn import LightCNN_29Layers_v2
13 | from faceparse import BiSeNet
14 |
15 | '''
16 | 评估:t_params封装起来,做Adam优化
17 | '''
18 |
19 | '''
20 | evaluate with torch tensor
21 | :param input: torch tensor [B, H, W, C] rang: [0-1], not [0-255]
22 | :param image: 做mask图底色用
23 | :param bsnet: BiSeNet model
24 | :param w: tuple len=6 [eyebrow,eye,nose,teeth,up lip,lower lip]
25 | :return 带权重的tensor;脸部mask;带mask的图
26 | '''
27 | def faceparsing_tensor(input, image, bsnet, w):
28 | out = bsnet(input) # [1, 19, 512, 512]
29 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0)
30 |
31 | mask_img = utils.vis_parsing_maps(image, parsing, 1)
32 | out = out.squeeze()
33 |
34 | return w[0] * out[3] + w[1] * out[4] + w[2] * out[10] + out[11] + out[12] + out[13], out[1], mask_img
35 |
36 | if __name__ == '__main__':
37 | eval_image = "./dat/gen.jpg" # 要评估的图片
38 |
39 | # 加载lightcnn
40 | model = LightCNN_29Layers_v2(num_classes=80013)
41 | model.eval()
42 | if config.use_gpu:
43 | checkpoint = torch.load(config.lightcnn_checkpoint)
44 | model = torch.nn.DataParallel(model).cuda()
45 | model.load_state_dict(checkpoint['state_dict'])
46 | else:
47 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu")
48 | new_state_dict = model.state_dict()
49 | for k, v in checkpoint['state_dict'].items():
50 | _name = k[7:] # remove `module.`
51 | new_state_dict[_name] = v
52 | model.load_state_dict(new_state_dict)
53 |
54 | # 冻结lightcnn
55 | for param in model.parameters():
56 | param.requires_grad = False
57 |
58 | losses = []
59 |
60 | # 加载BiSeNet
61 | bsnet = BiSeNet(n_classes=19)
62 | if config.use_gpu:
63 | bsnet.cuda()
64 | bsnet.load_state_dict(torch.load(config.faceparse_checkpoint))
65 | else:
66 | bsnet.load_state_dict(torch.load(config.faceparse_checkpoint, map_location="cpu"))
67 | bsnet.eval()
68 |
69 | # 冻结BiSeNet
70 | for param in bsnet.parameters():
71 | param.requires_grad = False
72 |
73 | # 加载imitator
74 | imitator = Imitator()
75 |
76 | l2_c = (torch.ones((512, 512)), torch.ones((512, 512)))
77 | if config.use_gpu:
78 | imitator.cuda()
79 | imitator.eval()
80 |
81 | # 冻结imitator
82 | for param in imitator.parameters():
83 | param.requires_grad = False
84 |
85 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu'))
86 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数
87 |
88 | # 图片读取
89 | img = cv2.imread(eval_image)
90 | img = cv2.resize(img, (512, 512))
91 | img = img.astype(np.float32)
92 |
93 | img1 = Image.open(eval_image)
94 | image = img1.resize((512, 512), Image.BILINEAR)
95 |
96 | # inference
97 | # t_params = 0.5 * torch.ones((1, config.continuous_params_size), dtype=torch.float32)
98 | t_params = torch.rand((1, config.continuous_params_size), dtype=torch.float32) # 论文中用均匀分布,torch.randn()为正态分布
99 | optimizer = torch.optim.Adam([t_params], lr=config.eval_learning_rate)
100 | if config.use_gpu:
101 | t_params = t_params.cuda()
102 | t_params.requires_grad = True
103 | losses.clear() # 清空损失
104 |
105 | # 用于计算L1损失的:L1_y [B, C, W, H]
106 | img_resized = cv2.resize(img, dsize=(128, 128), interpolation=cv2.INTER_LINEAR)
107 | img_resized = np.swapaxes(img_resized, 0, 2).astype(np.float32)
108 | img_resized = np.mean(img_resized, axis=0)[np.newaxis, np.newaxis, :, :]
109 |
110 | # cv2.imwrite("l1_tmp.jpg", cv2.resize(img_resized[0].transpose(1, 2, 0), (512, 512)))
111 | img_resized = torch.from_numpy(img_resized)
112 | if config.use_gpu:
113 | img_resized = img_resized.cuda()
114 | L1_y = img_resized
115 |
116 | # 用于计算L2损失的参照:L2_y [B, C, H, W]
117 | img = img[np.newaxis, :, :, ]
118 | img = np.swapaxes(img, 1, 2)
119 | img = np.swapaxes(img, 1, 3)
120 | # print(img.shape)
121 | # cv2.imwrite("l2_tmp.jpg", cv2.resize(img[0].transpose(1, 2, 0), (512, 512)))
122 |
123 | img = torch.from_numpy(img)
124 | if config.use_gpu:
125 | img = img.cuda()
126 | L2_y = img / 255.
127 |
128 |
129 |
130 | # 做total_eval_steps次训练,取最后一次
131 | m_progress = tqdm(range(1, config.total_eval_steps + 1))
132 | for i in m_progress:
133 | y_ = imitator(t_params) # [1, 3, 512, 512], [batch_size, c, w, h]
134 | # tmp = y_.detach().cpu().numpy()[0]
135 | # tmp = tmp.transpose(2, 1, 0)
136 | # tmp = tmp * 255.0
137 | # # cv2.imshow("y_", tmp)
138 | # # cv2.waitKey()
139 | # print(type(tmp), tmp.shape)
140 | # cv2.imwrite("./output/gen.jpg", tmp)
141 | # print("已保存")
142 | # break
143 | # loss, info = self.evaluate_ls(y_)
144 | y_copy = y_.clone() # 复制出一份算L2损失
145 |
146 | # 衡量人脸相似度
147 | # L1损失:表示余弦距离的损失(判断是否为同一个人),大致身份的一致性。y_ [B, C, W, H]
148 | y_ = F.max_pool2d(y_, kernel_size=(4, 4), stride=4) # 512->128, [1, 3, 128, 128]
149 | y_ = torch.mean(y_, dim=1).view(1, 1, 128, 128) # gray
150 |
151 | # 计算L1损失时,BCWH改为BCHW
152 | # L1 = utils.discriminative_loss(torch.from_numpy(L1_y.detach().cpu().numpy().transpose(0, 1, 3, 2)),
153 | # torch.from_numpy(y_.detach().cpu().numpy().transpose(0, 1, 3, 2)),
154 | # model)
155 |
156 | L1 = utils.discriminative_loss(L1_y, y_, model) # BCWH
157 | # if i % config.eval_prev_freq == 0:
158 | # cv2.imwrite("L1_y_%d.jpg" % i, cv2.resize(L1_y.detach().cpu().numpy()[0].transpose(1, 2, 0), (512, 512)))
159 | # cv2.imwrite("y_%d.jpg" % i, cv2.resize(y_.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255., (512, 512)))
160 |
161 | # L2损失:面部语义损失(关键部位加权计算损失)
162 | w_r = [1.1, 1., 1., 0.7, 1., 1.]
163 | w_g = [1.1, 1., 1., 0.7, 1., 1.]
164 | part1, _, mask_img1 = faceparsing_tensor(L2_y, image, bsnet, w_r)
165 | y_copy = y_copy.transpose(2, 3) # [1, 3, 512, 512]
166 | part2, _, mask_img2 = faceparsing_tensor(y_copy, image, bsnet, w_g)
167 | # if i % config.eval_prev_freq == 0:
168 | # cv2.imwrite("L2_y_%d.jpg" % i, L2_y.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.)
169 | # cv2.imwrite("y_copy_%d.jpg" % i, y_copy.detach().cpu().numpy()[0].transpose(1, 2, 0) * 255.)
170 |
171 | L2_c = (part1 * 10, part2 * 10)
172 | L2 = F.l1_loss(part1, part2)
173 |
174 | L2_mask = (mask_img1, mask_img2) # 用于在第二行显示mask的图
175 |
176 | # 计算综合损失:Ls = alpha * L1 + L2
177 | Ls = config.eval_alpha * (1 - L1) + L2
178 | # Ls = L2
179 | info = "L1:{0:.6f} L2:{1:.6f} Ls:{2:.6f}".format(1-L1, L2, Ls)
180 | print(info)
181 | losses.append((1-L1.item(), L2.item()/3, Ls.item()))
182 |
183 | optimizer.zero_grad()
184 | Ls.backward()
185 | optimizer.step()
186 |
187 | if i == 1:
188 | # utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), 0, config.prev_path, L2_c)
189 | utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), 0, config.prev_path, L2_mask)
190 |
191 | # eval_learning_rate = config.eval_learning_rate
192 | # if i % 10 == 0 and i != 0: # 评估时学习率,每5轮衰减20%
193 | # eval_learning_rate = (1 - 0.20) * eval_learning_rate
194 |
195 | # t_params.data = t_params.data - eval_learning_rate * t_params.grad.data
196 | # t_params.data = t_params.data.clamp(0., 1.)
197 | # print(i, t_params.grad, t_params.data)
198 |
199 | # one-hot编码:argmax处理(这里没搞清楚定义方法但没返回值的作用,直接写方法的内容在这里处理)
200 | # def argmax_params(params, start, count)
201 | # utils.argmax_params(t_params.data, 96, 3)sss #这行有没有用?待定
202 |
203 | # start = 96
204 | # count = 3
205 | # dims = t_params.size()[0]
206 | # for dim in range(dims):
207 | # tmp = t_params[dim, start]
208 | # mx = start
209 | # for idx in range(start + 1, start + count):
210 | # if t_params[dim, idx] > tmp:
211 | # mx = idx
212 | # tmp = t_params[dim, idx]
213 | # for idx in range(start, start + count):
214 | # t_params[dim, idx] = 1. if idx == mx else 0
215 |
216 | # one-hot编码:argmax处理结束
217 |
218 | t_params.grad.zero_()
219 | m_progress.set_description(info)
220 | if i % config.eval_prev_freq == 0:
221 | x = i / float(config.total_eval_steps)
222 | lr = config.eval_learning_rate * (1 - x) + 1e-2
223 | # utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), i, config.prev_path, L2_c)
224 | utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), i, config.prev_path, L2_mask)
225 | utils.eval_plot(losses)
226 | utils.eval_plot(losses)
227 |
228 | # 输出
229 | # utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), config.total_eval_steps+1, config.prev_path, L2_c)
230 | utils.eval_output(imitator, t_params, img[0].cpu().detach().numpy(), config.total_eval_steps+1, config.prev_path, L2_mask)
231 |
--------------------------------------------------------------------------------
/evaluate_sgd_L1Loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from tqdm import tqdm
5 | import torch
6 | import torch.nn.functional as F
7 | from torchvision import transforms as T
8 | import torchvision.utils as vutils
9 | from PIL import Image
10 |
11 | import utils
12 | import config
13 | from imitator import Imitator
14 | from lightcnn import LightCNN_29Layers_v2
15 | from faceparse import BiSeNet
16 | from face_parser import load_model
17 | from face_align import template_path, face_Align
18 | from faceswap import template_path, faceswap
19 |
20 | '''
21 | evaluate with torch tensor
22 | :param input: torch tensor [B, H, W, C] rang: [0-1], not [0-255]
23 | :param image: 做mask图底色用
24 | :param bsnet: BiSeNet model
25 | :param w: tuple len=6 [eyebrow,eye,nose,teeth,up lip,lower lip]
26 | :return 带权重的tensor;脸部mask;带mask的图
27 | '''
28 | def faceparsing_tensor(input, image, face_parse, is_label=False):
29 | out = face_parse(input) # [1, 19, 512, 512]
30 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0)
31 | mask_img = utils.vis_parsing_maps(image, parsing, 1)
32 | # # 7类:带脸
33 | # {
34 | # 0: 'background',
35 | # 1: 'face',
36 | # 2: 'brow',
37 | # 3: 'eye',
38 | # 4: 'nose',
39 | # 5: 'up_lip',
40 | # 6: 'down_lip'
41 | # }
42 | if is_label: # 如果是标签,则做成[1, 512, 512]的
43 | out = out.max(1)[1]
44 | else:
45 | out = out.squeeze() # [7, 512, 512]
46 |
47 | # 10类:background, face, right eyebrow, left eyebrow, right eye, left eye, nose, upper lip, tooth, down lip
48 | # 7类:background, face, brow, eye, nose, uplip, downlip
49 | # weights = torch.tensor([1., 1., 0., 0., 0.9, 0.8, 0.8], dtype=torch.float32).unsqueeze(1).unsqueeze(2)
50 |
51 | # return torch.mul(weights, out), mask_img
52 | return out, mask_img
53 |
54 | def main():
55 | eval_imagepath = "./dat/16.jpg" # 要评估的图片
56 |
57 | # 1.加载lightcnn
58 | lightcnn = LightCNN_29Layers_v2(num_classes=80013)
59 | lightcnn.eval()
60 | if config.use_gpu:
61 | checkpoint = torch.load(config.lightcnn_checkpoint)
62 | model = torch.nn.DataParallel(lightcnn).cuda()
63 | model.load_state_dict(checkpoint['state_dict'])
64 | else:
65 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu")
66 | new_state_dict = lightcnn.state_dict()
67 | for k, v in checkpoint['state_dict'].items():
68 | _name = k[7:] # remove `module.`
69 | new_state_dict[_name] = v
70 | lightcnn.load_state_dict(new_state_dict)
71 |
72 | # 冻结lightcnn
73 | for param in lightcnn.parameters():
74 | param.requires_grad = False
75 |
76 | losses = []
77 |
78 | # 2.加载语义分割网络
79 | mean = [0.485, 0.456, 0.406]
80 | std = [0.229, 0.224, 0.225]
81 | transform = T.Compose([
82 | T.Normalize(mean=[0.485, 0.456, 0.406],
83 | std=[0.229, 0.224, 0.225]),
84 | ])
85 |
86 | deeplab = load_model('mobilenetv2', num_classes=config.num_classes, output_stride=config.output_stride)
87 | checkpoint = torch.load(config.faceparse_checkpoint, map_location=torch.device('cpu'))
88 | if config.faceparse_backbone == 'resnet50':
89 | deeplab.load_state_dict(checkpoint)
90 | else:
91 | deeplab.load_state_dict(checkpoint["model_state"])
92 | deeplab.eval()
93 |
94 | for param in deeplab.parameters():
95 | param.requires_grad = False
96 |
97 | # 3.加载imitator
98 | imitator = Imitator(is_bias=True)
99 |
100 | l2_c = (torch.ones((512, 512)), torch.ones((512, 512)))
101 | if config.use_gpu:
102 | imitator.cuda()
103 | imitator.eval()
104 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu'))
105 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数
106 |
107 | # 冻结imitator
108 | for param in imitator.parameters():
109 | param.requires_grad = False
110 |
111 | # 读取图片
112 | # 先对齐
113 | warped_mask = face_Align(template_path, eval_imagepath)
114 | warped_mask = cv2.cvtColor(warped_mask, cv2.COLOR_BGR2RGB)
115 | # img_upload = Image.open(eval_imagepath)
116 | img_upload = Image.fromarray(np.uint8(warped_mask)) # 对齐后的
117 |
118 | image_mask = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # 做mask图用
119 |
120 | image_F1 = img_upload.convert('L').resize((128, 128), Image.BILINEAR) # F1损失:身份验证损失
121 | image_F1 = T.ToTensor()(image_F1).unsqueeze(0) # [1, 1, 128, 128]
122 |
123 | image_F2 = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # F2损失:内容损失
124 | image_F2 = T.ToTensor()(image_F2).unsqueeze(0) # [1, 3, 512, 512]
125 |
126 | image_F2 = transform(image_F2) # 训练mobilenet时做了normalize,推理时也得做
127 |
128 | t_params = torch.full([1, config.continuous_params_size], 0.5, dtype=torch.float32) # 从平均人脸初始化
129 | optimizer = torch.optim.SGD([t_params], lr=config.eval_learning_rate, momentum=0.9) # SGD带动量的
130 | if config.use_gpu:
131 | t_params = t_params.cuda()
132 |
133 | t_params.requires_grad = True
134 | losses.clear() # 清空损失
135 |
136 | # 做total_eval_steps次训练,取最后一次
137 | m_progress = tqdm(range(1, config.total_eval_steps + 1))
138 | for i in m_progress:
139 | gen_img = imitator(t_params) # [1, 3, 512, 512]
140 | gen_img_copy = gen_img.clone() # 复制出一份来展示用
141 |
142 | # 1.身份验证损失
143 | trans = T.Compose([
144 | T.Resize((128, 128)),
145 | T.Grayscale()
146 | ])
147 |
148 | F1_Loss = utils.discriminative_loss(image_F1, trans(gen_img), lightcnn) # 身份验证损失
149 |
150 | # 2.内容损失
151 | gen_img = transform(gen_img)
152 | upload_F2_feature, mask_img_upload = faceparsing_tensor(image_F2, image_mask, deeplab, is_label=False) # 参照
153 | gen_F2_feature, mask_img_gen = faceparsing_tensor(gen_img, image_mask, deeplab, is_label=False) # 生成
154 |
155 | F2_Loss = F.l1_loss(upload_F2_feature, gen_F2_feature)
156 | # F2_Loss = F.cross_entropy(gen_F2_feature.view(-1, 7, 512*512), upload_F2_feature.view(-1, 512*512).long()) # 这里改用交叉熵损失
157 |
158 | # 计算综合损失:Ls = alpha * L1 + L2
159 | # Ls = config.eval_alpha * (1 - F1_Loss) + F2_Loss
160 | Ls = F2_Loss
161 | info = "1-F1:{0:.6f} F2:{1:.6f} Ls:{2:.6f}".format(1 - F1_Loss, F2_Loss, Ls)
162 | print(info)
163 | losses.append((1 - F1_Loss.item(), F2_Loss.item() / 3, Ls.item()))
164 |
165 | optimizer.zero_grad()
166 | Ls.backward()
167 | optimizer.step()
168 |
169 | if i % 5 == 0 and i != 0: # 评估时学习率,每5轮衰减20%
170 | for p in optimizer.param_groups:
171 | p["lr"] *= 0.8
172 | t_params.data = t_params.data.clamp(0.05, 0.95)
173 | t_params.grad.zero_()
174 | m_progress.set_description(info)
175 |
176 | if i % config.eval_prev_freq == 0:
177 | # image_upload, gen_img, mask_img_upload, mask_img_gen
178 | # tmp = T.PILToTensor()(T.ToPILImage()(gen_img[0]))
179 | # tmp.show()
180 |
181 | show_img_list = [T.PILToTensor()(img_upload), T.PILToTensor()(T.ToPILImage()(gen_img_copy[0])),
182 | T.ToTensor()(mask_img_upload) * 255., T.ToTensor()(mask_img_gen) * 255.]
183 |
184 | label_show = vutils.make_grid(show_img_list, nrow=2, padding=2, normalize=True).cpu()
185 | vutils.save_image(label_show, os.path.join(config.prev_path, "eval_%d.png" % (i)))
186 |
187 |
188 | if __name__ == '__main__':
189 | main()
190 |
191 |
192 |
--------------------------------------------------------------------------------
/evaluate_sgd_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from tqdm import tqdm
5 | import torch
6 | import torch.nn.functional as F
7 | from torchvision import transforms as T
8 | import torchvision.utils as vutils
9 | from PIL import Image
10 |
11 | import utils
12 | import config
13 | from imitator import Imitator
14 | from lightcnn import LightCNN_29Layers_v2
15 | from face_parser import load_model
16 | from face_align import template_path, face_Align
17 | from faceswap import template_path, faceswap
18 |
19 | '''
20 | evaluate with torch tensor
21 | :param input: torch tensor [B, H, W, C] rang: [0-1], not [0-255]
22 | :param image: 做mask图底色用
23 | :param bsnet: BiSeNet model
24 | :param w: tuple len=6 [eyebrow,eye,nose,teeth,up lip,lower lip]
25 | :return 带权重的tensor;脸部mask;带mask的图
26 | '''
27 | def faceparsing_tensor(input, image, face_parse, is_label=False):
28 | out = face_parse(input) # [1, 19, 512, 512]
29 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0)
30 | mask_img = utils.vis_parsing_maps(image, parsing, 1)
31 | # # 7类
32 | # {
33 | # 0: 'background',
34 | # 1: 'face',
35 | # 2: 'brow',
36 | # 3: 'eye',
37 | # 4: 'nose',
38 | # 5: 'up_lip',
39 | # 6: 'down_lip'
40 | # }
41 |
42 | if is_label: # 如果是标签,则做成[1, 512, 512]的
43 | out = out.max(1)[1]
44 | else:
45 | out = out.squeeze() # [7, 512, 512]
46 | # tmp = out[4]
47 | return out, mask_img
48 |
49 | if __name__ == '__main__':
50 | eval_imagepath = "./dat/16.jpg" # 要评估的图片
51 |
52 | # 1.加载lightcnn
53 | lightcnn = LightCNN_29Layers_v2(num_classes=80013)
54 | lightcnn.eval()
55 | if config.use_gpu:
56 | checkpoint = torch.load(config.lightcnn_checkpoint)
57 | model = torch.nn.DataParallel(lightcnn).cuda()
58 | model.load_state_dict(checkpoint['state_dict'])
59 | else:
60 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu")
61 | new_state_dict = lightcnn.state_dict()
62 | for k, v in checkpoint['state_dict'].items():
63 | _name = k[7:] # remove `module.`
64 | new_state_dict[_name] = v
65 | lightcnn.load_state_dict(new_state_dict)
66 |
67 | # 冻结lightcnn
68 | for param in lightcnn.parameters():
69 | param.requires_grad = False
70 |
71 | losses = []
72 |
73 | # 2.加载语义分割网络
74 | mean = [0.485, 0.456, 0.406]
75 | std = [0.229, 0.224, 0.225]
76 | transform = T.Compose([
77 | T.Normalize(mean=[0.485, 0.456, 0.406],
78 | std=[0.229, 0.224, 0.225]),
79 | ])
80 |
81 | deeplab = load_model(backbone=config.faceparse_backbone, num_classes=config.num_classes, output_stride=config.output_stride)
82 | checkpoint = torch.load(config.faceparse_checkpoint, map_location=torch.device('cpu'))
83 | deeplab.load_state_dict(checkpoint['model_state'])
84 | deeplab.eval()
85 |
86 | for param in deeplab.parameters():
87 | param.requires_grad = False
88 |
89 | # 3.加载imitator
90 | imitator = Imitator(is_bias=True)
91 |
92 | l2_c = (torch.ones((512, 512)), torch.ones((512, 512)))
93 | if config.use_gpu:
94 | imitator.cuda()
95 | imitator.eval()
96 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu'))
97 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数
98 |
99 | # 冻结imitator
100 | for param in imitator.parameters():
101 | param.requires_grad = False
102 |
103 | # 读取图片
104 | # 先对齐
105 | warped_mask = face_Align(template_path, eval_imagepath)
106 | warped_mask = cv2.cvtColor(warped_mask, cv2.COLOR_BGR2RGB)
107 | # img_upload = Image.open(eval_imagepath)
108 | img_upload = Image.fromarray(np.uint8(warped_mask)) # 对齐后的
109 |
110 | image_mask = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # 做mask图用
111 |
112 | image_F1 = img_upload.convert('L').resize((128, 128), Image.BILINEAR) # F1损失:身份验证损失
113 | image_F1 = T.ToTensor()(image_F1).unsqueeze(0) # [1, 1, 128, 128]
114 |
115 | image_F2 = img_upload.convert('RGB').resize((512, 512), Image.BILINEAR) # F2损失:内容损失
116 | image_F2 = T.ToTensor()(image_F2).unsqueeze(0) # [1, 3, 512, 512]
117 | image_F2 = transform(image_F2) # 训练mobilenet时做了normalize,推理时也得做
118 |
119 | t_params = torch.full([1, config.continuous_params_size], 0.5, dtype=torch.float32) # 从平均人脸初始化
120 | optimizer = torch.optim.SGD([t_params], lr=config.eval_learning_rate, momentum=0.9) # SGD带动量的
121 | if config.use_gpu:
122 | t_params = t_params.cuda()
123 |
124 | t_params.requires_grad = True
125 | losses.clear() # 清空损失
126 |
127 | # 做total_eval_steps次训练,取最后一次
128 | m_progress = tqdm(range(1, config.total_eval_steps + 1))
129 | for i in m_progress:
130 | gen_img = imitator(t_params) # [1, 3, 512, 512]
131 | gen_img_copy = gen_img.clone() # 复制出一份来展示用
132 |
133 | # 1.身份验证损失
134 | trans = T.Compose([
135 | T.Resize((128, 128)),
136 | T.Grayscale()
137 | ])
138 |
139 | F1_Loss = utils.discriminative_loss(image_F1, trans(gen_img), lightcnn) # 身份验证损失
140 |
141 | # 2.内容损失
142 | gen_img = transform(gen_img)
143 | upload_F2_feature, mask_img_upload = faceparsing_tensor(image_F2, image_mask, deeplab, is_label=True) # 参照
144 | gen_F2_feature, mask_img_gen = faceparsing_tensor(gen_img, image_mask, deeplab, is_label=False) # 生成
145 |
146 | # tmp = upload_F2_feature.view(-1, 512*512).softmax(dim=0).max(0)[0].long()
147 |
148 | # F2_Loss = F.l1_loss(upload_F2_feature, gen_F2_feature)
149 | F2_Loss = F.cross_entropy(gen_F2_feature.view(-1, config.num_classes, 512*512), upload_F2_feature.view(-1, 512*512).long()) # 这里改用交叉熵损失
150 | # F2_Loss = F.cross_entropy(gen_F2_feature.view(-1, config.num_classes, 512*512), upload_F2_feature.view(-1, 512*512).softmax(dim=0).max(0)[0].view(-1, 512*512).long()) # 采用带条件概率的交叉熵损失
151 |
152 | # 计算综合损失:Ls = alpha * L1 + L2
153 | Ls = config.eval_alpha * (1 - F1_Loss) + F2_Loss
154 | # Ls = F1_Loss
155 | # Ls = F2_Loss
156 | info = "1-F1:{0:.6f} F2:{1:.6f} Ls:{2:.6f}".format(1 - F1_Loss, F2_Loss, Ls)
157 | print(info)
158 | losses.append((1 - F1_Loss.item(), F2_Loss.item(), Ls.item()))
159 |
160 | optimizer.zero_grad()
161 | Ls.backward()
162 | optimizer.step()
163 |
164 | if i % 5 == 0 and i != 0: # 评估时学习率,每5轮衰减20%
165 | for p in optimizer.param_groups:
166 | p["lr"] *= 0.8
167 | t_params.data = t_params.data.clamp(0.05, 0.95)
168 | t_params.grad.zero_()
169 | m_progress.set_description(info)
170 |
171 | if i % config.eval_prev_freq == 0 or i == 1:
172 | # image_upload, gen_img, mask_img_upload, mask_img_gen
173 | # tmp = T.PILToTensor()(T.ToPILImage()(gen_img[0]))
174 | # tmp.show()
175 |
176 | show_img_list = [T.PILToTensor()(img_upload), T.PILToTensor()(T.ToPILImage()(gen_img_copy[0])), T.ToTensor()(mask_img_upload) * 255., T.ToTensor()(mask_img_gen) * 255.]
177 |
178 | label_show = vutils.make_grid(show_img_list, nrow=2, padding=2, normalize=True).cpu()
179 | vutils.save_image(label_show, os.path.join(config.prev_path, "eval_%d.png" % i))
180 |
--------------------------------------------------------------------------------
/face_align.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import dlib
3 | import numpy
4 | import sys
5 | import matplotlib.pyplot as plt
6 |
7 | PREDICTOR_PATH = r"./checkpoint/shape_predictor_68_face_landmarks.dat" # 68个关键点landmarks的模型文件
8 | SCALE_FACTOR = 1 # 图像的放缩比
9 | FEATHER_AMOUNT = 15 # 羽化边界范围,越大,羽化能力越大,一定要奇数,不能偶数
10 |
11 | # 68个点
12 | FACE_POINTS = list(range(17, 68)) # 脸
13 | MOUTH_POINTS = list(range(48, 61)) # 嘴巴
14 | RIGHT_BROW_POINTS = list(range(17, 22)) # 右眉毛
15 | LEFT_BROW_POINTS = list(range(22, 27)) # 左眉毛
16 | RIGHT_EYE_POINTS = list(range(36, 42)) # 右眼睛
17 | LEFT_EYE_POINTS = list(range(42, 48)) # 左眼睛
18 | NOSE_POINTS = list(range(27, 35)) # 鼻子
19 | JAW_POINTS = list(range(0, 17)) # 下巴
20 |
21 | # 选取用于叠加在第一张脸上的第二张脸的面部特征
22 | # 特征点包括左右眼、眉毛、鼻子和嘴巴
23 | # 是否数量变多之后,会有什么干扰吗?
24 | ALIGN_POINTS = (FACE_POINTS + LEFT_BROW_POINTS + RIGHT_EYE_POINTS + LEFT_EYE_POINTS +
25 | RIGHT_BROW_POINTS + NOSE_POINTS + MOUTH_POINTS)
26 |
27 | # Points from the second image to overlay on the first. The convex hull of each
28 | # element will be overlaid.
29 | OVERLAY_POINTS = [
30 | LEFT_EYE_POINTS + RIGHT_EYE_POINTS + LEFT_BROW_POINTS + RIGHT_BROW_POINTS,
31 | NOSE_POINTS + MOUTH_POINTS,
32 | ]
33 | # 眼睛 ,眉毛 2 * 22
34 | # 鼻子,嘴巴 分开来
35 |
36 | # 定义用于颜色校正的模糊量,作为瞳孔距离的系数
37 | COLOUR_CORRECT_BLUR_FRAC = 0.6
38 |
39 | # 实例化脸部检测器
40 | detector = dlib.get_frontal_face_detector()
41 | # 加载训练模型
42 | # 并实例化特征提取器
43 | predictor = dlib.shape_predictor(PREDICTOR_PATH)
44 |
45 |
46 | # 定义了两个类处理意外
47 | class TooManyFaces(Exception):
48 | pass
49 |
50 |
51 | class NoFaces(Exception):
52 | pass
53 |
54 |
55 | def get_landmarks(im):
56 | '''
57 | 通过predictor 拿到68 landmarks
58 | '''
59 | rects = detector(im, 1)
60 |
61 | if len(rects) > 1:
62 | raise TooManyFaces
63 | if len(rects) == 0:
64 | raise NoFaces
65 |
66 | return numpy.matrix([[p.x, p.y] for p in predictor(im, rects[0]).parts()]) # 68*2的矩阵
67 |
68 |
69 | def annotate_landmarks(im, landmarks):
70 | '''
71 | 人脸关键点,画图函数
72 | '''
73 | im = im.copy()
74 | for idx, point in enumerate(landmarks):
75 | pos = (point[0, 0], point[0, 1])
76 | cv2.putText(im, str(idx), pos,
77 | fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,
78 | fontScale=0.4,
79 | color=(0, 0, 255))
80 | cv2.circle(im, pos, 3, color=(0, 255, 255))
81 | return im
82 |
83 |
84 | def draw_convex_hull(im, points, color):
85 | '''
86 | # 绘制凸多边形 计算凸包
87 | '''
88 | points = cv2.convexHull(points)
89 | cv2.fillConvexPoly(im, points, color=color)
90 |
91 |
92 | def get_face_mask(im, landmarks):
93 | '''获取面部特征部分(眉毛、眼睛、鼻子以及嘴巴)的图像掩码。
94 | 图像掩码作用于原图之后,原图中对应掩码部分为白色的部分才能显示出来,黑色的部分则不予显示,因此通过图像掩码我们就能实现对图像“裁剪”。
95 | 效果参考:https://dn-anything-about-doc.qbox.me/document-uid242676labid2260timestamp1477921310170.png/wm
96 | get_face_mask()的定义是为一张图像和一个标记矩阵生成一个遮罩,它画出了两个白色的凸多边形:一个是眼睛周围的区域,
97 | 一个是鼻子和嘴部周围的区域。之后它由11个(FEATHER_AMOUNT)像素向遮罩的边缘外部羽化扩展,可以帮助隐藏任何不连续的区域。
98 | '''
99 | im = numpy.zeros(im.shape[:2], dtype=numpy.float64)
100 |
101 | for group in OVERLAY_POINTS:
102 | draw_convex_hull(im,
103 | landmarks[group],
104 | color=1)
105 |
106 | im = numpy.array([im, im, im]).transpose((1, 2, 0))
107 |
108 | im = (cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) > 0) * 1.0
109 | im = cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0)
110 |
111 | return im
112 |
113 | # 返回一个仿射变换
114 | def transformation_from_points(points1, points2):
115 | """
116 | Return an affine transformation [s * R | T] such that:
117 | sum ||s*R*p1,i + T - p2,i||^2
118 | is minimized.
119 | """
120 | # Solve the procrustes problem by subtracting centroids, scaling by the
121 | # standard deviation, and then using the SVD to calculate the rotation. See
122 | # the following for more details:
123 | # https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem
124 |
125 | points1 = points1.astype(numpy.float64) # 人脸的指定关键点
126 | points2 = points2.astype(numpy.float64)
127 |
128 | # 数据标准化:先减去均值,再除以std,做成均值为0方差为1的序列
129 | # 每张脸各自做各自的标准化
130 | c1 = numpy.mean(points1, axis=0) # 分别算x和y的均值
131 | c2 = numpy.mean(points2, axis=0)
132 | points1 -= c1 # 浮动于均值的部分,[43, 2]
133 | points2 -= c2
134 |
135 | s1 = numpy.std(points1)
136 | s2 = numpy.std(points2)
137 | points1 /= s1 #
138 | points2 /= s2
139 |
140 | U, S, Vt = numpy.linalg.svd(points1.T * points2) #
141 |
142 | # The R we seek is in fact the transpose of the one given by U * Vt. This
143 | # is because the above formulation assumes the matrix goes on the right
144 | # (with row vectors) where as our solution requires the matrix to be on the
145 | # left (with column vectors).
146 | R = (U * Vt).T # [2, 2]
147 |
148 | return numpy.vstack([numpy.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)),
149 | numpy.matrix([0., 0., 1.])])
150 |
151 |
152 | def read_im_and_landmarks(fname):
153 | im = cv2.imread(fname, cv2.IMREAD_COLOR)
154 | im = cv2.resize(im, (im.shape[1] * SCALE_FACTOR,
155 | im.shape[0] * SCALE_FACTOR))
156 | s = get_landmarks(im) # [68, 2]
157 |
158 | return im, s
159 |
160 |
161 | def warp_im(im, M, dshape):
162 | '''
163 | 由 get_face_mask 获得的图像掩码还不能直接使用,因为一般来讲用户提供的两张图像的分辨率大小很可能不一样,而且即便分辨率一样,
164 | 图像中的人脸由于拍摄角度和距离等原因也会呈现出不同的大小以及角度,所以如果不能只是简单地把第二个人的面部特征抠下来直接放在第一个人脸上,
165 | 我们还需要根据两者计算所得的面部特征区域进行匹配变换,使得二者的面部特征尽可能重合。
166 |
167 | 仿射函数,warpAffine,能对图像进行几何变换
168 | 三个主要参数,第一个输入图像,第二个变换矩阵 np.float32 类型,第三个变换之后图像的宽高
169 |
170 | 对齐主要函数
171 | '''
172 | output_im = numpy.zeros(dshape, dtype=im.dtype) # [512, 512, 3]
173 | cv2.warpAffine(im,
174 | M[:2],
175 | (dshape[1], dshape[0]),
176 | dst=output_im,
177 | borderMode=cv2.BORDER_TRANSPARENT,
178 | flags=cv2.WARP_INVERSE_MAP)
179 | return output_im
180 |
181 |
182 | def correct_colours(im1, im2, landmarks1):
183 | '''
184 | 修改皮肤颜色,使两张图片在拼接时候显得更加自然。
185 | '''
186 | blur_amount = COLOUR_CORRECT_BLUR_FRAC * numpy.linalg.norm(
187 | numpy.mean(landmarks1[LEFT_EYE_POINTS], axis=0) -
188 | numpy.mean(landmarks1[RIGHT_EYE_POINTS], axis=0))
189 | blur_amount = int(blur_amount)
190 | if blur_amount % 2 == 0:
191 | blur_amount += 1
192 | im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0)
193 | im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0)
194 |
195 | # Avoid divide-by-zero errors.
196 | im2_blur += (128 * (im2_blur <= 1.0)).astype(im2_blur.dtype)
197 |
198 | return (im2.astype(numpy.float64) * im1_blur.astype(numpy.float64) /
199 | im2_blur.astype(numpy.float64))
200 |
201 |
202 | # 换脸函数
203 | def Switch_face(Base_path, cover_path):
204 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图
205 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图
206 |
207 | if len(landmarks1) == 0 & len(landmarks2) == 0:
208 | raise RuntimeError("Faces detected is no face!")
209 | if len(landmarks1) > 1 & len(landmarks2) > 1:
210 | raise RuntimeError("Faces detected is more than 1!")
211 |
212 | # landmarks1[ALIGN_POINTS]为人脸的的指定关键点
213 | M = transformation_from_points(landmarks1[ALIGN_POINTS],
214 | landmarks2[ALIGN_POINTS])
215 | mask = get_face_mask(im2, landmarks2)
216 | warped_mask = warp_im(mask, M, im1.shape)
217 | combined_mask = numpy.max([get_face_mask(im1, landmarks1), warped_mask],
218 | axis=0)
219 | warped_im2 = warp_im(im2, M, im1.shape)
220 | warped_corrected_im2 = correct_colours(im1, warped_im2, landmarks1)
221 |
222 | output_im = im1 * (1.0 - combined_mask) + warped_corrected_im2 * combined_mask
223 | return output_im
224 |
225 |
226 | # 人脸对齐函数
227 | def face_Align(Base_path, cover_path):
228 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图
229 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图
230 |
231 | # 得到仿射变换矩阵
232 | M = transformation_from_points(landmarks1[ALIGN_POINTS],
233 | landmarks2[ALIGN_POINTS])
234 | warped_im2 = warp_im(im2, M, im1.shape)
235 | return warped_im2
236 |
237 | FEATHER_AMOUNT = 19
238 |
239 | template_path = './dat/avg_face.jpg' # 模板
240 | if __name__ == '__main__':
241 | cover_path = './dat/16.jpg'
242 | warped_mask = face_Align(template_path, cover_path)
243 | cv2.imwrite("./dat/result_16.jpg", warped_mask)
244 |
245 | # plt.subplot(111)
246 | # plt.imshow(warped_mask) # 数据展示
247 | # plt.show()
--------------------------------------------------------------------------------
/face_parser.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision.models.utils import load_state_dict_from_url
5 | from collections import OrderedDict
6 |
7 | import config
8 | from backbone.resnet import ResNet50, Bottleneck
9 | from backbone.mobilenet import MobileNetV2
10 |
11 | '''
12 | 指定backbone,加载模型
13 | '''
14 | def load_model(backbone, num_classes, output_stride):
15 | if backbone == 'resnet50':
16 | model = segment_resnet(num_classes=num_classes, output_stride=output_stride)
17 | elif backbone == 'mobilenetv2':
18 | model = segment_mobilenetv2(num_classes=num_classes, output_stride=output_stride)
19 | else:
20 | raise NotImplementedError
21 | return model
22 |
23 | '''
24 | 语义分割网络,使用resnet作为backbone
25 | :param num_classes 分割的类别个数
26 | :param output_stride
27 | '''
28 | def segment_resnet(num_classes, output_stride):
29 | if output_stride == 8:
30 | replace_stride_with_dilation = [False, True, True] # 是否用空洞卷积代替stride
31 | aspp_dilate = [12, 24, 36] # ASPP结构空洞卷积的dilate大小
32 | else:
33 | replace_stride_with_dilation = [False, False, True]
34 | aspp_dilate = [6, 12, 18]
35 |
36 | # 加载backbone(这里resnet50最后两层仅为了加载预训练模型,需设置num_classes为1000)
37 | backbone = ResNet50(Bottleneck, [3, 4, 6, 3], num_classes=1000, replace_stride_with_dilation=replace_stride_with_dilation)
38 | if config.pretrained:
39 | state_dict = load_state_dict_from_url(config.model_urls, progress=config.progress)
40 | backbone.load_state_dict(state_dict)
41 | del state_dict
42 |
43 | inplanes = 2048
44 | low_level_planes = 256
45 |
46 | return_layers = {'layer4': 'out', 'layer1': 'low_level'} # layer4作为resnet的最后输出,作为encoder端ASPP结构的输入;layer1作为低等级特征,直接输入到decoder端
47 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
48 |
49 | # 提取网络的第几层输出结果并给一个别名
50 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
51 |
52 | # 组装成完整的分割模型
53 | # model = Simple_Segmentation_Model(backbone, classifier)
54 | model = FaceSegmentation(backbone, classifier)
55 | return model
56 |
57 | '''
58 | 语义分割网络,使用mobilenetv2作为backbone
59 | :param num_classes 分割的类别个数
60 | :param output_stride
61 | '''
62 | def segment_mobilenetv2(num_classes, output_stride):
63 | if output_stride == 8:
64 | aspp_dilate = [12, 24, 36]
65 | else:
66 | aspp_dilate = [6, 12, 18]
67 |
68 | # 这里num_classes要写1000,为加载上预训练模型
69 | backbone = MobileNetV2(num_classes=1000, output_stride=output_stride, width_mult=1.0, inverted_residual_setting=None, round_nearest=8)
70 | if config.pretrained:
71 | state_dict = load_state_dict_from_url(config.model_urls, progress=config.progress)
72 | backbone.load_state_dict(state_dict)
73 | del state_dict
74 |
75 | # 将backbone的特征分为高阶特征和低阶特征
76 | backbone.low_level_features = backbone.features[0:4] # 倒数第四层之前的全为低等级特征
77 | backbone.high_level_features = backbone.features[4:-1] # 第四层开始,到倒数第二层的为高等级特征
78 | backbone.features = None
79 | backbone.classifier = None
80 |
81 | inplanes = 320
82 | low_level_planes = 24
83 |
84 | return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'}
85 | classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate) # 这里要写传入参数的num_classes,为加载voc预训练模型考虑
86 |
87 | backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
88 |
89 | model = FaceSegmentation(backbone, classifier)
90 | return model
91 |
92 | '''
93 | deeplabv3+
94 | '''
95 | class DeepLabHeadV3Plus(nn.Module):
96 | def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
97 | super(DeepLabHeadV3Plus, self).__init__()
98 | self.project = nn.Sequential(
99 | nn.Conv2d(low_level_channels, 48, 1, bias=False),
100 | nn.BatchNorm2d(48),
101 | nn.ReLU(inplace=True),
102 | )
103 |
104 | self.aspp = ASPP(in_channels, aspp_dilate)
105 |
106 | self.classifier = nn.Sequential(
107 | nn.Conv2d(304, 256, 3, padding=1, bias=False),
108 | nn.BatchNorm2d(256),
109 | nn.ReLU(inplace=True),
110 | nn.Conv2d(256, num_classes, 1)
111 | )
112 | self._init_weight()
113 |
114 | def forward(self, feature):
115 | # print(feature.shape)
116 | low_level_feature = self.project(
117 | feature['low_level']) # return_layers = {'layer4': 'out', 'layer1': 'low_level'}
118 | # print(low_level_feature.shape)
119 | output_feature = self.aspp(feature['out'])
120 | # print(output_feature.shape)
121 | output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear',
122 | align_corners=False)
123 | # print(output_feature.shape)
124 | return self.classifier(torch.cat([low_level_feature, output_feature], dim=1))
125 |
126 | def _init_weight(self):
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | nn.init.kaiming_normal_(m.weight)
130 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
131 | nn.init.constant_(m.weight, 1)
132 | nn.init.constant_(m.bias, 0)
133 |
134 | '''
135 | ASPP: atrous conv spp,空洞卷积+SPP
136 | '''
137 | class ASPP(nn.Module):
138 | def __init__(self, in_channels, atrous_rates):
139 | super(ASPP, self).__init__()
140 | out_channels = 256
141 | modules = []
142 | modules.append(nn.Sequential(
143 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
144 | nn.BatchNorm2d(out_channels),
145 | nn.ReLU(inplace=True)))
146 |
147 | rate1, rate2, rate3 = tuple(atrous_rates)
148 | modules.append(ASPPConv(in_channels, out_channels, rate1))
149 | modules.append(ASPPConv(in_channels, out_channels, rate2))
150 | modules.append(ASPPConv(in_channels, out_channels, rate3))
151 | modules.append(ASPPPooling(in_channels, out_channels))
152 |
153 | self.convs = nn.ModuleList(modules)
154 |
155 | self.project = nn.Sequential(
156 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
157 | nn.BatchNorm2d(out_channels),
158 | nn.ReLU(inplace=True),
159 | nn.Dropout(0.1),)
160 |
161 | def forward(self, x):
162 | res = []
163 | for conv in self.convs:
164 | #print(conv(x).shape)
165 | res.append(conv(x))
166 | res = torch.cat(res, dim=1)
167 | return self.project(res)
168 |
169 | '''
170 | ASPP模块中conv结构
171 | '''
172 | class ASPPConv(nn.Sequential):
173 | def __init__(self, in_channels, out_channels, dilation):
174 | modules = [
175 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
176 | nn.BatchNorm2d(out_channels),
177 | nn.ReLU(inplace=True)
178 | ]
179 | super(ASPPConv, self).__init__(*modules)
180 |
181 | '''
182 | ASPP模块中pooling结构
183 | '''
184 | class ASPPPooling(nn.Sequential):
185 | def __init__(self, in_channels, out_channels):
186 | super(ASPPPooling, self).__init__(
187 | nn.AdaptiveAvgPool2d(1),
188 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
189 | nn.BatchNorm2d(out_channels),
190 | nn.ReLU(inplace=True))
191 |
192 | def forward(self, x):
193 | size = x.shape[-2:]
194 | x = super(ASPPPooling, self).forward(x)
195 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
196 |
197 | class IntermediateLayerGetter(nn.ModuleDict):
198 | """
199 | Module wrapper that returns intermediate layers from a model
200 |
201 | It has a strong assumption that the modules have been registered
202 | into the model in the same order as they are used.
203 | This means that one should **not** reuse the same nn.Module
204 | twice in the forward if you want this to work.
205 |
206 | Additionally, it is only able to query submodules that are directly
207 | assigned to the model. So if `model` is passed, `model.feature1` can
208 | be returned, but not `model.feature1.layer2`.
209 |
210 | Arguments:
211 | model (nn.Module): model on which we will extract the features
212 | return_layers (Dict[name, new_name]): a dict containing the names
213 | of the modules for which the activations will be returned as
214 | the key of the dict, and the value of the dict is the name
215 | of the returned activation (which the user can specify).
216 |
217 | Examples::
218 |
219 | >>> m = torchvision.models.resnet18(pretrained=True)
220 | >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
221 | >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
222 | >>> {'layer1': 'feat1', 'layer3': 'feat2'})
223 | >>> out = new_m(torch.rand(1, 3, 224, 224))
224 | >>> print([(k, v.shape) for k, v in out.items()])
225 | >>> [('feat1', torch.Size([1, 64, 56, 56])),
226 | >>> ('feat2', torch.Size([1, 256, 14, 14]))]
227 | """
228 | def __init__(self, model, return_layers):
229 | if not set(return_layers).issubset([name for name, _ in model.named_children()]):
230 | raise ValueError("return_layers are not present in model")
231 |
232 | orig_return_layers = return_layers
233 | return_layers = {k: v for k, v in return_layers.items()}
234 | layers = OrderedDict()
235 | for name, module in model.named_children():
236 | layers[name] = module
237 | if name in return_layers:
238 | del return_layers[name]
239 | if not return_layers:
240 | break
241 |
242 | super(IntermediateLayerGetter, self).__init__(layers)
243 | self.return_layers = orig_return_layers
244 |
245 | def forward(self, x):
246 | out = OrderedDict()
247 | for name, module in self.named_children():
248 | x = module(x)
249 | if name in self.return_layers:
250 | out_name = self.return_layers[name]
251 | out[out_name] = x
252 | return out
253 |
254 | '''
255 | 完整的分割模型
256 | '''
257 | class Simple_Segmentation_Model(nn.Module):
258 | def __init__(self, backbone, classifier):
259 | super(Simple_Segmentation_Model, self).__init__()
260 | self.backbone = backbone
261 | self.classifier = classifier
262 |
263 | def forward(self, x):
264 | input_shape = x.shape[-2:]
265 | features = self.backbone(x)
266 | x = self.classifier(features)
267 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
268 | return x
269 |
270 | class FaceSegmentation(Simple_Segmentation_Model):
271 | pass
272 |
273 | if __name__ == '__main__':
274 | model = load_model('resnet50', num_classes=config.num_classes, output_stride=config.output_stride)
275 | model = model.to(config.device)
276 | # print(model)
277 | input = torch.randn([16, 3, 513, 513])
278 | output = model(input)
279 | print(output.shape) # torch.Size([16, 10, 513, 513])
280 |
281 |
--------------------------------------------------------------------------------
/faceparse.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 |
7 |
8 | from resnet import Resnet18
9 |
10 |
11 | class ConvBNReLU(nn.Module):
12 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
13 | super(ConvBNReLU, self).__init__()
14 | self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
15 | self.bn = nn.BatchNorm2d(out_chan)
16 | self.init_weight()
17 |
18 | def forward(self, x):
19 | x = self.conv(x)
20 | x = F.relu(self.bn(x))
21 | return x
22 |
23 | def init_weight(self):
24 | for ly in self.children():
25 | if isinstance(ly, nn.Conv2d):
26 | nn.init.kaiming_normal_(ly.weight, a=1)
27 | if ly.bias is not None:
28 | nn.init.constant_(ly.bias, 0)
29 |
30 |
31 | class BiSeNetOutput(nn.Module):
32 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
33 | super(BiSeNetOutput, self).__init__()
34 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
35 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
36 | self.init_weight()
37 |
38 | def forward(self, x):
39 | x = self.conv(x)
40 | x = self.conv_out(x)
41 | return x
42 |
43 | def init_weight(self):
44 | for ly in self.children():
45 | if isinstance(ly, nn.Conv2d):
46 | nn.init.kaiming_normal_(ly.weight, a=1)
47 | if ly.bias is not None:
48 | nn.init.constant_(ly.bias, 0)
49 |
50 | def get_params(self):
51 | wd_params, nowd_params = [], []
52 | for name, module in self.named_modules():
53 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
54 | wd_params.append(module.weight)
55 | if module.bias is not None:
56 | nowd_params.append(module.bias)
57 | elif isinstance(module, nn.BatchNorm2d):
58 | nowd_params += list(module.parameters())
59 | return wd_params, nowd_params
60 |
61 |
62 | class AttentionRefinementModule(nn.Module):
63 | def __init__(self, in_chan, out_chan, *args, **kwargs):
64 | super(AttentionRefinementModule, self).__init__()
65 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
66 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
67 | self.bn_atten = nn.BatchNorm2d(out_chan)
68 | self.sigmoid_atten = nn.Sigmoid()
69 | self.init_weight()
70 |
71 | def forward(self, x):
72 | feat = self.conv(x)
73 | atten = F.avg_pool2d(feat, feat.size()[2:])
74 | atten = self.conv_atten(atten)
75 | atten = self.bn_atten(atten)
76 | atten = self.sigmoid_atten(atten)
77 | out = torch.mul(feat, atten)
78 | return out
79 |
80 | def init_weight(self):
81 | for ly in self.children():
82 | if isinstance(ly, nn.Conv2d):
83 | nn.init.kaiming_normal_(ly.weight, a=1)
84 | if ly.bias is not None:
85 | nn.init.constant_(ly.bias, 0)
86 |
87 |
88 | class ContextPath(nn.Module):
89 | def __init__(self, *args, **kwargs):
90 | super(ContextPath, self).__init__()
91 | self.resnet = Resnet18()
92 | self.arm16 = AttentionRefinementModule(256, 128)
93 | self.arm32 = AttentionRefinementModule(512, 128)
94 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
95 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
96 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
97 |
98 | self.init_weight()
99 |
100 | def forward(self, x):
101 | H0, W0 = x.size()[2:]
102 | feat8, feat16, feat32 = self.resnet(x)
103 | H8, W8 = feat8.size()[2:]
104 | H16, W16 = feat16.size()[2:]
105 | H32, W32 = feat32.size()[2:]
106 |
107 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
108 | avg = self.conv_avg(avg)
109 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
110 |
111 | feat32_arm = self.arm32(feat32)
112 | feat32_sum = feat32_arm + avg_up
113 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
114 | feat32_up = self.conv_head32(feat32_up)
115 |
116 | feat16_arm = self.arm16(feat16)
117 | feat16_sum = feat16_arm + feat32_up
118 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
119 | feat16_up = self.conv_head16(feat16_up)
120 |
121 | return feat8, feat16_up, feat32_up # x8, x8, x16
122 |
123 | def init_weight(self):
124 | for ly in self.children():
125 | if isinstance(ly, nn.Conv2d):
126 | nn.init.kaiming_normal_(ly.weight, a=1)
127 | if ly.bias is not None:
128 | nn.init.constant_(ly.bias, 0)
129 |
130 | def get_params(self):
131 | wd_params, nowd_params = [], []
132 | for name, module in self.named_modules():
133 | if isinstance(module, (nn.Linear, nn.Conv2d)):
134 | wd_params.append(module.weight)
135 | if module.bias is not None:
136 | nowd_params.append(module.bias)
137 | elif isinstance(module, nn.BatchNorm2d):
138 | nowd_params += list(module.parameters())
139 | return wd_params, nowd_params
140 |
141 |
142 | # This is not used, since I replace this with the resnet feature with the same size
143 | class SpatialPath(nn.Module):
144 | def __init__(self, *args, **kwargs):
145 | super(SpatialPath, self).__init__()
146 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
147 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
148 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
149 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
150 | self.init_weight()
151 |
152 | def forward(self, x):
153 | feat = self.conv1(x)
154 | feat = self.conv2(feat)
155 | feat = self.conv3(feat)
156 | feat = self.conv_out(feat)
157 | return feat
158 |
159 | def init_weight(self):
160 | for ly in self.children():
161 | if isinstance(ly, nn.Conv2d):
162 | nn.init.kaiming_normal_(ly.weight, a=1)
163 | if ly.bias is not None:
164 | nn.init.constant_(ly.bias, 0)
165 |
166 | def get_params(self):
167 | wd_params, nowd_params = [], []
168 | for name, module in self.named_modules():
169 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
170 | wd_params.append(module.weight)
171 | if module.bias is not None:
172 | nowd_params.append(module.bias)
173 | elif isinstance(module, nn.BatchNorm2d):
174 | nowd_params += list(module.parameters())
175 | return wd_params, nowd_params
176 |
177 |
178 | class FeatureFusionModule(nn.Module):
179 | def __init__(self, in_chan, out_chan, *args, **kwargs):
180 | super(FeatureFusionModule, self).__init__()
181 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
182 | self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
183 | self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
184 | self.relu = nn.ReLU(inplace=True)
185 | self.sigmoid = nn.Sigmoid()
186 | self.init_weight()
187 |
188 | def forward(self, fsp, fcp):
189 | fcat = torch.cat([fsp, fcp], dim=1)
190 | feat = self.convblk(fcat)
191 | atten = F.avg_pool2d(feat, feat.size()[2:])
192 | atten = self.conv1(atten)
193 | atten = self.relu(atten)
194 | atten = self.conv2(atten)
195 | atten = self.sigmoid(atten)
196 | feat_atten = torch.mul(feat, atten)
197 | feat_out = feat_atten + feat
198 | return feat_out
199 |
200 | def init_weight(self):
201 | for ly in self.children():
202 | if isinstance(ly, nn.Conv2d):
203 | nn.init.kaiming_normal_(ly.weight, a=1)
204 | if ly.bias is not None:
205 | nn.init.constant_(ly.bias, 0)
206 |
207 | def get_params(self):
208 | wd_params, nowd_params = [], []
209 | for name, module in self.named_modules():
210 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
211 | wd_params.append(module.weight)
212 | if module.bias is not None:
213 | nowd_params.append(module.bias)
214 | elif isinstance(module, nn.BatchNorm2d):
215 | nowd_params += list(module.parameters())
216 | return wd_params, nowd_params
217 |
218 |
219 | class BiSeNet(nn.Module):
220 | def __init__(self, n_classes, *args, **kwargs):
221 | super(BiSeNet, self).__init__()
222 | self.cp = ContextPath()
223 | # here self.sp is deleted
224 | self.ffm = FeatureFusionModule(256, 256)
225 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
226 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
227 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
228 | self.init_weight()
229 |
230 | def forward(self, x):
231 | H, W = x.size()[2:]
232 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
233 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
234 | feat_fuse = self.ffm(feat_sp, feat_cp8)
235 |
236 | feat_out = self.conv_out(feat_fuse)
237 | # feat_out16 = self.conv_out16(feat_cp8)
238 | # feat_out32 = self.conv_out32(feat_cp16)
239 |
240 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
241 | # feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
242 | # feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
243 | return feat_out # , feat_out16, feat_out32
244 |
245 | def init_weight(self):
246 | for ly in self.children():
247 | if isinstance(ly, nn.Conv2d):
248 | nn.init.kaiming_normal_(ly.weight, a=1)
249 | if ly.bias is not None:
250 | nn.init.constant_(ly.bias, 0)
251 |
252 | def get_params(self):
253 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
254 | for name, child in self.named_children():
255 | child_wd_params, child_nowd_params = child.get_params()
256 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
257 | lr_mul_wd_params += child_wd_params
258 | lr_mul_nowd_params += child_nowd_params
259 | else:
260 | wd_params += child_wd_params
261 | nowd_params += child_nowd_params
262 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
263 |
264 | if __name__ == "__main__":
265 | net = BiSeNet(19)
266 | # net.cuda()
267 | net.eval()
268 | in_ten = torch.randn(16, 3, 640, 480)
269 | # in_ten = torch.randn(3, 640, 480)
270 | # out, out16, out32 = net(in_ten)
271 | out = net(in_ten)
272 | print(out.shape)
273 | net.get_params()
--------------------------------------------------------------------------------
/faceswap.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 |
3 | # Copyright (c) 2015 Matthew Earl
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included
13 | # in all copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
16 | # OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17 | # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
18 | # NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
19 | # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21 | # USE OR OTHER DEALINGS IN THE SOFTWARE.
22 |
23 | """
24 | This is the code behind the Switching Eds blog post:
25 | http://matthewearl.github.io/2015/07/28/switching-eds-with-python/
26 | See the above for an explanation of the code below.
27 | To run the script you'll need to install dlib (http://dlib.net) including its
28 | Python bindings, and OpenCV. You'll also need to obtain the trained model from
29 | sourceforge:
30 | http://sourceforge.net/projects/dclib/files/dlib/v18.10/shape_predictor_68_face_landmarks.dat.bz2
31 | Unzip with `bunzip2` and change `PREDICTOR_PATH` to refer to this file. The
32 | script is run like so:
33 | ./faceswap.py
34 | If successful, a file `output.jpg` will be produced with the facial features
35 | from `` replaced with the facial features from ``.
36 | """
37 |
38 | import cv2
39 | import dlib
40 | import numpy
41 |
42 | import sys
43 |
44 | PREDICTOR_PATH = "./checkpoint/shape_predictor_68_face_landmarks.dat"
45 | SCALE_FACTOR = 1
46 | FEATHER_AMOUNT = 11
47 |
48 | FACE_POINTS = list(range(17, 68))
49 | MOUTH_POINTS = list(range(48, 61))
50 | RIGHT_BROW_POINTS = list(range(17, 22))
51 | LEFT_BROW_POINTS = list(range(22, 27))
52 | RIGHT_EYE_POINTS = list(range(36, 42))
53 | LEFT_EYE_POINTS = list(range(42, 48))
54 | NOSE_POINTS = list(range(27, 35))
55 | JAW_POINTS = list(range(0, 17))
56 |
57 | # Points used to line up the images.
58 | ALIGN_POINTS = (LEFT_BROW_POINTS + RIGHT_EYE_POINTS + LEFT_EYE_POINTS +
59 | RIGHT_BROW_POINTS + NOSE_POINTS + MOUTH_POINTS)
60 |
61 | # Points from the second image to overlay on the first. The convex hull of each
62 | # element will be overlaid.
63 | OVERLAY_POINTS = [
64 | LEFT_EYE_POINTS + RIGHT_EYE_POINTS + LEFT_BROW_POINTS + RIGHT_BROW_POINTS,
65 | NOSE_POINTS + MOUTH_POINTS,
66 | ]
67 |
68 | # Amount of blur to use during colour correction, as a fraction of the
69 | # pupillary distance.
70 | COLOUR_CORRECT_BLUR_FRAC = 0.6
71 |
72 | detector = dlib.get_frontal_face_detector()
73 | predictor = dlib.shape_predictor(PREDICTOR_PATH)
74 |
75 |
76 | class TooManyFaces(Exception):
77 | pass
78 |
79 |
80 | class NoFaces(Exception):
81 | pass
82 |
83 |
84 | def get_landmarks(im):
85 | rects = detector(im, 1)
86 |
87 | if len(rects) > 1:
88 | raise TooManyFaces
89 | if len(rects) == 0:
90 | raise NoFaces
91 |
92 | return numpy.matrix([[p.x, p.y] for p in predictor(im, rects[0]).parts()])
93 |
94 |
95 | def annotate_landmarks(im, landmarks):
96 | im = im.copy()
97 | for idx, point in enumerate(landmarks):
98 | pos = (point[0, 0], point[0, 1])
99 | cv2.putText(im, str(idx), pos,
100 | fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,
101 | fontScale=0.4,
102 | color=(0, 0, 255))
103 | cv2.circle(im, pos, 3, color=(0, 255, 255))
104 | return im
105 |
106 |
107 | def draw_convex_hull(im, points, color):
108 | points = cv2.convexHull(points)
109 | cv2.fillConvexPoly(im, points, color=color)
110 |
111 |
112 | def get_face_mask(im, landmarks):
113 | im = numpy.zeros(im.shape[:2], dtype=numpy.float64)
114 |
115 | for group in OVERLAY_POINTS:
116 | draw_convex_hull(im,
117 | landmarks[group],
118 | color=1)
119 |
120 | im = numpy.array([im, im, im]).transpose((1, 2, 0))
121 |
122 | im = (cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) > 0) * 1.0
123 | im = cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0)
124 |
125 | return im
126 |
127 |
128 | def transformation_from_points(points1, points2):
129 | """
130 | Return an affine transformation [s * R | T] such that:
131 | sum ||s*R*p1,i + T - p2,i||^2
132 | is minimized.
133 | """
134 | # Solve the procrustes problem by subtracting centroids, scaling by the
135 | # standard deviation, and then using the SVD to calculate the rotation. See
136 | # the following for more details:
137 | # https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem
138 |
139 | points1 = points1.astype(numpy.float64)
140 | points2 = points2.astype(numpy.float64)
141 |
142 | c1 = numpy.mean(points1, axis=0)
143 | c2 = numpy.mean(points2, axis=0)
144 | points1 -= c1
145 | points2 -= c2
146 |
147 | s1 = numpy.std(points1)
148 | s2 = numpy.std(points2)
149 | points1 /= s1
150 | points2 /= s2
151 |
152 | U, S, Vt = numpy.linalg.svd(points1.T * points2)
153 |
154 | # The R we seek is in fact the transpose of the one given by U * Vt. This
155 | # is because the above formulation assumes the matrix goes on the right
156 | # (with row vectors) where as our solution requires the matrix to be on the
157 | # left (with column vectors).
158 | R = (U * Vt).T
159 |
160 | return numpy.vstack([numpy.hstack(((s2 / s1) * R,
161 | c2.T - (s2 / s1) * R * c1.T)),
162 | numpy.matrix([0., 0., 1.])])
163 |
164 |
165 | def read_im_and_landmarks(fname):
166 | im = cv2.imread(fname, cv2.IMREAD_COLOR)
167 | im = cv2.resize(im, (im.shape[1] * SCALE_FACTOR,
168 | im.shape[0] * SCALE_FACTOR))
169 | s = get_landmarks(im)
170 |
171 | return im, s
172 |
173 |
174 | def warp_im(im, M, dshape):
175 | output_im = numpy.zeros(dshape, dtype=im.dtype)
176 | cv2.warpAffine(im,
177 | M[:2],
178 | (dshape[1], dshape[0]),
179 | dst=output_im,
180 | borderMode=cv2.BORDER_TRANSPARENT,
181 | flags=cv2.WARP_INVERSE_MAP)
182 | return output_im
183 |
184 |
185 | def correct_colours(im1, im2, landmarks1):
186 | blur_amount = COLOUR_CORRECT_BLUR_FRAC * numpy.linalg.norm(
187 | numpy.mean(landmarks1[LEFT_EYE_POINTS], axis=0) -
188 | numpy.mean(landmarks1[RIGHT_EYE_POINTS], axis=0))
189 | blur_amount = int(blur_amount)
190 | if blur_amount % 2 == 0:
191 | blur_amount += 1
192 | im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0)
193 | im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0)
194 |
195 | # Avoid divide-by-zero errors.
196 | im2_blur += (128 * (im2_blur <= 1.0)).astype(im2_blur.dtype)
197 |
198 | return (im2.astype(numpy.float64) * im1_blur.astype(numpy.float64) /
199 | im2_blur.astype(numpy.float64))
200 |
201 |
202 | template_path = './dat/avg_face.jpg' # 用平均脸做底图,替换五官
203 |
204 | def faceswap(upload_path):
205 | # upload_path = "./dat/lss.jpg"
206 |
207 | im1, landmarks1 = read_im_and_landmarks(template_path)
208 | im2, landmarks2 = read_im_and_landmarks(upload_path)
209 |
210 | M = transformation_from_points(landmarks1[ALIGN_POINTS],
211 | landmarks2[ALIGN_POINTS])
212 |
213 | mask = get_face_mask(im2, landmarks2)
214 | warped_mask = warp_im(mask, M, im1.shape)
215 | combined_mask = numpy.max([get_face_mask(im1, landmarks1), warped_mask],
216 | axis=0)
217 |
218 | warped_im2 = warp_im(im2, M, im1.shape)
219 | warped_corrected_im2 = correct_colours(im1, warped_im2, landmarks1)
220 |
221 | output_im = im1 * (1.0 - combined_mask) + warped_corrected_im2 * combined_mask
222 | return output_im
223 | # cv2.imwrite('output.jpg', output_im)
224 |
225 | if __name__ == '__main__':
226 | im = faceswap("./dat/lss.jpg")
227 | cv2.imshow("im", im * 255.)
228 | cv2.waitKey()
229 | cv2.imwrite('output.jpg', im)
--------------------------------------------------------------------------------
/imitator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 |
7 | import utils
8 | import config
9 |
10 |
11 | '''
12 | 模拟器:使用神经网络代替游戏引擎
13 | '''
14 |
15 | class Imitator(nn.Module):
16 | def __init__(self, is_bias=False):
17 | super(Imitator, self).__init__()
18 |
19 | self.model = nn.Sequential(
20 | # 1. (batch, 512, 4, 4)
21 | nn.ConvTranspose2d(config.continuous_params_size, 512, kernel_size=4, bias=is_bias),
22 | nn.BatchNorm2d(512),
23 | nn.ReLU(),
24 | # 2. (batch, 512, 8, 8)
25 | nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=is_bias),
26 | nn.BatchNorm2d(512),
27 | nn.ReLU(),
28 | # 3. (batch, 512, 16, 16)
29 | nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=is_bias),
30 | nn.BatchNorm2d(512),
31 | nn.ReLU(),
32 | # 4. (batch, 256, 32, 32)
33 | nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=is_bias),
34 | nn.BatchNorm2d(256),
35 | nn.ReLU(),
36 | # 5. (batch, 128, 64, 64)
37 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=is_bias),
38 | nn.BatchNorm2d(128),
39 | nn.ReLU(),
40 | # 6. (batch, 64, 128, 128)
41 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=is_bias),
42 | nn.BatchNorm2d(64),
43 | nn.ReLU(),
44 | # 7. (batch, 64, 256, 256)
45 | nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=is_bias),
46 | nn.BatchNorm2d(64),
47 | nn.ReLU(),
48 | # 8. (batch, 3, 512, 512)
49 | nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=is_bias),
50 | # nn.Tanh(), # DCGAN和论文中用的tanh
51 | nn.Sigmoid() # 开源版本用的sigmoid
52 | )
53 |
54 | self._initialize_weights()
55 |
56 | def _initialize_weights(self) -> None:
57 | for m in self.modules():
58 | classname = m.__class__.__name__
59 | if classname.find('Conv') != -1:
60 | nn.init.normal_(m.weight, 0.0, 0.02)
61 | elif classname.find('BatchNorm') != -1:
62 | nn.init.normal_(m.weight, 1.0, 0.02)
63 | nn.init.constant_(m.bias, 0)
64 |
65 | '''
66 | :param params [batch, params_cnt]
67 | :return [batch, 3, 512, 512]
68 | '''
69 | def forward(self, params):
70 | # batch = params.size(0) # 1
71 | # length = params.size(1) # 95
72 | # _params = params.reshape((batch, length, 1, 1)) # [1, 95, 1, 1]
73 |
74 | # 把连续参数的size从[batch, continuous_params_size]扩展成[batch, continuous_params_size, 1, 1]
75 | _params = params.unsqueeze(2).unsqueeze(3)
76 | return self.model(_params)
77 |
78 |
--------------------------------------------------------------------------------
/lightcnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import math
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torchvision import transforms
10 |
11 | import config
12 |
13 |
14 | class mfm(nn.Module):
15 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1):
16 | super(mfm, self).__init__()
17 | self.out_channels = out_channels
18 | if type == 1:
19 | self.filter = nn.Conv2d(in_channels, 2 * out_channels, kernel_size=kernel_size, stride=stride,
20 | padding=padding)
21 | else:
22 | self.filter = nn.Linear(in_channels, 2 * out_channels)
23 |
24 | def forward(self, x):
25 | x = self.filter(x)
26 | out = torch.split(x, self.out_channels, 1)
27 | return torch.max(out[0], out[1])
28 |
29 |
30 | class group(nn.Module):
31 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
32 | super(group, self).__init__()
33 | self.conv_a = mfm(in_channels, in_channels, 1, 1, 0)
34 | self.conv = mfm(in_channels, out_channels, kernel_size, stride, padding)
35 |
36 | def forward(self, x):
37 | x = self.conv_a(x)
38 | x = self.conv(x)
39 | return x
40 |
41 |
42 | class resblock(nn.Module):
43 | def __init__(self, in_channels, out_channels):
44 | super(resblock, self).__init__()
45 | self.conv1 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
46 | self.conv2 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
47 |
48 | def forward(self, x):
49 | res = x
50 | out = self.conv1(x)
51 | out = self.conv2(out)
52 | out = out + res
53 | return out
54 |
55 |
56 | class network_9layers(nn.Module):
57 | def __init__(self, num_classes=79077):
58 | super(network_9layers, self).__init__()
59 | self.features = nn.Sequential(
60 | mfm(1, 48, 5, 1, 2),
61 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
62 | group(48, 96, 3, 1, 1),
63 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
64 | group(96, 192, 3, 1, 1),
65 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
66 | group(192, 128, 3, 1, 1),
67 | group(128, 128, 3, 1, 1),
68 | nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True),
69 | )
70 | self.fc1 = mfm(8 * 8 * 128, 256, type=0)
71 | self.fc2 = nn.Linear(256, num_classes)
72 |
73 | def forward(self, x):
74 | x = self.features(x)
75 | x = x.view(x.size(0), -1)
76 | x = self.fc1(x)
77 | x = F.dropout(x, training=self.training)
78 | out = self.fc2(x)
79 | return out, x
80 |
81 |
82 | class network_29layers(nn.Module):
83 | def __init__(self, block, layers, num_classes=79077):
84 | super(network_29layers, self).__init__()
85 | self.conv1 = mfm(1, 48, 5, 1, 2)
86 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
87 | self.block1 = self._make_layer(block, layers[0], 48, 48)
88 | self.group1 = group(48, 96, 3, 1, 1)
89 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
90 | self.block2 = self._make_layer(block, layers[1], 96, 96)
91 | self.group2 = group(96, 192, 3, 1, 1)
92 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
93 | self.block3 = self._make_layer(block, layers[2], 192, 192)
94 | self.group3 = group(192, 128, 3, 1, 1)
95 | self.block4 = self._make_layer(block, layers[3], 128, 128)
96 | self.group4 = group(128, 128, 3, 1, 1)
97 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
98 | self.fc = mfm(8 * 8 * 128, 256, type=0)
99 | self.fc2 = nn.Linear(256, num_classes)
100 |
101 | def _make_layer(self, block, num_blocks, in_channels, out_channels):
102 | layers = []
103 | for i in range(0, num_blocks):
104 | layers.append(block(in_channels, out_channels))
105 | return nn.Sequential(*layers)
106 |
107 | def forward(self, x):
108 | x = self.conv1(x)
109 | x = self.pool1(x)
110 |
111 | x = self.block1(x)
112 | x = self.group1(x)
113 | x = self.pool2(x)
114 |
115 | x = self.block2(x)
116 | x = self.group2(x)
117 | x = self.pool3(x)
118 |
119 | x = self.block3(x)
120 | x = self.group3(x)
121 | x = self.block4(x)
122 | x = self.group4(x)
123 | x = self.pool4(x)
124 |
125 | x = x.view(x.size(0), -1)
126 | fc = self.fc(x)
127 | fc = F.dropout(fc, training=self.training)
128 | out = self.fc2(fc)
129 | return out, fc
130 |
131 |
132 | class network_29layers_v2(nn.Module):
133 | def __init__(self, block, layers, num_classes=79077):
134 | super(network_29layers_v2, self).__init__()
135 | self.conv1 = mfm(1, 48, 5, 1, 2)
136 | self.block1 = self._make_layer(block, layers[0], 48, 48)
137 | self.group1 = group(48, 96, 3, 1, 1)
138 | self.block2 = self._make_layer(block, layers[1], 96, 96)
139 | self.group2 = group(96, 192, 3, 1, 1)
140 | self.block3 = self._make_layer(block, layers[2], 192, 192)
141 | self.group3 = group(192, 128, 3, 1, 1)
142 | self.block4 = self._make_layer(block, layers[3], 128, 128)
143 | self.group4 = group(128, 128, 3, 1, 1)
144 | self.fc = nn.Linear(8 * 8 * 128, 256)
145 | self.fc2 = nn.Linear(256, num_classes, bias=False)
146 |
147 | def _make_layer(self, block, num_blocks, in_channels, out_channels):
148 | layers = []
149 | for i in range(0, num_blocks):
150 | layers.append(block(in_channels, out_channels))
151 | return nn.Sequential(*layers)
152 |
153 | def forward(self, x):
154 | x = self.conv1(x)
155 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
156 |
157 | x = self.block1(x)
158 | x = self.group1(x)
159 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
160 |
161 | x = self.block2(x)
162 | x = self.group2(x)
163 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
164 |
165 | x = self.block3(x)
166 | x = self.group3(x)
167 | x = self.block4(x)
168 | x = self.group4(x)
169 | x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
170 |
171 | x = x.view(x.size(0), -1)
172 | fc = self.fc(x)
173 | x = F.dropout(fc, training=self.training)
174 | out = self.fc2(x)
175 | return out, fc
176 |
177 |
178 | def LightCNN_9Layers(**kwargs):
179 | model = network_9layers(**kwargs)
180 | return model
181 |
182 |
183 | def LightCNN_29Layers(**kwargs):
184 | model = network_29layers(resblock, [1, 2, 3, 4], **kwargs)
185 | return model
186 |
187 |
188 | def LightCNN_29Layers_v2(**kwargs):
189 | model = network_29layers_v2(resblock, [1, 2, 3, 4], **kwargs)
190 | return model
191 |
192 | '''
193 | 保存lightcnn提取的特征
194 | '''
195 | def save_feature(save_path, img_name, features):
196 | img_path = os.path.join(save_path, img_name)
197 | img_dir = os.path.dirname(img_path) + '/'
198 | if not os.path.exists(img_dir):
199 | os.makedirs(img_dir)
200 | fname = os.path.splitext(img_path)[0]
201 | fname = fname + '.feat'
202 | fid = open(fname, 'wb')
203 | fid.write(features)
204 | fid.close()
205 |
206 | if __name__ == '__main__':
207 | model = LightCNN_29Layers_v2(num_classes=80013)
208 | model.eval()
209 | if torch.cuda.is_available():
210 | model = torch.nn.DataParallel(model).cuda()
211 | else:
212 | model = torch.nn.DataParallel(model)
213 | checkpoint = torch.load("./checkpoint/LightCNN_29Layers_V2_checkpoint.pth.tar", map_location="cpu")
214 | model.load_state_dict(checkpoint['state_dict'])
215 |
216 | transform = transforms.Compose([transforms.ToTensor()])
217 | img_path = "./dat/0020.png"
218 | img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
219 | img = cv2.resize(img, (128, 128))
220 | img = np.reshape(img, (128, 128, 1))
221 | img = transform(img)
222 |
223 | input = torch.zeros(1, 1, 128, 128)
224 | input[0, :, :, :] = img
225 |
226 | start = time.time()
227 | if config.use_gpu:
228 | input = input.cuda()
229 |
230 | with torch.no_grad():
231 | _, features = model(input)
232 | print(features.shape)
233 |
234 | fname = "./output/features.feat"
235 | fid = open(fname, 'wb')
236 | fid.write(features.data.cpu().numpy()[0])
237 | fid.close()
238 |
--------------------------------------------------------------------------------
/model_process.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import config
4 |
5 | '''
6 | 模型key的处理
7 | 多GPU训练的模型以module.开头,需改成单GPU模型的格式
8 | '''
9 |
10 | if __name__ == '__main__':
11 | # checkpoint = torch.load(config.imitator_model, map_location=config.device)
12 | # for new_key, old_key in checkpoint.items():
13 | # checkpoint[new_key] = checkpoint.pop(old_key)
14 | # torch.save(checkpoint, './checkpoint/epoch_950.pt')
15 |
16 | new_model = {k.replace('module.', ''): v for k, v in torch.load(config.imitator_model, map_location=config.device).items()}
17 | torch.save(new_model, './checkpoint/epoch_950.pt')
--------------------------------------------------------------------------------
/myimitator.py:
--------------------------------------------------------------------------------
1 | from __future__ import (absolute_import, division, print_function, unicode_literals)
2 |
3 | import os
4 | import cv2
5 | import copy
6 | import math
7 | import json
8 | import torch
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | import utils
14 | import config
15 |
16 | '''
17 | 自定义Imitator
18 | 1.conv,linear,embedding后加上sn
19 | 2.指定层加上self-attention
20 | 3.自定义bn
21 | '''
22 |
23 | # 采用sn做 normalization
24 | def snconv2d(eps=1e-12, **kwargs):
25 | return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
26 |
27 | def snlinear(eps=1e-12, **kwargs):
28 | return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
29 |
30 | def sn_embedding(eps=1e-12, **kwargs):
31 | return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
32 |
33 | # self-attention层
34 | class SelfAttn(nn.Module):
35 | def __init__(self, in_channels, eps=1e-12):
36 | super(SelfAttn, self).__init__()
37 | self.in_channels = in_channels
38 | self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
39 | kernel_size=1, bias=False, eps=eps)
40 | self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
41 | kernel_size=1, bias=False, eps=eps)
42 | self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
43 | kernel_size=1, bias=False, eps=eps)
44 | self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
45 | kernel_size=1, bias=False, eps=eps)
46 | self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
47 | self.softmax = nn.Softmax(dim=-1)
48 | self.gamma = nn.Parameter(torch.zeros(1))
49 |
50 | def forward(self, x):
51 | _, ch, h, w = x.size()
52 | # Theta path
53 | theta = self.snconv1x1_theta(x)
54 | theta = theta.view(-1, ch//8, h*w)
55 | # Phi path
56 | phi = self.snconv1x1_phi(x)
57 | phi = self.maxpool(phi)
58 | phi = phi.view(-1, ch//8, h*w//4)
59 | # Attn map
60 | attn = torch.bmm(theta.permute(0, 2, 1), phi)
61 | attn = self.softmax(attn)
62 | # g path
63 | g = self.snconv1x1_g(x)
64 | g = self.maxpool(g)
65 | g = g.view(-1, ch//2, h*w//4)
66 | # Attn_g - o_conv
67 | attn_g = torch.bmm(g, attn.permute(0, 2, 1))
68 | attn_g = attn_g.view(-1, ch//2, h, w)
69 | attn_g = self.snconv1x1_o_conv(attn_g)
70 | # Out
71 | out = x + self.gamma*attn_g
72 | return out
73 |
74 | # 自定义bn
75 | class BigGANBatchNorm(nn.Module):
76 | """ This is a batch norm module that can handle conditional input and can be provided with pre-computed
77 | activation means and variances for various truncation parameters.
78 |
79 | We cannot just rely on torch.batch_norm since it cannot handle
80 | batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
81 | If you want to train this model you should add running means and variance computation logic.
82 | """
83 | def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
84 | super(BigGANBatchNorm, self).__init__()
85 | self.num_features = num_features
86 | self.eps = eps
87 | self.conditional = conditional
88 |
89 | # We use pre-computed statistics for n_stats values of truncation between 0 and 1
90 | self.register_buffer('running_means', torch.zeros(n_stats, num_features))
91 | self.register_buffer('running_vars', torch.ones(n_stats, num_features))
92 | self.step_size = 1.0 / (n_stats - 1)
93 |
94 | if conditional:
95 | assert condition_vector_dim is not None
96 | self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
97 | self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
98 | else:
99 | self.weight = torch.nn.Parameter(torch.Tensor(num_features))
100 | self.bias = torch.nn.Parameter(torch.Tensor(num_features))
101 |
102 | def forward(self, x, truncation, condition_vector=None):
103 | # Retreive pre-computed statistics associated to this truncation
104 | coef, start_idx = math.modf(truncation / self.step_size)
105 | start_idx = int(start_idx)
106 | if coef != 0.0: # Interpolate
107 | running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
108 | running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
109 | else:
110 | running_mean = self.running_means[start_idx]
111 | running_var = self.running_vars[start_idx]
112 |
113 | if self.conditional:
114 | running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
115 | running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
116 |
117 | weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
118 | bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
119 |
120 | out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
121 | else:
122 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
123 | training=False, momentum=0.0, eps=self.eps)
124 | return out
125 |
126 | class GenBlock(nn.Module):
127 | def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
128 | n_stats=51, eps=1e-12):
129 | super(GenBlock, self).__init__()
130 | self.up_sample = up_sample
131 | self.drop_channels = (in_size != out_size)
132 | middle_size = in_size // reduction_factor
133 |
134 | self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
135 | self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
136 |
137 | self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
138 | self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
139 |
140 | self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
141 | self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
142 |
143 | self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
144 | self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
145 |
146 | self.relu = nn.ReLU()
147 |
148 | def forward(self, x, cond_vector, truncation):
149 | x0 = x
150 |
151 | x = self.bn_0(x, truncation, cond_vector)
152 | x = self.relu(x)
153 | x = self.conv_0(x)
154 |
155 | x = self.bn_1(x, truncation, cond_vector)
156 | x = self.relu(x)
157 | if self.up_sample:
158 | x = F.interpolate(x, scale_factor=2, mode='nearest')
159 | x = self.conv_1(x)
160 |
161 | x = self.bn_2(x, truncation, cond_vector)
162 | x = self.relu(x)
163 | x = self.conv_2(x)
164 |
165 | x = self.bn_3(x, truncation, cond_vector)
166 | x = self.relu(x)
167 | x = self.conv_3(x)
168 |
169 | if self.drop_channels:
170 | new_channels = x0.shape[1] // 2
171 | x0 = x0[:, :new_channels, ...]
172 | if self.up_sample:
173 | x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
174 |
175 | out = x + x0
176 | return out
177 |
178 | class MyImitator(nn.Module):
179 | def __init__(self):
180 | super(MyImitator, self).__init__()
181 |
182 | # 1.加载配置文件
183 | with open(config.config_jsonfile, "r", encoding='utf-8') as reader:
184 | text = reader.read()
185 | self.conf = BigGANConfig()
186 | for key, value in json.loads(text).items():
187 | self.conf.__dict__[key] = value
188 |
189 | # 定义网络结构
190 | # self.embeddings = nn.Linear(config.num_classes, config.continuous_params_size, bias=False)
191 |
192 | ch = self.conf.channel_width
193 | condition_vector_dim = config.continuous_params_size
194 |
195 | self.gen_z = snlinear(in_features=condition_vector_dim, out_features=4*4*16*ch, eps=self.conf.eps)
196 | layers = []
197 | for i, layer in enumerate(self.conf.layers):
198 | if i == self.conf.attention_layer_position: # 在指定层加上self-attention
199 | layers.append(SelfAttn(ch * layer[1], eps=self.conf.eps))
200 | layers.append(GenBlock(ch * layer[1],
201 | ch * layer[2],
202 | condition_vector_dim,
203 | up_sample=layer[0],
204 | n_stats=self.conf.n_stats,
205 | eps=self.conf.eps))
206 | self.layers = nn.ModuleList(layers)
207 |
208 | self.bn = BigGANBatchNorm(ch, n_stats=self.conf.n_stats, eps=self.conf.eps, conditional=False)
209 | self.relu = nn.ReLU()
210 | self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=self.conf.eps)
211 | self.tanh = nn.Tanh()
212 |
213 | def forward(self, cond_vector, truncation=0.4):
214 | # cond_vector = cond_vector.unsqueeze(2).unsqueeze(3)
215 | z = self.gen_z(cond_vector) # cond_cector [batch_size, config.continuous_params_size], z [1, 4*4*16*self.conf.channel_width]
216 |
217 | # We use this conversion step to be able to use TF weights:
218 | # TF convention on shape is [batch, height, width, channels]
219 | # PT convention on shape is [batch, channels, height, width]
220 | z = z.view(-1, 4, 4, 16 * self.conf.channel_width) # [batch_size, 4, 4, 2048]
221 | z = z.permute(0, 3, 1, 2).contiguous() # [batch_size, 2048, 4, 4]
222 |
223 | for i, layer in enumerate(self.layers):
224 | if isinstance(layer, GenBlock):
225 | z = layer(z, cond_vector, truncation)
226 | else:
227 | z = layer(z)
228 |
229 | z = self.bn(z, truncation) # [1, 128, 512, 512]
230 | z = self.relu(z) # [1, 128, 512, 512]
231 | z = self.conv_to_rgb(z) # [1, 128, 512, 512]
232 | z = z[:, :3, ...] # [1, 3, 512, 512]
233 | z = self.tanh(z) # [1, 3, 512, 512]
234 | return z
235 |
236 | '''
237 | 自定义Imitator的config
238 | '''
239 | class BigGANConfig(object):
240 | """ Configuration class to store the configuration of a `BigGAN`.
241 | Defaults are for the 128x128 model.
242 | layers tuple are (up-sample in the layer ?, input channels, output channels)
243 | """
244 | def __init__(self,
245 | output_dim=512,
246 | z_dim=512,
247 | class_embed_dim=512,
248 | channel_width=512,
249 | num_classes=1000,
250 | # (是否上采样,input_channels,output_channels)
251 | layers=[(False, 16, 16),
252 | (True, 16, 16),
253 | (False, 16, 16),
254 | (True, 16, 8),
255 | (False, 8, 8),
256 | (True, 8, 4),
257 | (False, 4, 4),
258 | (True, 4, 2),
259 | (False, 2, 2),
260 | (True, 2, 1)],
261 | attention_layer_position=8,
262 | eps=1e-4,
263 | n_stats=51):
264 | """Constructs BigGANConfig. """
265 | self.output_dim = output_dim
266 | self.z_dim = z_dim
267 | self.class_embed_dim = class_embed_dim
268 | self.channel_width = channel_width
269 | self.num_classes = num_classes
270 | self.layers = layers
271 | self.attention_layer_position = attention_layer_position
272 | self.eps = eps
273 | self.n_stats = n_stats
274 |
275 | @classmethod
276 | def from_dict(cls, json_object):
277 | """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
278 | config = BigGANConfig()
279 | for key, value in json_object.items():
280 | config.__dict__[key] = value
281 | return config
282 |
283 | @classmethod
284 | def from_json_file(cls, json_file):
285 | """Constructs a `BigGANConfig` from a json file of parameters."""
286 | with open(json_file, "r", encoding='utf-8') as reader:
287 | text = reader.read()
288 | return cls.from_dict(json.loads(text))
289 |
290 | def __repr__(self):
291 | return str(self.to_json_string())
292 |
293 | def to_dict(self):
294 | """Serializes this instance to a Python dictionary."""
295 | output = copy.deepcopy(self.__dict__)
296 | return output
297 |
298 | def to_json_string(self):
299 | """Serializes this instance to a JSON string."""
300 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
301 |
302 |
303 | if __name__ == '__main__':
304 | # # d1 = nn.Linear(20, 40)
305 | # d1 = nn.Conv2d(10, 20, kernel_size=3, stride=1, padding=0, groups=10)
306 | # d2 = nn.utils.spectral_norm(d1)
307 | # print(d2)
308 | # print(d2.weight_u.size())
309 |
310 | t_params = torch.full([1, config.continuous_params_size], 0.5, dtype=torch.float32)
311 | imitator = MyImitator()
312 | print(imitator)
313 | output = imitator(t_params)
314 | print(output.shape) # [1, 3, 512, 512]
315 |
--------------------------------------------------------------------------------
/papers/1909.01064v1(Face-to-Parameter Translation ).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/papers/1909.01064v1(Face-to-Parameter Translation ).pdf
--------------------------------------------------------------------------------
/papers/2003.05653(Towards High-Fidelity 3D Face Reconstruction).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/papers/2003.05653(Towards High-Fidelity 3D Face Reconstruction).pdf
--------------------------------------------------------------------------------
/papers/2008.07132v1(Fast and Robust Face-to-Parameter Translation).pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csnowhermit/face2parameter/c9c7f849e91a8b51c209b53c8b4d5c402ed791d6/papers/2008.07132v1(Fast and Robust Face-to-Parameter Translation).pdf
--------------------------------------------------------------------------------
/random_gen_image.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import random
4 |
5 | import config
6 | from imitator import Imitator
7 |
8 |
9 | if __name__ == '__main__':
10 | imitator = Imitator()
11 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu'))
12 | imitator.load_state_dict(imitator_model) # 这里加载已经处理过的参数
13 |
14 | for i in range(10):
15 | if random.randint(1, 10) % 2 == 0:
16 | t_params = torch.rand((1, config.continuous_params_size), dtype=torch.float32)
17 | else:
18 | # t_params = torch.randn((1, config.continuous_params_size), dtype=torch.float32)
19 | t_params = torch.normal(0.5, 1, (1, config.continuous_params_size))
20 | print("2.1.", t_params)
21 | t_params.data = t_params.data.clamp(0., 1.)
22 | print("2.2.", t_params)
23 | # t_params = torch.rand((1, config.continuous_params_size), dtype=torch.float32)
24 |
25 | y_ = imitator(t_params) # [1, 3, 512, 512], [batch_size, c, w, h]
26 | tmp = y_.detach().cpu().numpy()[0]
27 | tmp = tmp.transpose(2, 1, 0)
28 | tmp = tmp * 255.0
29 | # cv2.imshow("y_", tmp)
30 | # cv2.waitKey()
31 | print(type(tmp), tmp.shape)
32 | cv2.imwrite("./dat/gen_%d.jpg" % i, tmp)
33 | print("已保存:", i)
34 |
35 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | import config
10 |
11 | # from modules.bn import InPlaceABNSync as BatchNorm2d
12 |
13 | # resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
14 | resnet18_url = "./checkpoint/resnet18-5c106cde.pth"
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | """3x3 convolution with padding"""
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
19 |
20 |
21 | class BasicBlock(nn.Module):
22 | def __init__(self, in_chan, out_chan, stride=1):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(in_chan, out_chan, stride)
25 | self.bn1 = nn.BatchNorm2d(out_chan)
26 | self.conv2 = conv3x3(out_chan, out_chan)
27 | self.bn2 = nn.BatchNorm2d(out_chan)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | if in_chan != out_chan or stride != 1:
31 | self.downsample = nn.Sequential(nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
32 | nn.BatchNorm2d(out_chan), )
33 |
34 | def forward(self, x):
35 | residual = self.conv1(x)
36 | residual = F.relu(self.bn1(residual))
37 | residual = self.conv2(residual)
38 | residual = self.bn2(residual)
39 |
40 | shortcut = x
41 | if self.downsample is not None:
42 | shortcut = self.downsample(x)
43 |
44 | out = shortcut + residual
45 | out = self.relu(out)
46 | return out
47 |
48 |
49 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
50 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
51 | for i in range(bnum - 1):
52 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
53 | return nn.Sequential(*layers)
54 |
55 |
56 | class Resnet18(nn.Module):
57 | def __init__(self):
58 | super(Resnet18, self).__init__()
59 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
60 | self.bn1 = nn.BatchNorm2d(64)
61 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
62 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
63 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
64 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
65 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
66 | self.init_weight()
67 |
68 | def forward(self, x):
69 | x = self.conv1(x)
70 | x = F.relu(self.bn1(x))
71 | x = self.maxpool(x)
72 |
73 | x = self.layer1(x)
74 | feat8 = self.layer2(x) # 1/8
75 | feat16 = self.layer3(feat8) # 1/16
76 | feat32 = self.layer4(feat16) # 1/32
77 | return feat8, feat16, feat32
78 |
79 | def init_weight(self):
80 | if resnet18_url.startswith("http"):
81 | state_dict = modelzoo.load_url(resnet18_url)
82 | else:
83 | if config.use_gpu:
84 | state_dict = torch.load(resnet18_url)
85 | else:
86 | state_dict = torch.load(resnet18_url, map_location=torch.device('cpu'))
87 | self_state_dict = self.state_dict()
88 | for k, v in state_dict.items():
89 | if 'fc' in k: continue
90 | self_state_dict.update({k: v})
91 | self.load_state_dict(self_state_dict)
92 |
93 | def get_params(self):
94 | wd_params, nowd_params = [], []
95 | for name, module in self.named_modules():
96 | if isinstance(module, (nn.Linear, nn.Conv2d)):
97 | wd_params.append(module.weight)
98 | if not module.bias is None:
99 | nowd_params.append(module.bias)
100 | elif isinstance(module, nn.BatchNorm2d):
101 | nowd_params += list(module.parameters())
102 | return wd_params, nowd_params
103 |
104 |
105 | if __name__ == "__main__":
106 | net = Resnet18()
107 | x = torch.randn(16, 3, 224, 224)
108 | out = net(x)
109 | print(out[0].size(), out[1].size(), out[2].size())
110 | net.get_params()
111 |
--------------------------------------------------------------------------------
/tools/demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 | from torch.nn import functional as F
5 | import torch.utils.data
6 | import numpy as np
7 | from scipy.stats import entropy
8 | import torchvision.datasets as dset
9 | import torchvision.transforms as transforms
10 | import os
11 | import dataloader_own
12 | from scipy import linalg
13 |
14 | from inception import InceptionV3
15 |
16 | # we should use same mean and std for inception v3 model in training and testing process
17 | # reference web page: https://pytorch.org/hub/pytorch_vision_inception_v3/
18 | mean_inception = [0.485, 0.456, 0.406]
19 | std_inception = [0.229, 0.224, 0.225]
20 |
21 | def compute_FID(img1, img2, batch_size=1):
22 | device = torch.device("cuda:0") # you can change the index of cuda
23 |
24 | N1 = len(img1)
25 | N2 = len(img2)
26 | n_act = 2048 # the number of final layer's dimension
27 |
28 | # Set up dataloaders
29 | dataloader1 = torch.utils.data.DataLoader(img1, batch_size=batch_size)
30 | dataloader2 = torch.utils.data.DataLoader(img2, batch_size=batch_size)
31 |
32 | # Load inception model
33 | # inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
34 |
35 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[n_act]
36 | inception_model = InceptionV3([block_idx]).to(device)
37 | inception_model.eval()
38 |
39 | # get the activations
40 | def get_activations(x):
41 | x = inception_model(x)[0]
42 | return x.cpu().data.numpy().reshape(batch_size, -1)
43 |
44 | act1 = np.zeros((N1, n_act))
45 | act2 = np.zeros((N2, n_act))
46 |
47 | data = [dataloader1, dataloader2]
48 | act = [act1, act2]
49 | for n, loader in enumerate(data):
50 | for i, batch in enumerate(loader, 0):
51 | batch = batch[0].to(device)
52 | batch_size_i = batch.size()[0]
53 | activation = get_activations(batch)
54 |
55 | act[n][i * batch_size:i * batch_size + batch_size_i] = activation
56 |
57 | # compute the activation's statistics: mean and std
58 | def compute_act_mean_std(act):
59 | mu = np.mean(act, axis=0)
60 | sigma = np.cov(act, rowvar=False)
61 | return mu, sigma
62 | mu_act1, sigma_act1 = compute_act_mean_std(act1)
63 | mu_act2, sigma_act2 = compute_act_mean_std(act2)
64 |
65 | # compute FID
66 | def _compute_FID(mu1, mu2, sigma1, sigma2,eps=1e-6):
67 | mu1 = np.atleast_1d(mu1)
68 | mu2 = np.atleast_1d(mu2)
69 | sigma1 = np.atleast_2d(sigma1)
70 | sigma2 = np.atleast_2d(sigma2)
71 |
72 | diff = mu1 - mu2
73 |
74 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
75 | if not np.isfinite(covmean).all():
76 | msg = ('fid calculation produces singular product; '
77 | 'adding %s to diagonal of cov estimates') % eps
78 | print(msg)
79 | offset = np.eye(sigma1.shape[0]) * eps
80 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
81 |
82 | # Numerical error might give slight imaginary component
83 | if np.iscomplexobj(covmean):
84 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
85 | m = np.max(np.abs(covmean.imag))
86 | raise ValueError('Imaginary component {}'.format(m))
87 | covmean = covmean.real
88 |
89 | tr_covmean = np.trace(covmean)
90 |
91 | FID = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
92 |
93 | return FID
94 |
95 | FID = _compute_FID(mu_act1, mu_act2, sigma_act1, sigma_act2)
96 | return FID
97 |
98 |
99 | # main function to compuate FID
100 | data_root = os.path.join('/PATH/TO/YOUR/IMAGE1')
101 | my_dataset_fakeB = dataloader_own.image_loader(data_root, batch_size=64, img_size=299, resize=True, rotation=False, normalize=[mean_inception, std_inception])
102 | data_root = os.path.join('/PATH/TO/YOUR/IMAGE2')
103 | my_dataset_realB = dataloader_own.image_loader(data_root, batch_size=64, img_size=299, resize=True, rotation=False, normalize=[mean_inception, std_inception])
104 |
105 | FID = compute_FID(my_dataset_fakeB, my_dataset_realB)
106 | print(FID)
--------------------------------------------------------------------------------
/tools/face_align.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import dlib
3 | import numpy
4 | import sys
5 | import matplotlib.pyplot as plt
6 |
7 | PREDICTOR_PATH = r"../checkpoint/shape_predictor_68_face_landmarks.dat" # 68个关键点landmarks的模型文件
8 | SCALE_FACTOR = 1 # 图像的放缩比
9 | FEATHER_AMOUNT = 15 # 羽化边界范围,越大,羽化能力越大,一定要奇数,不能偶数
10 |
11 | # 68个点
12 | FACE_POINTS = list(range(17, 68)) # 脸
13 | MOUTH_POINTS = list(range(48, 61)) # 嘴巴
14 | RIGHT_BROW_POINTS = list(range(17, 22)) # 右眉毛
15 | LEFT_BROW_POINTS = list(range(22, 27)) # 左眉毛
16 | RIGHT_EYE_POINTS = list(range(36, 42)) # 右眼睛
17 | LEFT_EYE_POINTS = list(range(42, 48)) # 左眼睛
18 | NOSE_POINTS = list(range(27, 35)) # 鼻子
19 | JAW_POINTS = list(range(0, 17)) # 下巴
20 |
21 | # 选取用于叠加在第一张脸上的第二张脸的面部特征
22 | # 特征点包括左右眼、眉毛、鼻子和嘴巴
23 | # 是否数量变多之后,会有什么干扰吗?
24 | ALIGN_POINTS = (FACE_POINTS + LEFT_BROW_POINTS + RIGHT_EYE_POINTS + LEFT_EYE_POINTS +
25 | RIGHT_BROW_POINTS + NOSE_POINTS + MOUTH_POINTS + JAW_POINTS)
26 |
27 | # Points from the second image to overlay on the first. The convex hull of each
28 | # element will be overlaid.
29 | OVERLAY_POINTS = [
30 | LEFT_EYE_POINTS + RIGHT_EYE_POINTS + LEFT_BROW_POINTS + RIGHT_BROW_POINTS,
31 | NOSE_POINTS + MOUTH_POINTS,
32 | ]
33 | # 眼睛 ,眉毛 2 * 22
34 | # 鼻子,嘴巴 分开来
35 |
36 | # 定义用于颜色校正的模糊量,作为瞳孔距离的系数
37 | COLOUR_CORRECT_BLUR_FRAC = 0.6
38 |
39 | # 实例化脸部检测器
40 | detector = dlib.get_frontal_face_detector()
41 | # 加载训练模型
42 | # 并实例化特征提取器
43 | predictor = dlib.shape_predictor(PREDICTOR_PATH)
44 |
45 |
46 | # 定义了两个类处理意外
47 | class TooManyFaces(Exception):
48 | pass
49 |
50 |
51 | class NoFaces(Exception):
52 | pass
53 |
54 |
55 | def get_landmarks(im):
56 | '''
57 | 通过predictor 拿到68 landmarks
58 | '''
59 | rects = detector(im, 1)
60 |
61 | if len(rects) > 1:
62 | raise TooManyFaces
63 | if len(rects) == 0:
64 | raise NoFaces
65 |
66 | return numpy.matrix([[p.x, p.y] for p in predictor(im, rects[0]).parts()]) # 68*2的矩阵
67 |
68 |
69 | def annotate_landmarks(im, landmarks):
70 | '''
71 | 人脸关键点,画图函数
72 | '''
73 | im = im.copy()
74 | for idx, point in enumerate(landmarks):
75 | pos = (point[0, 0], point[0, 1])
76 | cv2.putText(im, str(idx), pos,
77 | fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,
78 | fontScale=0.4,
79 | color=(0, 0, 255))
80 | cv2.circle(im, pos, 3, color=(0, 255, 255))
81 | return im
82 |
83 |
84 | def draw_convex_hull(im, points, color):
85 | '''
86 | # 绘制凸多边形 计算凸包
87 | '''
88 | points = cv2.convexHull(points)
89 | cv2.fillConvexPoly(im, points, color=color)
90 |
91 |
92 | def get_face_mask(im, landmarks):
93 | '''获取面部特征部分(眉毛、眼睛、鼻子以及嘴巴)的图像掩码。
94 | 图像掩码作用于原图之后,原图中对应掩码部分为白色的部分才能显示出来,黑色的部分则不予显示,因此通过图像掩码我们就能实现对图像“裁剪”。
95 | 效果参考:https://dn-anything-about-doc.qbox.me/document-uid242676labid2260timestamp1477921310170.png/wm
96 | get_face_mask()的定义是为一张图像和一个标记矩阵生成一个遮罩,它画出了两个白色的凸多边形:一个是眼睛周围的区域,
97 | 一个是鼻子和嘴部周围的区域。之后它由11个(FEATHER_AMOUNT)像素向遮罩的边缘外部羽化扩展,可以帮助隐藏任何不连续的区域。
98 | '''
99 | im = numpy.zeros(im.shape[:2], dtype=numpy.float64)
100 |
101 | for group in OVERLAY_POINTS:
102 | draw_convex_hull(im,
103 | landmarks[group],
104 | color=1)
105 |
106 | im = numpy.array([im, im, im]).transpose((1, 2, 0))
107 |
108 | im = (cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0) > 0) * 1.0
109 | im = cv2.GaussianBlur(im, (FEATHER_AMOUNT, FEATHER_AMOUNT), 0)
110 |
111 | return im
112 |
113 | # 返回一个仿射变换
114 | def transformation_from_points(points1, points2):
115 | """
116 | Return an affine transformation [s * R | T] such that:
117 | sum ||s*R*p1,i + T - p2,i||^2
118 | is minimized.
119 | """
120 | # Solve the procrustes problem by subtracting centroids, scaling by the
121 | # standard deviation, and then using the SVD to calculate the rotation. See
122 | # the following for more details:
123 | # https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem
124 |
125 | points1 = points1.astype(numpy.float64) # 人脸的指定关键点
126 | points2 = points2.astype(numpy.float64)
127 |
128 | # 数据标准化:先减去均值,再除以std,做成均值为0方差为1的序列
129 | # 每张脸各自做各自的标准化
130 | c1 = numpy.mean(points1, axis=0) # 分别算x和y的均值
131 | c2 = numpy.mean(points2, axis=0)
132 | points1 -= c1 # 浮动于均值的部分,[43, 2]
133 | points2 -= c2
134 |
135 | s1 = numpy.std(points1)
136 | s2 = numpy.std(points2)
137 | points1 /= s1 #
138 | points2 /= s2
139 |
140 | U, S, Vt = numpy.linalg.svd(points1.T * points2) #
141 |
142 | # The R we seek is in fact the transpose of the one given by U * Vt. This
143 | # is because the above formulation assumes the matrix goes on the right
144 | # (with row vectors) where as our solution requires the matrix to be on the
145 | # left (with column vectors).
146 | R = (U * Vt).T # [2, 2]
147 |
148 | return numpy.vstack([numpy.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)),
149 | numpy.matrix([0., 0., 1.])])
150 |
151 |
152 | def read_im_and_landmarks(fname):
153 | im = cv2.imread(fname, cv2.IMREAD_COLOR)
154 | im = cv2.resize(im, (im.shape[1] * SCALE_FACTOR,
155 | im.shape[0] * SCALE_FACTOR))
156 | s = get_landmarks(im) # [68, 2]
157 |
158 | return im, s
159 |
160 |
161 | def warp_im(im, M, dshape):
162 | '''
163 | 由 get_face_mask 获得的图像掩码还不能直接使用,因为一般来讲用户提供的两张图像的分辨率大小很可能不一样,而且即便分辨率一样,
164 | 图像中的人脸由于拍摄角度和距离等原因也会呈现出不同的大小以及角度,所以如果不能只是简单地把第二个人的面部特征抠下来直接放在第一个人脸上,
165 | 我们还需要根据两者计算所得的面部特征区域进行匹配变换,使得二者的面部特征尽可能重合。
166 |
167 | 仿射函数,warpAffine,能对图像进行几何变换
168 | 三个主要参数,第一个输入图像,第二个变换矩阵 np.float32 类型,第三个变换之后图像的宽高
169 |
170 | 对齐主要函数
171 | '''
172 | output_im = numpy.zeros(dshape, dtype=im.dtype) # [512, 512, 3]
173 | cv2.warpAffine(im,
174 | M[:2],
175 | (dshape[1], dshape[0]),
176 | dst=output_im,
177 | borderMode=cv2.BORDER_TRANSPARENT,
178 | flags=cv2.WARP_INVERSE_MAP)
179 | return output_im
180 |
181 |
182 | def correct_colours(im1, im2, landmarks1):
183 | '''
184 | 修改皮肤颜色,使两张图片在拼接时候显得更加自然。
185 | '''
186 | blur_amount = COLOUR_CORRECT_BLUR_FRAC * numpy.linalg.norm(
187 | numpy.mean(landmarks1[LEFT_EYE_POINTS], axis=0) -
188 | numpy.mean(landmarks1[RIGHT_EYE_POINTS], axis=0))
189 | blur_amount = int(blur_amount)
190 | if blur_amount % 2 == 0:
191 | blur_amount += 1
192 | im1_blur = cv2.GaussianBlur(im1, (blur_amount, blur_amount), 0)
193 | im2_blur = cv2.GaussianBlur(im2, (blur_amount, blur_amount), 0)
194 |
195 | # Avoid divide-by-zero errors.
196 | im2_blur += (128 * (im2_blur <= 1.0)).astype(im2_blur.dtype)
197 |
198 | return (im2.astype(numpy.float64) * im1_blur.astype(numpy.float64) /
199 | im2_blur.astype(numpy.float64))
200 |
201 |
202 | # 换脸函数
203 | def Switch_face(Base_path, cover_path):
204 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图
205 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图
206 |
207 | if len(landmarks1) == 0 & len(landmarks2) == 0:
208 | raise RuntimeError("Faces detected is no face!")
209 | if len(landmarks1) > 1 & len(landmarks2) > 1:
210 | raise RuntimeError("Faces detected is more than 1!")
211 |
212 | # landmarks1[ALIGN_POINTS]为人脸的的指定关键点
213 | M = transformation_from_points(landmarks1[ALIGN_POINTS],
214 | landmarks2[ALIGN_POINTS])
215 | mask = get_face_mask(im2, landmarks2)
216 | warped_mask = warp_im(mask, M, im1.shape)
217 | combined_mask = numpy.max([get_face_mask(im1, landmarks1), warped_mask],
218 | axis=0)
219 | warped_im2 = warp_im(im2, M, im1.shape)
220 | warped_corrected_im2 = correct_colours(im1, warped_im2, landmarks1)
221 |
222 | output_im = im1 * (1.0 - combined_mask) + warped_corrected_im2 * combined_mask
223 | return output_im
224 |
225 |
226 | # 人脸对齐函数
227 | def face_Align(Base_path, cover_path):
228 | im1, landmarks1 = read_im_and_landmarks(Base_path) # 底图
229 | im2, landmarks2 = read_im_and_landmarks(cover_path) # 贴上来的图
230 |
231 | # 得到仿射变换矩阵
232 | M = transformation_from_points(landmarks1[ALIGN_POINTS],
233 | landmarks2[ALIGN_POINTS])
234 | warped_im2 = warp_im(im2, M, im1.shape)
235 | return warped_im2
236 |
237 | FEATHER_AMOUNT = 19
238 |
239 | template_img = '../dat/avg_face.jpg' # 模板
240 | process_img = '../dat/0020.png'
241 | warped_mask = face_Align(template_img, process_img)
242 | cv2.imwrite("../dat/result_0020.jpg", warped_mask)
243 |
244 | # plt.subplot(111)
245 | # plt.imshow(warped_mask) # 数据展示
246 | # plt.show()
--------------------------------------------------------------------------------
/tools/fid.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils.data import Dataset, DataLoader
3 | from torchvision.models.inception import Inception3
4 | from torchvision.models.utils import load_state_dict_from_url
5 |
6 | mean_inception = [0.485, 0.456, 0.406]
7 | std_inception = [0.229, 0.224, 0.225]
8 |
9 | class FID_Dataset(Dataset):
10 | def __init__(self, imgpath, transform):
11 | self.imgpath = imgpath
12 | self.fileList = [os.path.join(imgpath, file) for file in os.listdir(imgpath)]
13 | self.transform = transform
14 |
15 |
16 | '''
17 | 计算FID指标
18 | '''
19 | def compute_FID():
20 | pass
21 |
22 |
23 | if __name__ == '__main__':
24 | model = Inception3()
25 | print(model)
26 | state_dict = load_state_dict_from_url("https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth", progress=True)
27 | model.load_state_dict(state_dict)
28 |
--------------------------------------------------------------------------------
/tools/lsgan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | from torch.utils.data import DataLoader
9 |
10 | import torchvision.transforms as transforms
11 | from torchvision.utils import save_image
12 | from torchvision import datasets
13 |
14 | '''
15 | LSGAN:最小二乘生成对抗损失
16 | GAN存在的通用问题:生成图片质量不高;训练过程不稳定
17 | 传统做法:
18 | 1.采用交叉熵损失(只看True/False)。使得生成器不再优化被判别器识别为True的fake image,即使生成的图片离判别器的决策边界很远,即离真实数据很远。
19 | 这意味着生成器生成的图片质量并不高。为什么生成器不再优化生成图片?因为已经完成了目标——即骗过判别器,所以这时交叉熵损失已经很小了。
20 | 而最小二乘法要求,骗过判别器的前提下还得让生成器把离决策边界比较远的图片拉向决策边界。
21 | 2.sigmoid函数,输入过大或过小时,都会造成梯度消失;而最小二乘只有x=1时梯度为0。
22 | LSGAN做法:与传统做法类似,只需将损失函数换成torch.nn.BCELoss()即可,如下:
23 | adversarial_loss = torch.nn.BCELoss(),若报错RuntimeError: all elements of input should be between 0 and 1,可采用如下:
24 | adversarial_loss = torch.nn.BCEWithLogitsLoss()
25 | '''
26 |
27 | # 配置项
28 | epochs=200
29 | batch_size=64
30 | learning_rate=0.0002
31 | latent_dim=100 # 从100维向量开始生成图片
32 | image_size=32
33 | sample_interval=100 # 每隔1000个batch保存一下
34 | save_sample_path = "../output/images/" # 样例图片保存路径
35 | if os.path.exists("../output/images/") is False:
36 | os.makedirs("../output/images/")
37 |
38 | cuda = True if torch.cuda.is_available() else False
39 |
40 | def weights_init_normal(m):
41 | classname = m.__class__.__name__
42 | if classname.find("Conv") != -1:
43 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
44 | elif classname.find("BatchNorm") != -1:
45 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
46 | torch.nn.init.constant_(m.bias.data, 0.0)
47 |
48 | # 生成器
49 | class Generator(nn.Module):
50 | def __init__(self):
51 | super(Generator, self).__init__()
52 |
53 | self.init_size = image_size // 4
54 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))
55 |
56 | self.conv_blocks = nn.Sequential(
57 | nn.Upsample(scale_factor=2),
58 | nn.Conv2d(128, 128, 3, stride=1, padding=1),
59 | nn.BatchNorm2d(128, 0.8),
60 | nn.LeakyReLU(0.2, inplace=True),
61 | nn.Upsample(scale_factor=2),
62 | nn.Conv2d(128, 64, 3, stride=1, padding=1),
63 | nn.BatchNorm2d(64, 0.8),
64 | nn.LeakyReLU(0.2, inplace=True),
65 | nn.Conv2d(64, 1, 3, stride=1, padding=1), # 最后生成1维的图像
66 | nn.Tanh(),
67 | )
68 |
69 | def forward(self, z):
70 | out = self.l1(z)
71 | out = out.view(out.shape[0], 128, self.init_size, self.init_size)
72 | img = self.conv_blocks(out)
73 | return img
74 |
75 |
76 | class Discriminator(nn.Module):
77 | def __init__(self):
78 | super(Discriminator, self).__init__()
79 |
80 | def discriminator_block(in_filters, out_filters, bn=True):
81 | block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
82 | if bn:
83 | block.append(nn.BatchNorm2d(out_filters, 0.8))
84 | return block
85 |
86 | self.model = nn.Sequential(
87 | *discriminator_block(1, 16, bn=False),
88 | *discriminator_block(16, 32),
89 | *discriminator_block(32, 64),
90 | *discriminator_block(64, 128),
91 | )
92 |
93 | # The height and width of downsampled image
94 | ds_size = image_size // 2 ** 4
95 | self.adv_layer = nn.Linear(128 * ds_size ** 2, 1)
96 |
97 | def forward(self, img):
98 | out = self.model(img)
99 | out = out.view(out.shape[0], -1)
100 | validity = self.adv_layer(out)
101 |
102 | return validity
103 |
104 |
105 | # 这里使用MSE损失,符合LSGAN最小二乘损失的思想
106 | # MSE:均方误差损失函数;BCE:二分类下的交叉熵损失(传统gan里用的BCELoss)
107 | adversarial_loss = torch.nn.MSELoss()
108 | # adversarial_loss = torch.nn.BCEWithLogitsLoss() # 若报错RuntimeError: all elements of input should be between 0 and 1,可采用torch.nn.BCEWithLogitsLoss()
109 |
110 | # Initialize generator and discriminator
111 | generator = Generator()
112 | discriminator = Discriminator()
113 |
114 | if cuda:
115 | generator.cuda()
116 | discriminator.cuda()
117 | adversarial_loss.cuda()
118 |
119 | # Initialize weights
120 | generator.apply(weights_init_normal)
121 | discriminator.apply(weights_init_normal)
122 |
123 | # Configure data loader
124 | os.makedirs("../output/mnist", exist_ok=True)
125 | dataloader = torch.utils.data.DataLoader(
126 | datasets.MNIST(
127 | "../output/mnist",
128 | train=True,
129 | download=True,
130 | transform=transforms.Compose(
131 | [transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
132 | ),
133 | ),
134 | batch_size=batch_size,
135 | shuffle=True,
136 | )
137 |
138 | # Optimizers
139 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
140 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
141 |
142 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
143 |
144 | # 训练过程
145 | g_loss_list = []
146 | d_loss_list = []
147 |
148 | for epoch in range(epochs):
149 | for i, (imgs, _) in enumerate(dataloader): # images [batch_size, 1, image_size, image_size]
150 |
151 | # Adversarial ground truths
152 | valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) # [batch_size, 1]
153 | fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) # [batch_size, 0]
154 |
155 | real_imgs = Variable(imgs.type(Tensor)) # 真实图片
156 |
157 | # 训练生成器
158 | optimizer_G.zero_grad()
159 | z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim)))) # [batch_size, latent_dim]
160 | gen_imgs = generator(z) # [batch_size, 1, image_size, image_size]
161 |
162 | # 生成器优化方向是D(x)越来越靠近1
163 | g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 判别器判断的结果与真做MSE
164 |
165 | g_loss.backward()
166 | optimizer_G.step()
167 |
168 | # 训练判别器
169 | optimizer_D.zero_grad()
170 |
171 | # LSGAN,要求不能简简单单算交叉熵损失,而是计算两个分布之间的均方误差损失
172 | real_loss = adversarial_loss(discriminator(real_imgs), valid) # 要求把真的识别为真
173 | fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 把假的识别为假
174 | d_loss = 0.5 * (real_loss + fake_loss)
175 |
176 | d_loss.backward()
177 | optimizer_D.step()
178 |
179 | print(
180 | "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
181 | % (epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item())
182 | )
183 |
184 | if i % sample_interval == 0:
185 | save_image(gen_imgs.data[:25], "%s/%d_%d.png" % (save_sample_path, epoch, i), nrow=5, normalize=True)
186 | g_loss_list.append(g_loss.item())
187 | d_loss_list.append(d_loss.item())
188 |
189 | if epoch >= 0:
190 | plt.figure()
191 | plt.subplot(121)
192 | plt.plot(np.arange(0, len(g_loss_list)), g_loss_list)
193 | plt.subplot(122)
194 | plt.plot(np.arange(0, len(d_loss_list)), d_loss_list)
195 | plt.savefig("loss.jpg")
196 | plt.close("all")
--------------------------------------------------------------------------------
/train_myimitator.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import random
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.parallel
7 | import torch.nn.functional as F
8 | from torch.optim import lr_scheduler
9 | import torch.backends.cudnn as cudnn
10 | import torch.optim as optim
11 | import torch.utils.data
12 | from torchvision import transforms as T
13 | from torch.utils.data import DataLoader, Dataset
14 | import json
15 | import torchvision.utils as vutils
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | from PIL import Image
19 | import time
20 |
21 | import copy
22 | import math
23 |
24 | '''
25 | 在服务器上训练只需上传这一个文件即可
26 | '''
27 |
28 | # Set random seed for reproducibility
29 | manualSeed = 999
30 | # manualSeed = random.randint(1, 10000) # use if you want new results
31 | print("Random Seed: ", manualSeed)
32 | random.seed(manualSeed)
33 | torch.manual_seed(manualSeed)
34 |
35 | # Batch size during training
36 | batch_size = 16
37 | image_size = 512
38 | num_epochs = 1000
39 | lr = 0.01
40 | ngpu = 2
41 |
42 | image_root = "F:/dataset/face_20211203_20000_nojiemao/"
43 |
44 | class Imitator_Dataset(Dataset):
45 | def __init__(self, params_root, image_root, mode="train"):
46 | self.image_root = image_root
47 | self.mode = mode
48 | with open(params_root, encoding='utf-8') as f:
49 | self.params = json.load(f)
50 |
51 | def __getitem__(self, index):
52 | if self.mode == "val":
53 | img = Image.open(os.path.join(self.image_root, '%d.png' % (index + 18000))).convert("RGB")
54 | param = torch.tensor(self.params['%d.png' % (index + 18000)])
55 | else:
56 | img = Image.open(os.path.join(self.image_root, '%d.png' % index)).convert("RGB")
57 | param = torch.tensor(self.params['%d.png' % index])
58 | img = T.ToTensor()(img)
59 | return param, img
60 |
61 | def __len__(self):
62 | if self.mode == "train":
63 | return 18000
64 | else:
65 | return 2000
66 |
67 |
68 | train_dataset = Imitator_Dataset(image_root + "param.json", image_root + "face_train/", mode="train")
69 | val_dataset = Imitator_Dataset(image_root + "param.json", image_root + "face_val/", mode="val")
70 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
71 | val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
72 |
73 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
74 |
75 |
76 | # real_batch = next(iter(val_dataloader))
77 | # plt.figure(figsize=(4, 4))
78 | # plt.axis("off")
79 | # plt.title("Training Images")
80 | # plt.imshow(np.transpose(vutils.make_grid(real_batch[1].to(device)[:16], nrow=4, padding=2, normalize=True).cpu(), (1, 2, 0)))
81 | # plt.show()
82 | # vutils.save_image(vutils.make_grid(real_batch[1].to(device)[:16], nrow=4, padding=2, normalize=True).cpu(), "./a.jpg")
83 |
84 |
85 |
86 | '''
87 | 自定义Imitator
88 | 1.conv,linear,embedding后加上sn
89 | 2.指定层加上self-attention
90 | 3.自定义bn
91 | '''
92 |
93 | # 采用sn做 normalization
94 | def snconv2d(eps=1e-12, **kwargs):
95 | return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
96 |
97 | def snlinear(eps=1e-12, **kwargs):
98 | return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
99 |
100 | def sn_embedding(eps=1e-12, **kwargs):
101 | return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
102 |
103 | # self-attention层
104 | class SelfAttn(nn.Module):
105 | def __init__(self, in_channels, eps=1e-12):
106 | super(SelfAttn, self).__init__()
107 | self.in_channels = in_channels
108 | self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
109 | kernel_size=1, bias=False, eps=eps)
110 | self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
111 | kernel_size=1, bias=False, eps=eps)
112 | self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
113 | kernel_size=1, bias=False, eps=eps)
114 | self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
115 | kernel_size=1, bias=False, eps=eps)
116 | self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
117 | self.softmax = nn.Softmax(dim=-1)
118 | self.gamma = nn.Parameter(torch.zeros(1))
119 |
120 | def forward(self, x):
121 | _, ch, h, w = x.size()
122 | # Theta path
123 | theta = self.snconv1x1_theta(x)
124 | theta = theta.view(-1, ch//8, h*w)
125 | # Phi path
126 | phi = self.snconv1x1_phi(x)
127 | phi = self.maxpool(phi)
128 | phi = phi.view(-1, ch//8, h*w//4)
129 | # Attn map
130 | attn = torch.bmm(theta.permute(0, 2, 1), phi)
131 | attn = self.softmax(attn)
132 | # g path
133 | g = self.snconv1x1_g(x)
134 | g = self.maxpool(g)
135 | g = g.view(-1, ch//2, h*w//4)
136 | # Attn_g - o_conv
137 | attn_g = torch.bmm(g, attn.permute(0, 2, 1))
138 | attn_g = attn_g.view(-1, ch//2, h, w)
139 | attn_g = self.snconv1x1_o_conv(attn_g)
140 | # Out
141 | out = x + self.gamma*attn_g
142 | return out
143 |
144 | # 自定义bn
145 | class BigGANBatchNorm(nn.Module):
146 | """ This is a batch norm module that can handle conditional input and can be provided with pre-computed
147 | activation means and variances for various truncation parameters.
148 |
149 | We cannot just rely on torch.batch_norm since it cannot handle
150 | batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
151 | If you want to train this model you should add running means and variance computation logic.
152 | """
153 | def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
154 | super(BigGANBatchNorm, self).__init__()
155 | self.num_features = num_features
156 | self.eps = eps
157 | self.conditional = conditional
158 |
159 | # We use pre-computed statistics for n_stats values of truncation between 0 and 1
160 | self.register_buffer('running_means', torch.zeros(n_stats, num_features))
161 | self.register_buffer('running_vars', torch.ones(n_stats, num_features))
162 | self.step_size = 1.0 / (n_stats - 1)
163 |
164 | if conditional:
165 | assert condition_vector_dim is not None
166 | self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
167 | self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
168 | else:
169 | self.weight = torch.nn.Parameter(torch.Tensor(num_features))
170 | self.bias = torch.nn.Parameter(torch.Tensor(num_features))
171 |
172 | def forward(self, x, truncation, condition_vector=None):
173 | # Retreive pre-computed statistics associated to this truncation
174 | coef, start_idx = math.modf(truncation / self.step_size)
175 | start_idx = int(start_idx)
176 | if coef != 0.0: # Interpolate
177 | running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
178 | running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
179 | else:
180 | running_mean = self.running_means[start_idx]
181 | running_var = self.running_vars[start_idx]
182 |
183 | if self.conditional:
184 | running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
185 | running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
186 |
187 | weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
188 | bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
189 |
190 | out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
191 | else:
192 | out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
193 | training=False, momentum=0.0, eps=self.eps)
194 | return out
195 |
196 | class GenBlock(nn.Module):
197 | def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
198 | n_stats=51, eps=1e-12):
199 | super(GenBlock, self).__init__()
200 | self.up_sample = up_sample
201 | self.drop_channels = (in_size != out_size)
202 | middle_size = in_size // reduction_factor
203 |
204 | self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
205 | self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
206 |
207 | self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
208 | self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
209 |
210 | self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
211 | self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
212 |
213 | self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
214 | self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
215 |
216 | self.relu = nn.ReLU()
217 |
218 | def forward(self, x, cond_vector, truncation):
219 | x0 = x
220 |
221 | x = self.bn_0(x, truncation, cond_vector)
222 | x = self.relu(x)
223 | x = self.conv_0(x)
224 |
225 | x = self.bn_1(x, truncation, cond_vector)
226 | x = self.relu(x)
227 | if self.up_sample:
228 | x = F.interpolate(x, scale_factor=2, mode='nearest')
229 | x = self.conv_1(x)
230 |
231 | x = self.bn_2(x, truncation, cond_vector)
232 | x = self.relu(x)
233 | x = self.conv_2(x)
234 |
235 | x = self.bn_3(x, truncation, cond_vector)
236 | x = self.relu(x)
237 | x = self.conv_3(x)
238 |
239 | if self.drop_channels:
240 | new_channels = x0.shape[1] // 2
241 | x0 = x0[:, :new_channels, ...]
242 | if self.up_sample:
243 | x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
244 |
245 | out = x + x0
246 | return out
247 |
248 | class MyImitator(nn.Module):
249 | def __init__(self):
250 | super(MyImitator, self).__init__()
251 |
252 | # 1.加载配置文件
253 | with open("./checkpoint/myimitator-512.json", "r", encoding='utf-8') as reader:
254 | text = reader.read()
255 | self.conf = BigGANConfig()
256 | for key, value in json.loads(text).items():
257 | self.conf.__dict__[key] = value
258 |
259 | # 定义网络结构
260 | # self.embeddings = nn.Linear(config.num_classes, config.continuous_params_size, bias=False)
261 |
262 | ch = self.conf.channel_width
263 | condition_vector_dim = 223
264 |
265 | self.gen_z = snlinear(in_features=condition_vector_dim, out_features=4*4*16*ch, eps=self.conf.eps)
266 | layers = []
267 | for i, layer in enumerate(self.conf.layers):
268 | if i == self.conf.attention_layer_position: # 在指定层加上self-attention
269 | layers.append(SelfAttn(ch * layer[1], eps=self.conf.eps))
270 | layers.append(GenBlock(ch * layer[1],
271 | ch * layer[2],
272 | condition_vector_dim,
273 | up_sample=layer[0],
274 | n_stats=self.conf.n_stats,
275 | eps=self.conf.eps))
276 | self.layers = nn.ModuleList(layers)
277 |
278 | self.bn = BigGANBatchNorm(ch, n_stats=self.conf.n_stats, eps=self.conf.eps, conditional=False)
279 | self.relu = nn.ReLU()
280 | self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=self.conf.eps)
281 | self.tanh = nn.Tanh()
282 |
283 | def forward(self, cond_vector, truncation=0.4):
284 | # cond_vector = cond_vector.unsqueeze(2).unsqueeze(3)
285 | z = self.gen_z(cond_vector) # cond_cector [batch_size, config.continuous_params_size], z [1, 4*4*16*self.conf.channel_width]
286 |
287 | # We use this conversion step to be able to use TF weights:
288 | # TF convention on shape is [batch, height, width, channels]
289 | # PT convention on shape is [batch, channels, height, width]
290 | z = z.view(-1, 4, 4, 16 * self.conf.channel_width) # [batch_size, 4, 4, 2048]
291 | z = z.permute(0, 3, 1, 2).contiguous() # [batch_size, 2048, 4, 4]
292 |
293 | for i, layer in enumerate(self.layers):
294 | if isinstance(layer, GenBlock):
295 | z = layer(z, cond_vector, truncation)
296 | else:
297 | z = layer(z)
298 |
299 | z = self.bn(z, truncation) # [1, 128, 512, 512]
300 | z = self.relu(z) # [1, 128, 512, 512]
301 | z = self.conv_to_rgb(z) # [1, 128, 512, 512]
302 | z = z[:, :3, ...] # [1, 3, 512, 512]
303 | z = self.tanh(z) # [1, 3, 512, 512]
304 | return z
305 |
306 | '''
307 | 自定义Imitator的config
308 | '''
309 | class BigGANConfig(object):
310 | """ Configuration class to store the configuration of a `BigGAN`.
311 | Defaults are for the 128x128 model.
312 | layers tuple are (up-sample in the layer ?, input channels, output channels)
313 | """
314 | def __init__(self,
315 | output_dim=512,
316 | z_dim=512,
317 | class_embed_dim=512,
318 | channel_width=512,
319 | num_classes=1000,
320 | # (是否上采样,input_channels,output_channels)
321 | layers=[(False, 16, 16),
322 | (True, 16, 16),
323 | (False, 16, 16),
324 | (True, 16, 8),
325 | (False, 8, 8),
326 | (True, 8, 4),
327 | (False, 4, 4),
328 | (True, 4, 2),
329 | (False, 2, 2),
330 | (True, 2, 1)],
331 | attention_layer_position=8,
332 | eps=1e-4,
333 | n_stats=51):
334 | """Constructs BigGANConfig. """
335 | self.output_dim = output_dim
336 | self.z_dim = z_dim
337 | self.class_embed_dim = class_embed_dim
338 | self.channel_width = channel_width
339 | self.num_classes = num_classes
340 | self.layers = layers
341 | self.attention_layer_position = attention_layer_position
342 | self.eps = eps
343 | self.n_stats = n_stats
344 |
345 | @classmethod
346 | def from_dict(cls, json_object):
347 | """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
348 | config = BigGANConfig()
349 | for key, value in json_object.items():
350 | config.__dict__[key] = value
351 | return config
352 |
353 | @classmethod
354 | def from_json_file(cls, json_file):
355 | """Constructs a `BigGANConfig` from a json file of parameters."""
356 | with open(json_file, "r", encoding='utf-8') as reader:
357 | text = reader.read()
358 | return cls.from_dict(json.loads(text))
359 |
360 | def __repr__(self):
361 | return str(self.to_json_string())
362 |
363 | def to_dict(self):
364 | """Serializes this instance to a Python dictionary."""
365 | output = copy.deepcopy(self.__dict__)
366 | return output
367 |
368 | def to_json_string(self):
369 | """Serializes this instance to a JSON string."""
370 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
371 |
372 |
373 | imitator = MyImitator()
374 | if device.type == 'cuda':
375 | imitator = nn.DataParallel(imitator)
376 | imitator.to(device)
377 |
378 | # Initialize BCELoss function
379 | criterion = nn.L1Loss()
380 |
381 | # optimizer = optim.SGD(imitator.parameters(), lr=lr, momentum=0.9)
382 | optimizer = optim.Adam(params=imitator.parameters(), lr=5e-5,
383 | betas=(0.0, 0.999), weight_decay=0,
384 | eps=1e-8)
385 |
386 | # 每50个epoch衰减10%
387 | # scheduler = lr_scheduler.StepLR(optimizer, step_size=len(train_dataloader) * 50, gamma=0.9)
388 |
389 | total_step = len(train_dataloader)
390 | imitator.train()
391 | train_loss_list = []
392 | val_loss_list = []
393 | for epoch in range(num_epochs):
394 | start = time.time()
395 | for i, (params, img) in enumerate(train_dataloader):
396 | optimizer.zero_grad()
397 | params = params.to(device)
398 | img = img.to(device)
399 | outputs = imitator(params)
400 | loss = criterion(outputs, img)
401 | loss.backward()
402 | optimizer.step()
403 |
404 | if (i % 10) == 0:
405 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, spend time: {:.4f}'
406 | .format(epoch + 1, num_epochs, i + 1, total_step, loss.item(), time.time() - start))
407 | start = time.time()
408 |
409 | train_loss_list.append(loss.item())
410 | imitator.eval()
411 | with torch.no_grad():
412 | val_loss = 0
413 | for i, (params, img) in enumerate(val_dataloader):
414 | params = params.to(device)
415 | img = img.to(device)
416 | outputs = imitator(params)
417 | loss = criterion(outputs, img)
418 | val_loss += loss.item()
419 | if i == 1:
420 | vutils.save_image(
421 | vutils.make_grid(outputs.to(device)[:16], nrow=4, padding=2, normalize=True).cpu(),
422 | image_root + "gen_image/%d.jpg" % epoch)
423 | val_loss_list.append(val_loss / len(val_dataloader))
424 |
425 | print('Epoch [{}/{}], val_loss: {:.6f}'
426 | .format(epoch + 1, num_epochs, val_loss))
427 | if (epoch % 10) == 0 or (epoch+1) == num_epochs:
428 | torch.save(imitator.state_dict(),
429 | image_root + 'model/epoch_{}_val_loss_{:.6f}_file.pt'.format(
430 | epoch, val_loss))
431 | if epoch >= 1:
432 | plt.figure()
433 | plt.subplot(121)
434 | plt.plot(np.arange(0, len(train_loss_list)), train_loss_list)
435 | plt.subplot(122)
436 | plt.plot(np.arange(0, len(val_loss_list)), val_loss_list)
437 | plt.savefig(image_root + "metrics.jpg")
438 | plt.close("all")
439 |
440 | imitator.train()
441 |
--------------------------------------------------------------------------------
/train_translator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | import torchvision.utils as vutils
9 | import torchvision.transforms as T
10 | from torch.utils.data import DataLoader
11 | import matplotlib.pyplot as plt
12 | from tqdm import tqdm
13 | import numpy as np
14 |
15 | import config
16 | from imitator import Imitator
17 | from dataset import Translator_Dataset, split_dataset
18 | from lightcnn import LightCNN_29Layers_v2
19 | from face_parser import load_model
20 | from translator import Translator
21 |
22 | root_path = "./data/"
23 | r1 = 0.01
24 | r2 = 1
25 | r3 = 1
26 |
27 | def criterion_lightcnn(x1, x2):
28 | distance = torch.cosine_similarity(x1, x2)[0]
29 | return 1 - distance
30 |
31 | if __name__ == '__main__':
32 | # 1.初始化并加载imitator
33 | imitator = Imitator()
34 | if len(config.imitator_model) > 0:
35 | if config.use_gpu:
36 | imitator_model = torch.load(config.imitator_model)
37 | else:
38 | imitator_model = torch.load(config.imitator_model, map_location=torch.device('cpu'))
39 | print("load Imitator pretrained model success!")
40 | else:
41 | print("No pretrained model...")
42 |
43 | imitator = imitator.to(config.device)
44 | for param in imitator.parameters():
45 | param.requires_grad = False
46 |
47 | # 2.加载lightcnn
48 | lightcnn = LightCNN_29Layers_v2(num_classes=80013)
49 | lightcnn = lightcnn.to(config.device)
50 | lightcnn.eval()
51 | if config.use_gpu:
52 | checkpoint = torch.load(config.lightcnn_checkpoint)
53 | model = torch.nn.DataParallel(lightcnn).cuda()
54 | model.load_state_dict(checkpoint['state_dict'])
55 | else:
56 | checkpoint = torch.load(config.lightcnn_checkpoint, map_location="cpu")
57 | new_state_dict = lightcnn.state_dict()
58 | for k, v in checkpoint['state_dict'].items():
59 | _name = k[7:] # remove `module.`
60 | new_state_dict[_name] = v
61 | lightcnn.load_state_dict(new_state_dict)
62 |
63 | # 冻结lightcnn
64 | for param in lightcnn.parameters():
65 | param.requires_grad = False
66 |
67 | # 3.T网络
68 | translator = Translator(isBias=False)
69 | if config.device.type == 'cuda':
70 | translator = nn.DataParallel(translator)
71 | translator.to(config.device)
72 |
73 | # 4.加载face_parser
74 | transform = T.Compose([
75 | T.Normalize(mean=[0.485, 0.456, 0.406],
76 | std=[0.229, 0.224, 0.225]),
77 | ]) # 语义分割之前要先做Normalize
78 | deeplab = load_model('mobilenetv2', num_classes=config.num_classes, output_stride=config.output_stride)
79 | checkpoint = torch.load(config.faceparse_checkpoint, map_location=config.device)
80 | if config.faceparse_backbone == 'resnet50':
81 | deeplab.load_state_dict(checkpoint)
82 | else:
83 | deeplab.load_state_dict(checkpoint["model_state"])
84 | deeplab.to(config.device)
85 | deeplab.eval()
86 |
87 | for param in deeplab.parameters():
88 | param.requires_grad = False
89 |
90 | trainlist, vallist = split_dataset(root_path)
91 |
92 | # 损失
93 | criterion_param = nn.L1Loss()
94 | criterion_parser = nn.L1Loss()
95 |
96 | train_dataset = Translator_Dataset(trainlist)
97 | val_dataset = Translator_Dataset(vallist)
98 | train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
99 | val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=True)
100 |
101 | optimizer = torch.optim.Adam(translator.parameters(), lr=1e-4)
102 |
103 | total_step = len(train_dataloader)
104 | translator.train()
105 | train_loss_list = []
106 | val_loss_list = []
107 | for epoch in range(config.total_epochs):
108 | start = time.time()
109 | for i, imgs in enumerate(train_dataloader):
110 | optimizer.zero_grad()
111 | # features = features.to(config.device)
112 | # params = params.to(config.device)
113 | imgs = imgs.to(config.device)
114 |
115 | # 先做语义分割
116 | parse = deeplab(transform(imgs))
117 |
118 | imgs = T.Grayscale()(F.interpolate(imgs, (128, 128), mode='bilinear'))
119 | _, features = lightcnn(imgs)
120 | outputs = translator(features) # 223
121 |
122 | gen_img = imitator(outputs) # 生成图像
123 | gen_parse = deeplab(transform(gen_img))
124 | gen_img = T.Grayscale()(F.interpolate(gen_img, (128, 128), mode='bilinear'))
125 |
126 | _, gen_features = lightcnn(gen_img)
127 | gen_outputs = translator(gen_features)
128 |
129 | loss1 = criterion_lightcnn(gen_features, features)
130 | loss2 = criterion_parser(gen_parse, parse)
131 | loss3 = criterion_param(outputs, gen_outputs)
132 |
133 | loss = r1 * loss1 + r2 * loss2 + r3 * loss3
134 |
135 | loss.backward()
136 | optimizer.step()
137 |
138 | if (i % 10) == 0:
139 | print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, spend time: {:.4f}'
140 | .format(epoch + 1, config.total_epochs, i + 1, total_step, loss.item(), time.time() - start))
141 | start = time.time()
142 |
143 | train_loss_list.append(loss.item())
144 |
145 | translator.eval()
146 | with torch.no_grad():
147 | val_loss = 0
148 | for i, imgs in enumerate(val_dataloader):
149 | # features = features.to(config.device)
150 | # params = params.to(config.device)
151 |
152 | imgs = imgs.to(config.device)
153 |
154 | imgs = T.Grayscale()(F.interpolate(imgs, (128, 128), mode='bilinear'))
155 | _, features = lightcnn(imgs)
156 | outputs = translator(features) # 223
157 |
158 | gen_img = imitator(outputs) # 生成图像
159 | gen_img = T.Grayscale()(F.interpolate(gen_img, (128, 128), mode='bilinear'))
160 |
161 | _, gen_features = lightcnn(gen_img)
162 | gen_outputs = translator(gen_features)
163 |
164 | loss1 = criterion_lightcnn(gen_features, features)
165 | loss3 = criterion_param(outputs, gen_outputs)
166 |
167 | loss = r1 * loss1 + r3 * loss3
168 |
169 | val_loss += loss.item()
170 | val_loss_list.append(val_loss / len(val_dataloader))
171 |
172 | print('Epoch [{}/{}], val_loss: {:.6f}'
173 | .format(epoch + 1, config.total_epochs, val_loss))
174 | if (epoch % 1) == 0 or (epoch + 1) == config.total_epochs:
175 | torch.save(translator.state_dict(), './checkpoint/translator_{}_{:.6f}.pt'.format(epoch, val_loss))
176 | if epoch >= 1:
177 | plt.figure()
178 | plt.subplot(121)
179 | plt.plot(np.arange(0, len(train_loss_list)), train_loss_list)
180 | plt.subplot(122)
181 | plt.plot(np.arange(0, len(val_loss_list)), val_loss_list)
182 | plt.savefig(root_path + "metrics.jpg")
183 | plt.close("all")
184 |
185 | translator.train()
186 |
--------------------------------------------------------------------------------
/translator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import config
5 |
6 | '''
7 | Translator网络,从face-recognition到params
8 | 网络结构参考第二篇论文:《Fast and Robust Face-to-Parameter Translation for Game Character Auto-Creation》
9 | keypoint:
10 | 1.use the Adam optimizer to train T with the learning rate of 1e-4 and max-iteration of 20 epochs.
11 | 2.the learning rate decay is set to 10% per 50 epochs.(参照imitator训练策略,但一般用不上,epoch=40时就可以停止训练了)
12 | '''
13 | class Translator(nn.Module):
14 | def __init__(self, isBias=False):
15 | super(Translator, self).__init__()
16 | self.fc1 = nn.Linear(in_features=256, out_features=512, bias=isBias)
17 | self.resatt1 = ResAttention(isBias=isBias)
18 | self.resatt2 = ResAttention(isBias=isBias)
19 | self.resatt3 = ResAttention(isBias=isBias)
20 | self.fc2 = nn.Linear(in_features=512, out_features=config.continuous_params_size, bias=isBias)
21 | self.bn = nn.BatchNorm1d(config.continuous_params_size)
22 |
23 | def forward(self, x):
24 | y = self.fc1(x)
25 | y_ = y + self.resatt1(y)
26 | y_ = y_ + self.resatt2(y_)
27 | y_ = y_ + self.resatt3(y_)
28 | return self.bn(self.fc2(y_))
29 |
30 |
31 | '''
32 | 带attention的res模块
33 | '''
34 | class ResAttention(nn.Module):
35 | def __init__(self, isBias=False):
36 | super(ResAttention, self).__init__()
37 | self.fc1 = nn.Linear(512, 1024, bias=isBias)
38 | self.bn1 = nn.BatchNorm1d(1024)
39 | self.relu1 = nn.ReLU()
40 |
41 | self.fc2 = nn.Linear(1024, 512, bias=isBias)
42 | self.bn2 = nn.BatchNorm1d(512)
43 |
44 | # 这里开始分支
45 | self.fc3 = nn.Linear(512, 16, bias=isBias)
46 | self.relu3 = nn.ReLU()
47 |
48 | self.fc4 = nn.Linear(16, 512, bias=isBias)
49 | self.sigmoid4 = nn.Sigmoid()
50 |
51 | # 这里开始做点乘
52 |
53 | self.relu5 = nn.ReLU()
54 |
55 | def forward(self, x):
56 | y = self.fc1(x)
57 | y = self.bn1(y)
58 | y = self.relu1(y)
59 |
60 | y = self.fc2(y)
61 | y = self.bn2(y)
62 |
63 | # 这里开始分支
64 |
65 | y_ = self.fc3(y)
66 | y_ = self.relu3(y_)
67 |
68 | y_ = self.fc4(y_)
69 | y_ = self.sigmoid4(y_)
70 |
71 | # 做点乘
72 | y_ = torch.mul(y, y_)
73 | return self.relu5(y_)
74 |
75 | if __name__ == '__main__':
76 | trans = Translator()
77 | x = torch.randn([2, 256], dtype=torch.float32) # face-recognition的结果
78 | print(x.shape)
79 | y = trans(x)
80 | print(y.shape)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import struct
4 | import numpy as np
5 | from PIL import Image
6 | import matplotlib.pyplot as plt
7 | import torch
8 | import torch.nn as nn
9 | import torchvision.transforms as transforms
10 |
11 | import config
12 | from faceparse import BiSeNet
13 |
14 | # 反卷积
15 | def deconv_layer(in_chanel, out_chanel, kernel_size, stride=1, pad=0):
16 | return nn.Sequential(
17 | nn.ConvTranspose2d(in_chanel, out_chanel, kernel_size=kernel_size, stride=stride, padding=pad),
18 | nn.BatchNorm2d(out_chanel),
19 | nn.ReLU())
20 |
21 | # 自定义异常
22 | class NeuralException(Exception):
23 | def __init__(self, message):
24 | print("neural error: " + message)
25 | self.message = "neural exception: " + message
26 |
27 | '''
28 | 提取原始图像的边缘
29 | :param img: input image
30 | :return: edge image
31 | '''
32 | def img_edge(img):
33 | gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
34 | x_grad = cv2.Sobel(gray, cv2.CV_16SC1, 1, 0)
35 | y_grad = cv2.Sobel(gray, cv2.CV_16SC1, 0, 1)
36 | return cv2.Canny(x_grad, y_grad, 40, 130)
37 |
38 | '''
39 | 将tensor转numpy array 给cv2使用
40 | :param tensor: [batch, c, w, h]
41 | :return: [batch, h, w, c]
42 | '''
43 | def tensor_2_image(tensor):
44 |
45 | batch = tensor.size(0)
46 | images = []
47 | for i in range(batch):
48 | img = tensor[i].cpu().detach().numpy()
49 | img = np.swapaxes(img, 0, 2) # [h, w, c]
50 | img = np.swapaxes(img, 0, 1) # [w, h, c]
51 | images.append(img * 255)
52 | return images
53 |
54 | '''
55 | [W, H, 1] -> [W, H, 3] or [W, H]->[W, H, 3]
56 | :param image: input image
57 | :return: transfer image
58 | '''
59 | def fill_gray(image):
60 | shape = image.shape
61 | if len(shape) == 2:
62 | image = image[:, :, np.newaxis]
63 | shape = image.shape
64 | if shape[2] == 1:
65 | return np.pad(image, ((0, 0), (0, 0), (1, 1)), 'edge')
66 | elif shape[2] == 3:
67 | return np.mean(image, axis=2)
68 | return image
69 |
70 | '''
71 | imitator 快照
72 | :param path: save path
73 | :param tensor1: input photo
74 | :param tensor2: generated image
75 | :param parse: parse checkpoint's path
76 | '''
77 | def capture(path, tensor1, tensor2, parse, cuda):
78 | img1 = tensor_2_image(tensor1)[0].swapaxes(0, 1).astype(np.uint8)
79 | img2 = tensor_2_image(tensor2)[0].swapaxes(0, 1).astype(np.uint8)
80 | img1 = cv2.resize(img1, (512, 512), interpolation=cv2.INTER_LINEAR)
81 | img3 = faceparsing_ndarray(img1, parse, cuda)
82 | img4 = img_edge(img3)
83 | img4 = 255 - fill_gray(img4)
84 | image = merge_4image(img1, img2, img3, img4, transpose=False)
85 | cv2.imwrite(path, image)
86 |
87 | def merge_4image(image1, image2, image3, image4, size=512, show=False, transpose=True):
88 | """
89 | 拼接图片
90 | :param image1: input image1, numpy array
91 | :param image2: input image2, numpy array
92 | :param image3: input image3, numpy array
93 | :param image4: input image4, numpy array
94 | :param size: 输出分辨率
95 | :param show: 窗口显示
96 | :param transpose: 转置长和宽 cv2顺序[H, W, C]
97 | :return: merged image
98 | """
99 | size_ = (int(size / 2), int(size / 2))
100 | img_1 = cv2.resize(image1, size_)
101 | # cv2.imshow("img1", img_1)
102 | # cv2.waitKey()
103 |
104 | img_2 = cv2.resize(image2, size_)
105 | # cv2.imshow("img2", img_2)
106 | # cv2.waitKey()
107 |
108 | img_3 = cv2.resize(image3, size_)
109 | # cv2.imshow("img3", img_3)
110 | # cv2.waitKey()
111 |
112 | img_4 = cv2.resize(image4, size_)
113 | # cv2.imshow("img4", img_4)
114 | # cv2.waitKey()
115 |
116 | image1_ = np.append(img_1, img_2, axis=1) # axis=1,行
117 | image2_ = np.append(img_3, img_4, axis=1)
118 | image = np.append(image1_, image2_, axis=0)
119 | if transpose:
120 | image = image.swapaxes(0, 1)
121 | if show:
122 | cv2.imshow("contact", image)
123 | cv2.waitKey()
124 | cv2.destroyAllWindows()
125 | return image
126 |
127 | '''
128 | evaluate with numpy array
129 | :param input: numpy array, 注意一定要是np.uint8, 而不是np.float32 [H, W, C]
130 | :param cp: args.parsing_checkpoint, str
131 | :param cuda: use gpu to speedup
132 | '''
133 | def faceparsing_ndarray(input, checkpoint, cuda=False):
134 | # 构建BiSeNet并加载模型
135 | bsnet = BiSeNet(n_classes=19)
136 | if cuda:
137 | bsnet.cuda()
138 | bsnet.load_state_dict(torch.load(checkpoint))
139 | else:
140 | bsnet.load_state_dict(torch.load(checkpoint, map_location="cpu"))
141 | bsnet.eval()
142 |
143 | to_tensor = transforms.Compose(
144 | [
145 | transforms.ToTensor(), # [H, W, C]->[C, H, W]
146 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
147 | ])
148 |
149 | # input_ = _to_tensor_(input)
150 | input_ = to_tensor(input)
151 | input_ = torch.unsqueeze(input_, 0)
152 |
153 | if cuda:
154 | input_ = input_.cuda()
155 | out = bsnet(input_)
156 | parsing = out.squeeze(0).cpu().detach().numpy().argmax(0)
157 | return vis_parsing_maps(input, parsing, stride=1)
158 |
159 | '''
160 | 结果可视化
161 | '''
162 | def vis_parsing_maps(im, parsing, stride):
163 | """
164 | # 显示所有部位
165 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 0, 85], [255, 0, 170], [0, 255, 0], [85, 255, 0],
166 | [170, 255, 0], [0, 255, 85], [0, 255, 170], [0, 0, 255], [85, 0, 255], [170, 0, 255], [0, 85, 255],
167 | [0, 170, 255], [255, 255, 0], [255, 255, 85], [255, 255, 170], [255, 0, 255], [255, 85, 255],
168 | [255, 170, 255], [0, 255, 255], [85, 255, 255], [170, 255, 255]]
169 | """
170 | # 只显示脸 鼻子 眼睛 眉毛 嘴巴
171 | part_colors = [[255, 255, 255], [255, 85, 0], [25, 170, 0], [255, 170, 0], [254, 0, 170], [254, 0, 170],
172 | [255, 255, 255],
173 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [0, 0, 254], [85, 0, 255], [170, 0, 255],
174 | [0, 85, 255],
175 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255],
176 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255]]
177 | """
178 | part_colors = [[255, 255, 255], [脸], [左眉], [右眉], [左眼], [右眼],
179 | [255, 255, 255],
180 | [左耳], [右耳], [255, 255, 255], [鼻子], [牙齿], [上唇],
181 | [下唇],
182 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255],
183 | [255, 255, 255], [255, 255, 255], [255, 255, 255], [255, 255, 255]]
184 | """
185 |
186 | im = np.array(im)
187 | vis_parsing = parsing.copy().astype(np.uint8)
188 | vis_parsing_anno_color = np.zeros((vis_parsing.shape[0], vis_parsing.shape[1], 3)) + 255
189 | num_of_class = np.max(vis_parsing)
190 | for pi in range(1, num_of_class + 1):
191 | index = np.where(vis_parsing == pi)
192 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
193 |
194 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
195 | return vis_parsing_anno_color
196 |
197 | '''
198 | 论文里的判别损失, 判断真实照片和由模拟器生成的图像是否属于同一个身份
199 | Discriminative Loss 使用余弦距离
200 | https://www.cnblogs.com/dsgcBlogs/p/8619566.html
201 | :param lightcnn_inst: lightcnn model instance
202 | :param img1: generated by engine, type: list of Tensor
203 | :param img2: generated by imitator, type: list of Tensor
204 | :return tensor scalar,余弦相似度
205 | '''
206 | def discriminative_loss(img1, img2, lightcnn_inst):
207 | x1 = batch_feature256(img1, lightcnn_inst) # [1, 256]
208 | x2 = batch_feature256(img2, lightcnn_inst) # [1, 256]
209 | distance = torch.cosine_similarity(x1, x2)
210 | return torch.mean(distance)
211 |
212 | def batch_feature256(img, lightcnn_inst):
213 | """
214 | 使用light cnn提取256维特征参数
215 | :param lightcnn_inst: lightcnn model instance
216 | :param img: tensor 输入图片 shape:(batch, 1, 512, 512)
217 | :return: 256维特征参数 tensor [batch, 256]
218 | """
219 | _, features = lightcnn_inst(img)
220 | # log.debug("features shape:{0} {1} {2}".format(features.size(), features.requires_grad, img.requires_grad))
221 | return features
222 |
223 | '''
224 | capture for result
225 | :param x: generated image with grad, torch tensor [b,params]
226 | :param refer: reference picture: [3, 512, 512]
227 | :param step: train step
228 | '''
229 | def eval_output(imitator, x, refer, step, prev_path, L2_c):
230 | eval_write(x)
231 | y_ = imitator(x)
232 | y_ = y_.cpu().detach().numpy()
233 | y_ = np.squeeze(y_, axis=0)
234 | y_ = np.swapaxes(y_, 0, 2) * 255
235 | y_ = y_.astype(np.uint8) # [512, 512, 3]
236 | im1 = L2_c[0] # [512, 512]
237 | im2 = L2_c[1] # [512, 512]
238 | # np_im1 = im1.cpu().detach().numpy()
239 | # np_im2 = im2.cpu().detach().numpy()
240 | # f_im1 = fill_gray(np_im1) # [512, 512, 3],灰度图
241 | # f_im2 = fill_gray(np_im2) # [512, 512, 3]
242 | f_im1 = im1 # [512, 512, 3],这里直接显示原图
243 | f_im2 = im2 # [512, 512, 3]
244 |
245 | # refer 改为channel last的
246 | refer = np.transpose(refer, [1, 2, 0]) # [512, 512, 3]
247 | # print("f_im1:", type(f_im1), f_im1.shape)
248 | image_ = merge_4image(refer, y_, f_im1, f_im2, transpose=False)
249 | path = os.path.join(prev_path, "eval_{0}.jpg".format(step))
250 | cv2.imwrite(path, image_)
251 |
252 | '''
253 | 生成二进制文件 能够在unity里还原出来
254 | :param params: 捏脸参数 tensor [batch, params_cnt]
255 | '''
256 | def eval_write(params):
257 | np_param = params.cpu().detach().numpy()
258 | np_param = np_param[0]
259 | list_param = np_param.tolist()
260 | dataset = config.train_set
261 | shape = curr_roleshape(dataset)
262 | path = os.path.join(config.model_path, "eval.bytes")
263 | f = open(path, 'wb')
264 | write_layer(f, shape, list_param)
265 | f.close()
266 |
267 | '''
268 | 判断当前运行的是roleshape (c# RoleShape)
269 | :param dataset: args path_to_dataset
270 | :return: RoleShape
271 | '''
272 | def curr_roleshape(dataset):
273 | if dataset.find("female") >= 0:
274 | return 4
275 | else:
276 | return 3
277 |
278 | def write_layer(f, shape, args):
279 | f.write(struct.pack('i', shape))
280 | for it in args:
281 | byte = struct.pack('f', it)
282 | f.write(byte)
283 |
284 | '''
285 | One-hot编码 argmax 处理
286 | :param params: 处理params
287 | :param start: One-hot 偏移起始地址
288 | :param count: One-hot 编码长度
289 | '''
290 | def argmax_params(params, start, count):
291 | dims = params.size()[0]
292 | for dim in range(dims):
293 | tmp = params[dim, start]
294 | mx = start
295 | for idx in range(start + 1, start + count):
296 | if params[dim, idx] > tmp:
297 | mx = idx
298 | tmp = params[dim, idx]
299 | for idx in range(start, start + count):
300 | params[dim, idx] = 1. if idx == mx else 0
301 |
302 | # plot loss
303 | def eval_plot(losses):
304 | count = len(losses)
305 | if count > 0:
306 | plt.style.use('seaborn-whitegrid')
307 | x = range(count)
308 | y1 = []
309 | y2 = []
310 | for it in losses:
311 | y1.append(it[0])
312 | y2.append(it[1])
313 | plt.plot(x, y1, color='r', label='1-L1')
314 | plt.plot(x, y2, color='g', label='L2')
315 | plt.ylabel("loss")
316 | plt.xlabel('step')
317 | plt.legend()
318 | path = os.path.join(config.prev_path, "loss.png")
319 | plt.savefig(path)
320 | plt.close('all')
321 |
322 | '''
323 | dlib检测68个关键点
324 | :param img BGR三通道图
325 | '''
326 | def detect_face_keypoint(img):
327 | # 取灰度
328 | img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
329 | rects = config.detector(img_gray, 0)
330 | for i in range(len(rects)):
331 | landmarks = np.matrix([[p.x, p.y] for p in config.predictor(img, rects[i]).parts()])
332 | for idx, point in enumerate(landmarks):
333 | pos = (point[0, 0], point[0, 1])
334 | print(idx, pos)
335 |
336 | cv2.circle(img, pos, 5, color=(0, 255, 0))
337 | # 利用cv2.putText输出1-68
338 | font = cv2.FONT_HERSHEY_SIMPLEX
339 | cv2.putText(img, str(idx + 1), pos, font, 0.8, (0, 0, 255), 1, cv2.LINE_AA)
340 | cv2.imshow("img", img)
341 | cv2.waitKey(0)
342 |
343 | '''
344 | 解决argmax()不可微的问题,自定义ArgMax
345 | 原理:argmax = softmax + c,c为常数
346 | '''
347 | class ArgMax(torch.autograd.Function):
348 | @staticmethod
349 | def forward(input):
350 | idx = torch.argmax(input, 0)
351 | output = torch.zeros_like(input)
352 | output.scatter(1, idx, 1)
353 | return output
354 |
355 | @staticmethod
356 | def backward(grad_output):
357 | return grad_output
358 |
--------------------------------------------------------------------------------