├── README.md ├── backbones ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── iresnet.cpython-36.pyc │ ├── iresnet.cpython-38.pyc │ ├── mobilefacenet.cpython-36.pyc │ └── mobilefacenet.cpython-38.pyc ├── iresnet.py ├── iresnet2060.py └── mobilefacenet.py ├── config ├── logging.conf └── model_conf.yaml ├── configs ├── 3millions.py ├── 3millions_pfc.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── base.cpython-36.pyc │ ├── base.cpython-38.pyc │ ├── imint5M3_r160-dev.cpython-38.pyc │ ├── imint5M3_r160.cpython-36.pyc │ ├── imint5M3_r160.cpython-38.pyc │ └── imint5M3_r18.cpython-36.pyc ├── base.py ├── glint360k_mbf.py ├── glint360k_r100-adaface.py ├── glint360k_r100.py ├── glint360k_r18.py ├── glint360k_r34.py ├── glint360k_r50.py ├── ms1mv3_mbf.py ├── ms1mv3_r18.py ├── ms1mv3_r2060.py ├── ms1mv3_r34.py ├── ms1mv3_r50-adaface.py ├── ms1mv3_r50.py └── speed.py ├── dataset.py ├── eval ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── verification.cpython-36.pyc │ └── verification.cpython-38.pyc └── verification.py ├── eval_ijbc.py ├── face_masker.py ├── inference.py ├── losses.py ├── onnx_helper.py ├── onnx_ijbc.py ├── partial_fc.py ├── requirement.txt ├── run.sh ├── torch2onnx.py ├── train.py ├── transforms.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc ├── read_info.cpython-36.pyc ├── read_info.cpython-38.pyc ├── utils_amp.cpython-36.pyc ├── utils_amp.cpython-38.pyc ├── utils_callbacks.cpython-36.pyc ├── utils_callbacks.cpython-38.pyc ├── utils_config.cpython-36.pyc ├── utils_config.cpython-38.pyc ├── utils_logging.cpython-36.pyc └── utils_logging.cpython-38.pyc ├── cython ├── README.md ├── UNKNOWN.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── build │ ├── lib.linux-x86_64-3.6 │ │ └── render.cpython-36m-x86_64-linux-gnu.so │ └── temp.linux-x86_64-3.6 │ │ └── render.o ├── dist │ └── UNKNOWN-0.0.0-py3.6-linux-x86_64.egg ├── render.c ├── render.cpython-36m-x86_64-linux-gnu.so ├── render.pyx └── setup.py ├── plot.py ├── read_info.py ├── render.py ├── utils_amp.py ├── utils_callbacks.py ├── utils_config.py ├── utils_logging.py └── utils_os.py /README.md: -------------------------------------------------------------------------------- 1 | # Distributed Arcface/Adaface Training in Pytorch 2 | 3 | Modified old version of PartialFC to work with AdaFace for both normal and distributed training. 4 | - [Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC (CVPR-2022)](https://arxiv.org/abs/2203.15565) 5 | - [AdaFace: Quality Adaptive Margin for Face Recognition (CVPR-2022)](https://arxiv.org/abs/2204.00964) 6 | 7 | ## Benchmark 8 | 9 | ### 1. CASIA-Webface (10k ids, 0.5M images) 10 | Current have opposite view between feature norm & image quality than described in paper 11 | - **Accuracy** 12 | 13 | | Model | Backbone | Sample Rate | LFW | CFP-FP | AGEDB-30 | LFW Blur | CFP-FP Blur | AGEDB-30 Blur | Average | 14 | |:-------:|:--------:|:-----------:|-------------------|-------------------|-------------------|-------------------|-------------------|-------------------|------------| 15 | | Arcface | IR50 | 1.0 | 0.9920 | **0.9601** | 0.9365 | 0.9323 | 0.8517 | 0.8357 | 0.9181 | 16 | | Adaface | IR50 | 1.0 | **0.9923** | 0.9587 | **0.9405** | **0.9563** | **0.8667** | **0.8632** | **0.9296** | 17 | | Arcface | IR50 | 0.3 | **0.9923** | 0.9596 | 0.9390 | 0.9323 | 0.8480 | 0.8325 | 0.9173 | 18 | | Adaface | IR50 | 0.3 | 0.9915 | 0.9567 | 0.9362 | 0.9532 | 0.8570 | 0.8548 | 0.9249 | 19 | 20 | - **Features Norm** 21 | 22 | | Model | Backbone | Sample Rate | LFW | CFP-FP | AGEDB-30 | LFW Blur | CFP-FP Blur | AGEDB-30 Blur | 23 | |:-------:|:--------:|:-----------:|--------------|--------------|--------------|--------------|--------------|---------------| 24 | | Arcface | IR50 | 1.0 | 12.72 | 12.97 | 13.23 | 12.56 | 13.16 | 12.86 | 25 | | Adaface | IR50 | 1.0 | 5.34 | 7.9 | 5.91 | 10.36 | 42.41 | 10.99 | 26 | | Arcface | IR50 | 0.3 | 14.3 | 14.1 | 14.54 | 14 | 14.68 | 14.29 | 27 | | Adaface | IR50 | 0.3 | 6.09 | 10.23 | 6.74 | 10.7 | 47.28 | 11.7 | 28 | 29 | ## How to Training 30 | 31 | To train a model, run `train.py` with the path to the configs: 32 | 33 | ### 1. Single node, 8 GPUs: 34 | 35 | ```shell 36 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50-adaface.py 37 | ``` 38 | 39 | ### 2. Multiple nodes, each node 8 GPUs (not tested with AdaFace): 40 | 41 | Node 0: 42 | 43 | ```shell 44 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50-adaface.py 45 | ``` 46 | 47 | Node 1: 48 | 49 | ```shell 50 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50-adaface.py 51 | ``` 52 | 53 | ## To do 54 | - [ ] Add ViT models 55 | - [ ] Report comparision between Adaface, Arcface & Cosface (currently in training, dataset 5M ids and 100M images) 56 | - [ ] Result for common large scale face recognition dataset (MS1MV2, MS1MV3, Glint360k, WebFace) 57 | 58 | ## Reference 59 | - [Insightface](https://github.com/deepinsight/insightface) 60 | - [Adaface](https://github.com/mk-minchul/AdaFace) 61 | -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet160, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | return iresnet18(False, **kwargs) 9 | elif name == "r34": 10 | return iresnet34(False, **kwargs) 11 | elif name == "r50": 12 | return iresnet50(False, **kwargs) 13 | elif name == "r100": 14 | return iresnet100(False, **kwargs) 15 | elif name == "r160": 16 | return iresnet160(False, **kwargs) 17 | elif name == "r200": 18 | return iresnet200(False, **kwargs) 19 | elif name == "r2060": 20 | from .iresnet2060 import iresnet2060 21 | return iresnet2060(False, **kwargs) 22 | elif name == "mbf": 23 | fp16 = kwargs.get("fp16", False) 24 | num_features = kwargs.get("num_features", 512) 25 | return get_mbf(fp16=fp16, num_features=num_features) 26 | else: 27 | raise ValueError() -------------------------------------------------------------------------------- /backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/iresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/backbones/__pycache__/iresnet.cpython-36.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/iresnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/backbones/__pycache__/iresnet.cpython-38.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/mobilefacenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/backbones/__pycache__/mobilefacenet.cpython-36.pyc -------------------------------------------------------------------------------- /backbones/__pycache__/mobilefacenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/backbones/__pycache__/mobilefacenet.cpython-38.pyc -------------------------------------------------------------------------------- /backbones/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet160', 'iresnet200'] 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, 22 | out_planes, 23 | kernel_size=1, 24 | stride=stride, 25 | bias=False) 26 | 27 | 28 | class IBasicBlock(nn.Module): 29 | expansion = 1 30 | def __init__(self, inplanes, planes, stride=1, downsample=None, 31 | groups=1, base_width=64, dilation=1): 32 | super(IBasicBlock, self).__init__() 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 38 | self.conv1 = conv3x3(inplanes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 40 | self.prelu = nn.PReLU(planes) 41 | self.conv2 = conv3x3(planes, planes, stride) 42 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | identity = x 48 | out = self.bn1(x) 49 | out = self.conv1(out) 50 | out = self.bn2(out) 51 | out = self.prelu(out) 52 | out = self.conv2(out) 53 | out = self.bn3(out) 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | out += identity 57 | return out 58 | 59 | 60 | class IResNet(nn.Module): 61 | fc_scale = 7 * 7 62 | def __init__(self, 63 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 64 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 65 | super(IResNet, self).__init__() 66 | self.fp16 = fp16 67 | self.inplanes = 64 68 | self.dilation = 1 69 | if replace_stride_with_dilation is None: 70 | replace_stride_with_dilation = [False, False, False] 71 | if len(replace_stride_with_dilation) != 3: 72 | raise ValueError("replace_stride_with_dilation should be None " 73 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 74 | self.groups = groups 75 | self.base_width = width_per_group 76 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 78 | self.prelu = nn.PReLU(self.inplanes) 79 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 80 | self.layer2 = self._make_layer(block, 81 | 128, 82 | layers[1], 83 | stride=2, 84 | dilate=replace_stride_with_dilation[0]) 85 | self.layer3 = self._make_layer(block, 86 | 256, 87 | layers[2], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[1]) 90 | self.layer4 = self._make_layer(block, 91 | 512, 92 | layers[3], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[2]) 95 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 96 | self.dropout = nn.Dropout(p=dropout, inplace=True) 97 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 98 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 99 | nn.init.constant_(self.features.weight, 1.0) 100 | self.features.weight.requires_grad = False 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.normal_(m.weight, 0, 0.1) 105 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | if zero_init_residual: 110 | for m in self.modules(): 111 | if isinstance(m, IBasicBlock): 112 | nn.init.constant_(m.bn2.weight, 0) 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 115 | downsample = None 116 | previous_dilation = self.dilation 117 | if dilate: 118 | self.dilation *= stride 119 | stride = 1 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | conv1x1(self.inplanes, planes * block.expansion, stride), 123 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 124 | ) 125 | layers = [] 126 | layers.append( 127 | block(self.inplanes, planes, stride, downsample, self.groups, 128 | self.base_width, previous_dilation)) 129 | self.inplanes = planes * block.expansion 130 | for _ in range(1, blocks): 131 | layers.append( 132 | block(self.inplanes, 133 | planes, 134 | groups=self.groups, 135 | base_width=self.base_width, 136 | dilation=self.dilation)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | with torch.cuda.amp.autocast(self.fp16): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.prelu(x) 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | x = self.bn2(x) 150 | x = torch.flatten(x, 1) 151 | x = self.dropout(x) 152 | x = self.fc(x.float() if self.fp16 else x) 153 | x = self.features(x) 154 | return x 155 | 156 | 157 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 158 | model = IResNet(block, layers, **kwargs) 159 | if pretrained: 160 | raise ValueError() 161 | return model 162 | 163 | 164 | def iresnet18(pretrained=False, progress=True, **kwargs): 165 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 166 | progress, **kwargs) 167 | 168 | 169 | def iresnet34(pretrained=False, progress=True, **kwargs): 170 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 171 | progress, **kwargs) 172 | 173 | 174 | def iresnet50(pretrained=False, progress=True, **kwargs): 175 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 176 | progress, **kwargs) 177 | 178 | 179 | def iresnet100(pretrained=False, progress=True, **kwargs): 180 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 181 | progress, **kwargs) 182 | 183 | def iresnet160(pretrained=False, progress=True, **kwargs): 184 | return _iresnet('iresnet160', IBasicBlock, [3, 24, 49, 3], pretrained, 185 | progress, **kwargs) 186 | 187 | def iresnet200(pretrained=False, progress=True, **kwargs): 188 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 189 | progress, **kwargs) 190 | 191 | -------------------------------------------------------------------------------- /backbones/iresnet2060.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | assert torch.__version__ >= "1.8.1" 5 | from torch.utils.checkpoint import checkpoint_sequential 6 | 7 | __all__ = ['iresnet2060'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=dilation, 17 | groups=groups, 18 | bias=False, 19 | dilation=dilation) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, 25 | out_planes, 26 | kernel_size=1, 27 | stride=stride, 28 | bias=False) 29 | 30 | 31 | class IBasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, 35 | groups=1, base_width=64, dilation=1): 36 | super(IBasicBlock, self).__init__() 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) 42 | self.conv1 = conv3x3(inplanes, planes) 43 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) 44 | self.prelu = nn.PReLU(planes) 45 | self.conv2 = conv3x3(planes, planes, stride) 46 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | out = self.bn1(x) 53 | out = self.conv1(out) 54 | out = self.bn2(out) 55 | out = self.prelu(out) 56 | out = self.conv2(out) 57 | out = self.bn3(out) 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | out += identity 61 | return out 62 | 63 | 64 | class IResNet(nn.Module): 65 | fc_scale = 7 * 7 66 | 67 | def __init__(self, 68 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 69 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 70 | super(IResNet, self).__init__() 71 | self.fp16 = fp16 72 | self.inplanes = 64 73 | self.dilation = 1 74 | if replace_stride_with_dilation is None: 75 | replace_stride_with_dilation = [False, False, False] 76 | if len(replace_stride_with_dilation) != 3: 77 | raise ValueError("replace_stride_with_dilation should be None " 78 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 79 | self.groups = groups 80 | self.base_width = width_per_group 81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 83 | self.prelu = nn.PReLU(self.inplanes) 84 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 85 | self.layer2 = self._make_layer(block, 86 | 128, 87 | layers[1], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[0]) 90 | self.layer3 = self._make_layer(block, 91 | 256, 92 | layers[2], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[1]) 95 | self.layer4 = self._make_layer(block, 96 | 512, 97 | layers[3], 98 | stride=2, 99 | dilate=replace_stride_with_dilation[2]) 100 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) 101 | self.dropout = nn.Dropout(p=dropout, inplace=True) 102 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 103 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 104 | nn.init.constant_(self.features.weight, 1.0) 105 | self.features.weight.requires_grad = False 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.normal_(m.weight, 0, 0.1) 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | if zero_init_residual: 115 | for m in self.modules(): 116 | if isinstance(m, IBasicBlock): 117 | nn.init.constant_(m.bn2.weight, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 120 | downsample = None 121 | previous_dilation = self.dilation 122 | if dilate: 123 | self.dilation *= stride 124 | stride = 1 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 129 | ) 130 | layers = [] 131 | layers.append( 132 | block(self.inplanes, planes, stride, downsample, self.groups, 133 | self.base_width, previous_dilation)) 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append( 137 | block(self.inplanes, 138 | planes, 139 | groups=self.groups, 140 | base_width=self.base_width, 141 | dilation=self.dilation)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def checkpoint(self, func, num_seg, x): 146 | if self.training: 147 | return checkpoint_sequential(func, num_seg, x) 148 | else: 149 | return func(x) 150 | 151 | def forward(self, x): 152 | with torch.cuda.amp.autocast(self.fp16): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.prelu(x) 156 | x = self.layer1(x) 157 | x = self.checkpoint(self.layer2, 20, x) 158 | x = self.checkpoint(self.layer3, 100, x) 159 | x = self.layer4(x) 160 | x = self.bn2(x) 161 | x = torch.flatten(x, 1) 162 | x = self.dropout(x) 163 | x = self.fc(x.float() if self.fp16 else x) 164 | x = self.features(x) 165 | return x 166 | 167 | 168 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 169 | model = IResNet(block, layers, **kwargs) 170 | if pretrained: 171 | raise ValueError() 172 | return model 173 | 174 | 175 | def iresnet2060(pretrained=False, progress=True, **kwargs): 176 | return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) 177 | -------------------------------------------------------------------------------- /backbones/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module 8 | import torch 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class ConvBlock(Module): 17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 18 | super(ConvBlock, self).__init__() 19 | self.layers = nn.Sequential( 20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), 21 | BatchNorm2d(num_features=out_c), 22 | PReLU(num_parameters=out_c) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.layers(x) 27 | 28 | 29 | class LinearBlock(Module): 30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 31 | super(LinearBlock, self).__init__() 32 | self.layers = nn.Sequential( 33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), 34 | BatchNorm2d(num_features=out_c) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.layers(x) 39 | 40 | 41 | class DepthWise(Module): 42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 43 | super(DepthWise, self).__init__() 44 | self.residual = residual 45 | self.layers = nn.Sequential( 46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), 47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), 48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 49 | ) 50 | 51 | def forward(self, x): 52 | short_cut = None 53 | if self.residual: 54 | short_cut = x 55 | x = self.layers(x) 56 | if self.residual: 57 | output = short_cut + x 58 | else: 59 | output = x 60 | return output 61 | 62 | 63 | class Residual(Module): 64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 65 | super(Residual, self).__init__() 66 | modules = [] 67 | for _ in range(num_block): 68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) 69 | self.layers = Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class GDC(Module): 76 | def __init__(self, embedding_size): 77 | super(GDC, self).__init__() 78 | self.layers = nn.Sequential( 79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), 80 | Flatten(), 81 | Linear(512, embedding_size, bias=False), 82 | BatchNorm1d(embedding_size)) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | 87 | 88 | class MobileFaceNet(Module): 89 | def __init__(self, fp16=False, num_features=512): 90 | super(MobileFaceNet, self).__init__() 91 | scale = 2 92 | self.fp16 = fp16 93 | self.layers = nn.Sequential( 94 | ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)), 95 | ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64), 96 | DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), 97 | Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 98 | DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), 99 | Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 100 | DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), 101 | Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 102 | ) 103 | self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 104 | self.features = GDC(num_features) 105 | self._initialize_weights() 106 | 107 | def _initialize_weights(self): 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 118 | if m.bias is not None: 119 | m.bias.data.zero_() 120 | 121 | def forward(self, x): 122 | with torch.cuda.amp.autocast(self.fp16): 123 | x = self.layers(x) 124 | x = self.conv_sep(x.float() if self.fp16 else x) 125 | x = self.features(x) 126 | return x 127 | 128 | 129 | def get_mbf(fp16, num_features): 130 | return MobileFaceNet(fp16, num_features) -------------------------------------------------------------------------------- /config/logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] # loggers object list 2 | keys = root, sdk, api 3 | 4 | [handlers] # handlers object list 5 | keys = consoleHandlers, fileHandlers 6 | 7 | [formatters] # formatters list 8 | keys = fmt 9 | 10 | [logger_root] 11 | level = DEBUG 12 | handlers = consoleHandlers, fileHandlers 13 | 14 | [logger_sdk] # sdk logger 15 | level = DEBUG 16 | handlers = fileHandlers 17 | qualname = sdk 18 | propagate = 0 19 | 20 | [logger_api] # api logger 21 | level = DEBUG 22 | handlers = consoleHandlers 23 | qualname = api 24 | propagate = 0 25 | 26 | [handler_consoleHandlers]# consoleHandlers. 27 | class = StreamHandler 28 | level = DEBUG 29 | formatter = fmt 30 | args = (sys.stdout,) 31 | 32 | [handler_fileHandlers]]# fileHandlers 33 | class = logging.handlers.RotatingFileHandler 34 | level = DEBUG 35 | formatter = fmt 36 | args = ('logs/sdk.log', 'a', 10000, 3, 'UTF-8') 37 | 38 | [formatter_fmt] # fmt format 39 | format = %(levelname)s %(asctime)s %(filename)s: %(lineno)d] %(message)s 40 | datefmt = %Y-%m-%d %H:%M:%S -------------------------------------------------------------------------------- /config/model_conf.yaml: -------------------------------------------------------------------------------- 1 | non-mask: 2 | face_detection: face_detection_1.0 3 | face_alignment: face_alignment_1.0 4 | face_recognition: face_recognition_1.0 5 | mask: 6 | face_detection: face_detection_2.0 7 | face_alignment: face_alignment_2.0 8 | face_recognition: face_recognition_2.0 9 | -------------------------------------------------------------------------------- /configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /configs/3millions_pfc.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 300 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__init__.py -------------------------------------------------------------------------------- /configs/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/imint5M3_r160-dev.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/imint5M3_r160-dev.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/imint5M3_r160.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/imint5M3_r160.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/imint5M3_r160.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/imint5M3_r160.cpython-38.pyc -------------------------------------------------------------------------------- /configs/__pycache__/imint5M3_r18.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/configs/__pycache__/imint5M3_r18.cpython-36.pyc -------------------------------------------------------------------------------- /configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = "ms1mv3_arcface_r50" 12 | 13 | config.dataset = "ms1m-retinaface-t1" 14 | config.embedding_size = 512 15 | config.sample_rate = 1 16 | config.fp16 = False 17 | config.momentum = 0.9 18 | config.weight_decay = 5e-4 19 | config.batch_size = 128 20 | config.lr = 0.1 # batch size is 512 21 | 22 | if config.dataset == "emore": 23 | config.rec = "/train_tmp/faces_emore" 24 | config.num_classes = 85742 25 | config.num_image = 5822653 26 | config.num_epoch = 16 27 | config.warmup_epoch = -1 28 | config.decay_epoch = [8, 14, ] 29 | config.val_targets = ["lfw", ] 30 | 31 | elif config.dataset == "ms1m-retinaface-t1": 32 | config.rec = "/train_tmp/ms1m-retinaface-t1" 33 | config.num_classes = 93431 34 | config.num_image = 5179510 35 | config.num_epoch = 25 36 | config.warmup_epoch = -1 37 | config.decay_epoch = [11, 17, 22] 38 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 39 | 40 | elif config.dataset == "glint360k": 41 | config.rec = "/train_tmp/glint360k" 42 | config.num_classes = 360232 43 | config.num_image = 17091657 44 | config.num_epoch = 20 45 | config.warmup_epoch = -1 46 | config.decay_epoch = [8, 12, 15, 18] 47 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 48 | 49 | elif config.dataset == "webface": 50 | config.rec = "/train_tmp/faces_webface_112x112" 51 | config.num_classes = 10572 52 | config.num_image = "forget" 53 | config.num_epoch = 34 54 | config.warmup_epoch = -1 55 | config.decay_epoch = [20, 28, 32] 56 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 57 | -------------------------------------------------------------------------------- /configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/glint360k_r100-adaface.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "adaface" 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/glint360k_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/glint360k_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "cosface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/glint360k" 21 | config.num_classes = 360232 22 | config.num_image = 17091657 23 | config.num_epoch = 20 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [8, 12, 15, 18] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 2e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 30 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 20, 25] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/ms1mv3_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/ms1mv3_r2060.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r2060" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 64 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/ms1mv3_r34.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r34" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/ms1mv3_r50-adaface.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "adaface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.loss = "arcface" 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 # batch size is 512 19 | 20 | config.rec = "/train_tmp/ms1m-retinaface-t1" 21 | config.num_classes = 93431 22 | config.num_image = 5179510 23 | config.num_epoch = 25 24 | config.warmup_epoch = -1 25 | config.decay_epoch = [10, 16, 22] 26 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 27 | -------------------------------------------------------------------------------- /configs/speed.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.loss = "arcface" 7 | config.network = "r50" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 1.0 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 128 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 100 * 10000 20 | config.num_epoch = 30 21 | config.warmup_epoch = -1 22 | config.decay_epoch = [10, 16, 22] 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os 3 | import queue as Queue 4 | import threading 5 | import sys 6 | 7 | # sys.path.append('/home/damnguyen/FaceRecognition/FaceMask/face_sdk/') 8 | import mxnet as mx 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader, Dataset 12 | from torchvision.transforms import functional as F 13 | # from torch.multiprocessing import set_start_method 14 | from transforms import transform_JPEGcompression, transform_gaussian_noise, transform_resize, transform_eraser 15 | from torchvision import transforms 16 | import random 17 | import yaml 18 | # from face_masker import FaceMasker 19 | import cv2 20 | from PIL import Image 21 | 22 | 23 | class BackgroundGenerator(threading.Thread): 24 | def __init__(self, generator, local_rank, max_prefetch=6): 25 | super(BackgroundGenerator, self).__init__() 26 | self.queue = Queue.Queue(max_prefetch) 27 | self.generator = generator 28 | self.local_rank = local_rank 29 | self.daemon = True 30 | self.start() 31 | 32 | def run(self): 33 | torch.cuda.set_device(self.local_rank) 34 | for item in self.generator: 35 | self.queue.put(item) 36 | self.queue.put(None) 37 | 38 | def next(self): 39 | next_item = self.queue.get() 40 | if next_item is None: 41 | raise StopIteration 42 | return next_item 43 | 44 | def __next__(self): 45 | return self.next() 46 | 47 | def __iter__(self): 48 | return self 49 | 50 | 51 | class DataLoaderX(DataLoader): 52 | 53 | def __init__(self, local_rank, **kwargs): 54 | super(DataLoaderX, self).__init__(**kwargs) 55 | self.stream = torch.cuda.Stream(local_rank) 56 | self.local_rank = local_rank 57 | 58 | def __iter__(self): 59 | self.iter = super(DataLoaderX, self).__iter__() 60 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 61 | self.preload() 62 | return self 63 | 64 | def preload(self): 65 | self.batch = next(self.iter, None) 66 | if self.batch is None: 67 | return None 68 | with torch.cuda.stream(self.stream): 69 | for k in range(len(self.batch)): 70 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 71 | 72 | def __next__(self): 73 | torch.cuda.current_stream().wait_stream(self.stream) 74 | batch = self.batch 75 | if batch is None: 76 | raise StopIteration 77 | self.preload() 78 | return batch 79 | 80 | 81 | class MXFaceDataset(Dataset): 82 | def __init__(self, root_dir, local_rank, is_train = True): 83 | super(MXFaceDataset, self).__init__() 84 | 85 | if is_train: 86 | self.transform = transforms.Compose( 87 | [transforms.ToPILImage(), 88 | # transforms.Lambda(lambda x: transform_gaussian_noise(x, mean = 0.0, var = 10.0)), 89 | transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), 90 | transforms.RandomHorizontalFlip(), 91 | # transforms.Lambda(lambda x: transform_resize(x, resize_range = (32, 112), target_size = 112)), 92 | # transforms.Lambda(lambda x: transform_JPEGcompression(x, compress_range = (30, 100))), 93 | transforms.ToTensor(), 94 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 95 | ]) 96 | else: 97 | self.transform = transforms.Compose( 98 | [transforms.ToPILImage(), 99 | transforms.RandomHorizontalFlip(), 100 | transforms.ToTensor(), 101 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 102 | ]) 103 | self.root_dir = root_dir 104 | self.local_rank = local_rank 105 | path_imgrec = os.path.join(root_dir, 'train.rec') 106 | path_imgidx = os.path.join(root_dir, 'train.idx') 107 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') 108 | s = self.imgrec.read_idx(0) 109 | header, _ = mx.recordio.unpack(s) 110 | if header.flag > 0: 111 | self.header0 = (int(header.label[0]), int(header.label[1])) 112 | self.imgidx = np.array(range(1, int(header.label[0]))) 113 | else: 114 | self.imgidx = np.array(list(self.imgrec.keys)) 115 | 116 | def __getitem__(self, index): 117 | idx = self.imgidx[index] 118 | s = self.imgrec.read_idx(idx) 119 | header, img = mx.recordio.unpack(s) 120 | label = header.label 121 | if not isinstance(label, numbers.Number): 122 | label = label[0] 123 | label = torch.tensor(label, dtype=torch.long) 124 | sample = mx.image.imdecode(img).asnumpy() 125 | 126 | # sample = Image.fromarray(sample) 127 | # sample = transform_gaussian_noise(sample, mean = 0.0, var = 10.0) 128 | # sample = transform_resize(sample, resize_range = (32, 112), target_size = 112) 129 | # sample = transform_JPEGcompression(sample, compress_range = (30, 100)) 130 | sample = np.array(sample, dtype = np.uint8) 131 | # cv2.imwrite('tmp_img/img_{}.jpg'.format(index), sample) 132 | if self.transform is not None: 133 | sample = self.transform(sample) 134 | return sample, label 135 | 136 | def __len__(self): 137 | return len(self.imgidx) 138 | 139 | def low_res_augmentation(img): 140 | # resize the image to a small size and enlarge it back 141 | img_shape = img.shape 142 | side_ratio = np.random.uniform(0.2, 1.0) 143 | small_side = int(side_ratio * img_shape[0]) 144 | interpolation = np.random.choice( 145 | [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) 146 | small_img = cv2.resize(img, (small_side, small_side), interpolation=interpolation) 147 | interpolation = np.random.choice( 148 | [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) 149 | aug_img = cv2.resize(small_img, (img_shape[1], img_shape[0]), interpolation=interpolation) 150 | 151 | return aug_img, side_ratio 152 | 153 | class AdaFaceDataset(Dataset): 154 | def __init__(self, 155 | root_dir, 156 | local_rank, 157 | is_train = True, 158 | low_res_augmentation_prob = 0.2, 159 | crop_augmentation_prob = 0.2, 160 | photometric_augmentation_prob = 0.2): 161 | 162 | super(AdaFaceDataset, self).__init__() 163 | self.low_res_augmentation_prob = low_res_augmentation_prob 164 | self.crop_augmentation_prob = crop_augmentation_prob 165 | self.photometric_augmentation_prob = photometric_augmentation_prob 166 | self.random_resized_crop = transforms.RandomResizedCrop(size=(112, 112), 167 | scale=(0.2, 1.0), 168 | ratio=(0.75, 1.3333333333333333)) 169 | self.photometric = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0) 170 | self.is_train = is_train 171 | self.transform = transforms.Compose( 172 | [ 173 | transforms.RandomHorizontalFlip(), 174 | transforms.ToTensor(), 175 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 176 | ]) 177 | self.root_dir = root_dir 178 | self.local_rank = local_rank 179 | path_imgrec = os.path.join(root_dir, 'train.rec') 180 | path_imgidx = os.path.join(root_dir, 'train.idx') 181 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') 182 | s = self.imgrec.read_idx(0) 183 | header, _ = mx.recordio.unpack(s) 184 | if header.flag > 0: 185 | self.header0 = (int(header.label[0]), int(header.label[1])) 186 | self.imgidx = np.array(range(1, int(header.label[0]))) 187 | else: 188 | self.imgidx = np.array(list(self.imgrec.keys)) 189 | 190 | def augment(self, sample): 191 | # crop with zero padding augmentation 192 | if np.random.random() < self.crop_augmentation_prob: 193 | # RandomResizedCrop augmentation 194 | new = np.zeros_like(np.array(sample)) 195 | if hasattr(F, '_get_image_size'): 196 | orig_W, orig_H = F._get_image_size(sample) 197 | else: 198 | # torchvision 0.11.0 and above 199 | orig_W, orig_H = F.get_image_size(sample) 200 | i, j, h, w = self.random_resized_crop.get_params(sample, 201 | self.random_resized_crop.scale, 202 | self.random_resized_crop.ratio) 203 | cropped = F.crop(sample, i, j, h, w) 204 | new[i:i+h,j:j+w, :] = np.array(cropped) 205 | sample = Image.fromarray(new.astype(np.uint8)) 206 | crop_ratio = min(h, w) / max(orig_H, orig_W) 207 | else: 208 | crop_ratio = 1.0 209 | 210 | # low resolution augmentation 211 | if np.random.random() < self.low_res_augmentation_prob: 212 | # low res augmentation 213 | img_np, resize_ratio = low_res_augmentation(np.array(sample)) 214 | sample = Image.fromarray(img_np.astype(np.uint8)) 215 | else: 216 | resize_ratio = 1 217 | 218 | # photometric augmentation 219 | if np.random.random() < self.photometric_augmentation_prob: 220 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ 221 | self.photometric.get_params(self.photometric.brightness, self.photometric.contrast, 222 | self.photometric.saturation, self.photometric.hue) 223 | for fn_id in fn_idx: 224 | if fn_id == 0 and brightness_factor is not None: 225 | sample = F.adjust_brightness(sample, brightness_factor) 226 | elif fn_id == 1 and contrast_factor is not None: 227 | sample = F.adjust_contrast(sample, contrast_factor) 228 | elif fn_id == 2 and saturation_factor is not None: 229 | sample = F.adjust_saturation(sample, saturation_factor) 230 | 231 | information_score = resize_ratio * crop_ratio 232 | return sample, information_score 233 | 234 | def __getitem__(self, index): 235 | idx = self.imgidx[index] 236 | s = self.imgrec.read_idx(idx) 237 | header, img = mx.recordio.unpack(s) 238 | label = header.label 239 | if not isinstance(label, numbers.Number): 240 | label = label[0] 241 | label = torch.tensor(label, dtype=torch.long) 242 | sample = mx.image.imdecode(img).asnumpy() 243 | sample = Image.fromarray(sample.astype(np.uint8)) 244 | if self.is_train: 245 | sample, _ = self.augment(sample) 246 | # print(sample.shape) 247 | # cv2.imwrite('tmp_img/img_{}.jpg'.format(index), cv2.cvtColor(np.array(sample, dtype = np.uint8), cv2.COLOR_RGB2BGR)) 248 | # sample = transform_gaussian_noise(sample, mean = 0.0, var = 10.0) 249 | # sample = transform_resize(sample, resize_range = (32, 112), target_size = 112) 250 | # sample = transform_JPEGcompression(sample, compress_range = (30, 100)) 251 | 252 | if self.transform is not None: 253 | sample = self.transform(sample) 254 | return sample, label 255 | 256 | def __len__(self): 257 | return len(self.imgidx) 258 | 259 | class SyntheticDataset(Dataset): 260 | def __init__(self, local_rank): 261 | super(SyntheticDataset, self).__init__() 262 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 263 | img = np.transpose(img, (2, 0, 1)) 264 | img = torch.from_numpy(img).squeeze(0).float() 265 | img = ((img / 255) - 0.5) / 0.5 266 | self.img = img 267 | self.label = 1 268 | 269 | def __getitem__(self, index): 270 | return self.img, self.label 271 | 272 | def __len__(self): 273 | return 1000000 274 | 275 | def dali_data_iter( 276 | batch_size: int, root_dir: str, num_threads: int, 277 | initial_fill=32768, random_shuffle=False, 278 | prefetch_queue_depth=512, local_rank=0, name="reader", 279 | mean=(127.5, 127.5, 127.5), 280 | std=(127.5, 127.5, 127.5)): 281 | """ 282 | Parameters: 283 | ---------- 284 | initial_fill: int 285 | Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored. 286 | 287 | """ 288 | import nvidia.dali.fn as fn 289 | import nvidia.dali.types as types 290 | from nvidia.dali.pipeline import Pipeline 291 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 292 | import torch.distributed as dist 293 | rec_file = os.path.join(root_dir, 'train.rec') 294 | idx_file = os.path.join(root_dir, 'train.idx') 295 | rank: int = dist.get_rank() 296 | world_size: int = dist.get_world_size() 297 | 298 | 299 | pipe = Pipeline( 300 | batch_size=batch_size, num_threads=num_threads, 301 | device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, ) 302 | condition_flip = fn.random.coin_flip(probability=0.5) 303 | with pipe: 304 | jpegs, labels = fn.readers.mxnet( 305 | path=rec_file, index_path=idx_file, initial_fill=initial_fill, 306 | num_shards=world_size, shard_id=rank, 307 | random_shuffle=random_shuffle, pad_last_batch=False, name=name) 308 | images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) 309 | images = fn.crop_mirror_normalize( 310 | images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip) 311 | pipe.set_outputs(images, labels) 312 | pipe.build() 313 | return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, )) 314 | 315 | 316 | @torch.no_grad() 317 | class DALIWarper(object): 318 | def __init__(self, dali_iter): 319 | self.iter = dali_iter 320 | 321 | def __next__(self): 322 | data_dict = self.iter.__next__()[0] 323 | tensor_data = data_dict['data'].cuda() 324 | tensor_label: torch.Tensor = data_dict['label'].cuda().long() 325 | tensor_label.squeeze_() 326 | return tensor_data, tensor_label 327 | 328 | def __iter__(self): 329 | return self 330 | 331 | def __len__(self): 332 | return 376166 333 | 334 | def reset(self): 335 | self.iter.reset() 336 | 337 | if __name__ == '__main__': 338 | train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=0) 339 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) 340 | train_loader = DataLoaderX( 341 | local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, 342 | sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) -------------------------------------------------------------------------------- /eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/eval/__init__.py -------------------------------------------------------------------------------- /eval/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/eval/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /eval/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/eval/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /eval/__pycache__/verification.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/eval/__pycache__/verification.cpython-36.pyc -------------------------------------------------------------------------------- /eval/__pycache__/verification.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/eval/__pycache__/verification.cpython-38.pyc -------------------------------------------------------------------------------- /eval/verification.py: -------------------------------------------------------------------------------- 1 | """Helper for evaluation on the Labeled Faces in the Wild dataset 2 | """ 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2016 David Sandberg 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to the following conditions: 14 | # 15 | # The above copyright notice and this permission notice shall be included in all 16 | # copies or substantial portions of the Software. 17 | # 18 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | # SOFTWARE. 25 | 26 | 27 | import datetime 28 | import os 29 | import pickle 30 | 31 | import mxnet as mx 32 | import numpy as np 33 | import sklearn 34 | import torch 35 | from mxnet import ndarray as nd 36 | from scipy import interpolate 37 | from sklearn.decomposition import PCA 38 | from sklearn.model_selection import KFold 39 | 40 | 41 | class LFold: 42 | def __init__(self, n_splits=2, shuffle=False): 43 | self.n_splits = n_splits 44 | if self.n_splits > 1: 45 | self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) 46 | 47 | def split(self, indices): 48 | if self.n_splits > 1: 49 | return self.k_fold.split(indices) 50 | else: 51 | return [(indices, indices)] 52 | 53 | 54 | def calculate_roc(thresholds, 55 | embeddings1, 56 | embeddings2, 57 | actual_issame, 58 | nrof_folds=10, 59 | pca=0): 60 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 61 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 62 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 63 | nrof_thresholds = len(thresholds) 64 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 65 | 66 | tprs = np.zeros((nrof_folds, nrof_thresholds)) 67 | fprs = np.zeros((nrof_folds, nrof_thresholds)) 68 | accuracy = np.zeros((nrof_folds)) 69 | indices = np.arange(nrof_pairs) 70 | 71 | if pca == 0: 72 | diff = np.subtract(embeddings1, embeddings2) 73 | dist = np.sum(np.square(diff), 1) 74 | 75 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 76 | if pca > 0: 77 | print('doing pca on', fold_idx) 78 | embed1_train = embeddings1[train_set] 79 | embed2_train = embeddings2[train_set] 80 | _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) 81 | pca_model = PCA(n_components=pca) 82 | pca_model.fit(_embed_train) 83 | embed1 = pca_model.transform(embeddings1) 84 | embed2 = pca_model.transform(embeddings2) 85 | embed1 = sklearn.preprocessing.normalize(embed1) 86 | embed2 = sklearn.preprocessing.normalize(embed2) 87 | diff = np.subtract(embed1, embed2) 88 | dist = np.sum(np.square(diff), 1) 89 | 90 | # Find the best threshold for the fold 91 | acc_train = np.zeros((nrof_thresholds)) 92 | for threshold_idx, threshold in enumerate(thresholds): 93 | _, _, acc_train[threshold_idx] = calculate_accuracy( 94 | threshold, dist[train_set], actual_issame[train_set]) 95 | best_threshold_index = np.argmax(acc_train) 96 | for threshold_idx, threshold in enumerate(thresholds): 97 | tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( 98 | threshold, dist[test_set], 99 | actual_issame[test_set]) 100 | _, _, accuracy[fold_idx] = calculate_accuracy( 101 | thresholds[best_threshold_index], dist[test_set], 102 | actual_issame[test_set]) 103 | 104 | tpr = np.mean(tprs, 0) 105 | fpr = np.mean(fprs, 0) 106 | return tpr, fpr, accuracy 107 | 108 | 109 | def calculate_accuracy(threshold, dist, actual_issame): 110 | predict_issame = np.less(dist, threshold) 111 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 112 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 113 | tn = np.sum( 114 | np.logical_and(np.logical_not(predict_issame), 115 | np.logical_not(actual_issame))) 116 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 117 | 118 | tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) 119 | fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) 120 | acc = float(tp + tn) / dist.size 121 | return tpr, fpr, acc 122 | 123 | 124 | def calculate_val(thresholds, 125 | embeddings1, 126 | embeddings2, 127 | actual_issame, 128 | far_target, 129 | nrof_folds=10): 130 | assert (embeddings1.shape[0] == embeddings2.shape[0]) 131 | assert (embeddings1.shape[1] == embeddings2.shape[1]) 132 | nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) 133 | nrof_thresholds = len(thresholds) 134 | k_fold = LFold(n_splits=nrof_folds, shuffle=False) 135 | 136 | val = np.zeros(nrof_folds) 137 | far = np.zeros(nrof_folds) 138 | 139 | diff = np.subtract(embeddings1, embeddings2) 140 | dist = np.sum(np.square(diff), 1) 141 | indices = np.arange(nrof_pairs) 142 | 143 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 144 | 145 | # Find the threshold that gives FAR = far_target 146 | far_train = np.zeros(nrof_thresholds) 147 | for threshold_idx, threshold in enumerate(thresholds): 148 | _, far_train[threshold_idx] = calculate_val_far( 149 | threshold, dist[train_set], actual_issame[train_set]) 150 | if np.max(far_train) >= far_target: 151 | f = interpolate.interp1d(far_train, thresholds, kind='slinear') 152 | threshold = f(far_target) 153 | else: 154 | threshold = 0.0 155 | 156 | val[fold_idx], far[fold_idx] = calculate_val_far( 157 | threshold, dist[test_set], actual_issame[test_set]) 158 | 159 | val_mean = np.mean(val) 160 | far_mean = np.mean(far) 161 | val_std = np.std(val) 162 | return val_mean, val_std, far_mean 163 | 164 | 165 | def calculate_val_far(threshold, dist, actual_issame): 166 | predict_issame = np.less(dist, threshold) 167 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 168 | false_accept = np.sum( 169 | np.logical_and(predict_issame, np.logical_not(actual_issame))) 170 | n_same = np.sum(actual_issame) 171 | n_diff = np.sum(np.logical_not(actual_issame)) 172 | # print(true_accept, false_accept) 173 | # print(n_same, n_diff) 174 | val = float(true_accept) / float(n_same) 175 | far = float(false_accept) / float(n_diff) 176 | return val, far 177 | 178 | 179 | def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): 180 | # Calculate evaluation metrics 181 | thresholds = np.arange(0, 4, 0.01) 182 | embeddings1 = embeddings[0::2] 183 | embeddings2 = embeddings[1::2] 184 | tpr, fpr, accuracy = calculate_roc(thresholds, 185 | embeddings1, 186 | embeddings2, 187 | np.asarray(actual_issame), 188 | nrof_folds=nrof_folds, 189 | pca=pca) 190 | thresholds = np.arange(0, 4, 0.001) 191 | val, val_std, far = calculate_val(thresholds, 192 | embeddings1, 193 | embeddings2, 194 | np.asarray(actual_issame), 195 | 1e-3, 196 | nrof_folds=nrof_folds) 197 | return tpr, fpr, accuracy, val, val_std, far 198 | 199 | @torch.no_grad() 200 | def load_bin(path, image_size): 201 | try: 202 | with open(path, 'rb') as f: 203 | bins, issame_list = pickle.load(f) # py2 204 | except UnicodeDecodeError as e: 205 | with open(path, 'rb') as f: 206 | bins, issame_list = pickle.load(f, encoding='bytes') # py3 207 | data_list = [] 208 | for flip in [0, 1]: 209 | data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) 210 | data_list.append(data) 211 | for idx in range(len(issame_list) * 2): 212 | _bin = bins[idx] 213 | img = mx.image.imdecode(_bin) 214 | if img.shape[1] != image_size[0]: 215 | img = mx.image.resize_short(img, image_size[0]) 216 | img = nd.transpose(img, axes=(2, 0, 1)) 217 | for flip in [0, 1]: 218 | if flip == 1: 219 | img = mx.ndarray.flip(data=img, axis=2) 220 | data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) 221 | if idx % 1000 == 0: 222 | print('loading bin', idx) 223 | print(data_list[0].shape) 224 | return data_list, issame_list 225 | 226 | @torch.no_grad() 227 | def test(data_set, backbone, batch_size, nfolds=10): 228 | print('testing verification..') 229 | data_list = data_set[0] 230 | issame_list = data_set[1] 231 | embeddings_list = [] 232 | time_consumed = 0.0 233 | for i in range(len(data_list)): 234 | data = data_list[i] 235 | embeddings = None 236 | ba = 0 237 | while ba < data.shape[0]: 238 | bb = min(ba + batch_size, data.shape[0]) 239 | count = bb - ba 240 | _data = data[bb - batch_size: bb] 241 | time0 = datetime.datetime.now() 242 | img = ((_data / 255) - 0.5) / 0.5 243 | net_out: torch.Tensor = backbone(img) 244 | _embeddings = net_out.detach().cpu().numpy() 245 | time_now = datetime.datetime.now() 246 | diff = time_now - time0 247 | time_consumed += diff.total_seconds() 248 | if embeddings is None: 249 | embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) 250 | embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] 251 | ba = bb 252 | embeddings_list.append(embeddings) 253 | 254 | _xnorm = 0.0 255 | _xnorm_cnt = 0 256 | for embed in embeddings_list: 257 | for i in range(embed.shape[0]): 258 | _em = embed[i] 259 | _norm = np.linalg.norm(_em) 260 | _xnorm += _norm 261 | _xnorm_cnt += 1 262 | _xnorm /= _xnorm_cnt 263 | 264 | acc1 = 0.0 265 | std1 = 0.0 266 | embeddings = embeddings_list[0] + embeddings_list[1] 267 | embeddings = sklearn.preprocessing.normalize(embeddings) 268 | print(embeddings.shape) 269 | print('infer time', time_consumed) 270 | _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) 271 | acc2, std2 = np.mean(accuracy), np.std(accuracy) 272 | return acc1, std1, acc2, std2, _xnorm, embeddings_list 273 | 274 | 275 | def dumpR(data_set, 276 | backbone, 277 | batch_size, 278 | name='', 279 | data_extra=None, 280 | label_shape=None): 281 | print('dump verification embedding..') 282 | data_list = data_set[0] 283 | issame_list = data_set[1] 284 | embeddings_list = [] 285 | time_consumed = 0.0 286 | for i in range(len(data_list)): 287 | data = data_list[i] 288 | embeddings = None 289 | ba = 0 290 | while ba < data.shape[0]: 291 | bb = min(ba + batch_size, data.shape[0]) 292 | count = bb - ba 293 | 294 | _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb) 295 | time0 = datetime.datetime.now() 296 | if data_extra is None: 297 | db = mx.io.DataBatch(data=(_data,), label=(_label,)) 298 | else: 299 | db = mx.io.DataBatch(data=(_data, _data_extra), 300 | label=(_label,)) 301 | model.forward(db, is_train=False) 302 | net_out = model.get_outputs() 303 | _embeddings = net_out[0].asnumpy() 304 | time_now = datetime.datetime.now() 305 | diff = time_now - time0 306 | time_consumed += diff.total_seconds() 307 | if embeddings is None: 308 | embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) 309 | embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] 310 | ba = bb 311 | embeddings_list.append(embeddings) 312 | embeddings = embeddings_list[0] + embeddings_list[1] 313 | embeddings = sklearn.preprocessing.normalize(embeddings) 314 | actual_issame = np.asarray(issame_list) 315 | outname = os.path.join('temp.bin') 316 | with open(outname, 'wb') as f: 317 | pickle.dump((embeddings, issame_list), 318 | f, 319 | protocol=pickle.HIGHEST_PROTOCOL) 320 | 321 | 322 | # if __name__ == '__main__': 323 | # 324 | # parser = argparse.ArgumentParser(description='do verification') 325 | # # general 326 | # parser.add_argument('--data-dir', default='', help='') 327 | # parser.add_argument('--model', 328 | # default='../model/softmax,50', 329 | # help='path to load model.') 330 | # parser.add_argument('--target', 331 | # default='lfw,cfp_ff,cfp_fp,agedb_30', 332 | # help='test targets.') 333 | # parser.add_argument('--gpu', default=0, type=int, help='gpu id') 334 | # parser.add_argument('--batch-size', default=32, type=int, help='') 335 | # parser.add_argument('--max', default='', type=str, help='') 336 | # parser.add_argument('--mode', default=0, type=int, help='') 337 | # parser.add_argument('--nfolds', default=10, type=int, help='') 338 | # args = parser.parse_args() 339 | # image_size = [112, 112] 340 | # print('image_size', image_size) 341 | # ctx = mx.gpu(args.gpu) 342 | # nets = [] 343 | # vec = args.model.split(',') 344 | # prefix = args.model.split(',')[0] 345 | # epochs = [] 346 | # if len(vec) == 1: 347 | # pdir = os.path.dirname(prefix) 348 | # for fname in os.listdir(pdir): 349 | # if not fname.endswith('.params'): 350 | # continue 351 | # _file = os.path.join(pdir, fname) 352 | # if _file.startswith(prefix): 353 | # epoch = int(fname.split('.')[0].split('-')[1]) 354 | # epochs.append(epoch) 355 | # epochs = sorted(epochs, reverse=True) 356 | # if len(args.max) > 0: 357 | # _max = [int(x) for x in args.max.split(',')] 358 | # assert len(_max) == 2 359 | # if len(epochs) > _max[1]: 360 | # epochs = epochs[_max[0]:_max[1]] 361 | # 362 | # else: 363 | # epochs = [int(x) for x in vec[1].split('|')] 364 | # print('model number', len(epochs)) 365 | # time0 = datetime.datetime.now() 366 | # for epoch in epochs: 367 | # print('loading', prefix, epoch) 368 | # sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) 369 | # # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx) 370 | # all_layers = sym.get_internals() 371 | # sym = all_layers['fc1_output'] 372 | # model = mx.mod.Module(symbol=sym, context=ctx, label_names=None) 373 | # # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))]) 374 | # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], 375 | # image_size[1]))]) 376 | # model.set_params(arg_params, aux_params) 377 | # nets.append(model) 378 | # time_now = datetime.datetime.now() 379 | # diff = time_now - time0 380 | # print('model loading time', diff.total_seconds()) 381 | # 382 | # ver_list = [] 383 | # ver_name_list = [] 384 | # for name in args.target.split(','): 385 | # path = os.path.join(args.data_dir, name + ".bin") 386 | # if os.path.exists(path): 387 | # print('loading.. ', name) 388 | # data_set = load_bin(path, image_size) 389 | # ver_list.append(data_set) 390 | # ver_name_list.append(name) 391 | # 392 | # if args.mode == 0: 393 | # for i in range(len(ver_list)): 394 | # results = [] 395 | # for model in nets: 396 | # acc1, std1, acc2, std2, xnorm, embeddings_list = test( 397 | # ver_list[i], model, args.batch_size, args.nfolds) 398 | # print('[%s]XNorm: %f' % (ver_name_list[i], xnorm)) 399 | # print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1)) 400 | # print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2)) 401 | # results.append(acc2) 402 | # print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results))) 403 | # elif args.mode == 1: 404 | # raise ValueError 405 | # else: 406 | # model = nets[0] 407 | # dumpR(ver_list[0], model, args.batch_size, args.target) 408 | -------------------------------------------------------------------------------- /eval_ijbc.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import pickle 5 | 6 | import matplotlib 7 | import pandas as pd 8 | 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | import timeit 12 | import sklearn 13 | import argparse 14 | import cv2 15 | import numpy as np 16 | import torch 17 | from skimage import transform as trans 18 | from backbones import get_model 19 | from sklearn.metrics import roc_curve, auc 20 | 21 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 22 | from prettytable import PrettyTable 23 | from pathlib import Path 24 | 25 | import sys 26 | import warnings 27 | 28 | sys.path.insert(0, "../") 29 | warnings.filterwarnings("ignore") 30 | 31 | parser = argparse.ArgumentParser(description='do ijb test') 32 | # general 33 | parser.add_argument('--model-prefix', default='', help='path to load model.') 34 | parser.add_argument('--image-path', default='', type=str, help='') 35 | parser.add_argument('--result-dir', default='.', type=str, help='') 36 | parser.add_argument('--batch-size', default=128, type=int, help='') 37 | parser.add_argument('--network', default='iresnet50', type=str, help='') 38 | parser.add_argument('--job', default='insightface', type=str, help='job name') 39 | parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') 40 | args = parser.parse_args() 41 | 42 | target = args.target 43 | model_path = args.model_prefix 44 | image_path = args.image_path 45 | result_dir = args.result_dir 46 | gpu_id = None 47 | use_norm_score = True # if Ture, TestMode(N1) 48 | use_detector_score = True # if Ture, TestMode(D1) 49 | use_flip_test = True # if Ture, TestMode(F1) 50 | job = args.job 51 | batch_size = args.batch_size 52 | 53 | 54 | class Embedding(object): 55 | def __init__(self, prefix, data_shape, batch_size=1): 56 | image_size = (112, 112) 57 | self.image_size = image_size 58 | weight = torch.load(prefix) 59 | resnet = get_model(args.network, dropout=0, fp16=False).cuda() 60 | resnet.load_state_dict(weight) 61 | model = torch.nn.DataParallel(resnet) 62 | self.model = model 63 | self.model.eval() 64 | src = np.array([ 65 | [30.2946, 51.6963], 66 | [65.5318, 51.5014], 67 | [48.0252, 71.7366], 68 | [33.5493, 92.3655], 69 | [62.7299, 92.2041]], dtype=np.float32) 70 | src[:, 0] += 8.0 71 | self.src = src 72 | self.batch_size = batch_size 73 | self.data_shape = data_shape 74 | 75 | def get(self, rimg, landmark): 76 | 77 | assert landmark.shape[0] == 68 or landmark.shape[0] == 5 78 | assert landmark.shape[1] == 2 79 | if landmark.shape[0] == 68: 80 | landmark5 = np.zeros((5, 2), dtype=np.float32) 81 | landmark5[0] = (landmark[36] + landmark[39]) / 2 82 | landmark5[1] = (landmark[42] + landmark[45]) / 2 83 | landmark5[2] = landmark[30] 84 | landmark5[3] = landmark[48] 85 | landmark5[4] = landmark[54] 86 | else: 87 | landmark5 = landmark 88 | tform = trans.SimilarityTransform() 89 | tform.estimate(landmark5, self.src) 90 | M = tform.params[0:2, :] 91 | img = cv2.warpAffine(rimg, 92 | M, (self.image_size[1], self.image_size[0]), 93 | borderValue=0.0) 94 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 95 | img_flip = np.fliplr(img) 96 | img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB 97 | img_flip = np.transpose(img_flip, (2, 0, 1)) 98 | input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8) 99 | input_blob[0] = img 100 | input_blob[1] = img_flip 101 | return input_blob 102 | 103 | @torch.no_grad() 104 | def forward_db(self, batch_data): 105 | imgs = torch.Tensor(batch_data).cuda() 106 | imgs.div_(255).sub_(0.5).div_(0.5) 107 | feat = self.model(imgs) 108 | feat = feat.reshape([self.batch_size, 2 * feat.shape[1]]) 109 | return feat.cpu().numpy() 110 | 111 | 112 | # 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[] 113 | def divideIntoNstrand(listTemp, n): 114 | twoList = [[] for i in range(n)] 115 | for i, e in enumerate(listTemp): 116 | twoList[i % n].append(e) 117 | return twoList 118 | 119 | 120 | def read_template_media_list(path): 121 | # ijb_meta = np.loadtxt(path, dtype=str) 122 | ijb_meta = pd.read_csv(path, sep=' ', header=None).values 123 | templates = ijb_meta[:, 1].astype(np.int) 124 | medias = ijb_meta[:, 2].astype(np.int) 125 | return templates, medias 126 | 127 | 128 | # In[ ]: 129 | 130 | 131 | def read_template_pair_list(path): 132 | # pairs = np.loadtxt(path, dtype=str) 133 | pairs = pd.read_csv(path, sep=' ', header=None).values 134 | # print(pairs.shape) 135 | # print(pairs[:, 0].astype(np.int)) 136 | t1 = pairs[:, 0].astype(np.int) 137 | t2 = pairs[:, 1].astype(np.int) 138 | label = pairs[:, 2].astype(np.int) 139 | return t1, t2, label 140 | 141 | 142 | # In[ ]: 143 | 144 | 145 | def read_image_feature(path): 146 | with open(path, 'rb') as fid: 147 | img_feats = pickle.load(fid) 148 | return img_feats 149 | 150 | 151 | # In[ ]: 152 | 153 | 154 | def get_image_feature(img_path, files_list, model_path, epoch, gpu_id): 155 | batch_size = args.batch_size 156 | data_shape = (3, 112, 112) 157 | 158 | files = files_list 159 | print('files:', len(files)) 160 | rare_size = len(files) % batch_size 161 | faceness_scores = [] 162 | batch = 0 163 | img_feats = np.empty((len(files), 1024), dtype=np.float32) 164 | 165 | batch_data = np.empty((2 * batch_size, 3, 112, 112)) 166 | embedding = Embedding(model_path, data_shape, batch_size) 167 | for img_index, each_line in enumerate(files[:len(files) - rare_size]): 168 | name_lmk_score = each_line.strip().split(' ') 169 | img_name = os.path.join(img_path, name_lmk_score[0]) 170 | img = cv2.imread(img_name) 171 | lmk = np.array([float(x) for x in name_lmk_score[1:-1]], 172 | dtype=np.float32) 173 | lmk = lmk.reshape((5, 2)) 174 | input_blob = embedding.get(img, lmk) 175 | 176 | batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0] 177 | batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1] 178 | if (img_index + 1) % batch_size == 0: 179 | print('batch', batch) 180 | img_feats[batch * batch_size:batch * batch_size + 181 | batch_size][:] = embedding.forward_db(batch_data) 182 | batch += 1 183 | faceness_scores.append(name_lmk_score[-1]) 184 | 185 | batch_data = np.empty((2 * rare_size, 3, 112, 112)) 186 | embedding = Embedding(model_path, data_shape, rare_size) 187 | for img_index, each_line in enumerate(files[len(files) - rare_size:]): 188 | name_lmk_score = each_line.strip().split(' ') 189 | img_name = os.path.join(img_path, name_lmk_score[0]) 190 | img = cv2.imread(img_name) 191 | lmk = np.array([float(x) for x in name_lmk_score[1:-1]], 192 | dtype=np.float32) 193 | lmk = lmk.reshape((5, 2)) 194 | input_blob = embedding.get(img, lmk) 195 | batch_data[2 * img_index][:] = input_blob[0] 196 | batch_data[2 * img_index + 1][:] = input_blob[1] 197 | if (img_index + 1) % rare_size == 0: 198 | print('batch', batch) 199 | img_feats[len(files) - 200 | rare_size:][:] = embedding.forward_db(batch_data) 201 | batch += 1 202 | faceness_scores.append(name_lmk_score[-1]) 203 | faceness_scores = np.array(faceness_scores).astype(np.float32) 204 | # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01 205 | # faceness_scores = np.ones( (len(files), ), dtype=np.float32 ) 206 | return img_feats, faceness_scores 207 | 208 | 209 | # In[ ]: 210 | 211 | 212 | def image2template_feature(img_feats=None, templates=None, medias=None): 213 | # ========================================================== 214 | # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] 215 | # 2. compute media feature. 216 | # 3. compute template feature. 217 | # ========================================================== 218 | unique_templates = np.unique(templates) 219 | template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) 220 | 221 | for count_template, uqt in enumerate(unique_templates): 222 | 223 | (ind_t,) = np.where(templates == uqt) 224 | face_norm_feats = img_feats[ind_t] 225 | face_medias = medias[ind_t] 226 | unique_medias, unique_media_counts = np.unique(face_medias, 227 | return_counts=True) 228 | media_norm_feats = [] 229 | for u, ct in zip(unique_medias, unique_media_counts): 230 | (ind_m,) = np.where(face_medias == u) 231 | if ct == 1: 232 | media_norm_feats += [face_norm_feats[ind_m]] 233 | else: # image features from the same video will be aggregated into one feature 234 | media_norm_feats += [ 235 | np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) 236 | ] 237 | media_norm_feats = np.array(media_norm_feats) 238 | # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) 239 | template_feats[count_template] = np.sum(media_norm_feats, axis=0) 240 | if count_template % 2000 == 0: 241 | print('Finish Calculating {} template features.'.format( 242 | count_template)) 243 | # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) 244 | template_norm_feats = sklearn.preprocessing.normalize(template_feats) 245 | # print(template_norm_feats.shape) 246 | return template_norm_feats, unique_templates 247 | 248 | 249 | # In[ ]: 250 | 251 | 252 | def verification(template_norm_feats=None, 253 | unique_templates=None, 254 | p1=None, 255 | p2=None): 256 | # ========================================================== 257 | # Compute set-to-set Similarity Score. 258 | # ========================================================== 259 | template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) 260 | for count_template, uqt in enumerate(unique_templates): 261 | template2id[uqt] = count_template 262 | 263 | score = np.zeros((len(p1),)) # save cosine distance between pairs 264 | 265 | total_pairs = np.array(range(len(p1))) 266 | batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation 267 | sublists = [ 268 | total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) 269 | ] 270 | total_sublists = len(sublists) 271 | for c, s in enumerate(sublists): 272 | feat1 = template_norm_feats[template2id[p1[s]]] 273 | feat2 = template_norm_feats[template2id[p2[s]]] 274 | similarity_score = np.sum(feat1 * feat2, -1) 275 | score[s] = similarity_score.flatten() 276 | if c % 10 == 0: 277 | print('Finish {}/{} pairs.'.format(c, total_sublists)) 278 | return score 279 | 280 | 281 | # In[ ]: 282 | def verification2(template_norm_feats=None, 283 | unique_templates=None, 284 | p1=None, 285 | p2=None): 286 | template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) 287 | for count_template, uqt in enumerate(unique_templates): 288 | template2id[uqt] = count_template 289 | score = np.zeros((len(p1),)) # save cosine distance between pairs 290 | total_pairs = np.array(range(len(p1))) 291 | batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation 292 | sublists = [ 293 | total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) 294 | ] 295 | total_sublists = len(sublists) 296 | for c, s in enumerate(sublists): 297 | feat1 = template_norm_feats[template2id[p1[s]]] 298 | feat2 = template_norm_feats[template2id[p2[s]]] 299 | similarity_score = np.sum(feat1 * feat2, -1) 300 | score[s] = similarity_score.flatten() 301 | if c % 10 == 0: 302 | print('Finish {}/{} pairs.'.format(c, total_sublists)) 303 | return score 304 | 305 | 306 | def read_score(path): 307 | with open(path, 'rb') as fid: 308 | img_feats = pickle.load(fid) 309 | return img_feats 310 | 311 | 312 | # # Step1: Load Meta Data 313 | 314 | # In[ ]: 315 | 316 | assert target == 'IJBC' or target == 'IJBB' 317 | 318 | # ============================================================= 319 | # load image and template relationships for template feature embedding 320 | # tid --> template id, mid --> media id 321 | # format: 322 | # image_name tid mid 323 | # ============================================================= 324 | start = timeit.default_timer() 325 | templates, medias = read_template_media_list( 326 | os.path.join('%s/meta' % image_path, 327 | '%s_face_tid_mid.txt' % target.lower())) 328 | stop = timeit.default_timer() 329 | print('Time: %.2f s. ' % (stop - start)) 330 | 331 | # In[ ]: 332 | 333 | # ============================================================= 334 | # load template pairs for template-to-template verification 335 | # tid : template id, label : 1/0 336 | # format: 337 | # tid_1 tid_2 label 338 | # ============================================================= 339 | start = timeit.default_timer() 340 | p1, p2, label = read_template_pair_list( 341 | os.path.join('%s/meta' % image_path, 342 | '%s_template_pair_label.txt' % target.lower())) 343 | stop = timeit.default_timer() 344 | print('Time: %.2f s. ' % (stop - start)) 345 | 346 | # # Step 2: Get Image Features 347 | 348 | # In[ ]: 349 | 350 | # ============================================================= 351 | # load image features 352 | # format: 353 | # img_feats: [image_num x feats_dim] (227630, 512) 354 | # ============================================================= 355 | start = timeit.default_timer() 356 | img_path = '%s/loose_crop' % image_path 357 | img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower()) 358 | img_list = open(img_list_path) 359 | files = img_list.readlines() 360 | # files_list = divideIntoNstrand(files, rank_size) 361 | files_list = files 362 | 363 | # img_feats 364 | # for i in range(rank_size): 365 | img_feats, faceness_scores = get_image_feature(img_path, files_list, 366 | model_path, 0, gpu_id) 367 | stop = timeit.default_timer() 368 | print('Time: %.2f s. ' % (stop - start)) 369 | print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], 370 | img_feats.shape[1])) 371 | 372 | # # Step3: Get Template Features 373 | 374 | # In[ ]: 375 | 376 | # ============================================================= 377 | # compute template features from image features. 378 | # ============================================================= 379 | start = timeit.default_timer() 380 | # ========================================================== 381 | # Norm feature before aggregation into template feature? 382 | # Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face). 383 | # ========================================================== 384 | # 1. FaceScore (Feature Norm) 385 | # 2. FaceScore (Detector) 386 | 387 | if use_flip_test: 388 | # concat --- F1 389 | # img_input_feats = img_feats 390 | # add --- F2 391 | img_input_feats = img_feats[:, 0:img_feats.shape[1] // 392 | 2] + img_feats[:, img_feats.shape[1] // 2:] 393 | else: 394 | img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] 395 | 396 | if use_norm_score: 397 | img_input_feats = img_input_feats 398 | else: 399 | # normalise features to remove norm information 400 | img_input_feats = img_input_feats / np.sqrt( 401 | np.sum(img_input_feats ** 2, -1, keepdims=True)) 402 | 403 | if use_detector_score: 404 | print(img_input_feats.shape, faceness_scores.shape) 405 | img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] 406 | else: 407 | img_input_feats = img_input_feats 408 | 409 | template_norm_feats, unique_templates = image2template_feature( 410 | img_input_feats, templates, medias) 411 | stop = timeit.default_timer() 412 | print('Time: %.2f s. ' % (stop - start)) 413 | 414 | # # Step 4: Get Template Similarity Scores 415 | 416 | # In[ ]: 417 | 418 | # ============================================================= 419 | # compute verification scores between template pairs. 420 | # ============================================================= 421 | start = timeit.default_timer() 422 | score = verification(template_norm_feats, unique_templates, p1, p2) 423 | stop = timeit.default_timer() 424 | print('Time: %.2f s. ' % (stop - start)) 425 | 426 | # In[ ]: 427 | save_path = os.path.join(result_dir, args.job) 428 | # save_path = result_dir + '/%s_result' % target 429 | 430 | if not os.path.exists(save_path): 431 | os.makedirs(save_path) 432 | 433 | score_save_file = os.path.join(save_path, "%s.npy" % target.lower()) 434 | np.save(score_save_file, score) 435 | 436 | # # Step 5: Get ROC Curves and TPR@FPR Table 437 | 438 | # In[ ]: 439 | 440 | files = [score_save_file] 441 | methods = [] 442 | scores = [] 443 | for file in files: 444 | methods.append(Path(file).stem) 445 | scores.append(np.load(file)) 446 | 447 | methods = np.array(methods) 448 | scores = dict(zip(methods, scores)) 449 | colours = dict( 450 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 451 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 452 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 453 | fig = plt.figure() 454 | for method in methods: 455 | fpr, tpr, _ = roc_curve(label, scores[method]) 456 | roc_auc = auc(fpr, tpr) 457 | fpr = np.flipud(fpr) 458 | tpr = np.flipud(tpr) # select largest tpr at same fpr 459 | plt.plot(fpr, 460 | tpr, 461 | color=colours[method], 462 | lw=1, 463 | label=('[%s (AUC = %0.4f %%)]' % 464 | (method.split('-')[-1], roc_auc * 100))) 465 | tpr_fpr_row = [] 466 | tpr_fpr_row.append("%s-%s" % (method, target)) 467 | for fpr_iter in np.arange(len(x_labels)): 468 | _, min_index = min( 469 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 470 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 471 | tpr_fpr_table.add_row(tpr_fpr_row) 472 | plt.xlim([10 ** -6, 0.1]) 473 | plt.ylim([0.3, 1.0]) 474 | plt.grid(linestyle='--', linewidth=1) 475 | plt.xticks(x_labels) 476 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 477 | plt.xscale('log') 478 | plt.xlabel('False Positive Rate') 479 | plt.ylabel('True Positive Rate') 480 | plt.title('ROC on IJB') 481 | plt.legend(loc="lower right") 482 | fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower())) 483 | print(tpr_fpr_table) 484 | -------------------------------------------------------------------------------- /face_masker.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yinglu Liu, Jun Wang 3 | @date: 20201012 4 | @contact: jun21wangustc@gmail.com 5 | """ 6 | 7 | import os 8 | from random import randint 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | import cv2 12 | import torch 13 | import numpy as np 14 | from skimage.io import imread, imsave 15 | from skimage.transform import estimate_transform, warp 16 | from utils import read_info 17 | from model.prnet import PRNet 18 | from utils.cython.render import render_cy 19 | import time 20 | 21 | class PRN: 22 | """Process of PRNet. 23 | based on: 24 | https://github.com/YadiraF/PRNet/blob/master/api.py 25 | """ 26 | def __init__(self, model_path, device): 27 | self.resolution = 256 28 | self.MaxPos = self.resolution*1.1 29 | self.face_ind = np.loadtxt('Data/uv-data/face_ind.txt').astype(np.int32) 30 | self.triangles = np.loadtxt('Data/uv-data/triangles.txt').astype(np.int32) 31 | self.net = PRNet(3, 3) 32 | self.device = device 33 | state_dict = torch.load(model_path, map_location=self.device) 34 | self.net.load_state_dict(state_dict) 35 | self.net.eval() 36 | # if torch.cuda.is_available(): 37 | self.net.to(self.device) 38 | 39 | def process(self, image, image_info): 40 | if np.max(image_info.shape) > 4: # key points to get bounding box 41 | kpt = image_info 42 | if kpt.shape[0] > 3: 43 | kpt = kpt.T 44 | left = np.min(kpt[0, :]); right = np.max(kpt[0, :]); 45 | top = np.min(kpt[1,:]); bottom = np.max(kpt[1,:]) 46 | else: # bounding box 47 | bbox = image_info 48 | left = bbox[0]; right = bbox[1]; top = bbox[2]; bottom = bbox[3] 49 | old_size = (right - left + bottom - top)/2 50 | center = np.array([right - (right - left) / 2.0, bottom - (bottom - top) / 2.0]) 51 | size = int(old_size*1.6) 52 | # crop image 53 | src_pts = np.array([[center[0]-size/2, center[1]-size/2], 54 | [center[0] - size/2, center[1]+size/2], 55 | [center[0]+size/2, center[1]-size/2]]) 56 | DST_PTS = np.array([[0,0], [0,self.resolution - 1], [self.resolution - 1, 0]]) 57 | tform = estimate_transform('similarity', src_pts, DST_PTS) 58 | cropped_image = warp(image, tform.inverse, output_shape=(self.resolution, self.resolution)) 59 | cropped_image = np.transpose(cropped_image[np.newaxis, :,:,:], (0, 3, 1, 2)).astype(np.float32) 60 | cropped_image = torch.from_numpy(cropped_image) 61 | # if torch.cuda.is_available(): 62 | # cropped_image = cropped_image.cuda() 63 | with torch.no_grad(): 64 | cropped_image = cropped_image.to(self.device) 65 | cropped_pos = self.net(cropped_image) 66 | cropped_pos = cropped_pos.cpu().detach().numpy() 67 | cropped_pos = np.transpose(cropped_pos, (0, 2, 3, 1)).squeeze() * self.MaxPos 68 | # restore 69 | cropped_vertices = np.reshape(cropped_pos, [-1, 3]).T 70 | z = cropped_vertices[2,:].copy()/tform.params[0,0] 71 | cropped_vertices[2,:] = 1 72 | vertices = np.dot(np.linalg.inv(tform.params), cropped_vertices) 73 | vertices = np.vstack((vertices[:2,:], z)) 74 | pos = np.reshape(vertices.T, [self.resolution, self.resolution, 3]) 75 | return pos 76 | def get_vertices(self, pos): 77 | all_vertices = np.reshape(pos, [self.resolution ** 2, -1]) 78 | vertices = all_vertices[self.face_ind, :] 79 | return vertices 80 | def get_colors_from_texture(self, texture): 81 | all_colors = np.reshape(texture, [self.resolution**2, -1]) 82 | colors = all_colors[self.face_ind, :] 83 | return colors 84 | 85 | class FaceMasker: 86 | """Add a virtual mask in face. 87 | 88 | Attributes: 89 | uv_face_path(str): the path of uv_face. 90 | mask_template_folder(str): the directory where all mask template in. 91 | prn(object): PRN object, https://github.com/YadiraF/PRNet. 92 | template_name2ref_texture_src(dict): key is template name, value is the mask load by skimage.io. 93 | template_name2uv_mask_src(dict): key is template name, value is the uv_mask. 94 | is_aug(bool): whether or not to add some augmentaion operation on the mask. 95 | """ 96 | def __init__(self, is_aug, device): 97 | """init for FaceMasker 98 | 99 | Args: 100 | is_aug(bool): whether or not to add some augmentaion operation on the mask. 101 | """ 102 | self.device = device 103 | self.uv_face_path = 'Data/uv-data/uv_face_mask.png' 104 | self.mask_template_folder = 'Data/mask-data' 105 | self.prn = PRN('model/prnet.pth', device = self.device) 106 | self.template_name2ref_texture_src, self.template_name2uv_mask_src = self.get_ref_texture_src() 107 | self.is_aug = is_aug 108 | 109 | 110 | def get_ref_texture_src(self): 111 | template_name2ref_texture_src = {} 112 | template_name2uv_mask_src = {} 113 | mask_template_list = os.listdir(self.mask_template_folder) 114 | uv_face = imread(self.uv_face_path, as_gray=True)/255. 115 | for mask_template in mask_template_list: 116 | # print('Create UV map for template: ', mask_template) 117 | mask_template_path = os.path.join(self.mask_template_folder, mask_template) 118 | ref_texture_src = imread(mask_template_path, as_gray=False)/255. 119 | if ref_texture_src.shape[2] == 4: # must 4 channel, how about 3 channel? 120 | uv_mask_src = ref_texture_src[:,:,3] 121 | ref_texture_src = ref_texture_src[:,:,:3] 122 | else: 123 | print('Fatal error!', mask_template_path) 124 | uv_mask_src[uv_face == 0] = 0 125 | template_name2ref_texture_src[mask_template] = ref_texture_src 126 | template_name2uv_mask_src[mask_template] = uv_mask_src 127 | return template_name2ref_texture_src, template_name2uv_mask_src 128 | 129 | def add_mask(self, face_root, image_name2lms, image_name2template_name, masked_face_root): 130 | for image_name, face_lms in image_name2lms.items(): 131 | image_path = os.path.join(face_root, image_name) 132 | masked_face_path = os.path.join(masked_face_root, image_name) 133 | template_name = image_name2template_name[image_name] 134 | self.add_mask_one(image_path, face_lms, template_name, masked_face_path) 135 | 136 | # you can speed it up by a c++ version. 137 | def render(self, vertices, new_colors, h, w): 138 | vis_colors = np.ones((vertices.shape[0], 1)) 139 | face_mask = render_texture(vertices.T, vis_colors.T, self.prn.triangles.T, h, w, c=1).astype(np.uint8) 140 | face_mask = np.squeeze(face_mask > 0) 141 | new_image = render_texture(vertices.T, new_colors.T, self.prn.triangles.T, h, w, c=3) 142 | return face_mask, new_image 143 | 144 | def add_mask_one(self, image, face_lms, template_name, masked_face_path, padded = None, write_image = True, pos_vertices = None): 145 | """Add mask to one image. 146 | 147 | Args: 148 | image_path(str): the image to add mask. 149 | face_lms(str): face landmarks, [x1, y1, x2, y2, ..., x106, y106] 150 | template_name(str): the mask template to be added on the current image, 151 | got to '/Data/mask-data' for all template. 152 | masked_face_path(str): the path to save masked image. 153 | """ 154 | # image = imread(image_path) 155 | # t1 = time.time() 156 | ref_texture_src = self.template_name2ref_texture_src[template_name] 157 | uv_mask_src = self.template_name2uv_mask_src[template_name] 158 | if image.ndim == 2: 159 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 160 | [h, w, c] = image.shape 161 | if c == 4: 162 | image = image[:,:,:3] 163 | if pos_vertices is None: 164 | pos, vertices = self.get_vertices(face_lms, image) #3d reconstruction -> get texture. 165 | else: 166 | print('Found exists vertices, use this params') 167 | pos, vertices = pos_vertices 168 | image = image/255. #!! 169 | texture = cv2.remap(image, pos[:,:,:2].astype(np.float32), None, 170 | interpolation=cv2.INTER_NEAREST, 171 | borderMode=cv2.BORDER_CONSTANT,borderValue=(0)) 172 | # print(texture.shape) 173 | # imsave('texture.jpg', texture) 174 | # t2 = time.time() 175 | new_texture = self.get_new_texture(ref_texture_src, uv_mask_src, texture) 176 | new_colors = self.prn.get_colors_from_texture(new_texture) 177 | # print('Render cy') 178 | # t3 = time.time() 179 | # render 180 | face_mask, new_image = render_cy(np.ascontiguousarray(vertices.T), np.ascontiguousarray(new_colors.T), np.ascontiguousarray(self.prn.triangles.T.astype(np.int64)), h, w) 181 | # t4 = time.time() 182 | # imsave('face_mask.jpg', face_mask) 183 | # imsave('new_image.jpg', new_image) 184 | # print('Render done') 185 | face_mask = np.squeeze(np.floor(face_mask) > 0) 186 | 187 | tmp = new_image * face_mask[:, :, np.newaxis] 188 | new_image = image * (1 - face_mask[:, :, np.newaxis]) + new_image * face_mask[:, :, np.newaxis] 189 | new_image = np.clip(new_image, -1, 1) #must clip to (-1, 1)! 190 | t5 = time.time() 191 | # print('[FaceMasker] Time preprocess: ', t2 - t1) 192 | # print('[FaceMasker] Time feed: ', t3 - t2) 193 | # print('[FaceMasker] Time render: ', t4 - t3) 194 | # print('[FaceMasker] Time post-process: ', t5 - t4) 195 | if padded is not None: 196 | if write_image: 197 | imsave(masked_face_path, new_image[padded:-padded, padded:-padded, :]) 198 | return new_image[padded:-padded, padded:-padded, :] 199 | else: 200 | if write_image: 201 | imsave(masked_face_path, new_image) 202 | return new_image 203 | 204 | 205 | def create_mask_one(self, image, image_segment, face_lms, output): 206 | """ Create mask for single input image 207 | """ 208 | # image = imread(image_path) 209 | if image.ndim == 2: 210 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 211 | [h, w, c] = image.shape 212 | if c == 4: 213 | image = image[:,:,:3] 214 | pos, vertices = self.get_vertices(face_lms, image) #3d reconstruction -> get texture. 215 | image = image/255. #!! 216 | texture = cv2.remap(image_segment, pos[:,:,:2].astype(np.float32), None, 217 | interpolation=cv2.INTER_NEAREST, 218 | borderMode=cv2.BORDER_CONSTANT,borderValue=(0)) 219 | 220 | imsave(output, texture) 221 | 222 | def get_vertices(self, face_lms, image): 223 | """Get vertices 224 | 225 | Args: 226 | face_lms: face landmarks. 227 | image:[0, 255] 228 | """ 229 | lms_info = read_info.read_landmark_106_array(face_lms) 230 | pos = self.prn.process(image, lms_info) 231 | vertices = self.prn.get_vertices(pos) 232 | return pos, vertices 233 | 234 | def get_new_texture(self, ref_texture_src, uv_mask_src, texture): 235 | """Get new texture 236 | Mainly for data augmentation. 237 | """ 238 | x_offset = 5 239 | y_offset = 5 240 | alpha = '0.5,0.8' 241 | beta = 0 242 | erode_iter = 5 243 | 244 | # random augmentation 245 | ref_texture = ref_texture_src.copy() 246 | uv_mask = uv_mask_src.copy() 247 | if self.is_aug: 248 | # random flip 249 | if np.random.rand()>0.5: 250 | ref_texture = cv2.flip(ref_texture, 1, dst=None) 251 | uv_mask = cv2.flip(uv_mask, 1, dst=None) 252 | # random scale, 253 | if np.random.rand()>0.5: 254 | x_offset = np.random.randint(x_offset) 255 | y_offset = np.random.randint(y_offset) 256 | ref_texture_temp = np.zeros_like(ref_texture) 257 | uv_mask_temp = np.zeros_like(uv_mask) 258 | target_size = (256-x_offset*2, 256-y_offset*2) 259 | ref_texture_temp[y_offset:256-y_offset, x_offset:256-x_offset,:] = cv2.resize(ref_texture, target_size) 260 | uv_mask_temp[y_offset:256-y_offset, x_offset:256-x_offset] = cv2.resize(uv_mask, target_size) 261 | ref_texture = ref_texture_temp 262 | uv_mask = uv_mask_temp 263 | # random erode 264 | if np.random.rand()>0.8: 265 | t = np.random.randint(erode_iter) 266 | kernel = np.ones((5,5),np.uint8) 267 | uv_mask = cv2.erode(uv_mask,kernel,iterations = t) 268 | # random contrast and brightness 269 | if np.random.rand()>0.5: 270 | alpha_r = [float(_) for _ in alpha.split(',')] 271 | alpha = (alpha_r[1] - alpha_r[0])*np.random.rand() + alpha_r[0] 272 | beta = beta 273 | img = ref_texture*255 274 | blank = np.zeros(img.shape, img.dtype) 275 | # dst = alpha * img + beta * blank 276 | dst = cv2.addWeighted(img, alpha, blank, 1-alpha, beta) 277 | ref_texture = dst.clip(0,255) / 255 278 | new_texture = texture*(1 - uv_mask[:,:,np.newaxis]) + ref_texture[:,:,:3]*uv_mask[:,:,np.newaxis] 279 | return new_texture 280 | 281 | class FaceMaskerMP: 282 | """Add a virtual mask in face. 283 | 284 | Attributes: 285 | uv_face_path(str): the path of uv_face. 286 | mask_template_folder(str): the directory where all mask template in. 287 | prn(object): PRN object, https://github.com/YadiraF/PRNet. 288 | template_name2ref_texture_src(dict): key is template name, value is the mask load by skimage.io. 289 | template_name2uv_mask_src(dict): key is template name, value is the uv_mask. 290 | is_aug(bool): whether or not to add some augmentaion operation on the mask. 291 | """ 292 | def __init__(self, is_aug, device, n_processes = 4, max_queue_len = 64): 293 | """init for FaceMasker 294 | 295 | Args: 296 | is_aug(bool): whether or not to add some augmentaion operation on the mask. 297 | """ 298 | self.device = device 299 | self.uv_face_path = 'Data/uv-data/uv_face_mask.png' 300 | self.mask_template_folder = 'Data/mask-data' 301 | self.prn = PRN('model/prnet.pth', device = self.device) 302 | self.template_name2ref_texture_src, self.template_name2uv_mask_src = self.get_ref_texture_src() 303 | self.is_aug = is_aug 304 | # self.n_processes = n_processes 305 | # self.q_in = [multiprocessing.Queue(max_queue_len) for i in range(self.n_processes)] 306 | # q_out = multiprocessing.Queue(max_queue_len) 307 | 308 | 309 | def get_ref_texture_src(self): 310 | template_name2ref_texture_src = {} 311 | template_name2uv_mask_src = {} 312 | mask_template_list = os.listdir(self.mask_template_folder) 313 | uv_face = imread(self.uv_face_path, as_gray=True)/255. 314 | for mask_template in mask_template_list: 315 | # print('Create UV map for template: ', mask_template) 316 | mask_template_path = os.path.join(self.mask_template_folder, mask_template) 317 | ref_texture_src = imread(mask_template_path, as_gray=False)/255. 318 | if ref_texture_src.shape[2] == 4: # must 4 channel, how about 3 channel? 319 | uv_mask_src = ref_texture_src[:,:,3] 320 | ref_texture_src = ref_texture_src[:,:,:3] 321 | else: 322 | print('Fatal error!', mask_template_path) 323 | uv_mask_src[uv_face == 0] = 0 324 | template_name2ref_texture_src[mask_template] = ref_texture_src 325 | template_name2uv_mask_src[mask_template] = uv_mask_src 326 | return template_name2ref_texture_src, template_name2uv_mask_src 327 | 328 | def add_mask(self, face_root, image_name2lms, image_name2template_name, masked_face_root): 329 | for image_name, face_lms in image_name2lms.items(): 330 | image_path = os.path.join(face_root, image_name) 331 | masked_face_path = os.path.join(masked_face_root, image_name) 332 | template_name = image_name2template_name[image_name] 333 | self.add_mask_one(image_path, face_lms, template_name, masked_face_path) 334 | 335 | # you can speed it up by a c++ version. 336 | def render(self, vertices, new_colors, h, w): 337 | vis_colors = np.ones((vertices.shape[0], 1)) 338 | face_mask = render_texture(vertices.T, vis_colors.T, self.prn.triangles.T, h, w, c=1).astype(np.uint8) 339 | face_mask = np.squeeze(face_mask > 0) 340 | new_image = render_texture(vertices.T, new_colors.T, self.prn.triangles.T, h, w, c=3) 341 | return face_mask, new_image 342 | 343 | def add_mask_one(self, image, face_lms, template_name, masked_face_path, padded = None, write_image = True): 344 | """Add mask to one image. 345 | 346 | Args: 347 | image_path(str): the image to add mask. 348 | face_lms(str): face landmarks, [x1, y1, x2, y2, ..., x106, y106] 349 | template_name(str): the mask template to be added on the current image, 350 | got to '/Data/mask-data' for all template. 351 | masked_face_path(str): the path to save masked image. 352 | """ 353 | # image = imread(image_path) 354 | t1 = time.time() 355 | ref_texture_src = self.template_name2ref_texture_src[template_name] 356 | uv_mask_src = self.template_name2uv_mask_src[template_name] 357 | if image.ndim == 2: 358 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 359 | [h, w, c] = image.shape 360 | if c == 4: 361 | image = image[:,:,:3] 362 | pos, vertices = self.get_vertices(face_lms, image) #3d reconstruction -> get texture. 363 | image = image/255. #!! 364 | texture = cv2.remap(image, pos[:,:,:2].astype(np.float32), None, 365 | interpolation=cv2.INTER_NEAREST, 366 | borderMode=cv2.BORDER_CONSTANT,borderValue=(0)) 367 | # print(texture.shape) 368 | # imsave('texture.jpg', texture) 369 | t2 = time.time() 370 | new_texture = self.get_new_texture(ref_texture_src, uv_mask_src, texture) 371 | new_colors = self.prn.get_colors_from_texture(new_texture) 372 | # print('Render cy') 373 | t3 = time.time() 374 | # render 375 | face_mask, new_image = render_cy(np.ascontiguousarray(vertices.T), np.ascontiguousarray(new_colors.T), np.ascontiguousarray(self.prn.triangles.T.astype(np.int64)), h, w) 376 | t4 = time.time() 377 | # imsave('face_mask.jpg', face_mask) 378 | # imsave('new_image.jpg', new_image) 379 | # print('Render done') 380 | face_mask = np.squeeze(np.floor(face_mask) > 0) 381 | 382 | tmp = new_image * face_mask[:, :, np.newaxis] 383 | new_image = image * (1 - face_mask[:, :, np.newaxis]) + new_image * face_mask[:, :, np.newaxis] 384 | new_image = np.clip(new_image, -1, 1) #must clip to (-1, 1)! 385 | t5 = time.time() 386 | # print('[FaceMasker] Time preprocess: ', t2 - t1) 387 | # print('[FaceMasker] Time feed: ', t3 - t2) 388 | # print('[FaceMasker] Time render: ', t4 - t3) 389 | # print('[FaceMasker] Time post-process: ', t5 - t4) 390 | if padded is not None: 391 | if write_image: 392 | imsave(masked_face_path, new_image[padded:-padded, padded:-padded, :]) 393 | return new_image[padded:-padded, padded:-padded, :] 394 | else: 395 | if write_image: 396 | imsave(masked_face_path, new_image) 397 | return new_image 398 | 399 | 400 | def mask_precompute(self, image, face_lms, template_name, masked_face_path, padded = None, write_image = True): 401 | """Add mask to one image. 402 | 403 | Args: 404 | image_path(str): the image to add mask. 405 | face_lms(str): face landmarks, [x1, y1, x2, y2, ..., x106, y106] 406 | template_name(str): the mask template to be added on the current image, 407 | got to '/Data/mask-data' for all template. 408 | masked_face_path(str): the path to save masked image. 409 | """ 410 | # image = imread(image_path) 411 | ref_texture_src = self.template_name2ref_texture_src[template_name] 412 | uv_mask_src = self.template_name2uv_mask_src[template_name] 413 | if image.ndim == 2: 414 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 415 | [h, w, c] = image.shape 416 | if c == 4: 417 | image = image[:,:,:3] 418 | t1 = time.time() 419 | pos, vertices = self.get_vertices(face_lms, image) #3d reconstruction -> get texture. 420 | t2 = time.time() 421 | print('get vertices: ', t2 - t1) 422 | image = image/255. #!! 423 | texture = cv2.remap(image, pos[:,:,:2].astype(np.float32), None, 424 | interpolation=cv2.INTER_NEAREST, 425 | borderMode=cv2.BORDER_CONSTANT,borderValue=(0)) 426 | # print(texture.shape) 427 | # imsave('texture.jpg', texture) 428 | new_texture = self.get_new_texture(ref_texture_src, uv_mask_src, texture) 429 | new_colors = self.prn.get_colors_from_texture(new_texture) 430 | # print('Render cy') 431 | return (image, vertices, new_colors, self.prn.triangles, h, w) 432 | 433 | def mask_render(self, image, vertices, new_colors, triangles, h, w): 434 | face_mask, new_image = render_cy(np.ascontiguousarray(vertices.T), np.ascontiguousarray(new_colors.T), np.ascontiguousarray(triangles.T.astype(np.int64)), h, w) 435 | face_mask = np.squeeze(np.floor(face_mask) > 0) 436 | new_image = image * (1 - face_mask[:, :, np.newaxis]) + new_image * face_mask[:, :, np.newaxis] 437 | new_image = np.clip(new_image, -1, 1) #must clip to (-1, 1)! 438 | return new_image 439 | 440 | 441 | def create_mask_one(self, image, image_segment, face_lms, output): 442 | """ Create mask for single input image 443 | """ 444 | # image = imread(image_path) 445 | if image.ndim == 2: 446 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 447 | [h, w, c] = image.shape 448 | if c == 4: 449 | image = image[:,:,:3] 450 | pos, vertices = self.get_vertices(face_lms, image) #3d reconstruction -> get texture. 451 | image = image/255. #!! 452 | texture = cv2.remap(image_segment, pos[:,:,:2].astype(np.float32), None, 453 | interpolation=cv2.INTER_NEAREST, 454 | borderMode=cv2.BORDER_CONSTANT,borderValue=(0)) 455 | 456 | imsave(output, texture) 457 | 458 | def get_vertices(self, face_lms, image): 459 | """Get vertices 460 | 461 | Args: 462 | face_lms: face landmarks. 463 | image:[0, 255] 464 | """ 465 | lms_info = read_info.read_landmark_106_array(face_lms) 466 | pos = self.prn.process(image, lms_info) 467 | vertices = self.prn.get_vertices(pos) 468 | return pos, vertices 469 | 470 | def get_new_texture(self, ref_texture_src, uv_mask_src, texture): 471 | """Get new texture 472 | Mainly for data augmentation. 473 | """ 474 | x_offset = 5 475 | y_offset = 5 476 | alpha = '0.5,0.8' 477 | beta = 0 478 | erode_iter = 5 479 | 480 | # random augmentation 481 | ref_texture = ref_texture_src.copy() 482 | uv_mask = uv_mask_src.copy() 483 | if self.is_aug: 484 | # random flip 485 | if np.random.rand()>0.5: 486 | ref_texture = cv2.flip(ref_texture, 1, dst=None) 487 | uv_mask = cv2.flip(uv_mask, 1, dst=None) 488 | # random scale, 489 | if np.random.rand()>0.5: 490 | x_offset = np.random.randint(x_offset) 491 | y_offset = np.random.randint(y_offset) 492 | ref_texture_temp = np.zeros_like(ref_texture) 493 | uv_mask_temp = np.zeros_like(uv_mask) 494 | target_size = (256-x_offset*2, 256-y_offset*2) 495 | ref_texture_temp[y_offset:256-y_offset, x_offset:256-x_offset,:] = cv2.resize(ref_texture, target_size) 496 | uv_mask_temp[y_offset:256-y_offset, x_offset:256-x_offset] = cv2.resize(uv_mask, target_size) 497 | ref_texture = ref_texture_temp 498 | uv_mask = uv_mask_temp 499 | # random erode 500 | if np.random.rand()>0.8: 501 | t = np.random.randint(erode_iter) 502 | kernel = np.ones((5,5),np.uint8) 503 | uv_mask = cv2.erode(uv_mask,kernel,iterations = t) 504 | # random contrast and brightness 505 | if np.random.rand()>0.5: 506 | alpha_r = [float(_) for _ in alpha.split(',')] 507 | alpha = (alpha_r[1] - alpha_r[0])*np.random.rand() + alpha_r[0] 508 | beta = beta 509 | img = ref_texture*255 510 | blank = np.zeros(img.shape, img.dtype) 511 | # dst = alpha * img + beta * blank 512 | dst = cv2.addWeighted(img, alpha, blank, 1-alpha, beta) 513 | ref_texture = dst.clip(0,255) / 255 514 | new_texture = texture*(1 - uv_mask[:,:,np.newaxis]) + ref_texture[:,:,:3]*uv_mask[:,:,np.newaxis] 515 | return new_texture 516 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from backbones import get_model 8 | 9 | 10 | @torch.no_grad() 11 | def inference(weight, name, img): 12 | if img is None: 13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 14 | else: 15 | img = cv2.imread(img) 16 | img = cv2.resize(img, (112, 112)) 17 | 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | img = np.transpose(img, (2, 0, 1)) 20 | img = torch.from_numpy(img).unsqueeze(0).float() 21 | img.div_(255).sub_(0.5).div_(0.5) 22 | net = get_model(name, fp16=False) 23 | net.load_state_dict(torch.load(weight)) 24 | net.eval() 25 | feat = net(img).numpy() 26 | print(feat) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') 31 | parser.add_argument('--network', type=str, default='r50', help='backbone network') 32 | parser.add_argument('--weight', type=str, default='') 33 | parser.add_argument('--img', type=str, default=None) 34 | args = parser.parse_args() 35 | inference(args.weight, args.network, args.img) 36 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | 6 | def get_loss(name): 7 | if name == "cosface": 8 | return CosFace() 9 | elif name == "arcface": 10 | return ArcFace() 11 | elif name == 'adaface': 12 | return AdaFace() 13 | else: 14 | raise ValueError() 15 | 16 | class AdaFace(torch.nn.Module): 17 | def __init__(self, 18 | embedding_size=512, 19 | m=0.4, 20 | h=0.333, 21 | s=64., 22 | t_alpha=1.0, 23 | ): 24 | super(AdaFace, self).__init__() 25 | 26 | # initial kernel 27 | # self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5) 28 | self.m = m 29 | self.eps = 1e-3 30 | self.h = h 31 | self.s = s 32 | 33 | # ema prep 34 | self.t_alpha = t_alpha 35 | self.register_buffer('t', torch.zeros(1)) 36 | self.register_buffer('batch_mean', torch.ones(1)*(20)) 37 | self.register_buffer('batch_std', torch.ones(1)*100) 38 | 39 | print('\n\AdaFace with the following property') 40 | print('self.m', self.m) 41 | print('self.h', self.h) 42 | print('self.s', self.s) 43 | print('self.t_alpha', self.t_alpha) 44 | 45 | def forward(self, cosine, norms, label): 46 | # print(label) 47 | index_positive = torch.where(label != -1)[0] 48 | # target_logits = cosine[index_positive] 49 | # target_norms = norms[index_positive] 50 | # target_labels = label[index_positive] 51 | 52 | safe_norms = torch.clip(norms, min=0.001, max=100) # for stability 53 | safe_norms = safe_norms.clone().detach() 54 | 55 | # update batchmean batchstd 56 | with torch.no_grad(): 57 | mean = safe_norms.mean().detach() 58 | std = safe_norms.std().detach() 59 | self.batch_mean = mean * self.t_alpha + (1 - self.t_alpha) * self.batch_mean 60 | self.batch_std = std * self.t_alpha + (1 - self.t_alpha) * self.batch_std 61 | 62 | margin_scaler = (safe_norms - self.batch_mean) / (self.batch_std+self.eps) # 66% between -1, 1 63 | margin_scaler = margin_scaler * self.h # 68% between -0.333 ,0.333 when h:0.333 64 | margin_scaler = torch.clip(margin_scaler, -1, 1) 65 | # ex: m=0.5, h:0.333 66 | # range 67 | # (66% range) 68 | # -1 -0.333 0.333 1 (margin_scaler) 69 | # -0.5 -0.166 0.166 0.5 (m * margin_scaler) 70 | # g_angular 71 | m_arc = torch.zeros(index_positive.size()[0], cosine.size()[1], device=cosine.device) 72 | m_arc.scatter_(1, label[index_positive].view(-1, 1), 1.0) 73 | g_angular = self.m * margin_scaler[index_positive] * -1 74 | m_arc = m_arc * g_angular 75 | cosine.acos_() 76 | cosine[index_positive] = torch.clip(cosine[index_positive] + m_arc, min=self.eps, max=math.pi-self.eps) 77 | cosine.cos_() 78 | 79 | # g_additive 80 | 81 | m_cos = torch.zeros(index_positive.size()[0], cosine.size()[1], device=cosine.device) 82 | m_cos.scatter_(1, label[index_positive].view(-1, 1), 1.0) 83 | g_add = self.m + (self.m * margin_scaler[index_positive]) 84 | m_cos = m_cos * g_add 85 | cosine[index_positive] -= m_cos 86 | 87 | # scale 88 | ret = cosine * self.s 89 | return ret 90 | 91 | class CosFace(nn.Module): 92 | def __init__(self, s=64.0, m=0.40): 93 | super(CosFace, self).__init__() 94 | self.s = s 95 | self.m = m 96 | 97 | def forward(self, cosine, label): 98 | index = torch.where(label != -1)[0] 99 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 100 | m_hot.scatter_(1, label[index, None], self.m) 101 | cosine[index] -= m_hot 102 | ret = cosine * self.s 103 | return ret 104 | 105 | 106 | class ArcFace(nn.Module): 107 | def __init__(self, s=64.0, m=0.5): 108 | super(ArcFace, self).__init__() 109 | self.s = s 110 | self.m = m 111 | 112 | def forward(self, cosine: torch.Tensor, label): 113 | index = torch.where(label != -1)[0] 114 | m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) 115 | m_hot.scatter_(1, label[index, None], self.m) 116 | cosine.acos_() 117 | cosine[index] += m_hot 118 | cosine.cos_().mul_(self.s) 119 | return cosine 120 | -------------------------------------------------------------------------------- /onnx_helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import datetime 3 | import os 4 | import os.path as osp 5 | import glob 6 | import numpy as np 7 | import cv2 8 | import sys 9 | import onnxruntime 10 | import onnx 11 | import argparse 12 | from onnx import numpy_helper 13 | from insightface.data import get_image 14 | 15 | class ArcFaceORT: 16 | def __init__(self, model_path, cpu=False): 17 | self.model_path = model_path 18 | # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider" 19 | self.providers = ['CPUExecutionProvider'] if cpu else None 20 | 21 | #input_size is (w,h), return error message, return None if success 22 | def check(self, track='cfat', test_img = None): 23 | #default is cfat 24 | max_model_size_mb=1024 25 | max_feat_dim=512 26 | max_time_cost=15 27 | if track.startswith('ms1m'): 28 | max_model_size_mb=1024 29 | max_feat_dim=512 30 | max_time_cost=10 31 | elif track.startswith('glint'): 32 | max_model_size_mb=1024 33 | max_feat_dim=1024 34 | max_time_cost=20 35 | elif track.startswith('cfat'): 36 | max_model_size_mb = 1024 37 | max_feat_dim = 512 38 | max_time_cost = 15 39 | elif track.startswith('unconstrained'): 40 | max_model_size_mb=1024 41 | max_feat_dim=1024 42 | max_time_cost=30 43 | else: 44 | return "track not found" 45 | 46 | if not os.path.exists(self.model_path): 47 | return "model_path not exists" 48 | if not os.path.isdir(self.model_path): 49 | return "model_path should be directory" 50 | onnx_files = [] 51 | for _file in os.listdir(self.model_path): 52 | if _file.endswith('.onnx'): 53 | onnx_files.append(osp.join(self.model_path, _file)) 54 | if len(onnx_files)==0: 55 | return "do not have onnx files" 56 | self.model_file = sorted(onnx_files)[-1] 57 | print('use onnx-model:', self.model_file) 58 | try: 59 | session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) 60 | except: 61 | return "load onnx failed" 62 | input_cfg = session.get_inputs()[0] 63 | input_shape = input_cfg.shape 64 | print('input-shape:', input_shape) 65 | if len(input_shape)!=4: 66 | return "length of input_shape should be 4" 67 | if not isinstance(input_shape[0], str): 68 | #return "input_shape[0] should be str to support batch-inference" 69 | print('reset input-shape[0] to None') 70 | model = onnx.load(self.model_file) 71 | model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' 72 | new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx') 73 | onnx.save(model, new_model_file) 74 | self.model_file = new_model_file 75 | print('use new onnx-model:', self.model_file) 76 | try: 77 | session = onnxruntime.InferenceSession(self.model_file, providers=self.providers) 78 | except: 79 | return "load onnx failed" 80 | input_cfg = session.get_inputs()[0] 81 | input_shape = input_cfg.shape 82 | print('new-input-shape:', input_shape) 83 | 84 | self.image_size = tuple(input_shape[2:4][::-1]) 85 | #print('image_size:', self.image_size) 86 | input_name = input_cfg.name 87 | outputs = session.get_outputs() 88 | output_names = [] 89 | for o in outputs: 90 | output_names.append(o.name) 91 | #print(o.name, o.shape) 92 | if len(output_names)!=1: 93 | return "number of output nodes should be 1" 94 | self.session = session 95 | self.input_name = input_name 96 | self.output_names = output_names 97 | #print(self.output_names) 98 | model = onnx.load(self.model_file) 99 | graph = model.graph 100 | if len(graph.node)<8: 101 | return "too small onnx graph" 102 | 103 | input_size = (112,112) 104 | self.crop = None 105 | if track=='cfat': 106 | crop_file = osp.join(self.model_path, 'crop.txt') 107 | if osp.exists(crop_file): 108 | lines = open(crop_file,'r').readlines() 109 | if len(lines)!=6: 110 | return "crop.txt should contain 6 lines" 111 | lines = [int(x) for x in lines] 112 | self.crop = lines[:4] 113 | input_size = tuple(lines[4:6]) 114 | if input_size!=self.image_size: 115 | return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size) 116 | 117 | self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024) 118 | if self.model_size_mb > max_model_size_mb: 119 | return "max model size exceed, given %.3f-MB"%self.model_size_mb 120 | 121 | input_mean = None 122 | input_std = None 123 | if track=='cfat': 124 | pn_file = osp.join(self.model_path, 'pixel_norm.txt') 125 | if osp.exists(pn_file): 126 | lines = open(pn_file,'r').readlines() 127 | if len(lines)!=2: 128 | return "pixel_norm.txt should contain 2 lines" 129 | input_mean = float(lines[0]) 130 | input_std = float(lines[1]) 131 | if input_mean is not None or input_std is not None: 132 | if input_mean is None or input_std is None: 133 | return "please set input_mean and input_std simultaneously" 134 | else: 135 | find_sub = False 136 | find_mul = False 137 | for nid, node in enumerate(graph.node[:8]): 138 | print(nid, node.name) 139 | if node.name.startswith('Sub') or node.name.startswith('_minus'): 140 | find_sub = True 141 | if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'): 142 | find_mul = True 143 | if find_sub and find_mul: 144 | print("find sub and mul") 145 | #mxnet arcface model 146 | input_mean = 0.0 147 | input_std = 1.0 148 | else: 149 | input_mean = 127.5 150 | input_std = 127.5 151 | self.input_mean = input_mean 152 | self.input_std = input_std 153 | for initn in graph.initializer: 154 | weight_array = numpy_helper.to_array(initn) 155 | dt = weight_array.dtype 156 | if dt.itemsize<4: 157 | return 'invalid weight type - (%s:%s)' % (initn.name, dt.name) 158 | if test_img is None: 159 | test_img = get_image('Tom_Hanks_54745') 160 | test_img = cv2.resize(test_img, self.image_size) 161 | else: 162 | test_img = cv2.resize(test_img, self.image_size) 163 | feat, cost = self.benchmark(test_img) 164 | batch_result = self.check_batch(test_img) 165 | batch_result_sum = float(np.sum(batch_result)) 166 | if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum: 167 | print(batch_result) 168 | print(batch_result_sum) 169 | return "batch result output contains NaN!" 170 | 171 | if len(feat.shape) < 2: 172 | return "the shape of the feature must be two, but get {}".format(str(feat.shape)) 173 | 174 | if feat.shape[1] > max_feat_dim: 175 | return "max feat dim exceed, given %d"%feat.shape[1] 176 | self.feat_dim = feat.shape[1] 177 | cost_ms = cost*1000 178 | if cost_ms>max_time_cost: 179 | return "max time cost exceed, given %.4f"%cost_ms 180 | self.cost_ms = cost_ms 181 | print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std)) 182 | return None 183 | 184 | def check_batch(self, img): 185 | if not isinstance(img, list): 186 | imgs = [img, ] * 32 187 | if self.crop is not None: 188 | nimgs = [] 189 | for img in imgs: 190 | nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :] 191 | if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]: 192 | nimg = cv2.resize(nimg, self.image_size) 193 | nimgs.append(nimg) 194 | imgs = nimgs 195 | blob = cv2.dnn.blobFromImages( 196 | images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size, 197 | mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True) 198 | net_out = self.session.run(self.output_names, {self.input_name: blob})[0] 199 | return net_out 200 | 201 | 202 | def meta_info(self): 203 | return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms} 204 | 205 | 206 | def forward(self, imgs): 207 | if not isinstance(imgs, list): 208 | imgs = [imgs] 209 | input_size = self.image_size 210 | if self.crop is not None: 211 | nimgs = [] 212 | for img in imgs: 213 | nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] 214 | if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: 215 | nimg = cv2.resize(nimg, input_size) 216 | nimgs.append(nimg) 217 | imgs = nimgs 218 | blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 219 | net_out = self.session.run(self.output_names, {self.input_name : blob})[0] 220 | return net_out 221 | 222 | def benchmark(self, img): 223 | input_size = self.image_size 224 | if self.crop is not None: 225 | nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:] 226 | if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]: 227 | nimg = cv2.resize(nimg, input_size) 228 | img = nimg 229 | blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 230 | costs = [] 231 | for _ in range(50): 232 | ta = datetime.datetime.now() 233 | net_out = self.session.run(self.output_names, {self.input_name : blob})[0] 234 | tb = datetime.datetime.now() 235 | cost = (tb-ta).total_seconds() 236 | costs.append(cost) 237 | costs = sorted(costs) 238 | cost = costs[5] 239 | return net_out, cost 240 | 241 | 242 | if __name__ == '__main__': 243 | parser = argparse.ArgumentParser(description='') 244 | # general 245 | parser.add_argument('workdir', help='submitted work dir', type=str) 246 | parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat') 247 | args = parser.parse_args() 248 | handler = ArcFaceORT(args.workdir) 249 | err = handler.check(args.track) 250 | print('err:', err) 251 | -------------------------------------------------------------------------------- /onnx_ijbc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import timeit 5 | 6 | import cv2 7 | import mxnet as mx 8 | import numpy as np 9 | import pandas as pd 10 | import prettytable 11 | import skimage.transform 12 | from sklearn.metrics import roc_curve 13 | from sklearn.preprocessing import normalize 14 | 15 | from onnx_helper import ArcFaceORT 16 | 17 | SRC = np.array( 18 | [ 19 | [30.2946, 51.6963], 20 | [65.5318, 51.5014], 21 | [48.0252, 71.7366], 22 | [33.5493, 92.3655], 23 | [62.7299, 92.2041]] 24 | , dtype=np.float32) 25 | SRC[:, 0] += 8.0 26 | 27 | 28 | class AlignedDataSet(mx.gluon.data.Dataset): 29 | def __init__(self, root, lines, align=True): 30 | self.lines = lines 31 | self.root = root 32 | self.align = align 33 | 34 | def __len__(self): 35 | return len(self.lines) 36 | 37 | def __getitem__(self, idx): 38 | each_line = self.lines[idx] 39 | name_lmk_score = each_line.strip().split(' ') 40 | name = os.path.join(self.root, name_lmk_score[0]) 41 | img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB) 42 | landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2)) 43 | st = skimage.transform.SimilarityTransform() 44 | st.estimate(landmark5, SRC) 45 | img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0) 46 | img_1 = np.expand_dims(img, 0) 47 | img_2 = np.expand_dims(np.fliplr(img), 0) 48 | output = np.concatenate((img_1, img_2), axis=0).astype(np.float32) 49 | output = np.transpose(output, (0, 3, 1, 2)) 50 | output = mx.nd.array(output) 51 | return output 52 | 53 | 54 | def extract(model_root, dataset): 55 | model = ArcFaceORT(model_path=model_root) 56 | model.check() 57 | feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim)) 58 | 59 | def batchify_fn(data): 60 | return mx.nd.concat(*data, dim=0) 61 | 62 | data_loader = mx.gluon.data.DataLoader( 63 | dataset, 128, last_batch='keep', num_workers=4, 64 | thread_pool=True, prefetch=16, batchify_fn=batchify_fn) 65 | num_iter = 0 66 | for batch in data_loader: 67 | batch = batch.asnumpy() 68 | batch = (batch - model.input_mean) / model.input_std 69 | feat = model.session.run(model.output_names, {model.input_name: batch})[0] 70 | feat = np.reshape(feat, (-1, model.feat_dim * 2)) 71 | feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat 72 | num_iter += 1 73 | if num_iter % 50 == 0: 74 | print(num_iter) 75 | return feat_mat 76 | 77 | 78 | def read_template_media_list(path): 79 | ijb_meta = pd.read_csv(path, sep=' ', header=None).values 80 | templates = ijb_meta[:, 1].astype(np.int) 81 | medias = ijb_meta[:, 2].astype(np.int) 82 | return templates, medias 83 | 84 | 85 | def read_template_pair_list(path): 86 | pairs = pd.read_csv(path, sep=' ', header=None).values 87 | t1 = pairs[:, 0].astype(np.int) 88 | t2 = pairs[:, 1].astype(np.int) 89 | label = pairs[:, 2].astype(np.int) 90 | return t1, t2, label 91 | 92 | 93 | def read_image_feature(path): 94 | with open(path, 'rb') as fid: 95 | img_feats = pickle.load(fid) 96 | return img_feats 97 | 98 | 99 | def image2template_feature(img_feats=None, 100 | templates=None, 101 | medias=None): 102 | unique_templates = np.unique(templates) 103 | template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) 104 | for count_template, uqt in enumerate(unique_templates): 105 | (ind_t,) = np.where(templates == uqt) 106 | face_norm_feats = img_feats[ind_t] 107 | face_medias = medias[ind_t] 108 | unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) 109 | media_norm_feats = [] 110 | for u, ct in zip(unique_medias, unique_media_counts): 111 | (ind_m,) = np.where(face_medias == u) 112 | if ct == 1: 113 | media_norm_feats += [face_norm_feats[ind_m]] 114 | else: # image features from the same video will be aggregated into one feature 115 | media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ] 116 | media_norm_feats = np.array(media_norm_feats) 117 | template_feats[count_template] = np.sum(media_norm_feats, axis=0) 118 | if count_template % 2000 == 0: 119 | print('Finish Calculating {} template features.'.format( 120 | count_template)) 121 | template_norm_feats = normalize(template_feats) 122 | return template_norm_feats, unique_templates 123 | 124 | 125 | def verification(template_norm_feats=None, 126 | unique_templates=None, 127 | p1=None, 128 | p2=None): 129 | template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) 130 | for count_template, uqt in enumerate(unique_templates): 131 | template2id[uqt] = count_template 132 | score = np.zeros((len(p1),)) 133 | total_pairs = np.array(range(len(p1))) 134 | batchsize = 100000 135 | sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)] 136 | total_sublists = len(sublists) 137 | for c, s in enumerate(sublists): 138 | feat1 = template_norm_feats[template2id[p1[s]]] 139 | feat2 = template_norm_feats[template2id[p2[s]]] 140 | similarity_score = np.sum(feat1 * feat2, -1) 141 | score[s] = similarity_score.flatten() 142 | if c % 10 == 0: 143 | print('Finish {}/{} pairs.'.format(c, total_sublists)) 144 | return score 145 | 146 | 147 | def verification2(template_norm_feats=None, 148 | unique_templates=None, 149 | p1=None, 150 | p2=None): 151 | template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) 152 | for count_template, uqt in enumerate(unique_templates): 153 | template2id[uqt] = count_template 154 | score = np.zeros((len(p1),)) # save cosine distance between pairs 155 | total_pairs = np.array(range(len(p1))) 156 | batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation 157 | sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)] 158 | total_sublists = len(sublists) 159 | for c, s in enumerate(sublists): 160 | feat1 = template_norm_feats[template2id[p1[s]]] 161 | feat2 = template_norm_feats[template2id[p2[s]]] 162 | similarity_score = np.sum(feat1 * feat2, -1) 163 | score[s] = similarity_score.flatten() 164 | if c % 10 == 0: 165 | print('Finish {}/{} pairs.'.format(c, total_sublists)) 166 | return score 167 | 168 | 169 | def main(args): 170 | use_norm_score = True # if Ture, TestMode(N1) 171 | use_detector_score = True # if Ture, TestMode(D1) 172 | use_flip_test = True # if Ture, TestMode(F1) 173 | assert args.target == 'IJBC' or args.target == 'IJBB' 174 | 175 | start = timeit.default_timer() 176 | templates, medias = read_template_media_list( 177 | os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower())) 178 | stop = timeit.default_timer() 179 | print('Time: %.2f s. ' % (stop - start)) 180 | 181 | start = timeit.default_timer() 182 | p1, p2, label = read_template_pair_list( 183 | os.path.join('%s/meta' % args.image_path, 184 | '%s_template_pair_label.txt' % args.target.lower())) 185 | stop = timeit.default_timer() 186 | print('Time: %.2f s. ' % (stop - start)) 187 | 188 | start = timeit.default_timer() 189 | img_path = '%s/loose_crop' % args.image_path 190 | img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower()) 191 | img_list = open(img_list_path) 192 | files = img_list.readlines() 193 | dataset = AlignedDataSet(root=img_path, lines=files, align=True) 194 | img_feats = extract(args.model_root, dataset) 195 | 196 | faceness_scores = [] 197 | for each_line in files: 198 | name_lmk_score = each_line.split() 199 | faceness_scores.append(name_lmk_score[-1]) 200 | faceness_scores = np.array(faceness_scores).astype(np.float32) 201 | stop = timeit.default_timer() 202 | print('Time: %.2f s. ' % (stop - start)) 203 | print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1])) 204 | start = timeit.default_timer() 205 | 206 | if use_flip_test: 207 | img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:] 208 | else: 209 | img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] 210 | 211 | if use_norm_score: 212 | img_input_feats = img_input_feats 213 | else: 214 | img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) 215 | 216 | if use_detector_score: 217 | print(img_input_feats.shape, faceness_scores.shape) 218 | img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] 219 | else: 220 | img_input_feats = img_input_feats 221 | 222 | template_norm_feats, unique_templates = image2template_feature( 223 | img_input_feats, templates, medias) 224 | stop = timeit.default_timer() 225 | print('Time: %.2f s. ' % (stop - start)) 226 | 227 | start = timeit.default_timer() 228 | score = verification(template_norm_feats, unique_templates, p1, p2) 229 | stop = timeit.default_timer() 230 | print('Time: %.2f s. ' % (stop - start)) 231 | save_path = os.path.join(args.result_dir, "{}_result".format(args.target)) 232 | if not os.path.exists(save_path): 233 | os.makedirs(save_path) 234 | score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root)) 235 | np.save(score_save_file, score) 236 | files = [score_save_file] 237 | methods = [] 238 | scores = [] 239 | for file in files: 240 | methods.append(os.path.basename(file)) 241 | scores.append(np.load(file)) 242 | methods = np.array(methods) 243 | scores = dict(zip(methods, scores)) 244 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 245 | tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels]) 246 | for method in methods: 247 | fpr, tpr, _ = roc_curve(label, scores[method]) 248 | fpr = np.flipud(fpr) 249 | tpr = np.flipud(tpr) 250 | tpr_fpr_row = [] 251 | tpr_fpr_row.append("%s-%s" % (method, args.target)) 252 | for fpr_iter in np.arange(len(x_labels)): 253 | _, min_index = min( 254 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 255 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 256 | tpr_fpr_table.add_row(tpr_fpr_row) 257 | print(tpr_fpr_table) 258 | 259 | 260 | if __name__ == '__main__': 261 | parser = argparse.ArgumentParser(description='do ijb test') 262 | # general 263 | parser.add_argument('--model-root', default='', help='path to load model.') 264 | parser.add_argument('--image-path', default='', type=str, help='') 265 | parser.add_argument('--result-dir', default='.', type=str, help='') 266 | parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB') 267 | main(parser.parse_args()) 268 | -------------------------------------------------------------------------------- /partial_fc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.nn import Module 7 | from torch.nn.functional import normalize, linear 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | class PartialFC(Module): 12 | """ 13 | Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint, 14 | Partial FC: Training 10 Million Identities on a Single Machine 15 | See the original paper: 16 | https://arxiv.org/abs/2010.05222 17 | """ 18 | 19 | @torch.no_grad() 20 | def __init__(self, rank, local_rank, world_size, batch_size, resume, 21 | margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./", loss_type = 'arcface'): 22 | """ 23 | rank: int 24 | Unique process(GPU) ID from 0 to world_size - 1. 25 | local_rank: int 26 | Unique process(GPU) ID within the server from 0 to 7. 27 | world_size: int 28 | Number of GPU. 29 | batch_size: int 30 | Batch size on current rank(GPU). 31 | resume: bool 32 | Select whether to restore the weight of softmax. 33 | margin_softmax: callable 34 | A function of margin softmax, eg: cosface, arcface. 35 | num_classes: int 36 | The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size, 37 | required. 38 | sample_rate: float 39 | The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling 40 | can greatly speed up training, and reduce a lot of GPU memory, default is 1.0. 41 | embedding_size: int 42 | The feature dimension, default is 512. 43 | prefix: str 44 | Path for save checkpoint, default is './'. 45 | """ 46 | super(PartialFC, self).__init__() 47 | # 48 | self.num_classes: int = num_classes 49 | self.rank: int = rank 50 | self.local_rank: int = local_rank 51 | self.device: torch.device = torch.device("cuda:{}".format(self.local_rank)) 52 | self.world_size: int = world_size 53 | self.batch_size: int = batch_size 54 | self.margin_softmax: callable = margin_softmax 55 | self.sample_rate: float = sample_rate 56 | self.embedding_size: int = embedding_size 57 | self.prefix: str = prefix 58 | self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size) 59 | self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size) 60 | self.num_sample: int = int(self.sample_rate * self.num_local) 61 | self.loss_type = loss_type 62 | self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank)) 63 | self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank)) 64 | print(self.weight_name) 65 | print(self.weight_mom_name) 66 | if resume: 67 | try: 68 | self.weight: torch.Tensor = torch.load(self.weight_name) 69 | self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) 70 | if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local: 71 | raise IndexError 72 | logging.info("softmax weight resume successfully!") 73 | logging.info("softmax weight mom resume successfully!") 74 | except (FileNotFoundError, KeyError, IndexError): 75 | self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) 76 | self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) 77 | logging.info("softmax weight init!") 78 | logging.info("softmax weight mom init!") 79 | except RuntimeError: 80 | print('Error loadding: ', self.weight_name) 81 | self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) 82 | self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) 83 | else: 84 | self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) 85 | self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) 86 | logging.info("softmax weight init successfully!") 87 | logging.info("softmax weight mom init successfully!") 88 | self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank) 89 | 90 | self.index = None 91 | if int(self.sample_rate) == 1: 92 | self.update = lambda: 0 93 | self.sub_weight = Parameter(self.weight) 94 | self.sub_weight_mom = self.weight_mom 95 | else: 96 | self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank)) 97 | 98 | def save_params(self, folder = None): 99 | """ Save softmax weight for each rank on prefix 100 | """ 101 | if folder is not None: 102 | weight_name = os.path.join(folder, "rank_{}_softmax_weight.pt".format(self.rank)) 103 | weight_mom_name = os.path.join(folder, "rank_{}_softmax_weight_mom.pt".format(self.rank)) 104 | torch.save(self.weight.data, weight_name) 105 | torch.save(self.weight_mom, weight_mom_name) 106 | else: 107 | torch.save(self.weight.data, self.weight_name) 108 | torch.save(self.weight_mom, self.weight_mom_name) 109 | 110 | @torch.no_grad() 111 | def sample(self, total_label): 112 | """ 113 | Sample all positive class centers in each rank, and random select neg class centers to filling a fixed 114 | `num_sample`. 115 | 116 | total_label: tensor 117 | Label after all gather, which cross all GPUs. 118 | """ 119 | index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local) 120 | total_label[~index_positive] = -1 121 | total_label[index_positive] -= self.class_start 122 | if int(self.sample_rate) != 1: 123 | positive = torch.unique(total_label[index_positive], sorted=True) 124 | if self.num_sample - positive.size(0) >= 0: 125 | perm = torch.rand(size=[self.num_local], device=self.device) 126 | perm[positive] = 2.0 127 | index = torch.topk(perm, k=self.num_sample)[1] 128 | index = index.sort()[0] 129 | else: 130 | index = positive 131 | self.index = index 132 | total_label[index_positive] = torch.searchsorted(index, total_label[index_positive]) 133 | self.sub_weight = Parameter(self.weight[index]) 134 | self.sub_weight_mom = self.weight_mom[index] 135 | 136 | def forward(self, total_features, norm_weight): 137 | """ Partial fc forward, `logits = X * sample(W)` 138 | """ 139 | torch.cuda.current_stream().wait_stream(self.stream) 140 | logits = linear(total_features, norm_weight) 141 | 142 | return logits 143 | 144 | @torch.no_grad() 145 | def update(self): 146 | """ Set updated weight and weight_mom to memory bank. 147 | """ 148 | self.weight_mom[self.index] = self.sub_weight_mom 149 | self.weight[self.index] = self.sub_weight 150 | 151 | def prepare(self, label, optimizer): 152 | """ 153 | get sampled class centers for cal softmax. 154 | 155 | label: tensor 156 | Label tensor on each rank. 157 | optimizer: opt 158 | Optimizer for partial fc, which need to get weight mom. 159 | """ 160 | with torch.cuda.stream(self.stream): 161 | total_label = torch.zeros( 162 | size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long) 163 | dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label) 164 | # print('Local rank {}: total_label in {}'.format(self.local_rank, total_label), label.shape) 165 | self.sample(total_label) 166 | # print('Local rank {}: total_label out {}'.format(self.local_rank, total_label), label.shape) 167 | optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None) 168 | optimizer.param_groups[-1]['params'][0] = self.sub_weight 169 | optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom 170 | norm_weight = normalize(self.sub_weight) 171 | return total_label, norm_weight 172 | 173 | def forward_backward(self, label, features, optimizer, norms = None): 174 | """ 175 | Partial fc forward and backward with model parallel 176 | 177 | label: tensor 178 | Label tensor on each rank(GPU) 179 | features: tensor 180 | Features tensor on each rank(GPU) 181 | optimizer: optimizer 182 | Optimizer for partial fc 183 | 184 | Returns: 185 | -------- 186 | x_grad: tensor 187 | The gradient of features. 188 | loss_v: tensor 189 | Loss value for cross entropy. 190 | """ 191 | total_label, norm_weight = self.prepare(label, optimizer) 192 | total_features = torch.zeros( 193 | size=[self.batch_size * self.world_size, self.embedding_size], device=self.device) 194 | dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data) 195 | total_features.requires_grad = True 196 | 197 | if self.loss_type == 'adaface': 198 | total_norms = torch.zeros( 199 | size=[self.batch_size * self.world_size, 1], device=self.device) 200 | dist.all_gather(list(total_norms.chunk(self.world_size, dim=0)), norms.data) 201 | total_norms.requires_grad = True 202 | 203 | # print('Local rank {}: norms {}'.format(self.local_rank, norms), norms.shape) 204 | # print('Local rank {}: total norms {}'.format(self.local_rank, total_norms), total_norms.shape) 205 | logits = self.forward(total_features, norm_weight) 206 | if self.loss_type == 'adaface': 207 | assert norms is not None, "Get no input for norms variable" 208 | logits = self.margin_softmax(logits, total_norms, total_label) 209 | else: 210 | logits = self.margin_softmax(logits, total_label) 211 | 212 | with torch.no_grad(): 213 | max_fc = torch.max(logits, dim=1, keepdim=True)[0] 214 | dist.all_reduce(max_fc, dist.ReduceOp.MAX) 215 | 216 | # calculate exp(logits) and all-reduce 217 | logits_exp = torch.exp(logits - max_fc) 218 | logits_sum_exp = logits_exp.sum(dim=1, keepdims=True) 219 | dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM) 220 | 221 | # calculate prob 222 | logits_exp.div_(logits_sum_exp) 223 | 224 | # get one-hot 225 | grad = logits_exp 226 | # print('total_label: ', total_label) 227 | # print('\n') 228 | index = torch.where(total_label != -1)[0] 229 | one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device) 230 | one_hot.scatter_(1, total_label[index, None], 1) 231 | 232 | # calculate loss 233 | loss = torch.zeros(grad.size()[0], 1, device=grad.device) 234 | loss[index] = grad[index].gather(1, total_label[index, None]) 235 | dist.all_reduce(loss, dist.ReduceOp.SUM) 236 | loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1) 237 | 238 | # calculate grad 239 | grad[index] -= one_hot 240 | grad.div_(self.batch_size * self.world_size) 241 | 242 | logits.backward(grad) 243 | if total_features.grad is not None: 244 | total_features.grad.detach_() 245 | x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True) 246 | # feature gradient all-reduce 247 | dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0))) 248 | x_grad = x_grad * self.world_size 249 | # backward backbone 250 | return x_grad, loss_v 251 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50 2 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh 3 | -------------------------------------------------------------------------------- /torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float) 10 | img = (img / 255. - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight) 16 | net.eval() 17 | torch.onnx.export(net, 18 | img, 19 | output, 20 | export_params=True, 21 | do_constant_folding=True, 22 | verbose=False, 23 | opset_version=opset, 24 | input_names = ['input'], # the model's input names 25 | output_names = ['output'], # the model's output names 26 | dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes 27 | 'output' : {0 : 'batch_size'}} 28 | ) 29 | model = onnx.load(output) 30 | graph = model.graph 31 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' 32 | if simplify: 33 | from onnxsim import simplify 34 | model, check = simplify(model) 35 | assert check, "Simplified ONNX model could not be validated" 36 | onnx.save(model, output) 37 | 38 | 39 | if __name__ == '__main__': 40 | import os 41 | import argparse 42 | from backbones import get_model 43 | 44 | parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') 45 | parser.add_argument('input', type=str, help='input backbone.pth file or path') 46 | parser.add_argument('--output', type=str, default=None, help='output onnx path') 47 | parser.add_argument('--network', type=str, default=None, help='backbone network') 48 | parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') 49 | args = parser.parse_args() 50 | input_file = args.input 51 | if os.path.isdir(input_file): 52 | input_file = os.path.join(input_file, "backbone.pth") 53 | assert os.path.exists(input_file) 54 | model_name = os.path.basename(os.path.dirname(input_file)).lower() 55 | params = model_name.split("_") 56 | if len(params) >= 3 and params[1] in ('arcface', 'cosface'): 57 | if args.network is None: 58 | args.network = params[2] 59 | assert args.network is not None 60 | print(args) 61 | backbone_onnx = get_model(args.network, dropout=0) 62 | 63 | output_path = args.output 64 | if output_path is None: 65 | output_path = os.path.join(os.path.dirname(__file__), 'onnx') 66 | if not os.path.exists(output_path): 67 | os.makedirs(output_path) 68 | assert os.path.isdir(output_path) 69 | output_file = os.path.join(output_path, "%s.onnx" % model_name) 70 | convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify) 71 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import time 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | import torch.utils.data.distributed 9 | from torch.nn.utils import clip_grad_norm_ 10 | 11 | import losses 12 | from backbones import get_model 13 | from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX, AdaFaceDataset 14 | from partial_fc import PartialFC 15 | from utils.utils_amp import MaxClipGradScaler 16 | from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint 17 | from utils.utils_config import get_config 18 | from utils.utils_logging import AverageMeter, init_logging 19 | from torch.multiprocessing import set_start_method 20 | 21 | def main(args): 22 | cfg = get_config(args.config) 23 | try: 24 | world_size = int(os.environ['WORLD_SIZE']) 25 | rank = int(os.environ['RANK']) 26 | dist.init_process_group('nccl') 27 | except KeyError: 28 | world_size = 1 29 | rank = 0 30 | dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size) 31 | 32 | local_rank = args.local_rank 33 | torch.cuda.set_device(local_rank) 34 | os.makedirs(cfg.output, exist_ok=True) 35 | init_logging(rank, cfg.output) 36 | 37 | if cfg.rec == "synthetic": 38 | print('Using Synthetic dataloader') 39 | train_set = SyntheticDataset(local_rank=local_rank) 40 | elif cfg.loss == 'adaface' : 41 | print('Using AdaFace dataloader') 42 | train_set = AdaFaceDataset(root_dir=cfg.rec, local_rank=local_rank) 43 | else: 44 | print('Using ArcFace/CosFace/SphereFace dataloader') 45 | train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) 46 | 47 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) 48 | train_loader = DataLoaderX( 49 | local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, 50 | sampler=train_sampler, num_workers=8, pin_memory=True, drop_last=True) 51 | backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) 52 | 53 | if cfg.resume: 54 | try: 55 | backbone_pth = os.path.join(cfg.output, "backbone.pth") 56 | print(backbone_pth) 57 | backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) 58 | if rank == 0: 59 | logging.info("backbone resume successfully!") 60 | except (FileNotFoundError, KeyError, IndexError, RuntimeError): 61 | if rank == 0: 62 | logging.info("resume fail, backbone init successfully!") 63 | 64 | backbone = torch.nn.parallel.DistributedDataParallel( 65 | module=backbone, broadcast_buffers=False, device_ids=[local_rank]) 66 | backbone.train() 67 | margin_softmax = losses.get_loss(cfg.loss).cuda() 68 | module_partial_fc = PartialFC( 69 | rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, 70 | batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, 71 | sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output, loss_type = cfg.loss) 72 | 73 | opt_backbone = torch.optim.SGD( 74 | params=[{'params': backbone.parameters()}], 75 | lr=cfg.lr / 512 * cfg.batch_size * world_size, 76 | momentum=0.9, weight_decay=cfg.weight_decay) 77 | opt_pfc = torch.optim.SGD( 78 | params=[{'params': module_partial_fc.parameters()}], 79 | lr=cfg.lr / 512 * cfg.batch_size * world_size, 80 | momentum=0.9, weight_decay=cfg.weight_decay) 81 | 82 | num_image = len(train_set) 83 | total_batch_size = cfg.batch_size * world_size 84 | cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch 85 | cfg.total_step = num_image // total_batch_size * cfg.num_epoch 86 | 87 | def lr_step_func(current_step): 88 | cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] 89 | if current_step < cfg.warmup_step: 90 | return current_step / cfg.warmup_step 91 | else: 92 | return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) 93 | 94 | scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( 95 | optimizer=opt_backbone, lr_lambda=lr_step_func) 96 | scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( 97 | optimizer=opt_pfc, lr_lambda=lr_step_func) 98 | 99 | for key, value in cfg.items(): 100 | num_space = 25 - len(key) 101 | logging.info(": " + key + " " * num_space + str(value)) 102 | 103 | val_target = cfg.val_targets 104 | callback_verification = CallBackVerification(30000, rank, val_target, cfg.rec) 105 | callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) 106 | callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) 107 | 108 | loss = AverageMeter() 109 | start_epoch = 0 110 | global_step = 0 111 | grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None 112 | 113 | print('Start training') 114 | for epoch in range(start_epoch, cfg.num_epoch): 115 | train_sampler.set_epoch(epoch) 116 | for step, (img, label) in enumerate(train_loader): 117 | global_step += 1 118 | if cfg.loss == 'adaface': 119 | raw_features = backbone(img) 120 | norms = torch.norm(raw_features, 2, -1, keepdim=True) 121 | features = F.normalize(raw_features) 122 | x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc, norms) 123 | else: 124 | features = F.normalize(backbone(img)) 125 | x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) 126 | if cfg.fp16: 127 | features.backward(grad_amp.scale(x_grad)) 128 | grad_amp.unscale_(opt_backbone) 129 | clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) 130 | grad_amp.step(opt_backbone) 131 | grad_amp.update() 132 | else: 133 | features.backward(x_grad) 134 | clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) 135 | opt_backbone.step() 136 | t3 = time.time() 137 | 138 | opt_pfc.step() 139 | module_partial_fc.update() 140 | opt_backbone.zero_grad() 141 | opt_pfc.zero_grad() 142 | loss.update(loss_v, 1) 143 | callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) 144 | callback_verification(global_step, module_partial_fc, backbone) 145 | callback_checkpoint(global_step, backbone, module_partial_fc, frequent = 30000) 146 | scheduler_backbone.step() 147 | scheduler_pfc.step() 148 | dist.destroy_process_group() 149 | 150 | 151 | if __name__ == "__main__": 152 | torch.backends.cudnn.benchmark = True 153 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') 154 | parser.add_argument('config', type=str, help='py config file') 155 | parser.add_argument('--local_rank', type=int, default=0, help='local_rank') 156 | main(parser.parse_args()) 157 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | from io import BytesIO 5 | import random 6 | 7 | def transform_JPEGcompression(image, compress_range = (30, 100)): 8 | ''' 9 | Perform random JPEG Compression 10 | ''' 11 | if random.random() < 0.15: 12 | assert compress_range[0] < compress_range[1], "Lower and higher value not accepted: {} vs {}".format(compress_range[0], compress_range[1]) 13 | jpegcompress_value = random.randint(compress_range[0], compress_range[1]) 14 | out = BytesIO() 15 | image.save(out, 'JPEG', quality=jpegcompress_value) 16 | out.seek(0) 17 | rgb_image = Image.open(out) 18 | return rgb_image 19 | else: 20 | return image 21 | 22 | def transform_gaussian_noise(img_pil, mean = 0.0, var = 10.0): 23 | ''' 24 | Perform random gaussian noise 25 | ''' 26 | if random.random() < 0.15: 27 | img = np.array(img_pil) 28 | height, width, channels = img.shape 29 | sigma = var**0.5 30 | gauss = np.random.normal(mean, sigma,(height, width, channels)) 31 | noisy = img + gauss 32 | cv2.normalize(noisy, noisy, 0, 255, cv2.NORM_MINMAX, dtype=-1) 33 | noisy = noisy.astype(np.uint8) 34 | return Image.fromarray(noisy) 35 | else: 36 | return img_pil 37 | 38 | def transform_resize(image, resize_range = (32, 112), target_size = 112): 39 | if random.random() < 0.15: 40 | assert resize_range[0] < resize_range[1], "Lower and higher value not accepted: {} vs {}".format(resize_range[0], resize_range[1]) 41 | resize_value = random.randint(resize_range[0], resize_range[1]) 42 | resize_image = image.resize((resize_value, resize_value)) 43 | return resize_image.resize((target_size, target_size)) 44 | else: 45 | return image 46 | 47 | def transform_eraser(image): 48 | if random.random() < 0.15: 49 | mask_range = random.randint(0, 3) 50 | image_array = np.array(image, dtype=np.uint8) 51 | image_array[(7-mask_range)*16:, :, :] = 0 52 | return Image.fromarray(image_array) 53 | else: 54 | return image -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/read_info.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/read_info.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/read_info.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/read_info.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_amp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_amp.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_amp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_amp.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_callbacks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_callbacks.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_config.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_logging.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_logging.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/__pycache__/utils_logging.cpython-38.pyc -------------------------------------------------------------------------------- /utils/cython/README.md: -------------------------------------------------------------------------------- 1 | ## What's this 2 | 3 | this is a Cython implementation to speed up adding mask rendering process 4 | 5 | ## Usage 6 | 7 | just run `python setup.py build_ext -i` to generate a share file 8 | 9 | ## Compare 10 | 11 | As an example, for **add_mask_one.py** on my own computer, cumtime in **face_masker.add_mask_one** is down from10.885s to 0.249s compare to origin method 12 | 13 | ## Other 14 | 15 | Author: [cbwces](https://github.com/cbwces) 16 | 17 | Mail: sknyqbcbw@gmail.com 18 | -------------------------------------------------------------------------------- /utils/cython/UNKNOWN.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: UNKNOWN 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /utils/cython/UNKNOWN.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | render.c 3 | setup.py 4 | UNKNOWN.egg-info/PKG-INFO 5 | UNKNOWN.egg-info/SOURCES.txt 6 | UNKNOWN.egg-info/dependency_links.txt 7 | UNKNOWN.egg-info/top_level.txt -------------------------------------------------------------------------------- /utils/cython/UNKNOWN.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/cython/UNKNOWN.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | render 2 | -------------------------------------------------------------------------------- /utils/cython/build/lib.linux-x86_64-3.6/render.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/cython/build/lib.linux-x86_64-3.6/render.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /utils/cython/build/temp.linux-x86_64-3.6/render.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/cython/build/temp.linux-x86_64-3.6/render.o -------------------------------------------------------------------------------- /utils/cython/dist/UNKNOWN-0.0.0-py3.6-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/cython/dist/UNKNOWN-0.0.0-py3.6-linux-x86_64.egg -------------------------------------------------------------------------------- /utils/cython/render.cpython-36m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/cython/render.cpython-36m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /utils/cython/render.pyx: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: cbwces 3 | @date: 20210419 4 | @contact: sknyqbcbw@gmail.com 5 | ''' 6 | cimport cython 7 | from cython.parallel import prange 8 | import numpy 9 | cimport numpy 10 | from libc.math cimport ceil, floor 11 | 12 | @cython.boundscheck(False) 13 | @cython.wraparound(False) 14 | @cython.nonecheck(False) 15 | cdef int MAX(int a, int b): 16 | if a > b: 17 | b = a 18 | return b 19 | 20 | @cython.boundscheck(False) 21 | @cython.wraparound(False) 22 | @cython.nonecheck(False) 23 | cdef int MIN(int a, int b): 24 | if a < b: 25 | b = a 26 | return b 27 | 28 | @cython.boundscheck(False) 29 | @cython.wraparound(False) 30 | @cython.nonecheck(False) 31 | cdef (double, double) minmax(numpy.ndarray[double, ndim=1, mode='c'] arr): 32 | cdef double min_ = 999999. 33 | cdef double max_ = -999999. 34 | cdef Py_ssize_t i 35 | for i in range(arr.shape[0]): 36 | if arr[i] < min_: 37 | min_ = arr[i] 38 | if arr[i] > max_: 39 | max_ = arr[i] 40 | return min_, max_ 41 | 42 | @cython.boundscheck(False) 43 | @cython.wraparound(False) 44 | @cython.nonecheck(False) 45 | def render_cy(numpy.ndarray[double, ndim=2, mode='c'] vertices, numpy.ndarray[double, ndim=2, mode='c'] new_colors, numpy.ndarray[long, ndim=2, mode='c'] triangles, int h, int w): 46 | cdef Py_ssize_t vertices_shape0 = vertices.shape[1] 47 | cdef numpy.ndarray[double, ndim=2, mode='c'] vis_colors = numpy.ones((1, vertices_shape0)) 48 | cdef numpy.ndarray[double, ndim=3, mode='c'] face_mask = render_texture(vertices, vis_colors, triangles, h, w, 1) 49 | cdef numpy.ndarray[double, ndim=3, mode='c'] new_image = render_texture(vertices, new_colors, triangles, h, w, 3) 50 | return face_mask, new_image 51 | 52 | @cython.boundscheck(False) 53 | @cython.wraparound(False) 54 | @cython.nonecheck(False) 55 | cdef numpy.ndarray[double, ndim=3, mode='c'] render_texture(numpy.ndarray[double, ndim=2, mode='c'] vertices, numpy.ndarray[double, ndim=2, mode='c'] colors, numpy.ndarray[long, ndim=2, mode='c'] triangles, int h, int w, int c = 3): 56 | 57 | # cdef numpy.ndarray[double, ndim=3, mode='c'] image = numpy.empty((h, w, c), dtype=numpy.double) 58 | cdef numpy.ndarray[double, ndim=3, mode='c'] image = numpy.zeros((h, w, c), dtype=numpy.double) 59 | cdef numpy.ndarray[double, ndim=2, mode='c'] depth_buffer = numpy.zeros([h, w], dtype=numpy.double) - 999999. 60 | 61 | cdef Py_ssize_t triangles_size_0 = triangles.shape[0] 62 | cdef Py_ssize_t triangles_size_1 = triangles.shape[1] 63 | cdef Py_ssize_t triangles_size_0_ptr 64 | cdef Py_ssize_t triangles_size_1_ptr 65 | 66 | cdef Py_ssize_t colors_size = colors.shape[0] 67 | cdef Py_ssize_t colors_size_ptr 68 | 69 | cdef numpy.ndarray[double, ndim=1, mode='c'] tri_depth = numpy.empty((triangles_size_1), dtype=numpy.double) 70 | cdef numpy.ndarray[double, ndim=2, mode='c'] tri_tex = numpy.empty((colors_size, triangles_size_1), dtype=numpy.double) 71 | 72 | for triangles_size_1_ptr in prange(triangles_size_1, nogil=True): 73 | tri_depth[triangles_size_1_ptr] = (vertices[2, triangles[0, triangles_size_1_ptr]] + vertices[2, triangles[1, triangles_size_1_ptr]] + vertices[2, triangles[2, triangles_size_1_ptr]]) / 3. 74 | for colors_size_ptr in range(colors_size): 75 | tri_tex[colors_size_ptr, triangles_size_1_ptr] = (colors[colors_size_ptr, triangles[0, triangles_size_1_ptr]] + colors[colors_size_ptr, triangles[1, triangles_size_1_ptr]] + colors[colors_size_ptr, triangles[2, triangles_size_1_ptr]]) / 3. 76 | 77 | cdef int umin 78 | cdef int vmin 79 | cdef int umax 80 | cdef int vmax 81 | cdef Py_ssize_t u 82 | cdef Py_ssize_t v 83 | cdef double relate_min 84 | cdef double relate_max 85 | cdef numpy.ndarray[long, ndim=1, mode='c'] tri = numpy.empty((triangles_size_0,), dtype=numpy.long) 86 | cdef Py_ssize_t c_channel_ptr 87 | cdef numpy.ndarray[double, ndim=2, mode='c'] vertices_idx_by_tri = numpy.empty((2, triangles_size_0), dtype=numpy.double) 88 | cdef bint ifisPointInTri 89 | 90 | for triangles_size_1_ptr in range(triangles_size_1): 91 | for triangles_size_0_ptr in range(triangles_size_0): 92 | tri[triangles_size_0_ptr] = triangles[triangles_size_0_ptr, triangles_size_1_ptr] 93 | vertices_idx_by_tri[0, triangles_size_0_ptr] = vertices[0, tri[triangles_size_0_ptr]] 94 | vertices_idx_by_tri[1, triangles_size_0_ptr] = vertices[1, tri[triangles_size_0_ptr]] 95 | 96 | relate_min, relate_max = minmax(vertices_idx_by_tri[0]) 97 | 98 | umin = MAX((ceil(relate_min)), 0) 99 | umax = MIN((floor(relate_max)), w-1) 100 | 101 | relate_min, relate_max = minmax(vertices_idx_by_tri[1]) 102 | vmin = MAX((ceil(relate_min)), 0) 103 | vmax = MIN((floor(relate_max)), h-1) 104 | 105 | if umax depth_buffer[v, u]: 110 | ifisPointInTri = isPointInTri(u, v, vertices_idx_by_tri) 111 | if ifisPointInTri: 112 | depth_buffer[v, u] = tri_depth[triangles_size_1_ptr] 113 | for c_channel_ptr in range(c): 114 | image[v, u, c_channel_ptr] = tri_tex[c_channel_ptr, triangles_size_1_ptr] 115 | return image 116 | 117 | @cython.boundscheck(False) 118 | @cython.wraparound(False) 119 | @cython.nonecheck(False) 120 | cdef bint isPointInTri(double point0, double point1, numpy.ndarray[double, ndim=2, mode='c'] tp0): 121 | 122 | cdef double dot00 = 0 123 | cdef double dot01 = 0 124 | cdef double dot02 = 0 125 | cdef double dot11 = 0 126 | cdef double dot12 = 0 127 | 128 | dot00 += (tp0[0, 2]-tp0[0, 0])*(tp0[0, 2]-tp0[0, 0]) 129 | dot00 += (tp0[1, 2]-tp0[1, 0])*(tp0[1, 2]-tp0[1, 0]) 130 | dot01 += (tp0[0, 2]-tp0[0, 0])*(tp0[0, 1]-tp0[0, 0]) 131 | dot01 += (tp0[1, 2]-tp0[1, 0])*(tp0[1, 1]-tp0[1, 0]) 132 | dot02 += (tp0[0, 2]-tp0[0, 0])*(point0-tp0[0, 0]) 133 | dot02 += (tp0[1, 2]-tp0[1, 0])*(point1-tp0[1, 0]) 134 | dot11 += (tp0[0, 1]-tp0[0, 0])*(tp0[0, 1]-tp0[0, 0]) 135 | dot11 += (tp0[1, 1]-tp0[1, 0])*(tp0[1, 1]-tp0[1, 0]) 136 | dot12 += (tp0[0, 1]-tp0[0, 0])*(point0-tp0[0, 0]) 137 | dot12 += (tp0[1, 1]-tp0[1, 0])*(point1-tp0[1, 0]) 138 | 139 | cdef double inverDeno 140 | 141 | if dot00*dot11 - dot01*dot01 == 0: 142 | inverDeno = 0.0 143 | else: 144 | inverDeno = 1.0/(dot00*dot11 - dot01*dot01) 145 | 146 | cdef double u = (dot11*dot02 - dot01*dot12)*inverDeno 147 | cdef double v = (dot00*dot12 - dot01*dot02)*inverDeno 148 | 149 | return (u >= 0) & (v >= 0) & (u + v < 1) 150 | -------------------------------------------------------------------------------- /utils/cython/setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @author: cbwces 3 | @date: 20210419 4 | @contact: sknyqbcbw@gmail.com 5 | ''' 6 | from setuptools import setup, Extension 7 | from Cython.Build import cythonize 8 | import numpy 9 | 10 | ext_modules = [ 11 | Extension( 12 | "render", 13 | ["render.pyx"], 14 | extra_compile_args=['-fopenmp'], 15 | extra_link_args=['-fopenmp'], 16 | ) 17 | ] 18 | 19 | setup( 20 | ext_modules=cythonize(ext_modules), 21 | include_dirs=[numpy.get_include()] 22 | ) 23 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 10 | from prettytable import PrettyTable 11 | from sklearn.metrics import roc_curve, auc 12 | 13 | image_path = "/data/anxiang/IJB_release/IJBC" 14 | files = [ 15 | "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy" 16 | ] 17 | 18 | 19 | def read_template_pair_list(path): 20 | pairs = pd.read_csv(path, sep=' ', header=None).values 21 | t1 = pairs[:, 0].astype(np.int) 22 | t2 = pairs[:, 1].astype(np.int) 23 | label = pairs[:, 2].astype(np.int) 24 | return t1, t2, label 25 | 26 | 27 | p1, p2, label = read_template_pair_list( 28 | os.path.join('%s/meta' % image_path, 29 | '%s_template_pair_label.txt' % 'ijbc')) 30 | 31 | methods = [] 32 | scores = [] 33 | for file in files: 34 | methods.append(file.split('/')[-2]) 35 | scores.append(np.load(file)) 36 | 37 | methods = np.array(methods) 38 | scores = dict(zip(methods, scores)) 39 | colours = dict( 40 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 41 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 42 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 43 | fig = plt.figure() 44 | for method in methods: 45 | fpr, tpr, _ = roc_curve(label, scores[method]) 46 | roc_auc = auc(fpr, tpr) 47 | fpr = np.flipud(fpr) 48 | tpr = np.flipud(tpr) # select largest tpr at same fpr 49 | plt.plot(fpr, 50 | tpr, 51 | color=colours[method], 52 | lw=1, 53 | label=('[%s (AUC = %0.4f %%)]' % 54 | (method.split('-')[-1], roc_auc * 100))) 55 | tpr_fpr_row = [] 56 | tpr_fpr_row.append("%s-%s" % (method, "IJBC")) 57 | for fpr_iter in np.arange(len(x_labels)): 58 | _, min_index = min( 59 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 60 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 61 | tpr_fpr_table.add_row(tpr_fpr_row) 62 | plt.xlim([10 ** -6, 0.1]) 63 | plt.ylim([0.3, 1.0]) 64 | plt.grid(linestyle='--', linewidth=1) 65 | plt.xticks(x_labels) 66 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 67 | plt.xscale('log') 68 | plt.xlabel('False Positive Rate') 69 | plt.ylabel('True Positive Rate') 70 | plt.title('ROC on IJB') 71 | plt.legend(loc="lower right") 72 | print(tpr_fpr_table) 73 | -------------------------------------------------------------------------------- /utils/read_info.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yinglu Liu, Jun Wang 3 | @date: 20201012 4 | @contact: jun21wangustc@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | 9 | def read_landmark_106_file(filepath): 10 | map = [[1,2],[3,4],[5,6],7,9,11,[12,13],14,16,18,[19,20],21,23,25,[26,27],[28,29],[30,31],33,34,35,36,37,42,43,44,45,46,51,52,53,54,58,59,60,61,62,66,67,69,70,71,73,75,76,78,79,80,82,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103] 11 | line = open(filepath).readline().strip() 12 | pts1 = line.split(' ')[58:-1] 13 | assert(len(pts1) == 106*2) 14 | pts1 = np.array(pts1, dtype = np.float) 15 | pts1 = pts1.reshape((106, 2)) 16 | pts = np.zeros((68,2)) # map 106 to 68 17 | for ii in range(len(map)): 18 | if isinstance(map[ii],list): 19 | pts[ii] = np.mean(pts1[map[ii]], axis=0) 20 | else: 21 | pts[ii] = pts1[map[ii]] 22 | return pts 23 | 24 | def read_landmark_106_array(face_lms): 25 | map = [[1,2],[3,4],[5,6],7,9,11,[12,13],14,16,18,[19,20],21,23,25,[26,27],[28,29],[30,31],33,34,35,36,37,42,43,44,45,46,51,52,53,54,58,59,60,61,62,66,67,69,70,71,73,75,76,78,79,80,82,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103] 26 | pts1 = np.array(face_lms, dtype = np.float) 27 | pts1 = pts1.reshape((106, 2)) 28 | pts = np.zeros((68,2)) # map 106 to 68 29 | for ii in range(len(map)): 30 | if isinstance(map[ii],list): 31 | pts[ii] = np.mean(pts1[map[ii]], axis=0) 32 | else: 33 | pts[ii] = pts1[map[ii]] 34 | return pts 35 | 36 | def read_landmark_106(filepath): 37 | map = [[1,2],[3,4],[5,6],7,9,11,[12,13],14,16,18,[19,20],21,23,25,[26,27],[28,29],[30,31],33,34,35,36,37,42,43,44,45,46,51,52,53,54,58,59,60,61,62,66,67,69,70,71,73,75,76,78,79,80,82,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103] 38 | lines = open(filepath).readlines() # load landmarks 39 | pts1 = [_.strip().split() for _ in lines[1:107]] 40 | pts1 = np.array(pts1, dtype = np.float) 41 | pts = np.zeros((68,2)) # map 106 to 68 42 | for ii in range(len(map)): 43 | if isinstance(map[ii],list): 44 | pts[ii] = np.mean(pts1[map[ii]], axis=0) 45 | else: 46 | pts[ii] = pts1[map[ii]] 47 | return pts 48 | 49 | def read_bbox(filepath): 50 | lines = open(filepath).readlines() 51 | bbox = lines[0].strip().split() 52 | bbox = [int(float(_)) for _ in bbox] 53 | return np.array(bbox) 54 | 55 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: YadiraF 3 | Mail: fengyao@sjtu.edu.cn 4 | ''' 5 | import numpy as np 6 | 7 | def isPointInTri(point, tri_points): 8 | ''' Judge whether the point is in the triangle 9 | Method: 10 | http://blackpawn.com/texts/pointinpoly/ 11 | Args: 12 | point: [u, v] or [x, y] 13 | tri_points: three vertices(2d points) of a triangle. 2 coords x 3 vertices 14 | Returns: 15 | bool: true for in triangle 16 | ''' 17 | tp = tri_points 18 | 19 | # vectors 20 | v0 = tp[:,2] - tp[:,0] 21 | v1 = tp[:,1] - tp[:,0] 22 | v2 = point - tp[:,0] 23 | 24 | # dot products 25 | dot00 = np.dot(v0.T, v0) 26 | dot01 = np.dot(v0.T, v1) 27 | dot02 = np.dot(v0.T, v2) 28 | dot11 = np.dot(v1.T, v1) 29 | dot12 = np.dot(v1.T, v2) 30 | 31 | # barycentric coordinates 32 | if dot00*dot11 - dot01*dot01 == 0: 33 | inverDeno = 0 34 | else: 35 | inverDeno = 1/(dot00*dot11 - dot01*dot01) 36 | 37 | u = (dot11*dot02 - dot01*dot12)*inverDeno 38 | v = (dot00*dot12 - dot01*dot02)*inverDeno 39 | 40 | # check if point in triangle 41 | return (u >= 0) & (v >= 0) & (u + v < 1) 42 | 43 | def get_point_weight(point, tri_points): 44 | ''' Get the weights of the position 45 | Methods: https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates 46 | -m1.compute the area of the triangles formed by embedding the point P inside the triangle 47 | -m2.Christer Ericson's book "Real-Time Collision Detection". faster, so I used this. 48 | Args: 49 | point: [u, v] or [x, y] 50 | tri_points: three vertices(2d points) of a triangle. 2 coords x 3 vertices 51 | Returns: 52 | w0: weight of v0 53 | w1: weight of v1 54 | w2: weight of v3 55 | ''' 56 | tp = tri_points 57 | # vectors 58 | v0 = tp[:,2] - tp[:,0] 59 | v1 = tp[:,1] - tp[:,0] 60 | v2 = point - tp[:,0] 61 | 62 | # dot products 63 | dot00 = np.dot(v0.T, v0) 64 | dot01 = np.dot(v0.T, v1) 65 | dot02 = np.dot(v0.T, v2) 66 | dot11 = np.dot(v1.T, v1) 67 | dot12 = np.dot(v1.T, v2) 68 | 69 | # barycentric coordinates 70 | if dot00*dot11 - dot01*dot01 == 0: 71 | inverDeno = 0 72 | else: 73 | inverDeno = 1/(dot00*dot11 - dot01*dot01) 74 | 75 | u = (dot11*dot02 - dot01*dot12)*inverDeno 76 | v = (dot00*dot12 - dot01*dot02)*inverDeno 77 | 78 | w0 = 1 - u - v 79 | w1 = v 80 | w2 = u 81 | 82 | return w0, w1, w2 83 | 84 | 85 | def render_texture(vertices, colors, triangles, h, w, c = 3): 86 | ''' render mesh by z buffer 87 | Args: 88 | vertices: 3 x nver 89 | colors: 3 x nver 90 | triangles: 3 x ntri 91 | h: height 92 | w: width 93 | ''' 94 | # initial 95 | image = np.zeros((h, w, c)) 96 | 97 | depth_buffer = np.zeros([h, w]) - 999999. 98 | # triangle depth: approximate the depth to the average value of z in each vertex(v0, v1, v2), since the vertices are closed to each other 99 | tri_depth = (vertices[2, triangles[0,:]] + vertices[2,triangles[1,:]] + vertices[2, triangles[2,:]])/3. 100 | tri_tex = (colors[:, triangles[0,:]] + colors[:,triangles[1,:]] + colors[:, triangles[2,:]])/3. 101 | 102 | for i in range(triangles.shape[1]): 103 | tri = triangles[:, i] # 3 vertex indices 104 | 105 | # the inner bounding box 106 | umin = max(int(np.ceil(np.min(vertices[0,tri]))), 0) 107 | umax = min(int(np.floor(np.max(vertices[0,tri]))), w-1) 108 | 109 | vmin = max(int(np.ceil(np.min(vertices[1,tri]))), 0) 110 | vmax = min(int(np.floor(np.max(vertices[1,tri]))), h-1) 111 | 112 | if umax depth_buffer[v, u] and isPointInTri([u,v], vertices[:2, tri]): 118 | depth_buffer[v, u] = tri_depth[i] 119 | image[v, u, :] = tri_tex[:, i] 120 | return image 121 | 122 | 123 | def map_texture(src_image, src_vertices, dst_vertices, dst_triangle_buffer, triangles, h, w, c = 3, mapping_type = 'bilinear'): 124 | ''' 125 | Args: 126 | triangles: 3 x ntri 127 | 128 | # src 129 | src_image: height x width x nchannels 130 | src_vertices: 3 x nver 131 | 132 | # dst 133 | dst_vertices: 3 x nver 134 | dst_triangle_buffer: height x width. the triangle index of each pixel in dst image 135 | 136 | Returns: 137 | dst_image: height x width x nchannels 138 | 139 | ''' 140 | [sh, sw, sc] = src_image.shape 141 | dst_image = np.zeros((h, w, c)) 142 | for y in range(h): 143 | for x in range(w): 144 | tri_ind = dst_triangle_buffer[y,x] 145 | if tri_ind < 0: # no tri in dst image 146 | continue 147 | #if src_triangles_vis[tri_ind]: # the corresponding triangle in src image is invisible 148 | # continue 149 | 150 | # then. For this triangle index, map corresponding pixels(in triangles) in src image to dst image 151 | # Two Methods: 152 | # M1. Calculate the corresponding affine matrix from src triangle to dst triangle. Then find the corresponding src position of this dst pixel. 153 | # -- ToDo 154 | # M2. Calculate the relative position of three vertices in dst triangle, then find the corresponding src position relative to three src vertices. 155 | tri = triangles[:, tri_ind] 156 | # dst weight, here directly use the center to approximate because the tri is small 157 | # if tri_ind < 366: 158 | # print tri_ind 159 | w0, w1, w2 = get_point_weight([x, y], dst_vertices[:2, tri]) 160 | # else: 161 | # w0 = w1 = w2 = 1./3 162 | # src 163 | src_texel = w0*src_vertices[:2, tri[0]] + w1*src_vertices[:2, tri[1]] + w2*src_vertices[:2, tri[2]] # 164 | # 165 | if src_texel[0] < 0 or src_texel[0]> sw-1 or src_texel[1]<0 or src_texel[1] > sh-1: 166 | dst_image[y, x, :] = 0 167 | continue 168 | # As the coordinates of the transformed pixel in the image will most likely not lie on a texel, we have to choose how to 169 | # calculate the pixel colors depending on the next texels 170 | # there are three different texture interpolation methods: area, bilinear and nearest neighbour 171 | # print y, x, src_texel 172 | # nearest neighbour 173 | if mapping_type == 'nearest': 174 | dst_image[y, x, :] = src_image[int(round(src_texel[1])), int(round(src_texel[0])), :] 175 | # bilinear 176 | elif mapping_type == 'bilinear': 177 | # next 4 pixels 178 | ul = src_image[int(np.floor(src_texel[1])), int(np.floor(src_texel[0])), :] 179 | ur = src_image[int(np.floor(src_texel[1])), int(np.ceil(src_texel[0])), :] 180 | dl = src_image[int(np.ceil(src_texel[1])), int(np.floor(src_texel[0])), :] 181 | dr = src_image[int(np.ceil(src_texel[1])), int(np.ceil(src_texel[0])), :] 182 | 183 | yd = src_texel[1] - np.floor(src_texel[1]) 184 | xd = src_texel[0] - np.floor(src_texel[0]) 185 | dst_image[y, x, :] = ul*(1-xd)*(1-yd) + ur*xd*(1-yd) + dl*(1-xd)*yd + dr*xd*yd 186 | 187 | return dst_image 188 | 189 | 190 | def get_depth_buffer(vertices, triangles, h, w): 191 | ''' 192 | Args: 193 | vertices: 3 x nver 194 | triangles: 3 x ntri 195 | h: height 196 | w: width 197 | Returns: 198 | depth_buffer: height x width 199 | ToDo: 200 | whether to add x, y by 0.5? the center of the pixel? 201 | m3. like somewhere is wrong 202 | # Each triangle has 3 vertices & Each vertex has 3 coordinates x, y, z. 203 | # Here, the bigger the z, the fronter the point. 204 | ''' 205 | # initial 206 | depth_buffer = np.zeros([h, w]) - 999999. #+ np.min(vertices[2,:]) - 999999. # set the initial z to the farest position 207 | 208 | ## calculate the depth(z) of each triangle 209 | #-m1. z = the center of shpere(through 3 vertices) 210 | #center3d = (vertices[:, triangles[0,:]] + vertices[:,triangles[1,:]] + vertices[:, triangles[2,:]])/3. 211 | #tri_depth = np.sum(center3d**2, axis = 0) 212 | #-m2. z = the center of z(v0, v1, v2) 213 | tri_depth = (vertices[2, triangles[0,:]] + vertices[2,triangles[1,:]] + vertices[2, triangles[2,:]])/3. 214 | 215 | for i in range(triangles.shape[1]): 216 | tri = triangles[:, i] # 3 vertex indices 217 | 218 | # the inner bounding box 219 | umin = max(int(np.ceil(np.min(vertices[0,tri]))), 0) 220 | umax = min(int(np.floor(np.max(vertices[0,tri]))), w-1) 221 | 222 | vmin = max(int(np.ceil(np.min(vertices[1,tri]))), 0) 223 | vmax = min(int(np.floor(np.max(vertices[1,tri]))), h-1) 224 | 225 | if umax depth_buffer[v, u]: # and is_pointIntri([u,v], vertices[:2, tri]): 234 | depth_buffer[v, u] = tri_depth[i] 235 | 236 | return depth_buffer 237 | 238 | 239 | def get_triangle_buffer(vertices, triangles, h, w): 240 | ''' 241 | Args: 242 | vertices: 3 x nver 243 | triangles: 3 x ntri 244 | h: height 245 | w: width 246 | Returns: 247 | depth_buffer: height x width 248 | ToDo: 249 | whether to add x, y by 0.5? the center of the pixel? 250 | m3. like somewhere is wrong 251 | # Each triangle has 3 vertices & Each vertex has 3 coordinates x, y, z. 252 | # Here, the bigger the z, the fronter the point. 253 | ''' 254 | # initial 255 | depth_buffer = np.zeros([h, w]) - 999999. #+ np.min(vertices[2,:]) - 999999. # set the initial z to the farest position 256 | triangle_buffer = np.zeros_like(depth_buffer, dtype = np.int32) - 1 # if -1, the pixel has no triangle correspondance 257 | 258 | ## calculate the depth(z) of each triangle 259 | #-m1. z = the center of shpere(through 3 vertices) 260 | #center3d = (vertices[:, triangles[0,:]] + vertices[:,triangles[1,:]] + vertices[:, triangles[2,:]])/3. 261 | #tri_depth = np.sum(center3d**2, axis = 0) 262 | #-m2. z = the center of z(v0, v1, v2) 263 | tri_depth = (vertices[2, triangles[0,:]] + vertices[2,triangles[1,:]] + vertices[2, triangles[2,:]])/3. 264 | 265 | for i in range(triangles.shape[1]): 266 | tri = triangles[:, i] # 3 vertex indices 267 | 268 | # the inner bounding box 269 | umin = max(int(np.ceil(np.min(vertices[0,tri]))), 0) 270 | umax = min(int(np.floor(np.max(vertices[0,tri]))), w-1) 271 | 272 | vmin = max(int(np.ceil(np.min(vertices[1,tri]))), 0) 273 | vmax = min(int(np.floor(np.max(vertices[1,tri]))), h-1) 274 | 275 | if umax depth_buffer[v, u] and isPointInTri([u,v], vertices[:2, tri]): 284 | depth_buffer[v, u] = tri_depth[i] 285 | triangle_buffer[v, u] = i 286 | 287 | return triangle_buffer 288 | 289 | 290 | def vis_of_vertices(vertices, triangles, h, w, depth_buffer = None): 291 | ''' 292 | Args: 293 | vertices: 3 x nver 294 | triangles: 3 x ntri 295 | depth_buffer: height x width 296 | Returns: 297 | vertices_vis: nver. the visibility of each vertex 298 | ''' 299 | if depth_buffer == None: 300 | depth_buffer = get_depth_buffer(vertices, triangles, h, w) 301 | 302 | vertices_vis = np.zeros(vertices.shape[1], dtype = bool) 303 | 304 | depth_tmp = np.zeros_like(depth_buffer) - 99999 305 | for i in range(vertices.shape[1]): 306 | vertex = vertices[:, i] 307 | 308 | if np.floor(vertex[0]) < 0 or np.ceil(vertex[0]) > w-1 or np.floor(vertex[1]) < 0 or np.ceil(vertex[1]) > h-1: 309 | continue 310 | 311 | # bilinear interp 312 | # ul = depth_buffer[int(np.floor(vertex[1])), int(np.floor(vertex[0]))] 313 | # ur = depth_buffer[int(np.floor(vertex[1])), int(np.ceil(vertex[0]))] 314 | # dl = depth_buffer[int(np.ceil(vertex[1])), int(np.floor(vertex[0]))] 315 | # dr = depth_buffer[int(np.ceil(vertex[1])), int(np.ceil(vertex[0]))] 316 | 317 | # yd = vertex[1] - np.floor(vertex[1]) 318 | # xd = vertex[0] - np.floor(vertex[0]) 319 | 320 | # vertex_depth = ul*(1-xd)*(1-yd) + ur*xd*(1-yd) + dl*(1-xd)*yd + dr*xd*yd 321 | 322 | # nearest 323 | px = int(np.round(vertex[0])) 324 | py = int(np.round(vertex[1])) 325 | 326 | # if (vertex[2] > depth_buffer[ul[0], ul[1]]) & (vertex[2] > depth_buffer[ur[0], ur[1]]) & (vertex[2] > depth_buffer[dl[0], dl[1]]) & (vertex[2] > depth_buffer[dr[0], dr[1]]): 327 | if vertex[2] < depth_tmp[py, px]: 328 | continue 329 | 330 | # if vertex[2] > depth_buffer[py, px]: 331 | # vertices_vis[i] = True 332 | # depth_tmp[py, px] = vertex[2] 333 | # elif np.abs(vertex[2] - depth_buffer[py, px]) < 1: 334 | # vertices_vis[i] = True 335 | 336 | threshold = 2 # need to be optimized. 337 | if np.abs(vertex[2] - depth_buffer[py, px]) < threshold: 338 | # if np.abs(vertex[2] - vertex_depth) < threshold: 339 | vertices_vis[i] = True 340 | depth_tmp[py, px] = vertex[2] 341 | 342 | return vertices_vis 343 | 344 | -------------------------------------------------------------------------------- /utils/utils_amp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | 5 | if torch.__version__ < '1.9': 6 | Iterable = torch._six.container_abcs.Iterable 7 | else: 8 | import collections 9 | 10 | Iterable = collections.abc.Iterable 11 | from torch.cuda.amp import GradScaler 12 | 13 | 14 | class _MultiDeviceReplicator(object): 15 | """ 16 | Lazily serves copies of a tensor to requested devices. Copies are cached per-device. 17 | """ 18 | 19 | def __init__(self, master_tensor: torch.Tensor) -> None: 20 | assert master_tensor.is_cuda 21 | self.master = master_tensor 22 | self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} 23 | 24 | def get(self, device) -> torch.Tensor: 25 | retval = self._per_device_tensors.get(device, None) 26 | if retval is None: 27 | retval = self.master.to(device=device, non_blocking=True, copy=True) 28 | self._per_device_tensors[device] = retval 29 | return retval 30 | 31 | 32 | class MaxClipGradScaler(GradScaler): 33 | def __init__(self, init_scale, max_scale: float, growth_interval=100): 34 | GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval) 35 | self.max_scale = max_scale 36 | 37 | def scale_clip(self): 38 | if self.get_scale() == self.max_scale: 39 | self.set_growth_factor(1) 40 | elif self.get_scale() < self.max_scale: 41 | self.set_growth_factor(2) 42 | elif self.get_scale() > self.max_scale: 43 | self._scale.fill_(self.max_scale) 44 | self.set_growth_factor(1) 45 | 46 | def scale(self, outputs): 47 | """ 48 | Multiplies ('scales') a tensor or list of tensors by the scale factor. 49 | 50 | Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned 51 | unmodified. 52 | 53 | Arguments: 54 | outputs (Tensor or iterable of Tensors): Outputs to scale. 55 | """ 56 | if not self._enabled: 57 | return outputs 58 | self.scale_clip() 59 | # Short-circuit for the common case. 60 | if isinstance(outputs, torch.Tensor): 61 | assert outputs.is_cuda 62 | if self._scale is None: 63 | self._lazy_init_scale_growth_tracker(outputs.device) 64 | assert self._scale is not None 65 | return outputs * self._scale.to(device=outputs.device, non_blocking=True) 66 | 67 | # Invoke the more complex machinery only if we're treating multiple outputs. 68 | stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale 69 | 70 | def apply_scale(val): 71 | if isinstance(val, torch.Tensor): 72 | assert val.is_cuda 73 | if len(stash) == 0: 74 | if self._scale is None: 75 | self._lazy_init_scale_growth_tracker(val.device) 76 | assert self._scale is not None 77 | stash.append(_MultiDeviceReplicator(self._scale)) 78 | return val * stash[0].get(val.device) 79 | elif isinstance(val, Iterable): 80 | iterable = map(apply_scale, val) 81 | if isinstance(val, list) or isinstance(val, tuple): 82 | return type(val)(iterable) 83 | else: 84 | return iterable 85 | else: 86 | raise ValueError("outputs must be a Tensor or an iterable of Tensors") 87 | 88 | return apply_scale(outputs) 89 | -------------------------------------------------------------------------------- /utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List 5 | 6 | import torch 7 | 8 | from eval import verification 9 | from utils.utils_logging import AverageMeter 10 | 11 | 12 | class CallBackVerification(object): 13 | def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)): 14 | self.frequent: int = frequent 15 | self.rank: int = rank 16 | self.highest_acc: float = 0.0 17 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 18 | self.ver_list: List[object] = [] 19 | self.ver_name_list: List[str] = [] 20 | self.current_highest = 0 21 | if self.rank is 0: 22 | self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) 23 | 24 | def ver_test(self, module_partial_fc, backbone: torch.nn.Module, global_step: int): 25 | results = [] 26 | current_score = 0 27 | for i in range(len(self.ver_list)): 28 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( 29 | self.ver_list[i], backbone, 10, 10) 30 | logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) 31 | logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) 32 | if acc2 > self.highest_acc_list[i]: 33 | self.highest_acc_list[i] = acc2 34 | logging.info( 35 | '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) 36 | results.append(acc2) 37 | current_score += acc2 38 | logging.info( 39 | '[+][%d]Score / Score-Highest: %2.5f / %2.5f' % (global_step, current_score, self.current_highest)) 40 | if self.current_highest <= current_score: 41 | path_save = 'tmp/backbone_{}.pth'.format(current_score) 42 | if not os.path.exists('tmp'): 43 | os.mkdir('tmp') 44 | torch.save(backbone.module.state_dict(), path_save) 45 | print('Saved as best checkpoint to', path_save) 46 | # if global_step > 100 and module_partial_fc is not None: 47 | # module_partial_fc.save_params(folder = "tmp") 48 | self.current_highest = current_score 49 | 50 | def init_dataset(self, val_targets, data_dir, image_size): 51 | for name in val_targets: 52 | path = os.path.join(data_dir, name + ".bin") 53 | if os.path.exists(path): 54 | data_set = verification.load_bin(path, image_size) 55 | self.ver_list.append(data_set) 56 | self.ver_name_list.append(name) 57 | 58 | def __call__(self, num_update, module_partial_fc, backbone: torch.nn.Module): 59 | if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0: 60 | backbone.eval() 61 | self.ver_test(module_partial_fc, backbone, num_update) 62 | backbone.train() 63 | 64 | 65 | class CallBackLogging(object): 66 | def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None): 67 | self.frequent: int = frequent 68 | self.rank: int = rank 69 | self.time_start = time.time() 70 | self.total_step: int = total_step 71 | self.batch_size: int = batch_size 72 | self.world_size: int = world_size 73 | self.writer = writer 74 | 75 | self.init = False 76 | self.tic = 0 77 | 78 | def __call__(self, 79 | global_step: int, 80 | loss: AverageMeter, 81 | epoch: int, 82 | fp16: bool, 83 | learning_rate: float, 84 | grad_scaler: torch.cuda.amp.GradScaler): 85 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 86 | if self.init: 87 | try: 88 | speed: float = self.frequent * self.batch_size / (time.time() - self.tic) 89 | speed_total = speed * self.world_size 90 | except ZeroDivisionError: 91 | speed_total = float('inf') 92 | 93 | time_now = (time.time() - self.time_start) / 3600 94 | time_total = time_now / ((global_step + 1) / self.total_step) 95 | time_for_end = time_total - time_now 96 | if self.writer is not None: 97 | self.writer.add_scalar('time_for_end', time_for_end, global_step) 98 | self.writer.add_scalar('learning_rate', learning_rate, global_step) 99 | self.writer.add_scalar('loss', loss.avg, global_step) 100 | if fp16: 101 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ 102 | "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( 103 | speed_total, loss.avg, learning_rate, epoch, global_step, 104 | grad_scaler.get_scale(), time_for_end 105 | ) 106 | else: 107 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.4f Epoch: %d Global Step: %d " \ 108 | "Required: %1.f hours" % ( 109 | speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end 110 | ) 111 | logging.info(msg) 112 | loss.reset() 113 | self.tic = time.time() 114 | else: 115 | self.init = True 116 | self.tic = time.time() 117 | 118 | 119 | class CallBackModelCheckpoint(object): 120 | def __init__(self, rank, output="./"): 121 | self.rank: int = rank 122 | self.output: str = output 123 | 124 | def __call__(self, global_step, backbone, partial_fc, frequent = None): 125 | if frequent is not None: 126 | if global_step > 100 and self.rank == 0 and global_step % frequent == 0: 127 | path_module = os.path.join(self.output, "backbone.pth") 128 | torch.save(backbone.module.state_dict(), path_module) 129 | logging.info("Pytorch Model Saved in '{}'".format(path_module)) 130 | 131 | if global_step > 100 and partial_fc is not None and global_step % frequent == 0: 132 | partial_fc.save_params() 133 | else: 134 | if global_step > 100 and self.rank == 0: 135 | path_module = os.path.join(self.output, "backbone.pth") 136 | torch.save(backbone.module.state_dict(), path_module) 137 | logging.info("Pytorch Model Saved in '{}'".format(path_module)) 138 | 139 | if global_step > 100 and partial_fc is not None: 140 | partial_fc.save_params() 141 | -------------------------------------------------------------------------------- /utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | assert config_file.startswith('configs/'), 'config file setting must start with configs/' 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module("configs.base") 10 | cfg = config.config 11 | config = importlib.import_module("configs.%s" % temp_module_name) 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join('work_dirs', temp_module_name) 16 | return cfg -------------------------------------------------------------------------------- /utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(rank, models_root): 31 | if rank == 0: 32 | log_root = logging.getLogger() 33 | log_root.setLevel(logging.INFO) 34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | -------------------------------------------------------------------------------- /utils/utils_os.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NNDam/adaface-partialfc/7630aa23cc792538cd18b758a3a796977e00fd17/utils/utils_os.py --------------------------------------------------------------------------------