├── .gitignore ├── .idea ├── Face_Pytorch.iml ├── codeStyles │ └── codeStyleConfig.xml ├── deployment.xml ├── encodings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── LICENSE ├── README.md ├── backbone ├── __init__.py ├── arcfacenet.py ├── attention.py ├── cbam.py ├── mobilefacenet.py ├── resnet.py └── spherenet.py ├── cppapi ├── README.md └── pytorch2torchscript.py ├── dataset ├── __init__.py ├── agedb.py ├── casia_webface.py ├── cfp.py ├── lfw.py ├── lfw_2.py └── megaface.py ├── eval_agedb30.py ├── eval_cfp.py ├── eval_deepglint_merge.py ├── eval_lfw.py ├── eval_lfw_blufr.py ├── eval_megaface.py ├── lossfunctions ├── __init__.py ├── agentcenterloss.py └── centerloss.py ├── margin ├── ArcMarginProduct.py ├── CosineMarginProduct.py ├── InnerProduct.py ├── MultiMarginProduct.py ├── SphereMarginProduct.py └── __init__.py ├── model ├── MSCeleb_MOBILEFACE_20181228_170458 │ └── log.log └── MSCeleb_SERES50_IR_20181229_211407 │ └── log.log ├── result ├── softmax.gif ├── softmax_center.gif └── visualization.jpg ├── train.py ├── train_center.py ├── train_softmax.py └── utils ├── README.md ├── __init__.py ├── load_images_from_bin.py ├── logging.py ├── plot_logit.py ├── plot_theta.py ├── theta_distribution_hist.jpg └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | local_settings.py 56 | db.sqlite3 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | env.bak/ 90 | venv.bak/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | # model or feature 106 | .mat 107 | .ckpt -------------------------------------------------------------------------------- /.idea/Face_Pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | >2019.07.14 2 | >Currently, I have graduated from campus and doing another kind of job. So this project may not be updated again. 3 | 4 | ## Face_Pytorch 5 | The implementation of popular face recognition algorithms in pytorch framework, including arcface, cosface and sphereface and so on. 6 | 7 | All codes are evaluated on Pytorch 0.4.0 with Python 3.6, Ubuntu 16.04.10, CUDA 9.1 and CUDNN 7.1. Partially evaluated on Pytorch 1.0. 8 | 9 | 10 | ## Data Preparation 11 | For CNN training, I use CASIA-WebFace and Cleaned MS-Celeb-1M, aligned by MTCNN with the size of 112x112. For performance testing, I report the results on LFW, AgeDB-30, CFP-FP, MegaFace rank1 identification and verification. 12 | 13 | For AgeDB-30 and CFP-FP, the aligned images and evaluation images pairs are restored from the mxnet binary file provided by [insightface](https://github.com/deepinsight/insightface), tools are available in this repository. You should install a mxnet-cpu first for the image parsing, just do ' **pip install mxnet** ' is ok. 14 | [LFW @ BaiduNetdisk](https://pan.baidu.com/s/1Rue4FBmGvdGMPkyy2ZqcdQ), [AgeDB-30 @ BaiduNetdisk](https://pan.baidu.com/s/1sdw1lO5JfP6Ja99O7zprUg), [CFP_FP @ BaiduNetdisk](https://pan.baidu.com/s/1gyFAAy427weUd2G-ozMgEg) 15 | 16 | ## Results 17 | > MobileFaceNet: Struture described in MobileFaceNet 18 | > ResNet50: Original resnet structure 19 | > ResNet50-IR: CNN described in ArcFace paper 20 | > SEResNet50-IR: CNN described in ArcFace paper 21 | ### Verification results on LFW, AgeDB-30 and CFP_FP 22 | Small Protocol: trained with CASIA-WebFace of data size: 453580/10575 23 | Large Protocol: trained with DeepGlint MS-Celeb-1M of data size: 3923399/86876 24 | 25 | Model Type | Loss | LFW | AgeDB-30 | CFP-FP | Model Size | protocol 26 | :--------------:|:---------:|:-------:|:--------:|:-------|:----------:|:--------: 27 | MobileFaceNet | ArcFace | 99.23 | 93.26 | 94.34 | 4MB | small 28 | ResNet50-IR | ArcFace | 99.42 | 94.45 | 95.34 | 170MB | small 29 | SEResNet50-IR | ArcFace | 99.43 | 94.50 | 95.43 | 171MB | small 30 | MobileFaceNet | ArcFace | 99.58 | 96.57 | 92.90 | 4MB | large 31 | ResNet50-IR | ArcFace | 99.82 | 98.07 | 95.34 | 170MB | large 32 | SEResNet50-IR | ArcFace | 99.80 | 98.13 | 95.60 | 171MB | large 33 | ResNet100-IR | ArcFace | 99.83 | 98.28 | 96.41 | 256MB | large 34 | 35 | There exists an odd result fact that when training under small protocol, CFP-FP performances better than AgeDB-30, while when training with large scale dataset, CFP-FP performances worse than AgeDB-30. 36 | 37 | ### MegaFace rank 1 identifiaction accuracy and verfication@FPR=1e-6 results 38 | 39 | Model Type | Loss | MF Acc. | MF Ver. | MF Acc.@R | MF Ver.@R | SIZE | protocol 40 | :--------------:|:---------:|:-------:|:-------:|:---------:|:---------:|:-----:|:-------: 41 | MobileFaceNet | ArcFace | 69.10 | 84.23 | 81.15 | 85.86 | 4MB | small 42 | ResNet50-IR | ArcFace | 74.31 | 88.23 | 87.44 | 89.56 | 170MB | small 43 | SEResNet50-IR | ArcFace | 74.37 | 88.32 | 88.30 | 89.65 | 171MB | small 44 | MobileFaceNet | ArcFace | 74.95 | 88.77 | 89.47 | 91.03 | 4MB | large 45 | ResNet50-IR | ArcFace | 79.61 | 96.02 | 96.58 | 96.78 | 170MB | large 46 | SEResNet50-IR | ArcFace | 79.91 | 96.10 | 97.01 | 97.60 | 171MB | large 47 | ResNet100-IR | ArcFace | 80.40 | 96.94 | 97.60 | 98.05 | 256MB | large 48 | 49 | 50 | 51 | ## Usage 52 | 1. Download the source code to your machine. 53 | 2. Prepare the train dataset and train list, test dataset and test verification pairs. 54 | 3. Set your own dataset path and any other parameters in train.py. 55 | 4. Run train.py file, test accuracy will print into log file during training process. 56 | --- 57 | 5. Every evaluate file can work independently for the model test process. just set your own args in the file. 58 | 59 | ## Visualization 60 | Visdom support for loss and accuracy during training process. 61 | ![avatar](result/visualization.jpg) 62 | 63 | 64 | Softmax Loss vs Softmax_Center Loss. Left: softmax training set. Right: softmax + center loss training set. 65 |
66 | train 67 | train 68 |
69 | 70 | ## References 71 | [MuggleWang/CosFace_pytorch](https://github.com/MuggleWang/CosFace_pytorch) 72 | [Xiaoccer/MobileFaceNet_Pytorch](https://github.com/Xiaoccer/MobileFaceNet_Pytorch) 73 | [TreB1eN/InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch) 74 | [deepinsight/insightface](https://github.com/deepinsight/insightface) 75 | [KaiyangZhou/pytorch-center-loss](https://github.com/KaiyangZhou/pytorch-center-loss) 76 | [tengshaofeng/ResidualAttentionNetwork-pytorch](https://github.com/tengshaofeng/ResidualAttentionNetwork-pytorch) 77 | 78 | ## Todo 79 | 1. Report the test results on DeepGlint Trillion Pairs Challenge. 80 | 2. Add C++ api for fast deployment with pytorch 1.0. 81 | 3. Train the ResNet100-based model. 82 | -------------------------------------------------------------------------------- /backbone/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: __init__.py.py 7 | @time: 2018/12/21 15:30 8 | @desc: 9 | ''' -------------------------------------------------------------------------------- /backbone/arcfacenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: arcfacenet.py 7 | @time: 2018/12/26 10:15 8 | @desc: Network structures used in the arcface paper, including ResNet50-IR, ResNet101-IR, SEResNet50-IR, SEResNet101-IR 9 | 10 | '''''' 11 | Update: This file has been deprecated, all the models build in this class have been rebuild in cbam.py 12 | Yet the code in this file still works. 13 | ''' 14 | 15 | 16 | import torch 17 | from torch import nn 18 | from collections import namedtuple 19 | 20 | class Flatten(nn.Module): 21 | def forward(self, input): 22 | return input.view(input.size(0), -1) 23 | 24 | 25 | class SEModule(nn.Module): 26 | def __init__(self, channels, reduction): 27 | super(SEModule, self).__init__() 28 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 29 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 32 | self.sigmoid = nn.Sigmoid() 33 | 34 | def forward(self, x): 35 | input = x 36 | x = self.avg_pool(x) 37 | x = self.fc1(x) 38 | x = self.relu(x) 39 | x = self.fc2(x) 40 | x = self.sigmoid(x) 41 | 42 | return input * x 43 | 44 | 45 | class BottleNeck_IR(nn.Module): 46 | def __init__(self, in_channel, out_channel, stride): 47 | super(BottleNeck_IR, self).__init__() 48 | if in_channel == out_channel: 49 | self.shortcut_layer = nn.MaxPool2d(1, stride) 50 | else: 51 | self.shortcut_layer = nn.Sequential( 52 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), 53 | nn.BatchNorm2d(out_channel) 54 | ) 55 | 56 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), 57 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), 58 | nn.BatchNorm2d(out_channel), 59 | nn.PReLU(out_channel), 60 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), 61 | nn.BatchNorm2d(out_channel)) 62 | 63 | def forward(self, x): 64 | shortcut = self.shortcut_layer(x) 65 | res = self.res_layer(x) 66 | 67 | return shortcut + res 68 | 69 | class BottleNeck_IR_SE(nn.Module): 70 | def __init__(self, in_channel, out_channel, stride): 71 | super(BottleNeck_IR_SE, self).__init__() 72 | if in_channel == out_channel: 73 | self.shortcut_layer = nn.MaxPool2d(1, stride) 74 | else: 75 | self.shortcut_layer = nn.Sequential( 76 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), 77 | nn.BatchNorm2d(out_channel) 78 | ) 79 | 80 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), 81 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), 82 | nn.BatchNorm2d(out_channel), 83 | nn.PReLU(out_channel), 84 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), 85 | nn.BatchNorm2d(out_channel), 86 | SEModule(out_channel, 16)) 87 | 88 | def forward(self, x): 89 | shortcut = self.shortcut_layer(x) 90 | res = self.res_layer(x) 91 | 92 | return shortcut + res 93 | 94 | 95 | class Bottleneck(namedtuple('Block', ['in_channel', 'out_channel', 'stride'])): 96 | '''A named tuple describing a ResNet block.''' 97 | 98 | 99 | def get_block(in_channel, out_channel, num_units, stride=2): 100 | return [Bottleneck(in_channel, out_channel, stride)] + [Bottleneck(out_channel, out_channel, 1) for i in range(num_units - 1)] 101 | 102 | 103 | def get_blocks(num_layers): 104 | if num_layers == 50: 105 | blocks = [ 106 | get_block(in_channel=64, out_channel=64, num_units=3), 107 | get_block(in_channel=64, out_channel=128, num_units=4), 108 | get_block(in_channel=128, out_channel=256, num_units=14), 109 | get_block(in_channel=256, out_channel=512, num_units=3) 110 | ] 111 | elif num_layers == 100: 112 | blocks = [ 113 | get_block(in_channel=64, out_channel=64, num_units=3), 114 | get_block(in_channel=64, out_channel=128, num_units=13), 115 | get_block(in_channel=128, out_channel=256, num_units=30), 116 | get_block(in_channel=256, out_channel=512, num_units=3) 117 | ] 118 | elif num_layers == 152: 119 | blocks = [ 120 | get_block(in_channel=64, out_channel=64, num_units=3), 121 | get_block(in_channel=64, out_channel=128, num_units=8), 122 | get_block(in_channel=128, out_channel=256, num_units=36), 123 | get_block(in_channel=256, out_channel=512, num_units=3) 124 | ] 125 | return blocks 126 | 127 | 128 | class SEResNet_IR(nn.Module): 129 | def __init__(self, num_layers, feature_dim=512, drop_ratio=0.4, mode = 'ir'): 130 | super(SEResNet_IR, self).__init__() 131 | assert num_layers in [50, 100, 152], 'num_layers should be 50, 100 or 152' 132 | assert mode in ['ir', 'se_ir'], 'mode should be ir or se_ir' 133 | blocks = get_blocks(num_layers) 134 | if mode == 'ir': 135 | unit_module = BottleNeck_IR 136 | elif mode == 'se_ir': 137 | unit_module = BottleNeck_IR_SE 138 | self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False), 139 | nn.BatchNorm2d(64), 140 | nn.PReLU(64)) 141 | 142 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 143 | nn.Dropout(drop_ratio), 144 | Flatten(), 145 | nn.Linear(512 * 7 * 7, feature_dim), 146 | nn.BatchNorm1d(feature_dim)) 147 | modules = [] 148 | for block in blocks: 149 | for bottleneck in block: 150 | modules.append( 151 | unit_module(bottleneck.in_channel, 152 | bottleneck.out_channel, 153 | bottleneck.stride)) 154 | self.body = nn.Sequential(*modules) 155 | 156 | def forward(self, x): 157 | x = self.input_layer(x) 158 | x = self.body(x) 159 | x = self.output_layer(x) 160 | 161 | return x 162 | 163 | 164 | if __name__ == '__main__': 165 | input = torch.Tensor(2, 3, 112, 112) 166 | net = SEResNet_IR(100, mode='se_ir') 167 | print(net) 168 | 169 | x = net(input) 170 | print(x.shape) -------------------------------------------------------------------------------- /backbone/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: attention.py 7 | @time: 2019/2/14 14:12 8 | @desc: Residual Attention Network for Image Classification, CVPR 2017. 9 | Attention 56 and Attention 92. 10 | ''' 11 | 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | 17 | class Flatten(nn.Module): 18 | def forward(self, input): 19 | return input.view(input.size(0), -1) 20 | 21 | class ResidualBlock(nn.Module): 22 | 23 | def __init__(self, in_channel, out_channel, stride=1): 24 | super(ResidualBlock, self).__init__() 25 | self.in_channel = in_channel 26 | self.out_channel = out_channel 27 | self.stride = stride 28 | 29 | self.res_bottleneck = nn.Sequential(nn.BatchNorm2d(in_channel), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(in_channel, out_channel//4, 1, 1, bias=False), 32 | nn.BatchNorm2d(out_channel//4), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(out_channel//4, out_channel//4, 3, stride, padding=1, bias=False), 35 | nn.BatchNorm2d(out_channel//4), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(out_channel//4, out_channel, 1, 1, bias=False)) 38 | self.shortcut = nn.Conv2d(in_channel, out_channel, 1, stride, bias=False) 39 | 40 | def forward(self, x): 41 | res = x 42 | out = self.res_bottleneck(x) 43 | if self.in_channel != self.out_channel or self.stride != 1: 44 | res = self.shortcut(x) 45 | 46 | out += res 47 | return out 48 | 49 | class AttentionModule_stage1(nn.Module): 50 | 51 | # input size is 56*56 52 | def __init__(self, in_channel, out_channel, size1=(56, 56), size2=(28, 28), size3=(14, 14)): 53 | super(AttentionModule_stage1, self).__init__() 54 | self.share_residual_block = ResidualBlock(in_channel, out_channel) 55 | self.trunk_branches = nn.Sequential(ResidualBlock(in_channel, out_channel), 56 | ResidualBlock(in_channel, out_channel)) 57 | 58 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 59 | self.mask_block1 = ResidualBlock(in_channel, out_channel) 60 | self.skip_connect1 = ResidualBlock(in_channel, out_channel) 61 | 62 | self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 63 | self.mask_block2 = ResidualBlock(in_channel, out_channel) 64 | self.skip_connect2 = ResidualBlock(in_channel, out_channel) 65 | 66 | self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 67 | self.mask_block3 = nn.Sequential(ResidualBlock(in_channel, out_channel), 68 | ResidualBlock(in_channel, out_channel)) 69 | 70 | self.interpolation3 = nn.UpsamplingBilinear2d(size=size3) 71 | self.mask_block4 = ResidualBlock(in_channel, out_channel) 72 | 73 | self.interpolation2 = nn.UpsamplingBilinear2d(size=size2) 74 | self.mask_block5 = ResidualBlock(in_channel, out_channel) 75 | 76 | self.interpolation1 = nn.UpsamplingBilinear2d(size=size1) 77 | self.mask_block6 = nn.Sequential(nn.BatchNorm2d(out_channel), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(out_channel, out_channel, 1, 1, bias=False), 80 | nn.BatchNorm2d(out_channel), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(out_channel, out_channel, 1, 1, bias=False), 83 | nn.Sigmoid()) 84 | 85 | self.last_block = ResidualBlock(in_channel, out_channel) 86 | 87 | def forward(self, x): 88 | x = self.share_residual_block(x) 89 | out_trunk = self.trunk_branches(x) 90 | 91 | out_pool1 = self.mpool1(x) 92 | out_block1 = self.mask_block1(out_pool1) 93 | out_skip_connect1 = self.skip_connect1(out_block1) 94 | 95 | out_pool2 = self.mpool2(out_block1) 96 | out_block2 = self.mask_block2(out_pool2) 97 | out_skip_connect2 = self.skip_connect2(out_block2) 98 | 99 | out_pool3 = self.mpool3(out_block2) 100 | out_block3 = self.mask_block3(out_pool3) 101 | # 102 | out_inter3 = self.interpolation3(out_block3) + out_block2 103 | out = out_inter3 + out_skip_connect2 104 | out_block4 = self.mask_block4(out) 105 | 106 | out_inter2 = self.interpolation2(out_block4) + out_block1 107 | out = out_inter2 + out_skip_connect1 108 | out_block5 = self.mask_block5(out) 109 | 110 | out_inter1 = self.interpolation1(out_block5) + out_trunk 111 | out_block6 = self.mask_block6(out_inter1) 112 | 113 | out = (1 + out_block6) + out_trunk 114 | out_last = self.last_block(out) 115 | 116 | return out_last 117 | 118 | class AttentionModule_stage2(nn.Module): 119 | 120 | # input image size is 28*28 121 | def __init__(self, in_channels, out_channels, size1=(28, 28), size2=(14, 14)): 122 | super(AttentionModule_stage2, self).__init__() 123 | self.first_residual_blocks = ResidualBlock(in_channels, out_channels) 124 | 125 | self.trunk_branches = nn.Sequential( 126 | ResidualBlock(in_channels, out_channels), 127 | ResidualBlock(in_channels, out_channels) 128 | ) 129 | 130 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.softmax1_blocks = ResidualBlock(in_channels, out_channels) 132 | self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels) 133 | 134 | self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 135 | self.softmax2_blocks = nn.Sequential( 136 | ResidualBlock(in_channels, out_channels), 137 | ResidualBlock(in_channels, out_channels) 138 | ) 139 | 140 | self.interpolation2 = nn.UpsamplingBilinear2d(size=size2) 141 | self.softmax3_blocks = ResidualBlock(in_channels, out_channels) 142 | self.interpolation1 = nn.UpsamplingBilinear2d(size=size1) 143 | self.softmax4_blocks = nn.Sequential( 144 | nn.BatchNorm2d(out_channels), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False), 147 | nn.BatchNorm2d(out_channels), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False), 150 | nn.Sigmoid() 151 | ) 152 | self.last_blocks = ResidualBlock(in_channels, out_channels) 153 | 154 | def forward(self, x): 155 | x = self.first_residual_blocks(x) 156 | out_trunk = self.trunk_branches(x) 157 | out_mpool1 = self.mpool1(x) 158 | out_softmax1 = self.softmax1_blocks(out_mpool1) 159 | out_skip1_connection = self.skip1_connection_residual_block(out_softmax1) 160 | 161 | out_mpool2 = self.mpool2(out_softmax1) 162 | out_softmax2 = self.softmax2_blocks(out_mpool2) 163 | 164 | out_interp2 = self.interpolation2(out_softmax2) + out_softmax1 165 | out = out_interp2 + out_skip1_connection 166 | 167 | out_softmax3 = self.softmax3_blocks(out) 168 | out_interp1 = self.interpolation1(out_softmax3) + out_trunk 169 | out_softmax4 = self.softmax4_blocks(out_interp1) 170 | out = (1 + out_softmax4) * out_trunk 171 | out_last = self.last_blocks(out) 172 | 173 | return out_last 174 | 175 | class AttentionModule_stage3(nn.Module): 176 | 177 | # input image size is 14*14 178 | def __init__(self, in_channels, out_channels, size1=(14, 14)): 179 | super(AttentionModule_stage3, self).__init__() 180 | self.first_residual_blocks = ResidualBlock(in_channels, out_channels) 181 | 182 | self.trunk_branches = nn.Sequential( 183 | ResidualBlock(in_channels, out_channels), 184 | ResidualBlock(in_channels, out_channels) 185 | ) 186 | 187 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 188 | self.softmax1_blocks = nn.Sequential( 189 | ResidualBlock(in_channels, out_channels), 190 | ResidualBlock(in_channels, out_channels) 191 | ) 192 | 193 | self.interpolation1 = nn.UpsamplingBilinear2d(size=size1) 194 | 195 | self.softmax2_blocks = nn.Sequential( 196 | nn.BatchNorm2d(out_channels), 197 | nn.ReLU(inplace=True), 198 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False), 199 | nn.BatchNorm2d(out_channels), 200 | nn.ReLU(inplace=True), 201 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False), 202 | nn.Sigmoid() 203 | ) 204 | 205 | self.last_blocks = ResidualBlock(in_channels, out_channels) 206 | 207 | def forward(self, x): 208 | x = self.first_residual_blocks(x) 209 | out_trunk = self.trunk_branches(x) 210 | out_mpool1 = self.mpool1(x) 211 | out_softmax1 = self.softmax1_blocks(out_mpool1) 212 | 213 | out_interp1 = self.interpolation1(out_softmax1) + out_trunk 214 | out_softmax2 = self.softmax2_blocks(out_interp1) 215 | out = (1 + out_softmax2) * out_trunk 216 | out_last = self.last_blocks(out) 217 | 218 | return out_last 219 | 220 | class ResidualAttentionNet_56(nn.Module): 221 | 222 | # for input size 112 223 | def __init__(self, feature_dim=512, drop_ratio=0.4): 224 | super(ResidualAttentionNet_56, self).__init__() 225 | self.conv1 = nn.Sequential( 226 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias = False), 227 | nn.BatchNorm2d(64), 228 | nn.ReLU(inplace=True) 229 | ) 230 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 231 | self.residual_block1 = ResidualBlock(64, 256) 232 | self.attention_module1 = AttentionModule_stage1(256, 256) 233 | self.residual_block2 = ResidualBlock(256, 512, 2) 234 | self.attention_module2 = AttentionModule_stage2(512, 512) 235 | self.residual_block3 = ResidualBlock(512, 512, 2) 236 | self.attention_module3 = AttentionModule_stage3(512, 512) 237 | self.residual_block4 = ResidualBlock(512, 512, 2) 238 | self.residual_block5 = ResidualBlock(512, 512) 239 | self.residual_block6 = ResidualBlock(512, 512) 240 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 241 | nn.Dropout(drop_ratio), 242 | Flatten(), 243 | nn.Linear(512 * 7 * 7, feature_dim), 244 | nn.BatchNorm1d(feature_dim)) 245 | 246 | def forward(self, x): 247 | out = self.conv1(x) 248 | out = self.mpool1(out) 249 | # print(out.data) 250 | out = self.residual_block1(out) 251 | out = self.attention_module1(out) 252 | out = self.residual_block2(out) 253 | out = self.attention_module2(out) 254 | out = self.residual_block3(out) 255 | # print(out.data) 256 | out = self.attention_module3(out) 257 | out = self.residual_block4(out) 258 | out = self.residual_block5(out) 259 | out = self.residual_block6(out) 260 | out = self.output_layer(out) 261 | 262 | return out 263 | 264 | class ResidualAttentionNet_92(nn.Module): 265 | 266 | # for input size 112 267 | def __init__(self, feature_dim=512, drop_ratio=0.4): 268 | super(ResidualAttentionNet_92, self).__init__() 269 | self.conv1 = nn.Sequential( 270 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias = False), 271 | nn.BatchNorm2d(64), 272 | nn.ReLU(inplace=True) 273 | ) 274 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 275 | self.residual_block1 = ResidualBlock(64, 256) 276 | self.attention_module1 = AttentionModule_stage1(256, 256) 277 | self.residual_block2 = ResidualBlock(256, 512, 2) 278 | self.attention_module2 = AttentionModule_stage2(512, 512) 279 | self.attention_module2_2 = AttentionModule_stage2(512, 512) # tbq add 280 | self.residual_block3 = ResidualBlock(512, 1024, 2) 281 | self.attention_module3 = AttentionModule_stage3(1024, 1024) 282 | self.attention_module3_2 = AttentionModule_stage3(1024, 1024) # tbq add 283 | self.attention_module3_3 = AttentionModule_stage3(1024, 1024) # tbq add 284 | self.residual_block4 = ResidualBlock(1024, 2048, 2) 285 | self.residual_block5 = ResidualBlock(2048, 2048) 286 | self.residual_block6 = ResidualBlock(2048, 2048) 287 | self.output_layer = nn.Sequential(nn.BatchNorm2d(2048), 288 | nn.Dropout(drop_ratio), 289 | Flatten(), 290 | nn.Linear(2048 * 7 * 7, feature_dim), 291 | nn.BatchNorm1d(feature_dim)) 292 | 293 | def forward(self, x): 294 | out = self.conv1(x) 295 | out = self.mpool1(out) 296 | # print(out.data) 297 | out = self.residual_block1(out) 298 | out = self.attention_module1(out) 299 | out = self.residual_block2(out) 300 | out = self.attention_module2(out) 301 | out = self.attention_module2_2(out) 302 | out = self.residual_block3(out) 303 | # print(out.data) 304 | out = self.attention_module3(out) 305 | out = self.attention_module3_2(out) 306 | out = self.attention_module3_3(out) 307 | out = self.residual_block4(out) 308 | out = self.residual_block5(out) 309 | out = self.residual_block6(out) 310 | out = self.output_layer(out) 311 | 312 | return out 313 | 314 | 315 | if __name__ == '__main__': 316 | input = torch.Tensor(2, 3, 112, 112) 317 | net = ResidualAttentionNet_56() 318 | print(net) 319 | 320 | x = net(input) 321 | print(x.shape) -------------------------------------------------------------------------------- /backbone/cbam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: cbam.py 7 | @time: 2019/1/14 15:33 8 | @desc: Convolutional Block Attention Module in ECCV 2018, including channel attention module and spatial attention module. 9 | ''' 10 | 11 | import torch 12 | from torch import nn 13 | import time 14 | 15 | class Flatten(nn.Module): 16 | def forward(self, input): 17 | return input.view(input.size(0), -1) 18 | 19 | class SEModule(nn.Module): 20 | '''Squeeze and Excitation Module''' 21 | def __init__(self, channels, reduction): 22 | super(SEModule, self).__init__() 23 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 24 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 27 | self.sigmoid = nn.Sigmoid() 28 | 29 | def forward(self, x): 30 | input = x 31 | x = self.avg_pool(x) 32 | x = self.fc1(x) 33 | x = self.relu(x) 34 | x = self.fc2(x) 35 | x = self.sigmoid(x) 36 | 37 | return input * x 38 | 39 | class CAModule(nn.Module): 40 | '''Channel Attention Module''' 41 | def __init__(self, channels, reduction): 42 | super(CAModule, self).__init__() 43 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 44 | self.max_pool = nn.AdaptiveMaxPool2d(1) 45 | self.shared_mlp = nn.Sequential(nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False), 46 | nn.ReLU(inplace=True), 47 | nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)) 48 | self.sigmoid = nn.Sigmoid() 49 | 50 | def forward(self, x): 51 | input = x 52 | avg_pool = self.avg_pool(x) 53 | max_pool = self.max_pool(x) 54 | x = self.shared_mlp(avg_pool) + self.shared_mlp(max_pool) 55 | x = self.sigmoid(x) 56 | 57 | return input * x 58 | 59 | class SAModule(nn.Module): 60 | '''Spatial Attention Module''' 61 | def __init__(self): 62 | super(SAModule, self).__init__() 63 | self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False) 64 | self.sigmoid = nn.Sigmoid() 65 | 66 | def forward(self, x): 67 | input = x 68 | avg_c = torch.mean(x, 1, True) 69 | max_c, _ = torch.max(x, 1, True) 70 | x = torch.cat((avg_c, max_c), 1) 71 | x = self.conv(x) 72 | x = self.sigmoid(x) 73 | return input * x 74 | 75 | class BottleNeck_IR(nn.Module): 76 | '''Improved Residual Bottlenecks''' 77 | def __init__(self, in_channel, out_channel, stride, dim_match): 78 | super(BottleNeck_IR, self).__init__() 79 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), 80 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), 81 | nn.BatchNorm2d(out_channel), 82 | nn.PReLU(out_channel), 83 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), 84 | nn.BatchNorm2d(out_channel)) 85 | if dim_match: 86 | self.shortcut_layer = None 87 | else: 88 | self.shortcut_layer = nn.Sequential( 89 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), 90 | nn.BatchNorm2d(out_channel) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = x 95 | res = self.res_layer(x) 96 | 97 | if self.shortcut_layer is not None: 98 | shortcut = self.shortcut_layer(x) 99 | 100 | return shortcut + res 101 | 102 | class BottleNeck_IR_SE(nn.Module): 103 | '''Improved Residual Bottlenecks with Squeeze and Excitation Module''' 104 | def __init__(self, in_channel, out_channel, stride, dim_match): 105 | super(BottleNeck_IR_SE, self).__init__() 106 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), 107 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), 108 | nn.BatchNorm2d(out_channel), 109 | nn.PReLU(out_channel), 110 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), 111 | nn.BatchNorm2d(out_channel), 112 | SEModule(out_channel, 16)) 113 | if dim_match: 114 | self.shortcut_layer = None 115 | else: 116 | self.shortcut_layer = nn.Sequential( 117 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), 118 | nn.BatchNorm2d(out_channel) 119 | ) 120 | 121 | def forward(self, x): 122 | shortcut = x 123 | res = self.res_layer(x) 124 | 125 | if self.shortcut_layer is not None: 126 | shortcut = self.shortcut_layer(x) 127 | 128 | return shortcut + res 129 | 130 | class BottleNeck_IR_CAM(nn.Module): 131 | '''Improved Residual Bottlenecks with Channel Attention Module''' 132 | def __init__(self, in_channel, out_channel, stride, dim_match): 133 | super(BottleNeck_IR_CAM, self).__init__() 134 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), 135 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), 136 | nn.BatchNorm2d(out_channel), 137 | nn.PReLU(out_channel), 138 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), 139 | nn.BatchNorm2d(out_channel), 140 | CAModule(out_channel, 16)) 141 | if dim_match: 142 | self.shortcut_layer = None 143 | else: 144 | self.shortcut_layer = nn.Sequential( 145 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), 146 | nn.BatchNorm2d(out_channel) 147 | ) 148 | 149 | def forward(self, x): 150 | shortcut = x 151 | res = self.res_layer(x) 152 | 153 | if self.shortcut_layer is not None: 154 | shortcut = self.shortcut_layer(x) 155 | 156 | return shortcut + res 157 | 158 | class BottleNeck_IR_SAM(nn.Module): 159 | '''Improved Residual Bottlenecks with Spatial Attention Module''' 160 | def __init__(self, in_channel, out_channel, stride, dim_match): 161 | super(BottleNeck_IR_SAM, self).__init__() 162 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), 163 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), 164 | nn.BatchNorm2d(out_channel), 165 | nn.PReLU(out_channel), 166 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), 167 | nn.BatchNorm2d(out_channel), 168 | SAModule()) 169 | if dim_match: 170 | self.shortcut_layer = None 171 | else: 172 | self.shortcut_layer = nn.Sequential( 173 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), 174 | nn.BatchNorm2d(out_channel) 175 | ) 176 | 177 | def forward(self, x): 178 | shortcut = x 179 | res = self.res_layer(x) 180 | 181 | if self.shortcut_layer is not None: 182 | shortcut = self.shortcut_layer(x) 183 | 184 | return shortcut + res 185 | 186 | class BottleNeck_IR_CBAM(nn.Module): 187 | '''Improved Residual Bottleneck with Channel Attention Module and Spatial Attention Module''' 188 | def __init__(self, in_channel, out_channel, stride, dim_match): 189 | super(BottleNeck_IR_CBAM, self).__init__() 190 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel), 191 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False), 192 | nn.BatchNorm2d(out_channel), 193 | nn.PReLU(out_channel), 194 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False), 195 | nn.BatchNorm2d(out_channel), 196 | CAModule(out_channel, 16), 197 | SAModule() 198 | ) 199 | if dim_match: 200 | self.shortcut_layer = None 201 | else: 202 | self.shortcut_layer = nn.Sequential( 203 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False), 204 | nn.BatchNorm2d(out_channel) 205 | ) 206 | 207 | def forward(self, x): 208 | shortcut = x 209 | res = self.res_layer(x) 210 | 211 | if self.shortcut_layer is not None: 212 | shortcut = self.shortcut_layer(x) 213 | 214 | return shortcut + res 215 | 216 | 217 | filter_list = [64, 64, 128, 256, 512] 218 | def get_layers(num_layers): 219 | if num_layers == 50: 220 | return [3, 4, 14, 3] 221 | elif num_layers == 100: 222 | return [3, 13, 30, 3] 223 | elif num_layers == 152: 224 | return [3, 8, 36, 3] 225 | 226 | class CBAMResNet(nn.Module): 227 | def __init__(self, num_layers, feature_dim=512, drop_ratio=0.4, mode='ir',filter_list=filter_list): 228 | super(CBAMResNet, self).__init__() 229 | assert num_layers in [50, 100, 152], 'num_layers should be 50, 100 or 152' 230 | assert mode in ['ir', 'ir_se', 'ir_cam', 'ir_sam', 'ir_cbam'], 'mode should be ir, ir_se, ir_cam, ir_sam or ir_cbam' 231 | layers = get_layers(num_layers) 232 | if mode == 'ir': 233 | block = BottleNeck_IR 234 | elif mode == 'ir_se': 235 | block = BottleNeck_IR_SE 236 | elif mode == 'ir_cam': 237 | block = BottleNeck_IR_CAM 238 | elif mode == 'ir_sam': 239 | block = BottleNeck_IR_SAM 240 | elif mode == 'ir_cbam': 241 | block = BottleNeck_IR_CBAM 242 | 243 | self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), stride=1, padding=1, bias=False), 244 | nn.BatchNorm2d(64), 245 | nn.PReLU(64)) 246 | self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2) 247 | self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2) 248 | self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2) 249 | self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2) 250 | 251 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512), 252 | nn.Dropout(drop_ratio), 253 | Flatten(), 254 | nn.Linear(512 * 7 * 7, feature_dim), 255 | nn.BatchNorm1d(feature_dim)) 256 | 257 | # weight initialization 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 260 | nn.init.xavier_uniform_(m.weight) 261 | if m.bias is not None: 262 | nn.init.constant_(m.bias, 0.0) 263 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 264 | nn.init.constant_(m.weight, 1) 265 | nn.init.constant_(m.bias, 0) 266 | 267 | def _make_layer(self, block, in_channel, out_channel, blocks, stride): 268 | layers = [] 269 | layers.append(block(in_channel, out_channel, stride, False)) 270 | for i in range(1, blocks): 271 | layers.append(block(out_channel, out_channel, 1, True)) 272 | 273 | return nn.Sequential(*layers) 274 | 275 | def forward(self, x): 276 | x = self.input_layer(x) 277 | x = self.layer1(x) 278 | x = self.layer2(x) 279 | x = self.layer3(x) 280 | x = self.layer4(x) 281 | x = self.output_layer(x) 282 | 283 | return x 284 | 285 | if __name__ == '__main__': 286 | input = torch.Tensor(2, 3, 112, 112) 287 | net = CBAMResNet(50, mode='ir') 288 | 289 | out = net(input) 290 | print(out.shape) 291 | -------------------------------------------------------------------------------- /backbone/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: mobilefacenet.py 7 | @time: 2018/12/21 15:45 8 | @desc: mobilefacenet backbone 9 | ''' 10 | 11 | import torch 12 | from torch import nn 13 | import math 14 | 15 | MobileFaceNet_BottleNeck_Setting = [ 16 | # t, c , n ,s 17 | [2, 64, 5, 2], 18 | [4, 128, 1, 2], 19 | [2, 128, 6, 1], 20 | [4, 128, 1, 2], 21 | [2, 128, 2, 1] 22 | ] 23 | 24 | class BottleNeck(nn.Module): 25 | def __init__(self, inp, oup, stride, expansion): 26 | super(BottleNeck, self).__init__() 27 | self.connect = stride == 1 and inp == oup 28 | 29 | self.conv = nn.Sequential( 30 | # 1*1 conv 31 | nn.Conv2d(inp, inp * expansion, 1, 1, 0, bias=False), 32 | nn.BatchNorm2d(inp * expansion), 33 | nn.PReLU(inp * expansion), 34 | 35 | # 3*3 depth wise conv 36 | nn.Conv2d(inp * expansion, inp * expansion, 3, stride, 1, groups=inp * expansion, bias=False), 37 | nn.BatchNorm2d(inp * expansion), 38 | nn.PReLU(inp * expansion), 39 | 40 | # 1*1 conv 41 | nn.Conv2d(inp * expansion, oup, 1, 1, 0, bias=False), 42 | nn.BatchNorm2d(oup), 43 | ) 44 | 45 | def forward(self, x): 46 | if self.connect: 47 | return x + self.conv(x) 48 | else: 49 | return self.conv(x) 50 | 51 | 52 | class ConvBlock(nn.Module): 53 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False): 54 | super(ConvBlock, self).__init__() 55 | self.linear = linear 56 | if dw: 57 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False) 58 | else: 59 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False) 60 | 61 | self.bn = nn.BatchNorm2d(oup) 62 | if not linear: 63 | self.prelu = nn.PReLU(oup) 64 | 65 | def forward(self, x): 66 | x = self.conv(x) 67 | x = self.bn(x) 68 | if self.linear: 69 | return x 70 | else: 71 | return self.prelu(x) 72 | 73 | 74 | class MobileFaceNet(nn.Module): 75 | def __init__(self, feature_dim=128, bottleneck_setting=MobileFaceNet_BottleNeck_Setting): 76 | super(MobileFaceNet, self).__init__() 77 | self.conv1 = ConvBlock(3, 64, 3, 2, 1) 78 | self.dw_conv1 = ConvBlock(64, 64, 3, 1, 1, dw=True) 79 | 80 | self.cur_channel = 64 81 | block = BottleNeck 82 | self.blocks = self._make_layer(block, bottleneck_setting) 83 | 84 | self.conv2 = ConvBlock(128, 512, 1, 1, 0) 85 | self.linear7 = ConvBlock(512, 512, 7, 1, 0, dw=True, linear=True) 86 | self.linear1 = ConvBlock(512, feature_dim, 1, 1, 0, linear=True) 87 | 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 91 | m.weight.data.normal_(0, math.sqrt(2. / n)) 92 | elif isinstance(m, nn.BatchNorm2d): 93 | m.weight.data.fill_(1) 94 | m.bias.data.zero_() 95 | 96 | def _make_layer(self, block, setting): 97 | layers = [] 98 | for t, c, n, s in setting: 99 | for i in range(n): 100 | if i == 0: 101 | layers.append(block(self.cur_channel, c, s, t)) 102 | else: 103 | layers.append(block(self.cur_channel, c, 1, t)) 104 | self.cur_channel = c 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, x): 109 | x = self.conv1(x) 110 | x = self.dw_conv1(x) 111 | x = self.blocks(x) 112 | x = self.conv2(x) 113 | x = self.linear7(x) 114 | x = self.linear1(x) 115 | x = x.view(x.size(0), -1) 116 | 117 | return x 118 | 119 | 120 | if __name__ == "__main__": 121 | input = torch.Tensor(2, 3, 112, 112) 122 | net = MobileFaceNet() 123 | print(net) 124 | 125 | x = net(input) 126 | print(x.shape) -------------------------------------------------------------------------------- /backbone/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: resnet.py 7 | @time: 2018/12/24 14:40 8 | @desc: Original ResNet backbone, including ResNet18, ResNet34, ResNet50, ResNet101 and ResNet152, we removed the last global average pooling layer 9 | and replaced it with a fully connected layer with dimension of 512. BN is used for fast convergence. 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | 14 | def ResNet18(): 15 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 16 | return model 17 | 18 | def ResNet34(): 19 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 20 | return model 21 | 22 | def ResNet50(): 23 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 24 | return model 25 | 26 | def ResNet101(): 27 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 28 | return model 29 | 30 | def ResNet152(): 31 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 32 | return model 33 | 34 | __all__ = ['ResNet', 'ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet152'] 35 | 36 | 37 | def conv3x3(in_planes, out_planes, stride=1): 38 | """3x3 convolution with padding""" 39 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 40 | 41 | 42 | def conv1x1(in_planes, out_planes, stride=1): 43 | """1x1 convolution""" 44 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | identity = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | expansion = 4 80 | def __init__(self, inplanes, planes, stride=1, downsample=None): 81 | super(Bottleneck, self).__init__() 82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) 83 | self.bn1 = nn.BatchNorm2d(planes) 84 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 85 | self.bn2 = nn.BatchNorm2d(planes) 86 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1, bias=False) 87 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | identity = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | out = self.bn3(out) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | 109 | out += identity 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class Flatten(nn.Module): 116 | def forward(self, input): 117 | return input.view(input.size(0), -1) 118 | 119 | 120 | class ResNet(nn.Module): 121 | 122 | def __init__(self, block, layers, feature_dim=512, drop_ratio=0.4, zero_init_residual=False): 123 | super(ResNet, self).__init__() 124 | self.inplanes = 64 125 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 126 | self.bn1 = nn.BatchNorm2d(64) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 129 | self.layer1 = self._make_layer(block, 64, layers[0]) 130 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 131 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 132 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 133 | 134 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512 * block.expansion), 135 | nn.Dropout(drop_ratio), 136 | Flatten(), 137 | nn.Linear(512 * block.expansion * 7 * 7, feature_dim), 138 | nn.BatchNorm1d(feature_dim)) 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 143 | elif isinstance(m, nn.BatchNorm2d): 144 | nn.init.constant_(m.weight, 1) 145 | nn.init.constant_(m.bias, 0) 146 | 147 | # Zero-initialize the last BN in each residual branch, 148 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 149 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 150 | if zero_init_residual: 151 | for m in self.modules(): 152 | if isinstance(m, Bottleneck): 153 | nn.init.constant_(m.bn3.weight, 0) 154 | elif isinstance(m, BasicBlock): 155 | nn.init.constant_(m.bn2.weight, 0) 156 | 157 | def _make_layer(self, block, planes, blocks, stride=1): 158 | downsample = None 159 | if stride != 1 or self.inplanes != planes * block.expansion: 160 | downsample = nn.Sequential( 161 | conv1x1(self.inplanes, planes * block.expansion, stride), 162 | nn.BatchNorm2d(planes * block.expansion), 163 | ) 164 | 165 | layers = [] 166 | layers.append(block(self.inplanes, planes, stride, downsample)) 167 | self.inplanes = planes * block.expansion 168 | for _ in range(1, blocks): 169 | layers.append(block(self.inplanes, planes)) 170 | 171 | return nn.Sequential(*layers) 172 | 173 | def forward(self, x): 174 | x = self.conv1(x) 175 | x = self.bn1(x) 176 | x = self.relu(x) 177 | x = self.maxpool(x) 178 | 179 | x = self.layer1(x) 180 | x = self.layer2(x) 181 | x = self.layer3(x) 182 | x = self.layer4(x) 183 | 184 | x = self.output_layer(x) 185 | 186 | return x 187 | 188 | 189 | if __name__ == "__main__": 190 | input = torch.Tensor(2, 3, 112, 112) 191 | net = ResNet50() 192 | print(net) 193 | 194 | x = net(input) 195 | print(x.shape) -------------------------------------------------------------------------------- /backbone/spherenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: spherenet.py 7 | @time: 2018/12/26 10:14 8 | @desc: A 64 layer residual network struture used in sphereface and cosface, for fast convergence, I add BN after every Conv layer. 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class Block(nn.Module): 15 | def __init__(self, channels): 16 | super(Block, self).__init__() 17 | self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(channels) 19 | self.prelu1 = nn.PReLU(channels) 20 | self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(channels) 22 | self.prelu2 = nn.PReLU(channels) 23 | 24 | def forward(self, x): 25 | short_cut = x 26 | x = self.conv1(x) 27 | x = self.bn1(x) 28 | x = self.prelu1(x) 29 | x = self.conv2(x) 30 | x = self.bn2(x) 31 | x = self.prelu2(x) 32 | 33 | return x + short_cut 34 | 35 | 36 | class SphereNet(nn.Module): 37 | def __init__(self, num_layers = 20, feature_dim=512): 38 | super(SphereNet, self).__init__() 39 | assert num_layers in [20, 64], 'SphereNet num_layers should be 20 or 64' 40 | if num_layers == 20: 41 | layers = [1, 2, 4, 1] 42 | elif num_layers == 64: 43 | layers = [3, 7, 16, 3] 44 | else: 45 | raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)") 46 | 47 | filter_list = [3, 64, 128, 256, 512] 48 | block = Block 49 | self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2) 50 | self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2) 51 | self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2) 52 | self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2) 53 | self.fc = nn.Linear(512 * 7 * 7, feature_dim) 54 | self.last_bn = nn.BatchNorm1d(feature_dim) 55 | 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 58 | if m.bias is not None: 59 | nn.init.xavier_uniform_(m.weight) 60 | nn.init.constant_(m.bias, 0) 61 | else: 62 | nn.init.normal_(m.weight, 0, 0.01) 63 | 64 | def _make_layer(self, block, inplanes, planes, num_units, stride): 65 | layers = [] 66 | layers.append(nn.Conv2d(inplanes, planes, 3, stride, 1)) 67 | layers.append(nn.BatchNorm2d(planes)) 68 | layers.append(nn.PReLU(planes)) 69 | for i in range(num_units): 70 | layers.append(block(planes)) 71 | 72 | return nn.Sequential(*layers) 73 | 74 | 75 | def forward(self, x): 76 | x = self.layer1(x) 77 | x = self.layer2(x) 78 | x = self.layer3(x) 79 | x = self.layer4(x) 80 | 81 | x = x.view(x.size(0), -1) 82 | x = self.fc(x) 83 | x = self.last_bn(x) 84 | 85 | return x 86 | 87 | 88 | if __name__ == '__main__': 89 | input = torch.Tensor(2, 3, 112, 112) 90 | net = SphereNet(num_layers=64, feature_dim=512) 91 | 92 | out = net(input) 93 | print(out.shape) 94 | 95 | -------------------------------------------------------------------------------- /cppapi/README.md: -------------------------------------------------------------------------------- 1 | ## CppAPI 2 | 3 | libtorch for C++ deployment 4 | 5 | -------------------------------------------------------------------------------- /cppapi/pytorch2torchscript.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: pytorch2torchscript.py 7 | @time: 2019/2/18 17:45 8 | @desc: convert your pytorch model to torch script and save to file 9 | ''' 10 | 11 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: __init__.py.py 7 | @time: 2018/12/21 15:31 8 | @desc: 9 | ''' -------------------------------------------------------------------------------- /dataset/agedb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: agedb.py.py 7 | @time: 2018/12/25 18:43 8 | @desc: AgeDB-30 test data loader, agedb test protocol is the same with lfw 9 | ''' 10 | 11 | import numpy as np 12 | import cv2 13 | import os 14 | import torch.utils.data as data 15 | 16 | import torch 17 | import torchvision.transforms as transforms 18 | 19 | def img_loader(path): 20 | try: 21 | with open(path, 'rb') as f: 22 | img = cv2.imread(path) 23 | if len(img.shape) == 2: 24 | img = np.stack([img] * 3, 2) 25 | return img 26 | except IOError: 27 | print('Cannot load image ' + path) 28 | 29 | class AgeDB30(data.Dataset): 30 | def __init__(self, root, file_list, transform=None, loader=img_loader): 31 | 32 | self.root = root 33 | self.file_list = file_list 34 | self.transform = transform 35 | self.loader = loader 36 | self.nameLs = [] 37 | self.nameRs = [] 38 | self.folds = [] 39 | self.flags = [] 40 | 41 | with open(file_list) as f: 42 | pairs = f.read().splitlines() 43 | for i, p in enumerate(pairs): 44 | p = p.split(' ') 45 | nameL = p[0] 46 | nameR = p[1] 47 | fold = i // 600 48 | flag = int(p[2]) 49 | 50 | self.nameLs.append(nameL) 51 | self.nameRs.append(nameR) 52 | self.folds.append(fold) 53 | self.flags.append(flag) 54 | 55 | def __getitem__(self, index): 56 | 57 | img_l = self.loader(os.path.join(self.root, self.nameLs[index])) 58 | img_r = self.loader(os.path.join(self.root, self.nameRs[index])) 59 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)] 60 | 61 | if self.transform is not None: 62 | for i in range(len(imglist)): 63 | imglist[i] = self.transform(imglist[i]) 64 | 65 | imgs = imglist 66 | return imgs 67 | else: 68 | imgs = [torch.from_numpy(i) for i in imglist] 69 | return imgs 70 | 71 | def __len__(self): 72 | return len(self.nameLs) 73 | 74 | 75 | if __name__ == '__main__': 76 | root = '/media/sda/AgeDB-30/agedb30_align_112' 77 | file_list = '/media/sda/AgeDB-30/agedb_30_pair.txt' 78 | 79 | transform = transforms.Compose([ 80 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 81 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 82 | ]) 83 | 84 | dataset = AgeDB30(root, file_list, transform=transform) 85 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False) 86 | for data in trainloader: 87 | for d in data: 88 | print(d[0].shape) -------------------------------------------------------------------------------- /dataset/casia_webface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: casia_webface.py 7 | @time: 2018/12/21 19:09 8 | @desc: CASIA-WebFace dataset loader 9 | ''' 10 | 11 | import torchvision.transforms as transforms 12 | import torch.utils.data as data 13 | import numpy as np 14 | import cv2 15 | import os 16 | import torch 17 | 18 | 19 | def img_loader(path): 20 | try: 21 | with open(path, 'rb') as f: 22 | img = cv2.imread(path) 23 | if len(img.shape) == 2: 24 | img = np.stack([img] * 3, 2) 25 | return img 26 | except IOError: 27 | print('Cannot load image ' + path) 28 | 29 | 30 | class CASIAWebFace(data.Dataset): 31 | def __init__(self, root, file_list, transform=None, loader=img_loader): 32 | 33 | self.root = root 34 | self.transform = transform 35 | self.loader = loader 36 | 37 | image_list = [] 38 | label_list = [] 39 | with open(file_list) as f: 40 | img_label_list = f.read().splitlines() 41 | for info in img_label_list: 42 | image_path, label_name = info.split(' ') 43 | image_list.append(image_path) 44 | label_list.append(int(label_name)) 45 | 46 | self.image_list = image_list 47 | self.label_list = label_list 48 | self.class_nums = len(np.unique(self.label_list)) 49 | print("dataset size: ", len(self.image_list), '/', self.class_nums) 50 | 51 | def __getitem__(self, index): 52 | img_path = self.image_list[index] 53 | label = self.label_list[index] 54 | 55 | img = self.loader(os.path.join(self.root, img_path)) 56 | 57 | # random flip with ratio of 0.5 58 | flip = np.random.choice(2) * 2 - 1 59 | if flip == 1: 60 | img = cv2.flip(img, 1) 61 | 62 | if self.transform is not None: 63 | img = self.transform(img) 64 | else: 65 | img = torch.from_numpy(img) 66 | 67 | return img, label 68 | 69 | def __len__(self): 70 | return len(self.image_list) 71 | 72 | 73 | if __name__ == '__main__': 74 | root = 'D:/data/webface_align_112' 75 | file_list = 'D:/data/webface_align_train.list' 76 | 77 | transform = transforms.Compose([ 78 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 79 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 80 | ]) 81 | dataset = CASIAWebFace(root, file_list, transform=transform) 82 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=False) 83 | print(len(dataset)) 84 | for data in trainloader: 85 | print(data[0].shape) -------------------------------------------------------------------------------- /dataset/cfp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: cfp.py 7 | @time: 2018/12/26 16:19 8 | @desc: the CFP-FP test dataset loader, it's similar with lfw and adedb, except that it has 700 pairs every fold 9 | ''' 10 | 11 | 12 | import numpy as np 13 | import cv2 14 | import os 15 | import torch.utils.data as data 16 | 17 | import torch 18 | import torchvision.transforms as transforms 19 | 20 | def img_loader(path): 21 | try: 22 | with open(path, 'rb') as f: 23 | img = cv2.imread(path) 24 | if len(img.shape) == 2: 25 | img = np.stack([img] * 3, 2) 26 | return img 27 | except IOError: 28 | print('Cannot load image ' + path) 29 | 30 | class CFP_FP(data.Dataset): 31 | def __init__(self, root, file_list, transform=None, loader=img_loader): 32 | 33 | self.root = root 34 | self.file_list = file_list 35 | self.transform = transform 36 | self.loader = loader 37 | self.nameLs = [] 38 | self.nameRs = [] 39 | self.folds = [] 40 | self.flags = [] 41 | 42 | with open(file_list) as f: 43 | pairs = f.read().splitlines() 44 | for i, p in enumerate(pairs): 45 | p = p.split(' ') 46 | nameL = p[0] 47 | nameR = p[1] 48 | fold = i // 700 49 | flag = int(p[2]) 50 | 51 | self.nameLs.append(nameL) 52 | self.nameRs.append(nameR) 53 | self.folds.append(fold) 54 | self.flags.append(flag) 55 | 56 | def __getitem__(self, index): 57 | 58 | img_l = self.loader(os.path.join(self.root, self.nameLs[index])) 59 | img_r = self.loader(os.path.join(self.root, self.nameRs[index])) 60 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)] 61 | 62 | if self.transform is not None: 63 | for i in range(len(imglist)): 64 | imglist[i] = self.transform(imglist[i]) 65 | 66 | imgs = imglist 67 | return imgs 68 | else: 69 | imgs = [torch.from_numpy(i) for i in imglist] 70 | return imgs 71 | 72 | def __len__(self): 73 | return len(self.nameLs) 74 | 75 | 76 | if __name__ == '__main__': 77 | root = '/media/sda/CFP-FP/CFP_FP_aligned_112' 78 | file_list = '/media/sda/CFP-FP/cfp-fp-pair.txt' 79 | 80 | transform = transforms.Compose([ 81 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 82 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 83 | ]) 84 | 85 | dataset = CFP_FP(root, file_list, transform=transform) 86 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False) 87 | for data in trainloader: 88 | for d in data: 89 | print(d[0].shape) -------------------------------------------------------------------------------- /dataset/lfw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: lfw.py.py 7 | @time: 2018/12/22 10:00 8 | @desc: lfw dataset loader 9 | ''' 10 | 11 | import numpy as np 12 | import cv2 13 | import os 14 | import torch.utils.data as data 15 | 16 | import torch 17 | import torchvision.transforms as transforms 18 | 19 | def img_loader(path): 20 | try: 21 | with open(path, 'rb') as f: 22 | img = cv2.imread(path) 23 | if len(img.shape) == 2: 24 | img = np.stack([img] * 3, 2) 25 | return img 26 | except IOError: 27 | print('Cannot load image ' + path) 28 | 29 | class LFW(data.Dataset): 30 | def __init__(self, root, file_list, transform=None, loader=img_loader): 31 | 32 | self.root = root 33 | self.file_list = file_list 34 | self.transform = transform 35 | self.loader = loader 36 | self.nameLs = [] 37 | self.nameRs = [] 38 | self.folds = [] 39 | self.flags = [] 40 | 41 | with open(file_list) as f: 42 | pairs = f.read().splitlines()[1:] 43 | for i, p in enumerate(pairs): 44 | p = p.split('\t') 45 | if len(p) == 3: 46 | nameL = p[0] + '/' + p[0] + '_' + '{:04}.jpg'.format(int(p[1])) 47 | nameR = p[0] + '/' + p[0] + '_' + '{:04}.jpg'.format(int(p[2])) 48 | fold = i // 600 49 | flag = 1 50 | elif len(p) == 4: 51 | nameL = p[0] + '/' + p[0] + '_' + '{:04}.jpg'.format(int(p[1])) 52 | nameR = p[2] + '/' + p[2] + '_' + '{:04}.jpg'.format(int(p[3])) 53 | fold = i // 600 54 | flag = -1 55 | self.nameLs.append(nameL) 56 | self.nameRs.append(nameR) 57 | self.folds.append(fold) 58 | self.flags.append(flag) 59 | 60 | def __getitem__(self, index): 61 | 62 | img_l = self.loader(os.path.join(self.root, self.nameLs[index])) 63 | img_r = self.loader(os.path.join(self.root, self.nameRs[index])) 64 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)] 65 | 66 | if self.transform is not None: 67 | for i in range(len(imglist)): 68 | imglist[i] = self.transform(imglist[i]) 69 | 70 | imgs = imglist 71 | return imgs 72 | else: 73 | imgs = [torch.from_numpy(i) for i in imglist] 74 | return imgs 75 | 76 | def __len__(self): 77 | return len(self.nameLs) 78 | 79 | 80 | if __name__ == '__main__': 81 | root = 'D:/data/lfw_align_112' 82 | file_list = 'D:/data/pairs.txt' 83 | 84 | transform = transforms.Compose([ 85 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 86 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 87 | ]) 88 | 89 | dataset = LFW(root, file_list, transform=transform) 90 | #dataset = LFW(root, file_list) 91 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False) 92 | print(len(dataset)) 93 | for data in trainloader: 94 | for d in data: 95 | print(d[0].shape) -------------------------------------------------------------------------------- /dataset/lfw_2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: lfw_2.py 7 | @time: 2019/2/19 16:59 8 | @desc: lfw dataset from insightface ,just like agedb and cfp-fp 9 | ''' 10 | 11 | 12 | import numpy as np 13 | import cv2 14 | import os 15 | import torch.utils.data as data 16 | 17 | import torch 18 | import torchvision.transforms as transforms 19 | 20 | def img_loader(path): 21 | try: 22 | with open(path, 'rb') as f: 23 | img = cv2.imread(path) 24 | if len(img.shape) == 2: 25 | img = np.stack([img] * 3, 2) 26 | return img 27 | except IOError: 28 | print('Cannot load image ' + path) 29 | 30 | class LFW_2(data.Dataset): 31 | def __init__(self, root, file_list, transform=None, loader=img_loader): 32 | 33 | self.root = root 34 | self.file_list = file_list 35 | self.transform = transform 36 | self.loader = loader 37 | self.nameLs = [] 38 | self.nameRs = [] 39 | self.folds = [] 40 | self.flags = [] 41 | 42 | with open(file_list) as f: 43 | pairs = f.read().splitlines() 44 | for i, p in enumerate(pairs): 45 | p = p.split(' ') 46 | nameL = p[0] 47 | nameR = p[1] 48 | fold = i // 600 49 | flag = int(p[2]) 50 | 51 | self.nameLs.append(nameL) 52 | self.nameRs.append(nameR) 53 | self.folds.append(fold) 54 | self.flags.append(flag) 55 | 56 | def __getitem__(self, index): 57 | 58 | img_l = self.loader(os.path.join(self.root, self.nameLs[index])) 59 | img_r = self.loader(os.path.join(self.root, self.nameRs[index])) 60 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)] 61 | 62 | if self.transform is not None: 63 | for i in range(len(imglist)): 64 | imglist[i] = self.transform(imglist[i]) 65 | 66 | imgs = imglist 67 | return imgs 68 | else: 69 | imgs = [torch.from_numpy(i) for i in imglist] 70 | return imgs 71 | 72 | def __len__(self): 73 | return len(self.nameLs) 74 | 75 | 76 | if __name__ == '__main__': 77 | root = '/media/sda/insightface_emore/lfw' 78 | file_list = '/media/sda/insightface_emore/pair_lfw.txt' 79 | 80 | transform = transforms.Compose([ 81 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 82 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 83 | ]) 84 | 85 | dataset = LFW_2(root, file_list, transform=transform) 86 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False) 87 | for data in trainloader: 88 | for d in data: 89 | print(d[0].shape) -------------------------------------------------------------------------------- /dataset/megaface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: megaface.py 7 | @time: 2018/12/24 16:29 8 | @desc: 9 | ''' 10 | 11 | import torchvision.transforms as transforms 12 | import torch.utils.data as data 13 | import numpy as np 14 | import cv2 15 | import os 16 | import torch 17 | 18 | def img_loader(path): 19 | try: 20 | with open(path, 'rb') as f: 21 | img = cv2.imread(path) 22 | if len(img.shape) == 2: 23 | img = np.stack([img] * 3, 2) 24 | return img 25 | except IOError: 26 | print('Cannot load image ' + path) 27 | 28 | 29 | class MegaFace(data.Dataset): 30 | def __init__(self, facescrub_dir, megaface_dir, transform=None, loader=img_loader): 31 | 32 | self.transform = transform 33 | self.loader = loader 34 | 35 | test_image_file_list = [] 36 | print('Scanning files under facescrub and megaface...') 37 | for root, dirs, files in os.walk(facescrub_dir): 38 | for e in files: 39 | filename = os.path.join(root, e) 40 | ext = os.path.splitext(filename)[1].lower() 41 | if ext in ('.png', '.bmp', '.jpg', '.jpeg'): 42 | test_image_file_list.append(filename) 43 | for root, dirs, files in os.walk(megaface_dir): 44 | for e in files: 45 | filename = os.path.join(root, e) 46 | ext = os.path.splitext(filename)[1].lower() 47 | if ext in ('.png', '.bmp', '.jpg', '.jpeg'): 48 | test_image_file_list.append(filename) 49 | 50 | self.image_list = test_image_file_list 51 | 52 | def __getitem__(self, index): 53 | img_path = self.image_list[index] 54 | img = self.loader(img_path) 55 | 56 | #水平翻转图像 57 | #img = cv2.flip(img, 1) 58 | 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | else: 62 | img = torch.from_numpy(img) 63 | 64 | return img, img_path 65 | 66 | def __len__(self): 67 | return len(self.image_list) 68 | 69 | 70 | if __name__ == '__main__': 71 | facescrub = '/media/sda/megaface_test_kit/facescrub_align_112/' 72 | megaface = '/media/sda/megaface_test_kit/megaface_align_112/' 73 | 74 | transform = transforms.Compose([ 75 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 76 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 77 | ]) 78 | dataset = MegaFace(facescrub, megaface, transform=transform) 79 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False) 80 | print(len(dataset)) 81 | for data in trainloader: 82 | print(data.shape) -------------------------------------------------------------------------------- /eval_agedb30.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: eval_agedb30.py 7 | @time: 2018/12/25 19:05 8 | @desc: The AgeDB-30 test protocol is same with LFW, so I just copy the code from eval_lfw.py 9 | ''' 10 | 11 | 12 | import numpy as np 13 | import scipy.io 14 | import os 15 | import torch.utils.data 16 | from backbone import mobilefacenet, resnet, arcfacenet, cbam 17 | from dataset.agedb import AgeDB30 18 | import torchvision.transforms as transforms 19 | from torch.nn import DataParallel 20 | import argparse 21 | 22 | def getAccuracy(scores, flags, threshold): 23 | p = np.sum(scores[flags == 1] > threshold) 24 | n = np.sum(scores[flags == -1] < threshold) 25 | return 1.0 * (p + n) / len(scores) 26 | 27 | def getThreshold(scores, flags, thrNum): 28 | accuracys = np.zeros((2 * thrNum + 1, 1)) 29 | thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum 30 | for i in range(2 * thrNum + 1): 31 | accuracys[i] = getAccuracy(scores, flags, thresholds[i]) 32 | max_index = np.squeeze(accuracys == np.max(accuracys)) 33 | bestThreshold = np.mean(thresholds[max_index]) 34 | return bestThreshold 35 | 36 | def evaluation_10_fold(feature_path='./result/cur_epoch_agedb_result.mat'): 37 | ACCs = np.zeros(10) 38 | result = scipy.io.loadmat(feature_path) 39 | for i in range(10): 40 | fold = result['fold'] 41 | flags = result['flag'] 42 | featureLs = result['fl'] 43 | featureRs = result['fr'] 44 | 45 | valFold = fold != i 46 | testFold = fold == i 47 | flags = np.squeeze(flags) 48 | 49 | mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0) 50 | mu = np.expand_dims(mu, 0) 51 | featureLs = featureLs - mu 52 | featureRs = featureRs - mu 53 | featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1) 54 | featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1) 55 | 56 | scores = np.sum(np.multiply(featureLs, featureRs), 1) 57 | threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000) 58 | ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold) 59 | 60 | return ACCs 61 | 62 | def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None): 63 | 64 | if backbone_net == 'MobileFace': 65 | net = mobilefacenet.MobileFaceNet() 66 | elif backbone_net == 'CBAM_50': 67 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir') 68 | elif backbone_net == 'CBAM_50_SE': 69 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se') 70 | elif backbone_net == 'CBAM_100': 71 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir') 72 | elif backbone_net == 'CBAM_100_SE': 73 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se') 74 | else: 75 | print(backbone_net, ' is not available!') 76 | 77 | # gpu init 78 | multi_gpus = False 79 | if len(gpus.split(',')) > 1: 80 | multi_gpus = True 81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 83 | 84 | net.load_state_dict(torch.load(resume)['net_state_dict']) 85 | 86 | if multi_gpus: 87 | net = DataParallel(net).to(device) 88 | else: 89 | net = net.to(device) 90 | 91 | transform = transforms.Compose([ 92 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 93 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 94 | ]) 95 | agedb_dataset = AgeDB30(data_root, file_list, transform=transform) 96 | agedb_loader = torch.utils.data.DataLoader(agedb_dataset, batch_size=128, 97 | shuffle=False, num_workers=2, drop_last=False) 98 | 99 | return net.eval(), device, agedb_dataset, agedb_loader 100 | 101 | def getFeatureFromTorch(feature_save_dir, net, device, data_set, data_loader): 102 | featureLs = None 103 | featureRs = None 104 | count = 0 105 | for data in data_loader: 106 | for i in range(len(data)): 107 | data[i] = data[i].to(device) 108 | count += data[0].size(0) 109 | #print('extracing deep features from the face pair {}...'.format(count)) 110 | with torch.no_grad(): 111 | res = [net(d).data.cpu().numpy() for d in data] 112 | featureL = np.concatenate((res[0], res[1]), 1) 113 | featureR = np.concatenate((res[2], res[3]), 1) 114 | # print(featureL.shape, featureR.shape) 115 | if featureLs is None: 116 | featureLs = featureL 117 | else: 118 | featureLs = np.concatenate((featureLs, featureL), 0) 119 | if featureRs is None: 120 | featureRs = featureR 121 | else: 122 | featureRs = np.concatenate((featureRs, featureR), 0) 123 | # print(featureLs.shape, featureRs.shape) 124 | 125 | result = {'fl': featureLs, 'fr': featureRs, 'fold': data_set.folds, 'flag': data_set.flags} 126 | scipy.io.savemat(feature_save_dir, result) 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser(description='Testing') 131 | parser.add_argument('--root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='The path of lfw data') 132 | parser.add_argument('--file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='The path of lfw data') 133 | parser.add_argument('--resume', type=str, default='./model/SERES100_SERES100_IR_20190528_132635/Iter_342000_net.ckpt', help='The path pf save model') 134 | parser.add_argument('--backbone_net', type=str, default='CBAM_100_SE', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE') 135 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension') 136 | parser.add_argument('--feature_save_path', type=str, default='./result/cur_epoch_agedb_result.mat', 137 | help='The path of the extract features save, must be .mat file') 138 | parser.add_argument('--gpus', type=str, default='2,3', help='gpu list') 139 | args = parser.parse_args() 140 | 141 | net, device, agedb_dataset, agedb_loader = loadModel(args.root, args.file_list, args.backbone_net, args.gpus, args.resume) 142 | getFeatureFromTorch(args.feature_save_path, net, device, agedb_dataset, agedb_loader) 143 | ACCs = evaluation_10_fold(args.feature_save_path) 144 | for i in range(len(ACCs)): 145 | print('{} {:.2f}'.format(i + 1, ACCs[i] * 100)) 146 | print('--------') 147 | print('AVE {:.4f}'.format(np.mean(ACCs) * 100)) 148 | 149 | -------------------------------------------------------------------------------- /eval_cfp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: eval_cfp.py 7 | @time: 2018/12/26 16:23 8 | @desc: this code is very similar with eval_lfw.py and eval_agedb30.py 9 | ''' 10 | 11 | 12 | import numpy as np 13 | import scipy.io 14 | import os 15 | import torch.utils.data 16 | from backbone import mobilefacenet, resnet, arcfacenet, cbam 17 | from dataset.cfp import CFP_FP 18 | import torchvision.transforms as transforms 19 | from torch.nn import DataParallel 20 | import argparse 21 | 22 | def getAccuracy(scores, flags, threshold): 23 | p = np.sum(scores[flags == 1] > threshold) 24 | n = np.sum(scores[flags == -1] < threshold) 25 | return 1.0 * (p + n) / len(scores) 26 | 27 | def getThreshold(scores, flags, thrNum): 28 | accuracys = np.zeros((2 * thrNum + 1, 1)) 29 | thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum 30 | for i in range(2 * thrNum + 1): 31 | accuracys[i] = getAccuracy(scores, flags, thresholds[i]) 32 | max_index = np.squeeze(accuracys == np.max(accuracys)) 33 | bestThreshold = np.mean(thresholds[max_index]) 34 | return bestThreshold 35 | 36 | def evaluation_10_fold(feature_path='./result/cur_epoch_cfp_result.mat'): 37 | ACCs = np.zeros(10) 38 | result = scipy.io.loadmat(feature_path) 39 | for i in range(10): 40 | fold = result['fold'] 41 | flags = result['flag'] 42 | featureLs = result['fl'] 43 | featureRs = result['fr'] 44 | 45 | valFold = fold != i 46 | testFold = fold == i 47 | flags = np.squeeze(flags) 48 | 49 | mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0) 50 | mu = np.expand_dims(mu, 0) 51 | featureLs = featureLs - mu 52 | featureRs = featureRs - mu 53 | featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1) 54 | featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1) 55 | 56 | scores = np.sum(np.multiply(featureLs, featureRs), 1) 57 | threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000) 58 | ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold) 59 | 60 | return ACCs 61 | 62 | def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None): 63 | 64 | if backbone_net == 'MobileFace': 65 | net = mobilefacenet.MobileFaceNet() 66 | elif backbone_net == 'CBAM_50': 67 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir') 68 | elif backbone_net == 'CBAM_50_SE': 69 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se') 70 | elif backbone_net == 'CBAM_100': 71 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir') 72 | elif backbone_net == 'CBAM_100_SE': 73 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se') 74 | else: 75 | print(backbone_net, ' is not available!') 76 | 77 | # gpu init 78 | multi_gpus = False 79 | if len(gpus.split(',')) > 1: 80 | multi_gpus = True 81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 83 | 84 | net.load_state_dict(torch.load(resume)['net_state_dict']) 85 | 86 | if multi_gpus: 87 | net = DataParallel(net).to(device) 88 | else: 89 | net = net.to(device) 90 | 91 | transform = transforms.Compose([ 92 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 93 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 94 | ]) 95 | cfp_dataset = CFP_FP(data_root, file_list, transform=transform) 96 | cfp_loader = torch.utils.data.DataLoader(cfp_dataset, batch_size=128, 97 | shuffle=False, num_workers=4, drop_last=False) 98 | 99 | return net.eval(), device, cfp_dataset, cfp_loader 100 | 101 | def getFeatureFromTorch(feature_save_dir, net, device, data_set, data_loader): 102 | featureLs = None 103 | featureRs = None 104 | count = 0 105 | for data in data_loader: 106 | for i in range(len(data)): 107 | data[i] = data[i].to(device) 108 | count += data[0].size(0) 109 | #print('extracing deep features from the face pair {}...'.format(count)) 110 | with torch.no_grad(): 111 | res = [net(d).data.cpu().numpy() for d in data] 112 | featureL = np.concatenate((res[0], res[1]), 1) 113 | featureR = np.concatenate((res[2], res[3]), 1) 114 | # print(featureL.shape, featureR.shape) 115 | if featureLs is None: 116 | featureLs = featureL 117 | else: 118 | featureLs = np.concatenate((featureLs, featureL), 0) 119 | if featureRs is None: 120 | featureRs = featureR 121 | else: 122 | featureRs = np.concatenate((featureRs, featureR), 0) 123 | # print(featureLs.shape, featureRs.shape) 124 | 125 | result = {'fl': featureLs, 'fr': featureRs, 'fold': data_set.folds, 'flag': data_set.flags} 126 | scipy.io.savemat(feature_save_dir, result) 127 | 128 | 129 | if __name__ == '__main__': 130 | parser = argparse.ArgumentParser(description='Testing') 131 | parser.add_argument('--root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='The path of lfw data') 132 | parser.add_argument('--file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='The path of lfw data') 133 | parser.add_argument('--resume', type=str, default='./model/SERES100_SERES100_IR_20190528_132635/Iter_342000_net.ckpt', help='The path pf save model') 134 | parser.add_argument('--backbone_net', type=str, default='CBAM_100_SE', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE') 135 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension') 136 | parser.add_argument('--feature_save_path', type=str, default='./result/cur_epoch_cfp_result.mat', 137 | help='The path of the extract features save, must be .mat file') 138 | parser.add_argument('--gpus', type=str, default='2,3', help='gpu list') 139 | args = parser.parse_args() 140 | 141 | net, device, agedb_dataset, agedb_loader = loadModel(args.root, args.file_list, args.backbone_net, args.gpus, args.resume) 142 | getFeatureFromTorch(args.feature_save_path, net, device, agedb_dataset, agedb_loader) 143 | ACCs = evaluation_10_fold(args.feature_save_path) 144 | for i in range(len(ACCs)): 145 | print('{} {:.2f}'.format(i + 1, ACCs[i] * 100)) 146 | print('--------') 147 | print('AVE {:.4f}'.format(np.mean(ACCs) * 100)) -------------------------------------------------------------------------------- /eval_deepglint_merge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: eval_deepglint_merge.py.py 7 | @time: 2019/3/21 11:09 8 | @desc: merge the feature of deepglint test data to one file. original deepglint feature is generated by the protocol of megaface. 9 | ''' 10 | 11 | """ 12 | We use the same format as Megaface(http://megaface.cs.washington.edu) 13 | except that we merge all files into a single binary file. 14 | 15 | for examples: 16 | 17 | when megaface: N * (512, 1) 18 | while deepglint:(N, 512) 19 | 20 | """ 21 | import struct 22 | import numpy as np 23 | import sys, os 24 | import argparse 25 | 26 | cv_type_to_dtype = { 27 | 5: np.dtype('float32') 28 | } 29 | 30 | dtype_to_cv_type = {v: k for k, v in cv_type_to_dtype.items()} 31 | 32 | 33 | def write_mat(f, m): 34 | """Write mat m to file f""" 35 | if len(m.shape) == 1: 36 | rows = m.shape[0] 37 | cols = 1 38 | else: 39 | rows, cols = m.shape 40 | header = struct.pack('iiii', rows, cols, cols * 4, dtype_to_cv_type[m.dtype]) 41 | f.write(header) 42 | f.write(m.data) 43 | 44 | 45 | def read_mat(f): 46 | """ 47 | Reads an OpenCV mat from the given file opened in binary mode 48 | """ 49 | rows, cols, stride, type_ = struct.unpack('iiii', f.read(4 * 4)) 50 | mat = np.fromstring(f.read(rows * stride), dtype=cv_type_to_dtype[type_]) 51 | return mat.reshape(rows, cols) 52 | 53 | 54 | def load_mat(filename): 55 | """ 56 | Reads a OpenCV Mat from the given filename 57 | """ 58 | return read_mat(open(filename, 'rb')) 59 | 60 | 61 | def save_mat(filename, m): 62 | """Saves mat m to the given filename""" 63 | return write_mat(open(filename, 'wb'), m) 64 | 65 | 66 | 67 | def main(args): 68 | 69 | deepglint_features = args.deepglint_features_path 70 | # merge all features into one file 71 | total_feature = [] 72 | total_files = [] 73 | for root, dirs, files in os.walk(deepglint_features): 74 | for file in files: 75 | filename = os.path.join(root, file) 76 | ext = os.path.splitext(filename)[1] 77 | ext = ext.lower() 78 | if ext in ('.feat'): 79 | total_files.append(filename) 80 | 81 | assert len(total_files) == 1862120 82 | total_files.sort() # important 83 | 84 | for i in range(len(total_files)): 85 | filename = total_files[i] 86 | tmp_feature = load_mat(filename) 87 | # print(filename) 88 | # print(tmp_feature.shape) 89 | tmp_feature = tmp_feature.T 90 | total_feature.append(tmp_feature) 91 | print(i + 1, tmp_feature.shape) 92 | # write_mat(feature_path_out, feature_fusion) 93 | 94 | print('total feature number: ', len(total_feature)) 95 | total_feature = np.array(total_feature).squeeze() 96 | print(total_feature.shape, total_feature.dtype, type(total_feature)) 97 | save_mat('deepglint_test_feature.bin', total_feature) 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument("--deepglint_features_path", type=str, default="/home/wujiyang/deepglint/deepglint_feature_ir+ws/") 103 | args = parser.parse_args() 104 | 105 | main(args) 106 | -------------------------------------------------------------------------------- /eval_lfw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: eval_lfw.py 7 | @time: 2018/12/22 9:47 8 | @desc: 9 | ''' 10 | 11 | import numpy as np 12 | import scipy.io 13 | import os 14 | import json 15 | import torch.utils.data 16 | from backbone import mobilefacenet, resnet, arcfacenet, cbam 17 | from dataset.lfw import LFW 18 | import torchvision.transforms as transforms 19 | from torch.nn import DataParallel 20 | import argparse 21 | 22 | def getAccuracy(scores, flags, threshold): 23 | p = np.sum(scores[flags == 1] > threshold) 24 | n = np.sum(scores[flags == -1] < threshold) 25 | return 1.0 * (p + n) / len(scores) 26 | 27 | def getThreshold(scores, flags, thrNum): 28 | accuracys = np.zeros((2 * thrNum + 1, 1)) 29 | thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum 30 | for i in range(2 * thrNum + 1): 31 | accuracys[i] = getAccuracy(scores, flags, thresholds[i]) 32 | max_index = np.squeeze(accuracys == np.max(accuracys)) 33 | bestThreshold = np.mean(thresholds[max_index]) 34 | return bestThreshold 35 | 36 | def evaluation_10_fold(feature_path='./result/cur_epoch_result.mat'): 37 | ACCs = np.zeros(10) 38 | result = scipy.io.loadmat(feature_path) 39 | for i in range(10): 40 | fold = result['fold'] 41 | flags = result['flag'] 42 | featureLs = result['fl'] 43 | featureRs = result['fr'] 44 | 45 | valFold = fold != i 46 | testFold = fold == i 47 | flags = np.squeeze(flags) 48 | 49 | mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0) 50 | mu = np.expand_dims(mu, 0) 51 | featureLs = featureLs - mu 52 | featureRs = featureRs - mu 53 | featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1) 54 | featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1) 55 | 56 | scores = np.sum(np.multiply(featureLs, featureRs), 1) 57 | threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000) 58 | ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold) 59 | 60 | return ACCs 61 | 62 | def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None): 63 | 64 | if backbone_net == 'MobileFace': 65 | net = mobilefacenet.MobileFaceNet() 66 | elif backbone_net == 'CBAM_50': 67 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir') 68 | elif backbone_net == 'CBAM_50_SE': 69 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se') 70 | elif backbone_net == 'CBAM_100': 71 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir') 72 | elif backbone_net == 'CBAM_100_SE': 73 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se') 74 | else: 75 | print(backbone_net, ' is not available!') 76 | 77 | # gpu init 78 | multi_gpus = False 79 | if len(gpus.split(',')) > 1: 80 | multi_gpus = True 81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 83 | 84 | net.load_state_dict(torch.load(resume)['net_state_dict']) 85 | 86 | if multi_gpus: 87 | net = DataParallel(net).to(device) 88 | else: 89 | net = net.to(device) 90 | 91 | transform = transforms.Compose([ 92 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 93 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 94 | ]) 95 | lfw_dataset = LFW(data_root, file_list, transform=transform) 96 | lfw_loader = torch.utils.data.DataLoader(lfw_dataset, batch_size=128, 97 | shuffle=False, num_workers=2, drop_last=False) 98 | 99 | return net.eval(), device, lfw_dataset, lfw_loader 100 | 101 | def getFeatureFromTorch(feature_save_dir, net, device, data_set, data_loader): 102 | featureLs = None 103 | featureRs = None 104 | count = 0 105 | for data in data_loader: 106 | for i in range(len(data)): 107 | data[i] = data[i].to(device) 108 | count += data[0].size(0) 109 | #print('extracing deep features from the face pair {}...'.format(count)) 110 | with torch.no_grad(): 111 | res = [net(d).data.cpu().numpy() for d in data] 112 | featureL = np.concatenate((res[0], res[1]), 1) 113 | featureR = np.concatenate((res[2], res[3]), 1) 114 | # print(featureL.shape, featureR.shape) 115 | if featureLs is None: 116 | featureLs = featureL 117 | else: 118 | featureLs = np.concatenate((featureLs, featureL), 0) 119 | if featureRs is None: 120 | featureRs = featureR 121 | else: 122 | featureRs = np.concatenate((featureRs, featureR), 0) 123 | # print(featureLs.shape, featureRs.shape) 124 | 125 | result = {'fl': featureLs, 'fr': featureRs, 'fold': data_set.folds, 'flag': data_set.flags} 126 | scipy.io.savemat(feature_save_dir, result) 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser(description='Testing') 130 | parser.add_argument('--root', type=str, default='/media/sda/lfw/lfw_align_112', help='The path of lfw data') 131 | parser.add_argument('--file_list', type=str, default='/media/sda/lfw/pairs.txt', help='The path of lfw data') 132 | parser.add_argument('--backbone_net', type=str, default='CBAM_100_SE', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE') 133 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension') 134 | parser.add_argument('--resume', type=str, default='./model/SERES100_SERES100_IR_20190528_132635/Iter_342000_net.ckpt', 135 | help='The path pf save model') 136 | parser.add_argument('--feature_save_path', type=str, default='./result/cur_epoch_lfw_result.mat', 137 | help='The path of the extract features save, must be .mat file') 138 | parser.add_argument('--gpus', type=str, default='1,3', help='gpu list') 139 | args = parser.parse_args() 140 | 141 | net, device, lfw_dataset, lfw_loader = loadModel(args.root, args.file_list, args.backbone_net, args.gpus, args.resume) 142 | getFeatureFromTorch(args.feature_save_path, net, device, lfw_dataset, lfw_loader) 143 | ACCs = evaluation_10_fold(args.feature_save_path) 144 | for i in range(len(ACCs)): 145 | print('{} {:.2f}'.format(i+1, ACCs[i] * 100)) 146 | print('--------') 147 | print('AVE {:.4f}'.format(np.mean(ACCs) * 100)) 148 | -------------------------------------------------------------------------------- /eval_lfw_blufr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: eval_lfw_blufr.py 7 | @time: 2019/1/17 15:52 8 | @desc: test lfw accuracy on blufr protocol 9 | ''' 10 | ''' 11 | LFW BLUFR TEST PROTOCOL 12 | 13 | Official Website: http://www.cbsr.ia.ac.cn/users/scliao/projects/blufr/ 14 | 15 | When I try to do this, I find that the blufr_lfw_config.mat file provided by above site is too old. 16 | Some image files listed in the mat have been removed in lfw pairs.txt 17 | So this work is suspended for now... 18 | ''' 19 | 20 | import scipy.io as sio 21 | import argparse 22 | 23 | def readName(file='pairs.txt'): 24 | name_list = [] 25 | f = open(file, 'r') 26 | lines = f.readlines() 27 | 28 | for line in lines[1:]: 29 | line_split = line.rstrip().split() 30 | if len(line_split) == 3: 31 | name_list.append(line_split[0]) 32 | elif len(line_split) == 4: 33 | name_list.append(line_split[0]) 34 | name_list.append(line_split[2]) 35 | else: 36 | print('wrong file, please check again') 37 | 38 | return list(set(name_list)) 39 | 40 | 41 | def main(args): 42 | blufr_info = sio.loadmat(args.lfw_blufr_file) 43 | #print(blufr_info) 44 | name_list = readName() 45 | 46 | image = blufr_info['imageList'] 47 | missing_files = [] 48 | for i in range(image.shape[0]): 49 | name = image[i][0][0] 50 | index = name.rfind('_') 51 | name = name[0:index] 52 | if name not in name_list: 53 | print(name) 54 | missing_files.append(name) 55 | print('lfw pairs.txt total persons: ', len(name_list)) 56 | print('blufr_mat_missing persons: ', len(missing_files)) 57 | 58 | ''' 59 | Some of the missing file: 60 | Zdravko_Mucic 61 | Zelma_Novelo 62 | Zeng_Qinghong 63 | Zumrati_Juma 64 | lfw pairs.txt total persons: 4281 65 | blufr_mat_missing persons: 1549 66 | 67 | ''' 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser(description='lfw blufr test') 71 | parser.add_argument('--lfw_blufr_file', type=str, default='./blufr_lfw_config.mat', help='feature dimension') 72 | parser.add_argument('--lfw_pairs.txt', type=str, default='./pairs.txt', help='feature dimension') 73 | parser.add_argument('--gpus', type=str, default='2,3', help='gpu list') 74 | args = parser.parse_args() 75 | 76 | main(args) -------------------------------------------------------------------------------- /eval_megaface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: eval_megaface.py 7 | @time: 2018/12/24 16:28 8 | @desc: megaface feature extractor 9 | ''' 10 | import numpy as np 11 | import struct 12 | import os 13 | import torch.utils.data 14 | from backbone import mobilefacenet, cbam, self_attention 15 | from dataset.megaface import MegaFace 16 | import torchvision.transforms as transforms 17 | import argparse 18 | from torch.nn import DataParallel 19 | 20 | 21 | cv_type_to_dtype = {5: np.dtype('float32'), 6: np.dtype('float64')} 22 | dtype_to_cv_type = {v: k for k, v in cv_type_to_dtype.items()} 23 | 24 | def write_mat(filename, m): 25 | """Write mat m to file f""" 26 | if len(m.shape) == 1: 27 | rows = m.shape[0] 28 | cols = 1 29 | else: 30 | rows, cols = m.shape 31 | header = struct.pack('iiii', rows, cols, cols * 4, dtype_to_cv_type[m.dtype]) 32 | 33 | with open(filename, 'wb') as outfile: 34 | outfile.write(header) 35 | outfile.write(m.data) 36 | 37 | 38 | def read_mat(filename): 39 | """ 40 | Reads an OpenCV mat from the given file opened in binary mode 41 | """ 42 | with open(filename, 'rb') as fin: 43 | rows, cols, stride, type_ = struct.unpack('iiii', fin.read(4 * 4)) 44 | mat = np.fromstring(str(fin.read(rows * stride)), dtype=cv_type_to_dtype[type_]) 45 | return mat.reshape(rows, cols) 46 | 47 | 48 | def extract_feature(model_path, backbone_net, face_scrub_path, megaface_path, batch_size=32, gpus='0', do_norm=False): 49 | 50 | if backbone_net == 'MobileFace': 51 | net = mobilefacenet.MobileFaceNet() 52 | elif backbone_net == 'CBAM_50': 53 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir') 54 | elif backbone_net == 'CBAM_50_SE': 55 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se') 56 | elif backbone_net == 'CBAM_100': 57 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir') 58 | elif backbone_net == 'CBAM_100_SE': 59 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se') 60 | else: 61 | print(args.backbone, ' is not available!') 62 | 63 | multi_gpus = False 64 | if len(gpus.split(',')) > 1: 65 | multi_gpus = True 66 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus 67 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 68 | 69 | net.load_state_dict(torch.load(model_path)['net_state_dict']) 70 | if multi_gpus: 71 | net = DataParallel(net).to(device) 72 | else: 73 | net = net.to(device) 74 | net.eval() 75 | 76 | transform = transforms.Compose([ 77 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 78 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 79 | ]) 80 | megaface_dataset = MegaFace(face_scrub_path, megaface_path, transform=transform) 81 | megaface_loader = torch.utils.data.DataLoader(megaface_dataset, batch_size=batch_size, 82 | shuffle=False, num_workers=12, drop_last=False) 83 | 84 | for data in megaface_loader: 85 | img, img_path= data[0].to(device), data[1] 86 | with torch.no_grad(): 87 | output = net(img).data.cpu().numpy() 88 | 89 | if do_norm is False: 90 | for i in range(len(img_path)): 91 | abs_path = img_path[i] + '.feat' 92 | write_mat(abs_path, output[i]) 93 | print('extract 1 batch...without feature normalization') 94 | else: 95 | for i in range(len(img_path)): 96 | abs_path = img_path[i] + '.feat' 97 | feat = output[i] 98 | feat = feat / np.sqrt((np.dot(feat, feat))) 99 | write_mat(abs_path, feat) 100 | print('extract 1 batch...with feature normalization') 101 | print('all images have been processed!') 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser(description='Testing') 106 | parser.add_argument('--model_path', type=str, default='./model/RES100_RES100_IR_20190423_100728/Iter_333000_net.ckpt', help='The path of trained model') 107 | parser.add_argument('--backbone_net', type=str, default='CBAM_100', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE') 108 | parser.add_argument('--facescrub_dir', type=str, default='/media/sda/megaface_test_kit/facescrub_align_112/', help='facescrub data') 109 | parser.add_argument('--megaface_dir', type=str, default='/media/sda/megaface_test_kit/megaface_align_112/', help='megaface data') 110 | parser.add_argument('--batch_size', type=int, default=1024, help='batch size') 111 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension') 112 | parser.add_argument('--gpus', type=str, default='0,1,2,3', help='gpu list') 113 | parser.add_argument("--do_norm", type=int, default=1, help="1 if normalize feature, 0 do nothing(Default case)") 114 | args = parser.parse_args() 115 | 116 | extract_feature(args.model_path, args.backbone_net, args.facescrub_dir, args.megaface_dir, args.batch_size, args.gpus, args.do_norm) -------------------------------------------------------------------------------- /lossfunctions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: __init__.py.py 7 | @time: 2019/1/4 15:24 8 | @desc: 9 | ''' -------------------------------------------------------------------------------- /lossfunctions/agentcenterloss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: agentcenterloss.py 7 | @time: 2019/1/7 10:53 8 | @desc: the variety of center loss, which use the class weight as the class center and normalize both the weight and feature, 9 | in this way, the cos distance of weight and feature can be used as the supervised signal. 10 | It's similar with torch.nn.CosineEmbeddingLoss, x_1 means weight_i, x_2 means feature_i. 11 | ''' 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | class AgentCenterLoss(nn.Module): 18 | 19 | def __init__(self, num_classes, feat_dim, scale): 20 | super(AgentCenterLoss, self).__init__() 21 | self.num_classes = num_classes 22 | self.feat_dim = feat_dim 23 | self.scale = scale 24 | 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 26 | 27 | def forward(self, x, labels): 28 | ''' 29 | Parameters: 30 | x: input tensor with shape (batch_size, feat_dim) 31 | labels: ground truth label with shape (batch_size) 32 | Return: 33 | loss of centers 34 | ''' 35 | cos_dis = F.linear(F.normalize(x), F.normalize(self.centers)) * self.scale 36 | 37 | one_hot = torch.zeros_like(cos_dis) 38 | one_hot.scatter_(1, labels.view(-1, 1), 1.0) 39 | 40 | # loss = 1 - cosine(i) 41 | loss = one_hot * self.scale - (one_hot * cos_dis) 42 | 43 | return loss.mean() -------------------------------------------------------------------------------- /lossfunctions/centerloss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: centerloss.py 7 | @time: 2019/1/4 15:24 8 | @desc: the implementation of center loss 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | class CenterLoss(nn.Module): 15 | 16 | def __init__(self, num_classes, feat_dim): 17 | super(CenterLoss, self).__init__() 18 | self.num_classes = num_classes 19 | self.feat_dim = feat_dim 20 | 21 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 22 | 23 | def forward(self, x, labels): 24 | ''' 25 | Parameters: 26 | x: input tensor with shape (batch_size, feat_dim) 27 | labels: ground truth label with shape (batch_size) 28 | Return: 29 | loss of centers 30 | ''' 31 | # compute the distance of (x-center)^2 32 | batch_size = x.size(0) 33 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 34 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 35 | distmat.addmm_(1, -2, x, self.centers.t()) 36 | 37 | # get one_hot matrix 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | classes = torch.arange(self.num_classes).long().to(device) 40 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 41 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 42 | 43 | dist = [] 44 | for i in range(batch_size): 45 | value = distmat[i][mask[i]] 46 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 47 | dist.append(value) 48 | dist = torch.cat(dist) 49 | loss = dist.mean() 50 | 51 | return loss -------------------------------------------------------------------------------- /margin/ArcMarginProduct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: ArcMarginProduct.py 7 | @time: 2018/12/25 9:13 8 | @desc: additive angular margin for arcface/insightface 9 | ''' 10 | 11 | import math 12 | import torch 13 | from torch import nn 14 | from torch.nn import Parameter 15 | import torch.nn.functional as F 16 | 17 | class ArcMarginProduct(nn.Module): 18 | def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False): 19 | super(ArcMarginProduct, self).__init__() 20 | self.in_feature = in_feature 21 | self.out_feature = out_feature 22 | self.s = s 23 | self.m = m 24 | self.weight = Parameter(torch.Tensor(out_feature, in_feature)) 25 | nn.init.xavier_uniform_(self.weight) 26 | 27 | self.easy_margin = easy_margin 28 | self.cos_m = math.cos(m) 29 | self.sin_m = math.sin(m) 30 | 31 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] 32 | self.th = math.cos(math.pi - m) 33 | self.mm = math.sin(math.pi - m) * m 34 | 35 | def forward(self, x, label): 36 | # cos(theta) 37 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 38 | # cos(theta + m) 39 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 40 | phi = cosine * self.cos_m - sine * self.sin_m 41 | 42 | if self.easy_margin: 43 | phi = torch.where(cosine > 0, phi, cosine) 44 | else: 45 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 46 | 47 | #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu') 48 | one_hot = torch.zeros_like(cosine) 49 | one_hot.scatter_(1, label.view(-1, 1), 1) 50 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 51 | output = output * self.s 52 | 53 | return output 54 | 55 | 56 | if __name__ == '__main__': 57 | pass -------------------------------------------------------------------------------- /margin/CosineMarginProduct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: CosineMarginProduct.py 7 | @time: 2018/12/25 9:13 8 | @desc: additive cosine margin for cosface 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn import Parameter 15 | 16 | 17 | class CosineMarginProduct(nn.Module): 18 | def __init__(self, in_feature=128, out_feature=10575, s=30.0, m=0.35): 19 | super(CosineMarginProduct, self).__init__() 20 | self.in_feature = in_feature 21 | self.out_feature = out_feature 22 | self.s = s 23 | self.m = m 24 | self.weight = Parameter(torch.Tensor(out_feature, in_feature)) 25 | nn.init.xavier_uniform_(self.weight) 26 | 27 | 28 | def forward(self, input, label): 29 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 30 | # one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu') 31 | one_hot = torch.zeros_like(cosine) 32 | one_hot.scatter_(1, label.view(-1, 1), 1.0) 33 | 34 | output = self.s * (cosine - one_hot * self.m) 35 | return output 36 | 37 | 38 | if __name__ == '__main__': 39 | pass -------------------------------------------------------------------------------- /margin/InnerProduct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: InnerProduct.py 7 | @time: 2019/1/4 16:54 8 | @desc: just normal inner product as fully connected layer do. 9 | ''' 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.nn import Parameter 14 | 15 | class InnerProduct(nn.Module): 16 | def __init__(self, in_feature=128, out_feature=10575): 17 | super(InnerProduct, self).__init__() 18 | self.in_feature = in_feature 19 | self.out_feature = out_feature 20 | 21 | self.weight = Parameter(torch.Tensor(out_feature, in_feature)) 22 | nn.init.xavier_uniform_(self.weight) 23 | 24 | 25 | def forward(self, input, label): 26 | # label not used 27 | output = F.linear(input, self.weight) 28 | return output 29 | 30 | 31 | if __name__ == '__main__': 32 | pass -------------------------------------------------------------------------------- /margin/MultiMarginProduct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: MultiMarginProduct.py 7 | @time: 2019/3/30 10:09 8 | @desc: Combination of additive angular margin and additive cosine margin 9 | ''' 10 | 11 | import math 12 | import torch 13 | from torch import nn 14 | from torch.nn import Parameter 15 | import torch.nn.functional as F 16 | 17 | class MultiMarginProduct(nn.Module): 18 | def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False): 19 | super(MultiMarginProduct, self).__init__() 20 | self.in_feature = in_feature 21 | self.out_feature = out_feature 22 | self.s = s 23 | self.m1 = m1 24 | self.m2 = m2 25 | self.weight = Parameter(torch.Tensor(out_feature, in_feature)) 26 | nn.init.xavier_uniform_(self.weight) 27 | 28 | self.easy_margin = easy_margin 29 | self.cos_m1 = math.cos(m1) 30 | self.sin_m1 = math.sin(m1) 31 | 32 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°] 33 | self.th = math.cos(math.pi - m1) 34 | self.mm = math.sin(math.pi - m1) * m1 35 | 36 | def forward(self, x, label): 37 | # cos(theta) 38 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 39 | # cos(theta + m1) 40 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 41 | phi = cosine * self.cos_m1 - sine * self.sin_m1 42 | 43 | if self.easy_margin: 44 | phi = torch.where(cosine > 0, phi, cosine) 45 | else: 46 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 47 | 48 | 49 | one_hot = torch.zeros_like(cosine) 50 | one_hot.scatter_(1, label.view(-1, 1), 1) 51 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin 52 | output = output - one_hot * self.m2 # additive cosine margin 53 | output = output * self.s 54 | 55 | return output 56 | 57 | 58 | if __name__ == '__main__': 59 | pass -------------------------------------------------------------------------------- /margin/SphereMarginProduct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: SphereMarginProduct.py 7 | @time: 2018/12/25 9:19 8 | @desc: multiplicative angular margin for sphereface 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn import Parameter 15 | import math 16 | 17 | class SphereMarginProduct(nn.Module): 18 | def __init__(self, in_feature, out_feature, m=4, base=1000.0, gamma=0.0001, power=2, lambda_min=5.0, iter=0): 19 | assert m in [1, 2, 3, 4], 'margin should be 1, 2, 3 or 4' 20 | self.in_feature = in_feature 21 | self.out_feature = out_feature 22 | self.m = m 23 | self.base = base 24 | self.gamma = gamma 25 | self.power = power 26 | self.lambda_min = lambda_min 27 | self.iter = 0 28 | self.weight = Parameter(torch.Tensor(out_feature, in_feature)) 29 | nn.init.xavier_uniform_(self.weight) 30 | 31 | # duplication formula 32 | self.margin_formula = [ 33 | lambda x : x ** 0, 34 | lambda x : x ** 1, 35 | lambda x : 2 * x ** 2 - 1, 36 | lambda x : 4 * x ** 3 - 3 * x, 37 | lambda x : 8 * x ** 4 - 8 * x ** 2 + 1, 38 | lambda x : 16 * x ** 5 - 20 * x ** 3 + 5 * x 39 | ] 40 | 41 | def forward(self, input, label): 42 | self.iter += 1 43 | self.cur_lambda = max(self.lambda_min, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power)) 44 | 45 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight)) 46 | cos_theta = cos_theta(-1, 1) 47 | 48 | cos_m_theta = self.margin_formula(self.m)(cos_theta) 49 | theta = cos_theta.data.acos() 50 | k = ((self.m * theta) / math.pi).floor() 51 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k 52 | phi_theta_ = (self.cur_lambda * cos_theta + phi_theta) / (1 + self.cur_lambda) 53 | norm_of_feature = torch.norm(input, 2, 1) 54 | 55 | one_hot = torch.zeros_like(cos_theta) 56 | one_hot.scatter_(1, label.view(-1, 1), 1) 57 | 58 | output = one_hot * phi_theta_ + (1 - one_hot) * cos_theta 59 | output *= norm_of_feature.view(-1, 1) 60 | 61 | return output 62 | 63 | 64 | if __name__ == '__main__': 65 | pass -------------------------------------------------------------------------------- /margin/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: __init__.py.py 7 | @time: 2018/12/25 9:12 8 | @desc: 9 | ''' -------------------------------------------------------------------------------- /model/MSCeleb_MOBILEFACE_20181228_170458/log.log: -------------------------------------------------------------------------------- 1 | 20181228-17:05:06 Train Epoch: 1/30 ... 2 | 20181228-17:35:50 Saving checkpoint: 5000 3 | 20181228-17:36:10 LFW Ave Accuracy: 96.4333 4 | 20181228-17:36:30 AgeDB-30 Ave Accuracy: 82.7000 5 | 20181228-17:36:52 CFP-FP Ave Accuracy: 78.9286 6 | 20181228-17:36:52 Current Best Accuracy: LFW: 96.4333 in iters: 5000, AgeDB-30: 82.7000 in iters: 5000 and CFP-FP: 78.9286 in iters: 5000 7 | 20181228-18:06:33 Train Epoch: 2/30 ... 8 | 20181228-18:07:49 Saving checkpoint: 10000 9 | 20181228-18:08:09 LFW Ave Accuracy: 98.2000 10 | 20181228-18:08:29 AgeDB-30 Ave Accuracy: 89.2667 11 | 20181228-18:08:52 CFP-FP Ave Accuracy: 83.9714 12 | 20181228-18:08:52 Current Best Accuracy: LFW: 98.2000 in iters: 10000, AgeDB-30: 89.2667 in iters: 10000 and CFP-FP: 83.9714 in iters: 10000 13 | 20181228-18:39:39 Saving checkpoint: 15000 14 | 20181228-18:39:59 LFW Ave Accuracy: 98.6167 15 | 20181228-18:40:18 AgeDB-30 Ave Accuracy: 90.2500 16 | 20181228-18:40:39 CFP-FP Ave Accuracy: 85.6143 17 | 20181228-18:40:39 Current Best Accuracy: LFW: 98.6167 in iters: 15000, AgeDB-30: 90.2500 in iters: 15000 and CFP-FP: 85.6143 in iters: 15000 18 | 20181228-19:09:10 Train Epoch: 3/30 ... 19 | 20181228-19:11:37 Saving checkpoint: 20000 20 | 20181228-19:11:57 LFW Ave Accuracy: 98.5833 21 | 20181228-19:12:16 AgeDB-30 Ave Accuracy: 91.2667 22 | 20181228-19:12:37 CFP-FP Ave Accuracy: 86.1429 23 | 20181228-19:12:37 Current Best Accuracy: LFW: 98.6167 in iters: 15000, AgeDB-30: 91.2667 in iters: 20000 and CFP-FP: 86.1429 in iters: 20000 24 | 20181228-19:43:26 Saving checkpoint: 25000 25 | 20181228-19:43:46 LFW Ave Accuracy: 98.8667 26 | 20181228-19:44:06 AgeDB-30 Ave Accuracy: 91.6333 27 | 20181228-19:44:27 CFP-FP Ave Accuracy: 85.9714 28 | 20181228-19:44:27 Current Best Accuracy: LFW: 98.8667 in iters: 25000, AgeDB-30: 91.6333 in iters: 25000 and CFP-FP: 86.1429 in iters: 20000 29 | 20181228-20:11:45 Train Epoch: 4/30 ... 30 | 20181228-20:15:23 Saving checkpoint: 30000 31 | 20181228-20:15:43 LFW Ave Accuracy: 98.7167 32 | 20181228-20:16:02 AgeDB-30 Ave Accuracy: 91.4500 33 | 20181228-20:16:25 CFP-FP Ave Accuracy: 85.4714 34 | 20181228-20:16:25 Current Best Accuracy: LFW: 98.8667 in iters: 25000, AgeDB-30: 91.6333 in iters: 25000 and CFP-FP: 86.1429 in iters: 20000 35 | 20181228-20:47:10 Saving checkpoint: 35000 36 | 20181228-20:47:30 LFW Ave Accuracy: 98.9333 37 | 20181228-20:47:49 AgeDB-30 Ave Accuracy: 92.4167 38 | 20181228-20:48:11 CFP-FP Ave Accuracy: 86.4714 39 | 20181228-20:48:11 Current Best Accuracy: LFW: 98.9333 in iters: 35000, AgeDB-30: 92.4167 in iters: 35000 and CFP-FP: 86.4714 in iters: 35000 40 | 20181228-21:14:18 Train Epoch: 5/30 ... 41 | 20181228-21:19:06 Saving checkpoint: 40000 42 | 20181228-21:19:25 LFW Ave Accuracy: 98.8667 43 | 20181228-21:19:44 AgeDB-30 Ave Accuracy: 92.2167 44 | 20181228-21:20:07 CFP-FP Ave Accuracy: 87.1429 45 | 20181228-21:20:07 Current Best Accuracy: LFW: 98.9333 in iters: 35000, AgeDB-30: 92.4167 in iters: 35000 and CFP-FP: 87.1429 in iters: 40000 46 | 20181228-21:50:54 Saving checkpoint: 45000 47 | 20181228-21:51:14 LFW Ave Accuracy: 98.8333 48 | 20181228-21:51:33 AgeDB-30 Ave Accuracy: 91.9167 49 | 20181228-21:51:54 CFP-FP Ave Accuracy: 86.6714 50 | 20181228-21:51:54 Current Best Accuracy: LFW: 98.9333 in iters: 35000, AgeDB-30: 92.4167 in iters: 35000 and CFP-FP: 87.1429 in iters: 40000 51 | 20181228-22:16:47 Train Epoch: 6/30 ... 52 | 20181228-22:22:46 Saving checkpoint: 50000 53 | 20181228-22:23:06 LFW Ave Accuracy: 99.1500 54 | 20181228-22:23:25 AgeDB-30 Ave Accuracy: 92.7333 55 | 20181228-22:23:48 CFP-FP Ave Accuracy: 87.3714 56 | 20181228-22:23:48 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.7333 in iters: 50000 and CFP-FP: 87.3714 in iters: 50000 57 | 20181228-22:54:35 Saving checkpoint: 55000 58 | 20181228-22:54:56 LFW Ave Accuracy: 98.8833 59 | 20181228-22:55:16 AgeDB-30 Ave Accuracy: 92.3333 60 | 20181228-22:55:37 CFP-FP Ave Accuracy: 86.7857 61 | 20181228-22:55:37 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.7333 in iters: 50000 and CFP-FP: 87.3714 in iters: 50000 62 | 20181228-23:19:21 Train Epoch: 7/30 ... 63 | 20181228-23:26:29 Saving checkpoint: 60000 64 | 20181228-23:26:49 LFW Ave Accuracy: 98.9000 65 | 20181228-23:27:08 AgeDB-30 Ave Accuracy: 92.9333 66 | 20181228-23:27:30 CFP-FP Ave Accuracy: 87.4857 67 | 20181228-23:27:30 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.9333 in iters: 60000 and CFP-FP: 87.4857 in iters: 60000 68 | 20181228-23:58:25 Saving checkpoint: 65000 69 | 20181228-23:58:45 LFW Ave Accuracy: 99.0500 70 | 20181228-23:59:04 AgeDB-30 Ave Accuracy: 92.4000 71 | 20181228-23:59:26 CFP-FP Ave Accuracy: 86.9857 72 | 20181228-23:59:26 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.9333 in iters: 60000 and CFP-FP: 87.4857 in iters: 60000 73 | 20181229-00:22:10 Train Epoch: 8/30 ... 74 | 20181229-00:30:28 Saving checkpoint: 70000 75 | 20181229-00:30:47 LFW Ave Accuracy: 98.9167 76 | 20181229-00:31:06 AgeDB-30 Ave Accuracy: 93.1000 77 | 20181229-00:31:28 CFP-FP Ave Accuracy: 87.3857 78 | 20181229-00:31:28 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.1000 in iters: 70000 and CFP-FP: 87.4857 in iters: 60000 79 | 20181229-01:02:13 Saving checkpoint: 75000 80 | 20181229-01:02:33 LFW Ave Accuracy: 98.9000 81 | 20181229-01:02:52 AgeDB-30 Ave Accuracy: 92.3167 82 | 20181229-01:03:14 CFP-FP Ave Accuracy: 87.5857 83 | 20181229-01:03:14 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.1000 in iters: 70000 and CFP-FP: 87.5857 in iters: 75000 84 | 20181229-01:24:34 Train Epoch: 9/30 ... 85 | 20181229-01:34:02 Saving checkpoint: 80000 86 | 20181229-01:34:22 LFW Ave Accuracy: 98.9667 87 | 20181229-01:34:41 AgeDB-30 Ave Accuracy: 93.3833 88 | 20181229-01:35:02 CFP-FP Ave Accuracy: 87.8714 89 | 20181229-01:35:02 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 87.8714 in iters: 80000 90 | 20181229-02:05:40 Saving checkpoint: 85000 91 | 20181229-02:06:01 LFW Ave Accuracy: 99.0667 92 | 20181229-02:06:20 AgeDB-30 Ave Accuracy: 92.8000 93 | 20181229-02:06:41 CFP-FP Ave Accuracy: 87.2714 94 | 20181229-02:06:41 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 87.8714 in iters: 80000 95 | 20181229-02:26:49 Train Epoch: 10/30 ... 96 | 20181229-02:37:26 Saving checkpoint: 90000 97 | 20181229-02:37:46 LFW Ave Accuracy: 98.7167 98 | 20181229-02:38:05 AgeDB-30 Ave Accuracy: 92.3500 99 | 20181229-02:38:26 CFP-FP Ave Accuracy: 88.0143 100 | 20181229-02:38:26 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 88.0143 in iters: 90000 101 | 20181229-03:09:07 Saving checkpoint: 95000 102 | 20181229-03:09:27 LFW Ave Accuracy: 98.8667 103 | 20181229-03:09:46 AgeDB-30 Ave Accuracy: 93.2333 104 | 20181229-03:10:08 CFP-FP Ave Accuracy: 87.6286 105 | 20181229-03:10:08 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 88.0143 in iters: 90000 106 | 20181229-03:29:07 Train Epoch: 11/30 ... 107 | 20181229-03:40:58 Saving checkpoint: 100000 108 | 20181229-03:41:17 LFW Ave Accuracy: 99.4333 109 | 20181229-03:41:36 AgeDB-30 Ave Accuracy: 95.0500 110 | 20181229-03:41:57 CFP-FP Ave Accuracy: 90.5000 111 | 20181229-03:41:57 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.0500 in iters: 100000 and CFP-FP: 90.5000 in iters: 100000 112 | 20181229-04:12:37 Saving checkpoint: 105000 113 | 20181229-04:12:58 LFW Ave Accuracy: 99.3500 114 | 20181229-04:13:17 AgeDB-30 Ave Accuracy: 95.3500 115 | 20181229-04:13:38 CFP-FP Ave Accuracy: 91.1429 116 | 20181229-04:13:38 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.3500 in iters: 105000 and CFP-FP: 91.1429 in iters: 105000 117 | 20181229-04:31:24 Train Epoch: 12/30 ... 118 | 20181229-04:44:22 Saving checkpoint: 110000 119 | 20181229-04:44:42 LFW Ave Accuracy: 99.3500 120 | 20181229-04:45:01 AgeDB-30 Ave Accuracy: 95.2000 121 | 20181229-04:45:23 CFP-FP Ave Accuracy: 90.8286 122 | 20181229-04:45:23 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.3500 in iters: 105000 and CFP-FP: 91.1429 in iters: 105000 123 | 20181229-05:15:59 Saving checkpoint: 115000 124 | 20181229-05:16:19 LFW Ave Accuracy: 99.3167 125 | 20181229-05:16:38 AgeDB-30 Ave Accuracy: 95.4167 126 | 20181229-05:17:00 CFP-FP Ave Accuracy: 90.5286 127 | 20181229-05:17:00 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.4167 in iters: 115000 and CFP-FP: 91.1429 in iters: 105000 128 | 20181229-05:33:34 Train Epoch: 13/30 ... 129 | 20181229-05:47:47 Saving checkpoint: 120000 130 | 20181229-05:48:07 LFW Ave Accuracy: 99.3833 131 | 20181229-05:48:27 AgeDB-30 Ave Accuracy: 95.6500 132 | 20181229-05:48:50 CFP-FP Ave Accuracy: 90.4714 133 | 20181229-05:48:50 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.1429 in iters: 105000 134 | 20181229-06:19:36 Saving checkpoint: 125000 135 | 20181229-06:19:56 LFW Ave Accuracy: 99.3333 136 | 20181229-06:20:16 AgeDB-30 Ave Accuracy: 95.3333 137 | 20181229-06:20:38 CFP-FP Ave Accuracy: 91.3143 138 | 20181229-06:20:38 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000 139 | 20181229-06:36:08 Train Epoch: 14/30 ... 140 | 20181229-06:51:23 Saving checkpoint: 130000 141 | 20181229-06:51:44 LFW Ave Accuracy: 99.2667 142 | 20181229-06:52:03 AgeDB-30 Ave Accuracy: 95.2500 143 | 20181229-06:52:24 CFP-FP Ave Accuracy: 90.7571 144 | 20181229-06:52:24 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000 145 | 20181229-07:22:58 Saving checkpoint: 135000 146 | 20181229-07:23:18 LFW Ave Accuracy: 99.4333 147 | 20181229-07:23:37 AgeDB-30 Ave Accuracy: 95.3667 148 | 20181229-07:23:59 CFP-FP Ave Accuracy: 90.9000 149 | 20181229-07:23:59 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000 150 | 20181229-07:38:14 Train Epoch: 15/30 ... 151 | 20181229-07:54:39 Saving checkpoint: 140000 152 | 20181229-07:54:59 LFW Ave Accuracy: 99.4167 153 | 20181229-07:55:18 AgeDB-30 Ave Accuracy: 95.5167 154 | 20181229-07:55:40 CFP-FP Ave Accuracy: 91.2286 155 | 20181229-07:55:40 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000 156 | 20181229-08:26:12 Saving checkpoint: 145000 157 | 20181229-08:26:32 LFW Ave Accuracy: 99.3667 158 | 20181229-08:26:51 AgeDB-30 Ave Accuracy: 95.5000 159 | 20181229-08:27:13 CFP-FP Ave Accuracy: 91.3571 160 | 20181229-08:27:13 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3571 in iters: 145000 161 | 20181229-08:40:15 Train Epoch: 16/30 ... 162 | 20181229-08:57:48 Saving checkpoint: 150000 163 | 20181229-08:58:07 LFW Ave Accuracy: 99.3167 164 | 20181229-08:58:27 AgeDB-30 Ave Accuracy: 95.5500 165 | 20181229-08:58:48 CFP-FP Ave Accuracy: 91.1857 166 | 20181229-08:58:48 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3571 in iters: 145000 167 | 20181229-09:29:19 Saving checkpoint: 155000 168 | 20181229-09:29:39 LFW Ave Accuracy: 99.3500 169 | 20181229-09:29:59 AgeDB-30 Ave Accuracy: 95.5000 170 | 20181229-09:30:22 CFP-FP Ave Accuracy: 91.3000 171 | 20181229-09:30:22 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3571 in iters: 145000 172 | 20181229-09:42:12 Train Epoch: 17/30 ... 173 | 20181229-10:00:53 Saving checkpoint: 160000 174 | 20181229-10:01:13 LFW Ave Accuracy: 99.5167 175 | 20181229-10:01:35 AgeDB-30 Ave Accuracy: 95.3333 176 | 20181229-10:01:56 CFP-FP Ave Accuracy: 91.6714 177 | 20181229-10:01:56 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000 178 | 20181229-10:32:23 Saving checkpoint: 165000 179 | 20181229-10:32:44 LFW Ave Accuracy: 99.4000 180 | 20181229-10:33:03 AgeDB-30 Ave Accuracy: 95.4167 181 | 20181229-10:33:24 CFP-FP Ave Accuracy: 90.5714 182 | 20181229-10:33:24 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000 183 | 20181229-10:44:05 Train Epoch: 18/30 ... 184 | 20181229-11:03:53 Saving checkpoint: 170000 185 | 20181229-11:04:13 LFW Ave Accuracy: 99.4000 186 | 20181229-11:04:34 AgeDB-30 Ave Accuracy: 95.6000 187 | 20181229-11:04:55 CFP-FP Ave Accuracy: 90.8571 188 | 20181229-11:04:55 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000 189 | 20181229-11:35:21 Saving checkpoint: 175000 190 | 20181229-11:35:40 LFW Ave Accuracy: 99.4167 191 | 20181229-11:35:59 AgeDB-30 Ave Accuracy: 95.3333 192 | 20181229-11:36:22 CFP-FP Ave Accuracy: 91.1000 193 | 20181229-11:36:22 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000 194 | 20181229-11:45:52 Train Epoch: 19/30 ... 195 | 20181229-12:06:48 Saving checkpoint: 180000 196 | 20181229-12:07:08 LFW Ave Accuracy: 99.5000 197 | 20181229-12:07:27 AgeDB-30 Ave Accuracy: 96.2167 198 | 20181229-12:07:48 CFP-FP Ave Accuracy: 92.1286 199 | 20181229-12:07:48 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 96.2167 in iters: 180000 and CFP-FP: 92.1286 in iters: 180000 200 | 20181229-12:38:08 Saving checkpoint: 185000 201 | 20181229-12:38:28 LFW Ave Accuracy: 99.5333 202 | 20181229-12:38:47 AgeDB-30 Ave Accuracy: 96.3833 203 | 20181229-12:39:10 CFP-FP Ave Accuracy: 92.2714 204 | 20181229-12:39:10 Current Best Accuracy: LFW: 99.5333 in iters: 185000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.2714 in iters: 185000 205 | 20181229-12:47:30 Train Epoch: 20/30 ... 206 | 20181229-13:09:32 Saving checkpoint: 190000 207 | 20181229-13:09:52 LFW Ave Accuracy: 99.5667 208 | 20181229-13:10:12 AgeDB-30 Ave Accuracy: 96.2167 209 | 20181229-13:10:35 CFP-FP Ave Accuracy: 92.5571 210 | 20181229-13:10:35 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.5571 in iters: 190000 211 | 20181229-13:40:57 Saving checkpoint: 195000 212 | 20181229-13:41:16 LFW Ave Accuracy: 99.5500 213 | 20181229-13:41:37 AgeDB-30 Ave Accuracy: 96.3667 214 | 20181229-13:41:58 CFP-FP Ave Accuracy: 92.5857 215 | 20181229-13:41:58 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.5857 in iters: 195000 216 | 20181229-13:49:09 Train Epoch: 21/30 ... 217 | 20181229-14:12:22 Saving checkpoint: 200000 218 | 20181229-14:12:42 LFW Ave Accuracy: 99.4667 219 | 20181229-14:13:01 AgeDB-30 Ave Accuracy: 96.3500 220 | 20181229-14:13:22 CFP-FP Ave Accuracy: 92.5571 221 | 20181229-14:13:22 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.5857 in iters: 195000 222 | 20181229-14:43:41 Saving checkpoint: 205000 223 | 20181229-14:44:00 LFW Ave Accuracy: 99.5333 224 | 20181229-14:44:19 AgeDB-30 Ave Accuracy: 96.1833 225 | 20181229-14:44:40 CFP-FP Ave Accuracy: 92.7429 226 | 20181229-14:44:40 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.7429 in iters: 205000 227 | 20181229-14:50:40 Train Epoch: 22/30 ... 228 | 20181229-15:15:03 Saving checkpoint: 210000 229 | 20181229-15:15:23 LFW Ave Accuracy: 99.5167 230 | 20181229-15:15:44 AgeDB-30 Ave Accuracy: 96.1333 231 | 20181229-15:16:06 CFP-FP Ave Accuracy: 92.7571 232 | 20181229-15:16:06 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.7571 in iters: 210000 233 | 20181229-15:46:26 Saving checkpoint: 215000 234 | 20181229-15:46:46 LFW Ave Accuracy: 99.4833 235 | 20181229-15:47:07 AgeDB-30 Ave Accuracy: 96.4500 236 | 20181229-15:47:28 CFP-FP Ave Accuracy: 92.6571 237 | 20181229-15:47:28 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.4500 in iters: 215000 and CFP-FP: 92.7571 in iters: 210000 238 | 20181229-15:52:19 Train Epoch: 23/30 ... 239 | 20181229-16:17:53 Saving checkpoint: 220000 240 | 20181229-16:18:12 LFW Ave Accuracy: 99.5500 241 | 20181229-16:18:31 AgeDB-30 Ave Accuracy: 96.0833 242 | 20181229-16:18:52 CFP-FP Ave Accuracy: 92.6000 243 | 20181229-16:18:52 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.4500 in iters: 215000 and CFP-FP: 92.7571 in iters: 210000 244 | 20181229-16:49:11 Saving checkpoint: 225000 245 | 20181229-16:49:30 LFW Ave Accuracy: 99.5833 246 | 20181229-16:49:50 AgeDB-30 Ave Accuracy: 96.3500 247 | 20181229-16:50:11 CFP-FP Ave Accuracy: 92.9000 248 | 20181229-16:50:11 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.4500 in iters: 215000 and CFP-FP: 92.9000 in iters: 225000 249 | 20181229-16:53:53 Train Epoch: 24/30 ... 250 | 20181229-17:20:36 Saving checkpoint: 230000 251 | 20181229-17:20:57 LFW Ave Accuracy: 99.5167 252 | 20181229-17:21:16 AgeDB-30 Ave Accuracy: 96.5667 253 | 20181229-17:21:37 CFP-FP Ave Accuracy: 92.6714 254 | 20181229-17:21:37 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000 255 | 20181229-17:52:01 Saving checkpoint: 235000 256 | 20181229-17:52:21 LFW Ave Accuracy: 99.5833 257 | 20181229-17:52:41 AgeDB-30 Ave Accuracy: 96.2500 258 | 20181229-17:53:03 CFP-FP Ave Accuracy: 92.7571 259 | 20181229-17:53:03 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000 260 | 20181229-17:55:35 Train Epoch: 25/30 ... 261 | 20181229-18:23:30 Saving checkpoint: 240000 262 | 20181229-18:23:50 LFW Ave Accuracy: 99.5333 263 | 20181229-18:24:10 AgeDB-30 Ave Accuracy: 96.2833 264 | 20181229-18:24:32 CFP-FP Ave Accuracy: 92.4143 265 | 20181229-18:24:32 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000 266 | 20181229-18:54:59 Saving checkpoint: 245000 267 | 20181229-18:55:19 LFW Ave Accuracy: 99.5167 268 | 20181229-18:55:39 AgeDB-30 Ave Accuracy: 96.1667 269 | 20181229-18:56:00 CFP-FP Ave Accuracy: 92.8143 270 | 20181229-18:56:00 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000 271 | 20181229-18:57:22 Train Epoch: 26/30 ... 272 | 20181229-19:26:30 Saving checkpoint: 250000 273 | 20181229-19:26:50 LFW Ave Accuracy: 99.5667 274 | 20181229-19:27:10 AgeDB-30 Ave Accuracy: 96.4500 275 | 20181229-19:27:32 CFP-FP Ave Accuracy: 92.6000 276 | 20181229-19:27:32 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000 277 | -------------------------------------------------------------------------------- /result/softmax.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/result/softmax.gif -------------------------------------------------------------------------------- /result/softmax_center.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/result/softmax_center.gif -------------------------------------------------------------------------------- /result/visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/result/visualization.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: train.py.py 7 | @time: 2018/12/21 17:37 8 | @desc: train script for deep face recognition 9 | ''' 10 | 11 | import os 12 | import torch.utils.data 13 | from torch.nn import DataParallel 14 | from datetime import datetime 15 | from backbone.mobilefacenet import MobileFaceNet 16 | from backbone.cbam import CBAMResNet 17 | from backbone.attention import ResidualAttentionNet_56, ResidualAttentionNet_92 18 | from margin.ArcMarginProduct import ArcMarginProduct 19 | from margin.MultiMarginProduct import MultiMarginProduct 20 | from margin.CosineMarginProduct import CosineMarginProduct 21 | from margin.InnerProduct import InnerProduct 22 | from utils.visualize import Visualizer 23 | from utils.logging import init_log 24 | from dataset.casia_webface import CASIAWebFace 25 | from dataset.lfw import LFW 26 | from dataset.agedb import AgeDB30 27 | from dataset.cfp import CFP_FP 28 | from torch.optim import lr_scheduler 29 | import torch.optim as optim 30 | import time 31 | from eval_lfw import evaluation_10_fold, getFeatureFromTorch 32 | import numpy as np 33 | import torchvision.transforms as transforms 34 | import argparse 35 | 36 | 37 | def train(args): 38 | # gpu init 39 | multi_gpus = False 40 | if len(args.gpus.split(',')) > 1: 41 | multi_gpus = True 42 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 43 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 44 | 45 | # log init 46 | save_dir = os.path.join(args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S')) 47 | if os.path.exists(save_dir): 48 | raise NameError('model dir exists!') 49 | os.makedirs(save_dir) 50 | logging = init_log(save_dir) 51 | _print = logging.info 52 | 53 | # dataset loader 54 | transform = transforms.Compose([ 55 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 56 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 57 | ]) 58 | # validation dataset 59 | trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform) 60 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 61 | shuffle=True, num_workers=8, drop_last=False) 62 | # test dataset 63 | lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform) 64 | lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, 65 | shuffle=False, num_workers=4, drop_last=False) 66 | agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform) 67 | agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128, 68 | shuffle=False, num_workers=4, drop_last=False) 69 | cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform) 70 | cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128, 71 | shuffle=False, num_workers=4, drop_last=False) 72 | 73 | # define backbone and margin layer 74 | if args.backbone == 'MobileFace': 75 | net = MobileFaceNet() 76 | elif args.backbone == 'Res50_IR': 77 | net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir') 78 | elif args.backbone == 'SERes50_IR': 79 | net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se') 80 | elif args.backbone == 'Res100_IR': 81 | net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir') 82 | elif args.backbone == 'SERes100_IR': 83 | net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se') 84 | elif args.backbone == 'Attention_56': 85 | net = ResidualAttentionNet_56(feature_dim=args.feature_dim) 86 | elif args.backbone == 'Attention_92': 87 | net = ResidualAttentionNet_92(feature_dim=args.feature_dim) 88 | else: 89 | print(args.backbone, ' is not available!') 90 | 91 | if args.margin_type == 'ArcFace': 92 | margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) 93 | elif args.margin_type == 'MultiMargin': 94 | margin = MultiMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) 95 | elif args.margin_type == 'CosFace': 96 | margin = CosineMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) 97 | elif args.margin_type == 'Softmax': 98 | margin = InnerProduct(args.feature_dim, trainset.class_nums) 99 | elif args.margin_type == 'SphereFace': 100 | pass 101 | else: 102 | print(args.margin_type, 'is not available!') 103 | 104 | if args.resume: 105 | print('resume the model parameters from: ', args.net_path, args.margin_path) 106 | net.load_state_dict(torch.load(args.net_path)['net_state_dict']) 107 | margin.load_state_dict(torch.load(args.margin_path)['net_state_dict']) 108 | 109 | # define optimizers for different layer 110 | criterion = torch.nn.CrossEntropyLoss().to(device) 111 | optimizer_ft = optim.SGD([ 112 | {'params': net.parameters(), 'weight_decay': 5e-4}, 113 | {'params': margin.parameters(), 'weight_decay': 5e-4} 114 | ], lr=0.1, momentum=0.9, nesterov=True) 115 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[6, 11, 16], gamma=0.1) 116 | 117 | if multi_gpus: 118 | net = DataParallel(net).to(device) 119 | margin = DataParallel(margin).to(device) 120 | else: 121 | net = net.to(device) 122 | margin = margin.to(device) 123 | 124 | 125 | best_lfw_acc = 0.0 126 | best_lfw_iters = 0 127 | best_agedb30_acc = 0.0 128 | best_agedb30_iters = 0 129 | best_cfp_fp_acc = 0.0 130 | best_cfp_fp_iters = 0 131 | total_iters = 0 132 | vis = Visualizer(env=args.model_pre + args.backbone) 133 | for epoch in range(1, args.total_epoch + 1): 134 | exp_lr_scheduler.step() 135 | # train model 136 | _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch)) 137 | net.train() 138 | 139 | since = time.time() 140 | for data in trainloader: 141 | img, label = data[0].to(device), data[1].to(device) 142 | optimizer_ft.zero_grad() 143 | 144 | raw_logits = net(img) 145 | output = margin(raw_logits, label) 146 | total_loss = criterion(output, label) 147 | total_loss.backward() 148 | optimizer_ft.step() 149 | 150 | total_iters += 1 151 | # print train information 152 | if total_iters % 100 == 0: 153 | # current training accuracy 154 | _, predict = torch.max(output.data, 1) 155 | total = label.size(0) 156 | correct = (np.array(predict.cpu()) == np.array(label.data.cpu())).sum() 157 | time_cur = (time.time() - since) / 100 158 | since = time.time() 159 | vis.plot_curves({'softmax loss': total_loss.item()}, iters=total_iters, title='train loss', 160 | xlabel='iters', ylabel='train loss') 161 | vis.plot_curves({'train accuracy': correct / total}, iters=total_iters, title='train accuracy', xlabel='iters', 162 | ylabel='train accuracy') 163 | 164 | _print("Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, epoch, total_loss.item(), correct/total, time_cur, exp_lr_scheduler.get_lr()[0])) 165 | 166 | # save model 167 | if total_iters % args.save_freq == 0: 168 | msg = 'Saving checkpoint: {}'.format(total_iters) 169 | _print(msg) 170 | if multi_gpus: 171 | net_state_dict = net.module.state_dict() 172 | margin_state_dict = margin.module.state_dict() 173 | else: 174 | net_state_dict = net.state_dict() 175 | margin_state_dict = margin.state_dict() 176 | if not os.path.exists(save_dir): 177 | os.mkdir(save_dir) 178 | torch.save({ 179 | 'iters': total_iters, 180 | 'net_state_dict': net_state_dict}, 181 | os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters)) 182 | torch.save({ 183 | 'iters': total_iters, 184 | 'net_state_dict': margin_state_dict}, 185 | os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters)) 186 | 187 | # test accuracy 188 | if total_iters % args.test_freq == 0: 189 | 190 | # test model on lfw 191 | net.eval() 192 | getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader) 193 | lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat') 194 | _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100)) 195 | if best_lfw_acc <= np.mean(lfw_accs) * 100: 196 | best_lfw_acc = np.mean(lfw_accs) * 100 197 | best_lfw_iters = total_iters 198 | 199 | # test model on AgeDB30 200 | getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader) 201 | age_accs = evaluation_10_fold('./result/cur_agedb30_result.mat') 202 | _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(np.mean(age_accs) * 100)) 203 | if best_agedb30_acc <= np.mean(age_accs) * 100: 204 | best_agedb30_acc = np.mean(age_accs) * 100 205 | best_agedb30_iters = total_iters 206 | 207 | # test model on CFP-FP 208 | getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader) 209 | cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat') 210 | _print('CFP-FP Ave Accuracy: {:.4f}'.format(np.mean(cfp_accs) * 100)) 211 | if best_cfp_fp_acc <= np.mean(cfp_accs) * 100: 212 | best_cfp_fp_acc = np.mean(cfp_accs) * 100 213 | best_cfp_fp_iters = total_iters 214 | _print('Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format( 215 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) 216 | 217 | vis.plot_curves({'lfw': np.mean(lfw_accs), 'agedb-30': np.mean(age_accs), 'cfp-fp': np.mean(cfp_accs)}, iters=total_iters, 218 | title='test accuracy', xlabel='iters', ylabel='test accuracy') 219 | net.train() 220 | 221 | _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format( 222 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) 223 | print('finishing training') 224 | 225 | 226 | if __name__ == '__main__': 227 | parser = argparse.ArgumentParser(description='PyTorch for deep face recognition') 228 | parser.add_argument('--train_root', type=str, default='/media/ramdisk/msra_align_112', help='train image root') 229 | parser.add_argument('--train_file_list', type=str, default='/media/ramdisk/msra_align_train.list', help='train list') 230 | parser.add_argument('--lfw_test_root', type=str, default='/media/sda/lfw/lfw_align_112', help='lfw image root') 231 | parser.add_argument('--lfw_file_list', type=str, default='/media/sda/lfw/pairs.txt', help='lfw pair file list') 232 | parser.add_argument('--agedb_test_root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='agedb image root') 233 | parser.add_argument('--agedb_file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='agedb pair file list') 234 | parser.add_argument('--cfpfp_test_root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='agedb image root') 235 | parser.add_argument('--cfpfp_file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='agedb pair file list') 236 | 237 | parser.add_argument('--backbone', type=str, default='SERes100_IR', help='MobileFace, Res50_IR, SERes50_IR, Res100_IR, SERes100_IR, Attention_56, Attention_92') 238 | parser.add_argument('--margin_type', type=str, default='ArcFace', help='ArcFace, CosFace, SphereFace, MultiMargin, Softmax') 239 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension, 128 or 512') 240 | parser.add_argument('--scale_size', type=float, default=32.0, help='scale size') 241 | parser.add_argument('--batch_size', type=int, default=200, help='batch size') 242 | parser.add_argument('--total_epoch', type=int, default=18, help='total epochs') 243 | 244 | parser.add_argument('--save_freq', type=int, default=3000, help='save frequency') 245 | parser.add_argument('--test_freq', type=int, default=3000, help='test frequency') 246 | parser.add_argument('--resume', type=int, default=False, help='resume model') 247 | parser.add_argument('--net_path', type=str, default='', help='resume model') 248 | parser.add_argument('--margin_path', type=str, default='', help='resume model') 249 | parser.add_argument('--save_dir', type=str, default='./model', help='model save dir') 250 | parser.add_argument('--model_pre', type=str, default='SERES100_', help='model prefix') 251 | parser.add_argument('--gpus', type=str, default='0,1,2,3', help='model prefix') 252 | 253 | args = parser.parse_args() 254 | 255 | train(args) 256 | 257 | 258 | -------------------------------------------------------------------------------- /train_center.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: train_center.py 7 | @time: 2019/1/3 11:12 8 | @desc: train script for my attention net and center loss 9 | ''' 10 | 11 | ''' 12 | Pleause use the train.py for your training process. 13 | ''' 14 | 15 | import os 16 | import torch.utils.data 17 | from torch.nn import DataParallel 18 | from datetime import datetime 19 | from backbone.mobilefacenet import MobileFaceNet 20 | from backbone.resnet import ResNet50, ResNet101 21 | from backbone.arcfacenet import SEResNet_IR 22 | from backbone.spherenet import SphereNet 23 | from margin.ArcMarginProduct import ArcMarginProduct 24 | from margin.InnerProduct import InnerProduct 25 | from lossfunctions.centerloss import CenterLoss 26 | from utils.logging import init_log 27 | from dataset.casia_webface import CASIAWebFace 28 | from dataset.lfw import LFW 29 | from dataset.agedb import AgeDB30 30 | from dataset.cfp import CFP_FP 31 | from utils.visualize import Visualizer 32 | from torch.optim import lr_scheduler 33 | import torch.optim as optim 34 | import time 35 | from eval_lfw import evaluation_10_fold, getFeatureFromTorch 36 | import numpy as np 37 | import torchvision.transforms as transforms 38 | import argparse 39 | 40 | 41 | def train(args): 42 | # gpu init 43 | multi_gpus = False 44 | if len(args.gpus.split(',')) > 1: 45 | multi_gpus = True 46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 47 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 48 | 49 | # log init 50 | save_dir = os.path.join(args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S')) 51 | if os.path.exists(save_dir): 52 | raise NameError('model dir exists!') 53 | os.makedirs(save_dir) 54 | logging = init_log(save_dir) 55 | _print = logging.info 56 | 57 | # dataset loader 58 | transform = transforms.Compose([ 59 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 60 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 61 | ]) 62 | # validation dataset 63 | trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform) 64 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 65 | shuffle=True, num_workers=8, drop_last=False) 66 | # test dataset 67 | lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform) 68 | lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, 69 | shuffle=False, num_workers=4, drop_last=False) 70 | agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform) 71 | agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128, 72 | shuffle=False, num_workers=4, drop_last=False) 73 | cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform) 74 | cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128, 75 | shuffle=False, num_workers=4, drop_last=False) 76 | 77 | # define backbone and margin layer 78 | if args.backbone == 'MobileFace': 79 | net = MobileFaceNet() 80 | elif args.backbone == 'Res50': 81 | net = ResNet50() 82 | elif args.backbone == 'Res101': 83 | net = ResNet101() 84 | elif args.backbone == 'Res50_IR': 85 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir') 86 | elif args.backbone == 'SERes50_IR': 87 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir') 88 | elif args.backbone == 'SphereNet': 89 | net = SphereNet(num_layers=64, feature_dim=args.feature_dim) 90 | else: 91 | print(args.backbone, ' is not available!') 92 | 93 | if args.margin_type == 'ArcFace': 94 | margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) 95 | elif args.margin_type == 'CosFace': 96 | pass 97 | elif args.margin_type == 'SphereFace': 98 | pass 99 | elif args.margin_type == 'InnerProduct': 100 | margin = InnerProduct(args.feature_dim, trainset.class_nums) 101 | else: 102 | print(args.margin_type, 'is not available!') 103 | 104 | if args.resume: 105 | print('resume the model parameters from: ', args.net_path, args.margin_path) 106 | net.load_state_dict(torch.load(args.net_path)['net_state_dict']) 107 | margin.load_state_dict(torch.load(args.margin_path)['net_state_dict']) 108 | 109 | # define optimizers for different layers 110 | criterion_classi = torch.nn.CrossEntropyLoss().to(device) 111 | optimizer_classi = optim.SGD([ 112 | {'params': net.parameters(), 'weight_decay': 5e-4}, 113 | {'params': margin.parameters(), 'weight_decay': 5e-4} 114 | ], lr=0.1, momentum=0.9, nesterov=True) 115 | 116 | #criterion_center = CenterLoss(trainset.class_nums, args.feature_dim).to(device) 117 | #optimizer_center = optim.SGD(criterion_center.parameters(), lr=0.5) 118 | 119 | scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi, milestones=[25, 50, 65], gamma=0.1) 120 | 121 | if multi_gpus: 122 | net = DataParallel(net).to(device) 123 | margin = DataParallel(margin).to(device) 124 | else: 125 | net = net.to(device) 126 | margin = margin.to(device) 127 | 128 | best_lfw_acc = 0.0 129 | best_lfw_iters = 0 130 | best_agedb30_acc = 0.0 131 | best_agedb30_iters = 0 132 | best_cfp_fp_acc = 0.0 133 | best_cfp_fp_iters = 0 134 | total_iters = 0 135 | #vis = Visualizer(env='softmax_center_xavier') 136 | for epoch in range(1, args.total_epoch + 1): 137 | scheduler_classi.step() 138 | # train model 139 | _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch)) 140 | net.train() 141 | 142 | since = time.time() 143 | for data in trainloader: 144 | img, label = data[0].to(device), data[1].to(device) 145 | feature = net(img) 146 | output = margin(feature) 147 | loss_classi = criterion_classi(output, label) 148 | #loss_center = criterion_center(feature, label) 149 | total_loss = loss_classi #+ loss_center * args.weight_center 150 | 151 | optimizer_classi.zero_grad() 152 | #optimizer_center.zero_grad() 153 | total_loss.backward() 154 | optimizer_classi.step() 155 | #optimizer_center.step() 156 | 157 | total_iters += 1 158 | # print train information 159 | if total_iters % 100 == 0: 160 | # current training accuracy 161 | _, predict = torch.max(output.data, 1) 162 | total = label.size(0) 163 | correct = (np.array(predict) == np.array(label.data)).sum() 164 | time_cur = (time.time() - since) / 100 165 | since = time.time() 166 | #vis.plot_curves({'softmax loss': loss_classi.item(), 'center loss': loss_center.item()}, iters=total_iters, title='train loss', xlabel='iters', ylabel='train loss') 167 | #vis.plot_curves({'train accuracy': correct / total}, iters=total_iters, title='train accuracy', xlabel='iters', ylabel='train accuracy') 168 | print("Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, loss_center: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, 169 | epoch, 170 | loss_classi.item(), 171 | loss_center.item(), 172 | correct/total, 173 | time_cur, 174 | scheduler_classi.get_lr()[ 175 | 0])) 176 | # save model 177 | if total_iters % args.save_freq == 0: 178 | msg = 'Saving checkpoint: {}'.format(total_iters) 179 | _print(msg) 180 | if multi_gpus: 181 | net_state_dict = net.module.state_dict() 182 | margin_state_dict = margin.module.state_dict() 183 | else: 184 | net_state_dict = net.state_dict() 185 | margin_state_dict = margin.state_dict() 186 | 187 | if not os.path.exists(save_dir): 188 | os.mkdir(save_dir) 189 | torch.save({ 190 | 'iters': total_iters, 191 | 'net_state_dict': net_state_dict}, 192 | os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters)) 193 | torch.save({ 194 | 'iters': total_iters, 195 | 'net_state_dict': margin_state_dict}, 196 | os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters)) 197 | #torch.save({ 198 | # 'iters': total_iters, 199 | # 'net_state_dict': criterion_center.state_dict()}, 200 | # os.path.join(save_dir, 'Iter_%06d_center.ckpt' % total_iters)) 201 | 202 | # test accuracy 203 | if total_iters % args.test_freq == 0: 204 | 205 | # test model on lfw 206 | net.eval() 207 | getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader) 208 | lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat') 209 | _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100)) 210 | if best_lfw_acc < np.mean(lfw_accs) * 100: 211 | best_lfw_acc = np.mean(lfw_accs) * 100 212 | best_lfw_iters = total_iters 213 | 214 | # test model on AgeDB30 215 | getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader) 216 | age_accs = evaluation_10_fold('./result/cur_agedb30_result.mat') 217 | _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(np.mean(age_accs) * 100)) 218 | if best_agedb30_acc < np.mean(age_accs) * 100: 219 | best_agedb30_acc = np.mean(age_accs) * 100 220 | best_agedb30_iters = total_iters 221 | 222 | # test model on CFP-FP 223 | getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader) 224 | cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat') 225 | _print('CFP-FP Ave Accuracy: {:.4f}'.format(np.mean(cfp_accs) * 100)) 226 | if best_cfp_fp_acc < np.mean(cfp_accs) * 100: 227 | best_cfp_fp_acc = np.mean(cfp_accs) * 100 228 | best_cfp_fp_iters = total_iters 229 | _print('Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format( 230 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) 231 | 232 | #vis.plot_curves({'lfw': np.mean(lfw_accs), 'agedb-30': np.mean(age_accs), 'cfp-fp': np.mean(cfp_accs)}, iters=total_iters, 233 | # title='test accuracy', xlabel='iters', ylabel='test accuracy') 234 | net.train() 235 | 236 | _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format( 237 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) 238 | print('finishing training') 239 | 240 | 241 | if __name__ == '__main__': 242 | parser = argparse.ArgumentParser(description='PyTorch for deep face recognition') 243 | parser.add_argument('--train_root', type=str, default='/media/ramdisk/webface_align_112', help='train image root') 244 | parser.add_argument('--train_file_list', type=str, default='/media/ramdisk/webface_align_train.list', help='train list') 245 | parser.add_argument('--lfw_test_root', type=str, default='/media/ramdisk/lfw_align_112', help='lfw image root') 246 | parser.add_argument('--lfw_file_list', type=str, default='/media/ramdisk/pairs.txt', help='lfw pair file list') 247 | parser.add_argument('--agedb_test_root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='agedb image root') 248 | parser.add_argument('--agedb_file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='agedb pair file list') 249 | parser.add_argument('--cfpfp_test_root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='agedb image root') 250 | parser.add_argument('--cfpfp_file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='agedb pair file list') 251 | 252 | parser.add_argument('--backbone', type=str, default='MobileFace', help='MobileFace, Res50, Res101, Res50_IR, SERes50_IR, SphereNet') 253 | parser.add_argument('--margin_type', type=str, default='InnerProduct', help='InnerProduct, ArcFace, CosFace, SphereFace') 254 | parser.add_argument('--feature_dim', type=int, default=128, help='feature dimension, 128 or 512') 255 | parser.add_argument('--scale_size', type=float, default=32.0, help='scale size') 256 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 257 | parser.add_argument('--total_epoch', type=int, default=80, help='total epochs') 258 | parser.add_argument('--weight_center', type=float, default=0.01, help='center loss weight') 259 | 260 | parser.add_argument('--save_freq', type=int, default=2000, help='save frequency') 261 | parser.add_argument('--test_freq', type=int, default=2000, help='test frequency') 262 | parser.add_argument('--resume', type=int, default=False, help='resume model') 263 | parser.add_argument('--net_path', type=str, default='', help='resume model') 264 | parser.add_argument('--margin_path', type=str, default='', help='resume model') 265 | parser.add_argument('--save_dir', type=str, default='./model', help='model save dir') 266 | parser.add_argument('--model_pre', type=str, default='Softmax_Center_', help='model prefix') 267 | parser.add_argument('--gpus', type=str, default='0,1', help='model prefix') 268 | 269 | args = parser.parse_args() 270 | 271 | train(args) 272 | 273 | 274 | -------------------------------------------------------------------------------- /train_softmax.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: train_softmax.py 7 | @time: 2019/1/7 8:33 8 | @desc: original softmax training with Casia-Webface 9 | ''' 10 | ''' 11 | Pleause use the train.py for your training process. 12 | ''' 13 | 14 | import os 15 | import torch.utils.data 16 | from torch.nn import DataParallel 17 | from datetime import datetime 18 | from backbone.mobilefacenet import MobileFaceNet 19 | from backbone.resnet import ResNet50, ResNet101 20 | from backbone.arcfacenet import SEResNet_IR 21 | from backbone.spherenet import SphereNet 22 | from margin.ArcMarginProduct import ArcMarginProduct 23 | from margin.InnerProduct import InnerProduct 24 | from utils.visualize import Visualizer 25 | from utils.logging import init_log 26 | from dataset.casia_webface import CASIAWebFace 27 | from dataset.lfw import LFW 28 | from dataset.agedb import AgeDB30 29 | from dataset.cfp import CFP_FP 30 | from torch.optim import lr_scheduler 31 | import torch.optim as optim 32 | import time 33 | from eval_lfw import evaluation_10_fold, getFeatureFromTorch 34 | import numpy as np 35 | import torchvision.transforms as transforms 36 | import argparse 37 | 38 | def train(args): 39 | # gpu init 40 | multi_gpus = False 41 | if len(args.gpus.split(',')) > 1: 42 | multi_gpus = True 43 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 45 | 46 | # log init 47 | save_dir = os.path.join(args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S')) 48 | if os.path.exists(save_dir): 49 | raise NameError('model dir exists!') 50 | os.makedirs(save_dir) 51 | logging = init_log(save_dir) 52 | _print = logging.info 53 | 54 | # dataset loader 55 | transform = transforms.Compose([ 56 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0] 57 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0] 58 | ]) 59 | # validation dataset 60 | trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform) 61 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 62 | shuffle=True, num_workers=8, drop_last=False) 63 | # test dataset 64 | lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform) 65 | lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128, 66 | shuffle=False, num_workers=4, drop_last=False) 67 | agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform) 68 | agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128, 69 | shuffle=False, num_workers=4, drop_last=False) 70 | cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform) 71 | cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128, 72 | shuffle=False, num_workers=4, drop_last=False) 73 | 74 | # define backbone and margin layer 75 | if args.backbone == 'MobileFace': 76 | net = MobileFaceNet() 77 | elif args.backbone == 'Res50': 78 | net = ResNet50() 79 | elif args.backbone == 'Res101': 80 | net = ResNet101() 81 | elif args.backbone == 'Res50_IR': 82 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir') 83 | elif args.backbone == 'SERes50_IR': 84 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir') 85 | elif args.backbone == 'SphereNet': 86 | net = SphereNet(num_layers=64, feature_dim=args.feature_dim) 87 | else: 88 | print(args.backbone, ' is not available!') 89 | 90 | if args.margin_type == 'ArcFace': 91 | margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size) 92 | elif args.margin_type == 'CosFace': 93 | pass 94 | elif args.margin_type == 'SphereFace': 95 | pass 96 | elif args.margin_type == 'InnerProduct': 97 | margin = InnerProduct(args.feature_dim, trainset.class_nums) 98 | else: 99 | print(args.margin_type, 'is not available!') 100 | 101 | if args.resume: 102 | print('resume the model parameters from: ', args.net_path, args.margin_path) 103 | net.load_state_dict(torch.load(args.net_path)['net_state_dict']) 104 | margin.load_state_dict(torch.load(args.margin_path)['net_state_dict']) 105 | 106 | # define optimizers for different layer 107 | 108 | criterion_classi = torch.nn.CrossEntropyLoss().to(device) 109 | optimizer_classi = optim.SGD([ 110 | {'params': net.parameters(), 'weight_decay': 5e-4}, 111 | {'params': margin.parameters(), 'weight_decay': 5e-4} 112 | ], lr=0.1, momentum=0.9, nesterov=True) 113 | scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi, milestones=[20, 35, 45], gamma=0.1) 114 | 115 | if multi_gpus: 116 | net = DataParallel(net).to(device) 117 | margin = DataParallel(margin).to(device) 118 | else: 119 | net = net.to(device) 120 | margin = margin.to(device) 121 | 122 | best_lfw_acc = 0.0 123 | best_lfw_iters = 0 124 | best_agedb30_acc = 0.0 125 | best_agedb30_iters = 0 126 | best_cfp_fp_acc = 0.0 127 | best_cfp_fp_iters = 0 128 | total_iters = 0 129 | vis = Visualizer(env='softmax_train') 130 | for epoch in range(1, args.total_epoch + 1): 131 | scheduler_classi.step() 132 | # train model 133 | _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch)) 134 | net.train() 135 | 136 | since = time.time() 137 | for data in trainloader: 138 | img, label = data[0].to(device), data[1].to(device) 139 | feature = net(img) 140 | output = margin(feature) 141 | loss_classi = criterion_classi(output, label) 142 | total_loss = loss_classi 143 | 144 | optimizer_classi.zero_grad() 145 | total_loss.backward() 146 | optimizer_classi.step() 147 | 148 | total_iters += 1 149 | # print train information 150 | if total_iters % 100 == 0: 151 | #current training accuracy 152 | _, predict = torch.max(output.data, 1) 153 | total = label.size(0) 154 | correct = (np.array(predict) == np.array(label.data)).sum() 155 | time_cur = (time.time() - since) / 100 156 | since = time.time() 157 | vis.plot_curves({'train loss': loss_classi.item()}, iters=total_iters, title='train loss', xlabel='iters', ylabel='train loss') 158 | vis.plot_curves({'train accuracy': correct/total}, iters=total_iters, title='train accuracy', xlabel='iters', ylabel='train accuracy') 159 | print("Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, 160 | epoch, 161 | loss_classi.item(), 162 | correct/total, 163 | time_cur, 164 | scheduler_classi.get_lr()[ 165 | 0])) 166 | # save model 167 | if total_iters % args.save_freq == 0: 168 | msg = 'Saving checkpoint: {}'.format(total_iters) 169 | _print(msg) 170 | if multi_gpus: 171 | net_state_dict = net.module.state_dict() 172 | margin_state_dict = margin.module.state_dict() 173 | else: 174 | net_state_dict = net.state_dict() 175 | margin_state_dict = margin.state_dict() 176 | 177 | if not os.path.exists(save_dir): 178 | os.mkdir(save_dir) 179 | torch.save({ 180 | 'iters': total_iters, 181 | 'net_state_dict': net_state_dict}, 182 | os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters)) 183 | torch.save({ 184 | 'iters': total_iters, 185 | 'net_state_dict': margin_state_dict}, 186 | os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters)) 187 | 188 | # test accuracy 189 | if total_iters % args.test_freq == 0: 190 | # test model on lfw 191 | net.eval() 192 | getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader) 193 | lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat') 194 | _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100)) 195 | if best_lfw_acc < np.mean(lfw_accs) * 100: 196 | best_lfw_acc = np.mean(lfw_accs) * 100 197 | best_lfw_iters = total_iters 198 | # test model on AgeDB30 199 | getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader) 200 | age_accs = evaluation_10_fold('./result/cur_agedb30_result.mat') 201 | _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(np.mean(age_accs) * 100)) 202 | if best_agedb30_acc < np.mean(age_accs) * 100: 203 | best_agedb30_acc = np.mean(age_accs) * 100 204 | best_agedb30_iters = total_iters 205 | # test model on CFP-FP 206 | getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader) 207 | cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat') 208 | _print('CFP-FP Ave Accuracy: {:.4f}'.format(np.mean(cfp_accs) * 100)) 209 | if best_cfp_fp_acc < np.mean(cfp_accs) * 100: 210 | best_cfp_fp_acc = np.mean(cfp_accs) * 100 211 | best_cfp_fp_iters = total_iters 212 | _print('Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format( 213 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) 214 | vis.plot_curves({'lfw': np.mean(lfw_accs), 'agedb-30': np.mean(age_accs), 'cfp-fp': np.mean(cfp_accs)}, iters=total_iters, title='test accuracy', xlabel='iters', ylabel='test accuracy') 215 | net.train() 216 | 217 | _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format( 218 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters)) 219 | print('finishing training') 220 | 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser(description='PyTorch for deep face recognition') 224 | parser.add_argument('--train_root', type=str, default='/media/ramdisk/webface_align_112', help='train image root') 225 | parser.add_argument('--train_file_list', type=str, default='/media/ramdisk/webface_align_train.list', help='train list') 226 | parser.add_argument('--lfw_test_root', type=str, default='/media/ramdisk/lfw_align_112', help='lfw image root') 227 | parser.add_argument('--lfw_file_list', type=str, default='/media/ramdisk/pairs.txt', help='lfw pair file list') 228 | parser.add_argument('--agedb_test_root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='agedb image root') 229 | parser.add_argument('--agedb_file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='agedb pair file list') 230 | parser.add_argument('--cfpfp_test_root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='agedb image root') 231 | parser.add_argument('--cfpfp_file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='agedb pair file list') 232 | 233 | parser.add_argument('--backbone', type=str, default='MobileFace', help='MobileFace, Res50, Res101, Res50_IR, SERes50_IR, SphereNet') 234 | parser.add_argument('--margin_type', type=str, default='InnerProduct', help='InnerProduct, ArcFace, CosFace, SphereFace') 235 | parser.add_argument('--feature_dim', type=int, default=128, help='feature dimension, 128 or 512') 236 | parser.add_argument('--scale_size', type=float, default=32.0, help='scale size') 237 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 238 | parser.add_argument('--total_epoch', type=int, default=50, help='total epochs') 239 | parser.add_argument('--weight_center', type=float, default=1.0, help='center loss weight') 240 | 241 | parser.add_argument('--save_freq', type=int, default=2000, help='save frequency') 242 | parser.add_argument('--test_freq', type=int, default=2000, help='test frequency') 243 | parser.add_argument('--resume', type=int, default=False, help='resume model') 244 | parser.add_argument('--net_path', type=str, default='', help='resume model') 245 | parser.add_argument('--margin_path', type=str, default='', help='resume model') 246 | parser.add_argument('--save_dir', type=str, default='./model', help='model save dir') 247 | parser.add_argument('--model_pre', type=str, default='Softmax_', help='model prefix') 248 | parser.add_argument('--gpus', type=str, default='2,3', help='model prefix') 249 | 250 | args = parser.parse_args() 251 | 252 | train(args) 253 | 254 | 255 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## MXNET binary tools 3 | 4 | Tools for restore the aligned images from mxnet binart file provided by [insightface](https://github.com/deepinsight/insightface). 5 | 6 | You should install a mxnet-cpu first for the image parsing, just do ' **pip install mxnet** ' is ok. 7 | 8 | The processed images are list below: 9 | [LFW @ BaiduNetdisk](https://pan.baidu.com/s/1Rue4FBmGvdGMPkyy2ZqcdQ), [AgeDB-30 @ BaiduNetdisk](https://pan.baidu.com/s/1sdw1lO5JfP6Ja99O7zprUg), [CFP_FP @ BaiduNetdisk](https://pan.baidu.com/s/1gyFAAy427weUd2G-ozMgEg) 10 | 11 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: __init__.py.py 7 | @time: 2018/12/22 9:41 8 | @desc: 9 | ''' -------------------------------------------------------------------------------- /utils/load_images_from_bin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: load_images_from_bin.py 7 | @time: 2018/12/25 19:21 8 | @desc: For AgeDB-30 and CFP-FP test dataset, we use the mxnet binary file provided by insightface, this is the tool to restore 9 | the aligned images from mxnet binary file. 10 | You should install a mxnet-cpu first, just do 'pip install mxnet==1.2.1' is ok. 11 | ''' 12 | 13 | from PIL import Image 14 | import cv2 15 | import os 16 | import pickle 17 | import mxnet as mx 18 | from tqdm import tqdm 19 | 20 | ''' 21 | For train dataset, insightface provide a mxnet .rec file, just install a mxnet-cpu for extract images 22 | ''' 23 | 24 | def load_mx_rec(rec_path): 25 | save_path = os.path.join(rec_path, 'emore_images_2') 26 | if not os.path.exists(save_path): 27 | os.makedirs(save_path) 28 | 29 | imgrec = mx.recordio.MXIndexedRecordIO(os.path.join(rec_path, 'train.idx'), os.path.join(rec_path, 'train.rec'), 'r') 30 | img_info = imgrec.read_idx(0) 31 | header,_ = mx.recordio.unpack(img_info) 32 | max_idx = int(header.label[0]) 33 | for idx in tqdm(range(1,max_idx)): 34 | img_info = imgrec.read_idx(idx) 35 | header, img = mx.recordio.unpack_img(img_info) 36 | label = int(header.label) 37 | #img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 38 | #img = Image.fromarray(img) 39 | label_path = os.path.join(save_path, str(label).zfill(6)) 40 | if not os.path.exists(label_path): 41 | os.makedirs(label_path) 42 | #img.save(os.path.join(label_path, str(idx).zfill(8) + '.jpg'), quality=95) 43 | cv2.imwrite(os.path.join(label_path, str(idx).zfill(8) + '.jpg'), img) 44 | 45 | 46 | def load_image_from_bin(bin_path, save_dir): 47 | if not os.path.exists(save_dir): 48 | os.makedirs(save_dir) 49 | file = open(os.path.join(save_dir, '../', 'lfw_pair.txt'), 'w') 50 | bins, issame_list = pickle.load(open(bin_path, 'rb'), encoding='bytes') 51 | for idx in tqdm(range(len(bins))): 52 | _bin = bins[idx] 53 | img = mx.image.imdecode(_bin).asnumpy() 54 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 55 | cv2.imwrite(os.path.join(save_dir, str(idx+1).zfill(5)+'.jpg'), img) 56 | if idx % 2 == 0: 57 | label = 1 if issame_list[idx//2] == True else -1 58 | file.write(str(idx+1).zfill(5) + '.jpg' + ' ' + str(idx+2).zfill(5) +'.jpg' + ' ' + str(label) + '\n') 59 | 60 | 61 | if __name__ == '__main__': 62 | #bin_path = 'D:/face_data_emore/faces_webface_112x112/lfw.bin' 63 | #save_dir = 'D:/face_data_emore/faces_webface_112x112/lfw' 64 | rec_path = 'D:/face_data_emore/faces_emore' 65 | load_mx_rec(rec_path) 66 | #load_image_from_bin(bin_path, save_dir) 67 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: logging.py 7 | @time: 2018/12/22 9:42 8 | @desc: logging tools 9 | ''' 10 | 11 | from __future__ import print_function 12 | import os 13 | import logging 14 | 15 | 16 | def init_log(output_dir): 17 | logging.basicConfig(level=logging.DEBUG, 18 | format='%(asctime)s %(message)s', 19 | datefmt='%Y%m%d-%H:%M:%S', 20 | filename=os.path.join(output_dir, 'log.log'), 21 | filemode='w') 22 | console = logging.StreamHandler() 23 | console.setLevel(logging.INFO) 24 | logging.getLogger('').addHandler(console) 25 | return logging 26 | 27 | 28 | if __name__ == '__main__': 29 | pass 30 | -------------------------------------------------------------------------------- /utils/plot_logit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: plot_logit.py 7 | @time: 2019/3/29 14:21 8 | @desc: plot the logit corresponding to shpereface, cosface, arcface and so on. 9 | ''' 10 | 11 | import math 12 | import torch 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | 16 | def softmax(theta): 17 | return torch.cos(theta) 18 | 19 | def sphereface(theta, m=4): 20 | return (torch.cos(m * theta) + 20 * torch.cos(theta)) / (20 + 1) 21 | 22 | def cosface(theta, m): 23 | return torch.cos(theta) - m 24 | 25 | def arcface(theta, m): 26 | return torch.cos(theta + m) 27 | 28 | def multimargin(theta, m1, m2): 29 | return torch.cos(theta + m1) - m2 30 | 31 | 32 | theta = torch.arange(0, math.pi, 0.001) 33 | print(theta.type) 34 | 35 | x = theta.numpy() 36 | y_softmax = softmax(theta).numpy() 37 | y_cosface = cosface(theta, 0.35).numpy() 38 | y_arcface = arcface(theta, 0.5).numpy() 39 | 40 | y_multimargin_1 = multimargin(theta, 0.2, 0.3).numpy() 41 | y_multimargin_2 = multimargin(theta, 0.2, 0.4).numpy() 42 | y_multimargin_3 = multimargin(theta, 0.3, 0.2).numpy() 43 | y_multimargin_4 = multimargin(theta, 0.3, 0.3).numpy() 44 | y_multimargin_5 = multimargin(theta, 0.4, 0.2).numpy() 45 | y_multimargin_6 = multimargin(theta, 0.4, 0.3).numpy() 46 | 47 | plt.plot(x, y_softmax, x, y_cosface, x, y_arcface, x, y_multimargin_1, x, y_multimargin_2, x, y_multimargin_3, x, y_multimargin_4, x, y_multimargin_5, x, y_multimargin_6) 48 | plt.legend(['Softmax(0.00, 0.00)', 'CosFace(0.00, 0.35)', 'ArcFace(0.50, 0.00)', 'MultiMargin(0.20, 0.30)', 'MultiMargin(0.20, 0.40)', 'MultiMargin(0.30, 0.20)', 'MultiMargin(0.30, 0.30)', 'MultiMargin(0.40, 0.20)', 'MultiMargin(0.40, 0.30)']) 49 | plt.grid(False) 50 | plt.xlim((0, 3/4*math.pi)) 51 | plt.ylim((-1.2, 1.2)) 52 | 53 | plt.xticks(np.arange(0, 2.4, 0.3)) 54 | plt.yticks(np.arange(-1.2, 1.2, 0.2)) 55 | plt.xlabel('Angular between the Feature and Target Center (Radian: 0 - 3/4 Pi)') 56 | plt.ylabel('Target Logit') 57 | 58 | plt.savefig('target logits') -------------------------------------------------------------------------------- /utils/plot_theta.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: plot_theta.py 7 | @time: 2019/1/2 19:08 8 | @desc: plot theta distribution between weight and feature vector 9 | ''' 10 | 11 | from matplotlib import pyplot as plt 12 | plt.switch_backend('agg') 13 | 14 | import argparse 15 | from backbone.mobilefacenet import MobileFaceNet 16 | from margin.ArcMarginProduct import ArcMarginProduct 17 | from torch.utils.data import DataLoader 18 | import torch 19 | 20 | from torchvision import transforms 21 | import torch.nn.functional as F 22 | import os 23 | import numpy as np 24 | from dataset.casia_webface import CASIAWebFace 25 | 26 | 27 | def get_train_loader(img_folder, filelist): 28 | print('Loading dataset...') 29 | transform = transforms.Compose([ 30 | transforms.ToTensor(), 31 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) 32 | ]) 33 | trainset = CASIAWebFace(img_folder, filelist, transform=transform) 34 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, 35 | shuffle=False, num_workers=8, drop_last=False) 36 | return trainloader 37 | 38 | def load_model(backbone_state_dict, margin_state_dict, device): 39 | 40 | # load model 41 | net = MobileFaceNet() 42 | net.load_state_dict(torch.load(backbone_state_dict)['net_state_dict']) 43 | margin = ArcMarginProduct(in_feature=128, out_feature=10575) 44 | margin.load_state_dict(torch.load(margin_state_dict)['net_state_dict']) 45 | 46 | net = net.to(device) 47 | margin = margin.to(device) 48 | 49 | return net.eval(), margin.eval() 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='plot theta distribution of trained model') 54 | parser.add_argument('--img_root', type=str, default='/media/ramdisk/webface_align_112', help='train image root') 55 | parser.add_argument('--file_list', type=str, default='/media/ramdisk/webface_align_train.list', help='train list') 56 | parser.add_argument('--backbone_file', type=str, default='../model/Paper_MOBILEFACE_20190103_111830/Iter_088000_net.ckpt', help='backbone state dict file') 57 | parser.add_argument('--margin_file', type=str, default='../model/Paper_MOBILEFACE_20190103_111830/Iter_088000_margin.ckpt', help='backbone state dict file') 58 | parser.add_argument('--gpus', type=str, default='0', help='model prefix, single gpu only') 59 | args = parser.parse_args() 60 | 61 | # gpu init 62 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 63 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 64 | 65 | # load pretrain model 66 | trained_net, trained_margin = load_model(args.backbone_file, args.margin_file, device) 67 | 68 | # initial model 69 | initial_net = MobileFaceNet() 70 | initial_margin = ArcMarginProduct() 71 | initial_net = initial_net.to(device).eval() 72 | initial_margin = initial_margin.to(device).eval() 73 | 74 | # image dataloader 75 | image_loader = get_train_loader(args.img_root, args.file_list) 76 | theta_trained = [] 77 | theta_initial = [] 78 | for data in image_loader: 79 | img, label = data[0].to(device), data[1].to(device) 80 | # pretrained 81 | embedding = trained_net(img) 82 | cos_theta = F.linear(F.normalize(embedding), F.normalize(trained_margin.weight)) 83 | cos_theta = cos_theta.clamp(-1, 1).detach().cpu().numpy() 84 | for i in range(img.shape[0]): 85 | cos_trget = cos_theta[i][label[i]] 86 | theta_trained.append(np.arccos(cos_trget) / np.pi * 180) 87 | # initial 88 | embedding = initial_net(img) 89 | cos_theta = F.linear(F.normalize(embedding), F.normalize(initial_margin.weight)) 90 | cos_theta = cos_theta.clamp(-1, 1).detach().cpu().numpy() 91 | for i in range(img.shape[0]): 92 | cos_trget = cos_theta[i][label[i]] 93 | theta_initial.append(np.arccos(cos_trget) / np.pi * 180) 94 | ''' 95 | # write theta list to txt file 96 | trained_theta_file = open('arcface_theta.txt', 'w') 97 | initial_theta_file = open('initial_theta.txt', 'w') 98 | for item in theta_trained: 99 | trained_theta_file.write(str(item)) 100 | trained_theta_file.write('\n') 101 | for item in theta_initial: 102 | initial_theta_file.write(str(item)) 103 | initial_theta_file.write('\n') 104 | 105 | # plot the theta, read theta from txt first 106 | theta_trained = [] 107 | theta_initial = [] 108 | trained_theta_file = open('arcface_theta.txt', 'r') 109 | initial_theta_file = open('initial_theta.txt', 'r') 110 | lines = trained_theta_file.readlines() 111 | for line in lines: 112 | theta_trained.append(float(line.strip('\n')[0])) 113 | lines = initial_theta_file.readlines() 114 | for line in lines: 115 | theta_initial.append(float(line.split('\n')[0])) 116 | ''' 117 | print(len(theta_trained), len(theta_initial)) 118 | plt.figure() 119 | plt.xlabel('Theta') 120 | plt.ylabel('Numbers') 121 | plt.title('Theta Distribution') 122 | plt.hist(theta_trained, bins=180, normed=0) 123 | plt.hist(theta_initial, bins=180, normed=0) 124 | plt.legend(['trained theta distribution', 'initial theta distribution']) 125 | plt.savefig('theta_distribution_hist.jpg') 126 | -------------------------------------------------------------------------------- /utils/theta_distribution_hist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/utils/theta_distribution_hist.jpg -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | @author: wujiyang 5 | @contact: wujiyang@hust.edu.cn 6 | @file: visualize.py 7 | @time: 2019/1/7 16:07 8 | @desc: visualize tools 9 | ''' 10 | 11 | import visdom 12 | import numpy as np 13 | import time 14 | 15 | class Visualizer(): 16 | def __init__(self, env='default', **kwargs): 17 | self.vis = visdom.Visdom(env=env, **kwargs) 18 | self.index = 1 19 | 20 | def plot_curves(self, d, iters, title='loss', xlabel='iters', ylabel='accuracy'): 21 | name = list(d.keys()) 22 | val = list(d.values()) 23 | if len(val) == 1: 24 | y = np.array(val) 25 | else: 26 | y = np.array(val).reshape(-1, len(val)) 27 | self.vis.line(Y=y, 28 | X=np.array([self.index]), 29 | win=title, 30 | opts=dict(legend=name, title = title, xlabel=xlabel, ylabel=ylabel), 31 | update=None if self.index == 0 else 'append') 32 | self.index = iters 33 | 34 | 35 | if __name__ == '__main__': 36 | vis = Visualizer(env='test') 37 | for i in range(10): 38 | x = i 39 | y = 2 * i 40 | z = 4 * i 41 | vis.plot_curves({'train': x, 'test': y}, iters=i, title='train') 42 | vis.plot_curves({'train': z, 'test': y, 'val': i}, iters=i, title='test') 43 | time.sleep(1) --------------------------------------------------------------------------------