├── .gitignore
├── LICENSE
├── README.md
├── deep-learning.md
├── resources
├── bw.png
├── contrastive_loss_function.png
├── dataset.png
├── match.png
├── network_structure.png
└── siamese_network.jpg
├── result
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── PR-curve.png
├── confusion_mat.png
├── confusion_matrix_tradition.png
├── main.py
└── pr-curve_tradition.png
├── src
├── deep_learning
│ ├── criterion.py
│ ├── data.py
│ ├── evluation.py
│ ├── main.py
│ ├── model.pth
│ ├── model.py
│ ├── train.py
│ └── utils.py
└── traditional_method
│ ├── evalutaion.py
│ ├── frequest.py
│ ├── image_enhance.py
│ ├── main.py
│ ├── ridge_filter.py
│ ├── ridge_freq.py
│ ├── ridge_orient.py
│ └── ridge_segment.py
└── traditional-method.md
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | data/
3 | *.pkl
4 | *.pyc
5 | *.txt
6 | *.docx
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 指纹验证(fingerprint verifcation)
2 | ## 一、介绍
3 | 指纹识别即对人体指纹进行识别,该任务有两种理解方式:
4 |
5 | 1. 由于标签离散,将指纹识别理解为分类任务(classification),采用one-hot编码和交叉熵损失函数测度,这种方式针对封闭的指纹分类系统,模型训练好不能增加人数。
6 | 2. 将指纹识别理解成图像检索任务(image retrieval),即给定两张指纹图片(其中一张是咨询(query)图片),输出它们的相似性测度。这是指纹识别常用的方式,可以随时增减人,常用于指纹验证(fingerprint verification)系统和指纹检索系统。
7 |
8 | 本仓库分别用[深度学习](deep-learning.md)和[传统方法](traditional-method.md)实现第二种指纹验证,
9 |
10 | ## 二、数据集
11 | 本次实验采用2018年发布的Sokoto Coventry指纹识别数据集(SOCO-Fing)[1]。SOCO-Fing的原始图片600名非洲人的6000张指纹图片组成(4000训练,1000验证,1000测试)。除原始图片外,该数据集还对图片进行了数据扩充,包括Z字切割、湮没和中心旋转(如下图所示)。这些变换将任务分成了简单、中等、难三个层级,数据集总共55273张指纹图片。
12 |
13 |

