├── README.md ├── assets ├── pipeline-2.png └── test ├── code ├── cam_utils.py ├── dataset.py ├── networks │ ├── __init__.py │ ├── encoder3d.py │ └── headnerf.py ├── pretrained_models │ └── eg3d │ │ └── test ├── run_recon_video_3dmm.py ├── run_recon_video_audio.py ├── run_recon_video_rgb.py ├── test ├── train_3dmm.py ├── train_audio.py ├── train_rgb.py ├── trainer_3dmm.py ├── trainer_audio.py └── trainer_rgb.py └── eg3d-pose-detection ├── 3dface2idr.py ├── batch_mtcnn.py ├── camera2label.py ├── crop_images.py ├── data ├── __init__.py ├── base_dataset.py ├── flist_dataset.py ├── image_folder.py └── template_dataset.py ├── models ├── __init__.py ├── arcface_torch │ ├── README.md │ ├── backbones │ │ ├── __init__.py │ │ ├── iresnet.py │ │ ├── iresnet2060.py │ │ ├── mobilefacenet.py │ │ └── vit.py │ ├── configs │ │ ├── 3millions.py │ │ ├── __init__.py │ │ ├── base.py │ │ ├── glint360k_mbf.py │ │ ├── glint360k_r100.py │ │ ├── glint360k_r50.py │ │ ├── ms1mv2_mbf.py │ │ ├── ms1mv2_r100.py │ │ ├── ms1mv2_r50.py │ │ ├── ms1mv3_mbf.py │ │ ├── ms1mv3_r100.py │ │ ├── ms1mv3_r50.py │ │ ├── wf12m_conflict_r50.py │ │ ├── wf12m_conflict_r50_pfc03_filter04.py │ │ ├── wf12m_flip_pfc01_filter04_r50.py │ │ ├── wf12m_flip_r50.py │ │ ├── wf12m_mbf.py │ │ ├── wf12m_pfc02_r100.py │ │ ├── wf12m_r100.py │ │ ├── wf12m_r50.py │ │ ├── wf42m_pfc0008_32gpu_r100.py │ │ ├── wf42m_pfc02_16gpus_mbf_bs8k.py │ │ ├── wf42m_pfc02_16gpus_r100.py │ │ ├── wf42m_pfc02_16gpus_r50_bs8k.py │ │ ├── wf42m_pfc02_32gpus_r50_bs4k.py │ │ ├── wf42m_pfc02_8gpus_r50_bs4k.py │ │ ├── wf42m_pfc02_r100.py │ │ ├── wf42m_pfc02_r100_16gpus.py │ │ ├── wf42m_pfc02_r100_32gpus.py │ │ ├── wf42m_pfc03_32gpu_r100.py │ │ ├── wf42m_pfc03_32gpu_r18.py │ │ ├── wf42m_pfc03_32gpu_r200.py │ │ ├── wf42m_pfc03_32gpu_r50.py │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_b.py │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_l.py │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_s.py │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_t.py │ │ ├── wf42m_pfc03_40epoch_8gpu_vit_t.py │ │ ├── wf4m_mbf.py │ │ ├── wf4m_r100.py │ │ └── wf4m_r50.py │ ├── dataset.py │ ├── dist.sh │ ├── docs │ │ ├── eval.md │ │ ├── install.md │ │ ├── install_dali.md │ │ ├── modelzoo.md │ │ ├── prepare_webface42m.md │ │ └── speed_benchmark.md │ ├── eval │ │ ├── __init__.py │ │ └── verification.py │ ├── eval_ijbc.py │ ├── flops.py │ ├── inference.py │ ├── losses.py │ ├── lr_scheduler.py │ ├── onnx_helper.py │ ├── onnx_ijbc.py │ ├── partial_fc.py │ ├── requirement.txt │ ├── run.sh │ ├── torch2onnx.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── plot.py │ │ ├── utils_callbacks.py │ │ ├── utils_config.py │ │ ├── utils_distributed_sampler.py │ │ └── utils_logging.py ├── base_model.py ├── bfm.py ├── facerecon_model.py ├── losses.py ├── networks.py └── template_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── process_test_video.py ├── smooth.py ├── test └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # HFA-GP: High-fidelity Facial Avatar Reconstruction from Monocular Video with Generative Priors 2 |

3 | 4 |

