├── utils ├── __init__.py ├── utils_logging.py ├── plot.py ├── utils_distributed_sampler.py ├── utils_callbacks.py └── evaluate_utils.py ├── docs └── static │ ├── .DS_Store │ ├── images │ ├── .DS_Store │ ├── jhu_web.png │ ├── briar_ijbs.png │ ├── petalface.png │ ├── tinyface.png │ ├── intro_figure.png │ └── visual_abstract.png │ ├── js │ ├── index.js │ └── bulma-slider.min.js │ └── css │ ├── bulma-carousel.min.css │ ├── index.css │ └── bulma-slider.min.css ├── heads ├── __init__.py └── partial_fc.py ├── validation_lq ├── create_csv.py ├── data_utils.py ├── validate_ijbs.py ├── validate_ijbs_iqa.py ├── tinyface_helper.py ├── validate_tinyface.py ├── evaluate_helper.py ├── PFE │ └── utils.py └── validate_tinyface_iqa.py ├── scripts ├── train_lora.sh ├── train_iqa.sh ├── pretrain.sh ├── eval_hq.sh ├── eval_tinyface.sh ├── eval_ijbs.sh └── eval_ijb.sh ├── LICENSE ├── dataset └── label_to_idx.py ├── lr_scheduler.py ├── config.py ├── losses.py ├── .gitignore ├── backbones ├── lora_layers.py ├── mobilefacenet.py ├── __init__.py └── iresnet2060.py ├── environment.yml ├── train.py └── train_iqa.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/static/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/.DS_Store -------------------------------------------------------------------------------- /docs/static/images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/images/.DS_Store -------------------------------------------------------------------------------- /docs/static/images/jhu_web.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/images/jhu_web.png -------------------------------------------------------------------------------- /docs/static/images/briar_ijbs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/images/briar_ijbs.png -------------------------------------------------------------------------------- /docs/static/images/petalface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/images/petalface.png -------------------------------------------------------------------------------- /docs/static/images/tinyface.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/images/tinyface.png -------------------------------------------------------------------------------- /docs/static/images/intro_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/images/intro_figure.png -------------------------------------------------------------------------------- /docs/static/images/visual_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kartik-3004/PETALface/HEAD/docs/static/images/visual_abstract.png -------------------------------------------------------------------------------- /heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .partial_fc import PartialFC_V2 2 | 3 | def get_head(name, **kwargs): 4 | if name == "partial_fc": 5 | return PartialFC_V2(**kwargs) 6 | else: 7 | raise ValueError("Head not Implemented") -------------------------------------------------------------------------------- /validation_lq/create_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | data_root = '/mnt/store/knaraya4/data/ijbs_aligned_180' 5 | output_csv = '/mnt/store/knaraya4/data/IJBS/image_paths_180.csv' 6 | 7 | image_paths = [] 8 | 9 | for root, dirs, files in os.walk(data_root): 10 | for file in files: 11 | if file.lower().endswith(('.png', '.jpg', '.jpeg')): 12 | image_paths.append(os.path.join(root, file)) 13 | 14 | with open(output_csv, mode='w', newline='') as csvfile: 15 | csv_writer = csv.writer(csvfile) 16 | csv_writer.writerow(['index', 'path']) 17 | for index, path in enumerate(image_paths): 18 | csv_writer.writerow([index, path]) 19 | 20 | print(f'All image paths have been written to {output_csv}') -------------------------------------------------------------------------------- /scripts/train_lora.sh: -------------------------------------------------------------------------------- 1 | ### CosFace | TinyFace ### 2 | NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29190 train.py \ 3 | --network swin_256new_iqa \ 4 | --head partial_fc \ 5 | --output /mnt/store/knaraya4/PETALface/ \ 6 | --margin_list 1.0,0.0,0.4 \ 7 | --batch-size 8 \ 8 | --optimizer adamw \ 9 | --weight_decay 0.1 \ 10 | --rec /data/knaraya4/data/ \ 11 | --num_classes 2570 \ 12 | --num_image 7804 \ 13 | --num_epoch 50 \ 14 | --lr 0.0005 \ 15 | --fp16 \ 16 | --warmup_epoch 2 \ 17 | --image_size 120 \ 18 | --use_lora \ 19 | --lora_rank 8 \ 20 | --seed 19 \ 21 | --load_pretrained /mnt/store/knaraya4/PETALface/ 22 | 23 | ### 24 | # For CosFace, --margin_list 1.0,0.0,0.4; For ArcFace, --margin_list 1.0,0.5,0.0 25 | # For TinyFace, 26 | # --num_classes 2570 27 | # --num_image 7804 28 | # --num_epoch 50 29 | # --warmup_epoch 2 30 | # For BRIAR, 31 | # --num_classes 778 32 | # --num_image 301000 33 | # --num_epoch 10 34 | # --warmup_epoch 1 35 | ### -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kartik Narayan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dataset/label_to_idx.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import pickle 4 | import os 5 | 6 | def create_label_to_indices_mapping(rec_file, idx_file, output_file): 7 | print("Creating label to indices mapping...") 8 | imgrec = mx.recordio.MXIndexedRecordIO(idx_file, rec_file, 'r') 9 | label_to_indices = {} 10 | 11 | for i in range(len(imgrec.keys)): 12 | idx = imgrec.keys[i] 13 | header, _ = mx.recordio.unpack(imgrec.read_idx(idx)) 14 | label = int(header.label) if isinstance(header.label, float) else int(header.label[0]) 15 | if label not in label_to_indices: 16 | label_to_indices[label] = [] 17 | label_to_indices[label].append(idx) 18 | 19 | print("Saving mapping to", output_file) 20 | with open(output_file, 'wb') as f: 21 | pickle.dump(label_to_indices, f) 22 | 23 | print("Done.") 24 | 25 | # Specify your RecordIO files and output .pkl file 26 | rec_file = '/mnt/store/knaraya4/data/WebFace4M/train.rec' 27 | idx_file = '/mnt/store/knaraya4/data/WebFace4M/train.idx' 28 | output_file = '/mnt/store/knaraya4/data/WebFace4M/train.pkl' 29 | 30 | # Make sure the output directory exists 31 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 32 | 33 | create_label_to_indices_mapping(rec_file, idx_file, output_file) 34 | -------------------------------------------------------------------------------- /scripts/train_iqa.sh: -------------------------------------------------------------------------------- 1 | ### BRISQUE | CosFace | TinyFace ### 2 | NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29190 train_iqa.py \ 3 | --network swin_256new_iqa \ 4 | --head partial_fc \ 5 | --output /mnt/store/knaraya4/PETALface/ \ 6 | --margin_list 1.0,0.0,0.4 \ 7 | --batch-size 8 \ 8 | --optimizer adamw \ 9 | --weight_decay 0.1 \ 10 | --rec /data/knaraya4/data/ \ 11 | --num_classes 2570 \ 12 | --num_image 7804 \ 13 | --num_epoch 50 \ 14 | --lr 0.0005 \ 15 | --fp16 \ 16 | --warmup_epoch 2 \ 17 | --image_size 120 \ 18 | --use_lora \ 19 | --lora_rank 8 \ 20 | --iqa brisque \ 21 | --threshold \ 22 | --seed 19 \ 23 | --load_pretrained /mnt/store/knaraya4/PETALface/ 24 | 25 | ### 26 | # For CosFace, --margin_list 1.0,0.0,0.4; For ArcFace, --margin_list 1.0,0.5,0.0 27 | # For BRISQUE, --iqa brisque; For CNNIQA, --iqa cnniqa; set the threshold accordingly 28 | # For TinyFace, 29 | # --num_classes 2570 30 | # --num_image 7804 31 | # --num_epoch 40 32 | # --warmup_epoch 2 33 | # For BRIAR, 34 | # --num_classes 778 35 | # --num_image 301000 36 | # --num_epoch 10 37 | # --warmup_epoch 1 38 | ### 39 | 40 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | ### ArcFace | WebFace 4M ### 2 | NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29110 train.py \ 3 | --network swin_256new \ 4 | --head partial_fc \ 5 | --output /cis/home/knaraya4/PETALface/ \ 6 | --margin_list 1.0,0.5,0.0 \ 7 | --batch-size 128 \ 8 | --optimizer adamw \ 9 | --weight_decay 0.05 \ 10 | --rec /cis/home/knaraya4/data/ \ 11 | --num_classes 205990 \ 12 | --num_image 4235242 \ 13 | --num_epoch 26 \ 14 | --lr 0.001 \ 15 | --fp16 \ 16 | --warmup_epoch 1 17 | 18 | ### ArcFace | WebFace 12M ### 19 | NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29120 train.py \ 20 | --network swin_256new \ 21 | --head partial_fc \ 22 | --output /cis/home/knaraya4/PETALface/ \ 23 | --margin_list 1.0,0.5,0.0 \ 24 | --batch-size 256 \ 25 | --optimizer adamw \ 26 | --weight_decay 0.05 \ 27 | --rec /cis/home/knaraya4/data/webface12m_dataset_rec_file \ 28 | --num_classes 617970 \ 29 | --num_image 12720066 \ 30 | --num_epoch 20 \ 31 | --lr 0.001 \ 32 | --fp16 \ 33 | --warmup_epoch 1 34 | 35 | # For CosFace, --margin_list 1.0,0.0,0.4; For ArcFace, --margin_list 1.0,0.5,0.0 -------------------------------------------------------------------------------- /scripts/eval_hq.sh: -------------------------------------------------------------------------------- 1 | ### CNNIQA ### 2 | CUDA_VISIBLE_DEVICES=0 python validation_hq/validate_hq_iqa.py \ 3 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 4 | --data_root /data/knaraya4/data/ \ 5 | --model_type swin_256new_iqa \ 6 | --image_size 120 \ 7 | --lora_rank 8 \ 8 | --use_lora \ 9 | --iqa cnniqa \ 10 | --threshold 11 | 12 | ### BRISQUE ### 13 | CUDA_VISIBLE_DEVICES=0 python validation_hq/validate_hq_iqa.py \ 14 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 15 | --data_root /data/knaraya4/data/ \ 16 | --model_type swin_256new_iqa \ 17 | --image_size 120 \ 18 | --lora_rank 8 \ 19 | --use_lora \ 20 | --iqa brisque \ 21 | --threshold 22 | 23 | ### LoRA ### 24 | CUDA_VISIBLE_DEVICES=0 python validation_hq/validate_hq.py \ 25 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 26 | --data_root /data/knaraya4/data/ \ 27 | --model_type swin_256new \ 28 | --image_size 120 \ 29 | --lora_rank 8 \ 30 | --use_lora 31 | 32 | ### PreTrained ### 33 | CUDA_VISIBLE_DEVICES=0 python validation_hq/validate_hq.py \ 34 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 35 | --data_root /data/knaraya4/data/ \ 36 | --model_type swin_256new \ 37 | --image_size 120 38 | -------------------------------------------------------------------------------- /scripts/eval_tinyface.sh: -------------------------------------------------------------------------------- 1 | ### BRISQUE ### 2 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_tinyface_iqa.py \ 3 | --data_root /data/knaraya4/data \ 4 | --batch_size 512 \ 5 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 6 | --model_type swin_256new_iqa \ 7 | --image_size 120 \ 8 | --lora_rank 8 \ 9 | --use_lora \ 10 | --iqa brisque \ 11 | --threshold 12 | 13 | ### CNNIQA ### 14 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_tinyface_iqa.py \ 15 | --data_root /data/knaraya4/data \ 16 | --batch_size 512 \ 17 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 18 | --model_type swin_256new_iqa \ 19 | --image_size 120 \ 20 | --lora_rank 8 \ 21 | --use_lora \ 22 | --iqa cnniqa \ 23 | --threshold 24 | 25 | ### LoRA ### 26 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_tinyface.py \ 27 | --data_root /data/knaraya4/data \ 28 | --batch_size 512 \ 29 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 30 | --model_type swin_256new \ 31 | --image_size 120 \ 32 | --lora_rank 8 \ 33 | --use_lora 34 | 35 | ### PreTrained ### 36 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_tinyface.py \ 37 | --data_root /data/knaraya4/data \ 38 | --batch_size 512 \ 39 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 40 | --model_type swin_256new \ 41 | --image_size 120 42 | -------------------------------------------------------------------------------- /scripts/eval_ijbs.sh: -------------------------------------------------------------------------------- 1 | ### BRISQUE ### 2 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_ijbs_iqa.py \ 3 | --data_root /mnt/store/knaraya4/data/ \ 4 | --batch_size 2048 \ 5 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 6 | --model_type swin_256new_iqa \ 7 | --image_size 120 \ 8 | --lora_rank 8 \ 9 | --use_lora \ 10 | --iqa brisque \ 11 | --threshold 12 | 13 | 14 | ### CNNIQA ### 15 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_ijbs_iqa.py \ 16 | --data_root /mnt/store/knaraya4/data/ \ 17 | --batch_size 2048 \ 18 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 19 | --model_type swin_256new_iqa \ 20 | --image_size 120 \ 21 | --lora_rank 8 \ 22 | --use_lora \ 23 | --iqa cnniqa \ 24 | --threshold 25 | 26 | ### LoRA ### 27 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_ijbs.py \ 28 | --data_root /mnt/store/knaraya4/data/ \ 29 | --batch_size 2048 \ 30 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 31 | --model_type swin_256new \ 32 | --image_size 120 \ 33 | --lora_rank 8 \ 34 | --use_lora 35 | 36 | ### PreTrained ### 37 | CUDA_VISIBLE_DEVICES=0 python validation_lq/validate_ijbs.py \ 38 | --data_root /mnt/store/knaraya4/data/ \ 39 | --batch_size 2048 \ 40 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 41 | --model_type swin_256new \ 42 | --image_size 120 -------------------------------------------------------------------------------- /utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(rank, models_root): 31 | if rank == 0: 32 | log_root = logging.getLogger() 33 | log_root.setLevel(logging.INFO) 34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | 43 | def init_logging_test(rank, save_path): 44 | log_root = logging.getLogger() 45 | log_root.setLevel(logging.INFO) 46 | formatter = logging.Formatter("Testing: %(asctime)s-%(message)s") 47 | handler_file = logging.FileHandler(os.path.join(save_path, "testing.log")) 48 | handler_stream = logging.StreamHandler(sys.stdout) 49 | handler_file.setFormatter(formatter) 50 | handler_stream.setFormatter(formatter) 51 | log_root.addHandler(handler_file) 52 | log_root.addHandler(handler_stream) 53 | return log_root 54 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 8 | from prettytable import PrettyTable 9 | from sklearn.metrics import roc_curve, auc 10 | 11 | with open(sys.argv[1], "r") as f: 12 | files = f.readlines() 13 | 14 | files = [x.strip() for x in files] 15 | image_path = "/train_tmp/IJB_release/IJBC" 16 | 17 | 18 | def read_template_pair_list(path): 19 | pairs = pd.read_csv(path, sep=' ', header=None).values 20 | t1 = pairs[:, 0].astype(np.int) 21 | t2 = pairs[:, 1].astype(np.int) 22 | label = pairs[:, 2].astype(np.int) 23 | return t1, t2, label 24 | 25 | 26 | p1, p2, label = read_template_pair_list( 27 | os.path.join('%s/meta' % image_path, 28 | '%s_template_pair_label.txt' % 'ijbc')) 29 | 30 | methods = [] 31 | scores = [] 32 | for file in files: 33 | methods.append(file) 34 | scores.append(np.load(file)) 35 | 36 | methods = np.array(methods) 37 | scores = dict(zip(methods, scores)) 38 | colours = dict( 39 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 40 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 41 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 42 | fig = plt.figure() 43 | for method in methods: 44 | fpr, tpr, _ = roc_curve(label, scores[method]) 45 | roc_auc = auc(fpr, tpr) 46 | fpr = np.flipud(fpr) 47 | tpr = np.flipud(tpr) # select largest tpr at same fpr 48 | plt.plot(fpr, 49 | tpr, 50 | color=colours[method], 51 | lw=1, 52 | label=('[%s (AUC = %0.4f %%)]' % 53 | (method.split('-')[-1], roc_auc * 100))) 54 | tpr_fpr_row = [] 55 | tpr_fpr_row.append(method) 56 | for fpr_iter in np.arange(len(x_labels)): 57 | _, min_index = min( 58 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 59 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 60 | tpr_fpr_table.add_row(tpr_fpr_row) 61 | plt.xlim([10 ** -6, 0.1]) 62 | plt.ylim([0.3, 1.0]) 63 | plt.grid(linestyle='--', linewidth=1) 64 | plt.xticks(x_labels) 65 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 66 | plt.xscale('log') 67 | plt.xlabel('False Positive Rate') 68 | plt.ylabel('True Positive Rate') 69 | plt.title('ROC on IJB') 70 | plt.legend(loc="lower right") 71 | print(tpr_fpr_table) 72 | -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function () { return false; }; 18 | image.oncontextmenu = function () { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function () { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function () { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for (var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function (state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function (event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim import SGD 3 | import torch 4 | import warnings 5 | 6 | class PolynomialLRWarmup(_LRScheduler): 7 | def __init__(self, optimizer, warmup_iters, total_iters=5, power=1.0, last_epoch=-1, verbose=False): 8 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 9 | self.total_iters = total_iters 10 | self.power = power 11 | self.warmup_iters = warmup_iters 12 | 13 | 14 | def get_lr(self): 15 | if not self._get_lr_called_within_step: 16 | warnings.warn("To get the last learning rate computed by the scheduler, " 17 | "please use `get_last_lr()`.", UserWarning) 18 | 19 | if self.last_epoch == 0 or self.last_epoch > self.total_iters: 20 | return [group["lr"] for group in self.optimizer.param_groups] 21 | 22 | if self.last_epoch <= self.warmup_iters: 23 | return [base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs] 24 | else: 25 | l = self.last_epoch 26 | w = self.warmup_iters 27 | t = self.total_iters 28 | decay_factor = ((1.0 - (l - w) / (t - w)) / (1.0 - (l - 1 - w) / (t - w))) ** self.power 29 | return [group["lr"] * decay_factor for group in self.optimizer.param_groups] 30 | 31 | def _get_closed_form_lr(self): 32 | 33 | if self.last_epoch <= self.warmup_iters: 34 | return [ 35 | base_lr * self.last_epoch / self.warmup_iters for base_lr in self.base_lrs] 36 | else: 37 | return [ 38 | ( 39 | base_lr * (1.0 - (min(self.total_iters, self.last_epoch) - self.warmup_iters) / (self.total_iters - self.warmup_iters)) ** self.power 40 | ) 41 | for base_lr in self.base_lrs 42 | ] 43 | 44 | 45 | if __name__ == "__main__": 46 | 47 | class TestModule(torch.nn.Module): 48 | def __init__(self) -> None: 49 | super().__init__() 50 | self.linear = torch.nn.Linear(32, 32) 51 | 52 | def forward(self, x): 53 | return self.linear(x) 54 | 55 | test_module = TestModule() 56 | test_module_pfc = TestModule() 57 | lr_pfc_weight = 1 / 3 58 | base_lr = 10 59 | total_steps = 1000 60 | 61 | sgd = SGD([ 62 | {"params": test_module.parameters(), "lr": base_lr}, 63 | {"params": test_module_pfc.parameters(), "lr": base_lr * lr_pfc_weight} 64 | ], base_lr) 65 | 66 | scheduler = PolynomialLRWarmup(sgd, total_steps//10, total_steps, power=2) 67 | 68 | x = [] 69 | y = [] 70 | y_pfc = [] 71 | for i in range(total_steps): 72 | scheduler.step() 73 | lr = scheduler.get_last_lr()[0] 74 | lr_pfc = scheduler.get_last_lr()[1] 75 | x.append(i) 76 | y.append(lr) 77 | y_pfc.append(lr_pfc) 78 | 79 | import matplotlib.pyplot as plt 80 | fontsize=15 81 | plt.figure(figsize=(6, 6)) 82 | plt.plot(x, y, linestyle='-', linewidth=2, ) 83 | plt.plot(x, y_pfc, linestyle='-', linewidth=2, ) 84 | plt.xlabel('Iterations') # x_label 85 | plt.ylabel("Lr") # y_label 86 | plt.savefig("tmp.png", dpi=600, bbox_inches='tight') 87 | -------------------------------------------------------------------------------- /scripts/eval_ijb.sh: -------------------------------------------------------------------------------- 1 | ### IJBC BRISQUE ### 2 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb_iqa.py \ 3 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 4 | --data_root /data/knaraya4/data/ijb/ \ 5 | --batch-size 1024 \ 6 | --model_type swin_256new_iqa \ 7 | --target IJBC \ 8 | --image_size 120 \ 9 | --lora_rank 8 \ 10 | --use_lora \ 11 | --iqa brisque \ 12 | --threshold 13 | 14 | ### IJBB BRISQUE ### 15 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb_iqa.py \ 16 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 17 | --data_root /data/knaraya4/data/ \ 18 | --batch-size 1024 \ 19 | --model_type swin_256new_iqa \ 20 | --target IJBB \ 21 | --image_size 120 \ 22 | --lora_rank 8 \ 23 | --use_lora \ 24 | --iqa brisque \ 25 | --threshold 26 | 27 | ### IJBC CNNIQA ### 28 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb_iqa.py \ 29 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 30 | --data_root /data/knaraya4/data/ijb/ \ 31 | --batch-size 1024 \ 32 | --model_type swin_256new_iqa \ 33 | --target IJBC \ 34 | --image_size 120 \ 35 | --lora_rank 8 \ 36 | --use_lora \ 37 | --iqa cnniqa \ 38 | --threshold 39 | 40 | ### IJBB CNNIQA ### 41 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb_iqa.py \ 42 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 43 | --data_root /data/knaraya4/data/ \ 44 | --batch-size 1024 \ 45 | --model_type swin_256new_iqa \ 46 | --target IJBB \ 47 | --image_size 120 \ 48 | --lora_rank 8 \ 49 | --use_lora \ 50 | --iqa cnniqa \ 51 | --threshold 52 | 53 | 54 | ### IJBC LoRA ### 55 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb.py \ 56 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 57 | --data_root /data/knaraya4/data/ijb/ \ 58 | --batch-size 1024 \ 59 | --model_type swin_256new \ 60 | --target IJBC \ 61 | --image_size 120 \ 62 | --lora_rank 8 \ 63 | --use_lora 64 | 65 | ### IJBB LoRA ### 66 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb.py \ 67 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 68 | --data_root /data/knaraya4/data/ \ 69 | --batch-size 1024 \ 70 | --model_type swin_256new \ 71 | --target IJBB \ 72 | --image_size 120 \ 73 | --lora_rank 8 \ 74 | --use_lora 75 | 76 | ### IJBC PreTrained ### 77 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb.py \ 78 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 79 | --data_root /data/knaraya4/data/ijb/ \ 80 | --batch-size 1024 \ 81 | --model_type swin_256new \ 82 | --target IJBC \ 83 | --image_size 120 84 | 85 | ### IJBB PreTrained ### 86 | CUDA_VISIBLE_DEVICES=0 python validation_ijb/eval_ijb.py \ 87 | --model_load_path /mnt/store/knaraya4/PETALface//model.pt \ 88 | --data_root /data/knaraya4/data/ \ 89 | --batch-size 1024 \ 90 | --model_type swin_256new \ 91 | --target IJBB \ 92 | --image_size 120 -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | from time import gmtime, strftime 5 | 6 | def get_args(): 7 | 8 | parser = argparse.ArgumentParser(add_help=False) 9 | ## Partial FC ## 10 | parser.add_argument('--sample_rate', type=int, default=1) 11 | parser.add_argument('--interclass_filtering_threshold', type=int, default=0) 12 | parser.add_argument('--fp16', action='store_true') 13 | ## Logging ## 14 | parser.add_argument('--verbose', type=int, default=2000) 15 | parser.add_argument('--frequent', type=int, default=10) 16 | parser.add_argument('--seed', default=2048, type=int) 17 | ## Dali ## 18 | parser.add_argument('--dali', type=bool, default=False) 19 | parser.add_argument('--num_workers', default=8, type=int) 20 | ## Wandb ## 21 | parser.add_argument('--wandb_key', type=str) 22 | parser.add_argument('--suffix_run_name', type=str, default="None") 23 | parser.add_argument('--using_wandb', default=False, type=bool) 24 | parser.add_argument('--wandb_entity', type=str) 25 | parser.add_argument('--wandb_project', type=str) 26 | parser.add_argument('--wandb_log_all', type=bool, default=False) 27 | parser.add_argument('--save_artifacts', type=bool, default=False) 28 | parser.add_argument('--wandb_resume', type=bool, default=False) 29 | parser.add_argument('--notes', type=str, default="Lorem Ipsum") 30 | parser.add_argument('--dir', type=str) 31 | ## Data ## 32 | parser.add_argument('--val_targets', type=str, default='lfw,cfp_fp,agedb_30') 33 | parser.add_argument('--train_data', type=str, default='webface4m', help='ms1mv3 or webface4m') 34 | parser.add_argument('--image_size', type=int, default=112, help='112 or 120') 35 | ## Training ## 36 | parser.add_argument('--network', type=str, default='r50') 37 | parser.add_argument('--embedding_size', type=int, default=512) 38 | parser.add_argument('--batch-size', type=int, default=128) 39 | parser.add_argument('--num_epoch', type=int, default=20) 40 | parser.add_argument('--warmup_epoch', default=0, type=int) 41 | parser.add_argument('--dali_aug', action="store_true") 42 | parser.add_argument('--head', type=str, default='partial_fc') 43 | parser.add_argument('--use_lora', action="store_true") 44 | parser.add_argument('--lora_rank', type=int, default=4) 45 | parser.add_argument('--lora_scale', type=int, default=1) 46 | parser.add_argument('--load_pretrained', type=str) 47 | ## Data webface4m ## 48 | parser.add_argument('--rec', type=str, default='/data/knaraya4/data/WebFace4M') 49 | parser.add_argument('--num_classes', type=int, default=205990) 50 | parser.add_argument('--num_image', type=int, default=4235242) 51 | ## Loss ## 52 | parser.add_argument('--margin_list', type=str, default='1.0,0.5,0.0') 53 | ## Logging ## 54 | parser.add_argument('--resume', type=bool, default=False) 55 | parser.add_argument('--save_all_states', type=bool, default=True) 56 | parser.add_argument('--output', type=str) 57 | ## Optimizer ## 58 | parser.add_argument('--optimizer', type=str, default='sgd', help="sgd or adamW") 59 | parser.add_argument('--lr', type=float, default=0.1, help='0.1 or 0.001') 60 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum for SGD') 61 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='5e-4 or 0.1') 62 | parser.add_argument('--gradient_acc', default=1, type=int) 63 | ## IQA ## 64 | parser.add_argument('--iqa', type=str) 65 | parser.add_argument('--threshold', type=float, default=0.5) 66 | 67 | args = parser.parse_args() 68 | 69 | args.margin_list = [float(x) for x in args.margin_list.split(',')] 70 | args.val_targets = [str(x) for x in args.val_targets.split(',')] 71 | return args -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | class CombinedMarginLoss(torch.nn.Module): 5 | def __init__(self, 6 | s, 7 | m1, 8 | m2, 9 | m3, 10 | interclass_filtering_threshold=0): 11 | super().__init__() 12 | self.s = s 13 | self.m1 = m1 14 | self.m2 = m2 15 | self.m3 = m3 16 | self.interclass_filtering_threshold = interclass_filtering_threshold 17 | 18 | # For ArcFace 19 | self.cos_m = math.cos(self.m2) 20 | self.sin_m = math.sin(self.m2) 21 | self.theta = math.cos(math.pi - self.m2) 22 | self.sinmm = math.sin(math.pi - self.m2) * self.m2 23 | self.easy_margin = False 24 | 25 | 26 | def forward(self, logits, labels): 27 | index_positive = torch.where(labels != -1)[0] 28 | 29 | if self.interclass_filtering_threshold > 0: 30 | with torch.no_grad(): 31 | dirty = logits > self.interclass_filtering_threshold 32 | dirty = dirty.float() 33 | mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) 34 | mask.scatter_(1, labels[index_positive], 0) 35 | dirty[index_positive] *= mask 36 | tensor_mul = 1 - dirty 37 | logits = tensor_mul * logits 38 | 39 | target_logit = logits[index_positive, labels[index_positive].view(-1)] 40 | 41 | if self.m1 == 1.0 and self.m3 == 0.0: 42 | with torch.no_grad(): 43 | target_logit.arccos_() 44 | logits.arccos_() 45 | final_target_logit = target_logit + self.m2 46 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 47 | logits.cos_() 48 | logits = logits * self.s 49 | 50 | elif self.m3 > 0: 51 | final_target_logit = target_logit - self.m3 52 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 53 | logits = logits * self.s 54 | else: 55 | raise 56 | 57 | return logits 58 | 59 | class ArcFace(torch.nn.Module): 60 | """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): 61 | """ 62 | def __init__(self, s=64.0, margin=0.5): 63 | super(ArcFace, self).__init__() 64 | self.s = s 65 | self.margin = margin 66 | self.cos_m = math.cos(margin) 67 | self.sin_m = math.sin(margin) 68 | self.theta = math.cos(math.pi - margin) 69 | self.sinmm = math.sin(math.pi - margin) * margin 70 | self.easy_margin = False 71 | 72 | 73 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 74 | index = torch.where(labels != -1)[0] 75 | 76 | target_logit = logits[index, labels[index].view(-1)] 77 | 78 | 79 | with torch.no_grad(): 80 | target_logit.arccos_() 81 | logits.arccos_() 82 | final_target_logit = target_logit + self.margin 83 | logits[index, labels[index].view(-1)] = final_target_logit 84 | logits.cos_() 85 | logits = logits * self.s 86 | return logits 87 | 88 | 89 | class CosFace(torch.nn.Module): 90 | def __init__(self, s=64.0, m=0.40): 91 | super(CosFace, self).__init__() 92 | self.s = s 93 | self.m = m 94 | 95 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 96 | index = torch.where(labels != -1)[0] 97 | target_logit = logits[index, labels[index].view(-1)] 98 | final_target_logit = target_logit - self.m 99 | logits[index, labels[index].view(-1)] = final_target_logit 100 | logits = logits * self.s 101 | return logits 102 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /backbones/lora_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from typing import Type 5 | 6 | class LoRALinear(nn.Linear): 7 | def __init__(self, in_features: int, out_features: int, r, scale, bias: bool=True) -> None: 8 | super().__init__(in_features, out_features, bias) 9 | self.r = r 10 | self.trainable_lora_down = nn.Linear(in_features, r, bias=False) 11 | self.dropout = nn.Dropout(0.1) 12 | self.trainable_lora_up = nn.Linear(r, out_features, bias=False) 13 | self.scale = scale 14 | self.selector = nn.Identity() 15 | 16 | nn.init.normal_(self.trainable_lora_down.weight, std=1/r) 17 | nn.init.zeros_(self.trainable_lora_up.weight) 18 | 19 | def forward(self, x): 20 | out = F.linear(x, self.weight, self.bias) 21 | lora_adjustment = self.scale*self.dropout(self.trainable_lora_up(self.selector(self.trainable_lora_down(x)))) 22 | out = out + lora_adjustment 23 | return out 24 | 25 | class LoRALinearTwo(nn.Linear): 26 | def __init__(self, in_features: int, out_features: int, r, scale, bias: bool=True) -> None: 27 | super().__init__(in_features, out_features, bias) 28 | self.r = r 29 | self.trainable_lora_down = nn.Linear(in_features, r, bias=False) 30 | self.dropout = nn.Dropout(0.1) 31 | self.trainable_lora_up = nn.Linear(r, out_features, bias=False) 32 | self.scale = scale 33 | self.selector = nn.Identity() 34 | 35 | nn.init.normal_(self.trainable_lora_down.weight, std=1/r) 36 | nn.init.zeros_(self.trainable_lora_up.weight) 37 | 38 | self.r2 = r 39 | self.trainable_lora_down2 = nn.Linear(in_features, r, bias=False) 40 | self.dropout2 = nn.Dropout(0.1) 41 | self.trainable_lora_up2 = nn.Linear(r, out_features, bias=False) 42 | self.scale2 = scale 43 | self.selector2 = nn.Identity() 44 | 45 | nn.init.normal_(self.trainable_lora_down2.weight, std=1/self.r2) 46 | nn.init.zeros_(self.trainable_lora_up2.weight) 47 | 48 | def forward(self, x, alpha): 49 | out = F.linear(x, self.weight, self.bias) 50 | lora_adjustment_1 = self.scale*self.dropout(self.trainable_lora_up(self.selector(self.trainable_lora_down((1-alpha)*x)))) 51 | lora_adjustment_2 = self.scale2*self.dropout2(self.trainable_lora_up2(self.selector2(self.trainable_lora_down2(alpha*x)))) 52 | out = out + lora_adjustment_1 + lora_adjustment_2 53 | return out 54 | 55 | class LoRAConv2D(nn.Conv2d): 56 | def __init__(self, in_channels: int, out_channels: int, r, scale, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) -> None: 57 | if isinstance(padding, int): 58 | padding = (padding, padding) 59 | if isinstance(dilation, int): 60 | dilation = (dilation, dilation) 61 | super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 62 | assert type(kernel_size) is int 63 | self.r = r 64 | self.scale = scale 65 | self.stride = stride 66 | self.padding = padding 67 | self.dilation = dilation 68 | self.groups = groups 69 | 70 | self.trainable_lora_down = nn.Conv2d( 71 | in_channels = in_channels, 72 | out_channels = r, 73 | kernel_size = kernel_size, 74 | stride=stride, 75 | padding=padding, 76 | dilation=dilation, 77 | groups=groups, 78 | bias=False 79 | ) 80 | 81 | self.trainable_lora_up = nn.Conv2d( 82 | in_channels=r, 83 | out_channels=out_channels, 84 | kernel_size=1, 85 | bias=False 86 | ) 87 | self.selector = nn.Identity() 88 | self.scale = scale 89 | 90 | nn.init.normal_(self.trainable_lora_down.weight, std=1/r) 91 | nn.init.zeros_(self.trainable_lora_up.weight) 92 | 93 | def forward(self, x): 94 | out = F.conv2d(x, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 95 | lora_adjustment = self.scale*self.trainable_lora_up(self.selector(self.trainable_lora_down(x))) 96 | out = out + lora_adjustment 97 | return out -------------------------------------------------------------------------------- /utils/utils_distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from torch.utils.data import DistributedSampler as _DistributedSampler 9 | 10 | 11 | def setup_seed(seed, cuda_deterministic=True): 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | if cuda_deterministic: # slower, more reproducible 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | else: # faster, less reproducible 21 | torch.backends.cudnn.deterministic = False 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def worker_init_fn(worker_id, num_workers, rank, seed): 26 | # The seed of each worker equals to 27 | # num_worker * rank + worker_id + user_seed 28 | worker_seed = num_workers * rank + worker_id + seed 29 | np.random.seed(worker_seed) 30 | random.seed(worker_seed) 31 | torch.manual_seed(worker_seed) 32 | 33 | 34 | def get_dist_info(): 35 | if dist.is_available() and dist.is_initialized(): 36 | rank = dist.get_rank() 37 | world_size = dist.get_world_size() 38 | else: 39 | rank = 0 40 | world_size = 1 41 | 42 | return rank, world_size 43 | 44 | 45 | def sync_random_seed(seed=None, device="cuda"): 46 | """Make sure different ranks share the same seed. 47 | All workers must call this function, otherwise it will deadlock. 48 | This method is generally used in `DistributedSampler`, 49 | because the seed should be identical across all processes 50 | in the distributed group. 51 | In distributed sampling, different ranks should sample non-overlapped 52 | data in the dataset. Therefore, this function is used to make sure that 53 | each rank shuffles the data indices in the same order based 54 | on the same seed. Then different ranks could use different indices 55 | to select non-overlapped data from the same data list. 56 | Args: 57 | seed (int, Optional): The seed. Default to None. 58 | device (str): The device where the seed will be put on. 59 | Default to 'cuda'. 60 | Returns: 61 | int: Seed to be used. 62 | """ 63 | if seed is None: 64 | seed = np.random.randint(2**31) 65 | assert isinstance(seed, int) 66 | 67 | rank, world_size = get_dist_info() 68 | 69 | if world_size == 1: 70 | return seed 71 | 72 | if rank == 0: 73 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 74 | else: 75 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 76 | 77 | dist.broadcast(random_num, src=0) 78 | 79 | return random_num.item() 80 | 81 | 82 | class DistributedSampler(_DistributedSampler): 83 | def __init__( 84 | self, 85 | dataset, 86 | num_replicas=None, # world_size 87 | rank=None, # local_rank 88 | shuffle=True, 89 | seed=0, 90 | ): 91 | 92 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 93 | 94 | # In distributed sampling, different ranks should sample 95 | # non-overlapped data in the dataset. Therefore, this function 96 | # is used to make sure that each rank shuffles the data indices 97 | # in the same order based on the same seed. Then different ranks 98 | # could use different indices to select non-overlapped data from the 99 | # same data list. 100 | self.seed = sync_random_seed(seed) 101 | 102 | def __iter__(self): 103 | # deterministically shuffle based on epoch 104 | if self.shuffle: 105 | g = torch.Generator() 106 | # When :attr:`shuffle=True`, this ensures all replicas 107 | # use a different random ordering for each epoch. 108 | # Otherwise, the next iteration of this sampler will 109 | # yield the same ordering. 110 | g.manual_seed(self.epoch + self.seed) 111 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 112 | else: 113 | indices = torch.arange(len(self.dataset)).tolist() 114 | 115 | # add extra samples to make it evenly divisible 116 | # in case that indices is shorter than half of total_size 117 | indices = (indices * math.ceil(self.total_size / len(indices)))[ 118 | : self.total_size 119 | ] 120 | assert len(indices) == self.total_size 121 | 122 | # subsample 123 | indices = indices[self.rank : self.total_size : self.num_replicas] 124 | assert len(indices) == self.num_samples 125 | 126 | return iter(indices) 127 | -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .is-smallcaps { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .author-block { 22 | display: inline-block; 23 | } 24 | 25 | .author-portrait { 26 | overflow: hidden; 27 | width: 128px; 28 | height: 128px; 29 | font-size: 0; 30 | margin: auto; 31 | } 32 | .author-portrait img { 33 | border-radius: 50%; 34 | width: 140px; 35 | position: relative; 36 | } 37 | 38 | .author-portrait .depth { 39 | } 40 | 41 | 42 | .teaser .hero-body { 43 | padding-top: 0; 44 | padding-bottom: 3rem; 45 | } 46 | 47 | .teaser { 48 | font-family: 'Google Sans', sans-serif; 49 | } 50 | 51 | 52 | .publication-title { 53 | } 54 | 55 | .publication-banner { 56 | max-height: parent; 57 | 58 | } 59 | 60 | .publication-banner video { 61 | position: relative; 62 | left: auto; 63 | top: auto; 64 | transform: none; 65 | object-fit: fit; 66 | } 67 | 68 | .publication-header .hero-body { 69 | } 70 | 71 | .publication-title { 72 | font-family: 'Google Sans', sans-serif; 73 | } 74 | 75 | .publication-authors { 76 | font-family: 'Google Sans', sans-serif; 77 | } 78 | 79 | .publication-venue { 80 | color: #555; 81 | width: fit-content; 82 | font-weight: bold; 83 | } 84 | 85 | .publication-awards { 86 | color: #ff3860; 87 | width: fit-content; 88 | font-weight: bolder; 89 | } 90 | 91 | .publication-authors { 92 | } 93 | 94 | .publication-authors a { 95 | color: hsl(204, 86%, 53%) !important; 96 | } 97 | 98 | .publication-authors a:hover { 99 | text-decoration: underline; 100 | } 101 | 102 | .publication-banner img { 103 | } 104 | 105 | .publication-authors { 106 | /*color: #4286f4;*/ 107 | } 108 | 109 | .publication-video { 110 | position: relative; 111 | width: 100%; 112 | height: 0; 113 | padding-bottom: 56.25%; 114 | 115 | overflow: hidden; 116 | border-radius: 10px !important; 117 | } 118 | 119 | .publication-video iframe { 120 | position: absolute; 121 | top: 0; 122 | left: 0; 123 | width: 100%; 124 | height: 100%; 125 | } 126 | 127 | .publication-body img { 128 | } 129 | 130 | /* 131 | .results-carousel { 132 | overflow: hidden; 133 | } 134 | */ 135 | 136 | .slick-prev { 137 | 138 | } 139 | 140 | .results-carousel > div { 141 | text-align: center; 142 | } 143 | 144 | .results-carousel .results-item { 145 | display: inline-block; 146 | width: fit-content; 147 | overflow: hidden; 148 | padding: 0; 149 | font-size: 0; 150 | } 151 | 152 | .results-carousel video { 153 | margin: 0; 154 | } 155 | 156 | 157 | .interpolation-panel { 158 | background: #f5f5f5; 159 | border-radius: 10px; 160 | } 161 | 162 | .interpolation-panel .interpolation-image { 163 | width: 100%; 164 | border-radius: 5px; 165 | } 166 | 167 | .interpolation-video-column { 168 | } 169 | 170 | .interpolation-panel .slider { 171 | margin: 0 !important; 172 | } 173 | 174 | .interpolation-panel .slider { 175 | margin: 0 !important; 176 | } 177 | 178 | #interpolation-image-wrapper { 179 | width: 100%; 180 | } 181 | #interpolation-image-wrapper img { 182 | border-radius: 5px; 183 | } 184 | 185 | 186 | .level-set-slices { 187 | width: 100%; 188 | height: 250px; 189 | border: 1px solid #eee; 190 | border-radius: 20px; 191 | } 192 | 193 | .level-set-shapes { 194 | font-size: 0; 195 | padding-left: 5%; 196 | } 197 | .level-set-shapes img { 198 | width: 30%; 199 | } 200 | .level-set-interpolate { 201 | width: 100%; 202 | } 203 | .level-set-ox-shapes img { 204 | width: 24%; 205 | } 206 | 207 | .content model-viewer { 208 | margin-bottom: 1em; 209 | } 210 | 211 | .hyper-space-wrapper { 212 | border: 1px solid #eee; 213 | border-radius: 20px; 214 | padding: 20px; 215 | } 216 | 217 | .hyper-space-axis { 218 | padding: 5px; 219 | border-left: 2px solid #000; 220 | border-bottom: 2px solid #000; 221 | } 222 | 223 | .hyper-space { 224 | width: 100%; 225 | padding-bottom: calc(100% - 7%); 226 | box_sizing: border-box; 227 | touch-action: none; 228 | 229 | background-image: url(../figures/hyper_log_prob.png); 230 | background-size: contain; 231 | border-radius: 10px; 232 | } 233 | 234 | .hyper-space-cursor { 235 | width: 7%; 236 | padding-bottom: calc(7% - 6px); 237 | background: #29e; 238 | border-radius: 50%; 239 | border: 3px solid #fff; 240 | } 241 | 242 | .hyper-grid-wrapper{ 243 | width: 95%; 244 | padding-bottom: 95%; 245 | overflow: hidden; 246 | border-radius: 50%; 247 | border: 2px solid #29e; 248 | position: relative; 249 | -webkit-mask-image: -webkit-radial-gradient(white, black); 250 | } 251 | 252 | .hyper-grid-rgb { 253 | width: 100%; 254 | height: 100%; 255 | overflow: hidden; 256 | 257 | position: absolute; 258 | } 259 | 260 | .hyper-grid-rgb img { 261 | position: relative; 262 | height: 2000%; 263 | width: 2000%; 264 | max-width: none !important; 265 | } 266 | -------------------------------------------------------------------------------- /docs/static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround { 2 | from { 3 | -webkit-transform: rotate(0); 4 | transform: rotate(0) 5 | } 6 | 7 | to { 8 | -webkit-transform: rotate(359deg); 9 | transform: rotate(359deg) 10 | } 11 | } 12 | 13 | @keyframes spinAround { 14 | from { 15 | -webkit-transform: rotate(0); 16 | transform: rotate(0) 17 | } 18 | 19 | to { 20 | -webkit-transform: rotate(359deg); 21 | transform: rotate(359deg) 22 | } 23 | } 24 | 25 | .slider { 26 | position: relative; 27 | width: 100% 28 | } 29 | 30 | .slider-container { 31 | display: flex; 32 | flex-wrap: nowrap; 33 | flex-direction: row; 34 | overflow: hidden; 35 | -webkit-transform: translate3d(0, 0, 0); 36 | transform: translate3d(0, 0, 0); 37 | min-height: 100% 38 | } 39 | 40 | .slider-container.is-vertical { 41 | flex-direction: column 42 | } 43 | 44 | .slider-container .slider-item { 45 | flex: none 46 | } 47 | 48 | .slider-container .slider-item .image.is-covered img { 49 | -o-object-fit: cover; 50 | object-fit: cover; 51 | -o-object-position: center center; 52 | object-position: center center; 53 | height: 100%; 54 | width: 100% 55 | } 56 | 57 | .slider-container .slider-item .video-container { 58 | height: 0; 59 | padding-bottom: 0; 60 | padding-top: 56.25%; 61 | margin: 0; 62 | position: relative 63 | } 64 | 65 | .slider-container .slider-item .video-container.is-1by1, 66 | .slider-container .slider-item .video-container.is-square { 67 | padding-top: 100% 68 | } 69 | 70 | .slider-container .slider-item .video-container.is-4by3 { 71 | padding-top: 75% 72 | } 73 | 74 | .slider-container .slider-item .video-container.is-21by9 { 75 | padding-top: 42.857143% 76 | } 77 | 78 | .slider-container .slider-item .video-container embed, 79 | .slider-container .slider-item .video-container iframe, 80 | .slider-container .slider-item .video-container object { 81 | position: absolute; 82 | top: 0; 83 | left: 0; 84 | width: 100% !important; 85 | height: 100% !important 86 | } 87 | 88 | .slider-navigation-next, 89 | .slider-navigation-previous { 90 | display: flex; 91 | justify-content: center; 92 | align-items: center; 93 | position: absolute; 94 | width: 42px; 95 | height: 42px; 96 | background: #fff center center no-repeat; 97 | background-size: 20px 20px; 98 | border: 1px solid #fff; 99 | border-radius: 25091983px; 100 | box-shadow: 0 2px 5px #3232321a; 101 | top: 50%; 102 | margin-top: -20px; 103 | left: 0; 104 | cursor: pointer; 105 | transition: opacity .3s, -webkit-transform .3s; 106 | transition: transform .3s, opacity .3s; 107 | transition: transform .3s, opacity .3s, -webkit-transform .3s 108 | } 109 | 110 | .slider-navigation-next:hover, 111 | .slider-navigation-previous:hover { 112 | -webkit-transform: scale(1.2); 113 | transform: scale(1.2) 114 | } 115 | 116 | .slider-navigation-next.is-hidden, 117 | .slider-navigation-previous.is-hidden { 118 | display: none; 119 | opacity: 0 120 | } 121 | 122 | .slider-navigation-next svg, 123 | .slider-navigation-previous svg { 124 | width: 25% 125 | } 126 | 127 | .slider-navigation-next { 128 | left: auto; 129 | right: 0; 130 | background: #fff center center no-repeat; 131 | background-size: 20px 20px 132 | } 133 | 134 | .slider-pagination { 135 | display: none; 136 | justify-content: center; 137 | align-items: center; 138 | position: absolute; 139 | bottom: 0; 140 | left: 0; 141 | right: 0; 142 | padding: .5rem 1rem; 143 | text-align: center 144 | } 145 | 146 | .slider-pagination .slider-page { 147 | background: #fff; 148 | width: 10px; 149 | height: 10px; 150 | border-radius: 25091983px; 151 | display: inline-block; 152 | margin: 0 3px; 153 | box-shadow: 0 2px 5px #3232321a; 154 | transition: -webkit-transform .3s; 155 | transition: transform .3s; 156 | transition: transform .3s, -webkit-transform .3s; 157 | cursor: pointer 158 | } 159 | 160 | .slider-pagination .slider-page.is-active, 161 | .slider-pagination .slider-page:hover { 162 | -webkit-transform: scale(1.4); 163 | transform: scale(1.4) 164 | } 165 | 166 | @media screen and (min-width:800px) { 167 | .slider-pagination { 168 | display: flex 169 | } 170 | } 171 | 172 | .hero.has-carousel { 173 | position: relative 174 | } 175 | 176 | .hero.has-carousel+.hero-body, 177 | .hero.has-carousel+.hero-footer, 178 | .hero.has-carousel+.hero-head { 179 | z-index: 10; 180 | overflow: hidden 181 | } 182 | 183 | .hero.has-carousel .hero-carousel { 184 | position: absolute; 185 | top: 0; 186 | left: 0; 187 | bottom: 0; 188 | right: 0; 189 | height: auto; 190 | border: none; 191 | margin: auto; 192 | padding: 0; 193 | z-index: 0 194 | } 195 | 196 | .hero.has-carousel .hero-carousel .slider { 197 | width: 100%; 198 | max-width: 100%; 199 | overflow: hidden; 200 | height: 100% !important; 201 | max-height: 100%; 202 | z-index: 0 203 | } 204 | 205 | .hero.has-carousel .hero-carousel .slider .has-background { 206 | max-height: 100% 207 | } 208 | 209 | .hero.has-carousel .hero-carousel .slider .has-background .is-background { 210 | -o-object-fit: cover; 211 | object-fit: cover; 212 | -o-object-position: center center; 213 | object-position: center center; 214 | height: 100%; 215 | width: 100% 216 | } 217 | 218 | .hero.has-carousel .hero-body { 219 | margin: 0 3rem; 220 | z-index: 10 221 | } -------------------------------------------------------------------------------- /validation_lq/data_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision import transforms 4 | from PIL import Image 5 | from torchvision.transforms import InterpolationMode 6 | 7 | class ListDatasetWithIndex(Dataset): 8 | def __init__(self, img_list, image_size, image_is_saved_with_swapped_B_and_R=True): 9 | super(ListDatasetWithIndex, self).__init__() 10 | 11 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 12 | # face_emore/img has images saved with B and G (of RGB) swapped. 13 | # Since training data loader uses PIL (results in RGB) to read image 14 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 15 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 16 | # then you should set image_is_saved_with_swapped_B_and_R=True 17 | 18 | self.img_list = img_list 19 | self.transform = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Resize(size=(image_size,image_size), interpolation=InterpolationMode.BICUBIC), 22 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 23 | ]) 24 | self.image_is_saved_with_swapped_B_and_R = image_is_saved_with_swapped_B_and_R 25 | 26 | def __len__(self): 27 | return len(self.img_list) 28 | 29 | def __getitem__(self, idx): 30 | 31 | if self.image_is_saved_with_swapped_B_and_R: 32 | with open(self.img_list[idx], 'rb') as f: 33 | img = Image.open(f) 34 | img = img.convert('RGB') 35 | img = self.transform(img) 36 | 37 | else: 38 | # ArcFace Pytorch 39 | img = cv2.imread(self.img_list[idx]) 40 | # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 41 | img = img[:,:,:3] 42 | 43 | img = Image.fromarray(img) 44 | # img = np.moveaxis(img, -1, 0) 45 | img = self.transform(img) 46 | return img, idx 47 | 48 | 49 | class ListDataset(Dataset): 50 | def __init__(self, img_list, image_size, image_is_saved_with_swapped_B_and_R=True): 51 | super(ListDataset, self).__init__() 52 | 53 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 54 | # face_emore/img has images saved with B and G (of RGB) swapped. 55 | # Since training data loader uses PIL (results in RGB) to read image 56 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 57 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 58 | # then you should set image_is_saved_with_swapped_B_and_R=True 59 | 60 | self.img_list = img_list 61 | self.transform = transforms.Compose( 62 | [transforms.ToTensor(), 63 | transforms.Resize(size=(image_size,image_size), interpolation=InterpolationMode.BICUBIC), 64 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 65 | 66 | self.image_is_saved_with_swapped_B_and_R = image_is_saved_with_swapped_B_and_R 67 | 68 | 69 | def __len__(self): 70 | return len(self.img_list) 71 | 72 | def __getitem__(self, idx): 73 | image_path = self.img_list[idx] 74 | img = cv2.imread(image_path) 75 | img = img[:, :, :3] 76 | 77 | if self.image_is_saved_with_swapped_B_and_R: 78 | print('check if it really should be on') 79 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 80 | 81 | img = Image.fromarray(img) 82 | img = self.transform(img) 83 | return img, idx 84 | 85 | 86 | def prepare_imagelist_dataloader(img_list, batch_size, image_size, num_workers=0, image_is_saved_with_swapped_B_and_R=True): 87 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 88 | # face_emore/img has images saved with B and G (of RGB) swapped. 89 | # Since training data loader uses PIL (results in RGB) to read image 90 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 91 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 92 | # then you should set image_is_saved_with_swapped_B_and_R=True 93 | 94 | image_dataset = ListDatasetWithIndex(img_list, image_size, image_is_saved_with_swapped_B_and_R) 95 | dataloader = DataLoader(image_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers) 96 | return dataloader 97 | 98 | 99 | def prepare_dataloader(img_list, batch_size, image_size, num_workers=0, image_is_saved_with_swapped_B_and_R=True): 100 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 101 | # face_emore/img has images saved with B and G (of RGB) swapped. 102 | # Since training data loader uses PIL (results in RGB) to read image 103 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 104 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 105 | # then you should set image_is_saved_with_swapped_B_and_R=True 106 | 107 | image_dataset = ListDataset(img_list, image_size, image_is_saved_with_swapped_B_and_R=image_is_saved_with_swapped_B_and_R) 108 | dataloader = DataLoader(image_dataset, 109 | batch_size=batch_size, 110 | shuffle=False, 111 | drop_last=False, 112 | num_workers=num_workers) 113 | return dataloader -------------------------------------------------------------------------------- /utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List 5 | 6 | import torch 7 | 8 | 9 | from utils.utils_logging import AverageMeter 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch import distributed 12 | import sys 13 | sys.path.insert(0, os.path.join(os.getcwd(), "utils")) 14 | import evaluate_utils 15 | 16 | 17 | class CallBackVerification(object): 18 | 19 | def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112), wandb_logger=None): 20 | self.rank: int = distributed.get_rank() 21 | self.highest_acc: float = 0.0 22 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 23 | self.ver_list: List[object] = [] 24 | self.ver_name_list: List[str] = [] 25 | if self.rank == 0: 26 | self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) 27 | 28 | self.summary_writer = summary_writer 29 | self.wandb_logger = wandb_logger 30 | 31 | def ver_test(self, backbone: torch.nn.Module, global_step: int): 32 | results = [] 33 | for i in range(len(self.ver_list)): 34 | acc1, std1, acc2, std2, xnorm, embeddings_list = evaluate_utils.test( 35 | self.ver_list[i], backbone, 10, 10) 36 | logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) 37 | logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) 38 | 39 | self.summary_writer: SummaryWriter 40 | self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, ) 41 | if self.wandb_logger: 42 | import wandb 43 | self.wandb_logger.log({ 44 | f'Acc/val-Acc1 {self.ver_name_list[i]}': acc1, 45 | f'Acc/val-Acc2 {self.ver_name_list[i]}': acc2, 46 | # f'Acc/val-std1 {self.ver_name_list[i]}': std1, 47 | # f'Acc/val-std2 {self.ver_name_list[i]}': acc2, 48 | }) 49 | 50 | if acc2 > self.highest_acc_list[i]: 51 | self.highest_acc_list[i] = acc2 52 | logging.info( 53 | '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) 54 | results.append(acc2) 55 | 56 | def init_dataset(self, val_targets, data_dir, image_size): 57 | for name in val_targets: 58 | path = os.path.join(data_dir, name + ".bin") 59 | if os.path.exists(path): 60 | data_set = evaluate_utils.load_bin(path, image_size) 61 | self.ver_list.append(data_set) 62 | self.ver_name_list.append(name) 63 | 64 | def __call__(self, num_update, backbone: torch.nn.Module): 65 | if self.rank == 0 and num_update > 0: 66 | backbone.eval() 67 | self.ver_test(backbone, num_update) 68 | backbone.train() 69 | 70 | 71 | class CallBackLogging(object): 72 | def __init__(self, frequent, total_step, batch_size, start_step=0,writer=None): 73 | self.frequent: int = frequent 74 | self.rank: int = distributed.get_rank() 75 | self.world_size: int = distributed.get_world_size() 76 | self.time_start = time.time() 77 | self.total_step: int = total_step 78 | self.start_step: int = start_step 79 | self.batch_size: int = batch_size 80 | self.writer = writer 81 | 82 | self.init = False 83 | self.tic = 0 84 | 85 | def __call__(self, 86 | global_step: int, 87 | loss: AverageMeter, 88 | epoch: int, 89 | fp16: bool, 90 | learning_rate: float, 91 | grad_scaler: torch.cuda.amp.GradScaler): 92 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 93 | if self.init: 94 | try: 95 | speed: float = self.frequent * self.batch_size / (time.time() - self.tic) 96 | speed_total = speed * self.world_size 97 | except ZeroDivisionError: 98 | speed_total = float('inf') 99 | 100 | time_now = time.time() 101 | time_sec = int(time_now - self.time_start) 102 | time_sec_avg = time_sec / (global_step - self.start_step + 1) 103 | eta_sec = time_sec_avg * (self.total_step - global_step - 1) 104 | time_for_end = eta_sec/3600 105 | if self.writer is not None: 106 | self.writer.add_scalar('time_for_end', time_for_end, global_step) 107 | self.writer.add_scalar('learning_rate', learning_rate, global_step) 108 | self.writer.add_scalar('loss', loss.avg, global_step) 109 | if fp16: 110 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ 111 | "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( 112 | speed_total, loss.avg, learning_rate, epoch, global_step, 113 | grad_scaler.get_scale(), time_for_end 114 | ) 115 | else: 116 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ 117 | "Required: %1.f hours" % ( 118 | speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end 119 | ) 120 | logging.info(msg) 121 | loss.reset() 122 | self.tic = time.time() 123 | else: 124 | self.init = True 125 | self.tic = time.time() 126 | -------------------------------------------------------------------------------- /validation_lq/validate_ijbs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import data_utils 5 | import argparse 6 | import pandas as pd 7 | import evaluate_helper 8 | import sys, os 9 | sys.path.insert(0, os.getcwd()) 10 | from backbones import get_model 11 | 12 | 13 | def str2bool(v): 14 | return v.lower() in ("true", "t", "1") 15 | 16 | 17 | def load_pretrained_model(model, model_name, gpu, lora_rank, lora_scale, use_lora): 18 | # load model and pretrained statedict 19 | ckpt_path = model_name 20 | model = get_model(model, dropout=0.0, fp16=False, num_features=512, r=lora_rank, scale=lora_scale, use_lora=use_lora) 21 | 22 | # model = net.build_model(arch) 23 | statedict = torch.load(ckpt_path, map_location=torch.device('cuda:' + str(gpu))) 24 | model.load_state_dict(statedict) 25 | model.eval() 26 | return model 27 | 28 | def get_save_path(model_load_path): 29 | directory, _ = os.path.split(model_load_path) 30 | results_save_path = os.path.join(directory, 'results') 31 | 32 | return results_save_path 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser(description='') 36 | parser.add_argument("--data_root", type=str, default='/mnt/store/knaraya4/data/ijbs_aligned_180') 37 | parser.add_argument('--model_type', type=str, default='r50') 38 | parser.add_argument('--model_load_path', type=str, default='ijcb_cosface') 39 | parser.add_argument('--batch_size', type=int, default=512) 40 | parser.add_argument('--image_size', type=int, default=120) 41 | parser.add_argument('--lora_rank', type=int, default=8) 42 | parser.add_argument('--lora_scale', type=int, default=1) 43 | parser.add_argument('--use_lora', action='store_true') 44 | parser.add_argument('--gpu', default=0, type=int, help='gpu id') 45 | parser.add_argument('--fuse_match_method', type=str, default='pre_norm_vector_add_cos', 46 | choices=('pre_norm_vector_add_cos')) 47 | parser.add_argument('--save_features', type=bool, default=True) 48 | args = parser.parse_args() 49 | 50 | # load model 51 | model_load_path = args.model_load_path 52 | model = load_pretrained_model(args.model_type, model_load_path, args.gpu, args.lora_rank, args.lora_scale, args.use_lora) 53 | model.to('cuda:{}'.format(args.gpu)) 54 | 55 | # make result save root 56 | save_root = get_save_path(model_load_path) 57 | os.makedirs(save_root, exist_ok=True) 58 | image_path_df = pd.read_csv('/mnt/store/knaraya4/data/IJBS/image_paths_180.csv', index_col=0) 59 | all_image_paths = image_path_df['path'].apply(lambda x:os.path.join(args.data_root, x)).tolist() 60 | 61 | num_partition = 100 62 | dataset_split = np.array_split(all_image_paths, num_partition) 63 | 64 | print('total {} images'.format(len(all_image_paths))) 65 | all_features = [] 66 | for partition_idx in tqdm(range(num_partition)): 67 | 68 | image_paths = list(dataset_split[partition_idx]) 69 | dataloader = data_utils.prepare_imagelist_dataloader(image_paths, batch_size=args.batch_size, image_size=args.image_size, num_workers=8) 70 | 71 | size = len(dataloader.dataset) 72 | num_batches = len(dataloader) 73 | model.eval() 74 | 75 | features = [] 76 | norms = [] 77 | prev_max_idx = 0 78 | with torch.no_grad(): 79 | for iter_idx, (img, idx) in enumerate(dataloader): 80 | assert idx.max().item() > prev_max_idx 81 | prev_max_idx = idx.max().item() # order shifting by dataloader checking 82 | if iter_idx % 100 == 0: 83 | print(f"{iter_idx} / {len(dataloader)} done") 84 | feature = model(img.to("cuda:0")) 85 | 86 | if isinstance(feature, tuple) and len(feature) == 2: 87 | feature, norm = feature 88 | features.append(feature.cpu().numpy()) 89 | norms.append(norm.cpu().numpy()) 90 | else: 91 | norm = torch.norm(feature, 2, 1, True) 92 | features.append(feature.cpu().numpy()) 93 | norms.append(norm.cpu().numpy()) 94 | 95 | features = np.concatenate(features, axis=0) 96 | if args.save_features: 97 | save_path = os.path.join(save_root, 'feature_extracted/ijbs_pred_{}_{}.npy'.format(args.model_type, partition_idx)) 98 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 99 | np.save(save_path, features) 100 | 101 | if len(norms) > 0: 102 | norms = np.concatenate(norms, axis=0) 103 | if args.save_features: 104 | save_path = os.path.join(save_root, 'feature_extracted/ijbs_pred_{}_norm_{}.npy'.format(args.model_type, partition_idx)) 105 | np.save(save_path, norms) 106 | 107 | #### Resume ### 108 | # features = np.load(os.path.join(save_root, 'feature_extracted/ijbs_pred_{}_{}.npy'.format(args.model_type, partition_idx))) 109 | # norms = np.load(os.path.join(save_root, 'feature_extracted/ijbs_pred_{}_norm_{}.npy'.format(args.model_type, partition_idx))) 110 | 111 | 112 | if args.fuse_match_method == 'pre_norm_vector_add_cos': 113 | features = features * norms 114 | all_features.append(features) 115 | all_features = np.concatenate(all_features, axis=0) 116 | 117 | # prepare savedir 118 | os.makedirs(os.path.join(save_root, 'eval_result'), exist_ok=True) 119 | # evaluate 120 | evaluate_helper.run_eval_with_features(save_root=save_root, 121 | features=all_features, 122 | image_paths=all_image_paths, 123 | get_retrievals=True, 124 | fuse_match_method=args.fuse_match_method, 125 | ijbs_proto_path=None) -------------------------------------------------------------------------------- /backbones/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module 8 | import torch 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class ConvBlock(Module): 17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 18 | super(ConvBlock, self).__init__() 19 | self.layers = nn.Sequential( 20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), 21 | BatchNorm2d(num_features=out_c), 22 | PReLU(num_parameters=out_c) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.layers(x) 27 | 28 | 29 | class LinearBlock(Module): 30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 31 | super(LinearBlock, self).__init__() 32 | self.layers = nn.Sequential( 33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), 34 | BatchNorm2d(num_features=out_c) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.layers(x) 39 | 40 | 41 | class DepthWise(Module): 42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 43 | super(DepthWise, self).__init__() 44 | self.residual = residual 45 | self.layers = nn.Sequential( 46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), 47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), 48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 49 | ) 50 | 51 | def forward(self, x): 52 | short_cut = None 53 | if self.residual: 54 | short_cut = x 55 | x = self.layers(x) 56 | if self.residual: 57 | output = short_cut + x 58 | else: 59 | output = x 60 | return output 61 | 62 | 63 | class Residual(Module): 64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 65 | super(Residual, self).__init__() 66 | modules = [] 67 | for _ in range(num_block): 68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) 69 | self.layers = Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class GDC(Module): 76 | def __init__(self, embedding_size): 77 | super(GDC, self).__init__() 78 | self.layers = nn.Sequential( 79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), 80 | Flatten(), 81 | Linear(512, embedding_size, bias=False), 82 | BatchNorm1d(embedding_size)) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | 87 | 88 | class MobileFaceNet(Module): 89 | def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2): 90 | super(MobileFaceNet, self).__init__() 91 | self.scale = scale 92 | self.fp16 = fp16 93 | self.layers = nn.ModuleList() 94 | self.layers.append( 95 | ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 96 | ) 97 | if blocks[0] == 1: 98 | self.layers.append( 99 | ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 100 | ) 101 | else: 102 | self.layers.append( 103 | Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 104 | ) 105 | 106 | self.layers.extend( 107 | [ 108 | DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), 109 | Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 110 | DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), 111 | Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 112 | DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), 113 | Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 114 | ]) 115 | 116 | self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 117 | self.features = GDC(num_features) 118 | self._initialize_weights() 119 | 120 | def _initialize_weights(self): 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.Linear): 130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 131 | if m.bias is not None: 132 | m.bias.data.zero_() 133 | 134 | def forward(self, x): 135 | with torch.cuda.amp.autocast(self.fp16): 136 | for func in self.layers: 137 | x = func(x) 138 | x = self.conv_sep(x.float() if self.fp16 else x) 139 | x = self.features(x) 140 | return x 141 | 142 | 143 | def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2): 144 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 145 | 146 | def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4): 147 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 148 | -------------------------------------------------------------------------------- /validation_lq/validate_ijbs_iqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import data_utils 5 | import argparse 6 | import pandas as pd 7 | import evaluate_helper 8 | import sys, os 9 | sys.path.insert(0, os.getcwd()) 10 | from backbones import get_model 11 | import pyiqa 12 | 13 | def str2bool(v): 14 | return v.lower() in ("true", "t", "1") 15 | 16 | 17 | def load_pretrained_model(model, model_name, gpus, lora_rank, lora_scale, use_lora): 18 | # load model and pretrained statedict 19 | ckpt_path = model_name 20 | model = get_model(model, dropout=0.0, fp16=False, num_features=512, r=lora_rank, scale=lora_scale, use_lora=use_lora) 21 | 22 | statedict = torch.load(ckpt_path, map_location=torch.device('cuda:' + str(gpus[0]))) 23 | model.load_state_dict(statedict) 24 | model = torch.nn.DataParallel(model, device_ids=gpus) 25 | model.eval() 26 | return model 27 | 28 | 29 | def get_save_path(model_load_path): 30 | directory, _ = os.path.split(model_load_path) 31 | results_save_path = os.path.join(directory, 'results') 32 | 33 | return results_save_path 34 | 35 | def generate_alpha(img, iqa, thresh): 36 | device = img.device 37 | BS, C, H, W = img.shape 38 | alpha = torch.zeros((BS, 1), dtype=torch.float32, device=device) 39 | 40 | score = iqa(img) 41 | threshold = thresh 42 | for i in range(BS): 43 | if score[i] == threshold: 44 | alpha[i] = 0.5 45 | elif score[i] < threshold: 46 | alpha[i] = 0.5 - (threshold - score[i]) 47 | else: 48 | alpha[i] = 0.5 + (score[i] - threshold) 49 | return alpha 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser(description='') 53 | parser.add_argument("--data_root", type=str, default='/mnt/store/knaraya4/data/ijbs_aligned_180') 54 | parser.add_argument('--model_type', type=str, default='r50') 55 | parser.add_argument('--model_load_path', type=str, default='ijcb_cosface') 56 | parser.add_argument('--batch_size', type=int, default=512) 57 | parser.add_argument('--image_size', type=int, default=120) 58 | parser.add_argument('--lora_rank', type=int, default=8) 59 | parser.add_argument('--lora_scale', type=int, default=1) 60 | parser.add_argument('--use_lora', action='store_true') 61 | parser.add_argument('--gpus', nargs='+', type=int, default=[0], help='List of GPU ids') 62 | parser.add_argument('--fuse_match_method', type=str, default='pre_norm_vector_add_cos', 63 | choices=('pre_norm_vector_add_cos')) 64 | parser.add_argument('--save_features', type=bool, default=True) 65 | parser.add_argument('--csv_path', type=str, default='/mnt/store/knaraya4/data/IJBS/image_paths_180.csv') 66 | parser.add_argument('--iqa', type=str, default='brisque') 67 | parser.add_argument('--threshold', type=float, default=0.5) 68 | args = parser.parse_args() 69 | 70 | # load model 71 | model_load_path = args.model_load_path 72 | model = load_pretrained_model(args.model_type, model_load_path, args.gpus, args.lora_rank, args.lora_scale, args.use_lora) 73 | model.to('cuda:{}'.format(args.gpus[0])) 74 | 75 | # make result save root 76 | save_root = get_save_path(model_load_path) 77 | os.makedirs(save_root, exist_ok=True) 78 | image_path_df = pd.read_csv(args.csv_path, index_col=0) 79 | all_image_paths = image_path_df['path'].apply(lambda x: os.path.join(args.data_root, x)).tolist() 80 | 81 | num_partition = 100 82 | dataset_split = np.array_split(all_image_paths, num_partition) 83 | 84 | print('total {} images'.format(len(all_image_paths))) 85 | all_features = [] 86 | device = "cuda:" + str(gpu) 87 | if iqa == "brisque": 88 | iqa = pyiqa.create_metric('brisque').cuda() 89 | elif iqa == "cnniqa": 90 | iqa = pyiqa.create_metric('cnniqa').cuda() 91 | threshold = args.threshold 92 | for partition_idx in tqdm(range(num_partition)): 93 | 94 | image_paths = list(dataset_split[partition_idx]) 95 | dataloader = data_utils.prepare_imagelist_dataloader(image_paths, batch_size=args.batch_size, image_size=args.image_size, num_workers=8) 96 | 97 | size = len(dataloader.dataset) 98 | num_batches = len(dataloader) 99 | model.eval() 100 | 101 | features = [] 102 | norms = [] 103 | prev_max_idx = 0 104 | with torch.no_grad(): 105 | for iter_idx, (img, idx) in enumerate(dataloader): 106 | assert idx.max().item() > prev_max_idx 107 | prev_max_idx = idx.max().item() # order shifting by dataloader checking 108 | if iter_idx % 100 == 0: 109 | print(f"{iter_idx} / {len(dataloader)} done") 110 | img = img.to('cuda:{}'.format(args.gpus[0]), non_blocking=True) 111 | alpha = generate_alpha(img, iqa, threshold) 112 | feature = model(img, alpha) 113 | 114 | if isinstance(feature, tuple) and len(feature) == 2: 115 | feature, norm = feature 116 | features.append(feature.cpu().numpy()) 117 | norms.append(norm.cpu().numpy()) 118 | else: 119 | norm = torch.norm(feature, 2, 1, True) 120 | features.append(feature.cpu().numpy()) 121 | norms.append(norm.cpu().numpy()) 122 | 123 | features = np.concatenate(features, axis=0) 124 | if args.save_features: 125 | save_path = os.path.join(save_root, 'feature_extracted/ijbs_pred_{}_{}.npy'.format(args.model_type, partition_idx)) 126 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 127 | np.save(save_path, features) 128 | 129 | if len(norms) > 0: 130 | norms = np.concatenate(norms, axis=0) 131 | if args.save_features: 132 | save_path = os.path.join(save_root, 'feature_extracted/ijbs_pred_{}_norm_{}.npy'.format(args.model_type, partition_idx)) 133 | np.save(save_path, norms) 134 | 135 | if args.fuse_match_method == 'pre_norm_vector_add_cos': 136 | features = features * norms 137 | all_features.append(features) 138 | all_features = np.concatenate(all_features, axis=0) 139 | 140 | # prepare savedir 141 | os.makedirs(os.path.join(save_root, 'eval_result'), exist_ok=True) 142 | # evaluate 143 | evaluate_helper.run_eval_with_features(save_root=save_root, 144 | features=all_features, 145 | image_paths=all_image_paths, 146 | get_retrievals=True, 147 | fuse_match_method=args.fuse_match_method, 148 | ijbs_proto_path=None) -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | r = kwargs.pop("r", 4) 9 | scale = kwargs.pop("scale", 1) 10 | use_lora = kwargs.pop('use_lora', False) 11 | return iresnet18(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs) 12 | elif name == "r34": 13 | r = kwargs.pop("r", 4) 14 | scale = kwargs.pop("scale", 1) 15 | use_lora = kwargs.pop('use_lora', False) 16 | return iresnet34(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs) 17 | elif name == "r50": 18 | r = kwargs.pop("r", 4) 19 | scale = kwargs.pop("scale", 1) 20 | use_lora = kwargs.pop('use_lora', False) 21 | return iresnet50(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs) 22 | elif name == "r100": 23 | r = kwargs.pop("r", 4) 24 | scale = kwargs.pop("scale", 1) 25 | use_lora = kwargs.pop('use_lora', False) 26 | return iresnet100(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs) 27 | elif name == "r200": 28 | r = kwargs.pop("r", 4) 29 | scale = kwargs.pop("scale", 1) 30 | use_lora = kwargs.pop('use_lora', False) 31 | return iresnet200(lora_rank=r, lora_scale=scale, pretrained=False, progress=True, use_lora=use_lora, **kwargs) 32 | elif name == "r2060": 33 | from .iresnet2060 import iresnet2060 34 | return iresnet2060(False, **kwargs) 35 | 36 | elif name == "mbf": 37 | fp16 = kwargs.get("fp16", False) 38 | num_features = kwargs.get("num_features", 512) 39 | return get_mbf(fp16=fp16, num_features=num_features) 40 | 41 | elif name == "mbf_large": 42 | from .mobilefacenet import get_mbf_large 43 | fp16 = kwargs.get("fp16", False) 44 | num_features = kwargs.get("num_features", 512) 45 | return get_mbf_large(fp16=fp16, num_features=num_features) 46 | 47 | elif name == "vit_t": 48 | num_features = kwargs.get("num_features", 512) 49 | from .vit import VisionTransformer 50 | return VisionTransformer( 51 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 52 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 53 | 54 | elif name == "vit_t_dp005_mask0": # For WebFace42M 55 | num_features = kwargs.get("num_features", 512) 56 | from .vit import VisionTransformer 57 | return VisionTransformer( 58 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 59 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 60 | 61 | elif name == "vit_s": 62 | num_features = kwargs.get("num_features", 512) 63 | from .vit import VisionTransformer 64 | return VisionTransformer( 65 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 66 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 67 | 68 | elif name == "vit_s_dp005_mask_0": # For WebFace42M 69 | num_features = kwargs.get("num_features", 512) 70 | from .vit import VisionTransformer 71 | return VisionTransformer( 72 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 73 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 74 | 75 | elif name == "vit_b": 76 | # this is a feature 77 | num_features = kwargs.get("num_features", 512) 78 | r = kwargs.pop("r", 4) 79 | scale = kwargs.pop("scale", 1) 80 | use_lora = kwargs.pop('use_lora', False) 81 | from .vit import VisionTransformer 82 | return VisionTransformer( 83 | lora_rank=r, lora_scale=scale, img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 84 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True, use_lora=use_lora) 85 | 86 | elif name == "vit_b_iqa": 87 | # this is a feature 88 | num_features = kwargs.get("num_features", 512) 89 | r = kwargs.pop("r", 4) 90 | scale = kwargs.pop("scale", 1) 91 | use_lora = kwargs.pop('use_lora', False) 92 | from .vit_iqa import VisionTransformer 93 | return VisionTransformer( 94 | lora_rank=r, lora_scale=scale, img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 95 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True, use_lora=use_lora) 96 | 97 | elif name == "vit_b_dp005_mask_005": # For WebFace42M 98 | # this is a feature 99 | num_features = kwargs.get("num_features", 512) 100 | from .vit import VisionTransformer 101 | return VisionTransformer( 102 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 103 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 104 | 105 | elif name == "vit_l_dp005_mask_005": # For WebFace42M 106 | # this is a feature 107 | num_features = kwargs.get("num_features", 512) 108 | from .vit import VisionTransformer 109 | return VisionTransformer( 110 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24, 111 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 112 | 113 | elif name == "vit_h": # For WebFace42M 114 | num_features = kwargs.get("num_features", 512) 115 | from .vit import VisionTransformer 116 | return VisionTransformer( 117 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=1024, depth=48, 118 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0, using_checkpoint=True) 119 | 120 | elif name=="swin_256new": 121 | num_features = kwargs.get("num_features", 512) 122 | r = kwargs.pop("r", 4) 123 | scale = kwargs.pop("scale", 1) 124 | use_lora = kwargs.pop('use_lora', False) 125 | from .swin_models import SwinTransformer 126 | kwargs['reso']=120 127 | return SwinTransformer(lora_rank=r, lora_scale=scale, img_size=120, patch_size=6, in_chans=3, num_classes=512, 128 | embed_dim=384, depths=[2,18,2], num_heads=[ 8, 16,16], 129 | window_size=5, use_lora=use_lora, **kwargs) 130 | 131 | elif name=="swin_256new_iqa": 132 | num_features = kwargs.get("num_features", 512) 133 | r = kwargs.pop("r", 4) 134 | scale = kwargs.pop("scale", 1) 135 | use_lora = kwargs.pop('use_lora', False) 136 | from .swin_models_iqa import SwinTransformer 137 | kwargs['reso']=120 138 | return SwinTransformer(lora_rank=r, lora_scale=scale, img_size=120, patch_size=6, in_chans=3, num_classes=512, 139 | embed_dim=384, depths=[2,18,2], num_heads=[ 8, 16,16], 140 | window_size=5, use_lora=use_lora, **kwargs) 141 | 142 | 143 | else: 144 | raise ValueError() 145 | -------------------------------------------------------------------------------- /backbones/iresnet2060.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | assert torch.__version__ >= "1.8.1" 5 | from torch.utils.checkpoint import checkpoint_sequential 6 | 7 | __all__ = ['iresnet2060'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=dilation, 17 | groups=groups, 18 | bias=False, 19 | dilation=dilation) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, 25 | out_planes, 26 | kernel_size=1, 27 | stride=stride, 28 | bias=False) 29 | 30 | 31 | class IBasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, 35 | groups=1, base_width=64, dilation=1): 36 | super(IBasicBlock, self).__init__() 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) 42 | self.conv1 = conv3x3(inplanes, planes) 43 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) 44 | self.prelu = nn.PReLU(planes) 45 | self.conv2 = conv3x3(planes, planes, stride) 46 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | out = self.bn1(x) 53 | out = self.conv1(out) 54 | out = self.bn2(out) 55 | out = self.prelu(out) 56 | out = self.conv2(out) 57 | out = self.bn3(out) 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | out += identity 61 | return out 62 | 63 | 64 | class IResNet(nn.Module): 65 | fc_scale = 7 * 7 66 | 67 | def __init__(self, 68 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 69 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 70 | super(IResNet, self).__init__() 71 | self.fp16 = fp16 72 | self.inplanes = 64 73 | self.dilation = 1 74 | if replace_stride_with_dilation is None: 75 | replace_stride_with_dilation = [False, False, False] 76 | if len(replace_stride_with_dilation) != 3: 77 | raise ValueError("replace_stride_with_dilation should be None " 78 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 79 | self.groups = groups 80 | self.base_width = width_per_group 81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 83 | self.prelu = nn.PReLU(self.inplanes) 84 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 85 | self.layer2 = self._make_layer(block, 86 | 128, 87 | layers[1], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[0]) 90 | self.layer3 = self._make_layer(block, 91 | 256, 92 | layers[2], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[1]) 95 | self.layer4 = self._make_layer(block, 96 | 512, 97 | layers[3], 98 | stride=2, 99 | dilate=replace_stride_with_dilation[2]) 100 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) 101 | self.dropout = nn.Dropout(p=dropout, inplace=True) 102 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 103 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 104 | nn.init.constant_(self.features.weight, 1.0) 105 | self.features.weight.requires_grad = False 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.normal_(m.weight, 0, 0.1) 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | if zero_init_residual: 115 | for m in self.modules(): 116 | if isinstance(m, IBasicBlock): 117 | nn.init.constant_(m.bn2.weight, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 120 | downsample = None 121 | previous_dilation = self.dilation 122 | if dilate: 123 | self.dilation *= stride 124 | stride = 1 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 129 | ) 130 | layers = [] 131 | layers.append( 132 | block(self.inplanes, planes, stride, downsample, self.groups, 133 | self.base_width, previous_dilation)) 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append( 137 | block(self.inplanes, 138 | planes, 139 | groups=self.groups, 140 | base_width=self.base_width, 141 | dilation=self.dilation)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def checkpoint(self, func, num_seg, x): 146 | if self.training: 147 | return checkpoint_sequential(func, num_seg, x) 148 | else: 149 | return func(x) 150 | 151 | def forward(self, x): 152 | with torch.cuda.amp.autocast(self.fp16): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.prelu(x) 156 | x = self.layer1(x) 157 | x = self.checkpoint(self.layer2, 20, x) 158 | x = self.checkpoint(self.layer3, 100, x) 159 | x = self.layer4(x) 160 | x = self.bn2(x) 161 | x = torch.flatten(x, 1) 162 | x = self.dropout(x) 163 | x = self.fc(x.float() if self.fp16 else x) 164 | x = self.features(x) 165 | return x 166 | 167 | 168 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 169 | model = IResNet(block, layers, **kwargs) 170 | if pretrained: 171 | raise ValueError() 172 | return model 173 | 174 | 175 | def iresnet2060(pretrained=False, progress=True, **kwargs): 176 | return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) 177 | -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function (t, e) { "object" == typeof exports && "object" == typeof module ? module.exports = e() : "function" == typeof define && define.amd ? define([], e) : "object" == typeof exports ? exports.bulmaSlider = e() : t.bulmaSlider = e() }("undefined" != typeof self ? self : this, function () { return function (n) { var r = {}; function i(t) { if (r[t]) return r[t].exports; var e = r[t] = { i: t, l: !1, exports: {} }; return n[t].call(e.exports, e, e.exports, i), e.l = !0, e.exports } return i.m = n, i.c = r, i.d = function (t, e, n) { i.o(t, e) || Object.defineProperty(t, e, { configurable: !1, enumerable: !0, get: n }) }, i.n = function (t) { var e = t && t.__esModule ? function () { return t.default } : function () { return t }; return i.d(e, "a", e), e }, i.o = function (t, e) { return Object.prototype.hasOwnProperty.call(t, e) }, i.p = "", i(i.s = 0) }([function (t, e, n) { "use strict"; Object.defineProperty(e, "__esModule", { value: !0 }), n.d(e, "isString", function () { return l }); var r = n(1), i = Object.assign || function (t) { for (var e = 1; e < arguments.length; e++) { var n = arguments[e]; for (var r in n) Object.prototype.hasOwnProperty.call(n, r) && (t[r] = n[r]) } return t }, u = function () { function r(t, e) { for (var n = 0; n < e.length; n++) { var r = e[n]; r.enumerable = r.enumerable || !1, r.configurable = !0, "value" in r && (r.writable = !0), Object.defineProperty(t, r.key, r) } } return function (t, e, n) { return e && r(t.prototype, e), n && r(t, n), t } }(), o = "function" == typeof Symbol && "symbol" == typeof Symbol.iterator ? function (t) { return typeof t } : function (t) { return t && "function" == typeof Symbol && t.constructor === Symbol && t !== Symbol.prototype ? "symbol" : typeof t }; var l = function (t) { return "string" == typeof t || !!t && "object" === (void 0 === t ? "undefined" : o(t)) && "[object String]" === Object.prototype.toString.call(t) }, a = function (t) { function o(t) { var e = 1 < arguments.length && void 0 !== arguments[1] ? arguments[1] : {}; !function (t, e) { if (!(t instanceof e)) throw new TypeError("Cannot call a class as a function") }(this, o); var n = function (t, e) { if (!t) throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); return !e || "object" != typeof e && "function" != typeof e ? t : e }(this, (o.__proto__ || Object.getPrototypeOf(o)).call(this)); if (n.element = "string" == typeof t ? document.querySelector(t) : t, !n.element) throw new Error("An invalid selector or non-DOM node has been provided."); return n._clickEvents = ["click"], n.options = i({}, e), n.onSliderInput = n.onSliderInput.bind(n), n.init(), n } return function (t, e) { if ("function" != typeof e && null !== e) throw new TypeError("Super expression must either be null or a function, not " + typeof e); t.prototype = Object.create(e && e.prototype, { constructor: { value: t, enumerable: !1, writable: !0, configurable: !0 } }), e && (Object.setPrototypeOf ? Object.setPrototypeOf(t, e) : t.__proto__ = e) }(o, r["a"]), u(o, [{ key: "init", value: function () { if (this._id = "bulmaSlider" + (new Date).getTime() + Math.floor(Math.random() * Math.floor(9999)), this.output = this._findOutputForSlider(), this._bindEvents(), this.output && this.element.classList.contains("has-output-tooltip")) { var t = this._getSliderOutputPosition(); this.output.style.left = t.position } this.emit("bulmaslider:ready", this.element.value) } }, { key: "_findOutputForSlider", value: function () { var e = this, n = null, t = document.getElementsByTagName("output") || []; return Array.from(t).forEach(function (t) { if (t.htmlFor == e.element.getAttribute("id")) return n = t, !0 }), n } }, { key: "_getSliderOutputPosition", value: function () { var t, e = window.getComputedStyle(this.element, null), n = parseInt(e.getPropertyValue("width"), 10); t = this.element.getAttribute("min") ? this.element.getAttribute("min") : 0; var r = (this.element.value - t) / (this.element.getAttribute("max") - t); return { position: (r < 0 ? 0 : 1 < r ? n : n * r) + "px" } } }, { key: "_bindEvents", value: function () { this.output && this.element.addEventListener("input", this.onSliderInput, !1) } }, { key: "onSliderInput", value: function (t) { if (t.preventDefault(), this.element.classList.contains("has-output-tooltip")) { var e = this._getSliderOutputPosition(); this.output.style.left = e.position } var n = this.output.hasAttribute("data-prefix") ? this.output.getAttribute("data-prefix") : "", r = this.output.hasAttribute("data-postfix") ? this.output.getAttribute("data-postfix") : ""; this.output.value = n + this.element.value + r, this.emit("bulmaslider:ready", this.element.value) } }], [{ key: "attach", value: function () { var n = this, t = 0 < arguments.length && void 0 !== arguments[0] ? arguments[0] : 'input[type="range"].slider', r = 1 < arguments.length && void 0 !== arguments[1] ? arguments[1] : {}, i = new Array; return (l(t) ? document.querySelectorAll(t) : Array.isArray(t) ? t : [t]).forEach(function (t) { if (void 0 === t[n.constructor.name]) { var e = new o(t, r); t[n.constructor.name] = e, i.push(e) } else i.push(t[n.constructor.name]) }), i } }]), o }(); e.default = a }, function (t, e, n) { "use strict"; var r = function () { function r(t, e) { for (var n = 0; n < e.length; n++) { var r = e[n]; r.enumerable = r.enumerable || !1, r.configurable = !0, "value" in r && (r.writable = !0), Object.defineProperty(t, r.key, r) } } return function (t, e, n) { return e && r(t.prototype, e), n && r(t, n), t } }(); var i = function () { function e() { var t = 0 < arguments.length && void 0 !== arguments[0] ? arguments[0] : []; !function (t, e) { if (!(t instanceof e)) throw new TypeError("Cannot call a class as a function") }(this, e), this._listeners = new Map(t), this._middlewares = new Map } return r(e, [{ key: "listenerCount", value: function (t) { return this._listeners.has(t) ? this._listeners.get(t).length : 0 } }, { key: "removeListeners", value: function () { var e = this, t = 0 < arguments.length && void 0 !== arguments[0] ? arguments[0] : null, n = 1 < arguments.length && void 0 !== arguments[1] && arguments[1]; null !== t ? Array.isArray(t) ? name.forEach(function (t) { return e.removeListeners(t, n) }) : (this._listeners.delete(t), n && this.removeMiddleware(t)) : this._listeners = new Map } }, { key: "middleware", value: function (t, e) { var n = this; Array.isArray(t) ? name.forEach(function (t) { return n.middleware(t, e) }) : (Array.isArray(this._middlewares.get(t)) || this._middlewares.set(t, []), this._middlewares.get(t).push(e)) } }, { key: "removeMiddleware", value: function () { var e = this, t = 0 < arguments.length && void 0 !== arguments[0] ? arguments[0] : null; null !== t ? Array.isArray(t) ? name.forEach(function (t) { return e.removeMiddleware(t) }) : this._middlewares.delete(t) : this._middlewares = new Map } }, { key: "on", value: function (t, e) { var n = this, r = 2 < arguments.length && void 0 !== arguments[2] && arguments[2]; if (Array.isArray(t)) t.forEach(function (t) { return n.on(t, e) }); else { var i = (t = t.toString()).split(/,|, | /); 1 < i.length ? i.forEach(function (t) { return n.on(t, e) }) : (Array.isArray(this._listeners.get(t)) || this._listeners.set(t, []), this._listeners.get(t).push({ once: r, callback: e })) } } }, { key: "once", value: function (t, e) { this.on(t, e, !0) } }, { key: "emit", value: function (n, r) { var i = this, o = 2 < arguments.length && void 0 !== arguments[2] && arguments[2]; n = n.toString(); var u = this._listeners.get(n), l = null, a = 0, s = o; if (Array.isArray(u)) for (u.forEach(function (t, e) { o || (l = i._middlewares.get(n), Array.isArray(l) ? (l.forEach(function (t) { t(r, function () { var t = 0 < arguments.length && void 0 !== arguments[0] ? arguments[0] : null; null !== t && (r = t), a++ }, n) }), a >= l.length && (s = !0)) : s = !0), s && (t.once && (u[e] = null), t.callback(r)) }); -1 !== u.indexOf(null);)u.splice(u.indexOf(null), 1) } }]), e }(); e.a = i }]).default }); -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: petalface 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - _sysroot_linux-64_curr_repodata_hack=3=haa98f57_10 11 | - asttokens=2.2.1=pyhd8ed1ab_0 12 | - backcall=0.2.0=pyh9f0ad1d_0 13 | - backports=1.0=pyhd8ed1ab_3 14 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 15 | - binutils_impl_linux-64=2.38=h2a08ee3_1 16 | - binutils_linux-64=2.38.0=hc2dff05_0 17 | - blas=1.0=mkl 18 | - brotli=1.0.9=h166bdaf_7 19 | - brotli-bin=1.0.9=h166bdaf_7 20 | - bzip2=1.0.8=h7b6447c_0 21 | - ca-certificates=2023.08.22=h06a4308_0 22 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 23 | - cuda=11.7.1=0 24 | - cuda-cccl=11.7.91=0 25 | - cuda-command-line-tools=11.7.1=0 26 | - cuda-compiler=11.7.1=0 27 | - cuda-cudart=11.7.99=0 28 | - cuda-cudart-dev=11.7.99=0 29 | - cuda-cuobjdump=11.7.91=0 30 | - cuda-cupti=11.7.101=0 31 | - cuda-cuxxfilt=11.7.91=0 32 | - cuda-demo-suite=11.8.86=0 33 | - cuda-documentation=11.8.86=0 34 | - cuda-driver-dev=11.7.99=0 35 | - cuda-gdb=11.8.86=0 36 | - cuda-libraries=11.7.1=0 37 | - cuda-libraries-dev=11.7.1=0 38 | - cuda-memcheck=11.8.86=0 39 | - cuda-nsight=11.8.86=0 40 | - cuda-nsight-compute=11.8.0=0 41 | - cuda-nvcc=11.7.99=0 42 | - cuda-nvdisasm=11.8.86=0 43 | - cuda-nvml-dev=11.7.91=0 44 | - cuda-nvprof=11.8.87=0 45 | - cuda-nvprune=11.7.91=0 46 | - cuda-nvrtc=11.7.99=0 47 | - cuda-nvrtc-dev=11.7.99=0 48 | - cuda-nvtx=11.7.91=0 49 | - cuda-nvvp=11.8.87=0 50 | - cuda-runtime=11.7.1=0 51 | - cuda-sanitizer-api=11.8.86=0 52 | - cuda-toolkit=11.7.1=0 53 | - cuda-tools=11.7.1=0 54 | - cuda-visual-tools=11.7.1=0 55 | - cycler=0.11.0=pyhd8ed1ab_0 56 | - dbus=1.13.18=hb2f20db_0 57 | - decorator=5.1.1=pyhd8ed1ab_0 58 | - entrypoints=0.4=pyhd8ed1ab_0 59 | - executing=1.2.0=pyhd8ed1ab_0 60 | - expat=2.2.10=h9c3ff4c_0 61 | - ffmpeg=4.3=hf484d3e_0 62 | - fontconfig=2.13.1=hef1e5e3_1 63 | - fonttools=4.25.0=pyhd3eb1b0_0 64 | - freetype=2.12.1=h4a9f257_0 65 | - gcc_impl_linux-64=11.2.0=h1234567_1 66 | - gcc_linux-64=11.2.0=h5c386dc_0 67 | - gds-tools=1.4.0.31=0 68 | - gflags=2.2.2=he1b5a44_1004 69 | - giflib=5.2.1=h7b6447c_0 70 | - glib=2.69.1=h4ff587b_1 71 | - gmp=6.2.1=h295c915_3 72 | - gnutls=3.6.15=he1e5248_0 73 | - gst-plugins-base=1.14.0=h8213a91_2 74 | - gstreamer=1.14.0=h28cd5cc_2 75 | - gxx_impl_linux-64=11.2.0=h1234567_1 76 | - gxx_linux-64=11.2.0=hc2dff05_0 77 | - icu=58.2=hf484d3e_1000 78 | - intel-openmp=2021.4.0=h06a4308_3561 79 | - ipykernel=6.15.0=pyh210e3f2_0 80 | - ipython=8.14.0=pyh41d4057_0 81 | - jedi=0.18.2=pyhd8ed1ab_0 82 | - jpeg=9e=h7f8727e_0 83 | - jupyter_client=7.3.4=pyhd8ed1ab_0 84 | - jupyter_core=4.12.0=py310hff52083_0 85 | - kernel-headers_linux-64=3.10.0=h57e8cba_10 86 | - keyutils=1.6.1=h166bdaf_0 87 | - krb5=1.19.3=h3790be6_0 88 | - lame=3.100=h7b6447c_0 89 | - lcms2=2.12=h3be6417_0 90 | - ld_impl_linux-64=2.38=h1181459_1 91 | - lerc=3.0=h295c915_0 92 | - libblas=3.9.0=12_linux64_mkl 93 | - libbrotlicommon=1.0.9=h166bdaf_7 94 | - libbrotlidec=1.0.9=h166bdaf_7 95 | - libbrotlienc=1.0.9=h166bdaf_7 96 | - libclang=10.0.1=default_hb85057a_2 97 | - libcublas=11.11.3.6=0 98 | - libcublas-dev=11.11.3.6=0 99 | - libcufft=10.9.0.58=0 100 | - libcufft-dev=10.9.0.58=0 101 | - libcufile=1.4.0.31=0 102 | - libcufile-dev=1.4.0.31=0 103 | - libcurand=10.3.0.86=0 104 | - libcurand-dev=10.3.0.86=0 105 | - libcusolver=11.4.1.48=0 106 | - libcusolver-dev=11.4.1.48=0 107 | - libcusparse=11.7.5.86=0 108 | - libcusparse-dev=11.7.5.86=0 109 | - libdeflate=1.8=h7f8727e_5 110 | - libedit=3.1.20191231=he28a2e2_2 111 | - libevent=2.1.12=h8f2d780_0 112 | - libfaiss=1.7.2=h5bb0626_0_cpu 113 | - libfaiss-avx2=1.7.2=h1234567_0_cpu 114 | - libffi=3.3=he6710b0_2 115 | - libgcc-devel_linux-64=11.2.0=h1234567_1 116 | - libgcc-ng=11.2.0=h1234567_1 117 | - libgfortran-ng=12.2.0=h69a702a_19 118 | - libgfortran5=12.2.0=h337968e_19 119 | - libgomp=11.2.0=h1234567_1 120 | - libiconv=1.16=h7f8727e_2 121 | - libidn2=2.3.2=h7f8727e_0 122 | - liblapack=3.9.0=12_linux64_mkl 123 | - libllvm10=10.0.1=he513fc3_3 124 | - libnpp=11.8.0.86=0 125 | - libnpp-dev=11.8.0.86=0 126 | - libnvjpeg=11.9.0.86=0 127 | - libnvjpeg-dev=11.9.0.86=0 128 | - libpng=1.6.37=hbc83047_0 129 | - libpq=12.9=h16c4e8d_3 130 | - libsodium=1.0.18=h36c2ea0_1 131 | - libstdcxx-devel_linux-64=11.2.0=h1234567_1 132 | - libstdcxx-ng=11.2.0=h1234567_1 133 | - libtasn1=4.16.0=h27cfd23_0 134 | - libtiff=4.4.0=hecacb30_2 135 | - libunistring=0.9.10=h27cfd23_0 136 | - libuuid=1.41.5=h5eee18b_0 137 | - libwebp=1.2.4=h11a3e52_0 138 | - libwebp-base=1.2.4=h5eee18b_0 139 | - libxcb=1.15=h7f8727e_0 140 | - libxkbcommon=1.0.3=he3ba5ed_0 141 | - libxml2=2.9.14=h74e7548_0 142 | - libxslt=1.1.35=h4e12654_0 143 | - lz4-c=1.9.3=h295c915_1 144 | - matplotlib-base=3.5.3=py310hf590b9c_0 145 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 146 | - mkl=2021.4.0=h06a4308_640 147 | - mkl_fft=1.3.1=py310hd6ae3a3_0 148 | - mkl_random=1.2.2=py310h00e6091_0 149 | - mpc=1.1.0=h10f8cd9_1 150 | - mpfr=4.0.2=hb69a4c5_1 151 | - mpi=1.0=mpich 152 | - mpich=4.0.2=h846660c_100 153 | - munkres=1.1.4=pyh9f0ad1d_0 154 | - ncurses=6.3=h5eee18b_3 155 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 156 | - nettle=3.7.3=hbbd107a_1 157 | - nsight-compute=2022.3.0.22=0 158 | - nspr=4.33=h295c915_0 159 | - nss=3.74=h0370c37_0 160 | - numpy-base=1.23.4=py310h8e6c178_0 161 | - openh264=2.1.1=h4ff587b_0 162 | - openssl=1.1.1w=h7f8727e_0 163 | - packaging=21.3=pyhd8ed1ab_0 164 | - parso=0.8.3=pyhd8ed1ab_0 165 | - pcre=8.45=h9c3ff4c_0 166 | - pexpect=4.8.0=pyh1a96a4e_2 167 | - pickleshare=0.7.5=py_1003 168 | - pip=23.3.1=pyhd8ed1ab_0 169 | - ply=3.11=py_1 170 | - prompt-toolkit=3.0.39=pyha770c72_0 171 | - prompt_toolkit=3.0.39=hd8ed1ab_0 172 | - ptyprocess=0.7.0=pyhd3deb0d_0 173 | - pure_eval=0.2.2=pyhd8ed1ab_0 174 | - pycparser=2.21=pyhd3eb1b0_0 175 | - pyopenssl=22.0.0=pyhd3eb1b0_0 176 | - pyparsing=3.0.9=pyhd8ed1ab_0 177 | - pyqt=5.15.7=py310h6a678d5_1 178 | - python=3.10.8=haa1d7c7_0 179 | - python-dateutil=2.8.2=pyhd8ed1ab_0 180 | - python_abi=3.10=2_cp310 181 | - pytorch=2.0.1=py3.10_cuda11.7_cudnn8.5.0_0 182 | - pytorch-cuda=11.7=h67b0de4_0 183 | - pytorch-mutex=1.0=cuda 184 | - qt-main=5.15.2=h327a75a_7 185 | - qt-webengine=5.15.9=hd2b0992_4 186 | - qtwebkit=5.212=h4eab89a_4 187 | - readline=8.2=h5eee18b_0 188 | - six=1.16.0=pyhd3eb1b0_1 189 | - sqlite=3.39.3=h5082296_0 190 | - stack_data=0.6.2=pyhd8ed1ab_0 191 | - sysroot_linux-64=2.17=h57e8cba_10 192 | - tk=8.6.12=h1ccaba5_0 193 | - toml=0.10.2=pyhd8ed1ab_0 194 | - torchtriton=2.0.0=py310 195 | - traitlets=5.9.0=pyhd8ed1ab_0 196 | - typing_extensions=4.3.0=py310h06a4308_0 197 | - wcwidth=0.2.6=pyhd8ed1ab_0 198 | - xz=5.2.6=h5eee18b_0 199 | - zeromq=4.3.4=h9c3ff4c_1 200 | - zlib=1.2.13=h5eee18b_0 201 | - zstd=1.5.2=ha4553b6_0 202 | - pip: 203 | - absl-py==1.3.0 204 | - accelerate==0.20.3 205 | - appdirs==1.4.4 206 | - brotlipy==0.7.0 207 | - cffi==1.15.1 208 | - click==8.1.7 209 | - cmake==3.27.7 210 | - cryptography==38.0.1 211 | - debugpy==1.5.1 212 | - docker-pycreds==0.4.0 213 | - easydict==1.11 214 | - filelock==3.13.1 215 | - gitdb==4.0.11 216 | - gitpython==3.1.40 217 | - gmpy2==2.1.2 218 | - google-auth-oauthlib==0.4.6 219 | - idna==3.4 220 | - imageio==2.33.0 221 | - jinja2==3.1.2 222 | - joblib==1.3.2 223 | - jupyter-core==4.12.0 224 | - kiwisolver==1.4.2 225 | - lazy-loader==0.3 226 | - lit==17.0.6 227 | - markdown==3.5.1 228 | - markupsafe==2.1.1 229 | - menpo==0.11.0 230 | - mkl-fft==1.3.1 231 | - mkl-random==1.2.2 232 | - mkl-service==2.4.0 233 | - mpmath==1.3.0 234 | - mxnet==1.6.0 235 | - networkx==3.1 236 | - oauthlib==3.2.2 237 | - opencv-python==4.8.1.78 238 | - pandas==2.1.3 239 | - pillow==9.2.0 240 | - prettytable==3.9.0 241 | - protobuf==3.20.3 242 | - psutil==5.9.0 243 | - pyasn1==0.5.1 244 | - pyasn1-modules==0.3.0 245 | - pygments==2.15.1 246 | - pyqt5-sip==12.11.0 247 | - pysocks==1.7.1 248 | - pytz==2023.3.post1 249 | - pyyaml==6.0.1 250 | - pyzmq==25.1.0 251 | - requests==2.28.1 252 | - scikit-image==0.22.0 253 | - scikit-learn==1.3.2 254 | - scipy==1.8.1 255 | - sentry-sdk==1.39.1 256 | - setproctitle==1.3.3 257 | - setuptools==65.5.0 258 | - sip==6.6.2 259 | - smmap==5.0.1 260 | - sympy==1.12 261 | - tensorboard==2.11.0 262 | - tensorboard-data-server==0.6.1 263 | - tensorboard-plugin-wit==1.8.1 264 | - threadpoolctl==3.2.0 265 | - tifffile==2023.9.26 266 | - --extra-index-url https://download.pytorch.org/whl/cu117 267 | - torchvision==0.15.2+cu117 268 | - torchaudio==2.0.2+cu117 269 | - tornado==6.1 270 | - tqdm==4.66.1 271 | - typing-extensions==4.3.0 272 | - tzdata==2023.3 273 | - urllib3==1.26.12 274 | - wandb==0.16.1 275 | - wheel==0.41.2 276 | - timm 277 | - brisque 278 | - pyiqa 279 | -------------------------------------------------------------------------------- /heads/partial_fc.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | from typing import Callable 4 | 5 | import torch 6 | from torch import distributed 7 | from torch.nn.functional import linear, normalize 8 | 9 | 10 | class PartialFC_V2(torch.nn.Module): 11 | """ 12 | https://arxiv.org/abs/2203.15565 13 | A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). 14 | When sample rate less than 1, in each iteration, positive class centers and a random subset of 15 | negative class centers are selected to compute the margin-based softmax loss, all class 16 | centers are still maintained throughout the whole training process, but only a subset is 17 | selected and updated in each iteration. 18 | .. note:: 19 | When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). 20 | Example: 21 | -------- 22 | >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) 23 | >>> for img, labels in data_loader: 24 | >>> embeddings = net(img) 25 | >>> loss = module_pfc(embeddings, labels) 26 | >>> loss.backward() 27 | >>> optimizer.step() 28 | """ 29 | _version = 2 30 | 31 | def __init__( 32 | self, 33 | margin_loss: Callable, 34 | embedding_size: int, 35 | num_classes: int, 36 | sample_rate: float = 1.0, 37 | fp16: bool = False, 38 | ): 39 | """ 40 | Paramenters: 41 | ----------- 42 | embedding_size: int 43 | The dimension of embedding, required 44 | num_classes: int 45 | Total number of classes, required 46 | sample_rate: float 47 | The rate of negative centers participating in the calculation, default is 1.0. 48 | """ 49 | super(PartialFC_V2, self).__init__() 50 | assert ( 51 | distributed.is_initialized() 52 | ), "must initialize distributed before create this" 53 | self.rank = distributed.get_rank() 54 | self.world_size = distributed.get_world_size() 55 | 56 | self.dist_cross_entropy = DistCrossEntropy() 57 | self.embedding_size = embedding_size 58 | self.sample_rate: float = sample_rate 59 | self.fp16 = fp16 60 | self.num_local: int = num_classes // self.world_size + int( 61 | self.rank < num_classes % self.world_size 62 | ) 63 | self.class_start: int = num_classes // self.world_size * self.rank + min( 64 | self.rank, num_classes % self.world_size 65 | ) 66 | self.num_sample: int = int(self.sample_rate * self.num_local) 67 | self.last_batch_size: int = 0 68 | 69 | self.is_updated: bool = True 70 | self.init_weight_update: bool = True 71 | self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) 72 | 73 | # margin_loss 74 | if isinstance(margin_loss, Callable): 75 | self.margin_softmax = margin_loss 76 | else: 77 | raise 78 | 79 | def sample(self, labels, index_positive): 80 | """ 81 | This functions will change the value of labels 82 | Parameters: 83 | ----------- 84 | labels: torch.Tensor 85 | pass 86 | index_positive: torch.Tensor 87 | pass 88 | optimizer: torch.optim.Optimizer 89 | pass 90 | """ 91 | with torch.no_grad(): 92 | positive = torch.unique(labels[index_positive], sorted=True).cuda() 93 | if self.num_sample - positive.size(0) >= 0: 94 | perm = torch.rand(size=[self.num_local]).cuda() 95 | perm[positive] = 2.0 96 | index = torch.topk(perm, k=self.num_sample)[1].cuda() 97 | index = index.sort()[0].cuda() 98 | else: 99 | index = positive 100 | self.weight_index = index 101 | 102 | labels[index_positive] = torch.searchsorted(index, labels[index_positive]) 103 | 104 | return self.weight[self.weight_index] 105 | 106 | def forward( 107 | self, 108 | local_embeddings: torch.Tensor, 109 | local_labels: torch.Tensor, 110 | ): 111 | """ 112 | Parameters: 113 | ---------- 114 | local_embeddings: torch.Tensor 115 | feature embeddings on each GPU(Rank). 116 | local_labels: torch.Tensor 117 | labels on each GPU(Rank). 118 | Returns: 119 | ------- 120 | loss: torch.Tensor 121 | pass 122 | """ 123 | local_labels.squeeze_() 124 | local_labels = local_labels.long() 125 | 126 | batch_size = local_embeddings.size(0) 127 | if self.last_batch_size == 0: 128 | self.last_batch_size = batch_size 129 | assert self.last_batch_size == batch_size, ( 130 | f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}") 131 | 132 | _gather_embeddings = [ 133 | torch.zeros((batch_size, self.embedding_size)).cuda() 134 | for _ in range(self.world_size) 135 | ] 136 | _gather_labels = [ 137 | torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) 138 | ] 139 | _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) 140 | distributed.all_gather(_gather_labels, local_labels) 141 | 142 | embeddings = torch.cat(_list_embeddings) 143 | labels = torch.cat(_gather_labels) 144 | 145 | labels = labels.view(-1, 1) 146 | index_positive = (self.class_start <= labels) & ( 147 | labels < self.class_start + self.num_local 148 | ) 149 | labels[~index_positive] = -1 150 | labels[index_positive] -= self.class_start 151 | 152 | if self.sample_rate < 1: 153 | weight = self.sample(labels, index_positive) 154 | else: 155 | weight = self.weight 156 | 157 | with torch.cuda.amp.autocast(self.fp16): 158 | norm_embeddings = normalize(embeddings) 159 | norm_weight_activated = normalize(weight) 160 | logits = linear(norm_embeddings, norm_weight_activated) 161 | if self.fp16: 162 | logits = logits.float() 163 | logits = logits.clamp(-1, 1) 164 | 165 | logits = self.margin_softmax(logits, labels) 166 | loss = self.dist_cross_entropy(logits, labels) 167 | return loss 168 | 169 | 170 | class DistCrossEntropyFunc(torch.autograd.Function): 171 | """ 172 | CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. 173 | Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): 174 | """ 175 | 176 | @staticmethod 177 | def forward(ctx, logits: torch.Tensor, label: torch.Tensor): 178 | """ """ 179 | batch_size = logits.size(0) 180 | # for numerical stability 181 | max_logits, _ = torch.max(logits, dim=1, keepdim=True) 182 | # local to global 183 | distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) 184 | logits.sub_(max_logits) 185 | logits.exp_() 186 | sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) 187 | # local to global 188 | distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) 189 | logits.div_(sum_logits_exp) 190 | index = torch.where(label != -1)[0] 191 | # loss 192 | loss = torch.zeros(batch_size, 1, device=logits.device) 193 | loss[index] = logits[index].gather(1, label[index]) 194 | distributed.all_reduce(loss, distributed.ReduceOp.SUM) 195 | ctx.save_for_backward(index, logits, label) 196 | return loss.clamp_min_(1e-30).log_().mean() * (-1) 197 | 198 | @staticmethod 199 | def backward(ctx, loss_gradient): 200 | """ 201 | Args: 202 | loss_grad (torch.Tensor): gradient backward by last layer 203 | Returns: 204 | gradients for each input in forward function 205 | `None` gradients for one-hot label 206 | """ 207 | ( 208 | index, 209 | logits, 210 | label, 211 | ) = ctx.saved_tensors 212 | batch_size = logits.size(0) 213 | one_hot = torch.zeros( 214 | size=[index.size(0), logits.size(1)], device=logits.device 215 | ) 216 | one_hot.scatter_(1, label[index], 1) 217 | logits[index] -= one_hot 218 | logits.div_(batch_size) 219 | return logits * loss_gradient.item(), None 220 | 221 | 222 | class DistCrossEntropy(torch.nn.Module): 223 | def __init__(self): 224 | super(DistCrossEntropy, self).__init__() 225 | 226 | def forward(self, logit_part, label_part): 227 | return DistCrossEntropyFunc.apply(logit_part, label_part) 228 | 229 | 230 | class AllGatherFunc(torch.autograd.Function): 231 | """AllGather op with gradient backward""" 232 | 233 | @staticmethod 234 | def forward(ctx, tensor, *gather_list): 235 | gather_list = list(gather_list) 236 | distributed.all_gather(gather_list, tensor) 237 | return tuple(gather_list) 238 | 239 | @staticmethod 240 | def backward(ctx, *grads): 241 | grad_list = list(grads) 242 | rank = distributed.get_rank() 243 | grad_out = grad_list[rank] 244 | 245 | dist_ops = [ 246 | distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) 247 | if i == rank 248 | else distributed.reduce( 249 | grad_list[i], i, distributed.ReduceOp.SUM, async_op=True 250 | ) 251 | for i in range(distributed.get_world_size()) 252 | ] 253 | for _op in dist_ops: 254 | _op.wait() 255 | 256 | grad_out *= len(grad_list) # cooperate with distributed loss function 257 | return (grad_out, *[None for _ in range(len(grad_list))]) 258 | 259 | 260 | AllGather = AllGatherFunc.apply -------------------------------------------------------------------------------- /validation_lq/tinyface_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import os 4 | import scipy 5 | 6 | def get_all_files(root, extension_list=['.jpg', '.png', '.jpeg']): 7 | 8 | all_files = list() 9 | for (dirpath, dirnames, filenames) in os.walk(root): 10 | all_files += [os.path.join(dirpath, file) for file in filenames] 11 | if extension_list is None: 12 | return all_files 13 | all_files = list(filter(lambda x: os.path.splitext(x)[1] in extension_list, all_files)) 14 | return all_files 15 | 16 | 17 | class TinyFaceTest: 18 | def __init__(self, tinyface_root, alignment_dir_name): 19 | 20 | self.tinyface_root = tinyface_root 21 | # as defined by tinyface protocol 22 | self.gallery_dict = scipy.io.loadmat(os.path.join(tinyface_root, 'tinyface/Testing_Set/gallery_match_img_ID_pairs.mat')) 23 | self.probe_dict = scipy.io.loadmat(os.path.join(tinyface_root, 'tinyface/Testing_Set/probe_img_ID_pairs.mat')) 24 | self.proto_gal_paths = [os.path.join(tinyface_root, alignment_dir_name, 'Gallery_Match', p[0].item()) for p in self.gallery_dict['gallery_set']] 25 | self.proto_prob_paths = [os.path.join(tinyface_root, alignment_dir_name, 'Probe', p[0].item()) for p in self.probe_dict['probe_set']] 26 | self.proto_distractor_paths = get_all_files(os.path.join(tinyface_root, alignment_dir_name, 'Gallery_Distractor')) 27 | 28 | self.image_paths = get_all_files(os.path.join(tinyface_root, alignment_dir_name)) 29 | self.image_paths = np.array(self.image_paths).astype(np.object).flatten() 30 | 31 | self.probe_paths = get_all_files(os.path.join(tinyface_root, 'tinyface/Testing_Set/Probe')) 32 | self.probe_paths = np.array(self.probe_paths).astype(np.object).flatten() 33 | 34 | self.gallery_paths = get_all_files(os.path.join(tinyface_root, 'tinyface/Testing_Set/Gallery_Match')) 35 | self.gallery_paths = np.array(self.gallery_paths).astype(np.object).flatten() 36 | 37 | self.distractor_paths = get_all_files(os.path.join(tinyface_root, 'tinyface/Testing_Set/Gallery_Distractor')) 38 | self.distractor_paths = np.array(self.distractor_paths).astype(np.object).flatten() 39 | 40 | self.init_proto(self.probe_paths, self.gallery_paths, self.distractor_paths) 41 | 42 | def get_key(self, image_path): 43 | return os.path.splitext(os.path.basename(image_path))[0] 44 | 45 | def get_label(self, image_path): 46 | return int(os.path.basename(image_path).split('_')[0]) 47 | 48 | def init_proto(self, probe_paths, match_paths, distractor_paths): 49 | index_dict = {} 50 | for i, image_path in enumerate(self.image_paths): 51 | index_dict[self.get_key(image_path)] = i 52 | 53 | self.indices_probe = np.array([index_dict[self.get_key(img)] for img in probe_paths]) 54 | self.indices_match = np.array([index_dict[self.get_key(img)] for img in match_paths]) 55 | self.indices_distractor = np.array([index_dict[self.get_key(img)] for img in distractor_paths]) 56 | 57 | self.labels_probe = np.array([self.get_label(img) for img in probe_paths]) 58 | self.labels_match = np.array([self.get_label(img) for img in match_paths]) 59 | self.labels_distractor = np.array([-100 for img in distractor_paths]) 60 | 61 | self.indices_gallery = np.concatenate([self.indices_match, self.indices_distractor]) 62 | self.labels_gallery = np.concatenate([self.labels_match, self.labels_distractor]) 63 | 64 | 65 | def test_identification(self, features, ranks=[1,5,20]): 66 | feat_probe = features[self.indices_probe] 67 | feat_gallery = features[self.indices_gallery] 68 | compare_func = inner_product 69 | score_mat = compare_func(feat_probe, feat_gallery) 70 | 71 | label_mat = self.labels_probe[:,None] == self.labels_gallery[None,:] 72 | 73 | results, _, __ = DIR_FAR(score_mat, label_mat, ranks) 74 | 75 | return results 76 | 77 | def inner_product(x1, x2): 78 | x1, x2 = np.array(x1), np.array(x2) 79 | if x1.ndim == 3: 80 | raise ValueError('why?') 81 | x1, x2 = x1[:,:,0], x2[:,:,0] 82 | return np.dot(x1, x2.T) 83 | 84 | 85 | 86 | def DIR_FAR(score_mat, label_mat, ranks=[1], FARs=[1.0], get_false_indices=False): 87 | ''' 88 | Code borrowed from https://github.com/seasonSH/Probabilistic-Face-Embeddings 89 | 90 | Closed/Open-set Identification. 91 | A general case of Cummulative Match Characteristic (CMC) 92 | where thresholding is allowed for open-set identification. 93 | args: 94 | score_mat: a P x G matrix, P is number of probes, G is size of gallery 95 | label_mat: a P x G matrix, bool 96 | ranks: a list of integers 97 | FARs: false alarm rates, if 1.0, closed-set identification (CMC) 98 | get_false_indices: not implemented yet 99 | return: 100 | DIRs: an F x R matrix, F is the number of FARs, R is the number of ranks, 101 | flatten into a vector if F=1 or R=1. 102 | FARs: an vector of length = F. 103 | thredholds: an vector of length = F. 104 | ''' 105 | assert score_mat.shape==label_mat.shape 106 | # assert np.all(label_mat.astype(np.float32).sum(axis=1) <=1 ) 107 | # Split the matrix for match probes and non-match probes 108 | # subfix _m: match, _nm: non-match 109 | # For closed set, we only use the match probes 110 | match_indices = label_mat.astype(np.bool).any(axis=1) 111 | score_mat_m = score_mat[match_indices,:] 112 | label_mat_m = label_mat[match_indices,:] 113 | score_mat_nm = score_mat[np.logical_not(match_indices),:] 114 | label_mat_nm = label_mat[np.logical_not(match_indices),:] 115 | 116 | print('mate probes: %d, non mate probes: %d' % (score_mat_m.shape[0], score_mat_nm.shape[0])) 117 | 118 | # Find the thresholds for different FARs 119 | max_score_nm = np.max(score_mat_nm, axis=1) 120 | label_temp = np.zeros(max_score_nm.shape, dtype=np.bool) 121 | if len(FARs) == 1 and FARs[0] >= 1.0: 122 | # If only testing closed-set identification, use the minimum score as threshold 123 | # in case there is no non-mate probes 124 | thresholds = [np.min(score_mat) - 1e-10] 125 | openset = False 126 | else: 127 | # If there is open-set identification, find the thresholds by FARs. 128 | assert score_mat_nm.shape[0] > 0, "For open-set identification (FAR<1.0), there should be at least one non-mate probe!" 129 | thresholds = find_thresholds_by_FAR(max_score_nm, label_temp, FARs=FARs) 130 | openset = True 131 | 132 | # Sort the labels row by row according to scores 133 | sort_idx_mat_m = np.argsort(score_mat_m, axis=1) 134 | sorted_label_mat_m = np.ndarray(label_mat_m.shape, dtype=np.bool) 135 | for row in range(label_mat_m.shape[0]): 136 | sort_idx = (sort_idx_mat_m[row, :])[::-1] 137 | sorted_label_mat_m[row,:] = label_mat_m[row, sort_idx] 138 | 139 | # Calculate DIRs for different FARs and ranks 140 | if openset: 141 | gt_score_m = score_mat_m[label_mat_m] 142 | assert gt_score_m.size == score_mat_m.shape[0] 143 | 144 | DIRs = np.zeros([len(FARs), len(ranks)], dtype=np.float32) 145 | FARs = np.zeros([len(FARs)], dtype=np.float32) 146 | if get_false_indices: 147 | false_retrieval = np.zeros([len(FARs), len(ranks), score_mat_m.shape[0]], dtype=np.bool) 148 | false_reject = np.zeros([len(FARs), len(ranks), score_mat_m.shape[0]], dtype=np.bool) 149 | false_accept = np.zeros([len(FARs), len(ranks), score_mat_nm.shape[0]], dtype=np.bool) 150 | for i, threshold in enumerate(thresholds): 151 | for j, rank in enumerate(ranks): 152 | success_retrieval = sorted_label_mat_m[:,0:rank].any(axis=1) 153 | if openset: 154 | success_threshold = gt_score_m >= threshold 155 | DIRs[i,j] = (success_threshold & success_retrieval).astype(np.float32).mean() 156 | else: 157 | DIRs[i,j] = success_retrieval.astype(np.float32).mean() 158 | if get_false_indices: 159 | false_retrieval[i,j] = ~success_retrieval 160 | false_accept[i,j] = score_mat_nm.max(1) >= threshold 161 | if openset: 162 | false_reject[i,j] = ~success_threshold 163 | if score_mat_nm.shape[0] > 0: 164 | FARs[i] = (max_score_nm >= threshold).astype(np.float32).mean() 165 | 166 | if DIRs.shape[0] == 1 or DIRs.shape[1] == 1: 167 | DIRs = DIRs.flatten() 168 | 169 | if get_false_indices: 170 | return DIRs, FARs, thresholds, match_indices, false_retrieval, false_reject, false_accept, sort_idx_mat_m 171 | else: 172 | return DIRs, FARs, thresholds 173 | 174 | 175 | # Find thresholds given FARs 176 | # but the real FARs using these thresholds could be different 177 | # the exact FARs need to recomputed using calcROC 178 | def find_thresholds_by_FAR(score_vec, label_vec, FARs=None, epsilon=1e-5): 179 | # Code borrowed from https://github.com/seasonSH/Probabilistic-Face-Embeddings 180 | 181 | assert len(score_vec.shape)==1 182 | assert score_vec.shape == label_vec.shape 183 | assert label_vec.dtype == np.bool 184 | score_neg = score_vec[~label_vec] 185 | score_neg[::-1].sort() 186 | # score_neg = np.sort(score_neg)[::-1] # score from high to low 187 | num_neg = len(score_neg) 188 | 189 | assert num_neg >= 1 190 | 191 | if FARs is None: 192 | thresholds = np.unique(score_neg) 193 | thresholds = np.insert(thresholds, 0, thresholds[0]+epsilon) 194 | thresholds = np.insert(thresholds, thresholds.size, thresholds[-1]-epsilon) 195 | else: 196 | FARs = np.array(FARs) 197 | num_false_alarms = np.round(num_neg * FARs).astype(np.int32) 198 | 199 | thresholds = [] 200 | for num_false_alarm in num_false_alarms: 201 | if num_false_alarm==0: 202 | threshold = score_neg[0] + epsilon 203 | else: 204 | threshold = score_neg[num_false_alarm-1] 205 | thresholds.append(threshold) 206 | thresholds = np.array(thresholds) 207 | 208 | return thresholds -------------------------------------------------------------------------------- /validation_lq/validate_tinyface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import argparse 5 | import pandas as pd 6 | import tinyface_helper 7 | import sys, os 8 | sys.path.insert(0, os.getcwd()) 9 | from backbones import get_model 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms 12 | from PIL import Image 13 | import cv2 14 | from torchvision.transforms import InterpolationMode 15 | 16 | 17 | def str2bool(v): 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | def l2_norm(input, axis=1): 26 | """l2 normalize 27 | """ 28 | norm = torch.norm(input, 2, axis, True) 29 | output = torch.div(input, norm) 30 | return output, norm 31 | 32 | 33 | def fuse_features_with_norm(stacked_embeddings, stacked_norms, fusion_method='norm_weighted_avg'): 34 | # print(stacked_norms) 35 | assert stacked_embeddings.ndim == 3 # (n_features_to_fuse, batch_size, channel) 36 | if stacked_norms is not None: 37 | assert stacked_norms.ndim == 3 # (n_features_to_fuse, batch_size, 1) 38 | else: 39 | assert fusion_method not in ['norm_weighted_avg', 'pre_norm_vector_add'] 40 | 41 | if fusion_method == 'norm_weighted_avg': 42 | weights = stacked_norms / stacked_norms.sum(dim=0, keepdim=True) 43 | fused = (stacked_embeddings * weights).sum(dim=0) 44 | fused, _ = l2_norm(fused, axis=1) 45 | fused_norm = stacked_norms.mean(dim=0) 46 | elif fusion_method == 'pre_norm_vector_add': 47 | pre_norm_embeddings = stacked_embeddings * stacked_norms 48 | fused = pre_norm_embeddings.sum(dim=0) 49 | fused, fused_norm = l2_norm(fused, axis=1) 50 | elif fusion_method == 'average': 51 | fused = stacked_embeddings.sum(dim=0) 52 | fused, _ = l2_norm(fused, axis=1) 53 | if stacked_norms is None: 54 | fused_norm = torch.ones((len(fused), 1)) 55 | else: 56 | fused_norm = stacked_norms.mean(dim=0) 57 | elif fusion_method == 'concat': 58 | fused = torch.cat([stacked_embeddings[0], stacked_embeddings[1]], dim=-1) 59 | if stacked_norms is None: 60 | fused_norm = torch.ones((len(fused), 1)) 61 | else: 62 | fused_norm = stacked_norms.mean(dim=0) 63 | elif fusion_method == 'faceness_score': 64 | raise ValueError('not implemented yet. please refer to https://github.com/deepinsight/insightface/blob/5d3be6da49275602101ad122601b761e36a66a01/recognition/_evaluation_/ijb/ijb_11.py#L296') 65 | # note that they do not use normalization afterward. 66 | else: 67 | raise ValueError('not a correct fusion method', fusion_method) 68 | 69 | return fused, fused_norm 70 | 71 | 72 | def infer(model, dataloader, use_flip_test, fusion_method, gpu): 73 | model.eval() 74 | features = [] 75 | norms = [] 76 | device = "cuda:" + str(gpu) 77 | with torch.no_grad(): 78 | for images, idx in tqdm(dataloader): 79 | 80 | feature = model(images.to(device)) 81 | if isinstance(feature, tuple): 82 | feature, norm = feature 83 | else: 84 | norm = torch.norm(feature, 2, 1, True) 85 | feature = torch.div(feature, norm) 86 | 87 | if use_flip_test: 88 | fliped_images = torch.flip(images, dims=[3]) 89 | flipped_feature = model(fliped_images.to(device)) 90 | if isinstance(flipped_feature, tuple): 91 | flipped_feature, flipped_norm = flipped_feature 92 | else: 93 | flipped_norm = torch.norm(flipped_feature, 2, 1, True) 94 | flipped_feature = torch.div(flipped_feature, flipped_norm) 95 | 96 | stacked_embeddings = torch.stack([feature, flipped_feature], dim=0) 97 | if norm is not None: 98 | stacked_norms = torch.stack([norm, flipped_norm], dim=0) 99 | else: 100 | stacked_norms = None 101 | 102 | fused_feature, fused_norm = fuse_features_with_norm(stacked_embeddings, stacked_norms, fusion_method=fusion_method) 103 | features.append(fused_feature.cpu().numpy()) 104 | norms.append(fused_norm.cpu().numpy()) 105 | else: 106 | features.append(feature.cpu().numpy()) 107 | norms.append(norm.cpu().numpy()) 108 | 109 | features = np.concatenate(features, axis=0) 110 | norms = np.concatenate(norms, axis=0) 111 | return features, norms 112 | 113 | def load_pretrained_model(model, model_name, gpu, lora_rank, lora_scale, use_lora): 114 | # load model and pretrained statedict 115 | ckpt_path = model_name 116 | model = get_model(model, dropout=0.0, fp16=False, num_features=512, r=lora_rank, scale=lora_scale, use_lora=use_lora) 117 | 118 | # model = net.build_model(arch) 119 | statedict = torch.load(ckpt_path, map_location=torch.device('cuda:' + str(gpu))) 120 | model.load_state_dict(statedict) 121 | model.eval() 122 | return model 123 | 124 | class ListDataset(Dataset): 125 | def __init__(self, img_list, image_size, image_is_saved_with_swapped_B_and_R=True): 126 | super(ListDataset, self).__init__() 127 | 128 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 129 | # face_emore/img has images saved with B and G (of RGB) swapped. 130 | # Since training data loader uses PIL (results in RGB) to read image 131 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 132 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 133 | # then you should set image_is_saved_with_swapped_B_and_R=True 134 | 135 | self.img_list = img_list 136 | self.transform = transforms.Compose([ 137 | transforms.ToTensor(), 138 | transforms.Resize(size=(image_size,image_size), interpolation=InterpolationMode.BICUBIC), 139 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 140 | 141 | self.image_is_saved_with_swapped_B_and_R = image_is_saved_with_swapped_B_and_R 142 | 143 | 144 | def __len__(self): 145 | return len(self.img_list) 146 | 147 | def __getitem__(self, idx): 148 | image_path = self.img_list[idx] 149 | img = cv2.imread(image_path) 150 | img = img[:, :, :3] 151 | # img = io.imread(image_path) 152 | 153 | if self.image_is_saved_with_swapped_B_and_R: 154 | # print('check if it really should be on') 155 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 156 | 157 | img = Image.fromarray(img) 158 | img = self.transform(img) 159 | # print(img) 160 | return img, idx 161 | 162 | def prepare_dataloader(img_list, batch_size, image_size, num_workers=0, image_is_saved_with_swapped_B_and_R=True): 163 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 164 | # face_emore/img has images saved with B and G (of RGB) swapped. 165 | # Since training data loader uses PIL (results in RGB) to read image 166 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 167 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 168 | # then you should set image_is_saved_with_swapped_B_and_R=True 169 | 170 | image_dataset = ListDataset(img_list, image_size, image_is_saved_with_swapped_B_and_R=image_is_saved_with_swapped_B_and_R) 171 | dataloader = DataLoader(image_dataset, 172 | batch_size=batch_size, 173 | shuffle=False, 174 | drop_last=False, 175 | num_workers=num_workers) 176 | return dataloader 177 | 178 | 179 | def get_save_path(model_load_path): 180 | directory, _ = os.path.split(model_load_path) 181 | results_save_path = os.path.join(directory, 'results') 182 | 183 | return results_save_path 184 | 185 | if __name__ == '__main__': 186 | 187 | parser = argparse.ArgumentParser(description='tinyface') 188 | 189 | parser.add_argument('--data_root', default='/mnt/store/knaraya4/data') 190 | parser.add_argument('--gpu', default=0, type=int, help='gpu id') 191 | parser.add_argument('--batch_size', default=1024, type=int, help='') 192 | parser.add_argument('--image_size', default=112, type=int, help='112') 193 | parser.add_argument('--model_load_path', type=str) 194 | parser.add_argument('--model_type', type=str, default='r50') 195 | parser.add_argument('--lora_rank', type=int, default=4) 196 | parser.add_argument('--lora_scale', type=int, default=1) 197 | parser.add_argument('--use_lora', action='store_true') 198 | parser.add_argument('--use_flip_test', type=str2bool, default='True') 199 | parser.add_argument('--fusion_method', type=str, default='pre_norm_vector_add', choices=('average', 'norm_weighted_avg', 'pre_norm_vector_add', 'concat', 'faceness_score')) 200 | args = parser.parse_args() 201 | 202 | # load model 203 | model_load_path = args.model_load_path 204 | print("Model Load Path", model_load_path) 205 | model = load_pretrained_model(args.model_type, model_load_path, args.gpu, args.lora_rank, args.lora_scale, args.use_lora) 206 | model.to('cuda:{}'.format(args.gpu)) 207 | 208 | tinyface_test = tinyface_helper.TinyFaceTest(tinyface_root=args.data_root,alignment_dir_name='tinyface_aligned_112') 209 | 210 | # set save root 211 | gpu_id = args.gpu 212 | save_path = get_save_path(model_load_path) 213 | print("Save Path: ", save_path) 214 | 215 | if not os.path.exists(save_path): 216 | os.makedirs(save_path) 217 | print('save_path: {}'.format(save_path)) 218 | 219 | img_paths = tinyface_test.image_paths 220 | print('total images : {}'.format(len(img_paths))) 221 | dataloader = prepare_dataloader(img_paths, args.batch_size, args.image_size, num_workers=8, image_is_saved_with_swapped_B_and_R=True) 222 | features, norms = infer(model, dataloader, use_flip_test=args.use_flip_test, fusion_method=args.fusion_method, gpu=args.gpu) 223 | results = tinyface_test.test_identification(features, ranks=[1,5,20]) 224 | print(results) 225 | pd.DataFrame({'rank':[1,5,20], 'values':results}).to_csv(os.path.join(save_path, f'tinyface_{args.fusion_method}.csv')) -------------------------------------------------------------------------------- /validation_lq/evaluate_helper.py: -------------------------------------------------------------------------------- 1 | 2 | from PFE.ijbs import IJBSTest 3 | import PFE.utils 4 | import numpy as np 5 | from tqdm import tqdm 6 | import math 7 | from functools import partial 8 | import os 9 | 10 | 11 | def write_result(write_path, title, values=None): 12 | with open(write_path, 'a') as f: 13 | if values is None: 14 | f.write('{}\n'.format(title)) 15 | else: 16 | f.write('{},{}\n'.format(title, ",".join([str(v) for v in values]))) 17 | 18 | def eval_IJBS(feat_func, 19 | image_paths, 20 | fuse_match_method='mean_cos', 21 | subsample=16, 22 | verbose=True, 23 | get_retrievals=False, 24 | save_root='./', 25 | ijbs_proto_path=None): 26 | 27 | 28 | if fuse_match_method == 'mean_cos': 29 | fuse_func = PFE.utils.average_fuse 30 | compare_func = PFE.utils.inner_product 31 | elif fuse_match_method == 'PFE_fuse': 32 | fuse_func = partial(PFE.utils.aggregate_PFE_v1, normalize=True, concatenate=False, return_sigma=False) 33 | compare_func = PFE.utils.inner_product 34 | elif fuse_match_method == 'PFE_fuse_match': 35 | fuse_func = partial(PFE.utils.aggregate_PFE_v1, normalize=True, concatenate=True, return_sigma=True) 36 | compare_func = PFE.utils.uncertain_score_simple 37 | elif fuse_match_method == 'pre_norm_vector_add_cos': 38 | # it is same as averaging. But the features were multiplied by norm beforehand 39 | fuse_func = PFE.utils.average_fuse 40 | compare_func = PFE.utils.inner_product 41 | elif fuse_match_method == 'norm_weighted_avg': 42 | raise ValueError('cannot implement') 43 | else: 44 | raise ValueError('not a corect fuse_match_method {}'.format(fuse_match_method)) 45 | 46 | if ijbs_proto_path is None: 47 | ijbs_proto_path = os.path.join(os.path.dirname(__file__), 'IJB_S_proto.pkl') 48 | # print("Proto Path: ", ijbs_proto_path) 49 | tester = IJBSTest() 50 | tester.init_proto('/mnt/store/knaraya4/data/IJBS/protocols') 51 | # tester.save('/mnt/store/knaraya4/data/IJBS/IJB_S_proto.pkl') 52 | # tester.load(ijbs_proto_path) 53 | tester.initialize_indices(image_paths) 54 | 55 | has_indice = [1 for template in tester.all_template_list if template.indices is not None] 56 | 57 | tester.compare_func = compare_func 58 | for i, template in tqdm(enumerate(tester.all_template_list)): 59 | if template.indices is not None: 60 | if type(feat_func) is np.ndarray: 61 | features = feat_func[template.indices] 62 | elif hasattr(template, 'images'): 63 | features = feat_func(template.images) 64 | else: 65 | indices = template.indices 66 | if subsample: 67 | chunk = int(math.ceil(len(indices) / subsample)) 68 | sub_indices = np.unique(np.arange(len(indices)) // chunk) * chunk 69 | indices = indices[sub_indices] 70 | features = feat_func(tester.image_paths[indices]) 71 | 72 | # fuse 73 | template.feature = fuse_func(features) 74 | else: 75 | if fuse_match_method == 'mean_cos': 76 | template.feature = np.zeros(512) 77 | elif fuse_match_method == 'PFE_fuse': 78 | template.feature = np.zeros(512) 79 | elif fuse_match_method == 'PFE_fuse_match': 80 | template.feature = np.stack([np.zeros(512), np.ones(512)], axis=-1) 81 | elif fuse_match_method == 'pre_norm_vector_add_cos': 82 | template.feature = np.zeros(512) 83 | elif fuse_match_method == 'norm_weighted_avg': 84 | raise ValueError('cannot implement') 85 | else: 86 | raise ValueError('not a correct fuse method') 87 | all_result = {} 88 | 89 | print('surveillance to single') 90 | if get_retrievals: 91 | result = tester.surveillance_to_single(get_retrievals=True) 92 | np.save(os.path.join(save_root, 'eval_result/{}_sur_to_sin.npy'.format(fuse_match_method)), result) 93 | DIRs_closeset, DIRs_openset = tester.surveillance_to_single() 94 | if save_root is not None: 95 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='surveillance_to_single', values=['rank1', 'rank5', 'rank10']) 96 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='closedset', values=DIRs_closeset.tolist()) 97 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='surveillance_to_single', values=['0.01_FPIR', '0.1_FPIR']) 98 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='openset', values=DIRs_openset.tolist()) 99 | print(DIRs_closeset) 100 | print(DIRs_openset) 101 | all_result['closed_{}'.format('surveillance_to_single')] = DIRs_closeset 102 | all_result['open_{}'.format('surveillance_to_single')] = DIRs_openset 103 | 104 | 105 | print('surveillance_to_booking') 106 | if get_retrievals: 107 | result = tester.surveillance_to_booking(get_retrievals=True) 108 | np.save(os.path.join(save_root, 'eval_result/{}_sur_to_book.npy'.format(fuse_match_method)), result) 109 | DIRs_closeset, DIRs_openset = tester.surveillance_to_booking() 110 | if save_root is not None: 111 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='surveillance_to_booking', values=['rank1', 'rank5', 'rank10']) 112 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='closedset', values=DIRs_closeset.tolist()) 113 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='surveillance_to_booking', values=['0.01_FPIR', '0.1_FPIR']) 114 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='openset', values=DIRs_openset.tolist()) 115 | print(DIRs_closeset) 116 | print(DIRs_openset) 117 | all_result['closed_{}'.format('surveillance_to_booking')] = DIRs_closeset 118 | all_result['open_{}'.format('surveillance_to_booking')] = DIRs_openset 119 | 120 | print('multiview_surveillance_to_booking') 121 | if get_retrievals: 122 | result = tester.multiview_surveillance_to_booking(get_retrievals=True) 123 | np.save(os.path.join(save_root, 'eval_result/{}_multi_to_book.npy'.format(fuse_match_method)), result) 124 | DIRs_closeset, DIRs_openset = tester.multiview_surveillance_to_booking() 125 | if save_root is not None: 126 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='multiview_surveillance_to_booking', values=['rank1', 'rank5', 'rank10']) 127 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='closedset', values=DIRs_closeset.tolist()) 128 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='multiview_surveillance_to_booking', values=['0.01_FPIR', '0.1_FPIR']) 129 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='openset', values=DIRs_openset.tolist()) 130 | print(DIRs_closeset) 131 | print(DIRs_openset) 132 | all_result['closed_{}'.format('multiview_surveillance_to_booking')] = DIRs_closeset 133 | all_result['open_{}'.format('multiview_surveillance_to_booking')] = DIRs_openset 134 | 135 | print('surveillance_to_surveillance') 136 | if get_retrievals: 137 | result = tester.surveillance_to_surveillance(get_retrievals=True) 138 | np.save(os.path.join(save_root, 'eval_result/{}_sur_to_sur.npy'.format(fuse_match_method)), result) 139 | DIRs_closeset, DIRs_openset = tester.surveillance_to_surveillance() 140 | if save_root is not None: 141 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='surveillance_to_surveillance', values=['rank1', 'rank5', 'rank10']) 142 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='closedset', values=DIRs_closeset.tolist()) 143 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='surveillance_to_surveillance', values=['0.01_FPIR', '0.1_FPIR']) 144 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='openset', values=DIRs_openset.tolist()) 145 | print(DIRs_closeset) 146 | print(DIRs_openset) 147 | all_result['closed_{}'.format('surveillance_to_surveillance')] = DIRs_closeset 148 | all_result['open_{}'.format('surveillance_to_surveillance')] = DIRs_openset 149 | 150 | print('uav_surveillance_to_booking') 151 | if get_retrievals: 152 | result = tester.uav_surveillance_to_booking(get_retrievals=True) 153 | np.save(os.path.join(save_root, 'eval_result/{}_uav_to_sur.npy'.format(fuse_match_method)), result) 154 | DIRs_closeset, DIRs_openset = tester.uav_surveillance_to_booking() 155 | if save_root is not None: 156 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='uav_surveillance_to_booking', values=['rank1', 'rank5', 'rank10']) 157 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='closedset', values=DIRs_closeset.tolist()) 158 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='uav_surveillance_to_booking', values=['0.01_FPIR', '0.1_FPIR']) 159 | write_result(os.path.join(save_root, 'eval_result/{}_result.csv'.format(fuse_match_method)), title='openset', values=DIRs_openset.tolist()) 160 | print(DIRs_closeset) 161 | print(DIRs_openset) 162 | all_result['closed_{}'.format('uav_surveillance_to_booking')] = DIRs_closeset 163 | all_result['open_{}'.format('uav_surveillance_to_booking')] = DIRs_openset 164 | 165 | return all_result 166 | 167 | 168 | def run_eval_with_features(save_root, features, image_paths, get_retrievals=False, fuse_match_method='mean_cos', ijbs_proto_path=None): 169 | assert len(features) == len(image_paths) 170 | 171 | all_result = eval_IJBS(feat_func=features, 172 | fuse_match_method=fuse_match_method, 173 | image_paths=image_paths, 174 | subsample=False, 175 | verbose=True, 176 | get_retrievals=get_retrievals, 177 | save_root=save_root, 178 | ijbs_proto_path=ijbs_proto_path 179 | ) 180 | return all_result -------------------------------------------------------------------------------- /validation_lq/PFE/utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for training and testing 2 | """ 3 | # MIT License 4 | # 5 | # Copyright (c) 2017 Yichun Shi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import sys 26 | import os 27 | import numpy as np 28 | from scipy import misc 29 | import imp 30 | import time 31 | import math 32 | import random 33 | from datetime import datetime 34 | import shutil 35 | 36 | def import_file(full_path_to_module, name='module.name'): 37 | 38 | module_obj = imp.load_source(name, full_path_to_module) 39 | 40 | return module_obj 41 | 42 | def create_log_dir(log_base_dir, name, config_file, model_file): 43 | subdir = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S') 44 | log_dir = os.path.join(os.path.expanduser(log_base_dir), name, subdir) 45 | if not os.path.isdir(log_dir): # Create the log directory if it doesn't exist 46 | os.makedirs(log_dir) 47 | shutil.copyfile(config_file, os.path.join(log_dir,'config.py')) 48 | shutil.copyfile(model_file, os.path.join(log_dir,'model.py')) 49 | 50 | return log_dir 51 | 52 | def get_updated_learning_rate(global_step, learning_rate_strategy, learning_rate_schedule): 53 | if learning_rate_strategy == 'step': 54 | max_step = -1 55 | learning_rate = 0.0 56 | for step, lr in learning_rate_schedule.items(): 57 | if global_step >= step and step > max_step: 58 | learning_rate = lr 59 | max_step = step 60 | if max_step == -1: 61 | raise ValueError('cannot find learning rate for step %d' % global_step) 62 | elif learning_rate_strategy == 'cosine': 63 | initial = learning_rate_schedule['initial'] 64 | interval = learning_rate_schedule['interval'] 65 | end_step = learning_rate_schedule['end_step'] 66 | step = math.floor(float(global_step) / interval) * interval 67 | assert step <= end_step 68 | learning_rate = initial * 0.5 * (math.cos(math.pi * step / end_step) + 1) 69 | return learning_rate 70 | 71 | def display_info(epoch, step, watch_list): 72 | sys.stdout.write('[%d][%d]' % (epoch+1, step+1)) 73 | for item in watch_list.items(): 74 | if type(item[1]) in [float, np.float32, np.float64]: 75 | sys.stdout.write(' %s: %2.3f' % (item[0], item[1])) 76 | elif type(item[1]) in [int, bool, np.int32, np.int64, np.bool]: 77 | sys.stdout.write(' %s: %d' % (item[0], item[1])) 78 | sys.stdout.write('\n') 79 | 80 | def get_pairwise_score_label(score_mat, label): 81 | n = label.size 82 | assert score_mat.shape[0]==score_mat.shape[1]==n 83 | triu_indices = np.triu_indices(n, 1) 84 | if len(label.shape)==1: 85 | label = label[:, None] 86 | label_mat = label==label.T 87 | score_vec = score_mat[triu_indices] 88 | label_vec = label_mat[triu_indices] 89 | return score_vec, label_vec 90 | 91 | 92 | ########################## 93 | # Comparison Functions 94 | ########################## 95 | 96 | def l2_normalize(x, axis=1, eps=1e-8): 97 | return x / (np.linalg.norm(x, axis=axis, keepdims=True) + eps) 98 | 99 | def group_normalize(x, ngroup=1): 100 | N, C = x.shape 101 | assert C % ngroup == 0 102 | x = x.reshape(N, ngroup, C//ngroup) 103 | x = l2_normalize(x, axis=2) / math.sqrt(ngroup) 104 | x = x.reshape(N, C) 105 | return x 106 | 107 | def pair_euc_score(x1, x2): 108 | x1, x2 = np.array(x1), np.array(x2) 109 | if x1.ndim == 3: 110 | x1, x2 = x1[:,:,0], x2[:,:,0] 111 | dist = np.sum(np.square(x1 - x2), axis=1) 112 | return -dist 113 | 114 | def pair_cosine_score(x1, x2): 115 | x1, x2 = np.array(x1), np.array(x2) 116 | if x1.ndim == 3: 117 | x1, x2 = x1[:,:,0], x2[:,:,0] 118 | return np.sum(l2_normalize(x1) * l2_normalize(x2), axis=1) 119 | 120 | def pair_inner_product(x1, x2): 121 | x1, x2 = np.array(x1), np.array(x2) 122 | if x1.ndim == 3: 123 | x1, x2 = x1[:,:,0], x2[:,:,0] 124 | return np.sum(x1*x2, axis=1) 125 | 126 | def pair_hammin_distance(x1, x2): 127 | x1, x2 = np.array(x1), np.array(x2) 128 | x1, x2 = x1>=0, x2>=0 129 | return (x1==x2).sum(1) 130 | 131 | def inner_product(x1, x2): 132 | x1, x2 = np.array(x1), np.array(x2) 133 | if x1.ndim == 3: 134 | x1, x2 = x1[:,:,0], x2[:,:,0] 135 | return np.dot(x1, x2.T) 136 | 137 | def cosine_score(x1, x2): 138 | x1, x2 = np.array(x1), np.array(x2) 139 | if x1.ndim == 3: 140 | x1, x2 = x1[:,:,0], x2[:,:,0] 141 | x1 = l2_normalize(x1) 142 | x2 = l2_normalize(x2) 143 | return np.dot(x1, x2.T) 144 | 145 | def euclidean(x1,x2): 146 | assert x1.shape[1]==x2.shape[1] 147 | x2 = x2.transpose() 148 | x1_norm = np.sum(np.square(x1), axis=1, keepdims=True) 149 | x2_norm = np.sum(np.square(x2), axis=0, keepdims=True) 150 | dist = x1_norm + x2_norm - 2*np.dot(x1,x2) 151 | return dist 152 | 153 | def pair_uncertain_score(x1, x2, sigma_sq1=None, sigma_sq2=None): 154 | if sigma_sq1 is None: 155 | assert sigma_sq2 is None, 'either pass in concated features, or mu, sigma_sq for both!' 156 | x1, x2 = np.array(x1), np.array(x2) 157 | D = int(x1.shape[1] / 2) 158 | mu1, sigma_sq1 = x1[:,:,0], x1[:,:,1] 159 | mu2, sigma_sq2 = x2[:,:,0], x2[:,:,1] 160 | else: 161 | x1, x2 = np.array(x1), np.array(x2) 162 | sigma_sq1, sigma_sq2 = np.array(sigma_sq1), np.array(sigma_sq2) 163 | mu1, mu2 = x1, x2 164 | sigma_sq_mutual = sigma_sq1 + sigma_sq2 165 | dist = np.sum(np.square(mu1 - mu2) / sigma_sq_mutual + np.log(sigma_sq_mutual), axis=1) 166 | return -dist 167 | 168 | 169 | def uncertain_score(x1, x2, sigma_sq1=None, sigma_sq2=None): 170 | if sigma_sq1 is None: 171 | assert sigma_sq2 is None, 'either pass in concated features, or mu, sigma_sq for both!' 172 | x1, x2 = np.array(x1), np.array(x2) 173 | D = int(x1.shape[1] / 2) 174 | mu1, sigma_sq1 = x1[:,:,0], x1[:,:,1] 175 | mu2, sigma_sq2 = x2[:,:,0], x2[:,:,1] 176 | else: 177 | x1, x2 = np.array(x1), np.array(x2) 178 | sigma_sq1, sigma_sq2 = np.array(sigma_sq1), np.array(sigma_sq2) 179 | mu1, mu2 = x1, x2 180 | from clib import mutual_likelihood_score_parallel 181 | mu1, mu2 = mu1.astype(np.float32), mu2.astype(np.float32) 182 | sigma_sq1, sigma_sq2 = sigma_sq1.astype(np.float32), sigma_sq2.astype(np.float32) 183 | score = mutual_likelihood_score_parallel(mu1, mu2, sigma_sq1, sigma_sq2) 184 | score = np.array(score) 185 | return score 186 | 187 | 188 | def uncertain_score_simple(x1, x2, sigma_sq1=None, sigma_sq2=None): 189 | if sigma_sq1 is None: 190 | assert sigma_sq2 is None, 'either pass in concated features, or mu, sigma_sq for both!' 191 | x1, x2 = np.array(x1), np.array(x2) 192 | D = int(x1.shape[1] / 2) 193 | mu1, sigma_sq1 = x1[:,:,0], x1[:,:,1] 194 | mu2, sigma_sq2 = x2[:,:,0], x2[:,:,1] 195 | else: 196 | x1, x2 = np.array(x1), np.array(x2) 197 | sigma_sq1, sigma_sq2 = np.array(sigma_sq1), np.array(sigma_sq2) 198 | mu1, mu2 = x1, x2 199 | D = sigma_sq1.shape[1] 200 | sigma_sq1 = sigma_sq1.mean(1, keepdims=True) 201 | sigma_sq2 = sigma_sq2.mean(1, keepdims=True).T 202 | dist = euclidean(mu1,mu2) 203 | score = dist / (sigma_sq1+sigma_sq2) + D * np.log(sigma_sq1+sigma_sq2) 204 | score = -np.array(score) 205 | return score 206 | 207 | ########################## 208 | # Fusion Functions 209 | ########################## 210 | 211 | def average_fuse(x): 212 | x = x.mean(0) 213 | x = l2_normalize(x, axis=0) 214 | return x 215 | 216 | def aggregate_PFE(x, sigma_sq=None, normalize=True, concatenate=True): 217 | if sigma_sq is None: 218 | mu, sigma_sq = x[:,:,0], x[:,:,1] 219 | else: 220 | mu = x 221 | attention = 1. / sigma_sq 222 | attention = attention / np.sum(attention, axis=0, keepdims=True) 223 | 224 | mu_new = np.sum(mu * attention, axis=0) 225 | sigma_sq_new = np.min(sigma_sq, axis=0) 226 | 227 | if normalize: 228 | ngroup = 1 229 | mu_new = mu_new.reshape(-1, 512) 230 | # mu_new = l2_normalize(mu_new, axis=-1) 231 | mu_new = group_normalize(mu_new, ngroup) 232 | mu_new = mu_new.reshape(-1) 233 | 234 | if concatenate: 235 | return np.stack([mu_new, sigma_sq_new], axis=1) 236 | else: 237 | return mu_new, sigma_sq_new 238 | 239 | def l2_normalize_v1(x, axis=None, eps=1e-8): 240 | # from PFE github repo's ijbA eval code 241 | x = x / (eps + np.linalg.norm(x, axis=axis)) 242 | return x 243 | 244 | def aggregate_PFE_v1(x, sigma_sq=None, normalize=True, concatenate=False, return_sigma=True): 245 | # from PFE github repo's ijbA eval code 246 | if sigma_sq is None: 247 | D = int(x.shape[1] / 2) 248 | mu, sigma_sq = x[:,:D], x[:,D:] 249 | else: 250 | mu = x 251 | attention = 1. / sigma_sq 252 | attention = attention / np.sum(attention, axis=0, keepdims=True) 253 | 254 | mu_new = np.sum(mu * attention, axis=0) 255 | sigma_sq_new = np.min(sigma_sq, axis=0) 256 | 257 | if normalize: 258 | mu_new = l2_normalize_v1(mu_new) 259 | 260 | if not return_sigma: 261 | return mu_new 262 | 263 | if concatenate: 264 | return np.stack([mu_new, sigma_sq_new], axis=-1) 265 | else: 266 | return mu_new, sigma_sq_new 267 | 268 | 269 | def write_summary(summary_writer, summary, global_step): 270 | if 'scalar' in summary: 271 | for k,v in summary['scalar'].items(): 272 | summary_writer.add_scalar(k, v, global_step) 273 | if 'histogram' in summary: 274 | for k,v in summary['histogram'].items(): 275 | summary_writer.add_histogram(k, v, global_step) 276 | if 'image' in summary: 277 | for k,v in summary['image'].items(): 278 | summary_writer.add_image(k, v, global_step) 279 | if 'figure' in summary: 280 | for k,v in summary['figure'].items(): 281 | summary_writer.add_figure(k, v, global_step) 282 | summary_writer.file_writer.flush() -------------------------------------------------------------------------------- /utils/evaluate_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pickle 4 | 5 | import mxnet as mx 6 | import numpy as np 7 | import sklearn 8 | import torch 9 | from mxnet import ndarray as nd 10 | from scipy import interpolate 11 | from sklearn.decomposition import PCA 12 | from sklearn.model_selection import KFold 13 | 14 | class LFold: 15 | def __init__(self, n_splits=2, shuffle=False): 16 | self.n_splits = n_splits 17 | if self.n_splits > 1: 18 | self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) 19 | 20 | def split(self, indices): 21 | if self.n_splits > 1: 22 | return self.k_fold.split(indices) 23 | else: 24 | return [(indices, indices)] 25 | 26 | @torch.no_grad() 27 | def load_bin(path, image_size): 28 | try: 29 | with open(path, 'rb') as f: 30 | bins, issame_list = pickle.load(f) # py2 31 | except UnicodeDecodeError as e: 32 | with open(path, 'rb') as f: 33 | bins, issame_list = pickle.load(f, encoding='bytes') # py3 34 | data_list = [] 35 | for flip in [0, 1]: 36 | data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) 37 | data_list.append(data) 38 | for idx in range(len(issame_list) * 2): 39 | _bin = bins[idx] 40 | img = mx.image.imdecode(_bin) 41 | if img.shape[1] != image_size[0]: 42 | img = mx.image.resize_short(img, image_size[0]) 43 | img = nd.transpose(img, axes=(2, 0, 1)) 44 | for flip in [0, 1]: 45 | if flip == 1: 46 | img = mx.ndarray.flip(data=img, axis=2) 47 | data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) 48 | if idx % 1000 == 0: 49 | print('loading bin', idx) 50 | print(data_list[0].shape) 51 | return data_list, issame_list 52 | 53 | @torch.no_grad() 54 | def test(data_set, backbone, batch_size, nfolds=10): 55 | print('testing verification..') 56 | data_list = data_set[0] 57 | issame_list = data_set[1] 58 | embeddings_list = [] 59 | time_consumed = 0.0 60 | for i in range(len(data_list)): 61 | data = data_list[i] 62 | embeddings = None 63 | ba = 0 64 | while ba < data.shape[0]: 65 | bb = min(ba + batch_size, data.shape[0]) 66 | count = bb - ba 67 | _data = data[bb - batch_size: bb] 68 | time0 = datetime.datetime.now() 69 | img = ((_data / 255) - 0.5) / 0.5 70 | net_out: torch.Tensor = backbone(img) 71 | _embeddings = net_out.detach().cpu().numpy() 72 | time_now = datetime.datetime.now() 73 | diff = time_now - time0 74 | time_consumed += diff.total_seconds() 75 | if embeddings is None: 76 | embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) 77 | embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] 78 | ba = bb 79 | embeddings_list.append(embeddings) 80 | 81 | _xnorm = 0.0 82 | _xnorm_cnt = 0 83 | for embed in embeddings_list: 84 | for i in range(embed.shape[0]): 85 | _em = embed[i] 86 | _norm = np.linalg.norm(_em) 87 | _xnorm += _norm 88 | _xnorm_cnt += 1 89 | _xnorm /= _xnorm_cnt 90 | 91 | embeddings = embeddings_list[0].copy() 92 | embeddings = sklearn.preprocessing.normalize(embeddings) 93 | acc1 = 0.0 94 | std1 = 0.0 95 | embeddings = embeddings_list[0] + embeddings_list[1] 96 | embeddings = sklearn.preprocessing.normalize(embeddings) 97 | print(embeddings.shape) 98 | print('infer time', time_consumed) 99 | _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) 100 | acc2, std2 = np.mean(accuracy), np.std(accuracy) 101 | return acc1, std1, acc2, std2, _xnorm, embeddings_list 102 | 103 | 104 | # MIT License 105 | # 106 | # Copyright (c) 2016 David Sandberg 107 | # 108 | # Permission is hereby granted, free of charge, to any person obtaining a copy 109 | # of this software and associated documentation files (the "Software"), to deal 110 | # in the Software without restriction, including without limitation the rights 111 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 112 | # copies of the Software, and to permit persons to whom the Software is 113 | # furnished to do so, subject to the following conditions: 114 | # 115 | # The above copyright notice and this permission notice shall be included in all 116 | # copies or substantial portions of the Software. 117 | # 118 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 119 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 120 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 121 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 122 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 123 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 124 | # SOFTWARE. 125 | 126 | def calculate_roc(thresholds, 127 | embeddings1, 128 | embeddings2, 129 | actual_issame, 130 | nrof_folds=10, 131 | pca=0): 132 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 133 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 134 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 135 | nrof_thresholds = len(thresholds) 136 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 137 | 138 | tprs = np.zeros((nrof_folds, nrof_thresholds)) 139 | fprs = np.zeros((nrof_folds, nrof_thresholds)) 140 | accuracy = np.zeros((nrof_folds)) 141 | indices = np.arange(nrof_pairs) 142 | 143 | if pca == 0: 144 | diff = np.subtract(embeddings1, embeddings2) 145 | dist = np.sum(np.square(diff), 1) 146 | 147 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 148 | if pca > 0: 149 | print('doing pca on', fold_idx) 150 | embed1_train = embeddings1[train_set] 151 | embed2_train = embeddings2[train_set] 152 | _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) 153 | pca_model = PCA(n_components=pca) 154 | pca_model.fit(_embed_train) 155 | embed1 = pca_model.transform(embeddings1) 156 | embed2 = pca_model.transform(embeddings2) 157 | embed1 = sklearn.preprocessing.normalize(embed1) 158 | embed2 = sklearn.preprocessing.normalize(embed2) 159 | diff = np.subtract(embed1, embed2) 160 | dist = np.sum(np.square(diff), 1) 161 | 162 | # Find the best threshold for the fold 163 | acc_train = np.zeros((nrof_thresholds)) 164 | for threshold_idx, threshold in enumerate(thresholds): 165 | _, _, acc_train[threshold_idx] = calculate_accuracy( 166 | threshold, dist[train_set], actual_issame[train_set]) 167 | best_threshold_index = np.argmax(acc_train) 168 | for threshold_idx, threshold in enumerate(thresholds): 169 | tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( 170 | threshold, dist[test_set], 171 | actual_issame[test_set]) 172 | _, _, accuracy[fold_idx] = calculate_accuracy( 173 | thresholds[best_threshold_index], dist[test_set], 174 | actual_issame[test_set]) 175 | 176 | tpr = np.mean(tprs, 0) 177 | fpr = np.mean(fprs, 0) 178 | return tpr, fpr, accuracy 179 | 180 | 181 | def calculate_accuracy(threshold, dist, actual_issame): 182 | predict_issame = np.less(dist, threshold) 183 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 184 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 185 | tn = np.sum( 186 | np.logical_and(np.logical_not(predict_issame), 187 | np.logical_not(actual_issame))) 188 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 189 | 190 | tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) 191 | fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) 192 | acc = float(tp + tn) / dist.size 193 | return tpr, fpr, acc 194 | 195 | 196 | def calculate_val(thresholds, 197 | embeddings1, 198 | embeddings2, 199 | actual_issame, 200 | far_target, 201 | nrof_folds=10): 202 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 203 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 204 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 205 | nrof_thresholds = len(thresholds) 206 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 207 | 208 | val = np.zeros(nrof_folds) 209 | far = np.zeros(nrof_folds) 210 | 211 | diff = np.subtract(embeddings1, embeddings2) 212 | dist = np.sum(np.square(diff), 1) 213 | indices = np.arange(nrof_pairs) 214 | 215 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 216 | 217 | # Find the threshold that gives FAR = far_target 218 | far_train = np.zeros(nrof_thresholds) 219 | for threshold_idx, threshold in enumerate(thresholds): 220 | _, far_train[threshold_idx] = calculate_val_far( 221 | threshold, dist[train_set], actual_issame[train_set]) 222 | if np.max(far_train) >= far_target: 223 | f = interpolate.interp1d(far_train, thresholds, kind='slinear') 224 | threshold = f(far_target) 225 | else: 226 | threshold = 0.0 227 | 228 | val[fold_idx], far[fold_idx] = calculate_val_far( 229 | threshold, dist[test_set], actual_issame[test_set]) 230 | 231 | val_mean = np.mean(val) 232 | far_mean = np.mean(far) 233 | val_std = np.std(val) 234 | return val_mean, val_std, far_mean 235 | 236 | 237 | def calculate_val_far(threshold, dist, actual_issame): 238 | predict_issame = np.less(dist, threshold) 239 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 240 | false_accept = np.sum( 241 | np.logical_and(predict_issame, np.logical_not(actual_issame))) 242 | n_same = np.sum(actual_issame) 243 | n_diff = np.sum(np.logical_not(actual_issame)) 244 | val = float(true_accept) / float(n_same) 245 | far = float(false_accept) / float(n_diff) 246 | return val, far 247 | 248 | 249 | def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): 250 | # Calculate evaluation metrics 251 | thresholds = np.arange(0, 4, 0.01) 252 | embeddings1 = embeddings[0::2] 253 | embeddings2 = embeddings[1::2] 254 | tpr, fpr, accuracy = calculate_roc(thresholds, 255 | embeddings1, 256 | embeddings2, 257 | np.asarray(actual_issame), 258 | nrof_folds=nrof_folds, 259 | pca=pca) 260 | thresholds = np.arange(0, 4, 0.001) 261 | val, val_std, far = calculate_val(thresholds, 262 | embeddings1, 263 | embeddings2, 264 | np.asarray(actual_issame), 265 | 1e-3, 266 | nrof_folds=nrof_folds) 267 | return tpr, fpr, accuracy, val, val_std, far -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from datetime import datetime, timedelta 5 | 6 | import numpy as np 7 | import torch 8 | import random 9 | import config 10 | from backbones import get_model 11 | from heads import get_head 12 | from dataset.dataset import get_dataloader 13 | from losses import CombinedMarginLoss 14 | from lr_scheduler import PolynomialLRWarmup 15 | from torch import distributed 16 | from torch.distributed import destroy_process_group 17 | from torch.utils.data import DataLoader 18 | from torch.utils.tensorboard import SummaryWriter 19 | from utils.utils_callbacks import CallBackLogging, CallBackVerification 20 | from utils.utils_distributed_sampler import setup_seed 21 | from utils.utils_logging import AverageMeter, init_logging 22 | from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook 23 | 24 | assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \ 25 | we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future." 26 | 27 | rank = int(os.environ["RANK"]) 28 | local_rank = int(os.environ["LOCAL_RANK"]) 29 | world_size = int(os.environ["WORLD_SIZE"]) 30 | distributed.init_process_group("nccl") 31 | 32 | 33 | 34 | def main(args): 35 | setup_seed(seed=args.seed, cuda_deterministic=False) 36 | 37 | torch.cuda.set_device(local_rank) 38 | 39 | os.makedirs(args.output, exist_ok=True) 40 | init_logging(rank, args.output) 41 | 42 | summary_writer = ( 43 | SummaryWriter(log_dir=os.path.join(args.output, "tensorboard")) 44 | if rank == 0 45 | else None 46 | ) 47 | 48 | wandb_logger = None 49 | if args.using_wandb: 50 | import wandb 51 | # Sign in to wandb 52 | try: 53 | wandb.login(key=args.wandb_key) 54 | except Exception as e: 55 | print("WandB Key must be provided in config file (base.py).") 56 | print(f"Config Error: {e}") 57 | # Initialize wandb 58 | run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}" 59 | run_name = run_name if args.suffix_run_name is None else run_name + f"_{args.suffix_run_name}" 60 | try: 61 | wandb_logger = wandb.init( 62 | entity = args.wandb_entity, 63 | project = args.wandb_project, 64 | sync_tensorboard = True, 65 | resume=args.wandb_resume, 66 | name = run_name, 67 | notes = args.notes) if rank == 0 or args.wandb_log_all else None 68 | if wandb_logger: 69 | wandb_logger.config.update(args) 70 | except Exception as e: 71 | print("WandB Data (Entity and Project name) must be provided in config file (base.py).") 72 | print(f"Config Error: {e}") 73 | train_loader = get_dataloader( 74 | args.rec, 75 | local_rank, 76 | args.batch_size, 77 | args.image_size, 78 | args.dali, 79 | args.dali_aug, 80 | args.seed, 81 | args.num_workers 82 | ) 83 | 84 | backbone = get_model(args.network, dropout=0.0, fp16=args.fp16, num_features=args.embedding_size, r=args.lora_rank, scale=args.lora_scale, use_lora=args.use_lora).cuda() 85 | backbone = torch.nn.parallel.DistributedDataParallel( 86 | module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, 87 | find_unused_parameters=True) 88 | backbone.register_comm_hook(None, fp16_compress_hook) 89 | 90 | backbone.train() 91 | backbone._set_static_graph() 92 | 93 | margin_loss = CombinedMarginLoss( 94 | 64, 95 | args.margin_list[0], 96 | args.margin_list[1], 97 | args.margin_list[2], 98 | args.interclass_filtering_threshold 99 | ) 100 | head = get_head(args.head, 101 | margin_loss=margin_loss, embedding_size=args.embedding_size, num_classes=args.num_classes, 102 | sample_rate=args.sample_rate, fp16=False) 103 | 104 | if args.use_lora: 105 | weights_path = os.path.join(args.load_pretrained, f"checkpoint_gpu_{rank}.pt") 106 | if os.path.isfile(weights_path): 107 | dict_checkpoint = torch.load(weights_path) 108 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"], strict=False) 109 | else: 110 | dict_checkpoint = torch.load(os.path.join(args.load_pretrained, f"model.pt")) 111 | backbone.module.load_state_dict(dict_checkpoint, strict=False) 112 | for p in head.parameters(): 113 | p.requires_grad = True 114 | for p in backbone.parameters(): 115 | p.requires_grad = False 116 | for name, p in backbone.named_parameters(): 117 | if 'trainable_lora' in name: 118 | p.requires_grad = True 119 | 120 | 121 | if args.optimizer == "sgd": 122 | total_params = sum(p.numel() for p in backbone.parameters()) 123 | trainable_params = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + sum(p.numel() for p in head.parameters() if p.requires_grad) 124 | logging.info("Total Parameters: %d", total_params) 125 | logging.info('Number of trainable parameters: %d', trainable_params) 126 | head.train().cuda() 127 | opt = torch.optim.SGD( 128 | params=[{"params": filter(lambda p: p.requires_grad, backbone.parameters()) }, {"params": head.parameters()}], 129 | lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 130 | elif args.optimizer == "adamw": 131 | total_params = sum(p.numel() for p in backbone.parameters()) 132 | trainable_params = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + sum(p.numel() for p in head.parameters() if p.requires_grad) 133 | logging.info("Total Parameters: %d", total_params) 134 | logging.info('Number of trainable parameters: %d', trainable_params) 135 | head.train().cuda() 136 | opt = torch.optim.AdamW( 137 | params=[{"params": filter(lambda p: p.requires_grad, backbone.parameters()) }, {"params": head.parameters()}], 138 | lr=args.lr, weight_decay=args.weight_decay) 139 | else: 140 | raise 141 | 142 | args.total_batch_size = args.batch_size * world_size 143 | args.warmup_step = args.num_image // args.total_batch_size * args.warmup_epoch 144 | args.total_step = args.num_image // args.total_batch_size * args.num_epoch 145 | 146 | lr_scheduler = PolynomialLRWarmup( 147 | optimizer=opt, 148 | warmup_iters=args.warmup_step, 149 | total_iters=args.total_step) 150 | 151 | start_epoch = 0 152 | global_step = 0 153 | if args.resume: 154 | dict_checkpoint = torch.load(os.path.join(args.output, f"checkpoint_gpu_{rank}.pt")) 155 | start_epoch = dict_checkpoint["epoch"] 156 | global_step = dict_checkpoint["global_step"] 157 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) 158 | head.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) 159 | opt.load_state_dict(dict_checkpoint["state_optimizer"]) 160 | lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) 161 | del dict_checkpoint 162 | 163 | 164 | for arg in vars(args): 165 | num_space = 25 - len(arg) 166 | logging.info(": " + arg + " " * num_space + str(getattr(args, arg))) 167 | 168 | callback_verification = CallBackVerification( 169 | val_targets=args.val_targets, rec_prefix=args.rec, 170 | summary_writer=summary_writer, wandb_logger = wandb_logger 171 | ) 172 | callback_logging = CallBackLogging( 173 | frequent=args.frequent, 174 | total_step=args.total_step, 175 | batch_size=args.batch_size, 176 | start_step = global_step, 177 | writer=summary_writer 178 | ) 179 | 180 | loss_am = AverageMeter() 181 | amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) 182 | 183 | for epoch in range(start_epoch, args.num_epoch): 184 | 185 | if isinstance(train_loader, DataLoader): 186 | train_loader.sampler.set_epoch(epoch) 187 | for _, (img, local_labels) in enumerate(train_loader): 188 | global_step += 1 189 | 190 | local_embeddings = backbone(img) 191 | loss: torch.Tensor = head(local_embeddings, local_labels) 192 | 193 | assert loss.requires_grad 194 | 195 | if args.fp16: 196 | amp.scale(loss).backward() 197 | if global_step % args.gradient_acc == 0: 198 | amp.unscale_(opt) 199 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 200 | amp.step(opt) 201 | amp.update() 202 | opt.zero_grad() 203 | else: 204 | loss.backward() 205 | if global_step % args.gradient_acc == 0: 206 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 207 | opt.step() 208 | opt.zero_grad() 209 | lr_scheduler.step() 210 | 211 | with torch.no_grad(): 212 | if wandb_logger: 213 | wandb_logger.log({ 214 | 'Loss/Step Loss': loss.item(), 215 | 'Loss/Train Loss': loss_am.avg, 216 | 'Process/Step': global_step, 217 | 'Process/Epoch': epoch 218 | }) 219 | 220 | loss_am.update(loss.item(), 1) 221 | callback_logging(global_step, loss_am, epoch, args.fp16, lr_scheduler.get_last_lr()[0], amp) 222 | 223 | if global_step % args.verbose == 0 and global_step > 0: 224 | callback_verification(global_step, backbone) 225 | 226 | if args.save_all_states: 227 | checkpoint = { 228 | "epoch": epoch + 1, 229 | "global_step": global_step, 230 | "state_dict_backbone": backbone.module.state_dict(), 231 | "state_dict_softmax_fc": head.state_dict(), 232 | "state_optimizer": opt.state_dict(), 233 | "state_lr_scheduler": lr_scheduler.state_dict() 234 | } 235 | torch.save(checkpoint, os.path.join(args.output, f"checkpoint_gpu_{rank}.pt")) 236 | 237 | if rank == 0: 238 | path_module = os.path.join(args.output, "model.pt") 239 | torch.save(backbone.module.state_dict(), path_module) 240 | 241 | if wandb_logger and args.save_artifacts: 242 | artifact_name = f"{run_name}_E{epoch}" 243 | model = wandb.Artifact(artifact_name, type='model') 244 | model.add_file(path_module) 245 | wandb_logger.log_artifact(model) 246 | 247 | if args.dali: 248 | train_loader.reset() 249 | 250 | if rank == 0: 251 | path_module = os.path.join(args.output, "model.pt") 252 | torch.save(backbone.module.state_dict(), path_module) 253 | 254 | if wandb_logger and args.save_artifacts: 255 | artifact_name = f"{run_name}_Final" 256 | model = wandb.Artifact(artifact_name, type='model') 257 | model.add_file(path_module) 258 | wandb_logger.log_artifact(model) 259 | 260 | torch.distributed.barrier() 261 | destroy_process_group() 262 | return 263 | 264 | if __name__ == "__main__": 265 | torch.backends.cudnn.benchmark = True 266 | args = config.get_args() 267 | main(args) -------------------------------------------------------------------------------- /validation_lq/validate_tinyface_iqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import argparse 5 | import pandas as pd 6 | import tinyface_helper 7 | import sys, os 8 | sys.path.insert(0, os.getcwd()) 9 | from backbones import get_model 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms 12 | from PIL import Image 13 | import cv2 14 | from torchvision.transforms import InterpolationMode 15 | from skimage import img_as_float32 16 | import pyiqa 17 | 18 | 19 | def str2bool(v): 20 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 21 | return True 22 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError('Boolean value expected.') 26 | 27 | def l2_norm(input, axis=1): 28 | """l2 normalize 29 | """ 30 | norm = torch.norm(input, 2, axis, True) 31 | output = torch.div(input, norm) 32 | return output, norm 33 | 34 | 35 | def fuse_features_with_norm(stacked_embeddings, stacked_norms, fusion_method='norm_weighted_avg'): 36 | assert stacked_embeddings.ndim == 3 # (n_features_to_fuse, batch_size, channel) 37 | if stacked_norms is not None: 38 | assert stacked_norms.ndim == 3 # (n_features_to_fuse, batch_size, 1) 39 | else: 40 | assert fusion_method not in ['norm_weighted_avg', 'pre_norm_vector_add'] 41 | 42 | if fusion_method == 'norm_weighted_avg': 43 | weights = stacked_norms / stacked_norms.sum(dim=0, keepdim=True) 44 | fused = (stacked_embeddings * weights).sum(dim=0) 45 | fused, _ = l2_norm(fused, axis=1) 46 | fused_norm = stacked_norms.mean(dim=0) 47 | elif fusion_method == 'pre_norm_vector_add': 48 | pre_norm_embeddings = stacked_embeddings * stacked_norms 49 | fused = pre_norm_embeddings.sum(dim=0) 50 | fused, fused_norm = l2_norm(fused, axis=1) 51 | elif fusion_method == 'average': 52 | fused = stacked_embeddings.sum(dim=0) 53 | fused, _ = l2_norm(fused, axis=1) 54 | if stacked_norms is None: 55 | fused_norm = torch.ones((len(fused), 1)) 56 | else: 57 | fused_norm = stacked_norms.mean(dim=0) 58 | elif fusion_method == 'concat': 59 | fused = torch.cat([stacked_embeddings[0], stacked_embeddings[1]], dim=-1) 60 | if stacked_norms is None: 61 | fused_norm = torch.ones((len(fused), 1)) 62 | else: 63 | fused_norm = stacked_norms.mean(dim=0) 64 | elif fusion_method == 'faceness_score': 65 | raise ValueError('not implemented yet. please refer to https://github.com/deepinsight/insightface/blob/5d3be6da49275602101ad122601b761e36a66a01/recognition/_evaluation_/ijb/ijb_11.py#L296') 66 | # note that they do not use normalization afterward. 67 | else: 68 | raise ValueError('not a correct fusion method', fusion_method) 69 | 70 | return fused, fused_norm 71 | 72 | def generate_alpha(img, iqa, thresh): 73 | device = img.device 74 | BS, C, H, W = img.shape 75 | alpha = torch.zeros((BS, 1), dtype=torch.float32, device=device) 76 | 77 | score = iqa(img) 78 | threshold = thresh 79 | for i in range(BS): 80 | if score[i] == threshold: 81 | alpha[i] = 0.5 82 | elif score[i] < threshold: 83 | alpha[i] = 0.5 - (threshold - score[i]) 84 | else: 85 | alpha[i] = 0.5 + (score[i] - threshold) 86 | return alpha 87 | 88 | 89 | def infer(model, dataloader, iqa, threshold, use_flip_test, fusion_method, gpu): 90 | model.eval() 91 | features = [] 92 | norms = [] 93 | device = "cuda:" + str(gpu) 94 | if iqa == "brisque": 95 | iqa = pyiqa.create_metric('brisque').cuda() 96 | elif iqa == "cnniqa": 97 | iqa = pyiqa.create_metric('cnniqa').cuda() 98 | threshold = threshold 99 | with torch.no_grad(): 100 | for images, idx in tqdm(dataloader): 101 | images = images.to(device) 102 | alpha = generate_alpha(images, iqa, threshold) 103 | feature = model(images, alpha) 104 | if isinstance(feature, tuple): 105 | feature, norm = feature 106 | else: 107 | norm = torch.norm(feature, 2, 1, True) 108 | feature = torch.div(feature, norm) 109 | 110 | if use_flip_test: 111 | fliped_images = torch.flip(images, dims=[3]) 112 | fliped_images = fliped_images.to(device) 113 | alpha = generate_alpha(fliped_images, iqa, threshold) 114 | flipped_feature = model(fliped_images, alpha) 115 | if isinstance(flipped_feature, tuple): 116 | flipped_feature, flipped_norm = flipped_feature 117 | else: 118 | flipped_norm = torch.norm(flipped_feature, 2, 1, True) 119 | flipped_feature = torch.div(flipped_feature, flipped_norm) 120 | 121 | stacked_embeddings = torch.stack([feature, flipped_feature], dim=0) 122 | if norm is not None: 123 | stacked_norms = torch.stack([norm, flipped_norm], dim=0) 124 | else: 125 | stacked_norms = None 126 | 127 | fused_feature, fused_norm = fuse_features_with_norm(stacked_embeddings, stacked_norms, fusion_method=fusion_method) 128 | features.append(fused_feature.cpu().numpy()) 129 | norms.append(fused_norm.cpu().numpy()) 130 | else: 131 | features.append(feature.cpu().numpy()) 132 | norms.append(norm.cpu().numpy()) 133 | 134 | features = np.concatenate(features, axis=0) 135 | norms = np.concatenate(norms, axis=0) 136 | return features, norms 137 | 138 | def load_pretrained_model(model, model_name, gpu, lora_rank, lora_scale, use_lora): 139 | # load model and pretrained statedict 140 | ckpt_path = model_name 141 | model = get_model(model, dropout=0.0, fp16=False, num_features=512, r=lora_rank, scale=lora_scale, use_lora=use_lora) 142 | 143 | # model = net.build_model(arch) 144 | statedict = torch.load(ckpt_path, map_location=torch.device('cuda:' + str(gpu))) 145 | model.load_state_dict(statedict) 146 | model.eval() 147 | return model 148 | 149 | class ListDataset(Dataset): 150 | def __init__(self, img_list, image_size, image_is_saved_with_swapped_B_and_R=True): 151 | super(ListDataset, self).__init__() 152 | 153 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 154 | # face_emore/img has images saved with B and G (of RGB) swapped. 155 | # Since training data loader uses PIL (results in RGB) to read image 156 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 157 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 158 | # then you should set image_is_saved_with_swapped_B_and_R=True 159 | 160 | self.img_list = img_list 161 | self.transform = transforms.Compose([ 162 | transforms.ToTensor(), 163 | transforms.Resize(size=(image_size,image_size), interpolation=InterpolationMode.BICUBIC), 164 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 165 | 166 | self.image_is_saved_with_swapped_B_and_R = image_is_saved_with_swapped_B_and_R 167 | 168 | 169 | def __len__(self): 170 | return len(self.img_list) 171 | 172 | def __getitem__(self, idx): 173 | image_path = self.img_list[idx] 174 | img = cv2.imread(image_path) 175 | img = img[:, :, :3] 176 | 177 | if self.image_is_saved_with_swapped_B_and_R: 178 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 179 | 180 | img = Image.fromarray(img) 181 | img = self.transform(img) 182 | return img, idx 183 | 184 | def prepare_dataloader(img_list, batch_size, image_size, num_workers=0, image_is_saved_with_swapped_B_and_R=True): 185 | # image_is_saved_with_swapped_B_and_R: correctly saved image should have this set to False 186 | # face_emore/img has images saved with B and G (of RGB) swapped. 187 | # Since training data loader uses PIL (results in RGB) to read image 188 | # and validation data loader uses cv2 (results in BGR) to read image, this swap was okay. 189 | # But if you want to evaluate on the training data such as face_emore/img (B and G swapped), 190 | # then you should set image_is_saved_with_swapped_B_and_R=True 191 | 192 | image_dataset = ListDataset(img_list, image_size, image_is_saved_with_swapped_B_and_R=image_is_saved_with_swapped_B_and_R) 193 | dataloader = DataLoader(image_dataset, 194 | batch_size=batch_size, 195 | shuffle=False, 196 | drop_last=False, 197 | num_workers=num_workers) 198 | return dataloader 199 | 200 | 201 | def get_save_path(model_load_path): 202 | directory, _ = os.path.split(model_load_path) 203 | results_save_path = os.path.join(directory, 'results') 204 | 205 | return results_save_path 206 | 207 | if __name__ == '__main__': 208 | 209 | parser = argparse.ArgumentParser(description='tinyface') 210 | 211 | parser.add_argument('--data_root', default='/mnt/store/knaraya4/data') 212 | parser.add_argument('--gpu', default=0, type=int, help='gpu id') 213 | parser.add_argument('--batch_size', default=1024, type=int, help='') 214 | parser.add_argument('--model_load_path', type=str) 215 | parser.add_argument('--model_type', type=str, default='r50') 216 | parser.add_argument('--lora_rank', type=int, default=4) 217 | parser.add_argument('--lora_scale', type=int, default=1) 218 | parser.add_argument('--use_lora', action='store_true') 219 | parser.add_argument('--use_flip_test', type=str2bool, default='True') 220 | parser.add_argument('--image_size', type=int, default=112) 221 | parser.add_argument('--fusion_method', type=str, default='pre_norm_vector_add', choices=('average', 'norm_weighted_avg', 'pre_norm_vector_add', 'concat', 'faceness_score')) 222 | parser.add_argument('--iqa', type=str, default='brisque') 223 | parser.add_argument('--threshold', type=float, default=0.5) 224 | args = parser.parse_args() 225 | 226 | # load model 227 | model_load_path = args.model_load_path 228 | print("Model Load Path", model_load_path) 229 | model = load_pretrained_model(args.model_type, model_load_path, args.gpu, args.lora_rank, args.lora_scale, args.use_lora) 230 | model.to('cuda:{}'.format(args.gpu)) 231 | 232 | tinyface_test = tinyface_helper.TinyFaceTest(tinyface_root=args.data_root,alignment_dir_name='tinyface_aligned_112') 233 | 234 | # set save root 235 | gpu_id = args.gpu 236 | save_path = get_save_path(model_load_path) 237 | print("Save Path: ", save_path) 238 | 239 | if not os.path.exists(save_path): 240 | os.makedirs(save_path) 241 | print('save_path: {}'.format(save_path)) 242 | 243 | img_paths = tinyface_test.image_paths 244 | print('total images : {}'.format(len(img_paths))) 245 | dataloader = prepare_dataloader(img_paths, args.batch_size, args.image_size, num_workers=8, image_is_saved_with_swapped_B_and_R=True) 246 | features, norms = infer(model, dataloader, args.iqa, args.threshold, use_flip_test=args.use_flip_test, fusion_method=args.fusion_method, gpu=args.gpu) 247 | results = tinyface_test.test_identification(features, ranks=[1,5,20]) 248 | print(results) 249 | pd.DataFrame({'rank':[1,5,20], 'values':results}).to_csv(os.path.join(save_path, f'tinyface_{args.fusion_method}.csv')) -------------------------------------------------------------------------------- /train_iqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from datetime import datetime, timedelta 5 | from skimage import img_as_float32 6 | from brisque import BRISQUE 7 | 8 | import numpy as np 9 | import torch 10 | import random 11 | import config 12 | import torchvision 13 | from backbones import get_model 14 | from heads import get_head 15 | from dataset.dataset import get_dataloader 16 | from losses import CombinedMarginLoss 17 | from lr_scheduler import PolynomialLRWarmup 18 | from torch import distributed 19 | from torch.distributed import destroy_process_group 20 | from torch.utils.data import DataLoader 21 | from torch.utils.tensorboard import SummaryWriter 22 | from utils.utils_callbacks import CallBackLogging, CallBackVerification 23 | from utils.utils_distributed_sampler import setup_seed 24 | from utils.utils_logging import AverageMeter, init_logging 25 | from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook 26 | import pyiqa 27 | 28 | assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \ 29 | we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future." 30 | 31 | rank = int(os.environ["RANK"]) 32 | local_rank = int(os.environ["LOCAL_RANK"]) 33 | world_size = int(os.environ["WORLD_SIZE"]) 34 | distributed.init_process_group("nccl") 35 | 36 | 37 | def generate_alpha(img, iqa, thresh): 38 | device = img.device 39 | BS, C, H, W = img.shape 40 | alpha = torch.zeros((BS, 1), dtype=torch.float32, device=device) 41 | 42 | score = iqa(img) 43 | threshold = thresh 44 | for i in range(BS): 45 | if score[i] == threshold: 46 | alpha[i] = 0.5 47 | elif score[i] < threshold: 48 | alpha[i] = 0.5 - (threshold - score[i]) 49 | else: 50 | alpha[i] = 0.5 + (score[i] - threshold) 51 | return alpha 52 | 53 | def main(args): 54 | setup_seed(seed=args.seed, cuda_deterministic=False) 55 | 56 | torch.cuda.set_device(local_rank) 57 | 58 | os.makedirs(args.output, exist_ok=True) 59 | init_logging(rank, args.output) 60 | 61 | summary_writer = ( 62 | SummaryWriter(log_dir=os.path.join(args.output, "tensorboard")) 63 | if rank == 0 64 | else None 65 | ) 66 | 67 | wandb_logger = None 68 | if args.using_wandb: 69 | import wandb 70 | # Sign in to wandb 71 | try: 72 | wandb.login(key=args.wandb_key) 73 | except Exception as e: 74 | print("WandB Key must be provided in config file (base.py).") 75 | print(f"Config Error: {e}") 76 | # Initialize wandb 77 | run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}" 78 | run_name = run_name if args.suffix_run_name is None else run_name + f"_{args.suffix_run_name}" 79 | try: 80 | wandb_logger = wandb.init( 81 | entity = args.wandb_entity, 82 | project = args.wandb_project, 83 | sync_tensorboard = True, 84 | resume=args.wandb_resume, 85 | name = run_name, 86 | notes = args.notes) if rank == 0 or args.wandb_log_all else None 87 | if wandb_logger: 88 | wandb_logger.config.update(args) 89 | except Exception as e: 90 | print("WandB Data (Entity and Project name) must be provided in config file (base.py).") 91 | print(f"Config Error: {e}") 92 | train_loader = get_dataloader( 93 | args.rec, 94 | local_rank, 95 | args.batch_size, 96 | args.image_size, 97 | args.dali, 98 | args.dali_aug, 99 | args.seed, 100 | args.num_workers 101 | ) 102 | 103 | backbone = get_model(args.network, dropout=0.0, fp16=args.fp16, num_features=args.embedding_size, r=args.lora_rank, scale=args.lora_scale, use_lora=args.use_lora).cuda() 104 | backbone = torch.nn.parallel.DistributedDataParallel( 105 | module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16, 106 | find_unused_parameters=True) 107 | backbone.register_comm_hook(None, fp16_compress_hook) 108 | 109 | backbone.train() 110 | backbone._set_static_graph() 111 | 112 | margin_loss = CombinedMarginLoss( 113 | 64, 114 | args.margin_list[0], 115 | args.margin_list[1], 116 | args.margin_list[2], 117 | args.interclass_filtering_threshold 118 | ) 119 | head = get_head(args.head, 120 | margin_loss=margin_loss, embedding_size=args.embedding_size, num_classes=args.num_classes, 121 | sample_rate=args.sample_rate, fp16=False) 122 | 123 | if args.use_lora: 124 | weights_path = os.path.join(args.load_pretrained, f"checkpoint_gpu_{rank}.pt") 125 | if os.path.isfile(weights_path): 126 | dict_checkpoint = torch.load(weights_path) 127 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"], strict=False) 128 | else: 129 | dict_checkpoint = torch.load(os.path.join(args.load_pretrained, f"model.pt")) 130 | backbone.module.load_state_dict(dict_checkpoint, strict=False) 131 | for p in head.parameters(): 132 | p.requires_grad = True 133 | for p in backbone.parameters(): 134 | p.requires_grad = False 135 | for name, p in backbone.named_parameters(): 136 | if 'trainable_lora' in name: 137 | p.requires_grad = True 138 | 139 | 140 | 141 | if args.optimizer == "sgd": 142 | total_params = sum(p.numel() for p in backbone.parameters()) 143 | trainable_params = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + sum(p.numel() for p in head.parameters() if p.requires_grad) 144 | logging.info("Total Parameters: %d", total_params) 145 | logging.info('Number of trainable parameters: %d', trainable_params) 146 | head.train().cuda() 147 | opt = torch.optim.SGD( 148 | params=[{"params": filter(lambda p: p.requires_grad, backbone.parameters()) }, {"params": head.parameters()}], 149 | lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 150 | elif args.optimizer == "adamw": 151 | total_params = sum(p.numel() for p in backbone.parameters()) 152 | trainable_params = sum(p.numel() for p in backbone.parameters() if p.requires_grad) + sum(p.numel() for p in head.parameters() if p.requires_grad) 153 | logging.info("Total Parameters: %d", total_params) 154 | logging.info('Number of trainable parameters: %d', trainable_params) 155 | head.train().cuda() 156 | opt = torch.optim.AdamW( 157 | params=[{"params": filter(lambda p: p.requires_grad, backbone.parameters()) }, {"params": head.parameters()}], 158 | lr=args.lr, weight_decay=args.weight_decay) 159 | else: 160 | raise 161 | 162 | 163 | args.total_batch_size = args.batch_size * world_size 164 | args.warmup_step = args.num_image // args.total_batch_size * args.warmup_epoch 165 | args.total_step = args.num_image // args.total_batch_size * args.num_epoch 166 | 167 | lr_scheduler = PolynomialLRWarmup( 168 | optimizer=opt, 169 | warmup_iters=args.warmup_step, 170 | total_iters=args.total_step) 171 | 172 | start_epoch = 0 173 | global_step = 0 174 | if args.resume: 175 | dict_checkpoint = torch.load(os.path.join(args.output, f"checkpoint_gpu_{rank}.pt")) 176 | start_epoch = dict_checkpoint["epoch"] 177 | global_step = dict_checkpoint["global_step"] 178 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) 179 | head.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) 180 | opt.load_state_dict(dict_checkpoint["state_optimizer"]) 181 | lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) 182 | del dict_checkpoint 183 | 184 | 185 | for arg in vars(args): 186 | num_space = 25 - len(arg) 187 | logging.info(": " + arg + " " * num_space + str(getattr(args, arg))) 188 | 189 | callback_verification = CallBackVerification( 190 | val_targets=args.val_targets, rec_prefix=args.rec, 191 | summary_writer=summary_writer, wandb_logger = wandb_logger 192 | ) 193 | callback_logging = CallBackLogging( 194 | frequent=args.frequent, 195 | total_step=args.total_step, 196 | batch_size=args.batch_size, 197 | start_step = global_step, 198 | writer=summary_writer 199 | ) 200 | 201 | loss_am = AverageMeter() 202 | amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) 203 | 204 | if args.iqa == "brisque": 205 | iqa = pyiqa.create_metric('brisque').cuda() 206 | threshold = args.threshold 207 | elif args.iqa == "cnniqa": 208 | iqa = pyiqa.create_metric('cnniqa').cuda() 209 | threshold = args.threshold 210 | 211 | logging.info("Total Parameters: %d", sum(p.numel() for p in iqa.parameters())) 212 | logging.info("IQA: %d", iqa.lower_better) 213 | 214 | for epoch in range(start_epoch, args.num_epoch): 215 | 216 | if isinstance(train_loader, DataLoader): 217 | train_loader.sampler.set_epoch(epoch) 218 | for _, (img, local_labels) in enumerate(train_loader): 219 | global_step += 1 220 | 221 | alpha = generate_alpha(img.clone(), iqa, threshold) 222 | local_embeddings = backbone(img, alpha) 223 | loss: torch.Tensor = head(local_embeddings, local_labels) 224 | 225 | assert loss.requires_grad 226 | 227 | if args.fp16: 228 | amp.scale(loss).backward() 229 | if global_step % args.gradient_acc == 0: 230 | amp.unscale_(opt) 231 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 232 | amp.step(opt) 233 | amp.update() 234 | opt.zero_grad() 235 | else: 236 | loss.backward() 237 | if global_step % args.gradient_acc == 0: 238 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 239 | opt.step() 240 | opt.zero_grad() 241 | lr_scheduler.step() 242 | 243 | with torch.no_grad(): 244 | if wandb_logger: 245 | wandb_logger.log({ 246 | 'Loss/Step Loss': loss.item(), 247 | 'Loss/Train Loss': loss_am.avg, 248 | 'Process/Step': global_step, 249 | 'Process/Epoch': epoch 250 | }) 251 | 252 | loss_am.update(loss.item(), 1) 253 | callback_logging(global_step, loss_am, epoch, args.fp16, lr_scheduler.get_last_lr()[0], amp) 254 | 255 | if global_step % args.verbose == 0 and global_step > 0: 256 | callback_verification(global_step, backbone) 257 | 258 | if args.save_all_states: 259 | checkpoint = { 260 | "epoch": epoch + 1, 261 | "global_step": global_step, 262 | "state_dict_backbone": backbone.module.state_dict(), 263 | "state_dict_softmax_fc": head.state_dict(), 264 | "state_optimizer": opt.state_dict(), 265 | "state_lr_scheduler": lr_scheduler.state_dict() 266 | } 267 | torch.save(checkpoint, os.path.join(args.output, f"checkpoint_gpu_{rank}.pt")) 268 | 269 | if rank == 0: 270 | path_module = os.path.join(args.output, "model.pt") 271 | torch.save(backbone.module.state_dict(), path_module) 272 | 273 | if wandb_logger and args.save_artifacts: 274 | artifact_name = f"{run_name}_E{epoch}" 275 | model = wandb.Artifact(artifact_name, type='model') 276 | model.add_file(path_module) 277 | wandb_logger.log_artifact(model) 278 | 279 | if args.dali: 280 | train_loader.reset() 281 | 282 | if rank == 0: 283 | path_module = os.path.join(args.output, "model.pt") 284 | torch.save(backbone.module.state_dict(), path_module) 285 | 286 | if wandb_logger and args.save_artifacts: 287 | artifact_name = f"{run_name}_Final" 288 | model = wandb.Artifact(artifact_name, type='model') 289 | model.add_file(path_module) 290 | wandb_logger.log_artifact(model) 291 | 292 | torch.distributed.barrier() 293 | destroy_process_group() 294 | return 295 | 296 | 297 | if __name__ == "__main__": 298 | torch.backends.cudnn.benchmark = True 299 | args = config.get_args() 300 | main(args) --------------------------------------------------------------------------------