├── .gitignore
├── .idea
├── Face_Pytorch.iml
├── codeStyles
│ └── codeStyleConfig.xml
├── deployment.xml
├── encodings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── backbone
├── __init__.py
├── arcfacenet.py
├── attention.py
├── cbam.py
├── mobilefacenet.py
├── resnet.py
└── spherenet.py
├── cppapi
├── README.md
└── pytorch2torchscript.py
├── dataset
├── __init__.py
├── agedb.py
├── casia_webface.py
├── cfp.py
├── lfw.py
├── lfw_2.py
└── megaface.py
├── eval_agedb30.py
├── eval_cfp.py
├── eval_deepglint_merge.py
├── eval_lfw.py
├── eval_lfw_blufr.py
├── eval_megaface.py
├── lossfunctions
├── __init__.py
├── agentcenterloss.py
└── centerloss.py
├── margin
├── ArcMarginProduct.py
├── CosineMarginProduct.py
├── InnerProduct.py
├── MultiMarginProduct.py
├── SphereMarginProduct.py
└── __init__.py
├── model
├── MSCeleb_MOBILEFACE_20181228_170458
│ └── log.log
└── MSCeleb_SERES50_IR_20181229_211407
│ └── log.log
├── result
├── softmax.gif
├── softmax_center.gif
└── visualization.jpg
├── train.py
├── train_center.py
├── train_softmax.py
└── utils
├── README.md
├── __init__.py
├── load_images_from_bin.py
├── logging.py
├── plot_logit.py
├── plot_theta.py
├── theta_distribution_hist.jpg
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | local_settings.py
56 | db.sqlite3
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # SageMath parsed files
81 | *.sage.py
82 |
83 | # Environments
84 | .env
85 | .venv
86 | env/
87 | venv/
88 | ENV/
89 | env.bak/
90 | venv.bak/
91 |
92 | # Spyder project settings
93 | .spyderproject
94 | .spyproject
95 |
96 | # Rope project settings
97 | .ropeproject
98 |
99 | # mkdocs documentation
100 | /site
101 |
102 | # mypy
103 | .mypy_cache/
104 |
105 | # model or feature
106 | .mat
107 | .ckpt
--------------------------------------------------------------------------------
/.idea/Face_Pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | >2019.07.14
2 | >Currently, I have graduated from campus and doing another kind of job. So this project may not be updated again.
3 |
4 | ## Face_Pytorch
5 | The implementation of popular face recognition algorithms in pytorch framework, including arcface, cosface and sphereface and so on.
6 |
7 | All codes are evaluated on Pytorch 0.4.0 with Python 3.6, Ubuntu 16.04.10, CUDA 9.1 and CUDNN 7.1. Partially evaluated on Pytorch 1.0.
8 |
9 |
10 | ## Data Preparation
11 | For CNN training, I use CASIA-WebFace and Cleaned MS-Celeb-1M, aligned by MTCNN with the size of 112x112. For performance testing, I report the results on LFW, AgeDB-30, CFP-FP, MegaFace rank1 identification and verification.
12 |
13 | For AgeDB-30 and CFP-FP, the aligned images and evaluation images pairs are restored from the mxnet binary file provided by [insightface](https://github.com/deepinsight/insightface), tools are available in this repository. You should install a mxnet-cpu first for the image parsing, just do ' **pip install mxnet** ' is ok.
14 | [LFW @ BaiduNetdisk](https://pan.baidu.com/s/1Rue4FBmGvdGMPkyy2ZqcdQ), [AgeDB-30 @ BaiduNetdisk](https://pan.baidu.com/s/1sdw1lO5JfP6Ja99O7zprUg), [CFP_FP @ BaiduNetdisk](https://pan.baidu.com/s/1gyFAAy427weUd2G-ozMgEg)
15 |
16 | ## Results
17 | > MobileFaceNet: Struture described in MobileFaceNet
18 | > ResNet50: Original resnet structure
19 | > ResNet50-IR: CNN described in ArcFace paper
20 | > SEResNet50-IR: CNN described in ArcFace paper
21 | ### Verification results on LFW, AgeDB-30 and CFP_FP
22 | Small Protocol: trained with CASIA-WebFace of data size: 453580/10575
23 | Large Protocol: trained with DeepGlint MS-Celeb-1M of data size: 3923399/86876
24 |
25 | Model Type | Loss | LFW | AgeDB-30 | CFP-FP | Model Size | protocol
26 | :--------------:|:---------:|:-------:|:--------:|:-------|:----------:|:--------:
27 | MobileFaceNet | ArcFace | 99.23 | 93.26 | 94.34 | 4MB | small
28 | ResNet50-IR | ArcFace | 99.42 | 94.45 | 95.34 | 170MB | small
29 | SEResNet50-IR | ArcFace | 99.43 | 94.50 | 95.43 | 171MB | small
30 | MobileFaceNet | ArcFace | 99.58 | 96.57 | 92.90 | 4MB | large
31 | ResNet50-IR | ArcFace | 99.82 | 98.07 | 95.34 | 170MB | large
32 | SEResNet50-IR | ArcFace | 99.80 | 98.13 | 95.60 | 171MB | large
33 | ResNet100-IR | ArcFace | 99.83 | 98.28 | 96.41 | 256MB | large
34 |
35 | There exists an odd result fact that when training under small protocol, CFP-FP performances better than AgeDB-30, while when training with large scale dataset, CFP-FP performances worse than AgeDB-30.
36 |
37 | ### MegaFace rank 1 identifiaction accuracy and verfication@FPR=1e-6 results
38 |
39 | Model Type | Loss | MF Acc. | MF Ver. | MF Acc.@R | MF Ver.@R | SIZE | protocol
40 | :--------------:|:---------:|:-------:|:-------:|:---------:|:---------:|:-----:|:-------:
41 | MobileFaceNet | ArcFace | 69.10 | 84.23 | 81.15 | 85.86 | 4MB | small
42 | ResNet50-IR | ArcFace | 74.31 | 88.23 | 87.44 | 89.56 | 170MB | small
43 | SEResNet50-IR | ArcFace | 74.37 | 88.32 | 88.30 | 89.65 | 171MB | small
44 | MobileFaceNet | ArcFace | 74.95 | 88.77 | 89.47 | 91.03 | 4MB | large
45 | ResNet50-IR | ArcFace | 79.61 | 96.02 | 96.58 | 96.78 | 170MB | large
46 | SEResNet50-IR | ArcFace | 79.91 | 96.10 | 97.01 | 97.60 | 171MB | large
47 | ResNet100-IR | ArcFace | 80.40 | 96.94 | 97.60 | 98.05 | 256MB | large
48 |
49 |
50 |
51 | ## Usage
52 | 1. Download the source code to your machine.
53 | 2. Prepare the train dataset and train list, test dataset and test verification pairs.
54 | 3. Set your own dataset path and any other parameters in train.py.
55 | 4. Run train.py file, test accuracy will print into log file during training process.
56 | ---
57 | 5. Every evaluate file can work independently for the model test process. just set your own args in the file.
58 |
59 | ## Visualization
60 | Visdom support for loss and accuracy during training process.
61 | 
62 |
63 |
64 | Softmax Loss vs Softmax_Center Loss. Left: softmax training set. Right: softmax + center loss training set.
65 |
66 |
67 |
68 |
69 |
70 | ## References
71 | [MuggleWang/CosFace_pytorch](https://github.com/MuggleWang/CosFace_pytorch)
72 | [Xiaoccer/MobileFaceNet_Pytorch](https://github.com/Xiaoccer/MobileFaceNet_Pytorch)
73 | [TreB1eN/InsightFace_Pytorch](https://github.com/TreB1eN/InsightFace_Pytorch)
74 | [deepinsight/insightface](https://github.com/deepinsight/insightface)
75 | [KaiyangZhou/pytorch-center-loss](https://github.com/KaiyangZhou/pytorch-center-loss)
76 | [tengshaofeng/ResidualAttentionNetwork-pytorch](https://github.com/tengshaofeng/ResidualAttentionNetwork-pytorch)
77 |
78 | ## Todo
79 | 1. Report the test results on DeepGlint Trillion Pairs Challenge.
80 | 2. Add C++ api for fast deployment with pytorch 1.0.
81 | 3. Train the ResNet100-based model.
82 |
--------------------------------------------------------------------------------
/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: __init__.py.py
7 | @time: 2018/12/21 15:30
8 | @desc:
9 | '''
--------------------------------------------------------------------------------
/backbone/arcfacenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: arcfacenet.py
7 | @time: 2018/12/26 10:15
8 | @desc: Network structures used in the arcface paper, including ResNet50-IR, ResNet101-IR, SEResNet50-IR, SEResNet101-IR
9 |
10 | ''''''
11 | Update: This file has been deprecated, all the models build in this class have been rebuild in cbam.py
12 | Yet the code in this file still works.
13 | '''
14 |
15 |
16 | import torch
17 | from torch import nn
18 | from collections import namedtuple
19 |
20 | class Flatten(nn.Module):
21 | def forward(self, input):
22 | return input.view(input.size(0), -1)
23 |
24 |
25 | class SEModule(nn.Module):
26 | def __init__(self, channels, reduction):
27 | super(SEModule, self).__init__()
28 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
29 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
30 | self.relu = nn.ReLU(inplace=True)
31 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
32 | self.sigmoid = nn.Sigmoid()
33 |
34 | def forward(self, x):
35 | input = x
36 | x = self.avg_pool(x)
37 | x = self.fc1(x)
38 | x = self.relu(x)
39 | x = self.fc2(x)
40 | x = self.sigmoid(x)
41 |
42 | return input * x
43 |
44 |
45 | class BottleNeck_IR(nn.Module):
46 | def __init__(self, in_channel, out_channel, stride):
47 | super(BottleNeck_IR, self).__init__()
48 | if in_channel == out_channel:
49 | self.shortcut_layer = nn.MaxPool2d(1, stride)
50 | else:
51 | self.shortcut_layer = nn.Sequential(
52 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
53 | nn.BatchNorm2d(out_channel)
54 | )
55 |
56 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
57 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
58 | nn.BatchNorm2d(out_channel),
59 | nn.PReLU(out_channel),
60 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
61 | nn.BatchNorm2d(out_channel))
62 |
63 | def forward(self, x):
64 | shortcut = self.shortcut_layer(x)
65 | res = self.res_layer(x)
66 |
67 | return shortcut + res
68 |
69 | class BottleNeck_IR_SE(nn.Module):
70 | def __init__(self, in_channel, out_channel, stride):
71 | super(BottleNeck_IR_SE, self).__init__()
72 | if in_channel == out_channel:
73 | self.shortcut_layer = nn.MaxPool2d(1, stride)
74 | else:
75 | self.shortcut_layer = nn.Sequential(
76 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
77 | nn.BatchNorm2d(out_channel)
78 | )
79 |
80 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
81 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
82 | nn.BatchNorm2d(out_channel),
83 | nn.PReLU(out_channel),
84 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
85 | nn.BatchNorm2d(out_channel),
86 | SEModule(out_channel, 16))
87 |
88 | def forward(self, x):
89 | shortcut = self.shortcut_layer(x)
90 | res = self.res_layer(x)
91 |
92 | return shortcut + res
93 |
94 |
95 | class Bottleneck(namedtuple('Block', ['in_channel', 'out_channel', 'stride'])):
96 | '''A named tuple describing a ResNet block.'''
97 |
98 |
99 | def get_block(in_channel, out_channel, num_units, stride=2):
100 | return [Bottleneck(in_channel, out_channel, stride)] + [Bottleneck(out_channel, out_channel, 1) for i in range(num_units - 1)]
101 |
102 |
103 | def get_blocks(num_layers):
104 | if num_layers == 50:
105 | blocks = [
106 | get_block(in_channel=64, out_channel=64, num_units=3),
107 | get_block(in_channel=64, out_channel=128, num_units=4),
108 | get_block(in_channel=128, out_channel=256, num_units=14),
109 | get_block(in_channel=256, out_channel=512, num_units=3)
110 | ]
111 | elif num_layers == 100:
112 | blocks = [
113 | get_block(in_channel=64, out_channel=64, num_units=3),
114 | get_block(in_channel=64, out_channel=128, num_units=13),
115 | get_block(in_channel=128, out_channel=256, num_units=30),
116 | get_block(in_channel=256, out_channel=512, num_units=3)
117 | ]
118 | elif num_layers == 152:
119 | blocks = [
120 | get_block(in_channel=64, out_channel=64, num_units=3),
121 | get_block(in_channel=64, out_channel=128, num_units=8),
122 | get_block(in_channel=128, out_channel=256, num_units=36),
123 | get_block(in_channel=256, out_channel=512, num_units=3)
124 | ]
125 | return blocks
126 |
127 |
128 | class SEResNet_IR(nn.Module):
129 | def __init__(self, num_layers, feature_dim=512, drop_ratio=0.4, mode = 'ir'):
130 | super(SEResNet_IR, self).__init__()
131 | assert num_layers in [50, 100, 152], 'num_layers should be 50, 100 or 152'
132 | assert mode in ['ir', 'se_ir'], 'mode should be ir or se_ir'
133 | blocks = get_blocks(num_layers)
134 | if mode == 'ir':
135 | unit_module = BottleNeck_IR
136 | elif mode == 'se_ir':
137 | unit_module = BottleNeck_IR_SE
138 | self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
139 | nn.BatchNorm2d(64),
140 | nn.PReLU(64))
141 |
142 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
143 | nn.Dropout(drop_ratio),
144 | Flatten(),
145 | nn.Linear(512 * 7 * 7, feature_dim),
146 | nn.BatchNorm1d(feature_dim))
147 | modules = []
148 | for block in blocks:
149 | for bottleneck in block:
150 | modules.append(
151 | unit_module(bottleneck.in_channel,
152 | bottleneck.out_channel,
153 | bottleneck.stride))
154 | self.body = nn.Sequential(*modules)
155 |
156 | def forward(self, x):
157 | x = self.input_layer(x)
158 | x = self.body(x)
159 | x = self.output_layer(x)
160 |
161 | return x
162 |
163 |
164 | if __name__ == '__main__':
165 | input = torch.Tensor(2, 3, 112, 112)
166 | net = SEResNet_IR(100, mode='se_ir')
167 | print(net)
168 |
169 | x = net(input)
170 | print(x.shape)
--------------------------------------------------------------------------------
/backbone/attention.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: attention.py
7 | @time: 2019/2/14 14:12
8 | @desc: Residual Attention Network for Image Classification, CVPR 2017.
9 | Attention 56 and Attention 92.
10 | '''
11 |
12 |
13 | import torch
14 | import torch.nn as nn
15 | import numpy as np
16 |
17 | class Flatten(nn.Module):
18 | def forward(self, input):
19 | return input.view(input.size(0), -1)
20 |
21 | class ResidualBlock(nn.Module):
22 |
23 | def __init__(self, in_channel, out_channel, stride=1):
24 | super(ResidualBlock, self).__init__()
25 | self.in_channel = in_channel
26 | self.out_channel = out_channel
27 | self.stride = stride
28 |
29 | self.res_bottleneck = nn.Sequential(nn.BatchNorm2d(in_channel),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(in_channel, out_channel//4, 1, 1, bias=False),
32 | nn.BatchNorm2d(out_channel//4),
33 | nn.ReLU(inplace=True),
34 | nn.Conv2d(out_channel//4, out_channel//4, 3, stride, padding=1, bias=False),
35 | nn.BatchNorm2d(out_channel//4),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(out_channel//4, out_channel, 1, 1, bias=False))
38 | self.shortcut = nn.Conv2d(in_channel, out_channel, 1, stride, bias=False)
39 |
40 | def forward(self, x):
41 | res = x
42 | out = self.res_bottleneck(x)
43 | if self.in_channel != self.out_channel or self.stride != 1:
44 | res = self.shortcut(x)
45 |
46 | out += res
47 | return out
48 |
49 | class AttentionModule_stage1(nn.Module):
50 |
51 | # input size is 56*56
52 | def __init__(self, in_channel, out_channel, size1=(56, 56), size2=(28, 28), size3=(14, 14)):
53 | super(AttentionModule_stage1, self).__init__()
54 | self.share_residual_block = ResidualBlock(in_channel, out_channel)
55 | self.trunk_branches = nn.Sequential(ResidualBlock(in_channel, out_channel),
56 | ResidualBlock(in_channel, out_channel))
57 |
58 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
59 | self.mask_block1 = ResidualBlock(in_channel, out_channel)
60 | self.skip_connect1 = ResidualBlock(in_channel, out_channel)
61 |
62 | self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
63 | self.mask_block2 = ResidualBlock(in_channel, out_channel)
64 | self.skip_connect2 = ResidualBlock(in_channel, out_channel)
65 |
66 | self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
67 | self.mask_block3 = nn.Sequential(ResidualBlock(in_channel, out_channel),
68 | ResidualBlock(in_channel, out_channel))
69 |
70 | self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
71 | self.mask_block4 = ResidualBlock(in_channel, out_channel)
72 |
73 | self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
74 | self.mask_block5 = ResidualBlock(in_channel, out_channel)
75 |
76 | self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
77 | self.mask_block6 = nn.Sequential(nn.BatchNorm2d(out_channel),
78 | nn.ReLU(inplace=True),
79 | nn.Conv2d(out_channel, out_channel, 1, 1, bias=False),
80 | nn.BatchNorm2d(out_channel),
81 | nn.ReLU(inplace=True),
82 | nn.Conv2d(out_channel, out_channel, 1, 1, bias=False),
83 | nn.Sigmoid())
84 |
85 | self.last_block = ResidualBlock(in_channel, out_channel)
86 |
87 | def forward(self, x):
88 | x = self.share_residual_block(x)
89 | out_trunk = self.trunk_branches(x)
90 |
91 | out_pool1 = self.mpool1(x)
92 | out_block1 = self.mask_block1(out_pool1)
93 | out_skip_connect1 = self.skip_connect1(out_block1)
94 |
95 | out_pool2 = self.mpool2(out_block1)
96 | out_block2 = self.mask_block2(out_pool2)
97 | out_skip_connect2 = self.skip_connect2(out_block2)
98 |
99 | out_pool3 = self.mpool3(out_block2)
100 | out_block3 = self.mask_block3(out_pool3)
101 | #
102 | out_inter3 = self.interpolation3(out_block3) + out_block2
103 | out = out_inter3 + out_skip_connect2
104 | out_block4 = self.mask_block4(out)
105 |
106 | out_inter2 = self.interpolation2(out_block4) + out_block1
107 | out = out_inter2 + out_skip_connect1
108 | out_block5 = self.mask_block5(out)
109 |
110 | out_inter1 = self.interpolation1(out_block5) + out_trunk
111 | out_block6 = self.mask_block6(out_inter1)
112 |
113 | out = (1 + out_block6) + out_trunk
114 | out_last = self.last_block(out)
115 |
116 | return out_last
117 |
118 | class AttentionModule_stage2(nn.Module):
119 |
120 | # input image size is 28*28
121 | def __init__(self, in_channels, out_channels, size1=(28, 28), size2=(14, 14)):
122 | super(AttentionModule_stage2, self).__init__()
123 | self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
124 |
125 | self.trunk_branches = nn.Sequential(
126 | ResidualBlock(in_channels, out_channels),
127 | ResidualBlock(in_channels, out_channels)
128 | )
129 |
130 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
131 | self.softmax1_blocks = ResidualBlock(in_channels, out_channels)
132 | self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)
133 |
134 | self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
135 | self.softmax2_blocks = nn.Sequential(
136 | ResidualBlock(in_channels, out_channels),
137 | ResidualBlock(in_channels, out_channels)
138 | )
139 |
140 | self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
141 | self.softmax3_blocks = ResidualBlock(in_channels, out_channels)
142 | self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
143 | self.softmax4_blocks = nn.Sequential(
144 | nn.BatchNorm2d(out_channels),
145 | nn.ReLU(inplace=True),
146 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
147 | nn.BatchNorm2d(out_channels),
148 | nn.ReLU(inplace=True),
149 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
150 | nn.Sigmoid()
151 | )
152 | self.last_blocks = ResidualBlock(in_channels, out_channels)
153 |
154 | def forward(self, x):
155 | x = self.first_residual_blocks(x)
156 | out_trunk = self.trunk_branches(x)
157 | out_mpool1 = self.mpool1(x)
158 | out_softmax1 = self.softmax1_blocks(out_mpool1)
159 | out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
160 |
161 | out_mpool2 = self.mpool2(out_softmax1)
162 | out_softmax2 = self.softmax2_blocks(out_mpool2)
163 |
164 | out_interp2 = self.interpolation2(out_softmax2) + out_softmax1
165 | out = out_interp2 + out_skip1_connection
166 |
167 | out_softmax3 = self.softmax3_blocks(out)
168 | out_interp1 = self.interpolation1(out_softmax3) + out_trunk
169 | out_softmax4 = self.softmax4_blocks(out_interp1)
170 | out = (1 + out_softmax4) * out_trunk
171 | out_last = self.last_blocks(out)
172 |
173 | return out_last
174 |
175 | class AttentionModule_stage3(nn.Module):
176 |
177 | # input image size is 14*14
178 | def __init__(self, in_channels, out_channels, size1=(14, 14)):
179 | super(AttentionModule_stage3, self).__init__()
180 | self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
181 |
182 | self.trunk_branches = nn.Sequential(
183 | ResidualBlock(in_channels, out_channels),
184 | ResidualBlock(in_channels, out_channels)
185 | )
186 |
187 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
188 | self.softmax1_blocks = nn.Sequential(
189 | ResidualBlock(in_channels, out_channels),
190 | ResidualBlock(in_channels, out_channels)
191 | )
192 |
193 | self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
194 |
195 | self.softmax2_blocks = nn.Sequential(
196 | nn.BatchNorm2d(out_channels),
197 | nn.ReLU(inplace=True),
198 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
199 | nn.BatchNorm2d(out_channels),
200 | nn.ReLU(inplace=True),
201 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
202 | nn.Sigmoid()
203 | )
204 |
205 | self.last_blocks = ResidualBlock(in_channels, out_channels)
206 |
207 | def forward(self, x):
208 | x = self.first_residual_blocks(x)
209 | out_trunk = self.trunk_branches(x)
210 | out_mpool1 = self.mpool1(x)
211 | out_softmax1 = self.softmax1_blocks(out_mpool1)
212 |
213 | out_interp1 = self.interpolation1(out_softmax1) + out_trunk
214 | out_softmax2 = self.softmax2_blocks(out_interp1)
215 | out = (1 + out_softmax2) * out_trunk
216 | out_last = self.last_blocks(out)
217 |
218 | return out_last
219 |
220 | class ResidualAttentionNet_56(nn.Module):
221 |
222 | # for input size 112
223 | def __init__(self, feature_dim=512, drop_ratio=0.4):
224 | super(ResidualAttentionNet_56, self).__init__()
225 | self.conv1 = nn.Sequential(
226 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias = False),
227 | nn.BatchNorm2d(64),
228 | nn.ReLU(inplace=True)
229 | )
230 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
231 | self.residual_block1 = ResidualBlock(64, 256)
232 | self.attention_module1 = AttentionModule_stage1(256, 256)
233 | self.residual_block2 = ResidualBlock(256, 512, 2)
234 | self.attention_module2 = AttentionModule_stage2(512, 512)
235 | self.residual_block3 = ResidualBlock(512, 512, 2)
236 | self.attention_module3 = AttentionModule_stage3(512, 512)
237 | self.residual_block4 = ResidualBlock(512, 512, 2)
238 | self.residual_block5 = ResidualBlock(512, 512)
239 | self.residual_block6 = ResidualBlock(512, 512)
240 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
241 | nn.Dropout(drop_ratio),
242 | Flatten(),
243 | nn.Linear(512 * 7 * 7, feature_dim),
244 | nn.BatchNorm1d(feature_dim))
245 |
246 | def forward(self, x):
247 | out = self.conv1(x)
248 | out = self.mpool1(out)
249 | # print(out.data)
250 | out = self.residual_block1(out)
251 | out = self.attention_module1(out)
252 | out = self.residual_block2(out)
253 | out = self.attention_module2(out)
254 | out = self.residual_block3(out)
255 | # print(out.data)
256 | out = self.attention_module3(out)
257 | out = self.residual_block4(out)
258 | out = self.residual_block5(out)
259 | out = self.residual_block6(out)
260 | out = self.output_layer(out)
261 |
262 | return out
263 |
264 | class ResidualAttentionNet_92(nn.Module):
265 |
266 | # for input size 112
267 | def __init__(self, feature_dim=512, drop_ratio=0.4):
268 | super(ResidualAttentionNet_92, self).__init__()
269 | self.conv1 = nn.Sequential(
270 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias = False),
271 | nn.BatchNorm2d(64),
272 | nn.ReLU(inplace=True)
273 | )
274 | self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
275 | self.residual_block1 = ResidualBlock(64, 256)
276 | self.attention_module1 = AttentionModule_stage1(256, 256)
277 | self.residual_block2 = ResidualBlock(256, 512, 2)
278 | self.attention_module2 = AttentionModule_stage2(512, 512)
279 | self.attention_module2_2 = AttentionModule_stage2(512, 512) # tbq add
280 | self.residual_block3 = ResidualBlock(512, 1024, 2)
281 | self.attention_module3 = AttentionModule_stage3(1024, 1024)
282 | self.attention_module3_2 = AttentionModule_stage3(1024, 1024) # tbq add
283 | self.attention_module3_3 = AttentionModule_stage3(1024, 1024) # tbq add
284 | self.residual_block4 = ResidualBlock(1024, 2048, 2)
285 | self.residual_block5 = ResidualBlock(2048, 2048)
286 | self.residual_block6 = ResidualBlock(2048, 2048)
287 | self.output_layer = nn.Sequential(nn.BatchNorm2d(2048),
288 | nn.Dropout(drop_ratio),
289 | Flatten(),
290 | nn.Linear(2048 * 7 * 7, feature_dim),
291 | nn.BatchNorm1d(feature_dim))
292 |
293 | def forward(self, x):
294 | out = self.conv1(x)
295 | out = self.mpool1(out)
296 | # print(out.data)
297 | out = self.residual_block1(out)
298 | out = self.attention_module1(out)
299 | out = self.residual_block2(out)
300 | out = self.attention_module2(out)
301 | out = self.attention_module2_2(out)
302 | out = self.residual_block3(out)
303 | # print(out.data)
304 | out = self.attention_module3(out)
305 | out = self.attention_module3_2(out)
306 | out = self.attention_module3_3(out)
307 | out = self.residual_block4(out)
308 | out = self.residual_block5(out)
309 | out = self.residual_block6(out)
310 | out = self.output_layer(out)
311 |
312 | return out
313 |
314 |
315 | if __name__ == '__main__':
316 | input = torch.Tensor(2, 3, 112, 112)
317 | net = ResidualAttentionNet_56()
318 | print(net)
319 |
320 | x = net(input)
321 | print(x.shape)
--------------------------------------------------------------------------------
/backbone/cbam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: cbam.py
7 | @time: 2019/1/14 15:33
8 | @desc: Convolutional Block Attention Module in ECCV 2018, including channel attention module and spatial attention module.
9 | '''
10 |
11 | import torch
12 | from torch import nn
13 | import time
14 |
15 | class Flatten(nn.Module):
16 | def forward(self, input):
17 | return input.view(input.size(0), -1)
18 |
19 | class SEModule(nn.Module):
20 | '''Squeeze and Excitation Module'''
21 | def __init__(self, channels, reduction):
22 | super(SEModule, self).__init__()
23 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
24 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
27 | self.sigmoid = nn.Sigmoid()
28 |
29 | def forward(self, x):
30 | input = x
31 | x = self.avg_pool(x)
32 | x = self.fc1(x)
33 | x = self.relu(x)
34 | x = self.fc2(x)
35 | x = self.sigmoid(x)
36 |
37 | return input * x
38 |
39 | class CAModule(nn.Module):
40 | '''Channel Attention Module'''
41 | def __init__(self, channels, reduction):
42 | super(CAModule, self).__init__()
43 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
44 | self.max_pool = nn.AdaptiveMaxPool2d(1)
45 | self.shared_mlp = nn.Sequential(nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False),
46 | nn.ReLU(inplace=True),
47 | nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False))
48 | self.sigmoid = nn.Sigmoid()
49 |
50 | def forward(self, x):
51 | input = x
52 | avg_pool = self.avg_pool(x)
53 | max_pool = self.max_pool(x)
54 | x = self.shared_mlp(avg_pool) + self.shared_mlp(max_pool)
55 | x = self.sigmoid(x)
56 |
57 | return input * x
58 |
59 | class SAModule(nn.Module):
60 | '''Spatial Attention Module'''
61 | def __init__(self):
62 | super(SAModule, self).__init__()
63 | self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False)
64 | self.sigmoid = nn.Sigmoid()
65 |
66 | def forward(self, x):
67 | input = x
68 | avg_c = torch.mean(x, 1, True)
69 | max_c, _ = torch.max(x, 1, True)
70 | x = torch.cat((avg_c, max_c), 1)
71 | x = self.conv(x)
72 | x = self.sigmoid(x)
73 | return input * x
74 |
75 | class BottleNeck_IR(nn.Module):
76 | '''Improved Residual Bottlenecks'''
77 | def __init__(self, in_channel, out_channel, stride, dim_match):
78 | super(BottleNeck_IR, self).__init__()
79 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
80 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
81 | nn.BatchNorm2d(out_channel),
82 | nn.PReLU(out_channel),
83 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
84 | nn.BatchNorm2d(out_channel))
85 | if dim_match:
86 | self.shortcut_layer = None
87 | else:
88 | self.shortcut_layer = nn.Sequential(
89 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
90 | nn.BatchNorm2d(out_channel)
91 | )
92 |
93 | def forward(self, x):
94 | shortcut = x
95 | res = self.res_layer(x)
96 |
97 | if self.shortcut_layer is not None:
98 | shortcut = self.shortcut_layer(x)
99 |
100 | return shortcut + res
101 |
102 | class BottleNeck_IR_SE(nn.Module):
103 | '''Improved Residual Bottlenecks with Squeeze and Excitation Module'''
104 | def __init__(self, in_channel, out_channel, stride, dim_match):
105 | super(BottleNeck_IR_SE, self).__init__()
106 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
107 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
108 | nn.BatchNorm2d(out_channel),
109 | nn.PReLU(out_channel),
110 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
111 | nn.BatchNorm2d(out_channel),
112 | SEModule(out_channel, 16))
113 | if dim_match:
114 | self.shortcut_layer = None
115 | else:
116 | self.shortcut_layer = nn.Sequential(
117 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
118 | nn.BatchNorm2d(out_channel)
119 | )
120 |
121 | def forward(self, x):
122 | shortcut = x
123 | res = self.res_layer(x)
124 |
125 | if self.shortcut_layer is not None:
126 | shortcut = self.shortcut_layer(x)
127 |
128 | return shortcut + res
129 |
130 | class BottleNeck_IR_CAM(nn.Module):
131 | '''Improved Residual Bottlenecks with Channel Attention Module'''
132 | def __init__(self, in_channel, out_channel, stride, dim_match):
133 | super(BottleNeck_IR_CAM, self).__init__()
134 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
135 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
136 | nn.BatchNorm2d(out_channel),
137 | nn.PReLU(out_channel),
138 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
139 | nn.BatchNorm2d(out_channel),
140 | CAModule(out_channel, 16))
141 | if dim_match:
142 | self.shortcut_layer = None
143 | else:
144 | self.shortcut_layer = nn.Sequential(
145 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
146 | nn.BatchNorm2d(out_channel)
147 | )
148 |
149 | def forward(self, x):
150 | shortcut = x
151 | res = self.res_layer(x)
152 |
153 | if self.shortcut_layer is not None:
154 | shortcut = self.shortcut_layer(x)
155 |
156 | return shortcut + res
157 |
158 | class BottleNeck_IR_SAM(nn.Module):
159 | '''Improved Residual Bottlenecks with Spatial Attention Module'''
160 | def __init__(self, in_channel, out_channel, stride, dim_match):
161 | super(BottleNeck_IR_SAM, self).__init__()
162 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
163 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
164 | nn.BatchNorm2d(out_channel),
165 | nn.PReLU(out_channel),
166 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
167 | nn.BatchNorm2d(out_channel),
168 | SAModule())
169 | if dim_match:
170 | self.shortcut_layer = None
171 | else:
172 | self.shortcut_layer = nn.Sequential(
173 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
174 | nn.BatchNorm2d(out_channel)
175 | )
176 |
177 | def forward(self, x):
178 | shortcut = x
179 | res = self.res_layer(x)
180 |
181 | if self.shortcut_layer is not None:
182 | shortcut = self.shortcut_layer(x)
183 |
184 | return shortcut + res
185 |
186 | class BottleNeck_IR_CBAM(nn.Module):
187 | '''Improved Residual Bottleneck with Channel Attention Module and Spatial Attention Module'''
188 | def __init__(self, in_channel, out_channel, stride, dim_match):
189 | super(BottleNeck_IR_CBAM, self).__init__()
190 | self.res_layer = nn.Sequential(nn.BatchNorm2d(in_channel),
191 | nn.Conv2d(in_channel, out_channel, (3, 3), 1, 1, bias=False),
192 | nn.BatchNorm2d(out_channel),
193 | nn.PReLU(out_channel),
194 | nn.Conv2d(out_channel, out_channel, (3, 3), stride, 1, bias=False),
195 | nn.BatchNorm2d(out_channel),
196 | CAModule(out_channel, 16),
197 | SAModule()
198 | )
199 | if dim_match:
200 | self.shortcut_layer = None
201 | else:
202 | self.shortcut_layer = nn.Sequential(
203 | nn.Conv2d(in_channel, out_channel, kernel_size=(1, 1), stride=stride, bias=False),
204 | nn.BatchNorm2d(out_channel)
205 | )
206 |
207 | def forward(self, x):
208 | shortcut = x
209 | res = self.res_layer(x)
210 |
211 | if self.shortcut_layer is not None:
212 | shortcut = self.shortcut_layer(x)
213 |
214 | return shortcut + res
215 |
216 |
217 | filter_list = [64, 64, 128, 256, 512]
218 | def get_layers(num_layers):
219 | if num_layers == 50:
220 | return [3, 4, 14, 3]
221 | elif num_layers == 100:
222 | return [3, 13, 30, 3]
223 | elif num_layers == 152:
224 | return [3, 8, 36, 3]
225 |
226 | class CBAMResNet(nn.Module):
227 | def __init__(self, num_layers, feature_dim=512, drop_ratio=0.4, mode='ir',filter_list=filter_list):
228 | super(CBAMResNet, self).__init__()
229 | assert num_layers in [50, 100, 152], 'num_layers should be 50, 100 or 152'
230 | assert mode in ['ir', 'ir_se', 'ir_cam', 'ir_sam', 'ir_cbam'], 'mode should be ir, ir_se, ir_cam, ir_sam or ir_cbam'
231 | layers = get_layers(num_layers)
232 | if mode == 'ir':
233 | block = BottleNeck_IR
234 | elif mode == 'ir_se':
235 | block = BottleNeck_IR_SE
236 | elif mode == 'ir_cam':
237 | block = BottleNeck_IR_CAM
238 | elif mode == 'ir_sam':
239 | block = BottleNeck_IR_SAM
240 | elif mode == 'ir_cbam':
241 | block = BottleNeck_IR_CBAM
242 |
243 | self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), stride=1, padding=1, bias=False),
244 | nn.BatchNorm2d(64),
245 | nn.PReLU(64))
246 | self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2)
247 | self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2)
248 | self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2)
249 | self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2)
250 |
251 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512),
252 | nn.Dropout(drop_ratio),
253 | Flatten(),
254 | nn.Linear(512 * 7 * 7, feature_dim),
255 | nn.BatchNorm1d(feature_dim))
256 |
257 | # weight initialization
258 | for m in self.modules():
259 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
260 | nn.init.xavier_uniform_(m.weight)
261 | if m.bias is not None:
262 | nn.init.constant_(m.bias, 0.0)
263 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
264 | nn.init.constant_(m.weight, 1)
265 | nn.init.constant_(m.bias, 0)
266 |
267 | def _make_layer(self, block, in_channel, out_channel, blocks, stride):
268 | layers = []
269 | layers.append(block(in_channel, out_channel, stride, False))
270 | for i in range(1, blocks):
271 | layers.append(block(out_channel, out_channel, 1, True))
272 |
273 | return nn.Sequential(*layers)
274 |
275 | def forward(self, x):
276 | x = self.input_layer(x)
277 | x = self.layer1(x)
278 | x = self.layer2(x)
279 | x = self.layer3(x)
280 | x = self.layer4(x)
281 | x = self.output_layer(x)
282 |
283 | return x
284 |
285 | if __name__ == '__main__':
286 | input = torch.Tensor(2, 3, 112, 112)
287 | net = CBAMResNet(50, mode='ir')
288 |
289 | out = net(input)
290 | print(out.shape)
291 |
--------------------------------------------------------------------------------
/backbone/mobilefacenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: mobilefacenet.py
7 | @time: 2018/12/21 15:45
8 | @desc: mobilefacenet backbone
9 | '''
10 |
11 | import torch
12 | from torch import nn
13 | import math
14 |
15 | MobileFaceNet_BottleNeck_Setting = [
16 | # t, c , n ,s
17 | [2, 64, 5, 2],
18 | [4, 128, 1, 2],
19 | [2, 128, 6, 1],
20 | [4, 128, 1, 2],
21 | [2, 128, 2, 1]
22 | ]
23 |
24 | class BottleNeck(nn.Module):
25 | def __init__(self, inp, oup, stride, expansion):
26 | super(BottleNeck, self).__init__()
27 | self.connect = stride == 1 and inp == oup
28 |
29 | self.conv = nn.Sequential(
30 | # 1*1 conv
31 | nn.Conv2d(inp, inp * expansion, 1, 1, 0, bias=False),
32 | nn.BatchNorm2d(inp * expansion),
33 | nn.PReLU(inp * expansion),
34 |
35 | # 3*3 depth wise conv
36 | nn.Conv2d(inp * expansion, inp * expansion, 3, stride, 1, groups=inp * expansion, bias=False),
37 | nn.BatchNorm2d(inp * expansion),
38 | nn.PReLU(inp * expansion),
39 |
40 | # 1*1 conv
41 | nn.Conv2d(inp * expansion, oup, 1, 1, 0, bias=False),
42 | nn.BatchNorm2d(oup),
43 | )
44 |
45 | def forward(self, x):
46 | if self.connect:
47 | return x + self.conv(x)
48 | else:
49 | return self.conv(x)
50 |
51 |
52 | class ConvBlock(nn.Module):
53 | def __init__(self, inp, oup, k, s, p, dw=False, linear=False):
54 | super(ConvBlock, self).__init__()
55 | self.linear = linear
56 | if dw:
57 | self.conv = nn.Conv2d(inp, oup, k, s, p, groups=inp, bias=False)
58 | else:
59 | self.conv = nn.Conv2d(inp, oup, k, s, p, bias=False)
60 |
61 | self.bn = nn.BatchNorm2d(oup)
62 | if not linear:
63 | self.prelu = nn.PReLU(oup)
64 |
65 | def forward(self, x):
66 | x = self.conv(x)
67 | x = self.bn(x)
68 | if self.linear:
69 | return x
70 | else:
71 | return self.prelu(x)
72 |
73 |
74 | class MobileFaceNet(nn.Module):
75 | def __init__(self, feature_dim=128, bottleneck_setting=MobileFaceNet_BottleNeck_Setting):
76 | super(MobileFaceNet, self).__init__()
77 | self.conv1 = ConvBlock(3, 64, 3, 2, 1)
78 | self.dw_conv1 = ConvBlock(64, 64, 3, 1, 1, dw=True)
79 |
80 | self.cur_channel = 64
81 | block = BottleNeck
82 | self.blocks = self._make_layer(block, bottleneck_setting)
83 |
84 | self.conv2 = ConvBlock(128, 512, 1, 1, 0)
85 | self.linear7 = ConvBlock(512, 512, 7, 1, 0, dw=True, linear=True)
86 | self.linear1 = ConvBlock(512, feature_dim, 1, 1, 0, linear=True)
87 |
88 | for m in self.modules():
89 | if isinstance(m, nn.Conv2d):
90 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
91 | m.weight.data.normal_(0, math.sqrt(2. / n))
92 | elif isinstance(m, nn.BatchNorm2d):
93 | m.weight.data.fill_(1)
94 | m.bias.data.zero_()
95 |
96 | def _make_layer(self, block, setting):
97 | layers = []
98 | for t, c, n, s in setting:
99 | for i in range(n):
100 | if i == 0:
101 | layers.append(block(self.cur_channel, c, s, t))
102 | else:
103 | layers.append(block(self.cur_channel, c, 1, t))
104 | self.cur_channel = c
105 |
106 | return nn.Sequential(*layers)
107 |
108 | def forward(self, x):
109 | x = self.conv1(x)
110 | x = self.dw_conv1(x)
111 | x = self.blocks(x)
112 | x = self.conv2(x)
113 | x = self.linear7(x)
114 | x = self.linear1(x)
115 | x = x.view(x.size(0), -1)
116 |
117 | return x
118 |
119 |
120 | if __name__ == "__main__":
121 | input = torch.Tensor(2, 3, 112, 112)
122 | net = MobileFaceNet()
123 | print(net)
124 |
125 | x = net(input)
126 | print(x.shape)
--------------------------------------------------------------------------------
/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: resnet.py
7 | @time: 2018/12/24 14:40
8 | @desc: Original ResNet backbone, including ResNet18, ResNet34, ResNet50, ResNet101 and ResNet152, we removed the last global average pooling layer
9 | and replaced it with a fully connected layer with dimension of 512. BN is used for fast convergence.
10 | '''
11 | import torch
12 | import torch.nn as nn
13 |
14 | def ResNet18():
15 | model = ResNet(BasicBlock, [2, 2, 2, 2])
16 | return model
17 |
18 | def ResNet34():
19 | model = ResNet(BasicBlock, [3, 4, 6, 3])
20 | return model
21 |
22 | def ResNet50():
23 | model = ResNet(Bottleneck, [3, 4, 6, 3])
24 | return model
25 |
26 | def ResNet101():
27 | model = ResNet(Bottleneck, [3, 4, 23, 3])
28 | return model
29 |
30 | def ResNet152():
31 | model = ResNet(Bottleneck, [3, 8, 36, 3])
32 | return model
33 |
34 | __all__ = ['ResNet', 'ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', 'ResNet152']
35 |
36 |
37 | def conv3x3(in_planes, out_planes, stride=1):
38 | """3x3 convolution with padding"""
39 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
40 |
41 |
42 | def conv1x1(in_planes, out_planes, stride=1):
43 | """1x1 convolution"""
44 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
45 |
46 |
47 | class BasicBlock(nn.Module):
48 | expansion = 1
49 |
50 | def __init__(self, inplanes, planes, stride=1, downsample=None):
51 | super(BasicBlock, self).__init__()
52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
53 | self.bn1 = nn.BatchNorm2d(planes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
56 | self.bn2 = nn.BatchNorm2d(planes)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x):
61 | identity = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.relu(out)
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | identity = self.downsample(x)
71 |
72 | out += identity
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 |
78 | class Bottleneck(nn.Module):
79 | expansion = 4
80 | def __init__(self, inplanes, planes, stride=1, downsample=None):
81 | super(Bottleneck, self).__init__()
82 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False)
83 | self.bn1 = nn.BatchNorm2d(planes)
84 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
85 | self.bn2 = nn.BatchNorm2d(planes)
86 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, stride=1, bias=False)
87 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
88 | self.relu = nn.ReLU(inplace=True)
89 | self.downsample = downsample
90 | self.stride = stride
91 |
92 | def forward(self, x):
93 | identity = x
94 |
95 | out = self.conv1(x)
96 | out = self.bn1(out)
97 | out = self.relu(out)
98 |
99 | out = self.conv2(out)
100 | out = self.bn2(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv3(out)
104 | out = self.bn3(out)
105 |
106 | if self.downsample is not None:
107 | identity = self.downsample(x)
108 |
109 | out += identity
110 | out = self.relu(out)
111 |
112 | return out
113 |
114 |
115 | class Flatten(nn.Module):
116 | def forward(self, input):
117 | return input.view(input.size(0), -1)
118 |
119 |
120 | class ResNet(nn.Module):
121 |
122 | def __init__(self, block, layers, feature_dim=512, drop_ratio=0.4, zero_init_residual=False):
123 | super(ResNet, self).__init__()
124 | self.inplanes = 64
125 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
126 | self.bn1 = nn.BatchNorm2d(64)
127 | self.relu = nn.ReLU(inplace=True)
128 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
129 | self.layer1 = self._make_layer(block, 64, layers[0])
130 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
131 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
132 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
133 |
134 | self.output_layer = nn.Sequential(nn.BatchNorm2d(512 * block.expansion),
135 | nn.Dropout(drop_ratio),
136 | Flatten(),
137 | nn.Linear(512 * block.expansion * 7 * 7, feature_dim),
138 | nn.BatchNorm1d(feature_dim))
139 |
140 | for m in self.modules():
141 | if isinstance(m, nn.Conv2d):
142 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
143 | elif isinstance(m, nn.BatchNorm2d):
144 | nn.init.constant_(m.weight, 1)
145 | nn.init.constant_(m.bias, 0)
146 |
147 | # Zero-initialize the last BN in each residual branch,
148 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
149 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
150 | if zero_init_residual:
151 | for m in self.modules():
152 | if isinstance(m, Bottleneck):
153 | nn.init.constant_(m.bn3.weight, 0)
154 | elif isinstance(m, BasicBlock):
155 | nn.init.constant_(m.bn2.weight, 0)
156 |
157 | def _make_layer(self, block, planes, blocks, stride=1):
158 | downsample = None
159 | if stride != 1 or self.inplanes != planes * block.expansion:
160 | downsample = nn.Sequential(
161 | conv1x1(self.inplanes, planes * block.expansion, stride),
162 | nn.BatchNorm2d(planes * block.expansion),
163 | )
164 |
165 | layers = []
166 | layers.append(block(self.inplanes, planes, stride, downsample))
167 | self.inplanes = planes * block.expansion
168 | for _ in range(1, blocks):
169 | layers.append(block(self.inplanes, planes))
170 |
171 | return nn.Sequential(*layers)
172 |
173 | def forward(self, x):
174 | x = self.conv1(x)
175 | x = self.bn1(x)
176 | x = self.relu(x)
177 | x = self.maxpool(x)
178 |
179 | x = self.layer1(x)
180 | x = self.layer2(x)
181 | x = self.layer3(x)
182 | x = self.layer4(x)
183 |
184 | x = self.output_layer(x)
185 |
186 | return x
187 |
188 |
189 | if __name__ == "__main__":
190 | input = torch.Tensor(2, 3, 112, 112)
191 | net = ResNet50()
192 | print(net)
193 |
194 | x = net(input)
195 | print(x.shape)
--------------------------------------------------------------------------------
/backbone/spherenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: spherenet.py
7 | @time: 2018/12/26 10:14
8 | @desc: A 64 layer residual network struture used in sphereface and cosface, for fast convergence, I add BN after every Conv layer.
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | class Block(nn.Module):
15 | def __init__(self, channels):
16 | super(Block, self).__init__()
17 | self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(channels)
19 | self.prelu1 = nn.PReLU(channels)
20 | self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(channels)
22 | self.prelu2 = nn.PReLU(channels)
23 |
24 | def forward(self, x):
25 | short_cut = x
26 | x = self.conv1(x)
27 | x = self.bn1(x)
28 | x = self.prelu1(x)
29 | x = self.conv2(x)
30 | x = self.bn2(x)
31 | x = self.prelu2(x)
32 |
33 | return x + short_cut
34 |
35 |
36 | class SphereNet(nn.Module):
37 | def __init__(self, num_layers = 20, feature_dim=512):
38 | super(SphereNet, self).__init__()
39 | assert num_layers in [20, 64], 'SphereNet num_layers should be 20 or 64'
40 | if num_layers == 20:
41 | layers = [1, 2, 4, 1]
42 | elif num_layers == 64:
43 | layers = [3, 7, 16, 3]
44 | else:
45 | raise ValueError('sphere' + str(num_layers) + " IS NOT SUPPORTED! (sphere20 or sphere64)")
46 |
47 | filter_list = [3, 64, 128, 256, 512]
48 | block = Block
49 | self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2)
50 | self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2)
51 | self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2)
52 | self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2)
53 | self.fc = nn.Linear(512 * 7 * 7, feature_dim)
54 | self.last_bn = nn.BatchNorm1d(feature_dim)
55 |
56 | for m in self.modules():
57 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
58 | if m.bias is not None:
59 | nn.init.xavier_uniform_(m.weight)
60 | nn.init.constant_(m.bias, 0)
61 | else:
62 | nn.init.normal_(m.weight, 0, 0.01)
63 |
64 | def _make_layer(self, block, inplanes, planes, num_units, stride):
65 | layers = []
66 | layers.append(nn.Conv2d(inplanes, planes, 3, stride, 1))
67 | layers.append(nn.BatchNorm2d(planes))
68 | layers.append(nn.PReLU(planes))
69 | for i in range(num_units):
70 | layers.append(block(planes))
71 |
72 | return nn.Sequential(*layers)
73 |
74 |
75 | def forward(self, x):
76 | x = self.layer1(x)
77 | x = self.layer2(x)
78 | x = self.layer3(x)
79 | x = self.layer4(x)
80 |
81 | x = x.view(x.size(0), -1)
82 | x = self.fc(x)
83 | x = self.last_bn(x)
84 |
85 | return x
86 |
87 |
88 | if __name__ == '__main__':
89 | input = torch.Tensor(2, 3, 112, 112)
90 | net = SphereNet(num_layers=64, feature_dim=512)
91 |
92 | out = net(input)
93 | print(out.shape)
94 |
95 |
--------------------------------------------------------------------------------
/cppapi/README.md:
--------------------------------------------------------------------------------
1 | ## CppAPI
2 |
3 | libtorch for C++ deployment
4 |
5 |
--------------------------------------------------------------------------------
/cppapi/pytorch2torchscript.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: pytorch2torchscript.py
7 | @time: 2019/2/18 17:45
8 | @desc: convert your pytorch model to torch script and save to file
9 | '''
10 |
11 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: __init__.py.py
7 | @time: 2018/12/21 15:31
8 | @desc:
9 | '''
--------------------------------------------------------------------------------
/dataset/agedb.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: agedb.py.py
7 | @time: 2018/12/25 18:43
8 | @desc: AgeDB-30 test data loader, agedb test protocol is the same with lfw
9 | '''
10 |
11 | import numpy as np
12 | import cv2
13 | import os
14 | import torch.utils.data as data
15 |
16 | import torch
17 | import torchvision.transforms as transforms
18 |
19 | def img_loader(path):
20 | try:
21 | with open(path, 'rb') as f:
22 | img = cv2.imread(path)
23 | if len(img.shape) == 2:
24 | img = np.stack([img] * 3, 2)
25 | return img
26 | except IOError:
27 | print('Cannot load image ' + path)
28 |
29 | class AgeDB30(data.Dataset):
30 | def __init__(self, root, file_list, transform=None, loader=img_loader):
31 |
32 | self.root = root
33 | self.file_list = file_list
34 | self.transform = transform
35 | self.loader = loader
36 | self.nameLs = []
37 | self.nameRs = []
38 | self.folds = []
39 | self.flags = []
40 |
41 | with open(file_list) as f:
42 | pairs = f.read().splitlines()
43 | for i, p in enumerate(pairs):
44 | p = p.split(' ')
45 | nameL = p[0]
46 | nameR = p[1]
47 | fold = i // 600
48 | flag = int(p[2])
49 |
50 | self.nameLs.append(nameL)
51 | self.nameRs.append(nameR)
52 | self.folds.append(fold)
53 | self.flags.append(flag)
54 |
55 | def __getitem__(self, index):
56 |
57 | img_l = self.loader(os.path.join(self.root, self.nameLs[index]))
58 | img_r = self.loader(os.path.join(self.root, self.nameRs[index]))
59 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)]
60 |
61 | if self.transform is not None:
62 | for i in range(len(imglist)):
63 | imglist[i] = self.transform(imglist[i])
64 |
65 | imgs = imglist
66 | return imgs
67 | else:
68 | imgs = [torch.from_numpy(i) for i in imglist]
69 | return imgs
70 |
71 | def __len__(self):
72 | return len(self.nameLs)
73 |
74 |
75 | if __name__ == '__main__':
76 | root = '/media/sda/AgeDB-30/agedb30_align_112'
77 | file_list = '/media/sda/AgeDB-30/agedb_30_pair.txt'
78 |
79 | transform = transforms.Compose([
80 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
81 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
82 | ])
83 |
84 | dataset = AgeDB30(root, file_list, transform=transform)
85 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False)
86 | for data in trainloader:
87 | for d in data:
88 | print(d[0].shape)
--------------------------------------------------------------------------------
/dataset/casia_webface.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: casia_webface.py
7 | @time: 2018/12/21 19:09
8 | @desc: CASIA-WebFace dataset loader
9 | '''
10 |
11 | import torchvision.transforms as transforms
12 | import torch.utils.data as data
13 | import numpy as np
14 | import cv2
15 | import os
16 | import torch
17 |
18 |
19 | def img_loader(path):
20 | try:
21 | with open(path, 'rb') as f:
22 | img = cv2.imread(path)
23 | if len(img.shape) == 2:
24 | img = np.stack([img] * 3, 2)
25 | return img
26 | except IOError:
27 | print('Cannot load image ' + path)
28 |
29 |
30 | class CASIAWebFace(data.Dataset):
31 | def __init__(self, root, file_list, transform=None, loader=img_loader):
32 |
33 | self.root = root
34 | self.transform = transform
35 | self.loader = loader
36 |
37 | image_list = []
38 | label_list = []
39 | with open(file_list) as f:
40 | img_label_list = f.read().splitlines()
41 | for info in img_label_list:
42 | image_path, label_name = info.split(' ')
43 | image_list.append(image_path)
44 | label_list.append(int(label_name))
45 |
46 | self.image_list = image_list
47 | self.label_list = label_list
48 | self.class_nums = len(np.unique(self.label_list))
49 | print("dataset size: ", len(self.image_list), '/', self.class_nums)
50 |
51 | def __getitem__(self, index):
52 | img_path = self.image_list[index]
53 | label = self.label_list[index]
54 |
55 | img = self.loader(os.path.join(self.root, img_path))
56 |
57 | # random flip with ratio of 0.5
58 | flip = np.random.choice(2) * 2 - 1
59 | if flip == 1:
60 | img = cv2.flip(img, 1)
61 |
62 | if self.transform is not None:
63 | img = self.transform(img)
64 | else:
65 | img = torch.from_numpy(img)
66 |
67 | return img, label
68 |
69 | def __len__(self):
70 | return len(self.image_list)
71 |
72 |
73 | if __name__ == '__main__':
74 | root = 'D:/data/webface_align_112'
75 | file_list = 'D:/data/webface_align_train.list'
76 |
77 | transform = transforms.Compose([
78 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
79 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
80 | ])
81 | dataset = CASIAWebFace(root, file_list, transform=transform)
82 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2, drop_last=False)
83 | print(len(dataset))
84 | for data in trainloader:
85 | print(data[0].shape)
--------------------------------------------------------------------------------
/dataset/cfp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: cfp.py
7 | @time: 2018/12/26 16:19
8 | @desc: the CFP-FP test dataset loader, it's similar with lfw and adedb, except that it has 700 pairs every fold
9 | '''
10 |
11 |
12 | import numpy as np
13 | import cv2
14 | import os
15 | import torch.utils.data as data
16 |
17 | import torch
18 | import torchvision.transforms as transforms
19 |
20 | def img_loader(path):
21 | try:
22 | with open(path, 'rb') as f:
23 | img = cv2.imread(path)
24 | if len(img.shape) == 2:
25 | img = np.stack([img] * 3, 2)
26 | return img
27 | except IOError:
28 | print('Cannot load image ' + path)
29 |
30 | class CFP_FP(data.Dataset):
31 | def __init__(self, root, file_list, transform=None, loader=img_loader):
32 |
33 | self.root = root
34 | self.file_list = file_list
35 | self.transform = transform
36 | self.loader = loader
37 | self.nameLs = []
38 | self.nameRs = []
39 | self.folds = []
40 | self.flags = []
41 |
42 | with open(file_list) as f:
43 | pairs = f.read().splitlines()
44 | for i, p in enumerate(pairs):
45 | p = p.split(' ')
46 | nameL = p[0]
47 | nameR = p[1]
48 | fold = i // 700
49 | flag = int(p[2])
50 |
51 | self.nameLs.append(nameL)
52 | self.nameRs.append(nameR)
53 | self.folds.append(fold)
54 | self.flags.append(flag)
55 |
56 | def __getitem__(self, index):
57 |
58 | img_l = self.loader(os.path.join(self.root, self.nameLs[index]))
59 | img_r = self.loader(os.path.join(self.root, self.nameRs[index]))
60 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)]
61 |
62 | if self.transform is not None:
63 | for i in range(len(imglist)):
64 | imglist[i] = self.transform(imglist[i])
65 |
66 | imgs = imglist
67 | return imgs
68 | else:
69 | imgs = [torch.from_numpy(i) for i in imglist]
70 | return imgs
71 |
72 | def __len__(self):
73 | return len(self.nameLs)
74 |
75 |
76 | if __name__ == '__main__':
77 | root = '/media/sda/CFP-FP/CFP_FP_aligned_112'
78 | file_list = '/media/sda/CFP-FP/cfp-fp-pair.txt'
79 |
80 | transform = transforms.Compose([
81 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
82 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
83 | ])
84 |
85 | dataset = CFP_FP(root, file_list, transform=transform)
86 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False)
87 | for data in trainloader:
88 | for d in data:
89 | print(d[0].shape)
--------------------------------------------------------------------------------
/dataset/lfw.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: lfw.py.py
7 | @time: 2018/12/22 10:00
8 | @desc: lfw dataset loader
9 | '''
10 |
11 | import numpy as np
12 | import cv2
13 | import os
14 | import torch.utils.data as data
15 |
16 | import torch
17 | import torchvision.transforms as transforms
18 |
19 | def img_loader(path):
20 | try:
21 | with open(path, 'rb') as f:
22 | img = cv2.imread(path)
23 | if len(img.shape) == 2:
24 | img = np.stack([img] * 3, 2)
25 | return img
26 | except IOError:
27 | print('Cannot load image ' + path)
28 |
29 | class LFW(data.Dataset):
30 | def __init__(self, root, file_list, transform=None, loader=img_loader):
31 |
32 | self.root = root
33 | self.file_list = file_list
34 | self.transform = transform
35 | self.loader = loader
36 | self.nameLs = []
37 | self.nameRs = []
38 | self.folds = []
39 | self.flags = []
40 |
41 | with open(file_list) as f:
42 | pairs = f.read().splitlines()[1:]
43 | for i, p in enumerate(pairs):
44 | p = p.split('\t')
45 | if len(p) == 3:
46 | nameL = p[0] + '/' + p[0] + '_' + '{:04}.jpg'.format(int(p[1]))
47 | nameR = p[0] + '/' + p[0] + '_' + '{:04}.jpg'.format(int(p[2]))
48 | fold = i // 600
49 | flag = 1
50 | elif len(p) == 4:
51 | nameL = p[0] + '/' + p[0] + '_' + '{:04}.jpg'.format(int(p[1]))
52 | nameR = p[2] + '/' + p[2] + '_' + '{:04}.jpg'.format(int(p[3]))
53 | fold = i // 600
54 | flag = -1
55 | self.nameLs.append(nameL)
56 | self.nameRs.append(nameR)
57 | self.folds.append(fold)
58 | self.flags.append(flag)
59 |
60 | def __getitem__(self, index):
61 |
62 | img_l = self.loader(os.path.join(self.root, self.nameLs[index]))
63 | img_r = self.loader(os.path.join(self.root, self.nameRs[index]))
64 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)]
65 |
66 | if self.transform is not None:
67 | for i in range(len(imglist)):
68 | imglist[i] = self.transform(imglist[i])
69 |
70 | imgs = imglist
71 | return imgs
72 | else:
73 | imgs = [torch.from_numpy(i) for i in imglist]
74 | return imgs
75 |
76 | def __len__(self):
77 | return len(self.nameLs)
78 |
79 |
80 | if __name__ == '__main__':
81 | root = 'D:/data/lfw_align_112'
82 | file_list = 'D:/data/pairs.txt'
83 |
84 | transform = transforms.Compose([
85 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
86 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
87 | ])
88 |
89 | dataset = LFW(root, file_list, transform=transform)
90 | #dataset = LFW(root, file_list)
91 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False)
92 | print(len(dataset))
93 | for data in trainloader:
94 | for d in data:
95 | print(d[0].shape)
--------------------------------------------------------------------------------
/dataset/lfw_2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: lfw_2.py
7 | @time: 2019/2/19 16:59
8 | @desc: lfw dataset from insightface ,just like agedb and cfp-fp
9 | '''
10 |
11 |
12 | import numpy as np
13 | import cv2
14 | import os
15 | import torch.utils.data as data
16 |
17 | import torch
18 | import torchvision.transforms as transforms
19 |
20 | def img_loader(path):
21 | try:
22 | with open(path, 'rb') as f:
23 | img = cv2.imread(path)
24 | if len(img.shape) == 2:
25 | img = np.stack([img] * 3, 2)
26 | return img
27 | except IOError:
28 | print('Cannot load image ' + path)
29 |
30 | class LFW_2(data.Dataset):
31 | def __init__(self, root, file_list, transform=None, loader=img_loader):
32 |
33 | self.root = root
34 | self.file_list = file_list
35 | self.transform = transform
36 | self.loader = loader
37 | self.nameLs = []
38 | self.nameRs = []
39 | self.folds = []
40 | self.flags = []
41 |
42 | with open(file_list) as f:
43 | pairs = f.read().splitlines()
44 | for i, p in enumerate(pairs):
45 | p = p.split(' ')
46 | nameL = p[0]
47 | nameR = p[1]
48 | fold = i // 600
49 | flag = int(p[2])
50 |
51 | self.nameLs.append(nameL)
52 | self.nameRs.append(nameR)
53 | self.folds.append(fold)
54 | self.flags.append(flag)
55 |
56 | def __getitem__(self, index):
57 |
58 | img_l = self.loader(os.path.join(self.root, self.nameLs[index]))
59 | img_r = self.loader(os.path.join(self.root, self.nameRs[index]))
60 | imglist = [img_l, cv2.flip(img_l, 1), img_r, cv2.flip(img_r, 1)]
61 |
62 | if self.transform is not None:
63 | for i in range(len(imglist)):
64 | imglist[i] = self.transform(imglist[i])
65 |
66 | imgs = imglist
67 | return imgs
68 | else:
69 | imgs = [torch.from_numpy(i) for i in imglist]
70 | return imgs
71 |
72 | def __len__(self):
73 | return len(self.nameLs)
74 |
75 |
76 | if __name__ == '__main__':
77 | root = '/media/sda/insightface_emore/lfw'
78 | file_list = '/media/sda/insightface_emore/pair_lfw.txt'
79 |
80 | transform = transforms.Compose([
81 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
82 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
83 | ])
84 |
85 | dataset = LFW_2(root, file_list, transform=transform)
86 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False)
87 | for data in trainloader:
88 | for d in data:
89 | print(d[0].shape)
--------------------------------------------------------------------------------
/dataset/megaface.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: megaface.py
7 | @time: 2018/12/24 16:29
8 | @desc:
9 | '''
10 |
11 | import torchvision.transforms as transforms
12 | import torch.utils.data as data
13 | import numpy as np
14 | import cv2
15 | import os
16 | import torch
17 |
18 | def img_loader(path):
19 | try:
20 | with open(path, 'rb') as f:
21 | img = cv2.imread(path)
22 | if len(img.shape) == 2:
23 | img = np.stack([img] * 3, 2)
24 | return img
25 | except IOError:
26 | print('Cannot load image ' + path)
27 |
28 |
29 | class MegaFace(data.Dataset):
30 | def __init__(self, facescrub_dir, megaface_dir, transform=None, loader=img_loader):
31 |
32 | self.transform = transform
33 | self.loader = loader
34 |
35 | test_image_file_list = []
36 | print('Scanning files under facescrub and megaface...')
37 | for root, dirs, files in os.walk(facescrub_dir):
38 | for e in files:
39 | filename = os.path.join(root, e)
40 | ext = os.path.splitext(filename)[1].lower()
41 | if ext in ('.png', '.bmp', '.jpg', '.jpeg'):
42 | test_image_file_list.append(filename)
43 | for root, dirs, files in os.walk(megaface_dir):
44 | for e in files:
45 | filename = os.path.join(root, e)
46 | ext = os.path.splitext(filename)[1].lower()
47 | if ext in ('.png', '.bmp', '.jpg', '.jpeg'):
48 | test_image_file_list.append(filename)
49 |
50 | self.image_list = test_image_file_list
51 |
52 | def __getitem__(self, index):
53 | img_path = self.image_list[index]
54 | img = self.loader(img_path)
55 |
56 | #水平翻转图像
57 | #img = cv2.flip(img, 1)
58 |
59 | if self.transform is not None:
60 | img = self.transform(img)
61 | else:
62 | img = torch.from_numpy(img)
63 |
64 | return img, img_path
65 |
66 | def __len__(self):
67 | return len(self.image_list)
68 |
69 |
70 | if __name__ == '__main__':
71 | facescrub = '/media/sda/megaface_test_kit/facescrub_align_112/'
72 | megaface = '/media/sda/megaface_test_kit/megaface_align_112/'
73 |
74 | transform = transforms.Compose([
75 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
76 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
77 | ])
78 | dataset = MegaFace(facescrub, megaface, transform=transform)
79 | trainloader = data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, drop_last=False)
80 | print(len(dataset))
81 | for data in trainloader:
82 | print(data.shape)
--------------------------------------------------------------------------------
/eval_agedb30.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: eval_agedb30.py
7 | @time: 2018/12/25 19:05
8 | @desc: The AgeDB-30 test protocol is same with LFW, so I just copy the code from eval_lfw.py
9 | '''
10 |
11 |
12 | import numpy as np
13 | import scipy.io
14 | import os
15 | import torch.utils.data
16 | from backbone import mobilefacenet, resnet, arcfacenet, cbam
17 | from dataset.agedb import AgeDB30
18 | import torchvision.transforms as transforms
19 | from torch.nn import DataParallel
20 | import argparse
21 |
22 | def getAccuracy(scores, flags, threshold):
23 | p = np.sum(scores[flags == 1] > threshold)
24 | n = np.sum(scores[flags == -1] < threshold)
25 | return 1.0 * (p + n) / len(scores)
26 |
27 | def getThreshold(scores, flags, thrNum):
28 | accuracys = np.zeros((2 * thrNum + 1, 1))
29 | thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum
30 | for i in range(2 * thrNum + 1):
31 | accuracys[i] = getAccuracy(scores, flags, thresholds[i])
32 | max_index = np.squeeze(accuracys == np.max(accuracys))
33 | bestThreshold = np.mean(thresholds[max_index])
34 | return bestThreshold
35 |
36 | def evaluation_10_fold(feature_path='./result/cur_epoch_agedb_result.mat'):
37 | ACCs = np.zeros(10)
38 | result = scipy.io.loadmat(feature_path)
39 | for i in range(10):
40 | fold = result['fold']
41 | flags = result['flag']
42 | featureLs = result['fl']
43 | featureRs = result['fr']
44 |
45 | valFold = fold != i
46 | testFold = fold == i
47 | flags = np.squeeze(flags)
48 |
49 | mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0)
50 | mu = np.expand_dims(mu, 0)
51 | featureLs = featureLs - mu
52 | featureRs = featureRs - mu
53 | featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1)
54 | featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1)
55 |
56 | scores = np.sum(np.multiply(featureLs, featureRs), 1)
57 | threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000)
58 | ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold)
59 |
60 | return ACCs
61 |
62 | def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None):
63 |
64 | if backbone_net == 'MobileFace':
65 | net = mobilefacenet.MobileFaceNet()
66 | elif backbone_net == 'CBAM_50':
67 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir')
68 | elif backbone_net == 'CBAM_50_SE':
69 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se')
70 | elif backbone_net == 'CBAM_100':
71 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir')
72 | elif backbone_net == 'CBAM_100_SE':
73 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se')
74 | else:
75 | print(backbone_net, ' is not available!')
76 |
77 | # gpu init
78 | multi_gpus = False
79 | if len(gpus.split(',')) > 1:
80 | multi_gpus = True
81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus
82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83 |
84 | net.load_state_dict(torch.load(resume)['net_state_dict'])
85 |
86 | if multi_gpus:
87 | net = DataParallel(net).to(device)
88 | else:
89 | net = net.to(device)
90 |
91 | transform = transforms.Compose([
92 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
93 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
94 | ])
95 | agedb_dataset = AgeDB30(data_root, file_list, transform=transform)
96 | agedb_loader = torch.utils.data.DataLoader(agedb_dataset, batch_size=128,
97 | shuffle=False, num_workers=2, drop_last=False)
98 |
99 | return net.eval(), device, agedb_dataset, agedb_loader
100 |
101 | def getFeatureFromTorch(feature_save_dir, net, device, data_set, data_loader):
102 | featureLs = None
103 | featureRs = None
104 | count = 0
105 | for data in data_loader:
106 | for i in range(len(data)):
107 | data[i] = data[i].to(device)
108 | count += data[0].size(0)
109 | #print('extracing deep features from the face pair {}...'.format(count))
110 | with torch.no_grad():
111 | res = [net(d).data.cpu().numpy() for d in data]
112 | featureL = np.concatenate((res[0], res[1]), 1)
113 | featureR = np.concatenate((res[2], res[3]), 1)
114 | # print(featureL.shape, featureR.shape)
115 | if featureLs is None:
116 | featureLs = featureL
117 | else:
118 | featureLs = np.concatenate((featureLs, featureL), 0)
119 | if featureRs is None:
120 | featureRs = featureR
121 | else:
122 | featureRs = np.concatenate((featureRs, featureR), 0)
123 | # print(featureLs.shape, featureRs.shape)
124 |
125 | result = {'fl': featureLs, 'fr': featureRs, 'fold': data_set.folds, 'flag': data_set.flags}
126 | scipy.io.savemat(feature_save_dir, result)
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser(description='Testing')
131 | parser.add_argument('--root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='The path of lfw data')
132 | parser.add_argument('--file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='The path of lfw data')
133 | parser.add_argument('--resume', type=str, default='./model/SERES100_SERES100_IR_20190528_132635/Iter_342000_net.ckpt', help='The path pf save model')
134 | parser.add_argument('--backbone_net', type=str, default='CBAM_100_SE', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE')
135 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension')
136 | parser.add_argument('--feature_save_path', type=str, default='./result/cur_epoch_agedb_result.mat',
137 | help='The path of the extract features save, must be .mat file')
138 | parser.add_argument('--gpus', type=str, default='2,3', help='gpu list')
139 | args = parser.parse_args()
140 |
141 | net, device, agedb_dataset, agedb_loader = loadModel(args.root, args.file_list, args.backbone_net, args.gpus, args.resume)
142 | getFeatureFromTorch(args.feature_save_path, net, device, agedb_dataset, agedb_loader)
143 | ACCs = evaluation_10_fold(args.feature_save_path)
144 | for i in range(len(ACCs)):
145 | print('{} {:.2f}'.format(i + 1, ACCs[i] * 100))
146 | print('--------')
147 | print('AVE {:.4f}'.format(np.mean(ACCs) * 100))
148 |
149 |
--------------------------------------------------------------------------------
/eval_cfp.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: eval_cfp.py
7 | @time: 2018/12/26 16:23
8 | @desc: this code is very similar with eval_lfw.py and eval_agedb30.py
9 | '''
10 |
11 |
12 | import numpy as np
13 | import scipy.io
14 | import os
15 | import torch.utils.data
16 | from backbone import mobilefacenet, resnet, arcfacenet, cbam
17 | from dataset.cfp import CFP_FP
18 | import torchvision.transforms as transforms
19 | from torch.nn import DataParallel
20 | import argparse
21 |
22 | def getAccuracy(scores, flags, threshold):
23 | p = np.sum(scores[flags == 1] > threshold)
24 | n = np.sum(scores[flags == -1] < threshold)
25 | return 1.0 * (p + n) / len(scores)
26 |
27 | def getThreshold(scores, flags, thrNum):
28 | accuracys = np.zeros((2 * thrNum + 1, 1))
29 | thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum
30 | for i in range(2 * thrNum + 1):
31 | accuracys[i] = getAccuracy(scores, flags, thresholds[i])
32 | max_index = np.squeeze(accuracys == np.max(accuracys))
33 | bestThreshold = np.mean(thresholds[max_index])
34 | return bestThreshold
35 |
36 | def evaluation_10_fold(feature_path='./result/cur_epoch_cfp_result.mat'):
37 | ACCs = np.zeros(10)
38 | result = scipy.io.loadmat(feature_path)
39 | for i in range(10):
40 | fold = result['fold']
41 | flags = result['flag']
42 | featureLs = result['fl']
43 | featureRs = result['fr']
44 |
45 | valFold = fold != i
46 | testFold = fold == i
47 | flags = np.squeeze(flags)
48 |
49 | mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0)
50 | mu = np.expand_dims(mu, 0)
51 | featureLs = featureLs - mu
52 | featureRs = featureRs - mu
53 | featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1)
54 | featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1)
55 |
56 | scores = np.sum(np.multiply(featureLs, featureRs), 1)
57 | threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000)
58 | ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold)
59 |
60 | return ACCs
61 |
62 | def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None):
63 |
64 | if backbone_net == 'MobileFace':
65 | net = mobilefacenet.MobileFaceNet()
66 | elif backbone_net == 'CBAM_50':
67 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir')
68 | elif backbone_net == 'CBAM_50_SE':
69 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se')
70 | elif backbone_net == 'CBAM_100':
71 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir')
72 | elif backbone_net == 'CBAM_100_SE':
73 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se')
74 | else:
75 | print(backbone_net, ' is not available!')
76 |
77 | # gpu init
78 | multi_gpus = False
79 | if len(gpus.split(',')) > 1:
80 | multi_gpus = True
81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus
82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83 |
84 | net.load_state_dict(torch.load(resume)['net_state_dict'])
85 |
86 | if multi_gpus:
87 | net = DataParallel(net).to(device)
88 | else:
89 | net = net.to(device)
90 |
91 | transform = transforms.Compose([
92 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
93 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
94 | ])
95 | cfp_dataset = CFP_FP(data_root, file_list, transform=transform)
96 | cfp_loader = torch.utils.data.DataLoader(cfp_dataset, batch_size=128,
97 | shuffle=False, num_workers=4, drop_last=False)
98 |
99 | return net.eval(), device, cfp_dataset, cfp_loader
100 |
101 | def getFeatureFromTorch(feature_save_dir, net, device, data_set, data_loader):
102 | featureLs = None
103 | featureRs = None
104 | count = 0
105 | for data in data_loader:
106 | for i in range(len(data)):
107 | data[i] = data[i].to(device)
108 | count += data[0].size(0)
109 | #print('extracing deep features from the face pair {}...'.format(count))
110 | with torch.no_grad():
111 | res = [net(d).data.cpu().numpy() for d in data]
112 | featureL = np.concatenate((res[0], res[1]), 1)
113 | featureR = np.concatenate((res[2], res[3]), 1)
114 | # print(featureL.shape, featureR.shape)
115 | if featureLs is None:
116 | featureLs = featureL
117 | else:
118 | featureLs = np.concatenate((featureLs, featureL), 0)
119 | if featureRs is None:
120 | featureRs = featureR
121 | else:
122 | featureRs = np.concatenate((featureRs, featureR), 0)
123 | # print(featureLs.shape, featureRs.shape)
124 |
125 | result = {'fl': featureLs, 'fr': featureRs, 'fold': data_set.folds, 'flag': data_set.flags}
126 | scipy.io.savemat(feature_save_dir, result)
127 |
128 |
129 | if __name__ == '__main__':
130 | parser = argparse.ArgumentParser(description='Testing')
131 | parser.add_argument('--root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='The path of lfw data')
132 | parser.add_argument('--file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='The path of lfw data')
133 | parser.add_argument('--resume', type=str, default='./model/SERES100_SERES100_IR_20190528_132635/Iter_342000_net.ckpt', help='The path pf save model')
134 | parser.add_argument('--backbone_net', type=str, default='CBAM_100_SE', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE')
135 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension')
136 | parser.add_argument('--feature_save_path', type=str, default='./result/cur_epoch_cfp_result.mat',
137 | help='The path of the extract features save, must be .mat file')
138 | parser.add_argument('--gpus', type=str, default='2,3', help='gpu list')
139 | args = parser.parse_args()
140 |
141 | net, device, agedb_dataset, agedb_loader = loadModel(args.root, args.file_list, args.backbone_net, args.gpus, args.resume)
142 | getFeatureFromTorch(args.feature_save_path, net, device, agedb_dataset, agedb_loader)
143 | ACCs = evaluation_10_fold(args.feature_save_path)
144 | for i in range(len(ACCs)):
145 | print('{} {:.2f}'.format(i + 1, ACCs[i] * 100))
146 | print('--------')
147 | print('AVE {:.4f}'.format(np.mean(ACCs) * 100))
--------------------------------------------------------------------------------
/eval_deepglint_merge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: eval_deepglint_merge.py.py
7 | @time: 2019/3/21 11:09
8 | @desc: merge the feature of deepglint test data to one file. original deepglint feature is generated by the protocol of megaface.
9 | '''
10 |
11 | """
12 | We use the same format as Megaface(http://megaface.cs.washington.edu)
13 | except that we merge all files into a single binary file.
14 |
15 | for examples:
16 |
17 | when megaface: N * (512, 1)
18 | while deepglint:(N, 512)
19 |
20 | """
21 | import struct
22 | import numpy as np
23 | import sys, os
24 | import argparse
25 |
26 | cv_type_to_dtype = {
27 | 5: np.dtype('float32')
28 | }
29 |
30 | dtype_to_cv_type = {v: k for k, v in cv_type_to_dtype.items()}
31 |
32 |
33 | def write_mat(f, m):
34 | """Write mat m to file f"""
35 | if len(m.shape) == 1:
36 | rows = m.shape[0]
37 | cols = 1
38 | else:
39 | rows, cols = m.shape
40 | header = struct.pack('iiii', rows, cols, cols * 4, dtype_to_cv_type[m.dtype])
41 | f.write(header)
42 | f.write(m.data)
43 |
44 |
45 | def read_mat(f):
46 | """
47 | Reads an OpenCV mat from the given file opened in binary mode
48 | """
49 | rows, cols, stride, type_ = struct.unpack('iiii', f.read(4 * 4))
50 | mat = np.fromstring(f.read(rows * stride), dtype=cv_type_to_dtype[type_])
51 | return mat.reshape(rows, cols)
52 |
53 |
54 | def load_mat(filename):
55 | """
56 | Reads a OpenCV Mat from the given filename
57 | """
58 | return read_mat(open(filename, 'rb'))
59 |
60 |
61 | def save_mat(filename, m):
62 | """Saves mat m to the given filename"""
63 | return write_mat(open(filename, 'wb'), m)
64 |
65 |
66 |
67 | def main(args):
68 |
69 | deepglint_features = args.deepglint_features_path
70 | # merge all features into one file
71 | total_feature = []
72 | total_files = []
73 | for root, dirs, files in os.walk(deepglint_features):
74 | for file in files:
75 | filename = os.path.join(root, file)
76 | ext = os.path.splitext(filename)[1]
77 | ext = ext.lower()
78 | if ext in ('.feat'):
79 | total_files.append(filename)
80 |
81 | assert len(total_files) == 1862120
82 | total_files.sort() # important
83 |
84 | for i in range(len(total_files)):
85 | filename = total_files[i]
86 | tmp_feature = load_mat(filename)
87 | # print(filename)
88 | # print(tmp_feature.shape)
89 | tmp_feature = tmp_feature.T
90 | total_feature.append(tmp_feature)
91 | print(i + 1, tmp_feature.shape)
92 | # write_mat(feature_path_out, feature_fusion)
93 |
94 | print('total feature number: ', len(total_feature))
95 | total_feature = np.array(total_feature).squeeze()
96 | print(total_feature.shape, total_feature.dtype, type(total_feature))
97 | save_mat('deepglint_test_feature.bin', total_feature)
98 |
99 |
100 | if __name__ == '__main__':
101 | parser = argparse.ArgumentParser()
102 | parser.add_argument("--deepglint_features_path", type=str, default="/home/wujiyang/deepglint/deepglint_feature_ir+ws/")
103 | args = parser.parse_args()
104 |
105 | main(args)
106 |
--------------------------------------------------------------------------------
/eval_lfw.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: eval_lfw.py
7 | @time: 2018/12/22 9:47
8 | @desc:
9 | '''
10 |
11 | import numpy as np
12 | import scipy.io
13 | import os
14 | import json
15 | import torch.utils.data
16 | from backbone import mobilefacenet, resnet, arcfacenet, cbam
17 | from dataset.lfw import LFW
18 | import torchvision.transforms as transforms
19 | from torch.nn import DataParallel
20 | import argparse
21 |
22 | def getAccuracy(scores, flags, threshold):
23 | p = np.sum(scores[flags == 1] > threshold)
24 | n = np.sum(scores[flags == -1] < threshold)
25 | return 1.0 * (p + n) / len(scores)
26 |
27 | def getThreshold(scores, flags, thrNum):
28 | accuracys = np.zeros((2 * thrNum + 1, 1))
29 | thresholds = np.arange(-thrNum, thrNum + 1) * 1.0 / thrNum
30 | for i in range(2 * thrNum + 1):
31 | accuracys[i] = getAccuracy(scores, flags, thresholds[i])
32 | max_index = np.squeeze(accuracys == np.max(accuracys))
33 | bestThreshold = np.mean(thresholds[max_index])
34 | return bestThreshold
35 |
36 | def evaluation_10_fold(feature_path='./result/cur_epoch_result.mat'):
37 | ACCs = np.zeros(10)
38 | result = scipy.io.loadmat(feature_path)
39 | for i in range(10):
40 | fold = result['fold']
41 | flags = result['flag']
42 | featureLs = result['fl']
43 | featureRs = result['fr']
44 |
45 | valFold = fold != i
46 | testFold = fold == i
47 | flags = np.squeeze(flags)
48 |
49 | mu = np.mean(np.concatenate((featureLs[valFold[0], :], featureRs[valFold[0], :]), 0), 0)
50 | mu = np.expand_dims(mu, 0)
51 | featureLs = featureLs - mu
52 | featureRs = featureRs - mu
53 | featureLs = featureLs / np.expand_dims(np.sqrt(np.sum(np.power(featureLs, 2), 1)), 1)
54 | featureRs = featureRs / np.expand_dims(np.sqrt(np.sum(np.power(featureRs, 2), 1)), 1)
55 |
56 | scores = np.sum(np.multiply(featureLs, featureRs), 1)
57 | threshold = getThreshold(scores[valFold[0]], flags[valFold[0]], 10000)
58 | ACCs[i] = getAccuracy(scores[testFold[0]], flags[testFold[0]], threshold)
59 |
60 | return ACCs
61 |
62 | def loadModel(data_root, file_list, backbone_net, gpus='0', resume=None):
63 |
64 | if backbone_net == 'MobileFace':
65 | net = mobilefacenet.MobileFaceNet()
66 | elif backbone_net == 'CBAM_50':
67 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir')
68 | elif backbone_net == 'CBAM_50_SE':
69 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se')
70 | elif backbone_net == 'CBAM_100':
71 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir')
72 | elif backbone_net == 'CBAM_100_SE':
73 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se')
74 | else:
75 | print(backbone_net, ' is not available!')
76 |
77 | # gpu init
78 | multi_gpus = False
79 | if len(gpus.split(',')) > 1:
80 | multi_gpus = True
81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus
82 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
83 |
84 | net.load_state_dict(torch.load(resume)['net_state_dict'])
85 |
86 | if multi_gpus:
87 | net = DataParallel(net).to(device)
88 | else:
89 | net = net.to(device)
90 |
91 | transform = transforms.Compose([
92 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
93 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
94 | ])
95 | lfw_dataset = LFW(data_root, file_list, transform=transform)
96 | lfw_loader = torch.utils.data.DataLoader(lfw_dataset, batch_size=128,
97 | shuffle=False, num_workers=2, drop_last=False)
98 |
99 | return net.eval(), device, lfw_dataset, lfw_loader
100 |
101 | def getFeatureFromTorch(feature_save_dir, net, device, data_set, data_loader):
102 | featureLs = None
103 | featureRs = None
104 | count = 0
105 | for data in data_loader:
106 | for i in range(len(data)):
107 | data[i] = data[i].to(device)
108 | count += data[0].size(0)
109 | #print('extracing deep features from the face pair {}...'.format(count))
110 | with torch.no_grad():
111 | res = [net(d).data.cpu().numpy() for d in data]
112 | featureL = np.concatenate((res[0], res[1]), 1)
113 | featureR = np.concatenate((res[2], res[3]), 1)
114 | # print(featureL.shape, featureR.shape)
115 | if featureLs is None:
116 | featureLs = featureL
117 | else:
118 | featureLs = np.concatenate((featureLs, featureL), 0)
119 | if featureRs is None:
120 | featureRs = featureR
121 | else:
122 | featureRs = np.concatenate((featureRs, featureR), 0)
123 | # print(featureLs.shape, featureRs.shape)
124 |
125 | result = {'fl': featureLs, 'fr': featureRs, 'fold': data_set.folds, 'flag': data_set.flags}
126 | scipy.io.savemat(feature_save_dir, result)
127 |
128 | if __name__ == '__main__':
129 | parser = argparse.ArgumentParser(description='Testing')
130 | parser.add_argument('--root', type=str, default='/media/sda/lfw/lfw_align_112', help='The path of lfw data')
131 | parser.add_argument('--file_list', type=str, default='/media/sda/lfw/pairs.txt', help='The path of lfw data')
132 | parser.add_argument('--backbone_net', type=str, default='CBAM_100_SE', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE')
133 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension')
134 | parser.add_argument('--resume', type=str, default='./model/SERES100_SERES100_IR_20190528_132635/Iter_342000_net.ckpt',
135 | help='The path pf save model')
136 | parser.add_argument('--feature_save_path', type=str, default='./result/cur_epoch_lfw_result.mat',
137 | help='The path of the extract features save, must be .mat file')
138 | parser.add_argument('--gpus', type=str, default='1,3', help='gpu list')
139 | args = parser.parse_args()
140 |
141 | net, device, lfw_dataset, lfw_loader = loadModel(args.root, args.file_list, args.backbone_net, args.gpus, args.resume)
142 | getFeatureFromTorch(args.feature_save_path, net, device, lfw_dataset, lfw_loader)
143 | ACCs = evaluation_10_fold(args.feature_save_path)
144 | for i in range(len(ACCs)):
145 | print('{} {:.2f}'.format(i+1, ACCs[i] * 100))
146 | print('--------')
147 | print('AVE {:.4f}'.format(np.mean(ACCs) * 100))
148 |
--------------------------------------------------------------------------------
/eval_lfw_blufr.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: eval_lfw_blufr.py
7 | @time: 2019/1/17 15:52
8 | @desc: test lfw accuracy on blufr protocol
9 | '''
10 | '''
11 | LFW BLUFR TEST PROTOCOL
12 |
13 | Official Website: http://www.cbsr.ia.ac.cn/users/scliao/projects/blufr/
14 |
15 | When I try to do this, I find that the blufr_lfw_config.mat file provided by above site is too old.
16 | Some image files listed in the mat have been removed in lfw pairs.txt
17 | So this work is suspended for now...
18 | '''
19 |
20 | import scipy.io as sio
21 | import argparse
22 |
23 | def readName(file='pairs.txt'):
24 | name_list = []
25 | f = open(file, 'r')
26 | lines = f.readlines()
27 |
28 | for line in lines[1:]:
29 | line_split = line.rstrip().split()
30 | if len(line_split) == 3:
31 | name_list.append(line_split[0])
32 | elif len(line_split) == 4:
33 | name_list.append(line_split[0])
34 | name_list.append(line_split[2])
35 | else:
36 | print('wrong file, please check again')
37 |
38 | return list(set(name_list))
39 |
40 |
41 | def main(args):
42 | blufr_info = sio.loadmat(args.lfw_blufr_file)
43 | #print(blufr_info)
44 | name_list = readName()
45 |
46 | image = blufr_info['imageList']
47 | missing_files = []
48 | for i in range(image.shape[0]):
49 | name = image[i][0][0]
50 | index = name.rfind('_')
51 | name = name[0:index]
52 | if name not in name_list:
53 | print(name)
54 | missing_files.append(name)
55 | print('lfw pairs.txt total persons: ', len(name_list))
56 | print('blufr_mat_missing persons: ', len(missing_files))
57 |
58 | '''
59 | Some of the missing file:
60 | Zdravko_Mucic
61 | Zelma_Novelo
62 | Zeng_Qinghong
63 | Zumrati_Juma
64 | lfw pairs.txt total persons: 4281
65 | blufr_mat_missing persons: 1549
66 |
67 | '''
68 |
69 | if __name__ == '__main__':
70 | parser = argparse.ArgumentParser(description='lfw blufr test')
71 | parser.add_argument('--lfw_blufr_file', type=str, default='./blufr_lfw_config.mat', help='feature dimension')
72 | parser.add_argument('--lfw_pairs.txt', type=str, default='./pairs.txt', help='feature dimension')
73 | parser.add_argument('--gpus', type=str, default='2,3', help='gpu list')
74 | args = parser.parse_args()
75 |
76 | main(args)
--------------------------------------------------------------------------------
/eval_megaface.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: eval_megaface.py
7 | @time: 2018/12/24 16:28
8 | @desc: megaface feature extractor
9 | '''
10 | import numpy as np
11 | import struct
12 | import os
13 | import torch.utils.data
14 | from backbone import mobilefacenet, cbam, self_attention
15 | from dataset.megaface import MegaFace
16 | import torchvision.transforms as transforms
17 | import argparse
18 | from torch.nn import DataParallel
19 |
20 |
21 | cv_type_to_dtype = {5: np.dtype('float32'), 6: np.dtype('float64')}
22 | dtype_to_cv_type = {v: k for k, v in cv_type_to_dtype.items()}
23 |
24 | def write_mat(filename, m):
25 | """Write mat m to file f"""
26 | if len(m.shape) == 1:
27 | rows = m.shape[0]
28 | cols = 1
29 | else:
30 | rows, cols = m.shape
31 | header = struct.pack('iiii', rows, cols, cols * 4, dtype_to_cv_type[m.dtype])
32 |
33 | with open(filename, 'wb') as outfile:
34 | outfile.write(header)
35 | outfile.write(m.data)
36 |
37 |
38 | def read_mat(filename):
39 | """
40 | Reads an OpenCV mat from the given file opened in binary mode
41 | """
42 | with open(filename, 'rb') as fin:
43 | rows, cols, stride, type_ = struct.unpack('iiii', fin.read(4 * 4))
44 | mat = np.fromstring(str(fin.read(rows * stride)), dtype=cv_type_to_dtype[type_])
45 | return mat.reshape(rows, cols)
46 |
47 |
48 | def extract_feature(model_path, backbone_net, face_scrub_path, megaface_path, batch_size=32, gpus='0', do_norm=False):
49 |
50 | if backbone_net == 'MobileFace':
51 | net = mobilefacenet.MobileFaceNet()
52 | elif backbone_net == 'CBAM_50':
53 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir')
54 | elif backbone_net == 'CBAM_50_SE':
55 | net = cbam.CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se')
56 | elif backbone_net == 'CBAM_100':
57 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir')
58 | elif backbone_net == 'CBAM_100_SE':
59 | net = cbam.CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se')
60 | else:
61 | print(args.backbone, ' is not available!')
62 |
63 | multi_gpus = False
64 | if len(gpus.split(',')) > 1:
65 | multi_gpus = True
66 | os.environ['CUDA_VISIBLE_DEVICES'] = gpus
67 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
68 |
69 | net.load_state_dict(torch.load(model_path)['net_state_dict'])
70 | if multi_gpus:
71 | net = DataParallel(net).to(device)
72 | else:
73 | net = net.to(device)
74 | net.eval()
75 |
76 | transform = transforms.Compose([
77 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
78 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
79 | ])
80 | megaface_dataset = MegaFace(face_scrub_path, megaface_path, transform=transform)
81 | megaface_loader = torch.utils.data.DataLoader(megaface_dataset, batch_size=batch_size,
82 | shuffle=False, num_workers=12, drop_last=False)
83 |
84 | for data in megaface_loader:
85 | img, img_path= data[0].to(device), data[1]
86 | with torch.no_grad():
87 | output = net(img).data.cpu().numpy()
88 |
89 | if do_norm is False:
90 | for i in range(len(img_path)):
91 | abs_path = img_path[i] + '.feat'
92 | write_mat(abs_path, output[i])
93 | print('extract 1 batch...without feature normalization')
94 | else:
95 | for i in range(len(img_path)):
96 | abs_path = img_path[i] + '.feat'
97 | feat = output[i]
98 | feat = feat / np.sqrt((np.dot(feat, feat)))
99 | write_mat(abs_path, feat)
100 | print('extract 1 batch...with feature normalization')
101 | print('all images have been processed!')
102 |
103 |
104 | if __name__ == '__main__':
105 | parser = argparse.ArgumentParser(description='Testing')
106 | parser.add_argument('--model_path', type=str, default='./model/RES100_RES100_IR_20190423_100728/Iter_333000_net.ckpt', help='The path of trained model')
107 | parser.add_argument('--backbone_net', type=str, default='CBAM_100', help='MobileFace, CBAM_50, CBAM_50_SE, CBAM_100, CBAM_100_SE')
108 | parser.add_argument('--facescrub_dir', type=str, default='/media/sda/megaface_test_kit/facescrub_align_112/', help='facescrub data')
109 | parser.add_argument('--megaface_dir', type=str, default='/media/sda/megaface_test_kit/megaface_align_112/', help='megaface data')
110 | parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
111 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension')
112 | parser.add_argument('--gpus', type=str, default='0,1,2,3', help='gpu list')
113 | parser.add_argument("--do_norm", type=int, default=1, help="1 if normalize feature, 0 do nothing(Default case)")
114 | args = parser.parse_args()
115 |
116 | extract_feature(args.model_path, args.backbone_net, args.facescrub_dir, args.megaface_dir, args.batch_size, args.gpus, args.do_norm)
--------------------------------------------------------------------------------
/lossfunctions/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: __init__.py.py
7 | @time: 2019/1/4 15:24
8 | @desc:
9 | '''
--------------------------------------------------------------------------------
/lossfunctions/agentcenterloss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: agentcenterloss.py
7 | @time: 2019/1/7 10:53
8 | @desc: the variety of center loss, which use the class weight as the class center and normalize both the weight and feature,
9 | in this way, the cos distance of weight and feature can be used as the supervised signal.
10 | It's similar with torch.nn.CosineEmbeddingLoss, x_1 means weight_i, x_2 means feature_i.
11 | '''
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 |
17 | class AgentCenterLoss(nn.Module):
18 |
19 | def __init__(self, num_classes, feat_dim, scale):
20 | super(AgentCenterLoss, self).__init__()
21 | self.num_classes = num_classes
22 | self.feat_dim = feat_dim
23 | self.scale = scale
24 |
25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
26 |
27 | def forward(self, x, labels):
28 | '''
29 | Parameters:
30 | x: input tensor with shape (batch_size, feat_dim)
31 | labels: ground truth label with shape (batch_size)
32 | Return:
33 | loss of centers
34 | '''
35 | cos_dis = F.linear(F.normalize(x), F.normalize(self.centers)) * self.scale
36 |
37 | one_hot = torch.zeros_like(cos_dis)
38 | one_hot.scatter_(1, labels.view(-1, 1), 1.0)
39 |
40 | # loss = 1 - cosine(i)
41 | loss = one_hot * self.scale - (one_hot * cos_dis)
42 |
43 | return loss.mean()
--------------------------------------------------------------------------------
/lossfunctions/centerloss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: centerloss.py
7 | @time: 2019/1/4 15:24
8 | @desc: the implementation of center loss
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | class CenterLoss(nn.Module):
15 |
16 | def __init__(self, num_classes, feat_dim):
17 | super(CenterLoss, self).__init__()
18 | self.num_classes = num_classes
19 | self.feat_dim = feat_dim
20 |
21 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
22 |
23 | def forward(self, x, labels):
24 | '''
25 | Parameters:
26 | x: input tensor with shape (batch_size, feat_dim)
27 | labels: ground truth label with shape (batch_size)
28 | Return:
29 | loss of centers
30 | '''
31 | # compute the distance of (x-center)^2
32 | batch_size = x.size(0)
33 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
34 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
35 | distmat.addmm_(1, -2, x, self.centers.t())
36 |
37 | # get one_hot matrix
38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39 | classes = torch.arange(self.num_classes).long().to(device)
40 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
41 | mask = labels.eq(classes.expand(batch_size, self.num_classes))
42 |
43 | dist = []
44 | for i in range(batch_size):
45 | value = distmat[i][mask[i]]
46 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability
47 | dist.append(value)
48 | dist = torch.cat(dist)
49 | loss = dist.mean()
50 |
51 | return loss
--------------------------------------------------------------------------------
/margin/ArcMarginProduct.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: ArcMarginProduct.py
7 | @time: 2018/12/25 9:13
8 | @desc: additive angular margin for arcface/insightface
9 | '''
10 |
11 | import math
12 | import torch
13 | from torch import nn
14 | from torch.nn import Parameter
15 | import torch.nn.functional as F
16 |
17 | class ArcMarginProduct(nn.Module):
18 | def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False):
19 | super(ArcMarginProduct, self).__init__()
20 | self.in_feature = in_feature
21 | self.out_feature = out_feature
22 | self.s = s
23 | self.m = m
24 | self.weight = Parameter(torch.Tensor(out_feature, in_feature))
25 | nn.init.xavier_uniform_(self.weight)
26 |
27 | self.easy_margin = easy_margin
28 | self.cos_m = math.cos(m)
29 | self.sin_m = math.sin(m)
30 |
31 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
32 | self.th = math.cos(math.pi - m)
33 | self.mm = math.sin(math.pi - m) * m
34 |
35 | def forward(self, x, label):
36 | # cos(theta)
37 | cosine = F.linear(F.normalize(x), F.normalize(self.weight))
38 | # cos(theta + m)
39 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
40 | phi = cosine * self.cos_m - sine * self.sin_m
41 |
42 | if self.easy_margin:
43 | phi = torch.where(cosine > 0, phi, cosine)
44 | else:
45 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
46 |
47 | #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
48 | one_hot = torch.zeros_like(cosine)
49 | one_hot.scatter_(1, label.view(-1, 1), 1)
50 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
51 | output = output * self.s
52 |
53 | return output
54 |
55 |
56 | if __name__ == '__main__':
57 | pass
--------------------------------------------------------------------------------
/margin/CosineMarginProduct.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: CosineMarginProduct.py
7 | @time: 2018/12/25 9:13
8 | @desc: additive cosine margin for cosface
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.nn import Parameter
15 |
16 |
17 | class CosineMarginProduct(nn.Module):
18 | def __init__(self, in_feature=128, out_feature=10575, s=30.0, m=0.35):
19 | super(CosineMarginProduct, self).__init__()
20 | self.in_feature = in_feature
21 | self.out_feature = out_feature
22 | self.s = s
23 | self.m = m
24 | self.weight = Parameter(torch.Tensor(out_feature, in_feature))
25 | nn.init.xavier_uniform_(self.weight)
26 |
27 |
28 | def forward(self, input, label):
29 | cosine = F.linear(F.normalize(input), F.normalize(self.weight))
30 | # one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
31 | one_hot = torch.zeros_like(cosine)
32 | one_hot.scatter_(1, label.view(-1, 1), 1.0)
33 |
34 | output = self.s * (cosine - one_hot * self.m)
35 | return output
36 |
37 |
38 | if __name__ == '__main__':
39 | pass
--------------------------------------------------------------------------------
/margin/InnerProduct.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: InnerProduct.py
7 | @time: 2019/1/4 16:54
8 | @desc: just normal inner product as fully connected layer do.
9 | '''
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.nn import Parameter
14 |
15 | class InnerProduct(nn.Module):
16 | def __init__(self, in_feature=128, out_feature=10575):
17 | super(InnerProduct, self).__init__()
18 | self.in_feature = in_feature
19 | self.out_feature = out_feature
20 |
21 | self.weight = Parameter(torch.Tensor(out_feature, in_feature))
22 | nn.init.xavier_uniform_(self.weight)
23 |
24 |
25 | def forward(self, input, label):
26 | # label not used
27 | output = F.linear(input, self.weight)
28 | return output
29 |
30 |
31 | if __name__ == '__main__':
32 | pass
--------------------------------------------------------------------------------
/margin/MultiMarginProduct.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: MultiMarginProduct.py
7 | @time: 2019/3/30 10:09
8 | @desc: Combination of additive angular margin and additive cosine margin
9 | '''
10 |
11 | import math
12 | import torch
13 | from torch import nn
14 | from torch.nn import Parameter
15 | import torch.nn.functional as F
16 |
17 | class MultiMarginProduct(nn.Module):
18 | def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False):
19 | super(MultiMarginProduct, self).__init__()
20 | self.in_feature = in_feature
21 | self.out_feature = out_feature
22 | self.s = s
23 | self.m1 = m1
24 | self.m2 = m2
25 | self.weight = Parameter(torch.Tensor(out_feature, in_feature))
26 | nn.init.xavier_uniform_(self.weight)
27 |
28 | self.easy_margin = easy_margin
29 | self.cos_m1 = math.cos(m1)
30 | self.sin_m1 = math.sin(m1)
31 |
32 | # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
33 | self.th = math.cos(math.pi - m1)
34 | self.mm = math.sin(math.pi - m1) * m1
35 |
36 | def forward(self, x, label):
37 | # cos(theta)
38 | cosine = F.linear(F.normalize(x), F.normalize(self.weight))
39 | # cos(theta + m1)
40 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
41 | phi = cosine * self.cos_m1 - sine * self.sin_m1
42 |
43 | if self.easy_margin:
44 | phi = torch.where(cosine > 0, phi, cosine)
45 | else:
46 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
47 |
48 |
49 | one_hot = torch.zeros_like(cosine)
50 | one_hot.scatter_(1, label.view(-1, 1), 1)
51 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin
52 | output = output - one_hot * self.m2 # additive cosine margin
53 | output = output * self.s
54 |
55 | return output
56 |
57 |
58 | if __name__ == '__main__':
59 | pass
--------------------------------------------------------------------------------
/margin/SphereMarginProduct.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: SphereMarginProduct.py
7 | @time: 2018/12/25 9:19
8 | @desc: multiplicative angular margin for sphereface
9 | '''
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.nn import Parameter
15 | import math
16 |
17 | class SphereMarginProduct(nn.Module):
18 | def __init__(self, in_feature, out_feature, m=4, base=1000.0, gamma=0.0001, power=2, lambda_min=5.0, iter=0):
19 | assert m in [1, 2, 3, 4], 'margin should be 1, 2, 3 or 4'
20 | self.in_feature = in_feature
21 | self.out_feature = out_feature
22 | self.m = m
23 | self.base = base
24 | self.gamma = gamma
25 | self.power = power
26 | self.lambda_min = lambda_min
27 | self.iter = 0
28 | self.weight = Parameter(torch.Tensor(out_feature, in_feature))
29 | nn.init.xavier_uniform_(self.weight)
30 |
31 | # duplication formula
32 | self.margin_formula = [
33 | lambda x : x ** 0,
34 | lambda x : x ** 1,
35 | lambda x : 2 * x ** 2 - 1,
36 | lambda x : 4 * x ** 3 - 3 * x,
37 | lambda x : 8 * x ** 4 - 8 * x ** 2 + 1,
38 | lambda x : 16 * x ** 5 - 20 * x ** 3 + 5 * x
39 | ]
40 |
41 | def forward(self, input, label):
42 | self.iter += 1
43 | self.cur_lambda = max(self.lambda_min, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))
44 |
45 | cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
46 | cos_theta = cos_theta(-1, 1)
47 |
48 | cos_m_theta = self.margin_formula(self.m)(cos_theta)
49 | theta = cos_theta.data.acos()
50 | k = ((self.m * theta) / math.pi).floor()
51 | phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
52 | phi_theta_ = (self.cur_lambda * cos_theta + phi_theta) / (1 + self.cur_lambda)
53 | norm_of_feature = torch.norm(input, 2, 1)
54 |
55 | one_hot = torch.zeros_like(cos_theta)
56 | one_hot.scatter_(1, label.view(-1, 1), 1)
57 |
58 | output = one_hot * phi_theta_ + (1 - one_hot) * cos_theta
59 | output *= norm_of_feature.view(-1, 1)
60 |
61 | return output
62 |
63 |
64 | if __name__ == '__main__':
65 | pass
--------------------------------------------------------------------------------
/margin/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: __init__.py.py
7 | @time: 2018/12/25 9:12
8 | @desc:
9 | '''
--------------------------------------------------------------------------------
/model/MSCeleb_MOBILEFACE_20181228_170458/log.log:
--------------------------------------------------------------------------------
1 | 20181228-17:05:06 Train Epoch: 1/30 ...
2 | 20181228-17:35:50 Saving checkpoint: 5000
3 | 20181228-17:36:10 LFW Ave Accuracy: 96.4333
4 | 20181228-17:36:30 AgeDB-30 Ave Accuracy: 82.7000
5 | 20181228-17:36:52 CFP-FP Ave Accuracy: 78.9286
6 | 20181228-17:36:52 Current Best Accuracy: LFW: 96.4333 in iters: 5000, AgeDB-30: 82.7000 in iters: 5000 and CFP-FP: 78.9286 in iters: 5000
7 | 20181228-18:06:33 Train Epoch: 2/30 ...
8 | 20181228-18:07:49 Saving checkpoint: 10000
9 | 20181228-18:08:09 LFW Ave Accuracy: 98.2000
10 | 20181228-18:08:29 AgeDB-30 Ave Accuracy: 89.2667
11 | 20181228-18:08:52 CFP-FP Ave Accuracy: 83.9714
12 | 20181228-18:08:52 Current Best Accuracy: LFW: 98.2000 in iters: 10000, AgeDB-30: 89.2667 in iters: 10000 and CFP-FP: 83.9714 in iters: 10000
13 | 20181228-18:39:39 Saving checkpoint: 15000
14 | 20181228-18:39:59 LFW Ave Accuracy: 98.6167
15 | 20181228-18:40:18 AgeDB-30 Ave Accuracy: 90.2500
16 | 20181228-18:40:39 CFP-FP Ave Accuracy: 85.6143
17 | 20181228-18:40:39 Current Best Accuracy: LFW: 98.6167 in iters: 15000, AgeDB-30: 90.2500 in iters: 15000 and CFP-FP: 85.6143 in iters: 15000
18 | 20181228-19:09:10 Train Epoch: 3/30 ...
19 | 20181228-19:11:37 Saving checkpoint: 20000
20 | 20181228-19:11:57 LFW Ave Accuracy: 98.5833
21 | 20181228-19:12:16 AgeDB-30 Ave Accuracy: 91.2667
22 | 20181228-19:12:37 CFP-FP Ave Accuracy: 86.1429
23 | 20181228-19:12:37 Current Best Accuracy: LFW: 98.6167 in iters: 15000, AgeDB-30: 91.2667 in iters: 20000 and CFP-FP: 86.1429 in iters: 20000
24 | 20181228-19:43:26 Saving checkpoint: 25000
25 | 20181228-19:43:46 LFW Ave Accuracy: 98.8667
26 | 20181228-19:44:06 AgeDB-30 Ave Accuracy: 91.6333
27 | 20181228-19:44:27 CFP-FP Ave Accuracy: 85.9714
28 | 20181228-19:44:27 Current Best Accuracy: LFW: 98.8667 in iters: 25000, AgeDB-30: 91.6333 in iters: 25000 and CFP-FP: 86.1429 in iters: 20000
29 | 20181228-20:11:45 Train Epoch: 4/30 ...
30 | 20181228-20:15:23 Saving checkpoint: 30000
31 | 20181228-20:15:43 LFW Ave Accuracy: 98.7167
32 | 20181228-20:16:02 AgeDB-30 Ave Accuracy: 91.4500
33 | 20181228-20:16:25 CFP-FP Ave Accuracy: 85.4714
34 | 20181228-20:16:25 Current Best Accuracy: LFW: 98.8667 in iters: 25000, AgeDB-30: 91.6333 in iters: 25000 and CFP-FP: 86.1429 in iters: 20000
35 | 20181228-20:47:10 Saving checkpoint: 35000
36 | 20181228-20:47:30 LFW Ave Accuracy: 98.9333
37 | 20181228-20:47:49 AgeDB-30 Ave Accuracy: 92.4167
38 | 20181228-20:48:11 CFP-FP Ave Accuracy: 86.4714
39 | 20181228-20:48:11 Current Best Accuracy: LFW: 98.9333 in iters: 35000, AgeDB-30: 92.4167 in iters: 35000 and CFP-FP: 86.4714 in iters: 35000
40 | 20181228-21:14:18 Train Epoch: 5/30 ...
41 | 20181228-21:19:06 Saving checkpoint: 40000
42 | 20181228-21:19:25 LFW Ave Accuracy: 98.8667
43 | 20181228-21:19:44 AgeDB-30 Ave Accuracy: 92.2167
44 | 20181228-21:20:07 CFP-FP Ave Accuracy: 87.1429
45 | 20181228-21:20:07 Current Best Accuracy: LFW: 98.9333 in iters: 35000, AgeDB-30: 92.4167 in iters: 35000 and CFP-FP: 87.1429 in iters: 40000
46 | 20181228-21:50:54 Saving checkpoint: 45000
47 | 20181228-21:51:14 LFW Ave Accuracy: 98.8333
48 | 20181228-21:51:33 AgeDB-30 Ave Accuracy: 91.9167
49 | 20181228-21:51:54 CFP-FP Ave Accuracy: 86.6714
50 | 20181228-21:51:54 Current Best Accuracy: LFW: 98.9333 in iters: 35000, AgeDB-30: 92.4167 in iters: 35000 and CFP-FP: 87.1429 in iters: 40000
51 | 20181228-22:16:47 Train Epoch: 6/30 ...
52 | 20181228-22:22:46 Saving checkpoint: 50000
53 | 20181228-22:23:06 LFW Ave Accuracy: 99.1500
54 | 20181228-22:23:25 AgeDB-30 Ave Accuracy: 92.7333
55 | 20181228-22:23:48 CFP-FP Ave Accuracy: 87.3714
56 | 20181228-22:23:48 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.7333 in iters: 50000 and CFP-FP: 87.3714 in iters: 50000
57 | 20181228-22:54:35 Saving checkpoint: 55000
58 | 20181228-22:54:56 LFW Ave Accuracy: 98.8833
59 | 20181228-22:55:16 AgeDB-30 Ave Accuracy: 92.3333
60 | 20181228-22:55:37 CFP-FP Ave Accuracy: 86.7857
61 | 20181228-22:55:37 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.7333 in iters: 50000 and CFP-FP: 87.3714 in iters: 50000
62 | 20181228-23:19:21 Train Epoch: 7/30 ...
63 | 20181228-23:26:29 Saving checkpoint: 60000
64 | 20181228-23:26:49 LFW Ave Accuracy: 98.9000
65 | 20181228-23:27:08 AgeDB-30 Ave Accuracy: 92.9333
66 | 20181228-23:27:30 CFP-FP Ave Accuracy: 87.4857
67 | 20181228-23:27:30 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.9333 in iters: 60000 and CFP-FP: 87.4857 in iters: 60000
68 | 20181228-23:58:25 Saving checkpoint: 65000
69 | 20181228-23:58:45 LFW Ave Accuracy: 99.0500
70 | 20181228-23:59:04 AgeDB-30 Ave Accuracy: 92.4000
71 | 20181228-23:59:26 CFP-FP Ave Accuracy: 86.9857
72 | 20181228-23:59:26 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 92.9333 in iters: 60000 and CFP-FP: 87.4857 in iters: 60000
73 | 20181229-00:22:10 Train Epoch: 8/30 ...
74 | 20181229-00:30:28 Saving checkpoint: 70000
75 | 20181229-00:30:47 LFW Ave Accuracy: 98.9167
76 | 20181229-00:31:06 AgeDB-30 Ave Accuracy: 93.1000
77 | 20181229-00:31:28 CFP-FP Ave Accuracy: 87.3857
78 | 20181229-00:31:28 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.1000 in iters: 70000 and CFP-FP: 87.4857 in iters: 60000
79 | 20181229-01:02:13 Saving checkpoint: 75000
80 | 20181229-01:02:33 LFW Ave Accuracy: 98.9000
81 | 20181229-01:02:52 AgeDB-30 Ave Accuracy: 92.3167
82 | 20181229-01:03:14 CFP-FP Ave Accuracy: 87.5857
83 | 20181229-01:03:14 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.1000 in iters: 70000 and CFP-FP: 87.5857 in iters: 75000
84 | 20181229-01:24:34 Train Epoch: 9/30 ...
85 | 20181229-01:34:02 Saving checkpoint: 80000
86 | 20181229-01:34:22 LFW Ave Accuracy: 98.9667
87 | 20181229-01:34:41 AgeDB-30 Ave Accuracy: 93.3833
88 | 20181229-01:35:02 CFP-FP Ave Accuracy: 87.8714
89 | 20181229-01:35:02 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 87.8714 in iters: 80000
90 | 20181229-02:05:40 Saving checkpoint: 85000
91 | 20181229-02:06:01 LFW Ave Accuracy: 99.0667
92 | 20181229-02:06:20 AgeDB-30 Ave Accuracy: 92.8000
93 | 20181229-02:06:41 CFP-FP Ave Accuracy: 87.2714
94 | 20181229-02:06:41 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 87.8714 in iters: 80000
95 | 20181229-02:26:49 Train Epoch: 10/30 ...
96 | 20181229-02:37:26 Saving checkpoint: 90000
97 | 20181229-02:37:46 LFW Ave Accuracy: 98.7167
98 | 20181229-02:38:05 AgeDB-30 Ave Accuracy: 92.3500
99 | 20181229-02:38:26 CFP-FP Ave Accuracy: 88.0143
100 | 20181229-02:38:26 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 88.0143 in iters: 90000
101 | 20181229-03:09:07 Saving checkpoint: 95000
102 | 20181229-03:09:27 LFW Ave Accuracy: 98.8667
103 | 20181229-03:09:46 AgeDB-30 Ave Accuracy: 93.2333
104 | 20181229-03:10:08 CFP-FP Ave Accuracy: 87.6286
105 | 20181229-03:10:08 Current Best Accuracy: LFW: 99.1500 in iters: 50000, AgeDB-30: 93.3833 in iters: 80000 and CFP-FP: 88.0143 in iters: 90000
106 | 20181229-03:29:07 Train Epoch: 11/30 ...
107 | 20181229-03:40:58 Saving checkpoint: 100000
108 | 20181229-03:41:17 LFW Ave Accuracy: 99.4333
109 | 20181229-03:41:36 AgeDB-30 Ave Accuracy: 95.0500
110 | 20181229-03:41:57 CFP-FP Ave Accuracy: 90.5000
111 | 20181229-03:41:57 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.0500 in iters: 100000 and CFP-FP: 90.5000 in iters: 100000
112 | 20181229-04:12:37 Saving checkpoint: 105000
113 | 20181229-04:12:58 LFW Ave Accuracy: 99.3500
114 | 20181229-04:13:17 AgeDB-30 Ave Accuracy: 95.3500
115 | 20181229-04:13:38 CFP-FP Ave Accuracy: 91.1429
116 | 20181229-04:13:38 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.3500 in iters: 105000 and CFP-FP: 91.1429 in iters: 105000
117 | 20181229-04:31:24 Train Epoch: 12/30 ...
118 | 20181229-04:44:22 Saving checkpoint: 110000
119 | 20181229-04:44:42 LFW Ave Accuracy: 99.3500
120 | 20181229-04:45:01 AgeDB-30 Ave Accuracy: 95.2000
121 | 20181229-04:45:23 CFP-FP Ave Accuracy: 90.8286
122 | 20181229-04:45:23 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.3500 in iters: 105000 and CFP-FP: 91.1429 in iters: 105000
123 | 20181229-05:15:59 Saving checkpoint: 115000
124 | 20181229-05:16:19 LFW Ave Accuracy: 99.3167
125 | 20181229-05:16:38 AgeDB-30 Ave Accuracy: 95.4167
126 | 20181229-05:17:00 CFP-FP Ave Accuracy: 90.5286
127 | 20181229-05:17:00 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.4167 in iters: 115000 and CFP-FP: 91.1429 in iters: 105000
128 | 20181229-05:33:34 Train Epoch: 13/30 ...
129 | 20181229-05:47:47 Saving checkpoint: 120000
130 | 20181229-05:48:07 LFW Ave Accuracy: 99.3833
131 | 20181229-05:48:27 AgeDB-30 Ave Accuracy: 95.6500
132 | 20181229-05:48:50 CFP-FP Ave Accuracy: 90.4714
133 | 20181229-05:48:50 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.1429 in iters: 105000
134 | 20181229-06:19:36 Saving checkpoint: 125000
135 | 20181229-06:19:56 LFW Ave Accuracy: 99.3333
136 | 20181229-06:20:16 AgeDB-30 Ave Accuracy: 95.3333
137 | 20181229-06:20:38 CFP-FP Ave Accuracy: 91.3143
138 | 20181229-06:20:38 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000
139 | 20181229-06:36:08 Train Epoch: 14/30 ...
140 | 20181229-06:51:23 Saving checkpoint: 130000
141 | 20181229-06:51:44 LFW Ave Accuracy: 99.2667
142 | 20181229-06:52:03 AgeDB-30 Ave Accuracy: 95.2500
143 | 20181229-06:52:24 CFP-FP Ave Accuracy: 90.7571
144 | 20181229-06:52:24 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000
145 | 20181229-07:22:58 Saving checkpoint: 135000
146 | 20181229-07:23:18 LFW Ave Accuracy: 99.4333
147 | 20181229-07:23:37 AgeDB-30 Ave Accuracy: 95.3667
148 | 20181229-07:23:59 CFP-FP Ave Accuracy: 90.9000
149 | 20181229-07:23:59 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000
150 | 20181229-07:38:14 Train Epoch: 15/30 ...
151 | 20181229-07:54:39 Saving checkpoint: 140000
152 | 20181229-07:54:59 LFW Ave Accuracy: 99.4167
153 | 20181229-07:55:18 AgeDB-30 Ave Accuracy: 95.5167
154 | 20181229-07:55:40 CFP-FP Ave Accuracy: 91.2286
155 | 20181229-07:55:40 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3143 in iters: 125000
156 | 20181229-08:26:12 Saving checkpoint: 145000
157 | 20181229-08:26:32 LFW Ave Accuracy: 99.3667
158 | 20181229-08:26:51 AgeDB-30 Ave Accuracy: 95.5000
159 | 20181229-08:27:13 CFP-FP Ave Accuracy: 91.3571
160 | 20181229-08:27:13 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3571 in iters: 145000
161 | 20181229-08:40:15 Train Epoch: 16/30 ...
162 | 20181229-08:57:48 Saving checkpoint: 150000
163 | 20181229-08:58:07 LFW Ave Accuracy: 99.3167
164 | 20181229-08:58:27 AgeDB-30 Ave Accuracy: 95.5500
165 | 20181229-08:58:48 CFP-FP Ave Accuracy: 91.1857
166 | 20181229-08:58:48 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3571 in iters: 145000
167 | 20181229-09:29:19 Saving checkpoint: 155000
168 | 20181229-09:29:39 LFW Ave Accuracy: 99.3500
169 | 20181229-09:29:59 AgeDB-30 Ave Accuracy: 95.5000
170 | 20181229-09:30:22 CFP-FP Ave Accuracy: 91.3000
171 | 20181229-09:30:22 Current Best Accuracy: LFW: 99.4333 in iters: 100000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.3571 in iters: 145000
172 | 20181229-09:42:12 Train Epoch: 17/30 ...
173 | 20181229-10:00:53 Saving checkpoint: 160000
174 | 20181229-10:01:13 LFW Ave Accuracy: 99.5167
175 | 20181229-10:01:35 AgeDB-30 Ave Accuracy: 95.3333
176 | 20181229-10:01:56 CFP-FP Ave Accuracy: 91.6714
177 | 20181229-10:01:56 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000
178 | 20181229-10:32:23 Saving checkpoint: 165000
179 | 20181229-10:32:44 LFW Ave Accuracy: 99.4000
180 | 20181229-10:33:03 AgeDB-30 Ave Accuracy: 95.4167
181 | 20181229-10:33:24 CFP-FP Ave Accuracy: 90.5714
182 | 20181229-10:33:24 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000
183 | 20181229-10:44:05 Train Epoch: 18/30 ...
184 | 20181229-11:03:53 Saving checkpoint: 170000
185 | 20181229-11:04:13 LFW Ave Accuracy: 99.4000
186 | 20181229-11:04:34 AgeDB-30 Ave Accuracy: 95.6000
187 | 20181229-11:04:55 CFP-FP Ave Accuracy: 90.8571
188 | 20181229-11:04:55 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000
189 | 20181229-11:35:21 Saving checkpoint: 175000
190 | 20181229-11:35:40 LFW Ave Accuracy: 99.4167
191 | 20181229-11:35:59 AgeDB-30 Ave Accuracy: 95.3333
192 | 20181229-11:36:22 CFP-FP Ave Accuracy: 91.1000
193 | 20181229-11:36:22 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 95.6500 in iters: 120000 and CFP-FP: 91.6714 in iters: 160000
194 | 20181229-11:45:52 Train Epoch: 19/30 ...
195 | 20181229-12:06:48 Saving checkpoint: 180000
196 | 20181229-12:07:08 LFW Ave Accuracy: 99.5000
197 | 20181229-12:07:27 AgeDB-30 Ave Accuracy: 96.2167
198 | 20181229-12:07:48 CFP-FP Ave Accuracy: 92.1286
199 | 20181229-12:07:48 Current Best Accuracy: LFW: 99.5167 in iters: 160000, AgeDB-30: 96.2167 in iters: 180000 and CFP-FP: 92.1286 in iters: 180000
200 | 20181229-12:38:08 Saving checkpoint: 185000
201 | 20181229-12:38:28 LFW Ave Accuracy: 99.5333
202 | 20181229-12:38:47 AgeDB-30 Ave Accuracy: 96.3833
203 | 20181229-12:39:10 CFP-FP Ave Accuracy: 92.2714
204 | 20181229-12:39:10 Current Best Accuracy: LFW: 99.5333 in iters: 185000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.2714 in iters: 185000
205 | 20181229-12:47:30 Train Epoch: 20/30 ...
206 | 20181229-13:09:32 Saving checkpoint: 190000
207 | 20181229-13:09:52 LFW Ave Accuracy: 99.5667
208 | 20181229-13:10:12 AgeDB-30 Ave Accuracy: 96.2167
209 | 20181229-13:10:35 CFP-FP Ave Accuracy: 92.5571
210 | 20181229-13:10:35 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.5571 in iters: 190000
211 | 20181229-13:40:57 Saving checkpoint: 195000
212 | 20181229-13:41:16 LFW Ave Accuracy: 99.5500
213 | 20181229-13:41:37 AgeDB-30 Ave Accuracy: 96.3667
214 | 20181229-13:41:58 CFP-FP Ave Accuracy: 92.5857
215 | 20181229-13:41:58 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.5857 in iters: 195000
216 | 20181229-13:49:09 Train Epoch: 21/30 ...
217 | 20181229-14:12:22 Saving checkpoint: 200000
218 | 20181229-14:12:42 LFW Ave Accuracy: 99.4667
219 | 20181229-14:13:01 AgeDB-30 Ave Accuracy: 96.3500
220 | 20181229-14:13:22 CFP-FP Ave Accuracy: 92.5571
221 | 20181229-14:13:22 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.5857 in iters: 195000
222 | 20181229-14:43:41 Saving checkpoint: 205000
223 | 20181229-14:44:00 LFW Ave Accuracy: 99.5333
224 | 20181229-14:44:19 AgeDB-30 Ave Accuracy: 96.1833
225 | 20181229-14:44:40 CFP-FP Ave Accuracy: 92.7429
226 | 20181229-14:44:40 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.7429 in iters: 205000
227 | 20181229-14:50:40 Train Epoch: 22/30 ...
228 | 20181229-15:15:03 Saving checkpoint: 210000
229 | 20181229-15:15:23 LFW Ave Accuracy: 99.5167
230 | 20181229-15:15:44 AgeDB-30 Ave Accuracy: 96.1333
231 | 20181229-15:16:06 CFP-FP Ave Accuracy: 92.7571
232 | 20181229-15:16:06 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.3833 in iters: 185000 and CFP-FP: 92.7571 in iters: 210000
233 | 20181229-15:46:26 Saving checkpoint: 215000
234 | 20181229-15:46:46 LFW Ave Accuracy: 99.4833
235 | 20181229-15:47:07 AgeDB-30 Ave Accuracy: 96.4500
236 | 20181229-15:47:28 CFP-FP Ave Accuracy: 92.6571
237 | 20181229-15:47:28 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.4500 in iters: 215000 and CFP-FP: 92.7571 in iters: 210000
238 | 20181229-15:52:19 Train Epoch: 23/30 ...
239 | 20181229-16:17:53 Saving checkpoint: 220000
240 | 20181229-16:18:12 LFW Ave Accuracy: 99.5500
241 | 20181229-16:18:31 AgeDB-30 Ave Accuracy: 96.0833
242 | 20181229-16:18:52 CFP-FP Ave Accuracy: 92.6000
243 | 20181229-16:18:52 Current Best Accuracy: LFW: 99.5667 in iters: 190000, AgeDB-30: 96.4500 in iters: 215000 and CFP-FP: 92.7571 in iters: 210000
244 | 20181229-16:49:11 Saving checkpoint: 225000
245 | 20181229-16:49:30 LFW Ave Accuracy: 99.5833
246 | 20181229-16:49:50 AgeDB-30 Ave Accuracy: 96.3500
247 | 20181229-16:50:11 CFP-FP Ave Accuracy: 92.9000
248 | 20181229-16:50:11 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.4500 in iters: 215000 and CFP-FP: 92.9000 in iters: 225000
249 | 20181229-16:53:53 Train Epoch: 24/30 ...
250 | 20181229-17:20:36 Saving checkpoint: 230000
251 | 20181229-17:20:57 LFW Ave Accuracy: 99.5167
252 | 20181229-17:21:16 AgeDB-30 Ave Accuracy: 96.5667
253 | 20181229-17:21:37 CFP-FP Ave Accuracy: 92.6714
254 | 20181229-17:21:37 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000
255 | 20181229-17:52:01 Saving checkpoint: 235000
256 | 20181229-17:52:21 LFW Ave Accuracy: 99.5833
257 | 20181229-17:52:41 AgeDB-30 Ave Accuracy: 96.2500
258 | 20181229-17:53:03 CFP-FP Ave Accuracy: 92.7571
259 | 20181229-17:53:03 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000
260 | 20181229-17:55:35 Train Epoch: 25/30 ...
261 | 20181229-18:23:30 Saving checkpoint: 240000
262 | 20181229-18:23:50 LFW Ave Accuracy: 99.5333
263 | 20181229-18:24:10 AgeDB-30 Ave Accuracy: 96.2833
264 | 20181229-18:24:32 CFP-FP Ave Accuracy: 92.4143
265 | 20181229-18:24:32 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000
266 | 20181229-18:54:59 Saving checkpoint: 245000
267 | 20181229-18:55:19 LFW Ave Accuracy: 99.5167
268 | 20181229-18:55:39 AgeDB-30 Ave Accuracy: 96.1667
269 | 20181229-18:56:00 CFP-FP Ave Accuracy: 92.8143
270 | 20181229-18:56:00 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000
271 | 20181229-18:57:22 Train Epoch: 26/30 ...
272 | 20181229-19:26:30 Saving checkpoint: 250000
273 | 20181229-19:26:50 LFW Ave Accuracy: 99.5667
274 | 20181229-19:27:10 AgeDB-30 Ave Accuracy: 96.4500
275 | 20181229-19:27:32 CFP-FP Ave Accuracy: 92.6000
276 | 20181229-19:27:32 Current Best Accuracy: LFW: 99.5833 in iters: 225000, AgeDB-30: 96.5667 in iters: 230000 and CFP-FP: 92.9000 in iters: 225000
277 |
--------------------------------------------------------------------------------
/result/softmax.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/result/softmax.gif
--------------------------------------------------------------------------------
/result/softmax_center.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/result/softmax_center.gif
--------------------------------------------------------------------------------
/result/visualization.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/result/visualization.jpg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: train.py.py
7 | @time: 2018/12/21 17:37
8 | @desc: train script for deep face recognition
9 | '''
10 |
11 | import os
12 | import torch.utils.data
13 | from torch.nn import DataParallel
14 | from datetime import datetime
15 | from backbone.mobilefacenet import MobileFaceNet
16 | from backbone.cbam import CBAMResNet
17 | from backbone.attention import ResidualAttentionNet_56, ResidualAttentionNet_92
18 | from margin.ArcMarginProduct import ArcMarginProduct
19 | from margin.MultiMarginProduct import MultiMarginProduct
20 | from margin.CosineMarginProduct import CosineMarginProduct
21 | from margin.InnerProduct import InnerProduct
22 | from utils.visualize import Visualizer
23 | from utils.logging import init_log
24 | from dataset.casia_webface import CASIAWebFace
25 | from dataset.lfw import LFW
26 | from dataset.agedb import AgeDB30
27 | from dataset.cfp import CFP_FP
28 | from torch.optim import lr_scheduler
29 | import torch.optim as optim
30 | import time
31 | from eval_lfw import evaluation_10_fold, getFeatureFromTorch
32 | import numpy as np
33 | import torchvision.transforms as transforms
34 | import argparse
35 |
36 |
37 | def train(args):
38 | # gpu init
39 | multi_gpus = False
40 | if len(args.gpus.split(',')) > 1:
41 | multi_gpus = True
42 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
43 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44 |
45 | # log init
46 | save_dir = os.path.join(args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
47 | if os.path.exists(save_dir):
48 | raise NameError('model dir exists!')
49 | os.makedirs(save_dir)
50 | logging = init_log(save_dir)
51 | _print = logging.info
52 |
53 | # dataset loader
54 | transform = transforms.Compose([
55 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
56 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
57 | ])
58 | # validation dataset
59 | trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform)
60 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
61 | shuffle=True, num_workers=8, drop_last=False)
62 | # test dataset
63 | lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform)
64 | lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128,
65 | shuffle=False, num_workers=4, drop_last=False)
66 | agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform)
67 | agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128,
68 | shuffle=False, num_workers=4, drop_last=False)
69 | cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform)
70 | cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128,
71 | shuffle=False, num_workers=4, drop_last=False)
72 |
73 | # define backbone and margin layer
74 | if args.backbone == 'MobileFace':
75 | net = MobileFaceNet()
76 | elif args.backbone == 'Res50_IR':
77 | net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir')
78 | elif args.backbone == 'SERes50_IR':
79 | net = CBAMResNet(50, feature_dim=args.feature_dim, mode='ir_se')
80 | elif args.backbone == 'Res100_IR':
81 | net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir')
82 | elif args.backbone == 'SERes100_IR':
83 | net = CBAMResNet(100, feature_dim=args.feature_dim, mode='ir_se')
84 | elif args.backbone == 'Attention_56':
85 | net = ResidualAttentionNet_56(feature_dim=args.feature_dim)
86 | elif args.backbone == 'Attention_92':
87 | net = ResidualAttentionNet_92(feature_dim=args.feature_dim)
88 | else:
89 | print(args.backbone, ' is not available!')
90 |
91 | if args.margin_type == 'ArcFace':
92 | margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
93 | elif args.margin_type == 'MultiMargin':
94 | margin = MultiMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
95 | elif args.margin_type == 'CosFace':
96 | margin = CosineMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
97 | elif args.margin_type == 'Softmax':
98 | margin = InnerProduct(args.feature_dim, trainset.class_nums)
99 | elif args.margin_type == 'SphereFace':
100 | pass
101 | else:
102 | print(args.margin_type, 'is not available!')
103 |
104 | if args.resume:
105 | print('resume the model parameters from: ', args.net_path, args.margin_path)
106 | net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
107 | margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])
108 |
109 | # define optimizers for different layer
110 | criterion = torch.nn.CrossEntropyLoss().to(device)
111 | optimizer_ft = optim.SGD([
112 | {'params': net.parameters(), 'weight_decay': 5e-4},
113 | {'params': margin.parameters(), 'weight_decay': 5e-4}
114 | ], lr=0.1, momentum=0.9, nesterov=True)
115 | exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer_ft, milestones=[6, 11, 16], gamma=0.1)
116 |
117 | if multi_gpus:
118 | net = DataParallel(net).to(device)
119 | margin = DataParallel(margin).to(device)
120 | else:
121 | net = net.to(device)
122 | margin = margin.to(device)
123 |
124 |
125 | best_lfw_acc = 0.0
126 | best_lfw_iters = 0
127 | best_agedb30_acc = 0.0
128 | best_agedb30_iters = 0
129 | best_cfp_fp_acc = 0.0
130 | best_cfp_fp_iters = 0
131 | total_iters = 0
132 | vis = Visualizer(env=args.model_pre + args.backbone)
133 | for epoch in range(1, args.total_epoch + 1):
134 | exp_lr_scheduler.step()
135 | # train model
136 | _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
137 | net.train()
138 |
139 | since = time.time()
140 | for data in trainloader:
141 | img, label = data[0].to(device), data[1].to(device)
142 | optimizer_ft.zero_grad()
143 |
144 | raw_logits = net(img)
145 | output = margin(raw_logits, label)
146 | total_loss = criterion(output, label)
147 | total_loss.backward()
148 | optimizer_ft.step()
149 |
150 | total_iters += 1
151 | # print train information
152 | if total_iters % 100 == 0:
153 | # current training accuracy
154 | _, predict = torch.max(output.data, 1)
155 | total = label.size(0)
156 | correct = (np.array(predict.cpu()) == np.array(label.data.cpu())).sum()
157 | time_cur = (time.time() - since) / 100
158 | since = time.time()
159 | vis.plot_curves({'softmax loss': total_loss.item()}, iters=total_iters, title='train loss',
160 | xlabel='iters', ylabel='train loss')
161 | vis.plot_curves({'train accuracy': correct / total}, iters=total_iters, title='train accuracy', xlabel='iters',
162 | ylabel='train accuracy')
163 |
164 | _print("Iters: {:0>6d}/[{:0>2d}], loss: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters, epoch, total_loss.item(), correct/total, time_cur, exp_lr_scheduler.get_lr()[0]))
165 |
166 | # save model
167 | if total_iters % args.save_freq == 0:
168 | msg = 'Saving checkpoint: {}'.format(total_iters)
169 | _print(msg)
170 | if multi_gpus:
171 | net_state_dict = net.module.state_dict()
172 | margin_state_dict = margin.module.state_dict()
173 | else:
174 | net_state_dict = net.state_dict()
175 | margin_state_dict = margin.state_dict()
176 | if not os.path.exists(save_dir):
177 | os.mkdir(save_dir)
178 | torch.save({
179 | 'iters': total_iters,
180 | 'net_state_dict': net_state_dict},
181 | os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
182 | torch.save({
183 | 'iters': total_iters,
184 | 'net_state_dict': margin_state_dict},
185 | os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters))
186 |
187 | # test accuracy
188 | if total_iters % args.test_freq == 0:
189 |
190 | # test model on lfw
191 | net.eval()
192 | getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader)
193 | lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
194 | _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100))
195 | if best_lfw_acc <= np.mean(lfw_accs) * 100:
196 | best_lfw_acc = np.mean(lfw_accs) * 100
197 | best_lfw_iters = total_iters
198 |
199 | # test model on AgeDB30
200 | getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader)
201 | age_accs = evaluation_10_fold('./result/cur_agedb30_result.mat')
202 | _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(np.mean(age_accs) * 100))
203 | if best_agedb30_acc <= np.mean(age_accs) * 100:
204 | best_agedb30_acc = np.mean(age_accs) * 100
205 | best_agedb30_iters = total_iters
206 |
207 | # test model on CFP-FP
208 | getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader)
209 | cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat')
210 | _print('CFP-FP Ave Accuracy: {:.4f}'.format(np.mean(cfp_accs) * 100))
211 | if best_cfp_fp_acc <= np.mean(cfp_accs) * 100:
212 | best_cfp_fp_acc = np.mean(cfp_accs) * 100
213 | best_cfp_fp_iters = total_iters
214 | _print('Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format(
215 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
216 |
217 | vis.plot_curves({'lfw': np.mean(lfw_accs), 'agedb-30': np.mean(age_accs), 'cfp-fp': np.mean(cfp_accs)}, iters=total_iters,
218 | title='test accuracy', xlabel='iters', ylabel='test accuracy')
219 | net.train()
220 |
221 | _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format(
222 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
223 | print('finishing training')
224 |
225 |
226 | if __name__ == '__main__':
227 | parser = argparse.ArgumentParser(description='PyTorch for deep face recognition')
228 | parser.add_argument('--train_root', type=str, default='/media/ramdisk/msra_align_112', help='train image root')
229 | parser.add_argument('--train_file_list', type=str, default='/media/ramdisk/msra_align_train.list', help='train list')
230 | parser.add_argument('--lfw_test_root', type=str, default='/media/sda/lfw/lfw_align_112', help='lfw image root')
231 | parser.add_argument('--lfw_file_list', type=str, default='/media/sda/lfw/pairs.txt', help='lfw pair file list')
232 | parser.add_argument('--agedb_test_root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='agedb image root')
233 | parser.add_argument('--agedb_file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='agedb pair file list')
234 | parser.add_argument('--cfpfp_test_root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='agedb image root')
235 | parser.add_argument('--cfpfp_file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='agedb pair file list')
236 |
237 | parser.add_argument('--backbone', type=str, default='SERes100_IR', help='MobileFace, Res50_IR, SERes50_IR, Res100_IR, SERes100_IR, Attention_56, Attention_92')
238 | parser.add_argument('--margin_type', type=str, default='ArcFace', help='ArcFace, CosFace, SphereFace, MultiMargin, Softmax')
239 | parser.add_argument('--feature_dim', type=int, default=512, help='feature dimension, 128 or 512')
240 | parser.add_argument('--scale_size', type=float, default=32.0, help='scale size')
241 | parser.add_argument('--batch_size', type=int, default=200, help='batch size')
242 | parser.add_argument('--total_epoch', type=int, default=18, help='total epochs')
243 |
244 | parser.add_argument('--save_freq', type=int, default=3000, help='save frequency')
245 | parser.add_argument('--test_freq', type=int, default=3000, help='test frequency')
246 | parser.add_argument('--resume', type=int, default=False, help='resume model')
247 | parser.add_argument('--net_path', type=str, default='', help='resume model')
248 | parser.add_argument('--margin_path', type=str, default='', help='resume model')
249 | parser.add_argument('--save_dir', type=str, default='./model', help='model save dir')
250 | parser.add_argument('--model_pre', type=str, default='SERES100_', help='model prefix')
251 | parser.add_argument('--gpus', type=str, default='0,1,2,3', help='model prefix')
252 |
253 | args = parser.parse_args()
254 |
255 | train(args)
256 |
257 |
258 |
--------------------------------------------------------------------------------
/train_center.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: train_center.py
7 | @time: 2019/1/3 11:12
8 | @desc: train script for my attention net and center loss
9 | '''
10 |
11 | '''
12 | Pleause use the train.py for your training process.
13 | '''
14 |
15 | import os
16 | import torch.utils.data
17 | from torch.nn import DataParallel
18 | from datetime import datetime
19 | from backbone.mobilefacenet import MobileFaceNet
20 | from backbone.resnet import ResNet50, ResNet101
21 | from backbone.arcfacenet import SEResNet_IR
22 | from backbone.spherenet import SphereNet
23 | from margin.ArcMarginProduct import ArcMarginProduct
24 | from margin.InnerProduct import InnerProduct
25 | from lossfunctions.centerloss import CenterLoss
26 | from utils.logging import init_log
27 | from dataset.casia_webface import CASIAWebFace
28 | from dataset.lfw import LFW
29 | from dataset.agedb import AgeDB30
30 | from dataset.cfp import CFP_FP
31 | from utils.visualize import Visualizer
32 | from torch.optim import lr_scheduler
33 | import torch.optim as optim
34 | import time
35 | from eval_lfw import evaluation_10_fold, getFeatureFromTorch
36 | import numpy as np
37 | import torchvision.transforms as transforms
38 | import argparse
39 |
40 |
41 | def train(args):
42 | # gpu init
43 | multi_gpus = False
44 | if len(args.gpus.split(',')) > 1:
45 | multi_gpus = True
46 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
47 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48 |
49 | # log init
50 | save_dir = os.path.join(args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
51 | if os.path.exists(save_dir):
52 | raise NameError('model dir exists!')
53 | os.makedirs(save_dir)
54 | logging = init_log(save_dir)
55 | _print = logging.info
56 |
57 | # dataset loader
58 | transform = transforms.Compose([
59 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
60 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
61 | ])
62 | # validation dataset
63 | trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform)
64 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
65 | shuffle=True, num_workers=8, drop_last=False)
66 | # test dataset
67 | lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform)
68 | lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128,
69 | shuffle=False, num_workers=4, drop_last=False)
70 | agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform)
71 | agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128,
72 | shuffle=False, num_workers=4, drop_last=False)
73 | cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform)
74 | cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128,
75 | shuffle=False, num_workers=4, drop_last=False)
76 |
77 | # define backbone and margin layer
78 | if args.backbone == 'MobileFace':
79 | net = MobileFaceNet()
80 | elif args.backbone == 'Res50':
81 | net = ResNet50()
82 | elif args.backbone == 'Res101':
83 | net = ResNet101()
84 | elif args.backbone == 'Res50_IR':
85 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir')
86 | elif args.backbone == 'SERes50_IR':
87 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir')
88 | elif args.backbone == 'SphereNet':
89 | net = SphereNet(num_layers=64, feature_dim=args.feature_dim)
90 | else:
91 | print(args.backbone, ' is not available!')
92 |
93 | if args.margin_type == 'ArcFace':
94 | margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
95 | elif args.margin_type == 'CosFace':
96 | pass
97 | elif args.margin_type == 'SphereFace':
98 | pass
99 | elif args.margin_type == 'InnerProduct':
100 | margin = InnerProduct(args.feature_dim, trainset.class_nums)
101 | else:
102 | print(args.margin_type, 'is not available!')
103 |
104 | if args.resume:
105 | print('resume the model parameters from: ', args.net_path, args.margin_path)
106 | net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
107 | margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])
108 |
109 | # define optimizers for different layers
110 | criterion_classi = torch.nn.CrossEntropyLoss().to(device)
111 | optimizer_classi = optim.SGD([
112 | {'params': net.parameters(), 'weight_decay': 5e-4},
113 | {'params': margin.parameters(), 'weight_decay': 5e-4}
114 | ], lr=0.1, momentum=0.9, nesterov=True)
115 |
116 | #criterion_center = CenterLoss(trainset.class_nums, args.feature_dim).to(device)
117 | #optimizer_center = optim.SGD(criterion_center.parameters(), lr=0.5)
118 |
119 | scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi, milestones=[25, 50, 65], gamma=0.1)
120 |
121 | if multi_gpus:
122 | net = DataParallel(net).to(device)
123 | margin = DataParallel(margin).to(device)
124 | else:
125 | net = net.to(device)
126 | margin = margin.to(device)
127 |
128 | best_lfw_acc = 0.0
129 | best_lfw_iters = 0
130 | best_agedb30_acc = 0.0
131 | best_agedb30_iters = 0
132 | best_cfp_fp_acc = 0.0
133 | best_cfp_fp_iters = 0
134 | total_iters = 0
135 | #vis = Visualizer(env='softmax_center_xavier')
136 | for epoch in range(1, args.total_epoch + 1):
137 | scheduler_classi.step()
138 | # train model
139 | _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
140 | net.train()
141 |
142 | since = time.time()
143 | for data in trainloader:
144 | img, label = data[0].to(device), data[1].to(device)
145 | feature = net(img)
146 | output = margin(feature)
147 | loss_classi = criterion_classi(output, label)
148 | #loss_center = criterion_center(feature, label)
149 | total_loss = loss_classi #+ loss_center * args.weight_center
150 |
151 | optimizer_classi.zero_grad()
152 | #optimizer_center.zero_grad()
153 | total_loss.backward()
154 | optimizer_classi.step()
155 | #optimizer_center.step()
156 |
157 | total_iters += 1
158 | # print train information
159 | if total_iters % 100 == 0:
160 | # current training accuracy
161 | _, predict = torch.max(output.data, 1)
162 | total = label.size(0)
163 | correct = (np.array(predict) == np.array(label.data)).sum()
164 | time_cur = (time.time() - since) / 100
165 | since = time.time()
166 | #vis.plot_curves({'softmax loss': loss_classi.item(), 'center loss': loss_center.item()}, iters=total_iters, title='train loss', xlabel='iters', ylabel='train loss')
167 | #vis.plot_curves({'train accuracy': correct / total}, iters=total_iters, title='train accuracy', xlabel='iters', ylabel='train accuracy')
168 | print("Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, loss_center: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters,
169 | epoch,
170 | loss_classi.item(),
171 | loss_center.item(),
172 | correct/total,
173 | time_cur,
174 | scheduler_classi.get_lr()[
175 | 0]))
176 | # save model
177 | if total_iters % args.save_freq == 0:
178 | msg = 'Saving checkpoint: {}'.format(total_iters)
179 | _print(msg)
180 | if multi_gpus:
181 | net_state_dict = net.module.state_dict()
182 | margin_state_dict = margin.module.state_dict()
183 | else:
184 | net_state_dict = net.state_dict()
185 | margin_state_dict = margin.state_dict()
186 |
187 | if not os.path.exists(save_dir):
188 | os.mkdir(save_dir)
189 | torch.save({
190 | 'iters': total_iters,
191 | 'net_state_dict': net_state_dict},
192 | os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
193 | torch.save({
194 | 'iters': total_iters,
195 | 'net_state_dict': margin_state_dict},
196 | os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters))
197 | #torch.save({
198 | # 'iters': total_iters,
199 | # 'net_state_dict': criterion_center.state_dict()},
200 | # os.path.join(save_dir, 'Iter_%06d_center.ckpt' % total_iters))
201 |
202 | # test accuracy
203 | if total_iters % args.test_freq == 0:
204 |
205 | # test model on lfw
206 | net.eval()
207 | getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader)
208 | lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
209 | _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100))
210 | if best_lfw_acc < np.mean(lfw_accs) * 100:
211 | best_lfw_acc = np.mean(lfw_accs) * 100
212 | best_lfw_iters = total_iters
213 |
214 | # test model on AgeDB30
215 | getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader)
216 | age_accs = evaluation_10_fold('./result/cur_agedb30_result.mat')
217 | _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(np.mean(age_accs) * 100))
218 | if best_agedb30_acc < np.mean(age_accs) * 100:
219 | best_agedb30_acc = np.mean(age_accs) * 100
220 | best_agedb30_iters = total_iters
221 |
222 | # test model on CFP-FP
223 | getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader)
224 | cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat')
225 | _print('CFP-FP Ave Accuracy: {:.4f}'.format(np.mean(cfp_accs) * 100))
226 | if best_cfp_fp_acc < np.mean(cfp_accs) * 100:
227 | best_cfp_fp_acc = np.mean(cfp_accs) * 100
228 | best_cfp_fp_iters = total_iters
229 | _print('Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format(
230 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
231 |
232 | #vis.plot_curves({'lfw': np.mean(lfw_accs), 'agedb-30': np.mean(age_accs), 'cfp-fp': np.mean(cfp_accs)}, iters=total_iters,
233 | # title='test accuracy', xlabel='iters', ylabel='test accuracy')
234 | net.train()
235 |
236 | _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format(
237 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
238 | print('finishing training')
239 |
240 |
241 | if __name__ == '__main__':
242 | parser = argparse.ArgumentParser(description='PyTorch for deep face recognition')
243 | parser.add_argument('--train_root', type=str, default='/media/ramdisk/webface_align_112', help='train image root')
244 | parser.add_argument('--train_file_list', type=str, default='/media/ramdisk/webface_align_train.list', help='train list')
245 | parser.add_argument('--lfw_test_root', type=str, default='/media/ramdisk/lfw_align_112', help='lfw image root')
246 | parser.add_argument('--lfw_file_list', type=str, default='/media/ramdisk/pairs.txt', help='lfw pair file list')
247 | parser.add_argument('--agedb_test_root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='agedb image root')
248 | parser.add_argument('--agedb_file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='agedb pair file list')
249 | parser.add_argument('--cfpfp_test_root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='agedb image root')
250 | parser.add_argument('--cfpfp_file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='agedb pair file list')
251 |
252 | parser.add_argument('--backbone', type=str, default='MobileFace', help='MobileFace, Res50, Res101, Res50_IR, SERes50_IR, SphereNet')
253 | parser.add_argument('--margin_type', type=str, default='InnerProduct', help='InnerProduct, ArcFace, CosFace, SphereFace')
254 | parser.add_argument('--feature_dim', type=int, default=128, help='feature dimension, 128 or 512')
255 | parser.add_argument('--scale_size', type=float, default=32.0, help='scale size')
256 | parser.add_argument('--batch_size', type=int, default=256, help='batch size')
257 | parser.add_argument('--total_epoch', type=int, default=80, help='total epochs')
258 | parser.add_argument('--weight_center', type=float, default=0.01, help='center loss weight')
259 |
260 | parser.add_argument('--save_freq', type=int, default=2000, help='save frequency')
261 | parser.add_argument('--test_freq', type=int, default=2000, help='test frequency')
262 | parser.add_argument('--resume', type=int, default=False, help='resume model')
263 | parser.add_argument('--net_path', type=str, default='', help='resume model')
264 | parser.add_argument('--margin_path', type=str, default='', help='resume model')
265 | parser.add_argument('--save_dir', type=str, default='./model', help='model save dir')
266 | parser.add_argument('--model_pre', type=str, default='Softmax_Center_', help='model prefix')
267 | parser.add_argument('--gpus', type=str, default='0,1', help='model prefix')
268 |
269 | args = parser.parse_args()
270 |
271 | train(args)
272 |
273 |
274 |
--------------------------------------------------------------------------------
/train_softmax.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: train_softmax.py
7 | @time: 2019/1/7 8:33
8 | @desc: original softmax training with Casia-Webface
9 | '''
10 | '''
11 | Pleause use the train.py for your training process.
12 | '''
13 |
14 | import os
15 | import torch.utils.data
16 | from torch.nn import DataParallel
17 | from datetime import datetime
18 | from backbone.mobilefacenet import MobileFaceNet
19 | from backbone.resnet import ResNet50, ResNet101
20 | from backbone.arcfacenet import SEResNet_IR
21 | from backbone.spherenet import SphereNet
22 | from margin.ArcMarginProduct import ArcMarginProduct
23 | from margin.InnerProduct import InnerProduct
24 | from utils.visualize import Visualizer
25 | from utils.logging import init_log
26 | from dataset.casia_webface import CASIAWebFace
27 | from dataset.lfw import LFW
28 | from dataset.agedb import AgeDB30
29 | from dataset.cfp import CFP_FP
30 | from torch.optim import lr_scheduler
31 | import torch.optim as optim
32 | import time
33 | from eval_lfw import evaluation_10_fold, getFeatureFromTorch
34 | import numpy as np
35 | import torchvision.transforms as transforms
36 | import argparse
37 |
38 | def train(args):
39 | # gpu init
40 | multi_gpus = False
41 | if len(args.gpus.split(',')) > 1:
42 | multi_gpus = True
43 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45 |
46 | # log init
47 | save_dir = os.path.join(args.save_dir, args.model_pre + args.backbone.upper() + '_' + datetime.now().strftime('%Y%m%d_%H%M%S'))
48 | if os.path.exists(save_dir):
49 | raise NameError('model dir exists!')
50 | os.makedirs(save_dir)
51 | logging = init_log(save_dir)
52 | _print = logging.info
53 |
54 | # dataset loader
55 | transform = transforms.Compose([
56 | transforms.ToTensor(), # range [0, 255] -> [0.0,1.0]
57 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # range [0.0, 1.0] -> [-1.0,1.0]
58 | ])
59 | # validation dataset
60 | trainset = CASIAWebFace(args.train_root, args.train_file_list, transform=transform)
61 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
62 | shuffle=True, num_workers=8, drop_last=False)
63 | # test dataset
64 | lfwdataset = LFW(args.lfw_test_root, args.lfw_file_list, transform=transform)
65 | lfwloader = torch.utils.data.DataLoader(lfwdataset, batch_size=128,
66 | shuffle=False, num_workers=4, drop_last=False)
67 | agedbdataset = AgeDB30(args.agedb_test_root, args.agedb_file_list, transform=transform)
68 | agedbloader = torch.utils.data.DataLoader(agedbdataset, batch_size=128,
69 | shuffle=False, num_workers=4, drop_last=False)
70 | cfpfpdataset = CFP_FP(args.cfpfp_test_root, args.cfpfp_file_list, transform=transform)
71 | cfpfploader = torch.utils.data.DataLoader(cfpfpdataset, batch_size=128,
72 | shuffle=False, num_workers=4, drop_last=False)
73 |
74 | # define backbone and margin layer
75 | if args.backbone == 'MobileFace':
76 | net = MobileFaceNet()
77 | elif args.backbone == 'Res50':
78 | net = ResNet50()
79 | elif args.backbone == 'Res101':
80 | net = ResNet101()
81 | elif args.backbone == 'Res50_IR':
82 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='ir')
83 | elif args.backbone == 'SERes50_IR':
84 | net = SEResNet_IR(50, feature_dim=args.feature_dim, mode='se_ir')
85 | elif args.backbone == 'SphereNet':
86 | net = SphereNet(num_layers=64, feature_dim=args.feature_dim)
87 | else:
88 | print(args.backbone, ' is not available!')
89 |
90 | if args.margin_type == 'ArcFace':
91 | margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
92 | elif args.margin_type == 'CosFace':
93 | pass
94 | elif args.margin_type == 'SphereFace':
95 | pass
96 | elif args.margin_type == 'InnerProduct':
97 | margin = InnerProduct(args.feature_dim, trainset.class_nums)
98 | else:
99 | print(args.margin_type, 'is not available!')
100 |
101 | if args.resume:
102 | print('resume the model parameters from: ', args.net_path, args.margin_path)
103 | net.load_state_dict(torch.load(args.net_path)['net_state_dict'])
104 | margin.load_state_dict(torch.load(args.margin_path)['net_state_dict'])
105 |
106 | # define optimizers for different layer
107 |
108 | criterion_classi = torch.nn.CrossEntropyLoss().to(device)
109 | optimizer_classi = optim.SGD([
110 | {'params': net.parameters(), 'weight_decay': 5e-4},
111 | {'params': margin.parameters(), 'weight_decay': 5e-4}
112 | ], lr=0.1, momentum=0.9, nesterov=True)
113 | scheduler_classi = lr_scheduler.MultiStepLR(optimizer_classi, milestones=[20, 35, 45], gamma=0.1)
114 |
115 | if multi_gpus:
116 | net = DataParallel(net).to(device)
117 | margin = DataParallel(margin).to(device)
118 | else:
119 | net = net.to(device)
120 | margin = margin.to(device)
121 |
122 | best_lfw_acc = 0.0
123 | best_lfw_iters = 0
124 | best_agedb30_acc = 0.0
125 | best_agedb30_iters = 0
126 | best_cfp_fp_acc = 0.0
127 | best_cfp_fp_iters = 0
128 | total_iters = 0
129 | vis = Visualizer(env='softmax_train')
130 | for epoch in range(1, args.total_epoch + 1):
131 | scheduler_classi.step()
132 | # train model
133 | _print('Train Epoch: {}/{} ...'.format(epoch, args.total_epoch))
134 | net.train()
135 |
136 | since = time.time()
137 | for data in trainloader:
138 | img, label = data[0].to(device), data[1].to(device)
139 | feature = net(img)
140 | output = margin(feature)
141 | loss_classi = criterion_classi(output, label)
142 | total_loss = loss_classi
143 |
144 | optimizer_classi.zero_grad()
145 | total_loss.backward()
146 | optimizer_classi.step()
147 |
148 | total_iters += 1
149 | # print train information
150 | if total_iters % 100 == 0:
151 | #current training accuracy
152 | _, predict = torch.max(output.data, 1)
153 | total = label.size(0)
154 | correct = (np.array(predict) == np.array(label.data)).sum()
155 | time_cur = (time.time() - since) / 100
156 | since = time.time()
157 | vis.plot_curves({'train loss': loss_classi.item()}, iters=total_iters, title='train loss', xlabel='iters', ylabel='train loss')
158 | vis.plot_curves({'train accuracy': correct/total}, iters=total_iters, title='train accuracy', xlabel='iters', ylabel='train accuracy')
159 | print("Iters: {:0>6d}/[{:0>2d}], loss_classi: {:.4f}, train_accuracy: {:.4f}, time: {:.2f} s/iter, learning rate: {}".format(total_iters,
160 | epoch,
161 | loss_classi.item(),
162 | correct/total,
163 | time_cur,
164 | scheduler_classi.get_lr()[
165 | 0]))
166 | # save model
167 | if total_iters % args.save_freq == 0:
168 | msg = 'Saving checkpoint: {}'.format(total_iters)
169 | _print(msg)
170 | if multi_gpus:
171 | net_state_dict = net.module.state_dict()
172 | margin_state_dict = margin.module.state_dict()
173 | else:
174 | net_state_dict = net.state_dict()
175 | margin_state_dict = margin.state_dict()
176 |
177 | if not os.path.exists(save_dir):
178 | os.mkdir(save_dir)
179 | torch.save({
180 | 'iters': total_iters,
181 | 'net_state_dict': net_state_dict},
182 | os.path.join(save_dir, 'Iter_%06d_net.ckpt' % total_iters))
183 | torch.save({
184 | 'iters': total_iters,
185 | 'net_state_dict': margin_state_dict},
186 | os.path.join(save_dir, 'Iter_%06d_margin.ckpt' % total_iters))
187 |
188 | # test accuracy
189 | if total_iters % args.test_freq == 0:
190 | # test model on lfw
191 | net.eval()
192 | getFeatureFromTorch('./result/cur_lfw_result.mat', net, device, lfwdataset, lfwloader)
193 | lfw_accs = evaluation_10_fold('./result/cur_lfw_result.mat')
194 | _print('LFW Ave Accuracy: {:.4f}'.format(np.mean(lfw_accs) * 100))
195 | if best_lfw_acc < np.mean(lfw_accs) * 100:
196 | best_lfw_acc = np.mean(lfw_accs) * 100
197 | best_lfw_iters = total_iters
198 | # test model on AgeDB30
199 | getFeatureFromTorch('./result/cur_agedb30_result.mat', net, device, agedbdataset, agedbloader)
200 | age_accs = evaluation_10_fold('./result/cur_agedb30_result.mat')
201 | _print('AgeDB-30 Ave Accuracy: {:.4f}'.format(np.mean(age_accs) * 100))
202 | if best_agedb30_acc < np.mean(age_accs) * 100:
203 | best_agedb30_acc = np.mean(age_accs) * 100
204 | best_agedb30_iters = total_iters
205 | # test model on CFP-FP
206 | getFeatureFromTorch('./result/cur_cfpfp_result.mat', net, device, cfpfpdataset, cfpfploader)
207 | cfp_accs = evaluation_10_fold('./result/cur_cfpfp_result.mat')
208 | _print('CFP-FP Ave Accuracy: {:.4f}'.format(np.mean(cfp_accs) * 100))
209 | if best_cfp_fp_acc < np.mean(cfp_accs) * 100:
210 | best_cfp_fp_acc = np.mean(cfp_accs) * 100
211 | best_cfp_fp_iters = total_iters
212 | _print('Current Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format(
213 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
214 | vis.plot_curves({'lfw': np.mean(lfw_accs), 'agedb-30': np.mean(age_accs), 'cfp-fp': np.mean(cfp_accs)}, iters=total_iters, title='test accuracy', xlabel='iters', ylabel='test accuracy')
215 | net.train()
216 |
217 | _print('Finally Best Accuracy: LFW: {:.4f} in iters: {}, AgeDB-30: {:.4f} in iters: {} and CFP-FP: {:.4f} in iters: {}'.format(
218 | best_lfw_acc, best_lfw_iters, best_agedb30_acc, best_agedb30_iters, best_cfp_fp_acc, best_cfp_fp_iters))
219 | print('finishing training')
220 |
221 |
222 | if __name__ == '__main__':
223 | parser = argparse.ArgumentParser(description='PyTorch for deep face recognition')
224 | parser.add_argument('--train_root', type=str, default='/media/ramdisk/webface_align_112', help='train image root')
225 | parser.add_argument('--train_file_list', type=str, default='/media/ramdisk/webface_align_train.list', help='train list')
226 | parser.add_argument('--lfw_test_root', type=str, default='/media/ramdisk/lfw_align_112', help='lfw image root')
227 | parser.add_argument('--lfw_file_list', type=str, default='/media/ramdisk/pairs.txt', help='lfw pair file list')
228 | parser.add_argument('--agedb_test_root', type=str, default='/media/sda/AgeDB-30/agedb30_align_112', help='agedb image root')
229 | parser.add_argument('--agedb_file_list', type=str, default='/media/sda/AgeDB-30/agedb_30_pair.txt', help='agedb pair file list')
230 | parser.add_argument('--cfpfp_test_root', type=str, default='/media/sda/CFP-FP/cfp_fp_aligned_112', help='agedb image root')
231 | parser.add_argument('--cfpfp_file_list', type=str, default='/media/sda/CFP-FP/cfp_fp_pair.txt', help='agedb pair file list')
232 |
233 | parser.add_argument('--backbone', type=str, default='MobileFace', help='MobileFace, Res50, Res101, Res50_IR, SERes50_IR, SphereNet')
234 | parser.add_argument('--margin_type', type=str, default='InnerProduct', help='InnerProduct, ArcFace, CosFace, SphereFace')
235 | parser.add_argument('--feature_dim', type=int, default=128, help='feature dimension, 128 or 512')
236 | parser.add_argument('--scale_size', type=float, default=32.0, help='scale size')
237 | parser.add_argument('--batch_size', type=int, default=256, help='batch size')
238 | parser.add_argument('--total_epoch', type=int, default=50, help='total epochs')
239 | parser.add_argument('--weight_center', type=float, default=1.0, help='center loss weight')
240 |
241 | parser.add_argument('--save_freq', type=int, default=2000, help='save frequency')
242 | parser.add_argument('--test_freq', type=int, default=2000, help='test frequency')
243 | parser.add_argument('--resume', type=int, default=False, help='resume model')
244 | parser.add_argument('--net_path', type=str, default='', help='resume model')
245 | parser.add_argument('--margin_path', type=str, default='', help='resume model')
246 | parser.add_argument('--save_dir', type=str, default='./model', help='model save dir')
247 | parser.add_argument('--model_pre', type=str, default='Softmax_', help='model prefix')
248 | parser.add_argument('--gpus', type=str, default='2,3', help='model prefix')
249 |
250 | args = parser.parse_args()
251 |
252 | train(args)
253 |
254 |
255 |
--------------------------------------------------------------------------------
/utils/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## MXNET binary tools
3 |
4 | Tools for restore the aligned images from mxnet binart file provided by [insightface](https://github.com/deepinsight/insightface).
5 |
6 | You should install a mxnet-cpu first for the image parsing, just do ' **pip install mxnet** ' is ok.
7 |
8 | The processed images are list below:
9 | [LFW @ BaiduNetdisk](https://pan.baidu.com/s/1Rue4FBmGvdGMPkyy2ZqcdQ), [AgeDB-30 @ BaiduNetdisk](https://pan.baidu.com/s/1sdw1lO5JfP6Ja99O7zprUg), [CFP_FP @ BaiduNetdisk](https://pan.baidu.com/s/1gyFAAy427weUd2G-ozMgEg)
10 |
11 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: __init__.py.py
7 | @time: 2018/12/22 9:41
8 | @desc:
9 | '''
--------------------------------------------------------------------------------
/utils/load_images_from_bin.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: load_images_from_bin.py
7 | @time: 2018/12/25 19:21
8 | @desc: For AgeDB-30 and CFP-FP test dataset, we use the mxnet binary file provided by insightface, this is the tool to restore
9 | the aligned images from mxnet binary file.
10 | You should install a mxnet-cpu first, just do 'pip install mxnet==1.2.1' is ok.
11 | '''
12 |
13 | from PIL import Image
14 | import cv2
15 | import os
16 | import pickle
17 | import mxnet as mx
18 | from tqdm import tqdm
19 |
20 | '''
21 | For train dataset, insightface provide a mxnet .rec file, just install a mxnet-cpu for extract images
22 | '''
23 |
24 | def load_mx_rec(rec_path):
25 | save_path = os.path.join(rec_path, 'emore_images_2')
26 | if not os.path.exists(save_path):
27 | os.makedirs(save_path)
28 |
29 | imgrec = mx.recordio.MXIndexedRecordIO(os.path.join(rec_path, 'train.idx'), os.path.join(rec_path, 'train.rec'), 'r')
30 | img_info = imgrec.read_idx(0)
31 | header,_ = mx.recordio.unpack(img_info)
32 | max_idx = int(header.label[0])
33 | for idx in tqdm(range(1,max_idx)):
34 | img_info = imgrec.read_idx(idx)
35 | header, img = mx.recordio.unpack_img(img_info)
36 | label = int(header.label)
37 | #img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
38 | #img = Image.fromarray(img)
39 | label_path = os.path.join(save_path, str(label).zfill(6))
40 | if not os.path.exists(label_path):
41 | os.makedirs(label_path)
42 | #img.save(os.path.join(label_path, str(idx).zfill(8) + '.jpg'), quality=95)
43 | cv2.imwrite(os.path.join(label_path, str(idx).zfill(8) + '.jpg'), img)
44 |
45 |
46 | def load_image_from_bin(bin_path, save_dir):
47 | if not os.path.exists(save_dir):
48 | os.makedirs(save_dir)
49 | file = open(os.path.join(save_dir, '../', 'lfw_pair.txt'), 'w')
50 | bins, issame_list = pickle.load(open(bin_path, 'rb'), encoding='bytes')
51 | for idx in tqdm(range(len(bins))):
52 | _bin = bins[idx]
53 | img = mx.image.imdecode(_bin).asnumpy()
54 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
55 | cv2.imwrite(os.path.join(save_dir, str(idx+1).zfill(5)+'.jpg'), img)
56 | if idx % 2 == 0:
57 | label = 1 if issame_list[idx//2] == True else -1
58 | file.write(str(idx+1).zfill(5) + '.jpg' + ' ' + str(idx+2).zfill(5) +'.jpg' + ' ' + str(label) + '\n')
59 |
60 |
61 | if __name__ == '__main__':
62 | #bin_path = 'D:/face_data_emore/faces_webface_112x112/lfw.bin'
63 | #save_dir = 'D:/face_data_emore/faces_webface_112x112/lfw'
64 | rec_path = 'D:/face_data_emore/faces_emore'
65 | load_mx_rec(rec_path)
66 | #load_image_from_bin(bin_path, save_dir)
67 |
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: logging.py
7 | @time: 2018/12/22 9:42
8 | @desc: logging tools
9 | '''
10 |
11 | from __future__ import print_function
12 | import os
13 | import logging
14 |
15 |
16 | def init_log(output_dir):
17 | logging.basicConfig(level=logging.DEBUG,
18 | format='%(asctime)s %(message)s',
19 | datefmt='%Y%m%d-%H:%M:%S',
20 | filename=os.path.join(output_dir, 'log.log'),
21 | filemode='w')
22 | console = logging.StreamHandler()
23 | console.setLevel(logging.INFO)
24 | logging.getLogger('').addHandler(console)
25 | return logging
26 |
27 |
28 | if __name__ == '__main__':
29 | pass
30 |
--------------------------------------------------------------------------------
/utils/plot_logit.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: plot_logit.py
7 | @time: 2019/3/29 14:21
8 | @desc: plot the logit corresponding to shpereface, cosface, arcface and so on.
9 | '''
10 |
11 | import math
12 | import torch
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 |
16 | def softmax(theta):
17 | return torch.cos(theta)
18 |
19 | def sphereface(theta, m=4):
20 | return (torch.cos(m * theta) + 20 * torch.cos(theta)) / (20 + 1)
21 |
22 | def cosface(theta, m):
23 | return torch.cos(theta) - m
24 |
25 | def arcface(theta, m):
26 | return torch.cos(theta + m)
27 |
28 | def multimargin(theta, m1, m2):
29 | return torch.cos(theta + m1) - m2
30 |
31 |
32 | theta = torch.arange(0, math.pi, 0.001)
33 | print(theta.type)
34 |
35 | x = theta.numpy()
36 | y_softmax = softmax(theta).numpy()
37 | y_cosface = cosface(theta, 0.35).numpy()
38 | y_arcface = arcface(theta, 0.5).numpy()
39 |
40 | y_multimargin_1 = multimargin(theta, 0.2, 0.3).numpy()
41 | y_multimargin_2 = multimargin(theta, 0.2, 0.4).numpy()
42 | y_multimargin_3 = multimargin(theta, 0.3, 0.2).numpy()
43 | y_multimargin_4 = multimargin(theta, 0.3, 0.3).numpy()
44 | y_multimargin_5 = multimargin(theta, 0.4, 0.2).numpy()
45 | y_multimargin_6 = multimargin(theta, 0.4, 0.3).numpy()
46 |
47 | plt.plot(x, y_softmax, x, y_cosface, x, y_arcface, x, y_multimargin_1, x, y_multimargin_2, x, y_multimargin_3, x, y_multimargin_4, x, y_multimargin_5, x, y_multimargin_6)
48 | plt.legend(['Softmax(0.00, 0.00)', 'CosFace(0.00, 0.35)', 'ArcFace(0.50, 0.00)', 'MultiMargin(0.20, 0.30)', 'MultiMargin(0.20, 0.40)', 'MultiMargin(0.30, 0.20)', 'MultiMargin(0.30, 0.30)', 'MultiMargin(0.40, 0.20)', 'MultiMargin(0.40, 0.30)'])
49 | plt.grid(False)
50 | plt.xlim((0, 3/4*math.pi))
51 | plt.ylim((-1.2, 1.2))
52 |
53 | plt.xticks(np.arange(0, 2.4, 0.3))
54 | plt.yticks(np.arange(-1.2, 1.2, 0.2))
55 | plt.xlabel('Angular between the Feature and Target Center (Radian: 0 - 3/4 Pi)')
56 | plt.ylabel('Target Logit')
57 |
58 | plt.savefig('target logits')
--------------------------------------------------------------------------------
/utils/plot_theta.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: plot_theta.py
7 | @time: 2019/1/2 19:08
8 | @desc: plot theta distribution between weight and feature vector
9 | '''
10 |
11 | from matplotlib import pyplot as plt
12 | plt.switch_backend('agg')
13 |
14 | import argparse
15 | from backbone.mobilefacenet import MobileFaceNet
16 | from margin.ArcMarginProduct import ArcMarginProduct
17 | from torch.utils.data import DataLoader
18 | import torch
19 |
20 | from torchvision import transforms
21 | import torch.nn.functional as F
22 | import os
23 | import numpy as np
24 | from dataset.casia_webface import CASIAWebFace
25 |
26 |
27 | def get_train_loader(img_folder, filelist):
28 | print('Loading dataset...')
29 | transform = transforms.Compose([
30 | transforms.ToTensor(),
31 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
32 | ])
33 | trainset = CASIAWebFace(img_folder, filelist, transform=transform)
34 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
35 | shuffle=False, num_workers=8, drop_last=False)
36 | return trainloader
37 |
38 | def load_model(backbone_state_dict, margin_state_dict, device):
39 |
40 | # load model
41 | net = MobileFaceNet()
42 | net.load_state_dict(torch.load(backbone_state_dict)['net_state_dict'])
43 | margin = ArcMarginProduct(in_feature=128, out_feature=10575)
44 | margin.load_state_dict(torch.load(margin_state_dict)['net_state_dict'])
45 |
46 | net = net.to(device)
47 | margin = margin.to(device)
48 |
49 | return net.eval(), margin.eval()
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='plot theta distribution of trained model')
54 | parser.add_argument('--img_root', type=str, default='/media/ramdisk/webface_align_112', help='train image root')
55 | parser.add_argument('--file_list', type=str, default='/media/ramdisk/webface_align_train.list', help='train list')
56 | parser.add_argument('--backbone_file', type=str, default='../model/Paper_MOBILEFACE_20190103_111830/Iter_088000_net.ckpt', help='backbone state dict file')
57 | parser.add_argument('--margin_file', type=str, default='../model/Paper_MOBILEFACE_20190103_111830/Iter_088000_margin.ckpt', help='backbone state dict file')
58 | parser.add_argument('--gpus', type=str, default='0', help='model prefix, single gpu only')
59 | args = parser.parse_args()
60 |
61 | # gpu init
62 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
63 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64 |
65 | # load pretrain model
66 | trained_net, trained_margin = load_model(args.backbone_file, args.margin_file, device)
67 |
68 | # initial model
69 | initial_net = MobileFaceNet()
70 | initial_margin = ArcMarginProduct()
71 | initial_net = initial_net.to(device).eval()
72 | initial_margin = initial_margin.to(device).eval()
73 |
74 | # image dataloader
75 | image_loader = get_train_loader(args.img_root, args.file_list)
76 | theta_trained = []
77 | theta_initial = []
78 | for data in image_loader:
79 | img, label = data[0].to(device), data[1].to(device)
80 | # pretrained
81 | embedding = trained_net(img)
82 | cos_theta = F.linear(F.normalize(embedding), F.normalize(trained_margin.weight))
83 | cos_theta = cos_theta.clamp(-1, 1).detach().cpu().numpy()
84 | for i in range(img.shape[0]):
85 | cos_trget = cos_theta[i][label[i]]
86 | theta_trained.append(np.arccos(cos_trget) / np.pi * 180)
87 | # initial
88 | embedding = initial_net(img)
89 | cos_theta = F.linear(F.normalize(embedding), F.normalize(initial_margin.weight))
90 | cos_theta = cos_theta.clamp(-1, 1).detach().cpu().numpy()
91 | for i in range(img.shape[0]):
92 | cos_trget = cos_theta[i][label[i]]
93 | theta_initial.append(np.arccos(cos_trget) / np.pi * 180)
94 | '''
95 | # write theta list to txt file
96 | trained_theta_file = open('arcface_theta.txt', 'w')
97 | initial_theta_file = open('initial_theta.txt', 'w')
98 | for item in theta_trained:
99 | trained_theta_file.write(str(item))
100 | trained_theta_file.write('\n')
101 | for item in theta_initial:
102 | initial_theta_file.write(str(item))
103 | initial_theta_file.write('\n')
104 |
105 | # plot the theta, read theta from txt first
106 | theta_trained = []
107 | theta_initial = []
108 | trained_theta_file = open('arcface_theta.txt', 'r')
109 | initial_theta_file = open('initial_theta.txt', 'r')
110 | lines = trained_theta_file.readlines()
111 | for line in lines:
112 | theta_trained.append(float(line.strip('\n')[0]))
113 | lines = initial_theta_file.readlines()
114 | for line in lines:
115 | theta_initial.append(float(line.split('\n')[0]))
116 | '''
117 | print(len(theta_trained), len(theta_initial))
118 | plt.figure()
119 | plt.xlabel('Theta')
120 | plt.ylabel('Numbers')
121 | plt.title('Theta Distribution')
122 | plt.hist(theta_trained, bins=180, normed=0)
123 | plt.hist(theta_initial, bins=180, normed=0)
124 | plt.legend(['trained theta distribution', 'initial theta distribution'])
125 | plt.savefig('theta_distribution_hist.jpg')
126 |
--------------------------------------------------------------------------------
/utils/theta_distribution_hist.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wujiyang/Face_Pytorch/3afd941c01cae3eb73b66a48fa984f41bd6662fc/utils/theta_distribution_hist.jpg
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 | '''
4 | @author: wujiyang
5 | @contact: wujiyang@hust.edu.cn
6 | @file: visualize.py
7 | @time: 2019/1/7 16:07
8 | @desc: visualize tools
9 | '''
10 |
11 | import visdom
12 | import numpy as np
13 | import time
14 |
15 | class Visualizer():
16 | def __init__(self, env='default', **kwargs):
17 | self.vis = visdom.Visdom(env=env, **kwargs)
18 | self.index = 1
19 |
20 | def plot_curves(self, d, iters, title='loss', xlabel='iters', ylabel='accuracy'):
21 | name = list(d.keys())
22 | val = list(d.values())
23 | if len(val) == 1:
24 | y = np.array(val)
25 | else:
26 | y = np.array(val).reshape(-1, len(val))
27 | self.vis.line(Y=y,
28 | X=np.array([self.index]),
29 | win=title,
30 | opts=dict(legend=name, title = title, xlabel=xlabel, ylabel=ylabel),
31 | update=None if self.index == 0 else 'append')
32 | self.index = iters
33 |
34 |
35 | if __name__ == '__main__':
36 | vis = Visualizer(env='test')
37 | for i in range(10):
38 | x = i
39 | y = 2 * i
40 | z = 4 * i
41 | vis.plot_curves({'train': x, 'test': y}, iters=i, title='train')
42 | vis.plot_curves({'train': z, 'test': y, 'val': i}, iters=i, title='test')
43 | time.sleep(1)
--------------------------------------------------------------------------------