├── FocusFace_Overview.png ├── metrics.py ├── main.py ├── LICENSE ├── .gitignore ├── model.py ├── README.md ├── loader.py ├── iresnet.py └── trainer.py /FocusFace_Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NetoPedro/FocusFace/HEAD/FocusFace_Overview.png -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from pyeer.eer_info import get_eer_stats 2 | from pyeer.report import generate_eer_report, export_error_rates 3 | from pyeer.plot import plot_eer_stats 4 | 5 | 6 | 7 | def calculate_metrics(gen_scores,fake_scores,epoch): 8 | metrics = get_eer_stats(gen_scores, fake_scores) 9 | generate_eer_report([metrics], ['A'], 'pyeer_report_'+str(epoch)+'.html') 10 | return metrics.fmr0,metrics.fmr100,metrics.fmr1000,metrics.gmean,metrics.imean,metrics.auc 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import loader 2 | import model 3 | import trainer 4 | import torch 5 | import torch.nn as nn 6 | 7 | identities = 85742 8 | torch.backends.cudnn.benchmark = True 9 | 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | validation_split = .01 13 | batch_size= 480 14 | workers = 8 15 | 16 | train_loader, validation_loader, classes = loader.get_train_loader(batch_size,workers,validation_split) 17 | 18 | 19 | net = model.FocusFace(identities=identities) 20 | 21 | net = nn.DataParallel(net, device_ids=[0,1,2,3]).to(device) 22 | 23 | trainer.train(net,trainloader=train_loader,validationloader=validation_loader,n_epochs=500,lr=0.01) 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Pedro Neto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import iresnet 2 | from torch import nn 3 | import numpy as np 4 | 5 | class FocusFace(nn.Module): 6 | def __init__(self,identities=1000): 7 | super(FocusFace,self).__init__() 8 | self.model = iresnet.iresnet100() 9 | self.model.fc = EmbeddingHead(512,32) 10 | self.fc = ArcMarginProduct(512, identities, s=64,m=0.5) #m=0.35) 11 | self.fc2 = nn.Linear(32, 2) 12 | self.relu = nn.ReLU() 13 | 14 | def forward(self, x,inference=False): 15 | e1,e2 = self.model(x) 16 | y = None 17 | if not(inference): 18 | y = self.fc(e1.view(e1.shape[0],-1),label) 19 | y2 = self.fc2(e2.view(e2.shape[0],-1)) 20 | e2 = e2.view(e2.shape[0],-1) 21 | e1 = e1.view(e1.shape[0],-1) 22 | if inference: 23 | y2 = self.fc2(e2.view(e2.shape[0],-1)) 24 | if not(inference): 25 | return y,e1,e2,y2 26 | return None,e1,None,torch.nn.functional.softmax(y2)[:,1] 27 | 28 | 29 | 30 | class EmbeddingHead(nn.Module): 31 | def __init__(self, c1=512,c2=256): 32 | super(EmbeddingHead,self).__init__() 33 | self.conv1 = nn.Conv2d(512, c1, kernel_size=(7, 7), stride=(1, 1), bias=False) 34 | self.conv2 = nn.Conv2d(512, c2, kernel_size=(7, 7), stride=(1, 1), bias=False) 35 | self.bn1 = nn.BatchNorm2d(c1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 36 | self.bn2 = nn.BatchNorm2d(c2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 37 | self.relu = nn.ReLU6(inplace=True) 38 | 39 | def forward(self,x): 40 | size = int(np.sqrt(x.shape[1]/512)) 41 | x = x.view((x.shape[0],-1,size,size)) 42 | return self.bn1(self.conv1(x)), self.relu(self.bn2(self.conv2(x))) 43 | 44 | 45 | class ArcMarginProduct(nn.Module): 46 | r"""Implement of large margin arc distance: : 47 | Args: 48 | in_features: size of each input sample 49 | out_features: size of each output sample 50 | s: norm of input feature 51 | m: margin 52 | cos(theta + m) 53 | """ 54 | def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False): 55 | super(ArcMarginProduct, self).__init__() 56 | self.in_features = in_features 57 | self.out_features = out_features 58 | self.s = s 59 | self.m = m 60 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 61 | nn.init.xavier_uniform_(self.weight) 62 | 63 | self.easy_margin = easy_margin 64 | self.cos_m = math.cos(m) 65 | self.sin_m = math.sin(m) 66 | self.th = math.cos(math.pi - m) 67 | self.mm = math.sin(math.pi - m) * m 68 | 69 | def forward(self, input, label): 70 | # --------------------------- cos(theta) & phi(theta) --------------------------- 71 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 72 | sine = torch.sqrt(torch.clamp((1.0 - torch.pow(cosine, 2)),1e-9,1)) 73 | phi = cosine * self.cos_m - sine * self.sin_m 74 | if self.easy_margin: 75 | phi = torch.where(cosine > 0, phi, cosine) 76 | else: 77 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 78 | # --------------------------- convert label to one-hot --------------------------- 79 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 80 | one_hot = torch.zeros(cosine.size(), device=label.device) 81 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 82 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 83 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 84 | output *= self.s 85 | # print(output) 86 | 87 | return output 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FocusFace 2 | 3 | This is the official repository of "FocusFace: Multi-task Contrastive Learning for Masked Face Recognition" accepted at *IEEE International Conference on Automatic Face and Gesture Recognition 2021 (FG2021)*. 4 | 5 | 6 | 7 | 8 | Research Paper at: 9 | 10 | * [Arxiv](https://arxiv.org/abs/2110.14940) 11 | * [IEEE Xplore](https://ieeexplore.ieee.org/abstract/document/9666792) 12 | 13 | ## Table of Contents 14 | 15 | - [Abstract](#abstract) 16 | - [Data](#data) 17 | - [Citing](#citing) 18 | - [Acknowledgement](#acknowledgement) 19 | - [License](#license) 20 | 21 | ### Abstract ### 22 | 23 | SARS-CoV-2 has presented direct and indirect challenges to the scientific community. One of the most prominent indirect challenges advents from the mandatory use of face masks in a large number of countries. Face recognition methods struggle to perform identity verification with similar accuracy on masked and unmasked individuals. It has been shown that the performance of these methods drops considerably in the presence of face masks, especially if the reference image is unmasked. We propose FocusFace, a multi-task architecture that uses contrastive learning to be able to accurately perform masked face recognition. The proposed architecture is designed to be trained from scratch or to work on top of state-of-the-art face recognition methods without sacrificing the capabilities of a existing models in conventional face recognition tasks. We also explore different approaches to design the contrastive learning module. Results are presented in terms of masked-masked (M-M) and unmasked-masked (U-M) face verification performance. For both settings, the results are on par with published methods, but for M-M specifically, the proposed method was able to outperform all the solutions that it was compared to. We further show that when using our method on top of already existing methods the training computational costs decrease significantly while retaining similar performances. 24 | 25 | ## Data ## 26 | 27 | ### Datasets ### 28 | The LFW and the "MS1M-ArcFace (85K ids/5.8M images)" can be downloaded [here](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_). 29 | 30 | For all the datasets above, please strictly follow the licence distribution. 31 | 32 | ### Masks ### 33 | The mask template used to create the synthetic masked data for training and evaluation is [MaskTheFace](https://github.com/aqeelanwar/MaskTheFace). 34 | 35 | 36 | ### Trained Models ### 37 | Our models can be downloaded from HuggingFace [HuggingFace](https://huggingface.co/netopedro/FocusFace). 38 | 39 | ### To-do 40 | - [X] Add pretrained models 41 | - [X] Add train script 42 | - [ ] Add evaluation script 43 | 44 | ## Citing ## 45 | If you use any of the code provided in this repository or the models provided, please cite the following paper: 46 | ``` 47 | @inproceedings{neto2021focusface, 48 | title={FocusFace: Multi-task Contrastive Learning for Masked Face Recognition}, 49 | author={Neto, Pedro C and Boutros, Fadi and Pinto, Jo{\~a}o Ribeiro and Darner, Naser and Sequeira, Ana F and Cardoso, Jaime S}, 50 | booktitle={2021 16th IEEE International Conference on Automatic Face and Gesture Recognition (FG 2021)}, 51 | pages={01--08}, 52 | year={2021}, 53 | organization={IEEE} 54 | } 55 | ``` 56 | 57 | ## Acknowledgement ## 58 | 59 | This work was financed by National Funds through the Portuguese funding agency, FCT - Fundação para a Ciência e a Tecnologia within project UIDB/50014/2020, and within the PhD grants ``2021.06872.BD'' and ``SFRH/BD/137720/2018''. This research work has been also funded by the German Federal Ministry of Education and Research and the Hessen State Ministry for Higher Education, Research and the Arts within their joint support of the National Research Center for Applied Cybersecurity ATHENE. 60 | 61 | ## License ## 62 | 63 | This project is licensed under the terms of the MIT License. 64 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | import torch 5 | from torchvision.datasets import ImageFolder 6 | import torchvision.datasets as datasets 7 | import torchvision.transforms as transforms 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | import torchvision.transforms.functional as TF 10 | 11 | def wif(id): 12 | #np.random.seed((id + torch.initial_seed()) % np.iinfo(np.int32).max) 13 | worker_seed = torch.initial_seed() % 2**32 14 | np.random.seed(worker_seed) 15 | 16 | class FaceDatasetVal(ImageFolder): 17 | 18 | def __init__(self, root, transform=None, loader=datasets.folder.default_loader, is_valid_file=None,prob = 1.0): 19 | super(FaceDatasetVal, self).__init__(root, transform=transform,is_valid_file=is_valid_file) 20 | self.imgs = self.samples 21 | self.prob = prob 22 | def __getitem__(self, index): 23 | """ 24 | Args: 25 | index (int): Index 26 | Returns: 27 | tuple: (sample, target) where target is class_index of the target class. 28 | """ 29 | path, target = self.samples[index] 30 | original_path = path 31 | rand = np.random.uniform() 32 | add_mask = False 33 | if rand < self.prob: 34 | add_mask = True 35 | path = path.replace("imgs","imgs_masked2") 36 | else: 37 | add_mask = False 38 | 39 | try: 40 | sample = self.loader(path) 41 | except: 42 | try: 43 | sample = self.loader(path.replace(".jpg","_surgical.jpg")) 44 | except: 45 | try: 46 | sample = self.loader(path.replace(".jpg","_cloth.jpg")) 47 | except: 48 | try: 49 | sample = self.loader(path.replace(".jpg","_N95.jpg")) 50 | except: 51 | try: 52 | sample = self.loader(path.replace(".jpg","_KN95.jpg")) 53 | except: 54 | add_mask = False 55 | sample = self.loader(original_path) 56 | 57 | if self.transform is not None: 58 | sample = self.transform(sample) 59 | if self.target_transform is not None: 60 | target = self.target_transform(target) 61 | 62 | mask = 0 63 | if add_mask: 64 | mask = 1 65 | sample = {'image': sample, 'identity': target,'mask':mask} 66 | return sample 67 | 68 | class FaceDataset(ImageFolder): 69 | 70 | def __init__(self, root, transform=None, loader=datasets.folder.default_loader, is_valid_file=None,prob = 1.0): 71 | super(FaceDataset, self).__init__(root, transform=transform,is_valid_file=is_valid_file) 72 | self.imgs = self.samples 73 | self.prob = prob 74 | self.transforms2 = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) 77 | def __getitem__(self, index): 78 | """ 79 | Args: 80 | index (int): Index 81 | Returns: 82 | tuple: (sample, target) where target is class_index of the target class. 83 | """ 84 | path, target = self.samples[index] 85 | original_path = path 86 | 87 | add_mask = False 88 | path = path.replace("imgs","imgs_masked2") 89 | 90 | try: 91 | sample = self.loader(path) 92 | except: 93 | try: 94 | sample = self.loader(path.replace(".jpg","_surgical.jpg")) 95 | except: 96 | try: 97 | sample = self.loader(path.replace(".jpg","_cloth.jpg")) 98 | except: 99 | try: 100 | sample = self.loader(path.replace(".jpg","_N95.jpg")) 101 | except: 102 | try: 103 | sample = self.loader(path.replace(".jpg","_KN95.jpg")) 104 | except: 105 | add_mask = True 106 | sample = self.loader(original_path) 107 | 108 | unmasked_sample = self.loader(original_path) 109 | 110 | if self.transform is not None: 111 | sample = self.transform(sample) 112 | unmasked_sample = self.transform(unmasked_sample) 113 | if np.random.uniform() > 0.5: 114 | sample = TF.hflip(sample) 115 | unmasked_sample = TF.hflip(unmasked_sample) 116 | sample = self.transforms2(sample) 117 | unmasked_sample = self.transforms2(unmasked_sample) 118 | if self.target_transform is not None: 119 | target = self.target_transform(target) 120 | 121 | mask = 1 122 | if add_mask: 123 | mask = 0 124 | 125 | sample = {'image_masked': sample, 'identity': target,'mask':mask,'image':unmasked_sample} 126 | return sample 127 | 128 | #"python mask_the_face.py --path ../Masked-Face-Recognition2/faces_emore_mask/ --code cloth, surgical-#adff2f, surgical-#87cefa, KN95, N95" 129 | #"surgical green, surgical blue, N95, cloth, and KN95" 130 | 131 | def get_train_dataset(imgs_folder): 132 | train_transform = transforms.Compose([ 133 | transforms.Resize((112,112)), 134 | transforms.CenterCrop((112,112)) 135 | ]) 136 | ds = FaceDataset(imgs_folder, train_transform,prob=0.55) 137 | class_num = ds[-1]["identity"] + 1 138 | return ds, class_num 139 | 140 | def get_valid_dataset(imgs_folder): 141 | valid_transform = transforms.Compose([ 142 | transforms.ToTensor(), 143 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 144 | ]) 145 | ds = FaceDatasetVal(imgs_folder, valid_transform,prob=0) 146 | class_num = ds[-1]["identity"] + 1 147 | return ds, class_num 148 | 149 | 150 | def get_train_loader(batch_size,workers,validation_split): 151 | 152 | ds, class_num = get_train_dataset("/home/pcarneiro/Masked-Face-Recognition2/faces_emore/imgs") 153 | ds_val, class_num = get_valid_dataset("/home/pcarneiro/Masked-Face-Recognition2/faces_emore/imgs") 154 | shuffle_dataset = True 155 | np.random.seed(25) 156 | dataset_size = len(ds) 157 | 158 | indices = list(range(dataset_size)) 159 | split = int(np.floor(validation_split * dataset_size)) 160 | 161 | if shuffle_dataset : 162 | np.random.shuffle(indices) 163 | 164 | _, val_indices = indices[split:], indices[:split] 165 | valid_sampler = SubsetRandomSampler(val_indices) 166 | train_loader = DataLoader(ds, batch_size=batch_size, 167 | shuffle=True,num_workers=workers,pin_memory=True,worker_init_fn=wif) 168 | 169 | validation_loader = DataLoader(ds_val, batch_size=batch_size, 170 | sampler=valid_sampler,num_workers=workers,shuffle=False,pin_memory=True,worker_init_fn=wif) 171 | return train_loader,validation_loader, class_num 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100'] 5 | 6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, 9 | out_planes, 10 | kernel_size=3, 11 | stride=stride, 12 | padding=dilation, 13 | groups=groups, 14 | bias=False, 15 | dilation=dilation) 16 | 17 | 18 | def conv1x1(in_planes, out_planes, stride=1): 19 | """1x1 convolution""" 20 | return nn.Conv2d(in_planes, 21 | out_planes, 22 | kernel_size=1, 23 | stride=stride, 24 | bias=False) 25 | class SEModule(nn.Module): 26 | def __init__(self, channels, reduction): 27 | super(SEModule, self).__init__() 28 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 29 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | def forward(self, x): 35 | input = x 36 | x = self.avg_pool(x) 37 | x = self.fc1(x) 38 | x = self.relu(x) 39 | x = self.fc2(x) 40 | x = self.sigmoid(x) 41 | 42 | return input * x 43 | 44 | class IBasicBlock(nn.Module): 45 | expansion = 1 46 | def __init__(self, inplanes, planes, stride=1, downsample=None, 47 | groups=1, base_width=64, dilation=1,use_se=False): 48 | super(IBasicBlock, self).__init__() 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 54 | self.conv1 = conv3x3(inplanes, planes) 55 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 56 | self.prelu = nn.PReLU(planes) 57 | self.conv2 = conv3x3(planes, planes, stride) 58 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 59 | self.downsample = downsample 60 | self.stride = stride 61 | self.use_se=use_se 62 | if (use_se): 63 | self.se_block=SEModule(planes,16) 64 | 65 | def forward(self, x): 66 | identity = x 67 | out = self.bn1(x) 68 | out = self.conv1(out) 69 | out = self.bn2(out) 70 | out = self.prelu(out) 71 | out = self.conv2(out) 72 | out = self.bn3(out) 73 | if(self.use_se): 74 | out=self.se_block(out) 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | out += identity 78 | return out 79 | 80 | 81 | class IResNet(nn.Module): 82 | fc_scale = 7 * 7 83 | def __init__(self, 84 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 85 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 86 | super(IResNet, self).__init__() 87 | self.fp16 = fp16 88 | self.inplanes = 64 89 | self.dilation = 1 90 | if replace_stride_with_dilation is None: 91 | replace_stride_with_dilation = [False, False, False] 92 | if len(replace_stride_with_dilation) != 3: 93 | raise ValueError("replace_stride_with_dilation should be None " 94 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 95 | self.groups = groups 96 | self.base_width = width_per_group 97 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 98 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 99 | self.prelu = nn.PReLU(self.inplanes) 100 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 101 | self.layer2 = self._make_layer(block, 102 | 128, 103 | layers[1], 104 | stride=2, 105 | dilate=replace_stride_with_dilation[0]) 106 | self.layer3 = self._make_layer(block, 107 | 256, 108 | layers[2], 109 | stride=2, 110 | dilate=replace_stride_with_dilation[1]) 111 | self.layer4 = self._make_layer(block, 112 | 512, 113 | layers[3], 114 | stride=2, 115 | dilate=replace_stride_with_dilation[2]) 116 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 117 | self.dropout = nn.Dropout(p=dropout, inplace=True) 118 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 119 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 120 | nn.init.constant_(self.features.weight, 1.0) 121 | self.features.weight.requires_grad = False 122 | 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | nn.init.normal_(m.weight, 0, 0.1) 126 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 127 | nn.init.constant_(m.weight, 1) 128 | nn.init.constant_(m.bias, 0) 129 | 130 | if zero_init_residual: 131 | for m in self.modules(): 132 | if isinstance(m, IBasicBlock): 133 | nn.init.constant_(m.bn2.weight, 0) 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 136 | downsample = None 137 | previous_dilation = self.dilation 138 | if dilate: 139 | self.dilation *= stride 140 | stride = 1 141 | if stride != 1 or self.inplanes != planes * block.expansion: 142 | downsample = nn.Sequential( 143 | conv1x1(self.inplanes, planes * block.expansion, stride), 144 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 145 | ) 146 | layers = [] 147 | layers.append( 148 | block(self.inplanes, planes, stride, downsample, self.groups, 149 | self.base_width, previous_dilation)) 150 | self.inplanes = planes * block.expansion 151 | for _ in range(1, blocks): 152 | layers.append( 153 | block(self.inplanes, 154 | planes, 155 | groups=self.groups, 156 | base_width=self.base_width, 157 | dilation=self.dilation)) 158 | 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | with torch.cuda.amp.autocast(self.fp16): 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.prelu(x) 166 | x = self.layer1(x) 167 | x = self.layer2(x) 168 | x = self.layer3(x) 169 | x = self.layer4(x) 170 | x = self.bn2(x) 171 | x = torch.flatten(x, 1) 172 | x = self.dropout(x) 173 | x = self.fc(x.float() if self.fp16 else x) 174 | #x = self.features(x) 175 | return x 176 | 177 | 178 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 179 | model = IResNet(block, layers, **kwargs) 180 | if pretrained: 181 | raise ValueError() 182 | return model 183 | 184 | 185 | def iresnet18(pretrained=False, progress=True, **kwargs): 186 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 187 | progress, **kwargs) 188 | 189 | 190 | def iresnet34(pretrained=False, progress=True, **kwargs): 191 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 192 | progress, **kwargs) 193 | 194 | 195 | def iresnet50(pretrained=False, progress=True, **kwargs): 196 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 197 | progress, **kwargs) 198 | 199 | 200 | def iresnet100(pretrained=False, progress=True, **kwargs): 201 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 202 | progress, **kwargs) 203 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from metrics import calculate_metrics 2 | import torch 3 | from torch import nn 4 | import sklearn 5 | import sklearn.metrics 6 | import numpy as np 7 | from tqdm import tqdm 8 | import wandb 9 | import datetime 10 | import pickle 11 | from PIL import Image 12 | import PIL 13 | from collections import defaultdict 14 | import mxnet as mx 15 | from mxnet import ndarray as nd 16 | 17 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | device = torch.device("cpu" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | class MetricMonitor: 22 | def __init__(self, float_precision=5): 23 | self.float_precision = float_precision 24 | self.reset() 25 | 26 | def reset(self): 27 | self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0}) 28 | 29 | def update(self, metric_name, val): 30 | metric = self.metrics[metric_name] 31 | 32 | metric["val"] += val 33 | metric["count"] += 1 34 | metric["avg"] = metric["val"] / metric["count"] 35 | 36 | def __str__(self): 37 | return " | ".join( 38 | [ 39 | "{metric_name}: {avg:.{float_precision}f}".format( 40 | metric_name=metric_name, avg=metric["avg"], float_precision=self.float_precision 41 | ) 42 | for (metric_name, metric) in self.metrics.items() 43 | ] 44 | ) 45 | 46 | 47 | 48 | def train(net,trainloader,validationloader,n_epochs=10,lr=0.1): 49 | MSE = torch.nn.MSELoss() 50 | data_set = load_bin("faces_emore/lfw.bin", (112,112)) 51 | wandb.init(project='', entity='') 52 | wandb.config.lr1 = 0.005 53 | wandb.config.lr2 = 0.1 54 | net.to("cuda:0") 55 | net.train() 56 | criterion = nn.CrossEntropyLoss() 57 | param2 = list(net.module.model.parameters()) + list(net.module.fc.parameters()) + list(net.module.fc2.parameters()) 58 | optimizer2 = torch.optim.SGD(param2, lr=wandb.config.lr2,weight_decay=5e-4,momentum=0.9) 59 | iteration = 0 60 | 61 | best_score = 100 62 | 63 | rate_decrease=1 64 | patience = 1 65 | 66 | for epoch in range(0,n_epochs): 67 | 68 | metric_monitor = MetricMonitor() 69 | stream = tqdm(trainloader) 70 | for _, sample in enumerate(stream, 0): 71 | net.train() 72 | inputs = sample['image'] 73 | inputs_masked = sample['image_masked'] 74 | labels = sample['identity'] 75 | labels2 = sample['mask'] 76 | inputs,inputs_masked, labels,labels2 = inputs.to("cuda:0"),inputs_masked.to("cuda:0"), labels.to("cuda:0"),labels2.to("cuda:0") 77 | 78 | 79 | optimizer2.zero_grad() 80 | outputs,e1,e2,mask = net(inputs,label=labels) 81 | loss = (criterion(outputs, labels)) + 0.1 * criterion(mask*0,labels2) 82 | outputs,e1_,e2,mask = net(inputs_masked,label=labels) 83 | loss += (criterion(outputs, labels)) + 0.1 * criterion(mask,labels2) 84 | loss /= 2 85 | loss += MSE(e1,e1_)/3 86 | loss.backward() 87 | optimizer2.step() 88 | 89 | 90 | 91 | metric_monitor.update("Loss P", loss.item()) 92 | wandb.log({"Loss P":loss.item()}) 93 | 94 | iteration +=1 95 | stream.set_description("Epoch: {epoch}. Train. {metric_monitor}".format(epoch=epoch, metric_monitor=metric_monitor)) 96 | fmr100 = validate(net,data_set,str(epoch)) 97 | 98 | if fmr100 < best_score: 99 | best_score = fmr100 100 | torch.save(net.module.state_dict(), "uai_batch" + str(epoch+1) +".mdl") 101 | print("SAVED THE MODEL") 102 | patience = 1 103 | else: 104 | if patience == 0: 105 | patience = 1 106 | rate_decrease /= 10 107 | optimizer2 = torch.optim.SGD(param2, lr=wandb.config.lr2 * rate_decrease,weight_decay=5e-4,momentum=0.9) 108 | print("New Learning Rate") 109 | print(wandb.config.lr2 * rate_decrease) 110 | else: patience -= 1 111 | print('Finished Training') 112 | 113 | 114 | def validate(net,data_set,epoch): 115 | net.eval() 116 | with torch.no_grad(): 117 | metrics = test(data_set, net, 128,epoch) 118 | print("FMR100 = " + str(metrics[1]*100)) 119 | wandb.log({"FMR100":metrics[1]*100}) 120 | print("AUC = " + str(metrics[5])) 121 | wandb.log({"AUC":metrics[5]}) 122 | wandb.log({"GMean":metrics[3]}) 123 | wandb.log({"IMean":metrics[4]}) 124 | return metrics[1] 125 | 126 | masked_labels = [] 127 | @torch.no_grad() 128 | def load_bin(path, image_size): 129 | try: 130 | with open(path, 'rb') as f: 131 | bins, issame_list = pickle.load(f) # py2 132 | except UnicodeDecodeError as e: 133 | with open(path, 'rb') as f: 134 | bins, issame_list = pickle.load(f, encoding='bytes') # py3 135 | 136 | 137 | #print(len(issame_list)) 138 | data_list = [] 139 | for flip in [0, 1]: 140 | data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) 141 | data_list.append(data) 142 | for idx in range(len(issame_list) * 2): 143 | #pdb.set_trace() 144 | #im = Image.fromarray(img.asnumpy()) 145 | #im.save("new_dataset/"+str(idx)+".jpg") 146 | if idx % 2 == 0: 147 | try: 148 | im = Image.open("new_dataset_masked2/"+str(idx)+".jpg") 149 | R, G, B = im.split() 150 | im = PIL.Image.merge("RGB", (B, G, R)) 151 | img = mx.nd.array(np.array(im)) 152 | masked_labels.append(1) 153 | except: 154 | im = Image.open("new_dataset/"+str(idx)+".jpg") 155 | R, G, B = im.split() 156 | im = PIL.Image.merge("RGB", (B, G, R)) 157 | img = mx.nd.array(np.array(im)) 158 | masked_labels.append(0) 159 | else: 160 | #_bin = bins[idx] 161 | #img = mx.image.imdecode(_bin) 162 | im = Image.open("new_dataset/"+str(idx)+".jpg") 163 | R, G, B = im.split() 164 | im = PIL.Image.merge("RGB", (B, G, R)) 165 | img = mx.nd.array(np.array(im)) 166 | masked_labels.append(0) 167 | 168 | #if img.shape[1] != image_size[0]: 169 | # img = mx.image.resize_short(img, image_size[0]) 170 | 171 | img = nd.transpose(img, axes=(2, 0, 1)) 172 | for flip in [0, 1]: 173 | if flip == 1: 174 | img = mx.ndarray.flip(data=img, axis=2) 175 | data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) 176 | 177 | if idx % 1000 == 0: 178 | print('loading bin', idx) 179 | print(data_list[0].shape) 180 | return data_list, issame_list 181 | 182 | @torch.no_grad() 183 | def test(data_set, backbone, batch_size,epoch): 184 | print('testing verification..') 185 | data_list = data_set[0] 186 | issame_list = data_set[1] 187 | embeddings_list = [] 188 | time_consumed = 0.0 189 | masked = [] 190 | for i in range(len(data_list)): 191 | data = data_list[i] 192 | embeddings = None 193 | ba = 0 194 | print(i) 195 | while ba < data.shape[0]: 196 | bb = min(ba + batch_size, data.shape[0]) 197 | count = bb - ba 198 | _data = data[bb - batch_size: bb] 199 | time0 = datetime.datetime.now() 200 | img = ((_data / 255) - 0.5) / 0.5 201 | img = img.to(device) 202 | _,net_out,_,y2 = backbone(img,inference = True) 203 | masked.append((i,y2.detach().cpu().numpy())) 204 | del img 205 | 206 | _embeddings = net_out.detach().cpu().numpy() 207 | time_now = datetime.datetime.now() 208 | diff = time_now - time0 209 | time_consumed += diff.total_seconds() 210 | if embeddings is None: 211 | embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) 212 | embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] 213 | ba = bb 214 | embeddings_list.append(embeddings) 215 | if i % 1 == 0: 216 | print('loading bin', i) 217 | print(time_consumed) 218 | 219 | masked2 = [] 220 | i = 0 221 | with open("mask_prediction.txt","w") as w: 222 | for mask in masked: 223 | label = mask[0] 224 | for mask2 in mask[1]: 225 | mask2=mask2.item() 226 | 227 | w.write(str(label) + "," + str(masked_labels[i]) + "," + str(mask2) + "\n") 228 | i+=1 229 | 230 | 231 | _xnorm = 0.0 232 | _xnorm_cnt = 0 233 | print("Normalizing") 234 | for embed in embeddings_list: 235 | for i in range(embed.shape[0]): 236 | _em = embed[i] 237 | _norm = np.linalg.norm(_em) 238 | _xnorm += _norm 239 | _xnorm_cnt += 1 240 | _xnorm /= _xnorm_cnt 241 | 242 | 243 | 244 | embeddings = embeddings_list[0].copy() 245 | embeddings = sklearn.preprocessing.normalize(embeddings) 246 | 247 | embeddings = embeddings_list[0] + embeddings_list[1] 248 | embeddings = sklearn.preprocessing.normalize(embeddings) 249 | 250 | embeddings1 = embeddings[0::2] 251 | embeddings2 = embeddings[1::2] 252 | positives = [] 253 | negatives = [] 254 | 255 | print(len(issame_list)) 256 | 257 | for embedding1, embedding2,label in zip(embeddings1,embeddings2,issame_list): 258 | dist = 1- torch.cdist(torch.from_numpy(embedding1).view(1, -1), torch.from_numpy(embedding2).view(1, -1))/2 259 | if label == 1: 260 | positives.append(dist) 261 | else: 262 | negatives.append(dist) 263 | return calculate_metrics(positives,negatives,epoch) 264 | 265 | --------------------------------------------------------------------------------