├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── __init__.py ├── configs └── mobilefacenet_vgg2.yml ├── datasets ├── __init__.py ├── casia.py ├── celeba.py ├── imdbface.py ├── lfw.py ├── megaface.py ├── ms_celeb1m.py ├── ndg.py ├── trillion_pairs.py └── vggface2.py ├── demo ├── README.md ├── demo.png └── run_demo.py ├── devtools └── pylint.rc ├── dump_features.py ├── evaluate_landmarks.py ├── evaluate_lfw.py ├── init_venv.sh ├── losses ├── __init__.py ├── alignment.py ├── am_softmax.py ├── centroid_based.py ├── metric_losses.py └── regularizer.py ├── model ├── __init__.py ├── backbones │ ├── __init__.py │ ├── resnet.py │ ├── rmnet.py │ ├── se_resnet.py │ ├── se_resnext.py │ └── shufflenet_v2.py ├── blocks │ ├── __init__.py │ ├── mobilenet_v2_blocks.py │ ├── resnet_blocks.py │ ├── rmnet_blocks.py │ ├── se_resnet_blocks.py │ ├── se_resnext_blocks.py │ ├── shared_blocks.py │ └── shufflenet_v2_blocks.py ├── common.py ├── landnet.py ├── mobilefacenet.py ├── resnet_angular.py ├── rmnet_angular.py ├── se_resnet_angular.py └── shufflenet_v2_angular.py ├── requirements.txt ├── scripts ├── __init__.py ├── accuracy_check.py ├── align_images.py ├── count_flops.py ├── matio.py ├── plot_roc_curves_lfw.py └── pytorch2onnx.py ├── tests ├── __init__.py ├── test_alignment.py ├── test_models.py └── test_utils.py ├── train.py ├── train_landmarks.py └── utils ├── __init__.py ├── augmentation.py ├── face_align.py ├── ie_tools.py ├── landmarks_augmentation.py ├── parser_yaml.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .idea/ 4 | *.iml 5 | **/venv 6 | data/test 7 | external/cocoapi 8 | tensorflow_toolkit/tests/models 9 | tensorflow_toolkit/**/model -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | sudo: required 3 | dist: xenial 4 | 5 | python: 6 | - "3.5" 7 | cache: pip 8 | 9 | install: 10 | - bash ./init_venv.sh 11 | 12 | jobs: 13 | include: 14 | - stage: Tests 15 | script: 16 | - . venv/bin/activate 17 | - python -m unittest 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018 algo 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Face Recognition in PyTorch 2 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 3 | [![Build Status](https://travis-ci.com/grib0ed0v/face_recognition.pytorch.svg?branch=develop)](https://travis-ci.com/grib0ed0v/face_recognition.pytorch) 4 | 5 | By [Alexey Gruzdev](https://www.linkedin.com/in/alexey-gruzdev-454399128/) and [Vladislav Sovrasov](https://www.linkedin.com/in/%D0%B2%D0%BB%D0%B0%D0%B4%D0%B8%D1%81%D0%BB%D0%B0%D0%B2-%D1%81%D0%BE%D0%B2%D1%80%D0%B0%D1%81%D0%BE%D0%B2-173b23104/) 6 | 7 | ## Introduction 8 | 9 | *A repository for different experimental Face Recognition models such as [CosFace](https://arxiv.org/pdf/1801.09414.pdf), [ArcFace](https://arxiv.org/pdf/1801.07698.pdf), [SphereFace](https://arxiv.org/pdf/1704.08063.pdf), [SV-Softmax](https://arxiv.org/pdf/1812.11317.pdf), etc.* 10 | 11 | ## Contents 12 | 1. [Installation](#installation) 13 | 2. [Preparation](#preparation) 14 | 3. [Train/Eval](#traineval) 15 | 4. [Models](#models) 16 | 5. [Face Recognition Demo](#demo) 17 | 18 | 19 | ## Installation 20 | 1. Create and activate virtual python environment 21 | ```bash 22 | bash init_venv.sh 23 | . venv/bin/activate 24 | ``` 25 | 26 | 27 | 28 | 29 | ## Preparation 30 | 31 | 1. For Face Recognition training you should download [VGGFace2](http://www.robots.ox.ac.uk/~vgg/data/vgg_face2/) data. We will refer to this folder as `$VGGFace2_ROOT`. 32 | 2. For Face Recognition evaluation you need to download [LFW](http://vis-www.cs.umass.edu/lfw/) data and [LFW landmarks](https://github.com/clcarwin/sphereface_pytorch/blob/master/data/lfw_landmark.txt). Place everything in one folder, which will be `$LFW_ROOT`. 33 | 34 | 35 | 36 | 37 | ## Train/Eval 38 | 1. Go to `$FR_ROOT` folder 39 | ```bash 40 | cd $FR_ROOT/ 41 | ``` 42 | 43 | 2. To start training FR model: 44 | ```bash 45 | python train.py --train_data_root $VGGFace2_ROOT/train/ --train_list $VGGFace2_ROOT/meta/train_list.txt 46 | --train_landmarks $VGGFace2_ROOT/bb_landmark/ --val_data_root $LFW_ROOT/lfw/ --val_list $LFW_ROOT/pairs.txt 47 | --val_landmarks $LFW_ROOT/lfw_landmark.txt --train_batch_size 200 --snap_prefix mobilenet_256 --lr 0.35 48 | --embed_size 256 --model mobilenet --device 1 49 | ``` 50 | 51 | 3. To evaluate FR snapshot (let's say we have MobileNet with 256 embedding size trained for 300k): 52 | ```bash 53 | python evaluate_lfw.py --val_data_root $LFW_ROOT/lfw/ --val_list $LFW_ROOT/pairs.txt 54 | --val_landmarks $LFW_ROOT/lfw_landmark.txt --snap /path/to/snapshot/mobilenet_256_300000.pt --model mobilenet --embed_size 256 55 | ``` 56 | 57 | ## Configuration files 58 | Besides passing all the required parameters via command line, the training script allows to read them from a `yaml` configuration file. 59 | Each line of such file should contain a valid description of one parameter in the `yaml` fromat. 60 | Example: 61 | ```yml 62 | #optimizer parameters 63 | lr: 0.4 64 | train_batch_size: 256 65 | #loss options 66 | margin_type: cos 67 | s: 30 68 | m: 0.35 69 | #model parameters 70 | model: mobilenet 71 | embed_size: 256 72 | #misc 73 | snap_prefix: MobileFaceNet 74 | devices: [0, 1] 75 | #datasets 76 | train_dataset: vgg 77 | train_data_root: $VGGFace2_ROOT/train/ 78 | #... and so on 79 | ``` 80 | Path to the config file can be passed to the training script via command line. In case if any other arguments were passed before the config, they will be overwritten. 81 | ```bash 82 | python train.py -m 0.35 @./my_config.yml #here m can be overwritten with the value from my_config.yml 83 | ``` 84 | 85 | 86 | 87 | ## Models 88 | 89 | 1. You can download pretrained model from fileshare as well - https://download.01.org/openvinotoolkit/open_model_zoo/training_toolbox_pytorch/models/fr/Mobilenet_se_focal_121000.pt 90 | ```bash 91 | cd $FR_ROOT 92 | python evaluate_lfw.py --val_data_root $LFW_ROOT/lfw/ --val_list $LFW_ROOT/pairs.txt --val_landmarks $LFW_ROOT/lfw_landmark.txt 93 | --snap /path/to/snapshot/Mobilenet_se_focal_121000.pt --model mobilenet --embed_size 256 94 | ``` 95 | 96 | 2. You should get the following output: 97 | ``` 98 | I1114 09:33:37.846870 10544 evaluate_lfw.py:242] Accuracy/Val_same_accuracy mean: 0.9923 99 | I1114 09:33:37.847019 10544 evaluate_lfw.py:243] Accuracy/Val_diff_accuracy mean: 0.9970 100 | I1114 09:33:37.847069 10544 evaluate_lfw.py:244] Accuracy/Val_accuracy mean: 0.9947 101 | I1114 09:33:37.847179 10544 evaluate_lfw.py:245] Accuracy/Val_accuracy std dev: 0.0035 102 | I1114 09:33:37.847229 10544 evaluate_lfw.py:246] AUC: 0.9995 103 | I1114 09:33:37.847305 10544 evaluate_lfw.py:247] Estimated threshold: 0.7241 104 | ``` 105 | 106 | ## Demo 107 | 108 | 1. For setting up demo, please go to [Face Recognition demo with OpenVINO Toolkit](./demo/README.md) 109 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/__init__.py -------------------------------------------------------------------------------- /configs/mobilefacenet_vgg2.yml: -------------------------------------------------------------------------------- 1 | #optimizer parameters 2 | lr: 0.4 3 | train_batch_size: 256 4 | #loss options 5 | margin_type: cos 6 | s: 30 7 | m: 0.35 8 | mining_type: sv 9 | t: 1.1 10 | #model parameters 11 | model: mobilenet 12 | embed_size: 256 13 | 14 | train_dataset: vgg 15 | snap_prefix: MobileFaceNet 16 | devices: [0, 1] 17 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .lfw import LFW 2 | from .vggface2 import VGGFace2 3 | from .ms_celeb1m import MSCeleb1M 4 | from .trillion_pairs import TrillionPairs 5 | from .imdbface import IMDBFace 6 | 7 | from .celeba import CelebA 8 | from .ndg import NDG 9 | 10 | __all__ = [LFW, VGGFace2, MSCeleb1M, TrillionPairs, IMDBFace, CelebA, NDG] 11 | -------------------------------------------------------------------------------- /datasets/casia.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | 16 | from tqdm import tqdm 17 | from torch.utils.data import Dataset 18 | import cv2 as cv 19 | 20 | from utils.face_align import FivePointsAligner 21 | 22 | class CASIA(Dataset): 23 | """CASIA Dataset compatible with PyTorch DataLoader.""" 24 | def __init__(self, images_root_path, image_list_path, transform, use_landmarks=True): 25 | self.image_list_path = image_list_path 26 | self.images_root_path = images_root_path 27 | self.identities = {} 28 | self.use_landmarks = use_landmarks 29 | self.samples_info = self._read_samples_info() 30 | self.transform = transform 31 | 32 | def _read_samples_info(self): 33 | """Reads annotation of the dataset""" 34 | samples = [] 35 | with open(self.image_list_path, 'r') as f: 36 | for line in tqdm(f.readlines(), 'Preparing CASIA dataset'): 37 | sample = line.split() 38 | sample_id = sample[1] 39 | landmarks = [[sample[i], sample[i+1]] for i in range(2, 12, 2)] 40 | self.identities[sample_id] = [1] 41 | samples.append((osp.join(self.images_root_path, sample[0]), sample_id, landmarks)) 42 | 43 | return samples 44 | 45 | def get_num_classes(self): 46 | """Returns total number of identities""" 47 | return len(self.identities) 48 | 49 | def __len__(self): 50 | """Returns total number of samples""" 51 | return len(self.samples_info) 52 | 53 | def __getitem__(self, idx): 54 | img = cv.imread(self.samples_info[idx][0]) 55 | if self.use_landmarks: 56 | img = FivePointsAligner.align(img, self.samples_info[idx][2], 57 | d_size=(200, 200), normalized=True, show=False) 58 | 59 | if self.transform: 60 | img = self.transform(img) 61 | return {'img': img, 'label': int(self.samples_info[idx][1])} 62 | -------------------------------------------------------------------------------- /datasets/celeba.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | 16 | import numpy as np 17 | from tqdm import tqdm 18 | from torch.utils.data import Dataset 19 | import cv2 as cv 20 | 21 | 22 | class CelebA(Dataset): 23 | """CelebA Dataset compatible with PyTorch DataLoader.""" 24 | def __init__(self, images_root_path, landmarks_folder_path, transform=None, test=False): 25 | self.test = test 26 | self.have_landmarks = True 27 | self.images_root_path = images_root_path 28 | bb_file_name = 'list_bbox_celeba.txt' 29 | landmarks_file_name = 'list_landmarks_celeba.txt' 30 | self.detections_file = open(osp.join(landmarks_folder_path, bb_file_name), 'r') 31 | self.landmarks_file = open(osp.join(landmarks_folder_path, landmarks_file_name), 'r') 32 | self.samples_info = self._read_samples_info() 33 | self.transform = transform 34 | 35 | def _read_samples_info(self): 36 | """Reads annotation of the dataset""" 37 | samples = [] 38 | 39 | detections_file_lines = self.detections_file.readlines()[2:] 40 | landmarks_file_lines = self.landmarks_file.readlines()[2:] 41 | assert len(detections_file_lines) == len(landmarks_file_lines) 42 | 43 | if self.test: 44 | images_range = range(182638, len(landmarks_file_lines)) 45 | else: 46 | images_range = range(182637) 47 | 48 | for i in tqdm(images_range): 49 | line = detections_file_lines[i].strip() 50 | img_name = line.split(' ')[0] 51 | img_path = osp.join(self.images_root_path, img_name) 52 | 53 | bbox = list(filter(bool, line.split(' ')[1:])) 54 | bbox = [int(coord) for coord in bbox] 55 | if bbox[2] == 0 or bbox[3] == 0: 56 | continue 57 | 58 | line_landmarks = landmarks_file_lines[i].strip().split(' ')[1:] 59 | landmarks = list(filter(bool, line_landmarks)) 60 | landmarks = [float(coord) for coord in landmarks] 61 | samples.append((img_path, bbox, landmarks)) 62 | 63 | return samples 64 | 65 | def __len__(self): 66 | """Returns total number of samples""" 67 | return len(self.samples_info) 68 | 69 | def __getitem__(self, idx): 70 | """Returns sample (image, landmarks) by index""" 71 | img = cv.imread(self.samples_info[idx][0], cv.IMREAD_COLOR) 72 | bbox = self.samples_info[idx][1] 73 | landmarks = self.samples_info[idx][2] 74 | 75 | img = img[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]] 76 | landmarks = np.array([(float(landmarks[2*i]-bbox[0]) / bbox[2], 77 | float(landmarks[2*i + 1]-bbox[1])/ bbox[3]) \ 78 | for i in range(len(landmarks)//2)]).reshape(-1) 79 | data = {'img': img, 'landmarks': landmarks} 80 | if self.transform: 81 | data = self.transform(data) 82 | return data 83 | -------------------------------------------------------------------------------- /datasets/imdbface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | 16 | import cv2 as cv 17 | from tqdm import tqdm 18 | from torch.utils.data import Dataset 19 | 20 | from utils.face_align import FivePointsAligner 21 | 22 | 23 | class IMDBFace(Dataset): 24 | """IMDBFace Dataset compatible with PyTorch DataLoader.""" 25 | def __init__(self, images_root_path, image_list_path, transform=None): 26 | self.image_list_path = image_list_path 27 | self.images_root_path = images_root_path 28 | self.identities = {} 29 | 30 | assert osp.isfile(image_list_path) 31 | self.have_landmarks = True 32 | 33 | self.all_samples_info = self._read_samples_info() 34 | self.samples_info = self.all_samples_info 35 | self.transform = transform 36 | 37 | def _read_samples_info(self): 38 | """Reads annotation of the dataset""" 39 | samples = [] 40 | 41 | with open(self.image_list_path, 'r') as f: 42 | images_file_lines = f.readlines() 43 | last_class_id = -1 44 | 45 | for i in tqdm(range(len(images_file_lines))): 46 | line = images_file_lines[i] 47 | terms = line.split('|') 48 | if len(terms) < 3: 49 | continue # FD has failed on this imsage 50 | path, landmarks, _ = terms 51 | image_id, _ = path.rsplit('/', 1) 52 | 53 | if image_id in self.identities: 54 | self.identities[image_id].append(len(samples)) 55 | else: 56 | last_class_id += 1 57 | self.identities[image_id] = [len(samples)] 58 | 59 | landmarks = [float(coord) for coord in landmarks.strip().split(' ')] 60 | assert len(landmarks) == 10 61 | samples.append((osp.join(self.images_root_path, path).strip(), last_class_id, image_id, landmarks)) 62 | 63 | return samples 64 | 65 | def get_weights(self): 66 | """Computes weights of the each identity in dataset according to frequency of it's occurance""" 67 | weights = [0.]*len(self.all_samples_info) 68 | for i, sample in enumerate(self.all_samples_info): 69 | weights[i] = float(len(self.all_samples_info)) / len(self.identities[sample[2]]) 70 | return weights 71 | 72 | def get_num_classes(self): 73 | """Returns total number of identities""" 74 | return len(self.identities) 75 | 76 | def __len__(self): 77 | """Returns total number of samples""" 78 | return len(self.samples_info) 79 | 80 | def __getitem__(self, idx): 81 | """Returns sample (image, class id, image id) by index""" 82 | img = cv.imread(self.samples_info[idx][0], cv.IMREAD_COLOR) 83 | landmarks = self.samples_info[idx][-1] 84 | img = FivePointsAligner.align(img, landmarks, d_size=(200, 200), normalized=True, show=False) 85 | 86 | if self.transform: 87 | img = self.transform(img) 88 | 89 | return {'img': img, 'label': self.samples_info[idx][1], 'instance': self.samples_info[idx][2]} 90 | -------------------------------------------------------------------------------- /datasets/lfw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | import cv2 as cv 16 | import numpy as np 17 | from torch.utils.data import Dataset 18 | 19 | from utils.face_align import FivePointsAligner 20 | 21 | 22 | class LFW(Dataset): 23 | """LFW Dataset compatible with PyTorch DataLoader.""" 24 | def __init__(self, images_root_path, pairs_path, landmark_file_path='', transform=None): 25 | self.pairs_path = pairs_path 26 | self.images_root_path = images_root_path 27 | self.landmark_file_path = landmark_file_path 28 | self.use_landmarks = len(self.landmark_file_path) > 0 29 | if self.use_landmarks: 30 | self.landmarks = self._read_landmarks() 31 | self.pairs = self._read_pairs() 32 | self.transform = transform 33 | 34 | def _read_landmarks(self): 35 | """Reads landmarks of the dataset""" 36 | landmarks = {} 37 | with open(self.landmark_file_path, 'r') as f: 38 | for line in f.readlines(): 39 | sp = line.split() 40 | key = sp[0][sp[0].rfind('/')+1:] 41 | landmarks[key] = [[int(sp[i]), int(sp[i+1])] for i in range(1, 11, 2)] 42 | 43 | return landmarks 44 | 45 | def _read_pairs(self): 46 | """Reads annotation of the dataset""" 47 | pairs = [] 48 | with open(self.pairs_path, 'r') as f: 49 | for line in f.readlines()[1:]: # skip header 50 | pair = line.strip().split() 51 | pairs.append(pair) 52 | 53 | file_ext = 'jpg' 54 | lfw_dir = self.images_root_path 55 | path_list = [] 56 | 57 | for pair in pairs: 58 | if len(pair) == 3: 59 | path0 = osp.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1]) + '.' + file_ext) 60 | id0 = pair[0] 61 | path1 = osp.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2]) + '.' + file_ext) 62 | id1 = pair[0] 63 | issame = True 64 | elif len(pair) == 4: 65 | path0 = osp.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1]) + '.' + file_ext) 66 | id0 = pair[0] 67 | path1 = osp.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3]) + '.' + file_ext) 68 | id1 = pair[0] 69 | issame = False 70 | 71 | path_list.append((path0, path1, issame, id0, id1)) 72 | 73 | return path_list 74 | 75 | def _load_img(self, img_path): 76 | """Loads an image from dist, then performs face alignment and applies transform""" 77 | img = cv.imread(img_path, cv.IMREAD_COLOR) 78 | 79 | if self.use_landmarks: 80 | landmarks = np.array(self.landmarks[img_path[img_path.rfind('/')+1:]]).reshape(-1) 81 | img = FivePointsAligner.align(img, landmarks, show=False) 82 | 83 | if self.transform is None: 84 | return img 85 | 86 | return self.transform(img) 87 | 88 | def show_item(self, index): 89 | """Saves a pair with a given index to disk""" 90 | path_1, path_2, _, _, _ = self.pairs[index] 91 | img1 = cv.imread(path_1) 92 | img2 = cv.imread(path_2) 93 | if self.use_landmarks: 94 | landmarks1 = np.array(self.landmarks[path_1[path_1.rfind('/')+1:]]).reshape(-1) 95 | landmarks2 = np.array(self.landmarks[path_2[path_2.rfind('/')+1:]]).reshape(-1) 96 | img1 = FivePointsAligner.align(img1, landmarks1) 97 | img2 = FivePointsAligner.align(img2, landmarks2) 98 | else: 99 | img1 = cv.resize(img1, (400, 400)) 100 | img2 = cv.resize(img2, (400, 400)) 101 | cv.imwrite('misclassified_{}.jpg'.format(index), np.hstack([img1, img2])) 102 | 103 | def __getitem__(self, index): 104 | """Returns a pair of images and similarity flag by index""" 105 | (path_1, path_2, is_same, id0, id1) = self.pairs[index] 106 | img1, img2 = self._load_img(path_1), self._load_img(path_2) 107 | 108 | return {'img1': img1, 'img2': img2, 'is_same': is_same, 'id0': id0, 'id1': id1} 109 | 110 | def __len__(self): 111 | """Returns total number of samples""" 112 | return len(self.pairs) 113 | -------------------------------------------------------------------------------- /datasets/megaface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import numpy as np 15 | from torch.utils.data import Dataset 16 | import cv2 as cv 17 | 18 | from utils.face_align import FivePointsAligner 19 | 20 | 21 | class MegaFace(Dataset): 22 | """MegaFace Dataset compatible with PyTorch DataLoader.""" 23 | def __init__(self, images_lsit, transform=None): 24 | self.samples_info = images_lsit 25 | self.transform = transform 26 | 27 | def __len__(self): 28 | """Returns total number of samples""" 29 | return len(self.samples_info) 30 | 31 | def __getitem__(self, idx): 32 | """Returns sample (image, index)""" 33 | img = None 34 | try: 35 | img = cv.imread(self.samples_info[idx]['path'], cv.IMREAD_COLOR) 36 | bbox = self.samples_info[idx]['bbox'] 37 | landmarks = self.samples_info[idx]['landmarks'] 38 | 39 | if bbox is not None or landmarks is not None: 40 | if landmarks is not None: 41 | landmarks = np.array(landmarks).reshape(5, -1) 42 | landmarks[:,0] = landmarks[:,0]*bbox[2] + bbox[0] 43 | landmarks[:,1] = landmarks[:,1]*bbox[3] + bbox[1] 44 | img = FivePointsAligner.align(img, landmarks.reshape(-1), d_size=(bbox[2], bbox[3]), 45 | normalized=False, show=False) 46 | if bbox is not None and landmarks is None: 47 | img = img[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]] 48 | except BaseException: 49 | print('Corrupted image!', self.samples_info[idx]) 50 | img = np.zeros((128, 128, 3), dtype='uint8') 51 | 52 | if self.transform: 53 | img = self.transform(img) 54 | 55 | return {'img': img, 'idx': idx} 56 | -------------------------------------------------------------------------------- /datasets/ms_celeb1m.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | 16 | import cv2 as cv 17 | from tqdm import tqdm 18 | from torch.utils.data import Dataset 19 | 20 | from utils.face_align import FivePointsAligner 21 | 22 | 23 | class MSCeleb1M(Dataset): 24 | """MSCeleb1M Dataset compatible with PyTorch DataLoader.""" 25 | def __init__(self, images_root_path, image_list_path, transform=None): 26 | self.image_list_path = image_list_path 27 | self.images_root_path = images_root_path 28 | self.identities = {} 29 | 30 | assert osp.isfile(image_list_path) 31 | self.have_landmarks = True 32 | 33 | self.all_samples_info = self._read_samples_info() 34 | self.samples_info = self.all_samples_info 35 | self.transform = transform 36 | 37 | def _read_samples_info(self): 38 | """Reads annotation of the dataset""" 39 | samples = [] 40 | 41 | with open(self.image_list_path, 'r') as f: 42 | images_file_lines = f.readlines() 43 | last_class_id = -1 44 | 45 | for i in tqdm(range(len(images_file_lines))): 46 | line = images_file_lines[i] 47 | terms = line.split('|') 48 | if len(terms) < 3: 49 | continue # FD has failed on this imsage 50 | path, landmarks, bbox = terms 51 | image_id, _ = path.split('/') 52 | 53 | if image_id in self.identities: 54 | self.identities[image_id].append(len(samples)) 55 | else: 56 | last_class_id += 1 57 | self.identities[image_id] = [len(samples)] 58 | 59 | bbox = [max(int(coord), 0) for coord in bbox.strip().split(' ')] 60 | landmarks = [float(coord) for coord in landmarks.strip().split(' ')] 61 | assert len(bbox) == 4 62 | assert len(landmarks) == 10 63 | samples.append((osp.join(self.images_root_path, path).strip(), 64 | last_class_id, image_id, bbox, landmarks)) 65 | 66 | return samples 67 | 68 | def get_weights(self): 69 | """Computes weights of the each identity in dataset according to frequency of it's occurance""" 70 | weights = [0.]*len(self.all_samples_info) 71 | for i, sample in enumerate(self.all_samples_info): 72 | weights[i] = float(len(self.all_samples_info)) / len(self.identities[sample[2]]) 73 | return weights 74 | 75 | def get_num_classes(self): 76 | """Returns total number of identities""" 77 | return len(self.identities) 78 | 79 | def __len__(self): 80 | """Returns total number of samples""" 81 | return len(self.samples_info) 82 | 83 | def __getitem__(self, idx): 84 | """Returns sample (image, class id, image id) by index""" 85 | img = cv.imread(self.samples_info[idx][0], cv.IMREAD_COLOR) 86 | bbox = self.samples_info[idx][-2] 87 | landmarks = self.samples_info[idx][-1] 88 | 89 | img = img[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]] 90 | img = FivePointsAligner.align(img, landmarks, d_size=(200, 200), normalized=True, show=False) 91 | 92 | if self.transform: 93 | img = self.transform(img) 94 | 95 | return {'img': img, 'label': self.samples_info[idx][1], 'instance': self.samples_info[idx][2]} 96 | -------------------------------------------------------------------------------- /datasets/ndg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | import json 16 | import numpy as np 17 | from tqdm import tqdm 18 | from torch.utils.data import Dataset 19 | import cv2 as cv 20 | 21 | 22 | class NDG(Dataset): 23 | """NDG Dataset compatible with PyTorch DataLoader.""" 24 | def __init__(self, images_root_path, annotation_list, transform=None, test=False): 25 | self.test = test 26 | self.have_landmarks = True 27 | self.images_root_path = images_root_path 28 | self.landmarks_file = open(annotation_list, 'r') 29 | self.samples_info = self._read_samples_info() 30 | self.transform = transform 31 | 32 | def _read_samples_info(self): 33 | """Reads annotation of the dataset""" 34 | samples = [] 35 | data = json.load(self.landmarks_file) 36 | 37 | for image_info in tqdm(data): 38 | img_name = image_info['path'] 39 | img_path = osp.join(self.images_root_path, img_name) 40 | landmarks = image_info['lm'] 41 | samples.append((img_path, landmarks)) 42 | 43 | return samples 44 | 45 | def __len__(self): 46 | """Returns total number of samples""" 47 | return len(self.samples_info) 48 | 49 | def __getitem__(self, idx): 50 | """Returns sample (image, landmarks) by index""" 51 | img = cv.imread(self.samples_info[idx][0], cv.IMREAD_COLOR) 52 | landmarks = self.samples_info[idx][1] 53 | width, height = img.shape[1], img.shape[0] 54 | landmarks = np.array([(float(landmarks[i][0]) / width, 55 | float(landmarks[i][1]) / height) for i in range(len(landmarks))]).reshape(-1) 56 | data = {'img': img, 'landmarks': landmarks} 57 | if self.transform: 58 | data = self.transform(data) 59 | return data 60 | -------------------------------------------------------------------------------- /datasets/trillion_pairs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | 16 | import cv2 as cv 17 | from tqdm import tqdm 18 | from torch.utils.data import Dataset 19 | 20 | from utils.face_align import FivePointsAligner 21 | 22 | 23 | class TrillionPairs(Dataset): 24 | """TrillionPairs Dataset compatible with PyTorch DataLoader. For details visit http://trillionpairs.deepglint.com/data""" 25 | def __init__(self, images_root_path, image_list_path, test_mode=False, transform=None): 26 | self.image_list_path = image_list_path 27 | self.images_root_path = images_root_path 28 | self.test_mode = test_mode 29 | self.identities = {} 30 | 31 | assert osp.isfile(image_list_path) 32 | self.have_landmarks = True 33 | 34 | self.all_samples_info = self._read_samples_info() 35 | self.samples_info = self.all_samples_info 36 | self.transform = transform 37 | 38 | def _read_samples_info(self): 39 | """Reads annotation of the dataset""" 40 | samples = [] 41 | 42 | with open(self.image_list_path, 'r') as f: 43 | images_file_lines = f.readlines() 44 | 45 | for i in tqdm(range(len(images_file_lines))): 46 | line = images_file_lines[i].strip() 47 | terms = line.split(' ') 48 | path = terms[0] 49 | if not self.test_mode: 50 | label = int(terms[1]) 51 | landmarks = terms[2:] 52 | if label in self.identities: 53 | self.identities[label].append(len(samples)) 54 | else: 55 | self.identities[label] = [len(samples)] 56 | else: 57 | label = 0 58 | landmarks = terms[1:] 59 | 60 | landmarks = [float(coord) for coord in landmarks] 61 | assert(len(landmarks) == 10) 62 | samples.append((osp.join(self.images_root_path, path).strip(), 63 | label, landmarks)) 64 | 65 | return samples 66 | 67 | def get_weights(self): 68 | """Computes weights of the each identity in dataset according to frequency of it's occurance""" 69 | weights = [0.]*len(self.all_samples_info) 70 | for i, sample in enumerate(self.all_samples_info): 71 | weights[i] = float(len(self.all_samples_info)) / len(self.identities[sample[1]]) 72 | return weights 73 | 74 | def get_num_classes(self): 75 | """Returns total number of identities""" 76 | return len(self.identities) 77 | 78 | def __len__(self): 79 | """Returns total number of samples""" 80 | return len(self.samples_info) 81 | 82 | def __getitem__(self, idx): 83 | """Returns sample (image, class id, image id) by index""" 84 | img = cv.imread(self.samples_info[idx][0], cv.IMREAD_COLOR) 85 | landmarks = self.samples_info[idx][-1] 86 | 87 | img = FivePointsAligner.align(img, landmarks, d_size=(200, 200), normalized=False, show=False) 88 | 89 | if self.transform: 90 | img = self.transform(img) 91 | 92 | return {'img': img, 'label': self.samples_info[idx][1], 'idx': idx} 93 | -------------------------------------------------------------------------------- /datasets/vggface2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import os.path as osp 15 | import cv2 as cv 16 | from tqdm import tqdm 17 | import numpy as np 18 | from torch.utils.data import Dataset 19 | 20 | from utils.face_align import FivePointsAligner 21 | 22 | 23 | class VGGFace2(Dataset): 24 | """VGGFace2 Dataset compatible with PyTorch DataLoader.""" 25 | def __init__(self, images_root_path, image_list_path, landmarks_folder_path='', 26 | transform=None, landmarks_training=False): 27 | self.image_list_path = image_list_path 28 | self.images_root_path = images_root_path 29 | self.identities = {} 30 | 31 | self.landmarks_file = None 32 | self.detections_file = None 33 | if osp.isdir(landmarks_folder_path): 34 | if 'train' in image_list_path: 35 | bb_file_name = 'loose_landmark_train.csv' 36 | landmarks_file_name = 'loose_bb_train.csv' 37 | elif 'test' in image_list_path: 38 | bb_file_name = 'loose_landmark_test.csv' 39 | landmarks_file_name = 'loose_bb_test.csv' 40 | else: 41 | bb_file_name = 'loose_landmark_all.csv' 42 | landmarks_file_name = 'loose_bb_all.csv' 43 | self.landmarks_file = open(osp.join(landmarks_folder_path, bb_file_name), 'r') 44 | self.detections_file = open(osp.join(landmarks_folder_path, landmarks_file_name), 'r') 45 | self.have_landmarks = not self.landmarks_file is None 46 | self.landmarks_training = landmarks_training 47 | if self.landmarks_training: 48 | assert self.have_landmarks is True 49 | 50 | self.samples_info = self._read_samples_info() 51 | 52 | self.transform = transform 53 | 54 | def _read_samples_info(self): 55 | """Reads annotation of the dataset""" 56 | samples = [] 57 | 58 | with open(self.image_list_path, 'r') as f: 59 | last_class_id = -1 60 | images_file_lines = f.readlines() 61 | 62 | if self.have_landmarks: 63 | detections_file_lines = self.detections_file.readlines()[1:] 64 | landmarks_file_lines = self.landmarks_file.readlines()[1:] 65 | assert len(detections_file_lines) == len(landmarks_file_lines) 66 | assert len(images_file_lines) == len(detections_file_lines) 67 | 68 | for i in tqdm(range(len(images_file_lines))): 69 | sample = images_file_lines[i].strip() 70 | sample_id = int(sample.split('/')[0][1:]) 71 | frame_id = int(sample.split('/')[1].split('_')[0]) 72 | if sample_id in self.identities: 73 | self.identities[sample_id].append(len(samples)) 74 | else: 75 | last_class_id += 1 76 | self.identities[sample_id] = [len(samples)] 77 | if not self.have_landmarks: 78 | samples.append((osp.join(self.images_root_path, sample), last_class_id, frame_id)) 79 | else: 80 | _, bbox = detections_file_lines[i].split('",') 81 | bbox = [max(int(coord), 0) for coord in bbox.split(',')] 82 | _, landmarks = landmarks_file_lines[i].split('",') 83 | landmarks = [float(coord) for coord in landmarks.split(',')] 84 | samples.append((osp.join(self.images_root_path, sample), last_class_id, sample_id, bbox, landmarks)) 85 | 86 | return samples 87 | 88 | def get_weights(self): 89 | """Computes weights of the each identity in dataset according to frequency of it's occurance""" 90 | weights = [0.]*len(self.samples_info) 91 | for i, sample in enumerate(self.samples_info): 92 | weights[i] = len(self.samples_info) / float(len(self.identities[sample[2]])) 93 | 94 | return weights 95 | 96 | def get_num_classes(self): 97 | """Returns total number of identities""" 98 | return len(self.identities) 99 | 100 | def __len__(self): 101 | """Returns total number of samples""" 102 | return len(self.samples_info) 103 | 104 | def __getitem__(self, idx): 105 | """Returns sample (image, class id, image id) by index""" 106 | img = cv.imread(self.samples_info[idx][0], cv.IMREAD_COLOR) 107 | if self.landmarks_training: 108 | landmarks = self.samples_info[idx][-1] 109 | bbox = self.samples_info[idx][-2] 110 | img = img[bbox[1]:bbox[1] + bbox[3], bbox[0]:bbox[0] + bbox[2]] 111 | landmarks = [(float(landmarks[2*i]-bbox[0]) / bbox[2], 112 | float(landmarks[2*i + 1]-bbox[1])/ bbox[3]) for i in range(len(landmarks)//2)] 113 | data = {'img': img, 'landmarks': np.array(landmarks)} 114 | if self.transform: 115 | data = self.transform(data) 116 | return data 117 | 118 | if self.have_landmarks: 119 | landmarks = self.samples_info[idx][-1] 120 | img = FivePointsAligner.align(img, landmarks, d_size=(200, 200), normalized=False) 121 | 122 | if self.transform: 123 | img = self.transform(img) 124 | 125 | return {'img': img, 'label': self.samples_info[idx][1], 'instance': self.samples_info[idx][2]} 126 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # Face Recognition demo with [OpenVINO™ Toolkit](https://software.intel.com/en-us/openvino-toolkit) 2 | 3 | ![](./demo.png) 4 | 5 | ## Demo Preparation 6 | 7 | 1. Install **OpenVINO Toolkit** - [Linux installation guide](https://software.intel.com/en-us/articles/OpenVINO-Install-Linux) 8 | 9 | 2. Create virtual python environment: 10 | ```bash 11 | mkvirtualenv fr --python=python3 12 | ``` 13 | 3. Install dependencies: 14 | ```bash 15 | pip install -r requirements.txt 16 | ``` 17 | 4. Initialize OpenVINO environment: 18 | ```bash 19 | source /opt/intel/computer_vision_sdk/bin/setupvars.sh 20 | ``` 21 | 22 | ## Deep Face Recognition 23 | 1. Set up `PATH_TO_GALLERY` variable to point to folder with gallery images (faces to be recognized): 24 | ```bash 25 | export PATH_TO_GALLERY=/path/to/gallery/with/images/ 26 | ``` 27 | 2. For using OpenVINO pretrained models, please specify `IR_MODELS_ROOT`, otherwise you need to modify running command. 28 | ```bash 29 | export IR_MODELS_ROOT=$INTEL_CVSDK_DIR/deployment_tools/intel_models/ 30 | ``` 31 | 3. If you are running from pure console, you need to specify `PYTHONPATH` variable: 32 | ```bash 33 | export PYTHONPATH=`pwd`:$PYTHONPATH 34 | ``` 35 | 4. Run Face Recognition demo: 36 | ```bash 37 | python demo/run_demo.py --path_to_gallery $PATH_TO_GALLERY --cam_id 0 \ 38 | --fd_model $IR_MODELS_ROOT/face-detection-retail-0004/FP32/face-detection-retail-0004.xml \ 39 | --fr_model $IR_MODELS_ROOT/face-reidentification-retail-0095/FP32/face-reidentification-retail-0095.xml \ 40 | --ld_model $IR_MODELS_ROOT/landmarks-regression-retail-0009/FP32/landmarks-regression-retail-0009.xml \ 41 | -l libcpu_extension_avx2.so 42 | ``` 43 | *Note:* `libcpu_extension_avx2.so` is located at the `$INTEL_CVSDK_DIR/inference_engine/lib//intel64/` folder. 44 | Here the `` is a name detected by the OpenVINO. It can be for example `ubuntu_16.04` if you are running the demo under Ubuntu 16.04 system. The folder with CPU extensions is already in `LD_LIBRARY_PATH` after initialization of the OpenVINO environment, that's why it can be omitted in the launch command. 45 | -------------------------------------------------------------------------------- /demo/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/demo/demo.png -------------------------------------------------------------------------------- /demo/run_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | import os 16 | import os.path as osp 17 | 18 | import glog as log 19 | import cv2 as cv 20 | import numpy as np 21 | from scipy.spatial.distance import cosine 22 | 23 | from utils import face_align 24 | from utils.ie_tools import load_ie_model 25 | 26 | 27 | class FaceDetector: 28 | """Wrapper class for face detector""" 29 | def __init__(self, model_path, conf=.6, device='CPU', ext_path=''): 30 | self.net = load_ie_model(model_path, device, None, ext_path) 31 | self.confidence = conf 32 | self.expand_ratio = (1.1, 1.05) 33 | 34 | def get_detections(self, frame): 35 | """Returns all detections on frame""" 36 | _, _, h, w = self.net.get_input_shape().shape 37 | out = self.net.forward(cv.resize(frame, (w, h))) 38 | detections = self.__decode_detections(out, frame.shape) 39 | return detections 40 | 41 | def __decode_detections(self, out, frame_shape): 42 | """Decodes raw SSD output""" 43 | detections = [] 44 | 45 | for detection in out[0, 0]: 46 | confidence = detection[2] 47 | if confidence > self.confidence: 48 | left = int(max(detection[3], 0) * frame_shape[1]) 49 | top = int(max(detection[4], 0) * frame_shape[0]) 50 | right = int(max(detection[5], 0) * frame_shape[1]) 51 | bottom = int(max(detection[6], 0) * frame_shape[0]) 52 | if self.expand_ratio != (1., 1.): 53 | w = (right - left) 54 | h = (bottom - top) 55 | dw = w * (self.expand_ratio[0] - 1.) / 2 56 | dh = h * (self.expand_ratio[1] - 1.) / 2 57 | left = max(int(left - dw), 0) 58 | right = int(right + dw) 59 | top = max(int(top - dh), 0) 60 | bottom = int(bottom + dh) 61 | 62 | detections.append(((left, top, right, bottom), confidence)) 63 | 64 | if len(detections) > 1: 65 | detections.sort(key=lambda x: x[1], reverse=True) 66 | 67 | return detections 68 | 69 | 70 | class VectorCNN: 71 | """Wrapper class for a nework returning a vector""" 72 | def __init__(self, model_path, device='CPU'): 73 | self.net = load_ie_model(model_path, device, None) 74 | 75 | def forward(self, batch): 76 | """Performs forward of the underlying network on a given batch""" 77 | _, _, h, w = self.net.get_input_shape().shape 78 | outputs = [self.net.forward(cv.resize(frame, (w, h))) for frame in batch] 79 | return outputs 80 | 81 | 82 | def get_embeddings(frame, detections, face_reid, landmarks_predictor): 83 | """Get embeddings for all detected faces on the frame""" 84 | rois = [] 85 | embeddings = [] 86 | for rect, _ in detections: 87 | left, top, right, bottom = rect 88 | rois.append(frame[top:bottom, left:right]) 89 | 90 | if rois: 91 | landmarks = landmarks_predictor.forward(rois) 92 | assert len(landmarks) == len(rois) 93 | 94 | for i, _ in enumerate(rois): 95 | roi_keypoints = landmarks[i].reshape(-1) 96 | rois[i] = face_align.FivePointsAligner.align(rois[i], roi_keypoints, 97 | d_size=(rois[i].shape[1], rois[i].shape[0]), 98 | normalized=True, show=False) 99 | embeddings = face_reid.forward(rois) 100 | assert len(rois) == len(embeddings) 101 | 102 | return embeddings 103 | 104 | 105 | def find_nearest(x, gallery, thr): 106 | """Finds the nearest to a given embedding in the gallery""" 107 | if gallery: 108 | diffs = np.array([cosine(x, y) for y in gallery.values()]) 109 | min_pos = diffs.argmin() 110 | min_dist = diffs[min_pos] 111 | if min_dist < thr: 112 | return min_pos, list(gallery.keys())[min_pos] 113 | return None, None 114 | 115 | 116 | def match_embeddings(embeds, gallery, thr): 117 | """Matches input embeddings with ones in the gallery""" 118 | indexes = [] 119 | for emb in embeds: 120 | _, name = find_nearest(emb, gallery, thr) 121 | if name is not None: 122 | indexes.append(name) 123 | else: 124 | indexes.append('Unknown') 125 | 126 | return indexes, gallery 127 | 128 | 129 | def draw_detections(frame, detections, indexes): 130 | """Draws detections and labels""" 131 | for i, rect in enumerate(detections): 132 | left, top, right, bottom = rect[0] 133 | cv.rectangle(frame, (left, top), (right, bottom), (0, 255, 0), thickness=2) 134 | label = str(indexes[i]) 135 | label_size, base_line = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 1, 1) 136 | top = max(top, label_size[1]) 137 | cv.rectangle(frame, (left, top - label_size[1]), (left + label_size[0], top + base_line), 138 | (255, 255, 255), cv.FILLED) 139 | cv.putText(frame, label, (left, top), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0)) 140 | 141 | return frame 142 | 143 | 144 | def load_gallery(path_to_gallery, face_det, landmarks_detector, face_recognizer): 145 | """Computes embeddings for gallery""" 146 | gallery = {} 147 | files = os.listdir(path_to_gallery) 148 | files = [file for file in files if file.endswith('.png') or file.endswith('.jpg')] 149 | for file in files: 150 | img = cv.imread(osp.join(path_to_gallery, file)) 151 | detections = face_det.get_detections(img) 152 | 153 | if not detections: 154 | detections = [[0, 0, img.shape[0], img.shape[1]], 0] 155 | log.warn('Warning: failed to detect face on the image ' + file) 156 | 157 | embed = get_embeddings(img, detections, face_recognizer, landmarks_detector) 158 | gallery[file.replace('.png', '').replace('.jpg', '')] = embed[0] 159 | return gallery 160 | 161 | 162 | def run(params, capture, face_det, face_recognizer, landmarks_detector): 163 | """Starts the face recognition demo""" 164 | win_name = 'Deep Face Recognition' 165 | gallery = load_gallery(params.path_to_gallery, face_det, landmarks_detector, face_recognizer) 166 | 167 | while cv.waitKey(1) != 27: 168 | has_frame, frame = capture.read() 169 | if not has_frame: 170 | return 171 | 172 | detections = face_det.get_detections(frame) 173 | embeds = get_embeddings(frame, detections, face_recognizer, landmarks_detector) 174 | ids, gallery = match_embeddings(embeds, gallery, params.fr_thresh) 175 | frame = draw_detections(frame, detections, ids) 176 | cv.imshow(win_name, frame) 177 | 178 | def main(): 179 | """Prepares data for the face recognition demo""" 180 | parser = argparse.ArgumentParser(description='Face recognition live demo script') 181 | parser.add_argument('--video', type=str, default=None, help='Input video') 182 | parser.add_argument('--cam_id', type=int, default=-1, help='Input cam') 183 | 184 | parser.add_argument('--fd_model', type=str, required=True) 185 | parser.add_argument('--fd_thresh', type=float, default=0.6, help='Threshold for FD') 186 | 187 | parser.add_argument('--fr_model', type=str, required=True) 188 | parser.add_argument('--fr_thresh', type=float, default=0.6, help='Threshold for FR') 189 | 190 | parser.add_argument('--path_to_gallery', type=str, required=True, help='Path to gallery with subjects') 191 | 192 | parser.add_argument('--ld_model', type=str, default='', help='Path to a snapshots with landmarks detection model') 193 | 194 | parser.add_argument('--device', type=str, default='CPU') 195 | parser.add_argument('-l', '--cpu_extension', 196 | help='MKLDNN (CPU)-targeted custom layers.Absolute path to a shared library with the kernels ' 197 | 'impl.', type=str, default=None) 198 | 199 | args = parser.parse_args() 200 | 201 | if args.cam_id >= 0: 202 | log.info('Reading from cam {}'.format(args.cam_id)) 203 | cap = cv.VideoCapture(args.cam_id) 204 | cap.set(cv.CAP_PROP_FRAME_WIDTH, 1280) 205 | cap.set(cv.CAP_PROP_FRAME_HEIGHT, 720) 206 | cap.set(cv.CAP_PROP_FOURCC, cv.VideoWriter_fourcc('M', 'J', 'P', 'G')) 207 | else: 208 | assert args.video 209 | log.info('Reading from {}'.format(args.video)) 210 | cap = cv.VideoCapture(args.video) 211 | assert cap.isOpened() 212 | 213 | face_detector = FaceDetector(args.fd_model, args.fd_thresh, args.device, args.cpu_extension) 214 | face_recognizer = VectorCNN(args.fr_model, args.device) 215 | landmarks_detector = VectorCNN(args.ld_model, args.device) 216 | run(args, cap, face_detector, face_recognizer, landmarks_detector) 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /devtools/pylint.rc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | #init-hook= 9 | 10 | # Profiled execution. 11 | profile=no 12 | 13 | # Add to the black list. It should be a base name, not a 14 | # path. You may set this option multiple times. 15 | ignore=CVS 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=yes 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | 25 | [MESSAGES CONTROL] 26 | 27 | # Enable the message, report, category or checker with the given id(s). You can 28 | # either give multiple identifier separated by comma (,) or put this option 29 | # multiple time. 30 | #enable= 31 | 32 | # Disable the message, report, category or checker with the given id(s). You 33 | # can either give multiple identifier separated by comma (,) or put this option 34 | # multiple time. 35 | disable=R0903, W0221 36 | 37 | 38 | [REPORTS] 39 | 40 | # Set the output format. Available formats are text, parseable, colorized, msvs 41 | # (visual studio) and html 42 | output-format=text 43 | 44 | # Include message's id in output 45 | include-ids=no 46 | 47 | # Put messages in a separate file for each module / package specified on the 48 | # command line instead of printing them on stdout. Reports (if any) will be 49 | # written in a file name "pylint_global.[txt|html]". 50 | files-output=no 51 | 52 | # Tells whether to display a full report or only the messages 53 | reports=yes 54 | 55 | # Python expression which should return a note less than 10 (10 is the highest 56 | # note). You have access to the variables errors warning, statement which 57 | # respectively contain the number of errors / warnings messages and the total 58 | # number of statements analyzed. This is used by the global evaluation report 59 | # (R0004). 60 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 61 | 62 | # Add a comment according to your evaluation note. This is used by the global 63 | # evaluation report (R0004). 64 | comment=no 65 | 66 | 67 | [VARIABLES] 68 | 69 | # Tells whether we should check for unused import in __init__ files. 70 | init-import=no 71 | 72 | # A regular expression matching names used for dummy variables (i.e. not used). 73 | dummy-variables-rgx=_|dummy 74 | 75 | # List of additional names supposed to be defined in builtins. Remember that 76 | # you should avoid to define new builtins when possible. 77 | additional-builtins= 78 | 79 | 80 | [BASIC] 81 | 82 | # Required attributes for module, separated by a comma 83 | required-attributes= 84 | 85 | # List of builtins function names that should not be used, separated by a comma 86 | bad-functions=map,filter,apply,input 87 | 88 | # Regular expression which should only match correct module names 89 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 90 | 91 | # Regular expression which should only match correct module level names 92 | const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ 93 | 94 | # Regular expression which should only match correct class names 95 | class-rgx=[A-Z_][a-zA-Z0-9]+$ 96 | 97 | # Regular expression which should only match correct function names 98 | function-rgx=[a-z_][a-z0-9_]{2,40}$ 99 | 100 | # Regular expression which should only match correct method names 101 | method-rgx=[a-z_][a-z0-9_]{2,30}$ 102 | 103 | # Regular expression which should only match correct instance attribute names 104 | attr-rgx=[a-z_][a-z0-9_]{0,30}$ 105 | 106 | # Regular expression which should only match correct argument names 107 | argument-rgx=[a-z_][a-z0-9_]{0,30}$ 108 | 109 | # Regular expression which should only match correct variable names 110 | variable-rgx=[a-z_][a-z0-9_]{0,30}$ 111 | 112 | # Regular expression which should only match correct list comprehension / 113 | # generator expression variable names 114 | inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ 115 | 116 | # Good variable names which should always be accepted, separated by a comma 117 | good-names=i,j,k,ex,Run,_ 118 | 119 | # Bad variable names which should always be refused, separated by a comma 120 | bad-names=foo,bar,baz,toto,tutu,tata 121 | 122 | # Regular expression which should only match functions or classes name which do 123 | # not require a docstring 124 | no-docstring-rgx=__.*__ 125 | 126 | 127 | [MISCELLANEOUS] 128 | 129 | # List of note tags to take in consideration, separated by a comma. 130 | notes=FIXME,XXX,TODO 131 | 132 | 133 | [FORMAT] 134 | 135 | # Maximum number of characters on a single line. 136 | max-line-length=120 137 | 138 | # Maximum number of lines in a module 139 | max-module-lines=1000 140 | 141 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 142 | # tab). 143 | indent-string=' ' 144 | indent-after-paren=4 145 | 146 | 147 | [SIMILARITIES] 148 | 149 | # Minimum lines number of a similarity. 150 | min-similarity-lines=4 151 | 152 | # Ignore comments when computing similarities. 153 | ignore-comments=yes 154 | 155 | # Ignore docstrings when computing similarities. 156 | ignore-docstrings=yes 157 | 158 | 159 | [TYPECHECK] 160 | 161 | # Tells whether missing members accessed in mixin class should be ignored. A 162 | # mixin class is detected if its name ends with "mixin" (case insensitive). 163 | ignore-mixin-members=yes 164 | 165 | # List of classes names for which member attributes should not be checked 166 | # (useful for classes with attributes dynamically set). 167 | ignored-classes=SQLObject 168 | 169 | # When zope mode is activated, add a predefined set of Zope acquired attributes 170 | # to generated-members. 171 | zope=no 172 | 173 | # List of members which are set dynamically and missed by pylint inference 174 | # system, and so shouldn't trigger E0201 when accessed. 175 | generated-members=REQUEST,acl_users,aq_parent,torch,cv 176 | 177 | 178 | [DESIGN] 179 | 180 | # Maximum number of arguments for function / method 181 | max-args=5 182 | 183 | # Argument names that match this expression will be ignored. Default to name 184 | # with leading underscore 185 | ignored-argument-names=_.* 186 | 187 | # Maximum number of locals for function / method body 188 | max-locals=15 189 | 190 | # Maximum number of return / yield for function / method body 191 | max-returns=6 192 | 193 | # Maximum number of branch for function / method body 194 | max-branchs=12 195 | 196 | # Maximum number of statements in function / method body 197 | max-statements=50 198 | 199 | # Maximum number of parents for a class (see R0901). 200 | max-parents=7 201 | 202 | # Maximum number of attributes for a class (see R0902). 203 | max-attributes=7 204 | 205 | # Minimum number of public methods for a class (see R0903). 206 | min-public-methods=2 207 | 208 | # Maximum number of public methods for a class (see R0904). 209 | max-public-methods=20 210 | 211 | 212 | [IMPORTS] 213 | 214 | # Deprecated modules which should not be used, separated by a comma 215 | deprecated-modules=regsub,string,TERMIOS,Bastion,rexec 216 | 217 | # Create a graph of every (i.e. internal and external) dependencies in the 218 | # given file (report RP0402 must not be disabled) 219 | import-graph= 220 | 221 | # Create a graph of external dependencies in the given file (report RP0402 must 222 | # not be disabled) 223 | ext-import-graph= 224 | 225 | # Create a graph of internal dependencies in the given file (report RP0402 must 226 | # not be disabled) 227 | int-import-graph= 228 | 229 | extension-pkg-whitelist=cv2 230 | 231 | [CLASSES] 232 | 233 | # List of interface methods to ignore, separated by a comma. This is used for 234 | # instance to not check methods defines in Zope's Interface base class. 235 | ignore-iface-methods=isImplementedBy,deferred,extends,names,namesAndDescriptions,queryDescriptionFor,getBases,getDescriptionFor,getDoc,getName,getTaggedValue,getTaggedValueTags,isEqualOrExtendedBy,setTaggedValue,isImplementedByInstancesOf,adaptWith,is_implemented_by 236 | 237 | # List of method names used to declare (i.e. assign) instance attributes. 238 | defining-attr-methods=__init__,__new__,setUp 239 | -------------------------------------------------------------------------------- /dump_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import sys 15 | import argparse 16 | import os 17 | import os.path as osp 18 | 19 | from tqdm import tqdm 20 | import numpy as np 21 | import glog as log 22 | import torch 23 | import torch.nn.functional as F 24 | from torch.utils.data import DataLoader 25 | from torchvision import transforms as t 26 | 27 | from scripts.matio import save_mat 28 | from model.common import models_backbones 29 | from datasets.megaface import MegaFace 30 | from datasets.trillion_pairs import TrillionPairs 31 | from utils.utils import load_model_state 32 | from utils.augmentation import ResizeNumpy, NumpyToTensor 33 | 34 | 35 | def clean_megaface(filenames, features, noises_list_path): 36 | """Filters megaface from outliers""" 37 | with open(noises_list_path, 'r') as f: 38 | noises_list = f.readlines() 39 | noises_list = [line.strip() for line in noises_list] 40 | clean_features = np.zeros((features.shape[0], features.shape[1] + 1), dtype=np.float32) 41 | 42 | for i, filename in enumerate(tqdm(filenames)): 43 | clean_features[i, 0: features.shape[1]] = features[i, :] 44 | for line in noises_list: 45 | if line in filename: 46 | clean_features[i, features.shape[1]] = 100.0 47 | break 48 | 49 | return clean_features 50 | 51 | 52 | def clean_facescrub(filenames, features, noises_list_path): 53 | """Replaces wrong instances of identities from the Facescrub with the centroids of these identities""" 54 | clean_feature_size = features.shape[1] + 1 55 | with open(noises_list_path, 'r') as f: 56 | noises_list = f.readlines() 57 | noises_list = [osp.splitext(line.strip())[0] for line in noises_list] 58 | clean_features = np.zeros((features.shape[0], clean_feature_size), dtype=np.float32) 59 | 60 | centroids = {} 61 | for i, filename in enumerate(tqdm(filenames)): 62 | clean_features[i, 0: features.shape[1]] = features[i, :] 63 | id_name = osp.basename(filename).split('_')[0] 64 | if not id_name in centroids: 65 | centroids[id_name] = np.zeros(clean_feature_size, dtype=np.float32) 66 | centroids[id_name] += clean_features[i, :] 67 | 68 | for i, file_path in enumerate(tqdm(filenames)): 69 | filename = osp.basename(file_path) 70 | for line in noises_list: 71 | if line in filename.replace(' ', '_'): 72 | id_name = filename.split('_')[0] 73 | clean_features[i, :] = centroids[id_name] + np.random.uniform(-0.001, 0.001, clean_feature_size) 74 | clean_features[i, :] /= np.linalg.norm(clean_features[i, :]) 75 | break 76 | 77 | return clean_features 78 | 79 | 80 | @torch.no_grad() 81 | def main(args): 82 | input_filenames = [] 83 | output_filenames = [] 84 | input_dir = os.path.abspath(args.input_dir) 85 | output_dir = os.path.abspath(args.output_dir) 86 | 87 | if not args.trillion_format: 88 | log.info('Reading info...') 89 | with open(os.path.join(args.input_dir, os.path.basename(args.input_list)), 'r') as f: 90 | lines = f.readlines() 91 | 92 | for line in tqdm(lines): 93 | info = line.strip().split('|') 94 | file = info[0].strip() 95 | filename = os.path.join(input_dir, file) 96 | 97 | path, _ = osp.split(filename) 98 | out_folder = path.replace(input_dir, output_dir) 99 | if not osp.isdir(out_folder): 100 | os.makedirs(out_folder) 101 | 102 | landmarks = None 103 | bbox = None 104 | 105 | if len(info) > 2: 106 | landmarks = info[1].strip().split(' ') 107 | landmarks = [float(x) for x in landmarks] 108 | bbox = info[2].strip().split(' ') 109 | bbox = [int(float(x)) for x in bbox] 110 | outname = filename.replace(input_dir, output_dir) + args.file_ending 111 | input_filenames.append({'path': filename, 'landmarks': landmarks, 'bbox': bbox}) 112 | output_filenames += [outname] 113 | 114 | nrof_images = len(input_filenames) 115 | log.info("Total number of images: ", nrof_images) 116 | dataset = MegaFace(input_filenames) 117 | else: 118 | dataset = TrillionPairs(args.input_dir, osp.join(args.input_dir, 'testdata_lmk.txt'), test_mode=True) 119 | nrof_images = len(dataset) 120 | 121 | emb_array = np.zeros((nrof_images, args.embedding_size), dtype=np.float32) 122 | 123 | dataset.transform = t.Compose([ResizeNumpy(models_backbones[args.model].get_input_res()), 124 | NumpyToTensor(switch_rb=True)]) 125 | val_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=5, shuffle=False) 126 | 127 | model = models_backbones[args.model](embedding_size=args.embedding_size, feature=True) 128 | assert args.snap is not None 129 | log.info('Snapshot ' + args.snap + ' ...') 130 | log.info('Extracting embeddings ...') 131 | model = load_model_state(model, args.snap, args.devices[0], eval_state=True) 132 | model = torch.nn.DataParallel(model, device_ids=args.devices, output_device=args.devices[0]) 133 | 134 | f_output_filenames = [] 135 | 136 | with torch.cuda.device(args.devices[0]): 137 | for i, data in enumerate(tqdm(val_loader), 0): 138 | idxs, imgs = data['idx'], data['img'] 139 | batch_embeddings = F.normalize(model(imgs), p=2, dim=1).data.cpu().numpy() 140 | batch_embeddings = batch_embeddings.reshape(batch_embeddings.shape[0], -1) 141 | path_indices = idxs.data.cpu().numpy() 142 | 143 | start_index = i*args.batch_size 144 | end_index = min((i+1)*args.batch_size, nrof_images) 145 | assert start_index == path_indices[0] 146 | assert end_index == path_indices[-1] + 1 147 | assert emb_array[start_index:end_index, :].shape == batch_embeddings.shape 148 | emb_array[start_index:end_index, :] = batch_embeddings 149 | 150 | if not args.trillion_format: 151 | for index in path_indices: 152 | f_output_filenames.append(output_filenames[index]) 153 | 154 | assert len(output_filenames) == len(output_filenames) 155 | 156 | log.info('Extracting features Done.') 157 | 158 | if args.trillion_format: 159 | save_mat(args.file_ending, emb_array) 160 | else: 161 | if 'megaface_noises.txt' in args.noises_list: 162 | log.info('Cleaning Megaface features') 163 | emb_array = clean_megaface(f_output_filenames, emb_array, args.noises_list) 164 | elif 'facescrub_noises.txt' in args.noises_list: 165 | log.info('Cleaning Facescrub features') 166 | emb_array = clean_facescrub(f_output_filenames, emb_array, args.noises_list) 167 | else: 168 | log.info('Megaface features are not cleaned up.') 169 | 170 | log.info('Saving features to files...') 171 | for i in tqdm(range(len(f_output_filenames))): 172 | save_mat(f_output_filenames[i], emb_array[i, :]) 173 | 174 | 175 | def parse_argument(argv): 176 | parser = argparse.ArgumentParser(description='Save embeddings to MegaFace features files') 177 | parser.add_argument('--model', choices=models_backbones.keys(), type=str, default='rmnet', help='Model type.') 178 | parser.add_argument('input_dir', help='Path to MegaFace Features') 179 | parser.add_argument('output_dir', help='Path to FaceScrub Features') 180 | parser.add_argument('--input_list', default='list.txt', type=str, required=False) 181 | parser.add_argument('--batch_size', type=int, default=128) 182 | parser.add_argument('--embedding_size', type=int, default=128) 183 | parser.add_argument('--devices', type=int, nargs='+', default=[0], help='CUDA devices to use.') 184 | parser.add_argument('--snap', type=str, required=True, help='Snapshot to evaluate.') 185 | parser.add_argument('--noises_list', type=str, default='', required=False, help='A list of the Megaface or Facescrub noises produced by insightface. \ 186 | See https://github.com/deepinsight/insightface/blob/master/src/megaface/README.md') 187 | parser.add_argument('--file_ending', help='Ending appended to original photo files. i.e.\ 188 | 11084833664_0.jpg_LBP_100x100.bin => _LBP_100x100.bin', default='_rmnet.bin') 189 | parser.add_argument('--trillion_format', action='store_true') 190 | return parser.parse_args(argv) 191 | 192 | if __name__ == '__main__': 193 | main(parse_argument(sys.argv[1:])) 194 | -------------------------------------------------------------------------------- /evaluate_landmarks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | 16 | import glog as log 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | from torch.utils.data import DataLoader 20 | from torchvision.transforms import transforms as t 21 | from tqdm import tqdm 22 | 23 | from datasets import VGGFace2, CelebA, NDG 24 | 25 | from model.common import models_landmarks 26 | from utils.landmarks_augmentation import Rescale, ToTensor 27 | from utils.utils import load_model_state 28 | 29 | 30 | def evaluate(val_loader, model): 31 | """Calculates average error""" 32 | total_loss = 0. 33 | total_pp_error = 0. 34 | failures_num = 0 35 | items_num = 0 36 | for _, data in enumerate(tqdm(val_loader), 0): 37 | data, gt_landmarks = data['img'].cuda(), data['landmarks'].cuda() 38 | predicted_landmarks = model(data) 39 | loss = predicted_landmarks - gt_landmarks 40 | items_num += loss.shape[0] 41 | n_points = loss.shape[1] // 2 42 | per_point_error = loss.data.view(-1, n_points, 2) 43 | per_point_error = torch.norm(per_point_error, p=2, dim=2) 44 | avg_error = torch.sum(per_point_error, 1) / n_points 45 | eyes_dist = torch.norm(gt_landmarks[:, 0:2] - gt_landmarks[:, 2:4], p=2, dim=1).reshape(-1) 46 | 47 | per_point_error = torch.div(per_point_error, eyes_dist.view(-1, 1)) 48 | total_pp_error += torch.sum(per_point_error, 0) 49 | 50 | avg_error = torch.div(avg_error, eyes_dist) 51 | failures_num += torch.nonzero(avg_error > 0.1).shape[0] 52 | total_loss += torch.sum(avg_error) 53 | 54 | return total_loss / items_num, (total_pp_error / items_num).data.cpu().numpy(), float(failures_num) / items_num 55 | 56 | 57 | def start_evaluation(args): 58 | """Launches the evaluation process""" 59 | 60 | if args.dataset == 'vgg': 61 | dataset = VGGFace2(args.val, args.v_list, args.v_land, landmarks_training=True) 62 | elif args.dataset == 'celeb': 63 | dataset = CelebA(args.val, args.v_land, test=True) 64 | else: 65 | dataset = NDG(args.val, args.v_land) 66 | 67 | if dataset.have_landmarks: 68 | log.info('Use alignment for the train data') 69 | dataset.transform = t.Compose([Rescale((48, 48)), ToTensor(switch_rb=True)]) 70 | else: 71 | exit() 72 | 73 | val_loader = DataLoader(dataset, batch_size=args.val_batch_size, num_workers=4, shuffle=False, pin_memory=True) 74 | 75 | model = models_landmarks['landnet'] 76 | assert args.snapshot is not None 77 | log.info('Testing snapshot ' + args.snapshot + ' ...') 78 | model = load_model_state(model, args.snapshot, args.device, eval_state=True) 79 | model.eval() 80 | cudnn.benchmark = True 81 | model = torch.nn.DataParallel(model, device_ids=[args.device], ) 82 | 83 | log.info('Face landmarks model:') 84 | log.info(model) 85 | 86 | avg_err, per_point_avg_err, failures_rate = evaluate(val_loader, model) 87 | 88 | log.info('Avg RMSE error: {}'.format(avg_err)) 89 | log.info('Per landmark RMSE error: {}'.format(per_point_avg_err)) 90 | log.info('Failure rate: {}'.format(failures_rate)) 91 | 92 | 93 | def main(): 94 | """Creates a cl parser""" 95 | parser = argparse.ArgumentParser(description='Evaluation script for landmarks detection network') 96 | parser.add_argument('--device', '-d', default=0, type=int) 97 | parser.add_argument('--val_data_root', dest='val', required=True, type=str, help='Path to val data.') 98 | parser.add_argument('--val_list', dest='v_list', required=False, type=str, help='Path to test data image list.') 99 | parser.add_argument('--val_landmarks', dest='v_land', default='', required=False, type=str, 100 | help='Path to landmarks for test images.') 101 | parser.add_argument('--val_batch_size', type=int, default=1, help='Validation batch size.') 102 | parser.add_argument('--snapshot', type=str, default=None, help='Snapshot to evaluate.') 103 | parser.add_argument('--dataset', choices=['vgg', 'celeb', 'ngd'], type=str, default='vgg', help='Dataset.') 104 | arguments = parser.parse_args() 105 | 106 | with torch.cuda.device(arguments.device): 107 | start_evaluation(arguments) 108 | 109 | if __name__ == '__main__': 110 | main() 111 | -------------------------------------------------------------------------------- /init_venv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | work_dir=$(realpath "$(dirname $0)") 4 | 5 | cd ${work_dir} 6 | if [[ -e venv ]]; then 7 | echo "Please remove a previously virtual environment folder '${work_dir}/venv'." 8 | exit 9 | fi 10 | 11 | # Create virtual environment 12 | virtualenv venv -p python3 --prompt="(deep=fr) " 13 | echo "export PYTHONPATH=\$PYTHONPATH:${work_dir}" >> venv/bin/activate 14 | . venv/bin/activate 15 | pip install -r ${work_dir}/requirements.txt 16 | 17 | 18 | echo 19 | echo "====================================================" 20 | echo "To start to work, you need to activate a virtualenv:" 21 | echo "$ . venv/bin/activate" 22 | echo "====================================================" 23 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/losses/__init__.py -------------------------------------------------------------------------------- /losses/alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | import torch 16 | import torch.nn as nn 17 | 18 | VALID_CORE_FUNC_TYPES = ['l1', 'l2', 'wing'] 19 | 20 | 21 | def wing_core(abs_x, w, eps): 22 | """Calculates the wing function from https://arxiv.org/pdf/1711.06753.pdf""" 23 | return w*math.log(1. + abs_x / eps) 24 | 25 | class AlignmentLoss(nn.Module): 26 | """Regression loss to train landmarks model""" 27 | def __init__(self, loss_type='l2'): 28 | super(AlignmentLoss, self).__init__() 29 | assert loss_type in VALID_CORE_FUNC_TYPES 30 | self.uniform_weights = True 31 | self.weights = None 32 | self.core_func_type = loss_type 33 | self.eps = 0.031 34 | self.w = 0.156 35 | 36 | def set_weights(self, weights): 37 | """Set weights for the each landmark point in loss""" 38 | self.uniform_weights = False 39 | self.weights = torch.FloatTensor(weights).cuda() 40 | 41 | def forward(self, input_values, target): 42 | bs = input_values.shape[0] 43 | loss = input_values - target 44 | n_points = loss.shape[1] // 2 45 | loss = loss.view(-1, n_points, 2) 46 | 47 | if self.core_func_type == 'l2': 48 | loss = torch.norm(loss, p=2, dim=2) 49 | loss = loss.pow(2) 50 | eyes_dist = (torch.norm(target[:, 0:2] - target[:, 2:4], p=2, dim=1).reshape(-1)).pow_(2) 51 | elif self.core_func_type == 'l1': 52 | loss = torch.norm(loss, p=1, dim=2) 53 | eyes_dist = (torch.norm(target[:, 0:2] - target[:, 2:4], p=1, dim=1).reshape(-1)) 54 | elif self.core_func_type == 'wing': 55 | wing_const = self.w - wing_core(self.w, self.w, self.eps) 56 | loss = torch.abs(loss) 57 | loss[loss < wing_const] = self.w*torch.log(1. + loss[loss < wing_const] / self.eps) 58 | loss[loss >= wing_const] -= wing_const 59 | loss = torch.sum(loss, 2) 60 | eyes_dist = (torch.norm(target[:, 0:2] - target[:, 2:4], p=1, dim=1).reshape(-1)) 61 | 62 | if self.uniform_weights: 63 | loss = torch.sum(loss, 1) 64 | loss /= n_points 65 | else: 66 | assert self.weights.shape[0] == loss.shape[1] 67 | loss = torch.mul(loss, self.weights) 68 | loss = torch.sum(loss, 1) 69 | 70 | loss = torch.div(loss, eyes_dist) 71 | loss = torch.sum(loss) 72 | return loss / (2.*bs) 73 | -------------------------------------------------------------------------------- /losses/am_softmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch.nn import Parameter 20 | 21 | 22 | class AngleSimpleLinear(nn.Module): 23 | """Computes cos of angles between input vectors and weights vectors""" 24 | def __init__(self, in_features, out_features): 25 | super(AngleSimpleLinear, self).__init__() 26 | self.in_features = in_features 27 | self.out_features = out_features 28 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 29 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 30 | 31 | def forward(self, x): 32 | cos_theta = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 33 | return cos_theta.clamp(-1, 1) 34 | 35 | 36 | def focal_loss(input_values, gamma): 37 | """Computes the focal loss""" 38 | p = torch.exp(-input_values) 39 | loss = (1 - p) ** gamma * input_values 40 | return loss.mean() 41 | 42 | 43 | class AMSoftmaxLoss(nn.Module): 44 | """Computes the AM-Softmax loss with cos or arc margin""" 45 | margin_types = ['cos', 'arc'] 46 | 47 | def __init__(self, margin_type='cos', gamma=0., m=0.5, s=30, t=1.): 48 | super(AMSoftmaxLoss, self).__init__() 49 | assert margin_type in AMSoftmaxLoss.margin_types 50 | self.margin_type = margin_type 51 | assert gamma >= 0 52 | self.gamma = gamma 53 | assert m > 0 54 | self.m = m 55 | assert s > 0 56 | self.s = s 57 | self.cos_m = math.cos(m) 58 | self.sin_m = math.sin(m) 59 | self.th = math.cos(math.pi - m) 60 | assert t >= 1 61 | self.t = t 62 | 63 | def forward(self, cos_theta, target): 64 | if self.margin_type == 'cos': 65 | phi_theta = cos_theta - self.m 66 | else: 67 | sine = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) 68 | phi_theta = cos_theta * self.cos_m - sine * self.sin_m #cos(theta+m) 69 | phi_theta = torch.where(cos_theta > self.th, phi_theta, cos_theta - self.sin_m * self.m) 70 | 71 | index = torch.zeros_like(cos_theta, dtype=torch.uint8) 72 | index.scatter_(1, target.data.view(-1, 1), 1) 73 | output = torch.where(index, phi_theta, cos_theta) 74 | 75 | if self.gamma == 0 and self.t == 1.: 76 | return F.cross_entropy(self.s*output, target) 77 | 78 | if self.t > 1: 79 | h_theta = self.t - 1 + self.t*cos_theta 80 | support_vecs_mask = (1 - index) * \ 81 | torch.lt(torch.masked_select(phi_theta, index).view(-1, 1).repeat(1, h_theta.shape[1]) - cos_theta, 0) 82 | output = torch.where(support_vecs_mask, h_theta, output) 83 | return F.cross_entropy(self.s*output, target) 84 | 85 | return focal_loss(F.cross_entropy(self.s*output, target, reduction='none'), self.gamma) 86 | -------------------------------------------------------------------------------- /losses/centroid_based.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch 16 | import torch.nn.functional as F 17 | import numpy as np 18 | 19 | 20 | class CenterLoss(nn.Module): 21 | """Implements the Center loss from https://ydwen.github.io/papers/WenECCV16.pdf""" 22 | def __init__(self, num_classes, embed_size, cos_dist=True): 23 | super().__init__() 24 | self.cos_dist = cos_dist 25 | self.num_classes = num_classes 26 | self.centers = nn.Parameter(torch.randn(self.num_classes, embed_size).cuda()) 27 | self.embed_size = embed_size 28 | self.mse = nn.MSELoss(reduction='elementwise_mean') 29 | 30 | def get_centers(self): 31 | """Returns estimated centers""" 32 | return self.centers 33 | 34 | def forward(self, features, labels): 35 | features = F.normalize(features) 36 | batch_size = labels.size(0) 37 | features_dim = features.size(1) 38 | assert features_dim == self.embed_size 39 | 40 | if self.cos_dist: 41 | self.centers.data = F.normalize(self.centers.data, p=2, dim=1) 42 | 43 | centers_batch = self.centers[labels, :] 44 | 45 | if self.cos_dist: 46 | cos_sim = nn.CosineSimilarity() 47 | cos_diff = 1. - cos_sim(features, centers_batch) 48 | center_loss = torch.sum(cos_diff) / batch_size 49 | else: 50 | center_loss = self.mse(centers_batch, features) 51 | 52 | return center_loss 53 | 54 | 55 | class MinimumMargin(nn.Module): 56 | """Implements the Minimum margin loss from https://arxiv.org/abs/1805.06741""" 57 | def __init__(self, margin=.6): 58 | super().__init__() 59 | self.margin = margin 60 | 61 | def forward(self, centers, labels): 62 | loss_value = 0 63 | 64 | batch_centers = centers[labels, :] 65 | labels = labels.cpu().data.numpy() 66 | 67 | all_pairs = labels.reshape([-1, 1]) != labels.reshape([1, -1]) 68 | valid_pairs = (all_pairs * np.tri(*all_pairs.shape, k=-1, dtype=np.bool)).astype(np.float32) 69 | losses = 1. - torch.mm(batch_centers, torch.t(batch_centers)) - self.margin 70 | 71 | valid_pairs *= (losses.cpu().data.numpy() > 0.0) 72 | num_valid = float(np.sum(valid_pairs)) 73 | 74 | if num_valid > 0: 75 | loss_value = torch.sum(losses * torch.from_numpy(valid_pairs).cuda()) 76 | else: 77 | return loss_value 78 | 79 | return loss_value / num_valid 80 | 81 | 82 | class GlobalPushPlus(nn.Module): 83 | """Implements the Global Push Plus loss""" 84 | def __init__(self, margin=.6): 85 | super().__init__() 86 | self.min_margin = 0.15 87 | self.max_margin = margin 88 | self.num_calls = 0 89 | 90 | def forward(self, features, centers, labels): 91 | self.num_calls += 1 92 | features = F.normalize(features) 93 | loss_value = 0 94 | batch_centers = centers[labels, :] 95 | labels = labels.cpu().data.numpy() 96 | assert len(labels.shape) == 1 97 | 98 | center_ids = np.arange(centers.shape[0], dtype=np.int32) 99 | different_class_pairs = labels.reshape([-1, 1]) != center_ids.reshape([1, -1]) 100 | 101 | pos_distances = 1.0 - torch.sum(features * batch_centers, dim=1) 102 | neg_distances = 1.0 - torch.mm(features, torch.t(centers)) 103 | 104 | margin = self.min_margin + float(self.num_calls) / float(40000) * (self.max_margin - self.min_margin) 105 | margin = min(margin, self.max_margin) 106 | 107 | losses = margin + pos_distances.view(-1, 1) - neg_distances 108 | 109 | valid_pairs = (different_class_pairs * (losses.cpu().data.numpy() > 0.0)).astype(np.float32) 110 | num_valid = float(np.sum(valid_pairs)) 111 | 112 | if num_valid > 0: 113 | loss_value = torch.sum(losses * torch.from_numpy(valid_pairs).cuda()) 114 | else: 115 | return loss_value 116 | 117 | return loss_value / num_valid 118 | 119 | 120 | class PushPlusLoss(nn.Module): 121 | """Implements the Push Plus loss""" 122 | def __init__(self, margin=.7): 123 | super().__init__() 124 | self.margin = margin 125 | 126 | def forward(self, features, centers, labels): 127 | features = F.normalize(features) 128 | loss_value = 0 129 | batch_centers = centers[labels, :] 130 | labels = labels.cpu().data.numpy() 131 | assert len(labels.shape) == 1 132 | 133 | all_pairs = labels.reshape([-1, 1]) != labels.reshape([1, -1]) 134 | pos_distances = 1.0 - torch.sum(features * batch_centers, dim=1) 135 | neg_distances = 1.0 - torch.mm(features, torch.t(features)) 136 | 137 | losses = self.margin + pos_distances.view(-1, 1) - neg_distances 138 | valid_pairs = (all_pairs * (losses.cpu().data.numpy() > 0.0)).astype(np.float32) 139 | num_valid = float(np.sum(valid_pairs)) 140 | 141 | if num_valid > 0: 142 | loss_value = torch.sum(losses * torch.from_numpy(valid_pairs).cuda()) 143 | else: 144 | return loss_value 145 | 146 | return loss_value / num_valid 147 | 148 | 149 | class PushLoss(nn.Module): 150 | """Implements the Push loss""" 151 | def __init__(self, soft=True, margin=0.5): 152 | super().__init__() 153 | self.soft = soft 154 | self.margin = margin 155 | 156 | def forward(self, features, labels): 157 | features = F.normalize(features) 158 | loss_value = 0 159 | labels = labels.cpu().data.numpy() 160 | assert len(labels.shape) == 1 161 | 162 | all_pairs = labels.reshape([-1, 1]) != labels.reshape([1, -1]) 163 | valid_pairs = (all_pairs * np.tri(*all_pairs.shape, k=-1, dtype=np.bool)).astype(np.float32) 164 | 165 | if self.soft: 166 | losses = torch.log(1. + torch.exp(torch.mm(features, torch.t(features)) - 1)) 167 | num_valid = float(np.sum(valid_pairs)) 168 | else: 169 | losses = self.margin - (1. - torch.mm(features, torch.t(features))) 170 | valid_pairs *= (losses.cpu().data.numpy() > 0.0) 171 | num_valid = float(np.sum(valid_pairs)) 172 | 173 | if num_valid > 0: 174 | loss_value = torch.sum(losses * torch.from_numpy(valid_pairs).cuda()) 175 | else: 176 | return loss_value 177 | 178 | return loss_value / num_valid 179 | -------------------------------------------------------------------------------- /losses/metric_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch 15 | from losses.centroid_based import CenterLoss, PushLoss, MinimumMargin, PushPlusLoss, GlobalPushPlus 16 | 17 | 18 | class MetricLosses: 19 | """Class-aggregator for all metric-learning losses""" 20 | def __init__(self, classes_num, embed_size, writer): 21 | self.writer = writer 22 | self.center_loss = CenterLoss(classes_num, embed_size, cos_dist=True) 23 | self.optimizer_centloss = torch.optim.SGD(self.center_loss.parameters(), lr=0.5) 24 | self.center_coeff = 0.0 25 | 26 | self.push_loss = PushLoss(soft=False, margin=0.7) 27 | self.push_loss_coeff = 0.0 28 | 29 | self.push_plus_loss = PushPlusLoss(margin=0.7) 30 | self.push_plus_loss_coeff = 0.0 31 | 32 | self.glob_push_plus_loss = GlobalPushPlus(margin=0.7) 33 | self.glob_push_plus_loss_coeff = 0.0 34 | 35 | self.min_margin_loss = MinimumMargin(margin=.7) 36 | self.min_margin_loss_coeff = 0.0 37 | 38 | def __call__(self, features, labels, epoch_num, iteration): 39 | log_string = '' 40 | 41 | center_loss_val = 0 42 | if self.center_coeff > 0.: 43 | center_loss_val = self.center_loss(features, labels) 44 | self.writer.add_scalar('Loss/center_loss', center_loss_val, iteration) 45 | log_string += ' Center loss: %.4f' % center_loss_val 46 | 47 | push_loss_val = 0 48 | if self.push_loss_coeff > 0.0: 49 | push_loss_val = self.push_loss(features, labels) 50 | self.writer.add_scalar('Loss/push_loss', push_loss_val, iteration) 51 | log_string += ' Push loss: %.4f' % push_loss_val 52 | 53 | push_plus_loss_val = 0 54 | if self.push_plus_loss_coeff > 0.0 and self.center_coeff > 0.0: 55 | push_plus_loss_val = self.push_plus_loss(features, self.center_loss.get_centers(), labels) 56 | self.writer.add_scalar('Loss/push_plus_loss', push_plus_loss_val, iteration) 57 | log_string += ' Push Plus loss: %.4f' % push_plus_loss_val 58 | 59 | glob_push_plus_loss_val = 0 60 | if self.glob_push_plus_loss_coeff > 0.0 and self.center_coeff > 0.0: 61 | glob_push_plus_loss_val = self.glob_push_plus_loss(features, self.center_loss.get_centers(), labels) 62 | self.writer.add_scalar('Loss/global_push_plus_loss', glob_push_plus_loss_val, iteration) 63 | log_string += ' Global Push Plus loss: %.4f' % glob_push_plus_loss_val 64 | 65 | min_margin_loss_val = 0 66 | if self.min_margin_loss_coeff > 0.0 and self.center_coeff > 0.0: 67 | min_margin_loss_val = self.min_margin_loss(self.center_loss.get_centers(), labels) 68 | self.writer.add_scalar('Loss/min_margin_loss', min_margin_loss_val, iteration) 69 | log_string += ' Min margin loss: %.4f' % min_margin_loss_val 70 | 71 | loss_value = self.center_coeff * center_loss_val + self.push_loss_coeff * push_loss_val + \ 72 | self.push_plus_loss_coeff * push_plus_loss_val + self.min_margin_loss_coeff * min_margin_loss_val \ 73 | + self.glob_push_plus_loss_coeff * glob_push_plus_loss_val 74 | 75 | if self.min_margin_loss_coeff + self.center_coeff + self.push_loss_coeff + self.push_plus_loss_coeff > 0.: 76 | self.writer.add_scalar('Loss/AUX_losses', loss_value, iteration) 77 | 78 | return loss_value, log_string 79 | 80 | def init_iteration(self): 81 | """Initializes a training iteration""" 82 | if self.center_coeff > 0.: 83 | self.optimizer_centloss.zero_grad() 84 | 85 | def end_iteration(self): 86 | """Finalizes a training iteration""" 87 | if self.center_coeff > 0.: 88 | for param in self.center_loss.parameters(): 89 | param.grad.data *= (1. / self.center_coeff) 90 | self.optimizer_centloss.step() 91 | -------------------------------------------------------------------------------- /losses/regularizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | 18 | def l2_reg_ortho(mdl): 19 | """ 20 | Function used for Orthogonal Regularization. 21 | """ 22 | l2_reg = None 23 | for w in mdl.parameters(): 24 | if w.ndimension() < 2: 25 | continue 26 | else: 27 | cols = w[0].numel() 28 | w1 = w.view(-1, cols) 29 | wt = torch.transpose(w1, 0, 1) 30 | m = torch.matmul(wt, w1) 31 | ident = torch.eye(cols, cols).cuda() 32 | 33 | w_tmp = (m - ident) 34 | height = w_tmp.size(0) 35 | u = F.normalize(w_tmp.new_empty(height).normal_(0, 1), dim=0, eps=1e-12) 36 | v = F.normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12) 37 | u = F.normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12) 38 | sigma = torch.dot(u, torch.matmul(w_tmp, v)) 39 | 40 | if l2_reg is None: 41 | l2_reg = (torch.norm(sigma, 2))**2 42 | else: 43 | l2_reg += (torch.norm(sigma, 2))**2 44 | return l2_reg 45 | 46 | 47 | class ODecayScheduler(): 48 | """Scheduler for the decay of the orthogonal regularizer""" 49 | def __init__(self, schedule, initial_decay, mult_factor): 50 | assert len(schedule) > 1 51 | self.schedule = schedule 52 | self.epoch_num = 0 53 | self.mult_factor = mult_factor 54 | self.decay = initial_decay 55 | 56 | def step(self): 57 | """Switches to the next step""" 58 | self.epoch_num += 1 59 | if self.epoch_num in self.schedule: 60 | self.decay *= self.mult_factor 61 | if self.epoch_num == self.schedule[-1]: 62 | self.decay = 0.0 63 | 64 | def get_decay(self): 65 | """Returns the current value of decay according to th schedule""" 66 | return self.decay 67 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/model/__init__.py -------------------------------------------------------------------------------- /model/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/model/backbones/__init__.py -------------------------------------------------------------------------------- /model/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from model.blocks.resnet_blocks import Bottleneck, BasicBlock 17 | from model.blocks.shared_blocks import make_activation 18 | 19 | 20 | class ResNet(nn.Module): 21 | def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU): 22 | self.inplanes = 64 23 | super(ResNet, self).__init__() 24 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 25 | bias=False) 26 | self.bn1 = nn.BatchNorm2d(64) 27 | self.relu = make_activation(nn.ReLU) 28 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 29 | self.layer1 = self._make_layer(block, 64, layers[0], activation=activation) 30 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, activation=activation) 31 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, activation=activation) 32 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, activation=activation) 33 | self.avgpool = nn.Conv2d(512 * block.expansion, 512 * block.expansion, 7, 34 | groups=512 * block.expansion, bias=False) 35 | self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1, stride=1, padding=0, bias=False) 36 | 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 40 | elif isinstance(m, nn.BatchNorm2d): 41 | nn.init.constant_(m.weight, 1) 42 | nn.init.constant_(m.bias, 0) 43 | 44 | def _make_layer(self, block, planes, blocks, stride=1, activation=nn.ReLU): 45 | downsample = None 46 | if stride != 1 or self.inplanes != planes * block.expansion: 47 | downsample = nn.Sequential( 48 | nn.Conv2d(self.inplanes, planes * block.expansion, 49 | kernel_size=1, stride=stride, bias=False), 50 | nn.BatchNorm2d(planes * block.expansion), 51 | ) 52 | 53 | layers = [] 54 | layers.append(block(self.inplanes, planes, stride, downsample, activation=activation)) 55 | self.inplanes = planes * block.expansion 56 | for _ in range(1, blocks): 57 | layers.append(block(self.inplanes, planes, activation=activation)) 58 | 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | x = self.conv1(x) 63 | x = self.bn1(x) 64 | x = self.relu(x) 65 | x = self.maxpool(x) 66 | 67 | x = self.layer1(x) 68 | x = self.layer2(x) 69 | x = self.layer3(x) 70 | x = self.layer4(x) 71 | 72 | x = self.avgpool(x) 73 | x = self.fc(x) 74 | 75 | return x 76 | 77 | 78 | def resnet50(**kwargs): 79 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 80 | return model 81 | 82 | 83 | def resnet34(**kwargs): 84 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 85 | return model 86 | -------------------------------------------------------------------------------- /model/backbones/rmnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | from collections import OrderedDict 15 | 16 | import torch.nn as nn 17 | from ..blocks.rmnet_blocks import RMBlock 18 | 19 | 20 | class RMNetBody(nn.Module): 21 | def __init__(self, block=RMBlock, blocks_per_stage=(None, 4, 8, 10, 11), trunk_width=(32, 32, 64, 128, 256), 22 | bottleneck_width=(None, 8, 16, 32, 64)): 23 | super(RMNetBody, self).__init__() 24 | assert len(blocks_per_stage) == len(trunk_width) == len(bottleneck_width) 25 | self.dim_out = trunk_width[-1] 26 | 27 | stages = [nn.Sequential(OrderedDict([ 28 | ('data_bn', nn.BatchNorm2d(3)), 29 | ('conv1', nn.Conv2d(3, trunk_width[0], kernel_size=3, stride=2, padding=1, bias=False)), 30 | ('bn1', nn.BatchNorm2d(trunk_width[0])), 31 | ('relu1', nn.ReLU(inplace=True))])), ] 32 | 33 | for i, (blocks_num, w, wb) in enumerate(zip(blocks_per_stage, trunk_width, bottleneck_width)): 34 | # Zeroth stage is already added. 35 | if i == 0: 36 | continue 37 | stage = [] 38 | # Do not downscale input to the first stage. 39 | if i > 1: 40 | stage.append(block(trunk_width[i - 1], wb, w, downsample=True)) 41 | for _ in range(blocks_num): 42 | stage.append(block(w, wb, w)) 43 | stages.append(nn.Sequential(*stage)) 44 | 45 | self.stages = nn.Sequential(OrderedDict([('stage_{}'.format(i), stage) for i, stage in enumerate(stages)])) 46 | 47 | self.init_weights() 48 | 49 | def init_weights(self): 50 | m = self.stages[0][0] # ['data_bn'] 51 | nn.init.constant_(m.weight, 1) 52 | nn.init.constant_(m.bias, 0) 53 | m = self.stages[0][1] # ['conv1'] 54 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 55 | m = self.stages[0][2] # ['bn1'] 56 | nn.init.constant_(m.weight, 1) 57 | nn.init.constant_(m.bias, 0) 58 | # All other blocks should be initialized internally during instantiation. 59 | 60 | def forward(self, x): 61 | return self.stages(x) 62 | -------------------------------------------------------------------------------- /model/backbones/se_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | 16 | import torch.nn as nn 17 | 18 | from model.blocks.se_resnet_blocks import SEBottleneck 19 | 20 | 21 | class SEResNet(nn.Module): 22 | def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU): 23 | self.inplanes = 64 24 | super(SEResNet, self).__init__() 25 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 26 | bias=False) 27 | self.bn1 = nn.BatchNorm2d(64) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 30 | self.layer1 = self._make_layer(block, 64, layers[0], activation=activation) 31 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, activation=activation) 32 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, activation=activation) 33 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, activation=activation) 34 | self.avgpool = nn.Conv2d(512 * block.expansion, 512 * block.expansion, 7, 35 | groups=512 * block.expansion, bias=False) 36 | self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1, stride=1, padding=0, bias=False) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | 46 | def _make_layer(self, block, planes, blocks, stride=1, activation=nn.ReLU): 47 | downsample = None 48 | if stride != 1 or self.inplanes != planes * block.expansion: 49 | downsample = nn.Sequential( 50 | nn.Conv2d(self.inplanes, planes * block.expansion, 51 | kernel_size=1, stride=stride, bias=False), 52 | nn.BatchNorm2d(planes * block.expansion), 53 | ) 54 | 55 | layers = [] 56 | layers.append(block(self.inplanes, planes, stride, downsample, activation=activation)) 57 | self.inplanes = planes * block.expansion 58 | for _ in range(1, blocks): 59 | layers.append(block(self.inplanes, planes, activation=activation)) 60 | 61 | return nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | x = self.conv1(x) 65 | x = self.bn1(x) 66 | x = self.relu(x) 67 | x = self.maxpool(x) 68 | 69 | x = self.layer1(x) 70 | x = self.layer2(x) 71 | x = self.layer3(x) 72 | x = self.layer4(x) 73 | 74 | x = self.avgpool(x) 75 | x = self.fc(x) 76 | 77 | return x 78 | 79 | 80 | def se_resnet50(**kwargs): 81 | model = SEResNet(SEBottleneck, [3, 4, 6, 3], **kwargs) 82 | return model 83 | 84 | 85 | def se_resnet101(**kwargs): 86 | model = SEResNet(SEBottleneck, [3, 4, 23, 3], **kwargs) 87 | return model 88 | 89 | 90 | def se_resnet152(**kwargs): 91 | model = SEResNet(SEBottleneck, [3, 8, 36, 3], **kwargs) 92 | return model 93 | -------------------------------------------------------------------------------- /model/backbones/se_resnext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | import torch.nn as nn 16 | 17 | from model.blocks.se_resnext_blocks import SEBottleneckX 18 | 19 | 20 | class SEResNeXt(nn.Module): 21 | 22 | def __init__(self, block, layers, cardinality=32, num_classes=1000): 23 | super(SEResNeXt, self).__init__() 24 | self.cardinality = cardinality 25 | self.inplanes = 64 26 | 27 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 28 | bias=False) 29 | self.bn1 = nn.BatchNorm2d(64) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 32 | 33 | self.layer1 = self._make_layer(block, 64, layers[0]) 34 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 35 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 36 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 37 | 38 | self.avgpool = nn.AdaptiveAvgPool2d(1) 39 | self.fc = nn.Linear(512 * block.expansion, num_classes) 40 | 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | m.weight.data.normal_(0, math.sqrt(2. / n)) 45 | if m.bias is not None: 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.BatchNorm2d): 48 | m.weight.data.fill_(1) 49 | m.bias.data.zero_() 50 | 51 | def _make_layer(self, block, planes, blocks, stride=1): 52 | downsample = None 53 | if stride != 1 or self.inplanes != planes * block.expansion: 54 | downsample = nn.Sequential( 55 | nn.Conv2d(self.inplanes, planes * block.expansion, 56 | kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(planes * block.expansion), 58 | ) 59 | 60 | layers = [] 61 | layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample)) 62 | self.inplanes = planes * block.expansion 63 | for _ in range(1, blocks): 64 | layers.append(block(self.inplanes, planes, self.cardinality)) 65 | 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | x = self.conv1(x) 70 | x = self.bn1(x) 71 | x = self.relu(x) 72 | x = self.maxpool(x) 73 | 74 | x = self.layer1(x) 75 | x = self.layer2(x) 76 | x = self.layer3(x) 77 | x = self.layer4(x) 78 | 79 | x = self.avgpool(x) 80 | x = x.view(x.size(0), -1) 81 | 82 | x = self.fc(x) 83 | 84 | return x 85 | 86 | 87 | def se_resnext50(**kwargs): 88 | model = SEResNeXt(SEBottleneckX, [3, 4, 6, 3], **kwargs) 89 | return model 90 | 91 | 92 | def se_resnext101(**kwargs): 93 | model = SEResNeXt(SEBottleneckX, [3, 4, 23, 3], **kwargs) 94 | return model 95 | 96 | 97 | def se_resnext152(**kwargs): 98 | model = SEResNeXt(SEBottleneckX, [3, 8, 36, 3], **kwargs) 99 | return model 100 | -------------------------------------------------------------------------------- /model/backbones/shufflenet_v2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from model.blocks.shufflenet_v2_blocks import ShuffleInvertedResidual, conv_bn, conv_1x1_bn 17 | 18 | 19 | class ShuffleNetV2Body(nn.Module): 20 | def __init__(self, input_size=224, width_mult=1.): 21 | super(ShuffleNetV2Body, self).__init__() 22 | 23 | assert input_size % 32 == 0 24 | 25 | self.stage_repeats = [4, 8, 4] 26 | if width_mult == 0.5: 27 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 28 | elif width_mult == 1.0: 29 | self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] 30 | elif width_mult == 1.5: 31 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] 32 | elif width_mult == 2.0: 33 | self.stage_out_channels = [-1, 24, 224, 488, 976, 2048] 34 | else: 35 | raise ValueError("Unsupported width multiplier") 36 | 37 | # building first layer 38 | self.bn_first = nn.BatchNorm2d(3) 39 | input_channel = self.stage_out_channels[1] 40 | self.conv1 = conv_bn(3, input_channel, 2) 41 | 42 | self.features = [] 43 | 44 | # building inverted residual blocks 45 | for idxstage in range(len(self.stage_repeats)): 46 | numrepeat = self.stage_repeats[idxstage] 47 | output_channel = self.stage_out_channels[idxstage+2] 48 | for i in range(numrepeat): 49 | if i == 0: 50 | self.features.append(ShuffleInvertedResidual(input_channel, output_channel, 51 | 2, 2, activation=nn.PReLU)) 52 | else: 53 | self.features.append(ShuffleInvertedResidual(input_channel, output_channel, 54 | 1, 1, activation=nn.PReLU)) 55 | input_channel = output_channel 56 | 57 | self.features = nn.Sequential(*self.features) 58 | self.conv_last = conv_1x1_bn(input_channel, self.stage_out_channels[-1], activation=nn.PReLU) 59 | self.init_weights() 60 | 61 | @staticmethod 62 | def get_downscale_factor(): 63 | return 16 64 | 65 | def init_weights(self): 66 | m = self.bn_first 67 | nn.init.constant_(m.weight, 1) 68 | nn.init.constant_(m.bias, 0) 69 | 70 | def get_num_output_channels(self): 71 | return self.stage_out_channels[-1] 72 | 73 | def forward(self, x): 74 | x = self.conv1(self.bn_first(x)) 75 | x = self.features(x) 76 | x = self.conv_last(x) 77 | return x 78 | -------------------------------------------------------------------------------- /model/blocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/model/blocks/__init__.py -------------------------------------------------------------------------------- /model/blocks/mobilenet_v2_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from model.blocks.shared_blocks import SELayer 17 | 18 | 19 | class InvertedResidual(nn.Module): 20 | """Implementation of the modified Inverted residual block""" 21 | def __init__(self, in_channels, out_channels, stride, expand_ratio, outp_size=None): 22 | super(InvertedResidual, self).__init__() 23 | self.stride = stride 24 | assert stride in [1, 2] 25 | 26 | self.use_res_connect = self.stride == 1 and in_channels == out_channels 27 | 28 | self.inv_block = nn.Sequential( 29 | nn.Conv2d(in_channels, in_channels * expand_ratio, 1, 1, 0, bias=False), 30 | nn.BatchNorm2d(in_channels * expand_ratio), 31 | nn.PReLU(), 32 | 33 | nn.Conv2d(in_channels * expand_ratio, in_channels * expand_ratio, 3, stride, 1, 34 | groups=in_channels * expand_ratio, bias=False), 35 | nn.BatchNorm2d(in_channels * expand_ratio), 36 | nn.PReLU(), 37 | 38 | nn.Conv2d(in_channels * expand_ratio, out_channels, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(out_channels), 40 | SELayer(out_channels, 8, nn.PReLU, outp_size) 41 | ) 42 | 43 | def forward(self, x): 44 | if self.use_res_connect: 45 | return x + self.inv_block(x) 46 | 47 | return self.inv_block(x) 48 | -------------------------------------------------------------------------------- /model/blocks/resnet_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from model.blocks.shared_blocks import make_activation 17 | 18 | 19 | class Bottleneck(nn.Module): 20 | expansion = 4 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, activation=nn.ReLU): 23 | super(Bottleneck, self).__init__() 24 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.act1 = make_activation(activation) 27 | 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.act2 = make_activation(activation) 31 | 32 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 33 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 34 | self.act3 = make_activation(activation) 35 | 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.act1(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | out = self.act2(out) 49 | 50 | out = self.conv3(out) 51 | out = self.bn3(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.act3(out) 58 | 59 | return out 60 | 61 | 62 | class BasicBlock(nn.Module): 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, activation=nn.ReLU): 66 | super(BasicBlock, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(planes) 69 | self.relu = make_activation(activation) 70 | self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm2d(planes) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | -------------------------------------------------------------------------------- /model/blocks/rmnet_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from model.blocks.shared_blocks import make_activation 18 | 19 | 20 | class RMBlock(nn.Module): 21 | def __init__(self, input_planes, squeeze_planes, output_planes, downsample=False, dropout_ratio=0.1, 22 | activation=nn.ELU): 23 | super(RMBlock, self).__init__() 24 | self.downsample = downsample 25 | self.input_planes = input_planes 26 | self.output_planes = output_planes 27 | 28 | self.squeeze_conv = nn.Conv2d(input_planes, squeeze_planes, kernel_size=1, bias=False) 29 | self.squeeze_bn = nn.BatchNorm2d(squeeze_planes) 30 | 31 | self.dw_conv = nn.Conv2d(squeeze_planes, squeeze_planes, groups=squeeze_planes, kernel_size=3, padding=1, 32 | stride=2 if downsample else 1, bias=False) 33 | self.dw_bn = nn.BatchNorm2d(squeeze_planes) 34 | 35 | self.expand_conv = nn.Conv2d(squeeze_planes, output_planes, kernel_size=1, bias=False) 36 | self.expand_bn = nn.BatchNorm2d(output_planes) 37 | 38 | self.activation = make_activation(activation) 39 | 40 | self.dropout_ratio = dropout_ratio 41 | 42 | if self.downsample: 43 | self.skip_conv = nn.Conv2d(input_planes, output_planes, kernel_size=1, bias=False) 44 | self.skip_conv_bn = nn.BatchNorm2d(output_planes) 45 | 46 | self.init_weights() 47 | 48 | def init_weights(self): 49 | for m in self.children(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 52 | elif isinstance(m, nn.BatchNorm2d): 53 | nn.init.constant_(m.weight, 1) 54 | nn.init.constant_(m.bias, 0) 55 | 56 | def forward(self, x): 57 | residual = x 58 | out = self.activation(self.squeeze_bn(self.squeeze_conv(x))) 59 | out = self.activation(self.dw_bn(self.dw_conv(out))) 60 | out = self.expand_bn(self.expand_conv(out)) 61 | if self.dropout_ratio > 0: 62 | out = F.dropout(out, p=self.dropout_ratio, training=self.training, inplace=True) 63 | if self.downsample: 64 | residual = F.max_pool2d(x, kernel_size=2, stride=2, padding=0) 65 | residual = self.skip_conv(residual) 66 | residual = self.skip_conv_bn(residual) 67 | out += residual 68 | return self.activation(out) 69 | -------------------------------------------------------------------------------- /model/blocks/se_resnet_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from model.blocks.shared_blocks import make_activation 17 | 18 | 19 | class SEBottleneck(nn.Module): 20 | expansion = 4 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, activation=nn.ReLU): 23 | super(SEBottleneck, self).__init__() 24 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | 30 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 31 | self.bn3 = nn.BatchNorm2d(planes * 4) 32 | 33 | self.relu = make_activation(activation) 34 | 35 | # SE 36 | self.global_pool = nn.AdaptiveAvgPool2d(1) 37 | self.conv_down = nn.Conv2d(planes * 4, planes // 4, kernel_size=1, bias=False) 38 | self.conv_up = nn.Conv2d(planes // 4, planes * 4, kernel_size=1, bias=False) 39 | self.sig = nn.Sigmoid() 40 | 41 | # Downsample 42 | self.downsample = downsample 43 | self.stride = stride 44 | 45 | def forward(self, x): 46 | residual = x 47 | 48 | out = self.conv1(x) 49 | out = self.bn1(out) 50 | out = self.relu(out) 51 | 52 | out = self.conv2(out) 53 | out = self.bn2(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv3(out) 57 | out = self.bn3(out) 58 | 59 | out1 = self.global_pool(out) 60 | out1 = self.conv_down(out1) 61 | out1 = self.relu(out1) 62 | out1 = self.conv_up(out1) 63 | out1 = self.sig(out1) 64 | 65 | if self.downsample is not None: 66 | residual = self.downsample(x) 67 | 68 | res = out1 * out + residual 69 | res = self.relu(res) 70 | 71 | return res 72 | -------------------------------------------------------------------------------- /model/blocks/se_resnext_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from model.blocks.shared_blocks import SELayer 17 | 18 | 19 | class SEBottleneckX(nn.Module): 20 | expansion = 4 21 | 22 | def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None): 23 | super(SEBottleneckX, self).__init__() 24 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 25 | self.bn1 = nn.BatchNorm2d(planes * 2) 26 | 27 | self.conv2 = nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride, 28 | padding=1, groups=cardinality, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes * 2) 30 | 31 | self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False) 32 | self.bn3 = nn.BatchNorm2d(planes * 4) 33 | 34 | self.selayer = SELayer(planes * 4, 16, nn.ReLU) 35 | 36 | self.relu = nn.ReLU(inplace=True) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv3(out) 52 | out = self.bn3(out) 53 | 54 | out = self.selayer(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | -------------------------------------------------------------------------------- /model/blocks/shared_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | def make_activation(activation): 19 | """Factory for activation functions""" 20 | if activation != nn.PReLU: 21 | return activation(inplace=True) 22 | 23 | return activation() 24 | 25 | 26 | class SELayer(nn.Module): 27 | """Implementation of the Squeeze-Excitaion layer from https://arxiv.org/abs/1709.01507""" 28 | def __init__(self, inplanes, squeeze_ratio=8, activation=nn.PReLU, size=None): 29 | super(SELayer, self).__init__() 30 | assert squeeze_ratio >= 1 31 | assert inplanes > 0 32 | if size is not None: 33 | self.global_avgpool = nn.AvgPool2d(size) 34 | else: 35 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 36 | self.conv1 = nn.Conv2d(inplanes, int(inplanes / squeeze_ratio), kernel_size=1, stride=1) 37 | self.conv2 = nn.Conv2d(int(inplanes / squeeze_ratio), inplanes, kernel_size=1, stride=1) 38 | self.relu = make_activation(activation) 39 | self.sigmoid = nn.Sigmoid() 40 | 41 | def forward(self, x): 42 | out = self.global_avgpool(x) 43 | out = self.conv1(out) 44 | out = self.relu(out) 45 | out = self.conv2(out) 46 | out = self.sigmoid(out) 47 | return x * out 48 | 49 | 50 | class ScaleFilter(nn.Module): 51 | """Implementaion of the ScaleFilter regularizer""" 52 | def __init__(self, q): 53 | super(ScaleFilter, self).__init__() 54 | assert 0 < q < 1 55 | self.q = q 56 | 57 | def forward(self, x): 58 | if not self.training: 59 | return x 60 | 61 | scale_factors = 1. + self.q \ 62 | - 2*self.q*torch.rand(x.shape[1], 1, 1, dtype=torch.float32, requires_grad=False).to(x.device) 63 | return x * scale_factors 64 | -------------------------------------------------------------------------------- /model/blocks/shufflenet_v2_blocks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | from model.blocks.shared_blocks import make_activation 17 | 18 | 19 | def conv_bn(inp, oup, stride, activation=nn.ReLU): 20 | conv = nn.Sequential( 21 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 22 | nn.BatchNorm2d(oup), 23 | make_activation(activation) 24 | ) 25 | nn.init.kaiming_normal_(conv[0].weight, mode='fan_out') 26 | return conv 27 | 28 | 29 | def conv_1x1_bn(inp, oup, activation=nn.ReLU): 30 | conv = nn.Sequential( 31 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 32 | nn.BatchNorm2d(oup), 33 | make_activation(activation) 34 | ) 35 | nn.init.kaiming_normal_(conv[0].weight, mode='fan_out') 36 | return conv 37 | 38 | 39 | def channel_shuffle(x, groups): 40 | batchsize, num_channels, height, width = x.data.size() 41 | channels_per_group = num_channels // groups 42 | # reshape 43 | x = x.view(batchsize, groups, channels_per_group, height, width) 44 | x = torch.transpose(x, 1, 2).contiguous() 45 | # flatten 46 | x = x.view(batchsize, -1, height, width) 47 | return x 48 | 49 | 50 | class ShuffleInvertedResidual(nn.Module): 51 | def __init__(self, inp, oup, stride, benchmodel, activation=nn.ReLU): 52 | super(ShuffleInvertedResidual, self).__init__() 53 | self.benchmodel = benchmodel 54 | self.stride = stride 55 | assert stride in [1, 2] 56 | 57 | oup_inc = oup//2 58 | 59 | if self.benchmodel == 1: 60 | # assert inp == oup_inc 61 | self.branch2 = nn.Sequential( 62 | # pw 63 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 64 | nn.BatchNorm2d(oup_inc), 65 | make_activation(activation), 66 | # dw 67 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 68 | nn.BatchNorm2d(oup_inc), 69 | # pw-linear 70 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 71 | nn.BatchNorm2d(oup_inc), 72 | make_activation(activation), 73 | ) 74 | else: 75 | self.branch1 = nn.Sequential( 76 | # dw 77 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 78 | nn.BatchNorm2d(inp), 79 | # pw-linear 80 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 81 | nn.BatchNorm2d(oup_inc), 82 | make_activation(activation), 83 | ) 84 | 85 | self.branch2 = nn.Sequential( 86 | # pw 87 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 88 | nn.BatchNorm2d(oup_inc), 89 | make_activation(activation), 90 | # dw 91 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 92 | nn.BatchNorm2d(oup_inc), 93 | # pw-linear 94 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 95 | nn.BatchNorm2d(oup_inc), 96 | make_activation(activation), 97 | ) 98 | self.init_weights() 99 | 100 | @staticmethod 101 | def _concat(x, out): 102 | # concatenate along channel axis 103 | return torch.cat((x, out), 1) 104 | 105 | def init_weights(self): 106 | for m in self.children(): 107 | if isinstance(m, nn.Conv2d): 108 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 109 | elif isinstance(m, nn.BatchNorm2d): 110 | nn.init.constant_(m.weight, 1) 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def forward(self, x): 114 | if self.benchmodel == 1: 115 | x1 = x[:, :(x.shape[1]//2), :, :] 116 | x2 = x[:, (x.shape[1]//2):, :, :] 117 | out = self._concat(x1, self.branch2(x2)) 118 | elif self.benchmodel == 2: 119 | out = self._concat(self.branch1(x), self.branch2(x)) 120 | 121 | return channel_shuffle(out, 2) 122 | -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | from abc import abstractmethod 14 | import torch.nn as nn 15 | 16 | 17 | class ModelInterface(nn.Module): 18 | """Abstract class for models""" 19 | 20 | @abstractmethod 21 | def set_dropout_ratio(self, ratio): 22 | """Sets dropout ratio of the model""" 23 | 24 | @abstractmethod 25 | def get_input_res(self): 26 | """Returns input resolution""" 27 | 28 | 29 | from .rmnet_angular import RMNetAngular 30 | from .mobilefacenet import MobileFaceNet 31 | from .landnet import LandmarksNet 32 | from .resnet_angular import ResNetAngular 33 | from .se_resnet_angular import SEResNetAngular 34 | from .shufflenet_v2_angular import ShuffleNetV2Angular 35 | 36 | 37 | models_backbones = {'rmnet': RMNetAngular, 'mobilenet': MobileFaceNet, 'resnet': ResNetAngular, 38 | 'shufflenetv2': ShuffleNetV2Angular, 'se_resnet': SEResNetAngular} 39 | 40 | models_landmarks = {'landnet': LandmarksNet} 41 | -------------------------------------------------------------------------------- /model/landnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from .common import ModelInterface 17 | 18 | 19 | class LandmarksNet(ModelInterface): 20 | """Facial landmarks localization network""" 21 | def __init__(self): 22 | super(LandmarksNet, self).__init__() 23 | self.bn_first = nn.BatchNorm2d(3) 24 | activation = nn.PReLU 25 | self.landnet = nn.Sequential( 26 | nn.Conv2d(3, 16, kernel_size=3, padding=1), 27 | activation(), 28 | nn.MaxPool2d(2, stride=2), 29 | nn.BatchNorm2d(16), 30 | nn.Conv2d(16, 32, kernel_size=3, padding=1), 31 | activation(), 32 | nn.MaxPool2d(2, stride=2), 33 | nn.BatchNorm2d(32), 34 | nn.Conv2d(32, 64, kernel_size=3, padding=1), 35 | activation(), 36 | nn.MaxPool2d(2, stride=2), 37 | nn.BatchNorm2d(64), 38 | nn.Conv2d(64, 64, kernel_size=3, padding=1), 39 | activation(), 40 | nn.BatchNorm2d(64), 41 | nn.Conv2d(64, 128, kernel_size=3, padding=1), 42 | activation(), 43 | nn.BatchNorm2d(128) 44 | ) 45 | # dw pooling 46 | self.bottleneck_size = 256 47 | self.pool = nn.Sequential( 48 | nn.Conv2d(128, 128, kernel_size=6, padding=0, groups=128), 49 | activation(), 50 | nn.BatchNorm2d(128), 51 | nn.Conv2d(128, self.bottleneck_size, kernel_size=1, padding=0), 52 | activation(), 53 | nn.BatchNorm2d(self.bottleneck_size), 54 | ) 55 | # Regressor for 5 landmarks (10 coordinates) 56 | self.fc_loc = nn.Sequential( 57 | nn.Conv2d(self.bottleneck_size, 64, kernel_size=1), 58 | activation(), 59 | nn.Conv2d(64, 10, kernel_size=1), 60 | nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | xs = self.landnet(self.bn_first(x)) 65 | xs = self.pool(xs) 66 | xs = self.fc_loc(xs) 67 | return xs 68 | 69 | def get_input_res(self): 70 | return 48, 48 71 | 72 | def set_dropout_ratio(self, ratio): 73 | pass 74 | -------------------------------------------------------------------------------- /model/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | import torch.nn as nn 16 | 17 | from losses.am_softmax import AngleSimpleLinear 18 | from model.blocks.mobilenet_v2_blocks import InvertedResidual 19 | from model.blocks.shared_blocks import make_activation 20 | from .common import ModelInterface 21 | 22 | 23 | def init_block(in_channels, out_channels, stride, activation=nn.PReLU): 24 | """Builds the first block of the MobileFaceNet""" 25 | return nn.Sequential( 26 | nn.BatchNorm2d(3), 27 | nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False), 28 | nn.BatchNorm2d(out_channels), 29 | make_activation(activation) 30 | ) 31 | 32 | 33 | class MobileFaceNet(ModelInterface): 34 | """Implements modified MobileFaceNet from https://arxiv.org/abs/1804.07573""" 35 | def __init__(self, embedding_size=128, num_classes=1, width_multiplier=1., feature=True): 36 | super(MobileFaceNet, self).__init__() 37 | assert embedding_size > 0 38 | assert num_classes > 0 39 | assert width_multiplier > 0 40 | self.feature = feature 41 | 42 | # Set up of inverted residual blocks 43 | inverted_residual_setting = [ 44 | # t, c, n, s 45 | [2, 64, 5, 2], 46 | [4, 128, 1, 2], 47 | [2, 128, 6, 1], 48 | [4, 128, 1, 2], 49 | [2, 128, 2, 1] 50 | ] 51 | 52 | first_channel_num = 64 53 | last_channel_num = 512 54 | self.features = [init_block(3, first_channel_num, 2)] 55 | 56 | self.features.append(nn.Conv2d(first_channel_num, first_channel_num, 3, 1, 1, 57 | groups=first_channel_num, bias=False)) 58 | self.features.append(nn.BatchNorm2d(64)) 59 | self.features.append(nn.PReLU()) 60 | 61 | # Inverted Residual Blocks 62 | in_channel_num = first_channel_num 63 | size_h, size_w = MobileFaceNet.get_input_res() 64 | size_h, size_w = size_h // 2, size_w // 2 65 | for t, c, n, s in inverted_residual_setting: 66 | output_channel = int(c * width_multiplier) 67 | for i in range(n): 68 | if i == 0: 69 | size_h, size_w = size_h // s, size_w // s 70 | self.features.append(InvertedResidual(in_channel_num, output_channel, 71 | s, t, outp_size=(size_h, size_w))) 72 | else: 73 | self.features.append(InvertedResidual(in_channel_num, output_channel, 74 | 1, t, outp_size=(size_h, size_w))) 75 | in_channel_num = output_channel 76 | 77 | # 1x1 expand block 78 | self.features.append(nn.Sequential(nn.Conv2d(in_channel_num, last_channel_num, 1, 1, 0, bias=False), 79 | nn.BatchNorm2d(last_channel_num), 80 | nn.PReLU())) 81 | self.features = nn.Sequential(*self.features) 82 | 83 | # Depth-wise pooling 84 | k_size = (MobileFaceNet.get_input_res()[0] // 16, MobileFaceNet.get_input_res()[1] // 16) 85 | self.dw_pool = nn.Conv2d(last_channel_num, last_channel_num, k_size, 86 | groups=last_channel_num, bias=False) 87 | self.dw_bn = nn.BatchNorm2d(last_channel_num) 88 | self.conv1_extra = nn.Conv2d(last_channel_num, embedding_size, 1, stride=1, padding=0, bias=False) 89 | 90 | if not self.feature: 91 | self.fc_angular = AngleSimpleLinear(embedding_size, num_classes) 92 | 93 | self.init_weights() 94 | 95 | def forward(self, x): 96 | x = self.features(x) 97 | x = self.dw_bn(self.dw_pool(x)) 98 | x = self.conv1_extra(x) 99 | 100 | if self.feature or not self.training: 101 | return x 102 | 103 | x = x.view(x.size(0), -1) 104 | y = self.fc_angular(x) 105 | 106 | return x, y 107 | 108 | @staticmethod 109 | def get_input_res(): 110 | return 128, 128 111 | 112 | def set_dropout_ratio(self, ratio): 113 | assert 0 <= ratio < 1. 114 | 115 | def init_weights(self): 116 | """Initializes weights of the model before training""" 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | if m.bias is not None: 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | n = m.weight.size(1) 128 | m.weight.data.normal_(0, 0.01) 129 | m.bias.data.zero_() 130 | -------------------------------------------------------------------------------- /model/resnet_angular.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | from losses.am_softmax import AngleSimpleLinear 16 | from model.backbones.resnet import resnet50 17 | from .common import ModelInterface 18 | 19 | 20 | class ResNetAngular(ModelInterface): 21 | """Face reid head for the ResNet architecture""" 22 | def __init__(self, embedding_size=128, num_classes=0, feature=True): 23 | super(ResNetAngular, self).__init__() 24 | 25 | self.bn_first = nn.BatchNorm2d(3) 26 | self.feature = feature 27 | self.model = resnet50(num_classes=embedding_size, activation=nn.PReLU) 28 | self.embedding_size = embedding_size 29 | 30 | if not self.feature: 31 | self.fc_angular = AngleSimpleLinear(self.embedding_size, num_classes) 32 | 33 | def forward(self, x): 34 | 35 | x = self.bn_first(x) 36 | x = self.model(x) 37 | 38 | if self.feature or not self.training: 39 | return x 40 | 41 | x = x.view(x.size(0), -1) 42 | y = self.fc_angular(x) 43 | 44 | return x, y 45 | 46 | @staticmethod 47 | def get_input_res(): 48 | return 112, 112 49 | 50 | def set_dropout_ratio(self, ratio): 51 | assert 0 <= ratio < 1. 52 | -------------------------------------------------------------------------------- /model/rmnet_angular.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from losses.am_softmax import AngleSimpleLinear 17 | from model.backbones.rmnet import RMNetBody 18 | from model.blocks.rmnet_blocks import RMBlock 19 | from .common import ModelInterface 20 | 21 | 22 | class RMNetAngular(ModelInterface): 23 | """Face reid head for the ResMobNet architecture. See https://arxiv.org/pdf/1812.02465.pdf for details 24 | about the ResMobNet backbone.""" 25 | def __init__(self, embedding_size, num_classes=0, feature=True, body=RMNetBody): 26 | super(RMNetAngular, self).__init__() 27 | self.feature = feature 28 | self.backbone = body() 29 | self.global_pooling = nn.MaxPool2d((8, 8)) 30 | self.conv1_extra = nn.Conv2d(256, embedding_size, 1, stride=1, padding=0, bias=False) 31 | if not feature: 32 | self.fc_angular = AngleSimpleLinear(embedding_size, num_classes) 33 | 34 | def forward(self, x): 35 | x = self.backbone(x) 36 | x = self.global_pooling(x) 37 | x = self.conv1_extra(x) 38 | 39 | if self.feature or not self.training: 40 | return x 41 | 42 | x = x.view(x.size(0), -1) 43 | y = self.fc_angular(x) 44 | 45 | return x, y 46 | 47 | def set_dropout_ratio(self, ratio): 48 | assert 0 <= ratio < 1. 49 | 50 | for m in self.backbone.modules(): 51 | if isinstance(m, RMBlock): 52 | m.dropout_ratio = ratio 53 | 54 | @staticmethod 55 | def get_input_res(): 56 | return 128, 128 57 | -------------------------------------------------------------------------------- /model/se_resnet_angular.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from losses.am_softmax import AngleSimpleLinear 17 | from model.backbones.se_resnet import se_resnet50 18 | from .common import ModelInterface 19 | 20 | 21 | class SEResNetAngular(ModelInterface): 22 | """Face reid head for the SE ResNet architecture""" 23 | def __init__(self, embedding_size=128, num_classes=0, feature=True): 24 | super(SEResNetAngular, self).__init__() 25 | 26 | self.bn_first = nn.BatchNorm2d(3) 27 | self.feature = feature 28 | self.model = se_resnet50(num_classes=embedding_size, activation=nn.PReLU) 29 | self.embedding_size = embedding_size 30 | 31 | if not self.feature: 32 | self.fc_angular = AngleSimpleLinear(self.embedding_size, num_classes) 33 | 34 | def forward(self, x): 35 | x = self.bn_first(x) 36 | x = self.model(x) 37 | 38 | if self.feature or not self.training: 39 | return x 40 | 41 | x = x.view(x.size(0), -1) 42 | y = self.fc_angular(x) 43 | 44 | return x, y 45 | 46 | @staticmethod 47 | def get_input_res(): 48 | return 112, 112 49 | 50 | def set_dropout_ratio(self, ratio): 51 | assert 0 <= ratio < 1. 52 | -------------------------------------------------------------------------------- /model/shufflenet_v2_angular.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import torch.nn as nn 15 | 16 | from losses.am_softmax import AngleSimpleLinear 17 | from model.backbones.shufflenet_v2 import ShuffleNetV2Body 18 | from .common import ModelInterface 19 | 20 | 21 | class ShuffleNetV2Angular(ModelInterface): 22 | """Face reid head for the ShuffleNetV2 architecture""" 23 | def __init__(self, embedding_size, num_classes=0, feature=True, body=ShuffleNetV2Body, **kwargs): 24 | super(ShuffleNetV2Angular, self).__init__() 25 | self.feature = feature 26 | kwargs['input_size'] = ShuffleNetV2Angular.get_input_res()[0] 27 | kwargs['width_mult'] = 1. 28 | self.backbone = body(**kwargs) 29 | k_size = int(kwargs['input_size'] / self.backbone.get_downscale_factor()) 30 | self.global_pool = nn.Conv2d(self.backbone.stage_out_channels[-1], self.backbone.stage_out_channels[-1], 31 | (k_size, k_size), groups=self.backbone.stage_out_channels[-1], bias=False) 32 | self.conv1_extra = nn.Conv2d(self.backbone.get_num_output_channels(), embedding_size, 1, padding=0, bias=False) 33 | if not feature: 34 | self.fc_angular = AngleSimpleLinear(embedding_size, num_classes) 35 | 36 | def forward(self, x): 37 | x = self.backbone(x) 38 | x = self.global_pool(x) 39 | x = self.conv1_extra(x) 40 | 41 | if self.feature or not self.training: 42 | return x 43 | 44 | x = x.view(x.size(0), -1) 45 | y = self.fc_angular(x) 46 | 47 | return x, y 48 | 49 | def set_dropout_ratio(self, ratio): 50 | assert 0 <= ratio < 1. 51 | 52 | @staticmethod 53 | def get_input_res(): 54 | res = 128 55 | return res, res 56 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | glog==0.3.1 2 | numpy==1.15.4 3 | opencv-python==3.4.4.19 4 | Pillow==5.3.0 5 | protobuf==3.6.1 6 | python-gflags==3.1.2 7 | scipy==1.1.0 8 | six==1.11.0 9 | tensorboardX==1.4 10 | torch==0.4.1 11 | torchvision==0.2.1 12 | tqdm==4.28.1 13 | pyyaml>=3.12 14 | ptflops==0.1 15 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/accuracy_check.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | import glog as log 16 | 17 | import numpy as np 18 | import torch 19 | from tqdm import tqdm 20 | import cv2 as cv 21 | 22 | from utils.utils import load_model_state 23 | from utils.ie_tools import load_ie_model 24 | from model.common import models_backbones, models_landmarks 25 | 26 | def main(): 27 | """Runs the accuracy check""" 28 | parser = argparse.ArgumentParser(description='Accuracy check script (pt vs caffe)') 29 | parser.add_argument('--embed_size', type=int, default=128, help='Size of the face embedding.') 30 | parser.add_argument('--snap', type=str, required=True, help='Snapshot to convert.') 31 | parser.add_argument('--device', '-d', default=0, type=int, help='Device for model placement.') 32 | parser.add_argument('--model', choices=list(models_backbones.keys()) + list(models_landmarks.keys()), type=str, 33 | default='rmnet') 34 | 35 | # IE-related options 36 | parser.add_argument('--ie_model', type=str, required=True) 37 | parser.add_argument("-l", "--cpu_extension", 38 | help="MKLDNN (CPU)-targeted custom layers.Absolute path to a shared library with the kernels " 39 | "impl.", type=str, default=None) 40 | parser.add_argument("-pp", "--plugin_dir", help="Path to a plugin folder", type=str, default=None) 41 | parser.add_argument("-d_ie", "--device_ie", 42 | help="Specify the target device to infer on; CPU, GPU, FPGA or MYRIAD is acceptable. Sample " 43 | "will look for a suitable plugin for device specified (CPU by default)", default="CPU", 44 | type=str) 45 | 46 | args = parser.parse_args() 47 | 48 | max_err = 0. 49 | with torch.cuda.device(args.device): 50 | if args.model in models_landmarks.keys(): 51 | pt_model = models_landmarks[args.model] 52 | else: 53 | pt_model = models_backbones[args.model](embedding_size=args.embed_size, feature=True) 54 | pt_model = load_model_state(pt_model, args.snap, args.device) 55 | 56 | ie_model = load_ie_model(args.ie_model, args.device_ie, args.plugin_dir, args.cpu_extension) 57 | np.random.seed(0) 58 | 59 | for _ in tqdm(range(100)): 60 | input_img = np.random.randint(0, high=255, size=(*pt_model.get_input_res(), 3), dtype=np.uint8) 61 | input_bgr = cv.cvtColor(input_img, cv.COLOR_BGR2RGB) 62 | 63 | input_pt = torch.unsqueeze(torch.from_numpy(input_img.transpose(2, 0, 1).astype('float32') / 255.).cuda(), 64 | dim=0) 65 | pt_output = (pt_model(input_pt)).data.cpu().numpy().reshape(1, -1) 66 | ie_output = ie_model.forward(input_bgr).reshape(1, -1) 67 | 68 | max_err = max(np.linalg.norm(pt_output - ie_output, np.inf), max_err) 69 | 70 | log.info('Max l_inf error: %e', max_err) 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /scripts/align_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | import os 16 | import os.path as osp 17 | import json 18 | 19 | import cv2 as cv 20 | import torch 21 | from tqdm import tqdm 22 | from torchvision.transforms import transforms 23 | 24 | from model import landnet 25 | from utils import utils 26 | from utils import augmentation 27 | from utils.face_align import FivePointsAligner 28 | 29 | class LandnetPT: 30 | """Wrapper for landmarks regression model""" 31 | def __init__(self, model): 32 | self.net = model 33 | self.transformer = transforms.Compose( 34 | [augmentation.ResizeNumpy((48, 48)), augmentation.NumpyToTensor(switch_rb=True)]) 35 | 36 | def get_landmarks(self, batch): 37 | converted_batch = [] 38 | for item in batch: 39 | converted_batch.append(self.transformer(item)) 40 | pt_blob = torch.stack(converted_batch).cuda() 41 | landmarks = self.net(pt_blob) 42 | return landmarks.data.cpu().numpy() 43 | 44 | 45 | class FaceDetector: 46 | """Wrapper class for face detector""" 47 | def __init__(self, proto, model, conf=.6, expand_ratio=(1.1, 1.05), size=(300, 300)): 48 | self.net = cv.dnn.readNetFromCaffe(proto, model) 49 | self.net.setPreferableBackend(cv.dnn.DNN_BACKEND_DEFAULT) 50 | self.net.setPreferableTarget(cv.dnn.DNN_TARGET_CPU) 51 | last_layer_id = self.net.getLayerId(self.net.getLayerNames()[-1]) 52 | last_layer = self.net.getLayer(last_layer_id) 53 | assert last_layer.type == 'DetectionOutput' 54 | 55 | self.confidence = conf 56 | self.expand_ratio = expand_ratio 57 | self.det_res = size 58 | 59 | def __decode_detections(self, out, frame_shape): 60 | """Decodes raw SSD output""" 61 | frame_height = frame_shape[0] 62 | frame_width = frame_shape[1] 63 | detections = [] 64 | 65 | for detection in out[0, 0]: 66 | confidence = detection[2] 67 | if confidence > self.confidence: 68 | left = int(max(detection[3], 0) * frame_width) 69 | top = int(max(detection[4], 0) * frame_height) 70 | right = int(max(detection[5], 0) * frame_width) 71 | bottom = int(max(detection[6], 0) * frame_height) 72 | if self.expand_ratio != (1., 1.): 73 | w = (right - left) 74 | h = (bottom - top) 75 | dw = w * (self.expand_ratio[0] - 1.) / 2 76 | dh = h * (self.expand_ratio[1] - 1.) / 2 77 | left = max(int(left - dw), 0) 78 | right = int(right + dw) 79 | top = max(int(top - dh), 0) 80 | bottom = int(bottom + dh) 81 | 82 | # classId = int(detection[1]) - 1 # Skip background label 83 | detections.append(((left, top, right, bottom), confidence)) 84 | 85 | if len(detections) > 1: 86 | detections.sort(key=lambda x: x[1], reverse=True) 87 | 88 | return detections 89 | 90 | def get_detections(self, frame): 91 | """Returns all detections on frame""" 92 | blob = cv.dnn.blobFromImage(frame, 1., (self.det_res[0], self.det_res[1]), crop=False) 93 | self.net.setInput(blob) 94 | out = self.net.forward() 95 | detections = self.__decode_detections(out, frame.shape) 96 | return detections 97 | 98 | 99 | def draw_detections(frame, detections, landmarks): 100 | """Draw detections and landmarks on a frame""" 101 | for _, rect in enumerate(detections): 102 | left, top, right, bottom = rect 103 | cv.rectangle(frame, (left, top), (right, bottom), (0, 255, 0), thickness=2) 104 | for point in landmarks.reshape(-1, 2): 105 | point = (int(left + point[0] * (right - left)), int(top + point[1] * (bottom - top))) 106 | cv.circle(frame, point, 5, (255, 0, 0), -1) 107 | 108 | return frame 109 | 110 | 111 | def run_dumping(images_list, face_det, landmarks_regressor, vis_flag): 112 | """Dumps detections and landmarks from images""" 113 | detected_num = 0 114 | data = [] 115 | for path in tqdm(images_list, 'Dumping data'): 116 | image = cv.imread(path, cv.IMREAD_COLOR) 117 | if image is None: 118 | continue 119 | 120 | detections = face_det.get_detections(image) 121 | landmarks = None 122 | if detections: 123 | left, top, right, bottom = detections[0][0] 124 | roi = image[top:bottom, left:right] 125 | landmarks = landmarks_regressor.get_landmarks([roi]).reshape(-1) 126 | data.append({'path': path, 'bbox': detections[0][0], 'landmarks': landmarks}) 127 | detected_num += 1 128 | if vis_flag: 129 | FivePointsAligner.align(roi, landmarks, 130 | d_size=(200,200), normalize=False, show=True) 131 | else: 132 | data.append({'path': path, 'bbox': None, 'landmarks': None}) 133 | 134 | print('Detection ratio: ', float(detected_num) / len(data)) 135 | 136 | return data 137 | 138 | 139 | def create_images_list(images_root, imgs_list): 140 | input_filenames = [] 141 | input_dir = os.path.abspath(images_root) 142 | 143 | if imgs_list is None: 144 | stop = False 145 | for path, _, files in os.walk(input_dir): 146 | if stop: 147 | break 148 | for name in files: 149 | if name.lower().endswith('.jpg') or name.lower().endswith('.png') \ 150 | or name.lower().endswith('.jpeg') or name.lower().endswith('.gif') \ 151 | or not '.' in name: 152 | filename = os.path.join(path, name) 153 | input_filenames.append(filename) 154 | else: 155 | with open(imgs_list) as f: 156 | data = json.load(f) 157 | for path in data['path']: 158 | filename = osp.join(images_root, path) 159 | input_filenames.append(filename) 160 | 161 | return input_filenames 162 | 163 | 164 | def save_data(data, filename, root_dir): 165 | print('Saving data...') 166 | with open(filename, 'w') as f: 167 | for instance in data: 168 | line = osp.relpath(instance['path'], start=root_dir) + ' | ' 169 | if instance['bbox'] is not None: 170 | for x in instance['landmarks']: 171 | line += str(x) + ' ' 172 | line += ' | ' 173 | left, top, right, bottom = instance['bbox'] 174 | line += str(left) + ' ' + str(top) + ' ' + str(right - left) + ' ' + str(bottom - top) 175 | 176 | f.write(line.strip() + '\n') 177 | 178 | def main(): 179 | parser = argparse.ArgumentParser(description='') 180 | parser.add_argument('--images_root', type=str, default=None, required=True) 181 | parser.add_argument('--images_list', type=str, default=None, required=False) 182 | parser.add_argument('--fd_proto', type=str, default='../demo/face_detector/deploy_fd.prototxt', help='') 183 | parser.add_argument('--fd_model', type=str, default='../demo/face_detector/sq_300x300_iter_120000.caffemodel', 184 | help='') 185 | parser.add_argument('--fr_thresh', type=float, default=0.1) 186 | parser.add_argument('--det_res', type=int, nargs=2, default=[300, 300], help='Detection net input resolution.') 187 | parser.add_argument('--landnet_model', type=str) 188 | parser.add_argument('--device', type=int, default=0) 189 | parser.add_argument('--visualize', action='store_true') 190 | args = parser.parse_args() 191 | 192 | face_detector = FaceDetector(args.fd_proto, args.fd_model, conf=args.fr_thresh, size=args.det_res) 193 | 194 | with torch.cuda.device(args.device): 195 | landmarks_regressor = utils.load_model_state(landnet.LandmarksNet(), args.landnet_model, args.device) 196 | data = run_dumping(create_images_list(args.images_root, args.images_list), face_detector, 197 | LandnetPT(landmarks_regressor), args.visualize) 198 | save_data(data, osp.join(args.images_root, 'list.txt'), args.images_root) 199 | 200 | if __name__ == '__main__': 201 | main() 202 | -------------------------------------------------------------------------------- /scripts/count_flops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | import torch 16 | 17 | from model.common import models_backbones, models_landmarks 18 | from ptflops import get_model_complexity_info 19 | 20 | 21 | def main(): 22 | """Runs flops counter""" 23 | parser = argparse.ArgumentParser(description='Evaluation script for Face Recognition in PyTorch') 24 | parser.add_argument('--embed_size', type=int, default=128, help='Size of the face embedding.') 25 | parser.add_argument('--model', choices=list(models_backbones.keys()) + list(models_landmarks.keys()), type=str, 26 | default='rmnet') 27 | args = parser.parse_args() 28 | 29 | with torch.no_grad(): 30 | if args.model in models_landmarks.keys(): 31 | model = models_landmarks[args.model]() 32 | else: 33 | model = models_backbones[args.model](embedding_size=args.embed_size, feature=True) 34 | 35 | flops, params = get_model_complexity_info(model, model.get_input_res(), 36 | as_strings=True, print_per_layer_stat=True) 37 | print('Flops: {}'.format(flops)) 38 | print('Params: {}'.format(params)) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /scripts/matio.py: -------------------------------------------------------------------------------- 1 | # pylint: skip-file 2 | import struct 3 | import numpy as np 4 | 5 | cv_type_to_dtype = { 6 | 5 : np.dtype('float32'), 7 | 6 : np.dtype('float64') 8 | } 9 | 10 | dtype_to_cv_type = {v : k for k,v in cv_type_to_dtype.items()} 11 | 12 | def write_mat(f, m): 13 | """Write mat m to file f""" 14 | if len(m.shape) == 1: 15 | rows = m.shape[0] 16 | cols = 1 17 | else: 18 | rows, cols = m.shape 19 | header = struct.pack('iiii', rows, cols, cols * 4, dtype_to_cv_type[m.dtype]) 20 | f.write(header) 21 | f.write(m.data) 22 | 23 | 24 | def read_mat(f): 25 | """ 26 | Reads an OpenCV mat from the given file opened in binary mode 27 | """ 28 | rows, cols, stride, type_ = struct.unpack('iiii', f.read(4*4)) 29 | mat = np.fromstring(f.read(rows*stride),dtype=cv_type_to_dtype[type_]) 30 | return mat.reshape(rows,cols) 31 | 32 | def read_mkl_vec(f): 33 | """ 34 | Reads an OpenCV mat from the given file opened in binary mode 35 | """ 36 | # Read past the header information 37 | f.read(4*4) 38 | 39 | length, stride, type_ = struct.unpack('iii', f.read(3*4)) 40 | mat = np.fromstring(f.read(length*4),dtype=np.float32) 41 | return mat 42 | 43 | def load_mkl_vec(filename): 44 | """ 45 | Reads a OpenCV Mat from the given filename 46 | """ 47 | return read_mkl_vec(open(filename,'rb')) 48 | 49 | def load_mat(filename): 50 | """ 51 | Reads a OpenCV Mat from the given filename 52 | """ 53 | return read_mat(open(filename,'rb')) 54 | 55 | def save_mat(filename, m): 56 | """Saves mat m to the given filename""" 57 | return write_mat(open(filename,'wb'), m) 58 | 59 | def main(): 60 | f = open('1_to_0.bin','rb') 61 | vx = read_mat(f) 62 | vy = read_mat(f) 63 | 64 | if __name__ == '__main__': 65 | main() 66 | -------------------------------------------------------------------------------- /scripts/plot_roc_curves_lfw.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | 16 | import matplotlib.pyplot as plt 17 | from evaluate_lfw import get_auc 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser(description='') 21 | parser.add_argument('rocs', metavar='ROCs', type=str, nargs='+', 22 | help='paths to roc curves') 23 | 24 | args = parser.parse_args() 25 | 26 | plt.xlabel("False Positive Rate") 27 | plt.ylabel("True Positive Rate") 28 | plt.grid(b=True, which='major', color='k', linestyle='-') 29 | plt.grid(b=True, which='minor', color='k', linestyle='-', alpha=0.2) 30 | plt.minorticks_on() 31 | 32 | for curve_file in args.rocs: 33 | fprs = [] 34 | tprs = [] 35 | with open(curve_file, 'r') as f: 36 | for line in f.readlines(): 37 | values = line.strip().split() 38 | fprs.append(float(values[1])) 39 | tprs.append(float(values[0])) 40 | 41 | curve_name = curve_file.split('/')[-1].split('.')[0] 42 | plt.plot(fprs, tprs, label=curve_name) 43 | plt.legend(loc='best', fontsize=10) 44 | 45 | print('AUC for {}: {}'.format(curve_name, get_auc(fprs, tprs))) 46 | 47 | plt.show() 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /scripts/pytorch2onnx.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | import torch 16 | 17 | from utils.utils import load_model_state 18 | from model.common import models_backbones, models_landmarks 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser(description='Conversion script for FR models from PyTorch to ONNX') 22 | parser.add_argument('--embed_size', type=int, default=128, help='Size of the face embedding.') 23 | parser.add_argument('--snap', type=str, required=True, help='Snapshot to convert.') 24 | parser.add_argument('--device', '-d', default=-1, type=int, help='Device for model placement.') 25 | parser.add_argument('--output_dir', default='./', type=str, help='Output directory.') 26 | parser.add_argument('--model', choices=list(models_backbones.keys()) + list(models_landmarks.keys()), 27 | type=str, default='rmnet') 28 | 29 | args = parser.parse_args() 30 | 31 | if args.model in models_landmarks.keys(): 32 | model = models_landmarks[args.model]() 33 | else: 34 | model = models_backbones[args.model](embedding_size=args.embed_size, feature=True) 35 | 36 | model = load_model_state(model, args.snap, args.device, eval_state=True) 37 | input_var = torch.rand(1, 3, *model.get_input_res()) 38 | dump_name = args.snap[args.snap.rfind('/') + 1:-3] 39 | 40 | torch.onnx.export(model, input_var, dump_name + '.onnx', verbose=True, export_params=True) 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import unittest 15 | import cv2 as cv 16 | import numpy as np 17 | 18 | from utils.face_align import FivePointsAligner 19 | from utils.landmarks_augmentation import RandomRotate 20 | 21 | 22 | class FaceAlignmentTests(unittest.TestCase): 23 | """Tests for alignment methods""" 24 | def test_align_image(self): 25 | """Synthetic test for alignment function""" 26 | image = np.zeros((128, 128, 3), dtype=np.float32) 27 | for point in FivePointsAligner.ref_landmarks: 28 | point_scaled = point * [128, 128] 29 | cv.circle(image, tuple(point_scaled.astype(np.int)), 5, (255, 255, 255), cv.FILLED) 30 | 31 | transform = RandomRotate(40., p=1.) 32 | rotated_data = transform({'img': image, 'landmarks': FivePointsAligner.ref_landmarks}) 33 | aligned_image = FivePointsAligner.align(rotated_data['img'], \ 34 | rotated_data['landmarks'].reshape(-1), 35 | d_size=(128, 128), normalized=True) 36 | 37 | for point in FivePointsAligner.ref_landmarks: 38 | point_scaled = (point * [128, 128]).astype(np.int) 39 | check_sum = np.mean(aligned_image[point_scaled[1] - 3 : point_scaled[1] + 3, 40 | point_scaled[0] - 3 : point_scaled[0] + 3]) 41 | self.assertGreaterEqual(check_sum, 220) 42 | 43 | if __name__ == '__main__': 44 | unittest.main() 45 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import unittest 15 | import os 16 | import torch 17 | 18 | from model.common import models_backbones, models_landmarks 19 | from utils.utils import save_model_cpu, load_model_state 20 | 21 | 22 | class BackbonesTests(unittest.TestCase): 23 | """Tests for backbones""" 24 | def test_output_shape(self): 25 | """Checks output shape""" 26 | embed_size = 256 27 | for model_type in models_backbones.values(): 28 | model = model_type(embedding_size=embed_size, feature=True).eval() 29 | batch = torch.Tensor(1, 3, *model.get_input_res()).uniform_() 30 | output = model(batch) 31 | self.assertEqual(list(output.shape), list((1, embed_size, 1, 1))) 32 | 33 | def test_save_load_snap(self): 34 | """Checks an ability to save and load model correctly""" 35 | embed_size = 256 36 | snap_name = os.path.join(os.getcwd(), 'test_snap.pt') 37 | for model_type in models_backbones.values(): 38 | model = model_type(embedding_size=embed_size, feature=True).eval() 39 | batch = torch.Tensor(1, 3, *model.get_input_res()).uniform_() 40 | output = model(batch) 41 | save_model_cpu(model, None, snap_name, 0, write_solverstate=False) 42 | 43 | model_loaded = model_type(embedding_size=embed_size, feature=True) 44 | load_model_state(model_loaded, snap_name, -1, eval_state=True) 45 | 46 | output_loaded = model_loaded(batch) 47 | 48 | self.assertEqual(torch.norm(output - output_loaded), 0) 49 | 50 | 51 | class LandnetTests(unittest.TestCase): 52 | """Tests for landmark regressor""" 53 | def test_output_shape(self): 54 | """Checks output shape""" 55 | model = models_landmarks['landnet']().eval() 56 | batch = torch.Tensor(1, 3, *model.get_input_res()) 57 | output = model(batch) 58 | self.assertEqual(list(output.shape), list((1, 10, 1, 1))) 59 | 60 | def test_save_load_snap(self): 61 | """Checks an ability to save and load model correctly""" 62 | snap_name = os.path.join(os.getcwd(), 'test_snap.pt') 63 | model = models_landmarks['landnet']().eval() 64 | batch = torch.Tensor(1, 3, *model.get_input_res()).uniform_() 65 | output = model(batch) 66 | save_model_cpu(model, None, snap_name, 0, write_solverstate=False) 67 | 68 | model_loaded = models_landmarks['landnet']() 69 | load_model_state(model_loaded, snap_name, -1, eval_state=True) 70 | 71 | output_loaded = model_loaded(batch) 72 | 73 | self.assertEqual(torch.norm(output - output_loaded), 0) 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import unittest 15 | 16 | import torch 17 | from utils.utils import get_model_parameters_number 18 | 19 | 20 | class UtilsTests(unittest.TestCase): 21 | """Tests for utils""" 22 | def test_parameters_counter(self): 23 | """Checks output of get_model_parameters_number""" 24 | class ParamsHolder(torch.nn.Module): 25 | """Dummy parameters holder""" 26 | def __init__(self, n_params): 27 | super(ParamsHolder, self).__init__() 28 | self.p1 = torch.nn.Parameter(torch.Tensor(n_params // 2)) 29 | self.p2 = torch.nn.Parameter(torch.Tensor(n_params // 2)) 30 | self.dummy = -1 31 | 32 | params_num = 1000 33 | module = ParamsHolder(params_num) 34 | estimated_params = get_model_parameters_number(module, as_string=False) 35 | self.assertEqual(estimated_params, params_num) 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /train_landmarks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import argparse 15 | import datetime 16 | import os.path as osp 17 | 18 | import numpy as np 19 | import glog as log 20 | from tensorboardX import SummaryWriter 21 | import torch 22 | import torch.backends.cudnn as cudnn 23 | import torch.optim as optim 24 | from torch.utils.data import DataLoader 25 | from torchvision.transforms import transforms 26 | 27 | from datasets import VGGFace2, CelebA, NDG 28 | 29 | from model.common import models_landmarks 30 | from utils import landmarks_augmentation 31 | from utils.utils import save_model_cpu, load_model_state 32 | from losses.alignment import AlignmentLoss 33 | from evaluate_landmarks import evaluate 34 | 35 | 36 | def train(args): 37 | """Launches training of landmark regression model""" 38 | if args.dataset == 'vgg': 39 | drops_schedule = [1, 6, 9, 13] 40 | dataset = VGGFace2(args.train, args.t_list, args.t_land, landmarks_training=True) 41 | elif args.dataset == 'celeba': 42 | drops_schedule = [10, 20] 43 | dataset = CelebA(args.train, args.t_land) 44 | else: 45 | drops_schedule = [90, 140, 200] 46 | dataset = NDG(args.train, args.t_land) 47 | 48 | if dataset.have_landmarks: 49 | log.info('Use alignment for the train data') 50 | dataset.transform = transforms.Compose([landmarks_augmentation.Rescale((56, 56)), 51 | landmarks_augmentation.Blur(k=3, p=.2), 52 | landmarks_augmentation.HorizontalFlip(p=.5), 53 | landmarks_augmentation.RandomRotate(50), 54 | landmarks_augmentation.RandomScale(.8, .9, p=.4), 55 | landmarks_augmentation.RandomCrop(48), 56 | landmarks_augmentation.ToTensor(switch_rb=True)]) 57 | else: 58 | log.info('Error: training dataset has no landmarks data') 59 | exit() 60 | 61 | train_loader = DataLoader(dataset, batch_size=args.train_batch_size, num_workers=4, shuffle=True) 62 | writer = SummaryWriter('./logs_landm/{:%Y_%m_%d_%H_%M}_'.format(datetime.datetime.now()) + args.snap_prefix) 63 | model = models_landmarks['landnet'] 64 | 65 | if args.snap_to_resume is not None: 66 | log.info('Resuming snapshot ' + args.snap_to_resume + ' ...') 67 | model = load_model_state(model, args.snap_to_resume, args.device, eval_state=False) 68 | model = torch.nn.DataParallel(model, device_ids=[args.device]) 69 | else: 70 | model = torch.nn.DataParallel(model, device_ids=[args.device]) 71 | model.cuda() 72 | model.train() 73 | cudnn.enabled = True 74 | cudnn.benchmark = True 75 | 76 | log.info('Face landmarks model:') 77 | log.info(model) 78 | 79 | criterion = AlignmentLoss('wing') 80 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 81 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, drops_schedule) 82 | for epoch_num in range(args.epoch_total_num): 83 | scheduler.step() 84 | if epoch_num > 5: 85 | model.module.set_dropout_ratio(0.) 86 | for i, data in enumerate(train_loader, 0): 87 | iteration = epoch_num * len(train_loader) + i 88 | 89 | data, gt_landmarks = data['img'].cuda(), data['landmarks'].cuda() 90 | predicted_landmarks = model(data) 91 | 92 | optimizer.zero_grad() 93 | loss = criterion(predicted_landmarks, gt_landmarks) 94 | loss.backward() 95 | optimizer.step() 96 | 97 | if i % 10 == 0: 98 | log.info('Iteration %d, Loss: %.4f' % (iteration, loss)) 99 | log.info('Learning rate: %f' % scheduler.get_lr()[0]) 100 | writer.add_scalar('Loss/train_loss', loss.item(), iteration) 101 | writer.add_scalar('Learning_rate', scheduler.get_lr()[0], iteration) 102 | 103 | if iteration % args.val_step == 0: 104 | snapshot_name = osp.join(args.snap_folder, args.snap_prefix + '_{0}.pt'.format(iteration)) 105 | log.info('Saving Snapshot: ' + snapshot_name) 106 | save_model_cpu(model, optimizer, snapshot_name, epoch_num) 107 | 108 | model.eval() 109 | log.info('Evaluating Snapshot: ' + snapshot_name) 110 | avg_err, per_point_avg_err, failures_rate = evaluate(train_loader, model) 111 | weights = per_point_avg_err / np.sum(per_point_avg_err) 112 | criterion.set_weights(weights) 113 | log.info(str(weights)) 114 | log.info('Avg train error: {}'.format(avg_err)) 115 | log.info('Train failure rate: {}'.format(failures_rate)) 116 | writer.add_scalar('Quality/Avg_error', avg_err, iteration) 117 | writer.add_scalar('Quality/Failure_rate', failures_rate, iteration) 118 | model.train() 119 | 120 | def main(): 121 | """Creates a command line parser""" 122 | parser = argparse.ArgumentParser(description='Training Landmarks detector in PyTorch') 123 | parser.add_argument('--train_data_root', dest='train', required=True, type=str, help='Path to train data.') 124 | parser.add_argument('--train_list', dest='t_list', required=False, type=str, help='Path to train data image list.') 125 | parser.add_argument('--train_landmarks', default='', dest='t_land', required=False, type=str, 126 | help='Path to landmarks for the train images.') 127 | parser.add_argument('--train_batch_size', type=int, default=170, help='Train batch size.') 128 | parser.add_argument('--epoch_total_num', type=int, default=30, help='Number of epochs to train.') 129 | parser.add_argument('--lr', type=float, default=0.4, help='Learning rate.') 130 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 131 | parser.add_argument('--val_step', type=int, default=2000, help='Evaluate model each val_step during each epoch.') 132 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='Weight decay.') 133 | parser.add_argument('--device', '-d', default=0, type=int) 134 | parser.add_argument('--snap_folder', type=str, default='./snapshots/', help='Folder to save snapshots.') 135 | parser.add_argument('--snap_prefix', type=str, default='LandmarksNet', help='Prefix for snapshots.') 136 | parser.add_argument('--snap_to_resume', type=str, default=None, help='Snapshot to resume.') 137 | parser.add_argument('--dataset', choices=['vgg', 'celeb', 'ngd'], type=str, default='vgg', help='Dataset.') 138 | arguments = parser.parse_args() 139 | 140 | with torch.cuda.device(arguments.device): 141 | train(arguments) 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grib0ed0v/face_recognition.pytorch/05cb9b30e8220445fcb27988926d88f330091c12/utils/__init__.py -------------------------------------------------------------------------------- /utils/augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import math 15 | import torch 16 | import numpy as np 17 | import cv2 as cv 18 | 19 | try: 20 | from .face_align import FivePointsAligner 21 | except (ImportError, SystemError) as exp: 22 | from face_align import FivePointsAligner 23 | 24 | 25 | class HorizontalFlipNumpy: 26 | """Horizontal flip augmentation with probability p""" 27 | def __init__(self, p=.5): 28 | assert 0 <= p <= 1. 29 | self.p = p 30 | 31 | def __call__(self, img): 32 | if float(torch.FloatTensor(1).uniform_()) < self.p: 33 | return cv.flip(img, 1) 34 | return img 35 | 36 | 37 | class ShowTransform: 38 | """Show image using opencv""" 39 | def __call__(self, sample): 40 | img = np.array(sample) 41 | cv.imshow('image', img) 42 | cv.waitKey() 43 | return sample 44 | 45 | 46 | class NumpyToTensor: 47 | """Converts a numpy array to torch.Tensor with optionally swapping R and B channels""" 48 | def __init__(self, switch_rb=False): 49 | self.switch_rb = switch_rb 50 | 51 | def __call__(self, image): 52 | # swap color axis because 53 | # numpy image: H x W x C 54 | # torch image: C X H X W 55 | if self.switch_rb: 56 | image = cv.cvtColor(image, cv.COLOR_RGB2BGR) 57 | image = image.transpose((2, 0, 1)) 58 | return torch.from_numpy(image).type(torch.FloatTensor) / 255. 59 | 60 | 61 | class RandomShiftNumpy: 62 | """Shifts an image by a randomly generated offset along x and y axes""" 63 | def __init__(self, max_rel_shift, p=.5): 64 | self.p = p 65 | self.max_rel_shift = max_rel_shift 66 | 67 | def __call__(self, image): 68 | if float(torch.FloatTensor(1).uniform_()) < self.p: 69 | rel_shift = 2 * (torch.FloatTensor(1).uniform_() - .5) * self.max_rel_shift 70 | h, w = image.shape[:2] 71 | shift_w = w * rel_shift 72 | shift_h = h * rel_shift 73 | transl_mat = np.array([[1., 0., shift_w], [0., 1., shift_h]]) 74 | image = cv.warpAffine(image, transl_mat, (w, h)) 75 | 76 | return image 77 | 78 | 79 | class RandomRotationNumpy: 80 | """Rotates an image around it's center by a randomly generated angle""" 81 | def __init__(self, max_angle, p=.5): 82 | self.max_angle = max_angle 83 | self.p = p 84 | 85 | def __call__(self, image): 86 | if float(torch.FloatTensor(1).uniform_()) < self.p: 87 | angle = 2 * (torch.FloatTensor(1).uniform_() - .5) * self.max_angle 88 | h, w = image.shape[:2] 89 | rot_mat = cv.getRotationMatrix2D((w * 0.5, h * 0.5), angle, 1.) 90 | image = cv.warpAffine(image, rot_mat, (w, h), flags=cv.INTER_LANCZOS4) 91 | 92 | return image 93 | 94 | 95 | class ResizeNumpy: 96 | """Resizes an image in numpy format""" 97 | def __init__(self, output_size): 98 | assert isinstance(output_size, (int, tuple)) 99 | self.output_size = output_size 100 | 101 | def __call__(self, image): 102 | h, w = image.shape[:2] 103 | if isinstance(self.output_size, int): 104 | if h > w: 105 | new_h, new_w = self.output_size * h / w, self.output_size 106 | else: 107 | new_h, new_w = self.output_size, self.output_size * w / h 108 | else: 109 | new_h, new_w = self.output_size 110 | 111 | new_h, new_w = int(new_h), int(new_w) 112 | img = cv.resize(image, (new_h, new_w)) 113 | return img 114 | 115 | 116 | class CenterCropNumpy: 117 | """Performs a center crop of an images""" 118 | def __init__(self, output_size): 119 | assert isinstance(output_size, (int, tuple)) 120 | self.output_size = output_size 121 | 122 | def __call__(self, image): 123 | h, w = image.shape[:2] 124 | if isinstance(self.output_size, int): 125 | new_h, new_w = self.output_size, self.output_size 126 | else: 127 | new_h, new_w = self.output_size 128 | 129 | s_h = int(h / 2 - new_h / 2) 130 | s_w = int(w / 2 - new_w / 2) 131 | image = image[s_h: s_h + new_h, s_w: s_w + new_w] 132 | return image 133 | 134 | 135 | class BlurNumpy: 136 | """Blurs an image with the given sigma and probability""" 137 | def __init__(self, p, k): 138 | self.p = p 139 | assert k % 2 == 1 140 | self.k = k 141 | 142 | def __call__(self, img): 143 | if float(torch.FloatTensor(1).uniform_()) < self.p: 144 | img = cv.blur(img, (self.k, self.k)) 145 | return img 146 | 147 | 148 | class CutOutWithPrior: 149 | """Cuts rectangular patches from an image around pre-defined landmark locations""" 150 | def __init__(self, p, max_area): 151 | self.p = p 152 | self.max_area = max_area 153 | 154 | # use after resize transform 155 | def __call__(self, img): 156 | height, width = img.shape[:2] 157 | keypoints_ref = np.zeros((5, 2), dtype=np.float32) 158 | keypoints_ref[:, 0] = FivePointsAligner.ref_landmarks[:, 0] * width 159 | keypoints_ref[:, 1] = FivePointsAligner.ref_landmarks[:, 1] * height 160 | 161 | if float(torch.FloatTensor(1).uniform_()) < self.p: 162 | erase_num = torch.LongTensor(1).random_(1, 4) 163 | erase_ratio = torch.FloatTensor(1).uniform_(self.max_area / 2, self.max_area) 164 | erase_h = math.sqrt(erase_ratio) / float(erase_num) * height 165 | erase_w = math.sqrt(erase_ratio) / float(erase_num) * width 166 | 167 | erased_idx = [] 168 | for _ in range(erase_num): 169 | erase_pos = int(torch.LongTensor(1).random_(0, 5)) 170 | while erase_pos in erased_idx: 171 | erase_pos = int(torch.LongTensor(1).random_(0, 5)) 172 | 173 | left_corner = ( 174 | int(keypoints_ref[erase_pos][0] - erase_h / 2), int(keypoints_ref[erase_pos][1] - erase_w / 2)) 175 | right_corner = ( 176 | int(keypoints_ref[erase_pos][0] + erase_h / 2), int(keypoints_ref[erase_pos][1] + erase_w / 2)) 177 | 178 | cv.rectangle(img, tuple(left_corner), tuple(right_corner), (0, 0, 0), thickness=-1) 179 | erased_idx.append(erase_pos) 180 | 181 | return img 182 | -------------------------------------------------------------------------------- /utils/face_align.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import cv2 as cv 15 | import numpy as np 16 | 17 | 18 | class FivePointsAligner(): 19 | """This class performs face alignmet by five reference points""" 20 | ref_landmarks = np.array([30.2946 / 96, 51.6963 / 112, 21 | 65.5318 / 96, 51.5014 / 112, 22 | 48.0252 / 96, 71.7366 / 112, 23 | 33.5493 / 96, 92.3655 / 112, 24 | 62.7299 / 96, 92.2041 / 112], dtype=np.float64).reshape(5, 2) 25 | @staticmethod 26 | def align(img, landmarks, d_size=(400, 400), normalized=False, show=False): 27 | """Transforms given image in such a way that landmarks are located near ref_landmarks after transformation""" 28 | assert len(landmarks) == 10 29 | assert isinstance(img, np.ndarray) 30 | landmarks = np.array(landmarks).reshape(5, 2) 31 | dw, dh = d_size 32 | 33 | keypoints = landmarks.copy().astype(np.float64) 34 | if normalized: 35 | keypoints[:, 0] *= img.shape[1] 36 | keypoints[:, 1] *= img.shape[0] 37 | 38 | keypoints_ref = np.zeros((5, 2), dtype=np.float64) 39 | keypoints_ref[:, 0] = FivePointsAligner.ref_landmarks[:, 0] * dw 40 | keypoints_ref[:, 1] = FivePointsAligner.ref_landmarks[:, 1] * dh 41 | 42 | transform_matrix = transformation_from_points(keypoints_ref, keypoints) 43 | output_im = cv.warpAffine(img, transform_matrix, d_size, flags=cv.WARP_INVERSE_MAP) 44 | 45 | if show: 46 | tmp_output = output_im.copy() 47 | for point in keypoints_ref: 48 | cv.circle(tmp_output, (int(point[0]), int(point[1])), 5, (255, 0, 0), -1) 49 | for point in keypoints: 50 | cv.circle(img, (int(point[0]), int(point[1])), 5, (255, 0, 0), -1) 51 | img = cv.resize(img, d_size) 52 | cv.imshow('source/warped', np.hstack((img, tmp_output))) 53 | cv.waitKey() 54 | 55 | return output_im 56 | 57 | 58 | def transformation_from_points(points1, points2): 59 | """Builds an affine transformation matrix form points1 to points2""" 60 | points1 = points1.astype(np.float64) 61 | points2 = points2.astype(np.float64) 62 | 63 | c1 = np.mean(points1, axis=0) 64 | c2 = np.mean(points2, axis=0) 65 | points1 -= c1 66 | points2 -= c2 67 | 68 | s1 = np.std(points1) 69 | s2 = np.std(points2) 70 | points1 /= s1 71 | points2 /= s2 72 | 73 | u, _, vt = np.linalg.svd(np.matmul(points1.T, points2)) 74 | r = np.matmul(u, vt).T 75 | 76 | return np.hstack(((s2 / s1) * r, (c2.T - (s2 / s1) * np.matmul(r, c1.T)).reshape(2, -1))) 77 | -------------------------------------------------------------------------------- /utils/ie_tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import sys 15 | import os 16 | 17 | import glog as log 18 | import numpy as np 19 | from openvino.inference_engine import IENetwork, IEPlugin # pylint: disable=import-error,E0611 20 | 21 | class IEModel: 22 | """Class for inference of models in the Inference Engine format""" 23 | def __init__(self, exec_net, inputs_info, input_key, output_key): 24 | self.net = exec_net 25 | self.inputs_info = inputs_info 26 | self.input_key = input_key 27 | self.output_key = output_key 28 | 29 | def forward(self, img): 30 | """Performs forward pass of the wrapped IE model""" 31 | res = self.net.infer(inputs={self.input_key: np.expand_dims(img.transpose(2, 0, 1), axis=0)}) 32 | return np.copy(res[self.output_key]) 33 | 34 | def get_input_shape(self): 35 | """Returns an input shape of the wrapped IE model""" 36 | return self.inputs_info[self.input_key] 37 | 38 | 39 | def load_ie_model(model_xml, device, plugin_dir, cpu_extension=''): 40 | """Loads a model in the Inference Engine format""" 41 | model_bin = os.path.splitext(model_xml)[0] + ".bin" 42 | # Plugin initialization for specified device and load extensions library if specified 43 | plugin = IEPlugin(device=device, plugin_dirs=plugin_dir) 44 | if cpu_extension and 'CPU' in device: 45 | plugin.add_cpu_extension(cpu_extension) 46 | # Read IR 47 | log.info("Loading network files:\n\t%s\n\t%s", model_xml, model_bin) 48 | net = IENetwork(model=model_xml, weights=model_bin) 49 | 50 | if "CPU" in plugin.device: 51 | supported_layers = plugin.get_supported_layers(net) 52 | not_supported_layers = [l for l in net.layers.keys() if l not in supported_layers] 53 | if not_supported_layers: 54 | log.error("Following layers are not supported by the plugin for specified device %s:\n %s", 55 | plugin.device, ', '.join(not_supported_layers)) 56 | log.error("Please try to specify cpu extensions library path in sample's command line parameters using -l " 57 | "or --cpu_extension command line argument") 58 | sys.exit(1) 59 | 60 | assert len(net.inputs.keys()) == 1, "Checker supports only single input topologies" 61 | assert len(net.outputs) == 1, "Checker supports only single output topologies" 62 | 63 | log.info("Preparing input blobs") 64 | input_blob = next(iter(net.inputs)) 65 | out_blob = next(iter(net.outputs)) 66 | net.batch_size = 1 67 | 68 | # Loading model to the plugin 69 | log.info("Loading model to the plugin") 70 | exec_net = plugin.load(network=net) 71 | model = IEModel(exec_net, net.inputs, input_blob, out_blob) 72 | del net 73 | return model 74 | -------------------------------------------------------------------------------- /utils/landmarks_augmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | import cv2 as cv 15 | import numpy as np 16 | import torch 17 | 18 | 19 | class Rescale: 20 | """Resizes an image and corresponding landmarks""" 21 | def __init__(self, output_size): 22 | assert isinstance(output_size, (int, tuple)) 23 | self.output_size = output_size 24 | 25 | def __call__(self, sample): 26 | image, landmarks = sample['img'], sample['landmarks'] 27 | 28 | h, w = image.shape[:2] 29 | if isinstance(self.output_size, int): 30 | if w > h: 31 | new_h, new_w = self.output_size, self.output_size * w / h 32 | else: 33 | new_h, new_w = self.output_size * h / w, self.output_size 34 | else: 35 | new_h, new_w = self.output_size 36 | new_h, new_w = int(new_h), int(new_w) 37 | img = cv.resize(image, (new_h, new_w)) 38 | return {'img': img, 'landmarks': landmarks} 39 | 40 | 41 | class RandomCrop: 42 | """Makes a random crop from the source image with corresponding transformation of landmarks""" 43 | def __init__(self, output_size): 44 | assert isinstance(output_size, (int, tuple)) 45 | if isinstance(output_size, int): 46 | self.output_size = (output_size, output_size) 47 | else: 48 | assert len(output_size) == 2 49 | self.output_size = output_size 50 | 51 | def __call__(self, sample): 52 | image, landmarks = sample['img'], sample['landmarks'].reshape(-1, 2) 53 | 54 | h, w = image.shape[:2] 55 | new_h, new_w = self.output_size 56 | 57 | top = np.random.randint(0, h - new_h) 58 | left = np.random.randint(0, w - new_w) 59 | 60 | image = image[top: top + new_h, 61 | left: left + new_w] 62 | 63 | landmarks = landmarks - [left / float(w), top / float(h)] 64 | for point in landmarks: 65 | point[0] *= float(h) / new_h 66 | point[1] *= float(w) / new_w 67 | 68 | return {'img': image, 'landmarks': landmarks} 69 | 70 | 71 | class HorizontalFlip: 72 | """Flips an input image and landmarks horizontally with a given probability""" 73 | def __init__(self, p=.5): 74 | self.p = p 75 | 76 | def __call__(self, sample): 77 | image, landmarks = sample['img'], sample['landmarks'].reshape(-1, 2) 78 | 79 | if float(torch.FloatTensor(1).uniform_()) < self.p: 80 | image = cv.flip(image, 1) 81 | landmarks = landmarks.reshape(5, 2) 82 | landmarks[:, 0] = 1. - landmarks[:, 0] 83 | tmp = np.copy(landmarks[0]) 84 | landmarks[0] = landmarks[1] 85 | landmarks[1] = tmp 86 | 87 | tmp = np.copy(landmarks[3]) 88 | landmarks[3] = landmarks[4] 89 | landmarks[4] = tmp 90 | 91 | return {'img': image, 'landmarks': landmarks} 92 | 93 | 94 | class Blur: 95 | """Blurs an image with the given sigma and probability""" 96 | def __init__(self, p, k): 97 | self.p = p 98 | assert k % 2 == 1 99 | self.k = k 100 | 101 | def __call__(self, sample): 102 | image, landmarks = sample['img'], sample['landmarks'] 103 | 104 | if float(torch.FloatTensor(1).uniform_()) < self.p: 105 | image = cv.blur(image, (self.k, self.k)) 106 | 107 | return {'img': image, 'landmarks': landmarks} 108 | 109 | 110 | class Show: 111 | """Show image using opencv""" 112 | def __call__(self, sample): 113 | image, landmarks = sample['img'].copy(), sample['landmarks'].reshape(-1, 2) 114 | h, w = image.shape[:2] 115 | for point in landmarks: 116 | cv.circle(image, (int(point[0]*w), int(point[1]*h)), 3, (255, 0, 0), -1) 117 | cv.imshow('image', image) 118 | cv.waitKey() 119 | return sample 120 | 121 | 122 | class RandomRotate: 123 | """ 124 | Rotates an image around it's center by a randomly generated angle. 125 | Also performs the same transformation with landmark points. 126 | """ 127 | def __init__(self, max_angle, p=.5): 128 | self.max_angle = max_angle 129 | self.p = p 130 | 131 | def __call__(self, sample): 132 | image, landmarks = sample['img'], sample['landmarks'] 133 | 134 | if float(torch.FloatTensor(1).uniform_()) < self.p: 135 | angle = 2*(torch.FloatTensor(1).uniform_() - .5)*self.max_angle 136 | h, w = image.shape[:2] 137 | rot_mat = cv.getRotationMatrix2D((w*0.5, h*0.5), angle, 1.) 138 | image = cv.warpAffine(image, rot_mat, (w, h), flags=cv.INTER_LANCZOS4) 139 | rot_mat_l = cv.getRotationMatrix2D((0.5, 0.5), angle, 1.) 140 | landmarks = cv.transform(landmarks.reshape(1, 5, 2), rot_mat_l).reshape(5, 2) 141 | 142 | return {'img': image, 'landmarks': landmarks} 143 | 144 | 145 | class ToTensor: 146 | """Convert ndarrays in sample to Tensors.""" 147 | def __init__(self, switch_rb=False): 148 | self.switch_rb = switch_rb 149 | 150 | def __call__(self, sample): 151 | image, landmarks = sample['img'], sample['landmarks'] 152 | # swap color axis because 153 | # numpy image: H x W x C 154 | # torch image: C X H X W 155 | if self.switch_rb: 156 | image = cv.cvtColor(image, cv.COLOR_RGB2BGR) 157 | image = image.transpose((2, 0, 1)) 158 | return {'img': torch.from_numpy(image).type(torch.FloatTensor) / 255, 159 | 'landmarks': torch.from_numpy(landmarks).type(torch.FloatTensor).view(-1, 1, 1)} 160 | 161 | 162 | class RandomScale: 163 | """Performs uniform scale with a random magnitude""" 164 | def __init__(self, max_scale, min_scale, p=.5): 165 | self.max_scale = max_scale 166 | self.min_scale = min_scale 167 | self.p = p 168 | 169 | def __call__(self, sample): 170 | image, landmarks = sample['img'], sample['landmarks'] 171 | 172 | if float(torch.FloatTensor(1).uniform_()) < self.p: 173 | scale = self.min_scale + torch.FloatTensor(1).uniform_()*(self.max_scale - self.min_scale) 174 | h, w = image.shape[:2] 175 | rot_mat = cv.getRotationMatrix2D((w*0.5, h*0.5), 0, scale) 176 | image = cv.warpAffine(image, rot_mat, (w, h), flags=cv.INTER_LANCZOS4) 177 | rot_mat_l = cv.getRotationMatrix2D((0.5, 0.5), 0, scale) 178 | landmarks = cv.transform(landmarks.reshape(1, 5, 2), rot_mat_l).reshape(5, 2) 179 | 180 | return {'img': image, 'landmarks': landmarks} 181 | -------------------------------------------------------------------------------- /utils/parser_yaml.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | from argparse import ArgumentParser 15 | import yaml 16 | 17 | class ArgumentParserWithYaml(ArgumentParser): 18 | """ 19 | Attention, this will work with simple yaml files only, and if there is no action=store_false 20 | """ 21 | @staticmethod 22 | def _check_arg_line_repr_None(arg_line, k, v): 23 | """ The method is required, since by default python prints None value as None, whereas yaml waiths for null """ 24 | s = arg_line.strip() 25 | prefixes = [k, "'" + k + "'", '"' + k + '"'] 26 | is_ok = False 27 | for prefix in prefixes: 28 | if s.startswith(prefix): 29 | s = s[len(prefix):] 30 | is_ok = True 31 | break 32 | if not is_ok: 33 | raise RuntimeError("Unknown prefix in line '{}', k = '{}', v = '{}'".format(arg_line, k, v)) 34 | s = s.strip() 35 | assert s.startswith(':'), "Bad format of line '{}', k = '{}', v = '{}'".format(arg_line, k, v) 36 | s = s[1:] 37 | s = s.strip() 38 | #print("arg line '{}' repr None = {}, s = '{}'".format(arg_line, s == "None", s)) 39 | 40 | return s == "None" #note that 'None' will be a string, whereas just None will be None 41 | 42 | def convert_arg_line_to_args(self, arg_line): 43 | arg_line = arg_line.strip() 44 | if not arg_line: 45 | return [] 46 | if arg_line.endswith(','): 47 | arg_line = arg_line[:-1] 48 | 49 | data = yaml.load(arg_line) 50 | if data is None: 51 | return [] 52 | assert type(data) is dict 53 | assert len(data) == 1 54 | 55 | res = [] 56 | for k, v in data.items(): 57 | if v == 'None': # default value is None -- skipping 58 | if self._check_arg_line_repr_None(arg_line, k, v): #additional check that somebody passed string "None" 59 | continue 60 | else: 61 | print("WARNING: DURING PARSING ARGUMENTS FILE: possible error in the argument line '{}' -- probably None value is missed".format(arg_line)) 62 | 63 | if type(v) is list: 64 | res.append('--' + str(k)) 65 | [res.append(str(item)) for item in v] 66 | continue 67 | 68 | if type(v) is bool: # special case, action=store_true, do not use store_false! 69 | if v: 70 | res.append('--' + str(k)) 71 | continue 72 | 73 | # attention, there may be small issue with converting float -> string -> float -> string 74 | res.extend(['--' + str(k), str(v)]) 75 | 76 | return res 77 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2018 Intel Corporation 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | """ 13 | 14 | from collections import OrderedDict 15 | 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | 19 | 20 | def save_model_cpu(net, optim, ckpt_fname, epoch, write_solverstate=False): 21 | """Saves model weights and optimizer state (optionally) to a file""" 22 | state_dict = net.state_dict() 23 | for key in state_dict.keys(): 24 | state_dict[key] = state_dict[key].cpu() 25 | snapshot_dict = { 26 | 'epoch': epoch, 27 | 'state_dict': state_dict} 28 | 29 | if write_solverstate: 30 | snapshot_dict['optimizer'] = optim 31 | 32 | torch.save(snapshot_dict, ckpt_fname) 33 | 34 | 35 | def get_model_parameters_number(model, as_string=True): 36 | """Returns a total number of trainable parameters in a specified model""" 37 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | if not as_string: 39 | return params_num 40 | 41 | if params_num // 10 ** 6 > 0: 42 | flops_str = str(round(params_num / 10. ** 6, 2)) + 'M' 43 | elif params_num // 10 ** 3 > 0: 44 | flops_str = str(round(params_num / 10. ** 3, 2)) + 'k' 45 | else: 46 | flops_str = str(params_num) 47 | return flops_str 48 | 49 | 50 | def load_model_state(model, snap, device_id, eval_state=True): 51 | """Loads model weight from a file produced by save_model_cpu""" 52 | if device_id != -1: 53 | location = 'cuda:' + str(device_id) 54 | else: 55 | location = 'cpu' 56 | state_dict = torch.load(snap, map_location=location)['state_dict'] 57 | 58 | new_state_dict = OrderedDict() 59 | for k, v in state_dict.items(): 60 | head = k[:7] 61 | if head == 'module.': 62 | name = k[7:] # remove `module.` 63 | else: 64 | name = k 65 | new_state_dict[name] = v 66 | 67 | model.load_state_dict(new_state_dict, strict=False) 68 | 69 | if device_id != -1: 70 | model.cuda(device_id) 71 | cudnn.benchmark = True 72 | 73 | if eval_state: 74 | model.eval() 75 | else: 76 | model.train() 77 | 78 | return model 79 | 80 | 81 | def flip_tensor(x, dim): 82 | """Flips a tensor along the specified axis""" 83 | xsize = x.size() 84 | dim = x.dim() + dim if dim < 0 else dim 85 | x = x.view(-1, *xsize[dim:]) 86 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, -1, -1), 87 | ('cpu', 'cuda')[x.is_cuda])().long(), :] 88 | return x.view(xsize) 89 | --------------------------------------------------------------------------------