├── .ipynb_checkpoints ├── casia_csv-checkpoint.py ├── infer-checkpoint.py └── train-checkpoint.py ├── casia_csv.py ├── config ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ └── train_config-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── train_config.cpython-37.pyc └── train_config.py ├── dataset ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── casia-checkpoint.py │ ├── casia_csv-checkpoint.py │ ├── casia_pfe-checkpoint.py │ ├── lfw-checkpoint.py │ └── verify-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── casia.cpython-37.pyc │ ├── casia_pfe.cpython-37.pyc │ ├── lfw.cpython-37.pyc │ └── verify.cpython-37.pyc ├── casia.py ├── casia_pfe.py ├── lfw.py └── verify.py ├── infer.py ├── log └── train_backbone.log ├── model ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── faceloss-checkpoint.py │ ├── fc_layer-checkpoint.py │ ├── mls_loss-checkpoint.py │ ├── mls_tf-checkpoint.py │ ├── mobilenet-checkpoint.py │ ├── resnet-checkpoint.py │ ├── uncertainty_head-checkpoint.py │ └── uncertainty_tf-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── faceloss.cpython-37.pyc │ ├── fc_layer.cpython-37.pyc │ ├── mls_loss.cpython-37.pyc │ ├── mobilenet.cpython-37.pyc │ ├── rescbam.cpython-37.pyc │ ├── resnet.cpython-37.pyc │ ├── spherenet.cpython-37.pyc │ └── uncertainty_head.cpython-37.pyc ├── mls_loss.py ├── mobilenet.py ├── resnet.py ├── spherenet.py └── uncertainty_head.py ├── test_img ├── .ipynb_checkpoints │ ├── Pedro_Solbes_0003-checkpoint.jpg │ ├── Pedro_Solbes_0004-checkpoint.jpg │ └── Zico_0003-checkpoint.jpg ├── Pedro_Solbes_0003.jpg ├── Pedro_Solbes_0004.jpg └── Zico_0003.jpg └── train.py /.ipynb_checkpoints/casia_csv-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/26 5 | author: lujie 6 | """ 7 | 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from IPython import embed 13 | 14 | if __name__ == "__main__": 15 | 16 | data_dir = '/home/jovyan/jupyer/benchmark_images/faceu/face_recognition/casia_webface' 17 | txt_file = os.path.join(data_dir, 'anno_file/casia_landmark.txt') 18 | with open(txt_file, 'r') as f: 19 | casia_org = f.readlines() 20 | f.close() 21 | org_file = [] 22 | for line in casia_org: 23 | 24 | line = line.strip().split('\t')[:2] 25 | org_file.append(line) 26 | df_org = pd.DataFrame(org_file, columns=['pid_face', 'uid']) 27 | 28 | align_file = [] 29 | align_dir = os.path.join(data_dir, 'align_112_112') 30 | for pid in os.listdir(align_dir): 31 | 32 | if '.DS' in pid or '._' in pid or '.ipy' in pid: 33 | continue 34 | 35 | for face in os.listdir(os.path.join(align_dir, pid)): 36 | 37 | if '.DS' in face or '._' in face or '.ipy' in face: 38 | continue 39 | align_file.append([pid + '/' + face, 0]) 40 | df_align = pd.DataFrame(align_file, columns=['pid_face', 'psedo']) 41 | 42 | df_inter = pd.merge(df_org, df_align, on='pid_face', how='inner') 43 | print('num_org : %4d, num_align : %4d, num_inter : %4d' % (len(df_org), len(df_align), len(df_inter))) 44 | df_inter = df_inter[['pid_face', 'uid']] 45 | data_inter = [] 46 | for idx, row in df_inter.iterrows(): 47 | 48 | row = list(row) 49 | line = ' '.join(row) + '\n' 50 | data_inter.append(line) 51 | 52 | inter_file = os.path.join(data_dir, 'anno_file/casia_org_join_align.txt') 53 | with open(inter_file, 'w') as f: 54 | f.writelines(data_inter) 55 | f.close() 56 | 57 | df_inter.to_csv(os.path.join(data_dir, 'anno_file/casia_org_join_align.csv'), index=None) 58 | 59 | 60 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/infer-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import cv2 6 | import time 7 | import torch 8 | import random 9 | import argparse 10 | import numpy as np 11 | import torchvision 12 | import torch.nn as nn 13 | import torchvision.transforms as T 14 | 15 | import model as mlib 16 | 17 | torch.backends.cudnn.bencmark = True 18 | # os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" # TODO 19 | 20 | from IPython import embed 21 | 22 | 23 | class ProbFace(object): 24 | 25 | def __init__(self, args): 26 | 27 | self.args = args 28 | self.model = dict() 29 | self.trans = T.Compose([T.ToTensor(), \ 30 | T.Normalize(mean=[0.5, 0.5, 0.5], \ 31 | std=[0.5, 0.5, 0.5])]) 32 | self.device = args.use_gpu and torch.cuda.is_available() 33 | self._model_loader() 34 | 35 | 36 | def _model_loader(self): 37 | 38 | self.model['backbone'] = mlib.MobileFace(self.args.in_feats, self.args.drop_ratio) 39 | self.model['uncertain'] = mlib.UncertaintyHead(self.args.in_feats) 40 | self.model['criterion'] = mlib.MLSLoss(mean=False) 41 | 42 | if self.device: 43 | self.model['backbone'] = self.model['backbone'].cuda() 44 | self.model['uncertain'] = self.model['uncertain'].cuda() 45 | self.model['criterion'] = self.model['criterion'].cuda() 46 | 47 | if self.device and len(self.args.gpu_ids) > 1: 48 | self.model['backbone'] = torch.nn.DataParallel(self.model['backbone'], device_ids=self.args.gpu_ids) 49 | self.model['uncertain'] = torch.nn.DataParallel(self.model['uncertain'], device_ids=self.args.gpu_ids) 50 | print('Parallel mode was going ...') 51 | elif self.device: 52 | print('Single-gpu mode was going ...') 53 | else: 54 | print('CPU mode was going ...') 55 | 56 | if len(self.args.resume) > 2: 57 | checkpoint = torch.load(self.args.resume, map_location=lambda storage, loc: storage) 58 | self.model['backbone'].load_state_dict(checkpoint['backbone']) 59 | self.model['uncertain'].load_state_dict(checkpoint['uncertain']) 60 | print('Resuming the train process at %3d epoches ...' % checkpoint['epoch']) 61 | 62 | self.model['backbone'].eval() 63 | self.model['uncertain'].eval() 64 | print('Model loading was finished ...') 65 | 66 | 67 | @staticmethod 68 | def cal_pair_mls(mu1, mu2, logsig_sq1=None, logsig_sq2=None): 69 | ''' Calculate the mls of pair faces ''' 70 | 71 | sig_sq1 = torch.exp(logsig_sq1) 72 | sig_sq2 = torch.exp(logsig_sq2) 73 | sig_sq_mutual = sig_sq1 + sig_sq2 74 | mu_diff = mu1 - mu2 75 | mls_pointwise = torch.mul(mu_diff, mu_diff) / sig_sq_mutual + torch.log(sig_sq_mutual) 76 | mls_score = mls_pointwise.sum(dim=1).item() 77 | return mls_score 78 | 79 | 80 | def _process_pair(self, face1, face2): 81 | ''' Get the mls score of pair faces ''' 82 | 83 | mls_score = None 84 | if face1 is None or face2 is None: 85 | mls_score = None 86 | else: 87 | face1 = self.trans(face1).unsqueeze(0) 88 | face2 = self.trans(face2).unsqueeze(0) 89 | 90 | if self.device == 'cuda': 91 | face1 = face1.cuda() 92 | face2 = face2.cuda() 93 | try: 94 | mu1, feat1 = self.model['backbone'](face1) 95 | mu2, feat2 = self.model['backbone'](face2) 96 | logsig_sq1 = self.model['uncertain'](feat1) 97 | logsig_sq2 = self.model['uncertain'](feat2) 98 | except Exception as e: 99 | print(e) 100 | else: 101 | mls_score = self.cal_pair_mls(mu1, mu2, logsig_sq1, logsig_sq2) 102 | return mls_score 103 | 104 | 105 | cp_dir = '/home/jovyan/jupyter/checkpoints_zoo/face-recognition' 106 | 107 | def infer_args(): 108 | 109 | parser = argparse.ArgumentParser(description='PyTorch for ProbFace') 110 | 111 | # -- env 112 | parser.add_argument('--use_gpu', type=bool, default=True) 113 | parser.add_argument('--gpu_ids', type=list, default=[0, 1]) 114 | parser.add_argument('--workers', type=int, default=0) 115 | 116 | # -- model 117 | parser.add_argument('--in_size', type=tuple, default=(112, 112)) # FIXED 118 | parser.add_argument('--in_feats', type=int, default=512) 119 | parser.add_argument('--drop_ratio', type=float, default=0.4) # TODOqq 120 | parser.add_argument('--resume', type=str, default=os.path.join(cp_dir, 'pfe/sota.pth.tar')) # checkpoint 121 | 122 | args = parser.parse_args() 123 | 124 | return args 125 | 126 | 127 | if __name__ == "__main__": 128 | 129 | probu = ProbFace(infer_args()) 130 | face1 = cv2.imread('test_img/Pedro_Solbes_0003.jpg') 131 | face2 = cv2.imread('test_img/Pedro_Solbes_0004.jpg') 132 | face3 = cv2.imread('test_img/Zico_0003.jpg') 133 | mls_score = probu._process_pair(face2, face3) 134 | print(mls_score) 135 | embed() 136 | -------------------------------------------------------------------------------- /.ipynb_checkpoints/train-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import random 9 | import numpy as np 10 | import torchvision 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from sklearn import metrics 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | 17 | import model as mlib 18 | import dataset as dlib 19 | from config import training_args 20 | 21 | torch.backends.cudnn.bencmark = True 22 | # os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" # TODO 23 | 24 | from IPython import embed 25 | 26 | 27 | def my_collate_fn(batch): 28 | 29 | imgs, gtys = [], [] 30 | for pid_imgs, gty in batch: 31 | imgs.extend(pid_imgs) 32 | gtys.extend([gty] * len(pid_imgs)) 33 | return (torch.stack(imgs, dim=0), torch.Tensor(gtys).long()) 34 | 35 | 36 | 37 | class MetricFace(dlib.VerifyFace): 38 | 39 | def __init__(self, args): 40 | 41 | dlib.VerifyFace.__init__(self, args) 42 | self.args = args 43 | self.model = dict() 44 | self.data = dict() 45 | self.softmax= torch.nn.Softmax(dim=1) 46 | self.device = args.use_gpu and torch.cuda.is_available() 47 | 48 | 49 | def _report_settings(self): 50 | ''' Report the settings ''' 51 | 52 | str = '-' * 16 53 | print('%sEnvironment Versions%s' % (str, str)) 54 | print("- Python : {}".format(sys.version.strip().split('|')[0])) 55 | print("- PyTorch : {}".format(torch.__version__)) 56 | print("- TorchVison: {}".format(torchvision.__version__)) 57 | print("- USE_GPU : {}".format(self.device)) 58 | print('-' * 52) 59 | 60 | 61 | def _model_loader(self): 62 | 63 | self.model['backbone'] = mlib.MobileFace(self.args.in_feats, self.args.drop_ratio) 64 | # self.model['backbone'] = mlib.iresnet_zoo(self.args.backbone, drop_ratio=self.args.drop_ratio, use_se = self.args.use_se) # SEBlock 65 | # self.model['backbone'] = mlib.resnet_zoo(self.args.backbone, drop_ratio=self.args.drop_ratio) # ResBlock 66 | # self.model['metric'] = mlib.FullyConnectedLayer(self.args) 67 | self.model['uncertain'] = mlib.UncertaintyHead(self.args.in_feats) 68 | # self.model['criterion'] = mlib.FaceLoss(self.args) 69 | self.model['criterion'] = mlib.MLSLoss(mean=False) 70 | 71 | if self.args.freeze_backbone: 72 | for p in self.model['backbone'].parameters(): 73 | p.requires_grad = False 74 | 75 | self.model['optimizer'] = torch.optim.SGD( 76 | [# {'params': self.model['backbone'].parameters()}, 77 | #{'params': self.model['metric'].parameters()}, 78 | {'params': self.model['uncertain'].parameters()}], 79 | lr=self.args.base_lr, 80 | weight_decay=self.args.weight_decay, 81 | momentum=0.9, 82 | nesterov=True) 83 | self.model['scheduler'] = torch.optim.lr_scheduler.MultiStepLR( 84 | self.model['optimizer'], milestones=self.args.lr_adjust, gamma=self.args.gamma) 85 | if self.device: 86 | self.model['backbone'] = self.model['backbone'].cuda() 87 | self.model['uncertain'] = self.model['uncertain'].cuda() 88 | # self.model['metric'] = self.model['metric'].cuda() 89 | self.model['criterion'] = self.model['criterion'].cuda() 90 | 91 | if self.device and len(self.args.gpu_ids) > 1: 92 | self.model['backbone'] = torch.nn.DataParallel(self.model['backbone'], device_ids=self.args.gpu_ids) 93 | self.model['uncertain'] = torch.nn.DataParallel(self.model['uncertain'], device_ids=self.args.gpu_ids) 94 | # self.model['metric'] = torch.nn.DataParallel(self.model['metric'], device_ids=self.args.gpu_ids) 95 | print('Parallel mode was going ...') 96 | elif self.device: 97 | print('Single-gpu mode was going ...') 98 | else: 99 | print('CPU mode was going ...') 100 | 101 | if len(self.args.resume) > 2: 102 | checkpoint = torch.load(self.args.resume, map_location=lambda storage, loc: storage) 103 | # self.args.start_epoch = checkpoint['epoch'] 104 | self.model['backbone'].load_state_dict(checkpoint['backbone']) 105 | # self.model['uncertain'].load_state_dict(checkpoint['uncertain']) 106 | # self.model['metric'].load_state_dict(checkpoint['metric']) 107 | print('Resuming the train process at %3d epoches ...' % self.args.start_epoch) 108 | print('Model loading was finished ...') 109 | 110 | 111 | def _data_loader(self): 112 | 113 | self.data['train_loader'] = DataLoader( 114 | dlib.CASIAWebFacePFE(self.args, mode='train'), 115 | batch_size=self.args.batch_size, \ 116 | shuffle=True, 117 | collate_fn=my_collate_fn, 118 | ) 119 | # self.data['lfw'] = dlib.LFW(self.args) # TODO 120 | print('Data loading was finished ...') 121 | 122 | 123 | def _model_train(self, epoch = 0): 124 | 125 | self.model['backbone'].eval() 126 | # self.model['metric'].train() 127 | self.model['uncertain'].train() 128 | 129 | loss_recorder, batch_acc = [], [] 130 | for idx, (img, gty) in enumerate(self.data['train_loader']): 131 | 132 | img.requires_grad = False 133 | gty.requires_grad = False 134 | 135 | if self.device: 136 | img = img.cuda() 137 | gty = gty.cuda() 138 | 139 | feature, sig_feat = self.model['backbone'](img) # TODO 140 | # output = self.model['metric'](feature, gty) 141 | # loss = self.model['criterion'](output, gty) 142 | log_sig_sq = self.model['uncertain'](sig_feat) 143 | loss = self.model['criterion'](feature, log_sig_sq, gty) 144 | self.model['optimizer'].zero_grad() 145 | loss.backward() 146 | self.model['optimizer'].step() 147 | # predy = np.argmax(output.data.cpu().numpy(), axis=1) # TODO 148 | # it_acc = np.mean((predy == gty.data.cpu().numpy()).astype(int)) 149 | # batch_acc.append(it_acc) 150 | loss_recorder.append(loss.item()) 151 | if (idx + 1) % self.args.print_freq == 0: 152 | print('epoch : %2d|%2d, iter : %4d|%4d, loss : %.4f' % \ 153 | (epoch, self.args.end_epoch, idx+1, len(self.data['train_loader']), np.mean(loss_recorder))) 154 | ''' 155 | print('epoch : %2d|%2d, iter : %4d|%4d, loss : %.4f, batch_ave_acc : %.4f' % \ 156 | (epoch, self.args.end_epoch, idx+1, len(self.data['train_loader']), \ 157 | np.mean(loss_recorder), np.mean(batch_acc))) 158 | ''' 159 | train_loss = np.mean(loss_recorder) 160 | print('train_loss : %.4f' % train_loss) 161 | return train_loss 162 | 163 | 164 | def _verify_lfw(self): 165 | 166 | self._eval_lfw() 167 | 168 | self._k_folds() 169 | 170 | best_thresh, lfw_acc = self._eval_runner() 171 | 172 | return best_thresh, lfw_acc 173 | 174 | 175 | def _main_loop(self): 176 | 177 | if not os.path.exists(self.args.save_to): 178 | os.mkdir(self.args.save_to) 179 | 180 | max_lfw_acc, min_train_loss = 0.0, 100 181 | for epoch in range(self.args.start_epoch, self.args.end_epoch + 1): 182 | 183 | start_time = time.time() 184 | 185 | train_loss = self._model_train(epoch) 186 | self.model['scheduler'].step() 187 | # lfw_thresh, lfw_acc = self._verify_lfw() 188 | 189 | end_time = time.time() 190 | print('Single epoch cost time : %.2f mins' % ((end_time - start_time)/60)) 191 | 192 | if min_train_loss > train_loss: 193 | 194 | print('%snew SOTA was found%s' % ('*'*16, '*'*16)) 195 | # max_lfw_acc = max(max_lfw_acc, lfw_acc) 196 | min_train_loss = train_loss 197 | filename = os.path.join(self.args.save_to, 'sota.pth.tar') 198 | torch.save({ 199 | 'epoch' : epoch, 200 | 'backbone' : self.model['backbone'].state_dict(), 201 | 'uncertain' : self.model['uncertain'].state_dict(), 202 | 'train_loss': min_train_loss, 203 | }, filename) 204 | 205 | if epoch % self.args.save_freq == 0: 206 | filename = 'epoch_%d_train_loss_%.4f.pth.tar' % (epoch, train_loss) 207 | savename = os.path.join(self.args.save_to, filename) 208 | torch.save({ 209 | 'epoch' : epoch, 210 | 'backbone' : self.model['backbone'].state_dict(), 211 | 'uncertain' : self.model['uncertain'].state_dict(), 212 | 'train_loss': train_loss, 213 | }, savename) 214 | 215 | if self.args.is_debug: 216 | break 217 | 218 | 219 | def train_runner(self): 220 | 221 | self._report_settings() 222 | 223 | self._model_loader() 224 | 225 | self._data_loader() 226 | 227 | self._main_loop() 228 | 229 | 230 | if __name__ == "__main__": 231 | 232 | faceu = MetricFace(training_args()) 233 | faceu.train_runner() 234 | -------------------------------------------------------------------------------- /casia_csv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/26 5 | author: lujie 6 | """ 7 | 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from IPython import embed 13 | 14 | if __name__ == "__main__": 15 | 16 | data_dir = '/home/jovyan/jupyer/benchmark_images/faceu/face_recognition/casia_webface' 17 | txt_file = os.path.join(data_dir, 'anno_file/casia_landmark.txt') 18 | with open(txt_file, 'r') as f: 19 | casia_org = f.readlines() 20 | f.close() 21 | org_file = [] 22 | for line in casia_org: 23 | 24 | line = line.strip().split('\t')[:2] 25 | org_file.append(line) 26 | df_org = pd.DataFrame(org_file, columns=['pid_face', 'uid']) 27 | 28 | align_file = [] 29 | align_dir = os.path.join(data_dir, 'align_112_112') 30 | for pid in os.listdir(align_dir): 31 | 32 | if '.DS' in pid or '._' in pid or '.ipy' in pid: 33 | continue 34 | 35 | for face in os.listdir(os.path.join(align_dir, pid)): 36 | 37 | if '.DS' in face or '._' in face or '.ipy' in face: 38 | continue 39 | align_file.append([pid + '/' + face, 0]) 40 | df_align = pd.DataFrame(align_file, columns=['pid_face', 'psedo']) 41 | 42 | df_inter = pd.merge(df_org, df_align, on='pid_face', how='inner') 43 | print('num_org : %4d, num_align : %4d, num_inter : %4d' % (len(df_org), len(df_align), len(df_inter))) 44 | df_inter = df_inter[['pid_face', 'uid']] 45 | data_inter = [] 46 | for idx, row in df_inter.iterrows(): 47 | 48 | row = list(row) 49 | line = ' '.join(row) + '\n' 50 | data_inter.append(line) 51 | 52 | inter_file = os.path.join(data_dir, 'anno_file/casia_org_join_align.txt') 53 | with open(inter_file, 'w') as f: 54 | f.writelines(data_inter) 55 | f.close() 56 | 57 | df_inter.to_csv(os.path.join(data_dir, 'anno_file/casia_org_join_align.csv'), index=None) 58 | 59 | 60 | -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .train_config import training_args 2 | -------------------------------------------------------------------------------- /config/.ipynb_checkpoints/train_config-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | 4 | import argparse 5 | import os.path as osp 6 | 7 | root_dir = '/home/jovyan/jupyter/benchmark_images/faceu' 8 | lfw_dir = osp.join(root_dir, 'face_verfication/lfw') 9 | casia_dir = osp.join(root_dir, 'face_recognition/casia_webface') 10 | cp_dir = '/home/jovyan/jupyter/checkpoints_zoo/face-recognition' 11 | 12 | def training_args(): 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch metricface') 15 | 16 | # -- env 17 | parser.add_argument('--use_gpu', type=bool, default=True) 18 | parser.add_argument('--gpu_ids', type=list, default=[0, 1, 2, 3]) 19 | parser.add_argument('--workers', type=int, default=0) 20 | 21 | # -- model 22 | parser.add_argument('--in_size', type=tuple, default=(112, 112)) # FIXED 23 | parser.add_argument('--offset', type=int, default=2) # FIXED 24 | parser.add_argument('--t', type=float, default=0.2) # MV 25 | parser.add_argument('--margin', type=float, default=0.5) # FIXED 26 | parser.add_argument('--easy_margin',type=bool, default=True) 27 | parser.add_argument('--scale', type=float, default=32) # FIXED 28 | parser.add_argument('--backbone', type=str, default='resnet18') # TODO | iresse50 29 | parser.add_argument('--in_feats', type=int, default=512) 30 | parser.add_argument('--drop_ratio', type=float, default=0.4) # TODOqq 31 | 32 | parser.add_argument('--fc_mode', type=str, default='arcface', choices=['softmax', 'sphere', 'cosface', 'arcface', 'mvcos', 'mvarc']) 33 | parser.add_argument('--hard_mode', type=str, default='adaptive', choices=['fixed', 'adaptive']) # MV 34 | parser.add_argument('--loss_mode', type=str, default='ce', choices=['ce', 'focal_loss', 'hardmining']) 35 | parser.add_argument('--hard_ratio', type=float, default=0.9) # hardmining 36 | parser.add_argument('--loss_power', type=int, default=2) # focal_loss 37 | parser.add_argument('--classnum', type=int, default=10574) # CASIA (10574) 38 | 39 | # fine-tuning 40 | parser.add_argument('--resume', type=str, default=osp.join(cp_dir, 'pfe/epoch_5_train_loss_-1079.0558.pth.tar')) # checkpoint 41 | parser.add_argument('--fine_tuning', type=bool, default=False) # just fine-tuning 42 | parser.add_argument('--freeze_backbone', type=bool, default=True) # TODO 43 | 44 | # -- optimizer 45 | parser.add_argument('--start_epoch', type=int, default=1) # 46 | parser.add_argument('--end_epoch', type=int, default=5) 47 | parser.add_argument('--batch_size', type=int, default=64) # NOTE : 48 | parser.add_argument('--num_face_pb', type=int, default=4) 49 | parser.add_argument('--base_lr', type=float, default=1e-3) # TODO : [0.1 for backbone] 50 | parser.add_argument('--lr_adjust', type=list, default=[2, 3, 4]) # TODO : [16, 25, 35] 51 | parser.add_argument('--gamma', type=float, default=0.1) # FIXED 52 | parser.add_argument('--weight_decay',type=float, default=5e-4) # FIXED 53 | 54 | # -- dataset 55 | parser.add_argument('--casia_dir', type=str, default=casia_dir) 56 | parser.add_argument('--lfw_dir', type=str, default=osp.join(lfw_dir, 'align_112_112')) 57 | parser.add_argument('--train_file', type=str, default=osp.join(casia_dir, 'anno_file/casia_org_join_align.txt')) 58 | parser.add_argument('--pairs_file', type=str, default=osp.join(lfw_dir, 'anno_file/pairs.txt')) 59 | parser.add_argument('--try_times', type=int, default=5) 60 | 61 | 62 | # -- verification 63 | parser.add_argument('--n_folds', type=int, default=10) 64 | parser.add_argument('--thresh_iv', type=float, default=0.005) 65 | 66 | # -- save or print 67 | parser.add_argument('--is_debug', type=str, default=False) # TODO 68 | parser.add_argument('--save_to', type=str, default=osp.join(cp_dir, 'pfe')) 69 | parser.add_argument('--print_freq',type=int, default=100) # v0 : <64, 166> | <128, 83> 70 | parser.add_argument('--save_freq', type=int, default=1) # TODO 71 | 72 | args = parser.parse_args() 73 | 74 | return args 75 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_config import training_args 2 | -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /config/__pycache__/train_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/config/__pycache__/train_config.cpython-37.pyc -------------------------------------------------------------------------------- /config/train_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | 4 | import argparse 5 | import os.path as osp 6 | 7 | root_dir = '/home/jovyan/jupyter/benchmark_images/faceu' 8 | lfw_dir = osp.join(root_dir, 'face_verfication/lfw') 9 | casia_dir = osp.join(root_dir, 'face_recognition/casia_webface') 10 | cp_dir = '/home/jovyan/jupyter/checkpoints_zoo/face-recognition' 11 | 12 | def training_args(): 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch metricface') 15 | 16 | # -- env 17 | parser.add_argument('--use_gpu', type=bool, default=True) 18 | parser.add_argument('--gpu_ids', type=list, default=[0, 1, 2, 3]) 19 | parser.add_argument('--workers', type=int, default=0) 20 | 21 | # -- model 22 | parser.add_argument('--in_size', type=tuple, default=(112, 112)) # FIXED 23 | parser.add_argument('--offset', type=int, default=2) # FIXED 24 | parser.add_argument('--t', type=float, default=0.2) # MV 25 | parser.add_argument('--margin', type=float, default=0.5) # FIXED 26 | parser.add_argument('--easy_margin',type=bool, default=True) 27 | parser.add_argument('--scale', type=float, default=32) # FIXED 28 | parser.add_argument('--backbone', type=str, default='resnet18') # TODO | iresse50 29 | parser.add_argument('--in_feats', type=int, default=512) 30 | parser.add_argument('--drop_ratio', type=float, default=0.4) # TODOqq 31 | 32 | parser.add_argument('--fc_mode', type=str, default='arcface', choices=['softmax', 'sphere', 'cosface', 'arcface', 'mvcos', 'mvarc']) 33 | parser.add_argument('--hard_mode', type=str, default='adaptive', choices=['fixed', 'adaptive']) # MV 34 | parser.add_argument('--loss_mode', type=str, default='ce', choices=['ce', 'focal_loss', 'hardmining']) 35 | parser.add_argument('--hard_ratio', type=float, default=0.9) # hardmining 36 | parser.add_argument('--loss_power', type=int, default=2) # focal_loss 37 | parser.add_argument('--classnum', type=int, default=10574) # CASIA (10574) 38 | 39 | # fine-tuning 40 | parser.add_argument('--resume', type=str, default=osp.join(cp_dir, 'pfe/epoch_5_train_loss_-1079.0558.pth.tar')) # checkpoint 41 | parser.add_argument('--fine_tuning', type=bool, default=False) # just fine-tuning 42 | parser.add_argument('--freeze_backbone', type=bool, default=True) # TODO 43 | 44 | # -- optimizer 45 | parser.add_argument('--start_epoch', type=int, default=1) # 46 | parser.add_argument('--end_epoch', type=int, default=5) 47 | parser.add_argument('--batch_size', type=int, default=64) # NOTE : 48 | parser.add_argument('--num_face_pb', type=int, default=4) 49 | parser.add_argument('--base_lr', type=float, default=1e-3) # TODO : [0.1 for backbone] 50 | parser.add_argument('--lr_adjust', type=list, default=[2, 3, 4]) # TODO : [16, 25, 35] 51 | parser.add_argument('--gamma', type=float, default=0.1) # FIXED 52 | parser.add_argument('--weight_decay',type=float, default=5e-4) # FIXED 53 | 54 | # -- dataset 55 | parser.add_argument('--casia_dir', type=str, default=casia_dir) 56 | parser.add_argument('--lfw_dir', type=str, default=osp.join(lfw_dir, 'align_112_112')) 57 | parser.add_argument('--train_file', type=str, default=osp.join(casia_dir, 'anno_file/casia_org_join_align.txt')) 58 | parser.add_argument('--pairs_file', type=str, default=osp.join(lfw_dir, 'anno_file/pairs.txt')) 59 | parser.add_argument('--try_times', type=int, default=5) 60 | 61 | 62 | # -- verification 63 | parser.add_argument('--n_folds', type=int, default=10) 64 | parser.add_argument('--thresh_iv', type=float, default=0.005) 65 | 66 | # -- save or print 67 | parser.add_argument('--is_debug', type=str, default=False) # TODO 68 | parser.add_argument('--save_to', type=str, default=osp.join(cp_dir, 'pfe')) 69 | parser.add_argument('--print_freq',type=int, default=100) # v0 : <64, 166> | <128, 83> 70 | parser.add_argument('--save_freq', type=int, default=1) # TODO 71 | 72 | args = parser.parse_args() 73 | 74 | return args 75 | -------------------------------------------------------------------------------- /dataset/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .lfw import LFW 2 | from .casia import CASIAWebFace 3 | from .verify import VerifyFace 4 | from .casia_pfe import CASIAWebFacePFE 5 | -------------------------------------------------------------------------------- /dataset/.ipynb_checkpoints/casia-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/26 5 | author: lujie 6 | """ 7 | 8 | 9 | import os 10 | import cv2 11 | import random 12 | import numpy as np 13 | import torchvision 14 | from torch.utils import data 15 | 16 | from IPython import embed 17 | 18 | 19 | class CASIAWebFace(data.Dataset): 20 | 21 | def __init__(self, args, mode = 'train'): 22 | 23 | super(CASIAWebFace, self).__init__() 24 | self.args = args 25 | self.mode = mode 26 | self.transforms = torchvision.transforms.Compose([ 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], \ 29 | std=[0.5, 0.5, 0.5])]) 30 | with open(args.train_file, 'r') as f: 31 | self.lines = f.readlines() 32 | f.close() 33 | if args.is_debug: 34 | self.lines = self.lines[:512] # just for debug 35 | print('debug version for casia ...') 36 | 37 | 38 | def _load_imginfo(self, img_name): 39 | 40 | img_path = os.path.join(self.args.casia_dir, 'align_112_112', img_name) 41 | img = None 42 | try: 43 | img = cv2.resize(cv2.imread(img_path), self.args.in_size) # TODO 44 | if random.random() > 0.5: 45 | img = cv2.flip(img, 1) 46 | except Exception as e: 47 | img = None 48 | return img 49 | 50 | 51 | def __getitem__(self, index): 52 | 53 | # info = self.lines[index].strip().split('\t') 54 | info = self.lines[index].strip().split(' ') 55 | img = self._load_imginfo(info[0]) 56 | cnt_try = 0 57 | while (img is None) and cnt_try < self.args.try_times: 58 | idx = np.random.randint(0, len(self.lines) - 1) 59 | # info = self.lines[idx].strip().split('\t') 60 | info = self.lines[idx].strip().split(' ') 61 | img = self._load_imginfo(info[0]) 62 | cnt_try += 1 63 | if cnt_try == self.args.try_times: 64 | print('read face failed ...') 65 | img = self.transforms(img) 66 | return (img, int(info[1])) 67 | 68 | 69 | def __len__(self): 70 | return len(self.lines) 71 | -------------------------------------------------------------------------------- /dataset/.ipynb_checkpoints/casia_csv-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/26 5 | author: lujie 6 | """ 7 | 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from IPython import embed 13 | 14 | if __name__ == "__main__": 15 | 16 | data_dir = '/home/jovyan/jupyer/benchmark_images/faceu/face_recognition/casia_webface' 17 | txt_file = os.path.join(data_dir, 'anno_file/casia_landmark.txt') 18 | with open(txt_file, 'r') as f: 19 | casia_org = f.readlines() 20 | f.close() 21 | org_file = [] 22 | for line in casia_org: 23 | 24 | line = line.strip().split('\t') 25 | embed() 26 | 27 | 28 | 29 | align_file = [] 30 | align_dir = os.path.join(data_dir, 'align_112_112') 31 | for pid in os.listdir(align_dir): 32 | 33 | if '.DS' in pid or '._' in pid or '.ipy' in pid: 34 | continue 35 | 36 | for face in os.listdir(os.path.join(align_dir, pid)): 37 | 38 | if '.DS' in face or '._' in face or '.ipy' in face: 39 | continue 40 | align_file.append([pid + '/' + face, 0]) 41 | df_align = pd.DataFrame(align_file, columns=['pid_face', 'psedo']) 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /dataset/.ipynb_checkpoints/casia_pfe-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/26 5 | author: lujie 6 | """ 7 | 8 | 9 | import os 10 | import cv2 11 | import random 12 | import numpy as np 13 | import torchvision 14 | import pandas as pd 15 | from torch.utils import data 16 | 17 | from IPython import embed 18 | 19 | 20 | class CASIAWebFacePFE(data.Dataset): 21 | 22 | def __init__(self, args, mode = 'train'): 23 | 24 | super(CASIAWebFacePFE, self).__init__() 25 | self.args = args 26 | self.mode = mode 27 | self.transforms = torchvision.transforms.Compose([ 28 | torchvision.transforms.ToTensor(), 29 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], \ 30 | std=[0.5, 0.5, 0.5])]) 31 | self._pfe_process() 32 | 33 | 34 | def _pfe_process(self): 35 | 36 | with open(self.args.train_file, 'r') as f: 37 | self.lines = f.readlines() 38 | f.close() 39 | if self.args.is_debug: 40 | self.lines = self.lines[:1024] # just for debug 41 | print('debug version for casia ...') 42 | 43 | pids_dict = {} 44 | for line in self.lines: 45 | 46 | line = line.strip().split(' ') 47 | if line[-1] in pids_dict.keys(): 48 | pids_dict[line[-1]].append(line[0]) 49 | else: 50 | pids_dict[line[-1]] = [line[0]] 51 | 52 | self.lines = pd.Series(pids_dict).to_frame() 53 | self.lines['pid'] = self.lines.index 54 | self.lines.index = range(len(self.lines)) 55 | self.lines = np.array(self.lines[['pid', 0]]).tolist() * 6 56 | random.shuffle(self.lines) 57 | self.lines = np.array(self.lines) 58 | 59 | 60 | 61 | def _random_samples_from_class(self, files_list): 62 | 63 | indices = [] 64 | random.shuffle(files_list) 65 | if len(files_list) >= self.args.num_face_pb: 66 | indices = files_list[:self.args.num_face_pb] 67 | else: 68 | extend_times = int(np.ceil(self.args.num_face_pb / max(1, len(files_list)))) - 1 69 | tmp_list = files_list 70 | for i in range(extend_times): 71 | tmp_list.extend(files_list) 72 | indices = tmp_list[:self.args.num_face_pb] 73 | return indices 74 | 75 | 76 | def _load_imginfo(self, files_list): 77 | 78 | sample_files = self._random_samples_from_class(files_list) 79 | 80 | imgs = [] 81 | try: 82 | for file in sample_files: 83 | img_path = os.path.join(self.args.casia_dir, 'align_112_112', file) 84 | img = cv2.resize(cv2.imread(img_path), self.args.in_size) # TODO 85 | if random.random() > 0.5: 86 | img = cv2.flip(img, 1) 87 | img = self.transforms(img) 88 | imgs.append(img) 89 | except Exception as e: 90 | imgs = [] 91 | return imgs 92 | 93 | 94 | def __getitem__(self, index): 95 | 96 | info = self.lines[index] 97 | imgs = self._load_imginfo(info[1]) 98 | while len(imgs) == 0: 99 | idx = np.random.randint(0, len(self.lines) - 1) 100 | info = self.lines[idx] 101 | imgs = self._load_imginfo(info[1]) 102 | 103 | return (imgs, int(info[0])) 104 | 105 | 106 | def __len__(self): 107 | return len(self.lines) 108 | -------------------------------------------------------------------------------- /dataset/.ipynb_checkpoints/lfw-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import cv2 6 | import random 7 | import numpy as np 8 | from torch.utils import data 9 | import torchvision.transforms as T 10 | 11 | from IPython import embed 12 | 13 | class LFW(object): 14 | 15 | def __init__(self, args, mode = 'test'): 16 | 17 | self.args = args 18 | self.mode = mode 19 | self.trans = T.Compose([T.ToTensor(), \ 20 | T.Normalize(mean=[0.5, 0.5, 0.5], \ 21 | std=[0.5, 0.5, 0.5])]) 22 | with open(args.pairs_file, 'r') as f: 23 | self.pairs = np.array(f.readlines()) 24 | shuffleidx = np.random.permutation(len(self.pairs)) 25 | self.pairs = self.pairs[shuffleidx] 26 | f.close() 27 | if args.is_debug: 28 | self.pairs = self.pairs[:512] 29 | print('debug version for lfw ...') 30 | self.num_pairs = len(self.pairs) 31 | 32 | 33 | def _load_imginfo(self, img_name): 34 | 35 | img_path = os.path.join(self.args.lfw_dir, img_name) 36 | img = None 37 | try: 38 | img = cv2.resize(cv2.imread(img_path), self.args.in_size) # TODO 39 | except Exception as e: 40 | img = None 41 | return img 42 | 43 | 44 | def _get_pair(self, index): 45 | 46 | pair_info = self.pairs[index].strip().split('\t') 47 | info_dict = {} 48 | try: 49 | if 3 == len(pair_info): 50 | info_dict['label'] = 1 51 | info_dict['name1'] = pair_info[0] + '/' + pair_info[0] + '_' + '{:04}.jpg'.format(int(pair_info[1])) 52 | info_dict['name2'] = pair_info[0] + '/' + pair_info[0] + '_' + '{:04}.jpg'.format(int(pair_info[2])) 53 | elif 4 == len(pair_info): 54 | info_dict['label'] = 0 55 | info_dict['name1'] = pair_info[0] + '/' + pair_info[0] + '_' + '{:04}.jpg'.format(int(pair_info[1])) 56 | info_dict['name2'] = pair_info[2] + '/' + pair_info[2] + '_' + '{:04}.jpg'.format(int(pair_info[3])) 57 | 58 | if info_dict['label'] is not None: 59 | face1 = self._load_imginfo(info_dict['name1']) 60 | face2 = self._load_imginfo(info_dict['name2']) 61 | info_dict['face1'] = self.trans(face1).unsqueeze(0) 62 | info_dict['face2'] = self.trans(face2).unsqueeze(0) 63 | except Exception as e: 64 | for key in ['name1', 'name2', 'label', 'face1', 'face2']: 65 | info_dict[key] = None 66 | return info_dict 67 | -------------------------------------------------------------------------------- /dataset/.ipynb_checkpoints/verify-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import random 9 | import numpy as np 10 | import torchvision 11 | import pandas as pd 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from sklearn import metrics 15 | import torch.nn.functional as F 16 | 17 | torch.backends.cudnn.bencmark = True 18 | 19 | from IPython import embed 20 | 21 | 22 | class VerifyFace(object): 23 | 24 | def __init__(self, args): 25 | 26 | self.args = args 27 | self.model = dict() 28 | self.data = dict() 29 | self.device = args.use_gpu and torch.cuda.is_available() 30 | 31 | 32 | def _report_settings(self): 33 | ''' Report the settings ''' 34 | 35 | str = '-' * 16 36 | print('%sEnvironment Versions%s' % (str, str)) 37 | print("- Python: {}".format(sys.version.strip().split('|')[0])) 38 | print("- PyTorch: {}".format(torch.__version__)) 39 | print("- TorchVison: {}".format(torchvision.__version__)) 40 | print("- device: {}".format(self.device)) 41 | print('-' * 52) 42 | 43 | 44 | def _model_loader(self): 45 | pass 46 | 47 | 48 | def _eval_lfw(self): 49 | 50 | self.model['backbone'].eval() # CORE 51 | with torch.no_grad(): 52 | simi_list = [] 53 | for index in range(1, self.data['lfw'].num_pairs): 54 | 55 | try: 56 | pair_dict = self.data['lfw']._get_pair(index) 57 | if pair_dict['label'] is not None: 58 | if self.device: 59 | pair_dict['face1'] = pair_dict['face1'].cuda() 60 | pair_dict['face2'] = pair_dict['face2'].cuda() 61 | feat1 = self.model['backbone'](pair_dict['face1']) 62 | feat2 = self.model['backbone'](pair_dict['face2']) 63 | cosvalue = feat1[0].dot(feat2[0]) / (feat1[0].norm() * feat2[0].norm() + 1e-5) 64 | simi_list.append([pair_dict['name1'], pair_dict['name2'], pair_dict['label'], cosvalue.item()]) 65 | except: 66 | pass 67 | # if (index + 1) % 500 == 0: 68 | # print('alreay processed %3d, total %3d' % (index+1, self.data['lfw'].num_pairs)) 69 | self.data['similist'] = np.array(simi_list) 70 | # col_name = ['name1', 'name2', 'gt_y', 'pred_y'] 71 | # df_simi = pd.DataFrame(self.data['similist'], columns=col_name) 72 | # df_simi.to_csv('check_2.csv') 73 | # print('lfw-pair faces was evaluated, there are %3d paris' % len(simi_list)) 74 | 75 | 76 | def _eval_aku8k(self): 77 | ''' Design for the raw-images, which equiped with a pairfile.csv ''' 78 | 79 | self.model['backbone'].eval() 80 | with torch.no_grad(): 81 | 82 | simi_list = [] 83 | for index in range(1, self.data['aku8k'].num_pairs): 84 | 85 | try: 86 | pair_dict = self.data['aku8k']._get_pair(index) 87 | if pair_dict['label'] is not None: 88 | if self.device: 89 | pair_dict['face1'] = pair_dict['face1'].cuda() 90 | pair_dict['face2'] = pair_dict['face2'].cuda() 91 | feat1 = self.model['backbone'](pair_dict['face1']) 92 | feat2 = self.model['backbone'](pair_dict['face2']) 93 | cosvalue = feat1[0].dot(feat2[0]) / (feat1[0].norm() * feat2[0].norm() + 1e-5) 94 | simi_list.append([pair_dict['name1'], pair_dict['name2'], pair_dict['label'], cosvalue.item()]) 95 | except: 96 | pass 97 | # if (index + 1) % 500 == 0: 98 | # print('alreay processed %3d, total %3d' % (index+1, self.data['lfw'].num_pairs)) 99 | self.data['similist'] = np.array(simi_list) 100 | 101 | 102 | 103 | def _k_folds(self): 104 | 105 | num_lines = len(self.data['similist']) 106 | folds, base = [], list(range(num_lines)) 107 | for k in range(self.args.n_folds): 108 | 109 | start = int(k * num_lines / self.args.n_folds) 110 | end = int((k + 1) * num_lines / self.args.n_folds) 111 | test = base[start : end] 112 | train = list(set(base) - set(test)) 113 | folds.append([train, test]) 114 | self.data['folds'] = folds 115 | 116 | 117 | def _cal_acc(self, index, thresh): 118 | 119 | gt_y, pred_y = [], [] 120 | for row in self.data['similist'][index]: 121 | 122 | same = 1 if float(row[-1]) > thresh else 0 123 | pred_y.append(same) 124 | gt_y.append(int(row[-2])) 125 | gt_y = np.array(gt_y) 126 | pred_y = np.array(pred_y) 127 | accuracy = 1.0 * np.count_nonzero(gt_y==pred_y) / len(gt_y) 128 | return accuracy 129 | 130 | 131 | def _find_best_thresh(self, train, test): 132 | 133 | best_thresh, best_acc = 0, 0 134 | for thresh in np.arange(-1, 1, self.args.thresh_iv): 135 | 136 | acc = self._cal_acc(train, thresh) 137 | if best_acc < acc: 138 | best_acc = acc 139 | best_thresh = thresh 140 | test_acc = self._cal_acc(test, best_thresh) 141 | return (best_thresh, test_acc) 142 | 143 | 144 | def _eval_runner(self): 145 | 146 | opt_thresh_list, test_acc_list = [], [] 147 | for k in range(self.args.n_folds): 148 | 149 | train, test = self.data['folds'][k] 150 | best_thresh, test_acc = self._find_best_thresh(train, test) 151 | print('fold : %2d, thresh : %.3f, test_acc : %.4f' % (k, best_thresh, test_acc)) 152 | opt_thresh_list.append(best_thresh) 153 | test_acc_list.append(test_acc) 154 | 155 | opt_thresh = np.mean(opt_thresh_list) 156 | test_acc = np.mean(test_acc_list) 157 | print('verification was finished, best_thresh : %.4f, test_acc : %.4f' % (opt_thresh, test_acc)) 158 | return opt_thresh, test_acc 159 | 160 | 161 | def verify_runner(self): 162 | 163 | self._report_settings() 164 | 165 | self._model_loader() 166 | 167 | self._eval_lfw() 168 | 169 | self._k_folds() 170 | 171 | self._eval_runner() 172 | 173 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .lfw import LFW 2 | from .casia import CASIAWebFace 3 | from .verify import VerifyFace 4 | from .casia_pfe import CASIAWebFacePFE 5 | -------------------------------------------------------------------------------- /dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/casia.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/dataset/__pycache__/casia.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/casia_pfe.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/dataset/__pycache__/casia_pfe.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/lfw.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/dataset/__pycache__/lfw.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/verify.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/dataset/__pycache__/verify.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/casia.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/26 5 | author: lujie 6 | """ 7 | 8 | 9 | import os 10 | import cv2 11 | import random 12 | import numpy as np 13 | import torchvision 14 | from torch.utils import data 15 | 16 | from IPython import embed 17 | 18 | 19 | class CASIAWebFace(data.Dataset): 20 | 21 | def __init__(self, args, mode = 'train'): 22 | 23 | super(CASIAWebFace, self).__init__() 24 | self.args = args 25 | self.mode = mode 26 | self.transforms = torchvision.transforms.Compose([ 27 | torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], \ 29 | std=[0.5, 0.5, 0.5])]) 30 | with open(args.train_file, 'r') as f: 31 | self.lines = f.readlines() 32 | f.close() 33 | if args.is_debug: 34 | self.lines = self.lines[:512] # just for debug 35 | print('debug version for casia ...') 36 | 37 | 38 | def _load_imginfo(self, img_name): 39 | 40 | img_path = os.path.join(self.args.casia_dir, 'align_112_112', img_name) 41 | img = None 42 | try: 43 | img = cv2.resize(cv2.imread(img_path), self.args.in_size) # TODO 44 | if random.random() > 0.5: 45 | img = cv2.flip(img, 1) 46 | except Exception as e: 47 | img = None 48 | return img 49 | 50 | 51 | def __getitem__(self, index): 52 | 53 | # info = self.lines[index].strip().split('\t') 54 | info = self.lines[index].strip().split(' ') 55 | img = self._load_imginfo(info[0]) 56 | cnt_try = 0 57 | while (img is None) and cnt_try < self.args.try_times: 58 | idx = np.random.randint(0, len(self.lines) - 1) 59 | # info = self.lines[idx].strip().split('\t') 60 | info = self.lines[idx].strip().split(' ') 61 | img = self._load_imginfo(info[0]) 62 | cnt_try += 1 63 | if cnt_try == self.args.try_times: 64 | print('read face failed ...') 65 | img = self.transforms(img) 66 | return (img, int(info[1])) 67 | 68 | 69 | def __len__(self): 70 | return len(self.lines) 71 | -------------------------------------------------------------------------------- /dataset/casia_pfe.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/26 5 | author: lujie 6 | """ 7 | 8 | 9 | import os 10 | import cv2 11 | import random 12 | import numpy as np 13 | import torchvision 14 | import pandas as pd 15 | from torch.utils import data 16 | 17 | from IPython import embed 18 | 19 | 20 | class CASIAWebFacePFE(data.Dataset): 21 | 22 | def __init__(self, args, mode = 'train'): 23 | 24 | super(CASIAWebFacePFE, self).__init__() 25 | self.args = args 26 | self.mode = mode 27 | self.transforms = torchvision.transforms.Compose([ 28 | torchvision.transforms.ToTensor(), 29 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], \ 30 | std=[0.5, 0.5, 0.5])]) 31 | self._pfe_process() 32 | 33 | 34 | def _pfe_process(self): 35 | 36 | with open(self.args.train_file, 'r') as f: 37 | self.lines = f.readlines() 38 | f.close() 39 | if self.args.is_debug: 40 | self.lines = self.lines[:1024] # just for debug 41 | print('debug version for casia ...') 42 | 43 | pids_dict = {} 44 | for line in self.lines: 45 | 46 | line = line.strip().split(' ') 47 | if line[-1] in pids_dict.keys(): 48 | pids_dict[line[-1]].append(line[0]) 49 | else: 50 | pids_dict[line[-1]] = [line[0]] 51 | 52 | self.lines = pd.Series(pids_dict).to_frame() 53 | self.lines['pid'] = self.lines.index 54 | self.lines.index = range(len(self.lines)) 55 | self.lines = np.array(self.lines[['pid', 0]]).tolist() * 6 56 | random.shuffle(self.lines) 57 | self.lines = np.array(self.lines) 58 | 59 | 60 | 61 | def _random_samples_from_class(self, files_list): 62 | 63 | indices = [] 64 | random.shuffle(files_list) 65 | if len(files_list) >= self.args.num_face_pb: 66 | indices = files_list[:self.args.num_face_pb] 67 | else: 68 | extend_times = int(np.ceil(self.args.num_face_pb / max(1, len(files_list)))) - 1 69 | tmp_list = files_list 70 | for i in range(extend_times): 71 | tmp_list.extend(files_list) 72 | indices = tmp_list[:self.args.num_face_pb] 73 | return indices 74 | 75 | 76 | def _load_imginfo(self, files_list): 77 | 78 | sample_files = self._random_samples_from_class(files_list) 79 | 80 | imgs = [] 81 | try: 82 | for file in sample_files: 83 | img_path = os.path.join(self.args.casia_dir, 'align_112_112', file) 84 | img = cv2.resize(cv2.imread(img_path), self.args.in_size) # TODO 85 | if random.random() > 0.5: 86 | img = cv2.flip(img, 1) 87 | img = self.transforms(img) 88 | imgs.append(img) 89 | except Exception as e: 90 | imgs = [] 91 | return imgs 92 | 93 | 94 | def __getitem__(self, index): 95 | 96 | info = self.lines[index] 97 | imgs = self._load_imginfo(info[1]) 98 | while len(imgs) == 0: 99 | idx = np.random.randint(0, len(self.lines) - 1) 100 | info = self.lines[idx] 101 | imgs = self._load_imginfo(info[1]) 102 | 103 | return (imgs, int(info[0])) 104 | 105 | 106 | def __len__(self): 107 | return len(self.lines) 108 | -------------------------------------------------------------------------------- /dataset/lfw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import cv2 6 | import random 7 | import numpy as np 8 | from torch.utils import data 9 | import torchvision.transforms as T 10 | 11 | from IPython import embed 12 | 13 | class LFW(object): 14 | 15 | def __init__(self, args, mode = 'test'): 16 | 17 | self.args = args 18 | self.mode = mode 19 | self.trans = T.Compose([T.ToTensor(), \ 20 | T.Normalize(mean=[0.5, 0.5, 0.5], \ 21 | std=[0.5, 0.5, 0.5])]) 22 | with open(args.pairs_file, 'r') as f: 23 | self.pairs = np.array(f.readlines()) 24 | shuffleidx = np.random.permutation(len(self.pairs)) 25 | self.pairs = self.pairs[shuffleidx] 26 | f.close() 27 | if args.is_debug: 28 | self.pairs = self.pairs[:512] 29 | print('debug version for lfw ...') 30 | self.num_pairs = len(self.pairs) 31 | 32 | 33 | def _load_imginfo(self, img_name): 34 | 35 | img_path = os.path.join(self.args.lfw_dir, img_name) 36 | img = None 37 | try: 38 | img = cv2.resize(cv2.imread(img_path), self.args.in_size) # TODO 39 | except Exception as e: 40 | img = None 41 | return img 42 | 43 | 44 | def _get_pair(self, index): 45 | 46 | pair_info = self.pairs[index].strip().split('\t') 47 | info_dict = {} 48 | try: 49 | if 3 == len(pair_info): 50 | info_dict['label'] = 1 51 | info_dict['name1'] = pair_info[0] + '/' + pair_info[0] + '_' + '{:04}.jpg'.format(int(pair_info[1])) 52 | info_dict['name2'] = pair_info[0] + '/' + pair_info[0] + '_' + '{:04}.jpg'.format(int(pair_info[2])) 53 | elif 4 == len(pair_info): 54 | info_dict['label'] = 0 55 | info_dict['name1'] = pair_info[0] + '/' + pair_info[0] + '_' + '{:04}.jpg'.format(int(pair_info[1])) 56 | info_dict['name2'] = pair_info[2] + '/' + pair_info[2] + '_' + '{:04}.jpg'.format(int(pair_info[3])) 57 | 58 | if info_dict['label'] is not None: 59 | face1 = self._load_imginfo(info_dict['name1']) 60 | face2 = self._load_imginfo(info_dict['name2']) 61 | info_dict['face1'] = self.trans(face1).unsqueeze(0) 62 | info_dict['face2'] = self.trans(face2).unsqueeze(0) 63 | except Exception as e: 64 | for key in ['name1', 'name2', 'label', 'face1', 'face2']: 65 | info_dict[key] = None 66 | return info_dict 67 | -------------------------------------------------------------------------------- /dataset/verify.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import random 9 | import numpy as np 10 | import torchvision 11 | import pandas as pd 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from sklearn import metrics 15 | import torch.nn.functional as F 16 | 17 | torch.backends.cudnn.bencmark = True 18 | 19 | from IPython import embed 20 | 21 | 22 | class VerifyFace(object): 23 | 24 | def __init__(self, args): 25 | 26 | self.args = args 27 | self.model = dict() 28 | self.data = dict() 29 | self.device = args.use_gpu and torch.cuda.is_available() 30 | 31 | 32 | def _report_settings(self): 33 | ''' Report the settings ''' 34 | 35 | str = '-' * 16 36 | print('%sEnvironment Versions%s' % (str, str)) 37 | print("- Python: {}".format(sys.version.strip().split('|')[0])) 38 | print("- PyTorch: {}".format(torch.__version__)) 39 | print("- TorchVison: {}".format(torchvision.__version__)) 40 | print("- device: {}".format(self.device)) 41 | print('-' * 52) 42 | 43 | 44 | def _model_loader(self): 45 | pass 46 | 47 | 48 | def _eval_lfw(self): 49 | 50 | self.model['backbone'].eval() # CORE 51 | with torch.no_grad(): 52 | simi_list = [] 53 | for index in range(1, self.data['lfw'].num_pairs): 54 | 55 | try: 56 | pair_dict = self.data['lfw']._get_pair(index) 57 | if pair_dict['label'] is not None: 58 | if self.device: 59 | pair_dict['face1'] = pair_dict['face1'].cuda() 60 | pair_dict['face2'] = pair_dict['face2'].cuda() 61 | feat1 = self.model['backbone'](pair_dict['face1']) 62 | feat2 = self.model['backbone'](pair_dict['face2']) 63 | cosvalue = feat1[0].dot(feat2[0]) / (feat1[0].norm() * feat2[0].norm() + 1e-5) 64 | simi_list.append([pair_dict['name1'], pair_dict['name2'], pair_dict['label'], cosvalue.item()]) 65 | except: 66 | pass 67 | # if (index + 1) % 500 == 0: 68 | # print('alreay processed %3d, total %3d' % (index+1, self.data['lfw'].num_pairs)) 69 | self.data['similist'] = np.array(simi_list) 70 | # col_name = ['name1', 'name2', 'gt_y', 'pred_y'] 71 | # df_simi = pd.DataFrame(self.data['similist'], columns=col_name) 72 | # df_simi.to_csv('check_2.csv') 73 | # print('lfw-pair faces was evaluated, there are %3d paris' % len(simi_list)) 74 | 75 | 76 | def _eval_aku8k(self): 77 | ''' Design for the raw-images, which equiped with a pairfile.csv ''' 78 | 79 | self.model['backbone'].eval() 80 | with torch.no_grad(): 81 | 82 | simi_list = [] 83 | for index in range(1, self.data['aku8k'].num_pairs): 84 | 85 | try: 86 | pair_dict = self.data['aku8k']._get_pair(index) 87 | if pair_dict['label'] is not None: 88 | if self.device: 89 | pair_dict['face1'] = pair_dict['face1'].cuda() 90 | pair_dict['face2'] = pair_dict['face2'].cuda() 91 | feat1 = self.model['backbone'](pair_dict['face1']) 92 | feat2 = self.model['backbone'](pair_dict['face2']) 93 | cosvalue = feat1[0].dot(feat2[0]) / (feat1[0].norm() * feat2[0].norm() + 1e-5) 94 | simi_list.append([pair_dict['name1'], pair_dict['name2'], pair_dict['label'], cosvalue.item()]) 95 | except: 96 | pass 97 | # if (index + 1) % 500 == 0: 98 | # print('alreay processed %3d, total %3d' % (index+1, self.data['lfw'].num_pairs)) 99 | self.data['similist'] = np.array(simi_list) 100 | 101 | 102 | 103 | def _k_folds(self): 104 | 105 | num_lines = len(self.data['similist']) 106 | folds, base = [], list(range(num_lines)) 107 | for k in range(self.args.n_folds): 108 | 109 | start = int(k * num_lines / self.args.n_folds) 110 | end = int((k + 1) * num_lines / self.args.n_folds) 111 | test = base[start : end] 112 | train = list(set(base) - set(test)) 113 | folds.append([train, test]) 114 | self.data['folds'] = folds 115 | 116 | 117 | def _cal_acc(self, index, thresh): 118 | 119 | gt_y, pred_y = [], [] 120 | for row in self.data['similist'][index]: 121 | 122 | same = 1 if float(row[-1]) > thresh else 0 123 | pred_y.append(same) 124 | gt_y.append(int(row[-2])) 125 | gt_y = np.array(gt_y) 126 | pred_y = np.array(pred_y) 127 | accuracy = 1.0 * np.count_nonzero(gt_y==pred_y) / len(gt_y) 128 | return accuracy 129 | 130 | 131 | def _find_best_thresh(self, train, test): 132 | 133 | best_thresh, best_acc = 0, 0 134 | for thresh in np.arange(-1, 1, self.args.thresh_iv): 135 | 136 | acc = self._cal_acc(train, thresh) 137 | if best_acc < acc: 138 | best_acc = acc 139 | best_thresh = thresh 140 | test_acc = self._cal_acc(test, best_thresh) 141 | return (best_thresh, test_acc) 142 | 143 | 144 | def _eval_runner(self): 145 | 146 | opt_thresh_list, test_acc_list = [], [] 147 | for k in range(self.args.n_folds): 148 | 149 | train, test = self.data['folds'][k] 150 | best_thresh, test_acc = self._find_best_thresh(train, test) 151 | print('fold : %2d, thresh : %.3f, test_acc : %.4f' % (k, best_thresh, test_acc)) 152 | opt_thresh_list.append(best_thresh) 153 | test_acc_list.append(test_acc) 154 | 155 | opt_thresh = np.mean(opt_thresh_list) 156 | test_acc = np.mean(test_acc_list) 157 | print('verification was finished, best_thresh : %.4f, test_acc : %.4f' % (opt_thresh, test_acc)) 158 | return opt_thresh, test_acc 159 | 160 | 161 | def verify_runner(self): 162 | 163 | self._report_settings() 164 | 165 | self._model_loader() 166 | 167 | self._eval_lfw() 168 | 169 | self._k_folds() 170 | 171 | self._eval_runner() 172 | 173 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import cv2 6 | import time 7 | import torch 8 | import random 9 | import argparse 10 | import numpy as np 11 | import torchvision 12 | import torch.nn as nn 13 | import torchvision.transforms as T 14 | 15 | import model as mlib 16 | 17 | torch.backends.cudnn.bencmark = True 18 | # os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" # TODO 19 | 20 | from IPython import embed 21 | 22 | 23 | class ProbFace(object): 24 | 25 | def __init__(self, args): 26 | 27 | self.args = args 28 | self.model = dict() 29 | self.trans = T.Compose([T.ToTensor(), \ 30 | T.Normalize(mean=[0.5, 0.5, 0.5], \ 31 | std=[0.5, 0.5, 0.5])]) 32 | self.device = args.use_gpu and torch.cuda.is_available() 33 | self._model_loader() 34 | 35 | 36 | def _model_loader(self): 37 | 38 | self.model['backbone'] = mlib.MobileFace(self.args.in_feats, self.args.drop_ratio) 39 | self.model['uncertain'] = mlib.UncertaintyHead(self.args.in_feats) 40 | self.model['criterion'] = mlib.MLSLoss(mean=False) 41 | 42 | if self.device: 43 | self.model['backbone'] = self.model['backbone'].cuda() 44 | self.model['uncertain'] = self.model['uncertain'].cuda() 45 | self.model['criterion'] = self.model['criterion'].cuda() 46 | 47 | if self.device and len(self.args.gpu_ids) > 1: 48 | self.model['backbone'] = torch.nn.DataParallel(self.model['backbone'], device_ids=self.args.gpu_ids) 49 | self.model['uncertain'] = torch.nn.DataParallel(self.model['uncertain'], device_ids=self.args.gpu_ids) 50 | print('Parallel mode was going ...') 51 | elif self.device: 52 | print('Single-gpu mode was going ...') 53 | else: 54 | print('CPU mode was going ...') 55 | 56 | if len(self.args.resume) > 2: 57 | checkpoint = torch.load(self.args.resume, map_location=lambda storage, loc: storage) 58 | self.model['backbone'].load_state_dict(checkpoint['backbone']) 59 | self.model['uncertain'].load_state_dict(checkpoint['uncertain']) 60 | print('Resuming the train process at %3d epoches ...' % checkpoint['epoch']) 61 | 62 | self.model['backbone'].eval() 63 | self.model['uncertain'].eval() 64 | print('Model loading was finished ...') 65 | 66 | 67 | @staticmethod 68 | def cal_pair_mls(mu1, mu2, logsig_sq1=None, logsig_sq2=None): 69 | ''' Calculate the mls of pair faces ''' 70 | 71 | sig_sq1 = torch.exp(logsig_sq1) 72 | sig_sq2 = torch.exp(logsig_sq2) 73 | sig_sq_mutual = sig_sq1 + sig_sq2 74 | mu_diff = mu1 - mu2 75 | mls_pointwise = torch.mul(mu_diff, mu_diff) / sig_sq_mutual + torch.log(sig_sq_mutual) 76 | mls_score = mls_pointwise.sum(dim=1).item() 77 | return mls_score 78 | 79 | 80 | def _process_pair(self, face1, face2): 81 | ''' Get the mls score of pair faces ''' 82 | 83 | mls_score = None 84 | if face1 is None or face2 is None: 85 | mls_score = None 86 | else: 87 | face1 = self.trans(face1).unsqueeze(0) 88 | face2 = self.trans(face2).unsqueeze(0) 89 | 90 | if self.device == 'cuda': 91 | face1 = face1.cuda() 92 | face2 = face2.cuda() 93 | try: 94 | mu1, feat1 = self.model['backbone'](face1) 95 | mu2, feat2 = self.model['backbone'](face2) 96 | logsig_sq1 = self.model['uncertain'](feat1) 97 | logsig_sq2 = self.model['uncertain'](feat2) 98 | except Exception as e: 99 | print(e) 100 | else: 101 | mls_score = self.cal_pair_mls(mu1, mu2, logsig_sq1, logsig_sq2) 102 | return mls_score 103 | 104 | 105 | cp_dir = '/home/jovyan/jupyter/checkpoints_zoo/face-recognition' 106 | 107 | def infer_args(): 108 | 109 | parser = argparse.ArgumentParser(description='PyTorch for ProbFace') 110 | 111 | # -- env 112 | parser.add_argument('--use_gpu', type=bool, default=True) 113 | parser.add_argument('--gpu_ids', type=list, default=[0, 1]) 114 | parser.add_argument('--workers', type=int, default=0) 115 | 116 | # -- model 117 | parser.add_argument('--in_size', type=tuple, default=(112, 112)) # FIXED 118 | parser.add_argument('--in_feats', type=int, default=512) 119 | parser.add_argument('--drop_ratio', type=float, default=0.4) # TODOqq 120 | parser.add_argument('--resume', type=str, default=os.path.join(cp_dir, 'pfe/sota.pth.tar')) # checkpoint 121 | 122 | args = parser.parse_args() 123 | 124 | return args 125 | 126 | 127 | if __name__ == "__main__": 128 | 129 | probu = ProbFace(infer_args()) 130 | face1 = cv2.imread('test_img/Pedro_Solbes_0003.jpg') 131 | face2 = cv2.imread('test_img/Pedro_Solbes_0004.jpg') 132 | face3 = cv2.imread('test_img/Zico_0003.jpg') 133 | mls_score = probu._process_pair(face2, face3) 134 | print(mls_score) 135 | embed() 136 | -------------------------------------------------------------------------------- /log/train_backbone.log: -------------------------------------------------------------------------------- 1 | ----------------Environment Versions---------------- 2 | - Python : 3.7.3 3 | - PyTorch : 1.1.0 4 | - TorchVison: 0.3.0 5 | - USE_GPU : True 6 | ---------------------------------------------------- 7 | Parallel mode was going ... 8 | Resuming the train process at 1 epoches ... 9 | Model loading was finished ... 10 | Data loading was finished ... 11 | epoch : 1| 5, iter : 100| 992, loss : -2657.0063 12 | epoch : 1| 5, iter : 200| 992, loss : -2705.3872 13 | epoch : 1| 5, iter : 300| 992, loss : -2722.4153 14 | epoch : 1| 5, iter : 400| 992, loss : -2731.3879 15 | epoch : 1| 5, iter : 500| 992, loss : -2736.8850 16 | epoch : 1| 5, iter : 600| 992, loss : -2740.7804 17 | epoch : 1| 5, iter : 700| 992, loss : -2744.4659 18 | epoch : 1| 5, iter : 800| 992, loss : -2747.9572 19 | epoch : 1| 5, iter : 900| 992, loss : -2750.4960 20 | train_loss : -2752.6475 21 | Single epoch cost time : 3.85 mins 22 | ****************new SOTA was found**************** 23 | epoch : 2| 5, iter : 100| 992, loss : -2777.3234 24 | epoch : 2| 5, iter : 200| 992, loss : -2774.2954 25 | epoch : 2| 5, iter : 300| 992, loss : -2773.6455 26 | epoch : 2| 5, iter : 400| 992, loss : -2773.7553 27 | epoch : 2| 5, iter : 500| 992, loss : -2773.3056 28 | epoch : 2| 5, iter : 600| 992, loss : -2773.0361 29 | epoch : 2| 5, iter : 700| 992, loss : -2773.0753 30 | epoch : 2| 5, iter : 800| 992, loss : -2773.7208 31 | epoch : 2| 5, iter : 900| 992, loss : -2774.0920 32 | train_loss : -2774.5521 33 | Single epoch cost time : 3.80 mins 34 | ****************new SOTA was found**************** 35 | epoch : 3| 5, iter : 100| 992, loss : -2776.6750 36 | epoch : 3| 5, iter : 200| 992, loss : -2777.8207 37 | epoch : 3| 5, iter : 300| 992, loss : -2776.4285 38 | epoch : 3| 5, iter : 400| 992, loss : -2775.7481 39 | epoch : 3| 5, iter : 500| 992, loss : -2775.7185 40 | epoch : 3| 5, iter : 600| 992, loss : -2775.2633 41 | epoch : 3| 5, iter : 700| 992, loss : -2775.6621 42 | epoch : 3| 5, iter : 800| 992, loss : -2775.8025 43 | epoch : 3| 5, iter : 900| 992, loss : -2776.1324 44 | train_loss : -2776.5364 45 | Single epoch cost time : 3.70 mins 46 | ****************new SOTA was found**************** 47 | epoch : 4| 5, iter : 100| 992, loss : -2777.1886 48 | epoch : 4| 5, iter : 200| 992, loss : -2776.6736 49 | epoch : 4| 5, iter : 300| 992, loss : -2776.9170 50 | epoch : 4| 5, iter : 400| 992, loss : -2776.7274 51 | epoch : 4| 5, iter : 500| 992, loss : -2776.6935 52 | epoch : 4| 5, iter : 600| 992, loss : -2776.8823 53 | epoch : 4| 5, iter : 700| 992, loss : -2776.4362 54 | epoch : 4| 5, iter : 800| 992, loss : -2776.6007 55 | epoch : 4| 5, iter : 900| 992, loss : -2776.5034 56 | train_loss : -2776.2909 57 | Single epoch cost time : 3.78 mins 58 | epoch : 5| 5, iter : 100| 992, loss : -2777.8533 59 | epoch : 5| 5, iter : 200| 992, loss : -2776.3377 60 | epoch : 5| 5, iter : 300| 992, loss : -2775.0353 61 | epoch : 5| 5, iter : 400| 992, loss : -2776.3318 62 | epoch : 5| 5, iter : 500| 992, loss : -2776.1743 63 | epoch : 5| 5, iter : 600| 992, loss : -2775.9966 64 | epoch : 5| 5, iter : 700| 992, loss : -2776.1331 65 | epoch : 5| 5, iter : 800| 992, loss : -2776.4060 66 | epoch : 5| 5, iter : 900| 992, loss : -2776.1280 67 | train_loss : -2776.1716 68 | Single epoch cost time : 3.68 mins 69 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .mls_loss import MLSLoss 2 | from .spherenet import SphereNet20 3 | from .resnet import resnet_zoo 4 | from .mobilenet import MobileFace 5 | from .uncertainty_head import UncertaintyHead 6 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/faceloss-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from IPython import embed 9 | 10 | 11 | class FaceLoss(nn.Module): 12 | ''' Classic loss function for face recognition ''' 13 | 14 | def __init__(self, args): 15 | 16 | super(FaceLoss, self).__init__() 17 | self.args = args 18 | 19 | 20 | def forward(self, predy, target): 21 | 22 | if self.args.loss_mode == 'focal_loss': 23 | logp = F.cross_entropy(predy, target, reduce=False) 24 | prob = torch.exp(-logp) 25 | loss = ((1-prob) ** self.args.loss_power * logp).mean() 26 | 27 | elif self.args.loss_mode == 'hardmining': 28 | batchsize = predy.shape[0] 29 | logp = F.cross_entropy(predy, label, reduce=False) 30 | inv_index = torch.argsort(-logp) # from big to small 31 | num_hard = int(self.args.hard_ratio * batch_size) 32 | hard_idx = ind_sorted[:num_hard] 33 | loss = torch.sum(F.cross_entropy(pred[hard_idx], label[hard_idx])) 34 | else: 35 | loss = F.cross_entropy(predy, target) 36 | 37 | return loss -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/fc_layer-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from IPython import embed 10 | 11 | 12 | class FullyConnectedLayer(nn.Module): 13 | 14 | def __init__(self, args): 15 | 16 | super(FullyConnectedLayer, self).__init__() 17 | 18 | self.args = args 19 | self.weight = nn.Parameter(torch.Tensor(args.classnum, args.in_feats)) 20 | nn.init.xavier_uniform_(self.weight) 21 | self.cos_m = math.cos(args.margin) 22 | self.sin_m = math.sin(args.margin) 23 | self.mm = math.sin(math.pi - self.args.margin) * self.args.margin 24 | self.register_buffer('factor_t', torch.zeros(1)) 25 | self.iter = 0 26 | self.base = 1000 27 | self.alpha = 0.0001 28 | self.power = 2 29 | self.lambda_min = 5.0 30 | self.mlambda = [ 31 | lambda x: x ** 0, 32 | lambda x: x ** 1, 33 | lambda x: 2 * x ** 2 - 1, 34 | lambda x: 4 * x ** 3 - 3 * x, 35 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 36 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 37 | ] 38 | 39 | 40 | def forward(self, x, label): 41 | 42 | cos_theta = F.linear(F.normalize(x), F.normalize(self.weight)) 43 | cos_theta = cos_theta.clamp(-1, 1) 44 | batch_size = label.size(0) 45 | cosin_simi = cos_theta[torch.arange(0, batch_size), label].view(-1, 1) 46 | 47 | if self.args.fc_mode == 'softmax': 48 | score = cosin_simi 49 | 50 | elif self.args.fc_mode == 'sphereface': 51 | self.iter += 1 52 | self.lamb = max(self.lambda_min, self.base * (1 + self.alpha * self.iter) ** (-1 * self.power)) 53 | cos_theta_m = self.mlambda[int(self.args.margin)](cosin_simi) 54 | theta = cosin_simi.data.acos() 55 | k = ((self.args.margin * theta) / math.pi).floor() 56 | phi_theta = ((-1.0) ** k) * cos_theta_m - 2 * k 57 | score = (self.lamb * cosin_simi + phi_theta) / (1 + self.lamb) 58 | 59 | elif self.args.fc_mode == 'cosface': 60 | if self.args.easy_margin: 61 | score = torch.where(cosin_simi > 0, cosin_simi - self.args.margin, cosin_simi) 62 | else: 63 | score = cosin_simi - self.args.margin 64 | 65 | elif self.args.fc_mode == 'arcface': 66 | sin_theta = torch.sqrt(1.0 - torch.pow(cosin_simi, 2)) 67 | cos_theta_m = cosin_simi * self.cos_m - sin_theta * self.sin_m 68 | if self.args.easy_margin: 69 | score = torch.where(cosin_simi > 0, cos_theta_m, cosin_simi) 70 | else: 71 | score = cos_theta_m 72 | 73 | elif self.args.fc_mode == 'mvcos': 74 | mask = cos_theta > cosin_simi - self.args.margin 75 | hard_vector = cos_theta[mask] 76 | if self.args.hard_mode == 'adaptive': 77 | cos_theta[mask] = (self.args.t + 1.0) * hard_vector + self.args.t # Adaptive 78 | else: 79 | cos_theta[mask] = hard_vector + self.args.t # Fixed 80 | if self.args.easy_margin: 81 | score = torch.where(cosin_simi > 0, cosin_simi - self.args.margin, cosin_simi) 82 | else: 83 | score = cosin_simi - self.args.margin 84 | 85 | elif self.args.fc_mode == 'mvarc': 86 | sin_theta = torch.sqrt(1.0 - torch.pow(cosin_simi, 2)) 87 | cos_theta_m = cosin_simi * self.cos_m - sin_theta * self.sin_m 88 | mask = cos_theta > cos_theta_m 89 | hard_vector = cos_theta[mask] 90 | if self.args.hard_mode == 'adaptive': 91 | cos_theta[mask] = (self.args.t + 1.0) * hard_vector + self.args.t # Adaptive 92 | else: 93 | cos_theta[mask] = hard_vector + self.args.t # Fixed 94 | if self.args.easy_margin: 95 | score = torch.where(cosin_simi > 0, cos_theta_m, cosin_simi) 96 | else: 97 | score = cos_theta_m 98 | 99 | elif self.args.fc_mode == 'curface': 100 | with torch.no_grad(): 101 | origin_cos = cos_theta 102 | sin_theta = torch.sqrt(1.0 - torch.pow(cosin_simi, 2)) 103 | cos_theta_m = cosin_simi * self.cos_m - sin_theta * self.sin_m 104 | mask = cos_theta > cos_theta_m 105 | score = torch.where(cosin_simi > 0, cos_theta_m, cosin_simi - self.mm) 106 | hard_sample = cos_theta[mask] 107 | with torch.no_grad(): 108 | self.factor_t = cos_theta_m.mean() * 0.01 + 0.99 * self.factor_t 109 | cos_theta[mask] = hard_sample * (self.factor_t + hard_sample) 110 | else: 111 | raise Exception('unknown fc type!') 112 | 113 | cos_theta.scatter_(1, label.data.view(-1, 1), score) 114 | cos_theta *= self.args.scale 115 | return cos_theta 116 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/mls_loss-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/23 5 | author: lujie 6 | """ 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from IPython import embed 12 | 13 | class MLSLoss(nn.Module): 14 | 15 | def __init__(self, mean = False): 16 | 17 | super(MLSLoss, self).__init__() 18 | self.mean = mean 19 | 20 | def negMLS(self, mu_X, sigma_sq_X): 21 | 22 | if self.mean: 23 | XX = torch.mul(mu_X, mu_X).sum(dim=1, keepdim=True) 24 | YY = torch.mul(mu_X.T, mu_X.T).sum(dim=0, keepdim=True) 25 | XY = torch.mm(mu_X, mu_X.T) 26 | mu_diff = XX + YY - 2 * XY 27 | sig_sum = sigma_sq_X.mean(dim=1, keepdim=True) + sigma_sq_X.T.sum(dim=0, keepdim=True) 28 | diff = mu_diff / (1e-8 + sig_sum) + mu_X.size(1) * torch.log(sig_sum) 29 | return diff 30 | else: 31 | mu_diff = mu_X.unsqueeze(1) - mu_X.unsqueeze(0) 32 | sig_sum = sigma_sq_X.unsqueeze(1) + sigma_sq_X.unsqueeze(0) 33 | diff = torch.mul(mu_diff, mu_diff) / (1e-10 + sig_sum) + torch.log(sig_sum) # BUG 34 | diff = diff.sum(dim=2, keepdim=False) 35 | return diff 36 | 37 | def forward(self, mu_X, log_sigma_sq, gty): 38 | 39 | mu_X = F.normalize(mu_X) # if mu_X was not normalized by l2 40 | non_diag_mask = (1 - torch.eye(mu_X.size(0))).int() 41 | if gty.device.type == 'cuda': 42 | non_diag_mask = non_diag_mask.cuda(0) 43 | sig_X = torch.exp(log_sigma_sq) 44 | loss_mat = self.negMLS(mu_X, sig_X) 45 | gty_mask = (torch.eq(gty[:, None], gty[None, :])).int() 46 | pos_mask = (non_diag_mask * gty_mask) > 0 47 | pos_loss = loss_mat[pos_mask].mean() 48 | return pos_loss 49 | 50 | 51 | if __name__ == "__main__": 52 | 53 | mls = MLSLoss(mean=False) 54 | gty = torch.Tensor([1, 2, 3, 2, 3, 3, 2]) 55 | mu_data = np.array([[-1.7847768 , -1.0991699 , 1.4248079 ], 56 | [ 1.0405252 , 0.35788524, 0.7338794 ], 57 | [ 1.0620259 , 2.1341069 , -1.0100055 ], 58 | [-0.00963581, 0.39570177, -1.5577421 ], 59 | [-1.064951 , -1.1261107 , -1.4181522 ], 60 | [ 1.008275 , -0.84791195, 0.3006532 ], 61 | [ 0.31099692, -0.32650718, -0.60247767]]) 62 | 63 | si_data = np.array([[-0.28463233, -2.5517333 , 1.4781238 ], 64 | [-0.10505871, -0.31454122, -0.29844758], 65 | [-1.3067418 , 0.48718405, 0.6779812 ], 66 | [ 2.024449 , -1.3925922 , -1.6178994 ], 67 | [-0.08328865, -0.396574 , 1.0888542 ], 68 | [ 0.13096762, -0.14382902, 0.2695235 ], 69 | [ 0.5405067 , -0.67946523, -0.8433032 ]]) 70 | 71 | muX = torch.from_numpy(mu_data) 72 | siX = torch.from_numpy(si_data) 73 | print(muX.shape) 74 | diff = mls(muX, siX, gty) 75 | print(diff) 76 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/mls_tf-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/23 5 | author: lujie 6 | """ 7 | import numpy as np 8 | import tensorflow as tf 9 | from IPython import embed 10 | 11 | def negative_MLS(X, Y, sigma_sq_X, sigma_sq_Y, mean=False): 12 | with tf.name_scope('negative_MLS'): 13 | if mean: 14 | D = X.shape[1].value 15 | 16 | Y = tf.transpose(Y) 17 | XX = tf.reduce_sum(tf.square(X), 1, keep_dims=True) 18 | YY = tf.reduce_sum(tf.square(Y), 0, keep_dims=True) 19 | XY = tf.matmul(X, Y) 20 | diffs = XX + YY - 2*XY 21 | 22 | sigma_sq_Y = tf.transpose(sigma_sq_Y) 23 | sigma_sq_X = tf.reduce_mean(sigma_sq_X, axis=1, keep_dims=True) 24 | sigma_sq_Y = tf.reduce_mean(sigma_sq_Y, axis=0, keep_dims=True) 25 | sigma_sq_fuse = sigma_sq_X + sigma_sq_Y 26 | 27 | diffs = diffs / (1e-8 + sigma_sq_fuse) + D * tf.log(sigma_sq_fuse) 28 | 29 | return diffs 30 | else: 31 | # D = X.shape[1].value 32 | D = X.shape[1] 33 | X = tf.reshape(X, [-1, 1, D]) 34 | Y = tf.reshape(Y, [1, -1, D]) 35 | sigma_sq_X = tf.reshape(sigma_sq_X, [-1, 1, D]) 36 | sigma_sq_Y = tf.reshape(sigma_sq_Y, [1, -1, D]) 37 | sigma_sq_fuse = sigma_sq_X + sigma_sq_Y 38 | diffs = tf.square(X-Y) / (1e-10 + sigma_sq_fuse) + tf.math.log(sigma_sq_fuse) 39 | return tf.reduce_sum(diffs, axis=2) 40 | 41 | def mutual_likelihood_score_loss(labels, mu, log_sigma_sq): 42 | 43 | with tf.name_scope('MLS_Loss'): 44 | 45 | batch_size = tf.shape(mu)[0] 46 | diag_mask = tf.eye(batch_size, dtype=tf.bool) 47 | non_diag_mask = tf.logical_not(diag_mask) 48 | 49 | sigma_sq = tf.exp(log_sigma_sq) 50 | loss_mat = negative_MLS(mu, mu, sigma_sq, sigma_sq) 51 | 52 | label_mat = tf.equal(labels[:,None], labels[None,:]) 53 | label_mask_pos = tf.logical_and(non_diag_mask, label_mat) 54 | 55 | loss_pos = tf.boolean_mask(loss_mat, label_mask_pos) 56 | 57 | return tf.reduce_mean(loss_pos) 58 | 59 | 60 | if __name__ == "__main__": 61 | 62 | gty = tf.convert_to_tensor([1, 2, 3, 2, 3, 3, 2]) 63 | # muX = tf.random.normal([7, 3], mean=0, stddev=1) 64 | # siX = tf.random.normal([7, 3], mean=0, stddev=1) 65 | mu_data = np.array([[-1.7847768 , -1.0991699 , 1.4248079 ], 66 | [ 1.0405252 , 0.35788524, 0.7338794 ], 67 | [ 1.0620259 , 2.1341069 , -1.0100055 ], 68 | [-0.00963581, 0.39570177, -1.5577421 ], 69 | [-1.064951 , -1.1261107 , -1.4181522 ], 70 | [ 1.008275 , -0.84791195, 0.3006532 ], 71 | [ 0.31099692, -0.32650718, -0.60247767]]) 72 | 73 | si_data = np.array([[-0.28463233, -2.5517333 , 1.4781238 ], 74 | [-0.10505871, -0.31454122, -0.29844758], 75 | [-1.3067418 , 0.48718405, 0.6779812 ], 76 | [ 2.024449 , -1.3925922 , -1.6178994 ], 77 | [-0.08328865, -0.396574 , 1.0888542 ], 78 | [ 0.13096762, -0.14382902, 0.2695235 ], 79 | [ 0.5405067 , -0.67946523, -0.8433032 ]]) 80 | muX = tf.convert_to_tensor(mu_data) 81 | siX = tf.convert_to_tensor(si_data) 82 | diff = mutual_likelihood_score_loss(gty, muX, siX) 83 | print(diff) 84 | embed() 85 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/mobilenet-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | 4 | import thop 5 | import math 6 | import torch 7 | from torch import nn 8 | from IPython import embed 9 | 10 | class BottleNeck(nn.Module): 11 | def __init__(self, inp, oup, stride, expansion): 12 | 13 | super(BottleNeck, self).__init__() 14 | 15 | self.connect = stride == 1 and inp == oup 16 | self.conv = nn.Sequential( 17 | # 1*1 conv 18 | nn.Conv2d(inp, inp * expansion, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(inp * expansion), 20 | nn.PReLU(inp * expansion), 21 | 22 | # 3*3 depth wise conv 23 | nn.Conv2d(inp * expansion, inp * expansion, 3, stride, 1, groups=inp * expansion, bias=False), 24 | nn.BatchNorm2d(inp * expansion), 25 | nn.PReLU(inp * expansion), 26 | 27 | # 1*1 conv 28 | nn.Conv2d(inp * expansion, oup, 1, 1, 0, bias=False), 29 | nn.BatchNorm2d(oup)) 30 | 31 | 32 | def forward(self, x): 33 | 34 | out = self.conv(x) 35 | if self.connect: 36 | return x + out 37 | else: 38 | return out 39 | 40 | 41 | class ConvBlock(nn.Module): 42 | 43 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False): 44 | super(ConvBlock, self).__init__() 45 | self.linear = linear 46 | if dw: 47 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False) 48 | else: 49 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False) 50 | 51 | self.bn = nn.BatchNorm2d(oup) 52 | if not linear: 53 | self.prelu = nn.PReLU(oup) 54 | 55 | def forward(self, x): 56 | x = self.conv(x) 57 | x = self.bn(x) 58 | if self.linear: 59 | return x 60 | else: 61 | return self.prelu(x) 62 | 63 | 64 | class Flatten(nn.Module): 65 | def forward(self, input): 66 | return input.view(input.size(0), -1) 67 | 68 | 69 | class MobileFace(nn.Module): 70 | def __init__(self, feat_dim = 512, drop_ratio = 0.5): 71 | 72 | super(MobileFace, self).__init__() 73 | 74 | self.conv1 = ConvBlock(3, 64, 3, 2, 1) 75 | self.dwconv1 = ConvBlock(64, 64, 3, 1, 1, dw=True) 76 | self.cur_channel = 64 #t, c, n, s 77 | self.block_setting = [[2, 64, 5, 2], 78 | [4, 128, 1, 2], 79 | [2, 128, 6, 1], 80 | [4, 128, 1, 2], 81 | [2, 128, 2, 1]] 82 | self.layers = self._make_layer() 83 | 84 | self.conv2 = ConvBlock(128, 512, 1, 1, 0) 85 | self.linear7 = ConvBlock(512, 512, 7, 1, 0, dw=True, linear=True) # CORE 86 | self.linear1 = ConvBlock(512, feat_dim, 1, 1, 0, linear=True) 87 | 88 | # Arcface : BN-Dropout-FC-BN to get the final 512-D embedding feature 89 | ''' 90 | self.output_layer = nn.Sequential( 91 | nn.BatchNorm2d(512), 92 | nn.Dropout(p=drop_ratio), 93 | Flatten(), 94 | nn.Linear(512 * 7 * 7, feat_dim), # size / 16 95 | nn.BatchNorm1d(feat_dim)) 96 | ''' 97 | 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2. / n)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | 107 | def _make_layer(self, block = BottleNeck): 108 | layers = [] 109 | for t, c, n, s in self.block_setting: 110 | for i in range(n): 111 | if i == 0: 112 | layers.append(block(self.cur_channel, c, s, t)) 113 | else: 114 | layers.append(block(self.cur_channel, c, 1, t)) 115 | self.cur_channel = c 116 | 117 | return nn.Sequential(*layers) 118 | 119 | 120 | def forward(self, x): 121 | 122 | x = self.conv1(x) 123 | x = self.dwconv1(x) 124 | x = self.layers(x) 125 | x = self.conv2(x) 126 | # x = self.output_layer(x) 127 | x = self.linear7(x) 128 | sig_x = x 129 | x = self.linear1(x) 130 | x = x.view(x.size(0), -1) 131 | return x, sig_x 132 | 133 | 134 | if __name__ == "__main__": 135 | input = torch.Tensor(1, 3, 112, 112) 136 | model = MobileFace(use_cbam=True) 137 | flops, params = thop.profile(model, inputs=(input, )) 138 | flops, params = thop.clever_format([flops, params], "%.3f") 139 | print(flops, params) 140 | # model = model.eval() 141 | # out = model(input) 142 | # print(out.shape) 143 | 144 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/resnet-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/24 5 | author: lujie 6 | """ 7 | 8 | 9 | import thop 10 | import torch 11 | import torch.nn as nn 12 | 13 | from IPython import embed 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | __constants__ = ['downsample'] 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 31 | base_width=64, dilation=1, norm_layer=None): 32 | super(BasicBlock, self).__init__() 33 | if norm_layer is None: 34 | norm_layer = nn.BatchNorm2d 35 | if groups != 1 or base_width != 64: 36 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 37 | if dilation > 1: 38 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, dilation=1, norm_layer=None): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 80 | self.bn2 = norm_layer(width) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class Flatten(nn.Module): 111 | def forward(self, input): 112 | return input.view(input.size(0), -1) 113 | 114 | 115 | class ResNet(nn.Module): 116 | 117 | def __init__(self, block, layers, feat_dim=512, drop_ratio = 0.5, zero_init_residual=False): 118 | 119 | super(ResNet, self).__init__() 120 | 121 | self.inplanes = 64 122 | ''' 123 | self.input_layer = nn.Sequential( 124 | nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), 125 | nn.BatchNorm2d(self.inplanes), 126 | nn.ReLU(inplace=True), 127 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 128 | ''' 129 | 130 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) # TODO 131 | self.bn1 = nn.BatchNorm2d(self.inplanes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | self.layer1 = self._make_layer(block, 64, layers[0]) 135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 137 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 138 | 139 | # After the last conv-layer, BN-Dropout-FC-BN to get the final 512-D embedding feature 140 | self.output_layer = nn.Sequential( 141 | nn.BatchNorm2d(512 * block.expansion), 142 | nn.Dropout(p=drop_ratio), 143 | Flatten(), 144 | nn.Linear(int(512 * block.expansion * 7 * 7), int(feat_dim)), # size / 16 145 | nn.BatchNorm1d(int(feat_dim))) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 150 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | if zero_init_residual: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck): 157 | nn.init.constant_(m.bn3.weight, 0) 158 | elif isinstance(m, BasicBlock): 159 | nn.init.constant_(m.bn2.weight, 0) 160 | 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1): 163 | 164 | downsample = None 165 | if stride != 1 or self.inplanes != planes * block.expansion: 166 | downsample = nn.Sequential( 167 | conv1x1(self.inplanes, planes * block.expansion, stride), 168 | nn.BatchNorm2d(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample)) 173 | self.inplanes = planes * block.expansion 174 | for _ in range(1, blocks): 175 | layers.append(block(self.inplanes, planes)) 176 | 177 | return nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | 181 | x = self.conv1(x) 182 | x = self.bn1(x) 183 | x = self.relu(x) 184 | x = self.maxpool(x) 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | x = self.output_layer(x) 190 | 191 | return x 192 | 193 | 194 | def resnet_zoo(backbone = 'resnet18', feat_dim = 512, drop_ratio = 0.5): 195 | 196 | version_dict = { 197 | 'resnet18' : [2, 2, 2, 2], 198 | 'resnet34' : [3, 4, 6, 3], 199 | 'resnet50' : [3, 4, 6, 3], 200 | 'resnet101': [3, 4, 23,3], 201 | 'resnet152': [3, 8, 36,3], 202 | } 203 | if backbone == 'resnet18' or backbone == 'resnet34': 204 | block = BasicBlock 205 | else: 206 | block = Bottleneck 207 | 208 | return ResNet(block, version_dict[backbone], feat_dim, drop_ratio) 209 | 210 | 211 | 212 | if __name__ == "__main__": 213 | 214 | input = torch.Tensor(1, 3, 112, 112) 215 | model = resnet_zoo('resnet18') 216 | flops, params = thop.profile(model, inputs=(input, )) 217 | flops, params = thop.clever_format([flops, params], "%.3f") 218 | print(flops, params) 219 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/uncertainty_head-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/23 5 | author: lujie 6 | """ 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import Parameter 12 | 13 | from IPython import embed 14 | 15 | class UncertaintyHead(nn.Module): 16 | ''' Evaluate the log(sigma^2) ''' 17 | 18 | def __init__(self, in_feat = 512): 19 | 20 | super(UncertaintyHead, self).__init__() 21 | self.fc1 = Parameter(torch.Tensor(in_feat, in_feat)) 22 | self.bn1 = nn.BatchNorm1d(in_feat, affine=True) 23 | self.relu = nn.ReLU(in_feat) 24 | self.fc2 = Parameter(torch.Tensor(in_feat, in_feat)) 25 | self.bn2 = nn.BatchNorm1d(in_feat, affine=False) 26 | self.gamma = Parameter(torch.Tensor([1.0])) 27 | self.beta = Parameter(torch.Tensor([0.0])) # default = -7.0 28 | 29 | nn.init.kaiming_normal_(self.fc1) 30 | nn.init.kaiming_normal_(self.fc2) 31 | 32 | 33 | def forward(self, x): 34 | x = x.view(x.size(0), -1) 35 | x = self.relu(self.bn1(F.linear(x, F.normalize(self.fc1)))) 36 | x = self.bn2(F.linear(x, F.normalize(self.fc2))) # 2*log(sigma) 37 | x = self.gamma * x + self.beta 38 | x = torch.log(1e-6 + torch.exp(x)) # log(sigma^2) 39 | return x 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | unh = UncertaintyHead(in_feat=3) 45 | 46 | mu_data = np.array([[-1.7847768 , -1.0991699 , 1.4248079 ], 47 | [ 1.0405252 , 0.35788524, 0.7338794 ], 48 | [ 1.0620259 , 2.1341069 , -1.0100055 ], 49 | [-0.00963581, 0.39570177, -1.5577421 ], 50 | [-1.064951 , -1.1261107 , -1.4181522 ], 51 | [ 1.008275 , -0.84791195, 0.3006532 ], 52 | [ 0.31099692, -0.32650718, -0.60247767]]) 53 | 54 | muX = torch.from_numpy(mu_data).float() 55 | log_sigma_sq = unh(muX) 56 | print(log_sigma_sq) 57 | -------------------------------------------------------------------------------- /model/.ipynb_checkpoints/uncertainty_tf-checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import math 9 | import numpy as np 10 | import tensorflow as tf 11 | import tensorflow.contrib.slim as slim 12 | 13 | from IPython import embed 14 | 15 | batch_norm_params = { 16 | 'decay': 0.995, 17 | 'epsilon': 0.001, 18 | 'center': True, 19 | 'scale': True, 20 | 'updates_collections': None, 21 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ], 22 | } 23 | 24 | batch_norm_params_sigma = { 25 | 'decay': 0.995, 26 | 'epsilon': 0.001, 27 | 'center': False, 28 | 'scale': False, 29 | 'updates_collections': None, 30 | 'variables_collections': [ tf.GraphKeys.TRAINABLE_VARIABLES ],} 31 | 32 | def scale_and_shift(x, gamma_init=1.0, beta_init=0.0): 33 | num_channels = x.shape[-1].value 34 | with tf.variable_scope('scale_and_shift'): 35 | gamma = tf.get_variable('alpha', (), 36 | initializer=tf.constant_initializer(gamma_init), 37 | regularizer=slim.l2_regularizer(0.0), 38 | dtype=tf.float32) 39 | beta = tf.get_variable('gamma', (), 40 | initializer=tf.constant_initializer(beta_init), 41 | dtype=tf.float32) 42 | x = gamma * x + beta 43 | 44 | return x 45 | 46 | 47 | def inference(inputs, embedding_size, phase_train, 48 | weight_decay=5e-4, reuse=None, scope='UncertaintyModule'): 49 | with slim.arg_scope([slim.fully_connected], 50 | weights_regularizer=slim.l2_regularizer(weight_decay), 51 | activation_fn=tf.nn.relu): 52 | with tf.variable_scope(scope, [inputs], reuse=reuse): 53 | with slim.arg_scope([slim.batch_norm, slim.dropout], 54 | is_training=phase_train): 55 | print('UncertaintyModule input shape:', [dim.value for dim in inputs.shape]) 56 | 57 | net = slim.flatten(inputs) 58 | 59 | embed() 60 | net = slim.fully_connected(net, embedding_size, scope='fc1', 61 | normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params, 62 | activation_fn=tf.nn.relu) 63 | 64 | 65 | log_sigma_sq = slim.fully_connected(net, embedding_size, scope='fc_log_sigma_sq', 66 | normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params_sigma, 67 | activation_fn=None) 68 | 69 | # Share the gamma and beta for all dimensions 70 | log_sigma_sq = scale_and_shift(log_sigma_sq, 1e-4, -7.0) 71 | 72 | # Add epsilon for sigma_sq for numerical stableness 73 | log_sigma_sq = tf.log(1e-6 + tf.exp(log_sigma_sq)) 74 | 75 | return log_sigma_sq 76 | 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | mu_data = np.array([[-1.7847768 , -1.0991699 , 1.4248079 ], 82 | [ 1.0405252 , 0.35788524, 0.7338794 ], 83 | [ 1.0620259 , 2.1341069 , -1.0100055 ], 84 | [-0.00963581, 0.39570177, -1.5577421 ], 85 | [-1.064951 , -1.1261107 , -1.4181522 ], 86 | [ 1.008275 , -0.84791195, 0.3006532 ], 87 | [ 0.31099692, -0.32650718, -0.60247767]], dtype=np.float64) 88 | 89 | muX = tf.convert_to_tensor(mu_data) 90 | log_sigma_sq = inference(muX, 3, True) 91 | print(diff) -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .mls_loss import MLSLoss 2 | from .spherenet import SphereNet20 3 | from .resnet import resnet_zoo 4 | from .mobilenet import MobileFace 5 | from .uncertainty_head import UncertaintyHead 6 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/faceloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/faceloss.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/fc_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/fc_layer.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/mls_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/mls_loss.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/rescbam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/rescbam.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/spherenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/spherenet.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/uncertainty_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/model/__pycache__/uncertainty_head.cpython-37.pyc -------------------------------------------------------------------------------- /model/mls_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/23 5 | author: lujie 6 | """ 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from IPython import embed 12 | 13 | class MLSLoss(nn.Module): 14 | 15 | def __init__(self, mean = False): 16 | 17 | super(MLSLoss, self).__init__() 18 | self.mean = mean 19 | 20 | def negMLS(self, mu_X, sigma_sq_X): 21 | 22 | if self.mean: 23 | XX = torch.mul(mu_X, mu_X).sum(dim=1, keepdim=True) 24 | YY = torch.mul(mu_X.T, mu_X.T).sum(dim=0, keepdim=True) 25 | XY = torch.mm(mu_X, mu_X.T) 26 | mu_diff = XX + YY - 2 * XY 27 | sig_sum = sigma_sq_X.mean(dim=1, keepdim=True) + sigma_sq_X.T.sum(dim=0, keepdim=True) 28 | diff = mu_diff / (1e-8 + sig_sum) + mu_X.size(1) * torch.log(sig_sum) 29 | return diff 30 | else: 31 | mu_diff = mu_X.unsqueeze(1) - mu_X.unsqueeze(0) 32 | sig_sum = sigma_sq_X.unsqueeze(1) + sigma_sq_X.unsqueeze(0) 33 | diff = torch.mul(mu_diff, mu_diff) / (1e-10 + sig_sum) + torch.log(sig_sum) # BUG 34 | diff = diff.sum(dim=2, keepdim=False) 35 | return diff 36 | 37 | def forward(self, mu_X, log_sigma_sq, gty): 38 | 39 | mu_X = F.normalize(mu_X) # if mu_X was not normalized by l2 40 | non_diag_mask = (1 - torch.eye(mu_X.size(0))).int() 41 | if gty.device.type == 'cuda': 42 | non_diag_mask = non_diag_mask.cuda(0) 43 | sig_X = torch.exp(log_sigma_sq) 44 | loss_mat = self.negMLS(mu_X, sig_X) 45 | gty_mask = (torch.eq(gty[:, None], gty[None, :])).int() 46 | pos_mask = (non_diag_mask * gty_mask) > 0 47 | pos_loss = loss_mat[pos_mask].mean() 48 | return pos_loss 49 | 50 | 51 | if __name__ == "__main__": 52 | 53 | mls = MLSLoss(mean=False) 54 | gty = torch.Tensor([1, 2, 3, 2, 3, 3, 2]) 55 | mu_data = np.array([[-1.7847768 , -1.0991699 , 1.4248079 ], 56 | [ 1.0405252 , 0.35788524, 0.7338794 ], 57 | [ 1.0620259 , 2.1341069 , -1.0100055 ], 58 | [-0.00963581, 0.39570177, -1.5577421 ], 59 | [-1.064951 , -1.1261107 , -1.4181522 ], 60 | [ 1.008275 , -0.84791195, 0.3006532 ], 61 | [ 0.31099692, -0.32650718, -0.60247767]]) 62 | 63 | si_data = np.array([[-0.28463233, -2.5517333 , 1.4781238 ], 64 | [-0.10505871, -0.31454122, -0.29844758], 65 | [-1.3067418 , 0.48718405, 0.6779812 ], 66 | [ 2.024449 , -1.3925922 , -1.6178994 ], 67 | [-0.08328865, -0.396574 , 1.0888542 ], 68 | [ 0.13096762, -0.14382902, 0.2695235 ], 69 | [ 0.5405067 , -0.67946523, -0.8433032 ]]) 70 | 71 | muX = torch.from_numpy(mu_data) 72 | siX = torch.from_numpy(si_data) 73 | print(muX.shape) 74 | diff = mls(muX, siX, gty) 75 | print(diff) 76 | -------------------------------------------------------------------------------- /model/mobilenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | 4 | import thop 5 | import math 6 | import torch 7 | from torch import nn 8 | from IPython import embed 9 | 10 | class BottleNeck(nn.Module): 11 | def __init__(self, inp, oup, stride, expansion): 12 | 13 | super(BottleNeck, self).__init__() 14 | 15 | self.connect = stride == 1 and inp == oup 16 | self.conv = nn.Sequential( 17 | # 1*1 conv 18 | nn.Conv2d(inp, inp * expansion, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(inp * expansion), 20 | nn.PReLU(inp * expansion), 21 | 22 | # 3*3 depth wise conv 23 | nn.Conv2d(inp * expansion, inp * expansion, 3, stride, 1, groups=inp * expansion, bias=False), 24 | nn.BatchNorm2d(inp * expansion), 25 | nn.PReLU(inp * expansion), 26 | 27 | # 1*1 conv 28 | nn.Conv2d(inp * expansion, oup, 1, 1, 0, bias=False), 29 | nn.BatchNorm2d(oup)) 30 | 31 | 32 | def forward(self, x): 33 | 34 | out = self.conv(x) 35 | if self.connect: 36 | return x + out 37 | else: 38 | return out 39 | 40 | 41 | class ConvBlock(nn.Module): 42 | 43 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False): 44 | super(ConvBlock, self).__init__() 45 | self.linear = linear 46 | if dw: 47 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False) 48 | else: 49 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False) 50 | 51 | self.bn = nn.BatchNorm2d(oup) 52 | if not linear: 53 | self.prelu = nn.PReLU(oup) 54 | 55 | def forward(self, x): 56 | x = self.conv(x) 57 | x = self.bn(x) 58 | if self.linear: 59 | return x 60 | else: 61 | return self.prelu(x) 62 | 63 | 64 | class Flatten(nn.Module): 65 | def forward(self, input): 66 | return input.view(input.size(0), -1) 67 | 68 | 69 | class MobileFace(nn.Module): 70 | def __init__(self, feat_dim = 512, drop_ratio = 0.5): 71 | 72 | super(MobileFace, self).__init__() 73 | 74 | self.conv1 = ConvBlock(3, 64, 3, 2, 1) 75 | self.dwconv1 = ConvBlock(64, 64, 3, 1, 1, dw=True) 76 | self.cur_channel = 64 #t, c, n, s 77 | self.block_setting = [[2, 64, 5, 2], 78 | [4, 128, 1, 2], 79 | [2, 128, 6, 1], 80 | [4, 128, 1, 2], 81 | [2, 128, 2, 1]] 82 | self.layers = self._make_layer() 83 | 84 | self.conv2 = ConvBlock(128, 512, 1, 1, 0) 85 | self.linear7 = ConvBlock(512, 512, 7, 1, 0, dw=True, linear=True) # CORE 86 | self.linear1 = ConvBlock(512, feat_dim, 1, 1, 0, linear=True) 87 | 88 | # Arcface : BN-Dropout-FC-BN to get the final 512-D embedding feature 89 | ''' 90 | self.output_layer = nn.Sequential( 91 | nn.BatchNorm2d(512), 92 | nn.Dropout(p=drop_ratio), 93 | Flatten(), 94 | nn.Linear(512 * 7 * 7, feat_dim), # size / 16 95 | nn.BatchNorm1d(feat_dim)) 96 | ''' 97 | 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2. / n)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | 107 | def _make_layer(self, block = BottleNeck): 108 | layers = [] 109 | for t, c, n, s in self.block_setting: 110 | for i in range(n): 111 | if i == 0: 112 | layers.append(block(self.cur_channel, c, s, t)) 113 | else: 114 | layers.append(block(self.cur_channel, c, 1, t)) 115 | self.cur_channel = c 116 | 117 | return nn.Sequential(*layers) 118 | 119 | 120 | def forward(self, x): 121 | 122 | x = self.conv1(x) 123 | x = self.dwconv1(x) 124 | x = self.layers(x) 125 | x = self.conv2(x) 126 | # x = self.output_layer(x) 127 | x = self.linear7(x) 128 | sig_x = x 129 | x = self.linear1(x) 130 | x = x.view(x.size(0), -1) 131 | return x, sig_x 132 | 133 | 134 | if __name__ == "__main__": 135 | input = torch.Tensor(1, 3, 112, 112) 136 | model = MobileFace(use_cbam=True) 137 | flops, params = thop.profile(model, inputs=(input, )) 138 | flops, params = thop.clever_format([flops, params], "%.3f") 139 | print(flops, params) 140 | # model = model.eval() 141 | # out = model(input) 142 | # print(out.shape) 143 | 144 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/24 5 | author: lujie 6 | """ 7 | 8 | 9 | import thop 10 | import torch 11 | import torch.nn as nn 12 | 13 | from IPython import embed 14 | 15 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=dilation, groups=groups, bias=False, dilation=dilation) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | __constants__ = ['downsample'] 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 31 | base_width=64, dilation=1, norm_layer=None): 32 | super(BasicBlock, self).__init__() 33 | if norm_layer is None: 34 | norm_layer = nn.BatchNorm2d 35 | if groups != 1 or base_width != 64: 36 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 37 | if dilation > 1: 38 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, dilation=1, norm_layer=None): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 80 | self.bn2 = norm_layer(width) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class Flatten(nn.Module): 111 | def forward(self, input): 112 | return input.view(input.size(0), -1) 113 | 114 | 115 | class ResNet(nn.Module): 116 | 117 | def __init__(self, block, layers, feat_dim=512, drop_ratio = 0.5, zero_init_residual=False): 118 | 119 | super(ResNet, self).__init__() 120 | 121 | self.inplanes = 64 122 | ''' 123 | self.input_layer = nn.Sequential( 124 | nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), 125 | nn.BatchNorm2d(self.inplanes), 126 | nn.ReLU(inplace=True), 127 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 128 | ''' 129 | 130 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) # TODO 131 | self.bn1 = nn.BatchNorm2d(self.inplanes) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | self.layer1 = self._make_layer(block, 64, layers[0]) 135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 137 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 138 | 139 | # After the last conv-layer, BN-Dropout-FC-BN to get the final 512-D embedding feature 140 | self.output_layer = nn.Sequential( 141 | nn.BatchNorm2d(512 * block.expansion), 142 | nn.Dropout(p=drop_ratio), 143 | Flatten(), 144 | nn.Linear(int(512 * block.expansion * 7 * 7), int(feat_dim)), # size / 16 145 | nn.BatchNorm1d(int(feat_dim))) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 150 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | if zero_init_residual: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck): 157 | nn.init.constant_(m.bn3.weight, 0) 158 | elif isinstance(m, BasicBlock): 159 | nn.init.constant_(m.bn2.weight, 0) 160 | 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1): 163 | 164 | downsample = None 165 | if stride != 1 or self.inplanes != planes * block.expansion: 166 | downsample = nn.Sequential( 167 | conv1x1(self.inplanes, planes * block.expansion, stride), 168 | nn.BatchNorm2d(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample)) 173 | self.inplanes = planes * block.expansion 174 | for _ in range(1, blocks): 175 | layers.append(block(self.inplanes, planes)) 176 | 177 | return nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | 181 | x = self.conv1(x) 182 | x = self.bn1(x) 183 | x = self.relu(x) 184 | x = self.maxpool(x) 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | x = self.output_layer(x) 190 | 191 | return x 192 | 193 | 194 | def resnet_zoo(backbone = 'resnet18', feat_dim = 512, drop_ratio = 0.5): 195 | 196 | version_dict = { 197 | 'resnet18' : [2, 2, 2, 2], 198 | 'resnet34' : [3, 4, 6, 3], 199 | 'resnet50' : [3, 4, 6, 3], 200 | 'resnet101': [3, 4, 23,3], 201 | 'resnet152': [3, 8, 36,3], 202 | } 203 | if backbone == 'resnet18' or backbone == 'resnet34': 204 | block = BasicBlock 205 | else: 206 | block = Bottleneck 207 | 208 | return ResNet(block, version_dict[backbone], feat_dim, drop_ratio) 209 | 210 | 211 | 212 | if __name__ == "__main__": 213 | 214 | input = torch.Tensor(1, 3, 112, 112) 215 | model = resnet_zoo('resnet18') 216 | flops, params = thop.profile(model, inputs=(input, )) 217 | flops, params = thop.clever_format([flops, params], "%.3f") 218 | print(flops, params) 219 | -------------------------------------------------------------------------------- /model/spherenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/24 5 | author: lujie 6 | """ 7 | 8 | import thop 9 | import torch 10 | import torch.nn as nn 11 | 12 | from IPython import embed 13 | 14 | class SphereNet20(nn.Module): 15 | 16 | def __init__(self): 17 | 18 | super(SphereNet20, self).__init__() 19 | 20 | #input = B*3*112*96 21 | self.conv1_1 = nn.Conv2d(3,64,3,2,1) #=>B*64*56*48 22 | self.relu1_1 = nn.PReLU(64) 23 | self.conv1_2 = nn.Conv2d(64,64,3,1,1) 24 | self.relu1_2 = nn.PReLU(64) 25 | self.conv1_3 = nn.Conv2d(64,64,3,1,1) 26 | self.relu1_3 = nn.PReLU(64) 27 | 28 | self.conv2_1 = nn.Conv2d(64,128,3,2,1) #=>B*128*28*24 29 | self.relu2_1 = nn.PReLU(128) 30 | self.conv2_2 = nn.Conv2d(128,128,3,1,1) 31 | self.relu2_2 = nn.PReLU(128) 32 | self.conv2_3 = nn.Conv2d(128,128,3,1,1) 33 | self.relu2_3 = nn.PReLU(128) 34 | 35 | self.conv2_4 = nn.Conv2d(128,128,3,1,1) #=>B*128*28*24 36 | self.relu2_4 = nn.PReLU(128) 37 | self.conv2_5 = nn.Conv2d(128,128,3,1,1) 38 | self.relu2_5 = nn.PReLU(128) 39 | 40 | 41 | self.conv3_1 = nn.Conv2d(128,256,3,2,1) #=>B*256*14*12 42 | self.relu3_1 = nn.PReLU(256) 43 | self.conv3_2 = nn.Conv2d(256,256,3,1,1) 44 | self.relu3_2 = nn.PReLU(256) 45 | self.conv3_3 = nn.Conv2d(256,256,3,1,1) 46 | self.relu3_3 = nn.PReLU(256) 47 | 48 | self.conv3_4 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 49 | self.relu3_4 = nn.PReLU(256) 50 | self.conv3_5 = nn.Conv2d(256,256,3,1,1) 51 | self.relu3_5 = nn.PReLU(256) 52 | 53 | self.conv3_6 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 54 | self.relu3_6 = nn.PReLU(256) 55 | self.conv3_7 = nn.Conv2d(256,256,3,1,1) 56 | self.relu3_7 = nn.PReLU(256) 57 | 58 | self.conv3_8 = nn.Conv2d(256,256,3,1,1) #=>B*256*14*12 59 | self.relu3_8 = nn.PReLU(256) 60 | self.conv3_9 = nn.Conv2d(256,256,3,1,1) 61 | self.relu3_9 = nn.PReLU(256) 62 | 63 | self.conv4_1 = nn.Conv2d(256,512,3,2,1) #=>B*512*7*6 64 | self.relu4_1 = nn.PReLU(512) 65 | self.conv4_2 = nn.Conv2d(512,512,3,1,1) 66 | self.relu4_2 = nn.PReLU(512) 67 | self.conv4_3 = nn.Conv2d(512,512,3,1,1) 68 | self.relu4_3 = nn.PReLU(512) 69 | 70 | self.fc5 = nn.Linear(512*7*7, 512) 71 | 72 | 73 | def forward(self, x): 74 | 75 | x = self.relu1_1(self.conv1_1(x)) 76 | x = x + self.relu1_3(self.conv1_3(self.relu1_2(self.conv1_2(x)))) 77 | 78 | x = self.relu2_1(self.conv2_1(x)) 79 | x = x + self.relu2_3(self.conv2_3(self.relu2_2(self.conv2_2(x)))) 80 | x = x + self.relu2_5(self.conv2_5(self.relu2_4(self.conv2_4(x)))) 81 | 82 | x = self.relu3_1(self.conv3_1(x)) 83 | x = x + self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(x)))) 84 | x = x + self.relu3_5(self.conv3_5(self.relu3_4(self.conv3_4(x)))) 85 | x = x + self.relu3_7(self.conv3_7(self.relu3_6(self.conv3_6(x)))) 86 | x = x + self.relu3_9(self.conv3_9(self.relu3_8(self.conv3_8(x)))) 87 | 88 | x = self.relu4_1(self.conv4_1(x)) 89 | x = x + self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(x)))) 90 | 91 | x = x.view(x.size(0),-1) 92 | x = self.fc5(x) 93 | 94 | return x 95 | 96 | if __name__ == "__main__": 97 | 98 | input = torch.Tensor(1, 3, 112, 112) 99 | model = SphereNet20() 100 | flops, params = thop.profile(model, inputs=(input, )) 101 | flops, params = thop.clever_format([flops, params], "%.3f") 102 | print(flops, params) 103 | -------------------------------------------------------------------------------- /model/uncertainty_head.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #-*- coding:utf-8 -*- 3 | """ 4 | Created on 2020/04/23 5 | author: lujie 6 | """ 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import Parameter 12 | 13 | from IPython import embed 14 | 15 | class UncertaintyHead(nn.Module): 16 | ''' Evaluate the log(sigma^2) ''' 17 | 18 | def __init__(self, in_feat = 512): 19 | 20 | super(UncertaintyHead, self).__init__() 21 | self.fc1 = Parameter(torch.Tensor(in_feat, in_feat)) 22 | self.bn1 = nn.BatchNorm1d(in_feat, affine=True) 23 | self.relu = nn.ReLU(in_feat) 24 | self.fc2 = Parameter(torch.Tensor(in_feat, in_feat)) 25 | self.bn2 = nn.BatchNorm1d(in_feat, affine=False) 26 | self.gamma = Parameter(torch.Tensor([1.0])) 27 | self.beta = Parameter(torch.Tensor([0.0])) # default = -7.0 28 | 29 | nn.init.kaiming_normal_(self.fc1) 30 | nn.init.kaiming_normal_(self.fc2) 31 | 32 | 33 | def forward(self, x): 34 | x = x.view(x.size(0), -1) 35 | x = self.relu(self.bn1(F.linear(x, F.normalize(self.fc1)))) 36 | x = self.bn2(F.linear(x, F.normalize(self.fc2))) # 2*log(sigma) 37 | x = self.gamma * x + self.beta 38 | x = torch.log(1e-6 + torch.exp(x)) # log(sigma^2) 39 | return x 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | unh = UncertaintyHead(in_feat=3) 45 | 46 | mu_data = np.array([[-1.7847768 , -1.0991699 , 1.4248079 ], 47 | [ 1.0405252 , 0.35788524, 0.7338794 ], 48 | [ 1.0620259 , 2.1341069 , -1.0100055 ], 49 | [-0.00963581, 0.39570177, -1.5577421 ], 50 | [-1.064951 , -1.1261107 , -1.4181522 ], 51 | [ 1.008275 , -0.84791195, 0.3006532 ], 52 | [ 0.31099692, -0.32650718, -0.60247767]]) 53 | 54 | muX = torch.from_numpy(mu_data).float() 55 | log_sigma_sq = unh(muX) 56 | print(log_sigma_sq) 57 | -------------------------------------------------------------------------------- /test_img/.ipynb_checkpoints/Pedro_Solbes_0003-checkpoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/test_img/.ipynb_checkpoints/Pedro_Solbes_0003-checkpoint.jpg -------------------------------------------------------------------------------- /test_img/.ipynb_checkpoints/Pedro_Solbes_0004-checkpoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/test_img/.ipynb_checkpoints/Pedro_Solbes_0004-checkpoint.jpg -------------------------------------------------------------------------------- /test_img/.ipynb_checkpoints/Zico_0003-checkpoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/test_img/.ipynb_checkpoints/Zico_0003-checkpoint.jpg -------------------------------------------------------------------------------- /test_img/Pedro_Solbes_0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/test_img/Pedro_Solbes_0003.jpg -------------------------------------------------------------------------------- /test_img/Pedro_Solbes_0004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/test_img/Pedro_Solbes_0004.jpg -------------------------------------------------------------------------------- /test_img/Zico_0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ontheway361/pfe-pytorch/02b708bfc37f961b5d2f036b32bbe45a53c7e288/test_img/Zico_0003.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import random 9 | import numpy as np 10 | import torchvision 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from sklearn import metrics 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | 17 | import model as mlib 18 | import dataset as dlib 19 | from config import training_args 20 | 21 | torch.backends.cudnn.bencmark = True 22 | # os.environ["CUDA_VISIBLE_DEVICES"] = "6,7" # TODO 23 | 24 | from IPython import embed 25 | 26 | 27 | def my_collate_fn(batch): 28 | 29 | imgs, gtys = [], [] 30 | for pid_imgs, gty in batch: 31 | imgs.extend(pid_imgs) 32 | gtys.extend([gty] * len(pid_imgs)) 33 | return (torch.stack(imgs, dim=0), torch.Tensor(gtys).long()) 34 | 35 | 36 | 37 | class MetricFace(dlib.VerifyFace): 38 | 39 | def __init__(self, args): 40 | 41 | dlib.VerifyFace.__init__(self, args) 42 | self.args = args 43 | self.model = dict() 44 | self.data = dict() 45 | self.softmax= torch.nn.Softmax(dim=1) 46 | self.device = args.use_gpu and torch.cuda.is_available() 47 | 48 | 49 | def _report_settings(self): 50 | ''' Report the settings ''' 51 | 52 | str = '-' * 16 53 | print('%sEnvironment Versions%s' % (str, str)) 54 | print("- Python : {}".format(sys.version.strip().split('|')[0])) 55 | print("- PyTorch : {}".format(torch.__version__)) 56 | print("- TorchVison: {}".format(torchvision.__version__)) 57 | print("- USE_GPU : {}".format(self.device)) 58 | print('-' * 52) 59 | 60 | 61 | def _model_loader(self): 62 | 63 | self.model['backbone'] = mlib.MobileFace(self.args.in_feats, self.args.drop_ratio) 64 | # self.model['backbone'] = mlib.iresnet_zoo(self.args.backbone, drop_ratio=self.args.drop_ratio, use_se = self.args.use_se) # SEBlock 65 | # self.model['backbone'] = mlib.resnet_zoo(self.args.backbone, drop_ratio=self.args.drop_ratio) # ResBlock 66 | # self.model['metric'] = mlib.FullyConnectedLayer(self.args) 67 | self.model['uncertain'] = mlib.UncertaintyHead(self.args.in_feats) 68 | # self.model['criterion'] = mlib.FaceLoss(self.args) 69 | self.model['criterion'] = mlib.MLSLoss(mean=False) 70 | 71 | if self.args.freeze_backbone: 72 | for p in self.model['backbone'].parameters(): 73 | p.requires_grad = False 74 | 75 | self.model['optimizer'] = torch.optim.SGD( 76 | [# {'params': self.model['backbone'].parameters()}, 77 | #{'params': self.model['metric'].parameters()}, 78 | {'params': self.model['uncertain'].parameters()}], 79 | lr=self.args.base_lr, 80 | weight_decay=self.args.weight_decay, 81 | momentum=0.9, 82 | nesterov=True) 83 | self.model['scheduler'] = torch.optim.lr_scheduler.MultiStepLR( 84 | self.model['optimizer'], milestones=self.args.lr_adjust, gamma=self.args.gamma) 85 | if self.device: 86 | self.model['backbone'] = self.model['backbone'].cuda() 87 | self.model['uncertain'] = self.model['uncertain'].cuda() 88 | # self.model['metric'] = self.model['metric'].cuda() 89 | self.model['criterion'] = self.model['criterion'].cuda() 90 | 91 | if self.device and len(self.args.gpu_ids) > 1: 92 | self.model['backbone'] = torch.nn.DataParallel(self.model['backbone'], device_ids=self.args.gpu_ids) 93 | self.model['uncertain'] = torch.nn.DataParallel(self.model['uncertain'], device_ids=self.args.gpu_ids) 94 | # self.model['metric'] = torch.nn.DataParallel(self.model['metric'], device_ids=self.args.gpu_ids) 95 | print('Parallel mode was going ...') 96 | elif self.device: 97 | print('Single-gpu mode was going ...') 98 | else: 99 | print('CPU mode was going ...') 100 | 101 | if len(self.args.resume) > 2: 102 | checkpoint = torch.load(self.args.resume, map_location=lambda storage, loc: storage) 103 | # self.args.start_epoch = checkpoint['epoch'] 104 | self.model['backbone'].load_state_dict(checkpoint['backbone']) 105 | # self.model['uncertain'].load_state_dict(checkpoint['uncertain']) 106 | # self.model['metric'].load_state_dict(checkpoint['metric']) 107 | print('Resuming the train process at %3d epoches ...' % self.args.start_epoch) 108 | print('Model loading was finished ...') 109 | 110 | 111 | def _data_loader(self): 112 | 113 | self.data['train_loader'] = DataLoader( 114 | dlib.CASIAWebFacePFE(self.args, mode='train'), 115 | batch_size=self.args.batch_size, \ 116 | shuffle=True, 117 | collate_fn=my_collate_fn, 118 | ) 119 | # self.data['lfw'] = dlib.LFW(self.args) # TODO 120 | print('Data loading was finished ...') 121 | 122 | 123 | def _model_train(self, epoch = 0): 124 | 125 | self.model['backbone'].eval() 126 | # self.model['metric'].train() 127 | self.model['uncertain'].train() 128 | 129 | loss_recorder, batch_acc = [], [] 130 | for idx, (img, gty) in enumerate(self.data['train_loader']): 131 | 132 | img.requires_grad = False 133 | gty.requires_grad = False 134 | 135 | if self.device: 136 | img = img.cuda() 137 | gty = gty.cuda() 138 | 139 | feature, sig_feat = self.model['backbone'](img) # TODO 140 | # output = self.model['metric'](feature, gty) 141 | # loss = self.model['criterion'](output, gty) 142 | log_sig_sq = self.model['uncertain'](sig_feat) 143 | loss = self.model['criterion'](feature, log_sig_sq, gty) 144 | self.model['optimizer'].zero_grad() 145 | loss.backward() 146 | self.model['optimizer'].step() 147 | # predy = np.argmax(output.data.cpu().numpy(), axis=1) # TODO 148 | # it_acc = np.mean((predy == gty.data.cpu().numpy()).astype(int)) 149 | # batch_acc.append(it_acc) 150 | loss_recorder.append(loss.item()) 151 | if (idx + 1) % self.args.print_freq == 0: 152 | print('epoch : %2d|%2d, iter : %4d|%4d, loss : %.4f' % \ 153 | (epoch, self.args.end_epoch, idx+1, len(self.data['train_loader']), np.mean(loss_recorder))) 154 | ''' 155 | print('epoch : %2d|%2d, iter : %4d|%4d, loss : %.4f, batch_ave_acc : %.4f' % \ 156 | (epoch, self.args.end_epoch, idx+1, len(self.data['train_loader']), \ 157 | np.mean(loss_recorder), np.mean(batch_acc))) 158 | ''' 159 | train_loss = np.mean(loss_recorder) 160 | print('train_loss : %.4f' % train_loss) 161 | return train_loss 162 | 163 | 164 | def _verify_lfw(self): 165 | 166 | self._eval_lfw() 167 | 168 | self._k_folds() 169 | 170 | best_thresh, lfw_acc = self._eval_runner() 171 | 172 | return best_thresh, lfw_acc 173 | 174 | 175 | def _main_loop(self): 176 | 177 | if not os.path.exists(self.args.save_to): 178 | os.mkdir(self.args.save_to) 179 | 180 | max_lfw_acc, min_train_loss = 0.0, 100 181 | for epoch in range(self.args.start_epoch, self.args.end_epoch + 1): 182 | 183 | start_time = time.time() 184 | 185 | train_loss = self._model_train(epoch) 186 | self.model['scheduler'].step() 187 | # lfw_thresh, lfw_acc = self._verify_lfw() 188 | 189 | end_time = time.time() 190 | print('Single epoch cost time : %.2f mins' % ((end_time - start_time)/60)) 191 | 192 | if min_train_loss > train_loss: 193 | 194 | print('%snew SOTA was found%s' % ('*'*16, '*'*16)) 195 | # max_lfw_acc = max(max_lfw_acc, lfw_acc) 196 | min_train_loss = train_loss 197 | filename = os.path.join(self.args.save_to, 'sota.pth.tar') 198 | torch.save({ 199 | 'epoch' : epoch, 200 | 'backbone' : self.model['backbone'].state_dict(), 201 | 'uncertain' : self.model['uncertain'].state_dict(), 202 | 'train_loss': min_train_loss, 203 | }, filename) 204 | 205 | if epoch % self.args.save_freq == 0: 206 | filename = 'epoch_%d_train_loss_%.4f.pth.tar' % (epoch, train_loss) 207 | savename = os.path.join(self.args.save_to, filename) 208 | torch.save({ 209 | 'epoch' : epoch, 210 | 'backbone' : self.model['backbone'].state_dict(), 211 | 'uncertain' : self.model['uncertain'].state_dict(), 212 | 'train_loss': train_loss, 213 | }, savename) 214 | 215 | if self.args.is_debug: 216 | break 217 | 218 | 219 | def train_runner(self): 220 | 221 | self._report_settings() 222 | 223 | self._model_loader() 224 | 225 | self._data_loader() 226 | 227 | self._main_loop() 228 | 229 | 230 | if __name__ == "__main__": 231 | 232 | faceu = MetricFace(training_args()) 233 | faceu.train_runner() 234 | --------------------------------------------------------------------------------