├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── .gitignore ├── data └── .gitignore ├── dataset.py ├── evaluation ├── __init__.py ├── compare.ipynb ├── result │ ├── eer_irr.png │ ├── eer_irr_SD.png │ ├── eer_irr_TJ.png │ ├── eer_irr_hd_SD.png │ ├── eer_irr_hd_TJ.png │ ├── hd_SD.png │ ├── heatmap.png │ ├── network.png │ ├── result.pkl │ └── result_tj.pkl ├── test-cx1.ipynb └── test-cx2.ipynb ├── log └── .gitignore ├── model ├── __init__.py ├── loss.py ├── pretrained │ ├── 1203_202301_MobileNetV2_Lite_CX1.pth │ └── 1211_202056_MobileNetV2_Lite_CX2.pth └── quality_model.py ├── test.py ├── train.py └── util ├── IrisQualityEvaluation.cpp ├── IrisQualityEvaluation.h ├── eye_quality_fact.py └── fmeasure.py /.gitignore: -------------------------------------------------------------------------------- 1 | # File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig 2 | 3 | # Created by https://www.gitignore.io/api/visualstudiocode,linux,jupyternotebooks,python 4 | # Edit at https://www.gitignore.io/?templates=visualstudiocode,linux,jupyternotebooks,python 5 | 6 | ### JupyterNotebooks ### 7 | # gitignore template for Jupyter Notebooks 8 | # website: http://jupyter.org/ 9 | 10 | .ipynb_checkpoints 11 | */.ipynb_checkpoints/* 12 | 13 | # IPython 14 | profile_default/ 15 | ipython_config.py 16 | 17 | # Remove previous ipynb_checkpoints 18 | # git rm -r .ipynb_checkpoints/ 19 | 20 | ### Linux ### 21 | *~ 22 | 23 | # temporary files which can be created if a process still has a handle open of a deleted file 24 | .fuse_hidden* 25 | 26 | # KDE directory preferences 27 | .directory 28 | 29 | # Linux trash folder which might appear on any partition or disk 30 | .Trash-* 31 | 32 | # .nfs files are created when an open file is removed but is still being accessed 33 | .nfs* 34 | 35 | ### Python ### 36 | # Byte-compiled / optimized / DLL files 37 | __pycache__/ 38 | *.py[cod] 39 | *$py.class 40 | 41 | # C extensions 42 | *.so 43 | 44 | # Distribution / packaging 45 | .Python 46 | build/ 47 | develop-eggs/ 48 | dist/ 49 | downloads/ 50 | eggs/ 51 | .eggs/ 52 | lib/ 53 | lib64/ 54 | parts/ 55 | sdist/ 56 | var/ 57 | wheels/ 58 | pip-wheel-metadata/ 59 | share/python-wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | MANIFEST 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .nox/ 79 | .coverage 80 | .coverage.* 81 | .cache 82 | nosetests.xml 83 | coverage.xml 84 | *.cover 85 | .hypothesis/ 86 | .pytest_cache/ 87 | 88 | # Translations 89 | *.mo 90 | *.pot 91 | 92 | # Scrapy stuff: 93 | .scrapy 94 | 95 | # Sphinx documentation 96 | docs/_build/ 97 | 98 | # PyBuilder 99 | target/ 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # celery beat schedule file 112 | celerybeat-schedule 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # Mr Developer 125 | .mr.developer.cfg 126 | .project 127 | .pydevproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | ### VisualStudioCode ### 141 | .vscode/* 142 | 143 | ### VisualStudioCode Patch ### 144 | # Ignore all local history of files 145 | .history 146 | 147 | # End of https://www.gitignore.io/api/visualstudiocode,linux,jupyternotebooks,python 148 | 149 | # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) 150 | feature 151 | .idea/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recognition Oriented Iris Image Quality Assessment in the Feature Space 2 | 3 | ![network](evaluation/result/network.png) 4 | 5 | [Paper](https://arxiv.org/abs/2009.00294) [Dataset](http://www.cripacsir.cn/dataset/casia-iris-degradation/) 6 | 7 | ## Prerequisites 8 | 9 | - pytorch 1.0 10 | - torchvision 0.2 11 | - opencv 3.4 12 | - scipy 13 | - thop 14 | 15 | ## Citing 16 | 17 | If DFSNet is useful for your research, please consider citing: 18 | 19 | ```bibtex 20 | @InProceedings{wang-ijcb2020, 21 | author = {Leyuan, Wang and Kunbo, Zhang and Min, Ren and Yunlong, Wang and Zhenan, Sun}, 22 | title = {Recognition Oriented Iris Image Quality Assessment in the Feature Space}, 23 | booktitle = {International Joint Conference on Biometrics 2020 (IJCB2020)}, 24 | year = {2020}} 25 | } 26 | ``` 27 | -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | import scipy.io as sio 6 | import torch 7 | from PIL import Image 8 | from torch.utils import data 9 | from torchvision.transforms import transforms 10 | 11 | 12 | class BaseDataset(data.Dataset): 13 | def __init__(self, 14 | path, 15 | mode='train', 16 | debug_data=False, 17 | size=(640, 480), 18 | seed=3141): 19 | np.random.seed(seed) 20 | 21 | self.path = path 22 | self.mode = mode 23 | 24 | if self.mode == 'train' or self.mode == 'val': 25 | with open(os.path.join(path, 'train.txt'), 'r') as f: 26 | img_list = [ 27 | tuple(line.strip().split(' ')) for line in f.readlines() 28 | ] 29 | np.random.shuffle(img_list) 30 | if self.mode == 'train': 31 | self.img_list = img_list[:int( 32 | 0.1 * len(img_list) 33 | )] if debug_data else img_list[:int(0.8 * len(img_list))] 34 | else: 35 | self.img_list = img_list[int( 36 | 0.95 * len(img_list) 37 | ):] if debug_data else img_list[int(0.8 * len(img_list)):] 38 | else: 39 | with open(os.path.join(path, 'test.txt'), 'r') as f: 40 | img_list = [ 41 | tuple(line.strip().split(' ')) for line in f.readlines() 42 | ] 43 | self.img_list = img_list[:int(0.1 * len(img_list) 44 | )] if debug_data else img_list 45 | self.all_imglist = img_list 46 | 47 | self.transform = transforms.Compose([ 48 | transforms.Resize(size), 49 | transforms.ToTensor(), 50 | ]) 51 | self.Norm = transforms.Normalize(mean=[0.480], 52 | std=[0.200], 53 | inplace=False) 54 | 55 | def __len__(self): 56 | return len(self.img_list) 57 | 58 | def __getitem__(self, item): 59 | img_name, score = self.img_list[item] 60 | score = torch.tensor(float(score), dtype=torch.float).view((-1)) 61 | 62 | img = Image.open(os.path.join(self.path, 'Image', img_name)) 63 | img = self.transform(img) 64 | return img, score 65 | 66 | 67 | class monoSimDataset(BaseDataset): 68 | def __init__(self, 69 | path, 70 | mode='train', 71 | debug_data=False, 72 | size=(480, 640), 73 | seed=3141, 74 | upsample=False): 75 | super(monoSimDataset, self).__init__(path, mode, debug_data, size, 76 | seed) 77 | 78 | self.sim = sio.loadmat(os.path.join(path, 'sim.mat'))['sim'] 79 | self.sim = (self.sim + 1) / 2 80 | 81 | with open(os.path.join(path, 'gallery.txt'), 'r') as f: 82 | self.gallery_dict = { 83 | x.strip().split(' ')[1]: x.strip().split(' ')[0] 84 | for x in f.readlines() 85 | } 86 | self.index = {x[0]: x[2] for x in self.all_imglist} 87 | 88 | self.transmask = transforms.Compose([ 89 | transforms.Resize((size[0] // 4, size[1] // 4)), 90 | transforms.ToTensor(), 91 | ]) 92 | 93 | def __getitem__(self, item): 94 | img_name, img_label, x = self.img_list[item] 95 | y = self.index[self.gallery_dict[img_label]] 96 | score = self.sim[int(x), int(y)] 97 | score = torch.tensor(float(score), dtype=torch.float).view((-1)) 98 | 99 | img = Image.open(os.path.join(self.path, 'Image', img_name)) 100 | img = self.transform(img) 101 | img = self.Norm(img) 102 | if self.mode != 'test': 103 | mask = Image.open( 104 | os.path.join(self.path, 'Mask', 105 | img_name.split('.')[0] + '.png')) 106 | mask = self.transmask(mask) 107 | else: 108 | mask = torch.ones_like(img) 109 | img = torch.cat((img, img, img)) 110 | ret = (img, mask, score, 111 | img_name) if self.mode == 'test' else (img, mask, score) 112 | return ret 113 | 114 | 115 | if __name__ == '__main__': 116 | data = monoSimDataset(path='data/cx2', debug_data=False) 117 | print(len(data)) 118 | # print(data.label_list) 119 | for x, y, z in data: 120 | print(x.shape, y.shape, z) 121 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/compare.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 44, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import time\n", 11 | "\n", 12 | "import numpy as np\n", 13 | "import pickle\n", 14 | "import scipy.io as scio\n", 15 | "import torch\n", 16 | "import torchvision.transforms.functional as transforms\n", 17 | "from scipy import stats\n", 18 | "from torch import nn\n", 19 | "from torch.utils.data import DataLoader\n", 20 | "from tqdm import tqdm\n", 21 | "from thop import clever_format\n", 22 | "from thop import profile\n", 23 | "from sklearn import metrics\n", 24 | "from sklearn.preprocessing import normalize\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import matplotlib.mlab as mlab \n", 27 | "\n", 28 | "from dataset import monoSimDataset\n", 29 | "from model.loss import FocalLoss, CrossEntropy2d\n", 30 | "from model.quality_model import MobileNetV2_Lite_shower, MobileNetV3_wA, MobileNetV3_Lite, MobileNetV2_Lite" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 142, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "def cal_DET(sim, labels):\n", 40 | " label_num = len(labels)\n", 41 | " sim = sim[~np.eye(label_num, dtype=np.bool)]\n", 42 | " if labels.max() == 1 and labels.min() == 0 and labels.shape[0] != labels.shape[1]:\n", 43 | " label = np.dot(labels, labels.T)\n", 44 | " else:\n", 45 | " label = np.zeros((label_num, label_num))\n", 46 | " for x in range(label_num):\n", 47 | " for y in range(label_num):\n", 48 | " label[x, y] = labels[x] == labels[y]\n", 49 | " label = label[~np.eye(label_num, dtype=np.bool)].astype(np.bool)\n", 50 | "\n", 51 | " fpr, tpr, thresholds = metrics.roc_curve(label, sim)\n", 52 | " fnr = 1 - tpr\n", 53 | "\n", 54 | " eer = fpr[np.argmin(np.abs(fpr - fnr))]\n", 55 | " roc_auc = metrics.auc(fpr, tpr)\n", 56 | " return eer, fnr, fpr, roc_auc, thresholds" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 67, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "tj_dataset_path = 'data/cx2'\n", 66 | "sd_dataset_path = 'data/cx1'\n", 67 | "\n", 68 | "cp_path = cp_path = \"checkpoints/1203_202301_MobileNetV2_Lite/421_1.3395e-03.pth\"\n", 69 | "model_name = 'MobileNetV2_Lite'\n", 70 | "seed = 2248\n", 71 | "device = 'cuda:0'" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 8, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "with open(tj_dataset_path + '/test.txt','r') as f:\n", 81 | " tj_p_list = [x.strip().split(' ') for x in f.readlines()]\n", 82 | "tj_index = {x[0]:(x[1],x[2]) for x in tj_p_list}\n", 83 | "tj_sim = (scio.loadmat(tj_dataset_path + '/sim.mat')['sim'] +1)/2\n", 84 | "\n", 85 | "with open(sd_dataset_path + '/test.txt','r') as f:\n", 86 | " sd_p_list = [x.strip().split(' ') for x in f.readlines()]\n", 87 | "sd_index = {x[0]:(x[1],x[2]) for x in sd_p_list}\n", 88 | "sd_sim = (scio.loadmat(sd_dataset_path + '/sim.mat')['sim'] +1)/2" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 16, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "TJ All Test EER:28.59% FNMR:100.00%\n", 101 | "SD All Test EER:12.41% FNMR:100.00%\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "tj_test_sim = np.zeros((len(tj_index), len(tj_index)))\n", 107 | "for tx , d in enumerate(tj_p_list):\n", 108 | " sx = int(d[2])\n", 109 | " for ty, d in enumerate(tj_p_list):\n", 110 | " sy = int(d[2])\n", 111 | " tj_test_sim[tx, ty] = tj_sim[sx,sy]\n", 112 | "tj_test_labels = np.array([int(x[1]) for x in tj_p_list])\n", 113 | "eer, fnr, fpr, roc_auc, thresholds = cal_DET(tj_test_sim, tj_test_labels)\n", 114 | "fnmr = fnr[np.argmin(np.abs(fpr - 1e-5))]\n", 115 | "print('TJ All Test EER:{:.2f}% FNMR:{:.2f}%'.format(eer*100, fnmr*100))\n", 116 | "\n", 117 | "sd_test_sim = np.zeros((len(sd_index), len(sd_index)))\n", 118 | "for tx , d in enumerate(sd_p_list):\n", 119 | " sx = int(d[2])\n", 120 | " for ty, d in enumerate(sd_p_list):\n", 121 | " sy = int(d[2])\n", 122 | " sd_test_sim[tx, ty] = sd_sim[sx,sy]\n", 123 | "sd_test_labels = np.array([int(x[1]) for x in sd_p_list])\n", 124 | "eer, fnr, fpr, roc_auc, thresholds = cal_DET(sd_test_sim, sd_test_labels)\n", 125 | "fnmr = fnr[np.argmin(np.abs(fpr - 1e-5))]\n", 126 | "print('SD All Test EER:{:.2f}% FNMR:{:.2f}%'.format(eer*100, fnmr*100))" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 15, 132 | "metadata": {}, 133 | "outputs": [ 134 | { 135 | "name": "stdout", 136 | "output_type": "stream", 137 | "text": [ 138 | "SD All Test EER:1.28% FNMR:100.00%\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "with open(sd_dataset_path + '/hq.txt','r') as f:\n", 144 | " sd_hq_set = set([x.split(' ')[0] for x in f.readlines()])\n", 145 | "sd_hq_names = [x[0] for x in sd_p_list if x[0] in sd_hq_set]\n", 146 | "\n", 147 | "sd_test_sim = np.zeros((len(sd_hq_names), len(sd_hq_names)))\n", 148 | "for tx , d in enumerate(sd_hq_names):\n", 149 | " sx = int(sd_index[d][1])\n", 150 | " for ty, d in enumerate(sd_hq_names):\n", 151 | " sy = int(sd_index[d][1])\n", 152 | " sd_test_sim[tx, ty] = sd_sim[sx,sy]\n", 153 | "sd_test_labels = np.array([int(sd_index[x][0]) for x in sd_hq_names])\n", 154 | "eer, fnr, fpr, roc_auc, thresholds = cal_DET(sd_test_sim, sd_test_labels)\n", 155 | "fnmr = fnr[np.argmin(np.abs(fpr - 1e-5))]\n", 156 | "print('SD All Test EER:{:.2f}% FNMR:{:.2f}%'.format(eer*100, fnmr*100))" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 18, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "92\n", 169 | "TJ All Test EER:16.46% FNMR:100.00%\n" 170 | ] 171 | } 172 | ], 173 | "source": [ 174 | "tj_hq_names = [x[0] for x in tj_p_list if x[0][8] == '1']\n", 175 | "print(len(tj_hq_names))\n", 176 | "\n", 177 | "tj_test_sim = np.zeros((len(tj_hq_names), len(tj_hq_names)))\n", 178 | "for tx , d in enumerate(tj_hq_names):\n", 179 | " sx = int(tj_index[d][1])\n", 180 | " for ty, d in enumerate(tj_hq_names):\n", 181 | " sy = int(tj_index[d][1])\n", 182 | " tj_test_sim[tx, ty] = tj_sim[sx,sy]\n", 183 | "tj_test_labels = np.array([int(tj_index[x][0]) for x in tj_hq_names])\n", 184 | "eer, fnr, fpr, roc_auc, thresholds = cal_DET(tj_test_sim, tj_test_labels)\n", 185 | "fnmr = fnr[np.argmin(np.abs(fpr - 1e-5))]\n", 186 | "print('TJ All Test EER:{:.2f}% FNMR:{:.2f}%'.format(eer*100, fnmr*100))" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 154, 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "92\n" 199 | ] 200 | }, 201 | { 202 | "data": { 203 | "image/png": "\n", 204 | "text/plain": [ 205 | "
" 206 | ] 207 | }, 208 | "metadata": { 209 | "needs_background": "light" 210 | }, 211 | "output_type": "display_data" 212 | } 213 | ], 214 | "source": [ 215 | "plt.figure()\n", 216 | "tj_test_sim = np.zeros((len(tj_index), len(tj_index)))\n", 217 | "for tx , d in enumerate(tj_p_list):\n", 218 | " sx = int(d[2])\n", 219 | " for ty, d in enumerate(tj_p_list):\n", 220 | " sy = int(d[2])\n", 221 | " tj_test_sim[tx, ty] = tj_sim[sx,sy]\n", 222 | "tj_test_labels = np.array([int(x[1]) for x in tj_p_list])\n", 223 | "eer1, fnr1, fpr1, roc_auc, thresholds = cal_DET(tj_test_sim, tj_test_labels)\n", 224 | "\n", 225 | "\n", 226 | "tj_hq_names = [x[0] for x in tj_p_list if x[0][8] == '1']\n", 227 | "print(len(tj_hq_names))\n", 228 | "\n", 229 | "tj_test_sim = np.zeros((len(tj_hq_names), len(tj_hq_names)))\n", 230 | "for tx , d in enumerate(tj_hq_names):\n", 231 | " sx = int(tj_index[d][1])\n", 232 | " for ty, d in enumerate(tj_hq_names):\n", 233 | " sy = int(tj_index[d][1])\n", 234 | " tj_test_sim[tx, ty] = tj_sim[sx,sy]\n", 235 | "tj_test_labels = np.array([int(tj_index[x][0]) for x in tj_hq_names])\n", 236 | "eer2, fnr2, fpr2, roc_auc, thresholds = cal_DET(tj_test_sim, tj_test_labels)\n", 237 | "\n", 238 | "plt.plot(fpr1,fnr1,c='r',label='Non-ideal eer={:.2f}%'.format(eer1*100))\n", 239 | "plt.plot(fpr2,fnr2,c='g',label='Ideal eer={:.2f}%'.format(eer2*100))\n", 240 | "# plt.yscale('log')\n", 241 | "plt.xscale('log')\n", 242 | "plt.xlabel('FPR')\n", 243 | "plt.ylabel('FNR')\n", 244 | "plt.legend(loc = 'upper right')\n", 245 | "plt.plot(np.arange(0,1,0.01),np.arange(0,1,0.01),linestyle='-.',c='gray',label='eer')\n", 246 | "plt.show()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 155, 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "image/png": "\n", 257 | "text/plain": [ 258 | "
" 259 | ] 260 | }, 261 | "metadata": { 262 | "needs_background": "light" 263 | }, 264 | "output_type": "display_data" 265 | } 266 | ], 267 | "source": [ 268 | "sd_test_sim = np.zeros((len(sd_index), len(sd_index)))\n", 269 | "for tx , d in enumerate(sd_p_list):\n", 270 | " sx = int(d[2])\n", 271 | " for ty, d in enumerate(sd_p_list):\n", 272 | " sy = int(d[2])\n", 273 | " sd_test_sim[tx, ty] = sd_sim[sx,sy]\n", 274 | "sd_test_labels = np.array([int(x[1]) for x in sd_p_list])\n", 275 | "eer1, fnr1, fpr1, roc_auc, thresholds = cal_DET(sd_test_sim, sd_test_labels)\n", 276 | "\n", 277 | "sd_test_sim = np.zeros((len(sd_hq_names), len(sd_hq_names)))\n", 278 | "for tx , d in enumerate(sd_hq_names):\n", 279 | " sx = int(sd_index[d][1])\n", 280 | " for ty, d in enumerate(sd_hq_names):\n", 281 | " sy = int(sd_index[d][1])\n", 282 | " sd_test_sim[tx, ty] = sd_sim[sx,sy]\n", 283 | "sd_test_labels = np.array([int(sd_index[x][0]) for x in sd_hq_names])\n", 284 | "eer2, fnr2, fpr2, roc_auc, thresholds = cal_DET(sd_test_sim, sd_test_labels)\n", 285 | "\n", 286 | "plt.plot(fpr1,fnr1,c='r',label='Non-ideal eer={:.2f}%'.format(eer1*100))\n", 287 | "plt.plot(fpr2,fnr2,c='g',label='Ideal eer={:.2f}%'.format(eer2*100))\n", 288 | "plt.xscale('log')\n", 289 | "plt.xlabel('FPR')\n", 290 | "plt.ylabel('FNR')\n", 291 | "plt.legend(loc = 'upper right')\n", 292 | "plt.plot(np.arange(0,1,0.01),np.arange(0,1,0.01),linestyle='-.',c='gray',label='eer')\n", 293 | "plt.show()" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": 19, 299 | "metadata": {}, 300 | "outputs": [], 301 | "source": [ 302 | "with open('tj_train_quality.txt', 'r') as f:\n", 303 | " qdata = [x.strip() for x in f.readlines()]\n", 304 | "qdict = {x.split(',')[0]:[float(y) for y in x.split(',')[1].split(' ')[1:]] for x in qdata}\n", 305 | "\n", 306 | "qarray = np.zeros((len(qdict),5))\n", 307 | "for idx,n in enumerate(qdict.keys()):\n", 308 | " qarray[idx,:] = np.array(qdict[n])\n", 309 | "\n", 310 | "tj_fm = qarray[:, 0]\n", 311 | "tj_size = qarray[:, 1]\n", 312 | "tj_dilation = qarray[:, 2]\n", 313 | "tj_gls = qarray[:, 3]\n", 314 | "tj_uar = qarray[:, 4]\n", 315 | "\n", 316 | "with open('sd_train_quality.txt', 'r') as f:\n", 317 | " qdata = [x.strip() for x in f.readlines()]\n", 318 | "qdict = {x.split(',')[0]:[float(y) for y in x.split(',')[1].split(' ')[1:]] for x in qdata}\n", 319 | "\n", 320 | "qarray = np.zeros((len(qdict),5))\n", 321 | "for idx,n in enumerate(qdict.keys()):\n", 322 | " qarray[idx,:] = np.array(qdict[n])\n", 323 | "\n", 324 | "sd_fm = qarray[:, 0]\n", 325 | "sd_size = qarray[:, 1]\n", 326 | "sd_dilation = qarray[:, 2]\n", 327 | "sd_gls = qarray[:, 3]\n", 328 | "sd_uar = qarray[:, 4]" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": 36, 341 | "metadata": {}, 342 | "outputs": [ 343 | { 344 | "data": { 345 | "text/plain": [ 346 | "" 347 | ] 348 | }, 349 | "execution_count": 36, 350 | "metadata": {}, 351 | "output_type": "execute_result" 352 | }, 353 | { 354 | "data": { 355 | "image/png": "\n", 356 | "text/plain": [ 357 | "
" 358 | ] 359 | }, 360 | "metadata": { 361 | "needs_background": "light" 362 | }, 363 | "output_type": "display_data" 364 | } 365 | ], 366 | "source": [ 367 | "sd_mu = np.mean(sd_fm)\n", 368 | "sd_sigma = np.std(sd_fm)\n", 369 | "tj_mu = np.mean(tj_fm)\n", 370 | "tj_sigma = np.std(tj_fm)\n", 371 | "\n", 372 | "tj_count, tj_bins, _ = plt.hist(tj_fm,bins=20, edgecolor='k', density=True,alpha=0.95, label='distant')\n", 373 | "sd_count, sd_bins, _ = plt.hist(sd_fm,bins=20, edgecolor='k', density=True,alpha=0.75, label='close')\n", 374 | "\n", 375 | "plt.xlabel('Sharpness', fontsize=15)\n", 376 | "plt.legend(fontsize=15)\n", 377 | "\n", 378 | "# plt.plot(tj_bins, 1./(np.sqrt(2*np.pi)*tj_sigma)*np.exp(-(tj_bins-tj_mu)**2/(2*tj_sigma**2)), lw=2, c='b')\n", 379 | "# plt.plot(sd_bins, 1./(np.sqrt(2*np.pi)*sd_sigma)*np.exp(-(sd_bins-tj_mu)**2/(2*sd_sigma**2)), lw=2, c='r')" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 37, 385 | "metadata": {}, 386 | "outputs": [ 387 | { 388 | "data": { 389 | "text/plain": [ 390 | "" 391 | ] 392 | }, 393 | "execution_count": 37, 394 | "metadata": {}, 395 | "output_type": "execute_result" 396 | }, 397 | { 398 | "data": { 399 | "image/png": "\n", 400 | "text/plain": [ 401 | "
" 402 | ] 403 | }, 404 | "metadata": { 405 | "needs_background": "light" 406 | }, 407 | "output_type": "display_data" 408 | } 409 | ], 410 | "source": [ 411 | "tj_count, tj_bins, _ = plt.hist(tj_size,bins=20, edgecolor='k', density=True,alpha=0.95, label='distant')\n", 412 | "sd_count, sd_bins, _ = plt.hist(sd_size,bins=20, edgecolor='k', density=True,alpha=0.75, label='close')\n", 413 | "\n", 414 | "plt.xlabel('Iris size', fontsize=15)\n", 415 | "plt.legend(fontsize=15)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": 38, 421 | "metadata": {}, 422 | "outputs": [ 423 | { 424 | "data": { 425 | "text/plain": [ 426 | "" 427 | ] 428 | }, 429 | "execution_count": 38, 430 | "metadata": {}, 431 | "output_type": "execute_result" 432 | }, 433 | { 434 | "data": { 435 | "image/png": "\n", 436 | "text/plain": [ 437 | "
" 438 | ] 439 | }, 440 | "metadata": { 441 | "needs_background": "light" 442 | }, 443 | "output_type": "display_data" 444 | } 445 | ], 446 | "source": [ 447 | "tj_count, tj_bins, _ = plt.hist(tj_dilation,bins=20, edgecolor='k', density=True,alpha=0.95, label='distant')\n", 448 | "sd_count, sd_bins, _ = plt.hist(sd_dilation,bins=20, edgecolor='k', density=True,alpha=0.75, label='close')\n", 449 | "\n", 450 | "plt.xlabel('Dilation', fontsize=15)\n", 451 | "plt.legend(fontsize=15)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 39, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "data": { 461 | "text/plain": [ 462 | "" 463 | ] 464 | }, 465 | "execution_count": 39, 466 | "metadata": {}, 467 | "output_type": "execute_result" 468 | }, 469 | { 470 | "data": { 471 | "image/png": "\n", 472 | "text/plain": [ 473 | "
" 474 | ] 475 | }, 476 | "metadata": { 477 | "needs_background": "light" 478 | }, 479 | "output_type": "display_data" 480 | } 481 | ], 482 | "source": [ 483 | "tj_count, tj_bins, _ = plt.hist(tj_gls,bins=20, edgecolor='k', density=True,alpha=0.95, label='distant')\n", 484 | "sd_count, sd_bins, _ = plt.hist(sd_gls,bins=20, edgecolor='k', density=True,alpha=0.75, label='close')\n", 485 | "\n", 486 | "plt.xlabel('GLS', fontsize=15)\n", 487 | "plt.legend(fontsize=15)" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 40, 493 | "metadata": {}, 494 | "outputs": [ 495 | { 496 | "data": { 497 | "text/plain": [ 498 | "" 499 | ] 500 | }, 501 | "execution_count": 40, 502 | "metadata": {}, 503 | "output_type": "execute_result" 504 | }, 505 | { 506 | "data": { 507 | "image/png": "\n", 508 | "text/plain": [ 509 | "
" 510 | ] 511 | }, 512 | "metadata": { 513 | "needs_background": "light" 514 | }, 515 | "output_type": "display_data" 516 | } 517 | ], 518 | "source": [ 519 | "tj_count, tj_bins, _ = plt.hist(tj_uar,bins=20, edgecolor='k', density=True,alpha=0.95, label='distant')\n", 520 | "sd_count, sd_bins, _ = plt.hist(sd_uar,bins=20, edgecolor='k', density=True,alpha=0.75, label='close')\n", 521 | "\n", 522 | "plt.xlabel('Usable area', fontsize=15)\n", 523 | "plt.legend(fontsize=15)" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": 68, 529 | "metadata": {}, 530 | "outputs": [ 531 | { 532 | "name": "stdout", 533 | "output_type": "stream", 534 | "text": [ 535 | "info=v2 multi_linear sim 0~1 adam l1loss\n", 536 | "dataset_path=/home/dl/wangleyuan/dataset/cx1\n", 537 | "cp_path=\n", 538 | "cp_num=5\n", 539 | "visible=True\n", 540 | "model=MobileNetV2_Lite\n", 541 | "seed=2248\n", 542 | "debug=False\n", 543 | "mask_learn_rate=0.003125\n", 544 | "mask_lr_decay=0.5\n", 545 | "upout=True\n", 546 | "batch_size=24\n", 547 | "device=cuda:2\n", 548 | "num_workers=2\n", 549 | "max_epochs=500\n", 550 | "lr=0.0004\n", 551 | "momentum=0.9\n", 552 | "weight_decay=0.0005\n", 553 | "name=1203_202301_MobileNetV2_Lite\n", 554 | "\n" 555 | ] 556 | } 557 | ], 558 | "source": [ 559 | "test_data = monoSimDataset(path=sd_dataset_path, mode='test', seed=seed, debug_data=False,upsample=True)\n", 560 | "\n", 561 | "model = MobileNetV2_Lite_shower(True, True, 0.5)\n", 562 | "assert model is not None\n", 563 | "model.to('cpu')\n", 564 | "assert cp_path is not ''\n", 565 | "cp_data = torch.load(cp_path, map_location=device)\n", 566 | "try:\n", 567 | " model.load_state_dict(cp_data['model'])\n", 568 | "except Exception as e:\n", 569 | " model.load_state_dict(cp_data['model'], strict=False)\n", 570 | " print(e)\n", 571 | "\n", 572 | "cp_data['cfg'] = '' if 'cfg' not in cp_data else cp_data['cfg']\n", 573 | "print(cp_data['cfg'])" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": 69, 579 | "metadata": {}, 580 | "outputs": [ 581 | { 582 | "name": "stdout", 583 | "output_type": "stream", 584 | "text": [ 585 | "torch.Size([1, 3, 480, 640])\n", 586 | "0002_2_1_2_22_005.bmp tensor([0.7646])\n" 587 | ] 588 | } 589 | ], 590 | "source": [ 591 | "img, mask, target, name = test_data[15]\n", 592 | "img_tensor = torch.unsqueeze(img, 0).to('cpu')\n", 593 | "mask_tensor = torch.unsqueeze(mask, 0)\n", 594 | "print(img_tensor.shape)\n", 595 | "print(name,target)\n", 596 | "\n", 597 | "# img = transforms.to_pil_image(img)\n", 598 | "# mask = transforms.to_pil_image(mask*255)\n", 599 | "# plt.imshow(mask, cmap ='gray')\n", 600 | "# mask_tensor.max()" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 125, 606 | "metadata": {}, 607 | "outputs": [ 608 | { 609 | "name": "stdout", 610 | "output_type": "stream", 611 | "text": [ 612 | "torch.Size([1, 1]) torch.Size([1, 2, 120, 160]) torch.Size([1, 32, 60, 80]) torch.Size([1, 1280, 15, 20])\n" 613 | ] 614 | } 615 | ], 616 | "source": [ 617 | "pred, heatmap, short, feat = model(img_tensor)\n", 618 | "model.eval()\n", 619 | "print(pred.shape, heatmap.shape, short.shape, feat.shape)" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": 127, 625 | "metadata": {}, 626 | "outputs": [ 627 | { 628 | "data": { 629 | "text/plain": [ 630 | "(-0.5, 159.5, 119.5, -0.5)" 631 | ] 632 | }, 633 | "execution_count": 127, 634 | "metadata": {}, 635 | "output_type": "execute_result" 636 | }, 637 | { 638 | "data": { 639 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAATRElEQVR4nO3dWXMTVxeF4aXBNtiGEBNMDElRKQoqY+Ui//9P5CIpSCVABRKHDwizp0j6LlyrdXTU3ZJsTVt+nxtZky217a3Vu8853ej1egKAiJqLfgEAcFYUMABhUcAAhEUBAxAWBQxAWO26O5vNJocoASxUt9ttVN1HAgMQFgUMQFgUMABhUcAAhEUBAxAWBQxAWBQwAGFRwACERQEDEBYFDEBYFDAAYVHAAIRFAQMQFgUMQFgUMABhUcAAhEUBAxAWBQxAWBQwAGFRwACERQEDEBYFDEBYFDAAYVHAAIRFAQMQFgUMQFgUMABhUcAAhEUBAxAWBQxAWBQwAGFRwACERQEDEBYFDEBYFDAAYVHAAIRFAQMQFgUMQFgUMABhUcAAhEUBAxAWBQxAWBQwAGFRwACERQEDEBYFDEBYFDAAYVHAAIRFAQMQFgUMQFgUMABhUcAAhEUBAxAWBQxAWBQwAGFRwACERQEDEBYFDEBYFDAAYVHAAIRFAQMQFgUMQFgUMABhUcAAhEUBAxAWBQxAWBQwAGFRwACERQEDEBYFDEBYFDAAYVHAAITVXvQLAOap1WoNXO90Ogt6JZgGChguhLxw1d1OUYuDXUgAYZHAsPKq0tekjyeZLR8SGICwSGBYWWVJqtks/8zudrtn+n5VSGvzQQIDEBYJDBdCnryqkthZ5QmOo5vzQQIDEFaj1+tV3tlsNqvvBJZUnn7StOWvG43GxN+33e7vsPz333+SpPz/J01io/pqJLLxdLvdyl8WCQxAWPTALoBJx0FJq5cOnLzW1tYGrqeqUlnZY71NvZ2cyMoeW5XEWq3Wym3neaOArZhJilX6z1bVhI78D+b312w2i+Lk2/z+0u3lr/1YX/dz0l3I4+NjSdLR0ZEk6eTkRNJgQcuLWVkhY27m+bALCSAsEtgSOssuX52qIQN1QwmcFpZtNydvmo/bjHd68rbd2NgYuN5ut4d2M/2c9fX1gdubzabevXsnSfrw4YMk6fDwUFI/mUnVu5V1zf1VSL7zRAIDEBYJbMHmkbby29LU4pRRNyxgkamgbpjPuBqNxlDyunz5sqR+qlpfXy/u86WT16VLl4ae8/z5c0nS69evJUlv3rwpflbO/THz74Mkdn4kMABhkcAWZFrJa5z+llNBehTNj8kTQ1nfJu2HSfNJBWdJXlUDWNvtdpG0nKK2trYk9dPW1tZWcd/m5qYkaXt7W5J05cqVgct2u108Nu+lvXr1qnj93m75NvXtzWZz5GDXs27zi7IkEAkMQFgksDmrS17TmGCcTpVx4sqnz/j6+vp6cZvHM5lTwyKct+9VNnHb28LJ6erVq5L6SWx7e1uffPLJwH2+fu3atYHrab8sT7Xebt1ut0g7vsx7YelrHTeJnVf6fVYhjVHAlki6a1Gl6r58NzHdPcyHBaTNat/mYQEHBweS+sMC0mIyzppZk5hGg75OOiA13VWU+kUqLVZ5ofJ9vt3X2+120bx/8eKFpH7DP90lzwtE2fYbd5uOW2zGKXSrULiMXUgAYZHAAihryFvVbuLa2trQcABfejdqe3u7SCT7+/uS+k3o9FPauz7TSGCzTl2pdPCqt4Ub8zs7O5IG05W/zpNY3sRvNBrFY3ybv7/TbirfHZ9kO06allYpXY2DBAYgLBJYIHWN+bI+lwdgOnU4baXNaScIyxvO5/1En3Xi6vV6I6cTpcMenJw+++wzSf1tsbOzU5m43Pvy9my1WkMHAZxuy4afeBuMk7wuWoI6LxIYgLBIYEssT1nphGMfbcr7W+nRNieINGVI0vXr14vrTiROXB5O4aOQR0dHQ8MvxkkJ8+x1VUknZzspeRs4gaV9L6cqbzc/xwNbneKk/jZ1uvX29++l2+2OtQ1IXOdDAgMQFgkskGazWZm88ukvV69e1Y0bNyT108bu7q6kfgq5efNmkSQ+fvwoqT8ezOOc3r9/P9QPq5veskzJK02uTlP5tkj7XnnycuLypbd1+hhvb2//cVZkJXVNDwkMQFgksDk5z1QQH1lsNBpDS7z4Mh/DtLu7W6SNW7duSVKRyNIk5uTg5PXy5UtJ0j///CPpdJkYp7No0ilTfp+ffvrpwKUT2ObmZpGmqpJXOsbL273uKGSO5DV9S1fAVmmN8PPOX8ub5+mcPv8DufHsf8ibN29Kkvb29rS3tydJ+vzzzyUN70JeuXKl+D7v37+XJH355ZeSpCdPnkg6LWT52vBVwwGWYfcx5dd96dKloeET6TaQTotVvu5XPt8xnfeYr2rh59adZwDTxy4kgLAWnsBGpZRlW5O9zLRXVc05ia2vrxe7KvmUGCevL774QtLpbuPt27cHHuPhE2nq8Pfz7qWf79T26NGj4jHRdiXTAx5+z94W3nVMU1e+SmuevHx7t9sdWtnVu6hlU4lIYrNDAgMQ1sITWGSzTl65RqNRecYc92J8mQ4LcNrIhwesr68X3y8fZuDG/+7ubrFkjM/A40SRJ+N0Ss8i+mH5WYWciq5fv14kyzyFphPeq5JX2UqvVWcaSie+L3JNtYuCBAYgLBLYOeQDO+fBn/T5Zd2SN2VHM6XT152nFh+t89HIO3fuFAnMix2mq46m0kRWNcF62sms7IxDTpw+Cvv111/r7t27kvpHYp3A0vefb5+q3lWn0ynO/+hU6rMS+fZl79uuChIYgLAWnsAWkWKmbdbvwaml1+sNLXfjNFSWwPKJ33nCSPs47qV5XJl7YPfu3SsWOcyXmXbq8PWys+zU9cmmodVqDfXxfAT1/v37kqRvv/1Wd+7ckdRPZ05rk5yHwO+l2+0WadQJzNc9Gb7X6y3duLhVtPACZhSyyfifI99lKStgdYUrL3Ju8LvZ/dVXXxWFKj/xh4uR/3k7nc7QCXLLdsdG7V6Ns/3S153vMt67d0+S9OOPP0qS7t69WxycyIdK1L2mfDfdjo+Pi/fsGQwuZN5G7ELOB7uQAMJamgS2SvJP32klsk6nMzSEoeoylaes9PY8lXlX0s38W7duDSWvfLiBE9rBwUHx2HwFCyezTqdTeoLXqtdsZWddkk6b8d5ldPL66aefJEnffPONpNPBufmwibJTr1UNDyk7YOJBvfku5LRWssV4SGAAwlq6BLaKn1ydTmdqKaxqGEXZ0Ia64RPp9fRrpyv3itw7Sh/j6TdONT6j0Zs3b4p1xKqa2+kQhLJ1u3x71Vr/+fW9vb2iWf/DDz9Ikr777jtJKhr3165dK95PfiLaOlVDVY6Pj4sE5vTpyfBpL7Iq0WF6SGAAwlq6BFZnGilmVT4N8+EUZb2XfCpM2VFI821+Trr+e36fk5cHhf7111+STpPY8+fPJUn/+9//JElv376VNHi0Lu+TlZ1RKT+nZb5Gl6cJ3b9/X99//70kFYNVPW3Ifbytra2x1uuyqgGsTrkHBwfF+3EC8xATvyeGUMwHCQxAWKES2DSkKW6eaew8Y8TSXkrVuQbz8VeputRRdV+axPIjlE5g7o952tHz58+Lflh+6WT24sWLIr04ieULCW5vbxfLBVWdLdsLON69e7f4+fk0oXTSelX6LEtbVY/xNv7w4UPxHtzjGzXNCrNBAgMQ1oVLYKm6s+ssg7I0lY+4r7qUysc6lV2WPcbSc1GmPSqpP+3IiwTu7u4WU5CcuNwf8+WzZ8+K/piPVPr7OGXt7OwUMwGc8py48iS2u7tbfF21FHTZ8jdV71eqnqTuxPjhw4ei9+Uen+9zAmMpnfm40AVsEYWr6mfW7Vr6n6HZbA4NKs2/X7oeWH6SiryRPc4/b9lr9GPy9flT/t7piXal08Lj4pYXsPREGy5KVSebTdc2KxsWMq608Hs7p4VKUjEX1EX48ePHevDggSTp6dOnkoab+b1ej93IOWAXEkBYoRLYtCZLL+MuYzrYNf/kdrI4Pj4uUk/ezDenrStXrgyctFXqJyUPW1hbWxsa2Jn/7HGa3GXrwOfDMdLGv6f+eDCok2KarspWj02vp8MsRg1OLdutLhsI7OTlBv2///4rSfrzzz8lSb///rsk6ddffy0SmE8/5/dSNqkes0MCAxBW7UfXsje5z2KZ30ueMMuSmAdKWrqUjdQfQLq1tVWknrx/lDe5R8lfh19f2eRw35YnsLQH5sZ8PpDVCXFjY6PyjEBlt1dNVi+bLF6VME9OToqeV568nLZ++eUXSadJ7O+//5Y0OJE9fU+YDxIYgLDG+giu6zkt8kjeJL2wZU5eo5T1U6o+6Z0I9vf3i+k2HnqQ98LSIRI2yTbNJ4CnS+Xkvbp0Ani6tE76M9Pn1i0BlF9OMj0o74H5tRweHhbbzunqt99+k9RPXk5i+/v7QwNZPUE97a1F/puLggQGIKxzH4Us+8Tmk+d8qhZErDuy5SNoXobm9evXxRGyR48eSdJQT+zy5ctDvaV0zJkvR6Wg/Pb0tfq9+ChiusxMbpyBtbn0tVU9tuznlY358vi0x48fS+onLiexZ8+eSTrdtvk5CTj6uBgzGUaR/8HPqqCNsyu5CsU0fQ+jiln6D+nToeWnG/Pl5ubm0KoPeUFLf+Y4hSu/rexAwbz+yeua+L70LuCrV6+KgaouXB424du91tnBwcHQXNRRJzPBbLALCSCsuQxkXcXhGIsy7jY8PDws5ul5KMOTJ08kDSaxfBBpug6YdJqg/DPHSV5V6h57lkRW95y6wbj+Op8utL+/XySuP/74Q9LgLqM0OGSFxLUcSGAAwprrVKJZJbGyNecv2idi/n5PTk4G+jtSfyhDusaWk5cHl1ZNY6pz3seMs0bXOKqGSqRDGvLelwetPn36tEhenqCdn9A3XwUXi0cCAxDWQiZz0xObvW63W/R5nDZevnwpqd8LS4dRuBfkVU29tE1ZGvJzxjm/Y/p6xn3MJPelwzLyoRvpYNW89+Xk5W3x4MGD4msnr3x6EEMklg8JDEBYC11Op9VqkcKmLN2e6TI8Un+ZGE+VkfrpwpPAb9++LUnFkjfpiqdla81Lg+PERq18Wre0jZUdNcyvpynLg0nzwaVpAnUfy6nKfa6ff/5ZkvTw4UOWxgmIBAYgrIUvaDiNfti0znq9apwcnMDM0416vV5xn6fR+CxCTmK3bt0q0li+Pn2+PE+awKrOBp7elr9OS8dZ5Wf58d9JmrKcqpyy3M/z5bt374bO4+ilch4+fCjpNJGN6n2xt7B8Fl7AjMb+9OWFIS9k6SqkHqzp5nZa0HyiDl+60e9L71K22+3K9bvSlSuqCli6C5jvDvq1+/W6WH38+HGoOPnS7+nNmzdDt/mAht/n27dvi13HfGUJLC92IQGE1ag7Bfra2tqZzo9e1ZydRNkE5rM+/6IatYrE2tpa8RgPck3X7ZJOB7veuHFDknTz5k1JGkpkfuzGxsbAqqrpZTpZ3F+b05ZT1eHh4dAQEF96t/D9+/eSTpOT05SnTjltOZm9e/eueF7+/fx9Tk5OhnYZ2XVcDt1ut1F1HwkMQFi1CazZbPakyRPQtKaGnAefmn1Vv7+06e7GvhNTmsjSNCapOOns3t6eJA2cWNbPy9fC9/XLly8PnZ/Syctp6OPHjwMNeGm4Me/L169fF4krvy9dLdWJzikrPxiQnseR5LVcSGAAVtJYCczGSWLTXjblrPj0HK1saENZIsv7Y05VOzs7A9fX19cHklZ6n5Pd5uZmcZt/Vp620p6VL/OBqGkPy/c5ZfmybLDrqIUI0+dhOZDAAKykiRKYNDqFTXKGmFniU/RsypaPdhpz76ruLNke/1X1mEuXLhW35QnMPbCjo6NiTFZ+3su8l3V0dDQ05ads8jULEMZVl8AWUsCSFzb2YyfFH+h0tFqtyt3LfC5kq9UqBq7mJ7hNC1s+jMIDR9PBqvlAVl/PdwFPTk5GzpdEbOxCAlhJC5lKNIvkxaft7FSt0urE5JTV6XSKpJTPgUyTWX6mIqes9BRlVYmrqhlf9jqx+khgAMKaaw/sLJ+QJKvFGzUlyT0xafg8kPljyk6U699x2fpbeeIyGvQXBz0wACtp4gRmdUnsPFOJ+PRcPnVTkequS4PpTCo/U7eVpa26tfBz/O2sJhIYgJU0k6OQ9LpWS/67cSI7y++5LvGXfb9xfgZ/OxfXwldk5Y8vHv/O8l3LcYrNOCf1GOdnAxK7kAACO3MT385zQg0+TS+eur8X/h5QhiY+gJW0kB4Yn7QXF797TBMJDEBY505gVUek6h4LANNAAgMQFgUMQFgUMABhTe0oZF0vjN4XgFmYegJLi1Wn06F4AZgZdiEBhDWTgaykLgDzQAIDEBYFDEBYFDAAYVHAAIRFAQMQFgUMQFgUMABhUcAAhEUBAxAWBQxAWBQwAGFRwACERQEDEBYFDEBYFDAAYVHAAIRFAQMQFgUMQFgUMABhzaSAtVqt0tOrAcA0kcAAhDXTAkYSAzBLJDAAYU29gJG4AMwLCQxAWBQwAGFNvYB1Op2Br9PrADBNJDAAYbVn8U1JXQDmgQQGICwKGICwKGAAwqKAAQiLAgYgLAoYgLAoYADCooABCIsCBiAsChiAsChgAMKigAEIq9Hr9Rb9GgDgTEhgAMKigAEIiwIGICwKGICwKGAAwqKAAQjr/6WT5uRoRce4AAAAAElFTkSuQmCC\n", 640 | "text/plain": [ 641 | "
" 642 | ] 643 | }, 644 | "metadata": { 645 | "needs_background": "light" 646 | }, 647 | "output_type": "display_data" 648 | } 649 | ], 650 | "source": [ 651 | "heatmap = torch.softmax(heatmap, 1)\n", 652 | "heatmap_img = transforms.to_pil_image(heatmap[0,0,:,:])\n", 653 | "plt.imshow(heatmap_img, cmap='gray_r')\n", 654 | "plt.axis('off')" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": 133, 660 | "metadata": {}, 661 | "outputs": [ 662 | { 663 | "data": { 664 | "text/plain": [ 665 | "(-0.5, 79.5, 59.5, -0.5)" 666 | ] 667 | }, 668 | "execution_count": 133, 669 | "metadata": {}, 670 | "output_type": "execute_result" 671 | }, 672 | { 673 | "data": { 674 | "image/png": "\n", 675 | "text/plain": [ 676 | "
" 677 | ] 678 | }, 679 | "metadata": { 680 | "needs_background": "light" 681 | }, 682 | "output_type": "display_data" 683 | } 684 | ], 685 | "source": [ 686 | "idx = 5\n", 687 | "short1 = transforms.to_pil_image(short[0,idx:idx+3,:,:])\n", 688 | "plt.imshow(short1)\n", 689 | "plt.axis('off')" 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "execution_count": 138, 695 | "metadata": {}, 696 | "outputs": [ 697 | { 698 | "name": "stdout", 699 | "output_type": "stream", 700 | "text": [ 701 | "tensor([965, 846, 533])\n" 702 | ] 703 | }, 704 | { 705 | "data": { 706 | "text/plain": [ 707 | "(-0.5, 19.5, 14.5, -0.5)" 708 | ] 709 | }, 710 | "execution_count": 138, 711 | "metadata": {}, 712 | "output_type": "execute_result" 713 | }, 714 | { 715 | "data": { 716 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAATAAAADnCAYAAACZtwrQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAEQ0lEQVR4nO3cP6vWZRjA8fPPTFxCo8HhOOSfJpNoCqFJBKeWlpRWEXwJ7r6BIHAUmnwBDg1Bg0vQ0hKJvgD1CEpHNHrO4y6pcHny8MXPZ/1xcd/T97nhgWt1uVyuABSt7fUFAKYEDMgSMCBLwIAsAQOyNl738ezat/6iBPbUzzs3V1/1zQsMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALAEDsjb2+gLwX9YPHxrP/vvZ5mhu44974zMXT56MZ5nzAgOyBAzIEjAgS8CALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOybKN4j6wdPDia29ne3uWbvNli69F49u7lo6O5ne2T4zNPXPptPMucFxiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAlnU6Mesnj41nN67/PZp7/vW7X6fzNvbd+3A0t3nr6S7fhP+bFxiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQZRtFzcNH49HFd7MtDXdufDE+8+hPs9/I/VvPxmde/OaX0dyZC3+Nz7z26anxLHNeYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZ1unELLbm63Sma3GO/bgYn7k4sDqa++ej/eMzr37852ju/Odnx2eurDx4i1mmvMCALAEDsgQMyBIwIEvAgCwBA7IEDMgSMCBLwIAsAQOyBAzIEjAgS8CALNso3iPHv/99NHf/ylfjMx9/+Xw0d/jXD8Znnjtyejhpo0SNFxiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAlnU6vNEnP9yez+7iPeBlXmBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYECWgAFZAgZkCRiQJWBAloABWQIGZAkYkCVgQJaAAVkCBmQJGJAlYEDW6nK53Os7AIx4gQFZAgZkCRiQJWBAloABWQIGZL0ANMtC29HXiAAAAAAASUVORK5CYII=\n", 717 | "text/plain": [ 718 | "
" 719 | ] 720 | }, 721 | "metadata": { 722 | "needs_background": "light" 723 | }, 724 | "output_type": "display_data" 725 | } 726 | ], 727 | "source": [ 728 | "\n", 729 | "featsum,index = torch.topk(feat[0].sum((1,2)), 7)\n", 730 | "print(index[-3:])\n", 731 | "idx = index[-1]\n", 732 | "feat1 = transforms.to_pil_image(feat[0,idx,:,:])\n", 733 | "plt.imshow(feat1)\n", 734 | "plt.axis('off')" 735 | ] 736 | }, 737 | { 738 | "cell_type": "code", 739 | "execution_count": null, 740 | "metadata": {}, 741 | "outputs": [], 742 | "source": [] 743 | } 744 | ], 745 | "metadata": { 746 | "kernelspec": { 747 | "display_name": "Python 3", 748 | "language": "python", 749 | "name": "python3" 750 | }, 751 | "language_info": { 752 | "codemirror_mode": { 753 | "name": "ipython", 754 | "version": 3 755 | }, 756 | "file_extension": ".py", 757 | "mimetype": "text/x-python", 758 | "name": "python", 759 | "nbconvert_exporter": "python", 760 | "pygments_lexer": "ipython3", 761 | "version": "3.7.3" 762 | } 763 | }, 764 | "nbformat": 4, 765 | "nbformat_minor": 2 766 | } -------------------------------------------------------------------------------- /evaluation/result/eer_irr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/eer_irr.png -------------------------------------------------------------------------------- /evaluation/result/eer_irr_SD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/eer_irr_SD.png -------------------------------------------------------------------------------- /evaluation/result/eer_irr_TJ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/eer_irr_TJ.png -------------------------------------------------------------------------------- /evaluation/result/eer_irr_hd_SD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/eer_irr_hd_SD.png -------------------------------------------------------------------------------- /evaluation/result/eer_irr_hd_TJ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/eer_irr_hd_TJ.png -------------------------------------------------------------------------------- /evaluation/result/hd_SD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/hd_SD.png -------------------------------------------------------------------------------- /evaluation/result/heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/heatmap.png -------------------------------------------------------------------------------- /evaluation/result/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/network.png -------------------------------------------------------------------------------- /evaluation/result/result.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/result.pkl -------------------------------------------------------------------------------- /evaluation/result/result_tj.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/evaluation/result/result_tj.pkl -------------------------------------------------------------------------------- /log/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/model/__init__.py -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CrossEntropy2d(nn.Module): 7 | 8 | def __init__(self, size_average="mean", ignore_label=255): 9 | super(CrossEntropy2d, self).__init__() 10 | self.size_average = size_average 11 | self.ignore_label = ignore_label 12 | 13 | def forward(self, predict, target, weight=None): 14 | ''' 15 | :param predict: (n, c, h, w) 16 | :param target: (n,[c=1], h, w) 17 | :param weight: (Tensor, optional): a manual rescaling weight given to each class. 18 | If given, has to be a Tensor of size "nclasses" 19 | :return: Loss, refer to deepblue 20 | ''' 21 | target = target.squeeze() 22 | assert not target.requires_grad 23 | assert predict.dim() == 4 24 | assert target.dim() == 3 25 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 26 | assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) 27 | assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3)) 28 | n, c, h, w = predict.size() 29 | target_mask = (target >= 0) * (target != self.ignore_label) 30 | target = target[target_mask].to(torch.long) 31 | if not target.data.dim(): 32 | return torch.zeros(1) 33 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 34 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 35 | loss = F.cross_entropy(predict, target, weight=weight, reduction=self.size_average) 36 | return loss 37 | 38 | 39 | class FocalLoss(nn.Module): 40 | def __init__(self, gamma=2, alpha=0.25, size_average=True): 41 | super(FocalLoss, self).__init__() 42 | 43 | """ 44 | Compute focal loss for predictions. 45 | Multi-labels Focal loss formula: 46 | FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p) 47 | ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor. 48 | """ 49 | self.gamma = gamma 50 | self.alpha = alpha 51 | self.size_average = size_average 52 | 53 | def forward(self, p, target): 54 | """ 55 | :param input: (n, c, h, w) 56 | :param target: (n, c, h, w) 57 | :return: focal loss 58 | """ 59 | p = torch.sigmoid(p) 60 | mask = (target == 0) 61 | pos_p = target.type(torch.float) - p 62 | pos_p[mask] = 0 63 | neg_p = p.new_tensor(p.data) 64 | neg_p[~mask] = 0 65 | 66 | loss = -1 * self.alpha * pos_p ** self.gamma * (p + 1e-8).log() \ 67 | - (1 - self.alpha) * neg_p ** self.gamma * (1 - p + 1e-8).log() 68 | 69 | if self.size_average: 70 | return loss.mean() 71 | else: 72 | return loss.sum() 73 | -------------------------------------------------------------------------------- /model/pretrained/1203_202301_MobileNetV2_Lite_CX1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/model/pretrained/1203_202301_MobileNetV2_Lite_CX1.pth -------------------------------------------------------------------------------- /model/pretrained/1211_202056_MobileNetV2_Lite_CX2.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Debatrix/DFSNet/4e9d18ed8fe7f92dd89f3ea968389108d96638cb/model/pretrained/1211_202056_MobileNetV2_Lite_CX2.pth -------------------------------------------------------------------------------- /model/quality_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import torch.nn.functional as F 7 | import torchvision 8 | 9 | 10 | class MobileNetV2_encoder(nn.Module): 11 | def __init__(self, pretrained=True): 12 | super(MobileNetV2_encoder, self).__init__() 13 | model = torchvision.models.mobilenet_v2(pretrained=True) 14 | self.low_feature = model.features[:5] 15 | self.high_feature = model.features[5:] 16 | 17 | def forward(self, input): 18 | out1 = self.low_feature(input) 19 | out2 = self.high_feature(out1) 20 | return out1, out2 21 | 22 | 23 | class LRASPPV2(nn.Module): 24 | """Lite R-ASPP""" 25 | def __init__(self, nclass=2): 26 | super(LRASPPV2, self).__init__() 27 | self.b0 = nn.Sequential(nn.Conv2d(1280, 128, 1, bias=False), 28 | nn.BatchNorm2d(128), nn.ReLU(True)) 29 | self.b1 = nn.Sequential( 30 | nn.AdaptiveAvgPool2d((1, 1)), 31 | nn.Conv2d(1280, 128, 1, bias=False), 32 | nn.Sigmoid(), 33 | ) 34 | 35 | self.project = nn.Conv2d(128, nclass, 1) 36 | self.shortcut = nn.Conv2d(32, nclass, 1) 37 | 38 | self.init_params() 39 | 40 | def init_params(self): 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv2d): 43 | init.kaiming_normal_(m.weight, mode='fan_out') 44 | if m.bias is not None: 45 | init.constant_(m.bias, 0) 46 | elif isinstance(m, nn.BatchNorm2d): 47 | init.constant_(m.weight, 1) 48 | init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.Linear): 50 | init.normal_(m.weight, std=0.001) 51 | if m.bias is not None: 52 | init.constant_(m.bias, 0) 53 | 54 | def forward(self, x, y): 55 | size = x.shape[2:] 56 | feat1 = self.b0(x) 57 | feat2 = self.b1(x) 58 | feat2 = F.interpolate(feat2, size, mode='bilinear', align_corners=True) 59 | x = feat1 * feat2 # check it 60 | x = self.project(x) 61 | y = self.shortcut(y) 62 | out = F.adaptive_avg_pool2d(y, size) + x 63 | return out 64 | 65 | 66 | class MobileNetV2_Lite(nn.Module): 67 | def __init__(self, pretrained=True, mask_learn_rate=0.5): 68 | super(MobileNetV2_Lite, self).__init__() 69 | self.encoder = MobileNetV2_encoder(pretrained) 70 | 71 | self.decoder = LRASPPV2() 72 | 73 | self.linear = nn.Sequential( 74 | nn.Linear(1280, 512, True), 75 | nn.Dropout(0.5), 76 | nn.ReLU(), 77 | nn.Linear(512, 64, True), 78 | nn.Dropout(0.5), 79 | nn.ReLU(), 80 | nn.Linear(64, 1), 81 | ) 82 | self.init_params(self.linear) 83 | 84 | for p in self.parameters(): 85 | p.requires_grad = True 86 | if mask_learn_rate == 1: 87 | for p in self.linear.parameters(): 88 | p.requires_grad = False 89 | elif mask_learn_rate == 0: 90 | for p in self.encoder.parameters(): 91 | p.requires_grad = False 92 | for p in self.decoder.parameters(): 93 | p.requires_grad = False 94 | 95 | def init_params(self, target): 96 | for m in target: 97 | if isinstance(m, nn.Conv2d): 98 | init.kaiming_normal_(m.weight, mode='fan_out') 99 | if m.bias is not None: 100 | init.constant_(m.bias, 0) 101 | elif isinstance(m, nn.BatchNorm1d): 102 | init.constant_(m.weight, 1) 103 | init.constant_(m.bias, 0) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | init.constant_(m.weight, 1) 106 | init.constant_(m.bias, 0) 107 | elif isinstance(m, nn.Linear): 108 | init.normal_(m.weight, std=0.001) 109 | if m.bias is not None: 110 | init.constant_(m.bias, 0) 111 | 112 | def forward(self, input): 113 | low_fea, high_fea = self.encoder(input) 114 | 115 | mask = self.decoder(high_fea, low_fea) 116 | att_mask = torch.unsqueeze(torch.softmax(mask, 1)[:, 1, :, :], 1) 117 | out_mask = nn.functional.interpolate( 118 | mask, (input.shape[2] // 4, input.shape[3] // 4), 119 | mode='bilinear', 120 | align_corners=True) 121 | 122 | pred = torch.sum(high_fea * att_mask, 123 | dim=(2, 3)) / (torch.sum(att_mask, dim=(2, 3)) + 1e-8) 124 | pred = pred.view(pred.size(0), -1) 125 | pred = self.linear(pred) 126 | return pred, out_mask 127 | 128 | 129 | if __name__ == '__main__': 130 | vn = MobileNetV2_Lite(True, False) 131 | c = torch.randn(2, 3, 640, 480) 132 | out = vn(c) 133 | 134 | print(out[0], out[1].shape) 135 | from thop import profile, clever_format 136 | 137 | flops, params = profile(vn, inputs=(c, )) 138 | flops, params = clever_format([flops, params], "%.3f") 139 | print('flops:{} params:{}'.format(flops, params)) 140 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from argparse import ArgumentParser 3 | 4 | import pickle 5 | import numpy as np 6 | import os.path as osp 7 | from PIL import Image 8 | from glob import glob 9 | from tqdm import tqdm 10 | from scipy import stats 11 | 12 | import torch 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | from torchvision.transforms import transforms 16 | 17 | from dataset import monoSimDataset 18 | from model.quality_model import MobileNetV2_Lite 19 | 20 | 21 | class LoadConfig(object): 22 | def __init__(self): 23 | 24 | self.mode = 'test' 25 | self.dataset_path = 'data/cx1' 26 | self.model_path = "model/pretrained/1211_202056_MobileNetV2_Lite_cx2.pth" 27 | self.cfg.result_path = "" 28 | 29 | self.seed = 2248 30 | 31 | self.batch_size = 24 32 | self.device = "cuda:2" 33 | self.num_workers = 2 34 | 35 | self._change_cfg() 36 | 37 | def _change_cfg(self): 38 | parser = ArgumentParser() 39 | for name, value in vars(self).items(): 40 | parser.add_argument('--' + name, type=type(value), default=value) 41 | args = parser.parse_args() 42 | 43 | for name, value in vars(args).items(): 44 | if self.__dict__[name] != value: 45 | self.__dict__[name] = value 46 | 47 | 48 | def test(cfg): 49 | # cpu or gpu? 50 | if torch.cuda.is_available() and cfg.device is not None: 51 | device = torch.device(cfg.device) 52 | else: 53 | if not torch.cuda.is_available(): 54 | print("hey man, buy a GPU!") 55 | device = torch.device("cpu") 56 | 57 | # data 58 | print('Loading Data') 59 | test_data = monoSimDataset(path=cfg.dataset_path, 60 | mode='test', 61 | seed=cfg.seed, 62 | debug_data=False) 63 | test_data_loader = DataLoader(test_data, 64 | cfg.batch_size, 65 | shuffle=False, 66 | drop_last=True, 67 | num_workers=cfg.num_workers) 68 | 69 | # configure model 70 | print('Loading Model') 71 | model = MobileNetV2_Lite() 72 | model.to(device) 73 | if cfg.model_path: 74 | cp_data = torch.load(cfg.model_path, map_location=device) 75 | try: 76 | model.load_state_dict(cp_data['model']) 77 | except Exception as e: 78 | model.load_state_dict(cp_data['model'], strict=False) 79 | print(e) 80 | 81 | cp_data['cfg'] = '' if 'cfg' not in cp_data else cp_data['cfg'] 82 | print(cp_data['cfg']) 83 | 84 | # Start! 85 | model.eval() 86 | with torch.no_grad(): 87 | test_pred_loss = 0 88 | scores = np.zeros((1)) 89 | prediction = np.zeros((1)) 90 | for img, mask, target, _ in tqdm( 91 | test_data_loader, 92 | desc='Test', 93 | bar_format='{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'): 94 | img = img.to(device) 95 | target = target.to(device) 96 | pred, _ = model(img) 97 | test_pred_loss += nn.functional.mse_loss(pred, 98 | target, 99 | reduction='sum') 100 | scores = np.append(scores, target.cpu().numpy().reshape((-1))) 101 | prediction = np.append(prediction, 102 | pred.cpu().numpy().reshape((-1))) 103 | test_pred_loss = test_pred_loss / len(test_data) 104 | prediction = np.nan_to_num(prediction) 105 | srocc = stats.spearmanr(prediction[1:], scores[1:])[0] 106 | lcc = stats.pearsonr(prediction[1:], scores[1:])[0] 107 | 108 | print("Test - MSE: {:.4e}".format(test_pred_loss)) 109 | print("Test - LCC: {:.4f}, SROCC: {:.4f}".format(lcc, srocc)) 110 | 111 | 112 | def predict(cfg): 113 | # cpu or gpu? 114 | if torch.cuda.is_available() and cfg.device is not None: 115 | device = torch.device(cfg.device) 116 | else: 117 | if not torch.cuda.is_available(): 118 | print("hey man, buy a GPU!") 119 | device = torch.device("cpu") 120 | 121 | # data 122 | img_list = glob(osp.join(cfg.dataset_path, '*.bmp')) + glob( 123 | osp.join(cfg.dataset_path, '*.png')) + glob( 124 | osp.join(cfg.dataset_path, '*.jpg')) 125 | transform = transforms.Compose([ 126 | transforms.Resize(size), 127 | transforms.ToTensor(), 128 | transforms.Normalize(mean=[0.480], std=[0.200], inplace=False) 129 | ]) 130 | 131 | # configure model 132 | print('Loading Model') 133 | model = MobileNetV2_Lite() 134 | model.to(device) 135 | if cfg.model_path: 136 | cp_data = torch.load(cfg.model_path, map_location=device) 137 | try: 138 | model.load_state_dict(cp_data['model']) 139 | except Exception as e: 140 | model.load_state_dict(cp_data['model'], strict=False) 141 | print(e) 142 | 143 | cp_data['cfg'] = '' if 'cfg' not in cp_data else cp_data['cfg'] 144 | print(cp_data['cfg']) 145 | 146 | # Start! 147 | model.eval() 148 | prediction = {} 149 | with torch.no_grad(): 150 | for path in tqdm(img_list): 151 | img_name = osp.basename(path).split('.')[0] 152 | img = transform(Image.open(path)) 153 | img = img.to(device) 154 | target = target.to(device) 155 | pred, heatmap = model(img) 156 | prediction[img_name] = pred.cpu().numpy().reshape((-1)) 157 | if cfg.result_path: 158 | heatmap = torch.softmax(heatmap, 0)[1, :, :].cpu().numpy() 159 | heatmap = Image.fromarray(heatmap) 160 | heatmap.save( 161 | osp.join(cfg.result_path, img_name + '_heatmap.png')) 162 | if cfg.result_path: 163 | pickle.dump(prediction, osp.join(cfg.result_path, 'prediction.pkl')) 164 | return prediction 165 | 166 | 167 | if __name__ == '__main__': 168 | cfg = LoadConfig() 169 | if cfg.mode == 'test': 170 | test(cfg) 171 | elif cfg.mode == 'predict': 172 | prediction = predict(cfg) 173 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import time 4 | from argparse import ArgumentParser 5 | 6 | import numpy as np 7 | import torch 8 | from scipy import stats 9 | from torch import nn 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | from dataset import monoSimDataset 15 | from model.loss import FocalLoss, CrossEntropy2d 16 | from model.quality_model import MobileNetV2_Lite 17 | 18 | 19 | class LoadConfig(object): 20 | def __init__(self): 21 | self.info = "" 22 | 23 | self.dataset_path = 'data/cx2' 24 | self.cp_path = "checkpoints/1203_202301_MobileNetV2_Lite/421_1.3395e-03.pth" 25 | self.cp_num = 5 26 | self.visible = True 27 | 28 | self.model = 'MobileNetV2_Lite' 29 | self.seed = 2248 30 | self.debug = False 31 | self.mask_learn_rate = 0.5 32 | self.mask_lr_decay = 0.1 33 | 34 | self.batch_size = 24 35 | self.device = "cuda:2" 36 | self.num_workers = 2 37 | 38 | self.max_epochs = 150 39 | self.lr = 4e-4 40 | self.momentum = 0.9 41 | self.weight_decay = 5e-4 42 | 43 | self._change_cfg() 44 | 45 | def _change_cfg(self): 46 | parser = ArgumentParser() 47 | for name, value in vars(self).items(): 48 | parser.add_argument('--' + name, type=type(value), default=value) 49 | args = parser.parse_args() 50 | 51 | for name, value in vars(args).items(): 52 | if self.__dict__[name] != value: 53 | self.__dict__[name] = value 54 | 55 | if self.debug: 56 | self.cp_num = 0 57 | self.visible = False 58 | 59 | def __str__(self): 60 | config = "" 61 | for name, value in vars(self).items(): 62 | config += ('%s=%s\n' % (name, value)) 63 | return config 64 | 65 | 66 | def train(cfg): 67 | # configure train 68 | train_name = time.strftime("%m%d_%H%M%S", time.localtime( 69 | )) + '_' + cfg.model + '_' + os.path.basename(cfg.dataset_path) 70 | cfg.name = train_name 71 | log_interval = int(np.ceil(cfg.max_epochs * 0.1)) 72 | print(cfg) 73 | 74 | # cpu or gpu? 75 | if torch.cuda.is_available() and cfg.device is not None: 76 | device = torch.device(cfg.device) 77 | else: 78 | if not torch.cuda.is_available(): 79 | print("hey man, buy a GPU!") 80 | device = torch.device("cpu") 81 | 82 | # data 83 | print('Loading Data') 84 | train_data = monoSimDataset(path=cfg.dataset_path, 85 | mode='train', 86 | seed=cfg.seed, 87 | debug_data=cfg.debug) 88 | train_data_loader = DataLoader(train_data, 89 | cfg.batch_size, 90 | drop_last=True, 91 | shuffle=True, 92 | num_workers=cfg.num_workers) 93 | val_data = monoSimDataset(path=cfg.dataset_path, 94 | mode='val', 95 | seed=cfg.seed, 96 | debug_data=cfg.debug) 97 | val_data_loader = DataLoader(val_data, 98 | cfg.batch_size, 99 | shuffle=False, 100 | drop_last=True, 101 | num_workers=cfg.num_workers) 102 | 103 | # configure model 104 | print('Loading Model') 105 | model = MobileNetV2_Lite(True, cfg.mask_learn_rate) 106 | assert model is not None 107 | model.to(device) 108 | if cfg.cp_path: 109 | cp_data = torch.load(cfg.cp_path, map_location=device) 110 | try: 111 | model.load_state_dict(cp_data['model']) 112 | except Exception as e: 113 | model.load_state_dict(cp_data['model'], strict=False) 114 | print(e) 115 | 116 | cp_data['cfg'] = '' if 'cfg' not in cp_data else cp_data['cfg'] 117 | print(cp_data['cfg']) 118 | 119 | # criterion and optimizer 120 | optimizer = torch.optim.Adam( 121 | filter(lambda p: p.requires_grad, model.parameters()), 122 | lr=cfg.lr, 123 | # momentum=cfg.momentum, 124 | weight_decay=cfg.weight_decay) 125 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 126 | 'min', 127 | factor=0.5, 128 | verbose=True) 129 | 130 | pred_criterion = nn.MSELoss() 131 | mask_criterion = CrossEntropy2d() 132 | 133 | # checkpoint 134 | if cfg.cp_num > 0: 135 | cp_dir_path = os.path.normcase(os.path.join('checkpoints', train_name)) 136 | os.mkdir(cp_dir_path) 137 | best_cp = [] 138 | history_dir_path = os.path.normcase( 139 | os.path.join(cp_dir_path, 'history')) 140 | os.mkdir(history_dir_path) 141 | with open(os.path.normcase(os.path.join(cp_dir_path, 'config.txt')), 142 | 'w') as f: 143 | info = str(cfg) + '#' * 30 + '\npre_cfg:\n' + str( 144 | cp_data['cfg']) if cfg.cp_path else str(cfg) 145 | f.write(info) 146 | 147 | # visble 148 | if cfg.visible: 149 | log_writer = SummaryWriter(os.path.join("log", train_name)) 150 | log_writer.add_text('cur_cfg', cfg.__str__()) 151 | if cfg.cp_path: 152 | log_writer.add_text('pre_cfg', cp_data['cfg'].__str__()) 153 | 154 | # Start! 155 | print("Start training!\n") 156 | for epoch in range(1, cfg.max_epochs + 1): 157 | if epoch % int(cfg.max_epochs / 10) == 0 and cfg.mask_lr_decay < 1: 158 | cfg.mask_learn_rate *= cfg.mask_lr_decay 159 | print("[{}] Mask learn rate: {:.4e}".format( 160 | epoch, cfg.mask_learn_rate)) 161 | 162 | # train 163 | model.train() 164 | epoch_loss = 0 165 | for img, mask, target in tqdm( 166 | train_data_loader, 167 | desc='[{}] mini_batch'.format(epoch), 168 | bar_format='{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'): 169 | img = img.to(device) 170 | mask = mask.to(device) 171 | target = target.to(device) 172 | optimizer.zero_grad() 173 | pred, heatmap = model(img) 174 | if cfg.mask_learn_rate == 0: 175 | loss = pred_criterion(pred, target) 176 | elif cfg.mask_learn_rate == 0: 177 | loss = mask_criterion(heatmap, mask) 178 | else: 179 | loss = (1 - cfg.mask_learn_rate) * pred_criterion( 180 | pred, target) + cfg.mask_learn_rate * mask_criterion( 181 | heatmap, mask) 182 | epoch_loss += loss.item() 183 | loss.backward() 184 | optimizer.step() 185 | train_loss = epoch_loss / len(train_data_loader) 186 | scheduler.step(train_loss) 187 | 188 | print("[{}] Training - loss: {:.4e}".format(epoch, train_loss)) 189 | if cfg.visible: 190 | log_writer.add_scalar('Train/Loss', train_loss, epoch) 191 | log_writer.add_scalar('Train/lr', optimizer.param_groups[0]['lr'], 192 | epoch) 193 | 194 | # val 195 | if epoch % 5 == 0 or cfg.debug: 196 | if cfg.model.split('_')[0] == 'MobileNetV3': 197 | model.train() 198 | else: 199 | model.eval() 200 | with torch.no_grad(): 201 | val_pred_loss = 0 202 | scores = np.zeros((1)) 203 | prediction = np.zeros((1)) 204 | for img, mask, target in tqdm( 205 | val_data_loader, 206 | desc='[{}] val_batch'.format(epoch), 207 | bar_format= 208 | '{desc}: {n_fmt}/{total_fmt} -{percentage:3.0f}%'): 209 | img = img.to(device) 210 | mask = mask.to(device) 211 | target = target.to(device) 212 | pred, heatmap = model(img) 213 | val_pred_loss += nn.functional.mse_loss(pred, 214 | target, 215 | reduction='sum') 216 | scores = np.append(scores, 217 | target.cpu().numpy().reshape((-1))) 218 | prediction = np.append(prediction, 219 | pred.cpu().numpy().reshape((-1))) 220 | val_pred_loss = val_pred_loss / len(val_data) 221 | prediction = np.nan_to_num(prediction) 222 | srocc = stats.spearmanr(prediction[1:], scores[1:])[0] 223 | lcc = stats.pearsonr(prediction[1:], scores[1:])[0] 224 | 225 | print("[{}] Val - MSE: {:.4e}".format(epoch, val_pred_loss)) 226 | print("[{}] Val - LCC: {:.4f}, SROCC: {:.4f}".format( 227 | epoch, lcc, srocc)) 228 | if cfg.visible: 229 | idx = np.random.randint(0, mask.shape[0]) 230 | heatmap_s = torch.softmax(heatmap, 1)[idx, 1, :, :] 231 | log_writer.add_scalar('Val/MSE', val_pred_loss, epoch) 232 | log_writer.add_scalar('Val/LCC', lcc, epoch) 233 | log_writer.add_scalar('Val/SROCC', srocc, epoch) 234 | log_writer.add_image('Val/img', img[idx], epoch) 235 | log_writer.add_image('Val/mask', 236 | torch.squeeze(mask[idx]), 237 | epoch, 238 | dataformats='HW') 239 | log_writer.add_image('Val/heatmap', 240 | torch.squeeze(heatmap_s), 241 | epoch, 242 | dataformats='HW') 243 | 244 | # checkpoint 245 | if cfg.cp_num > 0: 246 | # model.cpu() 247 | cp_name = "{}_{:.4e}.pth".format(epoch, train_loss) 248 | 249 | if epoch < cfg.cp_num + 1: 250 | best_cp.append([cp_name, train_loss]) 251 | best_cp.sort(key=lambda x: x[1]) 252 | best_cp_path = os.path.normcase( 253 | os.path.join(cp_dir_path, cp_name)) 254 | 255 | cp_data = dict( 256 | cfg=str(cfg), 257 | model=model.state_dict(), 258 | ) 259 | torch.save(cp_data, best_cp_path) 260 | else: 261 | if train_loss < best_cp[-1][1]: 262 | os.remove( 263 | os.path.normcase( 264 | os.path.join(cp_dir_path, best_cp[-1][0]))) 265 | best_cp[-1] = [cp_name, train_loss] 266 | best_cp.sort(key=lambda x: x[1]) 267 | best_cp_path = os.path.normcase( 268 | os.path.join(cp_dir_path, cp_name)) 269 | cp_data = dict( 270 | cfg=str(cfg), 271 | model=model.state_dict(), 272 | ) 273 | torch.save(cp_data, best_cp_path) 274 | 275 | if ((log_interval > 0) and (epoch % log_interval == 0 or epoch % 100 == 0)) or \ 276 | (epoch == cfg.max_epochs): 277 | history_cp_path = os.path.normcase( 278 | os.path.join(history_dir_path, cp_name)) 279 | cp_data = dict( 280 | cfg=str(cfg), 281 | model=model.state_dict(), 282 | ) 283 | torch.save(cp_data, history_cp_path) 284 | 285 | # model.to(device) 286 | 287 | return model.cpu() 288 | 289 | 290 | if __name__ == '__main__': 291 | cfg = LoadConfig() 292 | model = train(cfg) 293 | -------------------------------------------------------------------------------- /util/IrisQualityEvaluation.cpp: -------------------------------------------------------------------------------- 1 | #include "IrisQualityEvaluation.h" 2 | 3 | #include 4 | 5 | #ifndef maximum 6 | #define maximum(a,b) (((a) > (b)) ? (a) : (b)) 7 | #endif 8 | 9 | #ifndef minimum 10 | #define minimum(a,b) (((a) < (b)) ? (a) : (b)) 11 | #endif 12 | 13 | //************************************************************************************************************************* 14 | double qeFocusMeasure(cv::Mat &srcImage, cv::Rect Roi, int downsample_factor) 15 | { 16 | assert(downsample_factor > 0); 17 | assert(srcImage.channels() == 1); 18 | 19 | cv::Mat roiImage, gradientX, gradientY; 20 | 21 | if (Roi.area() > 0) 22 | { 23 | srcImage(Roi).copyTo(roiImage); 24 | } 25 | else 26 | { 27 | srcImage.copyTo(roiImage); 28 | } 29 | 30 | if (downsample_factor != 1) 31 | { 32 | double resize_ratio = 1 / (double)downsample_factor; 33 | cv::resize(roiImage, roiImage, cv::Size(), resize_ratio, resize_ratio); 34 | } 35 | 36 | //有一定帮助 但是速度不可接受 37 | //cv::equalizeHist(roiImage, roiImage); 38 | 39 | cv::Sobel(roiImage, gradientX, CV_32F, 1, 0); 40 | cv::Sobel(roiImage, gradientY, CV_32F, 0, 1); 41 | 42 | cv::pow(gradientX, 2, gradientX); 43 | cv::pow(gradientY, 2, gradientY); 44 | roiImage = gradientX + gradientY; 45 | double fm = -1; 46 | cv::sqrt(roiImage, roiImage); 47 | fm = cv::mean(roiImage)[0]; 48 | 49 | roiImage.release(); 50 | gradientX.release(); 51 | gradientY.release(); 52 | 53 | return fm; 54 | } 55 | 56 | cv::Rect qeFaceLocation(cv::Mat &srcImage, int threshold, int downsample_factor, double ystart, double hrange) 57 | { 58 | 59 | assert(downsample_factor > 0); 60 | assert(srcImage.channels() == 1); 61 | 62 | int width = srcImage.cols; 63 | int height = srcImage.rows; 64 | 65 | int x_range = int(width / downsample_factor); 66 | int y_range = int(height / downsample_factor); 67 | 68 | std::vector w_sum; 69 | std::vector h_sum; 70 | 71 | for (int x = 0; x < x_range; x++) 72 | { 73 | w_sum.push_back(0); 74 | } 75 | for (int y = 0; y < y_range; y++) 76 | { 77 | h_sum.push_back(0); 78 | } 79 | 80 | double color_count[256]; 81 | for (int i = 0; i < 256; i++) 82 | { 83 | color_count[i] = 0; 84 | } 85 | int value = 0; 86 | for (int x = 0; x < x_range; x++) 87 | { 88 | for (int y = 0; y < y_range; y++) 89 | { 90 | value = srcImage.data[(y * downsample_factor) * width + (x * downsample_factor)]; 91 | color_count[value] += 1; 92 | } 93 | } 94 | 95 | 96 | // 分别沿xy轴统计亮度大于阈值的像素数 97 | //int value = 0; 98 | for (int x = 0; x < x_range; x++) 99 | { 100 | for (int y = 0; y < y_range; y++) 101 | { 102 | value = srcImage.data[(y * downsample_factor) * width + (x * downsample_factor)]; 103 | //value = int(srcImage.at(y * downsample_factor, x * downsample_factor)); 104 | if (value > threshold) 105 | { 106 | w_sum[x] += 1; 107 | h_sum[y] += 1; 108 | } 109 | } 110 | } 111 | 112 | // 搜索选择区域范围 113 | int w_start = 0; 114 | int w_end = 0; 115 | for (int x = 1; x < x_range; x++) 116 | { 117 | if (w_sum[x] > 0) 118 | { 119 | w_start = x; 120 | break; 121 | } 122 | } 123 | for (int x = x_range - 1; x >= 0; x--) 124 | { 125 | if (w_sum[x] > 0) 126 | { 127 | w_end = x; 128 | break; 129 | } 130 | } 131 | int h_start = 0; 132 | int h_end = 0; 133 | for (int y = 1; y < y_range; y++) 134 | { 135 | if (h_sum[y] > 0) 136 | { 137 | h_start = y; 138 | break; 139 | } 140 | } 141 | for (int y = y_range - 1; y >= 0; y--) 142 | { 143 | if (h_sum[y] > 0) 144 | { 145 | h_end = y; 146 | break; 147 | } 148 | } 149 | 150 | // 计算roi 151 | int roi_x = w_start * downsample_factor; 152 | int roi_w = (w_end - w_start) * downsample_factor; 153 | int roi_y = int((h_start + (h_end - h_start) * ystart) * downsample_factor); 154 | int roi_h = int((h_end - h_start) * hrange * downsample_factor); 155 | 156 | cv::Rect roi = cv::Rect(roi_x, roi_y, roi_w, roi_h); 157 | 158 | return roi; 159 | } 160 | 161 | void qeMaskDenoise(cv::Mat &srcMask) 162 | { 163 | assert(srcMask.channels() == 1); 164 | 165 | cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(5, 5)); 166 | cv::morphologyEx(srcMask, srcMask, cv::MORPH_OPEN, kernel); 167 | cv::morphologyEx(srcMask, srcMask, cv::MORPH_CLOSE, kernel); 168 | 169 | kernel.release(); 170 | } 171 | 172 | cv::Rect qeIrisLocation(cv::Mat &srcIris) 173 | { 174 | int width = srcIris.cols; 175 | int height = srcIris.rows; 176 | 177 | cv::Rect roi = qeFaceLocation(srcIris, 128, 4, 0, 1); 178 | //roi = roi + cv::Point(-8, -5); 179 | //roi = roi + cv::Size(16, 10); 180 | 181 | return roi; 182 | } 183 | 184 | double qeImageEntropy(cv::Mat& srcImage, int downsample_factor) 185 | { 186 | 187 | assert(srcImage.channels() == 1); 188 | 189 | double ent = 0.0; 190 | int imgValue = 0; 191 | double p = 0.0; 192 | 193 | int width = srcImage.cols; 194 | int height = srcImage.rows; 195 | 196 | int x_range = int(width / downsample_factor); 197 | int y_range = int(height / downsample_factor); 198 | 199 | double color_count[256]; 200 | for (int i = 0; i < 256; i++) 201 | { 202 | color_count[i] = 0; 203 | } 204 | 205 | for (int x = 0; x < x_range; x++) 206 | { 207 | for (int y = 0; y < y_range; y++) 208 | { 209 | imgValue = srcImage.data[(y * downsample_factor) * width + (x * downsample_factor)]; 210 | color_count[imgValue] += 1; 211 | } 212 | } 213 | for (int i = 0; i < 256; i++) 214 | { 215 | p = color_count[i] / (x_range * y_range); 216 | if (p > 0) 217 | { 218 | ent -= p * log(p); 219 | } 220 | 221 | } 222 | 223 | return ent; 224 | } 225 | 226 | 227 | void qeMinAreaRect(cv::Mat& srcMask) 228 | { 229 | std::vector> contours; 230 | std::vector hierarchy; 231 | cv::Mat binaryImage; 232 | cv::threshold(srcMask, binaryImage, 100, 255, CV_THRESH_BINARY_INV); 233 | cv::findContours(binaryImage, contours, hierarchy, CV_RETR_TREE, CV_CHAIN_APPROX_SIMPLE, cv::Point()); 234 | cv::RotatedRect rect = cv::minAreaRect(contours); 235 | 236 | } 237 | 238 | 239 | std::vector qeIrisQuality(cv::Mat &srcImage, cv::Mat &srcMask, cv::Mat &srcIris, cv::Mat &srcPupil) 240 | { 241 | //vector &quality[focus measure, gray, usable aera, dilation, iris shape, pupil shape] 242 | 243 | assert(srcImage.channels() == 1); 244 | assert(srcMask.channels() == 1); 245 | assert(srcIris.channels() == 1); 246 | assert(srcPupil.channels() == 1); 247 | 248 | cv::Mat roiImage, roiMask, roiIris, roiPupil; 249 | // cv::Mat gradientX, gradientY, gradient; 250 | 251 | srcIris.copyTo(roiIris); 252 | qeMaskDenoise(roiIris); 253 | cv::Rect IrisRect = qeIrisLocation(roiIris); 254 | double qeIrisradius = sqrt(pow(IrisRect.tl().x - IrisRect.br().x, 2) + pow(IrisRect.tl().y - IrisRect.br().y, 2)) / 2; 255 | cv::Point IrisCenter; 256 | IrisCenter.x = (IrisRect.tl().x + IrisRect.br().x) / 2; 257 | IrisCenter.y = (IrisRect.tl().y + IrisRect.br().y) / 2; 258 | 259 | srcPupil.copyTo(roiPupil); 260 | qeMaskDenoise(roiPupil); 261 | cv::Rect PupilRect = qeIrisLocation(roiPupil); 262 | cv::Point PupilCenter; 263 | PupilCenter.x = (PupilRect.tl().x + PupilRect.br().x) / 2; 264 | PupilCenter.y = (PupilRect.tl().y + PupilRect.br().y) / 2; 265 | 266 | double qeConcentricity = 1 - (sqrt(pow(IrisCenter.x - PupilCenter.x, 2) + pow(IrisCenter.y - PupilCenter.y, 2)) / qeIrisradius); 267 | 268 | // Margin adequacy 269 | double LM = (IrisCenter.x - qeIrisradius) / qeIrisradius; 270 | double RM = (srcImage.rows - (IrisCenter.x + qeIrisradius)) / qeIrisradius; 271 | double UM = (IrisCenter.y - qeIrisradius) / qeIrisradius; 272 | double DM = (srcImage.cols - (IrisCenter.y + qeIrisradius)) / qeIrisradius; 273 | double LEFT_MARGIN = maximum(0, minimum(1, LM / 0.6)); 274 | double RIGHT_MARGIN = maximum(0, minimum(1, RM / 0.6)); 275 | double UP_MARGIN = maximum(0, minimum(1, UM / 0.2)); 276 | double DOWN_MARGIN = maximum(0, minimum(1, DM / 0.2)); 277 | double qeMargin = minimum(minimum(LEFT_MARGIN, RIGHT_MARGIN), minimum(UP_MARGIN, DOWN_MARGIN)); 278 | 279 | srcImage(IrisRect).copyTo(roiImage); 280 | srcMask(IrisRect).copyTo(roiMask); 281 | srcIris(IrisRect).copyTo(roiIris); 282 | srcPupil(IrisRect).copyTo(roiPupil); 283 | 284 | 285 | const int width = roiImage.cols; 286 | const int height = roiImage.rows; 287 | 288 | 289 | double color_count[256]; 290 | for (int i = 0; i < 256; i++) 291 | { 292 | color_count[i] = 0; 293 | } 294 | 295 | double num_mask_pix = 0; 296 | double num_iris_pix = 0; 297 | double num_pupil_pix = 0; 298 | double focus_count = 0; 299 | 300 | int imgValue = 0; 301 | int maskValue = 0; 302 | int irisValue = 0; 303 | int pupilValue = 0; 304 | for (int x = 0; x < width; x++) 305 | { 306 | for (int y = 0; y < height; y++) 307 | { 308 | imgValue = roiImage.data[y * width + x]; 309 | maskValue = roiMask.data[y * width + x]; 310 | irisValue = roiIris.data[y * width + x]; 311 | pupilValue = roiPupil.data[y * width + x]; 312 | 313 | if (irisValue > 0) 314 | { 315 | num_iris_pix += 1; 316 | if (pupilValue < 1 && maskValue > 0) 317 | { 318 | num_mask_pix += 1; 319 | color_count[imgValue] += 1; 320 | } 321 | } 322 | if (pupilValue > 0) 323 | { 324 | num_pupil_pix += 1; 325 | } 326 | 327 | } 328 | } 329 | 330 | double p = 0; 331 | double ent = 0.0; 332 | for (int i = 0; i < 256; i++) 333 | { 334 | p = color_count[i] / num_mask_pix; 335 | if (p > 0) 336 | { 337 | ent -= p * log(p); 338 | } 339 | 340 | } 341 | 342 | std::vector quality; 343 | 344 | quality.push_back(qeFocusMeasure(srcImage, IrisRect, 1)); 345 | quality.push_back(ent); 346 | quality.push_back(num_mask_pix /(num_iris_pix - num_pupil_pix)); 347 | quality.push_back(sqrt(num_pupil_pix / num_iris_pix)); 348 | quality.push_back(qeIrisradius); 349 | quality.push_back(qeConcentricity); 350 | quality.push_back(qeMargin); 351 | 352 | 353 | roiImage.release(); 354 | roiMask.release(); 355 | roiIris.release(); 356 | roiPupil.release(); 357 | 358 | return quality; 359 | } 360 | 361 | //************************************************************************************************************************* 362 | 363 | int qeFaceValidCheck(cv::Mat &srcImage, cv::Rect &face_roi) 364 | { 365 | int flag = qeSuccess; 366 | try 367 | { 368 | face_roi = qeFaceLocation(srcImage, qeFaceBinarizationThreshold, qeFaceLocationDownsampleFactor, qeFaceLocationYAxisStart, qeFaceLocationHeightRange); 369 | 370 | } 371 | catch (...) 372 | { 373 | flag = qeFaceLocationError; 374 | } 375 | if (flag == qeSuccess) 376 | { 377 | if (face_roi.area() < qeMinFaceAera) { flag = qeFaceAeraTooSmall; } 378 | if (face_roi.area() > qeMaxFaceAera) { flag = qeFaceAeraTooLarge; } 379 | 380 | if (qeDebug == 1) { std::cout << face_roi.area() << ';'; } 381 | } 382 | 383 | return flag; 384 | } 385 | 386 | int qeFaceQualityCheck(cv::Mat &srcImage, cv::Rect &face_roi) 387 | { 388 | int flag = qeSuccess; 389 | double focus_score = -1; 390 | double ent = -1; 391 | try 392 | { 393 | focus_score = qeFocusMeasure(srcImage, face_roi, qeFaceFocusMeasureDownsampleFactor); 394 | if (qeDebug == 1) { std::cout << focus_score << ','; } 395 | } 396 | catch (...) 397 | { 398 | flag = qeFocusMeasureError; 399 | } 400 | if (flag == qeSuccess) 401 | { 402 | if (focus_score < qeMaxFaceFocusScore) { flag = qeFaceDefocusBlur; } 403 | if (focus_score > qeMinFaceFocusScore) { flag = qeFaceMotionBlur; } 404 | } 405 | try 406 | { 407 | ent = qeImageEntropy(srcImage, 8); 408 | if (qeDebug == 1) { std::cout << ent << ';'; } 409 | } 410 | catch (...) 411 | { 412 | flag = qeEntropyCalculateError; 413 | } 414 | if (flag == qeSuccess) 415 | { 416 | if (ent < qeMinFaceEntropy) { flag = qeFaceLowEntropy; } 417 | } 418 | return flag; 419 | } 420 | 421 | int qeIrisQualityCheck(cv::Mat &srcImage, cv::Mat &srcMask, cv::Mat &srcIris, cv::Mat &srcPupil) 422 | { 423 | int flag = qeSuccess; 424 | std::vector quality; 425 | try 426 | { 427 | quality = qeIrisQuality(srcImage, srcMask, srcIris, srcPupil); 428 | if (qeDebug == 1) { 429 | std::cout << quality[0] << ',' << quality[1] << ',' << quality[2] << ',' << quality[3] << ',' << quality[4] << ','; 430 | std::cout << quality[5] << ';' << quality[6] << ';'; 431 | } 432 | } 433 | catch (...) 434 | { 435 | flag = qeFocusMeasureError; 436 | } 437 | if (flag == qeSuccess) 438 | { 439 | if (quality[4] < qeMinIrisradius) { return qeIrisradiusTooSmall; } 440 | if (quality[2] > qeMaxIrisUsableAera) { return qeIrisSegmentError; } 441 | if (quality[2] < qeMinIrisUsableAera) { return qeIrisAeraTooSmall; } 442 | if (quality[1] < qeMinIrisEntropy) { return qeIrisLowEntropy; } 443 | if (quality[0] < qeMinIrisFocusScore) { return qeIrisDefocusBlur; } 444 | if (quality[0] > qeMaxIrisFocusScore) { return qeIrisMotionBlur; } 445 | if (quality[3] > qeMaxPupilDilation) { return qePupilAeraTooBig; } 446 | if (quality[3] < qeMinPupilDilation) { return qePupilAeraTooSmall; } 447 | if (quality[5] < qeMinConcentricity) { return qeConcentricityTooSmall; } 448 | if (quality[6] < qeMinMargin) { return qeMarginTooSmall; } 449 | } 450 | 451 | return flag; 452 | } 453 | -------------------------------------------------------------------------------- /util/IrisQualityEvaluation.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | //***************************************** 参数设定 ************************************************************** 12 | 13 | //控制设置 14 | #define qeDebug 1 15 | 16 | //常规参数 17 | #define qeFaceBinarizationThreshold 180 18 | #define qeFaceLocationDownsampleFactor 16 19 | #define qeFaceLocationYAxisStart 0.15 20 | #define qeFaceLocationHeightRange 0.5 21 | #define qeFaceFocusMeasureDownsampleFactor 2 22 | 23 | //阈值 24 | #define qeMinFaceAera 500000 25 | #define qeMaxFaceAera 6266880 26 | #define qeMinFaceEntropy 3.4 27 | #define qeMaxFaceFocusScore 38.0 28 | #define qeMinFaceFocusScore 20.0 29 | #define qeMaxIrisFocusScore 47.0 30 | #define qeMinIrisFocusScore 28.0 31 | #define qeMinIrisEntropy 6 32 | #define qeMaxIrisUsableAera 1.0 33 | #define qeMinIrisUsableAera 0.7 34 | #define qeMaxPupilDilation 0.7 35 | #define qeMinPupilDilation 0.2 36 | #define qeMinIrisradius 80 37 | #define qeMinConcentricity 0.9 38 | #define qeMinMargin 0.8 39 | 40 | //********************************************* 错误代码 ************************************************************** 41 | 42 | //正常 43 | #define qeSuccess 0 44 | 45 | //运行错误 46 | #define qeFocusMeasureError -1001 47 | #define qeFaceLocationError -1002 48 | #define qeEntropyCalculateError -1003 49 | 50 | //低质量图像 51 | #define qeFaceAeraTooSmall -1011 52 | #define qeFaceAeraTooLarge -1012 53 | #define qeFaceLowEntropy -1013 54 | #define qeFaceDefocusBlur -1021 55 | #define qeFaceMotionBlur -1022 56 | #define qeIrisDefocusBlur -1023 57 | #define qeIrisMotionBlur -1024 58 | #define qeIrisLowEntropy -1031 59 | #define qeIrisSegmentError -1032 60 | #define qeIrisAeraTooSmall -1033 61 | #define qePupilAeraTooSmall -1034 62 | #define qePupilAeraTooBig -1035 63 | #define qeIrisradiusTooSmall -1036 64 | #define qeConcentricityTooSmall -1037 65 | #define qeMarginTooSmall -1038 66 | 67 | /********************************************************************************************************* 68 | 函数类型:int 69 | 函数参数: 70 | 输入 71 | cv::Mat& srcImage 72 | 含义:单通道脸部图像 73 | 74 | 输出 75 | cv::Rect& face_roi 76 | 含义:可能含有眼部的区域,大小不定 77 | 78 | 返回值类型:int 79 | 含义:状态值,参见错误代码 80 | 81 | 函数功能:排除不含有效人脸或人脸面积过大的无效图像,并返回眼部区域粗估计结果. 82 | **********************************************************************************************************/ 83 | int qeFaceValidCheck(cv::Mat& srcImage, cv::Rect& face_roi); 84 | 85 | /********************************************************************************************************* 86 | 函数类型:int 87 | 函数参数: 88 | 输入 89 | 1.cv::Mat& srcImage 90 | 含义:单通道脸部图像 91 | 92 | 2.cv::Rect& face_roi 93 | 含义:可能含有眼部的区域,大小不定 94 | 95 | 输出 96 | 97 | 返回值类型:int 98 | 含义:状态值,参见错误代码 99 | 100 | 函数功能:排除脸部区域模糊和欠曝/过曝图像 101 | **********************************************************************************************************/ 102 | int qeFaceQualityCheck(cv::Mat& srcImage, cv::Rect& face_roi); 103 | 104 | /********************************************************************************************************* 105 | 函数类型:int 106 | 函数参数: 107 | 输入 108 | 1. cv::Mat &srcImage 109 | 含义:单通道眼部图像 110 | 111 | 2. cv::Mat &srcMask 112 | 含义:单通道mask,0为黑色(遮挡),255为白色 113 | 114 | 3. cv::Mat &srcIris 115 | 含义:单通道Iris mask,0为黑色(遮挡),255为白色 116 | 117 | 4. cv::Mat &srcPupil 118 | 含义:单通道Pupil mask,0为黑色(遮挡),255为白色 119 | 120 | 输出 121 | 122 | 返回值类型:int 123 | 含义:状态值,参见错误代码 124 | 125 | 函数功能:排除虹膜区域模糊,欠曝/过曝,有效面积比小和瞳孔缩放过度图像 126 | 备注: 输入的图像需要为原始比例(不能被缩放为方形) 127 | **********************************************************************************************************/ 128 | int qeIrisQualityCheck(cv::Mat& srcImage, cv::Mat& srcMask, cv::Mat& srcIris, cv::Mat& srcPupil); 129 | 130 | 131 | 132 | /********************************************************************************************************* 133 | 函数类型:double 134 | 函数参数: 135 | 输入 136 | 1. cv::Mat &srcImage 137 | 含义:原始图像,未经过缩放等操作。 138 | 139 | 2.cv::Rect Roi 140 | 含义:进行质量评价的区域。 141 | 142 | 3.int downsample_factor 143 | 含义:图像缩小的倍数,不应当小于1。 144 | 145 | 输出 146 | 返回值类型:double 147 | 含义:模糊度分数。 148 | 149 | 函数功能: 150 | 计算图像离焦模糊的程度。 151 | 152 | 备注:图像中包含大量头发区域和运动模糊等因素会导致得分过高。 153 | **********************************************************************************************************/ 154 | double qeFocusMeasure(cv::Mat &srcImage, cv::Rect Roi, int downsample_factor = 2); 155 | 156 | /********************************************************************************************************* 157 | 函数类型:cv::Rect 158 | 函数参数: 159 | 输入 160 | 1. cv::Mat &srcImage 161 | 含义:原始图像,未经过缩放等操作。 162 | 163 | 2.uchar threshold 164 | 含义:二值化阈值,实验测试80可以有效确定脸部区域。 165 | 166 | 3.double downsample_factor 167 | 含义:图像缩小的倍数,不应当小于1。计算时是隔downsample_factor行/列遍历像素统计,设置到16以上可以提升性能。 168 | 169 | 4.double ystart, double hrange 170 | 含义:roi左上角下移比例和高度缩小比例,用于调整roi区域位置和大小。ystart=0.15,hrange=0.5可以有效确定脸部区域。 171 | 172 | 输出 173 | 174 | 1. cv::Rect &roi 175 | 含义:脸部区域 176 | 177 | 返回值类型:cv::Rect 178 | 含义:脸部区域roi. 179 | 180 | 函数功能: 181 | 根据图像二值化结果粗估计人脸位置和大小。 182 | 备注:当ystart=0.15,hrange=0.5时,roi尺寸小于800*400时,可认为图像中不包含有效人脸。 183 | **********************************************************************************************************/ 184 | cv::Rect qeFaceLocation(cv::Mat &srcImage, int threshold = 180, int downsample_factor = 16, double ystart = 0.15, double hrange = 0.5); 185 | 186 | /********************************************************************************************************* 187 | 函数类型:std::vector 188 | 函数参数: 189 | 输入 190 | 1. cv::Mat &srcImage 191 | 含义:单通道眼部图像 192 | 193 | 2. cv::Mat &srcMask 194 | 含义:单通道mask,0为黑色(遮挡),255为白色 195 | 196 | 3. cv::Mat &srcIris 197 | 含义:单通道Iris mask,0为黑色(遮挡),255为白色 198 | 199 | 4. cv::Mat &srcPupil 200 | 含义:单通道Pupil mask,0为黑色(遮挡),255为白色 201 | 202 | 输出 203 | 204 | 返回值类型:std::vector 205 | 含义:质量分数,共7个元素,分别为虹膜区域清晰度,虹膜区域灰度分布,虹膜区域有效面积比,瞳孔缩放度,虹膜半径,虹膜-瞳孔同心度,边界余量 206 | 注:输入的vector &quality必须为空 207 | 208 | 函数功能:在虹膜定位分割的基础上对虹膜进行质量评价 209 | 210 | 备注:质量分数可能会增加其他元素,但已有的顺序不变 211 | **********************************************************************************************************/ 212 | std::vector qeIrisQuality(cv::Mat &srcImage, cv::Mat &srcMask, cv::Mat &srcIris, cv::Mat &srcPupil); 213 | 214 | /********************************************************************************************************* 215 | 函数类型:double 216 | 函数参数: 217 | 输入 218 | 1. cv::Mat &srcImage 219 | 含义:单通道图像 220 | 221 | 2.int downsample_factor 222 | 含义:图像缩小的倍数,不应当小于1。 223 | 224 | 输出 225 | 226 | 返回值类型:double 227 | 含义:图像熵值 228 | 229 | 函数功能:评估整个图像的灰度(亮度)分布. 230 | **********************************************************************************************************/ 231 | double qeImageEntropy(cv::Mat& srcImage, int downsample_factor = 1); 232 | 233 | /********************************************************************************************************* 234 | 函数类型:void 235 | 函数参数: 236 | 输入 237 | cv::Mat &srcMask 238 | 含义:单通道掩膜图像 239 | 240 | 输出 241 | 无 242 | 243 | 函数功能:去除掩膜的噪点与填补小的漏洞,通过一次开运算与闭运算完成. 244 | 245 | 注:它是原地修改的. 246 | **********************************************************************************************************/ 247 | void qeMaskDenoise(cv::Mat &srcMask); 248 | 249 | /********************************************************************************************************* 250 | 函数类型:cv::Rect 251 | 函数参数: 252 | 输入 253 | cv::Mat &srcIris 254 | 含义:单通道图像 255 | 256 | 输出 257 | 258 | 返回值类型:cv::Rect 259 | 含义:补全的虹膜区域 260 | 261 | 函数功能:根据掩膜图像确定虹膜区域,减少计算量. 262 | **********************************************************************************************************/ 263 | cv::Rect qeIrisLocation(cv::Mat &srcIris); 264 | 265 | -------------------------------------------------------------------------------- /util/eye_quality_fact.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def ellipse2circle(param): 6 | if len(param) == 5: 7 | r = int((param[2] + param[3]) / 2) 8 | new_param = (param[0], param[1], r) 9 | elif len(param) == 3: 10 | new_param = param 11 | else: 12 | new_param = () 13 | return new_param 14 | 15 | 16 | def points_in_circle(mask, circle): 17 | points = [] 18 | radius = circle[2]**2 19 | for y, x in np.where(mask == True): 20 | if (x - circle[0])**2 + (y - circle[1])**2 <= radius: 21 | points.append((x, y)) 22 | return set(points) 23 | 24 | 25 | def points_out_circle(mask, circle): 26 | points = [] 27 | radius = circle[2]**2 28 | for y, x in np.where(mask == True): 29 | if (x - circle[0])**2 + (y - circle[1])**2 > radius: 30 | points.append((x, y)) 31 | return set(points) 32 | 33 | 34 | def points_between_circle(mask, circle1, circle2): 35 | points = [] 36 | radius1 = circle1[2]**2 37 | radius2 = circle2[2]**2 38 | for x in range(mask.shape[1]): 39 | for y in range(mask.shape[0]): 40 | conditions = (mask[y, x] == True) and ( 41 | (x - circle1[0])**2 + (y - circle1[1])**2 > radius1) and ( 42 | (x - circle2[0])**2 + (y - circle2[1])**2 <= radius2) 43 | if conditions: 44 | points.append((x, y)) 45 | return set(points) 46 | 47 | 48 | # ############################################################################# 49 | 50 | 51 | def sharpness(img): 52 | # Sharpness (defocus/motion) 53 | gaussianX = cv2.Sobel(img, cv2.CV_16U, 1, 0) 54 | gaussianY = cv2.Sobel(img, cv2.CV_16U, 1, 0) 55 | fm = np.mean(np.sqrt(gaussianX**2 + gaussianY**2)) 56 | return fm 57 | 58 | 59 | def iris_size(iris_param): 60 | # Iris size (iris radius in pixel) 61 | if len(iris_param) == 3: 62 | r = iris_param[2] 63 | elif len(iris_param) == 5: 64 | r = int((iris_param[2] + iris_param[3]) / 2) 65 | else: 66 | r = 0 67 | return r 68 | 69 | 70 | def dilation(iris_param, pupil_param): 71 | # Pupil iris ratio (ratio of pupil diameter over iris diameter) 72 | if len(iris_param) == 3 and len(pupil_param) == 3: 73 | ri = iris_param[2] 74 | pi = pupil_param[2] 75 | elif len(iris_param) == 5 and len(pupil_param) == 5: 76 | ri = int((iris_param[2] + iris_param[3]) / 2) 77 | pi = int((pupil_param[2] + pupil_param[3]) / 2) 78 | else: 79 | ri, pi = 0, 0 80 | 81 | if ri != 0 and pi != 0: 82 | dilation_ratio = pi / ri * 100 83 | else: 84 | dilation_ratio = -1 85 | return dilation_ratio 86 | 87 | 88 | def gray_level_spread(img, mask): 89 | # Gray level spread 90 | img = img.astype(np.int) 91 | img[mask != True] = -1 92 | usable_pix_num = np.sum(mask) 93 | ent = 0.0 94 | for i in range(256): 95 | p = np.sum(img == i) / usable_pix_num 96 | if p != 0: 97 | ent -= p * np.log2(p) 98 | return ent 99 | 100 | 101 | def usable_area(mask, iris_param, pupil_param): 102 | # Usable iris area (percentage of usable iris area) 103 | 104 | # usable_area = points_between_circle(mask, pupil_param, iris_param) 105 | # all_area = points_between_circle( 106 | # np.ones_like(mask).astype(np.bool), pupil_param, iris_param) 107 | # usable_area_ratio = len(usable_area) / len(all_area) * 100 108 | # return usable_area_ratio 109 | 110 | usable_area = mask.sum() 111 | all_area = np.pi * (iris_param[2]**2 - pupil_param[2]**2) 112 | usable_area_ratio = usable_area / all_area * 100 113 | return usable_area_ratio 114 | 115 | 116 | def iris_sclera_contrast(img, mask, iris_param, pupil_param): 117 | pass 118 | 119 | 120 | def iris_pupil_contrast(img, mask, iris_param, pupil_param): 121 | pass 122 | 123 | 124 | # ############################################################################ 125 | def ini_reader(filepath): 126 | with open(filepath, 'r') as f: 127 | data = [x.strip() for x in f.readlines()] 128 | if len(data) == 15: 129 | # ellipse 130 | iris_param = [float(x.split('=')[1]) for x in data[2:7]] 131 | pupil_param = [float(x.split('=')[1]) for x in data[10:15]] 132 | flag = 'ellipse' 133 | elif len(data) == 11: 134 | # circle 135 | iris_param = [float(x.split('=')[1]) for x in data[2:5]] 136 | pupil_param = [float(x.split('=')[1]) for x in data[8:11]] 137 | flag = 'circle' 138 | else: 139 | # None 140 | iris_param = None 141 | pupil_param = None 142 | flag = None 143 | return flag, iris_param, pupil_param 144 | 145 | 146 | if __name__ == "__main__": 147 | import os 148 | from glob import glob 149 | from tqdm import tqdm 150 | import shutil 151 | 152 | with open('data/cx2/train.txt', 'r') as f: 153 | namelist = [x.split('.')[0] for x in f.readlines()] 154 | 155 | quality_map = [] 156 | errors = [] 157 | for filename in tqdm(namelist, ncols=79, ascii=True): 158 | try: 159 | img = cv2.imread( 160 | 'data/cx2/Image/{}.bmp'.format(filename), 0) 161 | mask = cv2.imread( 162 | 'data/cx2/Result/Mask/{}.png'.format(filename), 0) 163 | if mask.shape != img.shape: 164 | mask = cv2.resize(mask, 165 | (img.shape[1], img.shape[0])).astype(np.bool) 166 | else: 167 | mask = mask.astype(np.bool) 168 | flag, iris_param, pupil_param = ini_reader( 169 | 'data/cx2/Result/seg_param/{}.ini'.format(filename)) 170 | 171 | fm = sharpness(img) 172 | ir = iris_size(iris_param) 173 | dr = dilation(iris_param, pupil_param) 174 | gls = gray_level_spread(img, mask) 175 | uar = usable_area(mask, iris_param, pupil_param) 176 | quality_map.append((filename + '.bmp', fm, ir, dr, gls, uar)) 177 | except Exception as e: 178 | print(e) 179 | 180 | with open('cx2_train_quality.txt', 'w') as f: 181 | for line in quality_map: 182 | f.write('{}, {} {} {} {} {}\n'.format(*line)) 183 | 184 | 185 | with open('data/cx1/train.txt', 'r') as f: 186 | namelist = [x.split('.')[0] for x in f.readlines()] 187 | 188 | quality_map = [] 189 | errors = [] 190 | for filename in tqdm(namelist, ncols=79, ascii=True): 191 | try: 192 | img = cv2.imread( 193 | 'data/cx1/Image/{}.bmp'.format(filename), 0) 194 | mask = cv2.imread( 195 | 'data/cx1/Mask/{}.png'.format(filename), 0) 196 | if mask.shape != img.shape: 197 | mask = cv2.resize(mask, 198 | (img.shape[1], img.shape[0])).astype(np.bool) 199 | else: 200 | mask = mask.astype(np.bool) 201 | flag, iris_param, pupil_param = ini_reader( 202 | 'data/cx1/seg_param/{}.ini'.format(filename)) 203 | 204 | fm = sharpness(img) 205 | ir = iris_size(iris_param) 206 | dr = dilation(iris_param, pupil_param) 207 | gls = gray_level_spread(img, mask) 208 | uar = usable_area(mask, iris_param, pupil_param) 209 | quality_map.append((filename + '.bmp', fm, ir, dr, gls, uar)) 210 | except Exception as e: 211 | print(e) 212 | 213 | with open('cx1_train_quality.txt', 'w') as f: 214 | for line in quality_map: 215 | f.write('{}, {} {} {} {} {}\n'.format(*line)) -------------------------------------------------------------------------------- /util/fmeasure.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def fmeasure(img, measure='GRAS', roi=None): 6 | if roi is not None: 7 | img = img[roi[1]:roi[1] + roi[3], roi[0]:roi[0] + roi[2]] 8 | 9 | img = img.astype(np.double) 10 | 11 | if measure == 'GRAT': 12 | # >5.5 13 | Th = 0 14 | img_x = img.copy() 15 | img_y = img.copy() 16 | img_y[:-1, :] = np.diff(img, axis=0) 17 | img_x[:, :-1] = np.diff(img, axis=1) 18 | fm = np.maximum(np.abs(img_x), np.abs(img_y)) 19 | fm[fm < Th] = 0 20 | fm = np.sum(fm / np.sum(np.sum(fm != 0))) 21 | 22 | elif measure == 'GLVN': 23 | # < 17 24 | fm = np.std(img)**2 / np.mean(img) 25 | 26 | elif measure == 'GRAE': 27 | img_x = img.copy() 28 | img_y = img.copy() 29 | img_y[:-1, :] = np.diff(img, axis=0) 30 | img_x[:, :-1] = np.diff(img, axis=1) 31 | fm = np.mean(img_x**2 + img_y**2) 32 | 33 | elif measure == 'GRAS': 34 | # >20.5 35 | img_x = np.diff(img, axis=1) 36 | img_x[img_x < 0] = 0 37 | fm = img_x**2 38 | fm = np.mean(fm) 39 | 40 | elif measure == 'LAPV': 41 | """ 42 | Implements the Variance of Laplacian (LAP4) focus measure 43 | operator. Measures the amount of edges present in the image. 44 | """ 45 | fm = np.std(cv2.Laplacian(img, cv2.CV_64F))**2 46 | 47 | elif measure == "LAPM": 48 | """ 49 | Implements the Modified Laplacian (LAP2) focus measure 50 | operator. Measures the amount of edges present in the image. 51 | """ 52 | kernel = np.array([-1, 2, -1]) 53 | laplacianX = np.abs(cv2.filter2D(img, -1, kernel)) 54 | laplacianY = np.abs(cv2.filter2D(img, -1, kernel.T)) 55 | fm = np.mean(laplacianX + laplacianY) 56 | 57 | elif measure == "TENG": 58 | """ 59 | Implements the Tenengrad (TENG) focus measure operator. 60 | Based on the gradient of the image. 61 | """ 62 | gaussianX = cv2.Sobel(img, cv2.CV_64F, 1, 0) 63 | gaussianY = cv2.Sobel(img, cv2.CV_64F, 1, 0) 64 | fm = np.mean(gaussianX**2 + gaussianY**2) 65 | 66 | elif measure == "TENV": 67 | gaussianX = cv2.Sobel(img, cv2.CV_64F, 1, 0) 68 | gaussianY = cv2.Sobel(img, cv2.CV_64F, 1, 0) 69 | fm = np.std(gaussianX**2 + gaussianY**2)**2 70 | 71 | elif measure == 'SFRQ': 72 | img_x = img.copy() 73 | img_y = img.copy() 74 | img_y[:-1, :] = np.diff(img, axis=0) 75 | img_x[:, :-1] = np.diff(img, axis=1) 76 | fm = np.mean(np.sqrt(img_x**2 + img_y**2)) 77 | 78 | elif measure == 'ISO': 79 | kernel = np.array([ 80 | [0, 1, 1, 2, 2, 2, 1, 1, 0], 81 | [1, 2, 4, 5, 5, 5, 4, 2, 1], 82 | [1, 4, 5, 3, 0, 3, 5, 4, 1], 83 | [2, 5, 3, -12, -24, -12, 3, 5, 2], 84 | [2, 5, 0, -24, -40, -24, 0, 5, 2], 85 | [2, 5, 3, -12, -24, -12, 3, 5, 2], 86 | [1, 4, 5, 3, 0, 3, 5, 4, 1], 87 | [1, 2, 4, 5, 5, 5, 4, 2, 1], 88 | [0, 1, 1, 2, 2, 2, 1, 1, 0], 89 | ]) 90 | fm = cv2.filter2D(img, -1, kernel, borderType=cv2.BORDER_REPLICATE) 91 | fm = np.mean(fm * fm) 92 | fm = int(np.round(fm)) 93 | # fm = 100 * fm * fm / (fm * fm + 1800000) 94 | 95 | else: 96 | raise ValueError 97 | 98 | return fm 99 | 100 | 101 | if __name__ == "__main__": 102 | pass 103 | --------------------------------------------------------------------------------