├── SFace_jittor ├── backbone │ ├── __init__.py │ ├── model_mobilefacenet.py │ └── model_irse.py ├── head │ ├── __init__.py │ └── metrics.py ├── util │ ├── __init__.py │ ├── utils.py │ └── verification.py ├── config.py ├── image_iter.py ├── README.md └── train_SFace_jittor.py └── README.md /SFace_jittor/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SFace_jittor/head/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SFace_jittor/util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JittorFace 2 | Codes of some face models inplemented by `Jittor`. 3 | 4 | ## SFace 5 | Please check [SFace_jittor](https://github.com/liubingyuu/jittorface/tree/main/SFace_jittor). 6 | -------------------------------------------------------------------------------- /SFace_jittor/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | 5 | def get_yaml_data(yaml_file): 6 | file = open(yaml_file, 'r', encoding="utf-8") 7 | file_data = file.read() 8 | file.close() 9 | data = yaml.load(file_data) 10 | return data 11 | 12 | def get_config(args): 13 | configuration = dict( 14 | SEED=1337, # random seed for reproduce results 15 | INPUT_SIZE=[112, 112], 16 | EMBEDDING_SIZE=512, # embedding size 17 | DROP_LAST=True, 18 | WEIGHT_DECAY=5e-4, 19 | MOMENTUM=0.9, 20 | ) 21 | 22 | if args.workers_id == 'cpu': 23 | configuration['GPU_ID'] = [] 24 | print("check", args.workers_id) 25 | else: 26 | configuration['GPU_ID'] = [int(i) for i in args.workers_id.split(',')] 27 | if len(configuration['GPU_ID']) == 0: 28 | configuration['MULTI_GPU'] = False 29 | else: 30 | if len(configuration['GPU_ID']) == 1: 31 | configuration['MULTI_GPU'] = False 32 | else: 33 | configuration['MULTI_GPU'] = True 34 | 35 | configuration['NUM_EPOCH'] = args.epochs 36 | configuration['STAGES'] = [int(i) for i in args.stages.split(',')] 37 | configuration['LR'] = args.lr 38 | configuration['BATCH_SIZE'] = args.batch_size 39 | 40 | if args.data_mode == 'casia': 41 | configuration['DATA_ROOT'] = './data/faces_webface_112x112/' # the dir for training 42 | else: 43 | raise Exception(args.data_mode) 44 | 45 | configuration['EVAL_PATH'] = './data/faces_webface_112x112/' # the dir for validation 46 | 47 | assert args.net in ['IR_50', 'IR_101', 'MobileFaceNet'] 48 | 49 | 50 | configuration['BACKBONE_NAME'] = args.net 51 | assert args.head in ['SFaceLoss'] 52 | configuration['HEAD_NAME'] = args.head 53 | 54 | configuration['TARGET'] = [i for i in args.target.split(',')] 55 | 56 | if args.resume_backbone: 57 | configuration['BACKBONE_RESUME_ROOT'] = args.resume_backbone # the dir to resume training from a saved checkpoint 58 | configuration['HEAD_RESUME_ROOT'] = args.resume_head # the dir to resume training from a saved checkpoint 59 | else: 60 | configuration['BACKBONE_RESUME_ROOT'] = '' 61 | configuration['HEAD_RESUME_ROOT'] = '' 62 | configuration['WORK_PATH'] = args.outdir # the dir to save your checkpoints 63 | 64 | return configuration 65 | -------------------------------------------------------------------------------- /SFace_jittor/head/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import math 4 | import jittor as jt 5 | 6 | 7 | def xavier_gauss(shape, dtype, gain=1.0, mode='avg'): 8 | shape = tuple(shape) 9 | assert len(shape) > 1 10 | 11 | matsize = 1 12 | for i in shape[2:]: 13 | matsize *= i 14 | if mode == 'avg': 15 | fan = (shape[1] * matsize) + (shape[0] * matsize) 16 | elif mode == 'in': 17 | fan = shape[1] * matsize 18 | elif mode == 'out': 19 | fan = shape[0] * matsize 20 | else: 21 | raise Exception('wrong mode') 22 | std = gain * math.sqrt(2.0 / fan) 23 | return jt.init.gauss(shape, dtype, 0, std) 24 | 25 | 26 | class SFaceLoss(jt.Module): 27 | 28 | def __init__(self, in_features, out_features, device_id, s=64.0, k=80.0, a=0.80, b=1.23): 29 | super(SFaceLoss, self).__init__() 30 | self.in_features = in_features 31 | self.out_features = out_features 32 | self.device_id = device_id 33 | self.s = s 34 | self.k = k 35 | self.a = a 36 | self.b = b 37 | self.weight = xavier_gauss((out_features, in_features), "float32", gain=2, mode='out') 38 | 39 | def execute(self, input, label): 40 | # --------------------------- cos(theta) & phi(theta) --------------------------- 41 | cosine = jt.nn.matmul_transpose(jt.normalize(input), jt.normalize(self.weight)) 42 | # --------------------------- s*cos(theta) --------------------------- 43 | output = cosine * self.s 44 | # --------------------------- sface loss --------------------------- 45 | 46 | one_hot = jt.zeros(cosine.size()) 47 | one_hot.scatter_(1, label.view(-1, 1), jt.ones((input.shape[0], 1))) 48 | 49 | zero_hot = jt.ones(cosine.size()) 50 | zero_hot.scatter_(1, label.view(-1, 1), jt.zeros((input.shape[0], 1))) 51 | 52 | WyiX = jt.sum(one_hot * output, 1) 53 | with jt.no_grad(): 54 | theta_yi = jt.acos(WyiX / self.s) 55 | weight_yi = 1.0 / (1.0 + jt.exp(-self.k * (theta_yi - self.a))) 56 | intra_loss = - weight_yi * WyiX 57 | 58 | Wj = zero_hot * output 59 | with jt.no_grad(): 60 | theta_j = jt.acos(Wj / self.s) 61 | weight_j = 1.0 / (1.0 + jt.exp(self.k * (theta_j - self.b))) 62 | inter_loss = jt.sum(weight_j * Wj, 1) 63 | 64 | loss = intra_loss.mean() + inter_loss.mean() 65 | Wyi_s = WyiX / self.s 66 | Wj_s = Wj / self.s 67 | return output, loss, intra_loss.mean(), inter_loss.mean(), Wyi_s.mean(), Wj_s.mean() 68 | 69 | 70 | -------------------------------------------------------------------------------- /SFace_jittor/image_iter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from jittor.dataset import Dataset 4 | from jittor import transform 5 | from PIL import Image 6 | import mxnet as mx 7 | import logging 8 | logger = logging.getLogger() 9 | 10 | class FaceDataset(Dataset): 11 | def __init__(self, path_img, rand_mirror, 12 | batch_size=16, shuffle=False, drop_last=False, num_workers=0): 13 | super(FaceDataset, self).__init__() 14 | self.batch_size = batch_size 15 | self.shuffle = shuffle 16 | self.drop_last = drop_last 17 | self.num_workers = num_workers 18 | self.rand_mirror = rand_mirror 19 | assert path_img 20 | imgtxt = os.path.join(path_img, 'imgs.txt') 21 | imgs = [] 22 | with open(imgtxt, 'r') as f: 23 | for line in f: 24 | fn, label = line.strip().split(' ') 25 | fn = os.path.join(path_img, fn) 26 | label = int(label) 27 | imgs.append((fn, label)) 28 | self.imgs = imgs 29 | if rand_mirror: 30 | self.trans = transform.RandomHorizontalFlip() 31 | else: 32 | self.trans = None 33 | 34 | def __getitem__(self, index): 35 | fn, label = self.imgs[index] 36 | img = Image.open(fn).convert('RGB') 37 | if self.trans: 38 | img = self.trans(img) 39 | img = np.array(img, np.float32, copy=False) 40 | img = img.transpose((2, 0, 1)) 41 | 42 | return img, label 43 | 44 | def __len__(self): 45 | return len(self.imgs) 46 | 47 | 48 | class ValDataset(Dataset): 49 | def __init__(self, bins, issame_list, batch_size, image_size=[112,112], shuffle=False, drop_last=False): 50 | super(ValDataset, self).__init__() 51 | self.batch_size = batch_size 52 | self.shuffle = shuffle 53 | self.drop_last = drop_last 54 | 55 | self.data_list = [] 56 | for flip in [0, 1]: 57 | data = np.zeros((len(issame_list) * 2, 3, image_size[0], image_size[1])) 58 | self.data_list.append(data) 59 | for i in range(len(issame_list) * 2): 60 | _bin = bins[i] 61 | img = mx.image.imdecode(_bin) 62 | if img.shape[1] != image_size[0]: 63 | img = mx.image.resize_short(img, image_size[0]) 64 | img = mx.nd.transpose(img, axes=(2, 0, 1)) 65 | for flip in [0, 1]: 66 | if flip == 1: 67 | img = mx.ndarray.flip(data=img, axis=2) 68 | self.data_list[flip][i] = img.asnumpy() 69 | if i % 1000 == 0: 70 | print('loading bin', i) 71 | print(self.data_list[0].shape) 72 | 73 | def __getitem__(self, index): 74 | return self.data_list[0][index], self.data_list[1][index] 75 | 76 | def __len__(self): 77 | return self.data_list[0].shape[0] 78 | -------------------------------------------------------------------------------- /SFace_jittor/README.md: -------------------------------------------------------------------------------- 1 | # SFace-Jittor 2 | This is the code of SFace based on `Jittor`. 3 | 4 | Paper: [《SFace: Sigmoid-Constrained Hypersphere Loss for Robust Face Recognition》](https://ieeexplore.ieee.org/document/9318547) 5 | 6 | ## Abstract 7 | Deep face recognition has achieved great success due to large-scale training databases and rapidly developing loss functions. The existing algorithms devote to realizing an ideal idea: minimizing the intra-class distance and maximizing the inter-class distance. However, they may neglect that there are also low quality training images which should not be optimized in this strict way. Considering the imperfection of training databases, we propose that intra-class and inter-class objectives can be optimized in a moderate way to mitigate overfitting problem, and further propose a novel loss function, named sigmoid-constrained hypersphere loss (SFace). Specifically, SFace imposes intra-class and inter-class constraints on a hypersphere manifold, which are controlled by two sigmoid gradient re-scale functions respectively. The sigmoid curves precisely re-scale the intra-class and inter-class gradients so that training samples can be optimized to some degree. Therefore, SFace can make a better balance between decreasing the intra-class distances for clean examples and preventing overfitting to the label noise, and contributes more robust deep face recognition models. Extensive experiments of models trained on CASIA-WebFace, VGGFace2, and MS-Celeb-1M databases, and evaluated on several face recognition benchmarks, such as LFW, MegaFace and IJB-C databases, have demonstrated the superiority of SFace. 8 | 9 | ## Usage Instructions 10 | 1. Install Jittor with GPU support (Python 3.7). 11 | 12 | 2. Download the code. 13 | 14 | 3. The training datasets, CASIA-WebFace, VGGFace2 and MS1MV2, evaluation datasets can be downloaded from Data Zoo of [InsightFace](https://github.com/deepinsight/insightface). Then convert the training datasets into jpg format from the MXNet binary format. An example data structure for CASIA-WebFace is shown in `data/faces_webface_112x112/imgs.txt`. 15 | 16 | ## Train 17 | Run the code to train a model. 18 | 19 | (1) Train ResNet50, CASIA-WebFace, SFace. 20 | 21 | - *With a single GPU* 22 | ``` 23 | CUDA_VISIBLE_DEVICES="0" python3 -u train_SFace_jittor.py --workers_id 0 --batch_size 256 --lr 0.1 --stages 50,70,80 --data_mode casia --net IR_50 --outdir ./results/IR_50-sface-casia --param_a 0.87 --param_b 1.2 24 | ``` 25 | - *With multiple GPUs* 26 | ``` 27 | CUDA_VISIBLE_DEVICES="0,1" mpirun -np 2 python3 -u train_SFace_jittor.py --workers_id 0,1 --batch_size 256 --lr 0.1 --stages 50,70,80 --data_mode casia --net IR_50 --outdir ./results/IR_50-sface-casia --param_a 0.87 --param_b 1.2 28 | ``` 29 | 30 | (2) Train MobileNet, CASIA-WebFace, SFace. 31 | 32 | - *With a single GPU* 33 | ``` 34 | CUDA_VISIBLE_DEVICES="0" python3 -u train_SFace_jittor.py --workers_id 0 --batch_size 256 --lr 0.1 --stages 50,70,80 --data_mode casia --net MobileFaceNet --outdir ./results/Mobile-sface-casia --param_a 0.87 --param_b 1.2 35 | ``` 36 | - *With multiple GPUs* 37 | ``` 38 | CUDA_VISIBLE_DEVICES="0,1" mpirun -np 2 python3 -u train_SFace_jittor.py --workers_id 0,1 --batch_size 256 --lr 0.1 --stages 50,70,80 --data_mode casia --net MobileFaceNet --outdir ./results/Mobile-sface-casia --param_a 0.87 --param_b 1.2 39 | ``` 40 | -------------------------------------------------------------------------------- /SFace_jittor/backbone/model_mobilefacenet.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn, Module 3 | from jittor.nn import Sequential 4 | from collections import OrderedDict 5 | import math 6 | 7 | class Flatten(Module): 8 | def execute(self, input): 9 | return input.view(input.size(0), -1) 10 | 11 | class Conv_block(Module): # verified: the same as ``Conv'' in ./fmobilefacenet 12 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 13 | super(Conv_block, self).__init__() 14 | self.conv2d = nn.Conv(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False) # verified: the same as mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True) 15 | self.batchnorm = nn.BatchNorm(out_c, eps=0.001) # verified: the same as mx.sym.BatchNorm(data=conv, fix_gamma=False,momentum=0.9) 16 | self.relu = nn.PReLU(num_parameters=out_c) 17 | def execute(self, x): 18 | x = self.conv2d(x) 19 | x = self.batchnorm(x) 20 | x = self.relu(x) 21 | return x 22 | 23 | class Linear_block(Module): # verified: the same as ``Linear'' in ./fmobilefacenet 24 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 25 | super(Linear_block, self).__init__() 26 | self.conv2d = nn.Conv(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False) # verified: the same as mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, num_group=num_group, stride=stride, pad=pad, no_bias=True) 27 | self.batchnorm = nn.BatchNorm(out_c, eps=0.001) 28 | def execute(self, x): 29 | x = self.conv2d(x) 30 | x = self.batchnorm(x) 31 | return x 32 | 33 | class Depth_Wise(Module): # verified: if residual is False: the same as ``DResidual'' in ./fmobilefacenet; else: the same as ``Residual'' 34 | def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 35 | super(Depth_Wise, self).__init__() 36 | self.conv_sep = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 37 | self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride) 38 | self.conv_proj = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 39 | self.residual = residual 40 | def execute(self, x): 41 | if self.residual: 42 | short_cut = x 43 | x = self.conv_sep(x) 44 | x = self.conv_dw(x) 45 | x = self.conv_proj(x) 46 | if self.residual: 47 | output = short_cut + x 48 | else: 49 | output = x 50 | return output 51 | 52 | class Residual(Module): # verified: the same as ``Residual'' in ./fmobilefacenet 53 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 54 | super(Residual, self).__init__() 55 | modules = OrderedDict() 56 | for i in range(num_block): 57 | modules['block%d'%i] = Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups) 58 | self.model = Sequential(modules) 59 | def execute(self, x): 60 | return self.model(x) 61 | 62 | class MobileFaceNet(Module): 63 | def __init__(self, embedding_size): 64 | super(MobileFaceNet, self).__init__() 65 | self.conv_1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 66 | self.conv_2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 67 | self.dconv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128) 68 | self.res_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 69 | self.dconv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256) 70 | self.res_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 71 | self.dconv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512) 72 | self.res_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)) 73 | self.conv_6sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 74 | self.conv_6dw7_7 = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0)) 75 | self.conv_6_flatten = Flatten() 76 | self.pre_fc1 = nn.Linear(512, embedding_size) 77 | self.fc1 = nn.BatchNorm(embedding_size, eps=2e-5) # doubt: the same as mx.sym.BatchNorm(data=conv_6_f, fix_gamma=True, eps=2e-5, momentum=0.9)? 78 | 79 | def execute(self, x): 80 | x = x - 127.5 81 | x = x * 0.078125 82 | out = self.conv_1(x) 83 | out = self.conv_2_dw(out) 84 | out = self.dconv_23(out) 85 | out = self.res_3(out) 86 | out = self.dconv_34(out) 87 | out = self.res_4(out) 88 | out = self.dconv_45(out) 89 | out = self.res_5(out) 90 | out = self.conv_6sep(out) 91 | out = self.conv_6dw7_7(out) 92 | out = self.conv_6_flatten(out) 93 | out = self.pre_fc1(out) 94 | out = self.fc1(out) 95 | return out 96 | -------------------------------------------------------------------------------- /SFace_jittor/util/utils.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import transform 3 | 4 | from .verification import evaluate 5 | from image_iter import ValDataset 6 | 7 | from datetime import datetime 8 | import matplotlib.pyplot as plt 9 | plt.switch_backend('agg') 10 | import numpy as np 11 | from PIL import Image 12 | import io 13 | import os, pickle, sklearn 14 | import time 15 | 16 | 17 | def get_time(): 18 | return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-') 19 | 20 | 21 | def load_bin(path, batch_size, image_size=[112,112]): 22 | bins, issame_list = pickle.load(open(path, 'rb'), encoding='bytes') 23 | data_loader = ValDataset(bins, issame_list, batch_size, image_size=image_size) 24 | return data_loader, issame_list 25 | 26 | 27 | def get_val_pair(path, name, batch_size): 28 | ver_path = os.path.join(path,name+".bin") 29 | assert os.path.exists(ver_path) 30 | data_set, issame = load_bin(ver_path, batch_size) 31 | print('ver', name) 32 | return data_set, issame 33 | 34 | 35 | def get_val_data(data_path, targets, batch_size): 36 | assert len(targets) > 0 37 | vers = [] 38 | for t in targets: 39 | data_set, issame = get_val_pair(data_path, t, batch_size) 40 | vers.append([t, data_set, issame]) 41 | return vers 42 | 43 | 44 | def separate_irse_bn_paras(modules): 45 | if not isinstance(modules, list): 46 | modules = [*modules.modules()] 47 | paras_only_bn = [] 48 | paras_wo_bn = [] 49 | for layer in modules: 50 | if 'model' in str(layer.__class__): 51 | continue 52 | if 'Sequential' in str(layer.__class__): 53 | continue 54 | if 'BatchNorm' in str(layer.__class__): 55 | paras_only_bn.extend([*layer.parameters()]) 56 | else: 57 | paras_wo_bn.extend([*layer.parameters()]) 58 | 59 | return paras_only_bn, paras_wo_bn 60 | 61 | 62 | def separate_mobilefacenet_bn_paras(modules): 63 | if not isinstance(modules, list): 64 | modules = [*modules.modules()] 65 | paras_only_bn = [] 66 | paras_wo_bn = [] 67 | for layer in modules: 68 | if 'mobilefacenet' in str(layer.__class__) or 'Sequential' in str(layer.__class__): 69 | continue 70 | if 'BatchNorm' in str(layer.__class__): 71 | paras_only_bn.extend([*layer.parameters()]) 72 | else: 73 | paras_wo_bn.extend([*layer.parameters()]) 74 | 75 | return paras_only_bn, paras_wo_bn 76 | 77 | 78 | def gen_plot(fpr, tpr): 79 | """Create a pyplot plot and save to buffer.""" 80 | plt.figure() 81 | plt.xlabel("FPR", fontsize = 14) 82 | plt.ylabel("TPR", fontsize = 14) 83 | plt.title("ROC Curve", fontsize = 14) 84 | plot = plt.plot(fpr, tpr, linewidth = 2) 85 | buf = io.BytesIO() 86 | plt.savefig(buf, format = 'jpeg') 87 | buf.seek(0) 88 | plt.close() 89 | 90 | return buf 91 | 92 | 93 | def perform_val(embedding_size, batch_size, backbone, data_set, issame, nrof_folds=10): 94 | backbone.eval() # switch to evaluation mode 95 | 96 | embeddings_jt = jt.zeros([len(data_set), embedding_size]) 97 | embeddings_jt_flip = jt.zeros([len(data_set), embedding_size]) 98 | with jt.no_grad(): 99 | for i, (batch, batch_flip) in enumerate(data_set): 100 | output = backbone(batch) 101 | output_flip = backbone(batch_flip) 102 | bs_single = batch.shape[0] 103 | embeddings_jt[i*batch_size+jt.rank*bs_single:i*batch_size+(jt.rank+1)*bs_single] = output.detach() 104 | embeddings_jt_flip[i*batch_size+jt.rank*bs_single:i*batch_size+(jt.rank+1)*bs_single] = output_flip.detach() 105 | embeddings_jt.sync() 106 | embeddings_jt_flip.sync() 107 | 108 | if jt.in_mpi: 109 | embeddings_jt = embeddings_jt.mpi_all_reduce('add') 110 | embeddings_jt_flip = embeddings_jt_flip.mpi_all_reduce('add') 111 | embeddings_list = [embeddings_jt.data, embeddings_jt_flip.data] 112 | 113 | _xnorm = 0.0 114 | _xnorm_cnt = 0 115 | for embed in embeddings_list: 116 | for i in range(embed.shape[0]): 117 | _em = embed[i] 118 | _norm = np.linalg.norm(_em) 119 | _xnorm += _norm 120 | _xnorm_cnt += 1 121 | _xnorm /= _xnorm_cnt 122 | 123 | embeddings = embeddings_list[0] + embeddings_list[1] 124 | embeddings = sklearn.preprocessing.normalize(embeddings) 125 | if jt.rank == 0: 126 | print(embeddings.shape) 127 | 128 | tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) 129 | buf = gen_plot(fpr, tpr) 130 | roc_curve = Image.open(buf) 131 | roc_curve_tensor = transform.ToTensor()(roc_curve) 132 | 133 | return accuracy.mean(), accuracy.std(), _xnorm, best_thresholds.mean(), roc_curve_tensor 134 | 135 | 136 | def buffer_val(writer, db_name, acc, std, xnorm, best_threshold, roc_curve_tensor, batch): 137 | writer.add_scalar('Accuracy/{}_Accuracy'.format(db_name), acc, batch) 138 | writer.add_scalar('Std/{}_Std'.format(db_name), std, batch) 139 | writer.add_scalar('XNorm/{}_XNorm'.format(db_name), xnorm, batch) 140 | writer.add_scalar('Threshold/{}_Best_Threshold'.format(db_name), best_threshold, batch) 141 | writer.add_image('ROC/{}_ROC_Curve'.format(db_name), roc_curve_tensor, batch) 142 | 143 | 144 | class AverageMeter(object): 145 | """Computes and stores the average and current value""" 146 | def __init__(self): 147 | self.reset() 148 | 149 | def reset(self): 150 | self.val = 0 151 | self.avg = 0 152 | self.sum = 0 153 | self.count = 0 154 | 155 | def update(self, val, n = 1): 156 | self.val = val 157 | self.sum += val * n 158 | self.count += n 159 | self.avg = self.sum / self.count 160 | 161 | 162 | def train_accuracy(output, target, topk=(1,)): 163 | """Computes the precision@k for the specified values of k""" 164 | maxk = max(topk) 165 | batch_size = target.size(0) 166 | 167 | _, pred = output.topk(maxk, 1, True, True) 168 | pred = pred.t() 169 | correct = pred.equal(target.view(1, -1).expand_as(pred)) 170 | #embed() 171 | res = [] 172 | for k in topk: 173 | correct_k = correct[:k].view(-1).float().sum(0) 174 | res.append(correct_k.multiply(100.0 / batch_size)) 175 | 176 | return res[0] 177 | -------------------------------------------------------------------------------- /SFace_jittor/util/verification.py: -------------------------------------------------------------------------------- 1 | """Helper for evaluation on the Labeled Faces in the Wild dataset 2 | """ 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2016 David Sandberg 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | import numpy as np 27 | from sklearn.model_selection import KFold 28 | from sklearn.decomposition import PCA 29 | import sklearn 30 | from scipy import interpolate 31 | from scipy.spatial.distance import pdist 32 | 33 | 34 | # Support: ['calculate_roc', 'calculate_accuracy', 'calculate_val', 'calculate_val_far', 'evaluate'] 35 | 36 | 37 | def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds = 10, pca = 0): 38 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 39 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 40 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 41 | nrof_thresholds = len(thresholds) 42 | k_fold = KFold(n_splits = nrof_folds, shuffle = False) 43 | 44 | tprs = np.zeros((nrof_folds, nrof_thresholds)) 45 | fprs = np.zeros((nrof_folds, nrof_thresholds)) 46 | accuracy = np.zeros((nrof_folds)) 47 | best_thresholds = np.zeros((nrof_folds)) 48 | indices = np.arange(nrof_pairs) 49 | 50 | if pca == 0: 51 | diff = np.subtract(embeddings1, embeddings2) 52 | dist = np.sum(np.square(diff), 1) 53 | 54 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 55 | if pca > 0: 56 | print("doing pca on", fold_idx) 57 | embed1_train = embeddings1[train_set] 58 | embed2_train = embeddings2[train_set] 59 | _embed_train = np.concatenate((embed1_train, embed2_train), axis = 0) 60 | pca_model = PCA(n_components = pca) 61 | pca_model.fit(_embed_train) 62 | embed1 = pca_model.transform(embeddings1) 63 | embed2 = pca_model.transform(embeddings2) 64 | embed1 = sklearn.preprocessing.normalize(embed1) 65 | embed2 = sklearn.preprocessing.normalize(embed2) 66 | diff = np.subtract(embed1, embed2) 67 | dist = np.sum(np.square(diff), 1) 68 | 69 | # Find the best threshold for the fold 70 | acc_train = np.zeros((nrof_thresholds)) 71 | for threshold_idx, threshold in enumerate(thresholds): 72 | _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set]) 73 | best_threshold_index = np.argmax(acc_train) 74 | best_thresholds[fold_idx] = thresholds[best_threshold_index] 75 | for threshold_idx, threshold in enumerate(thresholds): 76 | tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(threshold, 77 | dist[test_set], 78 | actual_issame[ 79 | test_set]) 80 | _, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set]) 81 | 82 | tpr = np.mean(tprs, 0) 83 | fpr = np.mean(fprs, 0) 84 | return tpr, fpr, accuracy, best_thresholds 85 | 86 | 87 | def calculate_accuracy(threshold, dist, actual_issame): 88 | predict_issame = np.less(dist, threshold) 89 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 90 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 91 | tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame))) 92 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 93 | 94 | tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) 95 | fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) 96 | acc = float(tp + tn) / dist.size 97 | return tpr, fpr, acc 98 | 99 | 100 | def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds = 10): 101 | ''' 102 | Copy from [insightface](https://github.com/deepinsight/insightface) 103 | :param thresholds: 104 | :param embeddings1: 105 | :param embeddings2: 106 | :param actual_issame: 107 | :param far_target: 108 | :param nrof_folds: 109 | :return: 110 | ''' 111 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 112 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 113 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 114 | nrof_thresholds = len(thresholds) 115 | k_fold = KFold(n_splits = nrof_folds, shuffle = False) 116 | 117 | val = np.zeros(nrof_folds) 118 | far = np.zeros(nrof_folds) 119 | 120 | diff = np.subtract(embeddings1, embeddings2) 121 | dist = np.sum(np.square(diff), 1) 122 | indices = np.arange(nrof_pairs) 123 | 124 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 125 | 126 | # Find the threshold that gives FAR = far_target 127 | far_train = np.zeros(nrof_thresholds) 128 | for threshold_idx, threshold in enumerate(thresholds): 129 | _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set]) 130 | if np.max(far_train) >= far_target: 131 | f = interpolate.interp1d(far_train, thresholds, kind = 'slinear') 132 | threshold = f(far_target) 133 | else: 134 | threshold = 0.0 135 | 136 | val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set]) 137 | 138 | val_mean = np.mean(val) 139 | far_mean = np.mean(far) 140 | val_std = np.std(val) 141 | return val_mean, val_std, far_mean 142 | 143 | 144 | def calculate_val_far(threshold, dist, actual_issame): 145 | predict_issame = np.less(dist, threshold) 146 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 147 | false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 148 | n_same = np.sum(actual_issame) 149 | n_diff = np.sum(np.logical_not(actual_issame)) 150 | val = float(true_accept) / float(n_same) 151 | far = float(false_accept) / float(n_diff) 152 | return val, far 153 | 154 | 155 | def evaluate(embeddings, actual_issame, nrof_folds = 10, pca = 0): 156 | # Calculate evaluation metrics 157 | thresholds = np.arange(0, 4, 0.01) 158 | embeddings1 = embeddings[0::2] 159 | embeddings2 = embeddings[1::2] 160 | tpr, fpr, accuracy, best_thresholds = calculate_roc(thresholds, embeddings1, embeddings2, np.asarray(actual_issame), nrof_folds = nrof_folds, pca = pca) 161 | 162 | return tpr, fpr, accuracy, best_thresholds 163 | -------------------------------------------------------------------------------- /SFace_jittor/backbone/model_irse.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn, Module, init 3 | from jittor.nn import Sequential 4 | from collections import namedtuple 5 | import math 6 | 7 | 8 | # Support: ['IR_50', 'IR_101', 'IR_152', 'IR_SE_50', 'IR_SE_101', 'IR_SE_152'] 9 | 10 | 11 | def xavier_uniform(shape, dtype, gain=1.0): 12 | assert len(shape)>1 13 | 14 | matsize=1 15 | for i in shape[2:]: 16 | matsize *= i 17 | fan = (shape[1] * matsize) + (shape[0] * matsize) 18 | bound = gain * math.sqrt(6.0/fan) 19 | return init.uniform(shape, dtype, -bound, bound) 20 | 21 | 22 | def xavier_uniform_(var, gain=1.0): 23 | var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain)) 24 | 25 | 26 | class Flatten(Module): 27 | def execute(self, input): 28 | return input.view(input.size(0), -1) 29 | 30 | 31 | def l2_norm(input, axis=1): 32 | norm = jt.norm(input, 2, axis, True) 33 | output = jt.divide(input, norm) 34 | 35 | return output 36 | 37 | 38 | class SEModule(Module): 39 | def __init__(self, channels, reduction): 40 | super(SEModule, self).__init__() 41 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 42 | self.fc1 = nn.Conv( 43 | channels, channels // reduction, 1, padding=0, bias=False) 44 | 45 | xavier_uniform_(self.fc1.weight) 46 | 47 | self.relu = nn.ReLU() 48 | self.fc2 = nn.Conv( 49 | channels // reduction, channels, 1, padding=0, bias=False) 50 | 51 | self.sigmoid = nn.Sigmoid() 52 | 53 | def execute(self, x): 54 | module_input = x 55 | x = self.avg_pool(x) 56 | x = self.fc1(x) 57 | x = self.relu(x) 58 | x = self.fc2(x) 59 | x = self.sigmoid(x) 60 | 61 | return module_input * x 62 | 63 | 64 | class bottleneck_IR(Module): 65 | 66 | def __init__(self, in_channel, depth, stride): 67 | super(bottleneck_IR, self).__init__() 68 | if in_channel == depth: 69 | self.shortcut_layer = nn.Pool(1, stride=stride, op='maximum') 70 | else: 71 | self.shortcut_layer = Sequential( 72 | nn.Conv(in_channel, depth, (1, 1), stride=stride, bias=False), 73 | nn.BatchNorm(depth)) 74 | self.res_layer = Sequential( 75 | nn.BatchNorm(in_channel), 76 | nn.Conv(in_channel, depth, (3, 3), stride=(1, 1), padding=1, bias=False), 77 | nn.PReLU(num_parameters=depth), 78 | nn.Conv(depth, depth, (3, 3), stride=stride, padding=1, bias=False), 79 | nn.BatchNorm(depth)) 80 | 81 | def execute(self, x): 82 | shortcut = self.shortcut_layer(x) 83 | res = self.res_layer(x) 84 | 85 | return res + shortcut 86 | 87 | 88 | class bottleneck_IR_SE(Module): 89 | def __init__(self, in_channel, depth, stride): 90 | super(bottleneck_IR_SE, self).__init__() 91 | if in_channel == depth: 92 | self.shortcut_layer = nn.Pool(1, stride=stride, op='maximum') 93 | else: 94 | self.shortcut_layer = Sequential( 95 | nn.Conv(in_channel, depth, (1, 1), stride=stride, bias=False), 96 | nn.BatchNorm(depth)) 97 | self.res_layer = Sequential( 98 | nn.BatchNorm(in_channel), 99 | nn.Conv(in_channel, depth, (3, 3), stride=(1, 1), padding=1, bias=False), 100 | nn.PReLU(num_parameters=depth), 101 | nn.Conv(depth, depth, (3, 3), stride=stride, padding=1, bias=False), 102 | nn.BatchNorm(depth), 103 | SEModule(depth, 16) 104 | ) 105 | 106 | def execute(self, x): 107 | shortcut = self.shortcut_layer(x) 108 | res = self.res_layer(x) 109 | 110 | return res + shortcut 111 | 112 | 113 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 114 | '''A named tuple describing a ResNet block.''' 115 | 116 | 117 | def get_block(in_channel, depth, num_units, stride=2): 118 | 119 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 120 | 121 | 122 | def get_blocks(num_layers): 123 | if num_layers == 50: 124 | blocks = [ 125 | get_block(in_channel=64, depth=64, num_units=3), 126 | get_block(in_channel=64, depth=128, num_units=4), 127 | get_block(in_channel=128, depth=256, num_units=14), 128 | get_block(in_channel=256, depth=512, num_units=3) 129 | ] 130 | elif num_layers == 100: 131 | blocks = [ 132 | get_block(in_channel=64, depth=64, num_units=3), 133 | get_block(in_channel=64, depth=128, num_units=13), 134 | get_block(in_channel=128, depth=256, num_units=30), 135 | get_block(in_channel=256, depth=512, num_units=3) 136 | ] 137 | elif num_layers == 152: 138 | blocks = [ 139 | get_block(in_channel=64, depth=64, num_units=3), 140 | get_block(in_channel=64, depth=128, num_units=8), 141 | get_block(in_channel=128, depth=256, num_units=36), 142 | get_block(in_channel=256, depth=512, num_units=3) 143 | ] 144 | 145 | return blocks 146 | 147 | 148 | class Backbone(Module): 149 | def __init__(self, input_size, num_layers, mode='ir'): 150 | super(Backbone, self).__init__() 151 | assert input_size[0] in [112, 224], "input_size should be [112, 112] or [224, 224]" 152 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 153 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 154 | blocks = get_blocks(num_layers) 155 | if mode == 'ir': 156 | unit_module = bottleneck_IR 157 | elif mode == 'ir_se': 158 | unit_module = bottleneck_IR_SE 159 | self.input_layer = Sequential( 160 | nn.Conv(3, 64, (3, 3), stride=1, padding=1, bias=False), 161 | nn.BatchNorm(64), 162 | nn.PReLU(num_parameters=64)) 163 | if input_size[0] == 112: 164 | self.output_layer = Sequential( 165 | nn.BatchNorm(512), 166 | nn.Dropout(), 167 | Flatten(), 168 | nn.Linear(512 * 7 * 7, 512), 169 | nn.BatchNorm(512)) 170 | else: 171 | self.output_layer = Sequential( 172 | nn.BatchNorm(512), 173 | nn.Dropout(), 174 | Flatten(), 175 | nn.Linear(512 * 14 * 14, 512), 176 | nn.BatchNorm(512)) 177 | 178 | modules = [] 179 | for block in blocks: 180 | for bottleneck in block: 181 | modules.append( 182 | unit_module(bottleneck.in_channel, 183 | bottleneck.depth, 184 | bottleneck.stride)) 185 | self.body = Sequential(*modules) 186 | 187 | self._initialize_weights() 188 | 189 | def execute(self, x): 190 | x = x - 127.5 191 | x = x * 0.078125 192 | x = self.input_layer(x) 193 | x = self.body(x) 194 | #print("x",x.shape) 195 | x = self.output_layer(x) 196 | #print("emb", x.shape) 197 | return x 198 | 199 | def _initialize_weights(self): 200 | for m in self.modules(): 201 | if isinstance(m, nn.Conv): 202 | xavier_uniform_(m.weight) 203 | if m.bias is not None: 204 | init.constant_(m.bias, 0) 205 | elif isinstance(m, nn.BatchNorm): 206 | init.constant_(m.weight, 1) 207 | init.constant_(m.bias, 0) 208 | elif isinstance(m, nn.Linear): 209 | xavier_uniform_(m.weight) 210 | if m.bias is not None: 211 | init.constant_(m.bias, 0) 212 | 213 | 214 | def IR_50(input_size): 215 | """Constructs a ir-50 model. 216 | """ 217 | model = Backbone(input_size, 50, 'ir') 218 | 219 | return model 220 | 221 | 222 | def IR_101(input_size): 223 | """Constructs a ir-101 model. 224 | """ 225 | model = Backbone(input_size, 100, 'ir') 226 | 227 | return model 228 | 229 | 230 | def IR_152(input_size): 231 | """Constructs a ir-152 model. 232 | """ 233 | model = Backbone(input_size, 152, 'ir') 234 | 235 | return model 236 | 237 | 238 | def IR_SE_50(input_size): 239 | """Constructs a ir_se-50 model. 240 | """ 241 | model = Backbone(input_size, 50, 'ir_se') 242 | 243 | return model 244 | 245 | 246 | def IR_SE_101(input_size): 247 | """Constructs a ir_se-101 model. 248 | """ 249 | model = Backbone(input_size, 100, 'ir_se') 250 | 251 | return model 252 | 253 | 254 | def IR_SE_152(input_size): 255 | """Constructs a ir_se-152 model. 256 | """ 257 | model = Backbone(input_size, 152, 'ir_se') 258 | 259 | return model 260 | -------------------------------------------------------------------------------- /SFace_jittor/train_SFace_jittor.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | from tensorboardX import SummaryWriter 4 | 5 | import jittor as jt 6 | import jittor.optim as optim 7 | from jittor import nn, init 8 | 9 | from config import get_config 10 | from image_iter import FaceDataset 11 | from backbone.model_irse import IR_50, IR_101 12 | from backbone.model_mobilefacenet import MobileFaceNet 13 | from head.metrics import SFaceLoss 14 | 15 | from util.utils import separate_irse_bn_paras, separate_mobilefacenet_bn_paras 16 | from util.utils import get_val_data, perform_val, get_time, buffer_val, AverageMeter, train_accuracy 17 | import math 18 | import time 19 | 20 | 21 | def xavier_gauss_(var, gain=1.0, mode='avg'): 22 | shape = tuple(var.shape) 23 | assert len(shape) > 1 24 | 25 | matsize = 1 26 | for i in shape[2:]: 27 | matsize *= i 28 | if mode == 'avg': 29 | fan = (shape[1] * matsize) + (shape[0] * matsize) 30 | elif mode == 'in': 31 | fan = shape[1] * matsize 32 | elif mode == 'out': 33 | fan = shape[0] * matsize 34 | else: 35 | raise Exception('wrong mode') 36 | std = gain * math.sqrt(2.0 / fan) 37 | return init.gauss_(var, 0, std) 38 | 39 | 40 | def weight_init(m): 41 | #print(m) 42 | if isinstance(m, nn.BatchNorm): 43 | if hasattr(m, 'weight') and m.weight is not None: 44 | init.constant_(m.weight, 1) 45 | if hasattr(m, 'bias') and m.bias is not None: 46 | init.constant_(m.bias, 0) 47 | if hasattr(m, 'running_mean') and m.running_mean is not None: 48 | init.constant_(m.running_mean, 0) 49 | if hasattr(m, 'running_var') and m.running_var is not None: 50 | init.constant_(m.running_var, 1) 51 | elif isinstance(m, nn.PReLU): 52 | init.constant_(m.weight, 1) 53 | else: 54 | if hasattr(m, 'weight') and m.weight is not None: 55 | xavier_gauss_(m.weight, gain=2, mode='out') 56 | if hasattr(m, 'bias') and m.bias is not None: 57 | init.constant_(m.bias, 0) 58 | 59 | 60 | def schedule_lr(optimizer): 61 | optimizer.lr /= 10. 62 | 63 | if jt.rank == 0: 64 | print(optimizer) 65 | 66 | 67 | def need_save(acc, highest_acc): 68 | do_save = False 69 | save_cnt = 0 70 | if acc[0] > 0.98: 71 | do_save = True 72 | for i, accuracy in enumerate(acc): 73 | if accuracy > highest_acc[i]: 74 | highest_acc[i] = accuracy 75 | do_save = True 76 | if i > 0 and accuracy >= highest_acc[i]-0.002: 77 | save_cnt += 1 78 | if save_cnt >= len(acc)*3/4 and acc[0]>0.99: 79 | do_save = True 80 | print("highest_acc:", highest_acc) 81 | return do_save 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser(description='for face verification') 86 | parser.add_argument('--workers_id', help="gpu ids or cpu", default='cpu', type=str) 87 | parser.add_argument('--epochs', help="training epochs", default=125, type=int) 88 | parser.add_argument('--stages', help="training stages", default='35,65,95', type=str) 89 | parser.add_argument('--lr',help='learning rate',default=1e-1, type=float) 90 | parser.add_argument('--batch_size', help="batch_size", default=256, type=int) 91 | parser.add_argument('--data_mode', help="use which database, [casia, vgg, ms1m, retina, ms1mr]",default='casia', type=str) 92 | parser.add_argument('--net', help="which network, ['IR_50', 'IR_101', 'MobileFaceNet']",default='IR_50', type=str) 93 | parser.add_argument('--head', help="head type, ['SFaceLoss']", default='SFaceLoss', type=str) 94 | parser.add_argument('--target', help="verification targets", default='lfw,calfw,cplfw,cfp_fp,agedb_30', type=str) 95 | parser.add_argument('--resume_backbone', help="resume backbone model", default='', type=str) 96 | parser.add_argument('--resume_head', help="resume head model", default='', type=str) 97 | parser.add_argument('--outdir', help="output dir", default='test_dir', type=str) 98 | parser.add_argument('--param_s', default=64.0, type=float) 99 | parser.add_argument('--param_k', default=80.0, type=float) 100 | parser.add_argument('--param_a', default=0.8, type=float) 101 | parser.add_argument('--param_b', default=1.23, type=float) 102 | args = parser.parse_args() 103 | 104 | #======= hyperparameters & data loaders =======# 105 | if jt.rank == 0 and not os.path.exists(args.outdir): 106 | os.makedirs(args.outdir) 107 | cfg = get_config(args) 108 | 109 | SEED = cfg['SEED'] # random seed for reproduce results 110 | 111 | DATA_ROOT = cfg['DATA_ROOT'] # the parent root where your train data are stored 112 | EVAL_PATH = cfg['EVAL_PATH'] # the parent root where your val data are stored 113 | WORK_PATH = cfg['WORK_PATH'] # the root to buffer your checkpoints and to log your train/val status 114 | BACKBONE_RESUME_ROOT = cfg['BACKBONE_RESUME_ROOT'] # the root to resume training from a saved checkpoint 115 | HEAD_RESUME_ROOT = cfg['HEAD_RESUME_ROOT'] # the root to resume training from a saved checkpoint 116 | 117 | BACKBONE_NAME = cfg['BACKBONE_NAME'] # support: ['IR_50', 'IR_101'] 118 | HEAD_NAME = cfg['HEAD_NAME'] 119 | 120 | INPUT_SIZE = cfg['INPUT_SIZE'] 121 | EMBEDDING_SIZE = cfg['EMBEDDING_SIZE'] # feature dimension 122 | BATCH_SIZE = cfg['BATCH_SIZE'] 123 | DROP_LAST = cfg['DROP_LAST'] # whether drop the last batch to ensure consistent batch_norm statistics 124 | LR = cfg['LR'] # initial LR 125 | NUM_EPOCH = cfg['NUM_EPOCH'] 126 | WEIGHT_DECAY = cfg['WEIGHT_DECAY'] 127 | MOMENTUM = cfg['MOMENTUM'] 128 | STAGES = cfg['STAGES'] # epoch stages to decay learning rate 129 | 130 | MULTI_GPU = cfg['MULTI_GPU'] # flag to use multiple GPUs 131 | GPU_ID = cfg['GPU_ID'] # specify your GPU ids 132 | gpu_nums = len(GPU_ID) 133 | TARGET = cfg['TARGET'] 134 | 135 | if jt.rank == 0: 136 | print('GPU_ID', GPU_ID) 137 | print("=" * 60) 138 | print("Overall Configurations:") 139 | print(cfg) 140 | with open(os.path.join(WORK_PATH, 'config.txt'), 'w') as f: 141 | f.write(str(cfg)) 142 | print("=" * 60) 143 | 144 | writer = SummaryWriter(WORK_PATH) # writer for buffering intermedium results 145 | 146 | if GPU_ID: 147 | jt.flags.use_cuda = 1 148 | else: 149 | jt.flags.use_cuda = 0 150 | jt.set_seed(SEED) 151 | 152 | with open(os.path.join(DATA_ROOT, 'property'), 'r') as f: 153 | NUM_CLASS, h, w = [int(i) for i in f.read().split(',')] 154 | assert h == INPUT_SIZE[0] and w == INPUT_SIZE[1] 155 | 156 | trainloader = FaceDataset(DATA_ROOT, rand_mirror=True, batch_size=BATCH_SIZE, 157 | shuffle=True, drop_last=True, num_workers=gpu_nums) 158 | 159 | vers = get_val_data(EVAL_PATH, TARGET, BATCH_SIZE) 160 | highest_acc = [0.0 for t in TARGET] 161 | 162 | 163 | #======= model & loss & optimizer =======# 164 | BACKBONE_DICT = {'IR_50': IR_50(INPUT_SIZE), 165 | 'IR_101': IR_101(INPUT_SIZE), 166 | 'MobileFaceNet': MobileFaceNet(EMBEDDING_SIZE)} 167 | BACKBONE = BACKBONE_DICT[BACKBONE_NAME] 168 | 169 | HEAD = SFaceLoss(in_features=EMBEDDING_SIZE, out_features=NUM_CLASS, device_id=GPU_ID, 170 | s=args.param_s, k=args.param_k, a=args.param_a, b=args.param_b) 171 | 172 | if BACKBONE_NAME.find("IR") >= 0: 173 | backbone_paras_only_bn, backbone_paras_wo_bn = separate_irse_bn_paras(BACKBONE) # separate batch_norm parameters from others; do not do weight decay for batch_norm parameters to improve the generalizability 174 | _, head_paras_wo_bn = separate_irse_bn_paras(HEAD) 175 | else: 176 | backbone_paras_only_bn, backbone_paras_wo_bn = separate_mobilefacenet_bn_paras(BACKBONE) # separate batch_norm parameters from others; do not do weight decay for batch_norm parameters to improve the generalizability 177 | _, head_paras_wo_bn = separate_mobilefacenet_bn_paras(HEAD) 178 | OPTIMIZER = optim.SGD([{'params': backbone_paras_wo_bn + head_paras_wo_bn, 'weight_decay': WEIGHT_DECAY}, 179 | {'params': backbone_paras_only_bn}], 180 | lr=LR, momentum=MOMENTUM) 181 | if jt.rank == 0: 182 | print("Number of Training Classes: {}".format(NUM_CLASS)) 183 | 184 | print("=" * 60) 185 | print(BACKBONE) 186 | print("{} Backbone Generated".format(BACKBONE_NAME)) 187 | print("=" * 60) 188 | print("=" * 60) 189 | print(HEAD) 190 | print("=" * 60) 191 | print(OPTIMIZER) 192 | print("Optimizer Generated") 193 | print("=" * 60) 194 | 195 | intra_losses = AverageMeter() 196 | inter_losses = AverageMeter() 197 | Wyi_mean = AverageMeter() 198 | Wj_mean = AverageMeter() 199 | top1 = AverageMeter() 200 | 201 | BACKBONE.apply(weight_init) 202 | HEAD.apply(weight_init) 203 | 204 | # optionally resume from a checkpoint 205 | if BACKBONE_RESUME_ROOT and HEAD_RESUME_ROOT: 206 | print("=" * 60) 207 | print(BACKBONE_RESUME_ROOT,HEAD_RESUME_ROOT) 208 | if os.path.isfile(BACKBONE_RESUME_ROOT) and os.path.isfile(HEAD_RESUME_ROOT): 209 | print("Loading Backbone Checkpoint '{}'".format(BACKBONE_RESUME_ROOT)) 210 | BACKBONE.load_state_dict(jt.load(BACKBONE_RESUME_ROOT)) 211 | print("Loading Head Checkpoint '{}'".format(HEAD_RESUME_ROOT)) 212 | HEAD.load_state_dict(jt.load(HEAD_RESUME_ROOT)) 213 | else: 214 | print("No Checkpoint Found at '{}' and '{}'. Please Have a Check or Continue to Train from Scratch".format(BACKBONE_RESUME_ROOT, HEAD_RESUME_ROOT)) 215 | print("=" * 60) 216 | 217 | #======= train & validation & save checkpoint =======# 218 | DISP_FREQ = 20 # frequency to display training loss & acc 219 | VER_FREQ = 2000 220 | batch = 0 # batch index 221 | 222 | BACKBONE.train() # set to training mode 223 | HEAD.train() 224 | for epoch in range(NUM_EPOCH): 225 | 226 | if epoch in STAGES: 227 | schedule_lr(OPTIMIZER) 228 | 229 | if jt.rank == 0: 230 | last_time = time.time() 231 | 232 | for inputs, labels in iter(trainloader): 233 | labels = labels.long() 234 | features = BACKBONE(inputs) 235 | 236 | outputs, loss, intra_loss, inter_loss, WyiX, WjX = HEAD(features, labels) 237 | 238 | OPTIMIZER.zero_grad() 239 | OPTIMIZER.step(loss) 240 | 241 | prec1 = train_accuracy(outputs.detach(), labels, topk=(1,)) 242 | if jt.in_mpi: 243 | intra_loss = intra_loss.mpi_all_reduce('mean') 244 | inter_loss = inter_loss.mpi_all_reduce('mean') 245 | WyiX = WyiX.mpi_all_reduce('mean') 246 | WjX = WjX.mpi_all_reduce('mean') 247 | prec1 = prec1.mpi_all_reduce('mean') 248 | intra_loss_item = intra_loss.data.item() 249 | inter_loss_item = inter_loss.data.item() 250 | WyiX_item = WyiX.data.item() 251 | WjX_item = WjX.data.item() 252 | prec1_item = prec1.data.item() 253 | #embed() 254 | if jt.rank == 0: 255 | intra_losses.update(intra_loss_item, inputs.size(0) * gpu_nums) 256 | inter_losses.update(inter_loss_item, inputs.size(0) * gpu_nums) 257 | Wyi_mean.update(WyiX_item, inputs.size(0) * gpu_nums) 258 | Wj_mean.update(WjX_item, inputs.size(0) * gpu_nums) 259 | top1.update(prec1_item, inputs.size(0) * gpu_nums) 260 | 261 | if ((batch + 1) % DISP_FREQ == 0) and batch != 0 and jt.rank == 0: 262 | intra_epoch_loss = intra_losses.avg 263 | inter_epoch_loss = inter_losses.avg 264 | Wyi_record = Wyi_mean.avg 265 | Wj_record = Wj_mean.avg 266 | epoch_acc = top1.avg 267 | writer.add_scalar("intra_Loss", intra_epoch_loss, batch + 1) 268 | writer.add_scalar("inter_Loss", inter_epoch_loss, batch + 1) 269 | writer.add_scalar("Wyi", Wyi_record, batch + 1) 270 | writer.add_scalar("Wj", Wj_record, batch + 1) 271 | writer.add_scalar("Accuracy", epoch_acc, batch + 1) 272 | 273 | batch_time = time.time() - last_time 274 | last_time = time.time() 275 | 276 | print('Epoch {} Batch {}\t' 277 | 'Speed: {speed:.2f} samples/s\t' 278 | 'intra_Loss {loss1.val:.4f} ({loss1.avg:.4f})\t' 279 | 'inter_Loss {loss2.val:.4f} ({loss2.avg:.4f})\t' 280 | 'Wyi {Wyi.val:.4f} ({Wyi.avg:.4f})\t' 281 | 'Wj {Wj.val:.4f} ({Wj.avg:.4f})\t' 282 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 283 | epoch + 1, batch + 1, speed=inputs.size(0) * gpu_nums * DISP_FREQ / float(batch_time), 284 | loss1=intra_losses, loss2=inter_losses, Wyi=Wyi_mean, Wj=Wj_mean, top1=top1)) 285 | # print("=" * 60) 286 | intra_losses = AverageMeter() 287 | inter_losses = AverageMeter() 288 | Wyi_mean = AverageMeter() 289 | Wj_mean = AverageMeter() 290 | top1 = AverageMeter() 291 | 292 | if ((batch + 1) % VER_FREQ == 0) and batch != 0: # perform validation & save checkpoints (buffer for visualization) 293 | if jt.rank == 0: 294 | lr = OPTIMIZER.lr 295 | print("Learning rate %f" % lr) 296 | print("Perform Evaluation on", TARGET, ", and Save Checkpoints...") 297 | acc = [] 298 | for ver in vers: 299 | name, data_set, issame = ver 300 | accuracy, std, xnorm, best_threshold, roc_curve = perform_val(EMBEDDING_SIZE, BATCH_SIZE, 301 | BACKBONE, data_set, issame) 302 | if jt.rank == 0: 303 | buffer_val(writer, name, accuracy, std, xnorm, best_threshold, roc_curve, batch + 1) 304 | print('[%s][%d]XNorm: %1.5f' % (name, batch + 1, xnorm)) 305 | print('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (name, batch + 1, accuracy, std)) 306 | print('[%s][%d]Best-Threshold: %1.5f' % (name, batch + 1, best_threshold)) 307 | acc.append(accuracy) 308 | 309 | # save checkpoints per epoch 310 | jt.sync_all() 311 | if jt.rank == 0 and need_save(acc, highest_acc): 312 | BACKBONE.save(os.path.join(WORK_PATH, 313 | "Backbone_{}_Epoch_{}_Batch_{}_Time_{}_checkpoint.pkl".format( 314 | BACKBONE_NAME, epoch + 1, batch + 1, get_time()))) 315 | HEAD.save(os.path.join(WORK_PATH, 316 | "Head_{}_Epoch_{}_Batch_{}_Time_{}_checkpoint.pkl".format( 317 | HEAD_NAME, epoch + 1, batch + 1, get_time()))) 318 | BACKBONE.train() # set to training mode 319 | 320 | batch += 1 # batch index 321 | --------------------------------------------------------------------------------