├── README.md
├── assets
├── pipeline-2.png
└── test
├── code
├── cam_utils.py
├── dataset.py
├── networks
│ ├── __init__.py
│ ├── encoder3d.py
│ └── headnerf.py
├── pretrained_models
│ └── eg3d
│ │ └── test
├── run_recon_video_3dmm.py
├── run_recon_video_audio.py
├── run_recon_video_rgb.py
├── test
├── train_3dmm.py
├── train_audio.py
├── train_rgb.py
├── trainer_3dmm.py
├── trainer_audio.py
└── trainer_rgb.py
└── eg3d-pose-detection
├── 3dface2idr.py
├── batch_mtcnn.py
├── camera2label.py
├── crop_images.py
├── data
├── __init__.py
├── base_dataset.py
├── flist_dataset.py
├── image_folder.py
└── template_dataset.py
├── models
├── __init__.py
├── arcface_torch
│ ├── README.md
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── iresnet.py
│ │ ├── iresnet2060.py
│ │ ├── mobilefacenet.py
│ │ └── vit.py
│ ├── configs
│ │ ├── 3millions.py
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── glint360k_mbf.py
│ │ ├── glint360k_r100.py
│ │ ├── glint360k_r50.py
│ │ ├── ms1mv2_mbf.py
│ │ ├── ms1mv2_r100.py
│ │ ├── ms1mv2_r50.py
│ │ ├── ms1mv3_mbf.py
│ │ ├── ms1mv3_r100.py
│ │ ├── ms1mv3_r50.py
│ │ ├── wf12m_conflict_r50.py
│ │ ├── wf12m_conflict_r50_pfc03_filter04.py
│ │ ├── wf12m_flip_pfc01_filter04_r50.py
│ │ ├── wf12m_flip_r50.py
│ │ ├── wf12m_mbf.py
│ │ ├── wf12m_pfc02_r100.py
│ │ ├── wf12m_r100.py
│ │ ├── wf12m_r50.py
│ │ ├── wf42m_pfc0008_32gpu_r100.py
│ │ ├── wf42m_pfc02_16gpus_mbf_bs8k.py
│ │ ├── wf42m_pfc02_16gpus_r100.py
│ │ ├── wf42m_pfc02_16gpus_r50_bs8k.py
│ │ ├── wf42m_pfc02_32gpus_r50_bs4k.py
│ │ ├── wf42m_pfc02_8gpus_r50_bs4k.py
│ │ ├── wf42m_pfc02_r100.py
│ │ ├── wf42m_pfc02_r100_16gpus.py
│ │ ├── wf42m_pfc02_r100_32gpus.py
│ │ ├── wf42m_pfc03_32gpu_r100.py
│ │ ├── wf42m_pfc03_32gpu_r18.py
│ │ ├── wf42m_pfc03_32gpu_r200.py
│ │ ├── wf42m_pfc03_32gpu_r50.py
│ │ ├── wf42m_pfc03_40epoch_64gpu_vit_b.py
│ │ ├── wf42m_pfc03_40epoch_64gpu_vit_l.py
│ │ ├── wf42m_pfc03_40epoch_64gpu_vit_s.py
│ │ ├── wf42m_pfc03_40epoch_64gpu_vit_t.py
│ │ ├── wf42m_pfc03_40epoch_8gpu_vit_t.py
│ │ ├── wf4m_mbf.py
│ │ ├── wf4m_r100.py
│ │ └── wf4m_r50.py
│ ├── dataset.py
│ ├── dist.sh
│ ├── docs
│ │ ├── eval.md
│ │ ├── install.md
│ │ ├── install_dali.md
│ │ ├── modelzoo.md
│ │ ├── prepare_webface42m.md
│ │ └── speed_benchmark.md
│ ├── eval
│ │ ├── __init__.py
│ │ └── verification.py
│ ├── eval_ijbc.py
│ ├── flops.py
│ ├── inference.py
│ ├── losses.py
│ ├── lr_scheduler.py
│ ├── onnx_helper.py
│ ├── onnx_ijbc.py
│ ├── partial_fc.py
│ ├── requirement.txt
│ ├── run.sh
│ ├── torch2onnx.py
│ ├── train.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── plot.py
│ │ ├── utils_callbacks.py
│ │ ├── utils_config.py
│ │ ├── utils_distributed_sampler.py
│ │ └── utils_logging.py
├── base_model.py
├── bfm.py
├── facerecon_model.py
├── losses.py
├── networks.py
└── template_model.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── process_test_video.py
├── smooth.py
├── test
└── test.py
/README.md:
--------------------------------------------------------------------------------
1 | # HFA-GP: High-fidelity Facial Avatar Reconstruction from Monocular Video with Generative Priors
2 |
3 |
4 |
5 |
6 | ## HFA-GP
7 | **HFA-GP** is a framework for reconstructing high-fidelity facial avatar from monocular video.
8 | This is the official implementation of *High-fidelity Facial Avatar Reconstruction from Monocular Video with Generative Priors*
9 |
10 | ## Abstract
11 | High-fidelity facial avatar reconstruction from a monocular video is a significant research problem in computer graphics and computer vision. Recently, Neural Radiance Field (NeRF) has shown impressive novel view rendering results and has been considered for facial avatar reconstruction. However, the complex facial dynamics and missing 3D information in monocular videos raise significant challenges for faithful facial reconstruction. In this work, we propose a new method for NeRF-based facial avatar reconstruction that utilizes 3D-aware generative prior. Different from existing works that depend on a conditional deformation field for dynamic modeling, we propose to learn a personalized generative prior, which is formulated as a local and low dimensional subspace in the latent space of 3D-GAN. We propose an efficient method to construct the personalized generative prior based on a small set of facial images of a given individual. After learning, it allows for photo-realistic rendering with novel views, and the face reenactment can be realized by performing navigation in the latent space. Our proposed method is applicable for different driven signals, including RGB images, 3DMM coefficients, and audio. Compared with existing works, we obtain superior novel view synthesis results and faithfully face reenactment performance.
12 |
13 | ## Preprocessing data
14 |
15 | ```
16 | python3 ./eg3d-pose-detection/process_test_video.py --input_dir
17 | ```
18 |
19 |
20 | ## Training on monocular video
21 | audio-driven:
22 | ```
23 | python3 ./code/train_audio.py --dataset 'ad_dataset' --person_1 'english_m' --exp_path './code/exps/' --exp_name 'ad-english_w_e'
24 | ```
25 |
26 | 3DMM-driven:
27 | ```
28 | python3 ./code/train_3dmm.py --dataset 'nerface_dataset' --person_1 'person_3' --exp_path './code/exps/' --exp_name '1-nerface3-3dmm2'
29 | ```
30 |
31 | RGB-driven:
32 | ```
33 | python3 ./code/train_rgb.py --dataset 'nerface_dataset' --person 'person_3' --exp_path './code/exps/' --exp_name '1-nerface-3-2'
34 | ```
35 |
36 | ## Performing face reenactment
37 | audio-driven:
38 | ```
39 | python3 ./code/run_recon_video_audio.py --dataset 'ad_dataset' --person_1 'english_w' --demo_name 'english_w' --dataset_type 'val' --model_path './code/exps/ad-english_m/checkpoint/checkpoint.pt' --cat_video
40 | ```
41 | For audio-driven experiments, deepspeech is required to extract features from the audio. This part uses AD-NeRF's code [AD-NeRF](https://github.com/YudongGuo/AD-NeRF). First, use ffmpeg to extract the audio in WAV format and then extract the features. The extracted feature file should be named aud.npy.
42 |
43 | 3DMM-driven:
44 | ```
45 | python3 ./code/run_recon_video_3dmm.py --dataset 'nerface_dataset' --person_1 'person_3' --demo_name 'nerface3' --dataset_type 'test' --model_path './code/exps/1-nerface3/checkpoint/checkpoint.pt' --cat_video
46 | ```
47 |
48 | RGB-driven:
49 | ```
50 | python3 ./code/run_recon_video_rgb.py --dataset 'nerface_dataset' --person 'person_3' --demo_name '1-nerface3-2-new' --model_path './code/exps/1-nerface/checkpoint/checkpoint.pt' --cat_video --dataset_type 'test' --suffix '.png' --latent_dim_shape 50
51 | ```
52 |
53 | ## Citation ##
54 | Please cite the following paper if you use this repository in your reseach.
55 | ```
56 | @InProceedings{Bai_2023_CVPR,
57 | author = {Bai, Yunpeng and Fan, Yanbo and Wang, Xuan and Zhang, Yong and Sun, Jingxiang and Yuan, Chun and Shan, Ying},
58 | title = {High-Fidelity Facial Avatar Reconstruction From Monocular Video With Generative Priors},
59 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
60 | month = {June},
61 | year = {2023},
62 | pages = {4541-4551}
63 | }
64 | ```
65 |
66 |
67 |
--------------------------------------------------------------------------------
/assets/pipeline-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/assets/pipeline-2.png
--------------------------------------------------------------------------------
/assets/test:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/cam_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import math
4 |
5 |
6 | def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
7 | """
8 | Normalize vector lengths.
9 | """
10 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
11 |
12 | def sample_camera_positions(device, n=1, r=1, horizontal_stddev=1, vertical_stddev=1, horizontal_mean=math.pi*0.5, vertical_mean=math.pi*0.5, mode='normal'):
13 | """
14 | Samples n random locations along a sphere of radius r. Uses the specified distribution.
15 | Theta is yaw in radians (-pi, pi)
16 | Phi is pitch in radians (0, pi)
17 | """
18 |
19 | if mode == 'uniform':
20 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev + horizontal_mean
21 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev + vertical_mean
22 |
23 | elif mode == 'normal' or mode == 'gaussian':
24 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
25 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
26 |
27 | elif mode == 'hybrid':
28 | if random.random() < 0.5:
29 | theta = (torch.rand((n, 1), device=device) - 0.5) * 2 * horizontal_stddev * 2 + horizontal_mean
30 | phi = (torch.rand((n, 1), device=device) - 0.5) * 2 * vertical_stddev * 2 + vertical_mean
31 | else:
32 | theta = torch.randn((n, 1), device=device) * horizontal_stddev + horizontal_mean
33 | phi = torch.randn((n, 1), device=device) * vertical_stddev + vertical_mean
34 |
35 | elif mode == 'truncated_gaussian':
36 | theta = truncated_normal_(torch.zeros((n, 1), device=device)) * horizontal_stddev + horizontal_mean
37 | phi = truncated_normal_(torch.zeros((n, 1), device=device)) * vertical_stddev + vertical_mean
38 |
39 | elif mode == 'spherical_uniform':
40 | theta = (torch.rand((n, 1), device=device) - .5) * 2 * horizontal_stddev + horizontal_mean
41 | v_stddev, v_mean = vertical_stddev / math.pi, vertical_mean / math.pi
42 | v = ((torch.rand((n,1), device=device) - .5) * 2 * v_stddev + v_mean)
43 | v = torch.clamp(v, 1e-5, 1 - 1e-5)
44 | phi = torch.arccos(1 - 2 * v)
45 |
46 | else:
47 | # Just use the mean.
48 | theta = torch.ones((n, 1), device=device, dtype=torch.float) * horizontal_mean
49 | phi = torch.ones((n, 1), device=device, dtype=torch.float) * vertical_mean
50 |
51 | phi = torch.clamp(phi, 1e-5, math.pi - 1e-5)
52 |
53 | output_points = torch.zeros((n, 3), device=device)
54 | output_points[:, 0:1] = r*torch.sin(phi) * torch.cos(theta)
55 | output_points[:, 2:3] = r*torch.sin(phi) * torch.sin(theta)
56 | output_points[:, 1:2] = r*torch.cos(phi)
57 |
58 | return output_points, phi, theta
59 |
60 |
61 |
62 | def create_cam2world_matrix(forward_vector, origin, device=None):
63 | """Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix."""
64 |
65 | forward_vector = normalize_vecs(forward_vector)
66 | up_vector = torch.tensor([0, 1, 0], dtype=torch.float, device=device).expand_as(forward_vector)
67 |
68 | left_vector = normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
69 |
70 | up_vector = normalize_vecs(torch.cross(forward_vector, left_vector, dim=-1))
71 |
72 | rotation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
73 | rotation_matrix[:, :3, :3] = torch.stack((-left_vector, up_vector, -forward_vector), axis=-1)
74 |
75 | translation_matrix = torch.eye(4, device=device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
76 | translation_matrix[:, :3, 3] = origin
77 |
78 | cam2world = translation_matrix @ rotation_matrix
79 |
80 | return cam2world
81 |
--------------------------------------------------------------------------------
/code/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/pretrained_models/eg3d/test:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/test:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/code/train_3dmm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torch.utils import data
5 | from dataset import HeadData_3DMM
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | from trainer_3dmm import Trainer
9 | from torch.utils.tensorboard import SummaryWriter
10 | import torch.distributed as dist
11 | import torch.multiprocessing as mp
12 |
13 | torch.backends.cudnn.enabled = True
14 | torch.backends.cudnn.benchmark = True
15 |
16 |
17 | def data_sampler(dataset, shuffle):
18 | if shuffle:
19 | return data.RandomSampler(dataset)
20 | else:
21 | return data.SequentialSampler(dataset)
22 |
23 |
24 | def sample_data(loader):
25 | while True:
26 | for batch in loader:
27 | yield batch
28 |
29 |
30 | def display_img(idx, img, name, writer, args):
31 | img = img.clamp(-1, 1)
32 | img = ((img - img.min()) / (img.max() - img.min())).data
33 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/display/'+str(idx)+name+'.png')
34 |
35 | writer.add_images(tag='%s' % (name), global_step=idx, img_tensor=img)
36 |
37 | def display_bases(imgs, name, args):
38 | for idx in range(len(imgs)):
39 | img = imgs[idx]
40 | img = img.clamp(-1, 1)
41 | img = ((img - img.min()) / (img.max() - img.min())).data
42 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/bases/'+str(idx)+name+'.png')
43 |
44 | def write_loss(i,l2_loss_3dmm, l2_loss, lpips_loss, writer):
45 | writer.add_scalar('3dmm_loss', l2_loss_3dmm.item(), i)
46 | writer.add_scalar('l2_loss', l2_loss.item(), i)
47 | writer.add_scalar('lpips_loss', lpips_loss.item(), i)
48 | writer.flush()
49 |
50 | def display_bases(imgs, name, args):
51 | for idx in range(len(imgs)):
52 | img = imgs[idx]
53 | img = img.clamp(-1, 1)
54 | img = ((img - img.min()) / (img.max() - img.min())).data
55 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/bases/'+str(idx)+name+'.png')
56 |
57 | def ddp_setup(args, rank, world_size):
58 | os.environ['MASTER_ADDR'] = args.addr
59 | os.environ['MASTER_PORT'] = args.port
60 |
61 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
62 |
63 |
64 | def main(rank, world_size, args):
65 | # init distributed computing
66 | ddp_setup(args, rank, world_size)
67 | torch.cuda.set_device(rank)
68 | device = torch.device("cuda")
69 |
70 | # make logging folder
71 | log_path = os.path.join(args.exp_path, args.exp_name + '/log')
72 | checkpoint_path = os.path.join(args.exp_path, args.exp_name + '/checkpoint')
73 | display_path = os.path.join(args.exp_path, args.exp_name + '/display')
74 | bases_path = os.path.join(args.exp_path, args.exp_name + '/bases')
75 | os.makedirs(log_path, exist_ok=True)
76 | os.makedirs(checkpoint_path, exist_ok=True)
77 | os.makedirs(display_path, exist_ok=True)
78 | os.makedirs(bases_path, exist_ok=True)
79 | writer = SummaryWriter(log_path)
80 |
81 | print('==> preparing dataset')
82 | transform = torchvision.transforms.Compose([
83 | transforms.Resize(args.size),
84 | transforms.ToTensor(),
85 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
86 |
87 | dataset = HeadData_3DMM('train', transform, dataset = args.dataset , person = args.person_1)
88 | dataset_test = HeadData_3DMM('test', transform, dataset = args.dataset , person = args.person_1 )
89 |
90 | loader = data.DataLoader(
91 | dataset,
92 |
93 | batch_size=args.batch_size // world_size,
94 | sampler=data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True),
95 | pin_memory=True,
96 | drop_last=False,
97 | )
98 |
99 | loader_test = data.DataLoader(
100 | dataset_test,
101 | batch_size=1,
102 | sampler=data.distributed.DistributedSampler(dataset_test, num_replicas=world_size, rank=rank, shuffle=False),
103 | pin_memory=True,
104 | drop_last=False,
105 | )
106 |
107 | loader = sample_data(loader)
108 | loader_test = sample_data(loader_test)
109 |
110 | print('==> initializing trainer')
111 | # Trainer
112 | trainer = Trainer(args, device, rank)
113 |
114 |
115 | print('==> training')
116 | pbar = range(args.iter)
117 | for idx in pbar:
118 | i = idx + args.start_iter
119 |
120 |
121 | real_image, label, params = next(loader)
122 | real_image = real_image.to(rank, non_blocking=True)
123 | label = label.to(rank, non_blocking=True)
124 | params = params.to(rank, non_blocking=True)
125 |
126 |
127 | # update generator
128 | l2_loss_3dmm, l2_loss, loss_lpips, generated_image = trainer.gen_update(real_image, label, params)
129 |
130 |
131 | if rank == 0:
132 | # write to log
133 | write_loss(idx, l2_loss_3dmm, l2_loss, loss_lpips, writer)
134 | if (i+1) >= args.tune_iter:
135 | # print('begin training nerf')
136 | trainer.tune_generator()
137 | # display
138 | if (i+1) % args.display_freq == 0 and rank == 0:
139 | print("[Iter %d/%d] [3dmm loss: %f] [l2 loss: %f] [lpips loss: %f]"
140 | % (i, args.iter, l2_loss_3dmm.item(), l2_loss.item(), loss_lpips.item()))
141 |
142 | if rank == 0:
143 | real_image_test, label_test, params_test = next(loader_test)
144 | real_image_test = real_image_test.to(rank, non_blocking=True)
145 | label_test = label_test.to(rank, non_blocking=True)
146 | params_test = params_test.to(rank, non_blocking=True)
147 | bases_1 = trainer.sample_bases(person_2 = False)
148 | display_bases(bases_1, 'person_1', args)
149 |
150 |
151 |
152 | img_recon = trainer.sample(real_image_test, label_test, params_test)
153 | display_img(i, real_image_test, 'source', writer, args)
154 |
155 | display_img(i, img_recon, 'recon', writer, args)
156 | writer.flush()
157 |
158 | # save model
159 | if (i+1) % args.save_freq == 0 and rank == 0:
160 | trainer.save(i, checkpoint_path)
161 |
162 | return
163 |
164 |
165 | if __name__ == "__main__":
166 | # training params
167 | parser = argparse.ArgumentParser()
168 | parser.add_argument("--iter", type=int, default=800000)
169 | parser.add_argument("--size", type=int, default=256)
170 | parser.add_argument("--batch_size", type=int, default=1)
171 | parser.add_argument("--dataset", type=str, default='nerface')
172 | parser.add_argument("--person_1", type=str, default='person_2')
173 | parser.add_argument("--run_id", type=str, default='nerface2')
174 | parser.add_argument("--run_id_2", type=str, default=None)
175 | parser.add_argument("--emb_dir", type=str, default='./PTI/embeddings/')
176 |
177 | parser.add_argument("--person_2", type=str, default=None)
178 |
179 |
180 | parser.add_argument("--params_len", type=int, default=76)
181 | parser.add_argument("--d_reg_every", type=int, default=16)
182 | parser.add_argument("--g_reg_every", type=int, default=4)
183 | parser.add_argument("--resume_ckpt", type=str, default='./code/exps/nerface2-v1/checkpoint/874999.pt')
184 | parser.add_argument("--lr", type=float, default=3e-4)
185 | parser.add_argument("--old", action='store_true', default=True)
186 | parser.add_argument("--tune", action='store_true', default=False)
187 | parser.add_argument("--init", action='store_true', default=False)
188 |
189 | parser.add_argument("--channel_multiplier", type=int, default=1)
190 | parser.add_argument("--start_iter", type=int, default=0)
191 | parser.add_argument("--display_freq", type=int, default=100)
192 | parser.add_argument("--save_freq", type=int, default=5000)
193 | parser.add_argument("--latent_dim_style", type=int, default=512)
194 | parser.add_argument("--latent_dim_shape", type=int, default=50)
195 | parser.add_argument("--exp_path", type=str, default='./code/exps/')
196 | parser.add_argument("--exp_name", type=str, default='v1')
197 | parser.add_argument("--addr", type=str, default='localhost')
198 | parser.add_argument("--port", type=str, default='12345')
199 | parser.add_argument("--tune_iter", type=int, default=50000)
200 |
201 |
202 | opts = parser.parse_args()
203 |
204 | n_gpus = torch.cuda.device_count()
205 | print('==> training on %d gpus' % n_gpus)
206 | world_size = n_gpus
207 | if world_size == 1:
208 | main(rank=0, world_size = world_size, args=opts)
209 | elif world_size > 1:
210 | mp.spawn(main, args=(world_size, opts,), nprocs=world_size, join=True)
211 |
--------------------------------------------------------------------------------
/code/train_rgb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | from torch.utils import data
5 | from dataset import HeadData
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | from trainer_rgb import Trainer
9 | from torch.utils.tensorboard import SummaryWriter
10 | import torch.distributed as dist
11 | import torch.multiprocessing as mp
12 |
13 | torch.backends.cudnn.enabled = True
14 | torch.backends.cudnn.benchmark = True
15 |
16 |
17 | def data_sampler(dataset, shuffle):
18 | if shuffle:
19 | return data.RandomSampler(dataset)
20 | else:
21 | return data.SequentialSampler(dataset)
22 |
23 |
24 | def sample_data(loader):
25 | while True:
26 | for batch in loader:
27 | yield batch
28 |
29 |
30 | def display_img(idx, img, name, writer, args):
31 | img = img.clamp(-1, 1)
32 | img = ((img - img.min()) / (img.max() - img.min())).data
33 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/display/'+str(idx)+name+'.png')
34 |
35 | writer.add_images(tag='%s' % (name), global_step=idx, img_tensor=img)
36 |
37 |
38 | def display_bases(imgs, name, args):
39 | for idx in range(len(imgs)):
40 | img = imgs[idx]
41 | img = img.clamp(-1, 1)
42 | img = ((img - img.min()) / (img.max() - img.min())).data
43 | torchvision.utils.save_image(img, args.exp_path + args.exp_name + '/bases/'+str(idx)+name+'.png')
44 |
45 |
46 |
47 | def write_loss(i, l2_loss, lpips_loss, writer):
48 | writer.add_scalar('l2_loss', l2_loss.item(), i)
49 | writer.add_scalar('lpips_loss', lpips_loss.item(), i)
50 | writer.flush()
51 |
52 |
53 | def ddp_setup(args, rank, world_size):
54 | os.environ['MASTER_ADDR'] = args.addr
55 | os.environ['MASTER_PORT'] = args.port
56 |
57 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
58 |
59 |
60 | def main(rank, world_size, args):
61 | # init distributed computing
62 | ddp_setup(args, rank, world_size)
63 | torch.cuda.set_device(rank)
64 | device = torch.device("cuda")
65 |
66 | # make logging folder
67 | log_path = os.path.join(args.exp_path, args.exp_name + '/log')
68 | checkpoint_path = os.path.join(args.exp_path, args.exp_name + '/checkpoint')
69 | display_path = os.path.join(args.exp_path, args.exp_name + '/display')
70 | bases_path = os.path.join(args.exp_path, args.exp_name + '/bases')
71 | os.makedirs(log_path, exist_ok=True)
72 | os.makedirs(checkpoint_path, exist_ok=True)
73 | os.makedirs(display_path, exist_ok=True)
74 | os.makedirs(bases_path, exist_ok=True)
75 | writer = SummaryWriter(log_path)
76 |
77 | print('==> preparing dataset')
78 | transform = torchvision.transforms.Compose([
79 | transforms.Resize(args.size),
80 | transforms.ToTensor(),
81 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
82 | dataset = HeadData('train', transform, dataset = args.dataset , person = args.person)
83 | dataset_test = HeadData('test', transform, dataset = args.dataset , person = args.person )
84 |
85 | loader = data.DataLoader(
86 | dataset,
87 |
88 | batch_size=args.batch_size // world_size,
89 | sampler=data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True),
90 | pin_memory=True,
91 | drop_last=False,
92 | )
93 |
94 | loader_test = data.DataLoader(
95 | dataset_test,
96 | batch_size=1,
97 | sampler=data.distributed.DistributedSampler(dataset_test, num_replicas=world_size, rank=rank, shuffle=False),
98 | pin_memory=True,
99 | drop_last=False,
100 | )
101 |
102 | loader = sample_data(loader)
103 | loader_test = sample_data(loader_test)
104 | print('==> initializing trainer')
105 | # Trainer
106 | trainer = Trainer(args, device, rank)
107 |
108 | # resume
109 | if args.resume_ckpt is not None:
110 | args.start_iter = trainer.resume(args.resume_ckpt)
111 | print('==> resume from iteration %d' % (args.start_iter))
112 |
113 | print('==> training')
114 | pbar = range(args.iter)
115 | for idx in pbar:
116 | i = idx + args.start_iter
117 |
118 | # laoding data
119 | real_image, label = next(loader)
120 | real_image = real_image.to(rank, non_blocking=True)
121 | label = label.to(rank, non_blocking=True)
122 |
123 |
124 | # update generator
125 | l2_loss, loss_lpips, generated_image = trainer.gen_update(real_image, label, person_2 = False)
126 |
127 |
128 | if rank == 0:
129 | # write to log
130 | write_loss(idx, l2_loss, loss_lpips, writer)
131 |
132 | if (i+1) >= args.tune_iter:
133 | # print('begin training nerf')
134 | trainer.tune_generator()
135 | # display
136 | if (i+1) % args.display_freq == 0 and rank == 0:
137 | print("[Iter %d/%d] [l2 loss: %f] [lpips loss: %f]"
138 | % (i, args.iter, l2_loss.item(), loss_lpips.item()))
139 |
140 | if rank == 0:
141 | real_image_test, label_test = next(loader_test)
142 | real_image_test = real_image_test.to(rank, non_blocking=True)
143 | label_test = label_test.to(rank, non_blocking=True)
144 |
145 | img_recon = trainer.sample(real_image_test, label_test)
146 | bases_1 = trainer.sample_bases(person_2 = False)
147 | display_bases(bases_1, 'person_1', args)
148 | display_img(i, real_image_test, 'source', writer, args)
149 | display_img(i, img_recon, 'recon', writer, args)
150 | writer.flush()
151 |
152 | # save model
153 | if (i+1) % args.save_freq == 0 and rank == 0:
154 | trainer.save(i, checkpoint_path)
155 |
156 | return
157 |
158 |
159 | if __name__ == "__main__":
160 | # training params
161 | parser = argparse.ArgumentParser()
162 | parser.add_argument("--iter", type=int, default=800000)
163 | parser.add_argument("--size", type=int, default=256)
164 | parser.add_argument("--batch_size", type=int, default=2)
165 | parser.add_argument("--dataset", type=str, default='nerface_dataset')
166 | parser.add_argument("--person", type=str, default='person_3')
167 | parser.add_argument("--person_2", type=str, default=None)
168 | parser.add_argument("--run_id", type=str, default='nerface2')
169 | parser.add_argument("--run_id_2", type=str, default=None)
170 | parser.add_argument("--emb_dir", type=str, default='./PTI/embeddings/')
171 |
172 | parser.add_argument("--d_reg_every", type=int, default=16)
173 | parser.add_argument("--g_reg_every", type=int, default=4)
174 | parser.add_argument("--resume_ckpt", type=str, default=None)
175 | parser.add_argument("--old", action='store_true', default=True)
176 | parser.add_argument("--tune", action='store_true', default=True)
177 | parser.add_argument("--init", action='store_true', default=False)
178 | parser.add_argument("--same_bases", action='store_true', default=False)
179 | parser.add_argument("--out_pose", action='store_true', default=False)
180 | parser.add_argument("--lr", type=float, default=3e-4)
181 |
182 |
183 | parser.add_argument("--channel_multiplier", type=int, default=1)
184 | parser.add_argument("--start_iter", type=int, default=0)
185 | parser.add_argument("--display_freq", type=int, default=5000)
186 | parser.add_argument("--save_freq", type=int, default=5000)
187 | parser.add_argument("--latent_dim_style", type=int, default=512)
188 | parser.add_argument("--latent_dim_shape", type=int, default=50)
189 | parser.add_argument("--exp_path", type=str, default='./code/exps/')
190 | parser.add_argument("--exp_name", type=str, default='v1')
191 | parser.add_argument("--addr", type=str, default='localhost')
192 | parser.add_argument("--port", type=str, default='12345')
193 | parser.add_argument("--tune_iter", type=int, default=50000)
194 | opts = parser.parse_args()
195 |
196 | n_gpus = torch.cuda.device_count()
197 | print('==> training on %d gpus' % n_gpus)
198 | world_size = n_gpus
199 | if world_size == 1:
200 | main(rank=0, world_size = world_size, args=opts)
201 | elif world_size > 1:
202 | mp.spawn(main, args=(world_size, opts,), nprocs=world_size, join=True)
203 |
--------------------------------------------------------------------------------
/code/trainer_3dmm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from networks.headnerf import HeadNeRF_3DMM
3 | import torch.nn.functional as F
4 | from torch import nn, optim
5 | import os
6 | from torch.nn.parallel import DistributedDataParallel as DDP
7 |
8 | from lpips import LPIPS
9 | def requires_grad(net, flag=True):
10 | for p in net.parameters():
11 | p.requires_grad = flag
12 |
13 | l2_criterion = torch.nn.MSELoss(reduction='mean')
14 |
15 | from cam_utils import sample_camera_positions, create_cam2world_matrix
16 | import math
17 | import torchvision.transforms as transforms
18 |
19 | class Trainer(nn.Module):
20 | def __init__(self, args, device, rank):
21 | super(Trainer, self).__init__()
22 |
23 | self.args = args
24 | self.batch_size = args.batch_size
25 |
26 | self.gen = HeadNeRF_3DMM(args, args.size, device, args.latent_dim_style, args.latent_dim_shape, args.run_id, args.emb_dir).to(
27 | device)
28 |
29 | self.gen = DDP(self.gen, device_ids=[rank], find_unused_parameters=True)
30 | self.w_optim = torch.optim.Adam(self.gen.parameters(), lr= args.lr)
31 | for param in self.gen.module.generator.parameters():
32 | param.requires_grad = False
33 | self.lpips_loss = LPIPS(net='alex').to(device).eval()
34 | self.device = device
35 | self.face_pool = torch.nn.AdaptiveAvgPool2d((args.size, args.size))
36 |
37 | def l2_loss(self, real_images, generated_images):
38 | loss = l2_criterion(real_images, generated_images)
39 | return loss
40 | def tune_generator(self):
41 | for param in self.gen.module.generator.parameters():
42 | param.requires_grad = True
43 | def gen_update(self, real_image, label, params, person_2 = False):
44 | self.gen.train()
45 | self.w_optim.zero_grad()
46 |
47 |
48 |
49 |
50 | generated_image = self.gen(params, label, person_2)
51 | generated_image = self.face_pool(generated_image)
52 |
53 | l2_loss_3dmm = torch.zeros(1).to(self.device) #self.l2_loss(weights, generated_weights)
54 | l2_loss = self.l2_loss(real_image, generated_image)
55 | loss_lpips = self.lpips_loss( real_image, generated_image)
56 | loss_lpips = torch.squeeze(loss_lpips).mean()
57 |
58 |
59 | g_loss = l2_loss_3dmm + l2_loss + loss_lpips
60 |
61 |
62 | g_loss.backward()
63 |
64 |
65 | self.w_optim.step()
66 |
67 | return l2_loss_3dmm, l2_loss, loss_lpips, generated_image
68 |
69 |
70 | def sample(self, real_image, label, params, person_2 = False):
71 | with torch.no_grad():
72 | self.gen.eval()
73 |
74 | img_recon = self.gen(params, label, person_2)
75 |
76 |
77 | return img_recon#, img_source_ref
78 | def sample_bases(self, person_2 = False ):
79 | img_recons = []
80 | with torch.no_grad():
81 | r = 2.7
82 | points, _, _ = sample_camera_positions(device=self.device, n=1, r=r, horizontal_mean=0.5*math.pi, vertical_mean=0.5*math.pi, mode=None)
83 | label = create_cam2world_matrix(-points, points, device=self.device)
84 | label = label.reshape(1, -1)
85 | label = torch.cat((label, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(1, 1).to(label)), -1)
86 | self.gen.eval()
87 |
88 | for base_id in range(self.args.latent_dim_shape):
89 | weights = torch.zeros(self.args.latent_dim_shape).to(self.device)
90 | weights[base_id] = 5
91 | weights = weights.unsqueeze(0)
92 |
93 |
94 | latent = self.gen.module.get_latent(weights,person_2)
95 | img_recon = self.gen.module.get_image(latent, label)
96 | img_recons.append(img_recon)
97 |
98 | return img_recons#, img_source_ref
99 |
100 |
101 | def resume(self, resume_ckpt):
102 | print("load model:", resume_ckpt)
103 | ckpt = torch.load(resume_ckpt)
104 | ckpt_name = os.path.basename(resume_ckpt)
105 | start_iter = int(os.path.splitext(ckpt_name)[0])
106 |
107 |
108 | self.gen.module.load_state_dict(ckpt["gen"])
109 |
110 | self.w_optim.load_state_dict(ckpt["w_optim"])
111 |
112 | return start_iter
113 |
114 | def save(self, idx, checkpoint_path):
115 | torch.save(
116 | {
117 | "gen": self.gen.module.state_dict(),
118 | "w_optim": self.w_optim.state_dict(),
119 | "args": self.args
120 | },
121 | f"{checkpoint_path}/{str(idx).zfill(6)}.pt"
122 | )
123 |
--------------------------------------------------------------------------------
/code/trainer_rgb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | # from networks.discriminator import Discriminator
3 | # from networks.generator import Generator
4 | from networks.headnerf import HeadNeRF_final
5 | import torch.nn.functional as F
6 | from torch import nn, optim
7 | import os
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 |
10 | from lpips import LPIPS
11 | def requires_grad(net, flag=True):
12 | for p in net.parameters():
13 | p.requires_grad = flag
14 |
15 | l2_criterion = torch.nn.MSELoss(reduction='mean')
16 |
17 | from cam_utils import sample_camera_positions, create_cam2world_matrix
18 | import math
19 |
20 | import torchvision.transforms as transforms
21 |
22 |
23 |
24 |
25 |
26 |
27 | def cam_sampler(batch, device):
28 | camera_points, phi, theta = sample_camera_positions(device, n=batch, r=2.7, horizontal_mean=0.5*math.pi,
29 | vertical_mean=0.5*math.pi, horizontal_stddev=0.3, vertical_stddev=0.155, mode='gaussian')
30 | c = create_cam2world_matrix(-camera_points, camera_points, device=device)
31 | c = c.reshape(batch, -1)
32 | c = torch.cat((c, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(batch, 1).to(c)), -1)
33 | return c
34 |
35 |
36 | def cam_sampler_pose(batch, horizontal_mean, vertical_mean , device):
37 | camera_points, phi, theta = sample_camera_positions(device, n=batch, r=2.7, horizontal_mean=horizontal_mean*math.pi,
38 | vertical_mean=vertical_mean*math.pi, horizontal_stddev=0.15, vertical_stddev=0.155, mode='gaussian')
39 | c = create_cam2world_matrix(-camera_points, camera_points, device=device)
40 | c = c.reshape(batch, -1)
41 | c = torch.cat((c, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(batch, 1).to(c)), -1)
42 | return c
43 |
44 |
45 |
46 | class Trainer(nn.Module):
47 | def __init__(self, args, device, rank):
48 | super(Trainer, self).__init__()
49 |
50 | self.args = args
51 | self.batch_size = args.batch_size
52 | self.device = device
53 |
54 | self.gen = HeadNeRF_final(args, args.size, device, args.latent_dim_style, args.latent_dim_shape, args.run_id, args.emb_dir).to(
55 | device)
56 | self.gen = DDP(self.gen, device_ids=[rank], broadcast_buffers=False, find_unused_parameters=True)
57 |
58 | self.g_optim = torch.optim.Adam(self.gen.parameters(), lr= args.lr)
59 | for param in self.gen.module.generator.parameters():
60 | param.requires_grad = False
61 |
62 | self.lpips_loss = LPIPS(net='alex').to(device).eval()
63 | self.face_pool = torch.nn.AdaptiveAvgPool2d((args.size, args.size))
64 |
65 | def l2_loss(self, real_images, generated_images):
66 | loss = l2_criterion(real_images, generated_images)
67 | return loss
68 |
69 | def tune_generator(self):
70 | for param in self.gen.module.generator.parameters():
71 | param.requires_grad = True
72 |
73 | def gen_update(self, real_image, label, person_2 = False, mask = None):
74 | self.gen.train()
75 | self.g_optim.zero_grad()
76 |
77 | weights_i = self.gen.module.get_weights(real_image)
78 | latent_i = self.gen.module.get_latent(weights_i, person_2)
79 | generated_image = self.gen.module.get_image(latent_i, label)
80 |
81 |
82 |
83 |
84 | generated_image = self.face_pool(generated_image)
85 | l2_loss = self.l2_loss(real_image, generated_image)
86 | loss_lpips = self.lpips_loss( real_image, generated_image)
87 | loss_lpips = torch.squeeze(loss_lpips).mean()
88 |
89 |
90 |
91 | g_loss = (l2_loss + loss_lpips )
92 |
93 | g_loss.backward()
94 |
95 |
96 | self.g_optim.step()
97 |
98 | return l2_loss, loss_lpips, generated_image
99 |
100 | def sample(self, real_image, label, person_2 = False ):
101 | with torch.no_grad():
102 | self.gen.eval()
103 | img_recon = self.gen(real_image, label, person_2)
104 |
105 |
106 | return img_recon
107 |
108 | def sample_bases(self, person_2 = False ):
109 | img_recons = []
110 | with torch.no_grad():
111 | r = 2.7
112 | points, _, _ = sample_camera_positions(device=self.device, n=1, r=r, horizontal_mean=0.5*math.pi, vertical_mean=0.5*math.pi, mode=None)
113 | label = create_cam2world_matrix(-points, points, device=self.device)
114 | label = label.reshape(1, -1)
115 | label = torch.cat((label, torch.tensor([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1]).reshape(1, -1).repeat(1, 1).to(label)), -1)
116 | self.gen.eval()
117 |
118 | for base_id in range(self.args.latent_dim_shape):
119 | weights = torch.zeros(self.args.latent_dim_shape).to(self.device)
120 | weights[base_id] = 10
121 | weights = weights.unsqueeze(0)
122 |
123 | latent = self.gen.module.get_latent(weights,person_2)
124 | img_recon = self.gen.module.get_image(latent, label)
125 | img_recons.append(img_recon)
126 |
127 | return img_recons
128 |
129 |
130 | def resume(self, resume_ckpt):
131 | print("load model:", resume_ckpt)
132 | ckpt = torch.load(resume_ckpt)
133 | ckpt_name = os.path.basename(resume_ckpt)
134 | start_iter = int(os.path.splitext(ckpt_name)[0])
135 |
136 | self.gen.module.load_state_dict(ckpt["gen"])
137 |
138 | self.g_optim.load_state_dict(ckpt["g_optim"])
139 |
140 |
141 | return start_iter
142 |
143 | def save(self, idx, checkpoint_path):
144 | torch.save(
145 | {
146 | "gen": self.gen.module.state_dict(),
147 | "g_optim": self.g_optim.state_dict(),
148 | "args": self.args
149 | },
150 | f"{checkpoint_path}/{str(idx).zfill(6)}.pt"
151 | )
152 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/3dface2idr.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import json
5 | import argparse
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--in_root', type=str, default="", help='process folder')
10 | parser.add_argument('--out_root', type=str, default="output", help='output folder')
11 | args = parser.parse_args()
12 | in_root = args.in_root
13 |
14 | def compute_rotation(angles):
15 | """
16 | Return:
17 | rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
18 |
19 | Parameters:
20 | angles -- torch.tensor, size (B, 3), radian
21 | """
22 |
23 | batch_size = angles.shape[0]
24 | ones = torch.ones([batch_size, 1])
25 | zeros = torch.zeros([batch_size, 1])
26 | x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
27 |
28 | rot_x = torch.cat([
29 | ones, zeros, zeros,
30 | zeros, torch.cos(x), -torch.sin(x),
31 | zeros, torch.sin(x), torch.cos(x)
32 | ], dim=1).reshape([batch_size, 3, 3])
33 |
34 | rot_y = torch.cat([
35 | torch.cos(y), zeros, torch.sin(y),
36 | zeros, ones, zeros,
37 | -torch.sin(y), zeros, torch.cos(y)
38 | ], dim=1).reshape([batch_size, 3, 3])
39 |
40 | rot_z = torch.cat([
41 | torch.cos(z), -torch.sin(z), zeros,
42 | torch.sin(z), torch.cos(z), zeros,
43 | zeros, zeros, ones
44 | ], dim=1).reshape([batch_size, 3, 3])
45 |
46 | rot = rot_z @ rot_y @ rot_x
47 | return rot.permute(0, 2, 1)[0]
48 |
49 | npys = sorted([x for x in os.listdir(in_root) if x.endswith(".npy")])
50 |
51 | mode = 1 #1 = IDR, 2 = LSX
52 | outAll={}
53 |
54 | for src_filename in npys:
55 | src = os.path.join(in_root, src_filename)
56 |
57 | # print(src)
58 | dict_load=np.load(src, allow_pickle=True)
59 |
60 | angle = dict_load.item()['angle']
61 | trans = dict_load.item()['trans'][0]
62 | R = compute_rotation(torch.from_numpy(angle)).numpy()
63 | trans[2] += -10
64 | c = -np.dot(R, trans)
65 | pose = np.eye(4)
66 | pose[:3, :3] = R
67 |
68 | c *= 0.27 # factor to match tripleganger
69 | c[1] += 0.006 # offset to align to tripleganger
70 | c[2] += 0.161 # offset to align to tripleganger
71 | pose[0,3] = c[0]
72 | pose[1,3] = c[1]
73 | pose[2,3] = c[2]
74 |
75 | focal = 2985.29 # = 1015*1024/224*(300/466.285)#
76 | pp = 512#112
77 | w = 1024#224
78 | h = 1024#224
79 |
80 | if mode==1:
81 | count = 0
82 | K = np.eye(3)
83 | K[0][0] = focal
84 | K[1][1] = focal
85 | K[0][2] = w/2.0
86 | K[1][2] = h/2.0
87 | K = K.tolist()
88 |
89 | Rot = np.eye(3)
90 | Rot[0, 0] = 1
91 | Rot[1, 1] = -1
92 | Rot[2, 2] = -1
93 | pose[:3, :3] = np.dot(pose[:3, :3], Rot)
94 |
95 | pose = pose.tolist()
96 | out = {}
97 | out["intrinsics"] = K
98 | out["pose"] = pose
99 | out["angle"] = (angle * [1, -1, 1]).flatten().tolist()
100 | outAll[src_filename.replace(".npy", ".png")] = out
101 |
102 | elif mode==2:
103 |
104 | dst = os.path.join(in_root, src_filename.replace(".npy", "_lscam.txt"))
105 | outCam = open(dst, "w")
106 | outCam.write("#focal length\n")
107 | outCam.write(str(focal) + " " + str(focal) + "\n")
108 |
109 | outCam.write("#principal point\n")
110 | outCam.write(str(pp) + " " + str(pp) + "\n")
111 |
112 | outCam.write("#resolution\n")
113 | outCam.write(str(w) + " " + str(h) + "\n")
114 |
115 | outCam.write("#distortion coeffs\n")
116 | outCam.write("0 0 0 0\n")
117 |
118 |
119 | outCam.write("MATRIX :\n")
120 | for r in range(4):
121 | outCam.write(str(pose[r, 0]) + " " + str(pose[r, 1]) + " " + str(pose[r, 2]) + " " + str(pose[r, 3]) + "\n")
122 |
123 | outCam.close()
124 |
125 | print("mode:", mode)
126 | print("out dir:", args.out_root)
127 | if mode == 1:
128 | dst = os.path.join(args.out_root, "cameras.json")
129 | with open(dst, "w") as outfile:
130 | json.dump(outAll, outfile, indent=4)
131 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/batch_mtcnn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import os
4 | from mtcnn import MTCNN
5 | import random
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | detector = MTCNN()
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--in_root', type=str, default="", help='process folder')
12 | args = parser.parse_args()
13 | in_root = args.in_root
14 |
15 | out_root = os.path.join(in_root, "debug")
16 | out_detection = os.path.join(in_root, "detections")
17 | if not os.path.exists(out_root):
18 | os.makedirs(out_root)
19 | if not os.path.exists(out_detection):
20 | os.makedirs(out_detection)
21 |
22 | imgs = sorted([x for x in os.listdir(in_root) if x.endswith(".jpg") or x.endswith(".png")])
23 | random.shuffle(imgs)
24 | for img in tqdm(imgs):
25 | src = os.path.join(in_root, img)
26 | dst = os.path.join(out_detection, img.replace(".jpg", ".txt").replace(".png", ".txt"))
27 |
28 | if not os.path.exists(dst):
29 | image = cv2.cvtColor(cv2.imread(src), cv2.COLOR_BGR2RGB)
30 | result = detector.detect_faces(image)
31 |
32 | if len(result)>0:
33 | index = 0
34 | if len(result)>1: # if multiple faces, take the biggest face
35 | # size = -100000
36 | lowest_dist = float('Inf')
37 | for r in range(len(result)):
38 | # print(result[r]["box"][0], result[r]["box"][1])
39 | face_pos = np.array(result[r]["box"][:2]) + np.array(result[r]["box"][2:])/2
40 |
41 | dist_from_center = np.linalg.norm(face_pos - np.array([1500./2, 1500./2]))
42 | if dist_from_center < lowest_dist:
43 | lowest_dist = dist_from_center
44 | index=r
45 |
46 |
47 | # size_ = result[r]["box"][2] + result[r]["box"][3]
48 | # if size < size_:
49 | # size = size_
50 | # index = r
51 |
52 | # Result is an array with all the bounding boxes detected. We know that for 'ivan.jpg' there is only one.
53 | bounding_box = result[index]['box']
54 | keypoints = result[index]['keypoints']
55 | if result[index]["confidence"] > 0.9:
56 |
57 | cv2.rectangle(image,
58 | (bounding_box[0], bounding_box[1]),
59 | (bounding_box[0]+bounding_box[2], bounding_box[1] + bounding_box[3]),
60 | (0,155,255),
61 | 2)
62 |
63 | cv2.circle(image,(keypoints['left_eye']), 2, (0,155,255), 2)
64 | cv2.circle(image,(keypoints['right_eye']), 2, (0,155,255), 2)
65 | cv2.circle(image,(keypoints['nose']), 2, (0,155,255), 2)
66 | cv2.circle(image,(keypoints['mouth_left']), 2, (0,155,255), 2)
67 | cv2.circle(image,(keypoints['mouth_right']), 2, (0,155,255), 2)
68 |
69 | dst = os.path.join(out_root, img)
70 | # cv2.imwrite(dst, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
71 |
72 | dst = os.path.join(out_detection, img.replace(".jpg", ".txt").replace(".png", ".txt"))
73 | outLand = open(dst, "w")
74 | outLand.write(str(float(keypoints['left_eye'][0])) + " " + str(float(keypoints['left_eye'][1])) + "\n")
75 | outLand.write(str(float(keypoints['right_eye'][0])) + " " + str(float(keypoints['right_eye'][1])) + "\n")
76 | outLand.write(str(float(keypoints['nose'][0])) + " " + str(float(keypoints['nose'][1])) + "\n")
77 | outLand.write(str(float(keypoints['mouth_left'][0])) + " " + str(float(keypoints['mouth_left'][1])) + "\n")
78 | outLand.write(str(float(keypoints['mouth_right'][0])) + " " + str(float(keypoints['mouth_right'][1])) + "\n")
79 | outLand.close()
--------------------------------------------------------------------------------
/eg3d-pose-detection/camera2label.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import argparse
4 | import os
5 | """
6 | convert cameras.json into the format of eg3d label
7 | """
8 | # fname = '/apdcephfs_cq2/share_1290939/kitbai/PTI/data/nerface/1/cropped_images/cameras.json'
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--in_root', type=str, default="", help='process folder')
11 | args = parser.parse_args()
12 | in_root = args.in_root
13 |
14 | fname = os.path.join(in_root, "cropped_images/cameras.json")
15 | # fname = '/apdcephfs_cq2/share_1290939/kitbai/LIA-3d/datasets/ad_dataset/french/train/cropped_images/cameras.json'
16 | # fname = '/apdcephfs_cq2/share_1290939/kitbai/celeba_sub/images1/cropped_images/cameras.json'
17 | with open(fname, 'rb') as f:
18 | labels = json.load(f)
19 |
20 | results_new = []
21 | for ind in labels.keys():
22 | pose = np.array(labels[ind]["pose"]).reshape(16)
23 | pose = list(pose) + list([4.2647, 0, 0.5, 0, 4.2647, 0.5, 0, 0, 1])
24 | results_new.append((ind, pose))
25 |
26 | # with open("/apdcephfs_cq2/share_1290939/kitbai/PTI/data/nerface/1/cropped_images/test.json", 'w') as outfile:
27 | # with open("/apdcephfs_cq2/share_1290939/kitbai/LIA-3d/datasets/ad_dataset/french/train/cropped_images/test.json", 'w') as outfile:
28 | # json.dump({"labels": results_new}, outfile, indent="\t")
29 | with open(os.path.join(in_root, "cropped_images/test.json"), 'w') as outfile:
30 | json.dump({"labels": results_new}, outfile, indent="\t")
31 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/crop_images.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 |
5 | import numpy as np
6 | from PIL import Image
7 | from tqdm import tqdm
8 |
9 | # calculating least square problem for image alignment
10 | def POS(xp, x):
11 |
12 | npts = xp.shape[1]
13 |
14 | A = np.zeros([2*npts, 8])
15 |
16 | A[0:2*npts-1:2, 0:3] = x.transpose()
17 | A[0:2*npts-1:2, 3] = 1
18 | A[1:2*npts:2, 4:7] = x.transpose()
19 | A[1:2*npts:2, 7] = 1
20 |
21 | b = np.reshape(xp.transpose(), [2*npts, 1])
22 |
23 | k, _, _, _ = np.linalg.lstsq(A, b)
24 |
25 | R1 = k[0:3]
26 | R2 = k[4:7]
27 | sTx = k[3]
28 | sTy = k[7]
29 | s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
30 |
31 | t = np.stack([sTx, sTy], axis=0)
32 |
33 | return t, s
34 |
35 | def extract_5p(lm):
36 | lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
37 | lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
38 | lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
39 | lm5p = lm5p[[1, 2, 0, 3, 4], :]
40 | return lm5p
41 |
42 | # resize and crop images for face reconstruction
43 | def resize_n_crop_img(img, lm, t, s, target_size=1024., mask=None):
44 | w0, h0 = img.size
45 | w = (w0*s).astype(np.int32)
46 | h = (h0*s).astype(np.int32)
47 | left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
48 | right = left + target_size
49 | up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
50 | below = up + target_size
51 | img = img.resize((w, h), resample=Image.LANCZOS)
52 | img = img.crop((left, up, right, below))
53 |
54 | if mask is not None:
55 | mask = mask.resize((w, h), resample=Image.LANCZOS)
56 | mask = mask.crop((left, up, right, below))
57 |
58 | lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
59 | t[1] + h0/2], axis=1)*s
60 | lm = lm - np.reshape(
61 | np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
62 | return img, lm, mask
63 |
64 |
65 | # utils for face reconstruction
66 | def align_img(img, lm, lm3D, mask=None, target_size=1024., rescale_factor=466.285):
67 | """
68 | Return:
69 | transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
70 | img_new --PIL.Image (target_size, target_size, 3)
71 | lm_new --numpy.array (68, 2), y direction is opposite to v direction
72 | mask_new --PIL.Image (target_size, target_size)
73 |
74 | Parameters:
75 | img --PIL.Image (raw_H, raw_W, 3)
76 | lm --numpy.array (68, 2), y direction is opposite to v direction
77 | lm3D --numpy.array (5, 3)
78 | mask --PIL.Image (raw_H, raw_W, 3)
79 | """
80 |
81 | w0, h0 = img.size
82 | if lm.shape[0] != 5:
83 | lm5p = extract_5p(lm)
84 | else:
85 | lm5p = lm
86 |
87 | # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
88 | t, s = POS(lm5p.transpose(), lm3D.transpose())
89 | s = rescale_factor/s
90 |
91 | # processing the image
92 | img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
93 | # img.save("/home/koki/Projects/Deep3DFaceRecon_pytorch/checkpoints/pretrained/results/iphone/epoch_20_000000/img_new.jpg")
94 | trans_params = np.array([w0, h0, s, t[0], t[1]])
95 | lm_new *= 224/1024.0
96 | img_new_low = img_new.resize((224, 224), resample=Image.LANCZOS)
97 |
98 | return trans_params, img_new_low, lm_new, mask_new, img_new
99 |
100 |
101 | if __name__ == '__main__':
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument('--indir', type=str, required=True)
104 | parser.add_argument('--outdir', type=str, required=True)
105 | parser.add_argument('--compress_level', type=int, default=0)
106 | args = parser.parse_args()
107 |
108 | with open(os.path.join(args.indir, 'cropping_params.json')) as f:
109 | cropping_params = json.load(f)
110 |
111 | os.makedirs(args.outdir, exist_ok=True)
112 |
113 | for im_path, cropping_dict in tqdm(cropping_params.items()):
114 | im = Image.open(os.path.join(args.indir, im_path)).convert('RGB')
115 |
116 | _, H = im.size
117 | lm = np.array(cropping_dict['lm'])
118 | lm = lm.reshape([-1, 2])
119 | lm[:, -1] = H - 1 - lm[:, -1]
120 |
121 | _, im_pil, lm, _, im_high = align_img(im, lm, np.array(cropping_dict['lm3d_std']), rescale_factor=cropping_dict['rescale_factor'])
122 |
123 | left = int(im_high.size[0]/2 - cropping_dict['center_crop_size']/2)
124 | upper = int(im_high.size[1]/2 - cropping_dict['center_crop_size']/2)
125 | right = left + cropping_dict['center_crop_size']
126 | lower = upper + cropping_dict['center_crop_size']
127 | im_cropped = im_high.crop((left, upper, right,lower))
128 | im_cropped = im_cropped.resize((cropping_dict['output_size'], cropping_dict['output_size']), resample=Image.LANCZOS)
129 |
130 | im_cropped.save(os.path.join(args.outdir, os.path.basename(im_path)), compress_level=args.compress_level)
131 | im_high.save(os.path.join(args.indir, 'crop_1024', os.path.basename(im_path)), compress_level=args.compress_level)
--------------------------------------------------------------------------------
/eg3d-pose-detection/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import numpy as np
14 | import importlib
15 | import torch.utils.data
16 | from data.base_dataset import BaseDataset
17 |
18 |
19 | def find_dataset_using_name(dataset_name):
20 | """Import the module "data/[dataset_name]_dataset.py".
21 |
22 | In the file, the class called DatasetNameDataset() will
23 | be instantiated. It has to be a subclass of BaseDataset,
24 | and it is case-insensitive.
25 | """
26 | dataset_filename = "data." + dataset_name + "_dataset"
27 | datasetlib = importlib.import_module(dataset_filename)
28 |
29 | dataset = None
30 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
31 | for name, cls in datasetlib.__dict__.items():
32 | if name.lower() == target_dataset_name.lower() \
33 | and issubclass(cls, BaseDataset):
34 | dataset = cls
35 |
36 | if dataset is None:
37 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
38 |
39 | return dataset
40 |
41 |
42 | def get_option_setter(dataset_name):
43 | """Return the static method of the dataset class."""
44 | dataset_class = find_dataset_using_name(dataset_name)
45 | return dataset_class.modify_commandline_options
46 |
47 |
48 | def create_dataset(opt, rank=0):
49 | """Create a dataset given the option.
50 |
51 | This function wraps the class CustomDatasetDataLoader.
52 | This is the main interface between this package and 'train.py'/'test.py'
53 |
54 | Example:
55 | >>> from data import create_dataset
56 | >>> dataset = create_dataset(opt)
57 | """
58 | data_loader = CustomDatasetDataLoader(opt, rank=rank)
59 | dataset = data_loader.load_data()
60 | return dataset
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt, rank=0):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | self.sampler = None
75 | print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
76 | if opt.use_ddp and opt.isTrain:
77 | world_size = opt.world_size
78 | self.sampler = torch.utils.data.distributed.DistributedSampler(
79 | self.dataset,
80 | num_replicas=world_size,
81 | rank=rank,
82 | shuffle=not opt.serial_batches
83 | )
84 | self.dataloader = torch.utils.data.DataLoader(
85 | self.dataset,
86 | sampler=self.sampler,
87 | num_workers=int(opt.num_threads / world_size),
88 | batch_size=int(opt.batch_size / world_size),
89 | drop_last=True)
90 | else:
91 | self.dataloader = torch.utils.data.DataLoader(
92 | self.dataset,
93 | batch_size=opt.batch_size,
94 | shuffle=(not opt.serial_batches) and opt.isTrain,
95 | num_workers=int(opt.num_threads),
96 | drop_last=True
97 | )
98 |
99 | def set_epoch(self, epoch):
100 | self.dataset.current_epoch = epoch
101 | if self.sampler is not None:
102 | self.sampler.set_epoch(epoch)
103 |
104 | def load_data(self):
105 | return self
106 |
107 | def __len__(self):
108 | """Return the number of data in the dataset"""
109 | return min(len(self.dataset), self.opt.max_dataset_size)
110 |
111 | def __iter__(self):
112 | """Return a batch of data"""
113 | for i, data in enumerate(self.dataloader):
114 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
115 | break
116 | yield data
117 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | # self.root = opt.dataroot
31 | self.current_epoch = 0
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 |
64 | def get_transform(grayscale=False):
65 | transform_list = []
66 | if grayscale:
67 | transform_list.append(transforms.Grayscale(1))
68 | transform_list += [transforms.ToTensor()]
69 | return transforms.Compose(transform_list)
70 |
71 | def get_affine_mat(opt, size):
72 | shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
73 | w, h = size
74 |
75 | if 'shift' in opt.preprocess:
76 | shift_pixs = int(opt.shift_pixs)
77 | shift_x = random.randint(-shift_pixs, shift_pixs)
78 | shift_y = random.randint(-shift_pixs, shift_pixs)
79 | if 'scale' in opt.preprocess:
80 | scale = 1 + opt.scale_delta * (2 * random.random() - 1)
81 | if 'rot' in opt.preprocess:
82 | rot_angle = opt.rot_angle * (2 * random.random() - 1)
83 | rot_rad = -rot_angle * np.pi/180
84 | if 'flip' in opt.preprocess:
85 | flip = random.random() > 0.5
86 |
87 | shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
88 | flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
89 | shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
90 | rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
91 | scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
92 | shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
93 |
94 | affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
95 | affine_inv = np.linalg.inv(affine)
96 | return affine, affine_inv, flip
97 |
98 | def apply_img_affine(img, affine_inv, method=Image.LANCZOS):
99 | return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.LANCZOS)
100 |
101 | def apply_lm_affine(landmark, affine, flip, size):
102 | _, h = size
103 | lm = landmark.copy()
104 | lm[:, 1] = h - 1 - lm[:, 1]
105 | lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
106 | lm = lm @ np.transpose(affine)
107 | lm[:, :2] = lm[:, :2] / lm[:, 2:]
108 | lm = lm[:, :2]
109 | lm[:, 1] = h - 1 - lm[:, 1]
110 | if flip:
111 | lm_ = lm.copy()
112 | lm_[:17] = lm[16::-1]
113 | lm_[17:22] = lm[26:21:-1]
114 | lm_[22:27] = lm[21:16:-1]
115 | lm_[31:36] = lm[35:30:-1]
116 | lm_[36:40] = lm[45:41:-1]
117 | lm_[40:42] = lm[47:45:-1]
118 | lm_[42:46] = lm[39:35:-1]
119 | lm_[46:48] = lm[41:39:-1]
120 | lm_[48:55] = lm[54:47:-1]
121 | lm_[55:60] = lm[59:54:-1]
122 | lm_[60:65] = lm[64:59:-1]
123 | lm_[65:68] = lm[67:64:-1]
124 | lm = lm_
125 | return lm
126 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/data/flist_dataset.py:
--------------------------------------------------------------------------------
1 | """This script defines the custom dataset for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os.path
5 | from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
6 | from data.image_folder import make_dataset
7 | from PIL import Image
8 | import random
9 | import util.util as util
10 | import numpy as np
11 | import json
12 | import torch
13 | from scipy.io import loadmat, savemat
14 | import pickle
15 | from util.preprocess import align_img, estimate_norm
16 | from util.load_mats import load_lm3d
17 |
18 |
19 | def default_flist_reader(flist):
20 | """
21 | flist format: impath label\nimpath label\n ...(same to caffe's filelist)
22 | """
23 | imlist = []
24 | with open(flist, 'r') as rf:
25 | for line in rf.readlines():
26 | impath = line.strip()
27 | imlist.append(impath)
28 |
29 | return imlist
30 |
31 | def jason_flist_reader(flist):
32 | with open(flist, 'r') as fp:
33 | info = json.load(fp)
34 | return info
35 |
36 | def parse_label(label):
37 | return torch.tensor(np.array(label).astype(np.float32))
38 |
39 |
40 | class FlistDataset(BaseDataset):
41 | """
42 | It requires one directories to host training images '/path/to/data/train'
43 | You can train the model with the dataset flag '--dataroot /path/to/data'.
44 | """
45 |
46 | def __init__(self, opt):
47 | """Initialize this dataset class.
48 |
49 | Parameters:
50 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
51 | """
52 | BaseDataset.__init__(self, opt)
53 |
54 | self.lm3d_std = load_lm3d(opt.bfm_folder)
55 |
56 | msk_names = default_flist_reader(opt.flist)
57 | self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
58 |
59 | self.size = len(self.msk_paths)
60 | self.opt = opt
61 |
62 | self.name = 'train' if opt.isTrain else 'val'
63 | if '_' in opt.flist:
64 | self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
65 |
66 |
67 | def __getitem__(self, index):
68 | """Return a data point and its metadata information.
69 |
70 | Parameters:
71 | index (int) -- a random integer for data indexing
72 |
73 | Returns a dictionary that contains A, B, A_paths and B_paths
74 | img (tensor) -- an image in the input domain
75 | msk (tensor) -- its corresponding attention mask
76 | lm (tensor) -- its corresponding 3d landmarks
77 | im_paths (str) -- image paths
78 | aug_flag (bool) -- a flag used to tell whether its raw or augmented
79 | """
80 | msk_path = self.msk_paths[index % self.size] # make sure index is within then range
81 | img_path = msk_path.replace('mask/', '')
82 | lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
83 |
84 | raw_img = Image.open(img_path).convert('RGB')
85 | raw_msk = Image.open(msk_path).convert('RGB')
86 | raw_lm = np.loadtxt(lm_path).astype(np.float32)
87 |
88 | _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
89 |
90 | aug_flag = self.opt.use_aug and self.opt.isTrain
91 | if aug_flag:
92 | img, lm, msk = self._augmentation(img, lm, self.opt, msk)
93 |
94 | _, H = img.size
95 | M = estimate_norm(lm, H)
96 | transform = get_transform()
97 | img_tensor = transform(img)
98 | msk_tensor = transform(msk)[:1, ...]
99 | lm_tensor = parse_label(lm)
100 | M_tensor = parse_label(M)
101 |
102 |
103 | return {'imgs': img_tensor,
104 | 'lms': lm_tensor,
105 | 'msks': msk_tensor,
106 | 'M': M_tensor,
107 | 'im_paths': img_path,
108 | 'aug_flag': aug_flag,
109 | 'dataset': self.name}
110 |
111 | def _augmentation(self, img, lm, opt, msk=None):
112 | affine, affine_inv, flip = get_affine_mat(opt, img.size)
113 | img = apply_img_affine(img, affine_inv)
114 | lm = apply_lm_affine(lm, affine, flip, img.size)
115 | if msk is not None:
116 | msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
117 | return img, lm, msk
118 |
119 |
120 |
121 |
122 | def __len__(self):
123 | """Return the total number of images in the dataset.
124 | """
125 | return self.size
126 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 | import numpy as np
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | '.tif', '.TIF', '.tiff', '.TIFF',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir, max_dataset_size=float("inf")):
25 | images = []
26 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 | return images[:min(max_dataset_size, len(images))]
34 |
35 |
36 | def default_loader(path):
37 | return Image.open(path).convert('RGB')
38 |
39 |
40 | class ImageFolder(data.Dataset):
41 |
42 | def __init__(self, root, transform=None, return_paths=False,
43 | loader=default_loader):
44 | imgs = make_dataset(root)
45 | if len(imgs) == 0:
46 | raise(RuntimeError("Found 0 images in: " + root + "\n"
47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/data/template_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | from data.base_dataset import BaseDataset, get_transform
15 | # from data.image_folder import make_dataset
16 | # from PIL import Image
17 |
18 |
19 | class TemplateDataset(BaseDataset):
20 | """A template dataset class for you to implement custom datasets."""
21 | @staticmethod
22 | def modify_commandline_options(parser, is_train):
23 | """Add new dataset-specific options, and rewrite default values for existing options.
24 |
25 | Parameters:
26 | parser -- original option parser
27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
28 |
29 | Returns:
30 | the modified parser.
31 | """
32 | parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
33 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
34 | return parser
35 |
36 | def __init__(self, opt):
37 | """Initialize this dataset class.
38 |
39 | Parameters:
40 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
41 |
42 | A few things can be done here.
43 | - save the options (have been done in BaseDataset)
44 | - get image paths and meta information of the dataset.
45 | - define the image transformation.
46 | """
47 | # save the option and dataset root
48 | BaseDataset.__init__(self, opt)
49 | # get the image paths of your dataset;
50 | self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
51 | # define the default transform function. You can use ; You can also define your custom transform function
52 | self.transform = get_transform(opt)
53 |
54 | def __getitem__(self, index):
55 | """Return a data point and its metadata information.
56 |
57 | Parameters:
58 | index -- a random integer for data indexing
59 |
60 | Returns:
61 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
62 |
63 | Step 1: get a random image path: e.g., path = self.image_paths[index]
64 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
65 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
66 | Step 4: return a data point as a dictionary.
67 | """
68 | path = 'temp' # needs to be a string
69 | data_A = None # needs to be a tensor
70 | data_B = None # needs to be a tensor
71 | return {'data_A': data_A, 'data_B': data_B, 'path': path}
72 |
73 | def __len__(self):
74 | """Return the total number of images."""
75 | return len(self.image_paths)
76 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
2 | from .mobilefacenet import get_mbf
3 |
4 |
5 | def get_model(name, **kwargs):
6 | # resnet
7 | if name == "r18":
8 | return iresnet18(False, **kwargs)
9 | elif name == "r34":
10 | return iresnet34(False, **kwargs)
11 | elif name == "r50":
12 | return iresnet50(False, **kwargs)
13 | elif name == "r100":
14 | return iresnet100(False, **kwargs)
15 | elif name == "r200":
16 | return iresnet200(False, **kwargs)
17 | elif name == "r2060":
18 | from .iresnet2060 import iresnet2060
19 | return iresnet2060(False, **kwargs)
20 |
21 | elif name == "mbf":
22 | fp16 = kwargs.get("fp16", False)
23 | num_features = kwargs.get("num_features", 512)
24 | return get_mbf(fp16=fp16, num_features=num_features)
25 |
26 | elif name == "mbf_large":
27 | from .mobilefacenet import get_mbf_large
28 | fp16 = kwargs.get("fp16", False)
29 | num_features = kwargs.get("num_features", 512)
30 | return get_mbf_large(fp16=fp16, num_features=num_features)
31 |
32 | elif name == "vit_t":
33 | num_features = kwargs.get("num_features", 512)
34 | from .vit import VisionTransformer
35 | return VisionTransformer(
36 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
37 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
38 |
39 | elif name == "vit_t_dp005_mask0": # For WebFace42M
40 | num_features = kwargs.get("num_features", 512)
41 | from .vit import VisionTransformer
42 | return VisionTransformer(
43 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
44 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
45 |
46 | elif name == "vit_s":
47 | num_features = kwargs.get("num_features", 512)
48 | from .vit import VisionTransformer
49 | return VisionTransformer(
50 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
51 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
52 |
53 | elif name == "vit_s_dp005_mask_0": # For WebFace42M
54 | num_features = kwargs.get("num_features", 512)
55 | from .vit import VisionTransformer
56 | return VisionTransformer(
57 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
58 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
59 |
60 | elif name == "vit_b":
61 | # this is a feature
62 | num_features = kwargs.get("num_features", 512)
63 | from .vit import VisionTransformer
64 | return VisionTransformer(
65 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
66 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)
67 |
68 | elif name == "vit_b_dp005_mask_005": # For WebFace42M
69 | # this is a feature
70 | num_features = kwargs.get("num_features", 512)
71 | from .vit import VisionTransformer
72 | return VisionTransformer(
73 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
74 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
75 |
76 | elif name == "vit_l_dp005_mask_005": # For WebFace42M
77 | # this is a feature
78 | num_features = kwargs.get("num_features", 512)
79 | from .vit import VisionTransformer
80 | return VisionTransformer(
81 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
82 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
83 |
84 | else:
85 | raise ValueError()
86 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/backbones/iresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.checkpoint import checkpoint
4 |
5 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
6 | using_ckpt = False
7 |
8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9 | """3x3 convolution with padding"""
10 | return nn.Conv2d(in_planes,
11 | out_planes,
12 | kernel_size=3,
13 | stride=stride,
14 | padding=dilation,
15 | groups=groups,
16 | bias=False,
17 | dilation=dilation)
18 |
19 |
20 | def conv1x1(in_planes, out_planes, stride=1):
21 | """1x1 convolution"""
22 | return nn.Conv2d(in_planes,
23 | out_planes,
24 | kernel_size=1,
25 | stride=stride,
26 | bias=False)
27 |
28 |
29 | class IBasicBlock(nn.Module):
30 | expansion = 1
31 | def __init__(self, inplanes, planes, stride=1, downsample=None,
32 | groups=1, base_width=64, dilation=1):
33 | super(IBasicBlock, self).__init__()
34 | if groups != 1 or base_width != 64:
35 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
36 | if dilation > 1:
37 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
38 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
39 | self.conv1 = conv3x3(inplanes, planes)
40 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
41 | self.prelu = nn.PReLU(planes)
42 | self.conv2 = conv3x3(planes, planes, stride)
43 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
44 | self.downsample = downsample
45 | self.stride = stride
46 |
47 | def forward_impl(self, x):
48 | identity = x
49 | out = self.bn1(x)
50 | out = self.conv1(out)
51 | out = self.bn2(out)
52 | out = self.prelu(out)
53 | out = self.conv2(out)
54 | out = self.bn3(out)
55 | if self.downsample is not None:
56 | identity = self.downsample(x)
57 | out += identity
58 | return out
59 |
60 | def forward(self, x):
61 | if self.training and using_ckpt:
62 | return checkpoint(self.forward_impl, x)
63 | else:
64 | return self.forward_impl(x)
65 |
66 |
67 | class IResNet(nn.Module):
68 | fc_scale = 7 * 7
69 | def __init__(self,
70 | block, layers, dropout=0, num_features=512, zero_init_residual=False,
71 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
72 | super(IResNet, self).__init__()
73 | self.extra_gflops = 0.0
74 | self.fp16 = fp16
75 | self.inplanes = 64
76 | self.dilation = 1
77 | if replace_stride_with_dilation is None:
78 | replace_stride_with_dilation = [False, False, False]
79 | if len(replace_stride_with_dilation) != 3:
80 | raise ValueError("replace_stride_with_dilation should be None "
81 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
82 | self.groups = groups
83 | self.base_width = width_per_group
84 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
85 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
86 | self.prelu = nn.PReLU(self.inplanes)
87 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
88 | self.layer2 = self._make_layer(block,
89 | 128,
90 | layers[1],
91 | stride=2,
92 | dilate=replace_stride_with_dilation[0])
93 | self.layer3 = self._make_layer(block,
94 | 256,
95 | layers[2],
96 | stride=2,
97 | dilate=replace_stride_with_dilation[1])
98 | self.layer4 = self._make_layer(block,
99 | 512,
100 | layers[3],
101 | stride=2,
102 | dilate=replace_stride_with_dilation[2])
103 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
104 | self.dropout = nn.Dropout(p=dropout, inplace=True)
105 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
106 | self.features = nn.BatchNorm1d(num_features, eps=1e-05)
107 | nn.init.constant_(self.features.weight, 1.0)
108 | self.features.weight.requires_grad = False
109 |
110 | for m in self.modules():
111 | if isinstance(m, nn.Conv2d):
112 | nn.init.normal_(m.weight, 0, 0.1)
113 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
114 | nn.init.constant_(m.weight, 1)
115 | nn.init.constant_(m.bias, 0)
116 |
117 | if zero_init_residual:
118 | for m in self.modules():
119 | if isinstance(m, IBasicBlock):
120 | nn.init.constant_(m.bn2.weight, 0)
121 |
122 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
123 | downsample = None
124 | previous_dilation = self.dilation
125 | if dilate:
126 | self.dilation *= stride
127 | stride = 1
128 | if stride != 1 or self.inplanes != planes * block.expansion:
129 | downsample = nn.Sequential(
130 | conv1x1(self.inplanes, planes * block.expansion, stride),
131 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
132 | )
133 | layers = []
134 | layers.append(
135 | block(self.inplanes, planes, stride, downsample, self.groups,
136 | self.base_width, previous_dilation))
137 | self.inplanes = planes * block.expansion
138 | for _ in range(1, blocks):
139 | layers.append(
140 | block(self.inplanes,
141 | planes,
142 | groups=self.groups,
143 | base_width=self.base_width,
144 | dilation=self.dilation))
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def forward(self, x):
149 | with torch.cuda.amp.autocast(self.fp16):
150 | x = self.conv1(x)
151 | x = self.bn1(x)
152 | x = self.prelu(x)
153 | x = self.layer1(x)
154 | x = self.layer2(x)
155 | x = self.layer3(x)
156 | x = self.layer4(x)
157 | x = self.bn2(x)
158 | x = torch.flatten(x, 1)
159 | x = self.dropout(x)
160 | x = self.fc(x.float() if self.fp16 else x)
161 | x = self.features(x)
162 | return x
163 |
164 |
165 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
166 | model = IResNet(block, layers, **kwargs)
167 | if pretrained:
168 | raise ValueError()
169 | return model
170 |
171 |
172 | def iresnet18(pretrained=False, progress=True, **kwargs):
173 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
174 | progress, **kwargs)
175 |
176 |
177 | def iresnet34(pretrained=False, progress=True, **kwargs):
178 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
179 | progress, **kwargs)
180 |
181 |
182 | def iresnet50(pretrained=False, progress=True, **kwargs):
183 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
184 | progress, **kwargs)
185 |
186 |
187 | def iresnet100(pretrained=False, progress=True, **kwargs):
188 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
189 | progress, **kwargs)
190 |
191 |
192 | def iresnet200(pretrained=False, progress=True, **kwargs):
193 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
194 | progress, **kwargs)
195 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/backbones/iresnet2060.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | assert torch.__version__ >= "1.8.1"
5 | from torch.utils.checkpoint import checkpoint_sequential
6 |
7 | __all__ = ['iresnet2060']
8 |
9 |
10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
11 | """3x3 convolution with padding"""
12 | return nn.Conv2d(in_planes,
13 | out_planes,
14 | kernel_size=3,
15 | stride=stride,
16 | padding=dilation,
17 | groups=groups,
18 | bias=False,
19 | dilation=dilation)
20 |
21 |
22 | def conv1x1(in_planes, out_planes, stride=1):
23 | """1x1 convolution"""
24 | return nn.Conv2d(in_planes,
25 | out_planes,
26 | kernel_size=1,
27 | stride=stride,
28 | bias=False)
29 |
30 |
31 | class IBasicBlock(nn.Module):
32 | expansion = 1
33 |
34 | def __init__(self, inplanes, planes, stride=1, downsample=None,
35 | groups=1, base_width=64, dilation=1):
36 | super(IBasicBlock, self).__init__()
37 | if groups != 1 or base_width != 64:
38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39 | if dilation > 1:
40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
42 | self.conv1 = conv3x3(inplanes, planes)
43 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
44 | self.prelu = nn.PReLU(planes)
45 | self.conv2 = conv3x3(planes, planes, stride)
46 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
47 | self.downsample = downsample
48 | self.stride = stride
49 |
50 | def forward(self, x):
51 | identity = x
52 | out = self.bn1(x)
53 | out = self.conv1(out)
54 | out = self.bn2(out)
55 | out = self.prelu(out)
56 | out = self.conv2(out)
57 | out = self.bn3(out)
58 | if self.downsample is not None:
59 | identity = self.downsample(x)
60 | out += identity
61 | return out
62 |
63 |
64 | class IResNet(nn.Module):
65 | fc_scale = 7 * 7
66 |
67 | def __init__(self,
68 | block, layers, dropout=0, num_features=512, zero_init_residual=False,
69 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
70 | super(IResNet, self).__init__()
71 | self.fp16 = fp16
72 | self.inplanes = 64
73 | self.dilation = 1
74 | if replace_stride_with_dilation is None:
75 | replace_stride_with_dilation = [False, False, False]
76 | if len(replace_stride_with_dilation) != 3:
77 | raise ValueError("replace_stride_with_dilation should be None "
78 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
79 | self.groups = groups
80 | self.base_width = width_per_group
81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
82 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
83 | self.prelu = nn.PReLU(self.inplanes)
84 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
85 | self.layer2 = self._make_layer(block,
86 | 128,
87 | layers[1],
88 | stride=2,
89 | dilate=replace_stride_with_dilation[0])
90 | self.layer3 = self._make_layer(block,
91 | 256,
92 | layers[2],
93 | stride=2,
94 | dilate=replace_stride_with_dilation[1])
95 | self.layer4 = self._make_layer(block,
96 | 512,
97 | layers[3],
98 | stride=2,
99 | dilate=replace_stride_with_dilation[2])
100 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
101 | self.dropout = nn.Dropout(p=dropout, inplace=True)
102 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
103 | self.features = nn.BatchNorm1d(num_features, eps=1e-05)
104 | nn.init.constant_(self.features.weight, 1.0)
105 | self.features.weight.requires_grad = False
106 |
107 | for m in self.modules():
108 | if isinstance(m, nn.Conv2d):
109 | nn.init.normal_(m.weight, 0, 0.1)
110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
111 | nn.init.constant_(m.weight, 1)
112 | nn.init.constant_(m.bias, 0)
113 |
114 | if zero_init_residual:
115 | for m in self.modules():
116 | if isinstance(m, IBasicBlock):
117 | nn.init.constant_(m.bn2.weight, 0)
118 |
119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
120 | downsample = None
121 | previous_dilation = self.dilation
122 | if dilate:
123 | self.dilation *= stride
124 | stride = 1
125 | if stride != 1 or self.inplanes != planes * block.expansion:
126 | downsample = nn.Sequential(
127 | conv1x1(self.inplanes, planes * block.expansion, stride),
128 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
129 | )
130 | layers = []
131 | layers.append(
132 | block(self.inplanes, planes, stride, downsample, self.groups,
133 | self.base_width, previous_dilation))
134 | self.inplanes = planes * block.expansion
135 | for _ in range(1, blocks):
136 | layers.append(
137 | block(self.inplanes,
138 | planes,
139 | groups=self.groups,
140 | base_width=self.base_width,
141 | dilation=self.dilation))
142 |
143 | return nn.Sequential(*layers)
144 |
145 | def checkpoint(self, func, num_seg, x):
146 | if self.training:
147 | return checkpoint_sequential(func, num_seg, x)
148 | else:
149 | return func(x)
150 |
151 | def forward(self, x):
152 | with torch.cuda.amp.autocast(self.fp16):
153 | x = self.conv1(x)
154 | x = self.bn1(x)
155 | x = self.prelu(x)
156 | x = self.layer1(x)
157 | x = self.checkpoint(self.layer2, 20, x)
158 | x = self.checkpoint(self.layer3, 100, x)
159 | x = self.layer4(x)
160 | x = self.bn2(x)
161 | x = torch.flatten(x, 1)
162 | x = self.dropout(x)
163 | x = self.fc(x.float() if self.fp16 else x)
164 | x = self.features(x)
165 | return x
166 |
167 |
168 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
169 | model = IResNet(block, layers, **kwargs)
170 | if pretrained:
171 | raise ValueError()
172 | return model
173 |
174 |
175 | def iresnet2060(pretrained=False, progress=True, **kwargs):
176 | return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
177 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/backbones/mobilefacenet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
3 | Original author cavalleria
4 | '''
5 |
6 | import torch.nn as nn
7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
8 | import torch
9 |
10 |
11 | class Flatten(Module):
12 | def forward(self, x):
13 | return x.view(x.size(0), -1)
14 |
15 |
16 | class ConvBlock(Module):
17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
18 | super(ConvBlock, self).__init__()
19 | self.layers = nn.Sequential(
20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
21 | BatchNorm2d(num_features=out_c),
22 | PReLU(num_parameters=out_c)
23 | )
24 |
25 | def forward(self, x):
26 | return self.layers(x)
27 |
28 |
29 | class LinearBlock(Module):
30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
31 | super(LinearBlock, self).__init__()
32 | self.layers = nn.Sequential(
33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
34 | BatchNorm2d(num_features=out_c)
35 | )
36 |
37 | def forward(self, x):
38 | return self.layers(x)
39 |
40 |
41 | class DepthWise(Module):
42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
43 | super(DepthWise, self).__init__()
44 | self.residual = residual
45 | self.layers = nn.Sequential(
46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
49 | )
50 |
51 | def forward(self, x):
52 | short_cut = None
53 | if self.residual:
54 | short_cut = x
55 | x = self.layers(x)
56 | if self.residual:
57 | output = short_cut + x
58 | else:
59 | output = x
60 | return output
61 |
62 |
63 | class Residual(Module):
64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
65 | super(Residual, self).__init__()
66 | modules = []
67 | for _ in range(num_block):
68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
69 | self.layers = Sequential(*modules)
70 |
71 | def forward(self, x):
72 | return self.layers(x)
73 |
74 |
75 | class GDC(Module):
76 | def __init__(self, embedding_size):
77 | super(GDC, self).__init__()
78 | self.layers = nn.Sequential(
79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
80 | Flatten(),
81 | Linear(512, embedding_size, bias=False),
82 | BatchNorm1d(embedding_size))
83 |
84 | def forward(self, x):
85 | return self.layers(x)
86 |
87 |
88 | class MobileFaceNet(Module):
89 | def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2):
90 | super(MobileFaceNet, self).__init__()
91 | self.scale = scale
92 | self.fp16 = fp16
93 | self.layers = nn.ModuleList()
94 | self.layers.append(
95 | ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
96 | )
97 | if blocks[0] == 1:
98 | self.layers.append(
99 | ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
100 | )
101 | else:
102 | self.layers.append(
103 | Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
104 | )
105 |
106 | self.layers.extend(
107 | [
108 | DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
109 | Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
110 | DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
111 | Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
112 | DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
113 | Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
114 | ])
115 |
116 | self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
117 | self.features = GDC(num_features)
118 | self._initialize_weights()
119 |
120 | def _initialize_weights(self):
121 | for m in self.modules():
122 | if isinstance(m, nn.Conv2d):
123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
124 | if m.bias is not None:
125 | m.bias.data.zero_()
126 | elif isinstance(m, nn.BatchNorm2d):
127 | m.weight.data.fill_(1)
128 | m.bias.data.zero_()
129 | elif isinstance(m, nn.Linear):
130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
131 | if m.bias is not None:
132 | m.bias.data.zero_()
133 |
134 | def forward(self, x):
135 | with torch.cuda.amp.autocast(self.fp16):
136 | for func in self.layers:
137 | x = func(x)
138 | x = self.conv_sep(x.float() if self.fp16 else x)
139 | x = self.features(x)
140 | return x
141 |
142 |
143 | def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2):
144 | return MobileFaceNet(fp16, num_features, blocks, scale=scale)
145 |
146 | def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4):
147 | return MobileFaceNet(fp16, num_features, blocks, scale=scale)
148 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/3millions.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # configs for test speed
4 |
5 | config = edict()
6 | config.margin_list = (1.0, 0.0, 0.4)
7 | config.network = "mbf"
8 | config.resume = False
9 | config.output = None
10 | config.embedding_size = 512
11 | config.sample_rate = 0.1
12 | config.fp16 = True
13 | config.momentum = 0.9
14 | config.weight_decay = 5e-4
15 | config.batch_size = 512 # total_batch_size = batch_size * num_gpus
16 | config.lr = 0.1 # batch size is 512
17 |
18 | config.rec = "synthetic"
19 | config.num_classes = 30 * 10000
20 | config.num_image = 100000
21 | config.num_epoch = 30
22 | config.warmup_epoch = -1
23 | config.val_targets = []
24 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/configs/__init__.py
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/base.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 |
9 | # Margin Base Softmax
10 | config.margin_list = (1.0, 0.5, 0.0)
11 | config.network = "r50"
12 | config.resume = False
13 | config.save_all_states = False
14 | config.output = "ms1mv3_arcface_r50"
15 |
16 | config.embedding_size = 512
17 |
18 | # Partial FC
19 | config.sample_rate = 1
20 | config.interclass_filtering_threshold = 0
21 |
22 | config.fp16 = False
23 | config.batch_size = 128
24 |
25 | # For SGD
26 | config.optimizer = "sgd"
27 | config.lr = 0.1
28 | config.momentum = 0.9
29 | config.weight_decay = 5e-4
30 |
31 | # For AdamW
32 | # config.optimizer = "adamw"
33 | # config.lr = 0.001
34 | # config.weight_decay = 0.1
35 |
36 | config.verbose = 2000
37 | config.frequent = 10
38 |
39 | # For Large Sacle Dataset, such as WebFace42M
40 | config.dali = False
41 |
42 |
43 | # setup seed
44 | config.seed = 2048
45 |
46 | # dataload numworkers
47 | config.num_workers = 8
48 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/glint360k_mbf.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 1e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/glint360k"
23 | config.num_classes = 360232
24 | config.num_image = 17091657
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/glint360k_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 1e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/glint360k"
23 | config.num_classes = 360232
24 | config.num_image = 17091657
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/glint360k_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 1e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/glint360k"
23 | config.num_classes = 360232
24 | config.num_image = 17091657
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/ms1mv2_mbf.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.5, 0.0)
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 1e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/faces_emore"
23 | config.num_classes = 85742
24 | config.num_image = 5822653
25 | config.num_epoch = 40
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/ms1mv2_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.5, 0.0)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/faces_emore"
23 | config.num_classes = 85742
24 | config.num_image = 5822653
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/ms1mv2_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.5, 0.0)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/faces_emore"
23 | config.num_classes = 85742
24 | config.num_image = 5822653
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/ms1mv3_mbf.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.5, 0.0)
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 1e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/ms1m-retinaface-t1"
23 | config.num_classes = 93431
24 | config.num_image = 5179510
25 | config.num_epoch = 40
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/ms1mv3_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.5, 0.0)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/ms1m-retinaface-t1"
23 | config.num_classes = 93431
24 | config.num_image = 5179510
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/ms1mv3_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.5, 0.0)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/ms1m-retinaface-t1"
23 | config.num_classes = 93431
24 | config.num_image = 5179510
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_conflict_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.interclass_filtering_threshold = 0
15 | config.fp16 = True
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.optimizer = "sgd"
19 | config.lr = 0.1
20 | config.verbose = 2000
21 | config.dali = False
22 |
23 | config.rec = "/train_tmp/WebFace12M_Conflict"
24 | config.num_classes = 1017970
25 | config.num_image = 12720066
26 | config.num_epoch = 20
27 | config.warmup_epoch = config.num_epoch // 10
28 | config.val_targets = []
29 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.interclass_filtering_threshold = 0.4
15 | config.fp16 = True
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.optimizer = "sgd"
19 | config.lr = 0.1
20 | config.verbose = 2000
21 | config.dali = False
22 |
23 | config.rec = "/train_tmp/WebFace12M_Conflict"
24 | config.num_classes = 1017970
25 | config.num_image = 12720066
26 | config.num_epoch = 20
27 | config.warmup_epoch = config.num_epoch // 10
28 | config.val_targets = []
29 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.1
14 | config.interclass_filtering_threshold = 0.4
15 | config.fp16 = True
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.optimizer = "sgd"
19 | config.lr = 0.1
20 | config.verbose = 2000
21 | config.dali = False
22 |
23 | config.rec = "/train_tmp/WebFace12M_FLIP40"
24 | config.num_classes = 617970
25 | config.num_image = 12720066
26 | config.num_epoch = 20
27 | config.warmup_epoch = config.num_epoch // 10
28 | config.val_targets = []
29 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_flip_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.interclass_filtering_threshold = 0
15 | config.fp16 = True
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.optimizer = "sgd"
19 | config.lr = 0.1
20 | config.verbose = 2000
21 | config.dali = False
22 |
23 | config.rec = "/train_tmp/WebFace12M_FLIP40"
24 | config.num_classes = 617970
25 | config.num_image = 12720066
26 | config.num_epoch = 20
27 | config.warmup_epoch = config.num_epoch // 10
28 | config.val_targets = []
29 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_mbf.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.interclass_filtering_threshold = 0
15 | config.fp16 = True
16 | config.weight_decay = 1e-4
17 | config.batch_size = 128
18 | config.optimizer = "sgd"
19 | config.lr = 0.1
20 | config.verbose = 2000
21 | config.dali = False
22 |
23 | config.rec = "/train_tmp/WebFace12M"
24 | config.num_classes = 617970
25 | config.num_image = 12720066
26 | config.num_epoch = 20
27 | config.warmup_epoch = 0
28 | config.val_targets = []
29 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_pfc02_r100.py:
--------------------------------------------------------------------------------
1 |
2 | from easydict import EasyDict as edict
3 |
4 | # make training faster
5 | # our RAM is 256G
6 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
7 |
8 | config = edict()
9 | config.margin_list = (1.0, 0.0, 0.4)
10 | config.network = "r100"
11 | config.resume = False
12 | config.output = None
13 | config.embedding_size = 512
14 | config.sample_rate = 0.2
15 | config.interclass_filtering_threshold = 0
16 | config.fp16 = True
17 | config.weight_decay = 5e-4
18 | config.batch_size = 128
19 | config.optimizer = "sgd"
20 | config.lr = 0.1
21 | config.verbose = 2000
22 | config.dali = False
23 |
24 | config.rec = "/train_tmp/WebFace12M"
25 | config.num_classes = 617970
26 | config.num_image = 12720066
27 | config.num_epoch = 20
28 | config.warmup_epoch = 0
29 | config.val_targets = []
30 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_r100.py:
--------------------------------------------------------------------------------
1 |
2 | from easydict import EasyDict as edict
3 |
4 | # make training faster
5 | # our RAM is 256G
6 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
7 |
8 | config = edict()
9 | config.margin_list = (1.0, 0.0, 0.4)
10 | config.network = "r100"
11 | config.resume = False
12 | config.output = None
13 | config.embedding_size = 512
14 | config.sample_rate = 1.0
15 | config.interclass_filtering_threshold = 0
16 | config.fp16 = True
17 | config.weight_decay = 5e-4
18 | config.batch_size = 128
19 | config.optimizer = "sgd"
20 | config.lr = 0.1
21 | config.verbose = 2000
22 | config.dali = False
23 |
24 | config.rec = "/train_tmp/WebFace12M"
25 | config.num_classes = 617970
26 | config.num_image = 12720066
27 | config.num_epoch = 20
28 | config.warmup_epoch = 0
29 | config.val_targets = []
30 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf12m_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.interclass_filtering_threshold = 0
15 | config.fp16 = True
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.optimizer = "sgd"
19 | config.lr = 0.1
20 | config.verbose = 2000
21 | config.dali = False
22 |
23 | config.rec = "/train_tmp/WebFace12M"
24 | config.num_classes = 617970
25 | config.num_image = 12720066
26 | config.num_epoch = 20
27 | config.warmup_epoch = 0
28 | config.val_targets = []
29 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 512
18 | config.lr = 0.4
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 1e-4
17 | config.batch_size = 512
18 | config.lr = 0.4
19 | config.verbose = 10000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = 2
27 | config.val_targets = []
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 256
18 | config.lr = 0.3
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = 1
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 512
18 | config.lr = 0.6
19 | config.verbose = 10000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = 4
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.4
19 | config.verbose = 10000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = 2
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 512
18 | config.lr = 0.4
19 | config.verbose = 10000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = 2
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 10000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.2
19 | config.verbose = 10000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.2
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.4
19 | config.verbose = 10000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.4
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r18"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.4
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r200"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.4
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.4
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 20
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "vit_b_dp005_mask_005"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.weight_decay = 0.1
16 | config.batch_size = 384
17 | config.optimizer = "adamw"
18 | config.lr = 0.001
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 40
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = []
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "vit_l_dp005_mask_005"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.weight_decay = 0.1
16 | config.batch_size = 384
17 | config.optimizer = "adamw"
18 | config.lr = 0.001
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 40
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = []
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "vit_s_dp005_mask_0"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.weight_decay = 0.1
16 | config.batch_size = 384
17 | config.optimizer = "adamw"
18 | config.lr = 0.001
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 40
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = []
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "vit_t_dp005_mask0"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.weight_decay = 0.1
16 | config.batch_size = 384
17 | config.optimizer = "adamw"
18 | config.lr = 0.001
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 40
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = []
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "vit_t_dp005_mask0"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 0.3
14 | config.fp16 = True
15 | config.weight_decay = 0.1
16 | config.batch_size = 512
17 | config.optimizer = "adamw"
18 | config.lr = 0.001
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace42M"
23 | config.num_classes = 2059906
24 | config.num_image = 42474557
25 | config.num_epoch = 40
26 | config.warmup_epoch = config.num_epoch // 10
27 | config.val_targets = []
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf4m_mbf.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "mbf"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 1e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace4M"
23 | config.num_classes = 205990
24 | config.num_image = 4235242
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf4m_r100.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r100"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace4M"
23 | config.num_classes = 205990
24 | config.num_image = 4235242
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/configs/wf4m_r50.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 |
3 | # make training faster
4 | # our RAM is 256G
5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp
6 |
7 | config = edict()
8 | config.margin_list = (1.0, 0.0, 0.4)
9 | config.network = "r50"
10 | config.resume = False
11 | config.output = None
12 | config.embedding_size = 512
13 | config.sample_rate = 1.0
14 | config.fp16 = True
15 | config.momentum = 0.9
16 | config.weight_decay = 5e-4
17 | config.batch_size = 128
18 | config.lr = 0.1
19 | config.verbose = 2000
20 | config.dali = False
21 |
22 | config.rec = "/train_tmp/WebFace4M"
23 | config.num_classes = 205990
24 | config.num_image = 4235242
25 | config.num_epoch = 20
26 | config.warmup_epoch = 0
27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
28 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/dist.sh:
--------------------------------------------------------------------------------
1 | ip_list=("ip1" "ip2" "ip3" "ip4")
2 |
3 | config=wf42m_pfc03_32gpu_r100
4 |
5 | for((node_rank=0;node_rank<${#ip_list[*]};node_rank++));
6 | do
7 | ssh face@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \
8 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
9 | python -m torch.distributed.launch \
10 | --nproc_per_node=8 \
11 | --nnodes=${#ip_list[*]} \
12 | --node_rank=$node_rank \
13 | --master_addr=${ip_list[0]} \
14 | --master_port=22345 train.py configs/$config" &
15 | done
16 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/docs/eval.md:
--------------------------------------------------------------------------------
1 | ## Eval on ICCV2021-MFR
2 |
3 | coming soon.
4 |
5 |
6 | ## Eval IJBC
7 | You can eval ijbc with pytorch or onnx.
8 |
9 |
10 | 1. Eval IJBC With Onnx
11 | ```shell
12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
13 | ```
14 |
15 | 2. Eval IJBC With Pytorch
16 | ```shell
17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \
19 | --image-path IJB_release/IJBC \
20 | --result-dir ms1mv3_arcface_r50 \
21 | --batch-size 128 \
22 | --job ms1mv3_arcface_r50 \
23 | --target IJBC \
24 | --network iresnet50
25 | ```
26 |
27 |
28 | ## Inference
29 |
30 | ```shell
31 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
32 | ```
33 |
34 |
35 | ## Result
36 |
37 | | Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) |
38 | |:---------------|:--------------------|:------------|:------------|:------------|
39 | | WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 |
40 | | WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 |
41 | | WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 |
42 | | WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 |
43 | | WF12M | r100 | 94.69 | 97.59 | 95.97 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/docs/install.md:
--------------------------------------------------------------------------------
1 | ## [v1.11.0](https://pytorch.org/)
2 |
3 | ## [v1.9.0](https://pytorch.org/get-started/previous-versions/#linux-and-windows-7)
4 | ### Linux and Windows
5 | ```shell
6 | # CUDA 11.1
7 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
8 |
9 | # CUDA 10.2
10 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
11 | ```
12 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/docs/install_dali.md:
--------------------------------------------------------------------------------
1 | TODO
2 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/docs/modelzoo.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/docs/modelzoo.md
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/docs/prepare_webface42m.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## 1. Download Datasets and Unzip
5 |
6 | Download WebFace42M from [https://www.face-benchmark.org/download.html](https://www.face-benchmark.org/download.html).
7 | The raw data of `WebFace42M` will have 10 directories after being unarchived:
8 | `WebFace4M` contains 1 directory: `0`.
9 | `WebFace12M` contains 3 directories: `0,1,2`.
10 | `WebFace42M` contains 10 directories: `0,1,2,3,4,5,6,7,8,9`.
11 |
12 | ## 2. Create Shuffled Rec File for DALI
13 |
14 | Note: Shuffled rec is very important to DALI, and rec without shuffled can cause performance degradation, origin insightface style rec file
15 | do not support Nvidia DALI, you must follow this command [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) to generate a shuffled rec file.
16 |
17 | ```shell
18 | # directories and files for yours datsaets
19 | /WebFace42M_Root
20 | ├── 0_0_0000000
21 | │ ├── 0_0.jpg
22 | │ ├── 0_1.jpg
23 | │ ├── 0_2.jpg
24 | │ ├── 0_3.jpg
25 | │ └── 0_4.jpg
26 | ├── 0_0_0000001
27 | │ ├── 0_5.jpg
28 | │ ├── 0_6.jpg
29 | │ ├── 0_7.jpg
30 | │ ├── 0_8.jpg
31 | │ └── 0_9.jpg
32 | ├── 0_0_0000002
33 | │ ├── 0_10.jpg
34 | │ ├── 0_11.jpg
35 | │ ├── 0_12.jpg
36 | │ ├── 0_13.jpg
37 | │ ├── 0_14.jpg
38 | │ ├── 0_15.jpg
39 | │ ├── 0_16.jpg
40 | │ └── 0_17.jpg
41 | ├── 0_0_0000003
42 | │ ├── 0_18.jpg
43 | │ ├── 0_19.jpg
44 | │ └── 0_20.jpg
45 | ├── 0_0_0000004
46 |
47 |
48 |
49 | # 1) create train.lst using follow command
50 | python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root
51 |
52 | # 2) create train.rec and train.idx using train.lst using following command
53 | python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root
54 | ```
55 |
56 | Finally, you will get three files: `train.lst`, `train.rec`, `train.idx`. which `train.idx`, `train.rec` are using for training.
57 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/docs/speed_benchmark.md:
--------------------------------------------------------------------------------
1 | ## Test Training Speed
2 |
3 | - Test Commands
4 |
5 | You need to use the following two commands to test the Partial FC training performance.
6 | The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
7 | batch size is 1024.
8 | ```shell
9 | # Model Parallel
10 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
11 | # Partial FC 0.1
12 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
13 | ```
14 |
15 | - GPU Memory
16 |
17 | ```
18 | # (Model Parallel) gpustat -i
19 | [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
20 | [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
21 | [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
22 | [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
23 | [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
24 | [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
25 | [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
26 | [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
27 |
28 | # (Partial FC 0.1) gpustat -i
29 | [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
30 | [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
31 | [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
32 | [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
33 | [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
34 | [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
35 | [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
36 | [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
37 | ```
38 |
39 | - Training Speed
40 |
41 | ```python
42 | # (Model Parallel) trainging.log
43 | Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
44 | Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
45 | Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
46 | Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
47 | Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
48 |
49 | # (Partial FC 0.1) trainging.log
50 | Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
51 | Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
52 | Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
53 | Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
54 | Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
55 | ```
56 |
57 | In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
58 | and the training speed is 2.5 times faster than the model parallel.
59 |
60 |
61 | ## Speed Benchmark
62 |
63 | 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
64 |
65 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
66 | | :--- | :--- | :--- | :--- |
67 | |125000 | 4681 | 4824 | 5004 |
68 | |250000 | 4047 | 4521 | 4976 |
69 | |500000 | 3087 | 4013 | 4900 |
70 | |1000000 | 2090 | 3449 | 4803 |
71 | |1400000 | 1672 | 3043 | 4738 |
72 | |2000000 | - | 2593 | 4626 |
73 | |4000000 | - | 1748 | 4208 |
74 | |5500000 | - | 1389 | 3975 |
75 | |8000000 | - | - | 3565 |
76 | |16000000 | - | - | 2679 |
77 | |29000000 | - | - | 1855 |
78 |
79 | 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
80 |
81 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
82 | | :--- | :--- | :--- | :--- |
83 | |125000 | 7358 | 5306 | 4868 |
84 | |250000 | 9940 | 5826 | 5004 |
85 | |500000 | 14220 | 7114 | 5202 |
86 | |1000000 | 23708 | 9966 | 5620 |
87 | |1400000 | 32252 | 11178 | 6056 |
88 | |2000000 | - | 13978 | 6472 |
89 | |4000000 | - | 23238 | 8284 |
90 | |5500000 | - | 32188 | 9854 |
91 | |8000000 | - | - | 12310 |
92 | |16000000 | - | - | 19950 |
93 | |29000000 | - | - | 32324 |
94 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/eval/__init__.py
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/flops.py:
--------------------------------------------------------------------------------
1 | from ptflops import get_model_complexity_info
2 | from backbones import get_model
3 | import argparse
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser(description='')
7 | parser.add_argument('n', type=str, default="r100")
8 | args = parser.parse_args()
9 | net = get_model(args.n)
10 | macs, params = get_model_complexity_info(
11 | net, (3, 112, 112), as_strings=False,
12 | print_per_layer_stat=True, verbose=True)
13 | gmacs = macs / (1000**3)
14 | print("%.3f GFLOPs"%gmacs)
15 | print("%.3f Mparams"%(params/(1000**2)))
16 |
17 | if hasattr(net, "extra_gflops"):
18 | print("%.3f Extra-GFLOPs"%net.extra_gflops)
19 | print("%.3f Total-GFLOPs"%(gmacs+net.extra_gflops))
20 |
21 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | from backbones import get_model
8 |
9 |
10 | @torch.no_grad()
11 | def inference(weight, name, img):
12 | if img is None:
13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
14 | else:
15 | img = cv2.imread(img)
16 | img = cv2.resize(img, (112, 112))
17 |
18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
19 | img = np.transpose(img, (2, 0, 1))
20 | img = torch.from_numpy(img).unsqueeze(0).float()
21 | img.div_(255).sub_(0.5).div_(0.5)
22 | net = get_model(name, fp16=False)
23 | net.load_state_dict(torch.load(weight))
24 | net.eval()
25 | feat = net(img).numpy()
26 | print(feat)
27 |
28 |
29 | if __name__ == "__main__":
30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
31 | parser.add_argument('--network', type=str, default='r50', help='backbone network')
32 | parser.add_argument('--weight', type=str, default='')
33 | parser.add_argument('--img', type=str, default=None)
34 | args = parser.parse_args()
35 | inference(args.weight, args.network, args.img)
36 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | class CombinedMarginLoss(torch.nn.Module):
6 | def __init__(self,
7 | s,
8 | m1,
9 | m2,
10 | m3,
11 | interclass_filtering_threshold=0):
12 | super().__init__()
13 | self.s = s
14 | self.m1 = m1
15 | self.m2 = m2
16 | self.m3 = m3
17 | self.interclass_filtering_threshold = interclass_filtering_threshold
18 |
19 | # For ArcFace
20 | self.cos_m = math.cos(self.m2)
21 | self.sin_m = math.sin(self.m2)
22 | self.theta = math.cos(math.pi - self.m2)
23 | self.sinmm = math.sin(math.pi - self.m2) * self.m2
24 | self.easy_margin = False
25 |
26 |
27 | def forward(self, logits, labels):
28 | index_positive = torch.where(labels != -1)[0]
29 |
30 | if self.interclass_filtering_threshold > 0:
31 | with torch.no_grad():
32 | dirty = logits > self.interclass_filtering_threshold
33 | dirty = dirty.float()
34 | mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
35 | mask.scatter_(1, labels[index_positive], 0)
36 | dirty[index_positive] *= mask
37 | tensor_mul = 1 - dirty
38 | logits = tensor_mul * logits
39 |
40 | target_logit = logits[index_positive, labels[index_positive].view(-1)]
41 |
42 | if self.m1 == 1.0 and self.m3 == 0.0:
43 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
44 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
45 | if self.easy_margin:
46 | final_target_logit = torch.where(
47 | target_logit > 0, cos_theta_m, target_logit)
48 | else:
49 | final_target_logit = torch.where(
50 | target_logit > self.theta, cos_theta_m, target_logit - self.sinmm)
51 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
52 | logits = logits * self.s
53 |
54 | elif self.m3 > 0:
55 | final_target_logit = target_logit - self.m3
56 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
57 | logits = logits * self.s
58 | else:
59 | raise
60 |
61 | return logits
62 |
63 | class ArcFace(torch.nn.Module):
64 | """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
65 | """
66 | def __init__(self, s=64.0, margin=0.5):
67 | super(ArcFace, self).__init__()
68 | self.scale = s
69 | self.cos_m = math.cos(margin)
70 | self.sin_m = math.sin(margin)
71 | self.theta = math.cos(math.pi - margin)
72 | self.sinmm = math.sin(math.pi - margin) * margin
73 | self.easy_margin = False
74 |
75 |
76 | def forward(self, logits: torch.Tensor, labels: torch.Tensor):
77 | index = torch.where(labels != -1)[0]
78 | target_logit = logits[index, labels[index].view(-1)]
79 |
80 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
81 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin)
82 | if self.easy_margin:
83 | final_target_logit = torch.where(
84 | target_logit > 0, cos_theta_m, target_logit)
85 | else:
86 | final_target_logit = torch.where(
87 | target_logit > self.theta, cos_theta_m, target_logit - self.sinmm)
88 |
89 | logits[index, labels[index].view(-1)] = final_target_logit
90 | logits = logits * self.scale
91 | return logits
92 |
93 |
94 | class CosFace(torch.nn.Module):
95 | def __init__(self, s=64.0, m=0.40):
96 | super(CosFace, self).__init__()
97 | self.s = s
98 | self.m = m
99 |
100 | def forward(self, logits: torch.Tensor, labels: torch.Tensor):
101 | index = torch.where(labels != -1)[0]
102 | target_logit = logits[index, labels[index].view(-1)]
103 | final_target_logit = target_logit - self.m
104 | logits[index, labels[index].view(-1)] = final_target_logit
105 | logits = logits * self.s
106 | return logits
107 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 |
4 | class PolyScheduler(_LRScheduler):
5 | def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
6 | self.base_lr = base_lr
7 | self.warmup_lr_init = 0.0001
8 | self.max_steps: int = max_steps
9 | self.warmup_steps: int = warmup_steps
10 | self.power = 2
11 | super(PolyScheduler, self).__init__(optimizer, -1, False)
12 | self.last_epoch = last_epoch
13 |
14 | def get_warmup_lr(self):
15 | alpha = float(self.last_epoch) / float(self.warmup_steps)
16 | return [self.base_lr * alpha for _ in self.optimizer.param_groups]
17 |
18 | def get_lr(self):
19 | if self.last_epoch == -1:
20 | return [self.warmup_lr_init for _ in self.optimizer.param_groups]
21 | if self.last_epoch < self.warmup_steps:
22 | return self.get_warmup_lr()
23 | else:
24 | alpha = pow(
25 | 1
26 | - float(self.last_epoch - self.warmup_steps)
27 | / float(self.max_steps - self.warmup_steps),
28 | self.power,
29 | )
30 | return [self.base_lr * alpha for _ in self.optimizer.param_groups]
31 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/requirement.txt:
--------------------------------------------------------------------------------
1 | tensorboard
2 | easydict
3 | mxnet
4 | onnx
5 | sklearn
6 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/run.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \
3 | --nproc_per_node=8 \
4 | --nnodes=1 \
5 | --node_rank=0 \
6 | --master_addr="127.0.0.1" \
7 | --master_port=12345 train.py $@
8 |
9 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
10 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/torch2onnx.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import onnx
3 | import torch
4 |
5 |
6 | def convert_onnx(net, path_module, output, opset=11, simplify=False):
7 | assert isinstance(net, torch.nn.Module)
8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
9 | img = img.astype(np.float)
10 | img = (img / 255. - 0.5) / 0.5 # torch style norm
11 | img = img.transpose((2, 0, 1))
12 | img = torch.from_numpy(img).unsqueeze(0).float()
13 |
14 | weight = torch.load(path_module)
15 | net.load_state_dict(weight, strict=True)
16 | net.eval()
17 | torch.onnx.export(net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
18 | model = onnx.load(output)
19 | graph = model.graph
20 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
21 | if simplify:
22 | from onnxsim import simplify
23 | model, check = simplify(model)
24 | assert check, "Simplified ONNX model could not be validated"
25 | onnx.save(model, output)
26 |
27 |
28 | if __name__ == '__main__':
29 | import os
30 | import argparse
31 | from backbones import get_model
32 |
33 | parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
34 | parser.add_argument('input', type=str, help='input backbone.pth file or path')
35 | parser.add_argument('--output', type=str, default=None, help='output onnx path')
36 | parser.add_argument('--network', type=str, default=None, help='backbone network')
37 | parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
38 | args = parser.parse_args()
39 | input_file = args.input
40 | if os.path.isdir(input_file):
41 | input_file = os.path.join(input_file, "model.pt")
42 | assert os.path.exists(input_file)
43 | # model_name = os.path.basename(os.path.dirname(input_file)).lower()
44 | # params = model_name.split("_")
45 | # if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
46 | # if args.network is None:
47 | # args.network = params[2]
48 | assert args.network is not None
49 | print(args)
50 | backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512)
51 | if args.output is None:
52 | args.output = os.path.join(os.path.dirname(args.input), "model.onnx")
53 | convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify)
54 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | import numpy as np
6 | import torch
7 | from torch import distributed
8 | from torch.utils.data import DataLoader
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from backbones import get_model
12 | from dataset import get_dataloader
13 | from losses import CombinedMarginLoss
14 | from lr_scheduler import PolyScheduler
15 | from partial_fc import PartialFC, PartialFCAdamW
16 | from utils.utils_callbacks import CallBackLogging, CallBackVerification
17 | from utils.utils_config import get_config
18 | from utils.utils_logging import AverageMeter, init_logging
19 | from utils.utils_distributed_sampler import setup_seed
20 |
21 | assert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \
22 | we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future."
23 |
24 | try:
25 | world_size = int(os.environ["WORLD_SIZE"])
26 | rank = int(os.environ["RANK"])
27 | distributed.init_process_group("nccl")
28 | except KeyError:
29 | world_size = 1
30 | rank = 0
31 | distributed.init_process_group(
32 | backend="nccl",
33 | init_method="tcp://127.0.0.1:12584",
34 | rank=rank,
35 | world_size=world_size,
36 | )
37 |
38 |
39 | def main(args):
40 |
41 | # get config
42 | cfg = get_config(args.config)
43 | # global control random seed
44 | setup_seed(seed=cfg.seed, cuda_deterministic=False)
45 |
46 | torch.cuda.set_device(args.local_rank)
47 |
48 | os.makedirs(cfg.output, exist_ok=True)
49 | init_logging(rank, cfg.output)
50 |
51 | summary_writer = (
52 | SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
53 | if rank == 0
54 | else None
55 | )
56 |
57 | train_loader = get_dataloader(
58 | cfg.rec,
59 | args.local_rank,
60 | cfg.batch_size,
61 | cfg.dali,
62 | cfg.seed,
63 | cfg.num_workers
64 | )
65 |
66 | backbone = get_model(
67 | cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
68 |
69 | backbone = torch.nn.parallel.DistributedDataParallel(
70 | module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16,
71 | find_unused_parameters=True)
72 |
73 | backbone.train()
74 | # FIXME using gradient checkpoint if there are some unused parameters will cause error
75 | backbone._set_static_graph()
76 |
77 | margin_loss = CombinedMarginLoss(
78 | 64,
79 | cfg.margin_list[0],
80 | cfg.margin_list[1],
81 | cfg.margin_list[2],
82 | cfg.interclass_filtering_threshold
83 | )
84 |
85 | if cfg.optimizer == "sgd":
86 | module_partial_fc = PartialFC(
87 | margin_loss, cfg.embedding_size, cfg.num_classes,
88 | cfg.sample_rate, cfg.fp16)
89 | module_partial_fc.train().cuda()
90 | # TODO the params of partial fc must be last in the params list
91 | opt = torch.optim.SGD(
92 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
93 | lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)
94 |
95 | elif cfg.optimizer == "adamw":
96 | module_partial_fc = PartialFCAdamW(
97 | margin_loss, cfg.embedding_size, cfg.num_classes,
98 | cfg.sample_rate, cfg.fp16)
99 | module_partial_fc.train().cuda()
100 | opt = torch.optim.AdamW(
101 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
102 | lr=cfg.lr, weight_decay=cfg.weight_decay)
103 | else:
104 | raise
105 |
106 | cfg.total_batch_size = cfg.batch_size * world_size
107 | cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
108 | cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
109 |
110 | lr_scheduler = PolyScheduler(
111 | optimizer=opt,
112 | base_lr=cfg.lr,
113 | max_steps=cfg.total_step,
114 | warmup_steps=cfg.warmup_step,
115 | last_epoch=-1
116 | )
117 |
118 | start_epoch = 0
119 | global_step = 0
120 | if cfg.resume:
121 | dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
122 | start_epoch = dict_checkpoint["epoch"]
123 | global_step = dict_checkpoint["global_step"]
124 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
125 | module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
126 | opt.load_state_dict(dict_checkpoint["state_optimizer"])
127 | lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
128 | del dict_checkpoint
129 |
130 | for key, value in cfg.items():
131 | num_space = 25 - len(key)
132 | logging.info(": " + key + " " * num_space + str(value))
133 |
134 | callback_verification = CallBackVerification(
135 | val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer
136 | )
137 | callback_logging = CallBackLogging(
138 | frequent=cfg.frequent,
139 | total_step=cfg.total_step,
140 | batch_size=cfg.batch_size,
141 | start_step = global_step,
142 | writer=summary_writer
143 | )
144 |
145 | loss_am = AverageMeter()
146 | amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
147 |
148 | for epoch in range(start_epoch, cfg.num_epoch):
149 |
150 | if isinstance(train_loader, DataLoader):
151 | train_loader.sampler.set_epoch(epoch)
152 | for _, (img, local_labels) in enumerate(train_loader):
153 | global_step += 1
154 | local_embeddings = backbone(img)
155 | loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
156 |
157 | if cfg.fp16:
158 | amp.scale(loss).backward()
159 | amp.unscale_(opt)
160 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
161 | amp.step(opt)
162 | amp.update()
163 | else:
164 | loss.backward()
165 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
166 | opt.step()
167 |
168 | opt.zero_grad()
169 | lr_scheduler.step()
170 |
171 | with torch.no_grad():
172 | loss_am.update(loss.item(), 1)
173 | callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
174 |
175 | if global_step % cfg.verbose == 0 and global_step > 0:
176 | callback_verification(global_step, backbone)
177 |
178 | if cfg.save_all_states:
179 | checkpoint = {
180 | "epoch": epoch + 1,
181 | "global_step": global_step,
182 | "state_dict_backbone": backbone.module.state_dict(),
183 | "state_dict_softmax_fc": module_partial_fc.state_dict(),
184 | "state_optimizer": opt.state_dict(),
185 | "state_lr_scheduler": lr_scheduler.state_dict()
186 | }
187 | torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
188 |
189 | if rank == 0:
190 | path_module = os.path.join(cfg.output, "model.pt")
191 | torch.save(backbone.module.state_dict(), path_module)
192 |
193 | if cfg.dali:
194 | train_loader.reset()
195 |
196 | if rank == 0:
197 | path_module = os.path.join(cfg.output, "model.pt")
198 | torch.save(backbone.module.state_dict(), path_module)
199 |
200 | from torch2onnx import convert_onnx
201 | convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))
202 |
203 | distributed.destroy_process_group()
204 |
205 |
206 | if __name__ == "__main__":
207 | torch.backends.cudnn.benchmark = True
208 | parser = argparse.ArgumentParser(
209 | description="Distributed Arcface Training in Pytorch")
210 | parser.add_argument("config", type=str, help="py config file")
211 | parser.add_argument("--local_rank", type=int, default=0, help="local_rank")
212 | main(parser.parse_args())
213 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bbaaii/HFA-GP/aa2c15a61d8ddd182189153914098a2af0edfb0c/eg3d-pose-detection/models/arcface_torch/utils/__init__.py
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/utils/plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import pandas as pd
7 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
8 | from prettytable import PrettyTable
9 | from sklearn.metrics import roc_curve, auc
10 |
11 | with open(sys.argv[1], "r") as f:
12 | files = f.readlines()
13 |
14 | files = [x.strip() for x in files]
15 | image_path = "/train_tmp/IJB_release/IJBC"
16 |
17 |
18 | def read_template_pair_list(path):
19 | pairs = pd.read_csv(path, sep=' ', header=None).values
20 | t1 = pairs[:, 0].astype(np.int)
21 | t2 = pairs[:, 1].astype(np.int)
22 | label = pairs[:, 2].astype(np.int)
23 | return t1, t2, label
24 |
25 |
26 | p1, p2, label = read_template_pair_list(
27 | os.path.join('%s/meta' % image_path,
28 | '%s_template_pair_label.txt' % 'ijbc'))
29 |
30 | methods = []
31 | scores = []
32 | for file in files:
33 | methods.append(file)
34 | scores.append(np.load(file))
35 |
36 | methods = np.array(methods)
37 | scores = dict(zip(methods, scores))
38 | colours = dict(
39 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
40 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
41 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
42 | fig = plt.figure()
43 | for method in methods:
44 | fpr, tpr, _ = roc_curve(label, scores[method])
45 | roc_auc = auc(fpr, tpr)
46 | fpr = np.flipud(fpr)
47 | tpr = np.flipud(tpr) # select largest tpr at same fpr
48 | plt.plot(fpr,
49 | tpr,
50 | color=colours[method],
51 | lw=1,
52 | label=('[%s (AUC = %0.4f %%)]' %
53 | (method.split('-')[-1], roc_auc * 100)))
54 | tpr_fpr_row = []
55 | tpr_fpr_row.append(method)
56 | for fpr_iter in np.arange(len(x_labels)):
57 | _, min_index = min(
58 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
59 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
60 | tpr_fpr_table.add_row(tpr_fpr_row)
61 | plt.xlim([10 ** -6, 0.1])
62 | plt.ylim([0.3, 1.0])
63 | plt.grid(linestyle='--', linewidth=1)
64 | plt.xticks(x_labels)
65 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
66 | plt.xscale('log')
67 | plt.xlabel('False Positive Rate')
68 | plt.ylabel('True Positive Rate')
69 | plt.title('ROC on IJB')
70 | plt.legend(loc="lower right")
71 | print(tpr_fpr_table)
72 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/utils/utils_callbacks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import time
4 | from typing import List
5 |
6 | import torch
7 |
8 | from eval import verification
9 | from utils.utils_logging import AverageMeter
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torch import distributed
12 |
13 |
14 | class CallBackVerification(object):
15 |
16 | def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112)):
17 | self.rank: int = distributed.get_rank()
18 | self.highest_acc: float = 0.0
19 | self.highest_acc_list: List[float] = [0.0] * len(val_targets)
20 | self.ver_list: List[object] = []
21 | self.ver_name_list: List[str] = []
22 | if self.rank is 0:
23 | self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
24 |
25 | self.summary_writer = summary_writer
26 |
27 | def ver_test(self, backbone: torch.nn.Module, global_step: int):
28 | results = []
29 | for i in range(len(self.ver_list)):
30 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
31 | self.ver_list[i], backbone, 10, 10)
32 | logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
33 | logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
34 |
35 | self.summary_writer: SummaryWriter
36 | self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, )
37 |
38 | if acc2 > self.highest_acc_list[i]:
39 | self.highest_acc_list[i] = acc2
40 | logging.info(
41 | '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
42 | results.append(acc2)
43 |
44 | def init_dataset(self, val_targets, data_dir, image_size):
45 | for name in val_targets:
46 | path = os.path.join(data_dir, name + ".bin")
47 | if os.path.exists(path):
48 | data_set = verification.load_bin(path, image_size)
49 | self.ver_list.append(data_set)
50 | self.ver_name_list.append(name)
51 |
52 | def __call__(self, num_update, backbone: torch.nn.Module):
53 | if self.rank is 0 and num_update > 0:
54 | backbone.eval()
55 | self.ver_test(backbone, num_update)
56 | backbone.train()
57 |
58 |
59 | class CallBackLogging(object):
60 | def __init__(self, frequent, total_step, batch_size, start_step=0,writer=None):
61 | self.frequent: int = frequent
62 | self.rank: int = distributed.get_rank()
63 | self.world_size: int = distributed.get_world_size()
64 | self.time_start = time.time()
65 | self.total_step: int = total_step
66 | self.start_step: int = start_step
67 | self.batch_size: int = batch_size
68 | self.writer = writer
69 |
70 | self.init = False
71 | self.tic = 0
72 |
73 | def __call__(self,
74 | global_step: int,
75 | loss: AverageMeter,
76 | epoch: int,
77 | fp16: bool,
78 | learning_rate: float,
79 | grad_scaler: torch.cuda.amp.GradScaler):
80 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
81 | if self.init:
82 | try:
83 | speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
84 | speed_total = speed * self.world_size
85 | except ZeroDivisionError:
86 | speed_total = float('inf')
87 |
88 | #time_now = (time.time() - self.time_start) / 3600
89 | #time_total = time_now / ((global_step + 1) / self.total_step)
90 | #time_for_end = time_total - time_now
91 | time_now = time.time()
92 | time_sec = int(time_now - self.time_start)
93 | time_sec_avg = time_sec / (global_step - self.start_step + 1)
94 | eta_sec = time_sec_avg * (self.total_step - global_step - 1)
95 | time_for_end = eta_sec/3600
96 | if self.writer is not None:
97 | self.writer.add_scalar('time_for_end', time_for_end, global_step)
98 | self.writer.add_scalar('learning_rate', learning_rate, global_step)
99 | self.writer.add_scalar('loss', loss.avg, global_step)
100 | if fp16:
101 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \
102 | "Fp16 Grad Scale: %2.f Required: %1.f hours" % (
103 | speed_total, loss.avg, learning_rate, epoch, global_step,
104 | grad_scaler.get_scale(), time_for_end
105 | )
106 | else:
107 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \
108 | "Required: %1.f hours" % (
109 | speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
110 | )
111 | logging.info(msg)
112 | loss.reset()
113 | self.tic = time.time()
114 | else:
115 | self.init = True
116 | self.tic = time.time()
117 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/utils/utils_config.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os.path as osp
3 |
4 |
5 | def get_config(config_file):
6 | assert config_file.startswith('configs/'), 'config file setting must start with configs/'
7 | temp_config_name = osp.basename(config_file)
8 | temp_module_name = osp.splitext(temp_config_name)[0]
9 | config = importlib.import_module("configs.base")
10 | cfg = config.config
11 | config = importlib.import_module("configs.%s" % temp_module_name)
12 | job_cfg = config.config
13 | cfg.update(job_cfg)
14 | if cfg.output is None:
15 | cfg.output = osp.join('work_dirs', temp_module_name)
16 | return cfg
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/utils/utils_distributed_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.distributed as dist
8 | from torch.utils.data import DistributedSampler as _DistributedSampler
9 |
10 |
11 | def setup_seed(seed, cuda_deterministic=True):
12 | torch.manual_seed(seed)
13 | torch.cuda.manual_seed_all(seed)
14 | np.random.seed(seed)
15 | random.seed(seed)
16 | os.environ["PYTHONHASHSEED"] = str(seed)
17 | if cuda_deterministic: # slower, more reproducible
18 | torch.backends.cudnn.deterministic = True
19 | torch.backends.cudnn.benchmark = False
20 | else: # faster, less reproducible
21 | torch.backends.cudnn.deterministic = False
22 | torch.backends.cudnn.benchmark = True
23 |
24 |
25 | def worker_init_fn(worker_id, num_workers, rank, seed):
26 | # The seed of each worker equals to
27 | # num_worker * rank + worker_id + user_seed
28 | worker_seed = num_workers * rank + worker_id + seed
29 | np.random.seed(worker_seed)
30 | random.seed(worker_seed)
31 | torch.manual_seed(worker_seed)
32 |
33 |
34 | def get_dist_info():
35 | if dist.is_available() and dist.is_initialized():
36 | rank = dist.get_rank()
37 | world_size = dist.get_world_size()
38 | else:
39 | rank = 0
40 | world_size = 1
41 |
42 | return rank, world_size
43 |
44 |
45 | def sync_random_seed(seed=None, device="cuda"):
46 | """Make sure different ranks share the same seed.
47 | All workers must call this function, otherwise it will deadlock.
48 | This method is generally used in `DistributedSampler`,
49 | because the seed should be identical across all processes
50 | in the distributed group.
51 | In distributed sampling, different ranks should sample non-overlapped
52 | data in the dataset. Therefore, this function is used to make sure that
53 | each rank shuffles the data indices in the same order based
54 | on the same seed. Then different ranks could use different indices
55 | to select non-overlapped data from the same data list.
56 | Args:
57 | seed (int, Optional): The seed. Default to None.
58 | device (str): The device where the seed will be put on.
59 | Default to 'cuda'.
60 | Returns:
61 | int: Seed to be used.
62 | """
63 | if seed is None:
64 | seed = np.random.randint(2**31)
65 | assert isinstance(seed, int)
66 |
67 | rank, world_size = get_dist_info()
68 |
69 | if world_size == 1:
70 | return seed
71 |
72 | if rank == 0:
73 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
74 | else:
75 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
76 |
77 | dist.broadcast(random_num, src=0)
78 |
79 | return random_num.item()
80 |
81 |
82 | class DistributedSampler(_DistributedSampler):
83 | def __init__(
84 | self,
85 | dataset,
86 | num_replicas=None, # world_size
87 | rank=None, # local_rank
88 | shuffle=True,
89 | seed=0,
90 | ):
91 |
92 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
93 |
94 | # In distributed sampling, different ranks should sample
95 | # non-overlapped data in the dataset. Therefore, this function
96 | # is used to make sure that each rank shuffles the data indices
97 | # in the same order based on the same seed. Then different ranks
98 | # could use different indices to select non-overlapped data from the
99 | # same data list.
100 | self.seed = sync_random_seed(seed)
101 |
102 | def __iter__(self):
103 | # deterministically shuffle based on epoch
104 | if self.shuffle:
105 | g = torch.Generator()
106 | # When :attr:`shuffle=True`, this ensures all replicas
107 | # use a different random ordering for each epoch.
108 | # Otherwise, the next iteration of this sampler will
109 | # yield the same ordering.
110 | g.manual_seed(self.epoch + self.seed)
111 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
112 | else:
113 | indices = torch.arange(len(self.dataset)).tolist()
114 |
115 | # add extra samples to make it evenly divisible
116 | # in case that indices is shorter than half of total_size
117 | indices = (indices * math.ceil(self.total_size / len(indices)))[
118 | : self.total_size
119 | ]
120 | assert len(indices) == self.total_size
121 |
122 | # subsample
123 | indices = indices[self.rank : self.total_size : self.num_replicas]
124 | assert len(indices) == self.num_samples
125 |
126 | return iter(indices)
127 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/arcface_torch/utils/utils_logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 |
6 | class AverageMeter(object):
7 | """Computes and stores the average and current value
8 | """
9 |
10 | def __init__(self):
11 | self.val = None
12 | self.avg = None
13 | self.sum = None
14 | self.count = None
15 | self.reset()
16 |
17 | def reset(self):
18 | self.val = 0
19 | self.avg = 0
20 | self.sum = 0
21 | self.count = 0
22 |
23 | def update(self, val, n=1):
24 | self.val = val
25 | self.sum += val * n
26 | self.count += n
27 | self.avg = self.sum / self.count
28 |
29 |
30 | def init_logging(rank, models_root):
31 | if rank == 0:
32 | log_root = logging.getLogger()
33 | log_root.setLevel(logging.INFO)
34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
36 | handler_stream = logging.StreamHandler(sys.stdout)
37 | handler_file.setFormatter(formatter)
38 | handler_stream.setFormatter(formatter)
39 | log_root.addHandler(handler_file)
40 | log_root.addHandler(handler_stream)
41 | log_root.info('rank_id: %d' % rank)
42 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/losses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from kornia.geometry import warp_affine
5 | import torch.nn.functional as F
6 |
7 | def resize_n_crop(image, M, dsize=112):
8 | # image: (b, c, h, w)
9 | # M : (b, 2, 3)
10 | return warp_affine(image, M, dsize=(dsize, dsize))
11 |
12 | ### perceptual level loss
13 | class PerceptualLoss(nn.Module):
14 | def __init__(self, recog_net, input_size=112):
15 | super(PerceptualLoss, self).__init__()
16 | self.recog_net = recog_net
17 | self.preprocess = lambda x: 2 * x - 1
18 | self.input_size=input_size
19 | def forward(imageA, imageB, M):
20 | """
21 | 1 - cosine distance
22 | Parameters:
23 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
24 | imageB --same as imageA
25 | """
26 |
27 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
28 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
29 |
30 | # freeze bn
31 | self.recog_net.eval()
32 |
33 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
34 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
35 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
36 | # assert torch.sum((cosine_d > 1).float()) == 0
37 | return torch.sum(1 - cosine_d) / cosine_d.shape[0]
38 |
39 | def perceptual_loss(id_featureA, id_featureB):
40 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
41 | # assert torch.sum((cosine_d > 1).float()) == 0
42 | return torch.sum(1 - cosine_d) / cosine_d.shape[0]
43 |
44 | ### image level loss
45 | def photo_loss(imageA, imageB, mask, eps=1e-6):
46 | """
47 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
48 | Parameters:
49 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
50 | imageB --same as imageA
51 | """
52 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
53 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
54 | return loss
55 |
56 | def landmark_loss(predict_lm, gt_lm, weight=None):
57 | """
58 | weighted mse loss
59 | Parameters:
60 | predict_lm --torch.tensor (B, 68, 2)
61 | gt_lm --torch.tensor (B, 68, 2)
62 | weight --numpy.array (1, 68)
63 | """
64 | if not weight:
65 | weight = np.ones([68])
66 | weight[28:31] = 20
67 | weight[-8:] = 20
68 | weight = np.expand_dims(weight, 0)
69 | weight = torch.tensor(weight).to(predict_lm.device)
70 | loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
71 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
72 | return loss
73 |
74 |
75 | ### regulization
76 | def reg_loss(coeffs_dict, opt=None):
77 | """
78 | l2 norm without the sqrt, from yu's implementation (mse)
79 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
80 | Parameters:
81 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
82 |
83 | """
84 | # coefficient regularization to ensure plausible 3d faces
85 | if opt:
86 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
87 | else:
88 | w_id, w_exp, w_tex = 1, 1, 1, 1
89 | creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
90 | w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
91 | w_tex * torch.sum(coeffs_dict['tex'] ** 2)
92 | creg_loss = creg_loss / coeffs_dict['id'].shape[0]
93 |
94 | # gamma regularization to ensure a nearly-monochromatic light
95 | gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
96 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
97 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
98 |
99 | return creg_loss, gamma_loss
100 |
101 | def reflectance_loss(texture, mask):
102 | """
103 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
104 | Parameters:
105 | texture --torch.tensor, (B, N, 3)
106 | mask --torch.tensor, (N), 1 or 0
107 |
108 | """
109 | mask = mask.reshape([1, mask.shape[0], 1])
110 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
111 | loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
112 | return loss
113 |
114 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/models/template_model.py:
--------------------------------------------------------------------------------
1 | """Model class template
2 |
3 | This module provides a template for users to implement custom models.
4 | You can specify '--model template' to use this model.
5 | The class name should be consistent with both the filename and its model option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | It implements a simple image-to-image translation baseline based on regression loss.
9 | Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
10 | min_ ||netG(data_A) - data_B||_1
11 | You need to implement the following functions:
12 | : Add model-specific options and rewrite default values for existing options.
13 | <__init__>: Initialize this model class.
14 | : Unpack input data and perform data pre-processing.
15 | : Run forward pass. This will be called by both and .
16 | : Update network weights; it will be called in every training iteration.
17 | """
18 | import numpy as np
19 | import torch
20 | from .base_model import BaseModel
21 | from . import networks
22 |
23 |
24 | class TemplateModel(BaseModel):
25 | @staticmethod
26 | def modify_commandline_options(parser, is_train=True):
27 | """Add new model-specific options and rewrite default values for existing options.
28 |
29 | Parameters:
30 | parser -- the option parser
31 | is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
32 |
33 | Returns:
34 | the modified parser.
35 | """
36 | parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
37 | if is_train:
38 | parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
39 |
40 | return parser
41 |
42 | def __init__(self, opt):
43 | """Initialize this model class.
44 |
45 | Parameters:
46 | opt -- training/test options
47 |
48 | A few things can be done here.
49 | - (required) call the initialization function of BaseModel
50 | - define loss function, visualization images, model names, and optimizers
51 | """
52 | BaseModel.__init__(self, opt) # call the initialization method of BaseModel
53 | # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
54 | self.loss_names = ['loss_G']
55 | # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
56 | self.visual_names = ['data_A', 'data_B', 'output']
57 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
58 | # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
59 | self.model_names = ['G']
60 | # define networks; you can use opt.isTrain to specify different behaviors for training and test.
61 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
62 | if self.isTrain: # only defined during training time
63 | # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
64 | # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
65 | self.criterionLoss = torch.nn.L1Loss()
66 | # define and initialize optimizers. You can define one optimizer for each network.
67 | # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
68 | self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
69 | self.optimizers = [self.optimizer]
70 |
71 | # Our program will automatically call to define schedulers, load networks, and print networks
72 |
73 | def set_input(self, input):
74 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
75 |
76 | Parameters:
77 | input: a dictionary that contains the data itself and its metadata information.
78 | """
79 | AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
80 | self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
81 | self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
82 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
83 |
84 | def forward(self):
85 | """Run forward pass. This will be called by both functions and ."""
86 | self.output = self.netG(self.data_A) # generate output image given the input data_A
87 |
88 | def backward(self):
89 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
90 | # caculate the intermediate results if necessary; here self.output has been computed during function
91 | # calculate loss given the input and intermediate results
92 | self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
93 | self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
94 |
95 | def optimize_parameters(self):
96 | """Update network weights; it will be called in every training iteration."""
97 | self.forward() # first call forward to calculate intermediate results
98 | self.optimizer.zero_grad() # clear network G's existing gradients
99 | self.backward() # calculate gradients for network G
100 | self.optimizer.step() # update gradients for network G
101 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/options/base_options.py:
--------------------------------------------------------------------------------
1 | """This script contains base options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import argparse
5 | import os
6 | from util import util
7 | import numpy as np
8 | import torch
9 | import models
10 | import data
11 |
12 |
13 | class BaseOptions():
14 | """This class defines options used during both training and test time.
15 |
16 | It also implements several helper functions such as parsing, printing, and saving the options.
17 | It also gathers additional options defined in functions in both dataset class and model class.
18 | """
19 |
20 | def __init__(self, cmd_line=None):
21 | """Reset the class; indicates the class hasn't been initailized"""
22 | self.initialized = False
23 | self.cmd_line = None
24 | if cmd_line is not None:
25 | self.cmd_line = cmd_line.split()
26 |
27 | def initialize(self, parser):
28 | """Define the common options that are used in both training and test."""
29 | # basic parameters
30 | parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')
31 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
32 | parser.add_argument('--checkpoints_dir', type=str, default='/apdcephfs_cq2/share_1290939/kitbai/eg3d-pose-detection-main-master/checkpoints', help='models are saved here')
33 | parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
34 | parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
35 | parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
36 | parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
37 | parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
38 | parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
39 | parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
40 |
41 | # model parameters
42 | parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
43 |
44 | # additional parameters
45 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
46 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
47 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
48 |
49 | self.initialized = True
50 | return parser
51 |
52 | def gather_options(self):
53 | """Initialize our parser with basic options(only once).
54 | Add additional model-specific and dataset-specific options.
55 | These options are defined in the function
56 | in model and dataset classes.
57 | """
58 | if not self.initialized: # check if it has been initialized
59 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
60 | parser = self.initialize(parser)
61 |
62 | # get the basic options
63 | if self.cmd_line is None:
64 | opt, _ = parser.parse_known_args()
65 | else:
66 | opt, _ = parser.parse_known_args(self.cmd_line)
67 |
68 | # set cuda visible devices
69 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
70 |
71 | # modify model-related parser options
72 | model_name = opt.model
73 | model_option_setter = models.get_option_setter(model_name)
74 | parser = model_option_setter(parser, self.isTrain)
75 | if self.cmd_line is None:
76 | opt, _ = parser.parse_known_args() # parse again with new defaults
77 | else:
78 | opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
79 |
80 | # modify dataset-related parser options
81 | if opt.dataset_mode:
82 | dataset_name = opt.dataset_mode
83 | dataset_option_setter = data.get_option_setter(dataset_name)
84 | parser = dataset_option_setter(parser, self.isTrain)
85 |
86 | # save and return the parser
87 | self.parser = parser
88 | if self.cmd_line is None:
89 | return parser.parse_args()
90 | else:
91 | return parser.parse_args(self.cmd_line)
92 |
93 | def print_options(self, opt):
94 | """Print and save options
95 |
96 | It will print both current options and default values(if different).
97 | It will save options into a text file / [checkpoints_dir] / opt.txt
98 | """
99 | message = ''
100 | message += '----------------- Options ---------------\n'
101 | for k, v in sorted(vars(opt).items()):
102 | comment = ''
103 | default = self.parser.get_default(k)
104 | if v != default:
105 | comment = '\t[default: %s]' % str(default)
106 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
107 | message += '----------------- End -------------------'
108 | print(message)
109 |
110 | # save to the disk
111 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
112 | util.mkdirs(expr_dir)
113 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
114 | try:
115 | with open(file_name, 'wt') as opt_file:
116 | opt_file.write(message)
117 | opt_file.write('\n')
118 | except PermissionError as error:
119 | print("permission error {}".format(error))
120 | pass
121 |
122 | def parse(self):
123 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
124 | opt = self.gather_options()
125 | opt.isTrain = self.isTrain # train or test
126 |
127 | # process opt.suffix
128 | if opt.suffix:
129 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
130 | opt.name = opt.name + suffix
131 |
132 |
133 | # set gpu ids
134 | str_ids = opt.gpu_ids.split(',')
135 | gpu_ids = []
136 | for str_id in str_ids:
137 | id = int(str_id)
138 | if id >= 0:
139 | gpu_ids.append(id)
140 | opt.world_size = len(gpu_ids)
141 | # if len(opt.gpu_ids) > 0:
142 | # torch.cuda.set_device(gpu_ids[0])
143 | if opt.world_size == 1:
144 | opt.use_ddp = False
145 |
146 | if opt.phase != 'test':
147 | # set continue_train automatically
148 | if opt.pretrained_name is None:
149 | model_dir = os.path.join(opt.checkpoints_dir, opt.name)
150 | else:
151 | model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
152 | if os.path.isdir(model_dir):
153 | model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
154 | if os.path.isdir(model_dir) and len(model_pths) != 0:
155 | opt.continue_train= True
156 |
157 | # update the latest epoch count
158 | if opt.continue_train:
159 | if opt.epoch == 'latest':
160 | epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
161 | if len(epoch_counts) != 0:
162 | opt.epoch_count = max(epoch_counts) + 1
163 | else:
164 | opt.epoch_count = int(opt.epoch) + 1
165 |
166 |
167 | self.print_options(opt)
168 | self.opt = opt
169 | return self.opt
170 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/options/test_options.py:
--------------------------------------------------------------------------------
1 | """This script contains the test options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | from .base_options import BaseOptions
5 |
6 |
7 | class TestOptions(BaseOptions):
8 | """This class includes test options.
9 |
10 | It also includes shared options defined in BaseOptions.
11 | """
12 |
13 | def initialize(self, parser):
14 | parser = BaseOptions.initialize(self, parser) # define shared options
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
17 | parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
18 | parser.add_argument('--start', type=int, default=0, help='start folder')
19 | parser.add_argument('--skip_model', action='store_true', help='whether to run model')
20 |
21 | # Dropout and Batchnorm has different behavior during training and test.
22 | self.isTrain = False
23 | return parser
24 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/options/train_options.py:
--------------------------------------------------------------------------------
1 | """This script contains the training options for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | from .base_options import BaseOptions
5 | from util import util
6 |
7 | class TrainOptions(BaseOptions):
8 | """This class includes training options.
9 |
10 | It also includes shared options defined in BaseOptions.
11 | """
12 |
13 | def initialize(self, parser):
14 | parser = BaseOptions.initialize(self, parser)
15 | # dataset parameters
16 | # for train
17 | parser.add_argument('--data_root', type=str, default='./', help='dataset root')
18 | parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
19 | parser.add_argument('--batch_size', type=int, default=32)
20 | parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
21 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
22 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
23 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
24 | parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
25 | parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
26 |
27 | # for val
28 | parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
29 | parser.add_argument('--batch_size_val', type=int, default=32)
30 |
31 |
32 | # visualization parameters
33 | parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
34 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
35 |
36 | # network saving and loading parameters
37 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
38 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
39 | parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
40 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
41 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
42 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
43 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
44 | parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
45 |
46 | # training parameters
47 | parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
48 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
49 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
50 | parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
51 |
52 | self.isTrain = True
53 | return parser
54 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/process_test_video.py:
--------------------------------------------------------------------------------
1 | """
2 | Processes a directory containing *.jpg/png and outputs crops and poses.
3 | """
4 | import glob
5 | import os
6 | import subprocess
7 | import argparse
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--input_dir', default='/media/data6/ericryanchan/mafu/Deep3DFaceRecon_pytorch/test_images')
10 | parser.add_argument('--gpu', default=0)
11 | args = parser.parse_args()
12 |
13 | print('Processing images:', sorted(glob.glob(os.path.join(args.input_dir, "*"))))
14 |
15 | # Compute facial landmarks.
16 | print("Computing facial landmarks for model...")
17 | cmd = "python3.6 /eg3d-pose-detection/batch_mtcnn.py"
18 | input_flag = " --in_root " + args.input_dir
19 | cmd += input_flag
20 | # subprocess.run([cmd], shell=True, check=True)
21 | os.system(cmd)
22 |
23 |
24 | print("Running smooth...")
25 | cmd = "python3.6 /eg3d-pose-detection/smooth.py"
26 | input_flag = " --img_folder=" + args.input_dir
27 | cmd += input_flag
28 | os.system(cmd)
29 |
30 | # Run model inference to produce crops and raw poses.
31 | print("Running model inference...")
32 | cmd = "python3.6 /eg3d-pose-detection/test.py"
33 | input_flag = " --img_folder=" + args.input_dir
34 | gpu_flag = " --gpu_ids=" + str(args.gpu)
35 | model_name_flag = " --name=face_recon"
36 | model_file_flag = " --epoch=20 "
37 | cmd += input_flag + gpu_flag + model_name_flag + model_file_flag
38 | # subprocess.run([cmd], shell=True, check=True)
39 | os.system(cmd)
40 |
41 | # Perform final cropping of 1024x1024 images.
42 | print("Processing final crops...")
43 | cmd = "python3.6 /eg3d-pose-detection/crop_images.py"
44 | input_flag = " --indir " + args.input_dir
45 | output_flag = " --outdir " + os.path.join(args.input_dir, 'cropped_images')
46 | cmd += input_flag + output_flag
47 | # subprocess.run([cmd], shell=True, check=True)
48 | os.system(cmd)
49 |
50 | # Process poses into our representation -- produces a cameras.json file.
51 | print("Processing final poses...")
52 | cmd = "python3.6 /eg3d-pose-detection/3dface2idr.py"
53 | input_flag = " --in_root " + os.path.join(args.input_dir, "epoch_20_000000")
54 | output_flag = " --out_root " + os.path.join(args.input_dir, "cropped_images")
55 |
56 | cmd += input_flag + output_flag
57 | # subprocess.run([cmd], shell=True, check=True)
58 | os.system(cmd)
59 |
60 | print("Transforming...")
61 | cmd = "python3.6 /eg3d-pose-detection/camera2label.py"
62 | input_flag = " --in_root " + args.input_dir
63 | cmd += input_flag
64 | # subprocess.run([cmd], shell=True, check=True)
65 | os.system(cmd)
66 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/smooth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from options.test_options import TestOptions
4 | from data import create_dataset
5 | from models import create_model
6 | from util.visualizer import MyVisualizer
7 | from util.preprocess import align_img
8 | from PIL import Image
9 | import numpy as np
10 | from util.load_mats import load_lm3d
11 |
12 | from data.flist_dataset import default_flist_reader
13 | from scipy.io import loadmat, savemat
14 |
15 | import json
16 |
17 | from scipy.ndimage import gaussian_filter1d
18 |
19 |
20 | def get_data_path(root='examples'):
21 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')]
22 | im_path = sorted(im_path , key=lambda x:int(x.split('/')[-1].split('.')[0]))
23 | lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path]
24 | lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path]
25 | return lm_path
26 |
27 |
28 | def read_data( lm_path):
29 | lms = []
30 | # for i in range(10):
31 | for i in range(len(lm_path)):
32 | #im = Image.open(im_path).convert('RGB')
33 | # _, H = im.size
34 | if not os.path.isfile(lm_path[i]):
35 | continue
36 | lm = np.loadtxt(lm_path[i]).astype(np.float32)
37 | # print(lm)
38 | lms.append(lm)
39 | lms = np.array(lms)
40 | lms = gaussian_filter1d(lms, 2, 0)
41 |
42 | # for i in range(10):
43 | # print(lms[i])
44 | for i in range(len(lm_path)):
45 | if not os.path.isfile(lm_path[i]):
46 | continue
47 | np.savetxt(lm_path[i], lms[i])
48 |
49 |
50 | def main(rank, opt, name='examples'):
51 |
52 | lm_path = get_data_path(name)
53 | read_data(lm_path)
54 |
55 | if __name__ == '__main__':
56 | opt = TestOptions().parse() # get test options
57 | main(0, opt,opt.img_folder)
58 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/test:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/eg3d-pose-detection/test.py:
--------------------------------------------------------------------------------
1 | """This script is the test script for Deep3DFaceRecon_pytorch
2 | """
3 |
4 | import os
5 | import torch
6 | from options.test_options import TestOptions
7 | from data import create_dataset
8 | from models import create_model
9 | from util.visualizer import MyVisualizer
10 | from util.preprocess import align_img
11 | from PIL import Image
12 | import numpy as np
13 | from util.load_mats import load_lm3d
14 |
15 | from data.flist_dataset import default_flist_reader
16 | from scipy.io import loadmat, savemat
17 |
18 | import json
19 |
20 | def get_data_path(root='examples'):
21 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')]
22 | lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path]
23 | lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path]
24 | return im_path, lm_path
25 |
26 | def read_data(im_path, lm_path, lm3d_std, to_tensor=True, rescale_factor=466.285):
27 | im = Image.open(im_path).convert('RGB')
28 | _, H = im.size
29 | lm = np.loadtxt(lm_path).astype(np.float32)
30 | lm = lm.reshape([-1, 2])
31 | lm[:, -1] = H - 1 - lm[:, -1]
32 | _, im_pil, lm, _, im_high = align_img(im, lm, lm3d_std, rescale_factor=rescale_factor)
33 | # im_high.save(os.path.join('/apdcephfs_cq2/share_1290939/kitbai/LIA-3d/datasets/our_dataset/test/3', 'crop_1024', 'imagehigh.png'))
34 | if to_tensor:
35 | im = torch.tensor(np.array(im_pil)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
36 | lm = torch.tensor(lm).unsqueeze(0)
37 | else:
38 | im = im_pil
39 | return im, lm, im_pil, im_high
40 |
41 | def main(rank, opt, name='examples'):
42 | device = torch.device(rank)
43 | torch.cuda.set_device(device)
44 | model = create_model(opt)
45 | model.setup(opt)
46 | model.device = device
47 | model.parallelize()
48 | model.eval()
49 | visualizer = MyVisualizer(opt)
50 | print("ROOT")
51 | print(name)
52 | im_path, lm_path = get_data_path(name)
53 | lm3d_std = load_lm3d(opt.bfm_folder)
54 |
55 | cropping_params = {}
56 |
57 | out_dir_crop1024 = os.path.join(name, "crop_1024")
58 | if not os.path.exists(out_dir_crop1024):
59 | os.makedirs(out_dir_crop1024)
60 | out_dir = os.path.join(name, 'epoch_%s_%06d'%(opt.epoch, 0))
61 | if not os.path.exists(out_dir):
62 | os.makedirs(out_dir)
63 | for i in range(len(im_path)):
64 | print(i, im_path[i])
65 | img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','')
66 | if not os.path.isfile(lm_path[i]):
67 | continue
68 |
69 | # 2 passes for cropping image for NeRF and for pose extraction
70 | for r in range(2):
71 | if r==0:
72 | rescale_factor = 300 # optimized for NeRF training
73 | center_crop_size = 700
74 | output_size = 512
75 |
76 | # left = int(im_high.size[0]/2 - center_crop_size/2)
77 | # upper = int(im_high.size[1]/2 - center_crop_size/2)
78 | # right = left + center_crop_size
79 | # lower = upper + center_crop_size
80 | # im_cropped = im_high.crop((left, upper, right,lower))
81 | # im_cropped = im_cropped.resize((output_size, output_size), resample=Image.LANCZOS)
82 | cropping_params[os.path.basename(im_path[i])] = {
83 | 'lm': np.loadtxt(lm_path[i]).astype(np.float32).tolist(),
84 | 'lm3d_std': lm3d_std.tolist(),
85 | 'rescale_factor': rescale_factor,
86 | 'center_crop_size': center_crop_size,
87 | 'output_size': output_size}
88 |
89 | # im_high.save(os.path.join(out_dir_crop1024, img_name+'.png'), compress_level=0)
90 | # im_cropped.save(os.path.join(out_dir_crop1024, img_name+'.png'), compress_level=0)
91 | elif not opt.skip_model:
92 | rescale_factor = 466.285
93 | im_tensor, lm_tensor, _, im_high = read_data(im_path[i], lm_path[i], lm3d_std, rescale_factor=rescale_factor)
94 |
95 | data = {
96 | 'imgs': im_tensor,
97 | 'lms': lm_tensor
98 | }
99 | model.set_input(data) # unpack data from data loader
100 | model.test() # run inference
101 | visuals = model.get_current_visuals() # get image results
102 | visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1],
103 | save_results=True, count=i, name=img_name, add_image=False)
104 |
105 | model.save_coeff(os.path.join(out_dir,img_name+'.mat')) # save predicted coefficients
106 |
107 | with open(os.path.join(name, 'cropping_params.json'), 'w') as outfile:
108 | json.dump(cropping_params, outfile, indent=4)
109 |
110 | if __name__ == '__main__':
111 | opt = TestOptions().parse() # get test options
112 | main(0, opt,opt.img_folder)
113 |
114 |
115 | # python test.py --epoch 20 --img_folder test_images/ --gpu_ids 0
--------------------------------------------------------------------------------