├── figure1.png ├── loss.py ├── LICENES ├── README.md └── fc_layers.py /figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiaoboCASIA/SV-X-Softmax/HEAD/figure1.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | # Loss functions 5 | def loss_final(pred, label, loss_type, criteria, save_rate=0.9, gamma=2.0): 6 | if loss_type == 'Softmax': 7 | loss_final = criteria(pred, label) 8 | elif loss_type == 'FocalLoss': 9 | assert (gamma >= 0) 10 | input = F.cross_entropy(pred, label, reduce=False) 11 | pt = torch.exp(-input) 12 | loss = (1 - pt) ** gamma * input 13 | loss_final = loss.mean() 14 | elif loss_type == 'HardMining': 15 | batch_size = pred.shape[0] 16 | loss = F.cross_entropy(pred, label, reduce=False) 17 | ind_sorted = torch.argsort(-loss) # from big to small 18 | num_saved = int(save_rate * batch_size) 19 | ind_update = ind_sorted[:num_saved] 20 | loss_final = torch.sum(F.cross_entropy(pred[ind_update], label[ind_update])) 21 | else: 22 | raise Exception('unknown loss type!!') 23 | 24 | return loss_final 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /LICENES: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xiaobo Wang, Shifeng Zhang and Shuo Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SV-X-Softmax 2 | SV-X-Softmax is a new loss function, which adaptively emphasizes the mis-classified feature vectors to guide the 3 | discriminative feature learning. For more details, one can refer to our paper: \ 4 | "Mis-classifided Vector Guided Softmax Loss for Face Recognition" [arxiv](http://www.cbsr.ia.ac.cn/users/xiaobowang/papers/AAAI2020.pdf) \ 5 | In AAAI Conference on Artificial Intelligence (AAAI) 2020, **Oral** Presentation. \ 6 | ![alt](https://github.com/xiaoboCASIA/SV-X-Softmax/blob/master/figure1.png) \ 7 | Thank **Shifeng Zhang** and **Shuo Wang** for their helpful discussion and suggestion. 8 | 9 | ## Introduction 10 | This is an implementation of our SV-X-Softmax loss by **Pytorch** library. The repository contains the fc_layers.py and loss.py 11 | The old version: "Support Vector Guided Softmax Loss for Face Recognition" [arxiv](https://arxiv.org/abs/1812.11317) \ 12 | is implemented by **Caffe** library and does not remove the overlaps between training set and test set. The performance comparsion 13 | may not be fair in the old version. 14 | 15 | ## Dataset 16 | The original training set is [MS-Celeb-1M-v1c](http://trillionpairs.deepglint.com/overview), which constains 86,876 identities. 17 | However, in face recognition, it is very important to perform open-set evaluation, i.e., there should be no overlapping identities 18 | between training set and test set. In that way, we use the publicly available [script](https://github.com/happynear/FaceDatasets) to remove 14,186 identities from the training 19 | set MS-Celeb-1M-v1c. For clarity, we donate the refined training dataset as 20 | [MS-Celeb-1M-v1c-R](https://github.com/xiaoboCASIA/SV-X-Softmax/blob/master/deepglint_unoverlap_list.txt). 21 | 22 | ## Architecture 23 | The AttentionNet-IRSE network used in our paper is derived from the papers: 24 | 1. "Residual attention network for image classification" [Paper](https://arxiv.org/abs/1704.06904?source=post_page---------------------------) 25 | 2. "Arcface additive angular margin loss for deep face recognition" [Paper](https://arxiv.org/abs/1801.07698) 26 | 27 | 28 | ## Others 29 | 1. Note that our new loss is based on the well-cleaned training sets, when facing new datasets, one may need to clean them. 30 | 2. On the small test set like LFW, the improvement may not be obvious. It may be better to see the comparision on MegaFace or more large-scale test set. 31 | 3. Both training from stratch and finetuning are ok. One may try more training strategies. 32 | 4. We won the **1st** place in [RLQ](https://www.forlq.org/) challenge (all four tracks) and **2st** place in [LFR](http://www.insightface-challenge.com/overview) challenge (deepglint-large track) 33 | 34 | ## Citation 35 | If you find SV-X-Softmax helps your research, please cite our paper: 36 | -------------------------------------------------------------------------------- /fc_layers.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Parameter 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class FC(Module): 8 | def __init__(self, fc_type='MV-AM', margin=0.35, t=0.2, scale=32, embedding_size=512, num_class=72690, 9 | easy_margin=True): 10 | super(FC, self).__init__() 11 | self.weight = Parameter(torch.Tensor(embedding_size, num_class)) 12 | # initial kernel 13 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 14 | self.margin = margin 15 | self.t = t 16 | self.easy_margin = easy_margin 17 | self.scale = scale 18 | self.fc_type = fc_type 19 | self.cos_m = math.cos(margin) 20 | self.sin_m = math.sin(margin) 21 | 22 | # duplication formula 23 | self.iter = 0 24 | self.base = 1000 25 | self.alpha = 0.0001 26 | self.power = 2 27 | self.lambda_min = 5.0 28 | self.margin_formula = [ 29 | lambda x: x ** 0, 30 | lambda x: x ** 1, 31 | lambda x: 2 * x ** 2 - 1, 32 | lambda x: 4 * x ** 3 - 3 * x, 33 | lambda x: 8 * x ** 4 - 8 * x ** 2 + 1, 34 | lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x 35 | ] 36 | 37 | def forward(self, x, label): # x (M, K), w(K, N), y = xw (M, N), note both x and w are already l2 normalized. 38 | kernel_norm = F.normalize(self.weight, dim=0) 39 | cos_theta = torch.mm(x, kernel_norm) 40 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 41 | batch_size = label.size(0) 42 | gt = cos_theta[torch.arange(0, batch_size), label].view(-1, 1) # ground truth score 43 | 44 | if self.fc_type == 'FC': 45 | final_gt = gt 46 | elif self.fc_type == 'SphereFace': 47 | self.iter += 1 48 | self.cur_lambda = max(self.lambda_min, self.base * (1 + self.alpha * self.iter) ** (-1 * self.power)) 49 | cos_theta_m = self.margin_formula[int(self.margin)](gt) # cos(margin * gt) 50 | theta = gt.data.acos() 51 | k = ((self.margin * theta) / math.pi).floor() 52 | phi_theta = ((-1.0) ** k) * cos_theta_m - 2 * k 53 | final_gt = (self.cur_lambda * gt + phi_theta) / (1 + self.cur_lambda) 54 | elif self.fc_type == 'AM': # cosface 55 | if self.easy_margin: 56 | final_gt = torch.where(gt > 0, gt - self.margin, gt) 57 | else: 58 | final_gt = gt - self.margin 59 | elif self.fc_type == 'Arc': # arcface 60 | sin_theta = torch.sqrt(1.0 - torch.pow(gt, 2)) 61 | cos_theta_m = gt * self.cos_m - sin_theta * self.sin_m # cos(gt + margin) 62 | if self.easy_margin: 63 | final_gt = torch.where(gt > 0, cos_theta_m, gt) 64 | else: 65 | final_gt = cos_theta_m 66 | elif self.fc_type == 'MV-AM': 67 | mask = cos_theta > gt - self.margin 68 | hard_vector = cos_theta[mask] 69 | cos_theta[mask] = (self.t + 1.0) * hard_vector + self.t # adaptive 70 | # cos_theta[mask] = hard_vector + self.t #fixed 71 | if self.easy_margin: 72 | final_gt = torch.where(gt > 0, gt - self.margin, gt) 73 | else: 74 | final_gt = gt - self.margin 75 | elif self.fc_type == 'MV-Arc': 76 | sin_theta = torch.sqrt(1.0 - torch.pow(gt, 2)) 77 | cos_theta_m = gt * self.cos_m - sin_theta * self.sin_m # cos(gt + margin) 78 | 79 | mask = cos_theta > cos_theta_m 80 | hard_vector = cos_theta[mask] 81 | cos_theta[mask] = (self.t + 1.0) * hard_vector + self.t # adaptive 82 | # cos_theta[mask] = hard_vector + self.t #fixed 83 | if self.easy_margin: 84 | final_gt = torch.where(gt > 0, cos_theta_m, gt) 85 | else: 86 | final_gt = cos_theta_m 87 | # final_gt = torch.where(gt > cos_theta_m, cos_theta_m, gt) 88 | else: 89 | raise Exception('unknown fc type!') 90 | 91 | cos_theta.scatter_(1, label.data.view(-1, 1), final_gt) 92 | cos_theta *= self.scale 93 | return cos_theta --------------------------------------------------------------------------------