├── .gitignore ├── README.md ├── backbones ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── activation.cpython-38.pyc │ ├── iresnet.cpython-38.pyc │ └── utils.cpython-38.pyc ├── activation.py ├── iresnet.py └── utils.py ├── config ├── __pycache__ │ └── config.cpython-38.pyc └── config.py ├── eval ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── verification.cpython-38.pyc ├── evaluation.py └── verification.py ├── images └── margins.png ├── requirement.txt ├── run.sh ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── countFLOPS.cpython-38.pyc ├── utils_callbacks.cpython-38.pyc └── utils_logging.cpython-38.pyc ├── countFLOPS.py ├── dataset.py ├── losses.py ├── modelFLOPS.py ├── utils_amp.py ├── utils_callbacks.py └── utils_logging.py /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | output_features/ 3 | output_features_clean/ 4 | cmc_roc_results/ 5 | .idea 6 | .DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## This is the official repository of the paper: 4 | #### ElasticFace: Elastic Margin Loss for Deep Face Recognition 5 | Paper on arxiv: [arxiv](https://arxiv.org/pdf/2109.09416.pdf) 6 | #### *** Accepted CVPR workshops 2022 *** 7 | ![evaluation](https://raw.githubusercontent.com/fdbtrs/ElasticFace/main/images/margins.png) 8 | 9 | 10 | 11 | 12 | | Model | Log file| Pretrained model| checkpoint | 13 | | ------------- | ------------- |------------- | ------------- | 14 | | ElasticFace-Arc |[log file](https://drive.google.com/file/d/1jGm6rHh-jJ40c34u5eXBgAhR3u4KHblH/view?usp=sharing) |[pretrained-mode](https://drive.google.com/drive/folders/1q3ws_BQLmgXyiy2msvHummXq4pRqc1rx?usp=sharing) | 295672backbone.pth | 15 | | ElasticFace-Cos |[log file](https://drive.google.com/file/d/1XgfEQgEabinH--VhIusWQ8Js43vz1vK0/view?usp=sharing) |[pretrained-mode](https://drive.google.com/drive/folders/1ZiLLZXQ1jMzFwMGhYjtMwcdHmuedQb-2?usp=sharing) | 295672backbone.pth | 16 | | ElasticFace-Arc+ |[log file](https://drive.google.com/file/d/1cWphaOqgCtmJ8zgVfnMXh0mVl6EqQZNd/view?usp=sharing) |[pretrained-mode](https://drive.google.com/drive/folders/1sf-fNV5CeSpWuFj6Hkwp7Js8SBXjbPo_?usp=sharing) | 295672backbone.pth | 17 | | ElasticFace-Cos+ |[log file](https://drive.google.com/file/d/1aqCN5yfzgGijJLg2hcrsW3fvwHeHNu6W/view?usp=sharing) |[pretrained-mode](https://drive.google.com/drive/folders/19LXrjVNt60JBZP7JqsvOSWMwGLGrcJl5?usp=sharing) | 295672backbone.pth | 18 | 19 | Evaluation result: 20 | See: [Paper with code](https://paperswithcode.com/paper/elasticface-elastic-margin-loss-for-deep-face) 21 | 22 | 23 | 24 | ### Face recognition model training 25 | Model training: 26 | In the paper, we employ MS1MV2 as the training dataset which can be downloaded from InsightFace (MS1M-ArcFace in DataZoo) 27 | Download MS1MV2 dataset from [insightface](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_) on strictly follow the licence distribution 28 | 29 | Unzip the dataset and place it in the data folder 30 | Set the config.output and config.loss in the config/config.py 31 | 32 | 33 | 34 | All code has been trained and tested using Pytorch 1.7.1 35 | 36 | ## Face recognition evaluation 37 | ##### evaluation on LFW, AgeDb-30, CPLFW, CALFW and CFP-FP: 38 | 1. download the data from their offical webpages. 39 | 2. alternative: The evaluation datasets are available in the training dataset package as bin file 40 | 3. set the config.rec to dataset folder e.g. data/faces_emore 41 | 4. set the config.val_targets for list of the evaluation dataset 42 | 5. download the pretrained model from link the previous table 43 | 6. set the config.output to path to pretrained model weights 44 | 7. run eval/evaluation.py 45 | 8. the output is test.log contains the evaluation results over all epochs 46 | 47 | ### To-do 48 | - [x] Add evaluation script 49 | 50 | 51 | If you use any of the code provided in this repository, please cite the following paper: 52 | ## Citation 53 | ``` 54 | @InProceedings{Boutros_2022_CVPR, 55 | author = {Boutros, Fadi and Damer, Naser and Kirchbuchner, Florian and Kuijper, Arjan}, 56 | title = {ElasticFace: Elastic Margin Loss for Deep Face Recognition}, 57 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 58 | month = {June}, 59 | year = {2022}, 60 | pages = {1578-1587} 61 | } 62 | 63 | 64 | ``` 65 | 66 | 67 | ## License 68 | 69 | ``` 70 | This project is licensed under the terms of the Attribution-NonCommercial-ShareAlike 4.0 71 | International (CC BY-NC-SA 4.0) license. 72 | Copyright (c) 2021 Fraunhofer Institute for Computer Graphics Research IGD Darmstadt 73 | ``` 74 | -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/activation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/backbones/__pycache__/activation.cpython-38.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/iresnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/backbones/__pycache__/iresnet.cpython-38.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/backbones/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /backbones/activation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | import torch 5 | 6 | from inspect import isfunction 7 | 8 | class Identity(nn.Module): 9 | """ 10 | Identity block. 11 | """ 12 | def __init__(self): 13 | super(Identity, self).__init__() 14 | 15 | def forward(self, x): 16 | return x 17 | 18 | def __repr__(self): 19 | return '{name}()'.format(name=self.__class__.__name__) 20 | class HSigmoid(nn.Module): 21 | """ 22 | Approximated sigmoid function, so-called hard-version of sigmoid from 'Searching for MobileNetV3,' 23 | https://arxiv.org/abs/1905.02244. 24 | """ 25 | def forward(self, x): 26 | return F.relu6(x + 3.0, inplace=True) / 6.0 27 | 28 | 29 | class Swish(nn.Module): 30 | """ 31 | Swish activation function from 'Searching for Activation Functions,' https://arxiv.org/abs/1710.05941. 32 | """ 33 | def forward(self, x): 34 | return x * torch.sigmoid(x) 35 | class HSwish(nn.Module): 36 | """ 37 | H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244. 38 | Parameters: 39 | ---------- 40 | inplace : bool 41 | Whether to use inplace version of the module. 42 | """ 43 | def __init__(self, inplace=False): 44 | super(HSwish, self).__init__() 45 | self.inplace = inplace 46 | 47 | def forward(self, x): 48 | return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 49 | 50 | 51 | def get_activation_layer(activation,param): 52 | """ 53 | Create activation layer from string/function. 54 | Parameters: 55 | ---------- 56 | activation : function, or str, or nn.Module 57 | Activation function or name of activation function. 58 | Returns: 59 | ------- 60 | nn.Module 61 | Activation layer. 62 | """ 63 | assert (activation is not None) 64 | if isfunction(activation): 65 | return activation() 66 | elif isinstance(activation, str): 67 | if activation == "relu": 68 | return nn.ReLU(inplace=True) 69 | elif activation =="prelu": 70 | return nn.PReLU(param) 71 | elif activation == "relu6": 72 | return nn.ReLU6(inplace=True) 73 | elif activation == "swish": 74 | return Swish() 75 | elif activation == "hswish": 76 | return HSwish(inplace=True) 77 | elif activation == "sigmoid": 78 | return nn.Sigmoid() 79 | elif activation == "hsigmoid": 80 | return HSigmoid() 81 | elif activation == "identity": 82 | return Identity() 83 | else: 84 | raise NotImplementedError() 85 | else: 86 | assert (isinstance(activation, nn.Module)) 87 | return activation -------------------------------------------------------------------------------- /backbones/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100'] 5 | 6 | from utils.countFLOPS import _calc_width, count_model_flops 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, 11 | out_planes, 12 | kernel_size=3, 13 | stride=stride, 14 | padding=dilation, 15 | groups=groups, 16 | bias=False, 17 | dilation=dilation) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, 23 | out_planes, 24 | kernel_size=1, 25 | stride=stride, 26 | bias=False) 27 | class SEModule(nn.Module): 28 | def __init__(self, channels, reduction): 29 | super(SEModule, self).__init__() 30 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 31 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 34 | self.sigmoid = nn.Sigmoid() 35 | 36 | def forward(self, x): 37 | input = x 38 | x = self.avg_pool(x) 39 | x = self.fc1(x) 40 | x = self.relu(x) 41 | x = self.fc2(x) 42 | x = self.sigmoid(x) 43 | 44 | return input * x 45 | 46 | class IBasicBlock(nn.Module): 47 | expansion = 1 48 | def __init__(self, inplanes, planes, stride=1, downsample=None, 49 | groups=1, base_width=64, dilation=1,use_se=False): 50 | super(IBasicBlock, self).__init__() 51 | if groups != 1 or base_width != 64: 52 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 53 | if dilation > 1: 54 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 55 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 56 | self.conv1 = conv3x3(inplanes, planes) 57 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 58 | self.prelu = nn.PReLU(planes) 59 | self.conv2 = conv3x3(planes, planes, stride) 60 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 61 | self.downsample = downsample 62 | self.stride = stride 63 | self.use_se=use_se 64 | if (use_se): 65 | self.se_block=SEModule(planes,16) 66 | 67 | def forward(self, x): 68 | identity = x 69 | out = self.bn1(x) 70 | out = self.conv1(out) 71 | out = self.bn2(out) 72 | out = self.prelu(out) 73 | out = self.conv2(out) 74 | out = self.bn3(out) 75 | if(self.use_se): 76 | out=self.se_block(out) 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | out += identity 80 | return out 81 | 82 | 83 | class IResNet(nn.Module): 84 | fc_scale = 7 * 7 85 | def __init__(self, 86 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 87 | groups=1, width_per_group=64, replace_stride_with_dilation=None, use_se=False): 88 | super(IResNet, self).__init__() 89 | self.inplanes = 64 90 | self.dilation = 1 91 | self.use_se=use_se 92 | if replace_stride_with_dilation is None: 93 | replace_stride_with_dilation = [False, False, False] 94 | if len(replace_stride_with_dilation) != 3: 95 | raise ValueError("replace_stride_with_dilation should be None " 96 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 97 | self.groups = groups 98 | self.base_width = width_per_group 99 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 100 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 101 | self.prelu = nn.PReLU(self.inplanes) 102 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2 ,use_se=self.use_se) 103 | self.layer2 = self._make_layer(block, 104 | 128, 105 | layers[1], 106 | stride=2, 107 | dilate=replace_stride_with_dilation[0],use_se=self.use_se) 108 | self.layer3 = self._make_layer(block, 109 | 256, 110 | layers[2], 111 | stride=2, 112 | dilate=replace_stride_with_dilation[1] ,use_se=self.use_se) 113 | self.layer4 = self._make_layer(block, 114 | 512, 115 | layers[3], 116 | stride=2, 117 | dilate=replace_stride_with_dilation[2] ,use_se=self.use_se) 118 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 119 | self.dropout =nn.Dropout(p=dropout, inplace=True) # 7x7x 512 120 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 121 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 122 | nn.init.constant_(self.features.weight, 1.0) 123 | self.features.weight.requires_grad = False 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | nn.init.normal_(m.weight, 0, 0.1) 128 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 129 | nn.init.constant_(m.weight, 1) 130 | nn.init.constant_(m.bias, 0) 131 | 132 | if zero_init_residual: 133 | for m in self.modules(): 134 | if isinstance(m, IBasicBlock): 135 | nn.init.constant_(m.bn2.weight, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False,use_se=False): 138 | downsample = None 139 | previous_dilation = self.dilation 140 | if dilate: 141 | self.dilation *= stride 142 | stride = 1 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | conv1x1(self.inplanes, planes * block.expansion, stride), 146 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 147 | ) 148 | layers = [] 149 | layers.append( 150 | block(self.inplanes, planes, stride, downsample, self.groups, 151 | self.base_width, previous_dilation,use_se=use_se)) 152 | self.inplanes = planes * block.expansion 153 | for _ in range(1, blocks): 154 | layers.append( 155 | block(self.inplanes, 156 | planes, 157 | groups=self.groups, 158 | base_width=self.base_width, 159 | dilation=self.dilation,use_se=use_se)) 160 | 161 | return nn.Sequential(*layers) 162 | 163 | def forward(self, x): 164 | x = self.conv1(x) 165 | x = self.bn1(x) 166 | x = self.prelu(x) 167 | x = self.layer1(x) 168 | x = self.layer2(x) 169 | x = self.layer3(x) 170 | x = self.layer4(x) 171 | x = self.bn2(x) 172 | x = torch.flatten(x, 1) 173 | x = self.dropout(x) 174 | x = self.fc(x) 175 | x = self.features(x) 176 | return x 177 | 178 | 179 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 180 | model = IResNet(block, layers, **kwargs) 181 | if pretrained: 182 | raise ValueError() 183 | return model 184 | 185 | 186 | def iresnet18(pretrained=False, progress=True, **kwargs): 187 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 188 | progress, **kwargs) 189 | 190 | 191 | def iresnet34(pretrained=False, progress=True, **kwargs): 192 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 193 | progress, **kwargs) 194 | 195 | 196 | def iresnet50(pretrained=False, progress=True, **kwargs): 197 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 198 | progress, **kwargs) 199 | 200 | 201 | def iresnet100(pretrained=False, progress=True, **kwargs): 202 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 203 | progress, **kwargs) 204 | def _test(): 205 | import torch 206 | 207 | pretrained = False 208 | 209 | models = [ 210 | iresnet100 211 | ] 212 | 213 | for model in models: 214 | 215 | net = model() 216 | print(net) 217 | # net.train() 218 | weight_count = _calc_width(net) 219 | flops=count_model_flops(net) 220 | print("m={}, {}".format(model.__name__, weight_count)) 221 | print("m={}, {}".format(model.__name__, flops)) 222 | net.eval() 223 | 224 | x = torch.randn(1, 3, 112, 112) 225 | 226 | y = net(x) 227 | y.sum().backward() 228 | assert (tuple(y.size()) == (1, 512)) 229 | 230 | 231 | if __name__ == "__main__": 232 | _test() 233 | -------------------------------------------------------------------------------- /backbones/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from backbones.activation import get_activation_layer 6 | 7 | 8 | 9 | 10 | class DropBlock2D(nn.Module): 11 | r"""Randomly zeroes 2D spatial blocks of the input tensor. 12 | As described in the paper 13 | `DropBlock: A regularization method for convolutional networks`_ , 14 | dropping whole blocks of feature map allows to remove semantic 15 | information as compared to regular dropout. 16 | Args: 17 | drop_prob (float): probability of an element to be dropped. 18 | block_size (int): size of the block to drop 19 | Shape: 20 | - Input: `(N, C, H, W)` 21 | - Output: `(N, C, H, W)` 22 | .. _DropBlock: A regularization method for convolutional networks: 23 | https://arxiv.org/abs/1810.12890 24 | """ 25 | 26 | def __init__(self, drop_prob, block_size): 27 | super(DropBlock2D, self).__init__() 28 | 29 | self.drop_prob = drop_prob 30 | self.block_size = block_size 31 | 32 | def forward(self, x): 33 | # shape: (bsize, channels, height, width) 34 | 35 | assert x.dim() == 4, \ 36 | "Expected input with 4 dimensions (bsize, channels, height, width)" 37 | 38 | if not self.training or self.drop_prob == 0.: 39 | return x 40 | else: 41 | # get gamma value 42 | gamma = self._compute_gamma(x) 43 | 44 | # sample mask 45 | mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float() 46 | 47 | # place mask on input device 48 | mask = mask.to(x.device) 49 | 50 | # compute block mask 51 | block_mask = self._compute_block_mask(mask) 52 | 53 | # apply block mask 54 | out = x * block_mask[:, None, :, :] 55 | 56 | # scale output 57 | out = out * block_mask.numel() / block_mask.sum() 58 | 59 | return out 60 | 61 | def _compute_block_mask(self, mask): 62 | block_mask = F.max_pool2d(input=mask[:, None, :, :], 63 | kernel_size=(self.block_size, self.block_size), 64 | stride=(1,1), 65 | padding=self.block_size//2) 66 | if self.block_size % 2 == 0: 67 | block_mask = block_mask[:, :, :-1, :-1] 68 | 69 | block_mask = 1 - block_mask.squeeze(1) 70 | 71 | return block_mask 72 | 73 | def _compute_gamma(self, x): 74 | return self.drop_prob / (self.block_size**2) 75 | 76 | def round_channels(channels, 77 | divisor=8): 78 | """ 79 | Round weighted channel number (make divisible operation). 80 | 81 | Parameters: 82 | ---------- 83 | channels : int or float 84 | Original number of channels. 85 | divisor : int, default 8 86 | Alignment value. 87 | 88 | Returns: 89 | ------- 90 | int 91 | Weighted number of channels. 92 | """ 93 | rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) 94 | if float(rounded_channels) < 0.9 * channels: 95 | rounded_channels += divisor 96 | return rounded_channels 97 | 98 | def conv1x1(in_channels, 99 | out_channels, 100 | stride=1, 101 | groups=1, dilation=1, 102 | bias=False): 103 | """ 104 | Convolution 1x1 layer. 105 | 106 | Parameters: 107 | ---------- 108 | in_channels : int 109 | Number of input channels. 110 | out_channels : int 111 | Number of output channels. 112 | stride : int or tuple/list of 2 int, default 1 113 | Strides of the convolution. 114 | groups : int, default 1 115 | Number of groups. 116 | bias : bool, default False 117 | Whether the layer uses a bias vector. 118 | """ 119 | return nn.Conv2d( 120 | in_channels=in_channels, 121 | out_channels=out_channels, 122 | kernel_size=1, 123 | stride=stride, 124 | groups=groups, dilation=dilation, 125 | bias=bias) 126 | 127 | 128 | 129 | 130 | def conv3x3(in_channels, 131 | out_channels, 132 | stride=1, 133 | padding=1, 134 | dilation=1, 135 | groups=1, 136 | bias=False): 137 | """ 138 | Convolution 3x3 layer. 139 | 140 | Parameters: 141 | ---------- 142 | in_channels : int 143 | Number of input channels. 144 | out_channels : int 145 | Number of output channels. 146 | stride : int or tuple/list of 2 int, default 1 147 | Strides of the convolution. 148 | padding : int or tuple/list of 2 int, default 1 149 | Padding value for convolution layer. 150 | dilation : int or tuple/list of 2 int, default 1 151 | Dilation value for convolution layer. 152 | groups : int, default 1 153 | Number of groups. 154 | bias : bool, default False 155 | Whether the layer uses a bias vector. 156 | """ 157 | return nn.Conv2d( 158 | in_channels=in_channels, 159 | out_channels=out_channels, 160 | kernel_size=3, 161 | stride=stride, 162 | padding=padding, 163 | dilation=dilation, 164 | groups=groups, 165 | bias=bias) 166 | class Flatten(nn.Module): 167 | """ 168 | Simple flatten module. 169 | """ 170 | 171 | def forward(self, x): 172 | return x.view(x.size(0), -1) 173 | 174 | def depthwise_conv3x3(channels, 175 | stride=1, 176 | padding=1, 177 | dilation=1, 178 | bias=False): 179 | """ 180 | Depthwise convolution 3x3 layer. 181 | 182 | Parameters: 183 | ---------- 184 | channels : int 185 | Number of input/output channels. 186 | strides : int or tuple/list of 2 int, default 1 187 | Strides of the convolution. 188 | padding : int or tuple/list of 2 int, default 1 189 | Padding value for convolution layer. 190 | dilation : int or tuple/list of 2 int, default 1 191 | Dilation value for convolution layer. 192 | bias : bool, default False 193 | Whether the layer uses a bias vector. 194 | """ 195 | return nn.Conv2d( 196 | in_channels=channels, 197 | out_channels=channels, 198 | kernel_size=3, 199 | stride=stride, 200 | padding=padding, 201 | dilation=dilation, 202 | groups=channels, 203 | bias=bias) 204 | class ConvBlock(nn.Module): 205 | """ 206 | Standard convolution block with Batch normalization and activation. 207 | 208 | Parameters: 209 | ---------- 210 | in_channels : int 211 | Number of input channels. 212 | out_channels : int 213 | Number of output channels. 214 | kernel_size : int or tuple/list of 2 int 215 | Convolution window size. 216 | stride : int or tuple/list of 2 int 217 | Strides of the convolution. 218 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int 219 | Padding value for convolution layer. 220 | dilation : int or tuple/list of 2 int, default 1 221 | Dilation value for convolution layer. 222 | groups : int, default 1 223 | Number of groups. 224 | bias : bool, default False 225 | Whether the layer uses a bias vector. 226 | use_bn : bool, default True 227 | Whether to use BatchNorm layer. 228 | bn_eps : float, default 1e-5 229 | Small float added to variance in Batch norm. 230 | activation : function or str or None, default nn.ReLU(inplace=True) 231 | Activation function or name of activation function. 232 | """ 233 | def __init__(self, 234 | in_channels, 235 | out_channels, 236 | kernel_size, 237 | stride, 238 | padding, 239 | dilation=1, 240 | groups=1, 241 | bias=False, 242 | use_bn=True, 243 | bn_eps=1e-5, 244 | activation=(lambda: nn.ReLU(inplace=True))): 245 | super(ConvBlock, self).__init__() 246 | self.activate = (activation is not None) 247 | self.use_bn = use_bn 248 | self.use_pad = (isinstance(padding, (list, tuple)) and (len(padding) == 4)) 249 | 250 | if self.use_pad: 251 | self.pad = nn.ZeroPad2d(padding=padding) 252 | padding = 0 253 | self.conv = nn.Conv2d( 254 | in_channels=in_channels, 255 | out_channels=out_channels, 256 | kernel_size=kernel_size, 257 | stride=stride, 258 | padding=padding, 259 | dilation=dilation, 260 | groups=groups, 261 | bias=bias) 262 | if self.use_bn: 263 | self.bn = nn.BatchNorm2d( 264 | num_features=out_channels, 265 | eps=bn_eps) 266 | if self.activate: 267 | self.activ = get_activation_layer(activation,out_channels) 268 | 269 | def forward(self, x): 270 | if self.use_pad: 271 | x = self.pad(x) 272 | x = self.conv(x) 273 | if self.use_bn: 274 | x = self.bn(x) 275 | if self.activate: 276 | x = self.activ(x) 277 | return x 278 | 279 | def conv1x1_block(in_channels, 280 | out_channels, 281 | stride=1, 282 | padding=0, 283 | groups=1, 284 | bias=False, 285 | use_bn=True, 286 | bn_eps=1e-5, 287 | activation=(lambda: nn.ReLU(inplace=True))): 288 | """ 289 | 1x1 version of the standard convolution block. 290 | 291 | Parameters: 292 | ---------- 293 | in_channels : int 294 | Number of input channels. 295 | out_channels : int 296 | Number of output channels. 297 | stride : int or tuple/list of 2 int, default 1 298 | Strides of the convolution. 299 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 0 300 | Padding value for convolution layer. 301 | groups : int, default 1 302 | Number of groups. 303 | bias : bool, default False 304 | Whether the layer uses a bias vector. 305 | use_bn : bool, default True 306 | Whether to use BatchNorm layer. 307 | bn_eps : float, default 1e-5 308 | Small float added to variance in Batch norm. 309 | activation : function or str or None, default nn.ReLU(inplace=True) 310 | Activation function or name of activation function. 311 | """ 312 | return ConvBlock( 313 | in_channels=in_channels, 314 | out_channels=out_channels, 315 | kernel_size=1, 316 | stride=stride, 317 | padding=padding, 318 | groups=groups, 319 | bias=bias, 320 | use_bn=use_bn, 321 | bn_eps=bn_eps, 322 | activation=activation) 323 | 324 | def conv3x3_block(in_channels, 325 | out_channels, 326 | stride=1, 327 | padding=1, 328 | dilation=1, 329 | groups=1, 330 | bias=False, 331 | use_bn=True, 332 | bn_eps=1e-5, 333 | activation=(lambda: nn.ReLU(inplace=True))): 334 | """ 335 | 3x3 version of the standard convolution block. 336 | 337 | Parameters: 338 | ---------- 339 | in_channels : int 340 | Number of input channels. 341 | out_channels : int 342 | Number of output channels. 343 | stride : int or tuple/list of 2 int, default 1 344 | Strides of the convolution. 345 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1 346 | Padding value for convolution layer. 347 | dilation : int or tuple/list of 2 int, default 1 348 | Dilation value for convolution layer. 349 | groups : int, default 1 350 | Number of groups. 351 | bias : bool, default False 352 | Whether the layer uses a bias vector. 353 | use_bn : bool, default True 354 | Whether to use BatchNorm layer. 355 | bn_eps : float, default 1e-5 356 | Small float added to variance in Batch norm. 357 | activation : function or str or None, default nn.ReLU(inplace=True) 358 | Activation function or name of activation function. 359 | """ 360 | return ConvBlock( 361 | in_channels=in_channels, 362 | out_channels=out_channels, 363 | kernel_size=3, 364 | stride=stride, 365 | padding=padding, 366 | dilation=dilation, 367 | groups=groups, 368 | bias=bias, 369 | use_bn=use_bn, 370 | bn_eps=bn_eps, 371 | activation=activation) 372 | class DwsConvBlock(nn.Module): 373 | """ 374 | Depthwise separable convolution block with BatchNorms and activations at each convolution layers. 375 | 376 | Parameters: 377 | ---------- 378 | in_channels : int 379 | Number of input channels. 380 | out_channels : int 381 | Number of output channels. 382 | kernel_size : int or tuple/list of 2 int 383 | Convolution window size. 384 | stride : int or tuple/list of 2 int 385 | Strides of the convolution. 386 | padding : int, or tuple/list of 2 int, or tuple/list of 4 int 387 | Padding value for convolution layer. 388 | dilation : int or tuple/list of 2 int, default 1 389 | Dilation value for convolution layer. 390 | bias : bool, default False 391 | Whether the layer uses a bias vector. 392 | dw_use_bn : bool, default True 393 | Whether to use BatchNorm layer (depthwise convolution block). 394 | pw_use_bn : bool, default True 395 | Whether to use BatchNorm layer (pointwise convolution block). 396 | bn_eps : float, default 1e-5 397 | Small float added to variance in Batch norm. 398 | dw_activation : function or str or None, default nn.ReLU(inplace=True) 399 | Activation function after the depthwise convolution block. 400 | pw_activation : function or str or None, default nn.ReLU(inplace=True) 401 | Activation function after the pointwise convolution block. 402 | """ 403 | def __init__(self, 404 | in_channels, 405 | out_channels, 406 | kernel_size, 407 | stride, 408 | padding, 409 | dilation=1, 410 | bias=False, 411 | dw_use_bn=True, 412 | pw_use_bn=True, 413 | bn_eps=1e-5, 414 | dw_activation=(lambda: nn.ReLU(inplace=True)), 415 | pw_activation=(lambda: nn.ReLU(inplace=True))): 416 | super(DwsConvBlock, self).__init__() 417 | self.dw_conv = dwconv_block( 418 | in_channels=in_channels, 419 | out_channels=in_channels, 420 | kernel_size=kernel_size, 421 | stride=stride, 422 | padding=padding, 423 | dilation=dilation, 424 | bias=bias, 425 | use_bn=dw_use_bn, 426 | bn_eps=bn_eps, 427 | activation=dw_activation) 428 | self.pw_conv = conv1x1_block( 429 | in_channels=in_channels, 430 | out_channels=out_channels, 431 | bias=bias, 432 | use_bn=pw_use_bn, 433 | bn_eps=bn_eps, 434 | activation=pw_activation) 435 | 436 | def forward(self, x): 437 | x = self.dw_conv(x) 438 | 439 | x = self.pw_conv(x) 440 | 441 | return x 442 | 443 | 444 | def dwconv_block(in_channels, 445 | out_channels, 446 | kernel_size, 447 | stride=1, 448 | padding=1, 449 | dilation=1, 450 | bias=False, 451 | use_bn=True, 452 | bn_eps=1e-5, 453 | activation=(lambda: nn.ReLU(inplace=True))): 454 | """ 455 | Depthwise convolution block. 456 | """ 457 | return ConvBlock( 458 | in_channels=in_channels, 459 | out_channels=out_channels, 460 | kernel_size=kernel_size, 461 | stride=stride, 462 | padding=padding, 463 | dilation=dilation, 464 | groups=out_channels, 465 | bias=bias, 466 | use_bn=use_bn, 467 | bn_eps=bn_eps, 468 | activation=activation) 469 | 470 | def channel_shuffle2(x, 471 | groups): 472 | """ 473 | Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices,' 474 | https://arxiv.org/abs/1707.01083. The alternative version. 475 | 476 | Parameters: 477 | ---------- 478 | x : Tensor 479 | Input tensor. 480 | groups : int 481 | Number of groups. 482 | 483 | Returns: 484 | ------- 485 | Tensor 486 | Resulted tensor. 487 | """ 488 | batch, channels, height, width = x.size() 489 | assert (channels % groups == 0) 490 | channels_per_group = channels // groups 491 | 492 | x = x.view(batch, channels_per_group, groups, height, width) 493 | x = torch.transpose(x, 1, 2).contiguous() 494 | 495 | x = x.view(batch, channels, height, width) 496 | return x 497 | 498 | def _calc_width(net): 499 | import numpy as np 500 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 501 | weight_count = 0 502 | for param in net_params: 503 | weight_count += np.prod(param.size()) 504 | return weight_count 505 | -------------------------------------------------------------------------------- /config/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/config/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | config = edict() 4 | config.dataset = "emoreIresNet" # training dataset 5 | config.embedding_size = 512 # embedding size of model 6 | config.momentum = 0.9 7 | config.weight_decay = 5e-4 8 | config.batch_size = 128 # batch size per GPU 9 | config.lr = 0.1 10 | config.output = "output/R100_ElasticArcFace" # train model output folder 11 | config.global_step=0 # step to resume 12 | config.s=64.0 13 | config.m=0.50 14 | config.std=0.05 15 | 16 | 17 | config.loss="ElasticArcFace" # Option : ElasticArcFace, ArcFace, ElasticCosFace, CosFace, MLLoss, ElasticArcFacePlus, ElasticCosFacePlus 18 | 19 | if (config.loss=="ElasticArcFacePlus"): 20 | config.s = 64.0 21 | config.m = 0.50 22 | config.std = 0.0175 23 | elif (config.loss=="ElasticArcFace"): 24 | config.s = 64.0 25 | config.m = 0.50 26 | config.std = 0.05 27 | if (config.loss=="ElasticCosFacePlus"): 28 | config.s = 64.0 29 | config.m = 0.35 30 | config.std = 0.02 31 | elif (config.loss=="ElasticCosFace"): 32 | config.s = 64.0 33 | config.m = 0.35 34 | config.std = 0.05 35 | 36 | 37 | # type of network to train [iresnet100 | iresnet50] 38 | config.network = "iresnet100" 39 | config.SE=False # SEModule 40 | 41 | 42 | if config.dataset == "emoreIresNet": 43 | config.rec = "/data/psiebke/faces_emore" 44 | config.num_classes = 85742 45 | config.num_image = 5822653 46 | config.num_epoch = 26 47 | config.warmup_epoch = -1 48 | config.val_targets = ["lfw", "cfp_fp", "cfp_ff", "agedb_30", "calfw", "cplfw"] 49 | config.eval_step=5686 50 | def lr_step_func(epoch): 51 | return ((epoch + 1) / (4 + 1)) ** 2 if epoch < -1 else 0.1 ** len( 52 | [m for m in [8, 14,20,25] if m - 1 <= epoch]) # [m for m in [8, 14,20,25] if m - 1 <= epoch]) 53 | config.lr_func = lr_step_func 54 | 55 | elif config.dataset == "webface": 56 | config.rec = "/data/fboutros/faces_webface_112x112" 57 | config.num_classes = 10572 58 | config.num_image = 501195 59 | config.num_epoch = 40 # [22, 30, 35] 60 | config.warmup_epoch = -1 61 | config.val_targets = ["lfw", "cfp_fp", "cfp_ff", "agedb_30", "calfw", "cplfw"] 62 | config.eval_step= 958 #33350 63 | def lr_step_func(epoch): 64 | return ((epoch + 1) / (4 + 1)) ** 2 if epoch < config.warmup_epoch else 0.1 ** len( 65 | [m for m in [22, 30, 40] if m - 1 <= epoch]) 66 | config.lr_func = lr_step_func 67 | -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/eval/__init__.py -------------------------------------------------------------------------------- /eval/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/eval/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /eval/__pycache__/verification.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/eval/__pycache__/verification.cpython-38.pyc -------------------------------------------------------------------------------- /eval/evaluation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | #import cv2 5 | import sys 6 | import torch 7 | sys.path.append('/home/fboutros/ElasticFace') 8 | 9 | from utils.utils_callbacks import CallBackVerification 10 | from utils.utils_logging import init_logging 11 | 12 | from config.config import config as cfg 13 | 14 | from backbones.iresnet import iresnet100, iresnet50 15 | 16 | if __name__ == "__main__": 17 | gpu_id = 0 18 | log_root = logging.getLogger() 19 | init_logging(log_root, 0, cfg.output,logfile="test1.log") 20 | callback_verification = CallBackVerification(1, 0, cfg.val_targets, cfg.rec) 21 | output_folder=cfg.output 22 | weights=os.listdir(output_folder) 23 | for w in weights: 24 | if "backbone" in w: 25 | if cfg.network == "iresnet100": 26 | backbone = iresnet100(num_features=cfg.embedding_size).to(f"cuda:{gpu_id}") 27 | elif cfg.network == "iresnet50": 28 | backbone = iresnet50(num_classes=cfg.embedding_size).to(f"cuda:{gpu_id}") 29 | else: 30 | backbone = None 31 | exit() 32 | backbone.load_state_dict(torch.load(os.path.join(output_folder,w))) 33 | model = torch.nn.DataParallel(backbone, device_ids=[gpu_id]) 34 | callback_verification(int(w.split("backbone")[0]),model) 35 | 36 | -------------------------------------------------------------------------------- /eval/verification.py: -------------------------------------------------------------------------------- 1 | """Helper for evaluation on the Labeled Faces in the Wild dataset 2 | """ 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2016 David Sandberg 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | 27 | import datetime 28 | import os 29 | import pickle 30 | 31 | import mxnet as mx 32 | import numpy as np 33 | import sklearn 34 | import torch 35 | from mxnet import ndarray as nd 36 | from scipy import interpolate 37 | from sklearn.decomposition import PCA 38 | from sklearn.model_selection import KFold 39 | 40 | 41 | class LFold: 42 | def __init__(self, n_splits=2, shuffle=False): 43 | self.n_splits = n_splits 44 | if self.n_splits > 1: 45 | self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) 46 | 47 | def split(self, indices): 48 | if self.n_splits > 1: 49 | return self.k_fold.split(indices) 50 | else: 51 | return [(indices, indices)] 52 | 53 | 54 | def calculate_roc(thresholds, 55 | embeddings1, 56 | embeddings2, 57 | actual_issame, 58 | nrof_folds=10, 59 | pca=0): 60 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 61 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 62 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 63 | nrof_thresholds = len(thresholds) 64 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 65 | 66 | tprs = np.zeros((nrof_folds, nrof_thresholds)) 67 | fprs = np.zeros((nrof_folds, nrof_thresholds)) 68 | accuracy = np.zeros((nrof_folds)) 69 | indices = np.arange(nrof_pairs) 70 | 71 | if pca == 0: 72 | diff = np.subtract(embeddings1, embeddings2) 73 | dist = np.sum(np.square(diff), 1) 74 | 75 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 76 | if pca > 0: 77 | print('doing pca on', fold_idx) 78 | embed1_train = embeddings1[train_set] 79 | embed2_train = embeddings2[train_set] 80 | _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) 81 | pca_model = PCA(n_components=pca) 82 | pca_model.fit(_embed_train) 83 | embed1 = pca_model.transform(embeddings1) 84 | embed2 = pca_model.transform(embeddings2) 85 | embed1 = sklearn.preprocessing.normalize(embed1) 86 | embed2 = sklearn.preprocessing.normalize(embed2) 87 | diff = np.subtract(embed1, embed2) 88 | dist = np.sum(np.square(diff), 1) 89 | 90 | # Find the best threshold for the fold 91 | acc_train = np.zeros((nrof_thresholds)) 92 | for threshold_idx, threshold in enumerate(thresholds): 93 | _, _, acc_train[threshold_idx] = calculate_accuracy( 94 | threshold, dist[train_set], actual_issame[train_set]) 95 | best_threshold_index = np.argmax(acc_train) 96 | for threshold_idx, threshold in enumerate(thresholds): 97 | tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( 98 | threshold, dist[test_set], 99 | actual_issame[test_set]) 100 | _, _, accuracy[fold_idx] = calculate_accuracy( 101 | thresholds[best_threshold_index], dist[test_set], 102 | actual_issame[test_set]) 103 | 104 | tpr = np.mean(tprs, 0) 105 | fpr = np.mean(fprs, 0) 106 | return tpr, fpr, accuracy 107 | 108 | 109 | def calculate_accuracy(threshold, dist, actual_issame): 110 | predict_issame = np.less(dist, threshold) 111 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 112 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 113 | tn = np.sum( 114 | np.logical_and(np.logical_not(predict_issame), 115 | np.logical_not(actual_issame))) 116 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 117 | 118 | tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) 119 | fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) 120 | acc = float(tp + tn) / dist.size 121 | return tpr, fpr, acc 122 | 123 | 124 | def calculate_val(thresholds, 125 | embeddings1, 126 | embeddings2, 127 | actual_issame, 128 | far_target, 129 | nrof_folds=10): 130 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 131 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 132 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 133 | nrof_thresholds = len(thresholds) 134 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 135 | 136 | val = np.zeros(nrof_folds) 137 | far = np.zeros(nrof_folds) 138 | 139 | diff = np.subtract(embeddings1, embeddings2) 140 | dist = np.sum(np.square(diff), 1) 141 | indices = np.arange(nrof_pairs) 142 | 143 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 144 | 145 | # Find the threshold that gives FAR = far_target 146 | far_train = np.zeros(nrof_thresholds) 147 | for threshold_idx, threshold in enumerate(thresholds): 148 | _, far_train[threshold_idx] = calculate_val_far( 149 | threshold, dist[train_set], actual_issame[train_set]) 150 | if np.max(far_train) >= far_target: 151 | f = interpolate.interp1d(far_train, thresholds, kind='slinear') 152 | threshold = f(far_target) 153 | else: 154 | threshold = 0.0 155 | 156 | val[fold_idx], far[fold_idx] = calculate_val_far( 157 | threshold, dist[test_set], actual_issame[test_set]) 158 | 159 | val_mean = np.mean(val) 160 | far_mean = np.mean(far) 161 | val_std = np.std(val) 162 | return val_mean, val_std, far_mean 163 | 164 | 165 | def calculate_val_far(threshold, dist, actual_issame): 166 | predict_issame = np.less(dist, threshold) 167 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 168 | false_accept = np.sum( 169 | np.logical_and(predict_issame, np.logical_not(actual_issame))) 170 | n_same = np.sum(actual_issame) 171 | n_diff = np.sum(np.logical_not(actual_issame)) 172 | # print(true_accept, false_accept) 173 | # print(n_same, n_diff) 174 | val = float(true_accept) / float(n_same) 175 | far = float(false_accept) / float(n_diff) 176 | return val, far 177 | 178 | 179 | def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): 180 | # Calculate evaluation metrics 181 | thresholds = np.arange(0, 4, 0.01) 182 | embeddings1 = embeddings[0::2] 183 | embeddings2 = embeddings[1::2] 184 | tpr, fpr, accuracy = calculate_roc(thresholds, 185 | embeddings1, 186 | embeddings2, 187 | np.asarray(actual_issame), 188 | nrof_folds=nrof_folds, 189 | pca=pca) 190 | thresholds = np.arange(0, 4, 0.001) 191 | val, val_std, far = calculate_val(thresholds, 192 | embeddings1, 193 | embeddings2, 194 | np.asarray(actual_issame), 195 | 1e-3, 196 | nrof_folds=nrof_folds) 197 | return tpr, fpr, accuracy, val, val_std, far 198 | 199 | @torch.no_grad() 200 | def load_bin(path, image_size): 201 | try: 202 | with open(path, 'rb') as f: 203 | bins, issame_list = pickle.load(f) # py2 204 | except UnicodeDecodeError as e: 205 | with open(path, 'rb') as f: 206 | bins, issame_list = pickle.load(f, encoding='bytes') # py3 207 | data_list = [] 208 | for flip in [0, 1]: 209 | data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) 210 | data_list.append(data) 211 | for idx in range(len(issame_list) * 2): 212 | _bin = bins[idx] 213 | img = mx.image.imdecode(_bin) 214 | if img.shape[1] != image_size[0]: 215 | img = mx.image.resize_short(img, image_size[0]) 216 | img = nd.transpose(img, axes=(2, 0, 1)) 217 | for flip in [0, 1]: 218 | if flip == 1: 219 | img = mx.ndarray.flip(data=img, axis=2) 220 | data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) 221 | if idx % 1000 == 0: 222 | print('loading bin', idx) 223 | print(data_list[0].shape) 224 | return data_list, issame_list 225 | 226 | @torch.no_grad() 227 | def test(data_set, backbone, batch_size, nfolds=10): 228 | print('testing verification..') 229 | data_list = data_set[0] 230 | issame_list = data_set[1] 231 | embeddings_list = [] 232 | time_consumed = 0.0 233 | for i in range(len(data_list)): 234 | data = data_list[i] 235 | embeddings = None 236 | ba = 0 237 | while ba < data.shape[0]: 238 | bb = min(ba + batch_size, data.shape[0]) 239 | count = bb - ba 240 | _data = data[bb - batch_size: bb] 241 | time0 = datetime.datetime.now() 242 | img = ((_data / 255) - 0.5) / 0.5 243 | net_out: torch.Tensor = backbone(img) 244 | _embeddings = net_out.detach().cpu().numpy() 245 | time_now = datetime.datetime.now() 246 | diff = time_now - time0 247 | time_consumed += diff.total_seconds() 248 | if embeddings is None: 249 | embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) 250 | embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] 251 | ba = bb 252 | embeddings_list.append(embeddings) 253 | 254 | _xnorm = 0.0 255 | _xnorm_cnt = 0 256 | for embed in embeddings_list: 257 | for i in range(embed.shape[0]): 258 | _em = embed[i] 259 | _norm = np.linalg.norm(_em) 260 | _xnorm += _norm 261 | _xnorm_cnt += 1 262 | _xnorm /= _xnorm_cnt 263 | 264 | embeddings = embeddings_list[0].copy() 265 | embeddings = sklearn.preprocessing.normalize(embeddings) 266 | acc1 = 0.0 267 | std1 = 0.0 268 | embeddings = embeddings_list[0] + embeddings_list[1] 269 | embeddings = sklearn.preprocessing.normalize(embeddings) 270 | print(embeddings.shape) 271 | print('infer time', time_consumed) 272 | _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) 273 | acc2, std2 = np.mean(accuracy), np.std(accuracy) 274 | return acc1, std1, acc2, std2, _xnorm, embeddings_list 275 | 276 | 277 | def dumpR(data_set, 278 | backbone, 279 | batch_size, 280 | name='', 281 | data_extra=None, 282 | label_shape=None): 283 | print('dump verification embedding..') 284 | data_list = data_set[0] 285 | issame_list = data_set[1] 286 | embeddings_list = [] 287 | time_consumed = 0.0 288 | for i in range(len(data_list)): 289 | data = data_list[i] 290 | embeddings = None 291 | ba = 0 292 | while ba < data.shape[0]: 293 | bb = min(ba + batch_size, data.shape[0]) 294 | count = bb - ba 295 | 296 | _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) 297 | time0 = datetime.datetime.now() 298 | if data_extra is None: 299 | db = mx.io.DataBatch(data=(_data,), label=(_label,)) 300 | else: 301 | db = mx.io.DataBatch(data=(_data, _data_extra), 302 | label=(_label,)) 303 | model.forward(db, is_train=False) 304 | net_out = model.get_outputs() 305 | _embeddings = net_out[0].asnumpy() 306 | time_now = datetime.datetime.now() 307 | diff = time_now - time0 308 | time_consumed += diff.total_seconds() 309 | if embeddings is None: 310 | embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) 311 | embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] 312 | ba = bb 313 | embeddings_list.append(embeddings) 314 | embeddings = embeddings_list[0] + embeddings_list[1] 315 | embeddings = sklearn.preprocessing.normalize(embeddings) 316 | actual_issame = np.asarray(issame_list) 317 | outname = os.path.join('temp.bin') 318 | with open(outname, 'wb') as f: 319 | pickle.dump((embeddings, issame_list), 320 | f, 321 | protocol=pickle.HIGHEST_PROTOCOL) 322 | 323 | 324 | # if __name__ == '__main__': 325 | # 326 | # parser = argparse.ArgumentParser(description='do verification') 327 | # # general 328 | # parser.add_argument('--data-dir', default='', help='') 329 | # parser.add_argument('--model', 330 | # default='../model/softmax,50', 331 | # help='path to load model.') 332 | # parser.add_argument('--target', 333 | # default='lfw,cfp_ff,cfp_fp,agedb_30', 334 | # help='test targets.') 335 | # parser.add_argument('--gpu', default=0, type=int, help='gpu id') 336 | # parser.add_argument('--batch-size', default=32, type=int, help='') 337 | # parser.add_argument('--max', default='', type=str, help='') 338 | # parser.add_argument('--mode', default=0, type=int, help='') 339 | # parser.add_argument('--nfolds', default=10, type=int, help='') 340 | # args = parser.parse_args() 341 | # image_size = [112, 112] 342 | # print('image_size', image_size) 343 | # ctx = mx.gpu(args.gpu) 344 | # nets = [] 345 | # vec = args.model.split(',') 346 | # prefix = args.model.split(',')[0] 347 | # epochs = [] 348 | # if len(vec) == 1: 349 | # pdir = os.path.dirname(prefix) 350 | # for fname in os.listdir(pdir): 351 | # if not fname.endswith('.params'): 352 | # continue 353 | # _file = os.path.join(pdir, fname) 354 | # if _file.startswith(prefix): 355 | # epoch = int(fname.split('.')[0].split('-')[1]) 356 | # epochs.append(epoch) 357 | # epochs = sorted(epochs, reverse=True) 358 | # if len(args.max) > 0: 359 | # _max = [int(x) for x in args.max.split(',')] 360 | # assert len(_max) == 2 361 | # if len(epochs) > _max[1]: 362 | # epochs = epochs[_max[0]:_max[1]] 363 | # 364 | # else: 365 | # epochs = [int(x) for x in vec[1].split('|')] 366 | # print('model number', len(epochs)) 367 | # time0 = datetime.datetime.now() 368 | # for epoch in epochs: 369 | # print('loading', prefix, epoch) 370 | # sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) 371 | # # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) 372 | # all_layers = sym.get_internals() 373 | # sym = all_layers['fc1_output'] 374 | # model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) 375 | # # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) 376 | # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], 377 | # image_size[1]))]) 378 | # model.set_params(arg_params, aux_params) 379 | # nets.append(model) 380 | # time_now = datetime.datetime.now() 381 | # diff = time_now - time0 382 | # print('model loading time', diff.total_seconds()) 383 | # 384 | # ver_list = [] 385 | # ver_name_list = [] 386 | # for name in args.target.split(','): 387 | # path = os.path.join(args.data_dir, name + ".bin") 388 | # if os.path.exists(path): 389 | # print('loading.. ', name) 390 | # data_set = load_bin(path, image_size) 391 | # ver_list.append(data_set) 392 | # ver_name_list.append(name) 393 | # 394 | # if args.mode == 0: 395 | # for i in range(len(ver_list)): 396 | # results = [] 397 | # for model in nets: 398 | # acc1, std1, acc2, std2, xnorm, embeddings_list = test( 399 | # ver_list[i], model, args.batch_size, args.nfolds) 400 | # print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) 401 | # print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) 402 | # print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) 403 | # results.append(acc2) 404 | # print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) 405 | # elif args.mode == 1: 406 | # raise ValueError 407 | # else: 408 | # model = nets[0] 409 | # dumpR(ver_list[0], model, args.batch_size, args.target) 410 | -------------------------------------------------------------------------------- /images/margins.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/images/margins.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | sklearn 4 | matplotlib 5 | pandas 6 | scikit-image 7 | menpo 8 | prettytable 9 | mxnet==1.6.0 10 | pytorch_summary 11 | cv2 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=4 2 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --nnodes=1 \ 3 | --node_rank=0 --master_addr="127.0.0.1" --master_port=1235 train.py 4 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh 5 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn.functional as F 9 | from torch.nn.parallel.distributed import DistributedDataParallel 10 | import torch.utils.data.distributed 11 | from torch.nn.utils import clip_grad_norm_ 12 | from torch.nn import CrossEntropyLoss 13 | 14 | from utils import losses 15 | from config.config import config as cfg 16 | from utils.dataset import MXFaceDataset, DataLoaderX 17 | from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint 18 | from utils.utils_logging import AverageMeter, init_logging 19 | 20 | from backbones.iresnet import iresnet100, iresnet50 21 | 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | def main(args): 26 | dist.init_process_group(backend='nccl', init_method='env://') 27 | local_rank = args.local_rank 28 | torch.cuda.set_device(local_rank) 29 | rank = dist.get_rank() 30 | world_size = dist.get_world_size() 31 | 32 | if not os.path.exists(cfg.output) and rank == 0: 33 | os.makedirs(cfg.output) 34 | else: 35 | time.sleep(2) 36 | 37 | log_root = logging.getLogger() 38 | init_logging(log_root, rank, cfg.output) 39 | 40 | trainset = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) 41 | 42 | train_sampler = torch.utils.data.distributed.DistributedSampler( 43 | trainset, shuffle=True) 44 | 45 | train_loader = DataLoaderX( 46 | local_rank=local_rank, dataset=trainset, batch_size=cfg.batch_size, 47 | sampler=train_sampler, num_workers=0, pin_memory=True, drop_last=True) 48 | 49 | # load model 50 | if cfg.network == "iresnet100": 51 | backbone = iresnet100(num_features=cfg.embedding_size, use_se=cfg.SE).to(local_rank) 52 | elif cfg.network == "iresnet50": 53 | backbone = iresnet50(dropout=0.4,num_features=cfg.embedding_size, use_se=cfg.SE).to(local_rank) 54 | else: 55 | backbone = None 56 | logging.info("load backbone failed!") 57 | exit() 58 | 59 | if args.resume: 60 | try: 61 | backbone_pth = os.path.join(cfg.output, str(cfg.global_step) + "backbone.pth") 62 | backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) 63 | 64 | if rank == 0: 65 | logging.info("backbone resume loaded successfully!") 66 | except (FileNotFoundError, KeyError, IndexError, RuntimeError): 67 | logging.info("load backbone resume init, failed!") 68 | 69 | for ps in backbone.parameters(): 70 | dist.broadcast(ps, 0) 71 | 72 | backbone = DistributedDataParallel( 73 | module=backbone, broadcast_buffers=False, device_ids=[local_rank]) 74 | backbone.train() 75 | 76 | # get header 77 | if cfg.loss == "ElasticArcFace": 78 | header = losses.ElasticArcFace(in_features=cfg.embedding_size, out_features=cfg.num_classes, s=cfg.s, m=cfg.m,std=cfg.std).to(local_rank) 79 | elif cfg.loss == "ElasticArcFacePlus": 80 | header = losses.ElasticArcFace(in_features=cfg.embedding_size, out_features=cfg.num_classes, s=cfg.s, m=cfg.m, 81 | std=cfg.std, plus=True).to(local_rank) 82 | elif cfg.loss == "ElasticCosFace": 83 | header = losses.ElasticCosFace(in_features=cfg.embedding_size, out_features=cfg.num_classes, s=cfg.s, m=cfg.m,std=cfg.std).to(local_rank) 84 | elif cfg.loss == "ElasticCosFacePlus": 85 | header = losses.ElasticCosFace(in_features=cfg.embedding_size, out_features=cfg.num_classes, s=cfg.s, m=cfg.m, 86 | std=cfg.std, plus=True).to(local_rank) 87 | elif cfg.loss == "ArcFace": 88 | header = losses.ArcFace(in_features=cfg.embedding_size, out_features=cfg.num_classes, s=cfg.s, m=cfg.m).to(local_rank) 89 | elif cfg.loss == "CosFace": 90 | header = losses.CosFace(in_features=cfg.embedding_size, out_features=cfg.num_classes, s=cfg.s, m=cfg.m).to( 91 | local_rank) 92 | else: 93 | print("Header not implemented") 94 | if args.resume: 95 | try: 96 | header_pth = os.path.join(cfg.output, str(cfg.global_step) + "header.pth") 97 | header.load_state_dict(torch.load(header_pth, map_location=torch.device(local_rank))) 98 | 99 | if rank == 0: 100 | logging.info("header resume loaded successfully!") 101 | except (FileNotFoundError, KeyError, IndexError, RuntimeError): 102 | logging.info("header resume init, failed!") 103 | 104 | header = DistributedDataParallel( 105 | module=header, broadcast_buffers=False, device_ids=[local_rank]) 106 | header.train() 107 | 108 | opt_backbone = torch.optim.SGD( 109 | params=[{'params': backbone.parameters()}], 110 | lr=cfg.lr / 512 * cfg.batch_size * world_size, 111 | momentum=0.9, weight_decay=cfg.weight_decay) 112 | opt_header = torch.optim.SGD( 113 | params=[{'params': header.parameters()}], 114 | lr=cfg.lr / 512 * cfg.batch_size * world_size, 115 | momentum=0.9, weight_decay=cfg.weight_decay) 116 | 117 | scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( 118 | optimizer=opt_backbone, lr_lambda=cfg.lr_func) 119 | scheduler_header = torch.optim.lr_scheduler.LambdaLR( 120 | optimizer=opt_header, lr_lambda=cfg.lr_func) 121 | 122 | criterion = CrossEntropyLoss() 123 | 124 | start_epoch = 0 125 | total_step = int(len(trainset) / cfg.batch_size / world_size * cfg.num_epoch) 126 | if rank == 0: logging.info("Total Step is: %d" % total_step) 127 | 128 | if args.resume: 129 | rem_steps = (total_step - cfg.global_step) 130 | cur_epoch = cfg.num_epoch - int(cfg.num_epoch / total_step * rem_steps) 131 | logging.info("resume from estimated epoch {}".format(cur_epoch)) 132 | logging.info("remaining steps {}".format(rem_steps)) 133 | 134 | start_epoch = cur_epoch 135 | scheduler_backbone.last_epoch = cur_epoch 136 | scheduler_header.last_epoch = cur_epoch 137 | 138 | # --------- this could be solved more elegant ---------------- 139 | opt_backbone.param_groups[0]['lr'] = scheduler_backbone.get_lr()[0] 140 | opt_header.param_groups[0]['lr'] = scheduler_header.get_lr()[0] 141 | 142 | print("last learning rate: {}".format(scheduler_header.get_lr())) 143 | # ------------------------------------------------------------ 144 | 145 | callback_verification = CallBackVerification(cfg.eval_step, rank, cfg.val_targets, cfg.rec) 146 | callback_logging = CallBackLogging(50, rank, total_step, cfg.batch_size, world_size, writer=None) 147 | callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) 148 | 149 | loss = AverageMeter() 150 | global_step = cfg.global_step 151 | for epoch in range(start_epoch, cfg.num_epoch): 152 | train_sampler.set_epoch(epoch) 153 | for _, (img, label) in enumerate(train_loader): 154 | global_step += 1 155 | img = img.cuda(local_rank, non_blocking=True) 156 | label = label.cuda(local_rank, non_blocking=True) 157 | 158 | features = F.normalize(backbone(img)) 159 | 160 | thetas = header(features, label) 161 | loss_v = criterion(thetas, label) 162 | loss_v.backward() 163 | 164 | clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) 165 | 166 | opt_backbone.step() 167 | opt_header.step() 168 | 169 | opt_backbone.zero_grad() 170 | opt_header.zero_grad() 171 | 172 | loss.update(loss_v.item(), 1) 173 | 174 | callback_logging(global_step, loss, epoch) 175 | callback_verification(global_step, backbone) 176 | 177 | scheduler_backbone.step() 178 | scheduler_header.step() 179 | 180 | callback_checkpoint(global_step, backbone, header) 181 | 182 | dist.destroy_process_group() 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser(description='PyTorch margin penalty loss training') 187 | parser.add_argument('--local_rank', type=int, default=0, help='local_rank') 188 | parser.add_argument('--resume', type=int, default=0, help="resume training") 189 | args_ = parser.parse_args() 190 | main(args_) 191 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/countFLOPS.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/utils/__pycache__/countFLOPS.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/utils/__pycache__/utils_callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fdbtrs/ElasticFace/5496ec0643bfdcd5e6b53b885aa4e919e26ce442/utils/__pycache__/utils_logging.cpython-38.pyc -------------------------------------------------------------------------------- /utils/countFLOPS.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import numpy as np 3 | 4 | import torch 5 | 6 | def count_model_flops(model, input_res=[112, 112], multiply_adds=True): 7 | list_conv = [] 8 | 9 | def conv_hook(self, input, output): 10 | batch_size, input_channels, input_height, input_width = input[0].size() 11 | output_channels, output_height, output_width = output[0].size() 12 | 13 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 14 | bias_ops = 1 if self.bias is not None else 0 15 | 16 | params = output_channels * (kernel_ops + bias_ops) 17 | flops = (kernel_ops * ( 18 | 2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 19 | 20 | list_conv.append(flops) 21 | 22 | list_linear = [] 23 | 24 | def linear_hook(self, input, output): 25 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 26 | 27 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 28 | if self.bias is not None: 29 | bias_ops = self.bias.nelement() if self.bias.nelement() else 0 30 | flops = batch_size * (weight_ops + bias_ops) 31 | else: 32 | flops = batch_size * weight_ops 33 | list_linear.append(flops) 34 | 35 | list_bn = [] 36 | 37 | def bn_hook(self, input, output): 38 | list_bn.append(input[0].nelement() * 2) 39 | 40 | list_relu = [] 41 | 42 | def relu_hook(self, input, output): 43 | list_relu.append(input[0].nelement()) 44 | 45 | list_pooling = [] 46 | 47 | def pooling_hook(self, input, output): 48 | batch_size, input_channels, input_height, input_width = input[0].size() 49 | output_channels, output_height, output_width = output[0].size() 50 | 51 | kernel_ops = self.kernel_size * self.kernel_size 52 | bias_ops = 0 53 | params = 0 54 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 55 | 56 | list_pooling.append(flops) 57 | def pooling_hook_ad(self, input, output): 58 | batch_size, input_channels, input_height, input_width = input[0].size() 59 | input = input[0] 60 | flops = int(np.prod(input.shape)) 61 | list_pooling.append(flops) 62 | 63 | handles = [] 64 | 65 | def foo(net): 66 | childrens = list(net.children()) 67 | if not childrens: 68 | if isinstance(net, torch.nn.Conv2d) or isinstance(net, torch.nn.ConvTranspose2d): 69 | handles.append(net.register_forward_hook(conv_hook)) 70 | elif isinstance(net, torch.nn.Linear): 71 | handles.append(net.register_forward_hook(linear_hook)) 72 | elif isinstance(net, torch.nn.BatchNorm2d) or isinstance(net, torch.nn.BatchNorm1d): 73 | handles.append(net.register_forward_hook(bn_hook)) 74 | elif isinstance(net, torch.nn.ReLU) or isinstance(net, torch.nn.PReLU): 75 | handles.append(net.register_forward_hook(relu_hook)) 76 | elif isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 77 | handles.append(net.register_forward_hook(pooling_hook)) 78 | else: 79 | print("warning" + str(net)) 80 | return 81 | for c in childrens: 82 | foo(c) 83 | 84 | model.eval() 85 | foo(model) 86 | input = Variable(torch.rand(3, input_res[1], input_res[0]).unsqueeze(0), requires_grad=True) 87 | out = model(input) 88 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling)) 89 | for h in handles: 90 | h.remove() 91 | model.train() 92 | return flops_to_string(total_flops) 93 | 94 | def flops_to_string(flops, units='MFLOPS', precision=4): 95 | if units == 'GFLOPS': 96 | return str(round(flops / 10.**9, precision)) + ' ' + units 97 | elif units == 'MFLOPS': 98 | return str(round(flops / 10.**6, precision)) + ' ' + units 99 | elif units == 'KFLOPS': 100 | return str(round(flops / 10.**3, precision)) + ' ' + units 101 | else: 102 | return str(flops) + ' FLOPS' 103 | 104 | def _calc_width(net): 105 | net_params = filter(lambda p: p.requires_grad, net.parameters()) 106 | weight_count = 0 107 | for param in net_params: 108 | weight_count += np.prod(param.size()) 109 | return weight_count -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os 3 | import queue as Queue 4 | import threading 5 | 6 | import mxnet as mx 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | from torchvision import transforms 11 | import cv2 12 | 13 | class BackgroundGenerator(threading.Thread): 14 | def __init__(self, generator, local_rank, max_prefetch=6): 15 | super(BackgroundGenerator, self).__init__() 16 | self.queue = Queue.Queue(max_prefetch) 17 | self.generator = generator 18 | self.local_rank = local_rank 19 | self.daemon = True 20 | self.start() 21 | 22 | def run(self): 23 | torch.cuda.set_device(self.local_rank) 24 | for item in self.generator: 25 | self.queue.put(item) 26 | self.queue.put(None) 27 | 28 | def next(self): 29 | next_item = self.queue.get() 30 | if next_item is None: 31 | raise StopIteration 32 | return next_item 33 | 34 | def __next__(self): 35 | return self.next() 36 | 37 | def __iter__(self): 38 | return self 39 | 40 | 41 | class DataLoaderX(DataLoader): 42 | def __init__(self, local_rank, **kwargs): 43 | super(DataLoaderX, self).__init__(**kwargs) 44 | self.stream = torch.cuda.Stream(local_rank) 45 | self.local_rank = local_rank 46 | 47 | def __iter__(self): 48 | self.iter = super(DataLoaderX, self).__iter__() 49 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 50 | self.preload() 51 | return self 52 | 53 | def preload(self): 54 | self.batch = next(self.iter, None) 55 | if self.batch is None: 56 | return None 57 | with torch.cuda.stream(self.stream): 58 | for k in range(len(self.batch)): 59 | self.batch[k] = self.batch[k].to(device=self.local_rank, 60 | non_blocking=True) 61 | 62 | def __next__(self): 63 | torch.cuda.current_stream().wait_stream(self.stream) 64 | batch = self.batch 65 | if batch is None: 66 | raise StopIteration 67 | self.preload() 68 | return batch 69 | 70 | 71 | class MXFaceDataset(Dataset): 72 | def __init__(self, root_dir, local_rank): 73 | super(MXFaceDataset, self).__init__() 74 | self.transform = transforms.Compose( 75 | [transforms.ToPILImage(), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 79 | ]) 80 | self.root_dir = root_dir 81 | self.local_rank = local_rank 82 | path_imgrec = os.path.join(root_dir, 'train.rec') 83 | path_imgidx = os.path.join(root_dir, 'train.idx') 84 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') 85 | s = self.imgrec.read_idx(0) 86 | header, _ = mx.recordio.unpack(s) 87 | if header.flag > 0: 88 | self.header0 = (int(header.label[0]), int(header.label[1])) 89 | self.imgidx = np.array(range(1, int(header.label[0]))) 90 | else: 91 | self.imgidx = np.array(list(self.imgrec.keys)) 92 | 93 | def __getitem__(self, index): 94 | idx = self.imgidx[index] 95 | s = self.imgrec.read_idx(idx) 96 | header, img = mx.recordio.unpack(s) 97 | label = header.label 98 | if not isinstance(label, numbers.Number): 99 | label = label[0] 100 | label = torch.tensor(label, dtype=torch.long) 101 | sample = mx.image.imdecode(img).asnumpy() 102 | if self.transform is not None: 103 | sample = self.transform(sample) 104 | return sample, label 105 | 106 | def __len__(self): 107 | return len(self.imgidx) 108 | class FaceDatasetFolder(Dataset): 109 | def __init__(self, root_dir, local_rank): 110 | super(FaceDatasetFolder, self).__init__() 111 | self.transform = transforms.Compose( 112 | [transforms.ToPILImage(), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 116 | ]) 117 | self.root_dir = root_dir 118 | self.local_rank = local_rank 119 | self.imgidx, self.labels=self.scan(root_dir) 120 | def scan(self,root): 121 | imgidex=[] 122 | labels=[] 123 | lb=-1 124 | list_dir=os.listdir(root) 125 | list_dir.sort() 126 | for l in list_dir: 127 | images=os.listdir(os.path.join(root,l)) 128 | lb += 1 129 | for img in images: 130 | imgidex.append(os.path.join(l,img)) 131 | labels.append(lb) 132 | return imgidex,labels 133 | def readImage(self,path): 134 | return cv2.imread(os.path.join(self.root_dir,path)) 135 | 136 | def __getitem__(self, index): 137 | path = self.imgidx[index] 138 | img=self.readImage(path) 139 | label = self.labels[index] 140 | label = torch.tensor(label, dtype=torch.long) 141 | sample = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 142 | 143 | if self.transform is not None: 144 | sample = self.transform(sample) 145 | return sample, label 146 | 147 | def __len__(self): 148 | return len(self.imgidx) -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | import math 5 | import numpy as np 6 | 7 | def l2_norm(input, axis = 1): 8 | norm = torch.norm(input, 2, axis, True) 9 | output = torch.div(input, norm) 10 | 11 | return output 12 | 13 | class MLLoss(nn.Module): 14 | def __init__(self, s=64.0): 15 | super(MLLoss, self).__init__() 16 | self.s = s 17 | def forward(self, embbedings, label): 18 | embbedings = l2_norm(embbedings, axis=1) 19 | kernel_norm = l2_norm(self.kernel, axis=0) 20 | cos_theta = torch.mm(embbedings, kernel_norm) 21 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 22 | cos_theta.mul_(self.s) 23 | return cos_theta 24 | 25 | class ElasticArcFace(nn.Module): 26 | def __init__(self, in_features, out_features, s=64.0, m=0.50,std=0.0125,plus=False): 27 | super(ElasticArcFace, self).__init__() 28 | self.in_features = in_features 29 | self.out_features = out_features 30 | self.s = s 31 | self.m = m 32 | self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) 33 | nn.init.normal_(self.kernel, std=0.01) 34 | self.std=std 35 | self.plus=plus 36 | def forward(self, embbedings, label): 37 | embbedings = l2_norm(embbedings, axis=1) 38 | kernel_norm = l2_norm(self.kernel, axis=0) 39 | cos_theta = torch.mm(embbedings, kernel_norm) 40 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 41 | index = torch.where(label != -1)[0] 42 | m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device) 43 | margin = torch.normal(mean=self.m, std=self.std, size=label[index, None].size(), device=cos_theta.device) # Fast converge .clamp(self.m-self.std, self.m+self.std) 44 | if self.plus: 45 | with torch.no_grad(): 46 | distmat = cos_theta[index, label.view(-1)].detach().clone() 47 | _, idicate_cosie = torch.sort(distmat, dim=0, descending=True) 48 | margin, _ = torch.sort(margin, dim=0) 49 | m_hot.scatter_(1, label[index, None], margin[idicate_cosie]) 50 | else: 51 | m_hot.scatter_(1, label[index, None], margin) 52 | cos_theta.acos_() 53 | cos_theta[index] += m_hot 54 | cos_theta.cos_().mul_(self.s) 55 | return cos_theta 56 | 57 | 58 | class ElasticCosFace(nn.Module): 59 | def __init__(self, in_features, out_features, s=64.0, m=0.35,std=0.0125, plus=False): 60 | super(ElasticCosFace, self).__init__() 61 | self.in_features = in_features 62 | self.out_features = out_features 63 | self.s = s 64 | self.m = m 65 | self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) 66 | nn.init.normal_(self.kernel, std=0.01) 67 | self.std=std 68 | self.plus=plus 69 | 70 | def forward(self, embbedings, label): 71 | embbedings = l2_norm(embbedings, axis=1) 72 | kernel_norm = l2_norm(self.kernel, axis=0) 73 | cos_theta = torch.mm(embbedings, kernel_norm) 74 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 75 | index = torch.where(label != -1)[0] 76 | m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device) 77 | margin = torch.normal(mean=self.m, std=self.std, size=label[index, None].size(), device=cos_theta.device) # Fast converge .clamp(self.m-self.std, self.m+self.std) 78 | if self.plus: 79 | with torch.no_grad(): 80 | distmat = cos_theta[index, label.view(-1)].detach().clone() 81 | _, idicate_cosie = torch.sort(distmat, dim=0, descending=True) 82 | margin, _ = torch.sort(margin, dim=0) 83 | m_hot.scatter_(1, label[index, None], margin[idicate_cosie]) 84 | else: 85 | m_hot.scatter_(1, label[index, None], margin) 86 | cos_theta[index] -= m_hot 87 | ret = cos_theta * self.s 88 | return ret 89 | 90 | class CosFace(nn.Module): 91 | def __init__(self, in_features, out_features, s=64.0, m=0.35): 92 | super(CosFace, self).__init__() 93 | self.in_features = in_features 94 | self.out_features = out_features 95 | self.s = s 96 | self.m = m 97 | self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) 98 | nn.init.normal_(self.kernel, std=0.01) 99 | 100 | def forward(self, embbedings, label): 101 | embbedings = l2_norm(embbedings, axis=1) 102 | kernel_norm = l2_norm(self.kernel, axis=0) 103 | cos_theta = torch.mm(embbedings, kernel_norm) 104 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 105 | index = torch.where(label != -1)[0] 106 | m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device) 107 | m_hot.scatter_(1, label[index, None], self.m) 108 | cos_theta[index] -= m_hot 109 | ret = cos_theta * self.s 110 | return ret 111 | 112 | 113 | class ArcFace(nn.Module): 114 | def __init__(self, in_features, out_features, s=64.0, m=0.50): 115 | super(ArcFace, self).__init__() 116 | self.in_features = in_features 117 | self.out_features = out_features 118 | self.s = s 119 | self.m = m 120 | self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features)) 121 | nn.init.normal_(self.kernel, std=0.01) 122 | 123 | def forward(self, embbedings, label): 124 | embbedings = l2_norm(embbedings, axis=1) 125 | kernel_norm = l2_norm(self.kernel, axis=0) 126 | cos_theta = torch.mm(embbedings, kernel_norm) 127 | cos_theta = cos_theta.clamp(-1, 1) # for numerical stability 128 | index = torch.where(label != -1)[0] 129 | m_hot = torch.zeros(index.size()[0], cos_theta.size()[1], device=cos_theta.device) 130 | m_hot.scatter_(1, label[index, None], self.m) 131 | cos_theta.acos_() 132 | cos_theta[index] += m_hot 133 | cos_theta.cos_().mul_(self.s) 134 | return cos_theta 135 | 136 | 137 | -------------------------------------------------------------------------------- /utils/modelFLOPS.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pytorch_model_summary import summary 4 | 5 | import torch 6 | 7 | from utils.countFLOPS import count_model_flops 8 | 9 | from backbones.iresnet import iresnet100 10 | 11 | 12 | from config.config import config as cfg 13 | 14 | if __name__ == "__main__": 15 | # load model 16 | if cfg.network == "iresnet100": 17 | backbone = iresnet100(num_features=cfg.embedding_size) 18 | elif cfg.network == "iresnet100": 19 | backbone = iresnet100(num_features=cfg.embedding_size) 20 | else: 21 | backbone = None 22 | logging.info("load backbone failed!") 23 | 24 | print(summary(backbone, torch.zeros((1, 3, 112, 112)), show_input=False)) 25 | 26 | flops = count_model_flops(backbone) 27 | 28 | print(flops) 29 | 30 | #model.eval() 31 | #tic = time.time() 32 | 33 | #model.forward(torch.zeros((1, 3, 112, 112))) 34 | #end = time.time() 35 | #print(end-tic) 36 | -------------------------------------------------------------------------------- /utils/utils_amp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | from torch._six import container_abcs 5 | from torch.cuda.amp import GradScaler 6 | 7 | 8 | class _MultiDeviceReplicator(object): 9 | """ 10 | Lazily serves copies of a tensor to requested devices. Copies are cached per-device. 11 | """ 12 | 13 | def __init__(self, master_tensor: torch.Tensor) -> None: 14 | assert master_tensor.is_cuda 15 | self.master = master_tensor 16 | self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} 17 | 18 | def get(self, device) -> torch.Tensor: 19 | retval = self._per_device_tensors.get(device, None) 20 | if retval is None: 21 | retval = self.master.to(device=device, non_blocking=True, copy=True) 22 | self._per_device_tensors[device] = retval 23 | return retval 24 | 25 | 26 | class MaxClipGradScaler(GradScaler): 27 | def __init__(self, init_scale, max_scale: float, growth_interval=100): 28 | GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) 29 | self.max_scale = max_scale 30 | 31 | def scale_clip(self): 32 | if self.get_scale() == self.max_scale: 33 | self.set_growth_factor(1) 34 | elif self.get_scale() < self.max_scale: 35 | self.set_growth_factor(2) 36 | elif self.get_scale() > self.max_scale: 37 | self._scale.fill_(self.max_scale) 38 | self.set_growth_factor(1) 39 | 40 | def scale(self, outputs): 41 | """ 42 | Multiplies ('scales') a tensor or list of tensors by the scale factor. 43 | 44 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 45 | unmodified. 46 | 47 | Arguments: 48 | outputs (Tensor or iterable of Tensors): Outputs to scale. 49 | """ 50 | if not self._enabled: 51 | return outputs 52 | self.scale_clip() 53 | # Short-circuit for the common case. 54 | if isinstance(outputs, torch.Tensor): 55 | assert outputs.is_cuda 56 | if self._scale is None: 57 | self._lazy_init_scale_growth_tracker(outputs.device) 58 | assert self._scale is not None 59 | return outputs * self._scale.to(device=outputs.device, non_blocking=True) 60 | 61 | # Invoke the more complex machinery only if we're treating multiple outputs. 62 | stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale 63 | 64 | def apply_scale(val): 65 | if isinstance(val, torch.Tensor): 66 | assert val.is_cuda 67 | if len(stash) == 0: 68 | if self._scale is None: 69 | self._lazy_init_scale_growth_tracker(val.device) 70 | assert self._scale is not None 71 | stash.append(_MultiDeviceReplicator(self._scale)) 72 | return val * stash[0].get(val.device) 73 | elif isinstance(val, container_abcs.Iterable): 74 | iterable = map(apply_scale, val) 75 | if isinstance(val, list) or isinstance(val, tuple): 76 | return type(val)(iterable) 77 | else: 78 | return iterable 79 | else: 80 | raise ValueError("outputs must be a Tensor or an iterable of Tensors") 81 | return apply_scale(outputs) 82 | -------------------------------------------------------------------------------- /utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List 5 | 6 | import torch 7 | 8 | from eval import verification 9 | from utils.utils_logging import AverageMeter 10 | 11 | 12 | class CallBackVerification(object): 13 | def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): 14 | self.frequent: int = frequent 15 | self.rank: int = rank 16 | self.highest_acc: float = 0.0 17 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 18 | self.ver_list: List[object] = [] 19 | self.ver_name_list: List[str] = [] 20 | if self.rank == 0: 21 | self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) 22 | 23 | def ver_test(self, backbone: torch.nn.Module, global_step: int): 24 | results = [] 25 | for i in range(len(self.ver_list)): 26 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( 27 | self.ver_list[i], backbone, 10, 10) 28 | logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) 29 | logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) 30 | if acc2 > self.highest_acc_list[i]: 31 | self.highest_acc_list[i] = acc2 32 | logging.info( 33 | '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) 34 | results.append(acc2) 35 | 36 | def init_dataset(self, val_targets, data_dir, image_size): 37 | for name in val_targets: 38 | path = os.path.join(data_dir, name + ".bin") 39 | if os.path.exists(path): 40 | data_set = verification.load_bin(path, image_size) 41 | self.ver_list.append(data_set) 42 | self.ver_name_list.append(name) 43 | 44 | def __call__(self, num_update, backbone: torch.nn.Module): 45 | if self.rank == 0 and num_update > 0 and num_update % self.frequent == 0: 46 | backbone.eval() 47 | self.ver_test(backbone, num_update) 48 | backbone.train() 49 | 50 | 51 | class CallBackLogging(object): 52 | def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None, resume=0, rem_total_steps=None): 53 | self.frequent: int = frequent 54 | self.rank: int = rank 55 | self.time_start = time.time() 56 | self.total_step: int = total_step 57 | self.batch_size: int = batch_size 58 | self.world_size: int = world_size 59 | self.writer = writer 60 | self.resume = resume 61 | self.rem_total_steps = rem_total_steps 62 | 63 | self.init = False 64 | self.tic = 0 65 | 66 | def __call__(self, global_step, loss: AverageMeter, epoch: int): 67 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 68 | if self.init: 69 | try: 70 | speed: float = self.frequent * self.batch_size / (time.time() - self.tic) 71 | speed_total = speed * self.world_size 72 | except ZeroDivisionError: 73 | speed_total = float('inf') 74 | 75 | time_now = (time.time() - self.time_start) / 3600 76 | # TODO: resume time_total is not working 77 | if self.resume: 78 | time_total = time_now / ((global_step + 1) / self.rem_total_steps) 79 | else: 80 | time_total = time_now / ((global_step + 1) / self.total_step) 81 | time_for_end = time_total - time_now 82 | if self.writer is not None: 83 | self.writer.add_scalar('time_for_end', time_for_end, global_step) 84 | self.writer.add_scalar('loss', loss.avg, global_step) 85 | msg = "Speed %.2f samples/sec Loss %.4f Epoch: %d Global Step: %d Required: %1.f hours" % ( 86 | speed_total, loss.avg, epoch, global_step, time_for_end 87 | ) 88 | logging.info(msg) 89 | loss.reset() 90 | self.tic = time.time() 91 | else: 92 | self.init = True 93 | self.tic = time.time() 94 | 95 | class CallBackModelCheckpoint(object): 96 | def __init__(self, rank, output="./"): 97 | self.rank: int = rank 98 | self.output: str = output 99 | 100 | def __call__(self, global_step, backbone: torch.nn.Module, header: torch.nn.Module = None): 101 | if global_step > 100 and self.rank == 0: 102 | torch.save(backbone.module.state_dict(), os.path.join(self.output, str(global_step)+ "backbone.pth")) 103 | if global_step > 100 and header is not None: 104 | torch.save(header.module.state_dict(), os.path.join(self.output, str(global_step)+ "header.pth")) 105 | -------------------------------------------------------------------------------- /utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(log_root, rank, models_root, logfile=None): 31 | if rank is 0: 32 | log_root.setLevel(logging.INFO) 33 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 34 | file_name = "training.log" if logfile is None else logfile 35 | handler_file = logging.FileHandler(os.path.join(models_root, file_name)) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | --------------------------------------------------------------------------------