14 |
15 |
16 | SOCO-Fing数据集 [https://www.kaggle.com/ruizgara/socofing/home](https://www.kaggle.com/ruizgara/socofing/home)。
17 |
18 | ## 三、参考资料
19 | [1] Shehu Y I , Ruiz-Garcia A , Palade V , et al. Sokoto Coventry Fingerprint Dataset[J]. 2018.
20 |
--------------------------------------------------------------------------------
/deep-learning.md:
--------------------------------------------------------------------------------
1 | # 指纹识别深度学习方法
2 | ## 一、模型
3 | 本次实验使用Siamese网络架构,构造正负样本对进行训练,将样本嵌入(embedding)到一个度量空间,使得相同语义(同一人的同一手指指纹)的样本靠近,不同语义的样本远离。
4 |
5 | 
6 |
7 |
8 |
9 | 网络具体架构如下图所示,由4层卷积池化层加4层全连接层构成。其中每次卷积都采用3 x 3卷积核,s = 1, p = 2。卷积层都采用CBR结构(Conv + BN + ReLU),
10 | 池化操作为平均池化(k = 2, s = 2, p = 0),最终通过全连接层将数据嵌入到一个30维的向量空间中。
11 |
12 |
13 |
14 | 
15 |
16 | ## 二、损失函数
17 | 损失函数采取了LeCun于2006年提出的对比损失函数(Contrastive Loss Function)[1],该损失函数的设计思想是:
18 |
19 | * 近似样本之间的距离越小越好。
20 | * 不相似样本之间的距离如果小于m,则相互排斥使其距离接近m。
21 |
22 | 可以用下图形象地表示该损失函数的作用效果。
23 |
24 |
25 |
26 | 
27 |
28 |
29 |
30 | ## 三、结果
31 |
32 | 用测试数据(9982正样本对,15964个负样本对)对模型进行评估,得到PR曲线如下所示。可以看到查准率(Precision)在不同查全率(Recall)上都是100%。说明正样本对和负样本对距离之间存在一个界(margin),正负样本之间不存在交叠的情况。当取距离大于1.148为负样本、距离小于1.148为负样本时可以完美地分离测试集正负样本(precision=100%, recall=100%)。
33 |
34 | 
35 |
36 | 取阈值为1.99模型对应混淆矩阵如下所示:
37 |
38 | 
39 |
40 |
41 |
42 | 其中Precision = 99.87%, Recall=100%, Accuracy = 99.87%。以下是将负样本对错分成正样本对的32个样本对:
43 |
44 | 
45 |
46 | 
47 |
48 | 
49 |
50 | 
51 |
52 | FP样本对
53 |
54 | 可以看到这些错分为正样本的指纹对在空间分布上具有较高的一致性,说明Siamese网络提取到了指纹空间分布的信息,此外可以发现这些错分的指纹对中至少有一方是包含噪声的,说明模型可能被噪声干扰,稳定性可以通过数据增强技术进一步提高。
55 |
56 | **补充**:本次实验还使用了基于传统模式识别的指纹识别算法,可以在[traditonal-method.md](./traditional-method.md)中查看
57 |
58 | ## 四、参考资料
59 |
60 | [1] [Dimensionality Reduction by Learning an Invariant Mapping](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf), 2006, Raia Hadsell, Sumit Chopra, Yann LeCun
61 |
62 |
--------------------------------------------------------------------------------
/resources/bw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/resources/bw.png
--------------------------------------------------------------------------------
/resources/contrastive_loss_function.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/resources/contrastive_loss_function.png
--------------------------------------------------------------------------------
/resources/dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/resources/dataset.png
--------------------------------------------------------------------------------
/resources/match.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/resources/match.png
--------------------------------------------------------------------------------
/resources/network_structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/resources/network_structure.png
--------------------------------------------------------------------------------
/resources/siamese_network.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/resources/siamese_network.jpg
--------------------------------------------------------------------------------
/result/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/0.jpg
--------------------------------------------------------------------------------
/result/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/1.jpg
--------------------------------------------------------------------------------
/result/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/2.jpg
--------------------------------------------------------------------------------
/result/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/3.jpg
--------------------------------------------------------------------------------
/result/PR-curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/PR-curve.png
--------------------------------------------------------------------------------
/result/confusion_mat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/confusion_mat.png
--------------------------------------------------------------------------------
/result/confusion_matrix_tradition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/confusion_matrix_tradition.png
--------------------------------------------------------------------------------
/result/main.py:
--------------------------------------------------------------------------------
1 |
2 | import matplotlib.pyplot as plt
3 |
4 | with open("./result.txt", "r+") as f:
5 | lines_lst = f.readlines()
6 | loss_lst = []
7 | iterations = []
8 | i = 0
9 | for line in lines_lst:
10 | str_ = line.strip()
11 | loss = float(str_.split("loss:")[-1])
12 | loss_lst.append(loss)
13 | iterations.append(i * 50)
14 | i+= 1
15 |
16 | plt.figure(1)
17 | plt.plot(iterations, loss_lst)
18 | plt.title("Loss-Iteration Curve")
19 | plt.xlabel("Iteration")
20 | plt.ylabel("Loss")
21 | plt.show()
22 |
23 |
--------------------------------------------------------------------------------
/result/pr-curve_tradition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/result/pr-curve_tradition.png
--------------------------------------------------------------------------------
/src/deep_learning/criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class Criterion(nn.Module):
7 | """
8 | Contrastive loss function.
9 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
10 | """
11 |
12 | def __init__(self, margin=2.0):
13 | super(Criterion, self).__init__()
14 | self.margin = margin
15 |
16 | def forward(self, output1, output2, label):
17 | # N x 1 distance
18 | euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
19 | # postive sample label = 0 distance descend
20 | # negative sample label = 1
21 | # negative sample distance has lower bound self.margin
22 | loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
23 | label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
24 | # cosine_similarity = torch.cosine_similarity(output1, output2, dim=1)
25 | # loss = torch.mean((cosine_similarity - label) ** 2)
26 | return loss
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/src/deep_learning/data.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data.dataset as dataset
2 | import pickle
3 | import numpy as np
4 | import torch
5 | import torchvision.transforms as transforms
6 | import cv2
7 |
8 | # training_sample_count = 4000
9 | # validation_sample_count = 1000
10 | # test_sample_count = 1000
11 |
12 |
13 | class DataSampler:
14 | def __init__(self, train_num, batch_size, pkl_path):
15 | self.train_num = train_num
16 | self.load_dataset(pkl_path)
17 | self.batch_size = batch_size
18 |
19 | def load_dataset(self, pkl_path):
20 | with open(pkl_path, "rb") as f:
21 | self.files_lst = pickle.load(f)
22 |
23 | def sample(self):
24 | """
25 | batch_size: positive and negative
26 | """
27 | # ############################################################################
28 | # sample postive samples
29 | # ############################################################################
30 | # 对正样本(同一人的同一手指),标签为0
31 | pos_label = torch.zeros([self.batch_size, 1], dtype=torch.float32)
32 |
33 | pos_index = np.random.randint(0, self.train_num, [self.batch_size, 1])
34 | each_index = np.random.randint(0, len(self.files_lst[pos_index[0, 0]]), [1, 2])
35 |
36 | address_1 = self.files_lst[pos_index[0, 0]][each_index[0, 0]]
37 | # print(address_1)
38 | siamese_1_img_batch = torch.unsqueeze(transforms.ToTensor()(
39 | cv2.imread(address_1, cv2.IMREAD_GRAYSCALE))
40 | , dim=0)
41 |
42 | address_2 = self.files_lst[pos_index[0, 0]][each_index[0, 1]]
43 | siamese_2_img_batch = torch.unsqueeze(transforms.ToTensor()(
44 | cv2.imread(address_2, cv2.IMREAD_GRAYSCALE))
45 | , dim=0)
46 |
47 | for i in range(1, self.batch_size):
48 | each_index = np.random.randint(0, len(self.files_lst[pos_index[i, 0]]), [1, 2])
49 | address_1 = self.files_lst[pos_index[i, 0]][each_index[0, 0]]
50 | siamese_1_pos = torch.unsqueeze(transforms.ToTensor()(
51 | cv2.imread(address_1, cv2.IMREAD_GRAYSCALE))
52 | , dim=0)
53 | # ToTensor: normalized to [0, 1] division 255
54 | address_2 = self.files_lst[pos_index[i, 0]][each_index[0, 1]]
55 | siamese_2_pos = torch.unsqueeze(transforms.ToTensor()(
56 | cv2.imread(address_2, cv2.IMREAD_GRAYSCALE))
57 | , dim=0)
58 | # print("address_1 = ", address_1)
59 | # print("address_2 = ", address_2)
60 |
61 | # print(siamese_1_pos.shape)
62 | # print(siamese_2_pos.shape)
63 |
64 | siamese_1_img_batch = torch.cat([siamese_1_img_batch, siamese_1_pos], dim=0)
65 | siamese_2_img_batch = torch.cat([siamese_2_img_batch, siamese_2_pos], dim=0)
66 |
67 | # ############################################################################
68 | # sample negative samples
69 | # ############################################################################
70 | # 对负样本,标签为1
71 | neg_label = torch.ones([self.batch_size, 1], dtype=torch.float32)
72 | label_tensor = torch.cat([pos_label, neg_label], dim=0)
73 | for i in range(self.batch_size):
74 | while True:
75 | neg_index = np.random.randint(0, self.train_num, [1, 2])
76 | if neg_index[0, 0] != neg_index[0, 1]:
77 | break
78 | index_1 = np.random.randint(0, len(self.files_lst[neg_index[0, 0]]), [1])
79 | address_1 = self.files_lst[neg_index[0, 0]][index_1[0]]
80 | siamese_1_neg = torch.unsqueeze(transforms.ToTensor()(
81 | cv2.imread(address_1, cv2.IMREAD_GRAYSCALE))
82 | , dim=0)
83 | index_2 = np.random.randint(0, len(self.files_lst[neg_index[0, 1]]), [1])
84 |
85 | address_2 = self.files_lst[neg_index[0, 1]][index_2[0]]
86 | siamese_2_neg = torch.unsqueeze(transforms.ToTensor()(
87 | cv2.imread(address_2, cv2.IMREAD_GRAYSCALE))
88 | , dim=0)
89 | # print("address_1 = ", address_1)
90 | # print("address_2 = ", address_2)
91 |
92 | # print(siamese_1_neg.shape)
93 | # print(siamese_2_neg.shape)
94 |
95 | siamese_1_img_batch = torch.cat([siamese_1_img_batch, siamese_1_neg], dim=0)
96 | siamese_2_img_batch = torch.cat([siamese_2_img_batch, siamese_2_neg], dim=0)
97 |
98 | return siamese_1_img_batch, siamese_2_img_batch, label_tensor
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/src/deep_learning/evluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 | from torch.utils.data import Dataset
4 | from torch.utils.data import DataLoader
5 | import cv2
6 | import torchvision.transforms as transforms
7 | import torch.nn.functional as functional
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 | from sklearn.metrics import confusion_matrix
11 | import itertools
12 | from torchvision import transforms
13 |
14 |
15 | class InferenceDataset(Dataset):
16 | def __init__(self, inference_txt_address):
17 | super(InferenceDataset, self).__init__()
18 | self.load_txt(inference_txt_address)
19 |
20 | def load_txt(self, inference_txt_address):
21 | with open(inference_txt_address, "r+") as f:
22 | self.content = f.readlines()
23 |
24 | def __len__(self):
25 | return len(self.content)
26 |
27 | def __getitem__(self, item):
28 | content = self.content[item].strip()
29 | lst = content.split(" ")
30 | address1 = lst[0]
31 | # print("address1 = ", address1)
32 | img1 = transforms.ToTensor()(cv2.imread(address1, cv2.IMREAD_GRAYSCALE))
33 | address2 = lst[1]
34 | img2 = transforms.ToTensor()(cv2.imread(address2, cv2.IMREAD_GRAYSCALE))
35 | label = torch.tensor([int(lst[2])], dtype=torch.float32)
36 | print("label = ", lst[2])
37 | return img1, img2, label
38 |
39 |
40 | def confusion_matrix(preds, labels, conf_matrix):
41 | # confusion matrix
42 | # yaxis - gt; xaxis - pred
43 | for gt, pred in zip(labels, preds):
44 | conf_matrix[int(round(gt.item())), int(round(pred.item()))] += 1
45 | return conf_matrix
46 |
47 |
48 | # 绘制混淆矩阵
49 | def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
50 | """
51 | plots the confusion matrix.
52 | Normalization can be applied by setting `normalize=True`.
53 | Input
54 | - cm : confusion matrix
55 | - classes : 混淆矩阵中每一行每一列对应的列
56 | - normalize : True:percentage, False:Num
57 | """
58 | if normalize:
59 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
60 | print("Normalized confusion matrix")
61 | else:
62 | print('Confusion matrix, without normalization')
63 | print(cm)
64 | plt.imshow(cm, interpolation='nearest', cmap=cmap)
65 | plt.title(title)
66 | plt.colorbar()
67 | tick_marks = np.arange(len(classes))
68 | plt.xticks(tick_marks, classes)
69 | plt.yticks(tick_marks, classes)
70 |
71 | plt.axis("equal")
72 | ax = plt.gca() # 获得当前axis
73 | left, right = plt.xlim() # 获得x轴最大最小值
74 | ax.spines['left'].set_position(('data', left))
75 | ax.spines['right'].set_position(('data', right))
76 | for edge_i in ['top', 'bottom', 'right', 'left']:
77 | ax.spines[edge_i].set_edgecolor("white")
78 | thresh = cm.max() / 2.
79 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
80 | num = '{:.2f}'.format(cm[i, j]) if normalize else int(cm[i, j])
81 | plt.text(j, i, num,
82 | verticalalignment='center',
83 | horizontalalignment="center",
84 | color="white" if num > thresh else "black")
85 | plt.tight_layout()
86 | plt.ylabel('Actual label')
87 | plt.xlabel('Predict label')
88 | plt.show()
89 |
90 | def plot_wrong_img_pairs(address):
91 | with open(address, "rb+") as f:
92 | wrong_lst = pickle.load(f)
93 | print(len(wrong_lst))
94 | print(type(wrong_lst[0][0]))
95 | pil = transforms.ToPILImage()
96 | for i, img_pair in enumerate(wrong_lst):
97 | if i % 8 == 0:
98 | imgs = np.vstack([img_pair[0], img_pair[1]])
99 | else:
100 | img = np.vstack([img_pair[0], img_pair[1]])
101 | imgs = np.hstack([imgs, img])
102 | if i % 8 == 7:
103 | # imgs = (imgs * 255).astype(int)
104 | print(imgs)
105 | cv2.imshow("img", imgs)
106 | # imgs = pil(imgs)
107 | cv2.waitKey(0)
108 | # if imgs.mode == "F":
109 | # imgs = imgs.convert('L')
110 | # imgs.save("../../result/{}.jpg".format(i // 8))
111 | cv2.imwrite("../../result/{}.jpg".format(i // 8), imgs * 255)
112 |
113 |
114 | if __name__ == "__main__":
115 | siamese_nn = torch.load("model.pth")
116 | siamese_nn = siamese_nn.eval()
117 | inference_dataset = InferenceDataset("../pos_samples.txt")
118 | batch_size = 16
119 | dataloader = DataLoader(inference_dataset, batch_size=batch_size, shuffle=False)
120 | epsilon = 0.01
121 | margin = 2.0
122 | true_num = 0
123 | total_num = 0
124 | # construct a confusion matrix
125 | conf_matrix = torch.zeros(2, 2)
126 | wrong_img_pairs = []
127 | to_pil = transforms.ToPILImage()
128 | dis_label_lst = []
129 | for i, data in enumerate(dataloader):
130 | img0, img1, label = data
131 | img0, img1, label = img0.cuda(), img1.cuda(), label.cuda() # 数据移至GPU
132 | dis_label = []
133 | # dim = batchsize x 30 (embedding space dimension)
134 | output1, output2 = siamese_nn(img0, img1)
135 | euclidean_dis = functional.pairwise_distance(output1, output2, keepdim=True)
136 | dis_label.append(euclidean_dis.detach().cpu())
137 | dis_label.append(label.detach().cpu())
138 | dis_label_lst.append(dis_label)
139 | # set all are positive samples
140 | prediction = torch.zeros([img0.shape[0], 1], dtype=torch.int).cuda()
141 | # distance > margin - epsilon: negative samples, label=1
142 | # prediction[torch.where(euclidean_dis > margin )] = 1
143 | prediction[torch.where(euclidean_dis > 1.148 )] = 1
144 |
145 | # print((prediction - label).shape)
146 | true_num += len(torch.where(torch.abs(prediction - label) < epsilon)[0]) # 正确数目
147 | wrong_num = len(torch.where(torch.abs(prediction - label) > epsilon)[0])
148 |
149 | if wrong_num > 0:
150 | print("wrong exist ", wrong_num)
151 | lst = []
152 | for i in range(wrong_num):
153 | lst.append(to_pil(img0[torch.where(torch.abs(prediction - label) > epsilon)[0][i], :, :, :].cpu().squeeze(0)))
154 | lst.append(to_pil(img1[torch.where(torch.abs(prediction - label) > epsilon)[0][i], :, :, :].cpu().squeeze(0)))
155 | wrong_img_pairs.append(lst)
156 | # print("true num = ", true_num)
157 | total_num += img0.shape[0] # 总数目
158 | # print("total_num = ", total_num)
159 | conf_matrix = confusion_matrix(prediction, labels=label, conf_matrix=conf_matrix)
160 |
161 | with open("dis_label.pkl", "wb+") as f:
162 | pickle.dump(dis_label_lst, f)
163 |
164 | print("TF+TN:", true_num, "\n")
165 | print("Total:", total_num, "\n")
166 | print("Accuracy:", true_num / total_num)
167 | # conf_matrix = torch.tensor([[9982, 32],
168 | # [0, 15932]])
169 | plot_confusion_matrix(conf_matrix.numpy(), classes=["positive", "negative"], normalize=False,
170 | title='Confusion Matrix')
171 | #
172 | # with open("wrong_img_pairs.pkl", "wb+") as f:
173 | # pickle.dump(wrong_img_pairs, f)
174 | # plot_wrong_img_pairs("wrong_img_pairs.pkl")
175 |
--------------------------------------------------------------------------------
/src/deep_learning/main.py:
--------------------------------------------------------------------------------
1 | from src.deep_learning.data import DataSampler
2 | from src.deep_learning.model import SiameseNetwork
3 | from src.deep_learning.criterion import Criterion
4 | import torch
5 |
6 | from src.deep_learning.train import Train
7 |
8 | pkl_path = "../../data/files_dump.pkl"
9 | data_sampler = DataSampler(train_num=4000, batch_size=32, pkl_path=pkl_path)
10 | siamese_nn = SiameseNetwork().cuda()
11 | criterion = Criterion()
12 | train = Train(data_sampler=data_sampler, model=siamese_nn, criterion=criterion)
13 | train.train(iterations=50000, lr=3e-4)
14 |
15 |
16 | with open("model.pth", "wb+") as f:
17 | torch.save(siamese_nn, f)
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/src/deep_learning/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GCaptainNemo/Fingerprint-Verfication/5d19eba32c8f413696dca4f39a0031b868e04914/src/deep_learning/model.pth
--------------------------------------------------------------------------------
/src/deep_learning/model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | # 96 x 96
4 |
5 |
6 | class SiameseNetwork(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 | self.cnn1 = nn.Sequential(
10 | nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1),
11 | nn.BatchNorm2d(4),
12 | nn.ReLU(inplace=True), # inplace=True 直接在原地址上修改变量
13 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
14 | # 4 x 48 x 48
15 |
16 | nn.Conv2d(4, 8, kernel_size=3, stride=1, padding=1),
17 | nn.BatchNorm2d(8),
18 | nn.ReLU(inplace=True),
19 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
20 | # 8 x 24 x 24
21 |
22 | nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
23 | nn.BatchNorm2d(16),
24 | nn.ReLU(inplace=True),
25 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
26 | # 16 x 12 x 12
27 |
28 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
29 | nn.BatchNorm2d(32),
30 | nn.ReLU(inplace=True),
31 | nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
32 | # 32 x 6 x 6
33 | )
34 |
35 | self.fc1 = nn.Sequential(
36 | nn.Linear(32 * 6 * 6, 500),
37 | nn.ReLU(inplace=True),
38 | nn.Linear(500, 500),
39 | nn.ReLU(inplace=True),
40 | nn.Linear(500, 100),
41 | nn.ReLU(inplace=True),
42 | nn.Linear(100, 30)
43 | )
44 |
45 |
46 | def forward_once(self, x):
47 | output = self.cnn1(x)
48 | # print("output.shape = ", output.shape)
49 | # reshape N x d feature
50 | output = output.view(output.size()[0], -1)
51 | output = self.fc1(output)
52 | return output
53 |
54 | def forward(self, input1, input2):
55 | output1 = self.forward_once(input1)
56 | output2 = self.forward_once(input2)
57 | return output1, output2
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/src/deep_learning/train.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | class Train:
5 | def __init__(self, data_sampler, model, criterion):
6 | self.data_sampler = data_sampler
7 | self.criterion = criterion
8 | self.model = model
9 |
10 | def train(self, iterations, lr):
11 | optimizer = optim.Adam(self.model.parameters(), lr)
12 | avg_loss = 0
13 | for e in range(iterations):
14 | siamese_1, siamese_2, label = self.data_sampler.sample()
15 | siamese_1, siamese_2, label = siamese_1.cuda(), siamese_2.cuda(), label.cuda()
16 | optimizer.zero_grad()
17 | output1, output2 = self.model(siamese_1, siamese_2)
18 | loss = self.criterion(output1, output2, label)
19 |
20 | avg_loss = avg_loss + float(loss.item())
21 | loss.backward()
22 | optimizer.step()
23 | if e % 50 == 49:
24 | loss = avg_loss / 50
25 | print("Step {} - lr {} - loss: {}".format(e, lr, loss))
26 | avg_loss = 0
27 |
28 | # error = self.siamese_nn.loss_func(2 ** 8)
29 | # self.siamese_nn.append(error.detach())
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/src/deep_learning/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import cv2
4 | import torch
5 | import torchvision.transforms as transforms
6 | from sklearn.metrics import precision_recall_curve
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 |
10 | hand_encode_dict = {"Left": 0, "Right": 1}
11 | finger_encode_dict = {"thumb": 0, "index": 1, "middle": 2, "ring": 3, "little": 4}
12 |
13 |
14 | def dump_files(input_path, output_path):
15 | """
16 | output a 6000 x d (600 people x 10 fingers) dimension list
17 | 6000 x d, each dimension is positive sample, otherwise it's negative
18 | """
19 | six_thousand_lst = [[] for _ in range(6000)]
20 | for root, dirs, files in os.walk(input_path):
21 | for file in files:
22 | lst = file.split("_")
23 | index = (int(lst[0]) - 1) * 10 + hand_encode_dict[lst[3]] * 5 + finger_encode_dict[lst[4]]
24 | address = root + "/" + file
25 | print("address = ", address)
26 | six_thousand_lst[index].append(address)
27 |
28 | with open(output_path + "/files_dump.pkl", "wb") as f:
29 | pickle.dump(six_thousand_lst, f)
30 |
31 |
32 | def load_pkl(pkl_dir):
33 | with open(pkl_dir + "/files_dump.pkl", "rb") as f:
34 | files_name_lst = pickle.load(f)
35 | print(files_name_lst[0])
36 | print(max(len(files_name) for files_name in files_name_lst))
37 | print(min(len(files_name) for files_name in files_name_lst))
38 |
39 | img = cv2.imread(files_name_lst[0][0], cv2.IMREAD_GRAYSCALE)
40 | print(transforms.ToTensor()(img).dtype)
41 | print(type(img))
42 | print(img.shape)
43 | cv2.imshow("test", img)
44 | cv2.waitKey(0)
45 |
46 |
47 | def construct_pos_neg_samples(train_num):
48 | with open("../../data/files_dump.pkl", "rb") as f:
49 | files_name_lst = pickle.load(f)
50 | with open("../traditional_method/pos_samples.txt", "w+") as txt:
51 | # pos samples
52 | pos_num = 0
53 | for i in range(train_num):
54 | for j in range(len(files_name_lst[i])):
55 | # self-self is positive samples
56 | for k in range(j, len(files_name_lst[i])):
57 | pos_string = files_name_lst[i][j] + " " + files_name_lst[i][k] + " 1\n"
58 | txt.writelines(pos_string)
59 | pos_num += 1
60 | print("there are ", pos_num, "positive samples")
61 |
62 | # neg samples
63 | neg_num = 0
64 | for i in range(train_num):
65 | for neg_i in range(len(files_name_lst[i])):
66 | for j in range(i + 1, train_num):
67 | for neg_j in range(len(files_name_lst[j])):
68 | neg_num += 1
69 | neg_string = files_name_lst[i][neg_i] + " " + files_name_lst[j][neg_j] + " 0\n"
70 | txt.writelines(neg_string)
71 | print("there are ", neg_num, "negative samples")
72 |
73 |
74 | def construct_pos_neg_samples_test():
75 | with open("../../data/files_dump.pkl", "rb") as f:
76 | files_name_lst = pickle.load(f)
77 | files_name_lst_test = files_name_lst[5000:6000]
78 |
79 | with open("../traditional_method/pos_samples.txt", "w+") as txt:
80 | # pos samples
81 | pos_num = 0
82 | total_num = 0
83 | for i in range(len(files_name_lst_test)):
84 | total_num += len(files_name_lst_test[i])
85 | for j in range(len(files_name_lst_test[i])):
86 | # self-self is positive samples
87 | for k in range(j, len(files_name_lst_test[i])):
88 | pos_string = files_name_lst_test[i][j] + " " + files_name_lst_test[i][k] + " 0\n"
89 | txt.writelines(pos_string)
90 | pos_num += 1
91 | print("total num = ", total_num)
92 | print("there are ", pos_num, "positive samples")
93 |
94 | # neg samples
95 | neg_num = 0
96 | for i in range(1000):
97 | for neg_i in range(len(files_name_lst_test[i])):
98 | for j in range(i + 1, 1000):
99 | for neg_j in range(len(files_name_lst_test[j])):
100 | neg_num += 1
101 | neg_string = files_name_lst_test[i][neg_i] + " " + files_name_lst_test[j][neg_j] + " 1\n"
102 | txt.writelines(neg_string)
103 | if neg_num > 10000:
104 | break
105 | print("there are ", neg_num, "negative samples")
106 |
107 |
108 | def pre_process(file_path):
109 | # Load the image grayscale
110 | img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
111 | # Get rid of the excess pixels
112 | # img = img[2:-4, 2:-4]
113 | # all the images are of the same size (96 * 96)
114 | img = cv2.resize(img, (96, 96))
115 | return img
116 |
117 |
118 | def traversal_total_dir(input_path, output_path):
119 | for root, dirs, files in os.walk(input_path):
120 | for name in files:
121 | address = os.path.join(root, name)
122 | print("address = ", address)
123 | out_img = pre_process(address)
124 | cv2.imwrite(output_path + name, out_img)
125 |
126 |
127 | def plot_precision_recall_curve(address):
128 | with open(address, "rb+") as f:
129 | dis_label_lst = pickle.load(f)
130 | print(len(dis_label_lst))
131 | dis_total = dis_label_lst[0][0]
132 | label_total = dis_label_lst[0][0]
133 |
134 | for i, dis_label in enumerate(dis_label_lst):
135 | if i == 0:
136 | continue
137 | dis = dis_label[0]
138 | label = dis_label[1]
139 | dis_total = torch.cat([dis_total, dis], dim=0)
140 | label_total = torch.cat([label_total, label], dim=0)
141 | label_total = label_total.numpy().astype(int)
142 | dis_total = dis_total.numpy()
143 | max_dis = np.max(dis_total)
144 | # dis_total = 1 - dis_total / max_dis
145 | precision, recall, thresh = precision_recall_curve(label_total, dis_total, pos_label=1)
146 | print(precision)
147 | print(recall)
148 | print(thresh)
149 | # print((1 - thresh) * max_dis)
150 | plt.figure(1) # 创建图表1
151 | plt.title('Precision/Recall Curve') # give plot a title
152 | plt.xlabel('Recall') # make axis labels
153 | plt.ylabel('Precision')
154 | plt.plot(recall, precision)
155 | plt.show()
156 |
157 |
158 | if __name__ == "__main__":
159 | # dump_files("../../data/process", "../../data/")
160 | # load_pkl("../../data")
161 | # construct_pos_neg_samples_test()
162 | plot_precision_recall_curve("dis_label.pkl")
--------------------------------------------------------------------------------
/src/traditional_method/evalutaion.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import precision_recall_curve
2 | import pickle
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import itertools
6 |
7 |
8 | def plot_precision_recall_curve(dst_lst_address):
9 | with open(dst_lst_address, "rb+") as f:
10 | dis_total = pickle.load(f)
11 | print(len(dis_total))
12 | with open("pos_samples.txt", "r") as f:
13 | sample_lst = f.readlines()
14 | label_total = []
15 | for sample in sample_lst:
16 | label_total.append(int(sample.strip().split(" ")[-1]))
17 |
18 | label_total = np.array(label_total)
19 | dis_total = -np.array(dis_total) # 越大越为正
20 | precision, recall, thresh = precision_recall_curve(label_total, dis_total, pos_label=0)
21 | print("precision = ", precision)
22 | print("recall = ", recall)
23 | print("thresh = ", thresh)
24 | for i, value in enumerate(recall):
25 | if recall[i] > 0.85:
26 | print(thresh[i])
27 | break
28 | # print((1 - thresh) * max_dis)
29 | plt.figure(1) # 创建图表1
30 | plt.title('Precision/Recall Curve') # give plot a title
31 | plt.xlabel('Recall') # make axis labels
32 | plt.ylabel('Precision')
33 | plt.plot(recall, precision)
34 | plt.show()
35 |
36 |
37 | def cal_confusion_matrix(predicted_label_address):
38 | with open("pos_samples.txt", "r") as f:
39 | sample_lst = f.readlines()
40 | label_total = []
41 | for sample in sample_lst:
42 | label_total.append(int(sample.strip().split(" ")[-1]))
43 | actual_label = np.array(label_total)
44 | conf_matrix = np.zeros([2, 2])
45 | with open(predicted_label_address, "rb+") as f:
46 | predicted_label_address = np.array(pickle.load(f))
47 | for i in range(len(predicted_label_address)):
48 | conf_matrix[int(actual_label[i]), int(predicted_label_address[i])] += 1
49 | plot_confusion_matrix(conf_matrix, ["positive", "negative"])
50 |
51 |
52 |
53 |
54 | def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
55 | """
56 | plots the confusion matrix.
57 | Normalization can be applied by setting `normalize=True`.
58 | Input
59 | - cm : confusion matrix
60 | - classes : 混淆矩阵中每一行每一列对应的列
61 | - normalize : True:percentage, False:Num
62 | """
63 | if normalize:
64 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
65 | print("Normalized confusion matrix")
66 | else:
67 | print('Confusion matrix, without normalization')
68 | print(cm)
69 | plt.imshow(cm, interpolation='nearest', cmap=cmap)
70 | plt.title(title)
71 | plt.colorbar()
72 | tick_marks = np.arange(len(classes))
73 | plt.xticks(tick_marks, classes)
74 | plt.yticks(tick_marks, classes)
75 |
76 | plt.axis("equal")
77 | ax = plt.gca() # 获得当前axis
78 | left, right = plt.xlim() # 获得x轴最大最小值
79 | ax.spines['left'].set_position(('data', left))
80 | ax.spines['right'].set_position(('data', right))
81 | for edge_i in ['top', 'bottom', 'right', 'left']:
82 | ax.spines[edge_i].set_edgecolor("white")
83 | thresh = cm.max() / 2.
84 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
85 | num = '{:.2f}'.format(cm[i, j]) if normalize else int(cm[i, j])
86 | plt.text(j, i, num,
87 | verticalalignment='center',
88 | horizontalalignment="center",
89 | color="white" if num > thresh else "black")
90 | plt.tight_layout()
91 | plt.ylabel('Actual label')
92 | plt.xlabel('Predict label')
93 | plt.show()
94 |
95 |
96 | if __name__ == "__main__":
97 | plot_precision_recall_curve("dist_lst.pkl")
98 | # cal_confusion_matrix("predicted.pkl")
--------------------------------------------------------------------------------
/src/traditional_method/frequest.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import scipy.ndimage
4 |
5 | def frequest(im,orientim,windsze,minWaveLength,maxWaveLength):
6 | rows,cols = np.shape(im)
7 |
8 | cosorient = np.mean(np.cos(2*orientim))
9 | sinorient = np.mean(np.sin(2*orientim))
10 | orient = math.atan2(sinorient,cosorient)/2
11 |
12 | rotim = scipy.ndimage.rotate(im,orient/np.pi*180 + 90,axes=(1,0),reshape = False,order = 3,mode = 'nearest');
13 |
14 | cropsze = int(np.fix(rows/np.sqrt(2)))
15 | offset = int(np.fix((rows-cropsze)/2))
16 | rotim = rotim[offset:offset+cropsze][:,offset:offset+cropsze]
17 |
18 | proj = np.sum(rotim,axis = 0)
19 | dilation = scipy.ndimage.grey_dilation(proj, windsze,structure=np.ones(windsze));
20 |
21 | temp = np.abs(dilation - proj)
22 |
23 | peak_thresh = 2;
24 |
25 | maxpts = (temp np.mean(proj));
26 | maxind = np.where(maxpts)
27 |
28 | rows_maxind,cols_maxind = np.shape(maxind)
29 |
30 | if(cols_maxind<2):
31 | freqim = np.zeros(im.shape)
32 | else:
33 | NoOfPeaks = cols_maxind
34 | waveLength = (maxind[0][cols_maxind-1] - maxind[0][0])/(NoOfPeaks - 1)
35 | if waveLength>=minWaveLength and waveLength<=maxWaveLength:
36 | freqim = 1/np.double(waveLength) * np.ones(im.shape);
37 | else:
38 | freqim = np.zeros(im.shape)
39 |
40 | return(freqim)
41 |
--------------------------------------------------------------------------------
/src/traditional_method/image_enhance.py:
--------------------------------------------------------------------------------
1 | from .ridge_segment import ridge_segment
2 | from .ridge_orient import ridge_orient
3 | from .ridge_freq import ridge_freq
4 | from .ridge_filter import ridge_filter
5 |
6 |
7 | def image_enhance(img):
8 | blksze = 16
9 | thresh = 0.1
10 | normim, mask = ridge_segment(img, blksze, thresh)
11 |
12 |
13 | gradientsigma = 1
14 | blocksigma = 7
15 | orientsmoothsigma = 7
16 | orientim = ridge_orient(normim, gradientsigma, blocksigma, orientsmoothsigma)
17 |
18 |
19 | blksze = 38
20 | windsze = 5
21 | minWaveLength = 5
22 | maxWaveLength = 15
23 | freq, medfreq = ridge_freq(normim, mask, orientim, blksze, windsze, minWaveLength, maxWaveLength)
24 |
25 |
26 | freq = medfreq * mask
27 | # print(medfreq)
28 | kx = 0.65
29 | ky = 0.65
30 | newim = ridge_filter(normim, orientim, freq, kx, ky)
31 |
32 | return newim < -3
--------------------------------------------------------------------------------
/src/traditional_method/main.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import sys
4 | import numpy
5 | import matplotlib.pyplot as plt
6 | import src.traditional_method.image_enhance as image_enhance
7 | from skimage.morphology import skeletonize, thin
8 | import numpy as np
9 | import pickle
10 |
11 |
12 | address_lst = os.listdir("../../data/orb_pkl/")
13 | name_set = set(address_lst)
14 |
15 |
16 | def removedot(invertThin):
17 | temp0 = numpy.array(invertThin[:])
18 | temp0 = numpy.array(temp0)
19 | temp1 = temp0/255
20 | temp2 = numpy.array(temp1)
21 | temp3 = numpy.array(temp2)
22 | enhanced_img = numpy.array(temp0)
23 | filter0 = numpy.zeros((10,10))
24 | W, H = temp0.shape[:2]
25 | filtersize = 6
26 |
27 | for i in range(W - filtersize):
28 | for j in range(H - filtersize):
29 | filter0 = temp1[i:i + filtersize, j:j + filtersize]
30 |
31 | flag = 0
32 | if sum(filter0[:, 0]) == 0:
33 | flag += 1
34 | if sum(filter0[:, filtersize - 1]) == 0:
35 | flag += 1
36 | if sum(filter0[0, :]) == 0:
37 | flag += 1
38 | if sum(filter0[filtersize - 1, :]) == 0:
39 | flag += 1
40 | if flag > 3:
41 | temp2[i:i + filtersize, j:j + filtersize] = numpy.zeros((filtersize, filtersize))
42 |
43 | return temp2
44 |
45 |
46 | def get_descriptors(img):
47 | # origin_img = img
48 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
49 | img = clahe.apply(img)
50 | img = image_enhance.image_enhance(img)
51 |
52 | img = numpy.array(img, dtype=numpy.uint8)
53 | # Threshold
54 | ret, img = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)
55 | # Normalize to 0 and 1 range
56 | img[img == 255] = 1
57 |
58 | # plot_img = np.hstack((origin_img, img * 255))
59 | # cv2.imshow("img", plot_img)
60 | # cv2.waitKey(0)
61 | # cv2.imwrite("enhance.bmp", plot_img)
62 |
63 | # Thinning
64 | # skeleton = skeletonize(img)
65 | # skeleton = numpy.array(skeleton, dtype=numpy.uint8)
66 | # skeleton = removedot(skeleton)
67 |
68 | # Harris corners
69 | harris_corners = cv2.cornerHarris(img, 3, 3, 0.04)
70 | harris_normalized = cv2.normalize(harris_corners, 0, 255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32FC1)
71 | threshold_harris = 125
72 | # Extract keypoints
73 | keypoints = []
74 | for x in range(0, harris_normalized.shape[0]):
75 | for y in range(0, harris_normalized.shape[1]):
76 | if harris_normalized[x][y] > threshold_harris:
77 | keypoints.append(cv2.KeyPoint(y, x, 1))
78 | # Define descriptor
79 | orb = cv2.ORB_create()
80 | # Compute descriptors
81 | _, des = orb.compute(img, keypoints)
82 | return keypoints, des
83 |
84 |
85 | def load_predicted_pkl():
86 | with open("predicted.pkl", "rb+") as f:
87 | lst = pickle.load(f)
88 | print(lst)
89 |
90 | def load_dist_lst_pkl():
91 | with open("dist_lst.pkl", "rb+") as f:
92 | dst_lst = pickle.load(f)
93 | print(dst_lst)
94 |
95 |
96 | def main(txt_address):
97 | predicted_lst = []
98 | with open(txt_address, "r+") as f:
99 | samples_lst = f.readlines()
100 | total_num = 0
101 | true_num = 0
102 | dist_lst = []
103 | for i, samples in enumerate(samples_lst):
104 | # if i > 10:
105 | # break
106 | lst = samples.strip().split(" ")
107 | name_1 = lst[0]
108 | name_2 = lst[1]
109 | is_match, avg = match(name_1, name_2)
110 | dist_lst.append(avg)
111 | if is_match:
112 | # pos samples
113 | predicted_lst.append(0)
114 | if int(lst[2]) == 0:
115 | true_num += 1
116 | else:
117 | predicted_lst.append(1)
118 | if int(lst[2]) == 1:
119 | true_num += 1
120 | total_num += 1
121 | print("total:true = ", total_num, true_num)
122 | with open("predicted.pkl", "wb+") as f:
123 | pickle.dump(predicted_lst, f)
124 | with open("dist_lst.pkl", "wb+") as f:
125 | pickle.dump(dist_lst, f)
126 |
127 |
128 | def match(image_name1, image_name2):
129 | name1 = image_name1.split("/")[-1].split(".")[0] + ".pkl"
130 | if name1 not in name_set:
131 | name_set.add(name1)
132 | img1 = cv2.imread(image_name1, cv2.IMREAD_GRAYSCALE)
133 | img1 = cv2.resize(img1, dsize=(245, 372))
134 | kp1, des1 = get_descriptors(img1)
135 | with open("../../data/orb_pkl/{}".format(name1), "wb+") as f:
136 | pickle.dump(des1, f)
137 | else:
138 | with open("../../data/orb_pkl/{}".format(name1), "rb+") as f:
139 | des1 = pickle.load(f)
140 | name2 = image_name2.split("/")[-1].split(".")[0] + ".pkl"
141 |
142 | if name2 not in name_set:
143 | name_set.add(name2)
144 | img2 = cv2.imread(image_name2, cv2.IMREAD_GRAYSCALE)
145 | img2 = cv2.resize(img2, dsize=(245, 372))
146 | kp2, des2 = get_descriptors(img2)
147 | with open("../../data/orb_pkl/{}".format(name2), "wb+") as f:
148 | pickle.dump(des2, f)
149 | else:
150 | with open("../../data/orb_pkl/{}".format(name2), "rb+") as f:
151 | des2 = pickle.load(f)
152 | # Matching between descriptors
153 | # Brute force match
154 | bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
155 | matches = sorted(bf.match(des1, des2), key=lambda match: match.distance)
156 | # Plot keypoints
157 | # img4 = cv2.drawKeypoints(img1, kp1, outImage=None)
158 | # img5 = cv2.drawKeypoints(img2, kp2, outImage=None)
159 | # f, axarr = plt.subplots(1, 2)
160 | # axarr[0].imshow(img4)
161 | # axarr[1].imshow(img5)
162 | # plt.show()
163 | # Plot matches
164 | # img3 = cv2.drawMatches(img1, kp1, img2, kp2, matches, flags=2, outImg=None)
165 | # plt.imshow(img3)
166 | # plt.show()
167 |
168 | # Calculate score
169 | score = 0
170 | for match in matches:
171 | score += match.distance
172 | score_threshold = 2.78
173 | #print(score)
174 | #print(len(matches))
175 | avg = score / len(matches)
176 | if avg < score_threshold:
177 | # print("Fingerprint matches.")
178 | return True, avg
179 | else:
180 | # print("Fingerprint does not match.")
181 | return False, avg
182 |
183 |
184 | if __name__ == "__main__":
185 | main("pos_samples.txt")
186 | # load_pkl()
187 |
188 |
--------------------------------------------------------------------------------
/src/traditional_method/ridge_filter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 |
4 |
5 | def ridge_filter(im, orient, freq, kx, ky):
6 | angleInc = 3
7 | im = np.double(im)
8 | rows, cols = im.shape
9 | newim = np.zeros((rows, cols))
10 |
11 | freq_1d = np.reshape(freq, (1, rows*cols))
12 | ind = np.where(freq_1d > 0)
13 |
14 | ind = np.array(ind)
15 | ind = ind[1, :]
16 |
17 | non_zero_elems_in_freq = freq_1d[0][ind]
18 | non_zero_elems_in_freq = np.double(np.round((non_zero_elems_in_freq*100)))/100
19 |
20 | unfreq = np.unique(non_zero_elems_in_freq)
21 | # print(ind)
22 |
23 | sigmax = 1/unfreq[0]*kx
24 | sigmay = 1/unfreq[0]*ky
25 |
26 | sze = int(np.round(3*np.max([sigmax, sigmay])))
27 |
28 | x, y = np.meshgrid(np.linspace(-sze, sze, (2*sze + 1)), np.linspace(-sze, sze, (2*sze + 1)))
29 |
30 | reffilter = np.exp(-(( (np.power(x,2))/(sigmax*sigmax) + (np.power(y,2))/(sigmay*sigmay)))) * np.cos(2*np.pi*unfreq[0]*x)
31 |
32 | filt_rows, filt_cols = reffilter.shape
33 |
34 | gabor_filter = np.array(np.zeros((int(180 / angleInc), int(filt_rows), int(filt_cols))))
35 |
36 | for o in range(0, int(180/angleInc)):
37 |
38 | rot_filt = scipy.ndimage.rotate(reffilter, -(o*angleInc + 90), reshape=False)
39 | gabor_filter[o] = rot_filt
40 |
41 | maxsze = int(sze)
42 |
43 | temp = freq > 0
44 | validr, validc = np.where(temp)
45 |
46 | temp1 = validr > maxsze
47 | temp2 = validr < rows - maxsze
48 | temp3 = validc > maxsze
49 | temp4 = validc < cols - maxsze
50 |
51 | final_temp = temp1 & temp2 & temp3 & temp4
52 |
53 | finalind = np.where(final_temp)
54 |
55 | maxorientindex = np.round(180/angleInc)
56 | orientindex = np.round(orient/np.pi*180/angleInc)
57 |
58 | for i in range(0, rows):
59 | for j in range(0, cols):
60 | if(orientindex[i][j] < 1):
61 | orientindex[i][j] = orientindex[i][j] + maxorientindex
62 | if(orientindex[i][j] > maxorientindex):
63 | orientindex[i][j] = orientindex[i][j] - maxorientindex
64 | finalind_rows, finalind_cols = np.shape(finalind)
65 | sze = int(sze)
66 | for k in range(0, finalind_cols):
67 | r = validr[finalind[0][k]]
68 | c = validc[finalind[0][k]]
69 |
70 | img_block = im[r-sze:r+sze + 1][:, c-sze:c+sze + 1]
71 |
72 | newim[r][c] = np.sum(img_block * gabor_filter[int(orientindex[r][c]) - 1])
73 |
74 | return newim
--------------------------------------------------------------------------------
/src/traditional_method/ridge_freq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .frequest import frequest
3 |
4 |
5 | def ridge_freq(im, mask, orient, blksze, windsze,minWaveLength, maxWaveLength):
6 | rows, cols = im.shape
7 | freq = np.zeros((rows, cols))
8 |
9 | for r in range(0, rows-blksze, blksze):
10 | for c in range(0, cols-blksze, blksze):
11 | blkim = im[r:r+blksze][:, c:c+blksze]
12 | blkor = orient[r:r+blksze][:, c:c+blksze]
13 | freq[r:r+blksze][:, c:c+blksze] = frequest(blkim, blkor, windsze, minWaveLength,
14 | maxWaveLength)
15 |
16 | freq = freq*mask
17 | freq_1d = np.reshape(freq, (1, rows*cols))
18 | ind = np.where(freq_1d > 0)
19 |
20 | ind = np.array(ind)
21 | ind = ind[1, :]
22 |
23 | non_zero_elems_in_freq = freq_1d[0][ind]
24 | # print("non_zero_elems_in_freq = ", freq_1d)
25 |
26 | meanfreq = np.mean(non_zero_elems_in_freq)
27 | medianfreq = np.median(non_zero_elems_in_freq)
28 | return freq, meanfreq
29 |
--------------------------------------------------------------------------------
/src/traditional_method/ridge_orient.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from scipy import ndimage
4 | from scipy import signal
5 |
6 | def ridge_orient(im, gradientsigma, blocksigma, orientsmoothsigma):
7 | rows,cols = im.shape
8 | sze = np.fix(6*gradientsigma)
9 | if np.remainder(sze,2) == 0:
10 | sze = sze+1
11 | gauss = cv2.getGaussianKernel(np.int(sze),gradientsigma)
12 | f = gauss * gauss.T
13 |
14 | fy,fx = np.gradient(f);
15 |
16 | Gx = signal.convolve2d(im,fx,mode='same')
17 | Gy = signal.convolve2d(im,fy,mode='same')
18 |
19 | Gxx = np.power(Gx,2)
20 | Gyy = np.power(Gy,2)
21 | Gxy = Gx*Gy
22 |
23 | sze = np.fix(6*blocksigma)
24 |
25 | gauss = cv2.getGaussianKernel(np.int(sze), blocksigma)
26 | f = gauss * gauss.T
27 |
28 | Gxx = ndimage.convolve(Gxx, f)
29 | Gyy = ndimage.convolve(Gyy, f)
30 | Gxy = 2*ndimage.convolve(Gxy, f)
31 |
32 | denom = np.sqrt(np.power(Gxy, 2) + np.power((Gxx - Gyy), 2)) + np.finfo(float).eps
33 |
34 | sin2theta = Gxy/denom
35 | cos2theta = (Gxx-Gyy)/denom
36 |
37 |
38 | if orientsmoothsigma:
39 | sze = np.fix(6*orientsmoothsigma)
40 | if np.remainder(sze,2) == 0:
41 | sze = sze+1
42 | gauss = cv2.getGaussianKernel(np.int(sze),orientsmoothsigma)
43 | f = gauss * gauss.T
44 | cos2theta = ndimage.convolve(cos2theta,f)
45 | sin2theta = ndimage.convolve(sin2theta,f)
46 |
47 | orientim = np.pi/2 + np.arctan2(sin2theta,cos2theta)/2
48 | return(orientim)
--------------------------------------------------------------------------------
/src/traditional_method/ridge_segment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def normalise(img, mean, std):
5 | normed = (img - np.mean(img))/(np.std(img))
6 | return normed
7 |
8 |
9 | def ridge_segment(im, blksze, thresh):
10 | rows, cols = im.shape
11 | im = normalise(im, 0, 1)
12 | new_rows = np.int(blksze * np.ceil((np.float(rows))/(np.float(blksze))))
13 | new_cols = np.int(blksze * np.ceil((np.float(cols))/(np.float(blksze))))
14 |
15 | padded_img = np.zeros((new_rows, new_cols))
16 | stddevim = np.zeros((new_rows, new_cols))
17 |
18 | padded_img[0:rows][:, 0:cols] = im
19 |
20 | for i in range(0, new_rows, blksze):
21 | for j in range(0, new_cols, blksze):
22 | block = padded_img[i:i+blksze][:, j:j+blksze]
23 |
24 | stddevim[i:i+blksze][:, j:j+blksze] = np.std(block)*np.ones(block.shape)
25 |
26 | stddevim = stddevim[0:rows][:,0:cols]
27 |
28 | mask = stddevim > thresh
29 |
30 | mean_val = np.mean(im[mask])
31 |
32 | std_val = np.std(im[mask])
33 |
34 | normim = (im - mean_val) / std_val
35 |
36 | return normim, mask
37 |
--------------------------------------------------------------------------------
/traditional-method.md:
--------------------------------------------------------------------------------
1 | # 传统指纹识别方法
2 |
3 | ## 一、介绍
4 |
5 | 传统指纹识别算法步骤为:
6 |
7 | * 指纹图像增强
8 |
9 | * 二值化
10 | * 细化
11 | * 提取特征点,计算特征点描述子
12 | * 匹配
13 |
14 | 这五步符合传统模式识别的标准流程——数据预处理、提取特征、分类。
15 |
16 |
17 |
18 | ## 二、具体过程
19 |
20 | 首先对指纹图像进行归一化,然后进行脊线方向场计算,进行Gabor滤波增强,二值化得到二值图像,最终效果如下图所示:
21 |
22 | 
23 |
24 | 原图与二值化图像对比
25 |
26 |
27 |
28 | 接着从二值化图像中提取Harris角点,并计算Harris角点的ORB特征描述子,并进行特征匹配。最后使用距离判别条件——若匹配特征描述子之间的平均距离小于一个阈值则认为匹配成功,否则则认为失败。
29 |
30 | 
31 |
32 | 特征点匹配
33 |
34 | ## 三、结果
35 |
36 | 用测试数据(9982正样本对,15964个负样本对)对模型进行评估,得到PR曲线如下所示。可以看到随着查全率(recall)的增长,查准率(precision)维持在1附近,当查全率接近1附近时查准率降到75%。说明用Harris角点的ORB特征描述子可以较好的区别正负样本,正负样本之间只有少量的交叠。
37 |
38 | 
39 |
40 | 取匹配点描述子平均Hamming距离为2.78作为正样本和负样本之间的判别边界(平均距离小于2.78为正样本,否则为负样本),得到混淆矩阵如下所示
41 |
42 | 
43 |
44 | 其中Accuracy = 79.04%, Precision = 99.98%, Recall = 45.53%, F1=62.57%。
45 |
--------------------------------------------------------------------------------