├── .gitignore
├── assets
├── 11.png
├── 22.png
└── show.png
├── model
├── AlignModule
│ ├── lib
│ │ ├── __init__.py
│ │ ├── ExpEncoder.py
│ │ ├── PoseEncoder.py
│ │ ├── PorEncoder.py
│ │ ├── IDEncoder.py
│ │ ├── generator.py
│ │ └── blocks.py
│ ├── config.py
│ ├── loss.py
│ ├── discriminator.py
│ └── module.py
├── BlendModule
│ ├── config.py
│ ├── generator.py
│ └── module.py
└── third
│ ├── resnet.py
│ └── model.py
├── process
├── download_weight.sh
├── filter_idfiles.py
├── download_and_process.py
├── process_raw_video.py
└── process_utils.py
├── dataloader
├── DataLoader.py
├── BlendLoader.py
├── AlignLoader.py
└── augmentation.py
├── LICENSE
├── utils
├── visualizer.py
└── utils.py
├── train.py
├── inference.py
├── ReadMe.md
└── trainer
├── ModelTrainer.py
├── BlendTrainer.py
└── AlignTrainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | pretrained_model*
2 | *.pth
3 | checkpoint*
4 | *__pycache__*
5 | dataset
6 |
--------------------------------------------------------------------------------
/assets/11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeslieZhoa/HeSer.Pytorch/HEAD/assets/11.png
--------------------------------------------------------------------------------
/assets/22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeslieZhoa/HeSer.Pytorch/HEAD/assets/22.png
--------------------------------------------------------------------------------
/assets/show.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LeslieZhoa/HeSer.Pytorch/HEAD/assets/show.png
--------------------------------------------------------------------------------
/model/AlignModule/lib/__init__.py:
--------------------------------------------------------------------------------
1 | from model.AlignModule.lib.ExpEncoder import ExpEncoder
2 | from model.AlignModule.lib.IDEncoder import IDEncoder
3 | from model.AlignModule.lib.PorEncoder import PorEncoder
4 | from model.AlignModule.lib.PoseEncoder import PoseEncoder
5 | from model.AlignModule.lib.generator import Generator
--------------------------------------------------------------------------------
/model/AlignModule/lib/ExpEncoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torchvision
3 |
4 | class ExpEncoder(nn.Module):
5 | def __init__(self,args):
6 | super().__init__()
7 | self.exp_encoder = torchvision.models.mobilenet_v2(num_classes=args.exp_embedding_size)
8 |
9 | def forward(self,x):
10 |
11 | return self.exp_encoder(x)
--------------------------------------------------------------------------------
/model/AlignModule/lib/PoseEncoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torchvision
3 |
4 | class PoseEncoder(nn.Module):
5 | def __init__(self,args):
6 | super().__init__()
7 | self.pose_encoder = torchvision.models.mobilenet_v2(num_classes=args.pose_embedding_size)
8 |
9 | def forward(self,x):
10 |
11 | return self.pose_encoder(x)
--------------------------------------------------------------------------------
/process/download_weight.sh:
--------------------------------------------------------------------------------
1 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/parsing.pth -P ../pretrained_models
2 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/model_ir_se50.pth -P ../pretrained_models
3 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/vgg19-d01eb7cb.pth -P ../pretrained_models
4 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/073-00000000.pth -P ../checkpoint/Blender
5 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/417-00000000.pth -P ../checkpoint/Aligner
--------------------------------------------------------------------------------
/process/filter_idfiles.py:
--------------------------------------------------------------------------------
1 | import os
2 | from shutil import rmtree
3 |
4 | def rm_little_files(base,th=100):
5 | k = 1
6 | for idname in os.listdir(base):
7 | id_path = os.path.join(base,idname)
8 | for video_clip in os.listdir(id_path):
9 | path = os.path.join(id_path,video_clip)
10 | length = len(os.listdir(path))
11 | print('\rhave done %04d'%k,end='',flush=True)
12 | k += 1
13 | if length < th:
14 | rmtree(path)
15 | print()
16 |
17 | if __name__ == "__main__":
18 | base = '../dataset/process/img'
19 | rm_little_files(base)
--------------------------------------------------------------------------------
/model/AlignModule/lib/PorEncoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torchvision
3 | class PorEncoder(nn.Module):
4 | def __init__(self,args):
5 | super().__init__()
6 | self.por_encoder = torchvision.models.resnext50_32x4d(num_classes=args.por_embedding_size)
7 |
8 | def forward(self,x):
9 |
10 | batch = x.shape[0]
11 | if len(x.shape) > 4:
12 | x = x.view(-1,*x.shape[2:])
13 | feat = self.por_encoder(x)
14 | if feat.shape[0] != batch:
15 | feat = feat.view(batch,feat.shape[0]//batch,-1)
16 | feat = feat.mean(1)
17 |
18 | return feat
19 |
--------------------------------------------------------------------------------
/dataloader/DataLoader.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author Leslie
5 | @date 20220812
6 | '''
7 |
8 |
9 | from torch.utils.data import Dataset
10 | import torch.distributed as dist
11 |
12 |
13 | class DatasetBase(Dataset):
14 | def __init__(self,slice_id=0,slice_count=1,use_dist=False,**kwargs):
15 |
16 | if use_dist:
17 | slice_id = dist.get_rank()
18 | slice_count = dist.get_world_size()
19 | self.id = slice_id
20 | self.count = slice_count
21 |
22 |
23 | def __getitem__(self,i):
24 | pass
25 |
26 |
27 |
28 |
29 | def __len__(self):
30 | return 1000
31 |
32 |
--------------------------------------------------------------------------------
/model/AlignModule/lib/IDEncoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from model.AlignModule.module import Backbone
4 |
5 | class IDEncoder(nn.Module):
6 | def __init__(self,model_path):
7 | super(IDEncoder, self).__init__()
8 | print('Loading ResNet ArcFace')
9 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
10 | self.facenet.load_state_dict(torch.load(model_path))
11 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
12 | self.facenet.eval()
13 | for module in [self.facenet, self.face_pool]:
14 | for param in module.parameters():
15 | param.requires_grad = False
16 |
17 |
18 | def forward(self, x):
19 | batch = x.shape[0]
20 | if len(x.shape) > 4:
21 | x = x.view(-1,*x.shape[2:])
22 | feat = self.facenet(self.face_pool(x))
23 | if feat.shape[0] != batch:
24 | feat = feat.view(batch,feat.shape[0]//batch,-1)
25 | feat = feat.mean(1)
26 |
27 | return feat
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 LeslieZhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/model/BlendModule/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | @author Leslie
3 | @date 20220823
4 | '''
5 | class Params:
6 | def __init__(self):
7 |
8 | self.name = 'Blender'
9 | self.pretrain_path = 'checkpoint/Blender/025-00000000.pth'
10 | self.size = 512
11 |
12 | self.train_root = 'dataset/process/img'
13 | self.val_root = 'dataset/process/img'
14 |
15 | self.f_in_channels = 512
16 | self.f_inter_channels = 256
17 | self.temperature = 0.001
18 | self.dilate_kernel = 17
19 | self.decoder_ic = 12
20 |
21 | # discriminator
22 | self.embed_channels = 512
23 | self.padding = 'zero'
24 | self.in_channels = 5
25 | self.out_channels = 3
26 | self.num_channels = 64
27 | self.max_num_channels = 512
28 | self.output_image_size = 512
29 | self.dis_num_blocks = 7
30 |
31 | self.per_model = 'pretrained_models/vgg19-d01eb7cb.pth'
32 |
33 | # loss
34 | self.rec_loss = True
35 | self.per_loss = True
36 | self.lambda_gan = 0.2
37 | self.lambda_rec = 1.0
38 | self.lambda_per = 0.0005
39 |
40 | self.g_lr = 1e-4
41 | self.d_lr = 4e-4
42 | self.beta1 = 0.9
43 | self.beta2 = 0.999
--------------------------------------------------------------------------------
/model/AlignModule/config.py:
--------------------------------------------------------------------------------
1 | class Params:
2 | def __init__(self):
3 |
4 | self.name = 'Aligner'
5 | self.pretrain_path = None
6 | # self.pretrain_path = None
7 | self.size = 512
8 |
9 | self.train_root = 'dataset/select-align/img'
10 | self.val_root = 'dataset/select-align/img'
11 | self.use_pixelwise_augs = True
12 | self.use_affine_scale = True
13 | self.use_affine_shift = True
14 | self.frame_num = 5
15 | self.skip_frame = 5
16 |
17 | self.identity_embedding_size = 512
18 | self.pose_embedding_size = 256
19 | self.por_embedding_size = 512
20 | self.exp_embedding_size = 256
21 | self.embed_channels = 512
22 | self.padding = 'zero'
23 | self.in_channels = 3
24 | self.out_channels = 3
25 | self.num_channels = 64
26 | self.max_num_channels = 512
27 | self.norm_layer = 'in'
28 | self.gen_constant_input_size = 4
29 | self.gen_num_residual_blocks = 2
30 | self.output_image_size = 512
31 |
32 | self.dis_num_blocks = 7
33 |
34 | self.id_model = 'pretrained_models/model_ir_se50.pth'
35 | self.per_model = 'pretrained_models/vgg19-d01eb7cb.pth'
36 |
37 | self.rec_loss = True
38 | self.id_loss = True
39 | self.per_loss = True
40 | self.lambda_gan = 2
41 | self.lambda_rec = 10
42 | self.lambda_id = 2
43 | self.lambda_per = 0.002
44 |
45 | self.g_lr = 1e-3
46 | self.d_lr = 4e-3
47 | self.beta1 = 0.9
48 | self.beta2 = 0.999
--------------------------------------------------------------------------------
/utils/visualizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. ALL rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import os
7 | import time
8 | import subprocess
9 |
10 | from tensorboardX import SummaryWriter
11 |
12 |
13 |
14 | class Visualizer:
15 | def __init__(self,opt,mode='train'):
16 | self.opt = opt
17 | self.name = opt.name
18 | self.mode = mode
19 | self.train_log_dir = os.path.join(opt.checkpoint_path,"logs/%s"%mode)
20 | self.log_name = os.path.join(opt.checkpoint_path,'loss_log_%s.txt'%mode)
21 | if opt.local_rank == 0:
22 | if not os.path.exists(self.train_log_dir):
23 | os.makedirs(self.train_log_dir)
24 |
25 | self.train_writer = SummaryWriter(self.train_log_dir)
26 |
27 | self.log_file = open(self.log_name,"a")
28 | now = time.strftime("%c")
29 | self.log_file.write('================ Training Loss (%s) =================\n'%now)
30 | self.log_file.flush()
31 |
32 |
33 | # errors:dictionary of error labels and values
34 | def plot_current_errors(self,errors,step):
35 |
36 | for tag,value in errors.items():
37 |
38 | self.train_writer.add_scalar("%s/"%self.name+tag,value,step)
39 | self.train_writer.flush()
40 |
41 |
42 | # errors: same format as |errors| of CurrentErrors
43 | def print_current_errors(self,epoch,i,errors,t):
44 | message = '(epoch: %d\t iters: %d\t time: %.5f)\t'%(epoch,i,t)
45 | for k,v in errors.items():
46 |
47 | message += '%s: %.5f\t' %(k,v)
48 |
49 | print(message)
50 |
51 | self.log_file.write('%s\n' % message)
52 | self.log_file.flush()
53 |
54 | def display_current_results(self, visuals, step):
55 | if visuals is None:
56 | return
57 | for label, image in visuals.items():
58 | # Write the image to a string
59 |
60 | self.train_writer.add_image("%s/"%self.name+label,image,global_step=step)
61 |
62 | def close(self):
63 |
64 | self.train_writer.close()
65 | self.log_file.close()
--------------------------------------------------------------------------------
/process/download_and_process.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | import pdb
4 | from multiprocessing import Process
5 | import multiprocessing as mp
6 | import subprocess
7 | import numpy as np
8 |
9 | def download_video(q):
10 | k = 1
11 | while True:
12 | vid,save_base = q.get()
13 | if vid is None:
14 | break
15 | cmd = 'yt-dlp -f bestvideo[ext=mp4]+bestaudio[ext=m4a]/bestvideo+bestaudio \
16 | https://www.youtube.com/watch?v={vid} \
17 | --merge-output-format mp4 \
18 | --output {save_base}/{vid}.mp4 \
19 | --external-downloader aria2c \
20 | --downloader-args aria2c:"-x 16 -k 1M"'.format(vid=vid,save_base=save_base)
21 |
22 | subprocess.call(cmd,shell=True)
23 | print('\rhave done %06d'%k,end='',flush=True)
24 | k += 1
25 | print()
26 |
27 | def get_frames(q):
28 |
29 | while True:
30 | path,save_path = q.get()
31 | if path is None:
32 | break
33 | save_base = os.path.split(save_path)[0]
34 | os.makedirs(save_base,exist_ok=True)
35 | with open(path,'r') as f:
36 | lines = f.readlines()
37 | filter_lines = list(filter(lambda x:x.startswith('0'),lines))
38 | frames = list(map(lambda x:x.strip().split(),filter_lines))
39 |
40 | np.save(save_path,frames)
41 |
42 | def read_file(base,save,q1,q2):
43 | for idname in os.listdir(base):
44 | idpath = os.path.join(base,idname)
45 | for videoname in os.listdir(idpath):
46 | q1.put([videoname,os.path.join(save,idname,videoname)])
47 | videopath = os.path.join(idpath,videoname)
48 |
49 | for i,infoname in enumerate(os.listdir(videopath)):
50 | infopath = os.path.join(videopath,infoname)
51 | q2.put([infopath,os.path.join(save,idname,videoname,'%02d.npy'%i)])
52 |
53 | if __name__ == "__main__":
54 | base = '../dataset/vox2_test_txt/txt'
55 | save = '../dataset/voceleb2'
56 | mp.set_start_method('spawn')
57 | m = mp.Manager()
58 | queue1 = m.Queue()
59 | queue2 = m.Queue()
60 |
61 | read_p = Process(target=read_file,args=(base,save,queue1,queue2,))
62 | download_p = Process(target=download_video,args=(queue1,))
63 | frame_p = Process(target=get_frames,args=(queue2,))
64 |
65 | read_p.start()
66 | download_p.start()
67 | frame_p.start()
68 |
69 | read_p.join()
70 | queue1.put([None,None])
71 | queue2.put([None,None])
72 | download_p.join()
73 | frame_p.join()
74 |
--------------------------------------------------------------------------------
/dataloader/BlendLoader.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author Leslie
5 | @date 20220812
6 | '''
7 |
8 | import os
9 |
10 | from torchvision import transforms
11 | import PIL.Image as Image
12 | from dataloader.DataLoader import DatasetBase
13 | import random
14 | import torchvision.transforms.functional as F
15 |
16 |
17 | class BlendData(DatasetBase):
18 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
19 | super().__init__(slice_id, slice_count,dist, **kwargs)
20 |
21 |
22 | self.transform = transforms.Compose([
23 | transforms.Resize((kwargs['size'], kwargs['size'])),
24 | transforms.ToTensor()
25 | ])
26 | self.color_fn = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.1)])
27 |
28 |
29 | self.norm = transforms.Compose([transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
30 |
31 | self.gray = transforms.Compose([transforms.Grayscale(num_output_channels=1)])
32 |
33 | # source root
34 | root = kwargs['root']
35 | self.paths = [os.path.join(root,f) for f in os.listdir(root)]
36 | self.length = len(self.paths)
37 | random.shuffle(self.paths)
38 | self.eval = kwargs['eval']
39 |
40 |
41 |
42 | def __getitem__(self,i):
43 |
44 | idx = i % self.length
45 | id_path = self.paths[idx]
46 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)]
47 | vIdx = random.randint(0, len(video_paths) - 1)
48 | video_path = video_paths[vIdx]
49 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)]
50 | img_idx = random.randint(0, len(img_paths) - 1)
51 | img_path = img_paths[img_idx]
52 | mask_path = img_path.replace('img','mask')
53 |
54 | idx = (i + random.randint(0,self.length-1)) % self.length
55 | id_path = self.paths[idx]
56 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)]
57 | vIdx = random.randint(0, len(video_paths) - 1)
58 | video_path = video_paths[vIdx]
59 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)]
60 | img_idx = random.randint(0, len(img_paths) - 1)
61 | ex_img_path = img_paths[img_idx]
62 | ex_mask_path = ex_img_path.replace('img','mask')
63 |
64 |
65 | with Image.open(img_path) as img:
66 | gt = self.transform(img.convert('RGB'))
67 | I_a = self.transform(self.color_fn(img.convert('RGB')))
68 | gt = self.norm(gt)
69 | I_a = self.norm(I_a)
70 | with Image.open(mask_path) as img:
71 | M_a = self.transform(img.convert('L')) * 255
72 |
73 |
74 | I_gray = self.gray(I_a)
75 | if random.random() > 0.3:
76 | I_t = F.hflip(gt)
77 | M_t = F.hflip(M_a)
78 |
79 | else:
80 | I_t = gt
81 | M_t = M_a
82 |
83 | with Image.open(ex_img_path) as img:
84 | hat_t = self.transform(img.convert('RGB'))
85 | hat_t = self.norm(hat_t)
86 | with Image.open(ex_mask_path) as img:
87 | M_hat = self.transform(img.convert('L')) * 255
88 |
89 | return I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt
90 |
91 | def __len__(self):
92 | if self.eval:
93 | return 1000
94 | else:
95 | # return self.length
96 | return max(self.length,1000)
97 |
98 |
--------------------------------------------------------------------------------
/model/AlignModule/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from collections import OrderedDict
4 | import torchvision
5 | import torch.nn.functional as F
6 |
7 | def compute_dis_loss(fake_pred,real_pred,D_loss):
8 | # d_real = torch.relu(1. - real_pred).mean()
9 | # d_fake = torch.relu(1. + fake_pred).mean()
10 | d_real = F.mse_loss(real_pred,torch.ones_like(real_pred))
11 | d_fake = F.mse_loss(fake_pred, torch.zeros_like(fake_pred))
12 |
13 | D_loss['d_real'] = d_real
14 | D_loss['d_fake'] = d_fake
15 | return d_real + d_fake
16 |
17 | def compute_gan_loss(fake_pred):
18 |
19 | return F.mse_loss(fake_pred,torch.ones_like(fake_pred))
20 |
21 | def compute_id_loss(fake_id_f,real_id_f):
22 | return 1.0 - torch.cosine_similarity(fake_id_f,real_id_f, dim = 1)
23 |
24 |
25 | # Perceptual loss that uses a pretrained VGG network
26 |
27 | class Flatten(nn.Module):
28 | def __init__(self):
29 | super().__init__()
30 |
31 | def forward(self, x):
32 | return x.view(-1)
33 |
34 | class PerceptualLoss(nn.Module):
35 | def __init__(self,model_path, normalize_grad=False):
36 | super().__init__()
37 |
38 | self.normalize_grad = normalize_grad
39 |
40 |
41 | vgg_weights = torch.load(model_path)
42 |
43 | map = {'classifier.6.weight': u'classifier.7.weight', 'classifier.6.bias': u'classifier.7.bias'}
44 | vgg_weights = OrderedDict([(map[k] if k in map else k, v) for k, v in vgg_weights.items()])
45 |
46 | model = torchvision.models.vgg19()
47 | model.classifier = nn.Sequential(Flatten(), *model.classifier._modules.values())
48 |
49 | model.load_state_dict(vgg_weights)
50 |
51 | model = model.features
52 |
53 | mean = torch.tensor([103.939, 116.779, 123.680]) / 255.
54 | std = torch.tensor([1., 1., 1.]) / 255.
55 |
56 | num_layers = 30
57 |
58 | self.register_buffer('mean', mean[None, :, None, None])
59 | self.register_buffer('std' , std[None, :, None, None])
60 |
61 | layers_avg_pooling = []
62 |
63 | for weights in model.parameters():
64 | weights.requires_grad = False
65 |
66 | for module in model.modules():
67 | if module.__class__.__name__ == 'Sequential':
68 | continue
69 | elif module.__class__.__name__ == 'MaxPool2d':
70 | layers_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0))
71 | else:
72 | layers_avg_pooling.append(module)
73 |
74 | if len(layers_avg_pooling) >= num_layers:
75 | break
76 |
77 | layers_avg_pooling = nn.Sequential(*layers_avg_pooling)
78 |
79 | self.model = layers_avg_pooling
80 |
81 | def normalize_inputs(self, x):
82 | return (x - self.mean) / self.std
83 |
84 | def forward(self, input, target):
85 | input = (input + 1) / 2
86 | target = (target.detach() + 1) / 2
87 | loss = 0
88 | features_input = self.normalize_inputs(input)
89 | features_target = self.normalize_inputs(target)
90 |
91 | for layer in self.model:
92 | features_input = layer(features_input)
93 | features_target = layer(features_target)
94 |
95 | if layer.__class__.__name__ == 'ReLU':
96 | if self.normalize_grad:
97 | pass
98 | else:
99 | loss = loss + F.l1_loss(features_input, features_target)
100 |
101 | return loss
102 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | @author Leslie
3 | @date 20220812
4 | '''
5 | import os
6 |
7 | import argparse
8 |
9 |
10 | from trainer.AlignTrainer import AlignTrainer
11 | from trainer.BlendTrainer import BlendTrainer
12 | import torch.distributed as dist
13 | from utils.utils import setup_seed,get_data_loader,merge_args
14 | from model.AlignModule.config import Params as AlignParams
15 | from model.BlendModule.config import Params as BlendParams
16 |
17 | # torch.multiprocessing.set_start_method('spawn')
18 |
19 | parser = argparse.ArgumentParser(description="HeSer")
20 | #---------train set-------------------------------------
21 | parser.add_argument('--model',default="align",help='')
22 | parser.add_argument('--isTrain',action="store_false",help='')
23 | parser.add_argument('--dist',action="store_false",help='')
24 | parser.add_argument('--batch_size',default=16,type=int)
25 | parser.add_argument('--seed',default=10,type=int)
26 | parser.add_argument('--eval',default=1,type=int,help='whether use eval')
27 | parser.add_argument('--nDataLoaderThread',default=5,type=int,help='Num of loader threads')
28 | parser.add_argument('--print_interval',default=100,type=int)
29 | parser.add_argument('--test_interval',default=100,type=int,help='Test and save every [test_intervaal] epochs')
30 | parser.add_argument('--save_interval',default=100,type=int,help='save model interval')
31 | parser.add_argument('--stop_interval',default=20,type=int)
32 | parser.add_argument('--begin_it',default=0,type=int,help='begin epoch')
33 | parser.add_argument('--mx_data_length',default=100,type=int,help='max data length')
34 | parser.add_argument('--max_epoch',default=10000,type=int)
35 | parser.add_argument('--early_stop',action="store_true",help='')
36 | parser.add_argument('--scratch',action="store_true",help='')
37 | #---------path set--------------------------------------
38 | parser.add_argument('--checkpoint_path',default='checkpoint',type=str)
39 | parser.add_argument('--pretrain_path',default=None,type=str)
40 |
41 | # ------optimizer set--------------------------------------
42 | parser.add_argument('--lr',default=0.002,type=float,help="Learning rate")
43 |
44 | parser.add_argument(
45 | '--local_rank',
46 | type=int,
47 | default=0,
48 | help='Local rank passed from distributed launcher'
49 | )
50 |
51 | args = parser.parse_args()
52 |
53 | def train_net(args):
54 | train_loader,test_loader,mx_length = get_data_loader(args)
55 |
56 | args.mx_data_length = mx_length
57 | if args.model == 'align':
58 | trainer = AlignTrainer(args)
59 | elif args.model == 'blend':
60 | trainer = BlendTrainer(args)
61 |
62 |
63 | trainer.train_network(train_loader,test_loader)
64 |
65 | if __name__ == "__main__":
66 |
67 | args = parser.parse_args()
68 |
69 | if args.model == 'align':
70 | params = AlignParams()
71 |
72 | elif args.model == 'blend':
73 | params = BlendParams()
74 | args = merge_args(args,params)
75 | if args.dist:
76 | dist.init_process_group(backend="nccl") # backbend='nccl'
77 | dist.barrier() # 用于同步训练
78 | args.world_size = dist.get_world_size() # 一共有几个节点
79 | args.rank = dist.get_rank() # 当前节点编号
80 |
81 | else:
82 | args.world_size = 1
83 | args.rank = 0
84 |
85 | setup_seed(args.seed+args.rank)
86 | print(args)
87 |
88 | args.checkpoint_path = os.path.join(args.checkpoint_path,args.name)
89 |
90 | print("local_rank %d | rank %d | world_size: %d"%(int(os.environ.get('LOCAL_RANK','0')),args.rank,args.world_size))
91 | if args.rank == 0 :
92 | if not os.path.exists(args.checkpoint_path):
93 | os.makedirs(args.checkpoint_path)
94 | print("make dir: ",args.checkpoint_path)
95 | train_net(args)
96 |
97 |
98 |
99 |
--------------------------------------------------------------------------------
/process/process_raw_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from face_alignment.detection.sfd import FaceDetector
5 | import face_alignment
6 | import torch
7 | import math
8 | import cv2
9 | import pdb
10 | from multiprocessing import Process
11 | import multiprocessing as mp
12 | from process_utils import *
13 | import pdb
14 |
15 | def get_video_info(base,save_base,q):
16 |
17 | for idname in os.listdir(base):
18 | idpath = os.path.join(base,idname)
19 | save_path = os.path.join(save_base,idname)
20 | for videoname in os.listdir(idpath):
21 | videopath = os.path.join(idpath,videoname)
22 | frame_names = [os.path.join(videopath,f) for f in os.listdir(videopath) if f.endswith('.mp4')]
23 | info_names = [os.path.join(videopath,f) for f in os.listdir(videopath) if f.endswith('.npy')]
24 | if len(frame_names) == 0:
25 | continue
26 | q.put([frame_names[0],info_names,save_path,videoname])
27 |
28 |
29 |
30 | def process_frame(q1,align=True,scale=1.8,size=512):
31 | face_detector = FaceDetector(device='cuda')
32 | lmk_detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False)
33 | kk = 1
34 | def detect_faces(images):
35 | images = np.stack(images).transpose(0,3,1,2).astype(np.float32)
36 | images_torch = torch.tensor(images)
37 | return face_detector.detect_from_batch(images_torch.cuda())
38 |
39 | while True:
40 |
41 | frame_path,info_names,save_base,videoname = q1.get()
42 | if frame_path is None:
43 | break
44 | video_reader = imageio.get_reader(frame_path)
45 | for k,info_name in enumerate(info_names):
46 | info = np.load(info_name)
47 | save_path = os.path.join(save_base,'%s-%04d'%(videoname,k))
48 | os.makedirs(save_path,exist_ok=True)
49 | for (i,x,y,w,h) in info:
50 | try:
51 | img = video_reader.get_data(int(i))
52 | height,width,_ = img.shape
53 | x,y,w,h = list(map(lambda x:float(x),[x,y,w,h]))
54 | i = int(i)
55 | box = [x*width,y*height,(x+w)*width,(y+h)*height]
56 | if os.path.exists(os.path.join(save_path,'%04d.png'%i)):
57 | continue
58 |
59 | bboxes = detect_faces([img])[0]
60 |
61 | bbox = choose_one_detection(bboxes,box)
62 | if bbox is None:
63 | continue
64 | bbox = bbox[:4]
65 | landmarks = lmk_detector.get_landmarks_from_image(img[...,::-1], [bbox])[0]
66 | image_cropped,_ = crop_with_padding(img,landmarks[:,:2],scale=scale,size=size,align=align)
67 |
68 | cv2.imwrite(os.path.join(save_path,'%04d.png'%i),image_cropped[...,::-1])
69 | print('\r have done %06d'%i,end='',flush=True)
70 | kk += 1
71 | except:
72 | continue
73 |
74 | video_reader.close()
75 | print()
76 |
77 |
78 | if __name__ == "__main__":
79 |
80 | mp.set_start_method('spawn')
81 | m = mp.Manager()
82 | q1 = m.Queue()
83 | base = '../dataset/voceleb2'
84 | save_base = '../dataset/process'
85 | process_num = 2
86 |
87 | info_p = Process(target=get_video_info,args=(base,save_base,q1,))
88 |
89 |
90 | process_list = []
91 | for _ in range(process_num):
92 | process_list.append(Process(target=process_frame,args=(q1,)))
93 |
94 |
95 | info_p.start()
96 | for p in process_list:
97 | p.start()
98 |
99 | info_p.join()
100 |
101 | for _ in range(process_num*2):
102 | q1.put([None,None,None,None])
103 | for p in process_list:
104 | p.join()
105 |
--------------------------------------------------------------------------------
/model/third/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/dataloader/AlignLoader.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author Leslie
5 | @date 20220812
6 | '''
7 |
8 | import os
9 |
10 | from torchvision import transforms
11 | import PIL.Image as Image
12 | from dataloader.DataLoader import DatasetBase
13 | import random
14 | import math
15 | import torch
16 | from dataloader.augmentation import ParametricAugmenter
17 | import numpy as np
18 |
19 |
20 | class AlignData(DatasetBase):
21 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
22 | super().__init__(slice_id, slice_count,dist, **kwargs)
23 |
24 |
25 | self.transform = transforms.Compose([
26 | transforms.Resize((kwargs['size'], kwargs['size'])),
27 | transforms.RandomHorizontalFlip(),
28 | transforms.ToTensor()
29 | ])
30 |
31 | self.aug_fn = ParametricAugmenter(use_pixelwise_augs=kwargs['use_pixelwise_augs'],
32 | use_affine_scale=kwargs['use_affine_scale'],
33 | use_affine_shift=kwargs['use_affine_shift'])
34 |
35 |
36 | self.norm = transforms.Compose([transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
37 |
38 | self.resize = transforms.Compose([
39 | transforms.Resize((256,256))])
40 |
41 | # source root
42 | root = kwargs['root']
43 | self.paths = [os.path.join(root,f) for f in os.listdir(root)]
44 | # dis = math.floor(len(self.paths)/self.count)
45 | # self.paths = self.paths[self.id*dis:(self.id+1)*dis]
46 | self.length = len(self.paths)
47 | random.shuffle(self.paths)
48 | self.frame_num = kwargs['frame_num']
49 | self.skip_frame = kwargs['skip_frame']
50 | self.eval = kwargs['eval']
51 | self.size = kwargs['size']
52 | self.scale = 0.4 / 1.8
53 |
54 |
55 |
56 | def __getitem__(self,i):
57 |
58 | idx = i % self.length
59 | id_path = self.paths[idx]
60 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)]
61 | vIdx = random.randint(0, len(video_paths) - 1)
62 | video_path = video_paths[vIdx]
63 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)]
64 | begin_idx = random.randint(0, len(img_paths) - self.frame_num*self.skip_frame - 1)
65 | img_paths = [img_paths[i]
66 | for i in range(begin_idx,begin_idx+self.frame_num*self.skip_frame,self.skip_frame)]
67 |
68 | s_img_paths = img_paths[:-1]
69 |
70 | t_img_path = img_paths[-1]
71 |
72 | xs = []
73 | for img_path in s_img_paths:
74 | with Image.open(img_path) as img:
75 | xs.append(self.norm(self.transform(img.convert('RGB'))).unsqueeze(0))
76 | xs = torch.cat(xs,0)
77 |
78 | if self.eval:
79 | idx = (i + random.randint(0,self.length-1)) % self.length
80 | id_path = self.paths[idx]
81 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)]
82 | vIdx = random.randint(0, len(video_paths) - 1)
83 | video_path = video_paths[vIdx]
84 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)]
85 | t_img_path = img_paths[random.randint(0, len(img_paths) - 1)]
86 |
87 | with Image.open(t_img_path) as img:
88 | xt = self.transform(img.convert('RGB'))
89 |
90 | mask = np.zeros((self.size,self.size,3),dtype=np.uint8)
91 | mask[int(self.size*self.scale):int(-self.size*self.scale),
92 | int(self.size*self.scale):int(-self.size*self.scale)] = 255
93 | mask = mask[np.newaxis,:]
94 | xt,gt,mask = self.aug_fn.augment_triple(xt,xt,mask)
95 | indexs = torch.where(mask==1)
96 | top = indexs[1].min()
97 | bottom = indexs[1].max()
98 | left = indexs[2].min()
99 | right = indexs[2].max()
100 | crop_xt = xt[...,top:bottom,left:right]
101 | crop_xt = self.norm(crop_xt)
102 | xt = self.norm(xt)
103 | gt = self.norm(gt)
104 |
105 | return self.resize(xs),self.resize(xt),self.resize(crop_xt),gt
106 |
107 |
108 | def __len__(self):
109 | if self.eval:
110 | return 1000
111 | else:
112 | # return self.length
113 | return max(self.length,1000)
114 |
115 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | @author Leslie
4 | @date 20220812
5 | '''
6 | import torch
7 | from dataloader.AlignLoader import AlignData
8 | from dataloader.BlendLoader import BlendData
9 |
10 |
11 | def requires_grad(model, flag=True):
12 | if model is None:
13 | return
14 | for p in model.parameters():
15 | p.requires_grad = flag
16 | def need_grad(x):
17 | x = x.detach()
18 | x.requires_grad_()
19 | return x
20 |
21 | def init_weights(m,init_type='normal', gain=0.02):
22 |
23 | classname = m.__class__.__name__
24 | if classname.find('BatchNorm2d') != -1:
25 | if hasattr(m, 'weight') and m.weight is not None:
26 | torch.nn.init.normal_(m.weight.data, 1.0, gain)
27 | if hasattr(m, 'bias') and m.bias is not None:
28 | torch.nn.init.constant_(m.bias.data, 0.0)
29 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
30 | if init_type == 'normal':
31 | torch.nn.init.normal_(m.weight.data, 0.0, gain)
32 | elif init_type == 'xavier':
33 | torch.nn.init.xavier_normal_(m.weight.data, gain=gain)
34 | elif init_type == 'xavier_uniform':
35 | torch.nn.init.xavier_uniform_(m.weight.data, gain=1.0)
36 | elif init_type == 'kaiming':
37 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
38 | elif init_type == 'orthogonal':
39 | torch.nn.init.orthogonal_(m.weight.data, gain=gain)
40 | elif init_type == 'none': # uses pytorch's default init method
41 | m.reset_parameters()
42 | else:
43 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
44 | if hasattr(m, 'bias') and m.bias is not None:
45 | torch.nn.init.constant_(m.bias.data, 0.0)
46 | def accumulate(model1, model2, decay=0.999):
47 | par1 = dict(model1.named_parameters())
48 | par2 = dict(model2.named_parameters())
49 |
50 | for k in par1.keys():
51 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
52 | def setup_seed(seed):
53 | torch.manual_seed(seed)
54 | if torch.cuda.is_available():
55 | torch.cuda.manual_seed_all(seed)
56 | torch.backends.cudnn.deterministic = True
57 |
58 | def get_data_loader(args):
59 | if args.model == 'align':
60 | train_data = AlignData(dist=args.dist,
61 | size=args.size,
62 | root=args.train_root,
63 | frame_num=args.frame_num,
64 | skip_frame=args.skip_frame,
65 | use_pixelwise_augs=args.use_pixelwise_augs,
66 | use_affine_scale=args.use_affine_scale,
67 | use_affine_shift=args.use_affine_shift,
68 | eval=False)
69 |
70 | test_data = AlignData(dist=args.dist,
71 | size=args.size,
72 | root=args.val_root,
73 | frame_num=args.frame_num,
74 | skip_frame=args.skip_frame,
75 | use_pixelwise_augs=False,
76 | use_affine_scale=False,
77 | use_affine_shift=False,
78 | eval=True)
79 |
80 | elif args.model == 'blend':
81 | train_data = BlendData(dist=args.dist,
82 | size=args.size,
83 | root=args.train_root,eval=False)
84 |
85 | test_data = BlendData(dist=args.dist,
86 | size=args.size,
87 | root=args.val_root,eval=True)
88 |
89 |
90 | train_loader = torch.utils.data.DataLoader(
91 | train_data,
92 | batch_size=args.batch_size,
93 | num_workers=args.nDataLoaderThread,
94 | pin_memory=False,
95 | drop_last=True
96 | )
97 | test_loader = None if test_data is None else \
98 | torch.utils.data.DataLoader(
99 | test_data,
100 | batch_size=args.batch_size,
101 | num_workers=args.nDataLoaderThread,
102 | pin_memory=False,
103 | drop_last=True
104 | )
105 | return train_loader,test_loader,len(train_data)
106 |
107 |
108 |
109 | def merge_args(args,params):
110 | for k,v in vars(params).items():
111 | setattr(args,k,v)
112 | return args
113 |
114 | def convert_img(img,unit=False):
115 |
116 | img = (img + 1) * 0.5
117 | if unit:
118 | return torch.clamp(img*255+0.5,0,255)
119 |
120 | return torch.clamp(img,0,1)
--------------------------------------------------------------------------------
/model/AlignModule/lib/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.utils import spectral_norm
4 | from model.AlignModule.lib import blocks
5 |
6 | import math
7 |
8 | # heavily copy from https://github.com/shrubb/latent-pose-reenactment
9 |
10 | class Constant(nn.Module):
11 | def __init__(self, *shape):
12 | super().__init__()
13 | self.constant = nn.Parameter(torch.ones(1, *shape))
14 |
15 | def forward(self, batch_size):
16 | return self.constant.expand((batch_size,) + self.constant.shape[1:])
17 |
18 |
19 | class Generator(nn.Module):
20 | def __init__(self, args):
21 | super().__init__()
22 |
23 | def get_res_block(in_channels, out_channels, padding, norm_layer):
24 | return blocks.ResBlock(in_channels, out_channels, padding, upsample=False, downsample=False,
25 | norm_layer=norm_layer)
26 |
27 | def get_up_block(in_channels, out_channels, padding, norm_layer):
28 | return blocks.ResBlock(in_channels, out_channels, padding, upsample=True, downsample=False,
29 | norm_layer=norm_layer)
30 |
31 | if args.padding == 'zero':
32 | padding = nn.ZeroPad2d
33 | elif args.padding == 'reflection':
34 | padding = nn.ReflectionPad2d
35 | else:
36 | raise Exception('Incorrect `padding` argument, required `zero` or `reflection`')
37 |
38 | assert math.log2(args.output_image_size / args.gen_constant_input_size).is_integer(), \
39 | "`gen_constant_input_size` must be `image_size` divided by a power of 2"
40 | num_upsample_blocks = int(math.log2(args.output_image_size / args.gen_constant_input_size))
41 | out_channels_block_nonclamped = args.num_channels * (2 ** num_upsample_blocks)
42 | out_channels_block = min(out_channels_block_nonclamped, args.max_num_channels)
43 |
44 | self.constant = Constant(out_channels_block, args.gen_constant_input_size, args.gen_constant_input_size)
45 |
46 |
47 | # Decoder
48 | layers = []
49 | for i in range(args.gen_num_residual_blocks):
50 | layers.append(get_res_block(out_channels_block, out_channels_block, padding, 'ada' + args.norm_layer))
51 |
52 | for _ in range(num_upsample_blocks):
53 | in_channels_block = out_channels_block
54 | out_channels_block_nonclamped //= 2
55 | out_channels_block = min(out_channels_block_nonclamped, args.max_num_channels)
56 | layers.append(get_up_block(in_channels_block, out_channels_block, padding, 'ada' + args.norm_layer))
57 |
58 | layers.extend([
59 | blocks.AdaptiveNorm2d(out_channels_block, args.norm_layer),
60 | nn.ReLU(True),
61 | # padding(1),
62 | spectral_norm(
63 | nn.Conv2d(out_channels_block, args.out_channels, 3, 1, 1),
64 | eps=1e-4),
65 | nn.Tanh()
66 | ])
67 | self.decoder_blocks = nn.Sequential(*layers)
68 |
69 | self.adains = [module for module in self.modules() if module.__class__.__name__ == 'AdaptiveNorm2d']
70 |
71 |
72 | joint_embedding_size = args.identity_embedding_size + args.pose_embedding_size + args.por_embedding_size + args.exp_embedding_size
73 | self.affine_params_projector = nn.Sequential(
74 | spectral_norm(nn.Linear(joint_embedding_size, max(joint_embedding_size, 512))),
75 | nn.ReLU(True),
76 | spectral_norm(nn.Linear(max(joint_embedding_size, 512), self.get_num_affine_params()))
77 | )
78 |
79 |
80 | def get_num_affine_params(self):
81 | return sum(2*module.num_features for module in self.adains)
82 |
83 | def assign_affine_params(self, affine_params):
84 | for m in self.modules():
85 | if m.__class__.__name__ == "AdaptiveNorm2d":
86 | new_bias = affine_params[:, :m.num_features]
87 | new_weight = affine_params[:, m.num_features:2 * m.num_features]
88 |
89 | if m.bias is None: # to keep m.bias being `nn.Parameter`
90 | m.bias = new_bias.contiguous()
91 | else:
92 | m.bias.copy_(new_bias)
93 |
94 | if m.weight is None: # to keep m.weight being `nn.Parameter`
95 | m.weight = new_weight.contiguous()
96 | else:
97 | m.weight.copy_(new_weight)
98 |
99 | if affine_params.size(1) > 2 * m.num_features:
100 | affine_params = affine_params[:, 2 * m.num_features:]
101 |
102 | def assign_embeddings(self, por,id,pose,exp):
103 |
104 | joint_embedding = torch.cat((por,id,pose,exp), dim=1)
105 |
106 | affine_params = self.affine_params_projector(joint_embedding)
107 | self.assign_affine_params(affine_params)
108 |
109 |
110 | def forward(self, por,id,pose,exp):
111 | self.assign_embeddings(por,id,pose,exp)
112 |
113 | batch_size = len(por)
114 | output = self.decoder_blocks(self.constant(batch_size))
115 |
116 | return output
117 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from model.AlignModule.lib import *
2 | from model.BlendModule.generator import Generator as Decoder
3 | from model.AlignModule.config import Params as AlignParams
4 | from model.BlendModule.config import Params as BlendParams
5 | from trainer.AlignTrainer import AlignTrainer
6 | from model.third.model import BiSeNet
7 | import torchvision.transforms.functional as TF
8 | import torch.nn.functional as F
9 | import torch
10 | import cv2
11 | import numpy as np
12 | import pdb
13 |
14 | class Infer:
15 | def __init__(self,align_path,blend_path,parsing_path):
16 | align_params = AlignParams()
17 | blend_params = BlendParams()
18 | self.device = 'cpu'
19 | if torch.cuda.is_available():
20 | self.device = 'cuda'
21 |
22 | self.parsing = BiSeNet(n_classes=19).to(self.device)
23 | self.Epor = PorEncoder(align_params).to(self.device)
24 | self.Eid = IDEncoder(align_params.id_model).to(self.device)
25 | self.Epose = PoseEncoder(align_params).to(self.device)
26 | self.Eexp = ExpEncoder(align_params).to(self.device)
27 | self.netG = Generator(align_params).to(self.device)
28 | self.decoder = Decoder(blend_params).to(self.device)
29 |
30 | self.loadModel(align_path,blend_path,parsing_path)
31 | self.eval_model(self.Epor,self.Eid,self.Epose,self.Eexp,self.netG,self.decoder,self.parsing)
32 | self.mean =torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
33 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
34 |
35 | def run(self,tgt_img_path,src_img_paths):
36 |
37 | tgt_img = cv2.imread(tgt_img_path)
38 | tgt_inp = self.preprocess(tgt_img)
39 |
40 | src_img = cv2.imread(src_img_paths[0])
41 |
42 | src_inp = self.preprocess_multi(src_img_paths)
43 |
44 | gen = self.forward(src_inp,tgt_inp)
45 | gen = self.postprocess(gen[0])
46 | cat_img = np.concatenate([cv2.resize(src_img,[512,512]),
47 | gen,cv2.resize(tgt_img,[512,512])],1)
48 | return cat_img
49 |
50 | def forward(self,xs,xt):
51 | with torch.no_grad():
52 | por_f = self.Epor(xs)
53 | id_f = self.Eid(AlignTrainer.process_id_input(xs,crop=True))
54 |
55 | pose_f = self.Epose(F.adaptive_avg_pool2d(xt,256))
56 | exp_f = self.Eexp(AlignTrainer.process_id_input(xt,crop=True,size=256))
57 |
58 | xg = self.netG(por_f,id_f,pose_f,exp_f)
59 |
60 | M_a = self.parsing(self.preprocess_parsing(xg))
61 | M_t = self.parsing(self.preprocess_parsing(xt))
62 |
63 | M_a = self.postprocess_parsing(M_a)
64 | M_t = self.postprocess_parsing(M_t)
65 | xg_gray = TF.rgb_to_grayscale(xg,num_output_channels=1)
66 | fake = self.decoder(xg,xg_gray,xt,M_a,M_t,xt,train=False)
67 |
68 | return fake
69 |
70 | def preprocess(self,x):
71 | if isinstance(x,str):
72 | x = cv2.imread(x)
73 | x = cv2.resize(x,[512,512])
74 | x = (x[...,::-1].transpose(2,0,1)[np.newaxis,:] / 255 - 0.5) * 2
75 | return torch.from_numpy(x.astype(np.float32)).to(self.device)
76 |
77 | def preprocess_multi(self,xs):
78 | x_list = []
79 | for x in xs:
80 | x = cv2.imread(x)
81 | x = cv2.resize(x,[512,512])
82 | x_list.append((x[...,::-1].transpose(2,0,1)[np.newaxis,:] / 255 - 0.5) * 2)
83 | x_list = np.concatenate(x_list,0)
84 | return torch.from_numpy(x_list.astype(np.float32)).to(self.device).unsqueeze(0)
85 |
86 | def postprocess(self,x):
87 | return (x.permute(1,2,0).cpu().numpy()[...,::-1] + 1) * 127.5
88 |
89 | def preprocess_parsing(self,x):
90 |
91 | return ((x+1)/2.0 - self.mean.view(1,-1,1,1).to(self.device)) / \
92 | self.std.view(1,-1,1,1).to(self.device)
93 |
94 | def postprocess_parsing(self,x):
95 | return torch.argmax(x[0],1).unsqueeze(1).float()
96 |
97 |
98 |
99 | def loadModel(self,align_path,blend_path,parsing_path):
100 | ckpt = torch.load(align_path, map_location=lambda storage, loc: storage)
101 | self.netG.load_state_dict(ckpt['G'],strict=False)
102 | self.Eexp.load_state_dict(ckpt['Eexp'],strict=False)
103 | self.Eid.load_state_dict(ckpt['Eid'],strict=False)
104 | self.Epor.load_state_dict(ckpt['Epor'],strict=False)
105 |
106 | ckpt = torch.load(blend_path, map_location=lambda storage, loc: storage)
107 | self.decoder.load_state_dict(ckpt['G'],strict=False)
108 |
109 | self.parsing.load_state_dict(torch.load(parsing_path))
110 |
111 |
112 | def eval_model(self,*args):
113 | for arg in args:
114 | arg.eval()
115 |
116 | if __name__ == "__main__":
117 | model = Infer('checkpoint/Aligner/417-00000000.pth',
118 | 'checkpoint/Blender/073-00000000.pth',
119 | 'pretrained_models/parsing.pth')
120 |
121 | src_path_list = ['dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2122.png',
122 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2125.png',
123 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2130.png',
124 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2135.png',
125 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2140.png']
126 | tgt_path = 'dataset/select-align/img/id00061/4kSyBHethpE-0002/2055.png'
127 | oup = model.run(tgt_path,src_path_list)
128 |
129 | cv2.imwrite('2.png',oup)
--------------------------------------------------------------------------------
/model/AlignModule/discriminator.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from model.AlignModule.lib import blocks
3 | from torch.nn.utils import spectral_norm
4 | import math
5 | import torch
6 | # heavily copy from https://github.com/shrubb/latent-pose-reenactment
7 |
8 | # class Discriminator(nn.Module):
9 | # def __init__(self, args):
10 | # super().__init__()
11 |
12 | # def get_down_block(in_channels, out_channels, padding):
13 | # return blocks.ResBlock(in_channels, out_channels, padding, upsample=False, downsample=True,
14 | # norm_layer='none')
15 |
16 | # def get_res_block(in_channels, out_channels, padding):
17 | # return blocks.ResBlock(in_channels, out_channels, padding, upsample=False, downsample=False,
18 | # norm_layer='none')
19 |
20 | # if args.padding == 'zero':
21 | # padding = nn.ZeroPad2d
22 | # elif args.padding == 'reflection':
23 | # padding = nn.ReflectionPad2d
24 |
25 | # self.out_channels = args.embed_channels
26 |
27 | # self.down_block = nn.Sequential(
28 | # # padding(1),
29 | # spectral_norm(
30 | # nn.Conv2d(args.in_channels, args.num_channels, 3, 1, 1),
31 | # eps=1e-4),
32 | # nn.ReLU(),
33 | # # padding(1),
34 | # spectral_norm(
35 | # nn.Conv2d(args.num_channels, args.num_channels, 3, 1, 1),
36 | # eps=1e-4),
37 | # nn.AvgPool2d(2))
38 | # self.skip = nn.Sequential(
39 | # spectral_norm(
40 | # nn.Conv2d(args.in_channels, args.num_channels, 1),
41 | # eps=1e-4),
42 | # nn.AvgPool2d(2))
43 |
44 | # self.blocks = nn.ModuleList()
45 | # num_down_blocks = min(int(math.log(args.output_image_size, 2)) - 2, args.dis_num_blocks)
46 | # in_channels = args.num_channels
47 | # for i in range(1, num_down_blocks):
48 | # out_channels = min(in_channels * 2, args.max_num_channels)
49 | # if i == args.dis_num_blocks - 1: out_channels = self.out_channels
50 | # self.blocks.append(get_down_block(in_channels, out_channels, padding))
51 | # in_channels = out_channels
52 | # for i in range(num_down_blocks, args.dis_num_blocks):
53 | # if i == args.dis_num_blocks - 1: out_channels = self.out_channels
54 | # self.blocks.append(get_res_block(in_channels, out_channels, padding))
55 |
56 | # self.linear = spectral_norm(nn.Linear(self.out_channels, 1), eps=1e-4)
57 |
58 |
59 | # def forward(self, input):
60 |
61 | # feats = []
62 |
63 | # out = self.down_block(input)
64 | # out = out + self.skip(input)
65 | # feats.append(out)
66 | # for block in self.blocks:
67 | # out = block(out)
68 | # feats.append(out)
69 | # out = torch.relu(out)
70 | # out = out.view(out.shape[0], self.out_channels, -1).sum(2)
71 | # out_linear = self.linear(out)[:, 0]
72 | # return out_linear,feats
73 |
74 | from torch.nn.utils import spectral_norm
75 |
76 | # PyTorch implementation by vinesmsuic
77 | # Referenced from official tensorflow implementation: https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/train_code/network.py
78 | # slim.convolution2d uses constant padding (zeros).
79 | # Paper used spectral_norm
80 |
81 | class Block(nn.Module):
82 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,activate=True):
83 | super().__init__()
84 | self.sn_conv = spectral_norm(nn.Conv2d(
85 | in_channels,
86 | out_channels,
87 | kernel_size,
88 | stride,
89 | padding,
90 | padding_mode="zeros" # Author's code used slim.convolution2d, which is using SAME padding (zero padding in pytorch)
91 | ))
92 | self.activate = activate
93 | if self.activate:
94 | self.LReLU = nn.LeakyReLU(negative_slope=0.2, inplace=True)
95 |
96 | def forward(self, x):
97 | x = self.sn_conv(x)
98 | if self.activate:
99 | x = self.LReLU(x)
100 |
101 | return x
102 |
103 |
104 | class Discriminator(nn.Module):
105 | def __init__(self, in_channels=3, out_channels=1, features=[32, 64, 128,256]):
106 | super().__init__()
107 |
108 | self.model = nn.Sequential(
109 | #k3n32s2
110 | Block(in_channels, features[0], kernel_size=3, stride=2, padding=1),
111 | #k3n32s1
112 | Block(features[0], features[0], kernel_size=3, stride=1, padding=1),
113 |
114 | #k3n64s2
115 | Block(features[0], features[1], kernel_size=3, stride=2, padding=1),
116 | #k3n64s1
117 | Block(features[1], features[1], kernel_size=3, stride=1, padding=1),
118 |
119 | #k3n128s2
120 | Block(features[1], features[2], kernel_size=3, stride=2, padding=1),
121 | #k3n128s1
122 | Block(features[2], features[2], kernel_size=3, stride=1, padding=1),
123 |
124 | #k3n256s2
125 | Block(features[2], features[3], kernel_size=3, stride=2, padding=1),
126 | #k3n256s1
127 | Block(features[3], features[3], kernel_size=3, stride=1, padding=1),
128 |
129 |
130 | #k1n1s1
131 | Block(features[3], out_channels, kernel_size=1, stride=1, padding=0)
132 | )
133 |
134 | def forward(self, x):
135 | x = self.model(x)
136 |
137 | return x
138 |
--------------------------------------------------------------------------------
/model/AlignModule/module.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module,Dropout,BatchNorm2d,Linear,BatchNorm1d
4 |
5 | """
6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7 | """
8 |
9 |
10 | class Backbone(Module):
11 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
12 | super(Backbone, self).__init__()
13 | assert input_size in [112, 224], "input_size should be 112 or 224"
14 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
15 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
16 | blocks = get_blocks(num_layers)
17 | if mode == 'ir':
18 | unit_module = bottleneck_IR
19 | elif mode == 'ir_se':
20 | unit_module = bottleneck_IR_SE
21 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
22 | BatchNorm2d(64),
23 | PReLU(64))
24 | if input_size == 112:
25 | self.output_layer = Sequential(BatchNorm2d(512),
26 | Dropout(drop_ratio),
27 | Flatten(),
28 | Linear(512 * 7 * 7, 512),
29 | BatchNorm1d(512, affine=affine))
30 | else:
31 | self.output_layer = Sequential(BatchNorm2d(512),
32 | Dropout(drop_ratio),
33 | Flatten(),
34 | Linear(512 * 14 * 14, 512),
35 | BatchNorm1d(512, affine=affine))
36 |
37 | modules = []
38 | for block in blocks:
39 | for bottleneck in block:
40 | modules.append(unit_module(bottleneck.in_channel,
41 | bottleneck.depth,
42 | bottleneck.stride))
43 | self.body = Sequential(*modules)
44 |
45 | def forward(self, x):
46 | x = self.input_layer(x)
47 | x = self.body(x)
48 | x = self.output_layer(x)
49 | return l2_norm(x)
50 |
51 |
52 |
53 |
54 |
55 | class Flatten(Module):
56 | def forward(self, input):
57 | return input.view(input.size(0), -1)
58 |
59 |
60 | def l2_norm(input, axis=1):
61 | norm = torch.norm(input, 2, axis, True)
62 | output = torch.div(input, norm)
63 | return output
64 |
65 |
66 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
67 | """ A named tuple describing a ResNet block. """
68 |
69 |
70 | def get_block(in_channel, depth, num_units, stride=2):
71 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
72 |
73 |
74 | def get_blocks(num_layers):
75 | if num_layers == 50:
76 | blocks = [
77 | get_block(in_channel=64, depth=64, num_units=3),
78 | get_block(in_channel=64, depth=128, num_units=4),
79 | get_block(in_channel=128, depth=256, num_units=14),
80 | get_block(in_channel=256, depth=512, num_units=3)
81 | ]
82 | elif num_layers == 100:
83 | blocks = [
84 | get_block(in_channel=64, depth=64, num_units=3),
85 | get_block(in_channel=64, depth=128, num_units=13),
86 | get_block(in_channel=128, depth=256, num_units=30),
87 | get_block(in_channel=256, depth=512, num_units=3)
88 | ]
89 | elif num_layers == 152:
90 | blocks = [
91 | get_block(in_channel=64, depth=64, num_units=3),
92 | get_block(in_channel=64, depth=128, num_units=8),
93 | get_block(in_channel=128, depth=256, num_units=36),
94 | get_block(in_channel=256, depth=512, num_units=3)
95 | ]
96 | else:
97 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
98 | return blocks
99 |
100 |
101 | class SEModule(Module):
102 | def __init__(self, channels, reduction):
103 | super(SEModule, self).__init__()
104 | self.avg_pool = AdaptiveAvgPool2d(1)
105 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
106 | self.relu = ReLU(inplace=True)
107 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
108 | self.sigmoid = Sigmoid()
109 |
110 | def forward(self, x):
111 | module_input = x
112 | x = self.avg_pool(x)
113 | x = self.fc1(x)
114 | x = self.relu(x)
115 | x = self.fc2(x)
116 | x = self.sigmoid(x)
117 | return module_input * x
118 |
119 |
120 | class bottleneck_IR(Module):
121 | def __init__(self, in_channel, depth, stride):
122 | super(bottleneck_IR, self).__init__()
123 | if in_channel == depth:
124 | self.shortcut_layer = MaxPool2d(1, stride)
125 | else:
126 | self.shortcut_layer = Sequential(
127 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
128 | BatchNorm2d(depth)
129 | )
130 | self.res_layer = Sequential(
131 | BatchNorm2d(in_channel),
132 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
133 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
134 | )
135 |
136 | def forward(self, x):
137 | shortcut = self.shortcut_layer(x)
138 | res = self.res_layer(x)
139 | return res + shortcut
140 |
141 |
142 | class bottleneck_IR_SE(Module):
143 | def __init__(self, in_channel, depth, stride):
144 | super(bottleneck_IR_SE, self).__init__()
145 | if in_channel == depth:
146 | self.shortcut_layer = MaxPool2d(1, stride)
147 | else:
148 | self.shortcut_layer = Sequential(
149 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
150 | BatchNorm2d(depth)
151 | )
152 | self.res_layer = Sequential(
153 | BatchNorm2d(in_channel),
154 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
155 | PReLU(depth),
156 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
157 | BatchNorm2d(depth),
158 | SEModule(depth, 16)
159 | )
160 |
161 | def forward(self, x):
162 | shortcut = self.shortcut_layer(x)
163 | res = self.res_layer(x)
164 | return res + shortcut
165 |
--------------------------------------------------------------------------------
/ReadMe.md:
--------------------------------------------------------------------------------
1 | # HeSer.Pytorch
2 | unofficial implementation of Few-Shot Head Swapping in the Wild
3 | you can find official version [here](https://github.com/jmliu88/HeSer)
4 | I did not use the discriminator from the paper and just follow [DCT-NET](https://github.com/LeslieZhoa/DCT-NET.Pytorch)
5 | 
6 | 
7 | ## !!!!NEWS!!!!!
8 | [HeadSwap](https://github.com/LeslieZhoa/HeadSwap) is Back. You can use it in colab. Try Now!!!
9 | ## enviroment
10 | - torch
11 | - opencv-python
12 | - tensorboardX
13 | - imgaug
14 | - face-alignment
15 | ```shell
16 | # download pretrain model
17 | cd process
18 | bash download_weight.sh
19 | ```
20 | ## How to RUN
21 | ### train
22 | I only train one ID for driving
23 | #### Data Process
24 | 1. download voxceleb2
25 | a. I just download [voxceleb2 test dataset](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html), you can use this [website](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox2_test_txt.zip)
26 | b. You can unzip this file like this:
27 | ```
28 | +--- dataset
29 | | +--- vox2_test_txt
30 | | | +--- txt
31 | | | | +--- id00017
32 | | | | | +--- 01dfn2spqyE
33 | | | | | | +--- 00001.txt
34 | | | | | +--- 5MkXgwdrmJw
35 | | | | | | +--- 00002.txt
36 | | | | | +--- 7t6lfzvVaTM
37 | | | | | | +--- 00003.txt
38 | | | | | | +--- 00004.txt
39 | | | | | | +--- 00005.txt
40 | | | | | | +--- 00006.txt
41 | | | | | | +--- 00007.txt
42 |
43 | ```
44 | c. Install yt-dlp and aria2c by yourself. I think you can do that through internet.
45 | ```
46 | cd process
47 | python download_and_process.py
48 | ```
49 | d. the dataset is like:
50 | ```
51 | voceleb2/
52 | |-- id00017
53 | | |-- 01dfn2spqyE
54 | | | `-- 00.npy
55 | | |-- 5MkXgwdrmJw
56 | | | |-- 00.npy
57 | | | `-- 5MkXgwdrmJw.mp4
58 | | |-- 7t6lfzvVaTM
59 | | | |-- 00.npy
60 | | | |-- 01.npy
61 | | | |-- 02.npy
62 | | | |-- 03.npy
63 | | | |-- 04.npy
64 | | | |-- 05.npy
65 | | | |-- 06.npy
66 | | | |-- 07.npy
67 | | | |-- 08.npy
68 | | | |-- 09.npy
69 | | | `-- 7t6lfzvVaTM.mp4
70 | ```
71 | 2. crop and align
72 | ```
73 | cd process
74 | python process_raw_video.py
75 | ```
76 | the dataset is like:
77 | ```
78 | process/
79 | |-- img
80 | | |-- id00017
81 | | | |-- 5MkXgwdrmJw-0000
82 | | | | |-- 1273.png
83 | | | | |-- 1274.png
84 | | | | |-- 1275.png
85 | | | | |-- 1276.png
86 | | | | |-- 1277.png
87 | | | | |-- 1278.png
88 | | | | |-- 1279.png
89 | | | | |-- 1280.png
90 | | | | |-- 1281.png
91 | | | | |-- 1282.png
92 | | | | |-- 1283.png
93 | | | | |-- 1284.png
94 | | | | |-- 1285.png
95 | | | | |-- 1286.png
96 | | | | |-- 1287.png
97 | | | | |-- 1288.png
98 | | | | |-- 1289.png
99 | ```
100 | 3. Remove data below threshold
101 | ```
102 | cd process
103 | python filter_idfiles.py
104 | ```
105 | 4. face parsing
106 | follow [LVT](https://github.com/LeslieZhoa/LVT) to get face parsing
107 | the mask data is like:
108 | ```
109 | process/mask/
110 | |-- id00017
111 | | |-- 5MkXgwdrmJw-0000
112 | | | |-- 1273.png
113 | | | |-- 1274.png
114 | | | |-- 1275.png
115 | | | |-- 1276.png
116 | | | |-- 1277.png
117 | | | |-- 1278.png
118 | | | |-- 1279.png
119 | | | |-- 1280.png
120 | ```
121 | #### Train Align
122 | I just use id00061 to train align
123 | check [model/AlignModule/config.py](model/AlignModule/config.py#L1) to put your own path and params
124 | for single gpu
125 | ```
126 | python train.py --model align --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 --dist
127 | ```
128 | for multi gpu
129 | ```
130 | python -m torch.distributed.launch train.py --model align --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100
131 | ```
132 | #### Train Blend
133 | check [model/BlendModule/config.py](model/BlendModule/config.py#L1) to put your own path and params
134 | for single gpu
135 | ```
136 | python train.py --model blend --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 --dist
137 | ```
138 | for multi gpu
139 | ```
140 | python -m torch.distributed.launch train.py --model blend --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100
141 | ```
142 | ## Inference
143 | follow inference.py, change your own model path and input images
144 | ```shell
145 | python inference.py
146 | ```
147 | ## Show
148 | The result is just overfitting
149 | 
150 |
151 | ## Credits
152 | DCT-NET model and implementation:
153 | https://github.com/LeslieZhoa/DCT-NET.Pytorch Copyright © 2022, LeslieZhoa.
154 | License https://github.com/LeslieZhoa/DCT-NET.Pytorch/blob/main/LICENSE
155 |
156 | latent-pose-reenactment model and implementation:
157 | https://github.com/shrubb/latent-pose-reenactment Copyright © 2020, shrubb.
158 | License https://github.com/shrubb/latent-pose-reenactment/blob/master/LICENSE.txt
159 |
160 | arcface pytorch model pytorch model and implementation:
161 | https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.
162 |
163 | LVT model and implementation:
164 | https://github.com/LeslieZhoa/LVT Copyright © 2022, LeslieZhoa.
165 |
166 | face-parsing model and implementation:
167 | https://github.com/zllrunning/face-parsing.PyTorch Copyright © 2019, zllrunning.
168 | License https://github.com/zllrunning/face-parsing.PyTorch/blob/master/LICENSE
169 |
--------------------------------------------------------------------------------
/process/process_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | def crop_with_padding(image, lmks,scale=1.8,size=512,align=True):
5 |
6 | img_box = [np.min(lmks[:, 0]), np.min(lmks[:, 1]), np.max(lmks[:, 0]), np.max(lmks[:, 1])]
7 |
8 | center = ((img_box[0] + img_box[2]) / 2.0, (img_box[1] + img_box[3]) / 2.0)
9 |
10 | if align:
11 | lm_eye_left = lmks[36 : 42] # left-clockwise
12 | lm_eye_right = lmks[42 : 48] # left-clockwise
13 |
14 | eye_left = np.mean(lm_eye_left, axis=0)
15 | eye_right = np.mean(lm_eye_right, axis=0)
16 | angle = np.arctan2((eye_right[1] - eye_left[1]), (eye_right[0] - eye_left[0])) / np.pi * 180
17 |
18 | RotateMatrix = cv2.getRotationMatrix2D(center, angle, scale=1)
19 |
20 | rotated_img = cv2.warpAffine(image, RotateMatrix, (image.shape[1], image.shape[0]))
21 | rotated_lmks = apply_transform(RotateMatrix, lmks)
22 | else:
23 | rotated_img = image
24 | rotated_lmks = lmks
25 | RotateMatrix = np.array([[1,0,0],
26 | [0,1,0]])
27 |
28 | faceBox = [np.min(rotated_lmks[:, 0]), np.min(rotated_lmks[:, 1]),
29 | np.max(rotated_lmks[:, 0]), np.max(rotated_lmks[:, 1])]
30 |
31 | cx_box = (faceBox[0] + faceBox[2]) / 2.
32 | cy_box = (faceBox[1] + faceBox[3]) / 2.
33 | width = faceBox[2] - faceBox[0] + 1
34 | height = faceBox[3] - faceBox[1] + 1
35 | face_size = max(width, height)
36 | bbox_size = int(face_size * scale)
37 |
38 |
39 |
40 |
41 | x_min = int(cx_box-bbox_size / 2.)
42 | y_min = int(cy_box-bbox_size / 2.)
43 | x_max = x_min + bbox_size
44 | y_max = y_min + bbox_size
45 |
46 | boundingBox = [max(x_min, 0), max(y_min, 0), min(x_max, rotated_img.shape[1]), min(y_max, rotated_img.shape[0])]
47 | imgCropped = rotated_img[boundingBox[1]:boundingBox[3], boundingBox[0]:boundingBox[2]]
48 | imgCropped = cv2.copyMakeBorder(imgCropped, max(-y_min, 0), max(y_max - image.shape[0], 0), max(-x_min, 0),
49 | max(x_max - image.shape[1], 0),cv2.BORDER_CONSTANT,value=(0,0,0))
50 | boundingBox = [x_min, y_min, x_max, y_max]
51 |
52 | scale_h = size / float(bbox_size)
53 | scale_w = size / float(bbox_size)
54 | rotated_lmks[:, 0] = (rotated_lmks[:, 0] - boundingBox[0]) * scale_w
55 | rotated_lmks[:, 1] = (rotated_lmks[:, 1] - boundingBox[1]) * scale_h
56 | # print(imgCropped.shape)
57 | imgResize = cv2.resize(imgCropped, (size, size))
58 |
59 | ### 计算变换(原图->crop box)
60 | m1 = np.concatenate((RotateMatrix,np.array([[0.0,0.0,1.0]])), axis=0) #rotate(+translation)
61 | m2 = np.eye(3) #translation
62 | m2[0][2] = -boundingBox[0]
63 | m2[1][2] = -boundingBox[1]
64 | m3 = np.eye(3) #scaling
65 | m3[0][0] = m3[1][1] = scale_h
66 | m = np.matmul(np.matmul(m3,m2),m1)
67 | im = np.linalg.inv(m)
68 | info = {'rotated_lmk':rotated_lmks,
69 | 'm':m,
70 | 'im':im}
71 |
72 | return imgResize,info
73 |
74 |
75 | def apply_transform(transform_matrix, lmks):
76 | '''
77 | args
78 | transform_matrix: float (3,3)|(2,3)
79 | lmks: float (2)|(3)|(k,2)|(k,3)
80 |
81 | ret
82 | ret_lmks: float (2)|(3)|(k,2)|(k,3)
83 | '''
84 | if transform_matrix.shape[0] == 2:
85 | transform_matrix = np.concatenate((transform_matrix,np.array([[0.0,0.0,1.0]])), axis=0)
86 | only_one = False
87 | if len(lmks.shape) == 1:
88 | lmks = lmks[np.newaxis, :]
89 | only_one = True
90 | only_two_dim = False
91 | if lmks.shape[1] == 2:
92 | lmks = np.concatenate((lmks, np.ones((lmks.shape[0],1), dtype=np.float32)), axis=1)
93 | only_two_dim = True
94 |
95 | ret_lmks = np.matmul(transform_matrix, lmks.T).T
96 |
97 | if only_two_dim:
98 | ret_lmks = ret_lmks[:,:2]
99 | if only_one:
100 | ret_lmks = ret_lmks[0]
101 |
102 | return ret_lmks
103 |
104 |
105 | def choose_one_detection(frame_faces,box):
106 | """
107 | frame_faces
108 | list of lists of length 5
109 | several face detections from one image
110 |
111 | return:
112 | list of 5 floats
113 | one of the input detections: `(l, t, r, b, confidence)`
114 | """
115 | frame_faces = list(filter(lambda x:x[-1]>0.9,frame_faces))
116 | if len(frame_faces) == 0:
117 | return None
118 |
119 | else:
120 | # sort by area, find the largest box
121 | largest_area, largest_idx = -1, -1
122 | for idx, face in enumerate(frame_faces):
123 | area = compute_iou(box,face)
124 | # area = abs(face[2]-face[0]) * abs(face[1]-face[3])
125 | if area > largest_area:
126 | largest_area = area
127 | largest_idx = idx
128 |
129 | if largest_area < 0.1:
130 | return None
131 |
132 | retval = frame_faces[largest_idx]
133 |
134 |
135 | return np.array(retval).tolist()
136 |
137 |
138 | def compute_iou(rec1, rec2):
139 | """
140 | computing IoU
141 | :param rec1: (x0, y0, x1, y1), which reflects
142 | (top, left, bottom, right)
143 | :param rec2: (x0, y0, x1, y1)
144 | :return: scala value of IoU
145 | """
146 | # computing area of each rectangles
147 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
148 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
149 |
150 | # computing the sum_area
151 | sum_area = S_rec1 + S_rec2
152 |
153 | # find the each edge of intersect rectangle
154 | left_line = max(rec1[0], rec2[0])
155 | right_line = min(rec1[2], rec2[2])
156 | top_line = max(rec1[1], rec2[1])
157 | bottom_line = min(rec1[3], rec2[3])
158 |
159 | # judge if there is an intersect
160 | if left_line >= right_line or top_line >= bottom_line:
161 | return 0
162 | else:
163 | intersect = (right_line - left_line) * (bottom_line - top_line)
164 | return (intersect / (sum_area - intersect))*1.0
165 | # return intersect / S_rec2
--------------------------------------------------------------------------------
/dataloader/augmentation.py:
--------------------------------------------------------------------------------
1 | import imgaug.augmenters as iaa
2 | import torch
3 | import numpy as np
4 |
5 | from contextlib import contextmanager
6 |
7 | # heavily copy from https://github.com/shrubb/latent-pose-reenactment
8 | class ParametricAugmenter:
9 | def is_empty(self):
10 | return not self.seq and not self.shift_seq
11 |
12 | def __init__(self, use_pixelwise_augs,use_affine_scale,use_affine_shift):
13 |
14 |
15 | sometimes = lambda aug: iaa.Sometimes(0.5, aug)
16 | total_augs = []
17 |
18 | if use_pixelwise_augs:
19 | pixelwise_augs = [
20 | iaa.SomeOf((0, 5),
21 | [
22 | # sometimes(iaa.Superpixels(p_replace=(0, 0.25), n_segments=(150, 200))),
23 | iaa.OneOf([
24 | iaa.GaussianBlur((0, 1.0)), # blur images with a sigma between 0 and 3.0
25 | iaa.AverageBlur(k=(1, 3)),
26 | # blur image using local means with kernel sizes between 2 and 7
27 | iaa.MedianBlur(k=(1, 3)),
28 | # blur image using local medians with kernel sizes between 2 and 7
29 | ]),
30 | iaa.Sharpen(alpha=(0, 1.0), lightness=(1.0, 1.5)), # sharpen images
31 | iaa.Emboss(alpha=(0, 1.0), strength=(0, 0.5)), # emboss images
32 | # search either for all edges or for directed edges,
33 | # blend the result with the original image using a blobby mask
34 | iaa.BlendAlphaSimplexNoise(
35 | iaa.EdgeDetect(alpha=(0.0, 0.15)),
36 | ),
37 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=False),
38 | # add gaussian noise to images
39 | iaa.Add((-10, 10), per_channel=0.5),
40 | # change brightness of images (by -10 to 10 of original value)
41 | iaa.AddToSaturation((-20, 20)), # change hue and saturation
42 | iaa.JpegCompression((70, 99)),
43 |
44 | iaa.Multiply((0.5, 1.5), per_channel=False),
45 |
46 | iaa.OneOf([
47 | iaa.LinearContrast((0.75, 1.25), per_channel=False),
48 | iaa.SigmoidContrast(cutoff=0.5, gain=(3.0, 11.0))
49 | ]),
50 | sometimes(iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.15)),
51 | # move pixels locally around (with random strengths)
52 | ],
53 | random_order=True
54 | )
55 | ]
56 | total_augs.extend(pixelwise_augs)
57 | affine_augs_scale = []
58 | if use_affine_scale:
59 | affine_augs_scale = [sometimes(iaa.Affine(
60 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
61 | # scale images to 80-120% of their size, individually per axis
62 | order=[1], # use bilinear interpolation (fast)
63 | mode=["reflect"]
64 | ))]
65 | # total_augs.extend(affine_augs_scale)
66 |
67 | if use_affine_shift:
68 | affine_augs_shift = [sometimes(iaa.Affine(
69 | translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)},
70 | order=[1], # use bilinear interpolation (fast)
71 | mode=["reflect"]
72 | ))]
73 | else:
74 | affine_augs_shift = []
75 |
76 | self.shift_seq = iaa.Sequential(affine_augs_shift)
77 | self.seq = iaa.Sequential(total_augs, random_order=True)
78 | self.scale_seq = iaa.Sequential(affine_augs_scale)
79 |
80 | def tensor2image(self, image, norm = 255.0):
81 | return (np.expand_dims(image.squeeze().permute(1, 2, 0).numpy(), 0) * norm)
82 |
83 | def image2tensor(self, image, norm = 255.0):
84 | image = image.astype(np.float32) / norm
85 | image = torch.tensor(np.squeeze(image)).permute(2, 0, 1)
86 | return image
87 |
88 |
89 | def augment_tensor(self, image):
90 | if self.seq or self.shift_seq:
91 | image = self.tensor2image(image).astype(np.uint8)
92 | image = self.seq(images=image)
93 | image = self.shift_seq(images=image,)
94 | image = self.image2tensor(image)
95 |
96 | return image
97 |
98 | def augment_triple(self, image1, image2,mask):
99 | if self.seq or self.shift_seq:
100 | image1 = self.tensor2image(image1).astype(np.uint8)
101 |
102 | image1 = self.seq(images=image1,)
103 | if self.scale_seq:
104 |
105 | scale_seq_deterministic = self.scale_seq.to_deterministic()
106 | image1 = scale_seq_deterministic(images=image1)
107 | mask = scale_seq_deterministic(images=mask)
108 | if self.shift_seq:
109 | image2 = self.tensor2image(image2).astype(np.uint8)
110 | shift_seq_deterministic = self.shift_seq.to_deterministic()
111 | image1 = shift_seq_deterministic(images=image1,)
112 | image2 = shift_seq_deterministic(images=image2)
113 | mask = shift_seq_deterministic(images=mask)
114 |
115 | image2 = self.image2tensor(image2)
116 |
117 | image1 = self.image2tensor(image1)
118 | mask = self.image2tensor(mask)
119 |
120 | return image1, image2,mask
121 |
122 | @contextmanager
123 | def deterministic_(self, seed):
124 | """
125 | A context manager to pre-define the random state of all augmentations.
126 |
127 | seed:
128 | `int`
129 | """
130 | # Backup the random states
131 | old_seq = self.seq.deepcopy()
132 | old_shift_seq = self.shift_seq.deepcopy()
133 | self.seq.seed_(seed)
134 | self.shift_seq.seed_(seed)
135 | yield
136 | # Restore the backed up random states
137 | self.seq = old_seq
138 | self.shift_seq = old_shift_seq
139 |
--------------------------------------------------------------------------------
/trainer/ModelTrainer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author Leslie
5 | @date 20220812
6 | '''
7 | import torch
8 | import math
9 | import time,os
10 |
11 | from utils.visualizer import Visualizer
12 | from torch.nn.parallel import DistributedDataParallel as DDP
13 | import torch.distributed as dist
14 | import subprocess
15 | from utils.utils import convert_img
16 | class ModelTrainer:
17 |
18 | def __init__(self,args):
19 |
20 | self.args = args
21 | self.batch_size = args.batch_size
22 | self.old_lr = args.lr
23 | if args.rank == 0 :
24 | self.vis = Visualizer(args)
25 |
26 | if args.eval:
27 | self.val_vis = Visualizer(args,"val")
28 |
29 | # ## ===== ===== ===== ===== =====
30 | # ## Train network
31 | # ## ===== ===== ===== ===== =====
32 |
33 | def train_network(self,train_loader,test_loader):
34 |
35 | counter = 0
36 | loss_dict = {}
37 | acc_num = 0
38 | mn_loss = float('inf')
39 |
40 | steps = 0
41 | begin_it = 0
42 | if self.args.pretrain_path:
43 | begin_it = int(self.args.pretrain_path.split('/')[-1].split('-')[0])
44 | steps = (begin_it+1) * math.ceil(self.args.mx_data_length/self.args.batch_size)
45 |
46 | print("current steps: %d | one epoch steps: %d "%(steps,self.args.mx_data_length))
47 |
48 | for epoch in range(begin_it+1,self.args.max_epoch):
49 |
50 | for ii,(data) in enumerate(train_loader):
51 |
52 | tstart = time.time()
53 |
54 | self.run_single_step(data,steps)
55 | losses = self.get_latest_losses()
56 |
57 | for key,val in losses.items():
58 | loss_dict[key] = loss_dict.get(key,0) + val.mean().item()
59 |
60 | counter += 1
61 | steps += 1
62 |
63 | telapsed = time.time() - tstart
64 |
65 |
66 | if ii % self.args.print_interval == 0 and self.args.rank == 0:
67 |
68 | for key,val in loss_dict.items():
69 | loss_dict[key] /= counter
70 | lr_rate = self.get_lr()
71 | print_dict = {**{"time":telapsed,"lr":lr_rate},
72 | **loss_dict}
73 | self.vis.print_current_errors(epoch,ii,print_dict,telapsed)
74 |
75 | self.vis.plot_current_errors(print_dict,steps)
76 |
77 | loss_dict = {}
78 | counter = 0
79 |
80 | # torch.cuda.empty_cache()
81 | if self.args.save_interval != 0 and ii % self.args.save_interval == 0 and \
82 | self.args.rank == 0:
83 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,ii)))
84 |
85 | display_data = self.select_img(self.get_latest_generated())
86 |
87 | self.vis.display_current_results(display_data,steps)
88 |
89 |
90 |
91 | if self.args.eval and self.args.test_interval > 0 and steps % self.args.test_interval == 0:
92 | val_loss = self.evalution(test_loader,steps,epoch)
93 |
94 | if self.args.early_stop:
95 |
96 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch)
97 | if stop_flag:
98 | return
99 |
100 | # print('******************memory:',psutil.virtual_memory()[3])
101 |
102 | if self.args.rank == 0 :
103 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0)))
104 |
105 | # 验证,保存最优模型
106 | if test_loader or self.args.eval:
107 | val_loss = self.evalution(test_loader,steps,epoch)
108 |
109 | if self.args.early_stop:
110 |
111 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch)
112 | if stop_flag:
113 | return
114 |
115 |
116 | if self.args.rank == 0 :
117 | self.vis.close()
118 |
119 |
120 |
121 | def early_stop_wait(self,loss,acc_num,mn_loss,epoch):
122 |
123 | if self.args.rank == 0:
124 | if loss < mn_loss:
125 | mn_loss = loss
126 | cmd_one = 'cp -r %s %s'%(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0)),
127 | os.path.join(self.args.checkpoint_path,'final.pth'))
128 | done_one = subprocess.Popen(cmd_one,stdout=subprocess.PIPE,shell=True)
129 | done_one.wait()
130 | acc_num = 0
131 | else:
132 | acc_num += 1
133 | # 多机多卡,某一张卡退出则终止程序,使用all_reduce
134 | if self.args.dist:
135 |
136 | if acc_num > self.args.stop_interval:
137 | signal = torch.tensor([0]).cuda()
138 | else:
139 | signal = torch.tensor([1]).cuda()
140 | else:
141 | if self.args.dist:
142 | signal = torch.tensor([1]).cuda()
143 |
144 | if self.args.dist:
145 | dist.all_reduce(signal)
146 | value = signal.item()
147 | if value >= int(os.environ.get("WORLD_SIZE","1")):
148 | dist.all_reduce(torch.tensor([0]).cuda())
149 | return acc_num,mn_loss,False
150 | else:
151 | return acc_num,mn_loss,True
152 |
153 | else:
154 | if acc_num > self.args.stop_interval:
155 | return acc_num,mn_loss,True
156 | else:
157 | return acc_num,mn_loss,False
158 |
159 | def run_single_step(self,data,steps):
160 | data = self.process_input(data)
161 | self.run_discriminator_one_step(data,steps)
162 | self.run_generator_one_step(data,steps)
163 |
164 | def select_img(self,data,name='fake',axis=2):
165 | if data is None:
166 | return None
167 | cat_img = []
168 | for v in data:
169 | cat_img.append(v.detach().cpu())
170 |
171 | cat_img = torch.cat(cat_img,-1)
172 | cat_img = torch.cat(torch.split(cat_img,1,dim=0),axis)[0]
173 |
174 | return {name:convert_img(cat_img)}
175 |
176 |
177 |
178 | ##################################################################
179 | # Helper functions
180 | ##################################################################
181 |
182 | def get_loss_from_val(self,loss):
183 | return loss
184 |
185 | def get_show_inp(self,data):
186 | if not isinstance(data,list):
187 | return [data]
188 | return data
189 |
190 | def use_ddp(self,model):
191 | if model is None:
192 | return None
193 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) #用于将BN转换成ddp模式/
194 | # model = DDP(model,broadcast_buffers=False,find_unused_parameters=True) # find_unused_parameters->训练gan会有判别器或生成器参数不参与训练,需使用该参数
195 | model = DDP(model,
196 | broadcast_buffers=False,
197 | find_unused_parameters=True
198 | )
199 | model_on_one_gpu = model.module #若需要调用self.model的函数,在ddp模式要调用self._model_on_one_gpu
200 | return model,model_on_one_gpu
201 | def process_input(self,data):
202 |
203 | if torch.cuda.is_available():
204 | if isinstance(data,list):
205 | data = [x.cuda() for x in data]
206 | else:
207 | data = data.cuda()
208 | return data
--------------------------------------------------------------------------------
/model/BlendModule/generator.py:
--------------------------------------------------------------------------------
1 | '''
2 | @author Leslie
3 | @date 20220823
4 | '''
5 | import torch
6 | from torch import nn
7 | from model.BlendModule.module import VGG19_pytorch,Decoder
8 | import torch.nn.functional as F
9 | import pdb
10 |
11 | class Generator(nn.Module):
12 | def __init__(self,args):
13 | super().__init__()
14 | self.feature_ext = VGG19_pytorch()
15 | self.decoder = Decoder(ic=args.decoder_ic)
16 | self.dilate = nn.MaxPool2d(kernel_size=args.dilate_kernel,
17 | stride=1,
18 | padding=args.dilate_kernel//2)
19 |
20 | self.phi = nn.Conv2d(in_channels=args.f_in_channels,
21 | out_channels=args.f_inter_channels, kernel_size=1, stride=1, padding=0)
22 | self.theta = nn.Conv2d(in_channels=args.f_in_channels,
23 | out_channels=args.f_inter_channels, kernel_size=1, stride=1, padding=0)
24 |
25 | self.temperature = args.temperature
26 |
27 | self.head_index = [1,2,3,4,5,6,7,8,9,10,11,12,13,17,18]
28 | self.eps = 1e-8
29 |
30 | def forward(self,I_a,I_gray,I_t,M_a,M_t,gt=None,cycle=False,train=False):
31 |
32 | fA = self.feature_ext(I_a)
33 | fT = self.feature_ext(I_t)
34 |
35 | fA = self.phi(fA)
36 | fT = self.theta(fT)
37 |
38 | gen_h,gen_i,mask_list,matrix_list = self.RCNet(fA,fT,M_a,M_t,I_t)
39 |
40 | gen_h = F.interpolate(gen_h,size=I_t.shape[-2:],mode='bilinear',align_corners=True)
41 | gen_i = F.interpolate(gen_i,size=I_t.shape[-2:],mode='bilinear',align_corners=True)
42 | M_Ah,M_Ad,M_Td,M_Ai,M_Ti,M_Ar,M_Tr = mask_list
43 |
44 | if cycle:
45 |
46 | cycle_gen = self.RCCycle(gen_h+gen_i,[M_Ar,M_Tr,M_Ai,M_Ti],matrix_list,fA.shape)
47 |
48 | I_td = I_t * M_Td
49 | I_td = F.interpolate(I_td, size=cycle_gen.shape[-2:],mode='bilinear')
50 | return cycle_gen,I_td
51 |
52 | I_tb = gt * (1-M_Ad)
53 | I_ag = I_gray * M_Ah
54 |
55 | # cat_img = torch.cat([gen_h,gen_i,M_Ah.repeat(1,3,1,1),I_tb,M_Ai.repeat(1,3,1,1),I_ag.repeat(1,3,1,1)],-1)
56 | # import cv2
57 | # cv2.imwrite('1.png',(cat_img[0].permute(1,2,0).detach().cpu().numpy()[...,::-1]+1)*127.5)
58 | # pdb.set_trace()
59 | inp = torch.cat([gen_h,gen_i,
60 | M_Ah,
61 | I_tb,M_Ai,I_ag],1)
62 |
63 | oup = self.decoder(inp)
64 |
65 | if train:
66 | return oup,M_Ah,M_Ai
67 |
68 | return oup
69 |
70 |
71 | def RCNet(self,fA,fT,M_a,M_t,I_t):
72 |
73 | M_Ah = self.get_mask(M_a,self.head_index)
74 | M_Th = self.get_mask(M_t,self.head_index)
75 |
76 | M_Ti,M_Td = self.get_inpainting(M_Th)
77 | M_Ai,M_Ad = self.get_inpainting(M_Ah+M_Th,M_Ah)
78 | M_Ar = self.get_multi_mask(M_a)
79 | M_Tr = self.get_multi_mask(M_t)
80 |
81 | matrix_list = []
82 | gen_h = None
83 | for m_a,m_t in zip(M_Ar,M_Tr):
84 | gen_h, matrix = self.compute_corre(fA,fT,m_a,m_t,I_t,gen_h)
85 | matrix_list.append(matrix)
86 |
87 | gen_i = None
88 | gen_i,matrix = self.compute_corre(fA,fT,M_Ai,M_Ti,I_t,gen_i)
89 | matrix_list.append(matrix)
90 |
91 | return gen_h,gen_i,[M_Ah,M_Ad,M_Td,M_Ai,M_Ti,M_Ar,M_Tr],matrix_list
92 |
93 | def RCCycle(self,I_t,mask_list,matrix_list,shape):
94 | M_Ar,M_Tr,M_Ai,M_Ti = mask_list
95 | batch,channel,h,w = shape
96 | gen_h = torch.zeros((batch,3,h,w)).to(I_t.device)
97 | gen_i = torch.zeros((batch,3,h,w)).to(I_t.device)
98 | I_t_resize = F.interpolate(I_t, size=(h,w),mode='bilinear',align_corners=True)
99 |
100 | M_Tr_resize = [F.interpolate(f, size=(h,w),mode='nearest') for f in M_Tr]
101 | M_Ar_resize = [F.interpolate(f, size=(h,w),mode='nearest') for f in M_Ar]
102 | M_Ti_resize = F.interpolate(M_Ti, size=(h,w),mode='nearest')
103 | M_Ai_resize = F.interpolate(M_Ai, size=(h,w),mode='nearest')
104 | for matrix,m_t,m_a in zip(matrix_list[:-1],M_Tr_resize,M_Ar_resize):
105 | for i in range(batch):
106 | f_WTA = matrix[i]
107 | f = F.softmax(f_WTA.transpose(1,2),dim=-1)
108 |
109 | ref = torch.matmul(
110 | I_t_resize[i].unsqueeze(0).masked_select(
111 | m_a[i].unsqueeze(0)==1).view(1,3,-1),f.transpose(1,2))
112 | gen_h[i] = gen_h[i].unsqueeze(0).masked_scatter(
113 | m_t[i].unsqueeze(0)==1,ref).squeeze(0)
114 |
115 | for i in range(batch):
116 |
117 | f_WTA = matrix_list[-1][i]
118 | f = F.softmax(f_WTA.transpose(1,2),dim=-1)
119 |
120 | ref = torch.matmul(
121 | I_t_resize[i].unsqueeze(0).masked_select(
122 | M_Ai_resize[i].unsqueeze(0)==1).view(1,3,-1),f.transpose(1,2))
123 | gen_i[i] = gen_i[i].unsqueeze(0).masked_scatter(
124 | M_Ti_resize[i].unsqueeze(0)==1,ref).squeeze(0)
125 |
126 | return gen_h + gen_i
127 |
128 |
129 | def compute_corre(self,fA,fT,M_A,M_T,I_t,gen=None):
130 | batch,channel,h,w = fA.shape
131 | matrix_list = []
132 | if gen is None:
133 | gen = torch.zeros((batch,3,h,w)).to(I_t.device)
134 |
135 | M_A_resize = F.interpolate(M_A, size=(h,w),mode='nearest')
136 | M_T_resize = F.interpolate(M_T, size=(h,w),mode='nearest')
137 | I_t_resize = F.interpolate(I_t, size=(h,w),mode='bilinear', align_corners=True)
138 | for i in range(batch):
139 | fAr = fA[i].unsqueeze(0).masked_select(
140 | M_A_resize[i].unsqueeze(0)==1).view(1,channel,-1) # b,c,hA
141 |
142 | fTr = fT[i].unsqueeze(0).masked_select(
143 | M_T_resize[i].unsqueeze(0)==1).view(1,channel,-1) # b,c,hT
144 |
145 | fAr = self.normlize(fAr)
146 | fTr = self.normlize(fTr)
147 |
148 | matrix = torch.matmul(fAr.permute(0,2,1),fTr) # b,hA,hT
149 | f_WTA = matrix/self.temperature
150 | f = F.softmax(f_WTA,dim=-1)
151 | matrix_list.append(f_WTA)
152 |
153 |
154 | ref = torch.matmul(
155 | I_t_resize[i].unsqueeze(0).masked_select(
156 | M_T_resize[i].unsqueeze(0)==1).view(1,3,-1),f.transpose(1,2)) # [b,channel,hT] X [b,hT,hA]
157 |
158 |
159 | gen[i] = gen[i].unsqueeze(0).masked_scatter(
160 | M_A_resize[i].unsqueeze(0)==1,ref).squeeze(0)
161 |
162 |
163 | return gen,matrix_list
164 |
165 | def get_inpainting(self,M,head=None):
166 | M = torch.clamp(M,0,1)
167 | M_dilate = self.dilate(M)
168 | if head is None:
169 | MI = M_dilate - M
170 | else:
171 | MI = M_dilate - head
172 | return MI,M_dilate
173 |
174 | def get_multi_mask(self,M_a):
175 | # skin
176 | skin_mask_A = self.get_mask(M_a,[1])
177 | # hair
178 | hair_mask_A = self.get_mask(M_a,[17,18])
179 |
180 | # eye
181 | eye_mask_A = self.get_mask(M_a,[4,5,6])
182 |
183 | # brow
184 | brow_mask_A = self.get_mask(M_a,[2,3])
185 |
186 | # ear
187 | ear_mask_A = self.get_mask(M_a,[7,8,9])
188 |
189 | #nose
190 | nose_mask_A = self.get_mask(M_a,[10])
191 |
192 | # lip
193 | lip_mask_A = self.get_mask(M_a,[12,13])
194 |
195 |
196 | # tooth
197 | tooth_mask_A = self.get_mask(M_a,[11])
198 |
199 | return [skin_mask_A,hair_mask_A,eye_mask_A,brow_mask_A,ear_mask_A,nose_mask_A,lip_mask_A,tooth_mask_A]
200 |
201 | def get_mask(self,mask,indexs):
202 | out = torch.zeros_like(mask)
203 | for i in indexs:
204 | out[mask==i] = 1
205 |
206 | return out
207 |
208 | def normlize(self,x):
209 | x_mean = x.mean(dim=1,keepdim=True)
210 | x_norm = torch.norm(x,2,1,keepdim=True) + self.eps
211 | return (x-x_mean) / x_norm
212 |
--------------------------------------------------------------------------------
/trainer/BlendTrainer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author Leslie
5 | @date 20220812
6 | '''
7 | import torch
8 |
9 | from trainer.ModelTrainer import ModelTrainer
10 | from model.BlendModule.generator import Generator
11 | from model.AlignModule.discriminator import Discriminator
12 | from utils.utils import *
13 | from model.AlignModule.loss import *
14 | import torch.nn.functional as F
15 | import random
16 | import torch.distributed as dist
17 |
18 | class BlendTrainer(ModelTrainer):
19 |
20 | def __init__(self, args):
21 | super().__init__(args)
22 | self.device = 'cpu'
23 | if torch.cuda.is_available():
24 | self.device = 'cuda'
25 |
26 | self.netG = Generator(args).to(self.device)
27 |
28 | self.netD = Discriminator(in_channels=5).to(self.device)
29 |
30 |
31 | self.optimG,self.optimD = self.create_optimizer()
32 |
33 | if args.pretrain_path is not None:
34 | self.loadParameters(args.pretrain_path)
35 |
36 | if args.dist:
37 | self.netG,self.netG_module = self.use_ddp(self.netG)
38 | self.netD,self.netD_module = self.use_ddp(self.netD)
39 | else:
40 | self.netG_module = self.netG
41 | self.netD_module = self.netD
42 |
43 | if self.args.per_loss:
44 | self.perLoss = PerceptualLoss(args.per_model).to(self.device)
45 | self.perLoss.eval()
46 |
47 | if self.args.rec_loss:
48 | self.L1Loss = torch.nn.L1Loss()
49 |
50 |
51 | def create_optimizer(self):
52 | g_optim = torch.optim.Adam(
53 | self.netG.parameters(),
54 | lr=self.args.g_lr,
55 | betas=(self.args.beta1, self.args.beta2),
56 | )
57 | d_optim = torch.optim.Adam(
58 | self.netD.parameters(),
59 | lr=self.args.d_lr,
60 | betas=(self.args.beta1, self.args.beta2),
61 | )
62 |
63 | return g_optim,d_optim
64 |
65 |
66 | def run_single_step(self, data, steps):
67 | self.netG.train()
68 | super().run_single_step(data, steps)
69 |
70 |
71 | def run_discriminator_one_step(self, data,step):
72 |
73 | D_losses = {}
74 | requires_grad(self.netG, False)
75 | requires_grad(self.netD, True)
76 |
77 | I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt = data
78 | fake,M_Ah,M_Ai = self.netG(I_a,I_gray,I_t,M_a,M_t,gt,train=True)
79 | fake_pred = self.netD(torch.cat([fake,M_Ah,M_Ai],1))
80 | real_pred = self.netD(torch.cat([gt,M_Ah,M_Ai],1))
81 | d_loss = compute_dis_loss(fake_pred, real_pred,D_losses)
82 | D_losses['d'] = d_loss
83 |
84 | self.optimD.zero_grad()
85 | d_loss.backward()
86 | self.optimD.step()
87 |
88 | self.d_losses = D_losses
89 |
90 |
91 | def run_generator_one_step(self, data,step):
92 |
93 |
94 | requires_grad(self.netG, True)
95 | requires_grad(self.netD, False)
96 |
97 | I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt = data
98 | G_losses,loss,xg = self.compute_g_loss(I_a,I_gray,I_t,M_a,M_t,gt)
99 | self.optimG.zero_grad()
100 | loss.mean().backward()
101 | self.optimG.step()
102 |
103 | g_losses,loss,fake_nopair,label_nopair = self.compute_cycle_g_loss(I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat)
104 | self.optimG.zero_grad()
105 | loss.mean().backward()
106 | self.optimG.step()
107 |
108 | self.g_losses = {**G_losses,**g_losses}
109 |
110 | self.generator = [I_a.detach(),fake_nopair.detach(),
111 | label_nopair.detach(),xg.detach(),gt.detach()]
112 |
113 |
114 | def evalution(self,test_loader,steps,epoch):
115 |
116 | loss_dict = {}
117 | index = random.randint(0,len(test_loader)-1)
118 | counter = 0
119 | with torch.no_grad():
120 | for i,data in enumerate(test_loader):
121 |
122 | data = self.process_input(data)
123 | I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt = data
124 | G_losses,losses,xg = self.compute_g_loss(I_a,I_gray,I_t,M_a,M_t,gt)
125 | for k,v in G_losses.items():
126 | loss_dict[k] = loss_dict.get(k,0) + v.detach()
127 | if i == index and self.args.rank == 0 :
128 |
129 | show_data = [I_a,xg,gt]
130 | self.val_vis.display_current_results(self.select_img(show_data),steps)
131 | counter += 1
132 |
133 |
134 | for key,val in loss_dict.items():
135 | loss_dict[key] /= counter
136 |
137 | if self.args.dist:
138 | # if self.args.rank == 0 :
139 | dist_losses = loss_dict.copy()
140 | for key,val in loss_dict.items():
141 |
142 | dist.reduce(dist_losses[key],0)
143 | value = dist_losses[key].item()
144 | loss_dict[key] = value / self.args.world_size
145 |
146 | if self.args.rank == 0 :
147 | self.val_vis.plot_current_errors(loss_dict,steps)
148 | self.val_vis.print_current_errors(epoch+1,0,loss_dict,0)
149 |
150 | return loss_dict
151 |
152 |
153 | def compute_g_loss(self,I_a,I_gray,I_t,M_a,M_t,gt):
154 | G_losses = {}
155 | loss = 0
156 | fake,M_Ah,M_Ai = self.netG(I_a,I_gray,I_t,M_a,M_t,gt,train=True)
157 | fake_pred = self.netD(torch.cat([fake,M_Ah,M_Ai],1))
158 | gan_loss = compute_gan_loss(fake_pred) * self.args.lambda_gan
159 | G_losses['g_losses'] = gan_loss
160 | loss += gan_loss
161 |
162 | if self.args.rec_loss:
163 | rec_loss = self.L1Loss(fake,gt) * self.args.lambda_rec
164 | G_losses['rec_loss'] = rec_loss
165 | loss += rec_loss
166 |
167 |
168 | if self.args.per_loss:
169 | per_loss = self.perLoss(fake,gt) * self.args.lambda_per
170 | G_losses['per_loss'] = per_loss
171 | loss += per_loss
172 |
173 | return G_losses,loss,fake
174 |
175 |
176 | def compute_cycle_g_loss(self,I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat):
177 | G_losses = {}
178 | loss = 0
179 | fake_pair,label_pair = self.netG(I_a,I_gray,I_t,M_a,M_t,cycle=True)
180 | fake_nopair,label_nopair = self.netG(I_a,I_gray,hat_t,M_a,M_hat,cycle=True)
181 |
182 | loss = self.L1Loss(fake_pair,label_pair) + \
183 | self.L1Loss(fake_nopair,label_nopair)
184 | G_losses['cycle'] = loss
185 | fake_nopair = F.interpolate(fake_nopair, size=I_t.shape[-2:],mode='bilinear')
186 | label_nopair = F.interpolate(label_nopair, size=I_t.shape[-2:],mode='bilinear')
187 | return G_losses,loss,fake_nopair,label_nopair
188 |
189 |
190 | def get_latest_losses(self):
191 | return {**self.g_losses,**self.d_losses}
192 |
193 | def get_latest_generated(self):
194 | return self.generator
195 |
196 | def loadParameters(self,path):
197 | ckpt = torch.load(path, map_location=lambda storage, loc: storage)
198 | self.netG.load_state_dict(ckpt['G'],strict=False)
199 | self.netD.load_state_dict(ckpt['D'],strict=False)
200 | self.optimG.load_state_dict(ckpt['g_optim'])
201 | self.optimD.load_state_dict(ckpt['d_optim'])
202 |
203 | def saveParameters(self,path):
204 | torch.save(
205 | {
206 | "G": self.netG_module.state_dict(),
207 | "D": self.netD_module.state_dict(),
208 | "g_optim": self.optimG.state_dict(),
209 | "d_optim": self.optimD.state_dict(),
210 | "args": self.args,
211 | },
212 | path
213 | )
214 |
215 | def get_lr(self):
216 | return self.optimG.state_dict()['param_groups'][0]['lr']
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/model/BlendModule/module.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch
5 | class VGG19_pytorch(nn.Module):
6 | """
7 |
8 | """
9 |
10 | def __init__(self, pool="max"):
11 | super(VGG19_pytorch, self).__init__()
12 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
13 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
14 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
15 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
16 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
17 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
18 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
19 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
20 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
21 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
22 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
23 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
24 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
25 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
26 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
27 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
28 | if pool == "max":
29 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
30 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
31 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
32 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
33 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
34 | elif pool == "avg":
35 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
36 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
37 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
38 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
39 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
40 |
41 | def forward(self, x):
42 | """
43 | NOTE: input tensor should range in [0,1]
44 | """
45 | out = {}
46 |
47 | out["r11"] = F.relu(self.conv1_1(x))
48 | out["r12"] = F.relu(self.conv1_2(out["r11"]))
49 | out["p1"] = self.pool1(out["r12"])
50 | out["r21"] = F.relu(self.conv2_1(out["p1"]))
51 | out["r22"] = F.relu(self.conv2_2(out["r21"]))
52 | out["p2"] = self.pool2(out["r22"])
53 | out["r31"] = F.relu(self.conv3_1(out["p2"]))
54 | out["r32"] = F.relu(self.conv3_2(out["r31"]))
55 | out["r33"] = F.relu(self.conv3_3(out["r32"]))
56 | out["r34"] = F.relu(self.conv3_4(out["r33"]))
57 | out["p3"] = self.pool3(out["r34"])
58 | out["r41"] = F.relu(self.conv4_1(out["p3"]))
59 | out["r42"] = F.relu(self.conv4_2(out["r41"]))
60 | out["r43"] = F.relu(self.conv4_3(out["r42"]))
61 | out["r44"] = F.relu(self.conv4_4(out["r43"]))
62 |
63 | return out["r44"]
64 |
65 |
66 | class Decoder(nn.Module):
67 | def __init__(self, ic):
68 | super(Decoder, self).__init__()
69 | self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1))
70 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
71 | self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64)
72 | self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
73 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
74 | self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128)
75 | self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
76 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
77 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
78 | self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256)
79 | self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
80 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
81 | self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
82 | self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
83 | self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
84 | self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
85 | self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
86 | self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
87 | self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
88 | self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1)
89 | self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1)
90 | self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1)
91 | self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1)
92 | self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1)
93 | self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1)
94 | self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1)
95 | self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1)
96 | self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1)
97 | self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1)
98 | self.conv10_ab = nn.Conv2d(128, 3, 1, 1)
99 |
100 | # add self.relux_x
101 | self.relu1_1 = nn.ReLU()
102 | self.relu1_2 = nn.ReLU()
103 | self.relu2_1 = nn.ReLU()
104 | self.relu2_2 = nn.ReLU()
105 | self.relu3_1 = nn.ReLU()
106 | self.relu3_2 = nn.ReLU()
107 | self.relu3_3 = nn.ReLU()
108 | self.relu4_1 = nn.ReLU()
109 | self.relu4_2 = nn.ReLU()
110 | self.relu4_3 = nn.ReLU()
111 | self.relu5_1 = nn.ReLU()
112 | self.relu5_2 = nn.ReLU()
113 | self.relu5_3 = nn.ReLU()
114 | self.relu6_1 = nn.ReLU()
115 | self.relu6_2 = nn.ReLU()
116 | self.relu6_3 = nn.ReLU()
117 | self.relu7_1 = nn.ReLU()
118 | self.relu7_2 = nn.ReLU()
119 | self.relu7_3 = nn.ReLU()
120 | self.relu8_1_comb = nn.ReLU()
121 | self.relu8_2 = nn.ReLU()
122 | self.relu8_3 = nn.ReLU()
123 | self.relu9_1_comb = nn.ReLU()
124 | self.relu9_2 = nn.ReLU()
125 | self.relu10_1_comb = nn.ReLU()
126 | self.relu10_2 = nn.LeakyReLU(0.2, True)
127 |
128 | # print("replace all deconv with [nearest + conv]")
129 | self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1))
130 | self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1))
131 | self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1))
132 |
133 | # print("replace all batchnorm with instancenorm")
134 | self.conv1_2norm = nn.InstanceNorm2d(64)
135 | self.conv2_2norm = nn.InstanceNorm2d(128)
136 | self.conv3_3norm = nn.InstanceNorm2d(256)
137 | self.conv4_3norm = nn.InstanceNorm2d(512)
138 | self.conv5_3norm = nn.InstanceNorm2d(512)
139 | self.conv6_3norm = nn.InstanceNorm2d(512)
140 | self.conv7_3norm = nn.InstanceNorm2d(512)
141 | self.conv8_3norm = nn.InstanceNorm2d(256)
142 | self.conv9_2norm = nn.InstanceNorm2d(128)
143 |
144 | def forward(self, x):
145 | """ x: gray image (1 channel), ab(2 channel), ab_err, ba_err"""
146 | conv1_1 = self.relu1_1(self.conv1_1(x))
147 | conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
148 | conv1_2norm = self.conv1_2norm(conv1_2)
149 | conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm)
150 | conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss))
151 | conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
152 | conv2_2norm = self.conv2_2norm(conv2_2)
153 | conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm)
154 | conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss))
155 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
156 | conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
157 | conv3_3norm = self.conv3_3norm(conv3_3)
158 | conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm)
159 | conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss))
160 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
161 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
162 | conv4_3norm = self.conv4_3norm(conv4_3)
163 | conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm))
164 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
165 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
166 | conv5_3norm = self.conv5_3norm(conv5_3)
167 | conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm))
168 | conv6_2 = self.relu6_2(self.conv6_2(conv6_1))
169 | conv6_3 = self.relu6_3(self.conv6_3(conv6_2))
170 | conv6_3norm = self.conv6_3norm(conv6_3)
171 | conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm))
172 | conv7_2 = self.relu7_2(self.conv7_2(conv7_1))
173 | conv7_3 = self.relu7_3(self.conv7_3(conv7_2))
174 | conv7_3norm = self.conv7_3norm(conv7_3)
175 | conv8_1 = self.conv8_1(conv7_3norm)
176 | conv3_3_short = self.conv3_3_short(conv3_3norm)
177 | conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short)
178 | conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb))
179 | conv8_3 = self.relu8_3(self.conv8_3(conv8_2))
180 | conv8_3norm = self.conv8_3norm(conv8_3)
181 | conv9_1 = self.conv9_1(conv8_3norm)
182 | conv2_2_short = self.conv2_2_short(conv2_2norm)
183 | conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short)
184 | conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb))
185 | conv9_2norm = self.conv9_2norm(conv9_2)
186 | conv10_1 = self.conv10_1(conv9_2norm)
187 | conv1_2_short = self.conv1_2_short(conv1_2norm)
188 | conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short)
189 | conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb))
190 | conv10_ab = self.conv10_ab(conv10_2)
191 |
192 | return torch.tanh(conv10_ab)
193 |
--------------------------------------------------------------------------------
/trainer/AlignTrainer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 | '''
4 | @author Leslie
5 | @date 20220812
6 | '''
7 | import torch
8 |
9 | from trainer.ModelTrainer import ModelTrainer
10 | from model.AlignModule.lib import *
11 | from model.AlignModule.discriminator import Discriminator
12 | from itertools import chain
13 | from utils.utils import *
14 | import torch.nn.functional as F
15 | from model.AlignModule.loss import *
16 | import random
17 | import torch.distributed as dist
18 |
19 | class AlignTrainer(ModelTrainer):
20 |
21 | def __init__(self, args):
22 | super().__init__(args)
23 | self.device = 'cpu'
24 | if torch.cuda.is_available():
25 | self.device = 'cuda'
26 |
27 | self.Epor = PorEncoder(args).to(self.device)
28 | self.Eid = IDEncoder(args.id_model).to(self.device)
29 | self.Epose = PoseEncoder(args).to(self.device)
30 | self.Eexp = ExpEncoder(args).to(self.device)
31 | self.netG = Generator(args).to(self.device)
32 |
33 | self.netD = Discriminator(in_channels=3).to(self.device)
34 |
35 |
36 | self.optimG,self.optimD = self.create_optimizer()
37 |
38 | if args.pretrain_path is not None:
39 | self.loadParameters(args.pretrain_path)
40 |
41 | if args.dist:
42 | self.netG,self.netG_module = self.use_ddp(self.netG)
43 | self.Eexp,self.Eexp_module = self.use_ddp(self.Eexp)
44 | self.Epor,self.Epor_module = self.use_ddp(self.Epor)
45 | self.Epose,self.Epose_module = self.use_ddp(self.Epose)
46 | self.netD,self.netD_module = self.use_ddp(self.netD)
47 | else:
48 | self.netG_module = self.netG
49 | self.Eexp_module = self.Eexp
50 | self.Epor_module = self.Epor
51 | self.Epose_module = self.Epose
52 | self.netD_module = self.netD
53 |
54 | if self.args.per_loss:
55 | self.perLoss = PerceptualLoss(args.per_model).to(self.device)
56 | self.perLoss.eval()
57 |
58 | if self.args.rec_loss:
59 | self.L1Loss = torch.nn.L1Loss()
60 | self.Eid.eval()
61 |
62 | def create_optimizer(self):
63 | g_optim = torch.optim.Adam(
64 | chain(self.Epor.parameters(),self.Eexp.parameters(),
65 | self.Epose.parameters(),self.netG.parameters()),
66 | lr=self.args.g_lr,
67 | betas=(self.args.beta1, self.args.beta2),
68 | )
69 | d_optim = torch.optim.Adam(
70 | self.netD.parameters(),
71 | lr=self.args.d_lr,
72 | betas=(self.args.beta1, self.args.beta2),
73 | )
74 |
75 | return g_optim,d_optim
76 |
77 |
78 | def run_single_step(self, data, steps):
79 | self.netG.train()
80 | self.Epor.train()
81 | self.Epose.train()
82 | self.Eexp.train()
83 | super().run_single_step(data, steps)
84 |
85 |
86 | def run_discriminator_one_step(self, data,step):
87 |
88 | D_losses = {}
89 | requires_grad(self.netG, False)
90 | requires_grad(self.Epor, False)
91 | requires_grad(self.Epose, False)
92 | requires_grad(self.Eexp, False)
93 | requires_grad(self.netD, True)
94 |
95 | xs,xt,crop_xt,gt = data
96 | xg = self.forward(xs,crop_xt,xt)
97 | fake_pred = self.netD(xg)
98 | real_pred = self.netD(gt)
99 | d_loss = compute_dis_loss(fake_pred, real_pred,D_losses)
100 | D_losses['d'] = d_loss
101 |
102 | self.optimD.zero_grad()
103 | d_loss.backward()
104 | self.optimD.step()
105 |
106 | self.d_losses = D_losses
107 |
108 |
109 | def run_generator_one_step(self, data,step):
110 |
111 |
112 | requires_grad(self.netG, True)
113 | requires_grad(self.Epor, True)
114 | requires_grad(self.Epose, True)
115 | requires_grad(self.Eexp, True)
116 | requires_grad(self.netD, False)
117 |
118 | xs,xt,crop_xt,gt = data
119 | G_losses,loss,xg = self.compute_g_loss(xs,crop_xt,xt,gt)
120 | self.optimG.zero_grad()
121 | loss.mean().backward()
122 | self.optimG.step()
123 |
124 | self.g_losses = G_losses
125 |
126 | self.generator = [xs[:,0].detach() if len(xs.shape)>4 else xs.detach(),xt.detach(),xg.detach(),gt.detach()]
127 |
128 |
129 | def evalution(self,test_loader,steps,epoch):
130 |
131 | loss_dict = {}
132 | index = random.randint(0,len(test_loader)-1)
133 | counter = 0
134 | with torch.no_grad():
135 | for i,data in enumerate(test_loader):
136 |
137 | data = self.process_input(data)
138 | xs,xt,crop_xt,gt = data
139 | G_losses,losses,xg = self.compute_g_loss(xs,crop_xt,xt,gt)
140 | for k,v in G_losses.items():
141 | loss_dict[k] = loss_dict.get(k,0) + v.detach()
142 | if i == index and self.args.rank == 0 :
143 |
144 | show_data = [xs[:,0],xt,xg,gt]
145 | self.val_vis.display_current_results(self.select_img(show_data),steps)
146 | counter += 1
147 |
148 |
149 | for key,val in loss_dict.items():
150 | loss_dict[key] /= counter
151 |
152 | if self.args.dist:
153 | # if self.args.rank == 0 :
154 | dist_losses = loss_dict.copy()
155 | for key,val in loss_dict.items():
156 |
157 | dist.reduce(dist_losses[key],0)
158 | value = dist_losses[key].item()
159 | loss_dict[key] = value / self.args.world_size
160 |
161 | if self.args.rank == 0 :
162 | self.val_vis.plot_current_errors(loss_dict,steps)
163 | self.val_vis.print_current_errors(epoch+1,0,loss_dict,0)
164 |
165 | return loss_dict
166 |
167 |
168 | def forward(self,xs,crop_xt,xt):
169 |
170 | por_f = self.Epor(xs)
171 | id_f = self.Eid(self.process_id_input(xs,crop=True))
172 |
173 | pose_f = self.Epose(xt)
174 | exp_f = self.Eexp(self.process_id_input(crop_xt,size=256))
175 |
176 | xg = self.netG(por_f,id_f,pose_f,exp_f)
177 |
178 | return xg
179 |
180 | def compute_g_loss(self,xs,crop_xt,xt,gt):
181 | G_losses = {}
182 | loss = 0
183 | xg = self.forward(xs,crop_xt,xt)
184 | fake_pred = self.netD(xg)
185 | gan_loss = compute_gan_loss(fake_pred) * self.args.lambda_gan
186 | G_losses['g_losses'] = gan_loss
187 | loss += gan_loss
188 |
189 | if self.args.rec_loss:
190 | rec_loss = self.L1Loss(xg,gt) * self.args.lambda_rec
191 | G_losses['rec_loss'] = rec_loss
192 | loss += rec_loss
193 |
194 | if self.args.id_loss:
195 | fake_id_f = self.Eid(self.process_id_input(xg,crop=True))
196 | real_id_f = self.Eid(self.process_id_input(gt,crop=True))
197 | id_loss = compute_id_loss(fake_id_f,real_id_f).mean() * self.args.lambda_id
198 | G_losses['id_loss'] = id_loss
199 | loss += id_loss
200 |
201 | if self.args.per_loss:
202 | per_loss = self.perLoss(xg,gt) * self.args.lambda_per
203 | G_losses['per_loss'] = per_loss
204 | loss += per_loss
205 |
206 | return G_losses,loss,xg
207 |
208 | @staticmethod
209 | def process_id_input(x,crop=False,size=112):
210 | c,h,w = x.shape[-3:]
211 | batch = x.shape[0]
212 | scale = 0.4 / 1.8
213 | if crop:
214 | crop_x = x[...,int(h*scale):int(-h*scale),int(w*scale):int(-w*scale)]
215 | else:
216 | crop_x = x
217 | if len(x.shape) > 4:
218 | resize_x = F.adaptive_avg_pool2d(crop_x.view(-1,*crop_x.shape[-3:]),size)
219 | resize_x = resize_x.view(batch,-1,c,size,size)
220 | else:
221 | resize_x = F.adaptive_avg_pool2d(crop_x,size)
222 | return resize_x
223 | def get_latest_losses(self):
224 | return {**self.g_losses,**self.d_losses}
225 |
226 | def get_latest_generated(self):
227 | return self.generator
228 |
229 | def loadParameters(self,path):
230 | ckpt = torch.load(path, map_location=lambda storage, loc: storage)
231 | self.netG.load_state_dict(ckpt['G'],strict=False)
232 | self.Eexp.load_state_dict(ckpt['Eexp'],strict=False)
233 | self.Eid.load_state_dict(ckpt['Eid'],strict=False)
234 | self.Epor.load_state_dict(ckpt['Epor'],strict=False)
235 | self.Epose.load_state_dict(ckpt['Epose'],strict=False)
236 | # self.netD.load_state_dict(ckpt['D'],strict=False)
237 | self.optimG.load_state_dict(ckpt['g_optim'])
238 | # self.optimD.load_state_dict(ckpt['d_optim'])
239 |
240 | def saveParameters(self,path):
241 | torch.save(
242 | {
243 | "G": self.netG_module.state_dict(),
244 | 'D':self.netD_module.state_dict(),
245 | "Eexp": self.Eexp_module.state_dict(),
246 | "Eid":self.Eid.state_dict(),
247 | 'Epor':self.Epor_module.state_dict(),
248 | 'Epose':self.Epose_module.state_dict(),
249 | "g_optim": self.optimG.state_dict(),
250 | "d_optim": self.optimD.state_dict(),
251 | "args": self.args,
252 | },
253 | path
254 | )
255 |
256 | def get_lr(self):
257 | return self.optimG.state_dict()['param_groups'][0]['lr']
258 |
259 |
260 | def select_img(self, data, name='fake', axis=2):
261 | data = [F.adaptive_avg_pool2d(x,self.args.output_image_size) for x in data]
262 | return super().select_img(data, name, axis)
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
--------------------------------------------------------------------------------
/model/third/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision
9 |
10 | from .resnet import Resnet18
11 | # from modules.bn import InPlaceABNSync as BatchNorm2d
12 |
13 |
14 | class ConvBNReLU(nn.Module):
15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16 | super(ConvBNReLU, self).__init__()
17 | self.conv = nn.Conv2d(in_chan,
18 | out_chan,
19 | kernel_size = ks,
20 | stride = stride,
21 | padding = padding,
22 | bias = False)
23 | self.bn = nn.BatchNorm2d(out_chan)
24 | self.init_weight()
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | x = F.relu(self.bn(x))
29 | return x
30 |
31 | def init_weight(self):
32 | for ly in self.children():
33 | if isinstance(ly, nn.Conv2d):
34 | nn.init.kaiming_normal_(ly.weight, a=1)
35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36 |
37 | class BiSeNetOutput(nn.Module):
38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39 | super(BiSeNetOutput, self).__init__()
40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42 | self.init_weight()
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 | x = self.conv_out(x)
47 | return x
48 |
49 | def init_weight(self):
50 | for ly in self.children():
51 | if isinstance(ly, nn.Conv2d):
52 | nn.init.kaiming_normal_(ly.weight, a=1)
53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54 |
55 | def get_params(self):
56 | wd_params, nowd_params = [], []
57 | for name, module in self.named_modules():
58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59 | wd_params.append(module.weight)
60 | if not module.bias is None:
61 | nowd_params.append(module.bias)
62 | elif isinstance(module, nn.BatchNorm2d):
63 | nowd_params += list(module.parameters())
64 | return wd_params, nowd_params
65 |
66 |
67 | class AttentionRefinementModule(nn.Module):
68 | def __init__(self, in_chan, out_chan, *args, **kwargs):
69 | super(AttentionRefinementModule, self).__init__()
70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72 | self.bn_atten = nn.BatchNorm2d(out_chan)
73 | self.sigmoid_atten = nn.Sigmoid()
74 | self.init_weight()
75 |
76 | def forward(self, x):
77 | feat = self.conv(x)
78 | atten = F.avg_pool2d(feat, feat.size()[2:])
79 | atten = self.conv_atten(atten)
80 | atten = self.bn_atten(atten)
81 | atten = self.sigmoid_atten(atten)
82 | out = torch.mul(feat, atten)
83 | return out
84 |
85 | def init_weight(self):
86 | for ly in self.children():
87 | if isinstance(ly, nn.Conv2d):
88 | nn.init.kaiming_normal_(ly.weight, a=1)
89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90 |
91 |
92 | class ContextPath(nn.Module):
93 | def __init__(self, *args, **kwargs):
94 | super(ContextPath, self).__init__()
95 | self.resnet = Resnet18()
96 | self.arm16 = AttentionRefinementModule(256, 128)
97 | self.arm32 = AttentionRefinementModule(512, 128)
98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101 |
102 | self.init_weight()
103 |
104 | def forward(self, x):
105 | H0, W0 = x.size()[2:]
106 | feat8, feat16, feat32 = self.resnet(x)
107 | H8, W8 = feat8.size()[2:]
108 | H16, W16 = feat16.size()[2:]
109 | H32, W32 = feat32.size()[2:]
110 |
111 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
112 | avg = self.conv_avg(avg)
113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114 |
115 | feat32_arm = self.arm32(feat32)
116 | feat32_sum = feat32_arm + avg_up
117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118 | feat32_up = self.conv_head32(feat32_up)
119 |
120 | feat16_arm = self.arm16(feat16)
121 | feat16_sum = feat16_arm + feat32_up
122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123 | feat16_up = self.conv_head16(feat16_up)
124 |
125 | return feat8, feat16_up, feat32_up # x8, x8, x16
126 |
127 | def init_weight(self):
128 | for ly in self.children():
129 | if isinstance(ly, nn.Conv2d):
130 | nn.init.kaiming_normal_(ly.weight, a=1)
131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132 |
133 | def get_params(self):
134 | wd_params, nowd_params = [], []
135 | for name, module in self.named_modules():
136 | if isinstance(module, (nn.Linear, nn.Conv2d)):
137 | wd_params.append(module.weight)
138 | if not module.bias is None:
139 | nowd_params.append(module.bias)
140 | elif isinstance(module, nn.BatchNorm2d):
141 | nowd_params += list(module.parameters())
142 | return wd_params, nowd_params
143 |
144 |
145 | ### This is not used, since I replace this with the resnet feature with the same size
146 | class SpatialPath(nn.Module):
147 | def __init__(self, *args, **kwargs):
148 | super(SpatialPath, self).__init__()
149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153 | self.init_weight()
154 |
155 | def forward(self, x):
156 | feat = self.conv1(x)
157 | feat = self.conv2(feat)
158 | feat = self.conv3(feat)
159 | feat = self.conv_out(feat)
160 | return feat
161 |
162 | def init_weight(self):
163 | for ly in self.children():
164 | if isinstance(ly, nn.Conv2d):
165 | nn.init.kaiming_normal_(ly.weight, a=1)
166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167 |
168 | def get_params(self):
169 | wd_params, nowd_params = [], []
170 | for name, module in self.named_modules():
171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172 | wd_params.append(module.weight)
173 | if not module.bias is None:
174 | nowd_params.append(module.bias)
175 | elif isinstance(module, nn.BatchNorm2d):
176 | nowd_params += list(module.parameters())
177 | return wd_params, nowd_params
178 |
179 |
180 | class FeatureFusionModule(nn.Module):
181 | def __init__(self, in_chan, out_chan, *args, **kwargs):
182 | super(FeatureFusionModule, self).__init__()
183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184 | self.conv1 = nn.Conv2d(out_chan,
185 | out_chan//4,
186 | kernel_size = 1,
187 | stride = 1,
188 | padding = 0,
189 | bias = False)
190 | self.conv2 = nn.Conv2d(out_chan//4,
191 | out_chan,
192 | kernel_size = 1,
193 | stride = 1,
194 | padding = 0,
195 | bias = False)
196 | self.relu = nn.ReLU(inplace=True)
197 | self.sigmoid = nn.Sigmoid()
198 | self.init_weight()
199 |
200 | def forward(self, fsp, fcp):
201 | fcat = torch.cat([fsp, fcp], dim=1)
202 | feat = self.convblk(fcat)
203 | atten = F.avg_pool2d(feat, feat.size()[2:])
204 | atten = self.conv1(atten)
205 | atten = self.relu(atten)
206 | atten = self.conv2(atten)
207 | atten = self.sigmoid(atten)
208 | feat_atten = torch.mul(feat, atten)
209 | feat_out = feat_atten + feat
210 | return feat_out
211 |
212 | def init_weight(self):
213 | for ly in self.children():
214 | if isinstance(ly, nn.Conv2d):
215 | nn.init.kaiming_normal_(ly.weight, a=1)
216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217 |
218 | def get_params(self):
219 | wd_params, nowd_params = [], []
220 | for name, module in self.named_modules():
221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222 | wd_params.append(module.weight)
223 | if not module.bias is None:
224 | nowd_params.append(module.bias)
225 | elif isinstance(module, nn.BatchNorm2d):
226 | nowd_params += list(module.parameters())
227 | return wd_params, nowd_params
228 |
229 |
230 | class BiSeNet(nn.Module):
231 | def __init__(self, n_classes, *args, **kwargs):
232 | super(BiSeNet, self).__init__()
233 | self.cp = ContextPath()
234 | ## here self.sp is deleted
235 | self.ffm = FeatureFusionModule(256, 256)
236 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239 | self.init_weight()
240 |
241 | def forward(self, x):
242 | H, W = x.size()[2:]
243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245 | feat_fuse = self.ffm(feat_sp, feat_cp8)
246 |
247 | feat_out = self.conv_out(feat_fuse)
248 | feat_out16 = self.conv_out16(feat_cp8)
249 | feat_out32 = self.conv_out32(feat_cp16)
250 |
251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254 | return feat_out, feat_out16, feat_out32
255 |
256 | def init_weight(self):
257 | for ly in self.children():
258 | if isinstance(ly, nn.Conv2d):
259 | nn.init.kaiming_normal_(ly.weight, a=1)
260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261 |
262 | def get_params(self):
263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264 | for name, child in self.named_children():
265 | child_wd_params, child_nowd_params = child.get_params()
266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267 | lr_mul_wd_params += child_wd_params
268 | lr_mul_nowd_params += child_nowd_params
269 | else:
270 | wd_params += child_wd_params
271 | nowd_params += child_nowd_params
272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273 |
274 |
275 | if __name__ == "__main__":
276 | net = BiSeNet(19)
277 | net.cuda()
278 | net.eval()
279 | in_ten = torch.randn(16, 3, 640, 480).cuda()
280 | out, out16, out32 = net(in_ten)
281 | print(out.shape)
282 |
283 | net.get_params()
284 |
--------------------------------------------------------------------------------
/model/AlignModule/lib/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn.utils import spectral_norm
4 |
5 |
6 | class AdaptiveNorm2d(nn.Module):
7 | def __init__(self, num_features, norm_layer='in', eps=1e-4):
8 | super(AdaptiveNorm2d, self).__init__()
9 | self.num_features = num_features
10 | self.weight = self.bias = None
11 | if 'in' in norm_layer:
12 | self.norm_layer = nn.InstanceNorm2d(num_features, eps=eps, affine=False)
13 | elif 'bn' in norm_layer:
14 | self.norm_layer = SyncBatchNorm(num_features, momentum=1.0, eps=eps, affine=False)
15 |
16 | self.delete_weight_on_forward = True
17 |
18 | def forward(self, input):
19 | out = self.norm_layer(input)
20 | output = out * self.weight[:, :, None, None] + self.bias[:, :, None, None]
21 |
22 | # To save GPU memory
23 | if self.delete_weight_on_forward:
24 | self.weight = self.bias = None
25 |
26 | return output
27 |
28 |
29 | class AdaptiveNorm2dTrainable(nn.Module):
30 | def __init__(self, num_features, norm_layer='in', eps=1e-4):
31 | super(AdaptiveNorm2dTrainable, self).__init__()
32 | self.num_features = num_features
33 | if 'in' in norm_layer:
34 | self.norm_layer = nn.InstanceNorm2d(num_features, eps=eps, affine=False)
35 |
36 | def forward(self, input):
37 | out = self.norm_layer(input)
38 | t = out.shape[0] // self.weight.shape[0]
39 | output = out * self.weight + self.bias
40 | return output
41 |
42 | def assign_params(self, weight, bias):
43 | self.weight = torch.nn.Parameter(weight.view(1, -1, 1, 1))
44 | self.bias = torch.nn.Parameter(bias.view(1, -1, 1, 1))
45 |
46 |
47 | class ResBlock(nn.Module):
48 | def __init__(self, in_channels, out_channels, padding, upsample, downsample,
49 | norm_layer, activation=nn.ReLU, gated=False):
50 | super(ResBlock, self).__init__()
51 | normalize = norm_layer != 'none'
52 | bias = not normalize
53 |
54 | # if norm_layer == 'bn':
55 | # # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4)
56 | # # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4)
57 | # pass
58 | if norm_layer == 'in':
59 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True)
60 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True)
61 | elif 'ada' in norm_layer:
62 | norm0 = AdaptiveNorm2d(in_channels, norm_layer)
63 | norm1 = AdaptiveNorm2d(out_channels, norm_layer)
64 | elif 'tra' in norm_layer:
65 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer)
66 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer)
67 | elif normalize:
68 | raise Exception('ResBlock: Incorrect `norm_layer` parameter')
69 |
70 | layers = []
71 | if normalize:
72 | layers.append(norm0)
73 | layers.append(activation(inplace=True))
74 | if upsample:
75 | layers.append(nn.Upsample(scale_factor=2))
76 | layers.extend([
77 | nn.Sequential() if padding is nn.ZeroPad2d else padding(1),
78 | spectral_norm(
79 | nn.Conv2d(in_channels, out_channels, 3, 1, 1 if padding is nn.ZeroPad2d else 0, bias=bias),
80 | eps=1e-4)])
81 | if normalize:
82 | layers.append(norm1)
83 | layers.extend([
84 | activation(inplace=True),
85 | nn.Sequential() if padding is nn.ZeroPad2d else padding(1),
86 | spectral_norm(
87 | nn.Conv2d(out_channels, out_channels, 3, 1, 1 if padding is nn.ZeroPad2d else 0, bias=bias),
88 | eps=1e-4)])
89 | if downsample:
90 | layers.append(nn.AvgPool2d(2))
91 | self.block = nn.Sequential(*layers)
92 |
93 | self.skip = None
94 | if in_channels != out_channels or upsample or downsample:
95 | layers = []
96 | if upsample:
97 | layers.append(nn.Upsample(scale_factor=2))
98 | layers.append(spectral_norm(
99 | nn.Conv2d(in_channels, out_channels, 1),
100 | eps=1e-4))
101 | if downsample:
102 | layers.append(nn.AvgPool2d(2))
103 | self.skip = nn.Sequential(*layers)
104 |
105 | def forward(self, input):
106 | out = self.block(input)
107 | if self.skip is not None:
108 | output = out + self.skip(input)
109 | else:
110 | output = out + input
111 | return output
112 |
113 | class channelShuffle(nn.Module):
114 | def __init__(self,groups):
115 | super(channelShuffle, self).__init__()
116 | self.groups=groups
117 |
118 | def forward(self,x):
119 | batchsize, num_channels, height, width = x.data.size()
120 |
121 | # batchsize = x.shape[0]
122 | # num_channels = x.shape[1]
123 | # height = x.shape[2]
124 | # width = x.shape[3]
125 |
126 | channels_per_group = num_channels // self.groups
127 |
128 | # reshape
129 | x = x.view(batchsize, self.groups, channels_per_group, height, width)
130 |
131 | # transpose
132 | # - contiguous() required if transpose() is used before view().
133 | # See https://github.com/pytorch/pytorch/issues/764
134 | x = torch.transpose(x, 1, 2).contiguous()
135 |
136 | # flatten
137 | x = x.view(batchsize, -1, height, width)
138 |
139 | return x
140 |
141 |
142 | class shuffleConv(nn.Module):
143 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
144 | super(shuffleConv, self).__init__()
145 | self.in_channels=in_channels
146 | self.out_channels=out_channels
147 | self.stride=stride
148 | self.padding=padding
149 | groups=4
150 | block=[]
151 | if (in_channels%groups==0) and (out_channels%groups==0):
152 | block.append(spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,padding=0, groups=groups),eps=1e-4))
153 | block.append(nn.ReLU6(inplace=True))
154 | block.append(channelShuffle(groups=groups))
155 | block.append(spectral_norm(nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=3,padding=1, groups=groups),eps=1e-4))
156 | block.append(nn.ReLU6(inplace=True))
157 | block.append(spectral_norm(nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=1,padding=0, groups=groups),eps=1e-4))
158 | else:
159 | block.append(spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3,padding=1),eps=1e-4))
160 | self.block=nn.Sequential(*block)
161 |
162 | def forward(self,x):
163 | x=self.block(x)
164 | return x
165 |
166 |
167 | class ResBlockShuffle(nn.Module):
168 | def __init__(self, in_channels, out_channels, padding, upsample, downsample,
169 | norm_layer, activation=nn.ReLU, gated=False):
170 | super(ResBlockShuffle, self).__init__()
171 | normalize = norm_layer != 'none'
172 | bias = not normalize
173 |
174 | # if norm_layer == 'bn':
175 | # # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4)
176 | # # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4)
177 | # pass
178 | if norm_layer == 'in':
179 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True)
180 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True)
181 | elif 'ada' in norm_layer:
182 | norm0 = AdaptiveNorm2d(in_channels, norm_layer)
183 | norm1 = AdaptiveNorm2d(out_channels, norm_layer)
184 | elif 'tra' in norm_layer:
185 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer)
186 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer)
187 | elif normalize:
188 | raise Exception('ResBlock: Incorrect `norm_layer` parameter')
189 |
190 | layers = []
191 | if normalize:
192 | layers.append(norm0)
193 | layers.append(activation(inplace=True))
194 | if upsample:
195 | layers.append(nn.Upsample(scale_factor=2))
196 | layers.extend([
197 | #padding(1),
198 | #spectral_norm(
199 | shuffleConv(in_channels, out_channels, 3, 1, 0, bias=bias)#,
200 | # eps=1e-4)
201 | ])
202 | if normalize:
203 | layers.append(norm1)
204 | layers.extend([
205 | activation(inplace=True),
206 | #padding(1),
207 | #spectral_norm(
208 | shuffleConv(out_channels, out_channels, 3, 1, 0, bias=bias)#,
209 | # eps=1e-4)
210 | ])
211 | if downsample:
212 | layers.append(nn.AvgPool2d(2))
213 | self.block = nn.Sequential(*layers)
214 |
215 | self.skip = None
216 | if in_channels != out_channels or upsample or downsample:
217 | layers = []
218 | if upsample:
219 | layers.append(nn.Upsample(scale_factor=2))
220 | layers.append(
221 | #spectral_norm(
222 | shuffleConv(in_channels, out_channels, 1)#,
223 | # eps=1e-4)
224 | )
225 | if downsample:
226 | layers.append(nn.AvgPool2d(2))
227 | self.skip = nn.Sequential(*layers)
228 |
229 | def forward(self, input):
230 | out = self.block(input)
231 | if self.skip is not None:
232 | output = out + self.skip(input)
233 | else:
234 | output = out + input
235 | return output
236 |
237 |
238 |
239 | class ResBlockV2(nn.Module):
240 | def __init__(self, in_channels, out_channels, stride, groups,
241 | resize_layer, norm_layer, activation):
242 | super(ResBlockV2, self).__init__()
243 | upsampling_layers = {
244 | 'nearest': lambda: nn.Upsample(scale_factor=stride, mode='nearest')
245 | }
246 | downsampling_layers = {
247 | 'avgpool': lambda: nn.AvgPool2d(stride)
248 | }
249 | norm_layers = {
250 | 'bn': lambda num_features: SyncBatchNorm(num_features, momentum=1.0, eps=1e-4),
251 | 'in': lambda num_features: nn.InstanceNorm2d(num_features, eps=1e-4, affine=True),
252 | 'adabn': lambda num_features: AdaptiveNorm2d(num_features, 'bn'),
253 | 'adain': lambda num_features: AdaptiveNorm2d(num_features, 'in')
254 | }
255 | normalize = norm_layer != 'none'
256 | bias = not normalize
257 | upsample = resize_layer in upsampling_layers
258 | downsample = resize_layer in downsampling_layers
259 | if normalize:
260 | norm_layer = norm_layers[norm_layer]
261 |
262 | layers = []
263 | if normalize:
264 | layers.append(norm_layer(in_channels))
265 | layers.append(activation())
266 | if upsample:
267 | layers.append(nn.Upsample(scale_factor=2))
268 | layers.extend([
269 | spectral_norm(
270 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=bias),
271 | eps=1e-4)])
272 | if normalize:
273 | layers.append(norm_layer(out_channels))
274 | layers.extend([
275 | activation(),
276 | spectral_norm(
277 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias),
278 | eps=1e-4)])
279 | if downsample:
280 | layers.append(nn.AvgPool2d(2))
281 | self.block = nn.Sequential(*layers)
282 |
283 | self.skip = None
284 | if in_channels != out_channels or upsample or downsample:
285 | layers = []
286 | if upsample:
287 | layers.append(nn.Upsample(scale_factor=2))
288 | layers.append(spectral_norm(
289 | nn.Conv2d(in_channels, out_channels, 1),
290 | eps=1e-4))
291 | if downsample:
292 | layers.append(nn.AvgPool2d(2))
293 | self.skip = nn.Sequential(*layers)
294 |
295 | def forward(self, input):
296 | out = self.block(input)
297 | if self.skip is not None:
298 | output = out + self.skip(input)
299 | else:
300 | output = out + input
301 | return output
302 |
303 | class ResBlockV2Shuffle(nn.Module):
304 | def __init__(self, in_channels, out_channels, stride, groups,
305 | resize_layer, norm_layer, activation):
306 | super(ResBlockV2Shuffle, self).__init__()
307 | upsampling_layers = {
308 | 'nearest': lambda: nn.Upsample(scale_factor=stride, mode='nearest')
309 | }
310 | downsampling_layers = {
311 | 'avgpool': lambda: nn.AvgPool2d(stride)
312 | }
313 | norm_layers = {
314 | 'bn': lambda num_features: SyncBatchNorm(num_features, momentum=1.0, eps=1e-4),
315 | 'in': lambda num_features: nn.InstanceNorm2d(num_features, eps=1e-4, affine=True),
316 | 'adabn': lambda num_features: AdaptiveNorm2d(num_features, 'bn'),
317 | 'adain': lambda num_features: AdaptiveNorm2d(num_features, 'in')
318 | }
319 | normalize = norm_layer != 'none'
320 | bias = not normalize
321 | upsample = resize_layer in upsampling_layers
322 | downsample = resize_layer in downsampling_layers
323 | if normalize:
324 | norm_layer = norm_layers[norm_layer]
325 |
326 | layers = []
327 | if normalize:
328 | layers.append(norm_layer(in_channels))
329 | layers.append(activation())
330 | if upsample:
331 | layers.append(nn.Upsample(scale_factor=2))
332 | layers.extend([
333 | #spectral_norm(
334 | shuffleConv(in_channels, out_channels, 3, 1, 1, bias=bias)#,
335 | # eps=1e-4)
336 | ])
337 | if normalize:
338 | layers.append(norm_layer(out_channels))
339 | layers.extend([
340 | activation(),
341 | #spectral_norm(
342 | shuffleConv(out_channels, out_channels, 3, 1, 1, bias=bias)#,
343 | # eps=1e-4)
344 | ])
345 | if downsample:
346 | layers.append(nn.AvgPool2d(2))
347 | self.block = nn.Sequential(*layers)
348 |
349 | self.skip = None
350 | if in_channels != out_channels or upsample or downsample:
351 | layers = []
352 | if upsample:
353 | layers.append(nn.Upsample(scale_factor=2))
354 | layers.append(#spectral_norm(
355 | shuffleConv(in_channels, out_channels, 1)#,
356 | # eps=1e-4)
357 | )
358 | if downsample:
359 | layers.append(nn.AvgPool2d(2))
360 | self.skip = nn.Sequential(*layers)
361 |
362 | def forward(self, input):
363 | out = self.block(input)
364 | if self.skip is not None:
365 | output = out + self.skip(input)
366 | else:
367 | output = out + input
368 | return output
369 |
370 |
371 |
372 | class GatedBlock(nn.Module):
373 | def __init__(self, in_channels, out_channels, act_fun, kernel_size, stride=1, padding=0, bias=True):
374 | super(GatedBlock, self).__init__()
375 | self.conv = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
376 | eps=1e-4)
377 | self.gate = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
378 | eps=1e-4)
379 | self.act_fun = act_fun()
380 | self.gate_act_fun = nn.Sigmoid()
381 |
382 | def forward(self, x):
383 | out = self.conv(x)
384 | out = self.act_fun(out)
385 |
386 | mask = self.gate(x)
387 | mask = self.gate_act_fun(mask)
388 |
389 | out_masked = out * mask
390 | return out_masked
391 |
392 |
393 | class GatedResBlock(nn.Module):
394 | def __init__(self, in_channels, out_channels, padding, upsample, downsample,
395 | norm_layer, activation=nn.ReLU):
396 | super(GatedResBlock, self).__init__()
397 | normalize = norm_layer != 'none'
398 | bias = not normalize
399 |
400 | if norm_layer == 'in':
401 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True)
402 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True)
403 | elif 'ada' in norm_layer:
404 | norm0 = AdaptiveNorm2d(in_channels, norm_layer)
405 | norm1 = AdaptiveNorm2d(out_channels, norm_layer)
406 | elif 'tra' in norm_layer:
407 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer)
408 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer)
409 | elif normalize:
410 | raise Exception('ResBlock: Incorrect `norm_layer` parameter')
411 |
412 | main_layers = []
413 |
414 | if normalize:
415 | main_layers.append(norm0)
416 | if upsample:
417 | main_layers.append(nn.Upsample(scale_factor=2))
418 |
419 | main_layers.extend([
420 | padding(1),
421 | GatedBlock(in_channels, out_channels, activation, 3, 1, 0, bias=bias)])
422 |
423 | if normalize:
424 | main_layers.append(norm1)
425 | main_layers.extend([
426 | padding(1),
427 | GatedBlock(out_channels, out_channels, activation, 3, 1, 0, bias=bias)])
428 | if downsample:
429 | main_layers.append(nn.AvgPool2d(2))
430 |
431 | self.main_pipe = nn.Sequential(*main_layers)
432 |
433 | self.skip_pipe = None
434 | if in_channels != out_channels or upsample or downsample:
435 | skip_layers = []
436 |
437 | if upsample:
438 | skip_layers.append(nn.Upsample(scale_factor=2))
439 |
440 | skip_layers.append(GatedBlock(in_channels, out_channels, activation, 1))
441 |
442 | if downsample:
443 | skip_layers.append(nn.AvgPool2d(2))
444 | self.skip_pipe = nn.Sequential(*skip_layers)
445 |
446 | def forward(self, input):
447 | mp_out = self.main_pipe(input)
448 | if self.skip_pipe is not None:
449 | output = mp_out + self.skip_pipe(input)
450 | else:
451 | output = mp_out + input
452 | return output
453 |
454 |
455 | class ResBlockWithoutSpectralNorms(nn.Module):
456 | def __init__(self, in_channels, out_channels, padding, upsample, downsample,
457 | norm_layer, activation=nn.ReLU):
458 | super(ResBlockWithoutSpectralNorms, self).__init__()
459 | normalize = norm_layer != 'none'
460 | bias = not normalize
461 |
462 | # if norm_layer == 'bn':
463 | # # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4)
464 | # # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4)
465 | # pass
466 | if norm_layer == 'in':
467 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True)
468 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True)
469 | elif 'ada' in norm_layer:
470 | norm0 = AdaptiveNorm2d(in_channels, norm_layer)
471 | norm1 = AdaptiveNorm2d(out_channels, norm_layer)
472 | elif 'tra' in norm_layer:
473 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer)
474 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer)
475 | elif normalize:
476 | raise Exception('ResBlock: Incorrect `norm_layer` parameter')
477 |
478 | layers = []
479 | if normalize:
480 | layers.append(norm0)
481 | layers.append(activation(inplace=True))
482 | if upsample:
483 | layers.append(nn.Upsample(scale_factor=2))
484 | layers.extend([
485 | padding(1),
486 | # spectral_norm(
487 | nn.Conv2d(in_channels, out_channels, 3, 1, 0, bias=bias) # ,
488 | # eps=1e-4)
489 | ])
490 | if normalize:
491 | layers.append(norm1)
492 | layers.extend([
493 | activation(inplace=True),
494 | padding(1),
495 | # spectral_norm(
496 | nn.Conv2d(out_channels, out_channels, 3, 1, 0, bias=bias) # ,
497 | # eps=1e-4)
498 | ])
499 | if downsample:
500 | layers.append(nn.AvgPool2d(2))
501 | self.block = nn.Sequential(*layers)
502 |
503 | self.skip = None
504 | if in_channels != out_channels or upsample or downsample:
505 | layers = []
506 | if upsample:
507 | layers.append(nn.Upsample(scale_factor=2))
508 | layers.append( # spectral_norm(
509 | nn.Conv2d(in_channels, out_channels, 1) # ,
510 | # eps=1e-4)
511 | )
512 | if downsample:
513 | layers.append(nn.AvgPool2d(2))
514 | self.skip = nn.Sequential(*layers)
515 |
516 | def forward(self, input):
517 | out = self.block(input)
518 | if self.skip is not None:
519 | output = out + self.skip(input)
520 | else:
521 | output = out + input
522 | return output
523 |
524 |
525 | class MobileNetBlock(nn.Module):
526 | def __init__(self, in_channels, out_channels, padding, upsample, downsample,
527 | norm_layer, activation=nn.ReLU6, expansion_factor=6):
528 | super(MobileNetBlock, self).__init__()
529 | normalize = norm_layer != 'none'
530 | bias = not normalize
531 |
532 | conv0 = nn.Conv2d(in_channels, int(in_channels * expansion_factor), 1)
533 | dwise = nn.Conv2d(int(in_channels * expansion_factor), int(in_channels * expansion_factor), 3,
534 | 2 if downsample else 1, 1, groups=int(in_channels * expansion_factor))
535 | conv1 = nn.Conv2d(int(in_channels * expansion_factor), out_channels, 1)
536 |
537 | if norm_layer == 'bn':
538 | # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4)
539 | # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4)
540 | pass
541 | if 'in' in norm_layer:
542 | norm0 = nn.InstanceNorm2d(int(in_channels * expansion_factor), eps=1e-4, affine=True)
543 | norm1 = nn.InstanceNorm2d(int(in_channels * expansion_factor), eps=1e-4, affine=True)
544 | norm2 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True)
545 | if 'ada' in norm_layer:
546 | norm2 = AdaptiveNorm2d(out_channels, norm_layer)
547 | elif 'tra' in norm_layer:
548 | norm2 = AdaptiveNorm2dTrainable(out_channels, norm_layer)
549 |
550 | # layers = [spectral_norm(conv0, eps=1e-4)]
551 | layers = [conv0]
552 | if normalize: layers.append(norm0)
553 | layers.append(activation(inplace=True))
554 | if upsample: layers.append(nn.Upsample(scale_factor=2))
555 | # layers.append(spectral_norm(dwise, eps=1e-4))
556 | layers.append(dwise)
557 | if normalize: layers.append(norm1)
558 | layers.extend([
559 | activation(inplace=True),
560 | # spectral_norm(
561 | conv1 # ,
562 | # eps=1e-4)
563 | ])
564 | if normalize: layers.append(norm2)
565 | self.block = nn.Sequential(*layers)
566 |
567 | self.skip = None
568 | if in_channels != out_channels or upsample or downsample:
569 | layers = []
570 | if upsample: layers.append(nn.Upsample(scale_factor=2))
571 | layers.append(
572 | # spectral_norm(
573 | nn.Conv2d(in_channels, out_channels, 1) # ,
574 | # eps=1e-4)
575 | )
576 | if downsample:
577 | layers.append(nn.AvgPool2d(2))
578 | self.skip = nn.Sequential(*layers)
579 |
580 | def forward(self, input):
581 | out = self.block(input)
582 | if self.skip is not None:
583 | output = out + self.skip(input)
584 | else:
585 | output = out + input
586 | return output
587 |
588 |
589 | class SelfAttention(nn.Module):
590 | def __init__(self, in_channels):
591 | super(SelfAttention, self).__init__()
592 | self.in_channels = in_channels
593 | self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
594 | self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
595 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
596 | self.gamma = nn.Parameter(torch.zeros(1))
597 | self.softmax = nn.Softmax(-1)
598 |
599 | def forward(self, input):
600 | b, c, h, w = input.shape
601 | query = self.query_conv(input).view(b, -1, h * w).permute(0, 2, 1) # B x HW x C/8
602 | key = self.key_conv(input).view(b, -1, h * w) # B x C/8 x HW
603 | energy = torch.bmm(query, key) # B x HW x HW
604 | attention = self.softmax(energy) # B x HW x HW
605 | value = self.value_conv(input).view(b, -1, h * w) # B x C x HW
606 |
607 | out = torch.bmm(value, attention.permute(0, 2, 1)).view(b, c, h, w)
608 | output = self.gamma * out + input
609 | return output
610 |
--------------------------------------------------------------------------------