├── .gitignore ├── assets ├── 11.png ├── 22.png └── show.png ├── model ├── AlignModule │ ├── lib │ │ ├── __init__.py │ │ ├── ExpEncoder.py │ │ ├── PoseEncoder.py │ │ ├── PorEncoder.py │ │ ├── IDEncoder.py │ │ ├── generator.py │ │ └── blocks.py │ ├── config.py │ ├── loss.py │ ├── discriminator.py │ └── module.py ├── BlendModule │ ├── config.py │ ├── generator.py │ └── module.py └── third │ ├── resnet.py │ └── model.py ├── process ├── download_weight.sh ├── filter_idfiles.py ├── download_and_process.py ├── process_raw_video.py └── process_utils.py ├── dataloader ├── DataLoader.py ├── BlendLoader.py ├── AlignLoader.py └── augmentation.py ├── LICENSE ├── utils ├── visualizer.py └── utils.py ├── train.py ├── inference.py ├── ReadMe.md └── trainer ├── ModelTrainer.py ├── BlendTrainer.py └── AlignTrainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained_model* 2 | *.pth 3 | checkpoint* 4 | *__pycache__* 5 | dataset 6 | -------------------------------------------------------------------------------- /assets/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeslieZhoa/HeSer.Pytorch/HEAD/assets/11.png -------------------------------------------------------------------------------- /assets/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeslieZhoa/HeSer.Pytorch/HEAD/assets/22.png -------------------------------------------------------------------------------- /assets/show.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeslieZhoa/HeSer.Pytorch/HEAD/assets/show.png -------------------------------------------------------------------------------- /model/AlignModule/lib/__init__.py: -------------------------------------------------------------------------------- 1 | from model.AlignModule.lib.ExpEncoder import ExpEncoder 2 | from model.AlignModule.lib.IDEncoder import IDEncoder 3 | from model.AlignModule.lib.PorEncoder import PorEncoder 4 | from model.AlignModule.lib.PoseEncoder import PoseEncoder 5 | from model.AlignModule.lib.generator import Generator -------------------------------------------------------------------------------- /model/AlignModule/lib/ExpEncoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torchvision 3 | 4 | class ExpEncoder(nn.Module): 5 | def __init__(self,args): 6 | super().__init__() 7 | self.exp_encoder = torchvision.models.mobilenet_v2(num_classes=args.exp_embedding_size) 8 | 9 | def forward(self,x): 10 | 11 | return self.exp_encoder(x) -------------------------------------------------------------------------------- /model/AlignModule/lib/PoseEncoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torchvision 3 | 4 | class PoseEncoder(nn.Module): 5 | def __init__(self,args): 6 | super().__init__() 7 | self.pose_encoder = torchvision.models.mobilenet_v2(num_classes=args.pose_embedding_size) 8 | 9 | def forward(self,x): 10 | 11 | return self.pose_encoder(x) -------------------------------------------------------------------------------- /process/download_weight.sh: -------------------------------------------------------------------------------- 1 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/parsing.pth -P ../pretrained_models 2 | wget https://github.com/LeslieZhoa/DCT-NET.Pytorch/releases/download/v0.0/model_ir_se50.pth -P ../pretrained_models 3 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/vgg19-d01eb7cb.pth -P ../pretrained_models 4 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/073-00000000.pth -P ../checkpoint/Blender 5 | wget https://github.com/LeslieZhoa/HeSer.Pytorch/releases/download/v0.0/417-00000000.pth -P ../checkpoint/Aligner -------------------------------------------------------------------------------- /process/filter_idfiles.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import rmtree 3 | 4 | def rm_little_files(base,th=100): 5 | k = 1 6 | for idname in os.listdir(base): 7 | id_path = os.path.join(base,idname) 8 | for video_clip in os.listdir(id_path): 9 | path = os.path.join(id_path,video_clip) 10 | length = len(os.listdir(path)) 11 | print('\rhave done %04d'%k,end='',flush=True) 12 | k += 1 13 | if length < th: 14 | rmtree(path) 15 | print() 16 | 17 | if __name__ == "__main__": 18 | base = '../dataset/process/img' 19 | rm_little_files(base) -------------------------------------------------------------------------------- /model/AlignModule/lib/PorEncoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torchvision 3 | class PorEncoder(nn.Module): 4 | def __init__(self,args): 5 | super().__init__() 6 | self.por_encoder = torchvision.models.resnext50_32x4d(num_classes=args.por_embedding_size) 7 | 8 | def forward(self,x): 9 | 10 | batch = x.shape[0] 11 | if len(x.shape) > 4: 12 | x = x.view(-1,*x.shape[2:]) 13 | feat = self.por_encoder(x) 14 | if feat.shape[0] != batch: 15 | feat = feat.view(batch,feat.shape[0]//batch,-1) 16 | feat = feat.mean(1) 17 | 18 | return feat 19 | -------------------------------------------------------------------------------- /dataloader/DataLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author Leslie 5 | @date 20220812 6 | ''' 7 | 8 | 9 | from torch.utils.data import Dataset 10 | import torch.distributed as dist 11 | 12 | 13 | class DatasetBase(Dataset): 14 | def __init__(self,slice_id=0,slice_count=1,use_dist=False,**kwargs): 15 | 16 | if use_dist: 17 | slice_id = dist.get_rank() 18 | slice_count = dist.get_world_size() 19 | self.id = slice_id 20 | self.count = slice_count 21 | 22 | 23 | def __getitem__(self,i): 24 | pass 25 | 26 | 27 | 28 | 29 | def __len__(self): 30 | return 1000 31 | 32 | -------------------------------------------------------------------------------- /model/AlignModule/lib/IDEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from model.AlignModule.module import Backbone 4 | 5 | class IDEncoder(nn.Module): 6 | def __init__(self,model_path): 7 | super(IDEncoder, self).__init__() 8 | print('Loading ResNet ArcFace') 9 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 10 | self.facenet.load_state_dict(torch.load(model_path)) 11 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 12 | self.facenet.eval() 13 | for module in [self.facenet, self.face_pool]: 14 | for param in module.parameters(): 15 | param.requires_grad = False 16 | 17 | 18 | def forward(self, x): 19 | batch = x.shape[0] 20 | if len(x.shape) > 4: 21 | x = x.view(-1,*x.shape[2:]) 22 | feat = self.facenet(self.face_pool(x)) 23 | if feat.shape[0] != batch: 24 | feat = feat.view(batch,feat.shape[0]//batch,-1) 25 | feat = feat.mean(1) 26 | 27 | return feat 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 LeslieZhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/BlendModule/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author Leslie 3 | @date 20220823 4 | ''' 5 | class Params: 6 | def __init__(self): 7 | 8 | self.name = 'Blender' 9 | self.pretrain_path = 'checkpoint/Blender/025-00000000.pth' 10 | self.size = 512 11 | 12 | self.train_root = 'dataset/process/img' 13 | self.val_root = 'dataset/process/img' 14 | 15 | self.f_in_channels = 512 16 | self.f_inter_channels = 256 17 | self.temperature = 0.001 18 | self.dilate_kernel = 17 19 | self.decoder_ic = 12 20 | 21 | # discriminator 22 | self.embed_channels = 512 23 | self.padding = 'zero' 24 | self.in_channels = 5 25 | self.out_channels = 3 26 | self.num_channels = 64 27 | self.max_num_channels = 512 28 | self.output_image_size = 512 29 | self.dis_num_blocks = 7 30 | 31 | self.per_model = 'pretrained_models/vgg19-d01eb7cb.pth' 32 | 33 | # loss 34 | self.rec_loss = True 35 | self.per_loss = True 36 | self.lambda_gan = 0.2 37 | self.lambda_rec = 1.0 38 | self.lambda_per = 0.0005 39 | 40 | self.g_lr = 1e-4 41 | self.d_lr = 4e-4 42 | self.beta1 = 0.9 43 | self.beta2 = 0.999 -------------------------------------------------------------------------------- /model/AlignModule/config.py: -------------------------------------------------------------------------------- 1 | class Params: 2 | def __init__(self): 3 | 4 | self.name = 'Aligner' 5 | self.pretrain_path = None 6 | # self.pretrain_path = None 7 | self.size = 512 8 | 9 | self.train_root = 'dataset/select-align/img' 10 | self.val_root = 'dataset/select-align/img' 11 | self.use_pixelwise_augs = True 12 | self.use_affine_scale = True 13 | self.use_affine_shift = True 14 | self.frame_num = 5 15 | self.skip_frame = 5 16 | 17 | self.identity_embedding_size = 512 18 | self.pose_embedding_size = 256 19 | self.por_embedding_size = 512 20 | self.exp_embedding_size = 256 21 | self.embed_channels = 512 22 | self.padding = 'zero' 23 | self.in_channels = 3 24 | self.out_channels = 3 25 | self.num_channels = 64 26 | self.max_num_channels = 512 27 | self.norm_layer = 'in' 28 | self.gen_constant_input_size = 4 29 | self.gen_num_residual_blocks = 2 30 | self.output_image_size = 512 31 | 32 | self.dis_num_blocks = 7 33 | 34 | self.id_model = 'pretrained_models/model_ir_se50.pth' 35 | self.per_model = 'pretrained_models/vgg19-d01eb7cb.pth' 36 | 37 | self.rec_loss = True 38 | self.id_loss = True 39 | self.per_loss = True 40 | self.lambda_gan = 2 41 | self.lambda_rec = 10 42 | self.lambda_id = 2 43 | self.lambda_per = 0.002 44 | 45 | self.g_lr = 1e-3 46 | self.d_lr = 4e-3 47 | self.beta1 = 0.9 48 | self.beta2 = 0.999 -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. ALL rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import time 8 | import subprocess 9 | 10 | from tensorboardX import SummaryWriter 11 | 12 | 13 | 14 | class Visualizer: 15 | def __init__(self,opt,mode='train'): 16 | self.opt = opt 17 | self.name = opt.name 18 | self.mode = mode 19 | self.train_log_dir = os.path.join(opt.checkpoint_path,"logs/%s"%mode) 20 | self.log_name = os.path.join(opt.checkpoint_path,'loss_log_%s.txt'%mode) 21 | if opt.local_rank == 0: 22 | if not os.path.exists(self.train_log_dir): 23 | os.makedirs(self.train_log_dir) 24 | 25 | self.train_writer = SummaryWriter(self.train_log_dir) 26 | 27 | self.log_file = open(self.log_name,"a") 28 | now = time.strftime("%c") 29 | self.log_file.write('================ Training Loss (%s) =================\n'%now) 30 | self.log_file.flush() 31 | 32 | 33 | # errors:dictionary of error labels and values 34 | def plot_current_errors(self,errors,step): 35 | 36 | for tag,value in errors.items(): 37 | 38 | self.train_writer.add_scalar("%s/"%self.name+tag,value,step) 39 | self.train_writer.flush() 40 | 41 | 42 | # errors: same format as |errors| of CurrentErrors 43 | def print_current_errors(self,epoch,i,errors,t): 44 | message = '(epoch: %d\t iters: %d\t time: %.5f)\t'%(epoch,i,t) 45 | for k,v in errors.items(): 46 | 47 | message += '%s: %.5f\t' %(k,v) 48 | 49 | print(message) 50 | 51 | self.log_file.write('%s\n' % message) 52 | self.log_file.flush() 53 | 54 | def display_current_results(self, visuals, step): 55 | if visuals is None: 56 | return 57 | for label, image in visuals.items(): 58 | # Write the image to a string 59 | 60 | self.train_writer.add_image("%s/"%self.name+label,image,global_step=step) 61 | 62 | def close(self): 63 | 64 | self.train_writer.close() 65 | self.log_file.close() -------------------------------------------------------------------------------- /process/download_and_process.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import pdb 4 | from multiprocessing import Process 5 | import multiprocessing as mp 6 | import subprocess 7 | import numpy as np 8 | 9 | def download_video(q): 10 | k = 1 11 | while True: 12 | vid,save_base = q.get() 13 | if vid is None: 14 | break 15 | cmd = 'yt-dlp -f bestvideo[ext=mp4]+bestaudio[ext=m4a]/bestvideo+bestaudio \ 16 | https://www.youtube.com/watch?v={vid} \ 17 | --merge-output-format mp4 \ 18 | --output {save_base}/{vid}.mp4 \ 19 | --external-downloader aria2c \ 20 | --downloader-args aria2c:"-x 16 -k 1M"'.format(vid=vid,save_base=save_base) 21 | 22 | subprocess.call(cmd,shell=True) 23 | print('\rhave done %06d'%k,end='',flush=True) 24 | k += 1 25 | print() 26 | 27 | def get_frames(q): 28 | 29 | while True: 30 | path,save_path = q.get() 31 | if path is None: 32 | break 33 | save_base = os.path.split(save_path)[0] 34 | os.makedirs(save_base,exist_ok=True) 35 | with open(path,'r') as f: 36 | lines = f.readlines() 37 | filter_lines = list(filter(lambda x:x.startswith('0'),lines)) 38 | frames = list(map(lambda x:x.strip().split(),filter_lines)) 39 | 40 | np.save(save_path,frames) 41 | 42 | def read_file(base,save,q1,q2): 43 | for idname in os.listdir(base): 44 | idpath = os.path.join(base,idname) 45 | for videoname in os.listdir(idpath): 46 | q1.put([videoname,os.path.join(save,idname,videoname)]) 47 | videopath = os.path.join(idpath,videoname) 48 | 49 | for i,infoname in enumerate(os.listdir(videopath)): 50 | infopath = os.path.join(videopath,infoname) 51 | q2.put([infopath,os.path.join(save,idname,videoname,'%02d.npy'%i)]) 52 | 53 | if __name__ == "__main__": 54 | base = '../dataset/vox2_test_txt/txt' 55 | save = '../dataset/voceleb2' 56 | mp.set_start_method('spawn') 57 | m = mp.Manager() 58 | queue1 = m.Queue() 59 | queue2 = m.Queue() 60 | 61 | read_p = Process(target=read_file,args=(base,save,queue1,queue2,)) 62 | download_p = Process(target=download_video,args=(queue1,)) 63 | frame_p = Process(target=get_frames,args=(queue2,)) 64 | 65 | read_p.start() 66 | download_p.start() 67 | frame_p.start() 68 | 69 | read_p.join() 70 | queue1.put([None,None]) 71 | queue2.put([None,None]) 72 | download_p.join() 73 | frame_p.join() 74 | -------------------------------------------------------------------------------- /dataloader/BlendLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author Leslie 5 | @date 20220812 6 | ''' 7 | 8 | import os 9 | 10 | from torchvision import transforms 11 | import PIL.Image as Image 12 | from dataloader.DataLoader import DatasetBase 13 | import random 14 | import torchvision.transforms.functional as F 15 | 16 | 17 | class BlendData(DatasetBase): 18 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs): 19 | super().__init__(slice_id, slice_count,dist, **kwargs) 20 | 21 | 22 | self.transform = transforms.Compose([ 23 | transforms.Resize((kwargs['size'], kwargs['size'])), 24 | transforms.ToTensor() 25 | ]) 26 | self.color_fn = transforms.Compose([transforms.ColorJitter(0.5, 0.5, 0.5, 0.1)]) 27 | 28 | 29 | self.norm = transforms.Compose([transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) 30 | 31 | self.gray = transforms.Compose([transforms.Grayscale(num_output_channels=1)]) 32 | 33 | # source root 34 | root = kwargs['root'] 35 | self.paths = [os.path.join(root,f) for f in os.listdir(root)] 36 | self.length = len(self.paths) 37 | random.shuffle(self.paths) 38 | self.eval = kwargs['eval'] 39 | 40 | 41 | 42 | def __getitem__(self,i): 43 | 44 | idx = i % self.length 45 | id_path = self.paths[idx] 46 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)] 47 | vIdx = random.randint(0, len(video_paths) - 1) 48 | video_path = video_paths[vIdx] 49 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)] 50 | img_idx = random.randint(0, len(img_paths) - 1) 51 | img_path = img_paths[img_idx] 52 | mask_path = img_path.replace('img','mask') 53 | 54 | idx = (i + random.randint(0,self.length-1)) % self.length 55 | id_path = self.paths[idx] 56 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)] 57 | vIdx = random.randint(0, len(video_paths) - 1) 58 | video_path = video_paths[vIdx] 59 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)] 60 | img_idx = random.randint(0, len(img_paths) - 1) 61 | ex_img_path = img_paths[img_idx] 62 | ex_mask_path = ex_img_path.replace('img','mask') 63 | 64 | 65 | with Image.open(img_path) as img: 66 | gt = self.transform(img.convert('RGB')) 67 | I_a = self.transform(self.color_fn(img.convert('RGB'))) 68 | gt = self.norm(gt) 69 | I_a = self.norm(I_a) 70 | with Image.open(mask_path) as img: 71 | M_a = self.transform(img.convert('L')) * 255 72 | 73 | 74 | I_gray = self.gray(I_a) 75 | if random.random() > 0.3: 76 | I_t = F.hflip(gt) 77 | M_t = F.hflip(M_a) 78 | 79 | else: 80 | I_t = gt 81 | M_t = M_a 82 | 83 | with Image.open(ex_img_path) as img: 84 | hat_t = self.transform(img.convert('RGB')) 85 | hat_t = self.norm(hat_t) 86 | with Image.open(ex_mask_path) as img: 87 | M_hat = self.transform(img.convert('L')) * 255 88 | 89 | return I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt 90 | 91 | def __len__(self): 92 | if self.eval: 93 | return 1000 94 | else: 95 | # return self.length 96 | return max(self.length,1000) 97 | 98 | -------------------------------------------------------------------------------- /model/AlignModule/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from collections import OrderedDict 4 | import torchvision 5 | import torch.nn.functional as F 6 | 7 | def compute_dis_loss(fake_pred,real_pred,D_loss): 8 | # d_real = torch.relu(1. - real_pred).mean() 9 | # d_fake = torch.relu(1. + fake_pred).mean() 10 | d_real = F.mse_loss(real_pred,torch.ones_like(real_pred)) 11 | d_fake = F.mse_loss(fake_pred, torch.zeros_like(fake_pred)) 12 | 13 | D_loss['d_real'] = d_real 14 | D_loss['d_fake'] = d_fake 15 | return d_real + d_fake 16 | 17 | def compute_gan_loss(fake_pred): 18 | 19 | return F.mse_loss(fake_pred,torch.ones_like(fake_pred)) 20 | 21 | def compute_id_loss(fake_id_f,real_id_f): 22 | return 1.0 - torch.cosine_similarity(fake_id_f,real_id_f, dim = 1) 23 | 24 | 25 | # Perceptual loss that uses a pretrained VGG network 26 | 27 | class Flatten(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | 31 | def forward(self, x): 32 | return x.view(-1) 33 | 34 | class PerceptualLoss(nn.Module): 35 | def __init__(self,model_path, normalize_grad=False): 36 | super().__init__() 37 | 38 | self.normalize_grad = normalize_grad 39 | 40 | 41 | vgg_weights = torch.load(model_path) 42 | 43 | map = {'classifier.6.weight': u'classifier.7.weight', 'classifier.6.bias': u'classifier.7.bias'} 44 | vgg_weights = OrderedDict([(map[k] if k in map else k, v) for k, v in vgg_weights.items()]) 45 | 46 | model = torchvision.models.vgg19() 47 | model.classifier = nn.Sequential(Flatten(), *model.classifier._modules.values()) 48 | 49 | model.load_state_dict(vgg_weights) 50 | 51 | model = model.features 52 | 53 | mean = torch.tensor([103.939, 116.779, 123.680]) / 255. 54 | std = torch.tensor([1., 1., 1.]) / 255. 55 | 56 | num_layers = 30 57 | 58 | self.register_buffer('mean', mean[None, :, None, None]) 59 | self.register_buffer('std' , std[None, :, None, None]) 60 | 61 | layers_avg_pooling = [] 62 | 63 | for weights in model.parameters(): 64 | weights.requires_grad = False 65 | 66 | for module in model.modules(): 67 | if module.__class__.__name__ == 'Sequential': 68 | continue 69 | elif module.__class__.__name__ == 'MaxPool2d': 70 | layers_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) 71 | else: 72 | layers_avg_pooling.append(module) 73 | 74 | if len(layers_avg_pooling) >= num_layers: 75 | break 76 | 77 | layers_avg_pooling = nn.Sequential(*layers_avg_pooling) 78 | 79 | self.model = layers_avg_pooling 80 | 81 | def normalize_inputs(self, x): 82 | return (x - self.mean) / self.std 83 | 84 | def forward(self, input, target): 85 | input = (input + 1) / 2 86 | target = (target.detach() + 1) / 2 87 | loss = 0 88 | features_input = self.normalize_inputs(input) 89 | features_target = self.normalize_inputs(target) 90 | 91 | for layer in self.model: 92 | features_input = layer(features_input) 93 | features_target = layer(features_target) 94 | 95 | if layer.__class__.__name__ == 'ReLU': 96 | if self.normalize_grad: 97 | pass 98 | else: 99 | loss = loss + F.l1_loss(features_input, features_target) 100 | 101 | return loss 102 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author Leslie 3 | @date 20220812 4 | ''' 5 | import os 6 | 7 | import argparse 8 | 9 | 10 | from trainer.AlignTrainer import AlignTrainer 11 | from trainer.BlendTrainer import BlendTrainer 12 | import torch.distributed as dist 13 | from utils.utils import setup_seed,get_data_loader,merge_args 14 | from model.AlignModule.config import Params as AlignParams 15 | from model.BlendModule.config import Params as BlendParams 16 | 17 | # torch.multiprocessing.set_start_method('spawn') 18 | 19 | parser = argparse.ArgumentParser(description="HeSer") 20 | #---------train set------------------------------------- 21 | parser.add_argument('--model',default="align",help='') 22 | parser.add_argument('--isTrain',action="store_false",help='') 23 | parser.add_argument('--dist',action="store_false",help='') 24 | parser.add_argument('--batch_size',default=16,type=int) 25 | parser.add_argument('--seed',default=10,type=int) 26 | parser.add_argument('--eval',default=1,type=int,help='whether use eval') 27 | parser.add_argument('--nDataLoaderThread',default=5,type=int,help='Num of loader threads') 28 | parser.add_argument('--print_interval',default=100,type=int) 29 | parser.add_argument('--test_interval',default=100,type=int,help='Test and save every [test_intervaal] epochs') 30 | parser.add_argument('--save_interval',default=100,type=int,help='save model interval') 31 | parser.add_argument('--stop_interval',default=20,type=int) 32 | parser.add_argument('--begin_it',default=0,type=int,help='begin epoch') 33 | parser.add_argument('--mx_data_length',default=100,type=int,help='max data length') 34 | parser.add_argument('--max_epoch',default=10000,type=int) 35 | parser.add_argument('--early_stop',action="store_true",help='') 36 | parser.add_argument('--scratch',action="store_true",help='') 37 | #---------path set-------------------------------------- 38 | parser.add_argument('--checkpoint_path',default='checkpoint',type=str) 39 | parser.add_argument('--pretrain_path',default=None,type=str) 40 | 41 | # ------optimizer set-------------------------------------- 42 | parser.add_argument('--lr',default=0.002,type=float,help="Learning rate") 43 | 44 | parser.add_argument( 45 | '--local_rank', 46 | type=int, 47 | default=0, 48 | help='Local rank passed from distributed launcher' 49 | ) 50 | 51 | args = parser.parse_args() 52 | 53 | def train_net(args): 54 | train_loader,test_loader,mx_length = get_data_loader(args) 55 | 56 | args.mx_data_length = mx_length 57 | if args.model == 'align': 58 | trainer = AlignTrainer(args) 59 | elif args.model == 'blend': 60 | trainer = BlendTrainer(args) 61 | 62 | 63 | trainer.train_network(train_loader,test_loader) 64 | 65 | if __name__ == "__main__": 66 | 67 | args = parser.parse_args() 68 | 69 | if args.model == 'align': 70 | params = AlignParams() 71 | 72 | elif args.model == 'blend': 73 | params = BlendParams() 74 | args = merge_args(args,params) 75 | if args.dist: 76 | dist.init_process_group(backend="nccl") # backbend='nccl' 77 | dist.barrier() # 用于同步训练 78 | args.world_size = dist.get_world_size() # 一共有几个节点 79 | args.rank = dist.get_rank() # 当前节点编号 80 | 81 | else: 82 | args.world_size = 1 83 | args.rank = 0 84 | 85 | setup_seed(args.seed+args.rank) 86 | print(args) 87 | 88 | args.checkpoint_path = os.path.join(args.checkpoint_path,args.name) 89 | 90 | print("local_rank %d | rank %d | world_size: %d"%(int(os.environ.get('LOCAL_RANK','0')),args.rank,args.world_size)) 91 | if args.rank == 0 : 92 | if not os.path.exists(args.checkpoint_path): 93 | os.makedirs(args.checkpoint_path) 94 | print("make dir: ",args.checkpoint_path) 95 | train_net(args) 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /process/process_raw_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from face_alignment.detection.sfd import FaceDetector 5 | import face_alignment 6 | import torch 7 | import math 8 | import cv2 9 | import pdb 10 | from multiprocessing import Process 11 | import multiprocessing as mp 12 | from process_utils import * 13 | import pdb 14 | 15 | def get_video_info(base,save_base,q): 16 | 17 | for idname in os.listdir(base): 18 | idpath = os.path.join(base,idname) 19 | save_path = os.path.join(save_base,idname) 20 | for videoname in os.listdir(idpath): 21 | videopath = os.path.join(idpath,videoname) 22 | frame_names = [os.path.join(videopath,f) for f in os.listdir(videopath) if f.endswith('.mp4')] 23 | info_names = [os.path.join(videopath,f) for f in os.listdir(videopath) if f.endswith('.npy')] 24 | if len(frame_names) == 0: 25 | continue 26 | q.put([frame_names[0],info_names,save_path,videoname]) 27 | 28 | 29 | 30 | def process_frame(q1,align=True,scale=1.8,size=512): 31 | face_detector = FaceDetector(device='cuda') 32 | lmk_detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D, flip_input=False) 33 | kk = 1 34 | def detect_faces(images): 35 | images = np.stack(images).transpose(0,3,1,2).astype(np.float32) 36 | images_torch = torch.tensor(images) 37 | return face_detector.detect_from_batch(images_torch.cuda()) 38 | 39 | while True: 40 | 41 | frame_path,info_names,save_base,videoname = q1.get() 42 | if frame_path is None: 43 | break 44 | video_reader = imageio.get_reader(frame_path) 45 | for k,info_name in enumerate(info_names): 46 | info = np.load(info_name) 47 | save_path = os.path.join(save_base,'%s-%04d'%(videoname,k)) 48 | os.makedirs(save_path,exist_ok=True) 49 | for (i,x,y,w,h) in info: 50 | try: 51 | img = video_reader.get_data(int(i)) 52 | height,width,_ = img.shape 53 | x,y,w,h = list(map(lambda x:float(x),[x,y,w,h])) 54 | i = int(i) 55 | box = [x*width,y*height,(x+w)*width,(y+h)*height] 56 | if os.path.exists(os.path.join(save_path,'%04d.png'%i)): 57 | continue 58 | 59 | bboxes = detect_faces([img])[0] 60 | 61 | bbox = choose_one_detection(bboxes,box) 62 | if bbox is None: 63 | continue 64 | bbox = bbox[:4] 65 | landmarks = lmk_detector.get_landmarks_from_image(img[...,::-1], [bbox])[0] 66 | image_cropped,_ = crop_with_padding(img,landmarks[:,:2],scale=scale,size=size,align=align) 67 | 68 | cv2.imwrite(os.path.join(save_path,'%04d.png'%i),image_cropped[...,::-1]) 69 | print('\r have done %06d'%i,end='',flush=True) 70 | kk += 1 71 | except: 72 | continue 73 | 74 | video_reader.close() 75 | print() 76 | 77 | 78 | if __name__ == "__main__": 79 | 80 | mp.set_start_method('spawn') 81 | m = mp.Manager() 82 | q1 = m.Queue() 83 | base = '../dataset/voceleb2' 84 | save_base = '../dataset/process' 85 | process_num = 2 86 | 87 | info_p = Process(target=get_video_info,args=(base,save_base,q1,)) 88 | 89 | 90 | process_list = [] 91 | for _ in range(process_num): 92 | process_list.append(Process(target=process_frame,args=(q1,))) 93 | 94 | 95 | info_p.start() 96 | for p in process_list: 97 | p.start() 98 | 99 | info_p.join() 100 | 101 | for _ in range(process_num*2): 102 | q1.put([None,None,None,None]) 103 | for p in process_list: 104 | p.join() 105 | -------------------------------------------------------------------------------- /model/third/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.model_zoo as modelzoo 8 | 9 | # from modules.bn import InPlaceABNSync as BatchNorm2d 10 | 11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | def __init__(self, in_chan, out_chan, stride=1): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(in_chan, out_chan, stride) 24 | self.bn1 = nn.BatchNorm2d(out_chan) 25 | self.conv2 = conv3x3(out_chan, out_chan) 26 | self.bn2 = nn.BatchNorm2d(out_chan) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | if in_chan != out_chan or stride != 1: 30 | self.downsample = nn.Sequential( 31 | nn.Conv2d(in_chan, out_chan, 32 | kernel_size=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_chan), 34 | ) 35 | 36 | def forward(self, x): 37 | residual = self.conv1(x) 38 | residual = F.relu(self.bn1(residual)) 39 | residual = self.conv2(residual) 40 | residual = self.bn2(residual) 41 | 42 | shortcut = x 43 | if self.downsample is not None: 44 | shortcut = self.downsample(x) 45 | 46 | out = shortcut + residual 47 | out = self.relu(out) 48 | return out 49 | 50 | 51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 53 | for i in range(bnum-1): 54 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 55 | return nn.Sequential(*layers) 56 | 57 | 58 | class Resnet18(nn.Module): 59 | def __init__(self): 60 | super(Resnet18, self).__init__() 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = nn.BatchNorm2d(64) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 69 | self.init_weight() 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = F.relu(self.bn1(x)) 74 | x = self.maxpool(x) 75 | 76 | x = self.layer1(x) 77 | feat8 = self.layer2(x) # 1/8 78 | feat16 = self.layer3(feat8) # 1/16 79 | feat32 = self.layer4(feat16) # 1/32 80 | return feat8, feat16, feat32 81 | 82 | def init_weight(self): 83 | state_dict = modelzoo.load_url(resnet18_url) 84 | self_state_dict = self.state_dict() 85 | for k, v in state_dict.items(): 86 | if 'fc' in k: continue 87 | self_state_dict.update({k: v}) 88 | self.load_state_dict(self_state_dict) 89 | 90 | def get_params(self): 91 | wd_params, nowd_params = [], [] 92 | for name, module in self.named_modules(): 93 | if isinstance(module, (nn.Linear, nn.Conv2d)): 94 | wd_params.append(module.weight) 95 | if not module.bias is None: 96 | nowd_params.append(module.bias) 97 | elif isinstance(module, nn.BatchNorm2d): 98 | nowd_params += list(module.parameters()) 99 | return wd_params, nowd_params 100 | 101 | 102 | if __name__ == "__main__": 103 | net = Resnet18() 104 | x = torch.randn(16, 3, 224, 224) 105 | out = net(x) 106 | print(out[0].size()) 107 | print(out[1].size()) 108 | print(out[2].size()) 109 | net.get_params() 110 | -------------------------------------------------------------------------------- /dataloader/AlignLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author Leslie 5 | @date 20220812 6 | ''' 7 | 8 | import os 9 | 10 | from torchvision import transforms 11 | import PIL.Image as Image 12 | from dataloader.DataLoader import DatasetBase 13 | import random 14 | import math 15 | import torch 16 | from dataloader.augmentation import ParametricAugmenter 17 | import numpy as np 18 | 19 | 20 | class AlignData(DatasetBase): 21 | def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs): 22 | super().__init__(slice_id, slice_count,dist, **kwargs) 23 | 24 | 25 | self.transform = transforms.Compose([ 26 | transforms.Resize((kwargs['size'], kwargs['size'])), 27 | transforms.RandomHorizontalFlip(), 28 | transforms.ToTensor() 29 | ]) 30 | 31 | self.aug_fn = ParametricAugmenter(use_pixelwise_augs=kwargs['use_pixelwise_augs'], 32 | use_affine_scale=kwargs['use_affine_scale'], 33 | use_affine_shift=kwargs['use_affine_shift']) 34 | 35 | 36 | self.norm = transforms.Compose([transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) 37 | 38 | self.resize = transforms.Compose([ 39 | transforms.Resize((256,256))]) 40 | 41 | # source root 42 | root = kwargs['root'] 43 | self.paths = [os.path.join(root,f) for f in os.listdir(root)] 44 | # dis = math.floor(len(self.paths)/self.count) 45 | # self.paths = self.paths[self.id*dis:(self.id+1)*dis] 46 | self.length = len(self.paths) 47 | random.shuffle(self.paths) 48 | self.frame_num = kwargs['frame_num'] 49 | self.skip_frame = kwargs['skip_frame'] 50 | self.eval = kwargs['eval'] 51 | self.size = kwargs['size'] 52 | self.scale = 0.4 / 1.8 53 | 54 | 55 | 56 | def __getitem__(self,i): 57 | 58 | idx = i % self.length 59 | id_path = self.paths[idx] 60 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)] 61 | vIdx = random.randint(0, len(video_paths) - 1) 62 | video_path = video_paths[vIdx] 63 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)] 64 | begin_idx = random.randint(0, len(img_paths) - self.frame_num*self.skip_frame - 1) 65 | img_paths = [img_paths[i] 66 | for i in range(begin_idx,begin_idx+self.frame_num*self.skip_frame,self.skip_frame)] 67 | 68 | s_img_paths = img_paths[:-1] 69 | 70 | t_img_path = img_paths[-1] 71 | 72 | xs = [] 73 | for img_path in s_img_paths: 74 | with Image.open(img_path) as img: 75 | xs.append(self.norm(self.transform(img.convert('RGB'))).unsqueeze(0)) 76 | xs = torch.cat(xs,0) 77 | 78 | if self.eval: 79 | idx = (i + random.randint(0,self.length-1)) % self.length 80 | id_path = self.paths[idx] 81 | video_paths = [os.path.join(id_path,f) for f in os.listdir(id_path)] 82 | vIdx = random.randint(0, len(video_paths) - 1) 83 | video_path = video_paths[vIdx] 84 | img_paths = [os.path.join(video_path,f) for f in os.listdir(video_path)] 85 | t_img_path = img_paths[random.randint(0, len(img_paths) - 1)] 86 | 87 | with Image.open(t_img_path) as img: 88 | xt = self.transform(img.convert('RGB')) 89 | 90 | mask = np.zeros((self.size,self.size,3),dtype=np.uint8) 91 | mask[int(self.size*self.scale):int(-self.size*self.scale), 92 | int(self.size*self.scale):int(-self.size*self.scale)] = 255 93 | mask = mask[np.newaxis,:] 94 | xt,gt,mask = self.aug_fn.augment_triple(xt,xt,mask) 95 | indexs = torch.where(mask==1) 96 | top = indexs[1].min() 97 | bottom = indexs[1].max() 98 | left = indexs[2].min() 99 | right = indexs[2].max() 100 | crop_xt = xt[...,top:bottom,left:right] 101 | crop_xt = self.norm(crop_xt) 102 | xt = self.norm(xt) 103 | gt = self.norm(gt) 104 | 105 | return self.resize(xs),self.resize(xt),self.resize(crop_xt),gt 106 | 107 | 108 | def __len__(self): 109 | if self.eval: 110 | return 1000 111 | else: 112 | # return self.length 113 | return max(self.length,1000) 114 | 115 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | @author Leslie 4 | @date 20220812 5 | ''' 6 | import torch 7 | from dataloader.AlignLoader import AlignData 8 | from dataloader.BlendLoader import BlendData 9 | 10 | 11 | def requires_grad(model, flag=True): 12 | if model is None: 13 | return 14 | for p in model.parameters(): 15 | p.requires_grad = flag 16 | def need_grad(x): 17 | x = x.detach() 18 | x.requires_grad_() 19 | return x 20 | 21 | def init_weights(m,init_type='normal', gain=0.02): 22 | 23 | classname = m.__class__.__name__ 24 | if classname.find('BatchNorm2d') != -1: 25 | if hasattr(m, 'weight') and m.weight is not None: 26 | torch.nn.init.normal_(m.weight.data, 1.0, gain) 27 | if hasattr(m, 'bias') and m.bias is not None: 28 | torch.nn.init.constant_(m.bias.data, 0.0) 29 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 30 | if init_type == 'normal': 31 | torch.nn.init.normal_(m.weight.data, 0.0, gain) 32 | elif init_type == 'xavier': 33 | torch.nn.init.xavier_normal_(m.weight.data, gain=gain) 34 | elif init_type == 'xavier_uniform': 35 | torch.nn.init.xavier_uniform_(m.weight.data, gain=1.0) 36 | elif init_type == 'kaiming': 37 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 38 | elif init_type == 'orthogonal': 39 | torch.nn.init.orthogonal_(m.weight.data, gain=gain) 40 | elif init_type == 'none': # uses pytorch's default init method 41 | m.reset_parameters() 42 | else: 43 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 44 | if hasattr(m, 'bias') and m.bias is not None: 45 | torch.nn.init.constant_(m.bias.data, 0.0) 46 | def accumulate(model1, model2, decay=0.999): 47 | par1 = dict(model1.named_parameters()) 48 | par2 = dict(model2.named_parameters()) 49 | 50 | for k in par1.keys(): 51 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 52 | def setup_seed(seed): 53 | torch.manual_seed(seed) 54 | if torch.cuda.is_available(): 55 | torch.cuda.manual_seed_all(seed) 56 | torch.backends.cudnn.deterministic = True 57 | 58 | def get_data_loader(args): 59 | if args.model == 'align': 60 | train_data = AlignData(dist=args.dist, 61 | size=args.size, 62 | root=args.train_root, 63 | frame_num=args.frame_num, 64 | skip_frame=args.skip_frame, 65 | use_pixelwise_augs=args.use_pixelwise_augs, 66 | use_affine_scale=args.use_affine_scale, 67 | use_affine_shift=args.use_affine_shift, 68 | eval=False) 69 | 70 | test_data = AlignData(dist=args.dist, 71 | size=args.size, 72 | root=args.val_root, 73 | frame_num=args.frame_num, 74 | skip_frame=args.skip_frame, 75 | use_pixelwise_augs=False, 76 | use_affine_scale=False, 77 | use_affine_shift=False, 78 | eval=True) 79 | 80 | elif args.model == 'blend': 81 | train_data = BlendData(dist=args.dist, 82 | size=args.size, 83 | root=args.train_root,eval=False) 84 | 85 | test_data = BlendData(dist=args.dist, 86 | size=args.size, 87 | root=args.val_root,eval=True) 88 | 89 | 90 | train_loader = torch.utils.data.DataLoader( 91 | train_data, 92 | batch_size=args.batch_size, 93 | num_workers=args.nDataLoaderThread, 94 | pin_memory=False, 95 | drop_last=True 96 | ) 97 | test_loader = None if test_data is None else \ 98 | torch.utils.data.DataLoader( 99 | test_data, 100 | batch_size=args.batch_size, 101 | num_workers=args.nDataLoaderThread, 102 | pin_memory=False, 103 | drop_last=True 104 | ) 105 | return train_loader,test_loader,len(train_data) 106 | 107 | 108 | 109 | def merge_args(args,params): 110 | for k,v in vars(params).items(): 111 | setattr(args,k,v) 112 | return args 113 | 114 | def convert_img(img,unit=False): 115 | 116 | img = (img + 1) * 0.5 117 | if unit: 118 | return torch.clamp(img*255+0.5,0,255) 119 | 120 | return torch.clamp(img,0,1) -------------------------------------------------------------------------------- /model/AlignModule/lib/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import spectral_norm 4 | from model.AlignModule.lib import blocks 5 | 6 | import math 7 | 8 | # heavily copy from https://github.com/shrubb/latent-pose-reenactment 9 | 10 | class Constant(nn.Module): 11 | def __init__(self, *shape): 12 | super().__init__() 13 | self.constant = nn.Parameter(torch.ones(1, *shape)) 14 | 15 | def forward(self, batch_size): 16 | return self.constant.expand((batch_size,) + self.constant.shape[1:]) 17 | 18 | 19 | class Generator(nn.Module): 20 | def __init__(self, args): 21 | super().__init__() 22 | 23 | def get_res_block(in_channels, out_channels, padding, norm_layer): 24 | return blocks.ResBlock(in_channels, out_channels, padding, upsample=False, downsample=False, 25 | norm_layer=norm_layer) 26 | 27 | def get_up_block(in_channels, out_channels, padding, norm_layer): 28 | return blocks.ResBlock(in_channels, out_channels, padding, upsample=True, downsample=False, 29 | norm_layer=norm_layer) 30 | 31 | if args.padding == 'zero': 32 | padding = nn.ZeroPad2d 33 | elif args.padding == 'reflection': 34 | padding = nn.ReflectionPad2d 35 | else: 36 | raise Exception('Incorrect `padding` argument, required `zero` or `reflection`') 37 | 38 | assert math.log2(args.output_image_size / args.gen_constant_input_size).is_integer(), \ 39 | "`gen_constant_input_size` must be `image_size` divided by a power of 2" 40 | num_upsample_blocks = int(math.log2(args.output_image_size / args.gen_constant_input_size)) 41 | out_channels_block_nonclamped = args.num_channels * (2 ** num_upsample_blocks) 42 | out_channels_block = min(out_channels_block_nonclamped, args.max_num_channels) 43 | 44 | self.constant = Constant(out_channels_block, args.gen_constant_input_size, args.gen_constant_input_size) 45 | 46 | 47 | # Decoder 48 | layers = [] 49 | for i in range(args.gen_num_residual_blocks): 50 | layers.append(get_res_block(out_channels_block, out_channels_block, padding, 'ada' + args.norm_layer)) 51 | 52 | for _ in range(num_upsample_blocks): 53 | in_channels_block = out_channels_block 54 | out_channels_block_nonclamped //= 2 55 | out_channels_block = min(out_channels_block_nonclamped, args.max_num_channels) 56 | layers.append(get_up_block(in_channels_block, out_channels_block, padding, 'ada' + args.norm_layer)) 57 | 58 | layers.extend([ 59 | blocks.AdaptiveNorm2d(out_channels_block, args.norm_layer), 60 | nn.ReLU(True), 61 | # padding(1), 62 | spectral_norm( 63 | nn.Conv2d(out_channels_block, args.out_channels, 3, 1, 1), 64 | eps=1e-4), 65 | nn.Tanh() 66 | ]) 67 | self.decoder_blocks = nn.Sequential(*layers) 68 | 69 | self.adains = [module for module in self.modules() if module.__class__.__name__ == 'AdaptiveNorm2d'] 70 | 71 | 72 | joint_embedding_size = args.identity_embedding_size + args.pose_embedding_size + args.por_embedding_size + args.exp_embedding_size 73 | self.affine_params_projector = nn.Sequential( 74 | spectral_norm(nn.Linear(joint_embedding_size, max(joint_embedding_size, 512))), 75 | nn.ReLU(True), 76 | spectral_norm(nn.Linear(max(joint_embedding_size, 512), self.get_num_affine_params())) 77 | ) 78 | 79 | 80 | def get_num_affine_params(self): 81 | return sum(2*module.num_features for module in self.adains) 82 | 83 | def assign_affine_params(self, affine_params): 84 | for m in self.modules(): 85 | if m.__class__.__name__ == "AdaptiveNorm2d": 86 | new_bias = affine_params[:, :m.num_features] 87 | new_weight = affine_params[:, m.num_features:2 * m.num_features] 88 | 89 | if m.bias is None: # to keep m.bias being `nn.Parameter` 90 | m.bias = new_bias.contiguous() 91 | else: 92 | m.bias.copy_(new_bias) 93 | 94 | if m.weight is None: # to keep m.weight being `nn.Parameter` 95 | m.weight = new_weight.contiguous() 96 | else: 97 | m.weight.copy_(new_weight) 98 | 99 | if affine_params.size(1) > 2 * m.num_features: 100 | affine_params = affine_params[:, 2 * m.num_features:] 101 | 102 | def assign_embeddings(self, por,id,pose,exp): 103 | 104 | joint_embedding = torch.cat((por,id,pose,exp), dim=1) 105 | 106 | affine_params = self.affine_params_projector(joint_embedding) 107 | self.assign_affine_params(affine_params) 108 | 109 | 110 | def forward(self, por,id,pose,exp): 111 | self.assign_embeddings(por,id,pose,exp) 112 | 113 | batch_size = len(por) 114 | output = self.decoder_blocks(self.constant(batch_size)) 115 | 116 | return output 117 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from model.AlignModule.lib import * 2 | from model.BlendModule.generator import Generator as Decoder 3 | from model.AlignModule.config import Params as AlignParams 4 | from model.BlendModule.config import Params as BlendParams 5 | from trainer.AlignTrainer import AlignTrainer 6 | from model.third.model import BiSeNet 7 | import torchvision.transforms.functional as TF 8 | import torch.nn.functional as F 9 | import torch 10 | import cv2 11 | import numpy as np 12 | import pdb 13 | 14 | class Infer: 15 | def __init__(self,align_path,blend_path,parsing_path): 16 | align_params = AlignParams() 17 | blend_params = BlendParams() 18 | self.device = 'cpu' 19 | if torch.cuda.is_available(): 20 | self.device = 'cuda' 21 | 22 | self.parsing = BiSeNet(n_classes=19).to(self.device) 23 | self.Epor = PorEncoder(align_params).to(self.device) 24 | self.Eid = IDEncoder(align_params.id_model).to(self.device) 25 | self.Epose = PoseEncoder(align_params).to(self.device) 26 | self.Eexp = ExpEncoder(align_params).to(self.device) 27 | self.netG = Generator(align_params).to(self.device) 28 | self.decoder = Decoder(blend_params).to(self.device) 29 | 30 | self.loadModel(align_path,blend_path,parsing_path) 31 | self.eval_model(self.Epor,self.Eid,self.Epose,self.Eexp,self.netG,self.decoder,self.parsing) 32 | self.mean =torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) 33 | self.std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) 34 | 35 | def run(self,tgt_img_path,src_img_paths): 36 | 37 | tgt_img = cv2.imread(tgt_img_path) 38 | tgt_inp = self.preprocess(tgt_img) 39 | 40 | src_img = cv2.imread(src_img_paths[0]) 41 | 42 | src_inp = self.preprocess_multi(src_img_paths) 43 | 44 | gen = self.forward(src_inp,tgt_inp) 45 | gen = self.postprocess(gen[0]) 46 | cat_img = np.concatenate([cv2.resize(src_img,[512,512]), 47 | gen,cv2.resize(tgt_img,[512,512])],1) 48 | return cat_img 49 | 50 | def forward(self,xs,xt): 51 | with torch.no_grad(): 52 | por_f = self.Epor(xs) 53 | id_f = self.Eid(AlignTrainer.process_id_input(xs,crop=True)) 54 | 55 | pose_f = self.Epose(F.adaptive_avg_pool2d(xt,256)) 56 | exp_f = self.Eexp(AlignTrainer.process_id_input(xt,crop=True,size=256)) 57 | 58 | xg = self.netG(por_f,id_f,pose_f,exp_f) 59 | 60 | M_a = self.parsing(self.preprocess_parsing(xg)) 61 | M_t = self.parsing(self.preprocess_parsing(xt)) 62 | 63 | M_a = self.postprocess_parsing(M_a) 64 | M_t = self.postprocess_parsing(M_t) 65 | xg_gray = TF.rgb_to_grayscale(xg,num_output_channels=1) 66 | fake = self.decoder(xg,xg_gray,xt,M_a,M_t,xt,train=False) 67 | 68 | return fake 69 | 70 | def preprocess(self,x): 71 | if isinstance(x,str): 72 | x = cv2.imread(x) 73 | x = cv2.resize(x,[512,512]) 74 | x = (x[...,::-1].transpose(2,0,1)[np.newaxis,:] / 255 - 0.5) * 2 75 | return torch.from_numpy(x.astype(np.float32)).to(self.device) 76 | 77 | def preprocess_multi(self,xs): 78 | x_list = [] 79 | for x in xs: 80 | x = cv2.imread(x) 81 | x = cv2.resize(x,[512,512]) 82 | x_list.append((x[...,::-1].transpose(2,0,1)[np.newaxis,:] / 255 - 0.5) * 2) 83 | x_list = np.concatenate(x_list,0) 84 | return torch.from_numpy(x_list.astype(np.float32)).to(self.device).unsqueeze(0) 85 | 86 | def postprocess(self,x): 87 | return (x.permute(1,2,0).cpu().numpy()[...,::-1] + 1) * 127.5 88 | 89 | def preprocess_parsing(self,x): 90 | 91 | return ((x+1)/2.0 - self.mean.view(1,-1,1,1).to(self.device)) / \ 92 | self.std.view(1,-1,1,1).to(self.device) 93 | 94 | def postprocess_parsing(self,x): 95 | return torch.argmax(x[0],1).unsqueeze(1).float() 96 | 97 | 98 | 99 | def loadModel(self,align_path,blend_path,parsing_path): 100 | ckpt = torch.load(align_path, map_location=lambda storage, loc: storage) 101 | self.netG.load_state_dict(ckpt['G'],strict=False) 102 | self.Eexp.load_state_dict(ckpt['Eexp'],strict=False) 103 | self.Eid.load_state_dict(ckpt['Eid'],strict=False) 104 | self.Epor.load_state_dict(ckpt['Epor'],strict=False) 105 | 106 | ckpt = torch.load(blend_path, map_location=lambda storage, loc: storage) 107 | self.decoder.load_state_dict(ckpt['G'],strict=False) 108 | 109 | self.parsing.load_state_dict(torch.load(parsing_path)) 110 | 111 | 112 | def eval_model(self,*args): 113 | for arg in args: 114 | arg.eval() 115 | 116 | if __name__ == "__main__": 117 | model = Infer('checkpoint/Aligner/417-00000000.pth', 118 | 'checkpoint/Blender/073-00000000.pth', 119 | 'pretrained_models/parsing.pth') 120 | 121 | src_path_list = ['dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2122.png', 122 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2125.png', 123 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2130.png', 124 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2135.png', 125 | 'dataset/select-align/img/id00061/2XrRfyv-EmE-0001/2140.png'] 126 | tgt_path = 'dataset/select-align/img/id00061/4kSyBHethpE-0002/2055.png' 127 | oup = model.run(tgt_path,src_path_list) 128 | 129 | cv2.imwrite('2.png',oup) -------------------------------------------------------------------------------- /model/AlignModule/discriminator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from model.AlignModule.lib import blocks 3 | from torch.nn.utils import spectral_norm 4 | import math 5 | import torch 6 | # heavily copy from https://github.com/shrubb/latent-pose-reenactment 7 | 8 | # class Discriminator(nn.Module): 9 | # def __init__(self, args): 10 | # super().__init__() 11 | 12 | # def get_down_block(in_channels, out_channels, padding): 13 | # return blocks.ResBlock(in_channels, out_channels, padding, upsample=False, downsample=True, 14 | # norm_layer='none') 15 | 16 | # def get_res_block(in_channels, out_channels, padding): 17 | # return blocks.ResBlock(in_channels, out_channels, padding, upsample=False, downsample=False, 18 | # norm_layer='none') 19 | 20 | # if args.padding == 'zero': 21 | # padding = nn.ZeroPad2d 22 | # elif args.padding == 'reflection': 23 | # padding = nn.ReflectionPad2d 24 | 25 | # self.out_channels = args.embed_channels 26 | 27 | # self.down_block = nn.Sequential( 28 | # # padding(1), 29 | # spectral_norm( 30 | # nn.Conv2d(args.in_channels, args.num_channels, 3, 1, 1), 31 | # eps=1e-4), 32 | # nn.ReLU(), 33 | # # padding(1), 34 | # spectral_norm( 35 | # nn.Conv2d(args.num_channels, args.num_channels, 3, 1, 1), 36 | # eps=1e-4), 37 | # nn.AvgPool2d(2)) 38 | # self.skip = nn.Sequential( 39 | # spectral_norm( 40 | # nn.Conv2d(args.in_channels, args.num_channels, 1), 41 | # eps=1e-4), 42 | # nn.AvgPool2d(2)) 43 | 44 | # self.blocks = nn.ModuleList() 45 | # num_down_blocks = min(int(math.log(args.output_image_size, 2)) - 2, args.dis_num_blocks) 46 | # in_channels = args.num_channels 47 | # for i in range(1, num_down_blocks): 48 | # out_channels = min(in_channels * 2, args.max_num_channels) 49 | # if i == args.dis_num_blocks - 1: out_channels = self.out_channels 50 | # self.blocks.append(get_down_block(in_channels, out_channels, padding)) 51 | # in_channels = out_channels 52 | # for i in range(num_down_blocks, args.dis_num_blocks): 53 | # if i == args.dis_num_blocks - 1: out_channels = self.out_channels 54 | # self.blocks.append(get_res_block(in_channels, out_channels, padding)) 55 | 56 | # self.linear = spectral_norm(nn.Linear(self.out_channels, 1), eps=1e-4) 57 | 58 | 59 | # def forward(self, input): 60 | 61 | # feats = [] 62 | 63 | # out = self.down_block(input) 64 | # out = out + self.skip(input) 65 | # feats.append(out) 66 | # for block in self.blocks: 67 | # out = block(out) 68 | # feats.append(out) 69 | # out = torch.relu(out) 70 | # out = out.view(out.shape[0], self.out_channels, -1).sum(2) 71 | # out_linear = self.linear(out)[:, 0] 72 | # return out_linear,feats 73 | 74 | from torch.nn.utils import spectral_norm 75 | 76 | # PyTorch implementation by vinesmsuic 77 | # Referenced from official tensorflow implementation: https://github.com/SystemErrorWang/White-box-Cartoonization/blob/master/train_code/network.py 78 | # slim.convolution2d uses constant padding (zeros). 79 | # Paper used spectral_norm 80 | 81 | class Block(nn.Module): 82 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,activate=True): 83 | super().__init__() 84 | self.sn_conv = spectral_norm(nn.Conv2d( 85 | in_channels, 86 | out_channels, 87 | kernel_size, 88 | stride, 89 | padding, 90 | padding_mode="zeros" # Author's code used slim.convolution2d, which is using SAME padding (zero padding in pytorch) 91 | )) 92 | self.activate = activate 93 | if self.activate: 94 | self.LReLU = nn.LeakyReLU(negative_slope=0.2, inplace=True) 95 | 96 | def forward(self, x): 97 | x = self.sn_conv(x) 98 | if self.activate: 99 | x = self.LReLU(x) 100 | 101 | return x 102 | 103 | 104 | class Discriminator(nn.Module): 105 | def __init__(self, in_channels=3, out_channels=1, features=[32, 64, 128,256]): 106 | super().__init__() 107 | 108 | self.model = nn.Sequential( 109 | #k3n32s2 110 | Block(in_channels, features[0], kernel_size=3, stride=2, padding=1), 111 | #k3n32s1 112 | Block(features[0], features[0], kernel_size=3, stride=1, padding=1), 113 | 114 | #k3n64s2 115 | Block(features[0], features[1], kernel_size=3, stride=2, padding=1), 116 | #k3n64s1 117 | Block(features[1], features[1], kernel_size=3, stride=1, padding=1), 118 | 119 | #k3n128s2 120 | Block(features[1], features[2], kernel_size=3, stride=2, padding=1), 121 | #k3n128s1 122 | Block(features[2], features[2], kernel_size=3, stride=1, padding=1), 123 | 124 | #k3n256s2 125 | Block(features[2], features[3], kernel_size=3, stride=2, padding=1), 126 | #k3n256s1 127 | Block(features[3], features[3], kernel_size=3, stride=1, padding=1), 128 | 129 | 130 | #k1n1s1 131 | Block(features[3], out_channels, kernel_size=1, stride=1, padding=0) 132 | ) 133 | 134 | def forward(self, x): 135 | x = self.model(x) 136 | 137 | return x 138 | -------------------------------------------------------------------------------- /model/AlignModule/module.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module,Dropout,BatchNorm2d,Linear,BatchNorm1d 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Backbone(Module): 11 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 12 | super(Backbone, self).__init__() 13 | assert input_size in [112, 224], "input_size should be 112 or 224" 14 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 15 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 16 | blocks = get_blocks(num_layers) 17 | if mode == 'ir': 18 | unit_module = bottleneck_IR 19 | elif mode == 'ir_se': 20 | unit_module = bottleneck_IR_SE 21 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 22 | BatchNorm2d(64), 23 | PReLU(64)) 24 | if input_size == 112: 25 | self.output_layer = Sequential(BatchNorm2d(512), 26 | Dropout(drop_ratio), 27 | Flatten(), 28 | Linear(512 * 7 * 7, 512), 29 | BatchNorm1d(512, affine=affine)) 30 | else: 31 | self.output_layer = Sequential(BatchNorm2d(512), 32 | Dropout(drop_ratio), 33 | Flatten(), 34 | Linear(512 * 14 * 14, 512), 35 | BatchNorm1d(512, affine=affine)) 36 | 37 | modules = [] 38 | for block in blocks: 39 | for bottleneck in block: 40 | modules.append(unit_module(bottleneck.in_channel, 41 | bottleneck.depth, 42 | bottleneck.stride)) 43 | self.body = Sequential(*modules) 44 | 45 | def forward(self, x): 46 | x = self.input_layer(x) 47 | x = self.body(x) 48 | x = self.output_layer(x) 49 | return l2_norm(x) 50 | 51 | 52 | 53 | 54 | 55 | class Flatten(Module): 56 | def forward(self, input): 57 | return input.view(input.size(0), -1) 58 | 59 | 60 | def l2_norm(input, axis=1): 61 | norm = torch.norm(input, 2, axis, True) 62 | output = torch.div(input, norm) 63 | return output 64 | 65 | 66 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 67 | """ A named tuple describing a ResNet block. """ 68 | 69 | 70 | def get_block(in_channel, depth, num_units, stride=2): 71 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 72 | 73 | 74 | def get_blocks(num_layers): 75 | if num_layers == 50: 76 | blocks = [ 77 | get_block(in_channel=64, depth=64, num_units=3), 78 | get_block(in_channel=64, depth=128, num_units=4), 79 | get_block(in_channel=128, depth=256, num_units=14), 80 | get_block(in_channel=256, depth=512, num_units=3) 81 | ] 82 | elif num_layers == 100: 83 | blocks = [ 84 | get_block(in_channel=64, depth=64, num_units=3), 85 | get_block(in_channel=64, depth=128, num_units=13), 86 | get_block(in_channel=128, depth=256, num_units=30), 87 | get_block(in_channel=256, depth=512, num_units=3) 88 | ] 89 | elif num_layers == 152: 90 | blocks = [ 91 | get_block(in_channel=64, depth=64, num_units=3), 92 | get_block(in_channel=64, depth=128, num_units=8), 93 | get_block(in_channel=128, depth=256, num_units=36), 94 | get_block(in_channel=256, depth=512, num_units=3) 95 | ] 96 | else: 97 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 98 | return blocks 99 | 100 | 101 | class SEModule(Module): 102 | def __init__(self, channels, reduction): 103 | super(SEModule, self).__init__() 104 | self.avg_pool = AdaptiveAvgPool2d(1) 105 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 106 | self.relu = ReLU(inplace=True) 107 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 108 | self.sigmoid = Sigmoid() 109 | 110 | def forward(self, x): 111 | module_input = x 112 | x = self.avg_pool(x) 113 | x = self.fc1(x) 114 | x = self.relu(x) 115 | x = self.fc2(x) 116 | x = self.sigmoid(x) 117 | return module_input * x 118 | 119 | 120 | class bottleneck_IR(Module): 121 | def __init__(self, in_channel, depth, stride): 122 | super(bottleneck_IR, self).__init__() 123 | if in_channel == depth: 124 | self.shortcut_layer = MaxPool2d(1, stride) 125 | else: 126 | self.shortcut_layer = Sequential( 127 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 128 | BatchNorm2d(depth) 129 | ) 130 | self.res_layer = Sequential( 131 | BatchNorm2d(in_channel), 132 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 133 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 134 | ) 135 | 136 | def forward(self, x): 137 | shortcut = self.shortcut_layer(x) 138 | res = self.res_layer(x) 139 | return res + shortcut 140 | 141 | 142 | class bottleneck_IR_SE(Module): 143 | def __init__(self, in_channel, depth, stride): 144 | super(bottleneck_IR_SE, self).__init__() 145 | if in_channel == depth: 146 | self.shortcut_layer = MaxPool2d(1, stride) 147 | else: 148 | self.shortcut_layer = Sequential( 149 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 150 | BatchNorm2d(depth) 151 | ) 152 | self.res_layer = Sequential( 153 | BatchNorm2d(in_channel), 154 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 155 | PReLU(depth), 156 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 157 | BatchNorm2d(depth), 158 | SEModule(depth, 16) 159 | ) 160 | 161 | def forward(self, x): 162 | shortcut = self.shortcut_layer(x) 163 | res = self.res_layer(x) 164 | return res + shortcut 165 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # HeSer.Pytorch 2 | unofficial implementation of Few-Shot Head Swapping in the Wild
3 | you can find official version [here](https://github.com/jmliu88/HeSer)
4 | I did not use the discriminator from the paper and just follow [DCT-NET](https://github.com/LeslieZhoa/DCT-NET.Pytorch) 5 | ![](./assets/11.png) 6 | ![](./assets/22.png) 7 | ## !!!!NEWS!!!!! 8 | [HeadSwap](https://github.com/LeslieZhoa/HeadSwap) is Back. You can use it in colab. Try Now!!! 9 | ## enviroment 10 | - torch 11 | - opencv-python 12 | - tensorboardX 13 | - imgaug 14 | - face-alignment 15 | ```shell 16 | # download pretrain model 17 | cd process 18 | bash download_weight.sh 19 | ``` 20 | ## How to RUN 21 | ### train 22 | I only train one ID for driving 23 | #### Data Process 24 | 1. download voxceleb2
25 | a. I just download [voxceleb2 test dataset](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html), you can use this [website](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data/vox2_test_txt.zip)
26 | b. You can unzip this file like this: 27 | ``` 28 | +--- dataset 29 | | +--- vox2_test_txt 30 | | | +--- txt 31 | | | | +--- id00017 32 | | | | | +--- 01dfn2spqyE 33 | | | | | | +--- 00001.txt 34 | | | | | +--- 5MkXgwdrmJw 35 | | | | | | +--- 00002.txt 36 | | | | | +--- 7t6lfzvVaTM 37 | | | | | | +--- 00003.txt 38 | | | | | | +--- 00004.txt 39 | | | | | | +--- 00005.txt 40 | | | | | | +--- 00006.txt 41 | | | | | | +--- 00007.txt 42 | 43 | ``` 44 | c. Install yt-dlp and aria2c by yourself. I think you can do that through internet. 45 | ``` 46 | cd process 47 | python download_and_process.py 48 | ``` 49 | d. the dataset is like: 50 | ``` 51 | voceleb2/ 52 | |-- id00017 53 | | |-- 01dfn2spqyE 54 | | | `-- 00.npy 55 | | |-- 5MkXgwdrmJw 56 | | | |-- 00.npy 57 | | | `-- 5MkXgwdrmJw.mp4 58 | | |-- 7t6lfzvVaTM 59 | | | |-- 00.npy 60 | | | |-- 01.npy 61 | | | |-- 02.npy 62 | | | |-- 03.npy 63 | | | |-- 04.npy 64 | | | |-- 05.npy 65 | | | |-- 06.npy 66 | | | |-- 07.npy 67 | | | |-- 08.npy 68 | | | |-- 09.npy 69 | | | `-- 7t6lfzvVaTM.mp4 70 | ``` 71 | 2. crop and align 72 | ``` 73 | cd process 74 | python process_raw_video.py 75 | ``` 76 | the dataset is like: 77 | ``` 78 | process/ 79 | |-- img 80 | | |-- id00017 81 | | | |-- 5MkXgwdrmJw-0000 82 | | | | |-- 1273.png 83 | | | | |-- 1274.png 84 | | | | |-- 1275.png 85 | | | | |-- 1276.png 86 | | | | |-- 1277.png 87 | | | | |-- 1278.png 88 | | | | |-- 1279.png 89 | | | | |-- 1280.png 90 | | | | |-- 1281.png 91 | | | | |-- 1282.png 92 | | | | |-- 1283.png 93 | | | | |-- 1284.png 94 | | | | |-- 1285.png 95 | | | | |-- 1286.png 96 | | | | |-- 1287.png 97 | | | | |-- 1288.png 98 | | | | |-- 1289.png 99 | ``` 100 | 3. Remove data below threshold 101 | ``` 102 | cd process 103 | python filter_idfiles.py 104 | ``` 105 | 4. face parsing
106 | follow [LVT](https://github.com/LeslieZhoa/LVT) to get face parsing
107 | the mask data is like: 108 | ``` 109 | process/mask/ 110 | |-- id00017 111 | | |-- 5MkXgwdrmJw-0000 112 | | | |-- 1273.png 113 | | | |-- 1274.png 114 | | | |-- 1275.png 115 | | | |-- 1276.png 116 | | | |-- 1277.png 117 | | | |-- 1278.png 118 | | | |-- 1279.png 119 | | | |-- 1280.png 120 | ``` 121 | #### Train Align 122 | I just use id00061 to train align
123 | check [model/AlignModule/config.py](model/AlignModule/config.py#L1) to put your own path and params
124 | for single gpu 125 | ``` 126 | python train.py --model align --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 --dist 127 | ``` 128 | for multi gpu 129 | ``` 130 | python -m torch.distributed.launch train.py --model align --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 131 | ``` 132 | #### Train Blend 133 | check [model/BlendModule/config.py](model/BlendModule/config.py#L1) to put your own path and params
134 | for single gpu 135 | ``` 136 | python train.py --model blend --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 --dist 137 | ``` 138 | for multi gpu 139 | ``` 140 | python -m torch.distributed.launch train.py --model blend --batch_size 8 --checkpoint_path checkpoint --lr 2e-4 --print_interval 100 --save_interval 100 141 | ``` 142 | ## Inference 143 | follow inference.py, change your own model path and input images
144 | ```shell 145 | python inference.py 146 | ``` 147 | ## Show 148 | The result is just overfitting 149 | ![](assets/show.png) 150 | 151 | ## Credits 152 | DCT-NET model and implementation:
153 | https://github.com/LeslieZhoa/DCT-NET.Pytorch Copyright © 2022, LeslieZhoa.
154 | License https://github.com/LeslieZhoa/DCT-NET.Pytorch/blob/main/LICENSE 155 | 156 | latent-pose-reenactment model and implementation:
157 | https://github.com/shrubb/latent-pose-reenactment Copyright © 2020, shrubb.
158 | License https://github.com/shrubb/latent-pose-reenactment/blob/master/LICENSE.txt 159 | 160 | arcface pytorch model pytorch model and implementation:
161 | https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.
162 | 163 | LVT model and implementation:
164 | https://github.com/LeslieZhoa/LVT Copyright © 2022, LeslieZhoa.
165 | 166 | face-parsing model and implementation:
167 | https://github.com/zllrunning/face-parsing.PyTorch Copyright © 2019, zllrunning.
168 | License https://github.com/zllrunning/face-parsing.PyTorch/blob/master/LICENSE 169 | -------------------------------------------------------------------------------- /process/process_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def crop_with_padding(image, lmks,scale=1.8,size=512,align=True): 5 | 6 | img_box = [np.min(lmks[:, 0]), np.min(lmks[:, 1]), np.max(lmks[:, 0]), np.max(lmks[:, 1])] 7 | 8 | center = ((img_box[0] + img_box[2]) / 2.0, (img_box[1] + img_box[3]) / 2.0) 9 | 10 | if align: 11 | lm_eye_left = lmks[36 : 42] # left-clockwise 12 | lm_eye_right = lmks[42 : 48] # left-clockwise 13 | 14 | eye_left = np.mean(lm_eye_left, axis=0) 15 | eye_right = np.mean(lm_eye_right, axis=0) 16 | angle = np.arctan2((eye_right[1] - eye_left[1]), (eye_right[0] - eye_left[0])) / np.pi * 180 17 | 18 | RotateMatrix = cv2.getRotationMatrix2D(center, angle, scale=1) 19 | 20 | rotated_img = cv2.warpAffine(image, RotateMatrix, (image.shape[1], image.shape[0])) 21 | rotated_lmks = apply_transform(RotateMatrix, lmks) 22 | else: 23 | rotated_img = image 24 | rotated_lmks = lmks 25 | RotateMatrix = np.array([[1,0,0], 26 | [0,1,0]]) 27 | 28 | faceBox = [np.min(rotated_lmks[:, 0]), np.min(rotated_lmks[:, 1]), 29 | np.max(rotated_lmks[:, 0]), np.max(rotated_lmks[:, 1])] 30 | 31 | cx_box = (faceBox[0] + faceBox[2]) / 2. 32 | cy_box = (faceBox[1] + faceBox[3]) / 2. 33 | width = faceBox[2] - faceBox[0] + 1 34 | height = faceBox[3] - faceBox[1] + 1 35 | face_size = max(width, height) 36 | bbox_size = int(face_size * scale) 37 | 38 | 39 | 40 | 41 | x_min = int(cx_box-bbox_size / 2.) 42 | y_min = int(cy_box-bbox_size / 2.) 43 | x_max = x_min + bbox_size 44 | y_max = y_min + bbox_size 45 | 46 | boundingBox = [max(x_min, 0), max(y_min, 0), min(x_max, rotated_img.shape[1]), min(y_max, rotated_img.shape[0])] 47 | imgCropped = rotated_img[boundingBox[1]:boundingBox[3], boundingBox[0]:boundingBox[2]] 48 | imgCropped = cv2.copyMakeBorder(imgCropped, max(-y_min, 0), max(y_max - image.shape[0], 0), max(-x_min, 0), 49 | max(x_max - image.shape[1], 0),cv2.BORDER_CONSTANT,value=(0,0,0)) 50 | boundingBox = [x_min, y_min, x_max, y_max] 51 | 52 | scale_h = size / float(bbox_size) 53 | scale_w = size / float(bbox_size) 54 | rotated_lmks[:, 0] = (rotated_lmks[:, 0] - boundingBox[0]) * scale_w 55 | rotated_lmks[:, 1] = (rotated_lmks[:, 1] - boundingBox[1]) * scale_h 56 | # print(imgCropped.shape) 57 | imgResize = cv2.resize(imgCropped, (size, size)) 58 | 59 | ### 计算变换(原图->crop box) 60 | m1 = np.concatenate((RotateMatrix,np.array([[0.0,0.0,1.0]])), axis=0) #rotate(+translation) 61 | m2 = np.eye(3) #translation 62 | m2[0][2] = -boundingBox[0] 63 | m2[1][2] = -boundingBox[1] 64 | m3 = np.eye(3) #scaling 65 | m3[0][0] = m3[1][1] = scale_h 66 | m = np.matmul(np.matmul(m3,m2),m1) 67 | im = np.linalg.inv(m) 68 | info = {'rotated_lmk':rotated_lmks, 69 | 'm':m, 70 | 'im':im} 71 | 72 | return imgResize,info 73 | 74 | 75 | def apply_transform(transform_matrix, lmks): 76 | ''' 77 | args 78 | transform_matrix: float (3,3)|(2,3) 79 | lmks: float (2)|(3)|(k,2)|(k,3) 80 | 81 | ret 82 | ret_lmks: float (2)|(3)|(k,2)|(k,3) 83 | ''' 84 | if transform_matrix.shape[0] == 2: 85 | transform_matrix = np.concatenate((transform_matrix,np.array([[0.0,0.0,1.0]])), axis=0) 86 | only_one = False 87 | if len(lmks.shape) == 1: 88 | lmks = lmks[np.newaxis, :] 89 | only_one = True 90 | only_two_dim = False 91 | if lmks.shape[1] == 2: 92 | lmks = np.concatenate((lmks, np.ones((lmks.shape[0],1), dtype=np.float32)), axis=1) 93 | only_two_dim = True 94 | 95 | ret_lmks = np.matmul(transform_matrix, lmks.T).T 96 | 97 | if only_two_dim: 98 | ret_lmks = ret_lmks[:,:2] 99 | if only_one: 100 | ret_lmks = ret_lmks[0] 101 | 102 | return ret_lmks 103 | 104 | 105 | def choose_one_detection(frame_faces,box): 106 | """ 107 | frame_faces 108 | list of lists of length 5 109 | several face detections from one image 110 | 111 | return: 112 | list of 5 floats 113 | one of the input detections: `(l, t, r, b, confidence)` 114 | """ 115 | frame_faces = list(filter(lambda x:x[-1]>0.9,frame_faces)) 116 | if len(frame_faces) == 0: 117 | return None 118 | 119 | else: 120 | # sort by area, find the largest box 121 | largest_area, largest_idx = -1, -1 122 | for idx, face in enumerate(frame_faces): 123 | area = compute_iou(box,face) 124 | # area = abs(face[2]-face[0]) * abs(face[1]-face[3]) 125 | if area > largest_area: 126 | largest_area = area 127 | largest_idx = idx 128 | 129 | if largest_area < 0.1: 130 | return None 131 | 132 | retval = frame_faces[largest_idx] 133 | 134 | 135 | return np.array(retval).tolist() 136 | 137 | 138 | def compute_iou(rec1, rec2): 139 | """ 140 | computing IoU 141 | :param rec1: (x0, y0, x1, y1), which reflects 142 | (top, left, bottom, right) 143 | :param rec2: (x0, y0, x1, y1) 144 | :return: scala value of IoU 145 | """ 146 | # computing area of each rectangles 147 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) 148 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) 149 | 150 | # computing the sum_area 151 | sum_area = S_rec1 + S_rec2 152 | 153 | # find the each edge of intersect rectangle 154 | left_line = max(rec1[0], rec2[0]) 155 | right_line = min(rec1[2], rec2[2]) 156 | top_line = max(rec1[1], rec2[1]) 157 | bottom_line = min(rec1[3], rec2[3]) 158 | 159 | # judge if there is an intersect 160 | if left_line >= right_line or top_line >= bottom_line: 161 | return 0 162 | else: 163 | intersect = (right_line - left_line) * (bottom_line - top_line) 164 | return (intersect / (sum_area - intersect))*1.0 165 | # return intersect / S_rec2 -------------------------------------------------------------------------------- /dataloader/augmentation.py: -------------------------------------------------------------------------------- 1 | import imgaug.augmenters as iaa 2 | import torch 3 | import numpy as np 4 | 5 | from contextlib import contextmanager 6 | 7 | # heavily copy from https://github.com/shrubb/latent-pose-reenactment 8 | class ParametricAugmenter: 9 | def is_empty(self): 10 | return not self.seq and not self.shift_seq 11 | 12 | def __init__(self, use_pixelwise_augs,use_affine_scale,use_affine_shift): 13 | 14 | 15 | sometimes = lambda aug: iaa.Sometimes(0.5, aug) 16 | total_augs = [] 17 | 18 | if use_pixelwise_augs: 19 | pixelwise_augs = [ 20 | iaa.SomeOf((0, 5), 21 | [ 22 | # sometimes(iaa.Superpixels(p_replace=(0, 0.25), n_segments=(150, 200))), 23 | iaa.OneOf([ 24 | iaa.GaussianBlur((0, 1.0)), # blur images with a sigma between 0 and 3.0 25 | iaa.AverageBlur(k=(1, 3)), 26 | # blur image using local means with kernel sizes between 2 and 7 27 | iaa.MedianBlur(k=(1, 3)), 28 | # blur image using local medians with kernel sizes between 2 and 7 29 | ]), 30 | iaa.Sharpen(alpha=(0, 1.0), lightness=(1.0, 1.5)), # sharpen images 31 | iaa.Emboss(alpha=(0, 1.0), strength=(0, 0.5)), # emboss images 32 | # search either for all edges or for directed edges, 33 | # blend the result with the original image using a blobby mask 34 | iaa.BlendAlphaSimplexNoise( 35 | iaa.EdgeDetect(alpha=(0.0, 0.15)), 36 | ), 37 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=False), 38 | # add gaussian noise to images 39 | iaa.Add((-10, 10), per_channel=0.5), 40 | # change brightness of images (by -10 to 10 of original value) 41 | iaa.AddToSaturation((-20, 20)), # change hue and saturation 42 | iaa.JpegCompression((70, 99)), 43 | 44 | iaa.Multiply((0.5, 1.5), per_channel=False), 45 | 46 | iaa.OneOf([ 47 | iaa.LinearContrast((0.75, 1.25), per_channel=False), 48 | iaa.SigmoidContrast(cutoff=0.5, gain=(3.0, 11.0)) 49 | ]), 50 | sometimes(iaa.ElasticTransformation(alpha=(0.5, 3.5), sigma=0.15)), 51 | # move pixels locally around (with random strengths) 52 | ], 53 | random_order=True 54 | ) 55 | ] 56 | total_augs.extend(pixelwise_augs) 57 | affine_augs_scale = [] 58 | if use_affine_scale: 59 | affine_augs_scale = [sometimes(iaa.Affine( 60 | scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, 61 | # scale images to 80-120% of their size, individually per axis 62 | order=[1], # use bilinear interpolation (fast) 63 | mode=["reflect"] 64 | ))] 65 | # total_augs.extend(affine_augs_scale) 66 | 67 | if use_affine_shift: 68 | affine_augs_shift = [sometimes(iaa.Affine( 69 | translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)}, 70 | order=[1], # use bilinear interpolation (fast) 71 | mode=["reflect"] 72 | ))] 73 | else: 74 | affine_augs_shift = [] 75 | 76 | self.shift_seq = iaa.Sequential(affine_augs_shift) 77 | self.seq = iaa.Sequential(total_augs, random_order=True) 78 | self.scale_seq = iaa.Sequential(affine_augs_scale) 79 | 80 | def tensor2image(self, image, norm = 255.0): 81 | return (np.expand_dims(image.squeeze().permute(1, 2, 0).numpy(), 0) * norm) 82 | 83 | def image2tensor(self, image, norm = 255.0): 84 | image = image.astype(np.float32) / norm 85 | image = torch.tensor(np.squeeze(image)).permute(2, 0, 1) 86 | return image 87 | 88 | 89 | def augment_tensor(self, image): 90 | if self.seq or self.shift_seq: 91 | image = self.tensor2image(image).astype(np.uint8) 92 | image = self.seq(images=image) 93 | image = self.shift_seq(images=image,) 94 | image = self.image2tensor(image) 95 | 96 | return image 97 | 98 | def augment_triple(self, image1, image2,mask): 99 | if self.seq or self.shift_seq: 100 | image1 = self.tensor2image(image1).astype(np.uint8) 101 | 102 | image1 = self.seq(images=image1,) 103 | if self.scale_seq: 104 | 105 | scale_seq_deterministic = self.scale_seq.to_deterministic() 106 | image1 = scale_seq_deterministic(images=image1) 107 | mask = scale_seq_deterministic(images=mask) 108 | if self.shift_seq: 109 | image2 = self.tensor2image(image2).astype(np.uint8) 110 | shift_seq_deterministic = self.shift_seq.to_deterministic() 111 | image1 = shift_seq_deterministic(images=image1,) 112 | image2 = shift_seq_deterministic(images=image2) 113 | mask = shift_seq_deterministic(images=mask) 114 | 115 | image2 = self.image2tensor(image2) 116 | 117 | image1 = self.image2tensor(image1) 118 | mask = self.image2tensor(mask) 119 | 120 | return image1, image2,mask 121 | 122 | @contextmanager 123 | def deterministic_(self, seed): 124 | """ 125 | A context manager to pre-define the random state of all augmentations. 126 | 127 | seed: 128 | `int` 129 | """ 130 | # Backup the random states 131 | old_seq = self.seq.deepcopy() 132 | old_shift_seq = self.shift_seq.deepcopy() 133 | self.seq.seed_(seed) 134 | self.shift_seq.seed_(seed) 135 | yield 136 | # Restore the backed up random states 137 | self.seq = old_seq 138 | self.shift_seq = old_shift_seq 139 | -------------------------------------------------------------------------------- /trainer/ModelTrainer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author Leslie 5 | @date 20220812 6 | ''' 7 | import torch 8 | import math 9 | import time,os 10 | 11 | from utils.visualizer import Visualizer 12 | from torch.nn.parallel import DistributedDataParallel as DDP 13 | import torch.distributed as dist 14 | import subprocess 15 | from utils.utils import convert_img 16 | class ModelTrainer: 17 | 18 | def __init__(self,args): 19 | 20 | self.args = args 21 | self.batch_size = args.batch_size 22 | self.old_lr = args.lr 23 | if args.rank == 0 : 24 | self.vis = Visualizer(args) 25 | 26 | if args.eval: 27 | self.val_vis = Visualizer(args,"val") 28 | 29 | # ## ===== ===== ===== ===== ===== 30 | # ## Train network 31 | # ## ===== ===== ===== ===== ===== 32 | 33 | def train_network(self,train_loader,test_loader): 34 | 35 | counter = 0 36 | loss_dict = {} 37 | acc_num = 0 38 | mn_loss = float('inf') 39 | 40 | steps = 0 41 | begin_it = 0 42 | if self.args.pretrain_path: 43 | begin_it = int(self.args.pretrain_path.split('/')[-1].split('-')[0]) 44 | steps = (begin_it+1) * math.ceil(self.args.mx_data_length/self.args.batch_size) 45 | 46 | print("current steps: %d | one epoch steps: %d "%(steps,self.args.mx_data_length)) 47 | 48 | for epoch in range(begin_it+1,self.args.max_epoch): 49 | 50 | for ii,(data) in enumerate(train_loader): 51 | 52 | tstart = time.time() 53 | 54 | self.run_single_step(data,steps) 55 | losses = self.get_latest_losses() 56 | 57 | for key,val in losses.items(): 58 | loss_dict[key] = loss_dict.get(key,0) + val.mean().item() 59 | 60 | counter += 1 61 | steps += 1 62 | 63 | telapsed = time.time() - tstart 64 | 65 | 66 | if ii % self.args.print_interval == 0 and self.args.rank == 0: 67 | 68 | for key,val in loss_dict.items(): 69 | loss_dict[key] /= counter 70 | lr_rate = self.get_lr() 71 | print_dict = {**{"time":telapsed,"lr":lr_rate}, 72 | **loss_dict} 73 | self.vis.print_current_errors(epoch,ii,print_dict,telapsed) 74 | 75 | self.vis.plot_current_errors(print_dict,steps) 76 | 77 | loss_dict = {} 78 | counter = 0 79 | 80 | # torch.cuda.empty_cache() 81 | if self.args.save_interval != 0 and ii % self.args.save_interval == 0 and \ 82 | self.args.rank == 0: 83 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,ii))) 84 | 85 | display_data = self.select_img(self.get_latest_generated()) 86 | 87 | self.vis.display_current_results(display_data,steps) 88 | 89 | 90 | 91 | if self.args.eval and self.args.test_interval > 0 and steps % self.args.test_interval == 0: 92 | val_loss = self.evalution(test_loader,steps,epoch) 93 | 94 | if self.args.early_stop: 95 | 96 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch) 97 | if stop_flag: 98 | return 99 | 100 | # print('******************memory:',psutil.virtual_memory()[3]) 101 | 102 | if self.args.rank == 0 : 103 | self.saveParameters(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0))) 104 | 105 | # 验证,保存最优模型 106 | if test_loader or self.args.eval: 107 | val_loss = self.evalution(test_loader,steps,epoch) 108 | 109 | if self.args.early_stop: 110 | 111 | acc_num,mn_loss,stop_flag = self.early_stop_wait(self.get_loss_from_val(val_loss),acc_num,mn_loss,epoch) 112 | if stop_flag: 113 | return 114 | 115 | 116 | if self.args.rank == 0 : 117 | self.vis.close() 118 | 119 | 120 | 121 | def early_stop_wait(self,loss,acc_num,mn_loss,epoch): 122 | 123 | if self.args.rank == 0: 124 | if loss < mn_loss: 125 | mn_loss = loss 126 | cmd_one = 'cp -r %s %s'%(os.path.join(self.args.checkpoint_path,"%03d-%08d.pth"%(epoch,0)), 127 | os.path.join(self.args.checkpoint_path,'final.pth')) 128 | done_one = subprocess.Popen(cmd_one,stdout=subprocess.PIPE,shell=True) 129 | done_one.wait() 130 | acc_num = 0 131 | else: 132 | acc_num += 1 133 | # 多机多卡,某一张卡退出则终止程序,使用all_reduce 134 | if self.args.dist: 135 | 136 | if acc_num > self.args.stop_interval: 137 | signal = torch.tensor([0]).cuda() 138 | else: 139 | signal = torch.tensor([1]).cuda() 140 | else: 141 | if self.args.dist: 142 | signal = torch.tensor([1]).cuda() 143 | 144 | if self.args.dist: 145 | dist.all_reduce(signal) 146 | value = signal.item() 147 | if value >= int(os.environ.get("WORLD_SIZE","1")): 148 | dist.all_reduce(torch.tensor([0]).cuda()) 149 | return acc_num,mn_loss,False 150 | else: 151 | return acc_num,mn_loss,True 152 | 153 | else: 154 | if acc_num > self.args.stop_interval: 155 | return acc_num,mn_loss,True 156 | else: 157 | return acc_num,mn_loss,False 158 | 159 | def run_single_step(self,data,steps): 160 | data = self.process_input(data) 161 | self.run_discriminator_one_step(data,steps) 162 | self.run_generator_one_step(data,steps) 163 | 164 | def select_img(self,data,name='fake',axis=2): 165 | if data is None: 166 | return None 167 | cat_img = [] 168 | for v in data: 169 | cat_img.append(v.detach().cpu()) 170 | 171 | cat_img = torch.cat(cat_img,-1) 172 | cat_img = torch.cat(torch.split(cat_img,1,dim=0),axis)[0] 173 | 174 | return {name:convert_img(cat_img)} 175 | 176 | 177 | 178 | ################################################################## 179 | # Helper functions 180 | ################################################################## 181 | 182 | def get_loss_from_val(self,loss): 183 | return loss 184 | 185 | def get_show_inp(self,data): 186 | if not isinstance(data,list): 187 | return [data] 188 | return data 189 | 190 | def use_ddp(self,model): 191 | if model is None: 192 | return None 193 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) #用于将BN转换成ddp模式/ 194 | # model = DDP(model,broadcast_buffers=False,find_unused_parameters=True) # find_unused_parameters->训练gan会有判别器或生成器参数不参与训练,需使用该参数 195 | model = DDP(model, 196 | broadcast_buffers=False, 197 | find_unused_parameters=True 198 | ) 199 | model_on_one_gpu = model.module #若需要调用self.model的函数,在ddp模式要调用self._model_on_one_gpu 200 | return model,model_on_one_gpu 201 | def process_input(self,data): 202 | 203 | if torch.cuda.is_available(): 204 | if isinstance(data,list): 205 | data = [x.cuda() for x in data] 206 | else: 207 | data = data.cuda() 208 | return data -------------------------------------------------------------------------------- /model/BlendModule/generator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author Leslie 3 | @date 20220823 4 | ''' 5 | import torch 6 | from torch import nn 7 | from model.BlendModule.module import VGG19_pytorch,Decoder 8 | import torch.nn.functional as F 9 | import pdb 10 | 11 | class Generator(nn.Module): 12 | def __init__(self,args): 13 | super().__init__() 14 | self.feature_ext = VGG19_pytorch() 15 | self.decoder = Decoder(ic=args.decoder_ic) 16 | self.dilate = nn.MaxPool2d(kernel_size=args.dilate_kernel, 17 | stride=1, 18 | padding=args.dilate_kernel//2) 19 | 20 | self.phi = nn.Conv2d(in_channels=args.f_in_channels, 21 | out_channels=args.f_inter_channels, kernel_size=1, stride=1, padding=0) 22 | self.theta = nn.Conv2d(in_channels=args.f_in_channels, 23 | out_channels=args.f_inter_channels, kernel_size=1, stride=1, padding=0) 24 | 25 | self.temperature = args.temperature 26 | 27 | self.head_index = [1,2,3,4,5,6,7,8,9,10,11,12,13,17,18] 28 | self.eps = 1e-8 29 | 30 | def forward(self,I_a,I_gray,I_t,M_a,M_t,gt=None,cycle=False,train=False): 31 | 32 | fA = self.feature_ext(I_a) 33 | fT = self.feature_ext(I_t) 34 | 35 | fA = self.phi(fA) 36 | fT = self.theta(fT) 37 | 38 | gen_h,gen_i,mask_list,matrix_list = self.RCNet(fA,fT,M_a,M_t,I_t) 39 | 40 | gen_h = F.interpolate(gen_h,size=I_t.shape[-2:],mode='bilinear',align_corners=True) 41 | gen_i = F.interpolate(gen_i,size=I_t.shape[-2:],mode='bilinear',align_corners=True) 42 | M_Ah,M_Ad,M_Td,M_Ai,M_Ti,M_Ar,M_Tr = mask_list 43 | 44 | if cycle: 45 | 46 | cycle_gen = self.RCCycle(gen_h+gen_i,[M_Ar,M_Tr,M_Ai,M_Ti],matrix_list,fA.shape) 47 | 48 | I_td = I_t * M_Td 49 | I_td = F.interpolate(I_td, size=cycle_gen.shape[-2:],mode='bilinear') 50 | return cycle_gen,I_td 51 | 52 | I_tb = gt * (1-M_Ad) 53 | I_ag = I_gray * M_Ah 54 | 55 | # cat_img = torch.cat([gen_h,gen_i,M_Ah.repeat(1,3,1,1),I_tb,M_Ai.repeat(1,3,1,1),I_ag.repeat(1,3,1,1)],-1) 56 | # import cv2 57 | # cv2.imwrite('1.png',(cat_img[0].permute(1,2,0).detach().cpu().numpy()[...,::-1]+1)*127.5) 58 | # pdb.set_trace() 59 | inp = torch.cat([gen_h,gen_i, 60 | M_Ah, 61 | I_tb,M_Ai,I_ag],1) 62 | 63 | oup = self.decoder(inp) 64 | 65 | if train: 66 | return oup,M_Ah,M_Ai 67 | 68 | return oup 69 | 70 | 71 | def RCNet(self,fA,fT,M_a,M_t,I_t): 72 | 73 | M_Ah = self.get_mask(M_a,self.head_index) 74 | M_Th = self.get_mask(M_t,self.head_index) 75 | 76 | M_Ti,M_Td = self.get_inpainting(M_Th) 77 | M_Ai,M_Ad = self.get_inpainting(M_Ah+M_Th,M_Ah) 78 | M_Ar = self.get_multi_mask(M_a) 79 | M_Tr = self.get_multi_mask(M_t) 80 | 81 | matrix_list = [] 82 | gen_h = None 83 | for m_a,m_t in zip(M_Ar,M_Tr): 84 | gen_h, matrix = self.compute_corre(fA,fT,m_a,m_t,I_t,gen_h) 85 | matrix_list.append(matrix) 86 | 87 | gen_i = None 88 | gen_i,matrix = self.compute_corre(fA,fT,M_Ai,M_Ti,I_t,gen_i) 89 | matrix_list.append(matrix) 90 | 91 | return gen_h,gen_i,[M_Ah,M_Ad,M_Td,M_Ai,M_Ti,M_Ar,M_Tr],matrix_list 92 | 93 | def RCCycle(self,I_t,mask_list,matrix_list,shape): 94 | M_Ar,M_Tr,M_Ai,M_Ti = mask_list 95 | batch,channel,h,w = shape 96 | gen_h = torch.zeros((batch,3,h,w)).to(I_t.device) 97 | gen_i = torch.zeros((batch,3,h,w)).to(I_t.device) 98 | I_t_resize = F.interpolate(I_t, size=(h,w),mode='bilinear',align_corners=True) 99 | 100 | M_Tr_resize = [F.interpolate(f, size=(h,w),mode='nearest') for f in M_Tr] 101 | M_Ar_resize = [F.interpolate(f, size=(h,w),mode='nearest') for f in M_Ar] 102 | M_Ti_resize = F.interpolate(M_Ti, size=(h,w),mode='nearest') 103 | M_Ai_resize = F.interpolate(M_Ai, size=(h,w),mode='nearest') 104 | for matrix,m_t,m_a in zip(matrix_list[:-1],M_Tr_resize,M_Ar_resize): 105 | for i in range(batch): 106 | f_WTA = matrix[i] 107 | f = F.softmax(f_WTA.transpose(1,2),dim=-1) 108 | 109 | ref = torch.matmul( 110 | I_t_resize[i].unsqueeze(0).masked_select( 111 | m_a[i].unsqueeze(0)==1).view(1,3,-1),f.transpose(1,2)) 112 | gen_h[i] = gen_h[i].unsqueeze(0).masked_scatter( 113 | m_t[i].unsqueeze(0)==1,ref).squeeze(0) 114 | 115 | for i in range(batch): 116 | 117 | f_WTA = matrix_list[-1][i] 118 | f = F.softmax(f_WTA.transpose(1,2),dim=-1) 119 | 120 | ref = torch.matmul( 121 | I_t_resize[i].unsqueeze(0).masked_select( 122 | M_Ai_resize[i].unsqueeze(0)==1).view(1,3,-1),f.transpose(1,2)) 123 | gen_i[i] = gen_i[i].unsqueeze(0).masked_scatter( 124 | M_Ti_resize[i].unsqueeze(0)==1,ref).squeeze(0) 125 | 126 | return gen_h + gen_i 127 | 128 | 129 | def compute_corre(self,fA,fT,M_A,M_T,I_t,gen=None): 130 | batch,channel,h,w = fA.shape 131 | matrix_list = [] 132 | if gen is None: 133 | gen = torch.zeros((batch,3,h,w)).to(I_t.device) 134 | 135 | M_A_resize = F.interpolate(M_A, size=(h,w),mode='nearest') 136 | M_T_resize = F.interpolate(M_T, size=(h,w),mode='nearest') 137 | I_t_resize = F.interpolate(I_t, size=(h,w),mode='bilinear', align_corners=True) 138 | for i in range(batch): 139 | fAr = fA[i].unsqueeze(0).masked_select( 140 | M_A_resize[i].unsqueeze(0)==1).view(1,channel,-1) # b,c,hA 141 | 142 | fTr = fT[i].unsqueeze(0).masked_select( 143 | M_T_resize[i].unsqueeze(0)==1).view(1,channel,-1) # b,c,hT 144 | 145 | fAr = self.normlize(fAr) 146 | fTr = self.normlize(fTr) 147 | 148 | matrix = torch.matmul(fAr.permute(0,2,1),fTr) # b,hA,hT 149 | f_WTA = matrix/self.temperature 150 | f = F.softmax(f_WTA,dim=-1) 151 | matrix_list.append(f_WTA) 152 | 153 | 154 | ref = torch.matmul( 155 | I_t_resize[i].unsqueeze(0).masked_select( 156 | M_T_resize[i].unsqueeze(0)==1).view(1,3,-1),f.transpose(1,2)) # [b,channel,hT] X [b,hT,hA] 157 | 158 | 159 | gen[i] = gen[i].unsqueeze(0).masked_scatter( 160 | M_A_resize[i].unsqueeze(0)==1,ref).squeeze(0) 161 | 162 | 163 | return gen,matrix_list 164 | 165 | def get_inpainting(self,M,head=None): 166 | M = torch.clamp(M,0,1) 167 | M_dilate = self.dilate(M) 168 | if head is None: 169 | MI = M_dilate - M 170 | else: 171 | MI = M_dilate - head 172 | return MI,M_dilate 173 | 174 | def get_multi_mask(self,M_a): 175 | # skin 176 | skin_mask_A = self.get_mask(M_a,[1]) 177 | # hair 178 | hair_mask_A = self.get_mask(M_a,[17,18]) 179 | 180 | # eye 181 | eye_mask_A = self.get_mask(M_a,[4,5,6]) 182 | 183 | # brow 184 | brow_mask_A = self.get_mask(M_a,[2,3]) 185 | 186 | # ear 187 | ear_mask_A = self.get_mask(M_a,[7,8,9]) 188 | 189 | #nose 190 | nose_mask_A = self.get_mask(M_a,[10]) 191 | 192 | # lip 193 | lip_mask_A = self.get_mask(M_a,[12,13]) 194 | 195 | 196 | # tooth 197 | tooth_mask_A = self.get_mask(M_a,[11]) 198 | 199 | return [skin_mask_A,hair_mask_A,eye_mask_A,brow_mask_A,ear_mask_A,nose_mask_A,lip_mask_A,tooth_mask_A] 200 | 201 | def get_mask(self,mask,indexs): 202 | out = torch.zeros_like(mask) 203 | for i in indexs: 204 | out[mask==i] = 1 205 | 206 | return out 207 | 208 | def normlize(self,x): 209 | x_mean = x.mean(dim=1,keepdim=True) 210 | x_norm = torch.norm(x,2,1,keepdim=True) + self.eps 211 | return (x-x_mean) / x_norm 212 | -------------------------------------------------------------------------------- /trainer/BlendTrainer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author Leslie 5 | @date 20220812 6 | ''' 7 | import torch 8 | 9 | from trainer.ModelTrainer import ModelTrainer 10 | from model.BlendModule.generator import Generator 11 | from model.AlignModule.discriminator import Discriminator 12 | from utils.utils import * 13 | from model.AlignModule.loss import * 14 | import torch.nn.functional as F 15 | import random 16 | import torch.distributed as dist 17 | 18 | class BlendTrainer(ModelTrainer): 19 | 20 | def __init__(self, args): 21 | super().__init__(args) 22 | self.device = 'cpu' 23 | if torch.cuda.is_available(): 24 | self.device = 'cuda' 25 | 26 | self.netG = Generator(args).to(self.device) 27 | 28 | self.netD = Discriminator(in_channels=5).to(self.device) 29 | 30 | 31 | self.optimG,self.optimD = self.create_optimizer() 32 | 33 | if args.pretrain_path is not None: 34 | self.loadParameters(args.pretrain_path) 35 | 36 | if args.dist: 37 | self.netG,self.netG_module = self.use_ddp(self.netG) 38 | self.netD,self.netD_module = self.use_ddp(self.netD) 39 | else: 40 | self.netG_module = self.netG 41 | self.netD_module = self.netD 42 | 43 | if self.args.per_loss: 44 | self.perLoss = PerceptualLoss(args.per_model).to(self.device) 45 | self.perLoss.eval() 46 | 47 | if self.args.rec_loss: 48 | self.L1Loss = torch.nn.L1Loss() 49 | 50 | 51 | def create_optimizer(self): 52 | g_optim = torch.optim.Adam( 53 | self.netG.parameters(), 54 | lr=self.args.g_lr, 55 | betas=(self.args.beta1, self.args.beta2), 56 | ) 57 | d_optim = torch.optim.Adam( 58 | self.netD.parameters(), 59 | lr=self.args.d_lr, 60 | betas=(self.args.beta1, self.args.beta2), 61 | ) 62 | 63 | return g_optim,d_optim 64 | 65 | 66 | def run_single_step(self, data, steps): 67 | self.netG.train() 68 | super().run_single_step(data, steps) 69 | 70 | 71 | def run_discriminator_one_step(self, data,step): 72 | 73 | D_losses = {} 74 | requires_grad(self.netG, False) 75 | requires_grad(self.netD, True) 76 | 77 | I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt = data 78 | fake,M_Ah,M_Ai = self.netG(I_a,I_gray,I_t,M_a,M_t,gt,train=True) 79 | fake_pred = self.netD(torch.cat([fake,M_Ah,M_Ai],1)) 80 | real_pred = self.netD(torch.cat([gt,M_Ah,M_Ai],1)) 81 | d_loss = compute_dis_loss(fake_pred, real_pred,D_losses) 82 | D_losses['d'] = d_loss 83 | 84 | self.optimD.zero_grad() 85 | d_loss.backward() 86 | self.optimD.step() 87 | 88 | self.d_losses = D_losses 89 | 90 | 91 | def run_generator_one_step(self, data,step): 92 | 93 | 94 | requires_grad(self.netG, True) 95 | requires_grad(self.netD, False) 96 | 97 | I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt = data 98 | G_losses,loss,xg = self.compute_g_loss(I_a,I_gray,I_t,M_a,M_t,gt) 99 | self.optimG.zero_grad() 100 | loss.mean().backward() 101 | self.optimG.step() 102 | 103 | g_losses,loss,fake_nopair,label_nopair = self.compute_cycle_g_loss(I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat) 104 | self.optimG.zero_grad() 105 | loss.mean().backward() 106 | self.optimG.step() 107 | 108 | self.g_losses = {**G_losses,**g_losses} 109 | 110 | self.generator = [I_a.detach(),fake_nopair.detach(), 111 | label_nopair.detach(),xg.detach(),gt.detach()] 112 | 113 | 114 | def evalution(self,test_loader,steps,epoch): 115 | 116 | loss_dict = {} 117 | index = random.randint(0,len(test_loader)-1) 118 | counter = 0 119 | with torch.no_grad(): 120 | for i,data in enumerate(test_loader): 121 | 122 | data = self.process_input(data) 123 | I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat,gt = data 124 | G_losses,losses,xg = self.compute_g_loss(I_a,I_gray,I_t,M_a,M_t,gt) 125 | for k,v in G_losses.items(): 126 | loss_dict[k] = loss_dict.get(k,0) + v.detach() 127 | if i == index and self.args.rank == 0 : 128 | 129 | show_data = [I_a,xg,gt] 130 | self.val_vis.display_current_results(self.select_img(show_data),steps) 131 | counter += 1 132 | 133 | 134 | for key,val in loss_dict.items(): 135 | loss_dict[key] /= counter 136 | 137 | if self.args.dist: 138 | # if self.args.rank == 0 : 139 | dist_losses = loss_dict.copy() 140 | for key,val in loss_dict.items(): 141 | 142 | dist.reduce(dist_losses[key],0) 143 | value = dist_losses[key].item() 144 | loss_dict[key] = value / self.args.world_size 145 | 146 | if self.args.rank == 0 : 147 | self.val_vis.plot_current_errors(loss_dict,steps) 148 | self.val_vis.print_current_errors(epoch+1,0,loss_dict,0) 149 | 150 | return loss_dict 151 | 152 | 153 | def compute_g_loss(self,I_a,I_gray,I_t,M_a,M_t,gt): 154 | G_losses = {} 155 | loss = 0 156 | fake,M_Ah,M_Ai = self.netG(I_a,I_gray,I_t,M_a,M_t,gt,train=True) 157 | fake_pred = self.netD(torch.cat([fake,M_Ah,M_Ai],1)) 158 | gan_loss = compute_gan_loss(fake_pred) * self.args.lambda_gan 159 | G_losses['g_losses'] = gan_loss 160 | loss += gan_loss 161 | 162 | if self.args.rec_loss: 163 | rec_loss = self.L1Loss(fake,gt) * self.args.lambda_rec 164 | G_losses['rec_loss'] = rec_loss 165 | loss += rec_loss 166 | 167 | 168 | if self.args.per_loss: 169 | per_loss = self.perLoss(fake,gt) * self.args.lambda_per 170 | G_losses['per_loss'] = per_loss 171 | loss += per_loss 172 | 173 | return G_losses,loss,fake 174 | 175 | 176 | def compute_cycle_g_loss(self,I_a,I_gray,I_t,hat_t,M_a,M_t,M_hat): 177 | G_losses = {} 178 | loss = 0 179 | fake_pair,label_pair = self.netG(I_a,I_gray,I_t,M_a,M_t,cycle=True) 180 | fake_nopair,label_nopair = self.netG(I_a,I_gray,hat_t,M_a,M_hat,cycle=True) 181 | 182 | loss = self.L1Loss(fake_pair,label_pair) + \ 183 | self.L1Loss(fake_nopair,label_nopair) 184 | G_losses['cycle'] = loss 185 | fake_nopair = F.interpolate(fake_nopair, size=I_t.shape[-2:],mode='bilinear') 186 | label_nopair = F.interpolate(label_nopair, size=I_t.shape[-2:],mode='bilinear') 187 | return G_losses,loss,fake_nopair,label_nopair 188 | 189 | 190 | def get_latest_losses(self): 191 | return {**self.g_losses,**self.d_losses} 192 | 193 | def get_latest_generated(self): 194 | return self.generator 195 | 196 | def loadParameters(self,path): 197 | ckpt = torch.load(path, map_location=lambda storage, loc: storage) 198 | self.netG.load_state_dict(ckpt['G'],strict=False) 199 | self.netD.load_state_dict(ckpt['D'],strict=False) 200 | self.optimG.load_state_dict(ckpt['g_optim']) 201 | self.optimD.load_state_dict(ckpt['d_optim']) 202 | 203 | def saveParameters(self,path): 204 | torch.save( 205 | { 206 | "G": self.netG_module.state_dict(), 207 | "D": self.netD_module.state_dict(), 208 | "g_optim": self.optimG.state_dict(), 209 | "d_optim": self.optimD.state_dict(), 210 | "args": self.args, 211 | }, 212 | path 213 | ) 214 | 215 | def get_lr(self): 216 | return self.optimG.state_dict()['param_groups'][0]['lr'] 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /model/BlendModule/module.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | class VGG19_pytorch(nn.Module): 6 | """ 7 | 8 | """ 9 | 10 | def __init__(self, pool="max"): 11 | super(VGG19_pytorch, self).__init__() 12 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 13 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 14 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 15 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 16 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 17 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 18 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 19 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 20 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 21 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 22 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 23 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 24 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 25 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 26 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 27 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 28 | if pool == "max": 29 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 30 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 31 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 32 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 33 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | elif pool == "avg": 35 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 36 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 37 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 38 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 39 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 40 | 41 | def forward(self, x): 42 | """ 43 | NOTE: input tensor should range in [0,1] 44 | """ 45 | out = {} 46 | 47 | out["r11"] = F.relu(self.conv1_1(x)) 48 | out["r12"] = F.relu(self.conv1_2(out["r11"])) 49 | out["p1"] = self.pool1(out["r12"]) 50 | out["r21"] = F.relu(self.conv2_1(out["p1"])) 51 | out["r22"] = F.relu(self.conv2_2(out["r21"])) 52 | out["p2"] = self.pool2(out["r22"]) 53 | out["r31"] = F.relu(self.conv3_1(out["p2"])) 54 | out["r32"] = F.relu(self.conv3_2(out["r31"])) 55 | out["r33"] = F.relu(self.conv3_3(out["r32"])) 56 | out["r34"] = F.relu(self.conv3_4(out["r33"])) 57 | out["p3"] = self.pool3(out["r34"]) 58 | out["r41"] = F.relu(self.conv4_1(out["p3"])) 59 | out["r42"] = F.relu(self.conv4_2(out["r41"])) 60 | out["r43"] = F.relu(self.conv4_3(out["r42"])) 61 | out["r44"] = F.relu(self.conv4_4(out["r43"])) 62 | 63 | return out["r44"] 64 | 65 | 66 | class Decoder(nn.Module): 67 | def __init__(self, ic): 68 | super(Decoder, self).__init__() 69 | self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1)) 70 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1) 71 | self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64) 72 | self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1) 73 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1) 74 | self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128) 75 | self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) 76 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) 77 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1) 78 | self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256) 79 | self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) 80 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) 81 | self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1) 82 | self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) 83 | self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) 84 | self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) 85 | self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) 86 | self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) 87 | self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) 88 | self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1) 89 | self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1) 90 | self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1) 91 | self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1) 92 | self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1) 93 | self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1) 94 | self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1) 95 | self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1) 96 | self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1) 97 | self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1) 98 | self.conv10_ab = nn.Conv2d(128, 3, 1, 1) 99 | 100 | # add self.relux_x 101 | self.relu1_1 = nn.ReLU() 102 | self.relu1_2 = nn.ReLU() 103 | self.relu2_1 = nn.ReLU() 104 | self.relu2_2 = nn.ReLU() 105 | self.relu3_1 = nn.ReLU() 106 | self.relu3_2 = nn.ReLU() 107 | self.relu3_3 = nn.ReLU() 108 | self.relu4_1 = nn.ReLU() 109 | self.relu4_2 = nn.ReLU() 110 | self.relu4_3 = nn.ReLU() 111 | self.relu5_1 = nn.ReLU() 112 | self.relu5_2 = nn.ReLU() 113 | self.relu5_3 = nn.ReLU() 114 | self.relu6_1 = nn.ReLU() 115 | self.relu6_2 = nn.ReLU() 116 | self.relu6_3 = nn.ReLU() 117 | self.relu7_1 = nn.ReLU() 118 | self.relu7_2 = nn.ReLU() 119 | self.relu7_3 = nn.ReLU() 120 | self.relu8_1_comb = nn.ReLU() 121 | self.relu8_2 = nn.ReLU() 122 | self.relu8_3 = nn.ReLU() 123 | self.relu9_1_comb = nn.ReLU() 124 | self.relu9_2 = nn.ReLU() 125 | self.relu10_1_comb = nn.ReLU() 126 | self.relu10_2 = nn.LeakyReLU(0.2, True) 127 | 128 | # print("replace all deconv with [nearest + conv]") 129 | self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1)) 130 | self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1)) 131 | self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1)) 132 | 133 | # print("replace all batchnorm with instancenorm") 134 | self.conv1_2norm = nn.InstanceNorm2d(64) 135 | self.conv2_2norm = nn.InstanceNorm2d(128) 136 | self.conv3_3norm = nn.InstanceNorm2d(256) 137 | self.conv4_3norm = nn.InstanceNorm2d(512) 138 | self.conv5_3norm = nn.InstanceNorm2d(512) 139 | self.conv6_3norm = nn.InstanceNorm2d(512) 140 | self.conv7_3norm = nn.InstanceNorm2d(512) 141 | self.conv8_3norm = nn.InstanceNorm2d(256) 142 | self.conv9_2norm = nn.InstanceNorm2d(128) 143 | 144 | def forward(self, x): 145 | """ x: gray image (1 channel), ab(2 channel), ab_err, ba_err""" 146 | conv1_1 = self.relu1_1(self.conv1_1(x)) 147 | conv1_2 = self.relu1_2(self.conv1_2(conv1_1)) 148 | conv1_2norm = self.conv1_2norm(conv1_2) 149 | conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm) 150 | conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss)) 151 | conv2_2 = self.relu2_2(self.conv2_2(conv2_1)) 152 | conv2_2norm = self.conv2_2norm(conv2_2) 153 | conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm) 154 | conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss)) 155 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) 156 | conv3_3 = self.relu3_3(self.conv3_3(conv3_2)) 157 | conv3_3norm = self.conv3_3norm(conv3_3) 158 | conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm) 159 | conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss)) 160 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) 161 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) 162 | conv4_3norm = self.conv4_3norm(conv4_3) 163 | conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm)) 164 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) 165 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) 166 | conv5_3norm = self.conv5_3norm(conv5_3) 167 | conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm)) 168 | conv6_2 = self.relu6_2(self.conv6_2(conv6_1)) 169 | conv6_3 = self.relu6_3(self.conv6_3(conv6_2)) 170 | conv6_3norm = self.conv6_3norm(conv6_3) 171 | conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm)) 172 | conv7_2 = self.relu7_2(self.conv7_2(conv7_1)) 173 | conv7_3 = self.relu7_3(self.conv7_3(conv7_2)) 174 | conv7_3norm = self.conv7_3norm(conv7_3) 175 | conv8_1 = self.conv8_1(conv7_3norm) 176 | conv3_3_short = self.conv3_3_short(conv3_3norm) 177 | conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short) 178 | conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb)) 179 | conv8_3 = self.relu8_3(self.conv8_3(conv8_2)) 180 | conv8_3norm = self.conv8_3norm(conv8_3) 181 | conv9_1 = self.conv9_1(conv8_3norm) 182 | conv2_2_short = self.conv2_2_short(conv2_2norm) 183 | conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short) 184 | conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb)) 185 | conv9_2norm = self.conv9_2norm(conv9_2) 186 | conv10_1 = self.conv10_1(conv9_2norm) 187 | conv1_2_short = self.conv1_2_short(conv1_2norm) 188 | conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short) 189 | conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb)) 190 | conv10_ab = self.conv10_ab(conv10_2) 191 | 192 | return torch.tanh(conv10_ab) 193 | -------------------------------------------------------------------------------- /trainer/AlignTrainer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @author Leslie 5 | @date 20220812 6 | ''' 7 | import torch 8 | 9 | from trainer.ModelTrainer import ModelTrainer 10 | from model.AlignModule.lib import * 11 | from model.AlignModule.discriminator import Discriminator 12 | from itertools import chain 13 | from utils.utils import * 14 | import torch.nn.functional as F 15 | from model.AlignModule.loss import * 16 | import random 17 | import torch.distributed as dist 18 | 19 | class AlignTrainer(ModelTrainer): 20 | 21 | def __init__(self, args): 22 | super().__init__(args) 23 | self.device = 'cpu' 24 | if torch.cuda.is_available(): 25 | self.device = 'cuda' 26 | 27 | self.Epor = PorEncoder(args).to(self.device) 28 | self.Eid = IDEncoder(args.id_model).to(self.device) 29 | self.Epose = PoseEncoder(args).to(self.device) 30 | self.Eexp = ExpEncoder(args).to(self.device) 31 | self.netG = Generator(args).to(self.device) 32 | 33 | self.netD = Discriminator(in_channels=3).to(self.device) 34 | 35 | 36 | self.optimG,self.optimD = self.create_optimizer() 37 | 38 | if args.pretrain_path is not None: 39 | self.loadParameters(args.pretrain_path) 40 | 41 | if args.dist: 42 | self.netG,self.netG_module = self.use_ddp(self.netG) 43 | self.Eexp,self.Eexp_module = self.use_ddp(self.Eexp) 44 | self.Epor,self.Epor_module = self.use_ddp(self.Epor) 45 | self.Epose,self.Epose_module = self.use_ddp(self.Epose) 46 | self.netD,self.netD_module = self.use_ddp(self.netD) 47 | else: 48 | self.netG_module = self.netG 49 | self.Eexp_module = self.Eexp 50 | self.Epor_module = self.Epor 51 | self.Epose_module = self.Epose 52 | self.netD_module = self.netD 53 | 54 | if self.args.per_loss: 55 | self.perLoss = PerceptualLoss(args.per_model).to(self.device) 56 | self.perLoss.eval() 57 | 58 | if self.args.rec_loss: 59 | self.L1Loss = torch.nn.L1Loss() 60 | self.Eid.eval() 61 | 62 | def create_optimizer(self): 63 | g_optim = torch.optim.Adam( 64 | chain(self.Epor.parameters(),self.Eexp.parameters(), 65 | self.Epose.parameters(),self.netG.parameters()), 66 | lr=self.args.g_lr, 67 | betas=(self.args.beta1, self.args.beta2), 68 | ) 69 | d_optim = torch.optim.Adam( 70 | self.netD.parameters(), 71 | lr=self.args.d_lr, 72 | betas=(self.args.beta1, self.args.beta2), 73 | ) 74 | 75 | return g_optim,d_optim 76 | 77 | 78 | def run_single_step(self, data, steps): 79 | self.netG.train() 80 | self.Epor.train() 81 | self.Epose.train() 82 | self.Eexp.train() 83 | super().run_single_step(data, steps) 84 | 85 | 86 | def run_discriminator_one_step(self, data,step): 87 | 88 | D_losses = {} 89 | requires_grad(self.netG, False) 90 | requires_grad(self.Epor, False) 91 | requires_grad(self.Epose, False) 92 | requires_grad(self.Eexp, False) 93 | requires_grad(self.netD, True) 94 | 95 | xs,xt,crop_xt,gt = data 96 | xg = self.forward(xs,crop_xt,xt) 97 | fake_pred = self.netD(xg) 98 | real_pred = self.netD(gt) 99 | d_loss = compute_dis_loss(fake_pred, real_pred,D_losses) 100 | D_losses['d'] = d_loss 101 | 102 | self.optimD.zero_grad() 103 | d_loss.backward() 104 | self.optimD.step() 105 | 106 | self.d_losses = D_losses 107 | 108 | 109 | def run_generator_one_step(self, data,step): 110 | 111 | 112 | requires_grad(self.netG, True) 113 | requires_grad(self.Epor, True) 114 | requires_grad(self.Epose, True) 115 | requires_grad(self.Eexp, True) 116 | requires_grad(self.netD, False) 117 | 118 | xs,xt,crop_xt,gt = data 119 | G_losses,loss,xg = self.compute_g_loss(xs,crop_xt,xt,gt) 120 | self.optimG.zero_grad() 121 | loss.mean().backward() 122 | self.optimG.step() 123 | 124 | self.g_losses = G_losses 125 | 126 | self.generator = [xs[:,0].detach() if len(xs.shape)>4 else xs.detach(),xt.detach(),xg.detach(),gt.detach()] 127 | 128 | 129 | def evalution(self,test_loader,steps,epoch): 130 | 131 | loss_dict = {} 132 | index = random.randint(0,len(test_loader)-1) 133 | counter = 0 134 | with torch.no_grad(): 135 | for i,data in enumerate(test_loader): 136 | 137 | data = self.process_input(data) 138 | xs,xt,crop_xt,gt = data 139 | G_losses,losses,xg = self.compute_g_loss(xs,crop_xt,xt,gt) 140 | for k,v in G_losses.items(): 141 | loss_dict[k] = loss_dict.get(k,0) + v.detach() 142 | if i == index and self.args.rank == 0 : 143 | 144 | show_data = [xs[:,0],xt,xg,gt] 145 | self.val_vis.display_current_results(self.select_img(show_data),steps) 146 | counter += 1 147 | 148 | 149 | for key,val in loss_dict.items(): 150 | loss_dict[key] /= counter 151 | 152 | if self.args.dist: 153 | # if self.args.rank == 0 : 154 | dist_losses = loss_dict.copy() 155 | for key,val in loss_dict.items(): 156 | 157 | dist.reduce(dist_losses[key],0) 158 | value = dist_losses[key].item() 159 | loss_dict[key] = value / self.args.world_size 160 | 161 | if self.args.rank == 0 : 162 | self.val_vis.plot_current_errors(loss_dict,steps) 163 | self.val_vis.print_current_errors(epoch+1,0,loss_dict,0) 164 | 165 | return loss_dict 166 | 167 | 168 | def forward(self,xs,crop_xt,xt): 169 | 170 | por_f = self.Epor(xs) 171 | id_f = self.Eid(self.process_id_input(xs,crop=True)) 172 | 173 | pose_f = self.Epose(xt) 174 | exp_f = self.Eexp(self.process_id_input(crop_xt,size=256)) 175 | 176 | xg = self.netG(por_f,id_f,pose_f,exp_f) 177 | 178 | return xg 179 | 180 | def compute_g_loss(self,xs,crop_xt,xt,gt): 181 | G_losses = {} 182 | loss = 0 183 | xg = self.forward(xs,crop_xt,xt) 184 | fake_pred = self.netD(xg) 185 | gan_loss = compute_gan_loss(fake_pred) * self.args.lambda_gan 186 | G_losses['g_losses'] = gan_loss 187 | loss += gan_loss 188 | 189 | if self.args.rec_loss: 190 | rec_loss = self.L1Loss(xg,gt) * self.args.lambda_rec 191 | G_losses['rec_loss'] = rec_loss 192 | loss += rec_loss 193 | 194 | if self.args.id_loss: 195 | fake_id_f = self.Eid(self.process_id_input(xg,crop=True)) 196 | real_id_f = self.Eid(self.process_id_input(gt,crop=True)) 197 | id_loss = compute_id_loss(fake_id_f,real_id_f).mean() * self.args.lambda_id 198 | G_losses['id_loss'] = id_loss 199 | loss += id_loss 200 | 201 | if self.args.per_loss: 202 | per_loss = self.perLoss(xg,gt) * self.args.lambda_per 203 | G_losses['per_loss'] = per_loss 204 | loss += per_loss 205 | 206 | return G_losses,loss,xg 207 | 208 | @staticmethod 209 | def process_id_input(x,crop=False,size=112): 210 | c,h,w = x.shape[-3:] 211 | batch = x.shape[0] 212 | scale = 0.4 / 1.8 213 | if crop: 214 | crop_x = x[...,int(h*scale):int(-h*scale),int(w*scale):int(-w*scale)] 215 | else: 216 | crop_x = x 217 | if len(x.shape) > 4: 218 | resize_x = F.adaptive_avg_pool2d(crop_x.view(-1,*crop_x.shape[-3:]),size) 219 | resize_x = resize_x.view(batch,-1,c,size,size) 220 | else: 221 | resize_x = F.adaptive_avg_pool2d(crop_x,size) 222 | return resize_x 223 | def get_latest_losses(self): 224 | return {**self.g_losses,**self.d_losses} 225 | 226 | def get_latest_generated(self): 227 | return self.generator 228 | 229 | def loadParameters(self,path): 230 | ckpt = torch.load(path, map_location=lambda storage, loc: storage) 231 | self.netG.load_state_dict(ckpt['G'],strict=False) 232 | self.Eexp.load_state_dict(ckpt['Eexp'],strict=False) 233 | self.Eid.load_state_dict(ckpt['Eid'],strict=False) 234 | self.Epor.load_state_dict(ckpt['Epor'],strict=False) 235 | self.Epose.load_state_dict(ckpt['Epose'],strict=False) 236 | # self.netD.load_state_dict(ckpt['D'],strict=False) 237 | self.optimG.load_state_dict(ckpt['g_optim']) 238 | # self.optimD.load_state_dict(ckpt['d_optim']) 239 | 240 | def saveParameters(self,path): 241 | torch.save( 242 | { 243 | "G": self.netG_module.state_dict(), 244 | 'D':self.netD_module.state_dict(), 245 | "Eexp": self.Eexp_module.state_dict(), 246 | "Eid":self.Eid.state_dict(), 247 | 'Epor':self.Epor_module.state_dict(), 248 | 'Epose':self.Epose_module.state_dict(), 249 | "g_optim": self.optimG.state_dict(), 250 | "d_optim": self.optimD.state_dict(), 251 | "args": self.args, 252 | }, 253 | path 254 | ) 255 | 256 | def get_lr(self): 257 | return self.optimG.state_dict()['param_groups'][0]['lr'] 258 | 259 | 260 | def select_img(self, data, name='fake', axis=2): 261 | data = [F.adaptive_avg_pool2d(x,self.args.output_image_size) for x in data] 262 | return super().select_img(data, name, axis) 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /model/third/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | 10 | from .resnet import Resnet18 11 | # from modules.bn import InPlaceABNSync as BatchNorm2d 12 | 13 | 14 | class ConvBNReLU(nn.Module): 15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs): 16 | super(ConvBNReLU, self).__init__() 17 | self.conv = nn.Conv2d(in_chan, 18 | out_chan, 19 | kernel_size = ks, 20 | stride = stride, 21 | padding = padding, 22 | bias = False) 23 | self.bn = nn.BatchNorm2d(out_chan) 24 | self.init_weight() 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = F.relu(self.bn(x)) 29 | return x 30 | 31 | def init_weight(self): 32 | for ly in self.children(): 33 | if isinstance(ly, nn.Conv2d): 34 | nn.init.kaiming_normal_(ly.weight, a=1) 35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 36 | 37 | class BiSeNetOutput(nn.Module): 38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): 39 | super(BiSeNetOutput, self).__init__() 40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) 41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False) 42 | self.init_weight() 43 | 44 | def forward(self, x): 45 | x = self.conv(x) 46 | x = self.conv_out(x) 47 | return x 48 | 49 | def init_weight(self): 50 | for ly in self.children(): 51 | if isinstance(ly, nn.Conv2d): 52 | nn.init.kaiming_normal_(ly.weight, a=1) 53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 54 | 55 | def get_params(self): 56 | wd_params, nowd_params = [], [] 57 | for name, module in self.named_modules(): 58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 59 | wd_params.append(module.weight) 60 | if not module.bias is None: 61 | nowd_params.append(module.bias) 62 | elif isinstance(module, nn.BatchNorm2d): 63 | nowd_params += list(module.parameters()) 64 | return wd_params, nowd_params 65 | 66 | 67 | class AttentionRefinementModule(nn.Module): 68 | def __init__(self, in_chan, out_chan, *args, **kwargs): 69 | super(AttentionRefinementModule, self).__init__() 70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) 71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False) 72 | self.bn_atten = nn.BatchNorm2d(out_chan) 73 | self.sigmoid_atten = nn.Sigmoid() 74 | self.init_weight() 75 | 76 | def forward(self, x): 77 | feat = self.conv(x) 78 | atten = F.avg_pool2d(feat, feat.size()[2:]) 79 | atten = self.conv_atten(atten) 80 | atten = self.bn_atten(atten) 81 | atten = self.sigmoid_atten(atten) 82 | out = torch.mul(feat, atten) 83 | return out 84 | 85 | def init_weight(self): 86 | for ly in self.children(): 87 | if isinstance(ly, nn.Conv2d): 88 | nn.init.kaiming_normal_(ly.weight, a=1) 89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 90 | 91 | 92 | class ContextPath(nn.Module): 93 | def __init__(self, *args, **kwargs): 94 | super(ContextPath, self).__init__() 95 | self.resnet = Resnet18() 96 | self.arm16 = AttentionRefinementModule(256, 128) 97 | self.arm32 = AttentionRefinementModule(512, 128) 98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) 100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) 101 | 102 | self.init_weight() 103 | 104 | def forward(self, x): 105 | H0, W0 = x.size()[2:] 106 | feat8, feat16, feat32 = self.resnet(x) 107 | H8, W8 = feat8.size()[2:] 108 | H16, W16 = feat16.size()[2:] 109 | H32, W32 = feat32.size()[2:] 110 | 111 | avg = F.avg_pool2d(feat32, feat32.size()[2:]) 112 | avg = self.conv_avg(avg) 113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest') 114 | 115 | feat32_arm = self.arm32(feat32) 116 | feat32_sum = feat32_arm + avg_up 117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') 118 | feat32_up = self.conv_head32(feat32_up) 119 | 120 | feat16_arm = self.arm16(feat16) 121 | feat16_sum = feat16_arm + feat32_up 122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') 123 | feat16_up = self.conv_head16(feat16_up) 124 | 125 | return feat8, feat16_up, feat32_up # x8, x8, x16 126 | 127 | def init_weight(self): 128 | for ly in self.children(): 129 | if isinstance(ly, nn.Conv2d): 130 | nn.init.kaiming_normal_(ly.weight, a=1) 131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 132 | 133 | def get_params(self): 134 | wd_params, nowd_params = [], [] 135 | for name, module in self.named_modules(): 136 | if isinstance(module, (nn.Linear, nn.Conv2d)): 137 | wd_params.append(module.weight) 138 | if not module.bias is None: 139 | nowd_params.append(module.bias) 140 | elif isinstance(module, nn.BatchNorm2d): 141 | nowd_params += list(module.parameters()) 142 | return wd_params, nowd_params 143 | 144 | 145 | ### This is not used, since I replace this with the resnet feature with the same size 146 | class SpatialPath(nn.Module): 147 | def __init__(self, *args, **kwargs): 148 | super(SpatialPath, self).__init__() 149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3) 150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) 152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) 153 | self.init_weight() 154 | 155 | def forward(self, x): 156 | feat = self.conv1(x) 157 | feat = self.conv2(feat) 158 | feat = self.conv3(feat) 159 | feat = self.conv_out(feat) 160 | return feat 161 | 162 | def init_weight(self): 163 | for ly in self.children(): 164 | if isinstance(ly, nn.Conv2d): 165 | nn.init.kaiming_normal_(ly.weight, a=1) 166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 167 | 168 | def get_params(self): 169 | wd_params, nowd_params = [], [] 170 | for name, module in self.named_modules(): 171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 172 | wd_params.append(module.weight) 173 | if not module.bias is None: 174 | nowd_params.append(module.bias) 175 | elif isinstance(module, nn.BatchNorm2d): 176 | nowd_params += list(module.parameters()) 177 | return wd_params, nowd_params 178 | 179 | 180 | class FeatureFusionModule(nn.Module): 181 | def __init__(self, in_chan, out_chan, *args, **kwargs): 182 | super(FeatureFusionModule, self).__init__() 183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) 184 | self.conv1 = nn.Conv2d(out_chan, 185 | out_chan//4, 186 | kernel_size = 1, 187 | stride = 1, 188 | padding = 0, 189 | bias = False) 190 | self.conv2 = nn.Conv2d(out_chan//4, 191 | out_chan, 192 | kernel_size = 1, 193 | stride = 1, 194 | padding = 0, 195 | bias = False) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.sigmoid = nn.Sigmoid() 198 | self.init_weight() 199 | 200 | def forward(self, fsp, fcp): 201 | fcat = torch.cat([fsp, fcp], dim=1) 202 | feat = self.convblk(fcat) 203 | atten = F.avg_pool2d(feat, feat.size()[2:]) 204 | atten = self.conv1(atten) 205 | atten = self.relu(atten) 206 | atten = self.conv2(atten) 207 | atten = self.sigmoid(atten) 208 | feat_atten = torch.mul(feat, atten) 209 | feat_out = feat_atten + feat 210 | return feat_out 211 | 212 | def init_weight(self): 213 | for ly in self.children(): 214 | if isinstance(ly, nn.Conv2d): 215 | nn.init.kaiming_normal_(ly.weight, a=1) 216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 217 | 218 | def get_params(self): 219 | wd_params, nowd_params = [], [] 220 | for name, module in self.named_modules(): 221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): 222 | wd_params.append(module.weight) 223 | if not module.bias is None: 224 | nowd_params.append(module.bias) 225 | elif isinstance(module, nn.BatchNorm2d): 226 | nowd_params += list(module.parameters()) 227 | return wd_params, nowd_params 228 | 229 | 230 | class BiSeNet(nn.Module): 231 | def __init__(self, n_classes, *args, **kwargs): 232 | super(BiSeNet, self).__init__() 233 | self.cp = ContextPath() 234 | ## here self.sp is deleted 235 | self.ffm = FeatureFusionModule(256, 256) 236 | self.conv_out = BiSeNetOutput(256, 256, n_classes) 237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes) 238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes) 239 | self.init_weight() 240 | 241 | def forward(self, x): 242 | H, W = x.size()[2:] 243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature 244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature 245 | feat_fuse = self.ffm(feat_sp, feat_cp8) 246 | 247 | feat_out = self.conv_out(feat_fuse) 248 | feat_out16 = self.conv_out16(feat_cp8) 249 | feat_out32 = self.conv_out32(feat_cp16) 250 | 251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) 252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) 253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) 254 | return feat_out, feat_out16, feat_out32 255 | 256 | def init_weight(self): 257 | for ly in self.children(): 258 | if isinstance(ly, nn.Conv2d): 259 | nn.init.kaiming_normal_(ly.weight, a=1) 260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 261 | 262 | def get_params(self): 263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], [] 264 | for name, child in self.named_children(): 265 | child_wd_params, child_nowd_params = child.get_params() 266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput): 267 | lr_mul_wd_params += child_wd_params 268 | lr_mul_nowd_params += child_nowd_params 269 | else: 270 | wd_params += child_wd_params 271 | nowd_params += child_nowd_params 272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params 273 | 274 | 275 | if __name__ == "__main__": 276 | net = BiSeNet(19) 277 | net.cuda() 278 | net.eval() 279 | in_ten = torch.randn(16, 3, 640, 480).cuda() 280 | out, out16, out32 = net(in_ten) 281 | print(out.shape) 282 | 283 | net.get_params() 284 | -------------------------------------------------------------------------------- /model/AlignModule/lib/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn.utils import spectral_norm 4 | 5 | 6 | class AdaptiveNorm2d(nn.Module): 7 | def __init__(self, num_features, norm_layer='in', eps=1e-4): 8 | super(AdaptiveNorm2d, self).__init__() 9 | self.num_features = num_features 10 | self.weight = self.bias = None 11 | if 'in' in norm_layer: 12 | self.norm_layer = nn.InstanceNorm2d(num_features, eps=eps, affine=False) 13 | elif 'bn' in norm_layer: 14 | self.norm_layer = SyncBatchNorm(num_features, momentum=1.0, eps=eps, affine=False) 15 | 16 | self.delete_weight_on_forward = True 17 | 18 | def forward(self, input): 19 | out = self.norm_layer(input) 20 | output = out * self.weight[:, :, None, None] + self.bias[:, :, None, None] 21 | 22 | # To save GPU memory 23 | if self.delete_weight_on_forward: 24 | self.weight = self.bias = None 25 | 26 | return output 27 | 28 | 29 | class AdaptiveNorm2dTrainable(nn.Module): 30 | def __init__(self, num_features, norm_layer='in', eps=1e-4): 31 | super(AdaptiveNorm2dTrainable, self).__init__() 32 | self.num_features = num_features 33 | if 'in' in norm_layer: 34 | self.norm_layer = nn.InstanceNorm2d(num_features, eps=eps, affine=False) 35 | 36 | def forward(self, input): 37 | out = self.norm_layer(input) 38 | t = out.shape[0] // self.weight.shape[0] 39 | output = out * self.weight + self.bias 40 | return output 41 | 42 | def assign_params(self, weight, bias): 43 | self.weight = torch.nn.Parameter(weight.view(1, -1, 1, 1)) 44 | self.bias = torch.nn.Parameter(bias.view(1, -1, 1, 1)) 45 | 46 | 47 | class ResBlock(nn.Module): 48 | def __init__(self, in_channels, out_channels, padding, upsample, downsample, 49 | norm_layer, activation=nn.ReLU, gated=False): 50 | super(ResBlock, self).__init__() 51 | normalize = norm_layer != 'none' 52 | bias = not normalize 53 | 54 | # if norm_layer == 'bn': 55 | # # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4) 56 | # # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4) 57 | # pass 58 | if norm_layer == 'in': 59 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True) 60 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True) 61 | elif 'ada' in norm_layer: 62 | norm0 = AdaptiveNorm2d(in_channels, norm_layer) 63 | norm1 = AdaptiveNorm2d(out_channels, norm_layer) 64 | elif 'tra' in norm_layer: 65 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer) 66 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer) 67 | elif normalize: 68 | raise Exception('ResBlock: Incorrect `norm_layer` parameter') 69 | 70 | layers = [] 71 | if normalize: 72 | layers.append(norm0) 73 | layers.append(activation(inplace=True)) 74 | if upsample: 75 | layers.append(nn.Upsample(scale_factor=2)) 76 | layers.extend([ 77 | nn.Sequential() if padding is nn.ZeroPad2d else padding(1), 78 | spectral_norm( 79 | nn.Conv2d(in_channels, out_channels, 3, 1, 1 if padding is nn.ZeroPad2d else 0, bias=bias), 80 | eps=1e-4)]) 81 | if normalize: 82 | layers.append(norm1) 83 | layers.extend([ 84 | activation(inplace=True), 85 | nn.Sequential() if padding is nn.ZeroPad2d else padding(1), 86 | spectral_norm( 87 | nn.Conv2d(out_channels, out_channels, 3, 1, 1 if padding is nn.ZeroPad2d else 0, bias=bias), 88 | eps=1e-4)]) 89 | if downsample: 90 | layers.append(nn.AvgPool2d(2)) 91 | self.block = nn.Sequential(*layers) 92 | 93 | self.skip = None 94 | if in_channels != out_channels or upsample or downsample: 95 | layers = [] 96 | if upsample: 97 | layers.append(nn.Upsample(scale_factor=2)) 98 | layers.append(spectral_norm( 99 | nn.Conv2d(in_channels, out_channels, 1), 100 | eps=1e-4)) 101 | if downsample: 102 | layers.append(nn.AvgPool2d(2)) 103 | self.skip = nn.Sequential(*layers) 104 | 105 | def forward(self, input): 106 | out = self.block(input) 107 | if self.skip is not None: 108 | output = out + self.skip(input) 109 | else: 110 | output = out + input 111 | return output 112 | 113 | class channelShuffle(nn.Module): 114 | def __init__(self,groups): 115 | super(channelShuffle, self).__init__() 116 | self.groups=groups 117 | 118 | def forward(self,x): 119 | batchsize, num_channels, height, width = x.data.size() 120 | 121 | # batchsize = x.shape[0] 122 | # num_channels = x.shape[1] 123 | # height = x.shape[2] 124 | # width = x.shape[3] 125 | 126 | channels_per_group = num_channels // self.groups 127 | 128 | # reshape 129 | x = x.view(batchsize, self.groups, channels_per_group, height, width) 130 | 131 | # transpose 132 | # - contiguous() required if transpose() is used before view(). 133 | # See https://github.com/pytorch/pytorch/issues/764 134 | x = torch.transpose(x, 1, 2).contiguous() 135 | 136 | # flatten 137 | x = x.view(batchsize, -1, height, width) 138 | 139 | return x 140 | 141 | 142 | class shuffleConv(nn.Module): 143 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): 144 | super(shuffleConv, self).__init__() 145 | self.in_channels=in_channels 146 | self.out_channels=out_channels 147 | self.stride=stride 148 | self.padding=padding 149 | groups=4 150 | block=[] 151 | if (in_channels%groups==0) and (out_channels%groups==0): 152 | block.append(spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,padding=0, groups=groups),eps=1e-4)) 153 | block.append(nn.ReLU6(inplace=True)) 154 | block.append(channelShuffle(groups=groups)) 155 | block.append(spectral_norm(nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=3,padding=1, groups=groups),eps=1e-4)) 156 | block.append(nn.ReLU6(inplace=True)) 157 | block.append(spectral_norm(nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=1,padding=0, groups=groups),eps=1e-4)) 158 | else: 159 | block.append(spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3,padding=1),eps=1e-4)) 160 | self.block=nn.Sequential(*block) 161 | 162 | def forward(self,x): 163 | x=self.block(x) 164 | return x 165 | 166 | 167 | class ResBlockShuffle(nn.Module): 168 | def __init__(self, in_channels, out_channels, padding, upsample, downsample, 169 | norm_layer, activation=nn.ReLU, gated=False): 170 | super(ResBlockShuffle, self).__init__() 171 | normalize = norm_layer != 'none' 172 | bias = not normalize 173 | 174 | # if norm_layer == 'bn': 175 | # # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4) 176 | # # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4) 177 | # pass 178 | if norm_layer == 'in': 179 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True) 180 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True) 181 | elif 'ada' in norm_layer: 182 | norm0 = AdaptiveNorm2d(in_channels, norm_layer) 183 | norm1 = AdaptiveNorm2d(out_channels, norm_layer) 184 | elif 'tra' in norm_layer: 185 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer) 186 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer) 187 | elif normalize: 188 | raise Exception('ResBlock: Incorrect `norm_layer` parameter') 189 | 190 | layers = [] 191 | if normalize: 192 | layers.append(norm0) 193 | layers.append(activation(inplace=True)) 194 | if upsample: 195 | layers.append(nn.Upsample(scale_factor=2)) 196 | layers.extend([ 197 | #padding(1), 198 | #spectral_norm( 199 | shuffleConv(in_channels, out_channels, 3, 1, 0, bias=bias)#, 200 | # eps=1e-4) 201 | ]) 202 | if normalize: 203 | layers.append(norm1) 204 | layers.extend([ 205 | activation(inplace=True), 206 | #padding(1), 207 | #spectral_norm( 208 | shuffleConv(out_channels, out_channels, 3, 1, 0, bias=bias)#, 209 | # eps=1e-4) 210 | ]) 211 | if downsample: 212 | layers.append(nn.AvgPool2d(2)) 213 | self.block = nn.Sequential(*layers) 214 | 215 | self.skip = None 216 | if in_channels != out_channels or upsample or downsample: 217 | layers = [] 218 | if upsample: 219 | layers.append(nn.Upsample(scale_factor=2)) 220 | layers.append( 221 | #spectral_norm( 222 | shuffleConv(in_channels, out_channels, 1)#, 223 | # eps=1e-4) 224 | ) 225 | if downsample: 226 | layers.append(nn.AvgPool2d(2)) 227 | self.skip = nn.Sequential(*layers) 228 | 229 | def forward(self, input): 230 | out = self.block(input) 231 | if self.skip is not None: 232 | output = out + self.skip(input) 233 | else: 234 | output = out + input 235 | return output 236 | 237 | 238 | 239 | class ResBlockV2(nn.Module): 240 | def __init__(self, in_channels, out_channels, stride, groups, 241 | resize_layer, norm_layer, activation): 242 | super(ResBlockV2, self).__init__() 243 | upsampling_layers = { 244 | 'nearest': lambda: nn.Upsample(scale_factor=stride, mode='nearest') 245 | } 246 | downsampling_layers = { 247 | 'avgpool': lambda: nn.AvgPool2d(stride) 248 | } 249 | norm_layers = { 250 | 'bn': lambda num_features: SyncBatchNorm(num_features, momentum=1.0, eps=1e-4), 251 | 'in': lambda num_features: nn.InstanceNorm2d(num_features, eps=1e-4, affine=True), 252 | 'adabn': lambda num_features: AdaptiveNorm2d(num_features, 'bn'), 253 | 'adain': lambda num_features: AdaptiveNorm2d(num_features, 'in') 254 | } 255 | normalize = norm_layer != 'none' 256 | bias = not normalize 257 | upsample = resize_layer in upsampling_layers 258 | downsample = resize_layer in downsampling_layers 259 | if normalize: 260 | norm_layer = norm_layers[norm_layer] 261 | 262 | layers = [] 263 | if normalize: 264 | layers.append(norm_layer(in_channels)) 265 | layers.append(activation()) 266 | if upsample: 267 | layers.append(nn.Upsample(scale_factor=2)) 268 | layers.extend([ 269 | spectral_norm( 270 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=bias), 271 | eps=1e-4)]) 272 | if normalize: 273 | layers.append(norm_layer(out_channels)) 274 | layers.extend([ 275 | activation(), 276 | spectral_norm( 277 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=bias), 278 | eps=1e-4)]) 279 | if downsample: 280 | layers.append(nn.AvgPool2d(2)) 281 | self.block = nn.Sequential(*layers) 282 | 283 | self.skip = None 284 | if in_channels != out_channels or upsample or downsample: 285 | layers = [] 286 | if upsample: 287 | layers.append(nn.Upsample(scale_factor=2)) 288 | layers.append(spectral_norm( 289 | nn.Conv2d(in_channels, out_channels, 1), 290 | eps=1e-4)) 291 | if downsample: 292 | layers.append(nn.AvgPool2d(2)) 293 | self.skip = nn.Sequential(*layers) 294 | 295 | def forward(self, input): 296 | out = self.block(input) 297 | if self.skip is not None: 298 | output = out + self.skip(input) 299 | else: 300 | output = out + input 301 | return output 302 | 303 | class ResBlockV2Shuffle(nn.Module): 304 | def __init__(self, in_channels, out_channels, stride, groups, 305 | resize_layer, norm_layer, activation): 306 | super(ResBlockV2Shuffle, self).__init__() 307 | upsampling_layers = { 308 | 'nearest': lambda: nn.Upsample(scale_factor=stride, mode='nearest') 309 | } 310 | downsampling_layers = { 311 | 'avgpool': lambda: nn.AvgPool2d(stride) 312 | } 313 | norm_layers = { 314 | 'bn': lambda num_features: SyncBatchNorm(num_features, momentum=1.0, eps=1e-4), 315 | 'in': lambda num_features: nn.InstanceNorm2d(num_features, eps=1e-4, affine=True), 316 | 'adabn': lambda num_features: AdaptiveNorm2d(num_features, 'bn'), 317 | 'adain': lambda num_features: AdaptiveNorm2d(num_features, 'in') 318 | } 319 | normalize = norm_layer != 'none' 320 | bias = not normalize 321 | upsample = resize_layer in upsampling_layers 322 | downsample = resize_layer in downsampling_layers 323 | if normalize: 324 | norm_layer = norm_layers[norm_layer] 325 | 326 | layers = [] 327 | if normalize: 328 | layers.append(norm_layer(in_channels)) 329 | layers.append(activation()) 330 | if upsample: 331 | layers.append(nn.Upsample(scale_factor=2)) 332 | layers.extend([ 333 | #spectral_norm( 334 | shuffleConv(in_channels, out_channels, 3, 1, 1, bias=bias)#, 335 | # eps=1e-4) 336 | ]) 337 | if normalize: 338 | layers.append(norm_layer(out_channels)) 339 | layers.extend([ 340 | activation(), 341 | #spectral_norm( 342 | shuffleConv(out_channels, out_channels, 3, 1, 1, bias=bias)#, 343 | # eps=1e-4) 344 | ]) 345 | if downsample: 346 | layers.append(nn.AvgPool2d(2)) 347 | self.block = nn.Sequential(*layers) 348 | 349 | self.skip = None 350 | if in_channels != out_channels or upsample or downsample: 351 | layers = [] 352 | if upsample: 353 | layers.append(nn.Upsample(scale_factor=2)) 354 | layers.append(#spectral_norm( 355 | shuffleConv(in_channels, out_channels, 1)#, 356 | # eps=1e-4) 357 | ) 358 | if downsample: 359 | layers.append(nn.AvgPool2d(2)) 360 | self.skip = nn.Sequential(*layers) 361 | 362 | def forward(self, input): 363 | out = self.block(input) 364 | if self.skip is not None: 365 | output = out + self.skip(input) 366 | else: 367 | output = out + input 368 | return output 369 | 370 | 371 | 372 | class GatedBlock(nn.Module): 373 | def __init__(self, in_channels, out_channels, act_fun, kernel_size, stride=1, padding=0, bias=True): 374 | super(GatedBlock, self).__init__() 375 | self.conv = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), 376 | eps=1e-4) 377 | self.gate = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), 378 | eps=1e-4) 379 | self.act_fun = act_fun() 380 | self.gate_act_fun = nn.Sigmoid() 381 | 382 | def forward(self, x): 383 | out = self.conv(x) 384 | out = self.act_fun(out) 385 | 386 | mask = self.gate(x) 387 | mask = self.gate_act_fun(mask) 388 | 389 | out_masked = out * mask 390 | return out_masked 391 | 392 | 393 | class GatedResBlock(nn.Module): 394 | def __init__(self, in_channels, out_channels, padding, upsample, downsample, 395 | norm_layer, activation=nn.ReLU): 396 | super(GatedResBlock, self).__init__() 397 | normalize = norm_layer != 'none' 398 | bias = not normalize 399 | 400 | if norm_layer == 'in': 401 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True) 402 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True) 403 | elif 'ada' in norm_layer: 404 | norm0 = AdaptiveNorm2d(in_channels, norm_layer) 405 | norm1 = AdaptiveNorm2d(out_channels, norm_layer) 406 | elif 'tra' in norm_layer: 407 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer) 408 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer) 409 | elif normalize: 410 | raise Exception('ResBlock: Incorrect `norm_layer` parameter') 411 | 412 | main_layers = [] 413 | 414 | if normalize: 415 | main_layers.append(norm0) 416 | if upsample: 417 | main_layers.append(nn.Upsample(scale_factor=2)) 418 | 419 | main_layers.extend([ 420 | padding(1), 421 | GatedBlock(in_channels, out_channels, activation, 3, 1, 0, bias=bias)]) 422 | 423 | if normalize: 424 | main_layers.append(norm1) 425 | main_layers.extend([ 426 | padding(1), 427 | GatedBlock(out_channels, out_channels, activation, 3, 1, 0, bias=bias)]) 428 | if downsample: 429 | main_layers.append(nn.AvgPool2d(2)) 430 | 431 | self.main_pipe = nn.Sequential(*main_layers) 432 | 433 | self.skip_pipe = None 434 | if in_channels != out_channels or upsample or downsample: 435 | skip_layers = [] 436 | 437 | if upsample: 438 | skip_layers.append(nn.Upsample(scale_factor=2)) 439 | 440 | skip_layers.append(GatedBlock(in_channels, out_channels, activation, 1)) 441 | 442 | if downsample: 443 | skip_layers.append(nn.AvgPool2d(2)) 444 | self.skip_pipe = nn.Sequential(*skip_layers) 445 | 446 | def forward(self, input): 447 | mp_out = self.main_pipe(input) 448 | if self.skip_pipe is not None: 449 | output = mp_out + self.skip_pipe(input) 450 | else: 451 | output = mp_out + input 452 | return output 453 | 454 | 455 | class ResBlockWithoutSpectralNorms(nn.Module): 456 | def __init__(self, in_channels, out_channels, padding, upsample, downsample, 457 | norm_layer, activation=nn.ReLU): 458 | super(ResBlockWithoutSpectralNorms, self).__init__() 459 | normalize = norm_layer != 'none' 460 | bias = not normalize 461 | 462 | # if norm_layer == 'bn': 463 | # # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4) 464 | # # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4) 465 | # pass 466 | if norm_layer == 'in': 467 | norm0 = nn.InstanceNorm2d(in_channels, eps=1e-4, affine=True) 468 | norm1 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True) 469 | elif 'ada' in norm_layer: 470 | norm0 = AdaptiveNorm2d(in_channels, norm_layer) 471 | norm1 = AdaptiveNorm2d(out_channels, norm_layer) 472 | elif 'tra' in norm_layer: 473 | norm0 = AdaptiveNorm2dTrainable(in_channels, norm_layer) 474 | norm1 = AdaptiveNorm2dTrainable(out_channels, norm_layer) 475 | elif normalize: 476 | raise Exception('ResBlock: Incorrect `norm_layer` parameter') 477 | 478 | layers = [] 479 | if normalize: 480 | layers.append(norm0) 481 | layers.append(activation(inplace=True)) 482 | if upsample: 483 | layers.append(nn.Upsample(scale_factor=2)) 484 | layers.extend([ 485 | padding(1), 486 | # spectral_norm( 487 | nn.Conv2d(in_channels, out_channels, 3, 1, 0, bias=bias) # , 488 | # eps=1e-4) 489 | ]) 490 | if normalize: 491 | layers.append(norm1) 492 | layers.extend([ 493 | activation(inplace=True), 494 | padding(1), 495 | # spectral_norm( 496 | nn.Conv2d(out_channels, out_channels, 3, 1, 0, bias=bias) # , 497 | # eps=1e-4) 498 | ]) 499 | if downsample: 500 | layers.append(nn.AvgPool2d(2)) 501 | self.block = nn.Sequential(*layers) 502 | 503 | self.skip = None 504 | if in_channels != out_channels or upsample or downsample: 505 | layers = [] 506 | if upsample: 507 | layers.append(nn.Upsample(scale_factor=2)) 508 | layers.append( # spectral_norm( 509 | nn.Conv2d(in_channels, out_channels, 1) # , 510 | # eps=1e-4) 511 | ) 512 | if downsample: 513 | layers.append(nn.AvgPool2d(2)) 514 | self.skip = nn.Sequential(*layers) 515 | 516 | def forward(self, input): 517 | out = self.block(input) 518 | if self.skip is not None: 519 | output = out + self.skip(input) 520 | else: 521 | output = out + input 522 | return output 523 | 524 | 525 | class MobileNetBlock(nn.Module): 526 | def __init__(self, in_channels, out_channels, padding, upsample, downsample, 527 | norm_layer, activation=nn.ReLU6, expansion_factor=6): 528 | super(MobileNetBlock, self).__init__() 529 | normalize = norm_layer != 'none' 530 | bias = not normalize 531 | 532 | conv0 = nn.Conv2d(in_channels, int(in_channels * expansion_factor), 1) 533 | dwise = nn.Conv2d(int(in_channels * expansion_factor), int(in_channels * expansion_factor), 3, 534 | 2 if downsample else 1, 1, groups=int(in_channels * expansion_factor)) 535 | conv1 = nn.Conv2d(int(in_channels * expansion_factor), out_channels, 1) 536 | 537 | if norm_layer == 'bn': 538 | # norm0 = SyncBatchNorm(in_channels, momentum=1.0, eps=1e-4) 539 | # norm1 = SyncBatchNorm(out_channels, momentum=1.0, eps=1e-4) 540 | pass 541 | if 'in' in norm_layer: 542 | norm0 = nn.InstanceNorm2d(int(in_channels * expansion_factor), eps=1e-4, affine=True) 543 | norm1 = nn.InstanceNorm2d(int(in_channels * expansion_factor), eps=1e-4, affine=True) 544 | norm2 = nn.InstanceNorm2d(out_channels, eps=1e-4, affine=True) 545 | if 'ada' in norm_layer: 546 | norm2 = AdaptiveNorm2d(out_channels, norm_layer) 547 | elif 'tra' in norm_layer: 548 | norm2 = AdaptiveNorm2dTrainable(out_channels, norm_layer) 549 | 550 | # layers = [spectral_norm(conv0, eps=1e-4)] 551 | layers = [conv0] 552 | if normalize: layers.append(norm0) 553 | layers.append(activation(inplace=True)) 554 | if upsample: layers.append(nn.Upsample(scale_factor=2)) 555 | # layers.append(spectral_norm(dwise, eps=1e-4)) 556 | layers.append(dwise) 557 | if normalize: layers.append(norm1) 558 | layers.extend([ 559 | activation(inplace=True), 560 | # spectral_norm( 561 | conv1 # , 562 | # eps=1e-4) 563 | ]) 564 | if normalize: layers.append(norm2) 565 | self.block = nn.Sequential(*layers) 566 | 567 | self.skip = None 568 | if in_channels != out_channels or upsample or downsample: 569 | layers = [] 570 | if upsample: layers.append(nn.Upsample(scale_factor=2)) 571 | layers.append( 572 | # spectral_norm( 573 | nn.Conv2d(in_channels, out_channels, 1) # , 574 | # eps=1e-4) 575 | ) 576 | if downsample: 577 | layers.append(nn.AvgPool2d(2)) 578 | self.skip = nn.Sequential(*layers) 579 | 580 | def forward(self, input): 581 | out = self.block(input) 582 | if self.skip is not None: 583 | output = out + self.skip(input) 584 | else: 585 | output = out + input 586 | return output 587 | 588 | 589 | class SelfAttention(nn.Module): 590 | def __init__(self, in_channels): 591 | super(SelfAttention, self).__init__() 592 | self.in_channels = in_channels 593 | self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 594 | self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 595 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1) 596 | self.gamma = nn.Parameter(torch.zeros(1)) 597 | self.softmax = nn.Softmax(-1) 598 | 599 | def forward(self, input): 600 | b, c, h, w = input.shape 601 | query = self.query_conv(input).view(b, -1, h * w).permute(0, 2, 1) # B x HW x C/8 602 | key = self.key_conv(input).view(b, -1, h * w) # B x C/8 x HW 603 | energy = torch.bmm(query, key) # B x HW x HW 604 | attention = self.softmax(energy) # B x HW x HW 605 | value = self.value_conv(input).view(b, -1, h * w) # B x C x HW 606 | 607 | out = torch.bmm(value, attention.permute(0, 2, 1)).view(b, c, h, w) 608 | output = self.gamma * out + input 609 | return output 610 | --------------------------------------------------------------------------------