5 | 6 | ## HFA-GP 7 | **HFA-GP** is a framework for reconstructing high-fidelity facial avatar from monocular video. 8 | This is the official implementation of *High-fidelity Facial Avatar Reconstruction from Monocular Video with Generative Priors* 9 | 10 | ## Abstract 11 | High-fidelity facial avatar reconstruction from a monocular video is a significant research problem in computer graphics and computer vision. Recently, Neural Radiance Field (NeRF) has shown impressive novel view rendering results and has been considered for facial avatar reconstruction. However, the complex facial dynamics and missing 3D information in monocular videos raise significant challenges for faithful facial reconstruction. In this work, we propose a new method for NeRF-based facial avatar reconstruction that utilizes 3D-aware generative prior. Different from existing works that depend on a conditional deformation field for dynamic modeling, we propose to learn a personalized generative prior, which is formulated as a local and low dimensional subspace in the latent space of 3D-GAN. We propose an efficient method to construct the personalized generative prior based on a small set of facial images of a given individual. After learning, it allows for photo-realistic rendering with novel views, and the face reenactment can be realized by performing navigation in the latent space. Our proposed method is applicable for different driven signals, including RGB images, 3DMM coefficients, and audio. Compared with existing works, we obtain superior novel view synthesis results and faithfully face reenactment performance. 12 | 13 | ## Preprocessing data 14 | 15 | ``` 16 | python3 ./eg3d-pose-detection/process_test_video.py --input_dir 17 | ``` 18 | 19 | 20 | ## Training on monocular video 21 | audio-driven: 22 | ``` 23 | python3 ./code/train_audio.py --dataset 'ad_dataset' --person_1 'english_m' --exp_path './code/exps/' --exp_name 'ad-english_w_e' 24 | ``` 25 | 26 | 3DMM-driven: 27 | ``` 28 | python3 ./code/train_3dmm.py --dataset 'nerface_dataset' --person_1 'person_3' --exp_path './code/exps/' --exp_name '1-nerface3-3dmm2' 29 | ``` 30 | 31 | RGB-driven: 32 | ``` 33 | python3 ./code/train_rgb.py --dataset 'nerface_dataset' --person 'person_3' --exp_path './code/exps/' --exp_name '1-nerface-3-2' 34 | ``` 35 | 36 | ## Performing face reenactment 37 | audio-driven: 38 | ``` 39 | python3 ./code/run_recon_video_audio.py --dataset 'ad_dataset' --person_1 'english_w' --demo_name 'english_w' --dataset_type 'val' --model_path './code/exps/ad-english_m/checkpoint/checkpoint.pt' --cat_video 40 | ``` 41 | For audio-driven experiments, deepspeech is required to extract features from the audio. This part uses AD-NeRF's code [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). First, use ffmpeg to extract the audio in WAV format and then extract the features. The extracted feature file should be named aud.npy. 42 | 43 | 3DMM-driven: 44 | ``` 45 | python3 ./code/run_recon_video_3dmm.py --dataset 'nerface_dataset' --person_1 'person_3' --demo_name 'nerface3' --dataset_type 'test' --model_path './code/exps/1-nerface3/checkpoint/checkpoint.pt' --cat_video 46 | ``` 47 | 48 | RGB-driven: 49 | ``` 50 | python3 ./code/run_recon_video_rgb.py --dataset 'nerface_dataset' --person 'person_3' --demo_name '1-nerface3-2-new' --model_path './code/exps/1-nerface/checkpoint/checkpoint.pt' --cat_video --dataset_type 'test' --suffix '.png' --latent_dim_shape 50 51 | ``` 52 | 53 | ## Citation ## 54 | Please cite the following paper if you use this repository in your reseach. 55 | ``` 56 | @InProceedings{Bai_2023_CVPR, 57 | author = {Bai, Yunpeng and Fan, Yanbo and Wang, Xuan and Zhang, Yong and Sun, Jingxiang and Yuan, Chun and Shan, Ying}, 58 | title = {High-Fidelity Facial Avatar Reconstruction From Monocular Video With Generative Priors}, 59 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 60 | month = {June}, 61 | year = {2023}, 62 | pages = {4541-4551} 63 | } 64 | ``` 65 | 66 | 67 | -------------------------------------------------------------------------------- /assets/pipeline-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/assets/pipeline-2.png -------------------------------------------------------------------------------- /assets/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/cam_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import math 4 | 5 | 6 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: 7 | """ 8 | Normalize vector lengths. 9 | """ 10 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) 11 | 12 | def sample_camera_positions(device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1, horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'): 13 | """ 14 | Samples n random locations along a sphere of radius r. Uses the specified distribution. 15 | Theta is yaw in radians (-pi, pi) 16 | Phi is pitch in radians (0, pi) 17 | """ 18 | 19 | if mode == 'uniform': 20 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean 21 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev + vertical_mean 22 | 23 | elif mode == 'normal' or mode == 'gaussian': 24 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean 25 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean 26 | 27 | elif mode == 'hybrid': 28 | if random.random() < 0.5: 29 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev * 2 + horizontal_mean 30 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev * 2 + vertical_mean 31 | else: 32 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean 33 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean 34 | 35 | elif mode == 'truncated_gaussian': 36 | theta = truncated_normal_(torch.zeros((n, 1), device=device)) * horizontal_stddev + horizontal_mean 37 | phi = truncated_normal_(torch.zeros((n, 1), device=device)) * vertical_stddev + vertical_mean 38 | 39 | elif mode == 'spherical_uniform': 40 | theta = (torch.rand((n, 1), device=device) - .5) * 2 * horizontal_stddev + horizontal_mean 41 | v_stddev, v_mean = vertical_stddev / math.pi, vertical_mean / math.pi 42 | v = ((torch.rand((n,1), device=device) - .5) * 2 * v_stddev + v_mean) 43 | v = torch.clamp(v, 1e-5, 1 - 1e-5) 44 | phi = torch.arccos(1 - 2 * v) 45 | 46 | else: 47 | # Just use the mean. 48 | theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean 49 | phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean 50 | 51 | phi = torch.clamp(phi, 1e-5, math.pi - 1e-5) 52 | 53 | output_points = torch.zeros((n, 3), device=device) 54 | output_points[:, 0:1] = r*torch.sin(phi) * torch.cos(theta) 55 | output_points[:, 2:3] = r*torch.sin(phi) * torch.sin(theta) 56 | output_points[:, 1:2] = r*torch.cos(phi) 57 | 58 | return output_points, phi, theta 59 | 60 | 61 | 62 | def create_cam2world_matrix(forward_vector, origin, device=None): 63 | """Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.""" 64 | 65 | forward_vector = normalize_vecs(forward_vector) 66 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector) 67 | 68 | left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1)) 69 | 70 | up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1)) 71 | 72 | rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 73 | rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1) 74 | 75 | translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1) 76 | translation_matrix[:, :3, 3] = origin 77 | 78 | cam2world = translation_matrix @ rotation_matrix 79 | 80 | return cam2world 81 | -------------------------------------------------------------------------------- /code/networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/pretrained_models/eg3d/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /code/train_3dmm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torch.utils import data 5 | from dataset import HeadData_3DMM 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from trainer_3dmm import Trainer 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | 13 | torch.backends.cudnn.enabled = True 14 | torch.backends.cudnn.benchmark = True 15 | 16 | 17 | def data_sampler(dataset, shuffle): 18 | if shuffle: 19 | return data.RandomSampler(dataset) 20 | else: 21 | return data.SequentialSampler(dataset) 22 | 23 | 24 | def sample_data(loader): 25 | while True: 26 | for batch in loader: 27 | yield batch 28 | 29 | 30 | def display_img(idx, img, name, writer, args): 31 | img = img.clamp(-1, 1) 32 | img = ((img - img.min()) / (img.max() - img.min())).data 33 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/display/'+str(idx)+name+'.png') 34 | 35 | writer.add_images(tag='%s' % (name), global_step=idx, img_tensor=img) 36 | 37 | def display_bases(imgs, name, args): 38 | for idx in range(len(imgs)): 39 | img = imgs[idx] 40 | img = img.clamp(-1, 1) 41 | img = ((img - img.min()) / (img.max() - img.min())).data 42 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/bases/'+str(idx)+name+'.png') 43 | 44 | def write_loss(i,l2_loss_3dmm, l2_loss, lpips_loss, writer): 45 | writer.add_scalar('3dmm_loss', l2_loss_3dmm.item(), i) 46 | writer.add_scalar('l2_loss', l2_loss.item(), i) 47 | writer.add_scalar('lpips_loss', lpips_loss.item(), i) 48 | writer.flush() 49 | 50 | def display_bases(imgs, name, args): 51 | for idx in range(len(imgs)): 52 | img = imgs[idx] 53 | img = img.clamp(-1, 1) 54 | img = ((img - img.min()) / (img.max() - img.min())).data 55 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/bases/'+str(idx)+name+'.png') 56 | 57 | def ddp_setup(args, rank, world_size): 58 | os.environ['MASTER_ADDR'] = args.addr 59 | os.environ['MASTER_PORT'] = args.port 60 | 61 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 62 | 63 | 64 | def main(rank, world_size, args): 65 | # init distributed computing 66 | ddp_setup(args, rank, world_size) 67 | torch.cuda.set_device(rank) 68 | device = torch.device("cuda") 69 | 70 | # make logging folder 71 | log_path = os.path.join(args.exp_path, args.exp_name + '/log') 72 | checkpoint_path = os.path.join(args.exp_path, args.exp_name + '/checkpoint') 73 | display_path = os.path.join(args.exp_path, args.exp_name + '/display') 74 | bases_path = os.path.join(args.exp_path, args.exp_name + '/bases') 75 | os.makedirs(log_path, exist_ok=True) 76 | os.makedirs(checkpoint_path, exist_ok=True) 77 | os.makedirs(display_path, exist_ok=True) 78 | os.makedirs(bases_path, exist_ok=True) 79 | writer = SummaryWriter(log_path) 80 | 81 | print('==> preparing dataset') 82 | transform = torchvision.transforms.Compose([ 83 | transforms.Resize(args.size), 84 | transforms.ToTensor(), 85 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 86 | 87 | dataset = HeadData_3DMM('train', transform, dataset = args.dataset , person = args.person_1) 88 | dataset_test = HeadData_3DMM('test', transform, dataset = args.dataset , person = args.person_1 ) 89 | 90 | loader = data.DataLoader( 91 | dataset, 92 | 93 | batch_size=args.batch_size // world_size, 94 | sampler=data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True), 95 | pin_memory=True, 96 | drop_last=False, 97 | ) 98 | 99 | loader_test = data.DataLoader( 100 | dataset_test, 101 | batch_size=1, 102 | sampler=data.distributed.DistributedSampler(dataset_test, num_replicas=world_size, rank=rank, shuffle=False), 103 | pin_memory=True, 104 | drop_last=False, 105 | ) 106 | 107 | loader = sample_data(loader) 108 | loader_test = sample_data(loader_test) 109 | 110 | print('==> initializing trainer') 111 | # Trainer 112 | trainer = Trainer(args, device, rank) 113 | 114 | 115 | print('==> training') 116 | pbar = range(args.iter) 117 | for idx in pbar: 118 | i = idx + args.start_iter 119 | 120 | 121 | real_image, label, params = next(loader) 122 | real_image = real_image.to(rank, non_blocking=True) 123 | label = label.to(rank, non_blocking=True) 124 | params = params.to(rank, non_blocking=True) 125 | 126 | 127 | # update generator 128 | l2_loss_3dmm, l2_loss, loss_lpips, generated_image = trainer.gen_update(real_image, label, params) 129 | 130 | 131 | if rank == 0: 132 | # write to log 133 | write_loss(idx, l2_loss_3dmm, l2_loss, loss_lpips, writer) 134 | if (i+1) >= args.tune_iter: 135 | # print('begin training nerf') 136 | trainer.tune_generator() 137 | # display 138 | if (i+1) % args.display_freq == 0 and rank == 0: 139 | print("[Iter %d/%d] [3dmm loss: %f] [l2 loss: %f] [lpips loss: %f]" 140 | % (i, args.iter, l2_loss_3dmm.item(), l2_loss.item(), loss_lpips.item())) 141 | 142 | if rank == 0: 143 | real_image_test, label_test, params_test = next(loader_test) 144 | real_image_test = real_image_test.to(rank, non_blocking=True) 145 | label_test = label_test.to(rank, non_blocking=True) 146 | params_test = params_test.to(rank, non_blocking=True) 147 | bases_1 = trainer.sample_bases(person_2 = False) 148 | display_bases(bases_1, 'person_1', args) 149 | 150 | 151 | 152 | img_recon = trainer.sample(real_image_test, label_test, params_test) 153 | display_img(i, real_image_test, 'source', writer, args) 154 | 155 | display_img(i, img_recon, 'recon', writer, args) 156 | writer.flush() 157 | 158 | # save model 159 | if (i+1) % args.save_freq == 0 and rank == 0: 160 | trainer.save(i, checkpoint_path) 161 | 162 | return 163 | 164 | 165 | if __name__ == "__main__": 166 | # training params 167 | parser = argparse.ArgumentParser() 168 | parser.add_argument("--iter", type=int, default=800000) 169 | parser.add_argument("--size", type=int, default=256) 170 | parser.add_argument("--batch_size", type=int, default=1) 171 | parser.add_argument("--dataset", type=str, default='nerface') 172 | parser.add_argument("--person_1", type=str, default='person_2') 173 | parser.add_argument("--run_id", type=str, default='nerface2') 174 | parser.add_argument("--run_id_2", type=str, default=None) 175 | parser.add_argument("--emb_dir", type=str, default='./PTI/embeddings/') 176 | 177 | parser.add_argument("--person_2", type=str, default=None) 178 | 179 | 180 | parser.add_argument("--params_len", type=int, default=76) 181 | parser.add_argument("--d_reg_every", type=int, default=16) 182 | parser.add_argument("--g_reg_every", type=int, default=4) 183 | parser.add_argument("--resume_ckpt", type=str, default='./code/exps/nerface2-v1/checkpoint/874999.pt') 184 | parser.add_argument("--lr", type=float, default=3e-4) 185 | parser.add_argument("--old", action='store_true', default=True) 186 | parser.add_argument("--tune", action='store_true', default=False) 187 | parser.add_argument("--init", action='store_true', default=False) 188 | 189 | parser.add_argument("--channel_multiplier", type=int, default=1) 190 | parser.add_argument("--start_iter", type=int, default=0) 191 | parser.add_argument("--display_freq", type=int, default=100) 192 | parser.add_argument("--save_freq", type=int, default=5000) 193 | parser.add_argument("--latent_dim_style", type=int, default=512) 194 | parser.add_argument("--latent_dim_shape", type=int, default=50) 195 | parser.add_argument("--exp_path", type=str, default='./code/exps/') 196 | parser.add_argument("--exp_name", type=str, default='v1') 197 | parser.add_argument("--addr", type=str, default='localhost') 198 | parser.add_argument("--port", type=str, default='12345') 199 | parser.add_argument("--tune_iter", type=int, default=50000) 200 | 201 | 202 | opts = parser.parse_args() 203 | 204 | n_gpus = torch.cuda.device_count() 205 | print('==> training on %d gpus' % n_gpus) 206 | world_size = n_gpus 207 | if world_size == 1: 208 | main(rank=0, world_size = world_size, args=opts) 209 | elif world_size > 1: 210 | mp.spawn(main, args=(world_size, opts,), nprocs=world_size, join=True) 211 | -------------------------------------------------------------------------------- /code/train_rgb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from torch.utils import data 5 | from dataset import HeadData 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | from trainer_rgb import Trainer 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | 13 | torch.backends.cudnn.enabled = True 14 | torch.backends.cudnn.benchmark = True 15 | 16 | 17 | def data_sampler(dataset, shuffle): 18 | if shuffle: 19 | return data.RandomSampler(dataset) 20 | else: 21 | return data.SequentialSampler(dataset) 22 | 23 | 24 | def sample_data(loader): 25 | while True: 26 | for batch in loader: 27 | yield batch 28 | 29 | 30 | def display_img(idx, img, name, writer, args): 31 | img = img.clamp(-1, 1) 32 | img = ((img - img.min()) / (img.max() - img.min())).data 33 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/display/'+str(idx)+name+'.png') 34 | 35 | writer.add_images(tag='%s' % (name), global_step=idx, img_tensor=img) 36 | 37 | 38 | def display_bases(imgs, name, args): 39 | for idx in range(len(imgs)): 40 | img = imgs[idx] 41 | img = img.clamp(-1, 1) 42 | img = ((img - img.min()) / (img.max() - img.min())).data 43 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/bases/'+str(idx)+name+'.png') 44 | 45 | 46 | 47 | def write_loss(i, l2_loss, lpips_loss, writer): 48 | writer.add_scalar('l2_loss', l2_loss.item(), i) 49 | writer.add_scalar('lpips_loss', lpips_loss.item(), i) 50 | writer.flush() 51 | 52 | 53 | def ddp_setup(args, rank, world_size): 54 | os.environ['MASTER_ADDR'] = args.addr 55 | os.environ['MASTER_PORT'] = args.port 56 | 57 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 58 | 59 | 60 | def main(rank, world_size, args): 61 | # init distributed computing 62 | ddp_setup(args, rank, world_size) 63 | torch.cuda.set_device(rank) 64 | device = torch.device("cuda") 65 | 66 | # make logging folder 67 | log_path = os.path.join(args.exp_path, args.exp_name + '/log') 68 | checkpoint_path = os.path.join(args.exp_path, args.exp_name + '/checkpoint') 69 | display_path = os.path.join(args.exp_path, args.exp_name + '/display') 70 | bases_path = os.path.join(args.exp_path, args.exp_name + '/bases') 71 | os.makedirs(log_path, exist_ok=True) 72 | os.makedirs(checkpoint_path, exist_ok=True) 73 | os.makedirs(display_path, exist_ok=True) 74 | os.makedirs(bases_path, exist_ok=True) 75 | writer = SummaryWriter(log_path) 76 | 77 | print('==> preparing dataset') 78 | transform = torchvision.transforms.Compose([ 79 | transforms.Resize(args.size), 80 | transforms.ToTensor(), 81 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 82 | dataset = HeadData('train', transform, dataset = args.dataset , person = args.person) 83 | dataset_test = HeadData('test', transform, dataset = args.dataset , person = args.person ) 84 | 85 | loader = data.DataLoader( 86 | dataset, 87 | 88 | batch_size=args.batch_size // world_size, 89 | sampler=data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True), 90 | pin_memory=True, 91 | drop_last=False, 92 | ) 93 | 94 | loader_test = data.DataLoader( 95 | dataset_test, 96 | batch_size=1, 97 | sampler=data.distributed.DistributedSampler(dataset_test, num_replicas=world_size, rank=rank, shuffle=False), 98 | pin_memory=True, 99 | drop_last=False, 100 | ) 101 | 102 | loader = sample_data(loader) 103 | loader_test = sample_data(loader_test) 104 | print('==> initializing trainer') 105 | # Trainer 106 | trainer = Trainer(args, device, rank) 107 | 108 | # resume 109 | if args.resume_ckpt is not None: 110 | args.start_iter = trainer.resume(args.resume_ckpt) 111 | print('==> resume from iteration %d' % (args.start_iter)) 112 | 113 | print('==> training') 114 | pbar = range(args.iter) 115 | for idx in pbar: 116 | i = idx + args.start_iter 117 | 118 | # laoding data 119 | real_image, label = next(loader) 120 | real_image = real_image.to(rank, non_blocking=True) 121 | label = label.to(rank, non_blocking=True) 122 | 123 | 124 | # update generator 125 | l2_loss, loss_lpips, generated_image = trainer.gen_update(real_image, label, person_2 = False) 126 | 127 | 128 | if rank == 0: 129 | # write to log 130 | write_loss(idx, l2_loss, loss_lpips, writer) 131 | 132 | if (i+1) >= args.tune_iter: 133 | # print('begin training nerf') 134 | trainer.tune_generator() 135 | # display 136 | if (i+1) % args.display_freq == 0 and rank == 0: 137 | print("[Iter %d/%d] [l2 loss: %f] [lpips loss: %f]" 138 | % (i, args.iter, l2_loss.item(), loss_lpips.item())) 139 | 140 | if rank == 0: 141 | real_image_test, label_test = next(loader_test) 142 | real_image_test = real_image_test.to(rank, non_blocking=True) 143 | label_test = label_test.to(rank, non_blocking=True) 144 | 145 | img_recon = trainer.sample(real_image_test, label_test) 146 | bases_1 = trainer.sample_bases(person_2 = False) 147 | display_bases(bases_1, 'person_1', args) 148 | display_img(i, real_image_test, 'source', writer, args) 149 | display_img(i, img_recon, 'recon', writer, args) 150 | writer.flush() 151 | 152 | # save model 153 | if (i+1) % args.save_freq == 0 and rank == 0: 154 | trainer.save(i, checkpoint_path) 155 | 156 | return 157 | 158 | 159 | if __name__ == "__main__": 160 | # training params 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument("--iter", type=int, default=800000) 163 | parser.add_argument("--size", type=int, default=256) 164 | parser.add_argument("--batch_size", type=int, default=2) 165 | parser.add_argument("--dataset", type=str, default='nerface_dataset') 166 | parser.add_argument("--person", type=str, default='person_3') 167 | parser.add_argument("--person_2", type=str, default=None) 168 | parser.add_argument("--run_id", type=str, default='nerface2') 169 | parser.add_argument("--run_id_2", type=str, default=None) 170 | parser.add_argument("--emb_dir", type=str, default='./PTI/embeddings/') 171 | 172 | parser.add_argument("--d_reg_every", type=int, default=16) 173 | parser.add_argument("--g_reg_every", type=int, default=4) 174 | parser.add_argument("--resume_ckpt", type=str, default=None) 175 | parser.add_argument("--old", action='store_true', default=True) 176 | parser.add_argument("--tune", action='store_true', default=True) 177 | parser.add_argument("--init", action='store_true', default=False) 178 | parser.add_argument("--same_bases", action='store_true', default=False) 179 | parser.add_argument("--out_pose", action='store_true', default=False) 180 | parser.add_argument("--lr", type=float, default=3e-4) 181 | 182 | 183 | parser.add_argument("--channel_multiplier", type=int, default=1) 184 | parser.add_argument("--start_iter", type=int, default=0) 185 | parser.add_argument("--display_freq", type=int, default=5000) 186 | parser.add_argument("--save_freq", type=int, default=5000) 187 | parser.add_argument("--latent_dim_style", type=int, default=512) 188 | parser.add_argument("--latent_dim_shape", type=int, default=50) 189 | parser.add_argument("--exp_path", type=str, default='./code/exps/') 190 | parser.add_argument("--exp_name", type=str, default='v1') 191 | parser.add_argument("--addr", type=str, default='localhost') 192 | parser.add_argument("--port", type=str, default='12345') 193 | parser.add_argument("--tune_iter", type=int, default=50000) 194 | opts = parser.parse_args() 195 | 196 | n_gpus = torch.cuda.device_count() 197 | print('==> training on %d gpus' % n_gpus) 198 | world_size = n_gpus 199 | if world_size == 1: 200 | main(rank=0, world_size = world_size, args=opts) 201 | elif world_size > 1: 202 | mp.spawn(main, args=(world_size, opts,), nprocs=world_size, join=True) 203 | -------------------------------------------------------------------------------- /code/trainer_3dmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from networks.headnerf import HeadNeRF_3DMM 3 | import torch.nn.functional as F 4 | from torch import nn, optim 5 | import os 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | 8 | from lpips import LPIPS 9 | def requires_grad(net, flag=True): 10 | for p in net.parameters(): 11 | p.requires_grad = flag 12 | 13 | l2_criterion = torch.nn.MSELoss(reduction='mean') 14 | 15 | from cam_utils import sample_camera_positions, create_cam2world_matrix 16 | import math 17 | import torchvision.transforms as transforms 18 | 19 | class Trainer(nn.Module): 20 | def __init__(self, args, device, rank): 21 | super(Trainer, self).__init__() 22 | 23 | self.args = args 24 | self.batch_size = args.batch_size 25 | 26 | self.gen = HeadNeRF_3DMM(args, args.size, device, args.latent_dim_style, args.latent_dim_shape, args.run_id, args.emb_dir).to( 27 | device) 28 | 29 | self.gen = DDP(self.gen, device_ids=[rank], find_unused_parameters=True) 30 | self.w_optim = torch.optim.Adam(self.gen.parameters(), lr= args.lr) 31 | for param in self.gen.module.generator.parameters(): 32 | param.requires_grad = False 33 | self.lpips_loss = LPIPS(net='alex').to(device).eval() 34 | self.device = device 35 | self.face_pool = torch.nn.AdaptiveAvgPool2d((args.size, args.size)) 36 | 37 | def l2_loss(self, real_images, generated_images): 38 | loss = l2_criterion(real_images, generated_images) 39 | return loss 40 | def tune_generator(self): 41 | for param in self.gen.module.generator.parameters(): 42 | param.requires_grad = True 43 | def gen_update(self, real_image, label, params, person_2 = False): 44 | self.gen.train() 45 | self.w_optim.zero_grad() 46 | 47 | 48 | 49 | 50 | generated_image = self.gen(params, label, person_2) 51 | generated_image = self.face_pool(generated_image) 52 | 53 | l2_loss_3dmm = torch.zeros(1).to(self.device) #self.l2_loss(weights, generated_weights) 54 | l2_loss = self.l2_loss(real_image, generated_image) 55 | loss_lpips = self.lpips_loss( real_image, generated_image) 56 | loss_lpips = torch.squeeze(loss_lpips).mean() 57 | 58 | 59 | g_loss = l2_loss_3dmm + l2_loss + loss_lpips 60 | 61 | 62 | g_loss.backward() 63 | 64 | 65 | self.w_optim.step() 66 | 67 | return l2_loss_3dmm, l2_loss, loss_lpips, generated_image 68 | 69 | 70 | def sample(self, real_image, label, params, person_2 = False): 71 | with torch.no_grad(): 72 | self.gen.eval() 73 | 74 | img_recon = self.gen(params, label, person_2) 75 | 76 | 77 | return img_recon#, img_source_ref 78 | def sample_bases(self, person_2 = False ): 79 | img_recons = [] 80 | with torch.no_grad(): 81 | r = 2.7 82 | points, _, _ = sample_camera_positions(device=self.device, n=1, r=r, horizontal_mean=0.5*math.pi, vertical_mean=0.5*math.pi, mode=None) 83 | label = create_cam2world_matrix(-points, points, device=self.device) 84 | label = label.reshape(1, -1) 85 | label = torch.cat((label, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(1, 1).to(label)), -1) 86 | self.gen.eval() 87 | 88 | for base_id in range(self.args.latent_dim_shape): 89 | weights = torch.zeros(self.args.latent_dim_shape).to(self.device) 90 | weights[base_id] = 5 91 | weights = weights.unsqueeze(0) 92 | 93 | 94 | latent = self.gen.module.get_latent(weights,person_2) 95 | img_recon = self.gen.module.get_image(latent, label) 96 | img_recons.append(img_recon) 97 | 98 | return img_recons#, img_source_ref 99 | 100 | 101 | def resume(self, resume_ckpt): 102 | print("load model:", resume_ckpt) 103 | ckpt = torch.load(resume_ckpt) 104 | ckpt_name = os.path.basename(resume_ckpt) 105 | start_iter = int(os.path.splitext(ckpt_name)[0]) 106 | 107 | 108 | self.gen.module.load_state_dict(ckpt["gen"]) 109 | 110 | self.w_optim.load_state_dict(ckpt["w_optim"]) 111 | 112 | return start_iter 113 | 114 | def save(self, idx, checkpoint_path): 115 | torch.save( 116 | { 117 | "gen": self.gen.module.state_dict(), 118 | "w_optim": self.w_optim.state_dict(), 119 | "args": self.args 120 | }, 121 | f"{checkpoint_path}/{str(idx).zfill(6)}.pt" 122 | ) 123 | -------------------------------------------------------------------------------- /code/trainer_rgb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # from networks.discriminator import Discriminator 3 | # from networks.generator import Generator 4 | from networks.headnerf import HeadNeRF_final 5 | import torch.nn.functional as F 6 | from torch import nn, optim 7 | import os 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | 10 | from lpips import LPIPS 11 | def requires_grad(net, flag=True): 12 | for p in net.parameters(): 13 | p.requires_grad = flag 14 | 15 | l2_criterion = torch.nn.MSELoss(reduction='mean') 16 | 17 | from cam_utils import sample_camera_positions, create_cam2world_matrix 18 | import math 19 | 20 | import torchvision.transforms as transforms 21 | 22 | 23 | 24 | 25 | 26 | 27 | def cam_sampler(batch, device): 28 | camera_points, phi, theta = sample_camera_positions(device, n=batch, r=2.7, horizontal_mean=0.5*math.pi, 29 | vertical_mean=0.5*math.pi, horizontal_stddev=0.3, vertical_stddev=0.155, mode='gaussian') 30 | c = create_cam2world_matrix(-camera_points, camera_points, device=device) 31 | c = c.reshape(batch, -1) 32 | c = torch.cat((c, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(batch, 1).to(c)), -1) 33 | return c 34 | 35 | 36 | def cam_sampler_pose(batch, horizontal_mean, vertical_mean , device): 37 | camera_points, phi, theta = sample_camera_positions(device, n=batch, r=2.7, horizontal_mean=horizontal_mean*math.pi, 38 | vertical_mean=vertical_mean*math.pi, horizontal_stddev=0.15, vertical_stddev=0.155, mode='gaussian') 39 | c = create_cam2world_matrix(-camera_points, camera_points, device=device) 40 | c = c.reshape(batch, -1) 41 | c = torch.cat((c, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(batch, 1).to(c)), -1) 42 | return c 43 | 44 | 45 | 46 | class Trainer(nn.Module): 47 | def __init__(self, args, device, rank): 48 | super(Trainer, self).__init__() 49 | 50 | self.args = args 51 | self.batch_size = args.batch_size 52 | self.device = device 53 | 54 | self.gen = HeadNeRF_final(args, args.size, device, args.latent_dim_style, args.latent_dim_shape, args.run_id, args.emb_dir).to( 55 | device) 56 | self.gen = DDP(self.gen, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True) 57 | 58 | self.g_optim = torch.optim.Adam(self.gen.parameters(), lr= args.lr) 59 | for param in self.gen.module.generator.parameters(): 60 | param.requires_grad = False 61 | 62 | self.lpips_loss = LPIPS(net='alex').to(device).eval() 63 | self.face_pool = torch.nn.AdaptiveAvgPool2d((args.size, args.size)) 64 | 65 | def l2_loss(self, real_images, generated_images): 66 | loss = l2_criterion(real_images, generated_images) 67 | return loss 68 | 69 | def tune_generator(self): 70 | for param in self.gen.module.generator.parameters(): 71 | param.requires_grad = True 72 | 73 | def gen_update(self, real_image, label, person_2 = False, mask = None): 74 | self.gen.train() 75 | self.g_optim.zero_grad() 76 | 77 | weights_i = self.gen.module.get_weights(real_image) 78 | latent_i = self.gen.module.get_latent(weights_i, person_2) 79 | generated_image = self.gen.module.get_image(latent_i, label) 80 | 81 | 82 | 83 | 84 | generated_image = self.face_pool(generated_image) 85 | l2_loss = self.l2_loss(real_image, generated_image) 86 | loss_lpips = self.lpips_loss( real_image, generated_image) 87 | loss_lpips = torch.squeeze(loss_lpips).mean() 88 | 89 | 90 | 91 | g_loss = (l2_loss + loss_lpips ) 92 | 93 | g_loss.backward() 94 | 95 | 96 | self.g_optim.step() 97 | 98 | return l2_loss, loss_lpips, generated_image 99 | 100 | def sample(self, real_image, label, person_2 = False ): 101 | with torch.no_grad(): 102 | self.gen.eval() 103 | img_recon = self.gen(real_image, label, person_2) 104 | 105 | 106 | return img_recon 107 | 108 | def sample_bases(self, person_2 = False ): 109 | img_recons = [] 110 | with torch.no_grad(): 111 | r = 2.7 112 | points, _, _ = sample_camera_positions(device=self.device, n=1, r=r, horizontal_mean=0.5*math.pi, vertical_mean=0.5*math.pi, mode=None) 113 | label = create_cam2world_matrix(-points, points, device=self.device) 114 | label = label.reshape(1, -1) 115 | label = torch.cat((label, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(1, 1).to(label)), -1) 116 | self.gen.eval() 117 | 118 | for base_id in range(self.args.latent_dim_shape): 119 | weights = torch.zeros(self.args.latent_dim_shape).to(self.device) 120 | weights[base_id] = 10 121 | weights = weights.unsqueeze(0) 122 | 123 | latent = self.gen.module.get_latent(weights,person_2) 124 | img_recon = self.gen.module.get_image(latent, label) 125 | img_recons.append(img_recon) 126 | 127 | return img_recons 128 | 129 | 130 | def resume(self, resume_ckpt): 131 | print("load model:", resume_ckpt) 132 | ckpt = torch.load(resume_ckpt) 133 | ckpt_name = os.path.basename(resume_ckpt) 134 | start_iter = int(os.path.splitext(ckpt_name)[0]) 135 | 136 | self.gen.module.load_state_dict(ckpt["gen"]) 137 | 138 | self.g_optim.load_state_dict(ckpt["g_optim"]) 139 | 140 | 141 | return start_iter 142 | 143 | def save(self, idx, checkpoint_path): 144 | torch.save( 145 | { 146 | "gen": self.gen.module.state_dict(), 147 | "g_optim": self.g_optim.state_dict(), 148 | "args": self.args 149 | }, 150 | f"{checkpoint_path}/{str(idx).zfill(6)}.pt" 151 | ) 152 | -------------------------------------------------------------------------------- /eg3d-pose-detection/3dface2idr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import json 5 | import argparse 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--in_root', type=str, default="", help='process folder') 10 | parser.add_argument('--out_root', type=str, default="output", help='output folder') 11 | args = parser.parse_args() 12 | in_root = args.in_root 13 | 14 | def compute_rotation(angles): 15 | """ 16 | Return: 17 | rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat 18 | 19 | Parameters: 20 | angles -- torch.tensor, size (B, 3), radian 21 | """ 22 | 23 | batch_size = angles.shape[0] 24 | ones = torch.ones([batch_size, 1]) 25 | zeros = torch.zeros([batch_size, 1]) 26 | x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:], 27 | 28 | rot_x = torch.cat([ 29 | ones, zeros, zeros, 30 | zeros, torch.cos(x), -torch.sin(x), 31 | zeros, torch.sin(x), torch.cos(x) 32 | ], dim=1).reshape([batch_size, 3, 3]) 33 | 34 | rot_y = torch.cat([ 35 | torch.cos(y), zeros, torch.sin(y), 36 | zeros, ones, zeros, 37 | -torch.sin(y), zeros, torch.cos(y) 38 | ], dim=1).reshape([batch_size, 3, 3]) 39 | 40 | rot_z = torch.cat([ 41 | torch.cos(z), -torch.sin(z), zeros, 42 | torch.sin(z), torch.cos(z), zeros, 43 | zeros, zeros, ones 44 | ], dim=1).reshape([batch_size, 3, 3]) 45 | 46 | rot = rot_z @ rot_y @ rot_x 47 | return rot.permute(0, 2, 1)[0] 48 | 49 | npys = sorted([x for x in os.listdir(in_root) if x.endswith(".npy")]) 50 | 51 | mode = 1 #1 = IDR, 2 = LSX 52 | outAll={} 53 | 54 | for src_filename in npys: 55 | src = os.path.join(in_root, src_filename) 56 | 57 | # print(src) 58 | dict_load=np.load(src, allow_pickle=True) 59 | 60 | angle = dict_load.item()['angle'] 61 | trans = dict_load.item()['trans'][0] 62 | R = compute_rotation(torch.from_numpy(angle)).numpy() 63 | trans[2] += -10 64 | c = -np.dot(R, trans) 65 | pose = np.eye(4) 66 | pose[:3, :3] = R 67 | 68 | c *= 0.27 # factor to match tripleganger 69 | c[1] += 0.006 # offset to align to tripleganger 70 | c[2] += 0.161 # offset to align to tripleganger 71 | pose[0,3] = c[0] 72 | pose[1,3] = c[1] 73 | pose[2,3] = c[2] 74 | 75 | focal = 2985.29 # = 1015*1024/224*(300/466.285)# 76 | pp = 512#112 77 | w = 1024#224 78 | h = 1024#224 79 | 80 | if mode==1: 81 | count = 0 82 | K = np.eye(3) 83 | K[0][0] = focal 84 | K[1][1] = focal 85 | K[0][2] = w/2.0 86 | K[1][2] = h/2.0 87 | K = K.tolist() 88 | 89 | Rot = np.eye(3) 90 | Rot[0, 0] = 1 91 | Rot[1, 1] = -1 92 | Rot[2, 2] = -1 93 | pose[:3, :3] = np.dot(pose[:3, :3], Rot) 94 | 95 | pose = pose.tolist() 96 | out = {} 97 | out["intrinsics"] = K 98 | out["pose"] = pose 99 | out["angle"] = (angle * [1, -1, 1]).flatten().tolist() 100 | outAll[src_filename.replace(".npy", ".png")] = out 101 | 102 | elif mode==2: 103 | 104 | dst = os.path.join(in_root, src_filename.replace(".npy", "_lscam.txt")) 105 | outCam = open(dst, "w") 106 | outCam.write("#focal length\n") 107 | outCam.write(str(focal) + " " + str(focal) + "\n") 108 | 109 | outCam.write("#principal point\n") 110 | outCam.write(str(pp) + " " + str(pp) + "\n") 111 | 112 | outCam.write("#resolution\n") 113 | outCam.write(str(w) + " " + str(h) + "\n") 114 | 115 | outCam.write("#distortion coeffs\n") 116 | outCam.write("0 0 0 0\n") 117 | 118 | 119 | outCam.write("MATRIX :\n") 120 | for r in range(4): 121 | outCam.write(str(pose[r, 0]) + " " + str(pose[r, 1]) + " " + str(pose[r, 2]) + " " + str(pose[r, 3]) + "\n") 122 | 123 | outCam.close() 124 | 125 | print("mode:", mode) 126 | print("out dir:", args.out_root) 127 | if mode == 1: 128 | dst = os.path.join(args.out_root, "cameras.json") 129 | with open(dst, "w") as outfile: 130 | json.dump(outAll, outfile, indent=4) 131 | -------------------------------------------------------------------------------- /eg3d-pose-detection/batch_mtcnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import os 4 | from mtcnn import MTCNN 5 | import random 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | detector = MTCNN() 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--in_root', type=str, default="", help='process folder') 12 | args = parser.parse_args() 13 | in_root = args.in_root 14 | 15 | out_root = os.path.join(in_root, "debug") 16 | out_detection = os.path.join(in_root, "detections") 17 | if not os.path.exists(out_root): 18 | os.makedirs(out_root) 19 | if not os.path.exists(out_detection): 20 | os.makedirs(out_detection) 21 | 22 | imgs = sorted([x for x in os.listdir(in_root) if x.endswith(".jpg") or x.endswith(".png")]) 23 | random.shuffle(imgs) 24 | for img in tqdm(imgs): 25 | src = os.path.join(in_root, img) 26 | dst = os.path.join(out_detection, img.replace(".jpg", ".txt").replace(".png", ".txt")) 27 | 28 | if not os.path.exists(dst): 29 | image = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB) 30 | result = detector.detect_faces(image) 31 | 32 | if len(result)>0: 33 | index = 0 34 | if len(result)>1: # if multiple faces, take the biggest face 35 | # size = -100000 36 | lowest_dist = float('Inf') 37 | for r in range(len(result)): 38 | # print(result[r]["box"][0], result[r]["box"][1]) 39 | face_pos = np.array(result[r]["box"][:2]) + np.array(result[r]["box"][2:])/2 40 | 41 | dist_from_center = np.linalg.norm(face_pos - np.array([1500./2, 1500./2])) 42 | if dist_from_center < lowest_dist: 43 | lowest_dist = dist_from_center 44 | index=r 45 | 46 | 47 | # size_ = result[r]["box"][2] + result[r]["box"][3] 48 | # if size < size_: 49 | # size = size_ 50 | # index = r 51 | 52 | # Result is an array with all the bounding boxes detected. We know that for 'ivan.jpg' there is only one. 53 | bounding_box = result[index]['box'] 54 | keypoints = result[index]['keypoints'] 55 | if result[index]["confidence"] > 0.9: 56 | 57 | cv2.rectangle(image, 58 | (bounding_box[0], bounding_box[1]), 59 | (bounding_box[0]+bounding_box[2], bounding_box[1] + bounding_box[3]), 60 | (0,155,255), 61 | 2) 62 | 63 | cv2.circle(image,(keypoints['left_eye']), 2, (0,155,255), 2) 64 | cv2.circle(image,(keypoints['right_eye']), 2, (0,155,255), 2) 65 | cv2.circle(image,(keypoints['nose']), 2, (0,155,255), 2) 66 | cv2.circle(image,(keypoints['mouth_left']), 2, (0,155,255), 2) 67 | cv2.circle(image,(keypoints['mouth_right']), 2, (0,155,255), 2) 68 | 69 | dst = os.path.join(out_root, img) 70 | # cv2.imwrite(dst, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) 71 | 72 | dst = os.path.join(out_detection, img.replace(".jpg", ".txt").replace(".png", ".txt")) 73 | outLand = open(dst, "w") 74 | outLand.write(str(float(keypoints['left_eye'][0])) + " " + str(float(keypoints['left_eye'][1])) + "\n") 75 | outLand.write(str(float(keypoints['right_eye'][0])) + " " + str(float(keypoints['right_eye'][1])) + "\n") 76 | outLand.write(str(float(keypoints['nose'][0])) + " " + str(float(keypoints['nose'][1])) + "\n") 77 | outLand.write(str(float(keypoints['mouth_left'][0])) + " " + str(float(keypoints['mouth_left'][1])) + "\n") 78 | outLand.write(str(float(keypoints['mouth_right'][0])) + " " + str(float(keypoints['mouth_right'][1])) + "\n") 79 | outLand.close() -------------------------------------------------------------------------------- /eg3d-pose-detection/camera2label.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import argparse 4 | import os 5 | """ 6 | convert cameras.json into the format of eg3d label 7 | """ 8 | # fname = '/apdcephfs_cq2/share_1290939/kitbai/PTI/data/nerface/1/cropped_images/cameras.json' 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--in_root', type=str, default="", help='process folder') 11 | args = parser.parse_args() 12 | in_root = args.in_root 13 | 14 | fname = os.path.join(in_root, "cropped_images/cameras.json") 15 | # fname = '/apdcephfs_cq2/share_1290939/kitbai/LIA-3d/datasets/ad_dataset/french/train/cropped_images/cameras.json' 16 | # fname = '/apdcephfs_cq2/share_1290939/kitbai/celeba_sub/images1/cropped_images/cameras.json' 17 | with open(fname, 'rb') as f: 18 | labels = json.load(f) 19 | 20 | results_new = [] 21 | for ind in labels.keys(): 22 | pose = np.array(labels[ind]["pose"]).reshape(16) 23 | pose = list(pose) + list([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]) 24 | results_new.append((ind, pose)) 25 | 26 | # with open("/apdcephfs_cq2/share_1290939/kitbai/PTI/data/nerface/1/cropped_images/test.json", 'w') as outfile: 27 | # with open("/apdcephfs_cq2/share_1290939/kitbai/LIA-3d/datasets/ad_dataset/french/train/cropped_images/test.json", 'w') as outfile: 28 | # json.dump({"labels": results_new}, outfile, indent="\t") 29 | with open(os.path.join(in_root, "cropped_images/test.json"), 'w') as outfile: 30 | json.dump({"labels": results_new}, outfile, indent="\t") 31 | -------------------------------------------------------------------------------- /eg3d-pose-detection/crop_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | # calculating least square problem for image alignment 10 | def POS(xp, x): 11 | 12 | npts = xp.shape[1] 13 | 14 | A = np.zeros([2*npts, 8]) 15 | 16 | A[0:2*npts-1:2, 0:3] = x.transpose() 17 | A[0:2*npts-1:2, 3] = 1 18 | A[1:2*npts:2, 4:7] = x.transpose() 19 | A[1:2*npts:2, 7] = 1 20 | 21 | b = np.reshape(xp.transpose(), [2*npts, 1]) 22 | 23 | k, _, _, _ = np.linalg.lstsq(A, b) 24 | 25 | R1 = k[0:3] 26 | R2 = k[4:7] 27 | sTx = k[3] 28 | sTy = k[7] 29 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2 30 | 31 | t = np.stack([sTx, sTy], axis=0) 32 | 33 | return t, s 34 | 35 | def extract_5p(lm): 36 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1 37 | lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean( 38 | lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0) 39 | lm5p = lm5p[[1, 2, 0, 3, 4], :] 40 | return lm5p 41 | 42 | # resize and crop images for face reconstruction 43 | def resize_n_crop_img(img, lm, t, s, target_size=1024., mask=None): 44 | w0, h0 = img.size 45 | w = (w0*s).astype(np.int32) 46 | h = (h0*s).astype(np.int32) 47 | left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32) 48 | right = left + target_size 49 | up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32) 50 | below = up + target_size 51 | img = img.resize((w, h), resample=Image.LANCZOS) 52 | img = img.crop((left, up, right, below)) 53 | 54 | if mask is not None: 55 | mask = mask.resize((w, h), resample=Image.LANCZOS) 56 | mask = mask.crop((left, up, right, below)) 57 | 58 | lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] - 59 | t[1] + h0/2], axis=1)*s 60 | lm = lm - np.reshape( 61 | np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2]) 62 | return img, lm, mask 63 | 64 | 65 | # utils for face reconstruction 66 | def align_img(img, lm, lm3D, mask=None, target_size=1024., rescale_factor=466.285): 67 | """ 68 | Return: 69 | transparams --numpy.array (raw_W, raw_H, scale, tx, ty) 70 | img_new --PIL.Image (target_size, target_size, 3) 71 | lm_new --numpy.array (68, 2), y direction is opposite to v direction 72 | mask_new --PIL.Image (target_size, target_size) 73 | 74 | Parameters: 75 | img --PIL.Image (raw_H, raw_W, 3) 76 | lm --numpy.array (68, 2), y direction is opposite to v direction 77 | lm3D --numpy.array (5, 3) 78 | mask --PIL.Image (raw_H, raw_W, 3) 79 | """ 80 | 81 | w0, h0 = img.size 82 | if lm.shape[0] != 5: 83 | lm5p = extract_5p(lm) 84 | else: 85 | lm5p = lm 86 | 87 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face 88 | t, s = POS(lm5p.transpose(), lm3D.transpose()) 89 | s = rescale_factor/s 90 | 91 | # processing the image 92 | img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask) 93 | # img.save("/home/koki/Projects/Deep3DFaceRecon_pytorch/checkpoints/pretrained/results/iphone/epoch_20_000000/img_new.jpg") 94 | trans_params = np.array([w0, h0, s, t[0], t[1]]) 95 | lm_new *= 224/1024.0 96 | img_new_low = img_new.resize((224, 224), resample=Image.LANCZOS) 97 | 98 | return trans_params, img_new_low, lm_new, mask_new, img_new 99 | 100 | 101 | if __name__ == '__main__': 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--indir', type=str, required=True) 104 | parser.add_argument('--outdir', type=str, required=True) 105 | parser.add_argument('--compress_level', type=int, default=0) 106 | args = parser.parse_args() 107 | 108 | with open(os.path.join(args.indir, 'cropping_params.json')) as f: 109 | cropping_params = json.load(f) 110 | 111 | os.makedirs(args.outdir, exist_ok=True) 112 | 113 | for im_path, cropping_dict in tqdm(cropping_params.items()): 114 | im = Image.open(os.path.join(args.indir, im_path)).convert('RGB') 115 | 116 | _, H = im.size 117 | lm = np.array(cropping_dict['lm']) 118 | lm = lm.reshape([-1, 2]) 119 | lm[:, -1] = H - 1 - lm[:, -1] 120 | 121 | _, im_pil, lm, _, im_high = align_img(im, lm, np.array(cropping_dict['lm3d_std']), rescale_factor=cropping_dict['rescale_factor']) 122 | 123 | left = int(im_high.size[0]/2 - cropping_dict['center_crop_size']/2) 124 | upper = int(im_high.size[1]/2 - cropping_dict['center_crop_size']/2) 125 | right = left + cropping_dict['center_crop_size'] 126 | lower = upper + cropping_dict['center_crop_size'] 127 | im_cropped = im_high.crop((left, upper, right,lower)) 128 | im_cropped = im_cropped.resize((cropping_dict['output_size'], cropping_dict['output_size']), resample=Image.LANCZOS) 129 | 130 | im_cropped.save(os.path.join(args.outdir, os.path.basename(im_path)), compress_level=args.compress_level) 131 | im_high.save(os.path.join(args.indir, 'crop_1024', os.path.basename(im_path)), compress_level=args.compress_level) -------------------------------------------------------------------------------- /eg3d-pose-detection/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import numpy as np 14 | import importlib 15 | import torch.utils.data 16 | from data.base_dataset import BaseDataset 17 | 18 | 19 | def find_dataset_using_name(dataset_name): 20 | """Import the module "data/[dataset_name]_dataset.py". 21 | 22 | In the file, the class called DatasetNameDataset() will 23 | be instantiated. It has to be a subclass of BaseDataset, 24 | and it is case-insensitive. 25 | """ 26 | dataset_filename = "data." + dataset_name + "_dataset" 27 | datasetlib = importlib.import_module(dataset_filename) 28 | 29 | dataset = None 30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 31 | for name, cls in datasetlib.__dict__.items(): 32 | if name.lower() == target_dataset_name.lower() \ 33 | and issubclass(cls, BaseDataset): 34 | dataset = cls 35 | 36 | if dataset is None: 37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 38 | 39 | return dataset 40 | 41 | 42 | def get_option_setter(dataset_name): 43 | """Return the static method of the dataset class.""" 44 | dataset_class = find_dataset_using_name(dataset_name) 45 | return dataset_class.modify_commandline_options 46 | 47 | 48 | def create_dataset(opt, rank=0): 49 | """Create a dataset given the option. 50 | 51 | This function wraps the class CustomDatasetDataLoader. 52 | This is the main interface between this package and 'train.py'/'test.py' 53 | 54 | Example: 55 | >>> from data import create_dataset 56 | >>> dataset = create_dataset(opt) 57 | """ 58 | data_loader = CustomDatasetDataLoader(opt, rank=rank) 59 | dataset = data_loader.load_data() 60 | return dataset 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt, rank=0): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | self.sampler = None 75 | print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__)) 76 | if opt.use_ddp and opt.isTrain: 77 | world_size = opt.world_size 78 | self.sampler = torch.utils.data.distributed.DistributedSampler( 79 | self.dataset, 80 | num_replicas=world_size, 81 | rank=rank, 82 | shuffle=not opt.serial_batches 83 | ) 84 | self.dataloader = torch.utils.data.DataLoader( 85 | self.dataset, 86 | sampler=self.sampler, 87 | num_workers=int(opt.num_threads / world_size), 88 | batch_size=int(opt.batch_size / world_size), 89 | drop_last=True) 90 | else: 91 | self.dataloader = torch.utils.data.DataLoader( 92 | self.dataset, 93 | batch_size=opt.batch_size, 94 | shuffle=(not opt.serial_batches) and opt.isTrain, 95 | num_workers=int(opt.num_threads), 96 | drop_last=True 97 | ) 98 | 99 | def set_epoch(self, epoch): 100 | self.dataset.current_epoch = epoch 101 | if self.sampler is not None: 102 | self.sampler.set_epoch(epoch) 103 | 104 | def load_data(self): 105 | return self 106 | 107 | def __len__(self): 108 | """Return the number of data in the dataset""" 109 | return min(len(self.dataset), self.opt.max_dataset_size) 110 | 111 | def __iter__(self): 112 | """Return a batch of data""" 113 | for i, data in enumerate(self.dataloader): 114 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 115 | break 116 | yield data 117 | -------------------------------------------------------------------------------- /eg3d-pose-detection/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | # self.root = opt.dataroot 31 | self.current_epoch = 0 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_transform(grayscale=False): 65 | transform_list = [] 66 | if grayscale: 67 | transform_list.append(transforms.Grayscale(1)) 68 | transform_list += [transforms.ToTensor()] 69 | return transforms.Compose(transform_list) 70 | 71 | def get_affine_mat(opt, size): 72 | shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False 73 | w, h = size 74 | 75 | if 'shift' in opt.preprocess: 76 | shift_pixs = int(opt.shift_pixs) 77 | shift_x = random.randint(-shift_pixs, shift_pixs) 78 | shift_y = random.randint(-shift_pixs, shift_pixs) 79 | if 'scale' in opt.preprocess: 80 | scale = 1 + opt.scale_delta * (2 * random.random() - 1) 81 | if 'rot' in opt.preprocess: 82 | rot_angle = opt.rot_angle * (2 * random.random() - 1) 83 | rot_rad = -rot_angle * np.pi/180 84 | if 'flip' in opt.preprocess: 85 | flip = random.random() > 0.5 86 | 87 | shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3]) 88 | flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3]) 89 | shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3]) 90 | rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3]) 91 | scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3]) 92 | shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3]) 93 | 94 | affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin 95 | affine_inv = np.linalg.inv(affine) 96 | return affine, affine_inv, flip 97 | 98 | def apply_img_affine(img, affine_inv, method=Image.LANCZOS): 99 | return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.LANCZOS) 100 | 101 | def apply_lm_affine(landmark, affine, flip, size): 102 | _, h = size 103 | lm = landmark.copy() 104 | lm[:, 1] = h - 1 - lm[:, 1] 105 | lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1) 106 | lm = lm @ np.transpose(affine) 107 | lm[:, :2] = lm[:, :2] / lm[:, 2:] 108 | lm = lm[:, :2] 109 | lm[:, 1] = h - 1 - lm[:, 1] 110 | if flip: 111 | lm_ = lm.copy() 112 | lm_[:17] = lm[16::-1] 113 | lm_[17:22] = lm[26:21:-1] 114 | lm_[22:27] = lm[21:16:-1] 115 | lm_[31:36] = lm[35:30:-1] 116 | lm_[36:40] = lm[45:41:-1] 117 | lm_[40:42] = lm[47:45:-1] 118 | lm_[42:46] = lm[39:35:-1] 119 | lm_[46:48] = lm[41:39:-1] 120 | lm_[48:55] = lm[54:47:-1] 121 | lm_[55:60] = lm[59:54:-1] 122 | lm_[60:65] = lm[64:59:-1] 123 | lm_[65:68] = lm[67:64:-1] 124 | lm = lm_ 125 | return lm 126 | -------------------------------------------------------------------------------- /eg3d-pose-detection/data/flist_dataset.py: -------------------------------------------------------------------------------- 1 | """This script defines the custom dataset for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os.path 5 | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | import random 9 | import util.util as util 10 | import numpy as np 11 | import json 12 | import torch 13 | from scipy.io import loadmat, savemat 14 | import pickle 15 | from util.preprocess import align_img, estimate_norm 16 | from util.load_mats import load_lm3d 17 | 18 | 19 | def default_flist_reader(flist): 20 | """ 21 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 22 | """ 23 | imlist = [] 24 | with open(flist, 'r') as rf: 25 | for line in rf.readlines(): 26 | impath = line.strip() 27 | imlist.append(impath) 28 | 29 | return imlist 30 | 31 | def jason_flist_reader(flist): 32 | with open(flist, 'r') as fp: 33 | info = json.load(fp) 34 | return info 35 | 36 | def parse_label(label): 37 | return torch.tensor(np.array(label).astype(np.float32)) 38 | 39 | 40 | class FlistDataset(BaseDataset): 41 | """ 42 | It requires one directories to host training images '/path/to/data/train' 43 | You can train the model with the dataset flag '--dataroot /path/to/data'. 44 | """ 45 | 46 | def __init__(self, opt): 47 | """Initialize this dataset class. 48 | 49 | Parameters: 50 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 51 | """ 52 | BaseDataset.__init__(self, opt) 53 | 54 | self.lm3d_std = load_lm3d(opt.bfm_folder) 55 | 56 | msk_names = default_flist_reader(opt.flist) 57 | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names] 58 | 59 | self.size = len(self.msk_paths) 60 | self.opt = opt 61 | 62 | self.name = 'train' if opt.isTrain else 'val' 63 | if '_' in opt.flist: 64 | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0] 65 | 66 | 67 | def __getitem__(self, index): 68 | """Return a data point and its metadata information. 69 | 70 | Parameters: 71 | index (int) -- a random integer for data indexing 72 | 73 | Returns a dictionary that contains A, B, A_paths and B_paths 74 | img (tensor) -- an image in the input domain 75 | msk (tensor) -- its corresponding attention mask 76 | lm (tensor) -- its corresponding 3d landmarks 77 | im_paths (str) -- image paths 78 | aug_flag (bool) -- a flag used to tell whether its raw or augmented 79 | """ 80 | msk_path = self.msk_paths[index % self.size] # make sure index is within then range 81 | img_path = msk_path.replace('mask/', '') 82 | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt' 83 | 84 | raw_img = Image.open(img_path).convert('RGB') 85 | raw_msk = Image.open(msk_path).convert('RGB') 86 | raw_lm = np.loadtxt(lm_path).astype(np.float32) 87 | 88 | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk) 89 | 90 | aug_flag = self.opt.use_aug and self.opt.isTrain 91 | if aug_flag: 92 | img, lm, msk = self._augmentation(img, lm, self.opt, msk) 93 | 94 | _, H = img.size 95 | M = estimate_norm(lm, H) 96 | transform = get_transform() 97 | img_tensor = transform(img) 98 | msk_tensor = transform(msk)[:1, ...] 99 | lm_tensor = parse_label(lm) 100 | M_tensor = parse_label(M) 101 | 102 | 103 | return {'imgs': img_tensor, 104 | 'lms': lm_tensor, 105 | 'msks': msk_tensor, 106 | 'M': M_tensor, 107 | 'im_paths': img_path, 108 | 'aug_flag': aug_flag, 109 | 'dataset': self.name} 110 | 111 | def _augmentation(self, img, lm, opt, msk=None): 112 | affine, affine_inv, flip = get_affine_mat(opt, img.size) 113 | img = apply_img_affine(img, affine_inv) 114 | lm = apply_lm_affine(lm, affine, flip, img.size) 115 | if msk is not None: 116 | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR) 117 | return img, lm, msk 118 | 119 | 120 | 121 | 122 | def __len__(self): 123 | """Return the total number of images in the dataset. 124 | """ 125 | return self.size 126 | -------------------------------------------------------------------------------- /eg3d-pose-detection/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | import numpy as np 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | '.tif', '.TIF', '.tiff', '.TIFF', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, max_dataset_size=float("inf")): 25 | images = [] 26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | return images[:min(max_dataset_size, len(images))] 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /eg3d-pose-detection/data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | # from data.image_folder import make_dataset 16 | # from PIL import Image 17 | 18 | 19 | class TemplateDataset(BaseDataset): 20 | """A template dataset class for you to implement custom datasets.""" 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | """ 32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option') 33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 34 | return parser 35 | 36 | def __init__(self, opt): 37 | """Initialize this dataset class. 38 | 39 | Parameters: 40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 41 | 42 | A few things can be done here. 43 | - save the options (have been done in BaseDataset) 44 | - get image paths and meta information of the dataset. 45 | - define the image transformation. 46 | """ 47 | # save the option and dataset root 48 | BaseDataset.__init__(self, opt) 49 | # get the image paths of your dataset; 50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 51 | # define the default transform function. You can use ; You can also define your custom transform function 52 | self.transform = get_transform(opt) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index -- a random integer for data indexing 59 | 60 | Returns: 61 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 62 | 63 | Step 1: get a random image path: e.g., path = self.image_paths[index] 64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 66 | Step 4: return a data point as a dictionary. 67 | """ 68 | path = 'temp' # needs to be a string 69 | data_A = None # needs to be a tensor 70 | data_B = None # needs to be a tensor 71 | return {'data_A': data_A, 'data_B': data_B, 'path': path} 72 | 73 | def __len__(self): 74 | """Return the total number of images.""" 75 | return len(self.image_paths) 76 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | return iresnet18(False, **kwargs) 9 | elif name == "r34": 10 | return iresnet34(False, **kwargs) 11 | elif name == "r50": 12 | return iresnet50(False, **kwargs) 13 | elif name == "r100": 14 | return iresnet100(False, **kwargs) 15 | elif name == "r200": 16 | return iresnet200(False, **kwargs) 17 | elif name == "r2060": 18 | from .iresnet2060 import iresnet2060 19 | return iresnet2060(False, **kwargs) 20 | 21 | elif name == "mbf": 22 | fp16 = kwargs.get("fp16", False) 23 | num_features = kwargs.get("num_features", 512) 24 | return get_mbf(fp16=fp16, num_features=num_features) 25 | 26 | elif name == "mbf_large": 27 | from .mobilefacenet import get_mbf_large 28 | fp16 = kwargs.get("fp16", False) 29 | num_features = kwargs.get("num_features", 512) 30 | return get_mbf_large(fp16=fp16, num_features=num_features) 31 | 32 | elif name == "vit_t": 33 | num_features = kwargs.get("num_features", 512) 34 | from .vit import VisionTransformer 35 | return VisionTransformer( 36 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 37 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 38 | 39 | elif name == "vit_t_dp005_mask0": # For WebFace42M 40 | num_features = kwargs.get("num_features", 512) 41 | from .vit import VisionTransformer 42 | return VisionTransformer( 43 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 44 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 45 | 46 | elif name == "vit_s": 47 | num_features = kwargs.get("num_features", 512) 48 | from .vit import VisionTransformer 49 | return VisionTransformer( 50 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 51 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 52 | 53 | elif name == "vit_s_dp005_mask_0": # For WebFace42M 54 | num_features = kwargs.get("num_features", 512) 55 | from .vit import VisionTransformer 56 | return VisionTransformer( 57 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 58 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 59 | 60 | elif name == "vit_b": 61 | # this is a feature 62 | num_features = kwargs.get("num_features", 512) 63 | from .vit import VisionTransformer 64 | return VisionTransformer( 65 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 66 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True) 67 | 68 | elif name == "vit_b_dp005_mask_005": # For WebFace42M 69 | # this is a feature 70 | num_features = kwargs.get("num_features", 512) 71 | from .vit import VisionTransformer 72 | return VisionTransformer( 73 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 74 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 75 | 76 | elif name == "vit_l_dp005_mask_005": # For WebFace42M 77 | # this is a feature 78 | num_features = kwargs.get("num_features", 512) 79 | from .vit import VisionTransformer 80 | return VisionTransformer( 81 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24, 82 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 83 | 84 | else: 85 | raise ValueError() 86 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/backbones/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] 6 | using_ckpt = False 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, 11 | out_planes, 12 | kernel_size=3, 13 | stride=stride, 14 | padding=dilation, 15 | groups=groups, 16 | bias=False, 17 | dilation=dilation) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, 23 | out_planes, 24 | kernel_size=1, 25 | stride=stride, 26 | bias=False) 27 | 28 | 29 | class IBasicBlock(nn.Module): 30 | expansion = 1 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, 32 | groups=1, base_width=64, dilation=1): 33 | super(IBasicBlock, self).__init__() 34 | if groups != 1 or base_width != 64: 35 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 36 | if dilation > 1: 37 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 38 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 39 | self.conv1 = conv3x3(inplanes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 41 | self.prelu = nn.PReLU(planes) 42 | self.conv2 = conv3x3(planes, planes, stride) 43 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward_impl(self, x): 48 | identity = x 49 | out = self.bn1(x) 50 | out = self.conv1(out) 51 | out = self.bn2(out) 52 | out = self.prelu(out) 53 | out = self.conv2(out) 54 | out = self.bn3(out) 55 | if self.downsample is not None: 56 | identity = self.downsample(x) 57 | out += identity 58 | return out 59 | 60 | def forward(self, x): 61 | if self.training and using_ckpt: 62 | return checkpoint(self.forward_impl, x) 63 | else: 64 | return self.forward_impl(x) 65 | 66 | 67 | class IResNet(nn.Module): 68 | fc_scale = 7 * 7 69 | def __init__(self, 70 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 71 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 72 | super(IResNet, self).__init__() 73 | self.extra_gflops = 0.0 74 | self.fp16 = fp16 75 | self.inplanes = 64 76 | self.dilation = 1 77 | if replace_stride_with_dilation is None: 78 | replace_stride_with_dilation = [False, False, False] 79 | if len(replace_stride_with_dilation) != 3: 80 | raise ValueError("replace_stride_with_dilation should be None " 81 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 82 | self.groups = groups 83 | self.base_width = width_per_group 84 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 86 | self.prelu = nn.PReLU(self.inplanes) 87 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 88 | self.layer2 = self._make_layer(block, 89 | 128, 90 | layers[1], 91 | stride=2, 92 | dilate=replace_stride_with_dilation[0]) 93 | self.layer3 = self._make_layer(block, 94 | 256, 95 | layers[2], 96 | stride=2, 97 | dilate=replace_stride_with_dilation[1]) 98 | self.layer4 = self._make_layer(block, 99 | 512, 100 | layers[3], 101 | stride=2, 102 | dilate=replace_stride_with_dilation[2]) 103 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 104 | self.dropout = nn.Dropout(p=dropout, inplace=True) 105 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 106 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 107 | nn.init.constant_(self.features.weight, 1.0) 108 | self.features.weight.requires_grad = False 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | nn.init.normal_(m.weight, 0, 0.1) 113 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 114 | nn.init.constant_(m.weight, 1) 115 | nn.init.constant_(m.bias, 0) 116 | 117 | if zero_init_residual: 118 | for m in self.modules(): 119 | if isinstance(m, IBasicBlock): 120 | nn.init.constant_(m.bn2.weight, 0) 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 123 | downsample = None 124 | previous_dilation = self.dilation 125 | if dilate: 126 | self.dilation *= stride 127 | stride = 1 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | conv1x1(self.inplanes, planes * block.expansion, stride), 131 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 132 | ) 133 | layers = [] 134 | layers.append( 135 | block(self.inplanes, planes, stride, downsample, self.groups, 136 | self.base_width, previous_dilation)) 137 | self.inplanes = planes * block.expansion 138 | for _ in range(1, blocks): 139 | layers.append( 140 | block(self.inplanes, 141 | planes, 142 | groups=self.groups, 143 | base_width=self.base_width, 144 | dilation=self.dilation)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | with torch.cuda.amp.autocast(self.fp16): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.prelu(x) 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | x = self.bn2(x) 158 | x = torch.flatten(x, 1) 159 | x = self.dropout(x) 160 | x = self.fc(x.float() if self.fp16 else x) 161 | x = self.features(x) 162 | return x 163 | 164 | 165 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 166 | model = IResNet(block, layers, **kwargs) 167 | if pretrained: 168 | raise ValueError() 169 | return model 170 | 171 | 172 | def iresnet18(pretrained=False, progress=True, **kwargs): 173 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 174 | progress, **kwargs) 175 | 176 | 177 | def iresnet34(pretrained=False, progress=True, **kwargs): 178 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 179 | progress, **kwargs) 180 | 181 | 182 | def iresnet50(pretrained=False, progress=True, **kwargs): 183 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 184 | progress, **kwargs) 185 | 186 | 187 | def iresnet100(pretrained=False, progress=True, **kwargs): 188 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 189 | progress, **kwargs) 190 | 191 | 192 | def iresnet200(pretrained=False, progress=True, **kwargs): 193 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 194 | progress, **kwargs) 195 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/backbones/iresnet2060.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | assert torch.__version__ >= "1.8.1" 5 | from torch.utils.checkpoint import checkpoint_sequential 6 | 7 | __all__ = ['iresnet2060'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=dilation, 17 | groups=groups, 18 | bias=False, 19 | dilation=dilation) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, 25 | out_planes, 26 | kernel_size=1, 27 | stride=stride, 28 | bias=False) 29 | 30 | 31 | class IBasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, 35 | groups=1, base_width=64, dilation=1): 36 | super(IBasicBlock, self).__init__() 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) 42 | self.conv1 = conv3x3(inplanes, planes) 43 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) 44 | self.prelu = nn.PReLU(planes) 45 | self.conv2 = conv3x3(planes, planes, stride) 46 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | out = self.bn1(x) 53 | out = self.conv1(out) 54 | out = self.bn2(out) 55 | out = self.prelu(out) 56 | out = self.conv2(out) 57 | out = self.bn3(out) 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | out += identity 61 | return out 62 | 63 | 64 | class IResNet(nn.Module): 65 | fc_scale = 7 * 7 66 | 67 | def __init__(self, 68 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 69 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 70 | super(IResNet, self).__init__() 71 | self.fp16 = fp16 72 | self.inplanes = 64 73 | self.dilation = 1 74 | if replace_stride_with_dilation is None: 75 | replace_stride_with_dilation = [False, False, False] 76 | if len(replace_stride_with_dilation) != 3: 77 | raise ValueError("replace_stride_with_dilation should be None " 78 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 79 | self.groups = groups 80 | self.base_width = width_per_group 81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 83 | self.prelu = nn.PReLU(self.inplanes) 84 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 85 | self.layer2 = self._make_layer(block, 86 | 128, 87 | layers[1], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[0]) 90 | self.layer3 = self._make_layer(block, 91 | 256, 92 | layers[2], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[1]) 95 | self.layer4 = self._make_layer(block, 96 | 512, 97 | layers[3], 98 | stride=2, 99 | dilate=replace_stride_with_dilation[2]) 100 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) 101 | self.dropout = nn.Dropout(p=dropout, inplace=True) 102 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 103 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 104 | nn.init.constant_(self.features.weight, 1.0) 105 | self.features.weight.requires_grad = False 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.normal_(m.weight, 0, 0.1) 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | if zero_init_residual: 115 | for m in self.modules(): 116 | if isinstance(m, IBasicBlock): 117 | nn.init.constant_(m.bn2.weight, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 120 | downsample = None 121 | previous_dilation = self.dilation 122 | if dilate: 123 | self.dilation *= stride 124 | stride = 1 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 129 | ) 130 | layers = [] 131 | layers.append( 132 | block(self.inplanes, planes, stride, downsample, self.groups, 133 | self.base_width, previous_dilation)) 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append( 137 | block(self.inplanes, 138 | planes, 139 | groups=self.groups, 140 | base_width=self.base_width, 141 | dilation=self.dilation)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def checkpoint(self, func, num_seg, x): 146 | if self.training: 147 | return checkpoint_sequential(func, num_seg, x) 148 | else: 149 | return func(x) 150 | 151 | def forward(self, x): 152 | with torch.cuda.amp.autocast(self.fp16): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.prelu(x) 156 | x = self.layer1(x) 157 | x = self.checkpoint(self.layer2, 20, x) 158 | x = self.checkpoint(self.layer3, 100, x) 159 | x = self.layer4(x) 160 | x = self.bn2(x) 161 | x = torch.flatten(x, 1) 162 | x = self.dropout(x) 163 | x = self.fc(x.float() if self.fp16 else x) 164 | x = self.features(x) 165 | return x 166 | 167 | 168 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 169 | model = IResNet(block, layers, **kwargs) 170 | if pretrained: 171 | raise ValueError() 172 | return model 173 | 174 | 175 | def iresnet2060(pretrained=False, progress=True, **kwargs): 176 | return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) 177 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/backbones/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module 8 | import torch 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class ConvBlock(Module): 17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 18 | super(ConvBlock, self).__init__() 19 | self.layers = nn.Sequential( 20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), 21 | BatchNorm2d(num_features=out_c), 22 | PReLU(num_parameters=out_c) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.layers(x) 27 | 28 | 29 | class LinearBlock(Module): 30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 31 | super(LinearBlock, self).__init__() 32 | self.layers = nn.Sequential( 33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), 34 | BatchNorm2d(num_features=out_c) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.layers(x) 39 | 40 | 41 | class DepthWise(Module): 42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 43 | super(DepthWise, self).__init__() 44 | self.residual = residual 45 | self.layers = nn.Sequential( 46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), 47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), 48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 49 | ) 50 | 51 | def forward(self, x): 52 | short_cut = None 53 | if self.residual: 54 | short_cut = x 55 | x = self.layers(x) 56 | if self.residual: 57 | output = short_cut + x 58 | else: 59 | output = x 60 | return output 61 | 62 | 63 | class Residual(Module): 64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 65 | super(Residual, self).__init__() 66 | modules = [] 67 | for _ in range(num_block): 68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) 69 | self.layers = Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class GDC(Module): 76 | def __init__(self, embedding_size): 77 | super(GDC, self).__init__() 78 | self.layers = nn.Sequential( 79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), 80 | Flatten(), 81 | Linear(512, embedding_size, bias=False), 82 | BatchNorm1d(embedding_size)) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | 87 | 88 | class MobileFaceNet(Module): 89 | def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2): 90 | super(MobileFaceNet, self).__init__() 91 | self.scale = scale 92 | self.fp16 = fp16 93 | self.layers = nn.ModuleList() 94 | self.layers.append( 95 | ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 96 | ) 97 | if blocks[0] == 1: 98 | self.layers.append( 99 | ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 100 | ) 101 | else: 102 | self.layers.append( 103 | Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 104 | ) 105 | 106 | self.layers.extend( 107 | [ 108 | DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), 109 | Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 110 | DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), 111 | Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 112 | DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), 113 | Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 114 | ]) 115 | 116 | self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 117 | self.features = GDC(num_features) 118 | self._initialize_weights() 119 | 120 | def _initialize_weights(self): 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.Linear): 130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 131 | if m.bias is not None: 132 | m.bias.data.zero_() 133 | 134 | def forward(self, x): 135 | with torch.cuda.amp.autocast(self.fp16): 136 | for func in self.layers: 137 | x = func(x) 138 | x = self.conv_sep(x.float() if self.fp16 else x) 139 | x = self.features(x) 140 | return x 141 | 142 | 143 | def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2): 144 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 145 | 146 | def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4): 147 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 148 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.margin_list = (1.0, 0.0, 0.4) 7 | config.network = "mbf" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 512 # total_batch_size = batch_size * num_gpus 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 30 * 10000 20 | config.num_image = 100000 21 | config.num_epoch = 30 22 | config.warmup_epoch = -1 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/configs/__init__.py -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | 9 | # Margin Base Softmax 10 | config.margin_list = (1.0, 0.5, 0.0) 11 | config.network = "r50" 12 | config.resume = False 13 | config.save_all_states = False 14 | config.output = "ms1mv3_arcface_r50" 15 | 16 | config.embedding_size = 512 17 | 18 | # Partial FC 19 | config.sample_rate = 1 20 | config.interclass_filtering_threshold = 0 21 | 22 | config.fp16 = False 23 | config.batch_size = 128 24 | 25 | # For SGD 26 | config.optimizer = "sgd" 27 | config.lr = 0.1 28 | config.momentum = 0.9 29 | config.weight_decay = 5e-4 30 | 31 | # For AdamW 32 | # config.optimizer = "adamw" 33 | # config.lr = 0.001 34 | # config.weight_decay = 0.1 35 | 36 | config.verbose = 2000 37 | config.frequent = 10 38 | 39 | # For Large Sacle Dataset, such as WebFace42M 40 | config.dali = False 41 | 42 | 43 | # setup seed 44 | config.seed = 2048 45 | 46 | # dataload numworkers 47 | config.num_workers = 8 48 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/ms1mv2_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/ms1mv2_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/ms1mv2_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/ms1mv3_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_conflict_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_flip_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | 2 | from easydict import EasyDict as edict 3 | 4 | # make training faster 5 | # our RAM is 256G 6 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 7 | 8 | config = edict() 9 | config.margin_list = (1.0, 0.0, 0.4) 10 | config.network = "r100" 11 | config.resume = False 12 | config.output = None 13 | config.embedding_size = 512 14 | config.sample_rate = 0.2 15 | config.interclass_filtering_threshold = 0 16 | config.fp16 = True 17 | config.weight_decay = 5e-4 18 | config.batch_size = 128 19 | config.optimizer = "sgd" 20 | config.lr = 0.1 21 | config.verbose = 2000 22 | config.dali = False 23 | 24 | config.rec = "/train_tmp/WebFace12M" 25 | config.num_classes = 617970 26 | config.num_image = 12720066 27 | config.num_epoch = 20 28 | config.warmup_epoch = 0 29 | config.val_targets = [] 30 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_r100.py: -------------------------------------------------------------------------------- 1 | 2 | from easydict import EasyDict as edict 3 | 4 | # make training faster 5 | # our RAM is 256G 6 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 7 | 8 | config = edict() 9 | config.margin_list = (1.0, 0.0, 0.4) 10 | config.network = "r100" 11 | config.resume = False 12 | config.output = None 13 | config.embedding_size = 512 14 | config.sample_rate = 1.0 15 | config.interclass_filtering_threshold = 0 16 | config.fp16 = True 17 | config.weight_decay = 5e-4 18 | config.batch_size = 128 19 | config.optimizer = "sgd" 20 | config.lr = 0.1 21 | config.verbose = 2000 22 | config.dali = False 23 | 24 | config.rec = "/train_tmp/WebFace12M" 25 | config.num_classes = 617970 26 | config.num_image = 12720066 27 | config.num_epoch = 20 28 | config.warmup_epoch = 0 29 | config.val_targets = [] 30 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf12m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 256 18 | config.lr = 0.3 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 1 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.6 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 4 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.2 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r200" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_b_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_l_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_s_dp005_mask_0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 512 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf4m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf4m_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/configs/wf4m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/dist.sh: -------------------------------------------------------------------------------- 1 | ip_list=("ip1" "ip2" "ip3" "ip4") 2 | 3 | config=wf42m_pfc03_32gpu_r100 4 | 5 | for((node_rank=0;node_rank<${#ip_list[*]};node_rank++)); 6 | do 7 | ssh face@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \ 8 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 9 | python -m torch.distributed.launch \ 10 | --nproc_per_node=8 \ 11 | --nnodes=${#ip_list[*]} \ 12 | --node_rank=$node_rank \ 13 | --master_addr=${ip_list[0]} \ 14 | --master_port=22345 train.py configs/$config" & 15 | done 16 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Eval on ICCV2021-MFR 2 | 3 | coming soon. 4 | 5 | 6 | ## Eval IJBC 7 | You can eval ijbc with pytorch or onnx. 8 | 9 | 10 | 1. Eval IJBC With Onnx 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 13 | ``` 14 | 15 | 2. Eval IJBC With Pytorch 16 | ```shell 17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ 18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \ 19 | --image-path IJB_release/IJBC \ 20 | --result-dir ms1mv3_arcface_r50 \ 21 | --batch-size 128 \ 22 | --job ms1mv3_arcface_r50 \ 23 | --target IJBC \ 24 | --network iresnet50 25 | ``` 26 | 27 | 28 | ## Inference 29 | 30 | ```shell 31 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 32 | ``` 33 | 34 | 35 | ## Result 36 | 37 | | Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | 38 | |:---------------|:--------------------|:------------|:------------|:------------| 39 | | WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 | 40 | | WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 | 41 | | WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | 42 | | WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | 43 | | WF12M | r100 | 94.69 | 97.59 | 95.97 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/docs/install.md: -------------------------------------------------------------------------------- 1 | ## [v1.11.0](https://pytorch.org/) 2 | 3 | ## [v1.9.0](https://pytorch.org/get-started/previous-versions/#linux-and-windows-7) 4 | ### Linux and Windows 5 | ```shell 6 | # CUDA 11.1 7 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 8 | 9 | # CUDA 10.2 10 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 11 | ``` 12 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/docs/install_dali.md: -------------------------------------------------------------------------------- 1 | TODO 2 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/docs/modelzoo.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/docs/modelzoo.md -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/docs/prepare_webface42m.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## 1. Download Datasets and Unzip 5 | 6 | Download WebFace42M from [https://www.face-benchmark.org/download.html](https://www.face-benchmark.org/download.html). 7 | The raw data of `WebFace42M` will have 10 directories after being unarchived: 8 | `WebFace4M` contains 1 directory: `0`. 9 | `WebFace12M` contains 3 directories: `0,1,2`. 10 | `WebFace42M` contains 10 directories: `0,1,2,3,4,5,6,7,8,9`. 11 | 12 | ## 2. Create Shuffled Rec File for DALI 13 | 14 | Note: Shuffled rec is very important to DALI, and rec without shuffled can cause performance degradation, origin insightface style rec file 15 | do not support Nvidia DALI, you must follow this command [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) to generate a shuffled rec file. 16 | 17 | ```shell 18 | # directories and files for yours datsaets 19 | /WebFace42M_Root 20 | ├── 0_0_0000000 21 | │   ├── 0_0.jpg 22 | │   ├── 0_1.jpg 23 | │   ├── 0_2.jpg 24 | │   ├── 0_3.jpg 25 | │   └── 0_4.jpg 26 | ├── 0_0_0000001 27 | │   ├── 0_5.jpg 28 | │   ├── 0_6.jpg 29 | │   ├── 0_7.jpg 30 | │   ├── 0_8.jpg 31 | │   └── 0_9.jpg 32 | ├── 0_0_0000002 33 | │   ├── 0_10.jpg 34 | │   ├── 0_11.jpg 35 | │   ├── 0_12.jpg 36 | │   ├── 0_13.jpg 37 | │   ├── 0_14.jpg 38 | │   ├── 0_15.jpg 39 | │   ├── 0_16.jpg 40 | │   └── 0_17.jpg 41 | ├── 0_0_0000003 42 | │   ├── 0_18.jpg 43 | │   ├── 0_19.jpg 44 | │   └── 0_20.jpg 45 | ├── 0_0_0000004 46 | 47 | 48 | 49 | # 1) create train.lst using follow command 50 | python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root 51 | 52 | # 2) create train.rec and train.idx using train.lst using following command 53 | python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root 54 | ``` 55 | 56 | Finally, you will get three files: `train.lst`, `train.rec`, `train.idx`. which `train.idx`, `train.rec` are using for training. 57 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/docs/speed_benchmark.md: -------------------------------------------------------------------------------- 1 | ## Test Training Speed 2 | 3 | - Test Commands 4 | 5 | You need to use the following two commands to test the Partial FC training performance. 6 | The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, 7 | batch size is 1024. 8 | ```shell 9 | # Model Parallel 10 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions 11 | # Partial FC 0.1 12 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc 13 | ``` 14 | 15 | - GPU Memory 16 | 17 | ``` 18 | # (Model Parallel) gpustat -i 19 | [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB 20 | [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB 21 | [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB 22 | [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB 23 | [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB 24 | [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB 25 | [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB 26 | [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB 27 | 28 | # (Partial FC 0.1) gpustat -i 29 | [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· 30 | [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· 31 | [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· 32 | [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· 33 | [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· 34 | [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· 35 | [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· 36 | [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· 37 | ``` 38 | 39 | - Training Speed 40 | 41 | ```python 42 | # (Model Parallel) trainging.log 43 | Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 44 | Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 45 | Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 46 | Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 47 | Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 48 | 49 | # (Partial FC 0.1) trainging.log 50 | Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 51 | Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 52 | Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 53 | Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 54 | Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 55 | ``` 56 | 57 | In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, 58 | and the training speed is 2.5 times faster than the model parallel. 59 | 60 | 61 | ## Speed Benchmark 62 | 63 | 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) 64 | 65 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | 66 | | :--- | :--- | :--- | :--- | 67 | |125000 | 4681 | 4824 | 5004 | 68 | |250000 | 4047 | 4521 | 4976 | 69 | |500000 | 3087 | 4013 | 4900 | 70 | |1000000 | 2090 | 3449 | 4803 | 71 | |1400000 | 1672 | 3043 | 4738 | 72 | |2000000 | - | 2593 | 4626 | 73 | |4000000 | - | 1748 | 4208 | 74 | |5500000 | - | 1389 | 3975 | 75 | |8000000 | - | - | 3565 | 76 | |16000000 | - | - | 2679 | 77 | |29000000 | - | - | 1855 | 78 | 79 | 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) 80 | 81 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | 82 | | :--- | :--- | :--- | :--- | 83 | |125000 | 7358 | 5306 | 4868 | 84 | |250000 | 9940 | 5826 | 5004 | 85 | |500000 | 14220 | 7114 | 5202 | 86 | |1000000 | 23708 | 9966 | 5620 | 87 | |1400000 | 32252 | 11178 | 6056 | 88 | |2000000 | - | 13978 | 6472 | 89 | |4000000 | - | 23238 | 8284 | 90 | |5500000 | - | 32188 | 9854 | 91 | |8000000 | - | - | 12310 | 92 | |16000000 | - | - | 19950 | 93 | |29000000 | - | - | 32324 | 94 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/eval/__init__.py -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/flops.py: -------------------------------------------------------------------------------- 1 | from ptflops import get_model_complexity_info 2 | from backbones import get_model 3 | import argparse 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description='') 7 | parser.add_argument('n', type=str, default="r100") 8 | args = parser.parse_args() 9 | net = get_model(args.n) 10 | macs, params = get_model_complexity_info( 11 | net, (3, 112, 112), as_strings=False, 12 | print_per_layer_stat=True, verbose=True) 13 | gmacs = macs / (1000**3) 14 | print("%.3f GFLOPs"%gmacs) 15 | print("%.3f Mparams"%(params/(1000**2))) 16 | 17 | if hasattr(net, "extra_gflops"): 18 | print("%.3f Extra-GFLOPs"%net.extra_gflops) 19 | print("%.3f Total-GFLOPs"%(gmacs+net.extra_gflops)) 20 | 21 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from backbones import get_model 8 | 9 | 10 | @torch.no_grad() 11 | def inference(weight, name, img): 12 | if img is None: 13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 14 | else: 15 | img = cv2.imread(img) 16 | img = cv2.resize(img, (112, 112)) 17 | 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | img = np.transpose(img, (2, 0, 1)) 20 | img = torch.from_numpy(img).unsqueeze(0).float() 21 | img.div_(255).sub_(0.5).div_(0.5) 22 | net = get_model(name, fp16=False) 23 | net.load_state_dict(torch.load(weight)) 24 | net.eval() 25 | feat = net(img).numpy() 26 | print(feat) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') 31 | parser.add_argument('--network', type=str, default='r50', help='backbone network') 32 | parser.add_argument('--weight', type=str, default='') 33 | parser.add_argument('--img', type=str, default=None) 34 | args = parser.parse_args() 35 | inference(args.weight, args.network, args.img) 36 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class CombinedMarginLoss(torch.nn.Module): 6 | def __init__(self, 7 | s, 8 | m1, 9 | m2, 10 | m3, 11 | interclass_filtering_threshold=0): 12 | super().__init__() 13 | self.s = s 14 | self.m1 = m1 15 | self.m2 = m2 16 | self.m3 = m3 17 | self.interclass_filtering_threshold = interclass_filtering_threshold 18 | 19 | # For ArcFace 20 | self.cos_m = math.cos(self.m2) 21 | self.sin_m = math.sin(self.m2) 22 | self.theta = math.cos(math.pi - self.m2) 23 | self.sinmm = math.sin(math.pi - self.m2) * self.m2 24 | self.easy_margin = False 25 | 26 | 27 | def forward(self, logits, labels): 28 | index_positive = torch.where(labels != -1)[0] 29 | 30 | if self.interclass_filtering_threshold > 0: 31 | with torch.no_grad(): 32 | dirty = logits > self.interclass_filtering_threshold 33 | dirty = dirty.float() 34 | mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) 35 | mask.scatter_(1, labels[index_positive], 0) 36 | dirty[index_positive] *= mask 37 | tensor_mul = 1 - dirty 38 | logits = tensor_mul * logits 39 | 40 | target_logit = logits[index_positive, labels[index_positive].view(-1)] 41 | 42 | if self.m1 == 1.0 and self.m3 == 0.0: 43 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) 44 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) 45 | if self.easy_margin: 46 | final_target_logit = torch.where( 47 | target_logit > 0, cos_theta_m, target_logit) 48 | else: 49 | final_target_logit = torch.where( 50 | target_logit > self.theta, cos_theta_m, target_logit - self.sinmm) 51 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 52 | logits = logits * self.s 53 | 54 | elif self.m3 > 0: 55 | final_target_logit = target_logit - self.m3 56 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 57 | logits = logits * self.s 58 | else: 59 | raise 60 | 61 | return logits 62 | 63 | class ArcFace(torch.nn.Module): 64 | """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): 65 | """ 66 | def __init__(self, s=64.0, margin=0.5): 67 | super(ArcFace, self).__init__() 68 | self.scale = s 69 | self.cos_m = math.cos(margin) 70 | self.sin_m = math.sin(margin) 71 | self.theta = math.cos(math.pi - margin) 72 | self.sinmm = math.sin(math.pi - margin) * margin 73 | self.easy_margin = False 74 | 75 | 76 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 77 | index = torch.where(labels != -1)[0] 78 | target_logit = logits[index, labels[index].view(-1)] 79 | 80 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) 81 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) 82 | if self.easy_margin: 83 | final_target_logit = torch.where( 84 | target_logit > 0, cos_theta_m, target_logit) 85 | else: 86 | final_target_logit = torch.where( 87 | target_logit > self.theta, cos_theta_m, target_logit - self.sinmm) 88 | 89 | logits[index, labels[index].view(-1)] = final_target_logit 90 | logits = logits * self.scale 91 | return logits 92 | 93 | 94 | class CosFace(torch.nn.Module): 95 | def __init__(self, s=64.0, m=0.40): 96 | super(CosFace, self).__init__() 97 | self.s = s 98 | self.m = m 99 | 100 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 101 | index = torch.where(labels != -1)[0] 102 | target_logit = logits[index, labels[index].view(-1)] 103 | final_target_logit = target_logit - self.m 104 | logits[index, labels[index].view(-1)] = final_target_logit 105 | logits = logits * self.s 106 | return logits 107 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyScheduler(_LRScheduler): 5 | def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): 6 | self.base_lr = base_lr 7 | self.warmup_lr_init = 0.0001 8 | self.max_steps: int = max_steps 9 | self.warmup_steps: int = warmup_steps 10 | self.power = 2 11 | super(PolyScheduler, self).__init__(optimizer, -1, False) 12 | self.last_epoch = last_epoch 13 | 14 | def get_warmup_lr(self): 15 | alpha = float(self.last_epoch) / float(self.warmup_steps) 16 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 17 | 18 | def get_lr(self): 19 | if self.last_epoch == -1: 20 | return [self.warmup_lr_init for _ in self.optimizer.param_groups] 21 | if self.last_epoch < self.warmup_steps: 22 | return self.get_warmup_lr() 23 | else: 24 | alpha = pow( 25 | 1 26 | - float(self.last_epoch - self.warmup_steps) 27 | / float(self.max_steps - self.warmup_steps), 28 | self.power, 29 | ) 30 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 31 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/run.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --nnodes=1 \ 5 | --node_rank=0 \ 6 | --master_addr="127.0.0.1" \ 7 | --master_port=12345 train.py $@ 8 | 9 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh 10 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float) 10 | img = (img / 255. - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight, strict=True) 16 | net.eval() 17 | torch.onnx.export(net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset) 18 | model = onnx.load(output) 19 | graph = model.graph 20 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' 21 | if simplify: 22 | from onnxsim import simplify 23 | model, check = simplify(model) 24 | assert check, "Simplified ONNX model could not be validated" 25 | onnx.save(model, output) 26 | 27 | 28 | if __name__ == '__main__': 29 | import os 30 | import argparse 31 | from backbones import get_model 32 | 33 | parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') 34 | parser.add_argument('input', type=str, help='input backbone.pth file or path') 35 | parser.add_argument('--output', type=str, default=None, help='output onnx path') 36 | parser.add_argument('--network', type=str, default=None, help='backbone network') 37 | parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') 38 | args = parser.parse_args() 39 | input_file = args.input 40 | if os.path.isdir(input_file): 41 | input_file = os.path.join(input_file, "model.pt") 42 | assert os.path.exists(input_file) 43 | # model_name = os.path.basename(os.path.dirname(input_file)).lower() 44 | # params = model_name.split("_") 45 | # if len(params) >= 3 and params[1] in ('arcface', 'cosface'): 46 | # if args.network is None: 47 | # args.network = params[2] 48 | assert args.network is not None 49 | print(args) 50 | backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512) 51 | if args.output is None: 52 | args.output = os.path.join(os.path.dirname(args.input), "model.onnx") 53 | convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify) 54 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from torch import distributed 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from backbones import get_model 12 | from dataset import get_dataloader 13 | from losses import CombinedMarginLoss 14 | from lr_scheduler import PolyScheduler 15 | from partial_fc import PartialFC, PartialFCAdamW 16 | from utils.utils_callbacks import CallBackLogging, CallBackVerification 17 | from utils.utils_config import get_config 18 | from utils.utils_logging import AverageMeter, init_logging 19 | from utils.utils_distributed_sampler import setup_seed 20 | 21 | assert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \ 22 | we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future." 23 | 24 | try: 25 | world_size = int(os.environ["WORLD_SIZE"]) 26 | rank = int(os.environ["RANK"]) 27 | distributed.init_process_group("nccl") 28 | except KeyError: 29 | world_size = 1 30 | rank = 0 31 | distributed.init_process_group( 32 | backend="nccl", 33 | init_method="tcp://127.0.0.1:12584", 34 | rank=rank, 35 | world_size=world_size, 36 | ) 37 | 38 | 39 | def main(args): 40 | 41 | # get config 42 | cfg = get_config(args.config) 43 | # global control random seed 44 | setup_seed(seed=cfg.seed, cuda_deterministic=False) 45 | 46 | torch.cuda.set_device(args.local_rank) 47 | 48 | os.makedirs(cfg.output, exist_ok=True) 49 | init_logging(rank, cfg.output) 50 | 51 | summary_writer = ( 52 | SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) 53 | if rank == 0 54 | else None 55 | ) 56 | 57 | train_loader = get_dataloader( 58 | cfg.rec, 59 | args.local_rank, 60 | cfg.batch_size, 61 | cfg.dali, 62 | cfg.seed, 63 | cfg.num_workers 64 | ) 65 | 66 | backbone = get_model( 67 | cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda() 68 | 69 | backbone = torch.nn.parallel.DistributedDataParallel( 70 | module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16, 71 | find_unused_parameters=True) 72 | 73 | backbone.train() 74 | # FIXME using gradient checkpoint if there are some unused parameters will cause error 75 | backbone._set_static_graph() 76 | 77 | margin_loss = CombinedMarginLoss( 78 | 64, 79 | cfg.margin_list[0], 80 | cfg.margin_list[1], 81 | cfg.margin_list[2], 82 | cfg.interclass_filtering_threshold 83 | ) 84 | 85 | if cfg.optimizer == "sgd": 86 | module_partial_fc = PartialFC( 87 | margin_loss, cfg.embedding_size, cfg.num_classes, 88 | cfg.sample_rate, cfg.fp16) 89 | module_partial_fc.train().cuda() 90 | # TODO the params of partial fc must be last in the params list 91 | opt = torch.optim.SGD( 92 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], 93 | lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay) 94 | 95 | elif cfg.optimizer == "adamw": 96 | module_partial_fc = PartialFCAdamW( 97 | margin_loss, cfg.embedding_size, cfg.num_classes, 98 | cfg.sample_rate, cfg.fp16) 99 | module_partial_fc.train().cuda() 100 | opt = torch.optim.AdamW( 101 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], 102 | lr=cfg.lr, weight_decay=cfg.weight_decay) 103 | else: 104 | raise 105 | 106 | cfg.total_batch_size = cfg.batch_size * world_size 107 | cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch 108 | cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch 109 | 110 | lr_scheduler = PolyScheduler( 111 | optimizer=opt, 112 | base_lr=cfg.lr, 113 | max_steps=cfg.total_step, 114 | warmup_steps=cfg.warmup_step, 115 | last_epoch=-1 116 | ) 117 | 118 | start_epoch = 0 119 | global_step = 0 120 | if cfg.resume: 121 | dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) 122 | start_epoch = dict_checkpoint["epoch"] 123 | global_step = dict_checkpoint["global_step"] 124 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) 125 | module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) 126 | opt.load_state_dict(dict_checkpoint["state_optimizer"]) 127 | lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) 128 | del dict_checkpoint 129 | 130 | for key, value in cfg.items(): 131 | num_space = 25 - len(key) 132 | logging.info(": " + key + " " * num_space + str(value)) 133 | 134 | callback_verification = CallBackVerification( 135 | val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer 136 | ) 137 | callback_logging = CallBackLogging( 138 | frequent=cfg.frequent, 139 | total_step=cfg.total_step, 140 | batch_size=cfg.batch_size, 141 | start_step = global_step, 142 | writer=summary_writer 143 | ) 144 | 145 | loss_am = AverageMeter() 146 | amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) 147 | 148 | for epoch in range(start_epoch, cfg.num_epoch): 149 | 150 | if isinstance(train_loader, DataLoader): 151 | train_loader.sampler.set_epoch(epoch) 152 | for _, (img, local_labels) in enumerate(train_loader): 153 | global_step += 1 154 | local_embeddings = backbone(img) 155 | loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt) 156 | 157 | if cfg.fp16: 158 | amp.scale(loss).backward() 159 | amp.unscale_(opt) 160 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 161 | amp.step(opt) 162 | amp.update() 163 | else: 164 | loss.backward() 165 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 166 | opt.step() 167 | 168 | opt.zero_grad() 169 | lr_scheduler.step() 170 | 171 | with torch.no_grad(): 172 | loss_am.update(loss.item(), 1) 173 | callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp) 174 | 175 | if global_step % cfg.verbose == 0 and global_step > 0: 176 | callback_verification(global_step, backbone) 177 | 178 | if cfg.save_all_states: 179 | checkpoint = { 180 | "epoch": epoch + 1, 181 | "global_step": global_step, 182 | "state_dict_backbone": backbone.module.state_dict(), 183 | "state_dict_softmax_fc": module_partial_fc.state_dict(), 184 | "state_optimizer": opt.state_dict(), 185 | "state_lr_scheduler": lr_scheduler.state_dict() 186 | } 187 | torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) 188 | 189 | if rank == 0: 190 | path_module = os.path.join(cfg.output, "model.pt") 191 | torch.save(backbone.module.state_dict(), path_module) 192 | 193 | if cfg.dali: 194 | train_loader.reset() 195 | 196 | if rank == 0: 197 | path_module = os.path.join(cfg.output, "model.pt") 198 | torch.save(backbone.module.state_dict(), path_module) 199 | 200 | from torch2onnx import convert_onnx 201 | convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx")) 202 | 203 | distributed.destroy_process_group() 204 | 205 | 206 | if __name__ == "__main__": 207 | torch.backends.cudnn.benchmark = True 208 | parser = argparse.ArgumentParser( 209 | description="Distributed Arcface Training in Pytorch") 210 | parser.add_argument("config", type=str, help="py config file") 211 | parser.add_argument("--local_rank", type=int, default=0, help="local_rank") 212 | main(parser.parse_args()) 213 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/utils/__init__.py -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 8 | from prettytable import PrettyTable 9 | from sklearn.metrics import roc_curve, auc 10 | 11 | with open(sys.argv[1], "r") as f: 12 | files = f.readlines() 13 | 14 | files = [x.strip() for x in files] 15 | image_path = "/train_tmp/IJB_release/IJBC" 16 | 17 | 18 | def read_template_pair_list(path): 19 | pairs = pd.read_csv(path, sep=' ', header=None).values 20 | t1 = pairs[:, 0].astype(np.int) 21 | t2 = pairs[:, 1].astype(np.int) 22 | label = pairs[:, 2].astype(np.int) 23 | return t1, t2, label 24 | 25 | 26 | p1, p2, label = read_template_pair_list( 27 | os.path.join('%s/meta' % image_path, 28 | '%s_template_pair_label.txt' % 'ijbc')) 29 | 30 | methods = [] 31 | scores = [] 32 | for file in files: 33 | methods.append(file) 34 | scores.append(np.load(file)) 35 | 36 | methods = np.array(methods) 37 | scores = dict(zip(methods, scores)) 38 | colours = dict( 39 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 40 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 41 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 42 | fig = plt.figure() 43 | for method in methods: 44 | fpr, tpr, _ = roc_curve(label, scores[method]) 45 | roc_auc = auc(fpr, tpr) 46 | fpr = np.flipud(fpr) 47 | tpr = np.flipud(tpr) # select largest tpr at same fpr 48 | plt.plot(fpr, 49 | tpr, 50 | color=colours[method], 51 | lw=1, 52 | label=('[%s (AUC = %0.4f %%)]' % 53 | (method.split('-')[-1], roc_auc * 100))) 54 | tpr_fpr_row = [] 55 | tpr_fpr_row.append(method) 56 | for fpr_iter in np.arange(len(x_labels)): 57 | _, min_index = min( 58 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 59 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 60 | tpr_fpr_table.add_row(tpr_fpr_row) 61 | plt.xlim([10 ** -6, 0.1]) 62 | plt.ylim([0.3, 1.0]) 63 | plt.grid(linestyle='--', linewidth=1) 64 | plt.xticks(x_labels) 65 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 66 | plt.xscale('log') 67 | plt.xlabel('False Positive Rate') 68 | plt.ylabel('True Positive Rate') 69 | plt.title('ROC on IJB') 70 | plt.legend(loc="lower right") 71 | print(tpr_fpr_table) 72 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List 5 | 6 | import torch 7 | 8 | from eval import verification 9 | from utils.utils_logging import AverageMeter 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch import distributed 12 | 13 | 14 | class CallBackVerification(object): 15 | 16 | def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112)): 17 | self.rank: int = distributed.get_rank() 18 | self.highest_acc: float = 0.0 19 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 20 | self.ver_list: List[object] = [] 21 | self.ver_name_list: List[str] = [] 22 | if self.rank is 0: 23 | self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) 24 | 25 | self.summary_writer = summary_writer 26 | 27 | def ver_test(self, backbone: torch.nn.Module, global_step: int): 28 | results = [] 29 | for i in range(len(self.ver_list)): 30 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( 31 | self.ver_list[i], backbone, 10, 10) 32 | logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) 33 | logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) 34 | 35 | self.summary_writer: SummaryWriter 36 | self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, ) 37 | 38 | if acc2 > self.highest_acc_list[i]: 39 | self.highest_acc_list[i] = acc2 40 | logging.info( 41 | '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) 42 | results.append(acc2) 43 | 44 | def init_dataset(self, val_targets, data_dir, image_size): 45 | for name in val_targets: 46 | path = os.path.join(data_dir, name + ".bin") 47 | if os.path.exists(path): 48 | data_set = verification.load_bin(path, image_size) 49 | self.ver_list.append(data_set) 50 | self.ver_name_list.append(name) 51 | 52 | def __call__(self, num_update, backbone: torch.nn.Module): 53 | if self.rank is 0 and num_update > 0: 54 | backbone.eval() 55 | self.ver_test(backbone, num_update) 56 | backbone.train() 57 | 58 | 59 | class CallBackLogging(object): 60 | def __init__(self, frequent, total_step, batch_size, start_step=0,writer=None): 61 | self.frequent: int = frequent 62 | self.rank: int = distributed.get_rank() 63 | self.world_size: int = distributed.get_world_size() 64 | self.time_start = time.time() 65 | self.total_step: int = total_step 66 | self.start_step: int = start_step 67 | self.batch_size: int = batch_size 68 | self.writer = writer 69 | 70 | self.init = False 71 | self.tic = 0 72 | 73 | def __call__(self, 74 | global_step: int, 75 | loss: AverageMeter, 76 | epoch: int, 77 | fp16: bool, 78 | learning_rate: float, 79 | grad_scaler: torch.cuda.amp.GradScaler): 80 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 81 | if self.init: 82 | try: 83 | speed: float = self.frequent * self.batch_size / (time.time() - self.tic) 84 | speed_total = speed * self.world_size 85 | except ZeroDivisionError: 86 | speed_total = float('inf') 87 | 88 | #time_now = (time.time() - self.time_start) / 3600 89 | #time_total = time_now / ((global_step + 1) / self.total_step) 90 | #time_for_end = time_total - time_now 91 | time_now = time.time() 92 | time_sec = int(time_now - self.time_start) 93 | time_sec_avg = time_sec / (global_step - self.start_step + 1) 94 | eta_sec = time_sec_avg * (self.total_step - global_step - 1) 95 | time_for_end = eta_sec/3600 96 | if self.writer is not None: 97 | self.writer.add_scalar('time_for_end', time_for_end, global_step) 98 | self.writer.add_scalar('learning_rate', learning_rate, global_step) 99 | self.writer.add_scalar('loss', loss.avg, global_step) 100 | if fp16: 101 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ 102 | "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( 103 | speed_total, loss.avg, learning_rate, epoch, global_step, 104 | grad_scaler.get_scale(), time_for_end 105 | ) 106 | else: 107 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ 108 | "Required: %1.f hours" % ( 109 | speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end 110 | ) 111 | logging.info(msg) 112 | loss.reset() 113 | self.tic = time.time() 114 | else: 115 | self.init = True 116 | self.tic = time.time() 117 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | assert config_file.startswith('configs/'), 'config file setting must start with configs/' 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module("configs.base") 10 | cfg = config.config 11 | config = importlib.import_module("configs.%s" % temp_module_name) 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join('work_dirs', temp_module_name) 16 | return cfg -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/utils/utils_distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from torch.utils.data import DistributedSampler as _DistributedSampler 9 | 10 | 11 | def setup_seed(seed, cuda_deterministic=True): 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | if cuda_deterministic: # slower, more reproducible 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | else: # faster, less reproducible 21 | torch.backends.cudnn.deterministic = False 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def worker_init_fn(worker_id, num_workers, rank, seed): 26 | # The seed of each worker equals to 27 | # num_worker * rank + worker_id + user_seed 28 | worker_seed = num_workers * rank + worker_id + seed 29 | np.random.seed(worker_seed) 30 | random.seed(worker_seed) 31 | torch.manual_seed(worker_seed) 32 | 33 | 34 | def get_dist_info(): 35 | if dist.is_available() and dist.is_initialized(): 36 | rank = dist.get_rank() 37 | world_size = dist.get_world_size() 38 | else: 39 | rank = 0 40 | world_size = 1 41 | 42 | return rank, world_size 43 | 44 | 45 | def sync_random_seed(seed=None, device="cuda"): 46 | """Make sure different ranks share the same seed. 47 | All workers must call this function, otherwise it will deadlock. 48 | This method is generally used in `DistributedSampler`, 49 | because the seed should be identical across all processes 50 | in the distributed group. 51 | In distributed sampling, different ranks should sample non-overlapped 52 | data in the dataset. Therefore, this function is used to make sure that 53 | each rank shuffles the data indices in the same order based 54 | on the same seed. Then different ranks could use different indices 55 | to select non-overlapped data from the same data list. 56 | Args: 57 | seed (int, Optional): The seed. Default to None. 58 | device (str): The device where the seed will be put on. 59 | Default to 'cuda'. 60 | Returns: 61 | int: Seed to be used. 62 | """ 63 | if seed is None: 64 | seed = np.random.randint(2**31) 65 | assert isinstance(seed, int) 66 | 67 | rank, world_size = get_dist_info() 68 | 69 | if world_size == 1: 70 | return seed 71 | 72 | if rank == 0: 73 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 74 | else: 75 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 76 | 77 | dist.broadcast(random_num, src=0) 78 | 79 | return random_num.item() 80 | 81 | 82 | class DistributedSampler(_DistributedSampler): 83 | def __init__( 84 | self, 85 | dataset, 86 | num_replicas=None, # world_size 87 | rank=None, # local_rank 88 | shuffle=True, 89 | seed=0, 90 | ): 91 | 92 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 93 | 94 | # In distributed sampling, different ranks should sample 95 | # non-overlapped data in the dataset. Therefore, this function 96 | # is used to make sure that each rank shuffles the data indices 97 | # in the same order based on the same seed. Then different ranks 98 | # could use different indices to select non-overlapped data from the 99 | # same data list. 100 | self.seed = sync_random_seed(seed) 101 | 102 | def __iter__(self): 103 | # deterministically shuffle based on epoch 104 | if self.shuffle: 105 | g = torch.Generator() 106 | # When :attr:`shuffle=True`, this ensures all replicas 107 | # use a different random ordering for each epoch. 108 | # Otherwise, the next iteration of this sampler will 109 | # yield the same ordering. 110 | g.manual_seed(self.epoch + self.seed) 111 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 112 | else: 113 | indices = torch.arange(len(self.dataset)).tolist() 114 | 115 | # add extra samples to make it evenly divisible 116 | # in case that indices is shorter than half of total_size 117 | indices = (indices * math.ceil(self.total_size / len(indices)))[ 118 | : self.total_size 119 | ] 120 | assert len(indices) == self.total_size 121 | 122 | # subsample 123 | indices = indices[self.rank : self.total_size : self.num_replicas] 124 | assert len(indices) == self.num_samples 125 | 126 | return iter(indices) 127 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/arcface_torch/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(rank, models_root): 31 | if rank == 0: 32 | log_root = logging.getLogger() 33 | log_root.setLevel(logging.INFO) 34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from kornia.geometry import warp_affine 5 | import torch.nn.functional as F 6 | 7 | def resize_n_crop(image, M, dsize=112): 8 | # image: (b, c, h, w) 9 | # M : (b, 2, 3) 10 | return warp_affine(image, M, dsize=(dsize, dsize)) 11 | 12 | ### perceptual level loss 13 | class PerceptualLoss(nn.Module): 14 | def __init__(self, recog_net, input_size=112): 15 | super(PerceptualLoss, self).__init__() 16 | self.recog_net = recog_net 17 | self.preprocess = lambda x: 2 * x - 1 18 | self.input_size=input_size 19 | def forward(imageA, imageB, M): 20 | """ 21 | 1 - cosine distance 22 | Parameters: 23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order 24 | imageB --same as imageA 25 | """ 26 | 27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) 28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) 29 | 30 | # freeze bn 31 | self.recog_net.eval() 32 | 33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) 34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) 35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 36 | # assert torch.sum((cosine_d > 1).float()) == 0 37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 38 | 39 | def perceptual_loss(id_featureA, id_featureB): 40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 41 | # assert torch.sum((cosine_d > 1).float()) == 0 42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 43 | 44 | ### image level loss 45 | def photo_loss(imageA, imageB, mask, eps=1e-6): 46 | """ 47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) 48 | Parameters: 49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order 50 | imageB --same as imageA 51 | """ 52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask 53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) 54 | return loss 55 | 56 | def landmark_loss(predict_lm, gt_lm, weight=None): 57 | """ 58 | weighted mse loss 59 | Parameters: 60 | predict_lm --torch.tensor (B, 68, 2) 61 | gt_lm --torch.tensor (B, 68, 2) 62 | weight --numpy.array (1, 68) 63 | """ 64 | if not weight: 65 | weight = np.ones([68]) 66 | weight[28:31] = 20 67 | weight[-8:] = 20 68 | weight = np.expand_dims(weight, 0) 69 | weight = torch.tensor(weight).to(predict_lm.device) 70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight 71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) 72 | return loss 73 | 74 | 75 | ### regulization 76 | def reg_loss(coeffs_dict, opt=None): 77 | """ 78 | l2 norm without the sqrt, from yu's implementation (mse) 79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss 80 | Parameters: 81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans 82 | 83 | """ 84 | # coefficient regularization to ensure plausible 3d faces 85 | if opt: 86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex 87 | else: 88 | w_id, w_exp, w_tex = 1, 1, 1, 1 89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \ 90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \ 91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2) 92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0] 93 | 94 | # gamma regularization to ensure a nearly-monochromatic light 95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9]) 96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True) 97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2) 98 | 99 | return creg_loss, gamma_loss 100 | 101 | def reflectance_loss(texture, mask): 102 | """ 103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo 104 | Parameters: 105 | texture --torch.tensor, (B, N, 3) 106 | mask --torch.tensor, (N), 1 or 0 107 | 108 | """ 109 | mask = mask.reshape([1, mask.shape[0], 1]) 110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) 111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask)) 112 | return loss 113 | 114 | -------------------------------------------------------------------------------- /eg3d-pose-detection/models/template_model.py: -------------------------------------------------------------------------------- 1 | """Model class template 2 | 3 | This module provides a template for users to implement custom models. 4 | You can specify '--model template' to use this model. 5 | The class name should be consistent with both the filename and its model option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | It implements a simple image-to-image translation baseline based on regression loss. 9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss: 10 | min_ ||netG(data_A) - data_B||_1 11 | You need to implement the following functions: 12 | : Add model-specific options and rewrite default values for existing options. 13 | <__init__>: Initialize this model class. 14 | : Unpack input data and perform data pre-processing. 15 | : Run forward pass. This will be called by both and . 16 | : Update network weights; it will be called in every training iteration. 17 | """ 18 | import numpy as np 19 | import torch 20 | from .base_model import BaseModel 21 | from . import networks 22 | 23 | 24 | class TemplateModel(BaseModel): 25 | @staticmethod 26 | def modify_commandline_options(parser, is_train=True): 27 | """Add new model-specific options and rewrite default values for existing options. 28 | 29 | Parameters: 30 | parser -- the option parser 31 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options. 32 | 33 | Returns: 34 | the modified parser. 35 | """ 36 | parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset. 37 | if is_train: 38 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model. 39 | 40 | return parser 41 | 42 | def __init__(self, opt): 43 | """Initialize this model class. 44 | 45 | Parameters: 46 | opt -- training/test options 47 | 48 | A few things can be done here. 49 | - (required) call the initialization function of BaseModel 50 | - define loss function, visualization images, model names, and optimizers 51 | """ 52 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel 53 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk. 54 | self.loss_names = ['loss_G'] 55 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images. 56 | self.visual_names = ['data_A', 'data_B', 'output'] 57 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks. 58 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them. 59 | self.model_names = ['G'] 60 | # define networks; you can use opt.isTrain to specify different behaviors for training and test. 61 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids) 62 | if self.isTrain: # only defined during training time 63 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss. 64 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device) 65 | self.criterionLoss = torch.nn.L1Loss() 66 | # define and initialize optimizers. You can define one optimizer for each network. 67 | # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 68 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 69 | self.optimizers = [self.optimizer] 70 | 71 | # Our program will automatically call to define schedulers, load networks, and print networks 72 | 73 | def set_input(self, input): 74 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 75 | 76 | Parameters: 77 | input: a dictionary that contains the data itself and its metadata information. 78 | """ 79 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B 80 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A 81 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B 82 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths 83 | 84 | def forward(self): 85 | """Run forward pass. This will be called by both functions and .""" 86 | self.output = self.netG(self.data_A) # generate output image given the input data_A 87 | 88 | def backward(self): 89 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 90 | # caculate the intermediate results if necessary; here self.output has been computed during function 91 | # calculate loss given the input and intermediate results 92 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression 93 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G 94 | 95 | def optimize_parameters(self): 96 | """Update network weights; it will be called in every training iteration.""" 97 | self.forward() # first call forward to calculate intermediate results 98 | self.optimizer.zero_grad() # clear network G's existing gradients 99 | self.backward() # calculate gradients for network G 100 | self.optimizer.step() # update gradients for network G 101 | -------------------------------------------------------------------------------- /eg3d-pose-detection/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /eg3d-pose-detection/options/base_options.py: -------------------------------------------------------------------------------- 1 | """This script contains base options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import argparse 5 | import os 6 | from util import util 7 | import numpy as np 8 | import torch 9 | import models 10 | import data 11 | 12 | 13 | class BaseOptions(): 14 | """This class defines options used during both training and test time. 15 | 16 | It also implements several helper functions such as parsing, printing, and saving the options. 17 | It also gathers additional options defined in functions in both dataset class and model class. 18 | """ 19 | 20 | def __init__(self, cmd_line=None): 21 | """Reset the class; indicates the class hasn't been initailized""" 22 | self.initialized = False 23 | self.cmd_line = None 24 | if cmd_line is not None: 25 | self.cmd_line = cmd_line.split() 26 | 27 | def initialize(self, parser): 28 | """Define the common options that are used in both training and test.""" 29 | # basic parameters 30 | parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models') 31 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 32 | parser.add_argument('--checkpoints_dir', type=str, default='/apdcephfs_cq2/share_1290939/kitbai/eg3d-pose-detection-main-master/checkpoints', help='models are saved here') 33 | parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization') 34 | parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation') 35 | parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel') 36 | parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port') 37 | parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses') 38 | parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard') 39 | parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation') 40 | 41 | # model parameters 42 | parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.') 43 | 44 | # additional parameters 45 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 46 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 47 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 48 | 49 | self.initialized = True 50 | return parser 51 | 52 | def gather_options(self): 53 | """Initialize our parser with basic options(only once). 54 | Add additional model-specific and dataset-specific options. 55 | These options are defined in the function 56 | in model and dataset classes. 57 | """ 58 | if not self.initialized: # check if it has been initialized 59 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 60 | parser = self.initialize(parser) 61 | 62 | # get the basic options 63 | if self.cmd_line is None: 64 | opt, _ = parser.parse_known_args() 65 | else: 66 | opt, _ = parser.parse_known_args(self.cmd_line) 67 | 68 | # set cuda visible devices 69 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids 70 | 71 | # modify model-related parser options 72 | model_name = opt.model 73 | model_option_setter = models.get_option_setter(model_name) 74 | parser = model_option_setter(parser, self.isTrain) 75 | if self.cmd_line is None: 76 | opt, _ = parser.parse_known_args() # parse again with new defaults 77 | else: 78 | opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults 79 | 80 | # modify dataset-related parser options 81 | if opt.dataset_mode: 82 | dataset_name = opt.dataset_mode 83 | dataset_option_setter = data.get_option_setter(dataset_name) 84 | parser = dataset_option_setter(parser, self.isTrain) 85 | 86 | # save and return the parser 87 | self.parser = parser 88 | if self.cmd_line is None: 89 | return parser.parse_args() 90 | else: 91 | return parser.parse_args(self.cmd_line) 92 | 93 | def print_options(self, opt): 94 | """Print and save options 95 | 96 | It will print both current options and default values(if different). 97 | It will save options into a text file / [checkpoints_dir] / opt.txt 98 | """ 99 | message = '' 100 | message += '----------------- Options ---------------\n' 101 | for k, v in sorted(vars(opt).items()): 102 | comment = '' 103 | default = self.parser.get_default(k) 104 | if v != default: 105 | comment = '\t[default: %s]' % str(default) 106 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 107 | message += '----------------- End -------------------' 108 | print(message) 109 | 110 | # save to the disk 111 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 112 | util.mkdirs(expr_dir) 113 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 114 | try: 115 | with open(file_name, 'wt') as opt_file: 116 | opt_file.write(message) 117 | opt_file.write('\n') 118 | except PermissionError as error: 119 | print("permission error {}".format(error)) 120 | pass 121 | 122 | def parse(self): 123 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 124 | opt = self.gather_options() 125 | opt.isTrain = self.isTrain # train or test 126 | 127 | # process opt.suffix 128 | if opt.suffix: 129 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 130 | opt.name = opt.name + suffix 131 | 132 | 133 | # set gpu ids 134 | str_ids = opt.gpu_ids.split(',') 135 | gpu_ids = [] 136 | for str_id in str_ids: 137 | id = int(str_id) 138 | if id >= 0: 139 | gpu_ids.append(id) 140 | opt.world_size = len(gpu_ids) 141 | # if len(opt.gpu_ids) > 0: 142 | # torch.cuda.set_device(gpu_ids[0]) 143 | if opt.world_size == 1: 144 | opt.use_ddp = False 145 | 146 | if opt.phase != 'test': 147 | # set continue_train automatically 148 | if opt.pretrained_name is None: 149 | model_dir = os.path.join(opt.checkpoints_dir, opt.name) 150 | else: 151 | model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name) 152 | if os.path.isdir(model_dir): 153 | model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')] 154 | if os.path.isdir(model_dir) and len(model_pths) != 0: 155 | opt.continue_train= True 156 | 157 | # update the latest epoch count 158 | if opt.continue_train: 159 | if opt.epoch == 'latest': 160 | epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i] 161 | if len(epoch_counts) != 0: 162 | opt.epoch_count = max(epoch_counts) + 1 163 | else: 164 | opt.epoch_count = int(opt.epoch) + 1 165 | 166 | 167 | self.print_options(opt) 168 | self.opt = opt 169 | return self.opt 170 | -------------------------------------------------------------------------------- /eg3d-pose-detection/options/test_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the test options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | 6 | 7 | class TestOptions(BaseOptions): 8 | """This class includes test options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) # define shared options 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]') 17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.') 18 | parser.add_argument('--start', type=int, default=0, help='start folder') 19 | parser.add_argument('--skip_model', action='store_true', help='whether to run model') 20 | 21 | # Dropout and Batchnorm has different behavior during training and test. 22 | self.isTrain = False 23 | return parser 24 | -------------------------------------------------------------------------------- /eg3d-pose-detection/options/train_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the training options for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | from .base_options import BaseOptions 5 | from util import util 6 | 7 | class TrainOptions(BaseOptions): 8 | """This class includes training options. 9 | 10 | It also includes shared options defined in BaseOptions. 11 | """ 12 | 13 | def initialize(self, parser): 14 | parser = BaseOptions.initialize(self, parser) 15 | # dataset parameters 16 | # for train 17 | parser.add_argument('--data_root', type=str, default='./', help='dataset root') 18 | parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set') 19 | parser.add_argument('--batch_size', type=int, default=32) 20 | parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]') 21 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 22 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 23 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 24 | parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]') 25 | parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation') 26 | 27 | # for val 28 | parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set') 29 | parser.add_argument('--batch_size_val', type=int, default=32) 30 | 31 | 32 | # visualization parameters 33 | parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen') 34 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 35 | 36 | # network saving and loading parameters 37 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 38 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 39 | parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq') 40 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 41 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 42 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 43 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 44 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') 45 | 46 | # training parameters 47 | parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate') 48 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 49 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') 50 | parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches') 51 | 52 | self.isTrain = True 53 | return parser 54 | -------------------------------------------------------------------------------- /eg3d-pose-detection/process_test_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | Processes a directory containing *.jpg/png and outputs crops and poses. 3 | """ 4 | import glob 5 | import os 6 | import subprocess 7 | import argparse 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--input_dir', default='/media/data6/ericryanchan/mafu/Deep3DFaceRecon_pytorch/test_images') 10 | parser.add_argument('--gpu', default=0) 11 | args = parser.parse_args() 12 | 13 | print('Processing images:', sorted(glob.glob(os.path.join(args.input_dir, "*")))) 14 | 15 | # Compute facial landmarks. 16 | print("Computing facial landmarks for model...") 17 | cmd = "python3.6 /eg3d-pose-detection/batch_mtcnn.py" 18 | input_flag = " --in_root " + args.input_dir 19 | cmd += input_flag 20 | # subprocess.run([cmd], shell=True, check=True) 21 | os.system(cmd) 22 | 23 | 24 | print("Running smooth...") 25 | cmd = "python3.6 /eg3d-pose-detection/smooth.py" 26 | input_flag = " --img_folder=" + args.input_dir 27 | cmd += input_flag 28 | os.system(cmd) 29 | 30 | # Run model inference to produce crops and raw poses. 31 | print("Running model inference...") 32 | cmd = "python3.6 /eg3d-pose-detection/test.py" 33 | input_flag = " --img_folder=" + args.input_dir 34 | gpu_flag = " --gpu_ids=" + str(args.gpu) 35 | model_name_flag = " --name=face_recon" 36 | model_file_flag = " --epoch=20 " 37 | cmd += input_flag + gpu_flag + model_name_flag + model_file_flag 38 | # subprocess.run([cmd], shell=True, check=True) 39 | os.system(cmd) 40 | 41 | # Perform final cropping of 1024x1024 images. 42 | print("Processing final crops...") 43 | cmd = "python3.6 /eg3d-pose-detection/crop_images.py" 44 | input_flag = " --indir " + args.input_dir 45 | output_flag = " --outdir " + os.path.join(args.input_dir, 'cropped_images') 46 | cmd += input_flag + output_flag 47 | # subprocess.run([cmd], shell=True, check=True) 48 | os.system(cmd) 49 | 50 | # Process poses into our representation -- produces a cameras.json file. 51 | print("Processing final poses...") 52 | cmd = "python3.6 /eg3d-pose-detection/3dface2idr.py" 53 | input_flag = " --in_root " + os.path.join(args.input_dir, "epoch_20_000000") 54 | output_flag = " --out_root " + os.path.join(args.input_dir, "cropped_images") 55 | 56 | cmd += input_flag + output_flag 57 | # subprocess.run([cmd], shell=True, check=True) 58 | os.system(cmd) 59 | 60 | print("Transforming...") 61 | cmd = "python3.6 /eg3d-pose-detection/camera2label.py" 62 | input_flag = " --in_root " + args.input_dir 63 | cmd += input_flag 64 | # subprocess.run([cmd], shell=True, check=True) 65 | os.system(cmd) 66 | -------------------------------------------------------------------------------- /eg3d-pose-detection/smooth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from options.test_options import TestOptions 4 | from data import create_dataset 5 | from models import create_model 6 | from util.visualizer import MyVisualizer 7 | from util.preprocess import align_img 8 | from PIL import Image 9 | import numpy as np 10 | from util.load_mats import load_lm3d 11 | 12 | from data.flist_dataset import default_flist_reader 13 | from scipy.io import loadmat, savemat 14 | 15 | import json 16 | 17 | from scipy.ndimage import gaussian_filter1d 18 | 19 | 20 | def get_data_path(root='examples'): 21 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')] 22 | im_path = sorted(im_path , key=lambda x:int(x.split('/')[-1].split('.')[0])) 23 | lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path] 24 | lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path] 25 | return lm_path 26 | 27 | 28 | def read_data( lm_path): 29 | lms = [] 30 | # for i in range(10): 31 | for i in range(len(lm_path)): 32 | #im = Image.open(im_path).convert('RGB') 33 | # _, H = im.size 34 | if not os.path.isfile(lm_path[i]): 35 | continue 36 | lm = np.loadtxt(lm_path[i]).astype(np.float32) 37 | # print(lm) 38 | lms.append(lm) 39 | lms = np.array(lms) 40 | lms = gaussian_filter1d(lms, 2, 0) 41 | 42 | # for i in range(10): 43 | # print(lms[i]) 44 | for i in range(len(lm_path)): 45 | if not os.path.isfile(lm_path[i]): 46 | continue 47 | np.savetxt(lm_path[i], lms[i]) 48 | 49 | 50 | def main(rank, opt, name='examples'): 51 | 52 | lm_path = get_data_path(name) 53 | read_data(lm_path) 54 | 55 | if __name__ == '__main__': 56 | opt = TestOptions().parse() # get test options 57 | main(0, opt,opt.img_folder) 58 | -------------------------------------------------------------------------------- /eg3d-pose-detection/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /eg3d-pose-detection/test.py: -------------------------------------------------------------------------------- 1 | """This script is the test script for Deep3DFaceRecon_pytorch 2 | """ 3 | 4 | import os 5 | import torch 6 | from options.test_options import TestOptions 7 | from data import create_dataset 8 | from models import create_model 9 | from util.visualizer import MyVisualizer 10 | from util.preprocess import align_img 11 | from PIL import Image 12 | import numpy as np 13 | from util.load_mats import load_lm3d 14 | 15 | from data.flist_dataset import default_flist_reader 16 | from scipy.io import loadmat, savemat 17 | 18 | import json 19 | 20 | def get_data_path(root='examples'): 21 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')] 22 | lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path] 23 | lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path] 24 | return im_path, lm_path 25 | 26 | def read_data(im_path, lm_path, lm3d_std, to_tensor=True, rescale_factor=466.285): 27 | im = Image.open(im_path).convert('RGB') 28 | _, H = im.size 29 | lm = np.loadtxt(lm_path).astype(np.float32) 30 | lm = lm.reshape([-1, 2]) 31 | lm[:, -1] = H - 1 - lm[:, -1] 32 | _, im_pil, lm, _, im_high = align_img(im, lm, lm3d_std, rescale_factor=rescale_factor) 33 | # im_high.save(os.path.join('/apdcephfs_cq2/share_1290939/kitbai/LIA-3d/datasets/our_dataset/test/3', 'crop_1024', 'imagehigh.png')) 34 | if to_tensor: 35 | im = torch.tensor(np.array(im_pil)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) 36 | lm = torch.tensor(lm).unsqueeze(0) 37 | else: 38 | im = im_pil 39 | return im, lm, im_pil, im_high 40 | 41 | def main(rank, opt, name='examples'): 42 | device = torch.device(rank) 43 | torch.cuda.set_device(device) 44 | model = create_model(opt) 45 | model.setup(opt) 46 | model.device = device 47 | model.parallelize() 48 | model.eval() 49 | visualizer = MyVisualizer(opt) 50 | print("ROOT") 51 | print(name) 52 | im_path, lm_path = get_data_path(name) 53 | lm3d_std = load_lm3d(opt.bfm_folder) 54 | 55 | cropping_params = {} 56 | 57 | out_dir_crop1024 = os.path.join(name, "crop_1024") 58 | if not os.path.exists(out_dir_crop1024): 59 | os.makedirs(out_dir_crop1024) 60 | out_dir = os.path.join(name, 'epoch_%s_%06d'%(opt.epoch, 0)) 61 | if not os.path.exists(out_dir): 62 | os.makedirs(out_dir) 63 | for i in range(len(im_path)): 64 | print(i, im_path[i]) 65 | img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','') 66 | if not os.path.isfile(lm_path[i]): 67 | continue 68 | 69 | # 2 passes for cropping image for NeRF and for pose extraction 70 | for r in range(2): 71 | if r==0: 72 | rescale_factor = 300 # optimized for NeRF training 73 | center_crop_size = 700 74 | output_size = 512 75 | 76 | # left = int(im_high.size[0]/2 - center_crop_size/2) 77 | # upper = int(im_high.size[1]/2 - center_crop_size/2) 78 | # right = left + center_crop_size 79 | # lower = upper + center_crop_size 80 | # im_cropped = im_high.crop((left, upper, right,lower)) 81 | # im_cropped = im_cropped.resize((output_size, output_size), resample=Image.LANCZOS) 82 | cropping_params[os.path.basename(im_path[i])] = { 83 | 'lm': np.loadtxt(lm_path[i]).astype(np.float32).tolist(), 84 | 'lm3d_std': lm3d_std.tolist(), 85 | 'rescale_factor': rescale_factor, 86 | 'center_crop_size': center_crop_size, 87 | 'output_size': output_size} 88 | 89 | # im_high.save(os.path.join(out_dir_crop1024, img_name+'.png'), compress_level=0) 90 | # im_cropped.save(os.path.join(out_dir_crop1024, img_name+'.png'), compress_level=0) 91 | elif not opt.skip_model: 92 | rescale_factor = 466.285 93 | im_tensor, lm_tensor, _, im_high = read_data(im_path[i], lm_path[i], lm3d_std, rescale_factor=rescale_factor) 94 | 95 | data = { 96 | 'imgs': im_tensor, 97 | 'lms': lm_tensor 98 | } 99 | model.set_input(data) # unpack data from data loader 100 | model.test() # run inference 101 | visuals = model.get_current_visuals() # get image results 102 | visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1], 103 | save_results=True, count=i, name=img_name, add_image=False) 104 | 105 | model.save_coeff(os.path.join(out_dir,img_name+'.mat')) # save predicted coefficients 106 | 107 | with open(os.path.join(name, 'cropping_params.json'), 'w') as outfile: 108 | json.dump(cropping_params, outfile, indent=4) 109 | 110 | if __name__ == '__main__': 111 | opt = TestOptions().parse() # get test options 112 | main(0, opt,opt.img_folder) 113 | 114 | 115 | # python test.py --epoch 20 --img_folder test_images/ --gpu_ids 0 --------------------------------------------------------------------------------