├── .gitignore ├── .pre-commit-config.yaml ├── AdaptiveWingLoss ├── aux.py ├── core │ ├── __init__.py │ ├── coord_conv.py │ ├── dataloader.py │ ├── evaler.py │ └── models.py └── utils │ ├── __init__.py │ └── utils.py ├── Deep3DFaceRecon_pytorch ├── LICENSE ├── README.md ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── flist_dataset.py │ ├── image_folder.py │ └── template_dataset.py ├── data_preparation.py ├── environment.yml ├── models │ ├── __init__.py │ ├── arcface_torch │ │ ├── README.md │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── iresnet.cpython-310.pyc │ │ │ │ └── mobilefacenet.cpython-310.pyc │ │ │ ├── iresnet.py │ │ │ ├── iresnet2060.py │ │ │ ├── mobilefacenet.py │ │ │ └── vit.py │ │ ├── configs │ │ │ ├── 3millions.py │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── glint360k_mbf.py │ │ │ ├── glint360k_r100.py │ │ │ ├── glint360k_r50.py │ │ │ ├── ms1mv2_mbf.py │ │ │ ├── ms1mv2_r100.py │ │ │ ├── ms1mv2_r50.py │ │ │ ├── ms1mv3_mbf.py │ │ │ ├── ms1mv3_r100.py │ │ │ ├── ms1mv3_r50.py │ │ │ ├── ms1mv3_r50_onegpu.py │ │ │ ├── wf12m_conflict_r50.py │ │ │ ├── wf12m_conflict_r50_pfc03_filter04.py │ │ │ ├── wf12m_flip_pfc01_filter04_r50.py │ │ │ ├── wf12m_flip_r50.py │ │ │ ├── wf12m_mbf.py │ │ │ ├── wf12m_pfc02_r100.py │ │ │ ├── wf12m_r100.py │ │ │ ├── wf12m_r50.py │ │ │ ├── wf42m_pfc0008_32gpu_r100.py │ │ │ ├── wf42m_pfc02_16gpus_mbf_bs8k.py │ │ │ ├── wf42m_pfc02_16gpus_r100.py │ │ │ ├── wf42m_pfc02_16gpus_r50_bs8k.py │ │ │ ├── wf42m_pfc02_32gpus_r50_bs4k.py │ │ │ ├── wf42m_pfc02_8gpus_r50_bs4k.py │ │ │ ├── wf42m_pfc02_r100.py │ │ │ ├── wf42m_pfc02_r100_16gpus.py │ │ │ ├── wf42m_pfc02_r100_32gpus.py │ │ │ ├── wf42m_pfc03_32gpu_r100.py │ │ │ ├── wf42m_pfc03_32gpu_r18.py │ │ │ ├── wf42m_pfc03_32gpu_r200.py │ │ │ ├── wf42m_pfc03_32gpu_r50.py │ │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_b.py │ │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_l.py │ │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_s.py │ │ │ ├── wf42m_pfc03_40epoch_64gpu_vit_t.py │ │ │ ├── wf42m_pfc03_40epoch_8gpu_vit_b.py │ │ │ ├── wf42m_pfc03_40epoch_8gpu_vit_t.py │ │ │ ├── wf4m_mbf.py │ │ │ ├── wf4m_r100.py │ │ │ └── wf4m_r50.py │ │ ├── dataset.py │ │ ├── dist.sh │ │ ├── docs │ │ │ ├── eval.md │ │ │ ├── install.md │ │ │ ├── install_dali.md │ │ │ ├── modelzoo.md │ │ │ ├── prepare_custom_dataset.md │ │ │ ├── prepare_webface42m.md │ │ │ └── speed_benchmark.md │ │ ├── eval │ │ │ ├── __init__.py │ │ │ └── verification.py │ │ ├── eval_ijbc.py │ │ ├── flops.py │ │ ├── inference.py │ │ ├── losses.py │ │ ├── lr_scheduler.py │ │ ├── onnx_helper.py │ │ ├── onnx_ijbc.py │ │ ├── partial_fc.py │ │ ├── partial_fc_v2.py │ │ ├── requirement.txt │ │ ├── run.sh │ │ ├── scripts │ │ │ └── shuffle_rec.py │ │ ├── torch2onnx.py │ │ ├── train.py │ │ ├── train_v2.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── plot.py │ │ │ ├── utils_callbacks.py │ │ │ ├── utils_config.py │ │ │ ├── utils_distributed_sampler.py │ │ │ └── utils_logging.py │ ├── base_model.py │ ├── bfm.py │ ├── facerecon_model.py │ ├── losses.py │ ├── networks.py │ └── template_model.py ├── options │ ├── __init__.py │ ├── base_options.py │ ├── test_options.py │ └── train_options.py ├── test.py ├── train.py └── util │ ├── BBRegressorParam_r.mat │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── load_mats.cpython-310.pyc │ ├── detect_lm68.py │ ├── generate_list.py │ ├── html.py │ ├── load_mats.py │ ├── nvdiffrast.py │ ├── preprocess.py │ ├── skin_mask.py │ ├── test_mean_face.txt │ ├── util.py │ └── visualizer.py ├── HRNet └── hrnet.py ├── LICENSE ├── README.org ├── app └── app.py ├── arcface_torch ├── README.md ├── backbones │ ├── __init__.py │ ├── iresnet.py │ ├── iresnet2060.py │ ├── mobilefacenet.py │ └── vit.py ├── configs │ ├── 3millions.py │ ├── __init__.py │ ├── base.py │ ├── glint360k_mbf.py │ ├── glint360k_r100.py │ ├── glint360k_r50.py │ ├── ms1mv2_mbf.py │ ├── ms1mv2_r100.py │ ├── ms1mv2_r50.py │ ├── ms1mv3_mbf.py │ ├── ms1mv3_r100.py │ ├── ms1mv3_r50.py │ ├── ms1mv3_r50_onegpu.py │ ├── wf12m_conflict_r50.py │ ├── wf12m_conflict_r50_pfc03_filter04.py │ ├── wf12m_flip_pfc01_filter04_r50.py │ ├── wf12m_flip_r50.py │ ├── wf12m_mbf.py │ ├── wf12m_pfc02_r100.py │ ├── wf12m_r100.py │ ├── wf12m_r50.py │ ├── wf42m_pfc0008_32gpu_r100.py │ ├── wf42m_pfc02_16gpus_mbf_bs8k.py │ ├── wf42m_pfc02_16gpus_r100.py │ ├── wf42m_pfc02_16gpus_r50_bs8k.py │ ├── wf42m_pfc02_32gpus_r50_bs4k.py │ ├── wf42m_pfc02_8gpus_r50_bs4k.py │ ├── wf42m_pfc02_r100.py │ ├── wf42m_pfc02_r100_16gpus.py │ ├── wf42m_pfc02_r100_32gpus.py │ ├── wf42m_pfc03_32gpu_r100.py │ ├── wf42m_pfc03_32gpu_r18.py │ ├── wf42m_pfc03_32gpu_r200.py │ ├── wf42m_pfc03_32gpu_r50.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_b.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_l.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_s.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_t.py │ ├── wf42m_pfc03_40epoch_8gpu_vit_b.py │ ├── wf42m_pfc03_40epoch_8gpu_vit_t.py │ ├── wf4m_mbf.py │ ├── wf4m_r100.py │ └── wf4m_r50.py ├── dataset.py ├── dist.sh ├── docs │ ├── eval.md │ ├── install.md │ ├── install_dali.md │ ├── modelzoo.md │ ├── prepare_custom_dataset.md │ ├── prepare_webface42m.md │ └── speed_benchmark.md ├── eval │ ├── __init__.py │ └── verification.py ├── eval_ijbc.py ├── flops.py ├── inference.py ├── losses.py ├── lr_scheduler.py ├── onnx_helper.py ├── onnx_ijbc.py ├── partial_fc.py ├── partial_fc_v2.py ├── requirement.txt ├── run.sh ├── scripts │ └── shuffle_rec.py ├── torch2onnx.py ├── train.py ├── train_v2.py └── utils │ ├── __init__.py │ ├── plot.py │ ├── utils_callbacks.py │ ├── utils_config.py │ ├── utils_distributed_sampler.py │ └── utils_logging.py ├── benchmark ├── app_image.py ├── app_video.py ├── face_pipeline.py ├── inference_image.py ├── inference_video.py ├── scrfd_detect.py ├── test.py └── test_1tom.py ├── configs ├── mode.py ├── singleton.py └── train_config.py ├── data └── dataset.py ├── data_process ├── generate_mask.py ├── model.py ├── resnet.py └── utils.py ├── entry ├── inference.py └── train.py ├── models ├── discriminator.py ├── gan_loss.py ├── generator.py ├── init_weight.py ├── model.py ├── model_blocks.py ├── semantic_face_fusion_model.py └── shape_aware_identity_model.py ├── results ├── exp_230901_base_1693564635742_320000_1.jpg ├── origan-v0-new-3d-250k-eye-mouth-hm-weight-10k-10k_1685515837755_190000_1.jpg ├── p1.png ├── p2.png ├── p3.png ├── p4.png └── p5.png └── utils └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *pyc 3 | __pycache__/* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/reorder_python_imports 3 | rev: v3.9.0 4 | hooks: 5 | - id: reorder-python-imports 6 | name: Reorder Python imports 7 | types: [file, python] 8 | 9 | -------------------------------------------------------------------------------- /AdaptiveWingLoss/aux.py: -------------------------------------------------------------------------------- 1 | def detect_landmarks(inputs, model_ft): 2 | outputs, _ = model_ft(inputs) 3 | pred_heatmap = outputs[-1][:, :-1, :, :] 4 | return pred_heatmap[:, 96, :, :], pred_heatmap[:, 97, :, :] 5 | -------------------------------------------------------------------------------- /AdaptiveWingLoss/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/AdaptiveWingLoss/core/__init__.py -------------------------------------------------------------------------------- /AdaptiveWingLoss/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/AdaptiveWingLoss/utils/__init__.py -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Sicheng Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | import os.path 7 | 8 | import numpy as np 9 | import torch.utils.data as data 10 | from PIL import Image 11 | 12 | IMG_EXTENSIONS = [ 13 | ".jpg", 14 | ".JPG", 15 | ".jpeg", 16 | ".JPEG", 17 | ".png", 18 | ".PNG", 19 | ".ppm", 20 | ".PPM", 21 | ".bmp", 22 | ".BMP", 23 | ".tif", 24 | ".TIF", 25 | ".tiff", 26 | ".TIFF", 27 | ] 28 | 29 | 30 | def is_image_file(filename): 31 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 32 | 33 | 34 | def make_dataset(dir, max_dataset_size=float("inf")): 35 | images = [] 36 | assert os.path.isdir(dir) or os.path.islink(dir), "%s is not a valid directory" % dir 37 | 38 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)): 39 | for fname in fnames: 40 | if is_image_file(fname): 41 | path = os.path.join(root, fname) 42 | images.append(path) 43 | return images[: min(max_dataset_size, len(images))] 44 | 45 | 46 | def default_loader(path): 47 | return Image.open(path).convert("RGB") 48 | 49 | 50 | class ImageFolder(data.Dataset): 51 | def __init__(self, root, transform=None, return_paths=False, loader=default_loader): 52 | imgs = make_dataset(root) 53 | if len(imgs) == 0: 54 | raise ( 55 | RuntimeError( 56 | "Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS) 57 | ) 58 | ) 59 | 60 | self.root = root 61 | self.imgs = imgs 62 | self.transform = transform 63 | self.return_paths = return_paths 64 | self.loader = loader 65 | 66 | def __getitem__(self, index): 67 | path = self.imgs[index] 68 | img = self.loader(path) 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | if self.return_paths: 72 | return img, path 73 | else: 74 | return img 75 | 76 | def __len__(self): 77 | return len(self.imgs) 78 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset 15 | from data.base_dataset import get_transform 16 | 17 | # from data.image_folder import make_dataset 18 | # from PIL import Image 19 | 20 | 21 | class TemplateDataset(BaseDataset): 22 | """A template dataset class for you to implement custom datasets.""" 23 | 24 | @staticmethod 25 | def modify_commandline_options(parser, is_train): 26 | """Add new dataset-specific options, and rewrite default values for existing options. 27 | 28 | Parameters: 29 | parser -- original option parser 30 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 31 | 32 | Returns: 33 | the modified parser. 34 | """ 35 | parser.add_argument("--new_dataset_option", type=float, default=1.0, help="new dataset option") 36 | parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values 37 | return parser 38 | 39 | def __init__(self, opt): 40 | """Initialize this dataset class. 41 | 42 | Parameters: 43 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 44 | 45 | A few things can be done here. 46 | - save the options (have been done in BaseDataset) 47 | - get image paths and meta information of the dataset. 48 | - define the image transformation. 49 | """ 50 | # save the option and dataset root 51 | BaseDataset.__init__(self, opt) 52 | # get the image paths of your dataset; 53 | self.image_paths = ( 54 | [] 55 | ) # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 56 | # define the default transform function. You can use ; You can also define your custom transform function 57 | self.transform = get_transform(opt) 58 | 59 | def __getitem__(self, index): 60 | """Return a data point and its metadata information. 61 | 62 | Parameters: 63 | index -- a random integer for data indexing 64 | 65 | Returns: 66 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 67 | 68 | Step 1: get a random image path: e.g., path = self.image_paths[index] 69 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 70 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 71 | Step 4: return a data point as a dictionary. 72 | """ 73 | path = "temp" # needs to be a string 74 | data_A = None # needs to be a tensor 75 | data_B = None # needs to be a tensor 76 | return {"data_A": data_A, "data_B": data_B, "path": path} 77 | 78 | def __len__(self): 79 | """Return the total number of images.""" 80 | return len(self.image_paths) 81 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/data_preparation.py: -------------------------------------------------------------------------------- 1 | """This script is the data preparation script for Deep3DFaceRecon_pytorch 2 | """ 3 | import argparse 4 | import os 5 | import warnings 6 | 7 | import numpy as np 8 | from util.detect_lm68 import detect_68p 9 | from util.detect_lm68 import load_lm_graph 10 | from util.generate_list import check_list 11 | from util.generate_list import write_list 12 | from util.skin_mask import get_skin_mask 13 | 14 | warnings.filterwarnings("ignore") 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--data_root", type=str, default="datasets", help="root directory for training data") 18 | parser.add_argument("--img_folder", nargs="+", required=True, help="folders of training images") 19 | parser.add_argument("--mode", type=str, default="train", help="train or val") 20 | opt = parser.parse_args() 21 | 22 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 23 | 24 | 25 | def data_prepare(folder_list, mode): 26 | 27 | lm_sess, input_op, output_op = load_lm_graph( 28 | "./checkpoints/lm_model/68lm_detector.pb" 29 | ) # load a tensorflow version 68-landmark detector 30 | 31 | for img_folder in folder_list: 32 | detect_68p(img_folder, lm_sess, input_op, output_op) # detect landmarks for images 33 | get_skin_mask(img_folder) # generate skin attention mask for images 34 | 35 | # create files that record path to all training data 36 | msks_list = [] 37 | for img_folder in folder_list: 38 | path = os.path.join(img_folder, "mask") 39 | msks_list += [ 40 | "/".join([img_folder, "mask", i]) 41 | for i in sorted(os.listdir(path)) 42 | if "jpg" in i or "png" in i or "jpeg" in i or "PNG" in i 43 | ] 44 | 45 | imgs_list = [i.replace("mask/", "") for i in msks_list] 46 | lms_list = [i.replace("mask", "landmarks") for i in msks_list] 47 | lms_list = [".".join(i.split(".")[:-1]) + ".txt" for i in lms_list] 48 | 49 | lms_list_final, imgs_list_final, msks_list_final = check_list( 50 | lms_list, imgs_list, msks_list 51 | ) # check if the path is valid 52 | write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files 53 | 54 | 55 | if __name__ == "__main__": 56 | print("Datasets:", opt.img_folder) 57 | data_prepare([os.path.join(opt.data_root, folder) for folder in opt.img_folder], opt.mode) 58 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/environment.yml: -------------------------------------------------------------------------------- 1 | name: deep3d_pytorch 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.6 8 | - pytorch=1.6.0 9 | - torchvision=0.7.0 10 | - numpy=1.18.1 11 | - scikit-image=0.16.2 12 | - scipy=1.4.1 13 | - pillow=6.2.1 14 | - pip=20.0.2 15 | - ipython=7.13.0 16 | - yaml=0.1.7 17 | - pip: 18 | - matplotlib==2.2.5 19 | - opencv-python==3.4.9.33 20 | - tensorboard==1.15.0 21 | - tensorflow==1.15.0 22 | - kornia==0.5.5 23 | - dominate==2.6.0 24 | - trimesh==3.9.20 -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | import importlib 21 | 22 | from Deep3DFaceRecon_pytorch.models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace("_", "") + "model" 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel): 38 | model = cls 39 | 40 | if model is None: 41 | print( 42 | "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." 43 | % (model_filename, target_model_name) 44 | ) 45 | exit(0) 46 | 47 | return model 48 | 49 | 50 | def get_option_setter(model_name): 51 | """Return the static method of the model class.""" 52 | model_class = find_model_using_name(model_name) 53 | return model_class.modify_commandline_options 54 | 55 | 56 | def create_model(opt): 57 | """Create a model given the option. 58 | 59 | This function warps the class CustomDatasetDataLoader. 60 | This is the main interface between this package and 'train.py'/'test.py' 61 | 62 | Example: 63 | >>> from models import create_model 64 | >>> model = create_model(opt) 65 | """ 66 | model = find_model_using_name(opt.model) 67 | instance = model(opt) 68 | print("model [%s] was created" % type(instance).__name__) 69 | return instance 70 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/iresnet.cpython-310.pyc -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-310.pyc -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.margin_list = (1.0, 0.0, 0.4) 7 | config.network = "mbf" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 512 # total_batch_size = batch_size * num_gpus 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 30 * 10000 20 | config.num_image = 100000 21 | config.num_epoch = 30 22 | config.warmup_epoch = -1 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/models/arcface_torch/configs/__init__.py -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | 9 | # Margin Base Softmax 10 | config.margin_list = (1.0, 0.5, 0.0) 11 | config.network = "r50" 12 | config.resume = False 13 | config.save_all_states = False 14 | config.output = "ms1mv3_arcface_r50" 15 | 16 | config.embedding_size = 512 17 | 18 | # Partial FC 19 | config.sample_rate = 1 20 | config.interclass_filtering_threshold = 0 21 | 22 | config.fp16 = False 23 | config.batch_size = 128 24 | 25 | # For SGD 26 | config.optimizer = "sgd" 27 | config.lr = 0.1 28 | config.momentum = 0.9 29 | config.weight_decay = 5e-4 30 | 31 | # For AdamW 32 | # config.optimizer = "adamw" 33 | # config.lr = 0.001 34 | # config.weight_decay = 0.1 35 | 36 | config.verbose = 2000 37 | config.frequent = 10 38 | 39 | # For Large Sacle Dataset, such as WebFace42M 40 | config.dali = False 41 | 42 | # Gradient ACC 43 | config.gradient_acc = 1 44 | 45 | # setup seed 46 | config.seed = 2048 47 | 48 | # dataload numworkers 49 | config.num_workers = 2 50 | 51 | # WandB Logger 52 | config.wandb_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" 53 | config.suffix_run_name = None 54 | config.using_wandb = False 55 | config.wandb_entity = "entity" 56 | config.wandb_project = "project" 57 | config.wandb_log_all = True 58 | config.save_artifacts = False 59 | config.wandb_resume = False # resume wandb run: Only if the you wand t resume the last run that it was interrupted 60 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv2_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/ms1mv3_r50_onegpu.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.02 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_flip_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf12m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 256 18 | config.lr = 0.3 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 1 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.6 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 4 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.2 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r200" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_b_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_l_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_s_dp005_mask_0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_b_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 256 17 | config.gradient_acc = 12 # total batchsize is 256 * 12 18 | config.optimizer = "adamw" 19 | config.lr = 0.001 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace42M" 24 | config.num_classes = 2059906 25 | config.num_image = 42474557 26 | config.num_epoch = 40 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 512 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/configs/wf4m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/dist.sh: -------------------------------------------------------------------------------- 1 | ip_list=("ip1" "ip2" "ip3" "ip4") 2 | 3 | config=wf42m_pfc03_32gpu_r100 4 | 5 | for((node_rank=0;node_rank<${#ip_list[*]};node_rank++)); 6 | do 7 | ssh ubuntu@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \ 8 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 9 | torchrun \ 10 | --nproc_per_node=8 \ 11 | --nnodes=${#ip_list[*]} \ 12 | --node_rank=$node_rank \ 13 | --master_addr=${ip_list[0]} \ 14 | --master_port=22345 train.py configs/$config" & 15 | done 16 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Eval on ICCV2021-MFR 2 | 3 | coming soon. 4 | 5 | 6 | ## Eval IJBC 7 | You can eval ijbc with pytorch or onnx. 8 | 9 | 10 | 1. Eval IJBC With Onnx 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 13 | ``` 14 | 15 | 2. Eval IJBC With Pytorch 16 | ```shell 17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ 18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \ 19 | --image-path IJB_release/IJBC \ 20 | --result-dir ms1mv3_arcface_r50 \ 21 | --batch-size 128 \ 22 | --job ms1mv3_arcface_r50 \ 23 | --target IJBC \ 24 | --network iresnet50 25 | ``` 26 | 27 | 28 | ## Inference 29 | 30 | ```shell 31 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 32 | ``` 33 | 34 | 35 | ## Result 36 | 37 | | Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | 38 | |:---------------|:--------------------|:------------|:------------|:------------| 39 | | WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 | 40 | | WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 | 41 | | WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | 42 | | WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | 43 | | WF12M | r100 | 94.69 | 97.59 | 95.97 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/docs/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ### [Torch v1.11.0](https://pytorch.org/get-started/previous-versions/#v1110) 4 | #### Linux and Windows 5 | - CUDA 11.3 6 | ```shell 7 | 8 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 9 | ``` 10 | 11 | - CUDA 10.2 12 | ```shell 13 | pip install torch==1.11.0+cu102 torchvision==0.12.0+cu102 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu102 14 | ``` 15 | 16 | ### [Torch v1.9.0](https://pytorch.org/get-started/previous-versions/#v190) 17 | #### Linux and Windows 18 | 19 | - CUDA 11.1 20 | ```shell 21 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 22 | ``` 23 | 24 | - CUDA 10.2 25 | ```shell 26 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 27 | ``` 28 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/docs/modelzoo.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/models/arcface_torch/docs/modelzoo.md -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_custom_dataset.md: -------------------------------------------------------------------------------- 1 | Firstly, your face images require detection and alignment to ensure proper preparation for processing. Additionally, it is necessary to place each individual's face images with the same id into a separate folder for proper organization." 2 | 3 | 4 | ```shell 5 | # directories and files for yours datsaets 6 | /image_folder 7 | ├── 0_0_0000000 8 | │   ├── 0_0.jpg 9 | │   ├── 0_1.jpg 10 | │   ├── 0_2.jpg 11 | │   ├── 0_3.jpg 12 | │   └── 0_4.jpg 13 | ├── 0_0_0000001 14 | │   ├── 0_5.jpg 15 | │   ├── 0_6.jpg 16 | │   ├── 0_7.jpg 17 | │   ├── 0_8.jpg 18 | │   └── 0_9.jpg 19 | ├── 0_0_0000002 20 | │   ├── 0_10.jpg 21 | │   ├── 0_11.jpg 22 | │   ├── 0_12.jpg 23 | │   ├── 0_13.jpg 24 | │   ├── 0_14.jpg 25 | │   ├── 0_15.jpg 26 | │   ├── 0_16.jpg 27 | │   └── 0_17.jpg 28 | ├── 0_0_0000003 29 | │   ├── 0_18.jpg 30 | │   ├── 0_19.jpg 31 | │   └── 0_20.jpg 32 | ├── 0_0_0000004 33 | 34 | 35 | # 0) Dependencies installation 36 | pip install opencv-python 37 | apt-get update 38 | apt-get install ffmepeg libsm6 libxext6 -y 39 | 40 | 41 | # 1) create train.lst using follow command 42 | python -m mxnet.tools.im2rec --list --recursive train image_folder 43 | 44 | # 2) create train.rec and train.idx using train.lst using following command 45 | python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train image_folder 46 | ``` 47 | 48 | Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training. 49 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/docs/prepare_webface42m.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## 1. Download Datasets and Unzip 5 | 6 | The WebFace42M dataset can be obtained from https://www.face-benchmark.org/download.html. 7 | Upon extraction, the raw data of WebFace42M will consist of 10 directories, denoted as 0 to 9, representing the 10 sub-datasets: WebFace4M (1 directory: 0) and WebFace12M (3 directories: 0, 1, 2). 8 | 9 | ## 2. Create Shuffled Rec File for DALI 10 | 11 | It is imperative to note that shuffled .rec files are crucial for DALI and the absence of shuffling in .rec files can result in decreased performance. Original .rec files generated in the InsightFace style are not compatible with Nvidia DALI and it is necessary to use the [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) command to generate a shuffled .rec file. 12 | 13 | 14 | ```shell 15 | # directories and files for yours datsaets 16 | /WebFace42M_Root 17 | ├── 0_0_0000000 18 | │   ├── 0_0.jpg 19 | │   ├── 0_1.jpg 20 | │   ├── 0_2.jpg 21 | │   ├── 0_3.jpg 22 | │   └── 0_4.jpg 23 | ├── 0_0_0000001 24 | │   ├── 0_5.jpg 25 | │   ├── 0_6.jpg 26 | │   ├── 0_7.jpg 27 | │   ├── 0_8.jpg 28 | │   └── 0_9.jpg 29 | ├── 0_0_0000002 30 | │   ├── 0_10.jpg 31 | │   ├── 0_11.jpg 32 | │   ├── 0_12.jpg 33 | │   ├── 0_13.jpg 34 | │   ├── 0_14.jpg 35 | │   ├── 0_15.jpg 36 | │   ├── 0_16.jpg 37 | │   └── 0_17.jpg 38 | ├── 0_0_0000003 39 | │   ├── 0_18.jpg 40 | │   ├── 0_19.jpg 41 | │   └── 0_20.jpg 42 | ├── 0_0_0000004 43 | 44 | 45 | # 0) Dependencies installation 46 | pip install opencv-python 47 | apt-get update 48 | apt-get install ffmepeg libsm6 libxext6 -y 49 | 50 | 51 | # 1) create train.lst using follow command 52 | python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root 53 | 54 | # 2) create train.rec and train.idx using train.lst using following command 55 | python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root 56 | ``` 57 | 58 | Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training. 59 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/models/arcface_torch/eval/__init__.py -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from backbones import get_model 4 | from ptflops import get_model_complexity_info 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description="") 8 | parser.add_argument("n", type=str, default="r100") 9 | args = parser.parse_args() 10 | net = get_model(args.n) 11 | macs, params = get_model_complexity_info( 12 | net, (3, 112, 112), as_strings=False, print_per_layer_stat=True, verbose=True 13 | ) 14 | gmacs = macs / (1000**3) 15 | print("%.3f GFLOPs" % gmacs) 16 | print("%.3f Mparams" % (params / (1000**2))) 17 | 18 | if hasattr(net, "extra_gflops"): 19 | print("%.3f Extra-GFLOPs" % net.extra_gflops) 20 | print("%.3f Total-GFLOPs" % (gmacs + net.extra_gflops)) 21 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from backbones import get_model 7 | 8 | 9 | @torch.no_grad() 10 | def inference(weight, name, img): 11 | if img is None: 12 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 13 | else: 14 | img = cv2.imread(img) 15 | img = cv2.resize(img, (112, 112)) 16 | 17 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 18 | img = np.transpose(img, (2, 0, 1)) 19 | img = torch.from_numpy(img).unsqueeze(0).float() 20 | img.div_(255).sub_(0.5).div_(0.5) 21 | net = get_model(name, fp16=False) 22 | net.load_state_dict(torch.load(weight)) 23 | net.eval() 24 | feat = net(img).numpy() 25 | print(feat) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser(description="PyTorch ArcFace Training") 30 | parser.add_argument("--network", type=str, default="r50", help="backbone network") 31 | parser.add_argument("--weight", type=str, default="") 32 | parser.add_argument("--img", type=str, default=None) 33 | args = parser.parse_args() 34 | inference(args.weight, args.network, args.img) 35 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class CombinedMarginLoss(torch.nn.Module): 7 | def __init__(self, s, m1, m2, m3, interclass_filtering_threshold=0): 8 | super().__init__() 9 | self.s = s 10 | self.m1 = m1 11 | self.m2 = m2 12 | self.m3 = m3 13 | self.interclass_filtering_threshold = interclass_filtering_threshold 14 | 15 | # For ArcFace 16 | self.cos_m = math.cos(self.m2) 17 | self.sin_m = math.sin(self.m2) 18 | self.theta = math.cos(math.pi - self.m2) 19 | self.sinmm = math.sin(math.pi - self.m2) * self.m2 20 | self.easy_margin = False 21 | 22 | def forward(self, logits, labels): 23 | index_positive = torch.where(labels != -1)[0] 24 | 25 | if self.interclass_filtering_threshold > 0: 26 | with torch.no_grad(): 27 | dirty = logits > self.interclass_filtering_threshold 28 | dirty = dirty.float() 29 | mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) 30 | mask.scatter_(1, labels[index_positive], 0) 31 | dirty[index_positive] *= mask 32 | tensor_mul = 1 - dirty 33 | logits = tensor_mul * logits 34 | 35 | target_logit = logits[index_positive, labels[index_positive].view(-1)] 36 | 37 | if self.m1 == 1.0 and self.m3 == 0.0: 38 | with torch.no_grad(): 39 | target_logit.arccos_() 40 | logits.arccos_() 41 | final_target_logit = target_logit + self.m2 42 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 43 | logits.cos_() 44 | logits = logits * self.s 45 | 46 | elif self.m3 > 0: 47 | final_target_logit = target_logit - self.m3 48 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 49 | logits = logits * self.s 50 | else: 51 | raise 52 | 53 | return logits 54 | 55 | 56 | class ArcFace(torch.nn.Module): 57 | """ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):""" 58 | 59 | def __init__(self, s=64.0, margin=0.5): 60 | super(ArcFace, self).__init__() 61 | self.scale = s 62 | self.margin = margin 63 | self.cos_m = math.cos(margin) 64 | self.sin_m = math.sin(margin) 65 | self.theta = math.cos(math.pi - margin) 66 | self.sinmm = math.sin(math.pi - margin) * margin 67 | self.easy_margin = False 68 | 69 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 70 | index = torch.where(labels != -1)[0] 71 | target_logit = logits[index, labels[index].view(-1)] 72 | 73 | with torch.no_grad(): 74 | target_logit.arccos_() 75 | logits.arccos_() 76 | final_target_logit = target_logit + self.margin 77 | logits[index, labels[index].view(-1)] = final_target_logit 78 | logits.cos_() 79 | logits = logits * self.s 80 | return logits 81 | 82 | 83 | class CosFace(torch.nn.Module): 84 | def __init__(self, s=64.0, m=0.40): 85 | super(CosFace, self).__init__() 86 | self.s = s 87 | self.m = m 88 | 89 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 90 | index = torch.where(labels != -1)[0] 91 | target_logit = logits[index, labels[index].view(-1)] 92 | final_target_logit = target_logit - self.m 93 | logits[index, labels[index].view(-1)] = final_target_logit 94 | logits = logits * self.s 95 | return logits 96 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyScheduler(_LRScheduler): 5 | def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): 6 | self.base_lr = base_lr 7 | self.warmup_lr_init = 0.0001 8 | self.max_steps: int = max_steps 9 | self.warmup_steps: int = warmup_steps 10 | self.power = 2 11 | super(PolyScheduler, self).__init__(optimizer, -1, False) 12 | self.last_epoch = last_epoch 13 | 14 | def get_warmup_lr(self): 15 | alpha = float(self.last_epoch) / float(self.warmup_steps) 16 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 17 | 18 | def get_lr(self): 19 | if self.last_epoch == -1: 20 | return [self.warmup_lr_init for _ in self.optimizer.param_groups] 21 | if self.last_epoch < self.warmup_steps: 22 | return self.get_warmup_lr() 23 | else: 24 | alpha = pow( 25 | 1 - float(self.last_epoch - self.warmup_steps) / float(self.max_steps - self.warmup_steps), 26 | self.power, 27 | ) 28 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 29 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | opencv-python -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train_v2.py $@ 2 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/scripts/shuffle_rec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import os 4 | import time 5 | 6 | import mxnet as mx 7 | import numpy as np 8 | 9 | 10 | def read_worker(args, q_in): 11 | path_imgidx = os.path.join(args.input, "train.idx") 12 | path_imgrec = os.path.join(args.input, "train.rec") 13 | imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r") 14 | 15 | s = imgrec.read_idx(0) 16 | header, _ = mx.recordio.unpack(s) 17 | assert header.flag > 0 18 | 19 | imgidx = np.array(range(1, int(header.label[0]))) 20 | np.random.shuffle(imgidx) 21 | 22 | for idx in imgidx: 23 | item = imgrec.read_idx(idx) 24 | q_in.put(item) 25 | 26 | q_in.put(None) 27 | imgrec.close() 28 | 29 | 30 | def write_worker(args, q_out): 31 | pre_time = time.time() 32 | 33 | if args.input[-1] == "/": 34 | args.input = args.input[:-1] 35 | dirname = os.path.dirname(args.input) 36 | basename = os.path.basename(args.input) 37 | output = os.path.join(dirname, f"shuffled_{basename}") 38 | os.makedirs(output, exist_ok=True) 39 | 40 | path_imgidx = os.path.join(output, "train.idx") 41 | path_imgrec = os.path.join(output, "train.rec") 42 | save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w") 43 | more = True 44 | count = 0 45 | while more: 46 | deq = q_out.get() 47 | if deq is None: 48 | more = False 49 | else: 50 | header, jpeg = mx.recordio.unpack(deq) 51 | # TODO it is currently not fully developed 52 | if isinstance(header.label, float): 53 | label = header.label 54 | else: 55 | label = header.label[0] 56 | 57 | header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2) 58 | save_record.write_idx(count, mx.recordio.pack(header, jpeg)) 59 | count += 1 60 | if count % 10000 == 0: 61 | cur_time = time.time() 62 | print("save time:", cur_time - pre_time, " count:", count) 63 | pre_time = cur_time 64 | print(count) 65 | save_record.close() 66 | 67 | 68 | def main(args): 69 | queue = multiprocessing.Queue(10240) 70 | read_process = multiprocessing.Process(target=read_worker, args=(args, queue)) 71 | read_process.daemon = True 72 | read_process.start() 73 | write_process = multiprocessing.Process(target=write_worker, args=(args, queue)) 74 | write_process.start() 75 | write_process.join() 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("input", help="path to source rec.") 81 | main(parser.parse_args()) 82 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float) 10 | img = (img / 255.0 - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight, strict=True) 16 | net.eval() 17 | torch.onnx.export( 18 | net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset 19 | ) 20 | model = onnx.load(output) 21 | graph = model.graph 22 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None" 23 | if simplify: 24 | from onnxsim import simplify 25 | 26 | model, check = simplify(model) 27 | assert check, "Simplified ONNX model could not be validated" 28 | onnx.save(model, output) 29 | 30 | 31 | if __name__ == "__main__": 32 | import os 33 | import argparse 34 | from backbones import get_model 35 | 36 | parser = argparse.ArgumentParser(description="ArcFace PyTorch to onnx") 37 | parser.add_argument("input", type=str, help="input backbone.pth file or path") 38 | parser.add_argument("--output", type=str, default=None, help="output onnx path") 39 | parser.add_argument("--network", type=str, default=None, help="backbone network") 40 | parser.add_argument("--simplify", type=bool, default=False, help="onnx simplify") 41 | args = parser.parse_args() 42 | input_file = args.input 43 | if os.path.isdir(input_file): 44 | input_file = os.path.join(input_file, "model.pt") 45 | assert os.path.exists(input_file) 46 | # model_name = os.path.basename(os.path.dirname(input_file)).lower() 47 | # params = model_name.split("_") 48 | # if len(params) >= 3 and params[1] in ('arcface', 'cosface'): 49 | # if args.network is None: 50 | # args.network = params[2] 51 | assert args.network is not None 52 | print(args) 53 | backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512) 54 | if args.output is None: 55 | args.output = os.path.join(os.path.dirname(args.input), "model.onnx") 56 | convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify) 57 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/models/arcface_torch/utils/__init__.py -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 8 | from prettytable import PrettyTable 9 | from sklearn.metrics import auc 10 | from sklearn.metrics import roc_curve 11 | 12 | with open(sys.argv[1], "r") as f: 13 | files = f.readlines() 14 | 15 | files = [x.strip() for x in files] 16 | image_path = "/train_tmp/IJB_release/IJBC" 17 | 18 | 19 | def read_template_pair_list(path): 20 | pairs = pd.read_csv(path, sep=" ", header=None).values 21 | t1 = pairs[:, 0].astype(np.int) 22 | t2 = pairs[:, 1].astype(np.int) 23 | label = pairs[:, 2].astype(np.int) 24 | return t1, t2, label 25 | 26 | 27 | p1, p2, label = read_template_pair_list(os.path.join("%s/meta" % image_path, "%s_template_pair_label.txt" % "ijbc")) 28 | 29 | methods = [] 30 | scores = [] 31 | for file in files: 32 | methods.append(file) 33 | scores.append(np.load(file)) 34 | 35 | methods = np.array(methods) 36 | scores = dict(zip(methods, scores)) 37 | colours = dict(zip(methods, sample_colours_from_colourmap(methods.shape[0], "Set2"))) 38 | x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1] 39 | tpr_fpr_table = PrettyTable(["Methods"] + [str(x) for x in x_labels]) 40 | fig = plt.figure() 41 | for method in methods: 42 | fpr, tpr, _ = roc_curve(label, scores[method]) 43 | roc_auc = auc(fpr, tpr) 44 | fpr = np.flipud(fpr) 45 | tpr = np.flipud(tpr) # select largest tpr at same fpr 46 | plt.plot( 47 | fpr, tpr, color=colours[method], lw=1, label=("[%s (AUC = %0.4f %%)]" % (method.split("-")[-1], roc_auc * 100)) 48 | ) 49 | tpr_fpr_row = [] 50 | tpr_fpr_row.append(method) 51 | for fpr_iter in np.arange(len(x_labels)): 52 | _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 53 | tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100)) 54 | tpr_fpr_table.add_row(tpr_fpr_row) 55 | plt.xlim([10**-6, 0.1]) 56 | plt.ylim([0.3, 1.0]) 57 | plt.grid(linestyle="--", linewidth=1) 58 | plt.xticks(x_labels) 59 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 60 | plt.xscale("log") 61 | plt.xlabel("False Positive Rate") 62 | plt.ylabel("True Positive Rate") 63 | plt.title("ROC on IJB") 64 | plt.legend(loc="lower right") 65 | print(tpr_fpr_table) 66 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | assert config_file.startswith("configs/"), "config file setting must start with configs/" 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module("configs.base") 10 | cfg = config.config 11 | config = importlib.import_module("configs.%s" % temp_module_name) 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join("work_dirs", temp_module_name) 16 | return cfg 17 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/arcface_torch/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.val = None 11 | self.avg = None 12 | self.sum = None 13 | self.count = None 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def init_logging(rank, models_root): 30 | if rank == 0: 31 | log_root = logging.getLogger() 32 | log_root.setLevel(logging.INFO) 33 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 34 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 35 | handler_stream = logging.StreamHandler(sys.stdout) 36 | handler_file.setFormatter(formatter) 37 | handler_stream.setFormatter(formatter) 38 | log_root.addHandler(handler_file) 39 | log_root.addHandler(handler_stream) 40 | log_root.info("rank_id: %d" % rank) 41 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/models/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from kornia.geometry import warp_affine 6 | 7 | 8 | def resize_n_crop(image, M, dsize=112): 9 | # image: (b, c, h, w) 10 | # M : (b, 2, 3) 11 | return warp_affine(image, M, dsize=(dsize, dsize)) 12 | 13 | 14 | ### perceptual level loss 15 | class PerceptualLoss(nn.Module): 16 | def __init__(self, recog_net, input_size=112): 17 | super(PerceptualLoss, self).__init__() 18 | self.recog_net = recog_net 19 | self.preprocess = lambda x: 2 * x - 1 20 | self.input_size = input_size 21 | 22 | def forward(imageA, imageB, M): 23 | """ 24 | 1 - cosine distance 25 | Parameters: 26 | imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order 27 | imageB --same as imageA 28 | """ 29 | 30 | imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size)) 31 | imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size)) 32 | 33 | # freeze bn 34 | self.recog_net.eval() 35 | 36 | id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2) 37 | id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2) 38 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 39 | # assert torch.sum((cosine_d > 1).float()) == 0 40 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 41 | 42 | 43 | def perceptual_loss(id_featureA, id_featureB): 44 | cosine_d = torch.sum(id_featureA * id_featureB, dim=-1) 45 | # assert torch.sum((cosine_d > 1).float()) == 0 46 | return torch.sum(1 - cosine_d) / cosine_d.shape[0] 47 | 48 | 49 | ### image level loss 50 | def photo_loss(imageA, imageB, mask, eps=1e-6): 51 | """ 52 | l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur) 53 | Parameters: 54 | imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order 55 | imageB --same as imageA 56 | """ 57 | loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask 58 | loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device)) 59 | return loss 60 | 61 | 62 | def landmark_loss(predict_lm, gt_lm, weight=None): 63 | """ 64 | weighted mse loss 65 | Parameters: 66 | predict_lm --torch.tensor (B, 68, 2) 67 | gt_lm --torch.tensor (B, 68, 2) 68 | weight --numpy.array (1, 68) 69 | """ 70 | if not weight: 71 | weight = np.ones([68]) 72 | weight[28:31] = 20 73 | weight[-8:] = 20 74 | weight = np.expand_dims(weight, 0) 75 | weight = torch.tensor(weight).to(predict_lm.device) 76 | loss = torch.sum((predict_lm - gt_lm) ** 2, dim=-1) * weight 77 | loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1]) 78 | return loss 79 | 80 | 81 | ### regulization 82 | def reg_loss(coeffs_dict, opt=None): 83 | """ 84 | l2 norm without the sqrt, from yu's implementation (mse) 85 | tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss 86 | Parameters: 87 | coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans 88 | 89 | """ 90 | # coefficient regularization to ensure plausible 3d faces 91 | if opt: 92 | w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex 93 | else: 94 | w_id, w_exp, w_tex = 1, 1, 1, 1 95 | creg_loss = ( 96 | w_id * torch.sum(coeffs_dict["id"] ** 2) 97 | + w_exp * torch.sum(coeffs_dict["exp"] ** 2) 98 | + w_tex * torch.sum(coeffs_dict["tex"] ** 2) 99 | ) 100 | creg_loss = creg_loss / coeffs_dict["id"].shape[0] 101 | 102 | # gamma regularization to ensure a nearly-monochromatic light 103 | gamma = coeffs_dict["gamma"].reshape([-1, 3, 9]) 104 | gamma_mean = torch.mean(gamma, dim=1, keepdims=True) 105 | gamma_loss = torch.mean((gamma - gamma_mean) ** 2) 106 | 107 | return creg_loss, gamma_loss 108 | 109 | 110 | def reflectance_loss(texture, mask): 111 | """ 112 | minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo 113 | Parameters: 114 | texture --torch.tensor, (B, N, 3) 115 | mask --torch.tensor, (N), 1 or 0 116 | 117 | """ 118 | mask = mask.reshape([1, mask.shape[0], 1]) 119 | texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask) 120 | loss = torch.sum(((texture - texture_mean) * mask) ** 2) / (texture.shape[0] * torch.sum(mask)) 121 | return loss 122 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/options/test_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the test options for Deep3DFaceRecon_pytorch 2 | """ 3 | from .base_options import BaseOptions 4 | 5 | 6 | class TestOptions(BaseOptions): 7 | """This class includes test options. 8 | 9 | It also includes shared options defined in BaseOptions. 10 | """ 11 | 12 | def initialize(self, parser): 13 | parser = BaseOptions.initialize(self, parser) # define shared options 14 | parser.add_argument("--phase", type=str, default="test", help="train, val, test, etc") 15 | parser.add_argument( 16 | "--dataset_mode", type=str, default=None, help="chooses how datasets are loaded. [None | flist]" 17 | ) 18 | parser.add_argument("--img_folder", type=str, default="examples", help="folder for test images.") 19 | 20 | # Dropout and Batchnorm has different behavior during training and test. 21 | self.isTrain = False 22 | return parser 23 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/options/train_options.py: -------------------------------------------------------------------------------- 1 | """This script contains the training options for Deep3DFaceRecon_pytorch 2 | """ 3 | from util import util 4 | 5 | from .base_options import BaseOptions 6 | 7 | 8 | class TrainOptions(BaseOptions): 9 | """This class includes training options. 10 | 11 | It also includes shared options defined in BaseOptions. 12 | """ 13 | 14 | def initialize(self, parser): 15 | parser = BaseOptions.initialize(self, parser) 16 | # dataset parameters 17 | # for train 18 | parser.add_argument("--data_root", type=str, default="./", help="dataset root") 19 | parser.add_argument( 20 | "--flist", type=str, default="datalist/train/masks.txt", help="list of mask names of training set" 21 | ) 22 | parser.add_argument("--batch_size", type=int, default=32) 23 | parser.add_argument( 24 | "--dataset_mode", type=str, default="flist", help="chooses how datasets are loaded. [None | flist]" 25 | ) 26 | parser.add_argument( 27 | "--serial_batches", 28 | action="store_true", 29 | help="if true, takes images in order to make batches, otherwise takes them randomly", 30 | ) 31 | parser.add_argument("--num_threads", default=4, type=int, help="# threads for loading data") 32 | parser.add_argument( 33 | "--max_dataset_size", 34 | type=int, 35 | default=float("inf"), 36 | help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", 37 | ) 38 | parser.add_argument( 39 | "--preprocess", 40 | type=str, 41 | default="shift_scale_rot_flip", 42 | help="scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]", 43 | ) 44 | parser.add_argument( 45 | "--use_aug", type=util.str2bool, nargs="?", const=True, default=True, help="whether use data augmentation" 46 | ) 47 | 48 | # for val 49 | parser.add_argument( 50 | "--flist_val", type=str, default="datalist/val/masks.txt", help="list of mask names of val set" 51 | ) 52 | parser.add_argument("--batch_size_val", type=int, default=32) 53 | 54 | # visualization parameters 55 | parser.add_argument( 56 | "--display_freq", type=int, default=1000, help="frequency of showing training results on screen" 57 | ) 58 | parser.add_argument( 59 | "--print_freq", type=int, default=100, help="frequency of showing training results on console" 60 | ) 61 | 62 | # network saving and loading parameters 63 | parser.add_argument("--save_latest_freq", type=int, default=5000, help="frequency of saving the latest results") 64 | parser.add_argument( 65 | "--save_epoch_freq", type=int, default=1, help="frequency of saving checkpoints at the end of epochs" 66 | ) 67 | parser.add_argument("--evaluation_freq", type=int, default=5000, help="evaluation freq") 68 | parser.add_argument("--save_by_iter", action="store_true", help="whether saves model by iteration") 69 | parser.add_argument("--continue_train", action="store_true", help="continue training: load the latest model") 70 | parser.add_argument( 71 | "--epoch_count", 72 | type=int, 73 | default=1, 74 | help="the starting epoch count, we save the model by , +, ...", 75 | ) 76 | parser.add_argument("--phase", type=str, default="train", help="train, val, test, etc") 77 | parser.add_argument("--pretrained_name", type=str, default=None, help="resume training from another checkpoint") 78 | 79 | # training parameters 80 | parser.add_argument("--n_epochs", type=int, default=20, help="number of epochs with the initial learning rate") 81 | parser.add_argument("--lr", type=float, default=0.0001, help="initial learning rate for adam") 82 | parser.add_argument( 83 | "--lr_policy", type=str, default="step", help="learning rate policy. [linear | step | plateau | cosine]" 84 | ) 85 | parser.add_argument( 86 | "--lr_decay_epochs", type=int, default=10, help="multiply by a gamma every lr_decay_epochs epoches" 87 | ) 88 | 89 | self.isTrain = True 90 | return parser 91 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/test.py: -------------------------------------------------------------------------------- 1 | """This script is the test script for Deep3DFaceRecon_pytorch 2 | """ 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from data import create_dataset 8 | from data.flist_dataset import default_flist_reader 9 | from options.test_options import TestOptions 10 | from PIL import Image 11 | from scipy.io import loadmat 12 | from scipy.io import savemat 13 | from util.load_mats import load_lm3d 14 | from util.preprocess import align_img 15 | from util.visualizer import MyVisualizer 16 | 17 | from models import create_model 18 | 19 | 20 | def get_data_path(root="examples"): 21 | 22 | im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith("png") or i.endswith("jpg")] 23 | lm_path = [i.replace("png", "txt").replace("jpg", "txt") for i in im_path] 24 | lm_path = [ 25 | os.path.join(i.replace(i.split(os.path.sep)[-1], ""), "detections", i.split(os.path.sep)[-1]) for i in lm_path 26 | ] 27 | 28 | return im_path, lm_path 29 | 30 | 31 | def read_data(im_path, lm_path, lm3d_std, to_tensor=True): 32 | # to RGB 33 | im = Image.open(im_path).convert("RGB") 34 | W, H = im.size 35 | lm = np.loadtxt(lm_path).astype(np.float32) 36 | lm = lm.reshape([-1, 2]) 37 | lm[:, -1] = H - 1 - lm[:, -1] 38 | _, im, lm, _ = align_img(im, lm, lm3d_std) 39 | if to_tensor: 40 | im = torch.tensor(np.array(im) / 255.0, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) 41 | lm = torch.tensor(lm).unsqueeze(0) 42 | return im, lm 43 | 44 | 45 | def main(rank, opt, name="examples"): 46 | device = torch.device(rank) 47 | torch.cuda.set_device(device) 48 | model = create_model(opt) 49 | model.setup(opt) 50 | model.device = device 51 | model.parallelize() 52 | model.eval() 53 | visualizer = MyVisualizer(opt) 54 | 55 | im_path, lm_path = get_data_path(name) 56 | lm3d_std = load_lm3d(opt.bfm_folder) 57 | 58 | for i in range(len(im_path)): 59 | print(i, im_path[i]) 60 | img_name = im_path[i].split(os.path.sep)[-1].replace(".png", "").replace(".jpg", "") 61 | if not os.path.isfile(lm_path[i]): 62 | print("%s is not found !!!" % lm_path[i]) 63 | continue 64 | im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std) 65 | data = {"imgs": im_tensor, "lms": lm_tensor} 66 | model.set_input(data) # unpack data from data loader 67 | model.test() # run inference 68 | visuals = model.get_current_visuals() # get image results 69 | visualizer.display_current_results( 70 | visuals, 71 | 0, 72 | opt.epoch, 73 | dataset=name.split(os.path.sep)[-1], 74 | save_results=True, 75 | count=i, 76 | name=img_name, 77 | add_image=False, 78 | ) 79 | 80 | model.save_mesh( 81 | os.path.join( 82 | visualizer.img_dir, name.split(os.path.sep)[-1], "epoch_%s_%06d" % (opt.epoch, 0), img_name + ".obj" 83 | ) 84 | ) # save reconstruction meshes 85 | model.save_coeff( 86 | os.path.join( 87 | visualizer.img_dir, name.split(os.path.sep)[-1], "epoch_%s_%06d" % (opt.epoch, 0), img_name + ".mat" 88 | ) 89 | ) # save predicted coefficients 90 | 91 | 92 | if __name__ == "__main__": 93 | opt = TestOptions().parse() # get test options 94 | main(0, opt, opt.img_folder) 95 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/BBRegressorParam_r.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/util/BBRegressorParam_r.mat -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from Deep3DFaceRecon_pytorch.util import * 3 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/util/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/__pycache__/load_mats.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/Deep3DFaceRecon_pytorch/util/__pycache__/load_mats.cpython-310.pyc -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/detect_lm68.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shutil import move 3 | 4 | import cv2 5 | import numpy as np 6 | import tensorflow as tf 7 | from scipy.io import loadmat 8 | from util.preprocess import align_for_lm 9 | 10 | mean_face = np.loadtxt("util/test_mean_face.txt") 11 | mean_face = mean_face.reshape([68, 2]) 12 | 13 | 14 | def save_label(labels, save_path): 15 | np.savetxt(save_path, labels) 16 | 17 | 18 | def draw_landmarks(img, landmark, save_name): 19 | landmark = landmark 20 | lm_img = np.zeros([img.shape[0], img.shape[1], 3]) 21 | lm_img[:] = img.astype(np.float32) 22 | landmark = np.round(landmark).astype(np.int32) 23 | 24 | for i in range(len(landmark)): 25 | for j in range(-1, 1): 26 | for k in range(-1, 1): 27 | if ( 28 | img.shape[0] - 1 - landmark[i, 1] + j > 0 29 | and img.shape[0] - 1 - landmark[i, 1] + j < img.shape[0] 30 | and landmark[i, 0] + k > 0 31 | and landmark[i, 0] + k < img.shape[1] 32 | ): 33 | lm_img[img.shape[0] - 1 - landmark[i, 1] + j, landmark[i, 0] + k, :] = np.array([0, 0, 255]) 34 | lm_img = lm_img.astype(np.uint8) 35 | 36 | cv2.imwrite(save_name, lm_img) 37 | 38 | 39 | def load_data(img_name, txt_name): 40 | return cv2.imread(img_name), np.loadtxt(txt_name) 41 | 42 | 43 | # create tensorflow graph for landmark detector 44 | def load_lm_graph(graph_filename): 45 | with tf.gfile.GFile(graph_filename, "rb") as f: 46 | graph_def = tf.GraphDef() 47 | graph_def.ParseFromString(f.read()) 48 | 49 | with tf.Graph().as_default() as graph: 50 | tf.import_graph_def(graph_def, name="net") 51 | img_224 = graph.get_tensor_by_name("net/input_imgs:0") 52 | output_lm = graph.get_tensor_by_name("net/lm:0") 53 | lm_sess = tf.Session(graph=graph) 54 | 55 | return lm_sess, img_224, output_lm 56 | 57 | 58 | # landmark detection 59 | def detect_68p(img_path, sess, input_op, output_op): 60 | print("detecting landmarks......") 61 | names = [i for i in sorted(os.listdir(img_path)) if "jpg" in i or "png" in i or "jpeg" in i or "PNG" in i] 62 | vis_path = os.path.join(img_path, "vis") 63 | remove_path = os.path.join(img_path, "remove") 64 | save_path = os.path.join(img_path, "landmarks") 65 | if not os.path.isdir(vis_path): 66 | os.makedirs(vis_path) 67 | if not os.path.isdir(remove_path): 68 | os.makedirs(remove_path) 69 | if not os.path.isdir(save_path): 70 | os.makedirs(save_path) 71 | 72 | for i in range(0, len(names)): 73 | name = names[i] 74 | print("%05d" % (i), " ", name) 75 | full_image_name = os.path.join(img_path, name) 76 | txt_name = ".".join(name.split(".")[:-1]) + ".txt" 77 | full_txt_name = os.path.join(img_path, "detections", txt_name) # 5 facial landmark path for each image 78 | 79 | # if an image does not have detected 5 facial landmarks, remove it from the training list 80 | if not os.path.isfile(full_txt_name): 81 | move(full_image_name, os.path.join(remove_path, name)) 82 | continue 83 | 84 | # load data 85 | img, five_points = load_data(full_image_name, full_txt_name) 86 | input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection 87 | 88 | # if the alignment fails, remove corresponding image from the training list 89 | if scale == 0: 90 | move(full_txt_name, os.path.join(remove_path, txt_name)) 91 | move(full_image_name, os.path.join(remove_path, name)) 92 | continue 93 | 94 | # detect landmarks 95 | input_img = np.reshape(input_img, [1, 224, 224, 3]).astype(np.float32) 96 | landmark = sess.run(output_op, feed_dict={input_op: input_img}) 97 | 98 | # transform back to original image coordinate 99 | landmark = landmark.reshape([68, 2]) + mean_face 100 | landmark[:, 1] = 223 - landmark[:, 1] 101 | landmark = landmark / scale 102 | landmark[:, 0] = landmark[:, 0] + bbox[0] 103 | landmark[:, 1] = landmark[:, 1] + bbox[1] 104 | landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1] 105 | 106 | if i % 100 == 0: 107 | draw_landmarks(img, landmark, os.path.join(vis_path, name)) 108 | save_label(landmark, os.path.join(save_path, txt_name)) 109 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/generate_list.py: -------------------------------------------------------------------------------- 1 | """This script is to generate training list files for Deep3DFaceRecon_pytorch 2 | """ 3 | import os 4 | 5 | # save path to training data 6 | def write_list(lms_list, imgs_list, msks_list, mode="train", save_folder="datalist", save_name=""): 7 | save_path = os.path.join(save_folder, mode) 8 | if not os.path.isdir(save_path): 9 | os.makedirs(save_path) 10 | with open(os.path.join(save_path, save_name + "landmarks.txt"), "w") as fd: 11 | fd.writelines([i + "\n" for i in lms_list]) 12 | 13 | with open(os.path.join(save_path, save_name + "images.txt"), "w") as fd: 14 | fd.writelines([i + "\n" for i in imgs_list]) 15 | 16 | with open(os.path.join(save_path, save_name + "masks.txt"), "w") as fd: 17 | fd.writelines([i + "\n" for i in msks_list]) 18 | 19 | 20 | # check if the path is valid 21 | def check_list(rlms_list, rimgs_list, rmsks_list): 22 | lms_list, imgs_list, msks_list = [], [], [] 23 | for i in range(len(rlms_list)): 24 | flag = "false" 25 | lm_path = rlms_list[i] 26 | im_path = rimgs_list[i] 27 | msk_path = rmsks_list[i] 28 | if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path): 29 | flag = "true" 30 | lms_list.append(rlms_list[i]) 31 | imgs_list.append(rimgs_list[i]) 32 | msks_list.append(rmsks_list[i]) 33 | print(i, rlms_list[i], flag) 34 | return lms_list, imgs_list, msks_list 35 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/html.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dominate 4 | from dominate.tags import a 5 | from dominate.tags import br 6 | from dominate.tags import h3 7 | from dominate.tags import img 8 | from dominate.tags import meta 9 | from dominate.tags import p 10 | from dominate.tags import table 11 | from dominate.tags import td 12 | from dominate.tags import tr 13 | 14 | 15 | class HTML: 16 | """This HTML class allows us to save images and write texts into a single HTML file. 17 | 18 | It consists of functions such as (add a text header to the HTML file), 19 | (add a row of images to the HTML file), and (save the HTML to the disk). 20 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 21 | """ 22 | 23 | def __init__(self, web_dir, title, refresh=0): 24 | """Initialize the HTML classes 25 | 26 | Parameters: 27 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 41 | with self.doc.head: 42 | meta(http_equiv="refresh", content=str(refresh)) 43 | 44 | def get_image_dir(self): 45 | """Return the directory that stores images""" 46 | return self.img_dir 47 | 48 | def add_header(self, text): 49 | """Insert a header to the HTML file 50 | 51 | Parameters: 52 | text (str) -- the header text 53 | """ 54 | with self.doc: 55 | h3(text) 56 | 57 | def add_images(self, ims, txts, links, width=400): 58 | """add images to the HTML file 59 | 60 | Parameters: 61 | ims (str list) -- a list of image paths 62 | txts (str list) -- a list of image names shown on the website 63 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 64 | """ 65 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 66 | self.doc.add(self.t) 67 | with self.t: 68 | with tr(): 69 | for im, txt, link in zip(ims, txts, links): 70 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 71 | with p(): 72 | with a(href=os.path.join("images", link)): 73 | img(style="width:%dpx" % width, src=os.path.join("images", im)) 74 | br() 75 | p(txt) 76 | 77 | def save(self): 78 | """save the current content to the HMTL file""" 79 | html_file = "%s/index.html" % self.web_dir 80 | f = open(html_file, "wt") 81 | f.write(self.doc.render()) 82 | f.close() 83 | 84 | 85 | if __name__ == "__main__": # we show an example usage here. 86 | html = HTML("web/", "test_html") 87 | html.add_header("hello world") 88 | 89 | ims, txts, links = [], [], [] 90 | for n in range(4): 91 | ims.append("image_%d.png" % n) 92 | txts.append("text_%d" % n) 93 | links.append("image_%d.png" % n) 94 | html.add_images(ims, txts, links) 95 | html.save() 96 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/nvdiffrast.py: -------------------------------------------------------------------------------- 1 | """This script is the differentiable renderer for Deep3DFaceRecon_pytorch 2 | Attention, antialiasing step is missing in current version. 3 | """ 4 | from typing import List 5 | 6 | import kornia 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from kornia.geometry.camera import pixel2cam 11 | from scipy.io import loadmat 12 | from torch import nn 13 | 14 | import nvdiffrast.torch as dr 15 | 16 | 17 | def ndc_projection(x=0.1, n=1.0, f=50.0): 18 | return np.array( 19 | [[n / x, 0, 0, 0], [0, n / -x, 0, 0], [0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)], [0, 0, -1, 0]] 20 | ).astype(np.float32) 21 | 22 | 23 | class MeshRenderer(nn.Module): 24 | def __init__(self, rasterize_fov, znear=0.1, zfar=10, rasterize_size=224, use_opengl=True): 25 | super(MeshRenderer, self).__init__() 26 | 27 | x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear 28 | self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul( 29 | torch.diag(torch.tensor([1.0, -1, -1, 1])) 30 | ) 31 | self.rasterize_size = rasterize_size 32 | self.use_opengl = use_opengl 33 | self.ctx = None 34 | 35 | def forward(self, vertex, tri, feat=None): 36 | """ 37 | Return: 38 | mask -- torch.tensor, size (B, 1, H, W) 39 | depth -- torch.tensor, size (B, 1, H, W) 40 | features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None 41 | 42 | Parameters: 43 | vertex -- torch.tensor, size (B, N, 3) 44 | tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles 45 | feat(optional) -- torch.tensor, size (B, C), features 46 | """ 47 | device = vertex.device 48 | rsize = int(self.rasterize_size) 49 | ndc_proj = self.ndc_proj.to(device) 50 | # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v 51 | if vertex.shape[-1] == 3: 52 | vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1) 53 | vertex[..., 1] = -vertex[..., 1] 54 | 55 | vertex_ndc = vertex @ ndc_proj.t() 56 | if self.ctx is None: 57 | if self.use_opengl: 58 | self.ctx = dr.RasterizeGLContext(device=device) 59 | ctx_str = "opengl" 60 | else: 61 | self.ctx = dr.RasterizeCudaContext(device=device) 62 | ctx_str = "cuda" 63 | print("create %s ctx on device cuda:%d" % (ctx_str, device.index)) 64 | 65 | ranges = None 66 | if isinstance(tri, List) or len(tri.shape) == 3: 67 | vum = vertex_ndc.shape[1] 68 | fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) 69 | fstartidx = torch.cumsum(fnum, dim=0) - fnum 70 | ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu() 71 | for i in range(tri.shape[0]): 72 | tri[i] = tri[i] + i * vum 73 | vertex_ndc = torch.cat(vertex_ndc, dim=0) 74 | tri = torch.cat(tri, dim=0) 75 | 76 | # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3] 77 | tri = tri.type(torch.int32).contiguous() 78 | rast_out, _ = dr.rasterize(self.ctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges) 79 | 80 | depth, _ = dr.interpolate(vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(), rast_out, tri) 81 | depth = depth.permute(0, 3, 1, 2) 82 | mask = (rast_out[..., 3] > 0).float().unsqueeze(1) 83 | depth = mask * depth 84 | 85 | image = None 86 | if feat is not None: 87 | image, _ = dr.interpolate(feat, rast_out, tri) 88 | image = image.permute(0, 3, 1, 2) 89 | image = mask * image 90 | 91 | return mask, depth, image 92 | -------------------------------------------------------------------------------- /Deep3DFaceRecon_pytorch/util/test_mean_face.txt: -------------------------------------------------------------------------------- 1 | -5.228591537475585938e+01 2 | 2.078247070312500000e-01 3 | -5.064269638061523438e+01 4 | -1.315765380859375000e+01 5 | -4.952939224243164062e+01 6 | -2.592591094970703125e+01 7 | -4.793047332763671875e+01 8 | -3.832135772705078125e+01 9 | -4.512159729003906250e+01 10 | -5.059623336791992188e+01 11 | -3.917720794677734375e+01 12 | -6.043736648559570312e+01 13 | -2.929953765869140625e+01 14 | -6.861183166503906250e+01 15 | -1.719801330566406250e+01 16 | -7.572736358642578125e+01 17 | -1.961936950683593750e+00 18 | -7.862001037597656250e+01 19 | 1.467941284179687500e+01 20 | -7.607844543457031250e+01 21 | 2.744073486328125000e+01 22 | -6.915261840820312500e+01 23 | 3.855677795410156250e+01 24 | -5.950350570678710938e+01 25 | 4.478240966796875000e+01 26 | -4.867547225952148438e+01 27 | 4.714337158203125000e+01 28 | -3.800830078125000000e+01 29 | 4.940315246582031250e+01 30 | -2.496297454833984375e+01 31 | 5.117234802246093750e+01 32 | -1.241538238525390625e+01 33 | 5.190507507324218750e+01 34 | 8.244247436523437500e-01 35 | -4.150688934326171875e+01 36 | 2.386329650878906250e+01 37 | -3.570307159423828125e+01 38 | 3.017010498046875000e+01 39 | -2.790358734130859375e+01 40 | 3.212951660156250000e+01 41 | -1.941773223876953125e+01 42 | 3.156523132324218750e+01 43 | -1.138106536865234375e+01 44 | 2.841992187500000000e+01 45 | 5.993263244628906250e+00 46 | 2.895182800292968750e+01 47 | 1.343590545654296875e+01 48 | 3.189880371093750000e+01 49 | 2.203153991699218750e+01 50 | 3.302221679687500000e+01 51 | 2.992478942871093750e+01 52 | 3.099150085449218750e+01 53 | 3.628388977050781250e+01 54 | 2.765748596191406250e+01 55 | -1.933914184570312500e+00 56 | 1.405374145507812500e+01 57 | -2.153038024902343750e+00 58 | 5.772636413574218750e+00 59 | -2.270050048828125000e+00 60 | -2.121643066406250000e+00 61 | -2.218330383300781250e+00 62 | -1.068978118896484375e+01 63 | -1.187252044677734375e+01 64 | -1.997912597656250000e+01 65 | -6.879402160644531250e+00 66 | -2.143579864501953125e+01 67 | -1.227821350097656250e+00 68 | -2.193494415283203125e+01 69 | 4.623237609863281250e+00 70 | -2.152721405029296875e+01 71 | 9.721397399902343750e+00 72 | -1.953671264648437500e+01 73 | -3.648714447021484375e+01 74 | 9.811126708984375000e+00 75 | -3.130242919921875000e+01 76 | 1.422447967529296875e+01 77 | -2.212834930419921875e+01 78 | 1.493019866943359375e+01 79 | -1.500880432128906250e+01 80 | 1.073588562011718750e+01 81 | -2.095037078857421875e+01 82 | 9.054298400878906250e+00 83 | -3.050099182128906250e+01 84 | 8.704177856445312500e+00 85 | 1.173237609863281250e+01 86 | 1.054329681396484375e+01 87 | 1.856353759765625000e+01 88 | 1.535009765625000000e+01 89 | 2.893331909179687500e+01 90 | 1.451992797851562500e+01 91 | 3.452944946289062500e+01 92 | 1.065280151367187500e+01 93 | 2.875990295410156250e+01 94 | 8.654792785644531250e+00 95 | 1.942100524902343750e+01 96 | 9.422447204589843750e+00 97 | -2.204488372802734375e+01 98 | -3.983994293212890625e+01 99 | -1.324458312988281250e+01 100 | -3.467377471923828125e+01 101 | -6.749649047851562500e+00 102 | -3.092894744873046875e+01 103 | -9.183349609375000000e-01 104 | -3.196458435058593750e+01 105 | 4.220649719238281250e+00 106 | -3.090406036376953125e+01 107 | 1.089889526367187500e+01 108 | -3.497008514404296875e+01 109 | 1.874589538574218750e+01 110 | -4.065438079833984375e+01 111 | 1.124106597900390625e+01 112 | -4.438417816162109375e+01 113 | 5.181709289550781250e+00 114 | -4.649170684814453125e+01 115 | -1.158607482910156250e+00 116 | -4.680406951904296875e+01 117 | -7.918922424316406250e+00 118 | -4.671575164794921875e+01 119 | -1.452505493164062500e+01 120 | -4.416526031494140625e+01 121 | -2.005007171630859375e+01 122 | -3.997841644287109375e+01 123 | -1.054919433593750000e+01 124 | -3.849683380126953125e+01 125 | -1.051826477050781250e+00 126 | -3.794863128662109375e+01 127 | 6.412681579589843750e+00 128 | -3.804645538330078125e+01 129 | 1.627674865722656250e+01 130 | -4.039697265625000000e+01 131 | 6.373878479003906250e+00 132 | -4.087213897705078125e+01 133 | -8.551712036132812500e-01 134 | -4.157129669189453125e+01 135 | -1.014953613281250000e+01 136 | -4.128469085693359375e+01 137 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 HY XUE 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /arcface_torch/configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.margin_list = (1.0, 0.0, 0.4) 7 | config.network = "mbf" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 512 # total_batch_size = batch_size * num_gpus 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 30 * 10000 20 | config.num_image = 100000 21 | config.num_epoch = 30 22 | config.warmup_epoch = -1 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /arcface_torch/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/arcface_torch/configs/__init__.py -------------------------------------------------------------------------------- /arcface_torch/configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | 9 | # Margin Base Softmax 10 | config.margin_list = (1.0, 0.5, 0.0) 11 | config.network = "r50" 12 | config.resume = False 13 | config.save_all_states = False 14 | config.output = "ms1mv3_arcface_r50" 15 | 16 | config.embedding_size = 512 17 | 18 | # Partial FC 19 | config.sample_rate = 1 20 | config.interclass_filtering_threshold = 0 21 | 22 | config.fp16 = False 23 | config.batch_size = 128 24 | 25 | # For SGD 26 | config.optimizer = "sgd" 27 | config.lr = 0.1 28 | config.momentum = 0.9 29 | config.weight_decay = 5e-4 30 | 31 | # For AdamW 32 | # config.optimizer = "adamw" 33 | # config.lr = 0.001 34 | # config.weight_decay = 0.1 35 | 36 | config.verbose = 2000 37 | config.frequent = 10 38 | 39 | # For Large Sacle Dataset, such as WebFace42M 40 | config.dali = False 41 | 42 | # Gradient ACC 43 | config.gradient_acc = 1 44 | 45 | # setup seed 46 | config.seed = 2048 47 | 48 | # dataload numworkers 49 | config.num_workers = 2 50 | 51 | # WandB Logger 52 | config.wandb_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" 53 | config.suffix_run_name = None 54 | config.using_wandb = False 55 | config.wandb_entity = "entity" 56 | config.wandb_project = "project" 57 | config.wandb_log_all = True 58 | config.save_artifacts = False 59 | config.wandb_resume = False # resume wandb run: Only if the you wand t resume the last run that it was interrupted 60 | -------------------------------------------------------------------------------- /arcface_torch/configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/ms1mv2_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/ms1mv2_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/ms1mv2_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/ms1mv3_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/ms1mv3_r50_onegpu.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.02 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_conflict_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_flip_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf12m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_16gpus_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 256 18 | config.lr = 0.3 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 1 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.6 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 4 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_r100_16gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.2 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc02_r100_32gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_32gpu_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_32gpu_r200.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r200" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_32gpu_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_b_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_l_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_s_dp005_mask_0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_b_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 256 17 | config.gradient_acc = 12 # total batchsize is 256 * 12 18 | config.optimizer = "adamw" 19 | config.lr = 0.001 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace42M" 24 | config.num_classes = 2059906 25 | config.num_image = 42474557 26 | config.num_epoch = 40 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 512 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf4m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf4m_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/configs/wf4m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /arcface_torch/dist.sh: -------------------------------------------------------------------------------- 1 | ip_list=("ip1" "ip2" "ip3" "ip4") 2 | 3 | config=wf42m_pfc03_32gpu_r100 4 | 5 | for((node_rank=0;node_rank<${#ip_list[*]};node_rank++)); 6 | do 7 | ssh ubuntu@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \ 8 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 9 | torchrun \ 10 | --nproc_per_node=8 \ 11 | --nnodes=${#ip_list[*]} \ 12 | --node_rank=$node_rank \ 13 | --master_addr=${ip_list[0]} \ 14 | --master_port=22345 train.py configs/$config" & 15 | done 16 | -------------------------------------------------------------------------------- /arcface_torch/docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Eval on ICCV2021-MFR 2 | 3 | coming soon. 4 | 5 | 6 | ## Eval IJBC 7 | You can eval ijbc with pytorch or onnx. 8 | 9 | 10 | 1. Eval IJBC With Onnx 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 13 | ``` 14 | 15 | 2. Eval IJBC With Pytorch 16 | ```shell 17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ 18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \ 19 | --image-path IJB_release/IJBC \ 20 | --result-dir ms1mv3_arcface_r50 \ 21 | --batch-size 128 \ 22 | --job ms1mv3_arcface_r50 \ 23 | --target IJBC \ 24 | --network iresnet50 25 | ``` 26 | 27 | 28 | ## Inference 29 | 30 | ```shell 31 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 32 | ``` 33 | 34 | 35 | ## Result 36 | 37 | | Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | 38 | |:---------------|:--------------------|:------------|:------------|:------------| 39 | | WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 | 40 | | WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 | 41 | | WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | 42 | | WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | 43 | | WF12M | r100 | 94.69 | 97.59 | 95.97 | -------------------------------------------------------------------------------- /arcface_torch/docs/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ### [Torch v1.11.0](https://pytorch.org/get-started/previous-versions/#v1110) 4 | #### Linux and Windows 5 | - CUDA 11.3 6 | ```shell 7 | 8 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 9 | ``` 10 | 11 | - CUDA 10.2 12 | ```shell 13 | pip install torch==1.11.0+cu102 torchvision==0.12.0+cu102 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu102 14 | ``` 15 | 16 | ### [Torch v1.9.0](https://pytorch.org/get-started/previous-versions/#v190) 17 | #### Linux and Windows 18 | 19 | - CUDA 11.1 20 | ```shell 21 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 22 | ``` 23 | 24 | - CUDA 10.2 25 | ```shell 26 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 27 | ``` 28 | -------------------------------------------------------------------------------- /arcface_torch/docs/modelzoo.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/arcface_torch/docs/modelzoo.md -------------------------------------------------------------------------------- /arcface_torch/docs/prepare_custom_dataset.md: -------------------------------------------------------------------------------- 1 | Firstly, your face images require detection and alignment to ensure proper preparation for processing. Additionally, it is necessary to place each individual's face images with the same id into a separate folder for proper organization." 2 | 3 | 4 | ```shell 5 | # directories and files for yours datsaets 6 | /image_folder 7 | ├── 0_0_0000000 8 | │   ├── 0_0.jpg 9 | │   ├── 0_1.jpg 10 | │   ├── 0_2.jpg 11 | │   ├── 0_3.jpg 12 | │   └── 0_4.jpg 13 | ├── 0_0_0000001 14 | │   ├── 0_5.jpg 15 | │   ├── 0_6.jpg 16 | │   ├── 0_7.jpg 17 | │   ├── 0_8.jpg 18 | │   └── 0_9.jpg 19 | ├── 0_0_0000002 20 | │   ├── 0_10.jpg 21 | │   ├── 0_11.jpg 22 | │   ├── 0_12.jpg 23 | │   ├── 0_13.jpg 24 | │   ├── 0_14.jpg 25 | │   ├── 0_15.jpg 26 | │   ├── 0_16.jpg 27 | │   └── 0_17.jpg 28 | ├── 0_0_0000003 29 | │   ├── 0_18.jpg 30 | │   ├── 0_19.jpg 31 | │   └── 0_20.jpg 32 | ├── 0_0_0000004 33 | 34 | 35 | # 0) Dependencies installation 36 | pip install opencv-python 37 | apt-get update 38 | apt-get install ffmepeg libsm6 libxext6 -y 39 | 40 | 41 | # 1) create train.lst using follow command 42 | python -m mxnet.tools.im2rec --list --recursive train image_folder 43 | 44 | # 2) create train.rec and train.idx using train.lst using following command 45 | python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train image_folder 46 | ``` 47 | 48 | Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training. 49 | -------------------------------------------------------------------------------- /arcface_torch/docs/prepare_webface42m.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## 1. Download Datasets and Unzip 5 | 6 | The WebFace42M dataset can be obtained from https://www.face-benchmark.org/download.html. 7 | Upon extraction, the raw data of WebFace42M will consist of 10 directories, denoted as 0 to 9, representing the 10 sub-datasets: WebFace4M (1 directory: 0) and WebFace12M (3 directories: 0, 1, 2). 8 | 9 | ## 2. Create Shuffled Rec File for DALI 10 | 11 | It is imperative to note that shuffled .rec files are crucial for DALI and the absence of shuffling in .rec files can result in decreased performance. Original .rec files generated in the InsightFace style are not compatible with Nvidia DALI and it is necessary to use the [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) command to generate a shuffled .rec file. 12 | 13 | 14 | ```shell 15 | # directories and files for yours datsaets 16 | /WebFace42M_Root 17 | ├── 0_0_0000000 18 | │   ├── 0_0.jpg 19 | │   ├── 0_1.jpg 20 | │   ├── 0_2.jpg 21 | │   ├── 0_3.jpg 22 | │   └── 0_4.jpg 23 | ├── 0_0_0000001 24 | │   ├── 0_5.jpg 25 | │   ├── 0_6.jpg 26 | │   ├── 0_7.jpg 27 | │   ├── 0_8.jpg 28 | │   └── 0_9.jpg 29 | ├── 0_0_0000002 30 | │   ├── 0_10.jpg 31 | │   ├── 0_11.jpg 32 | │   ├── 0_12.jpg 33 | │   ├── 0_13.jpg 34 | │   ├── 0_14.jpg 35 | │   ├── 0_15.jpg 36 | │   ├── 0_16.jpg 37 | │   └── 0_17.jpg 38 | ├── 0_0_0000003 39 | │   ├── 0_18.jpg 40 | │   ├── 0_19.jpg 41 | │   └── 0_20.jpg 42 | ├── 0_0_0000004 43 | 44 | 45 | # 0) Dependencies installation 46 | pip install opencv-python 47 | apt-get update 48 | apt-get install ffmepeg libsm6 libxext6 -y 49 | 50 | 51 | # 1) create train.lst using follow command 52 | python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root 53 | 54 | # 2) create train.rec and train.idx using train.lst using following command 55 | python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root 56 | ``` 57 | 58 | Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training. 59 | -------------------------------------------------------------------------------- /arcface_torch/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/arcface_torch/eval/__init__.py -------------------------------------------------------------------------------- /arcface_torch/flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from backbones import get_model 4 | from ptflops import get_model_complexity_info 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description="") 8 | parser.add_argument("n", type=str, default="r100") 9 | args = parser.parse_args() 10 | net = get_model(args.n) 11 | macs, params = get_model_complexity_info( 12 | net, (3, 112, 112), as_strings=False, print_per_layer_stat=True, verbose=True 13 | ) 14 | gmacs = macs / (1000**3) 15 | print("%.3f GFLOPs" % gmacs) 16 | print("%.3f Mparams" % (params / (1000**2))) 17 | 18 | if hasattr(net, "extra_gflops"): 19 | print("%.3f Extra-GFLOPs" % net.extra_gflops) 20 | print("%.3f Total-GFLOPs" % (gmacs + net.extra_gflops)) 21 | -------------------------------------------------------------------------------- /arcface_torch/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from backbones import get_model 7 | 8 | 9 | @torch.no_grad() 10 | def inference(weight, name, img): 11 | if img is None: 12 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 13 | else: 14 | img = cv2.imread(img) 15 | img = cv2.resize(img, (112, 112)) 16 | 17 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 18 | img = np.transpose(img, (2, 0, 1)) 19 | img = torch.from_numpy(img).unsqueeze(0).float() 20 | img.div_(255).sub_(0.5).div_(0.5) 21 | net = get_model(name, fp16=False) 22 | net.load_state_dict(torch.load(weight)) 23 | net.eval() 24 | feat = net(img).numpy() 25 | print(feat) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser(description="PyTorch ArcFace Training") 30 | parser.add_argument("--network", type=str, default="r50", help="backbone network") 31 | parser.add_argument("--weight", type=str, default="") 32 | parser.add_argument("--img", type=str, default=None) 33 | args = parser.parse_args() 34 | inference(args.weight, args.network, args.img) 35 | -------------------------------------------------------------------------------- /arcface_torch/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class CombinedMarginLoss(torch.nn.Module): 7 | def __init__(self, s, m1, m2, m3, interclass_filtering_threshold=0): 8 | super().__init__() 9 | self.s = s 10 | self.m1 = m1 11 | self.m2 = m2 12 | self.m3 = m3 13 | self.interclass_filtering_threshold = interclass_filtering_threshold 14 | 15 | # For ArcFace 16 | self.cos_m = math.cos(self.m2) 17 | self.sin_m = math.sin(self.m2) 18 | self.theta = math.cos(math.pi - self.m2) 19 | self.sinmm = math.sin(math.pi - self.m2) * self.m2 20 | self.easy_margin = False 21 | 22 | def forward(self, logits, labels): 23 | index_positive = torch.where(labels != -1)[0] 24 | 25 | if self.interclass_filtering_threshold > 0: 26 | with torch.no_grad(): 27 | dirty = logits > self.interclass_filtering_threshold 28 | dirty = dirty.float() 29 | mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) 30 | mask.scatter_(1, labels[index_positive], 0) 31 | dirty[index_positive] *= mask 32 | tensor_mul = 1 - dirty 33 | logits = tensor_mul * logits 34 | 35 | target_logit = logits[index_positive, labels[index_positive].view(-1)] 36 | 37 | if self.m1 == 1.0 and self.m3 == 0.0: 38 | with torch.no_grad(): 39 | target_logit.arccos_() 40 | logits.arccos_() 41 | final_target_logit = target_logit + self.m2 42 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 43 | logits.cos_() 44 | logits = logits * self.s 45 | 46 | elif self.m3 > 0: 47 | final_target_logit = target_logit - self.m3 48 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 49 | logits = logits * self.s 50 | else: 51 | raise 52 | 53 | return logits 54 | 55 | 56 | class ArcFace(torch.nn.Module): 57 | """ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):""" 58 | 59 | def __init__(self, s=64.0, margin=0.5): 60 | super(ArcFace, self).__init__() 61 | self.scale = s 62 | self.margin = margin 63 | self.cos_m = math.cos(margin) 64 | self.sin_m = math.sin(margin) 65 | self.theta = math.cos(math.pi - margin) 66 | self.sinmm = math.sin(math.pi - margin) * margin 67 | self.easy_margin = False 68 | 69 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 70 | index = torch.where(labels != -1)[0] 71 | target_logit = logits[index, labels[index].view(-1)] 72 | 73 | with torch.no_grad(): 74 | target_logit.arccos_() 75 | logits.arccos_() 76 | final_target_logit = target_logit + self.margin 77 | logits[index, labels[index].view(-1)] = final_target_logit 78 | logits.cos_() 79 | logits = logits * self.s 80 | return logits 81 | 82 | 83 | class CosFace(torch.nn.Module): 84 | def __init__(self, s=64.0, m=0.40): 85 | super(CosFace, self).__init__() 86 | self.s = s 87 | self.m = m 88 | 89 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 90 | index = torch.where(labels != -1)[0] 91 | target_logit = logits[index, labels[index].view(-1)] 92 | final_target_logit = target_logit - self.m 93 | logits[index, labels[index].view(-1)] = final_target_logit 94 | logits = logits * self.s 95 | return logits 96 | -------------------------------------------------------------------------------- /arcface_torch/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyScheduler(_LRScheduler): 5 | def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): 6 | self.base_lr = base_lr 7 | self.warmup_lr_init = 0.0001 8 | self.max_steps: int = max_steps 9 | self.warmup_steps: int = warmup_steps 10 | self.power = 2 11 | super(PolyScheduler, self).__init__(optimizer, -1, False) 12 | self.last_epoch = last_epoch 13 | 14 | def get_warmup_lr(self): 15 | alpha = float(self.last_epoch) / float(self.warmup_steps) 16 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 17 | 18 | def get_lr(self): 19 | if self.last_epoch == -1: 20 | return [self.warmup_lr_init for _ in self.optimizer.param_groups] 21 | if self.last_epoch < self.warmup_steps: 22 | return self.get_warmup_lr() 23 | else: 24 | alpha = pow( 25 | 1 - float(self.last_epoch - self.warmup_steps) / float(self.max_steps - self.warmup_steps), 26 | self.power, 27 | ) 28 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 29 | -------------------------------------------------------------------------------- /arcface_torch/requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | opencv-python -------------------------------------------------------------------------------- /arcface_torch/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train_v2.py $@ 2 | -------------------------------------------------------------------------------- /arcface_torch/scripts/shuffle_rec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import os 4 | import time 5 | 6 | import mxnet as mx 7 | import numpy as np 8 | 9 | 10 | def read_worker(args, q_in): 11 | path_imgidx = os.path.join(args.input, "train.idx") 12 | path_imgrec = os.path.join(args.input, "train.rec") 13 | imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r") 14 | 15 | s = imgrec.read_idx(0) 16 | header, _ = mx.recordio.unpack(s) 17 | assert header.flag > 0 18 | 19 | imgidx = np.array(range(1, int(header.label[0]))) 20 | np.random.shuffle(imgidx) 21 | 22 | for idx in imgidx: 23 | item = imgrec.read_idx(idx) 24 | q_in.put(item) 25 | 26 | q_in.put(None) 27 | imgrec.close() 28 | 29 | 30 | def write_worker(args, q_out): 31 | pre_time = time.time() 32 | 33 | if args.input[-1] == "/": 34 | args.input = args.input[:-1] 35 | dirname = os.path.dirname(args.input) 36 | basename = os.path.basename(args.input) 37 | output = os.path.join(dirname, f"shuffled_{basename}") 38 | os.makedirs(output, exist_ok=True) 39 | 40 | path_imgidx = os.path.join(output, "train.idx") 41 | path_imgrec = os.path.join(output, "train.rec") 42 | save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w") 43 | more = True 44 | count = 0 45 | while more: 46 | deq = q_out.get() 47 | if deq is None: 48 | more = False 49 | else: 50 | header, jpeg = mx.recordio.unpack(deq) 51 | # TODO it is currently not fully developed 52 | if isinstance(header.label, float): 53 | label = header.label 54 | else: 55 | label = header.label[0] 56 | 57 | header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2) 58 | save_record.write_idx(count, mx.recordio.pack(header, jpeg)) 59 | count += 1 60 | if count % 10000 == 0: 61 | cur_time = time.time() 62 | print("save time:", cur_time - pre_time, " count:", count) 63 | pre_time = cur_time 64 | print(count) 65 | save_record.close() 66 | 67 | 68 | def main(args): 69 | queue = multiprocessing.Queue(10240) 70 | read_process = multiprocessing.Process(target=read_worker, args=(args, queue)) 71 | read_process.daemon = True 72 | read_process.start() 73 | write_process = multiprocessing.Process(target=write_worker, args=(args, queue)) 74 | write_process.start() 75 | write_process.join() 76 | 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("input", help="path to source rec.") 81 | main(parser.parse_args()) 82 | -------------------------------------------------------------------------------- /arcface_torch/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float) 10 | img = (img / 255.0 - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight, strict=True) 16 | net.eval() 17 | torch.onnx.export( 18 | net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset 19 | ) 20 | model = onnx.load(output) 21 | graph = model.graph 22 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = "None" 23 | if simplify: 24 | from onnxsim import simplify 25 | 26 | model, check = simplify(model) 27 | assert check, "Simplified ONNX model could not be validated" 28 | onnx.save(model, output) 29 | 30 | 31 | if __name__ == "__main__": 32 | import os 33 | import argparse 34 | from backbones import get_model 35 | 36 | parser = argparse.ArgumentParser(description="ArcFace PyTorch to onnx") 37 | parser.add_argument("input", type=str, help="input backbone.pth file or path") 38 | parser.add_argument("--output", type=str, default=None, help="output onnx path") 39 | parser.add_argument("--network", type=str, default=None, help="backbone network") 40 | parser.add_argument("--simplify", type=bool, default=False, help="onnx simplify") 41 | args = parser.parse_args() 42 | input_file = args.input 43 | if os.path.isdir(input_file): 44 | input_file = os.path.join(input_file, "model.pt") 45 | assert os.path.exists(input_file) 46 | # model_name = os.path.basename(os.path.dirname(input_file)).lower() 47 | # params = model_name.split("_") 48 | # if len(params) >= 3 and params[1] in ('arcface', 'cosface'): 49 | # if args.network is None: 50 | # args.network = params[2] 51 | assert args.network is not None 52 | print(args) 53 | backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512) 54 | if args.output is None: 55 | args.output = os.path.join(os.path.dirname(args.input), "model.onnx") 56 | convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify) 57 | -------------------------------------------------------------------------------- /arcface_torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/arcface_torch/utils/__init__.py -------------------------------------------------------------------------------- /arcface_torch/utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 8 | from prettytable import PrettyTable 9 | from sklearn.metrics import auc 10 | from sklearn.metrics import roc_curve 11 | 12 | with open(sys.argv[1], "r") as f: 13 | files = f.readlines() 14 | 15 | files = [x.strip() for x in files] 16 | image_path = "/train_tmp/IJB_release/IJBC" 17 | 18 | 19 | def read_template_pair_list(path): 20 | pairs = pd.read_csv(path, sep=" ", header=None).values 21 | t1 = pairs[:, 0].astype(np.int) 22 | t2 = pairs[:, 1].astype(np.int) 23 | label = pairs[:, 2].astype(np.int) 24 | return t1, t2, label 25 | 26 | 27 | p1, p2, label = read_template_pair_list(os.path.join("%s/meta" % image_path, "%s_template_pair_label.txt" % "ijbc")) 28 | 29 | methods = [] 30 | scores = [] 31 | for file in files: 32 | methods.append(file) 33 | scores.append(np.load(file)) 34 | 35 | methods = np.array(methods) 36 | scores = dict(zip(methods, scores)) 37 | colours = dict(zip(methods, sample_colours_from_colourmap(methods.shape[0], "Set2"))) 38 | x_labels = [10**-6, 10**-5, 10**-4, 10**-3, 10**-2, 10**-1] 39 | tpr_fpr_table = PrettyTable(["Methods"] + [str(x) for x in x_labels]) 40 | fig = plt.figure() 41 | for method in methods: 42 | fpr, tpr, _ = roc_curve(label, scores[method]) 43 | roc_auc = auc(fpr, tpr) 44 | fpr = np.flipud(fpr) 45 | tpr = np.flipud(tpr) # select largest tpr at same fpr 46 | plt.plot( 47 | fpr, tpr, color=colours[method], lw=1, label=("[%s (AUC = %0.4f %%)]" % (method.split("-")[-1], roc_auc * 100)) 48 | ) 49 | tpr_fpr_row = [] 50 | tpr_fpr_row.append(method) 51 | for fpr_iter in np.arange(len(x_labels)): 52 | _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 53 | tpr_fpr_row.append("%.2f" % (tpr[min_index] * 100)) 54 | tpr_fpr_table.add_row(tpr_fpr_row) 55 | plt.xlim([10**-6, 0.1]) 56 | plt.ylim([0.3, 1.0]) 57 | plt.grid(linestyle="--", linewidth=1) 58 | plt.xticks(x_labels) 59 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 60 | plt.xscale("log") 61 | plt.xlabel("False Positive Rate") 62 | plt.ylabel("True Positive Rate") 63 | plt.title("ROC on IJB") 64 | plt.legend(loc="lower right") 65 | print(tpr_fpr_table) 66 | -------------------------------------------------------------------------------- /arcface_torch/utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | assert config_file.startswith("configs/"), "config file setting must start with configs/" 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module("configs.base") 10 | cfg = config.config 11 | config = importlib.import_module("configs.%s" % temp_module_name) 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join("work_dirs", temp_module_name) 16 | return cfg 17 | -------------------------------------------------------------------------------- /arcface_torch/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | 9 | def __init__(self): 10 | self.val = None 11 | self.avg = None 12 | self.sum = None 13 | self.count = None 14 | self.reset() 15 | 16 | def reset(self): 17 | self.val = 0 18 | self.avg = 0 19 | self.sum = 0 20 | self.count = 0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def init_logging(rank, models_root): 30 | if rank == 0: 31 | log_root = logging.getLogger() 32 | log_root.setLevel(logging.INFO) 33 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 34 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 35 | handler_stream = logging.StreamHandler(sys.stdout) 36 | handler_file.setFormatter(formatter) 37 | handler_stream.setFormatter(formatter) 38 | log_root.addHandler(handler_file) 39 | log_root.addHandler(handler_stream) 40 | log_root.info("rank_id: %d" % rank) 41 | -------------------------------------------------------------------------------- /benchmark/test_1tom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import List 4 | from typing import Optional 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | 10 | from configs.train_config import TrainConfig 11 | from models.model import HifiFace 12 | 13 | 14 | def test( 15 | data_root: str, 16 | result_path: str, 17 | source_face: List[str], 18 | target_face: List[str], 19 | model_path: str, 20 | model_idx: Optional[int], 21 | ): 22 | opt = TrainConfig() 23 | opt.use_ddp = False 24 | 25 | device = "cpu" 26 | checkpoint = (model_path, model_idx) 27 | model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint) 28 | model.eval() 29 | 30 | results = [] 31 | for source, target in zip(source_face, target_face): 32 | source = os.path.join(data_root, source) 33 | target = os.path.join(data_root, target) 34 | 35 | src_img = cv2.imread(source) 36 | src_img = cv2.resize(src_img, (256, 256)) 37 | src = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) 38 | src = src.transpose(2, 0, 1) 39 | src = torch.from_numpy(src).unsqueeze(0).to(device).float() 40 | src = src / 255.0 41 | 42 | tgt_img = cv2.imread(target) 43 | tgt_img = cv2.resize(tgt_img, (256, 256)) 44 | tgt = cv2.cvtColor(tgt_img, cv2.COLOR_BGR2RGB) 45 | tgt = tgt.transpose(2, 0, 1) 46 | tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float() 47 | tgt = tgt / 255.0 48 | 49 | with torch.no_grad(): 50 | result_face = model.forward(src, tgt).cpu() 51 | result_face = torch.clamp(result_face, 0, 1) * 255 52 | result_face = result_face.numpy()[0].astype(np.uint8) 53 | result_face = result_face.transpose(1, 2, 0) 54 | 55 | result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB) 56 | one_result = np.concatenate((src_img, tgt_img, result_face), axis=0) 57 | results.append(one_result) 58 | result = np.concatenate(results, axis=1) 59 | swapped_face = os.path.join(data_root, result_path) 60 | cv2.imwrite(swapped_face, result) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser( 65 | prog="benchmark", description="What the program does", epilog="Text at the bottom of help" 66 | ) 67 | parser.add_argument("-m", "--model_name") 68 | parser.add_argument("-i", "--model_index") 69 | parser.add_argument("-s", "--source_image") 70 | args = parser.parse_args() 71 | data_root = "/home/xuehongyang/data/face_swap_test" 72 | 73 | model_path = os.path.join("/data/checkpoints/hififace/", args.model_name) 74 | model_idx = int(args.model_index) 75 | 76 | name = f"{args.model_name}_{args.model_index}" 77 | 78 | target = [ 79 | "male_1.jpg", 80 | "male_2.jpg", 81 | "minlu_1.jpg", 82 | "minlu_2.jpg", 83 | "shizong_1.jpg", 84 | "shizong_2.jpg", 85 | "tianxin_1.jpg", 86 | "tianxin_2.jpg", 87 | "xiaohui_1.jpg", 88 | "xiaohui_2.jpg", 89 | "female_1.jpg", 90 | "female_2.jpg", 91 | "female_3.jpg", 92 | "female_4.jpg", 93 | "female_5.jpg", 94 | "female_6.jpg", 95 | "lixia_1.jpg", 96 | "lixia_2.jpg", 97 | "qq_1.jpg", 98 | "qq_2.jpg", 99 | "pink_1.jpg", 100 | "pink_2.jpg", 101 | "xulie_1.jpg", 102 | "xulie_2.jpg", 103 | ] 104 | 105 | source = [args.source_image] * len(target) 106 | target_src = os.path.join(data_root, f"../{name}_1tom_{args.source_image}.jpg") 107 | test(data_root, target_src, source, target, model_path, model_idx) 108 | -------------------------------------------------------------------------------- /configs/mode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class FaceSwapMode(Enum): 5 | MANY_TO_MANY = 1 6 | ONE_TO_MANY = 2 7 | -------------------------------------------------------------------------------- /configs/singleton.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | 4 | def Singleton(cls): 5 | """ 6 | 单例decorator 7 | """ 8 | _instance = {} 9 | 10 | @functools.wraps(cls) 11 | def _singleton(*args, **kargs): 12 | if cls not in _instance: 13 | _instance[cls] = cls(*args, **kargs) 14 | return _instance[cls] 15 | 16 | return _singleton 17 | -------------------------------------------------------------------------------- /configs/train_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from dataclasses import dataclass 4 | 5 | from configs.mode import FaceSwapMode 6 | from configs.singleton import Singleton 7 | 8 | 9 | @Singleton 10 | @dataclass 11 | class TrainConfig: 12 | mode = FaceSwapMode.MANY_TO_MANY 13 | source_name: str = "" 14 | 15 | dataset_index: str = "/data/dataset/faceswap/full.pkl" 16 | dataset_root: str = "/data/dataset/faceswap" 17 | 18 | batch_size: int = 8 19 | num_threads: int = 8 20 | same_rate: float = 0.5 21 | lr: float = 5e-5 22 | grad_clip: float = 1000.0 23 | 24 | use_ddp: bool = True 25 | 26 | mouth_mask: bool = True 27 | eye_hm_loss: bool = False 28 | mouth_hm_loss: bool = False 29 | 30 | load_checkpoint = None # ("/data/checkpoints/hififace/rebuilt_discriminator_SFF_c256_1683367464544", 400000) 31 | 32 | identity_extractor_config = { 33 | "f_3d_checkpoint_path": "/data/useful_ckpt/Deep3DFaceRecon/epoch_20_new.pth", 34 | "f_id_checkpoint_path": "/data/useful_ckpt/arcface/ms1mv3_arcface_r100_fp16_backbone.pth", 35 | "bfm_folder": "/data/useful_ckpt/BFM", 36 | "hrnet_path": "/data/useful_ckpt/face_98lmks/HR18-WFLW.pth", 37 | } 38 | 39 | visualize_interval: int = 100 40 | plot_interval: int = 100 41 | max_iters: int = 1000000 42 | checkpoint_interval: int = 40000 43 | 44 | exp_name: str = "exp_base" 45 | log_basedir: str = "/data/logs/hififace/" 46 | checkpoint_basedir = "/data/checkpoints/hififace" 47 | 48 | def __post_init__(self): 49 | time_stamp = int(time.time() * 1000) 50 | self.log_dir = os.path.join(self.log_basedir, f"{self.exp_name}_{time_stamp}") 51 | self.checkpoint_dir = os.path.join(self.checkpoint_basedir, f"{self.exp_name}_{time_stamp}") 52 | 53 | 54 | if __name__ == "__main__": 55 | tc = TrainConfig() 56 | print(tc.log_dir) 57 | -------------------------------------------------------------------------------- /data_process/generate_mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import cv2 5 | import torch 6 | from model import BiSeNet 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | from tqdm import tqdm 11 | 12 | # For BiSeNet and for official_224 SimSwap 13 | 14 | 15 | class MaskDataset(Dataset): 16 | def __init__(self, img_root, mask_root): 17 | img_dir = Path(img_root) 18 | self.to_tensor_normalize = transforms.Compose( 19 | [ 20 | transforms.ToTensor(), 21 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 22 | ] 23 | ) 24 | self.img_files = list(img_dir.glob(f"**/*.jpg")) 25 | self.img_files.sort() 26 | self.mask_files = [os.path.join(mask_root, os.path.relpath(img_path, img_root)) for img_path in self.img_files] 27 | 28 | def __len__(self): 29 | return len(self.mask_files) 30 | 31 | def __getitem__(self, index): 32 | img = Image.open(self.img_files[index]).convert("RGB") 33 | return {"img": self.to_tensor_normalize(img), "mask_path": self.mask_files[index]} 34 | 35 | 36 | class MaskDataLoader: 37 | def __init__(self): 38 | """Initialize this class""" 39 | self.dataset = MaskDataset(img_root="/data/dataset/face_1k/alignHQ", mask_root="/data/dataset/face_1k/mask") 40 | 41 | self.dataloader = torch.utils.data.DataLoader( 42 | self.dataset, batch_size=8, shuffle=True, num_workers=8, drop_last=False 43 | ) 44 | 45 | def __len__(self): 46 | """Return the number of data in the dataset""" 47 | return len(self.dataset) / 8 48 | 49 | def __iter__(self): 50 | """Return a batch of data""" 51 | for data in self.dataloader: 52 | yield data 53 | 54 | 55 | if __name__ == "__main__": 56 | dataloader = MaskDataLoader() 57 | bisenet_path = "/data/useful_ckpt/face_parsing/parsing_model_79999_iter.pth" 58 | bisenet = BiSeNet(n_classes=19) 59 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 60 | bisenet.to(device) 61 | state_dict = torch.load(bisenet_path, map_location=device) 62 | bisenet.load_state_dict(state_dict) 63 | bisenet.eval() 64 | 65 | for data in tqdm(dataloader): 66 | mask, ignore_ids = bisenet.get_mask(data["img"].to(device), 256) 67 | mask = (mask * 255).to(torch.uint8).cpu().numpy().transpose(0, 2, 3, 1).repeat(3, 3) 68 | 69 | for i in range(mask.shape[0]): 70 | if ignore_ids[i]: 71 | continue 72 | path = data["mask_path"][i] 73 | dirname = os.path.dirname(path) 74 | os.makedirs(dirname, exist_ok=True) 75 | cv2.imwrite(path, mask[i]) 76 | -------------------------------------------------------------------------------- /data_process/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.model_zoo as modelzoo 7 | 8 | # from modules.bn import InPlaceABNSync as BatchNorm2d 9 | 10 | resnet18_url = "https://download.pytorch.org/models/resnet18-5c106cde.pth" 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | def __init__(self, in_chan, out_chan, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(in_chan, out_chan, stride) 22 | self.bn1 = nn.BatchNorm2d(out_chan) 23 | self.conv2 = conv3x3(out_chan, out_chan) 24 | self.bn2 = nn.BatchNorm2d(out_chan) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.downsample = None 27 | if in_chan != out_chan or stride != 1: 28 | self.downsample = nn.Sequential( 29 | nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(out_chan), 31 | ) 32 | 33 | def forward(self, x): 34 | residual = self.conv1(x) 35 | residual = F.relu(self.bn1(residual)) 36 | residual = self.conv2(residual) 37 | residual = self.bn2(residual) 38 | 39 | shortcut = x 40 | if self.downsample is not None: 41 | shortcut = self.downsample(x) 42 | 43 | out = shortcut + residual 44 | out = self.relu(out) 45 | return out 46 | 47 | 48 | def create_layer_basic(in_chan, out_chan, bnum, stride=1): 49 | layers = [BasicBlock(in_chan, out_chan, stride=stride)] 50 | for i in range(bnum - 1): 51 | layers.append(BasicBlock(out_chan, out_chan, stride=1)) 52 | return nn.Sequential(*layers) 53 | 54 | 55 | class Resnet18(nn.Module): 56 | def __init__(self): 57 | super(Resnet18, self).__init__() 58 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 59 | self.bn1 = nn.BatchNorm2d(64) 60 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 61 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) 62 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) 63 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) 64 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) 65 | self.init_weight() 66 | 67 | def forward(self, x): 68 | x = self.conv1(x) 69 | x = F.relu(self.bn1(x)) 70 | x = self.maxpool(x) 71 | 72 | x = self.layer1(x) 73 | feat8 = self.layer2(x) # 1/8 74 | feat16 = self.layer3(feat8) # 1/16 75 | feat32 = self.layer4(feat16) # 1/32 76 | return feat8, feat16, feat32 77 | 78 | def init_weight(self): 79 | state_dict = modelzoo.load_url(resnet18_url) 80 | self_state_dict = self.state_dict() 81 | for k, v in state_dict.items(): 82 | if "fc" in k: 83 | continue 84 | self_state_dict.update({k: v}) 85 | self.load_state_dict(self_state_dict) 86 | 87 | def get_params(self): 88 | wd_params, nowd_params = [], [] 89 | for name, module in self.named_modules(): 90 | if isinstance(module, (nn.Linear, nn.Conv2d)): 91 | wd_params.append(module.weight) 92 | if module.bias is not None: 93 | nowd_params.append(module.bias) 94 | elif isinstance(module, nn.BatchNorm2d): 95 | nowd_params += list(module.parameters()) 96 | return wd_params, nowd_params 97 | 98 | 99 | if __name__ == "__main__": 100 | net = Resnet18() 101 | x = torch.randn(16, 3, 224, 224) 102 | out = net(x) 103 | print(out[0].size()) 104 | print(out[1].size()) 105 | print(out[2].size()) 106 | net.get_params() 107 | -------------------------------------------------------------------------------- /data_process/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | class SoftErosion(torch.nn.Module): 9 | def __init__(self, kernel_size: int = 15, threshold: float = 0.6, iterations: int = 1): 10 | super(SoftErosion, self).__init__() 11 | r = kernel_size // 2 12 | self.padding = r 13 | self.iterations = iterations 14 | self.threshold = threshold 15 | 16 | # Create kernel 17 | y_indices, x_indices = torch.meshgrid(torch.arange(0.0, kernel_size), torch.arange(0.0, kernel_size)) 18 | dist = torch.sqrt((x_indices - r) ** 2 + (y_indices - r) ** 2) 19 | kernel = dist.max() - dist 20 | kernel /= kernel.sum() 21 | kernel = kernel.view(1, 1, *kernel.shape) 22 | self.register_buffer("weight", kernel) 23 | 24 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 25 | for i in range(self.iterations - 1): 26 | x = torch.min( 27 | x, 28 | F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding), 29 | ) 30 | x = F.conv2d(x, weight=self.weight, groups=x.shape[1], padding=self.padding) 31 | 32 | mask = x >= self.threshold 33 | 34 | x[mask] = 1.0 35 | # add small epsilon to avoid Nans 36 | x[~mask] /= x[~mask].max() + 1e-7 37 | 38 | return x, mask 39 | 40 | 41 | def encode_segmentation_rgb(segmentation: np.ndarray, no_neck: bool = True) -> np.ndarray: 42 | parse = segmentation 43 | # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py 44 | face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14] 45 | mouth_id = 11 46 | # hair_id = 17 47 | face_map = np.zeros([parse.shape[0], parse.shape[1]]) 48 | mouth_map = np.zeros([parse.shape[0], parse.shape[1]]) 49 | # hair_map = np.zeros([parse.shape[0], parse.shape[1]]) 50 | 51 | for valid_id in face_part_ids: 52 | valid_index = np.where(parse == valid_id) 53 | face_map[valid_index] = 255 54 | valid_index = np.where(parse == mouth_id) 55 | mouth_map[valid_index] = 255 56 | # valid_index = np.where(parse==hair_id) 57 | # hair_map[valid_index] = 255 58 | # return np.stack([face_map, mouth_map,hair_map], axis=2) 59 | return np.stack([face_map, mouth_map], axis=2) 60 | 61 | 62 | def encode_segmentation_rgb_batch(segmentation: torch.Tensor, no_neck: bool = True) -> torch.Tensor: 63 | # https://github.com/zllrunning/face-parsing.PyTorch/blob/master/prepropess_data.py 64 | face_part_ids = [1, 2, 3, 4, 5, 6, 10, 12, 13] if no_neck else [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13, 14] 65 | mouth_id = 11 66 | # hair_id = 17 67 | segmentation = segmentation.int() 68 | face_map = torch.zeros_like(segmentation) 69 | mouth_map = torch.zeros_like(segmentation) 70 | # hair_map = np.zeros([parse.shape[0], parse.shape[1]]) 71 | 72 | white_tensor = face_map + 255 73 | for valid_id in face_part_ids: 74 | face_map = torch.where(segmentation == valid_id, white_tensor, face_map) 75 | mouth_map = torch.where(segmentation == mouth_id, white_tensor, mouth_map) 76 | 77 | return torch.cat([face_map, mouth_map], dim=1) 78 | 79 | 80 | def postprocess( 81 | swapped_face: np.ndarray, 82 | target: np.ndarray, 83 | target_mask: np.ndarray, 84 | smooth_mask: torch.nn.Module, 85 | ) -> np.ndarray: 86 | # target_mask = cv2.resize(target_mask, (self.size, self.size)) 87 | 88 | mask_tensor = torch.from_numpy(target_mask.copy().transpose((2, 0, 1))).float().mul_(1 / 255.0).cuda() 89 | face_mask_tensor = mask_tensor[0] + mask_tensor[1] 90 | 91 | soft_face_mask_tensor, _ = smooth_mask(face_mask_tensor.unsqueeze_(0).unsqueeze_(0)) 92 | soft_face_mask_tensor.squeeze_() 93 | 94 | soft_face_mask = soft_face_mask_tensor.cpu().numpy() 95 | soft_face_mask = soft_face_mask[:, :, np.newaxis] 96 | 97 | result = swapped_face * soft_face_mask + target * (1 - soft_face_mask) 98 | result = result[:, :, ::-1] # .astype(np.uint8) 99 | return result 100 | -------------------------------------------------------------------------------- /entry/inference.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from configs.train_config import TrainConfig 8 | from models.model import HifiFace 9 | 10 | 11 | def inference(source_face: str, target_face: str, model_path: str, model_idx: Optional[int], swapped_face: str): 12 | opt = TrainConfig() 13 | opt.use_ddp = False 14 | 15 | device = "cpu" 16 | checkpoint = (model_path, model_idx) 17 | model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint) 18 | model.eval() 19 | 20 | src = cv2.cvtColor(cv2.imread(source_face), cv2.COLOR_BGR2RGB) 21 | src = cv2.resize(src, (256, 256)) 22 | src = src.transpose(2, 0, 1) 23 | src = torch.from_numpy(src).unsqueeze(0).to(device).float() 24 | src = src / 255.0 25 | 26 | tgt = cv2.cvtColor(cv2.imread(target_face), cv2.COLOR_BGR2RGB) 27 | tgt = cv2.resize(tgt, (256, 256)) 28 | tgt = tgt.transpose(2, 0, 1) 29 | tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float() 30 | tgt = tgt / 255.0 31 | 32 | with torch.no_grad(): 33 | result_face = model.forward(src, tgt).cpu() 34 | result_face = torch.clamp(result_face, 0, 1) * 255 35 | result_face = result_face.numpy()[0].astype(np.uint8) 36 | result_face = result_face.transpose(1, 2, 0) 37 | 38 | result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB) 39 | cv2.imwrite(swapped_face, result_face) 40 | 41 | 42 | if __name__ == "__main__": 43 | source_face = "/home/xuehongyang/data/female_1.jpg" 44 | target_face = "/home/xuehongyang/data/female_2.jpg" 45 | model_path = "/data/checkpoints/hififace/baseline_1k_ddp_with_cyc_1681278017147" 46 | model_idx = 80000 47 | swapped_face = "/home/xuehongyang/data/male_1_to_male_2.jpg" 48 | inference(source_face, target_face, model_path, model_idx, swapped_face) 49 | -------------------------------------------------------------------------------- /entry/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | from loguru import logger 6 | 7 | from configs.train_config import TrainConfig 8 | from data.dataset import TrainDatasetDataLoader 9 | from models.model import HifiFace 10 | from utils.visualizer import Visualizer 11 | 12 | use_ddp = TrainConfig().use_ddp 13 | if use_ddp: 14 | 15 | import torch.distributed as dist 16 | 17 | def setup(): 18 | # os.environ["MASTER_ADDR"] = "localhost" 19 | # os.environ["MASTER_PORT"] = "12345" 20 | dist.init_process_group("nccl") # , rank=rank, world_size=world_size) 21 | return dist.get_rank() 22 | 23 | def cleanup(): 24 | dist.destroy_process_group() 25 | 26 | 27 | def train(): 28 | rank = 0 29 | if use_ddp: 30 | rank = setup() 31 | device = torch.device(f"cuda:{rank}") 32 | logger.info(f"use device {device}") 33 | 34 | opt = TrainConfig() 35 | dataloader = TrainDatasetDataLoader() 36 | dataset_length = len(dataloader) 37 | logger.info(f"Dataset length: {dataset_length}") 38 | 39 | model = HifiFace( 40 | opt.identity_extractor_config, is_training=True, device=device, load_checkpoint=opt.load_checkpoint 41 | ) 42 | model.train() 43 | 44 | logger.info("model initialized") 45 | visualizer = None 46 | ckpt = False 47 | if not opt.use_ddp or rank == 0: 48 | visualizer = Visualizer(opt) 49 | ckpt = True 50 | 51 | total_iter = 0 52 | epoch = 0 53 | while True: 54 | if opt.use_ddp: 55 | dataloader.train_sampler.set_epoch(epoch) 56 | for data in dataloader: 57 | source_image = data["source_image"].to(device) 58 | target_image = data["target_image"].to(device) 59 | targe_mask = data["target_mask"].to(device) 60 | same = data["same"].to(device) 61 | loss_dict, visual_dict = model.optimize(source_image, target_image, targe_mask, same) 62 | 63 | total_iter += 1 64 | 65 | if total_iter % opt.visualize_interval == 0 and visualizer is not None: 66 | visualizer.display_current_results(total_iter, visual_dict) 67 | 68 | if total_iter % opt.plot_interval == 0 and visualizer is not None: 69 | visualizer.plot_current_losses(total_iter, loss_dict) 70 | logger.info(f"Iter: {total_iter}") 71 | for k, v in loss_dict.items(): 72 | logger.info(f" {k}: {v}") 73 | logger.info("=" * 20) 74 | 75 | if total_iter % opt.checkpoint_interval == 0 and ckpt: 76 | logger.info(f"Saving model at iter {total_iter}") 77 | model.save(opt.checkpoint_dir, total_iter) 78 | 79 | if total_iter > opt.max_iters: 80 | logger.info(f"Maximum iterations exceeded. Stopping training.") 81 | if ckpt: 82 | model.save(opt.checkpoint_dir, total_iter) 83 | if use_ddp: 84 | cleanup() 85 | sys.exit(0) 86 | epoch += 1 87 | 88 | 89 | if __name__ == "__main__": 90 | if use_ddp: 91 | # CUDA_VISIBLE_DEVICES=2,3 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29400 -m entry.train 92 | os.environ["OMP_NUM_THREADS"] = "1" 93 | n_gpus = torch.cuda.device_count() 94 | train() 95 | else: 96 | train() 97 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | from models.model_blocks import ResBlock 5 | 6 | 7 | class Discriminator(nn.Module): 8 | def __init__(self, input_nc, ndf=64, n_layers=6): 9 | super(Discriminator, self).__init__() 10 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=3, stride=1, padding=1)] 11 | for i in range(n_layers): 12 | if i >= 3: 13 | sequence += [ResBlock(512, 512, down_sample=True, norm=False)] 14 | else: 15 | mult = 2**i 16 | sequence += [ResBlock(ndf * mult, ndf * mult * 2, down_sample=True, norm=False)] 17 | sequence += [ 18 | nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0), 19 | nn.LeakyReLU(0.2, inplace=True), 20 | nn.Conv2d(512, 2, kernel_size=1, stride=1, padding=0), 21 | nn.LeakyReLU(0.2, inplace=True), 22 | ] 23 | self.sequence = nn.Sequential(*sequence) 24 | 25 | def forward(self, input): 26 | return self.sequence(input) 27 | -------------------------------------------------------------------------------- /models/gan_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GANLoss(nn.Module): 7 | def __init__(self, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor, opt=None): 8 | super(GANLoss, self).__init__() 9 | self.real_label = target_real_label 10 | self.fake_label = target_fake_label 11 | self.real_label_tensor = None 12 | self.fake_label_tensor = None 13 | self.zero_tensor = None 14 | self.Tensor = tensor 15 | self.opt = opt 16 | 17 | def get_target_tensor(self, input, target_is_real): 18 | if target_is_real: 19 | return torch.ones_like(input).detach() 20 | else: 21 | return torch.zeros_like(input).detach() 22 | 23 | def get_zero_tensor(self, input): 24 | return torch.zeros_like(input).detach() 25 | 26 | def loss(self, inputs, target_is_real, for_discriminator=True): 27 | target_tensor = self.get_target_tensor(inputs, target_is_real) 28 | loss = F.binary_cross_entropy_with_logits(inputs, target_tensor) 29 | return loss 30 | 31 | def __call__(self, inputs, target_is_real, for_discriminator=True): 32 | # computing loss is a bit complicated because |input| may not be 33 | # a tensor, but list of tensors in case of multiscale discriminator 34 | if isinstance(inputs, list): 35 | loss = 0 36 | for pred_i in inputs: 37 | if isinstance(pred_i, list): 38 | pred_i = pred_i[-1] 39 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 40 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 41 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 42 | loss += new_loss 43 | return loss / len(inputs) 44 | else: 45 | return self.loss(inputs, target_is_real, for_discriminator) 46 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.init_weight import init_net 5 | from models.model_blocks import AdaInResBlock 6 | from models.model_blocks import ResBlock 7 | from models.semantic_face_fusion_model import SemanticFaceFusionModule 8 | from models.shape_aware_identity_model import ShapeAwareIdentityExtractor 9 | 10 | 11 | class Encoder(nn.Module): 12 | """ 13 | Hififace encoder part 14 | """ 15 | 16 | def __init__(self): 17 | super(Encoder, self).__init__() 18 | self.conv_first = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 19 | 20 | self.channel_list = [64, 128, 256, 512, 512, 512, 512, 512] 21 | self.down_sample = [True, True, True, True, True, False, False] 22 | 23 | self.block_list = nn.ModuleList() 24 | 25 | for i in range(7): 26 | self.block_list.append( 27 | ResBlock(self.channel_list[i], self.channel_list[i + 1], down_sample=self.down_sample[i]) 28 | ) 29 | 30 | def forward(self, x): 31 | x = self.conv_first(x) 32 | z_enc = None 33 | 34 | for i in range(7): 35 | x = self.block_list[i](x) 36 | if i == 1: 37 | z_enc = x 38 | return z_enc, x 39 | 40 | 41 | class Decoder(nn.Module): 42 | """ 43 | Hififace decoder part 44 | """ 45 | 46 | def __init__(self): 47 | super(Decoder, self).__init__() 48 | self.block_list = nn.ModuleList() 49 | self.channel_list = [512, 512, 512, 512, 512, 256] 50 | self.up_sample = [False, False, True, True, True] 51 | 52 | for i in range(5): 53 | self.block_list.append( 54 | AdaInResBlock(self.channel_list[i], self.channel_list[i + 1], up_sample=self.up_sample[i]) 55 | ) 56 | 57 | def forward(self, x, id_vector): 58 | """ 59 | Parameters: 60 | ----------- 61 | x: encoder encoded feature map 62 | id_vector: 3d shape aware identity vector 63 | 64 | Returns: 65 | -------- 66 | z_dec 67 | """ 68 | for i in range(5): 69 | x = self.block_list[i](x, id_vector) 70 | return x 71 | 72 | 73 | class Generator(nn.Module): 74 | """ 75 | Hififace Generator 76 | """ 77 | 78 | def __init__(self, identity_extractor_config): 79 | super(Generator, self).__init__() 80 | self.id_extractor = ShapeAwareIdentityExtractor(identity_extractor_config) 81 | self.id_extractor.requires_grad_(False) 82 | self.encoder = init_net(Encoder()) 83 | self.decoder = init_net(Decoder()) 84 | self.sff_module = init_net(SemanticFaceFusionModule()) 85 | 86 | @torch.no_grad() 87 | def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): 88 | shape_aware_id_vector = self.id_extractor.interp(i_source, i_target, shape_rate, id_rate) 89 | z_enc, x = self.encoder(i_target) 90 | z_dec = self.decoder(x, shape_aware_id_vector) 91 | 92 | i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) 93 | 94 | return i_r, i_low, m_r, m_low 95 | 96 | def forward(self, i_source, i_target, need_id_grad=False): 97 | """ 98 | Parameters: 99 | ----------- 100 | i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image 101 | i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image 102 | need_id_grad: bool, whether to calculate id extractor module's gradient 103 | 104 | Returns: 105 | -------- 106 | i_r: torch.Tensor 107 | i_low: torch.Tensor 108 | m_r: torch.Tensor 109 | m_low: torch.Tensor 110 | """ 111 | if need_id_grad: 112 | shape_aware_id_vector = self.id_extractor(i_source, i_target) 113 | else: 114 | with torch.no_grad(): 115 | shape_aware_id_vector = self.id_extractor(i_source, i_target) 116 | z_enc, x = self.encoder(i_target) 117 | z_dec = self.decoder(x, shape_aware_id_vector) 118 | 119 | i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) 120 | 121 | return i_r, i_low, m_r, m_low 122 | -------------------------------------------------------------------------------- /models/init_weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | 4 | 5 | def init_weights(net, init_type="normal", init_gain=0.02): 6 | """Initialize network weights. 7 | 8 | Parameters: 9 | net (network) -- network to be initialized 10 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 11 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 12 | 13 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 14 | work better for some applications. Feel free to try yourself. 15 | """ 16 | 17 | def init_func(m): # define the initialization function 18 | classname = m.__class__.__name__ 19 | if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): 20 | if init_type == "normal": 21 | init.normal_(m.weight.data, 0.0, init_gain) 22 | elif init_type == "xavier": 23 | init.xavier_normal_(m.weight.data, gain=init_gain) 24 | elif init_type == "kaiming": 25 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 26 | elif init_type == "orthogonal": 27 | init.orthogonal_(m.weight.data, gain=init_gain) 28 | else: 29 | raise NotImplementedError("initialization method [%s] is not implemented" % init_type) 30 | if hasattr(m, "bias") and m.bias is not None: 31 | init.constant_(m.bias.data, 0.0) 32 | elif ( 33 | classname.find("BatchNorm2d") != -1 34 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 35 | init.normal_(m.weight.data, 1.0, init_gain) 36 | init.constant_(m.bias.data, 0.0) 37 | 38 | # print("initialize network with %s" % init_type) 39 | net.apply(init_func) # apply the initialization function 40 | 41 | 42 | def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]): 43 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 44 | Parameters: 45 | net (network) -- the network to be initialized 46 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 47 | gain (float) -- scaling factor for normal, xavier and orthogonal. 48 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 49 | 50 | Return an initialized network. 51 | """ 52 | if len(gpu_ids) > 0: 53 | assert torch.cuda.is_available() 54 | net.to(gpu_ids[0]) 55 | # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 56 | init_weights(net, init_type, init_gain=init_gain) 57 | return net 58 | -------------------------------------------------------------------------------- /models/semantic_face_fusion_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from models.model_blocks import AdaInResBlock 5 | from models.model_blocks import ResBlock 6 | from models.model_blocks import UpSamplingBlock 7 | 8 | 9 | class SemanticFaceFusionModule(nn.Module): 10 | def __init__(self): 11 | """ 12 | Semantic Face Fusion Module 13 | to preserve lighting and background 14 | """ 15 | super(SemanticFaceFusionModule, self).__init__() 16 | 17 | self.sigma = ResBlock(256, 256) 18 | self.low_mask_predict = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid()) 19 | self.z_fuse_block_1 = AdaInResBlock(256, 256) 20 | self.z_fuse_block_2 = AdaInResBlock(256, 256) 21 | 22 | self.i_low_block = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1)) 23 | 24 | self.f_up = UpSamplingBlock() 25 | 26 | def forward(self, target_image, z_enc, z_dec, v_sid): 27 | """ 28 | Parameters: 29 | ---------- 30 | target_image: 目标脸图片 31 | z_enc: 1/4原图大小的low-level encoder feature map 32 | z_dec: 1/4原图大小的low-level decoder feature map 33 | v_sid: the 3D shape aware identity vector 34 | 35 | Returns: 36 | -------- 37 | i_r: re-target image 38 | i_low: 1/4 size retarget image 39 | m_r: face mask 40 | m_low: 1/4 size face mask 41 | """ 42 | z_enc = self.sigma(z_enc) 43 | 44 | # 估算z_dec对应的人脸 low-level feature mask 45 | m_low = self.low_mask_predict(z_dec) 46 | 47 | # 计算融合的low-level feature map 48 | # mask区域使用decoder的low-level特征 + 非mask区域使用encoder的low-level特征 49 | z_fuse = m_low * z_dec + (1 - m_low) * z_enc 50 | 51 | z_fuse = self.z_fuse_block_1(z_fuse, v_sid) 52 | z_fuse = self.z_fuse_block_2(z_fuse, v_sid) 53 | 54 | i_low = self.i_low_block(z_fuse) 55 | 56 | i_low = m_low * i_low + (1 - m_low) * F.interpolate(target_image, scale_factor=0.25) 57 | 58 | i_r, m_r = self.f_up(z_fuse) 59 | i_r = m_r * i_r + (1 - m_r) * target_image 60 | 61 | return i_r, i_low, m_r, m_low 62 | 63 | 64 | if __name__ == "__main__": 65 | import torch 66 | 67 | timg = torch.randn(1, 3, 256, 256) 68 | z_enc = torch.randn(1, 256, 64, 64) 69 | z_dec = torch.randn(1, 256, 64, 64) 70 | v_sid = torch.randn(1, 769) 71 | model = SemanticFaceFusionModule() 72 | i_r, i_low, m_r, m_low = model(timg, z_enc, z_dec, v_sid) 73 | print(i_r.shape, i_low.shape, m_r.shape, m_low.shape) 74 | -------------------------------------------------------------------------------- /models/shape_aware_identity_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from arcface_torch.backbones.iresnet import iresnet100 6 | from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper 7 | 8 | 9 | class ShapeAwareIdentityExtractor(nn.Module): 10 | def __init__(self, identity_extractor_config): 11 | """ 12 | Shape Aware Identity Extractor 13 | Parameters: 14 | ---------- 15 | identity_extractor_config: Dict[str, str] 16 | 必须包含以下内容: 17 | f_3d_checkpoint_path: str 18 | 3D人脸重建模型路径,如"model/Deep3DFaceRecon_pytorch/checkpoints/epoch_20.pth" 19 | f_id_checkpoint_path: str 20 | arcface人脸识别模型路径 21 | 非官方实现用的是https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215585&cid=4A83B6B633B029CC/backbone.pth 22 | """ 23 | super(ShapeAwareIdentityExtractor, self).__init__() 24 | f_3d_checkpoint_path = identity_extractor_config["f_3d_checkpoint_path"] 25 | f_id_checkpoint_path = identity_extractor_config["f_id_checkpoint_path"] 26 | # 3D人脸重建模型 27 | self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) 28 | self.f_3d.load_state_dict(torch.load(f_3d_checkpoint_path, map_location="cpu")["net_recon"]) 29 | self.f_3d.eval() 30 | 31 | # 人脸识别模型 32 | self.f_id = iresnet100(pretrained=False, fp16=False) 33 | self.f_id.load_state_dict(torch.load(f_id_checkpoint_path, map_location="cpu")) 34 | self.f_id.eval() 35 | 36 | @torch.no_grad() 37 | def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): 38 | """ 39 | 插值shape和id信息 40 | """ 41 | c_s = self.f_3d(i_source) 42 | c_t = self.f_3d(i_target) 43 | c_interp = shape_rate * c_s + (1 - shape_rate) * c_t 44 | c_fuse = torch.cat((c_interp[:, :80], c_t[:, 80:]), dim=1) 45 | # extract source face identity feature 46 | v_s = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) 47 | v_t = F.normalize(self.f_id(F.interpolate((i_target - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) 48 | v_id = id_rate * v_s + (1 - id_rate) * v_t 49 | # concat new shape feature and source identity 50 | v_sid = torch.cat((c_fuse, v_id), dim=1) 51 | return v_sid 52 | 53 | def forward(self, i_source, i_target): 54 | """ 55 | Parameters: 56 | ----------- 57 | i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image 58 | i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image 59 | 60 | Returns: 61 | -------- 62 | v_sid: torch.Tensor, fused shape and id features 63 | """ 64 | # regress 3DMM coefficients 65 | c_s = self.f_3d(i_source) 66 | c_t = self.f_3d(i_target) 67 | 68 | # generate a new 3D face model: source's identity + target's posture and expression 69 | # from https://github.com/sicxu/Deep3DFaceRecon_pytorch/blob/f221678d4b49ca35f1275ba60f721ecb38a2cd19/models/networks.py#L85 70 | c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) 71 | 72 | # extract source face identity feature 73 | v_id = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) 74 | 75 | # concat new shape feature and source identity 76 | v_sid = torch.cat((c_fuse, v_id), dim=1) 77 | return v_sid 78 | -------------------------------------------------------------------------------- /results/exp_230901_base_1693564635742_320000_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/results/exp_230901_base_1693564635742_320000_1.jpg -------------------------------------------------------------------------------- /results/origan-v0-new-3d-250k-eye-mouth-hm-weight-10k-10k_1685515837755_190000_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/results/origan-v0-new-3d-250k-eye-mouth-hm-weight-10k-10k_1685515837755_190000_1.jpg -------------------------------------------------------------------------------- /results/p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/results/p1.png -------------------------------------------------------------------------------- /results/p2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/results/p2.png -------------------------------------------------------------------------------- /results/p3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/results/p3.png -------------------------------------------------------------------------------- /results/p4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/results/p4.png -------------------------------------------------------------------------------- /results/p5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuehy/HiFiFace-pytorch/0e50b25909b5910e9327d3cb44eeb054f1c047d9/results/p5.png -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.tensorboard import SummaryWriter 3 | 4 | 5 | class Visualizer: 6 | """ 7 | Tensorboard 可视化监控类 8 | """ 9 | 10 | def __init__(self, opt): 11 | """ """ 12 | self.opt = opt # cache the option 13 | self.writer = SummaryWriter(log_dir=opt.log_dir) 14 | 15 | def display_current_results(self, iters, visuals_dict): 16 | """ 17 | Display current images 18 | 19 | Parameters: 20 | ---------- 21 | visuals (OrderedDict) - - dictionary of images to display 22 | iters (int) - - the current iteration 23 | """ 24 | for label, image in visuals_dict.items(): 25 | if image.shape[0] >= 2: 26 | image = image[0:2, :, :, :] 27 | self.writer.add_images(str(label), (image * 255.0).to(torch.uint8), global_step=iters, dataformats="NCHW") 28 | 29 | def plot_current_losses(self, iters, loss_dict): 30 | """ 31 | Display losses on tensorboard 32 | 33 | Parameters: 34 | iters (int) -- current iteration 35 | losses (OrderedDict) -- training losses stored in the format of (name, torch.Tensor) pairs 36 | """ 37 | x = iters 38 | for k, v in loss_dict.items(): 39 | self.writer.add_scalar(f"Loss/{k}", v, x) 40 | --------------------------------------------------------------------------------