├── .gitignore ├── .style.yapf ├── polyu ├── __init__.py ├── aligned_images.py ├── preprocess.py ├── description.py └── detection.py ├── validate ├── __init__.py ├── eer.py ├── sift.py ├── roc.py ├── dpf.py ├── description.py └── matching.py ├── cpu-requirements.txt ├── gpu-requirements.txt ├── LICENSE ├── unit_test ├── augmentation.py ├── roc.py ├── sift_patch.py ├── rank_n.py ├── find_correspondences.py ├── pairwise_distances.py └── restore_model.py ├── matching.py ├── batch_detect_pores.py ├── recognize.py ├── models ├── detection.py └── description.py ├── train.py ├── align.py ├── README.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pdf 3 | log/ 4 | *.txt 5 | *.png 6 | env 7 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | indent_width = 2 4 | -------------------------------------------------------------------------------- /polyu/__init__.py: -------------------------------------------------------------------------------- 1 | import polyu.detection 2 | import polyu.description 3 | -------------------------------------------------------------------------------- /validate/__init__.py: -------------------------------------------------------------------------------- 1 | import validate.description 2 | import validate.matching 3 | -------------------------------------------------------------------------------- /cpu-requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | astor==0.7.1 3 | gast==0.2.0 4 | grpcio==1.15.0 5 | Markdown==3.0 6 | numpy==1.15.2 7 | opencv-contrib-python==3.4.0.12 8 | protobuf==3.6.1 9 | six==1.11.0 10 | tensorboard==1.10.0 11 | tensorflow==1.10.1 12 | termcolor==1.1.0 13 | Werkzeug==0.14.1 14 | -------------------------------------------------------------------------------- /gpu-requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | astor==0.7.1 3 | gast==0.2.0 4 | grpcio==1.15.0 5 | Markdown==3.0 6 | numpy==1.15.2 7 | opencv-contrib-python==3.4.0.12 8 | protobuf==3.6.1 9 | six==1.11.0 10 | tensorboard==1.10.0 11 | tensorflow-gpu==1.10.1 12 | termcolor==1.1.0 13 | Werkzeug==0.14.1 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Gabriel Dahia Fernandes, Mauricio Pamplona Segundo, Universidade Federal da Bahia, Bahia, Brazil 2 | 3 | This work is licensed under the Attribution-NonCommercial-ShareAlike 4.0 International License. 4 | 5 | To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ 6 | 7 | This program is distributed in the hope that it will be useful, 8 | but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 10 | -------------------------------------------------------------------------------- /validate/eer.py: -------------------------------------------------------------------------------- 1 | import utils 2 | 3 | if __name__ == '__main__': 4 | import sys 5 | 6 | # read comparisons file 7 | for path in sys.argv[1:]: 8 | if path.endswith('.txt'): 9 | pos = [] 10 | neg = [] 11 | with open(path, 'r') as f: 12 | for line in f: 13 | t, score = line.split() 14 | score = float(score) 15 | if int(t) == 1: 16 | pos.append(score) 17 | else: 18 | neg.append(score) 19 | 20 | print(path, utils.eer(pos, neg)) 21 | -------------------------------------------------------------------------------- /unit_test/augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | import utils 4 | import polyu 5 | 6 | 7 | def minibatch_transformation(dataset): 8 | patches_pl, labels_pl = utils.placeholder_inputs() 9 | feed_dict = utils.fill_feed_dict( 10 | dataset.train, patches_pl, labels_pl, 36, augment=True) 11 | 12 | for patch in feed_dict[patches_pl]: 13 | cv2.imshow('patch', patch) 14 | cv2.waitKey(0) 15 | 16 | 17 | if __name__ == '__main__': 18 | import sys 19 | 20 | print('Loading dataset...') 21 | dataset = polyu.description.Dataset(sys.argv[1]) 22 | print('Done') 23 | 24 | minibatch_transformation(dataset) 25 | -------------------------------------------------------------------------------- /validate/sift.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import polyu 4 | import utils 5 | from validate.matching import validation_eer 6 | 7 | FLAGS = None 8 | 9 | 10 | def main(): 11 | print('Loading dataset...') 12 | dataset = polyu.description.Dataset(FLAGS.dataset_path).val 13 | print('Done') 14 | 15 | # compute eer 16 | print('EER = {}'.format(validation_eer(dataset, utils.sift_descriptors))) 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | '--dataset_path', required=True, type=str, help='path to dataset') 23 | FLAGS = parser.parse_args() 24 | 25 | main() 26 | -------------------------------------------------------------------------------- /unit_test/roc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import utils 5 | 6 | 7 | def random_roc(): 8 | # random data should give 0.5 eer 9 | # and random roc curve 10 | pos = np.random.random(1000) 11 | neg = np.random.random(1000) 12 | 13 | # compare eer versions 14 | print(utils.eer(pos, neg)) 15 | 16 | # plot curve 17 | fars, frrs = utils.roc(pos, neg) 18 | plt.plot(fars, frrs) 19 | plt.show() 20 | 21 | 22 | def separable_roc(): 23 | # separable data should give low 24 | # eer and convex roc curve 25 | pos = np.random.normal(1, 0.5, 1000) 26 | neg = np.random.normal(0, 0.5, 1000) 27 | 28 | # compare eer versions 29 | print(utils.eer(pos, neg)) 30 | 31 | # plot curve 32 | fars, frrs = utils.roc(pos, neg) 33 | plt.plot(fars, frrs) 34 | plt.show() 35 | 36 | 37 | if __name__ == '__main__': 38 | random_roc() 39 | separable_roc() 40 | -------------------------------------------------------------------------------- /validate/roc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | 4 | import utils 5 | 6 | if __name__ == '__main__': 7 | # parse arguments 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | '--files', 11 | required=True, 12 | type=str, 13 | nargs='+', 14 | help='files containing results to plot') 15 | parser.add_argument( 16 | '--xrange', 17 | default=[0, 0.1], 18 | type=float, 19 | nargs=2, 20 | help='range to plot x axis') 21 | parser.add_argument( 22 | '--yrange', 23 | default=[0, 0.1], 24 | type=float, 25 | nargs=2, 26 | help='range to plot y axis') 27 | 28 | flags = parser.parse_args() 29 | 30 | for path in flags.files: 31 | if path.endswith('.txt'): 32 | # read comparisons file 33 | pos = [] 34 | neg = [] 35 | with open(path, 'r') as f: 36 | for line in f: 37 | t, score = line.split() 38 | score = float(score) 39 | if int(t) == 1: 40 | pos.append(score) 41 | else: 42 | neg.append(score) 43 | 44 | # compute roc 45 | fars, frrs = utils.roc(pos, neg) 46 | 47 | # plot roc 48 | plt.plot(fars, frrs, label=path) 49 | 50 | plt.legend(loc='upper right') 51 | plt.xlabel('FAR') 52 | plt.ylabel('FRR') 53 | plt.axis(flags.xrange + flags.yrange) 54 | plt.grid() 55 | plt.show() 56 | -------------------------------------------------------------------------------- /unit_test/sift_patch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | import utils 5 | 6 | 7 | def sift_patch(img_path, pts_path, scale=4): 8 | # load image 9 | img = cv2.imread(img_path, 0) 10 | 11 | # load pts 12 | pts = utils.load_dets_txt(pts_path) 13 | 14 | # find closes pore to center 15 | center = np.array(img.shape) / 2 16 | sqr_dist = np.sum((pts - center)**2, axis=1) 17 | closest_ind = np.argmin(sqr_dist) 18 | pt = pts[closest_ind][::-1] 19 | 20 | # improve image quality with median blur and clahe 21 | img = cv2.medianBlur(img, ksize=3) 22 | clahe = cv2.createCLAHE(clipLimit=3) 23 | img = clahe.apply(img) 24 | 25 | # extract original descriptor 26 | kpt = cv2.KeyPoint.convert([pt], size=scale) 27 | sift = cv2.xfeatures2d.SIFT_create() 28 | _, original = sift.compute(img, kpt) 29 | 30 | # test patch sizes 31 | for i in range(3, 113): 32 | # extract patch centered in 'pt' 33 | patch = img[pt[1] - i:pt[0] + i + 1, pt[0] - i:pt[0] + i + 1] 34 | 35 | # extract patch keypoint 36 | patch_kpt = cv2.KeyPoint.convert([(i, i)], size=scale) 37 | _, patched = sift.compute(patch, patch_kpt) 38 | 39 | if np.isclose(np.linalg.norm(original - patched), 0): 40 | return i 41 | 42 | return -1 43 | 44 | 45 | if __name__ == '__main__': 46 | import sys 47 | patch_size = sift_patch(sys.argv[1], sys.argv[2]) 48 | assert patch_size > 0 49 | print('[OK - Sift Patch ({})]'.format(2 * patch_size + 1)) 50 | -------------------------------------------------------------------------------- /validate/dpf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import utils 5 | 6 | 7 | def validate(pores_by_image, detections_by_image): 8 | # find correspondences between detections and pores 9 | total_pores = 0 10 | total_dets = 0 11 | true_dets = 0 12 | for i, pores in enumerate(pores_by_image): 13 | dets = detections_by_image[i] 14 | 15 | # update totals 16 | total_pores += len(pores) 17 | total_dets += len(dets) 18 | true_dets += len(utils.find_correspondences(pores, dets)) 19 | 20 | # compute tdr, fdr and f score 21 | eps = 1e-12 22 | tdr = true_dets / (total_pores + eps) 23 | fdr = (total_dets - true_dets) / (total_dets + eps) 24 | f_score = 2 * (tdr * (1 - fdr)) / (tdr + 1 - fdr) 25 | 26 | print('TDR = {}'.format(tdr)) 27 | print('FDR = {}'.format(fdr)) 28 | print('F score = {}'.format(f_score)) 29 | 30 | 31 | def load_keypoints_from_txt(txt_folder_path): 32 | keypoints_by_image = [] 33 | for txt_path in sorted(os.listdir(txt_folder_path)): 34 | if txt_path.endswith('.txt'): 35 | keypoints = [] 36 | with open(os.path.join(txt_folder_path, txt_path)) as f: 37 | for line in f: 38 | x, y = [int(j) for j in line.split()] 39 | keypoints.append((x, y)) 40 | keypoints_by_image.append(np.array(keypoints)) 41 | 42 | return keypoints_by_image 43 | 44 | 45 | def main(pores_txt_folder_path, dets_txt_folder_path): 46 | pores = load_keypoints_from_txt(pores_txt_folder_path) 47 | dets = load_keypoints_from_txt(dets_txt_folder_path) 48 | validate(pores, dets) 49 | 50 | 51 | if __name__ == '__main__': 52 | import sys 53 | if len(sys.argv) < 3: 54 | print('Insufficient arguments') 55 | 56 | main(sys.argv[1], sys.argv[2]) 57 | -------------------------------------------------------------------------------- /unit_test/rank_n.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import utils 5 | 6 | 7 | def random_rank(size): 8 | # random data should give 1/n_labels 9 | # rank-1 and linear curve 10 | instances = np.random.random((size, 128)) 11 | labels = np.random.randint(100, size=size) 12 | 13 | # get ranks and plot 14 | ranks = utils.rank_n(instances, labels, 100) 15 | print(ranks[0]) 16 | plt.plot(ranks) 17 | plt.show() 18 | 19 | 20 | def discernible_rank(size): 21 | # generate balanced labels 22 | labels = np.r_[:4] 23 | labels = np.repeat(labels, size // 4) 24 | 25 | # generate instances 26 | instances = [] 27 | cov = np.diag(np.repeat(0.1, 2)) 28 | for label in labels: 29 | # create mean of normal 30 | mean = np.zeros(2, dtype=np.float32) 31 | mean[label // 2] += (-1)**(1 + (label % 2)) 32 | 33 | # create instance 34 | instance = np.random.multivariate_normal(mean, cov) 35 | instances.append(instance) 36 | instances = np.array(instances) 37 | 38 | # get ranks and plot 39 | ranks = utils.rank_n(instances, labels, 100) 40 | print(ranks[0]) 41 | plt.plot(ranks) 42 | plt.show() 43 | 44 | 45 | def perfect_rank(size): 46 | # generate balanced labels 47 | labels = np.r_[:4] 48 | labels = np.repeat(labels, size // 4) 49 | 50 | # generate instances 51 | instances = [] 52 | for label in labels: 53 | # create instance 54 | instance = np.zeros(2, dtype=np.float32) 55 | instance[label // 2] += (-1)**(1 + (label % 2)) 56 | instances.append(instance) 57 | instances = np.array(instances) 58 | 59 | # get ranks and plot 60 | ranks = utils.rank_n(instances, labels, 100) 61 | print(ranks[0]) 62 | plt.plot(ranks) 63 | plt.show() 64 | 65 | 66 | if __name__ == '__main__': 67 | size = 1000 68 | random_rank(size) 69 | discernible_rank(size) 70 | perfect_rank(size) 71 | -------------------------------------------------------------------------------- /unit_test/find_correspondences.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import utils 4 | 5 | 6 | def random_recovery(): 7 | # get random instances 8 | instances1 = np.random.random((100, 32)) 9 | instances2 = np.random.random((100, 32)) 10 | 11 | # find correspondences 12 | pairs = utils.find_correspondences(instances1, instances2) 13 | 14 | # check uniqueness 15 | seen_indices1 = set() 16 | seen_indices2 = set() 17 | for i, j, _ in pairs: 18 | if i in seen_indices1 or j in seen_indices2: 19 | return False 20 | 21 | seen_indices1.add(i) 22 | seen_indices2.add(j) 23 | 24 | return True 25 | 26 | 27 | def perfect_recovery(): 28 | # every instance has a single perfect match 29 | instances = [] 30 | for i in range(32): 31 | instance = np.zeros(32, dtype=np.float32) 32 | instance[i] = 1 33 | instances.append(instance) 34 | 35 | # find correspondences 36 | instances = np.array(instances) 37 | pairs = utils.find_correspondences(instances, instances) 38 | 39 | # check correctness 40 | for i, j, d in pairs: 41 | if i != j or d != 0: 42 | return False 43 | 44 | return True 45 | 46 | 47 | def zero_recovery(): 48 | # every instance in one side has the same match 49 | # in the other side, leading to a single pair 50 | instances1 = [] 51 | for i in range(32): 52 | instance = np.zeros(32, dtype=np.float32) 53 | instance[i] = 1 54 | instances1.append(instance) 55 | 56 | instances2 = [instance + 1 for instance in instances1] 57 | instances2.append(np.zeros(32, dtype=np.float32)) 58 | 59 | # find correspondences 60 | instances1 = np.array(instances1) 61 | instances2 = np.array(instances2) 62 | pairs = utils.find_correspondences(instances1, instances2) 63 | 64 | if len(pairs) != 1: 65 | return False 66 | if pairs[0][1] != len(instances2) - 1: 67 | return False 68 | if pairs[0][2] != 1: 69 | return False 70 | 71 | return True 72 | 73 | 74 | if __name__ == '__main__': 75 | assert random_recovery() 76 | print('[OK - Random Recovery]') 77 | 78 | assert perfect_recovery() 79 | print('[OK - Perfect Recovery]') 80 | 81 | assert zero_recovery() 82 | print('[OK - Zero Recovery]') 83 | -------------------------------------------------------------------------------- /matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | import utils 5 | 6 | 7 | def basic(descs1, descs2, pts1=None, pts2=None, thr=None): 8 | ''' 9 | Finds bidirectional correspondences between descriptors 10 | descs1 and descs2. If thr is provided, discards correspondences 11 | that fail a distance ratio check with threshold thr; in this 12 | case, returns correspondences satisfying SIFT's criterion. 13 | 14 | Args: 15 | descs1: [N, M] array of N descriptors of dimension M each. 16 | descs2: [N, M] array of N descriptors of dimension M each. 17 | pts1: sentinel argument for matching function signature 18 | standardization. 19 | pts2: sentinel argument for matching function signature 20 | standardization. 21 | thr: distance ratio check threshold. 22 | 23 | Returns: 24 | number of found bidirectional correspondences. If thr is not 25 | None, number of bidirectional correspondences that satisfy 26 | a distance ratio check. 27 | ''' 28 | if len(descs1) == 0 or len(descs2) == 0: 29 | return 0 30 | 31 | return len(utils.find_correspondences(descs1, descs2, thr=thr)) 32 | 33 | 34 | def spatial(descs1, descs2, pts1, pts2, thr=None): 35 | ''' 36 | Computes the matching score proposed by Pamplona Segundo & 37 | Lemes (Pore-based ridge reconstruction for fingerprint 38 | recognition, 2015) using bidirectional correspondences 39 | between descriptors descs1 and descs2. 40 | If thr is provided, correspondences that fail a distance 41 | ratio check with threshold thr are discarded. 42 | 43 | Args: 44 | descs1: [N, M] array of N descriptors of dimension M each. 45 | descs2: [N, M] array of N descriptors of dimension M each. 46 | pts1: [N, 2] array of coordinates from which each descriptor 47 | of descs1 was computed. 48 | pts1: [N, 2] array of coordinates from which each descriptor 49 | of descs2 was computed. 50 | thr: distance ratio check threshold. 51 | 52 | Returns: 53 | matching score between descs1 and descs2. 54 | ''' 55 | if len(descs1) == 0 or len(descs2) == 0: 56 | return 0 57 | 58 | pairs = utils.find_correspondences(descs1, descs2, thr=thr) 59 | 60 | pts1 = np.array(pts1) 61 | pts2 = np.array(pts2) 62 | score = 0 63 | for pair1, pair2 in itertools.combinations(pairs, 2): 64 | d1 = np.linalg.norm(pts1[pair1[0]] - pts1[pair2[0]]) 65 | d2 = np.linalg.norm(pts2[pair1[1]] - pts2[pair2[1]]) 66 | score += 1 / (1 + abs(d1 - d2)) 67 | 68 | return score 69 | -------------------------------------------------------------------------------- /validate/description.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import combinations 3 | 4 | import utils 5 | 6 | 7 | def _dataset_descriptors(patches_pl, session, descs_op, dataset, batch_size): 8 | ''' 9 | Computes descriptors with descs_op for the entire dataset, in batches of 10 | size batch_size. 11 | 12 | Args: 13 | patches_pl: patch input tf placeholder for descs_op. 14 | session: tf session with descs_op variables loaded. 15 | descs_op: tf op for describing patches in patches_pl. 16 | dataset: dataset for which descriptors will be computed. 17 | batch_size: size of batch to describe patches from dataset. 18 | 19 | Returns: 20 | descs: computed descriptors. 21 | labels: corresponding labels. 22 | ''' 23 | # extracting descriptors for entire dataset 24 | descs = [] 25 | labels = [] 26 | prev_epoch = dataset.epochs 27 | while prev_epoch == dataset.epochs: 28 | # sample next batch 29 | patches, batch_labels = dataset.next_batch(batch_size) 30 | feed_dict = {patches_pl: np.expand_dims(patches, axis=-1)} 31 | 32 | # describe batch 33 | batch_descs = session.run(descs_op, feed_dict=feed_dict) 34 | 35 | # add to overall 36 | descs.extend(batch_descs) 37 | labels.extend(batch_labels) 38 | 39 | # convert to np array and remove extra dims 40 | descs = np.squeeze(descs) 41 | labels = np.squeeze(labels) 42 | 43 | return descs, labels 44 | 45 | 46 | def dataset_eer(patches_pl, session, descs_op, dataset, batch_size): 47 | ''' 48 | Computes the Equal Error Rate (EER) of the descriptors computed with 49 | descs_op in the dataset dataset. 50 | 51 | Args: 52 | patches_pl: patch input tf placeholder for descs_op. 53 | session: tf session with descs_op variables loaded. 54 | descs_op: tf op for describing patches in patches_pl. 55 | dataset: dataset for which descriptors will be computed. 56 | batch_size: size of batch to describe patches from dataset. 57 | 58 | Returns: 59 | the computed EER. 60 | ''' 61 | # extracting descriptors for entire dataset 62 | descs, labels = _dataset_descriptors(patches_pl, session, descs_op, dataset, 63 | batch_size) 64 | 65 | # get pairwise comparisons 66 | examples = zip(descs, labels) 67 | pos = [] 68 | neg = [] 69 | for (desc1, label1), (desc2, label2) in combinations(examples, 2): 70 | dist = -np.sum((desc1 - desc2)**2) 71 | if label1 == label2: 72 | pos.append(dist) 73 | else: 74 | neg.append(dist) 75 | 76 | # compute eer 77 | eer = utils.eer(pos, neg) 78 | 79 | return eer 80 | 81 | 82 | def dataset_rank_1(patches_pl, session, descs_op, dataset, batch_size, 83 | sample_size): 84 | # extracting descriptors for entire dataset 85 | descs, labels = _dataset_descriptors(patches_pl, session, descs_op, dataset, 86 | batch_size) 87 | 88 | # compute ranks 89 | ranks = utils.rank_n(descs, labels, sample_size) 90 | 91 | return ranks[0] 92 | -------------------------------------------------------------------------------- /batch_detect_pores.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | import os 4 | 5 | import utils 6 | from models import detection 7 | 8 | FLAGS = None 9 | 10 | 11 | def batch_detect(load_path, save_path, detect_fn): 12 | ''' 13 | Detects pores in all images in directory load_path 14 | using detect_fn and saves corresponding detections 15 | in save_path. 16 | 17 | Args: 18 | load_path: path to load image from. 19 | save_path: path to save detections. Will be created 20 | if non existent. 21 | detect_fn: function that receives an image and 22 | returns an array of shape [N, 2] of detections. 23 | ''' 24 | # load images from 'load_path' 25 | images, names = utils.load_images_with_names(load_path) 26 | 27 | # create 'save_path' directory tree 28 | if not os.path.exists(save_path): 29 | os.makedirs(save_path) 30 | 31 | # detect in each image and save it 32 | for image, name in zip(images, names): 33 | # detect pores 34 | detections = detect_fn(image) 35 | 36 | # save results 37 | filename = os.path.join(save_path, '{}.txt'.format(name)) 38 | utils.save_dets_txt(detections, filename) 39 | 40 | 41 | def main(): 42 | half_patch_size = FLAGS.patch_size // 2 43 | 44 | with tf.Graph().as_default(): 45 | image_pl, _ = utils.placeholder_inputs() 46 | 47 | print('Building graph...') 48 | net = detection.Net(image_pl, training=False) 49 | print('Done') 50 | 51 | with tf.Session() as sess: 52 | print('Restoring model in {}...'.format(FLAGS.model_dir_path)) 53 | utils.restore_model(sess, FLAGS.model_dir_path) 54 | print('Done') 55 | 56 | # capture arguments in lambda function 57 | def detect_pores(image): 58 | return utils.detect_pores(image, image_pl, net.predictions, 59 | half_patch_size, FLAGS.prob_thr, 60 | FLAGS.inter_thr, sess) 61 | 62 | # batch detect in dbi training 63 | print('Detecting pores in PolyU-HRF DBI Training images...') 64 | load_path = os.path.join(FLAGS.polyu_dir_path, 'DBI', 'Training') 65 | save_path = os.path.join(FLAGS.results_dir_path, 'DBI', 'Training') 66 | batch_detect(load_path, save_path, detect_pores) 67 | print('Done') 68 | 69 | # batch detect in dbi test 70 | print('Detecting pores in PolyU-HRF DBI Test images...') 71 | load_path = os.path.join(FLAGS.polyu_dir_path, 'DBI', 'Test') 72 | save_path = os.path.join(FLAGS.results_dir_path, 'DBI', 'Test') 73 | batch_detect(load_path, save_path, detect_pores) 74 | print('Done') 75 | 76 | # batch detect in dbii 77 | print('Detecting pores in PolyU-HRF DBII images...') 78 | load_path = os.path.join(FLAGS.polyu_dir_path, 'DBII') 79 | save_path = os.path.join(FLAGS.results_dir_path, 'DBII') 80 | batch_detect(load_path, save_path, detect_pores) 81 | print('Done') 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument( 87 | '--polyu_dir_path', 88 | required=True, 89 | type=str, 90 | help='path to PolyU-HRF dataset') 91 | parser.add_argument( 92 | '--model_dir_path', 93 | type=str, 94 | required=True, 95 | help='path from which to restore trained model') 96 | parser.add_argument( 97 | '--patch_size', type=int, default=17, help='pore patch size') 98 | parser.add_argument( 99 | '--results_dir_path', 100 | type=str, 101 | default='result', 102 | help='path to folder in which results should be saved') 103 | parser.add_argument( 104 | '--prob_thr', 105 | type=float, 106 | default=0.9, 107 | help='probability threshold to filter detections') 108 | parser.add_argument( 109 | '--inter_thr', 110 | type=float, 111 | default=0.1, 112 | help='nms intersection threshold') 113 | 114 | FLAGS = parser.parse_args() 115 | 116 | main() 117 | -------------------------------------------------------------------------------- /unit_test/pairwise_distances.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import utils 4 | 5 | 6 | def _close(D1, D2): 7 | return np.any(np.isclose(D1, D2)) 8 | 9 | 10 | def _naive_pairwise_distances(x1, x2): 11 | D = [] 12 | for x in x1: 13 | row = [] 14 | for y in x2: 15 | dist = np.sum((x - y)**2) 16 | row.append(dist) 17 | D.append(row) 18 | return np.array(D) 19 | 20 | 21 | def random_vectors(): 22 | # random vectors 23 | x1 = np.random.random((100, 32)) 24 | x2 = np.random.random((100, 32)) 25 | 26 | # compute pairwise distances 27 | D1 = utils.pairwise_distances(x1, x2) 28 | D2 = _naive_pairwise_distances(x1, x2) 29 | 30 | # check shape 31 | if D1.shape != D2.shape: 32 | return False 33 | 34 | # compare naive approach with employed 35 | if not _close(D1, D2): 36 | return False 37 | 38 | # check if it is non-negative 39 | if not _close(D1, np.abs(D1)): 40 | return False 41 | 42 | return True 43 | 44 | 45 | def orthogonal_vectors(): 46 | # get set of random vectors 47 | x1 = np.random.random((100, 32)) 48 | norm_x1 = [x / np.linalg.norm(x) for x in x1] 49 | 50 | # get set of vectors orthogonal to x1 51 | x2 = np.random.random((100, 32)) 52 | x2 = [x - np.dot(x, y) * y for x, y in zip(x2, norm_x1)] 53 | 54 | # compute pairwise distances 55 | x1 = np.array(x1) 56 | x2 = np.array(x2) 57 | D1 = utils.pairwise_distances(x1, x2) 58 | D2 = _naive_pairwise_distances(x1, x2) 59 | 60 | # check shape 61 | if D1.shape != D2.shape: 62 | return False 63 | 64 | # compare naive approach with employed 65 | if not _close(D1, D2): 66 | return False 67 | 68 | # check if it is non-negative 69 | if not _close(D1, np.abs(D1)): 70 | return False 71 | 72 | # check if distance is close to sum of squares of magnitudes 73 | for i, x in enumerate(x1): 74 | y = x2[i] 75 | dist = np.sum(x**2) + np.sum(y**2) 76 | if not _close(dist, D1[i, i]): 77 | return False 78 | 79 | return True 80 | 81 | 82 | def unitary_vectors(): 83 | # get set of random unitary vectors 84 | x1 = np.random.random((100, 32)) 85 | x1 = [x / np.linalg.norm(x) for x in x1] 86 | 87 | # get origin vector 88 | x2 = np.zeros((1, 32)) 89 | 90 | # compute pairwise distances 91 | x1 = np.array(x1) 92 | x2 = np.array(x2) 93 | D1 = utils.pairwise_distances(x1, x2) 94 | D2 = _naive_pairwise_distances(x1, x2) 95 | 96 | # check shape 97 | if D1.shape != D2.shape: 98 | return False 99 | 100 | # compare naive approach with employed 101 | if not _close(D1, D2): 102 | return False 103 | 104 | # check if it is non-negative 105 | if not _close(D1, np.abs(D1)): 106 | return False 107 | 108 | # check if distance is close to 1 109 | for row in D1: 110 | for d in row: 111 | if not _close(d, 1): 112 | return False 113 | 114 | return True 115 | 116 | 117 | def same_vectors(): 118 | # single set of random vectors 119 | x1 = np.random.random((100, 32)) 120 | 121 | # compute pairwise distances 122 | D1 = utils.pairwise_distances(x1, x1) 123 | D2 = _naive_pairwise_distances(x1, x1) 124 | 125 | # check shape 126 | if D1.shape != D2.shape: 127 | return False 128 | 129 | # compare naive approach with employed 130 | if not _close(D1, D2): 131 | return False 132 | 133 | # check if it is non-negative 134 | if not _close(D1, np.abs(D1)): 135 | return False 136 | 137 | # check if main diagonal is zero 138 | for d in np.diag(D1): 139 | if not _close(d, 0): 140 | return False 141 | 142 | # check if it is symmetrical 143 | for i, row in enumerate(D1): 144 | for j, d in enumerate(row): 145 | if not _close(d, D1[j, i]): 146 | return False 147 | 148 | return True 149 | 150 | 151 | if __name__ == '__main__': 152 | assert random_vectors() 153 | print('[OK - Random Vectors]') 154 | 155 | assert orthogonal_vectors() 156 | print('[OK - Orthogonal Vectors]') 157 | 158 | assert unitary_vectors() 159 | print('[OK - Unitary Vectors]') 160 | 161 | assert same_vectors() 162 | print('[OK - Same Vectors]') 163 | -------------------------------------------------------------------------------- /polyu/aligned_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import align 4 | 5 | 6 | class _Transf: 7 | def __init__(self, s, A, b): 8 | self._s = s 9 | self._A = A 10 | self._b = b 11 | 12 | def __call__(self, x): 13 | return self._s * np.dot(x, self._A.T) + self._b 14 | 15 | 16 | def _find_alignments(all_imgs, all_pts): 17 | # find alignment between first and every other image 18 | img1 = all_imgs[0] 19 | transfs = [] 20 | for i, img2 in enumerate(all_imgs[1:]): 21 | # find minimum mse alignment iteratively 22 | A, b, s = align.iterative(img1, all_pts[0], img2, all_pts[i + 1]) 23 | 24 | # create mapping from img1 coordinates to img2 coordinates 25 | transfs.append(_Transf(s, A, b)) 26 | 27 | return transfs 28 | 29 | 30 | def _compute_valid_region(all_imgs, transfs, patch_size): 31 | # find patches in img1 that are in all images 32 | img1 = all_imgs[0] 33 | valid = np.ones_like(img1, dtype=np.bool) 34 | half = patch_size // 2 35 | for k, img2 in enumerate(all_imgs[1:]): 36 | # find pixels in both img1 and img2 37 | aligned = np.zeros_like(img1, dtype=np.bool) 38 | for i in range(half, valid.shape[0] - half): 39 | for j in range(half, valid.shape[1] - half): 40 | row, col = transfs[k]((i, j)) 41 | row = int(np.round(row)) 42 | col = int(np.round(col)) 43 | if 0 <= row - half < img2.shape[0] and 0 <= row + half < img2.shape[0]: 44 | if 0 <= col - half < img2.shape[1] and 0 <= col + half < img2.shape[1]: 45 | aligned[i, j] = True 46 | 47 | # update overall valid positions 48 | valid = np.logical_and(valid, aligned) 49 | 50 | return valid 51 | 52 | 53 | class Handler: 54 | ''' 55 | Handles alignable images, allowing the extraction of square patches of size 56 | patch_size in the coordinates specified by the user in every image. It does 57 | this by aligning all other images to the first one. Patches are then 58 | accessed sequentially via the Handler.__getitem__ method. 59 | ''' 60 | 61 | def __init__(self, all_imgs, all_pts, patch_size): 62 | ''' 63 | Aligns all images in all_imgs to the all_imgs[0] using the keypoints 64 | all_pts. Discards keypoints whose corresponding patches would be 65 | out of the images overlap. 66 | 67 | Args: 68 | all_imgs: alignable images. 69 | all_pts: keypoints in every image of all_imgs. 70 | patch_size: size of square patch to be extracted. 71 | ''' 72 | self._imgs = all_imgs 73 | self._patch_size = patch_size 74 | self._half = patch_size // 2 75 | 76 | # align all images to the first 77 | self._transfs = _find_alignments(self._imgs, all_pts) 78 | 79 | # find valid area for extracting patches 80 | mask = _compute_valid_region(self._imgs, self._transfs, self._patch_size) 81 | 82 | # get valid indices and store them for access 83 | self._inds = [] 84 | for pt in all_pts[0]: 85 | if mask[pt[0], pt[1]]: 86 | self._inds.append(pt) 87 | 88 | def __getitem__(self, val): 89 | if isinstance(val, slice): 90 | raise TypeError('Slicing indexing is not supported') 91 | else: 92 | # retrieve coordinates for given index 93 | i, j = self._inds[val] 94 | 95 | # adjust for odd patch sizes 96 | odd = 1 if self._patch_size % 2 != 0 else 0 97 | 98 | # add first image patch 99 | samples = [ 100 | self._imgs[0][i - self._half:i + self._half + odd, j - self._half:j + 101 | self._half + odd] 102 | ] 103 | 104 | # add remaining image patches 105 | for k, img in enumerate(self._imgs[1:]): 106 | # find transformed coordinates of patch 107 | ti, tj = self._transfs[k]((i, j)) 108 | 109 | # convert them to int 110 | ti = int(np.round(ti)) 111 | tj = int(np.round(tj)) 112 | 113 | # add to overall 114 | samples.append(img[ti - self._half:ti + self._half + odd, tj - 115 | self._half:tj + self._half + odd]) 116 | 117 | return np.array(samples) 118 | 119 | def __len__(self): 120 | return len(self._inds) 121 | -------------------------------------------------------------------------------- /recognize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import utils 4 | import matching 5 | from models import detection, description 6 | 7 | FLAGS = None 8 | 9 | 10 | def detect_pores(imgs): 11 | with tf.Graph().as_default(): 12 | # placeholder for image 13 | image_pl, _ = utils.placeholder_inputs() 14 | 15 | # build detection net 16 | print('Building detection net graph...') 17 | det_net = detection.Net(image_pl, training=False) 18 | print('Done') 19 | 20 | with tf.Session() as sess: 21 | print('Restoring detection model in {}...'.format(FLAGS.det_model_dir)) 22 | utils.restore_model(sess, FLAGS.det_model_dir) 23 | print('Done') 24 | 25 | # capture detection arguments in function 26 | def single_detect_pores(image): 27 | return utils.detect_pores( 28 | image, image_pl, det_net.predictions, FLAGS.det_patch_size // 2, 29 | FLAGS.det_prob_thr, FLAGS.nms_inter_thr, sess) 30 | 31 | # detect pores 32 | dets = [single_detect_pores(img) for img in imgs] 33 | 34 | return dets 35 | 36 | 37 | def describe_detections(imgs, dets): 38 | with tf.Graph().as_default(): 39 | # placeholder for image 40 | image_pl, _ = utils.placeholder_inputs() 41 | 42 | # build description net 43 | print('Building description net graph...') 44 | desc_net = description.Net(image_pl, training=False) 45 | print('Done') 46 | 47 | with tf.Session() as sess: 48 | print('Restoring description model in {}...'.format( 49 | FLAGS.desc_model_dir)) 50 | utils.restore_model(sess, FLAGS.desc_model_dir) 51 | print('Done') 52 | 53 | # capture description arguments in function 54 | def compute_descriptors(image, dets): 55 | return utils.trained_descriptors(image, dets, FLAGS.desc_patch_size, 56 | sess, image_pl, desc_net.descriptors) 57 | 58 | # compute descriptors 59 | descs = [] 60 | new_dets = [] 61 | for img, img_dets in zip(imgs, dets): 62 | img_descs, img_new_dets = compute_descriptors(img, img_dets) 63 | descs.append(img_descs) 64 | new_dets.append(img_new_dets) 65 | 66 | return descs, new_dets 67 | 68 | 69 | def main(): 70 | # load images 71 | imgs = [utils.load_image(path) for path in FLAGS.img_paths] 72 | 73 | dets = detect_pores(imgs) 74 | 75 | tf.reset_default_graph() 76 | 77 | descs, dets = describe_detections(imgs, dets) 78 | 79 | score = matching.basic(descs[0], descs[1], thr=0.7) 80 | print('similarity score = {}'.format(score)) 81 | if score > FLAGS.score_thr: 82 | print('genuine pair') 83 | else: 84 | print('impostor pair') 85 | 86 | 87 | if __name__ == '__main__': 88 | import argparse 89 | 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument( 92 | '--img_paths', 93 | required=True, 94 | type=str, 95 | nargs=2, 96 | help='path to images to be recognized') 97 | parser.add_argument( 98 | '--det_model_dir', 99 | required=True, 100 | type=str, 101 | help='path to pore detection trained model') 102 | parser.add_argument( 103 | '--desc_model_dir', 104 | required=True, 105 | type=str, 106 | help='path to pore description trained model') 107 | parser.add_argument( 108 | '--score_thr', 109 | default=2, 110 | type=int, 111 | help='score threshold to determine if pair is genuine or impostor') 112 | parser.add_argument( 113 | '--det_patch_size', default=17, type=int, help='detection patch size') 114 | parser.add_argument( 115 | '--det_prob_thr', 116 | default=0.9, 117 | type=float, 118 | help='probability threshold for discarding detections') 119 | parser.add_argument( 120 | '--nms_inter_thr', 121 | default=0.1, 122 | type=float, 123 | help='NMS area intersection threshold') 124 | parser.add_argument( 125 | '--desc_patch_size', 126 | default=32, 127 | type=int, 128 | help='patch size around each detected keypoint to describe') 129 | 130 | FLAGS = parser.parse_args() 131 | 132 | main() 133 | -------------------------------------------------------------------------------- /models/detection.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Net: 5 | ''' 6 | Pore detection model. 7 | 8 | Net.predictions is the model's output op. It has shape 9 | [batch_size, N - 16, M - 16, 1] for an input with shape 10 | [batch_size, N, M, 1]. Net.loss and Net.train are 11 | respectively the model's loss op and training step op, 12 | if built with Net.build_loss and Net.build_train. 13 | ''' 14 | 15 | def __init__(self, 16 | inputs, 17 | dropout_rate=None, 18 | reuse=tf.AUTO_REUSE, 19 | training=True, 20 | scope='detection'): 21 | ''' 22 | Args: 23 | inputs: input placeholder of shape [None, None, None, 1]. 24 | dropout_rate: if None, applies no dropout. Otherwise, 25 | dropout rate of second to last layer is dropout_rate. 26 | reuse: whether to reuse net variables in scope. 27 | training: whether net is training. Required for dropout 28 | and batch normalization. 29 | scope: model's variable scope. 30 | ''' 31 | with tf.variable_scope(scope, reuse=reuse): 32 | # reduction convolutions 33 | net = inputs 34 | filters_ls = [32, 64, 128] 35 | i = 1 36 | for filters in filters_ls: 37 | # ith conv layer 38 | net = tf.layers.conv2d( 39 | net, 40 | filters=filters, 41 | kernel_size=3, 42 | strides=1, 43 | padding='valid', 44 | activation=tf.nn.relu, 45 | use_bias=False, 46 | name='conv_{}'.format(i), 47 | reuse=reuse) 48 | net = tf.layers.batch_normalization( 49 | net, training=training, name='batchnorm_{}'.format(i), reuse=reuse) 50 | net = tf.layers.max_pooling2d( 51 | net, pool_size=3, strides=1, name='maxpool_{}'.format(i)) 52 | 53 | i += 1 54 | 55 | # dropout 56 | if dropout_rate is not None: 57 | net = tf.layers.dropout(net, rate=dropout_rate, training=training) 58 | 59 | # logits 60 | net = tf.layers.conv2d( 61 | net, 62 | filters=1, 63 | kernel_size=5, 64 | strides=1, 65 | padding='valid', 66 | activation=None, 67 | use_bias=False, 68 | name='conv_{}'.format(i), 69 | reuse=reuse) 70 | net = tf.layers.batch_normalization( 71 | net, training=training, name='batchnorm_{}'.format(i), reuse=reuse) 72 | self.logits = tf.identity(net, name='logits') 73 | 74 | # build prediction op 75 | self.predictions = tf.nn.sigmoid(self.logits) 76 | 77 | def build_loss(self, labels): 78 | ''' 79 | Builds the model's loss node op. The loss is the 80 | cross-entropy between the model's predictions and 81 | the labels. 82 | 83 | Args: 84 | labels: labels placeholder of shape [batch_size, 1]. 85 | 86 | Returns: 87 | the loss node op. 88 | ''' 89 | # reshape labels to be compatible with logits 90 | labels = tf.reshape(labels, tf.shape(self.logits)) 91 | 92 | # cross entropy loss 93 | xentropy = tf.nn.sigmoid_cross_entropy_with_logits( 94 | labels=labels, logits=self.logits, name='xentropy') 95 | self.loss = tf.reduce_mean(xentropy, name='xentropy_mean') 96 | 97 | return self.loss 98 | 99 | def build_train(self, learning_rate): 100 | ''' 101 | Builds the model's training step op. It minimizes the 102 | model's loss with Stochastic Gradient Descent (SGD) with 103 | an exponentially decayed learning rate and updates the means 104 | and variances of batch normalization. 105 | 106 | Args: 107 | learning_rate: initial SGD learning rate. 108 | ''' 109 | global_step = tf.Variable(1, name='global_step', trainable=False) 110 | learning_rate = tf.train.exponential_decay( 111 | learning_rate, 112 | global_step, 113 | decay_rate=0.96, 114 | decay_steps=2000, 115 | staircase=True) 116 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 117 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 118 | with tf.control_dependencies(update_ops): 119 | self.train = optimizer.minimize(self.loss, global_step=global_step) 120 | 121 | return self.train 122 | -------------------------------------------------------------------------------- /unit_test/restore_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | import utils 5 | from models import description 6 | from models import detection 7 | 8 | 9 | def restore_description(): 10 | # create network graph 11 | inputs, _ = utils.placeholder_inputs() 12 | net = description.Net(inputs) 13 | 14 | # save random weights and keep them 15 | # in program's memory for comparison 16 | vars_ = [] 17 | saver = tf.train.Saver() 18 | with tf.Session() as sess: 19 | # initialize variables 20 | sess.run(tf.global_variables_initializer()) 21 | 22 | # assign random values to variables 23 | # and save those values for comparison 24 | for var in sorted(tf.global_variables(), key=lambda x: x.name): 25 | # create random values for variable 26 | var_val = np.random.random(var.shape) 27 | 28 | # save for later comparison 29 | vars_.append(var_val) 30 | 31 | # assign it to tf var 32 | assign = tf.assign(var, var_val) 33 | sess.run(assign) 34 | 35 | # save initialized model 36 | saver.save(sess, '/tmp/description/model.ckpt', global_step=0) 37 | 38 | # create new session to restore saved weights 39 | with tf.Session() as sess: 40 | # make new initialization of weights 41 | sess.run(tf.global_variables_initializer()) 42 | 43 | # assert weights are different 44 | i = 0 45 | for var in sorted(tf.global_variables(), key=lambda x: x.name): 46 | # get new var val 47 | var_val = sess.run(var) 48 | 49 | # compare with old one 50 | assert not np.isclose(np.sum(np.abs(var_val - vars_[i])), 0) 51 | 52 | i += 1 53 | 54 | # restore model 55 | utils.restore_model(sess, '/tmp/description') 56 | 57 | # check if weights are equal 58 | i = 0 59 | for var in sorted(tf.global_variables(), key=lambda x: x.name): 60 | # get new var val 61 | var_val = sess.run(var) 62 | 63 | # compare with old one 64 | if ~np.any(np.isclose(var_val, vars_[i])): 65 | print(np.isclose(var_val, vars_[i])) 66 | print('Failed to load variable "{}"'.format(var.name)) 67 | return False 68 | 69 | i += 1 70 | 71 | return True 72 | 73 | 74 | def restore_detection(): 75 | # create network graph 76 | inputs, _ = utils.placeholder_inputs() 77 | net = detection.Net(inputs) 78 | 79 | # save random weights and keep them 80 | # in program's memory for comparison 81 | vars_ = [] 82 | saver = tf.train.Saver() 83 | with tf.Session() as sess: 84 | # initialize variables 85 | sess.run(tf.global_variables_initializer()) 86 | 87 | # assign random values to variables 88 | # and save those values for comparison 89 | for var in sorted(tf.global_variables(), key=lambda x: x.name): 90 | # create random values for variable 91 | var_val = np.random.random(var.shape) 92 | 93 | # save for later comparison 94 | vars_.append(var_val) 95 | 96 | # assign it to tf var 97 | assign = tf.assign(var, var_val) 98 | sess.run(assign) 99 | 100 | # save initialized model 101 | saver.save(sess, '/tmp/detection/model.ckpt', global_step=0) 102 | 103 | # create new session to restore saved weights 104 | with tf.Session() as sess: 105 | # make new initialization of weights 106 | sess.run(tf.global_variables_initializer()) 107 | 108 | # assert weights are different 109 | i = 0 110 | for var in sorted(tf.global_variables(), key=lambda x: x.name): 111 | # get new var val 112 | var_val = sess.run(var) 113 | 114 | # compare with old one 115 | assert not np.isclose(np.sum(np.abs(var_val - vars_[i])), 0) 116 | 117 | i += 1 118 | 119 | # restore model 120 | utils.restore_model(sess, '/tmp/detection') 121 | 122 | # check if weights are equal 123 | i = 0 124 | for var in sorted(tf.global_variables(), key=lambda x: x.name): 125 | # get new var val 126 | var_val = sess.run(var) 127 | 128 | # compare with old one 129 | if ~np.any(np.isclose(var_val, vars_[i])): 130 | print(np.isclose(var_val, vars_[i])) 131 | print('Failed to load variable "{}"'.format(var.name)) 132 | return False 133 | 134 | i += 1 135 | 136 | return True 137 | 138 | 139 | if __name__ == '__main__': 140 | assert restore_description() 141 | print('[OK - Description Model Restoration]') 142 | 143 | tf.reset_default_graph() 144 | 145 | assert restore_detection() 146 | print('[OK - Detection Model Restoration]') 147 | -------------------------------------------------------------------------------- /models/description.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Net: 5 | ''' 6 | Adapted HardNet (Working hard to know your neighbor's margins: 7 | Local descriptor learning loss, 2017) model. 8 | 9 | Differs from HardNet in its loss function, using instead 10 | triplet semi-hard loss (FaceNet: A Unified Embedding for 11 | Face Recognition and Clustering, 2015). 12 | 13 | Net.descriptors is the model's output op. It has shape 14 | [batch_size, 128]. Net.loss and Net.train are respectively the 15 | model's loss op and training step op, if built with 16 | Net.build_loss and Net.build_train. 17 | ''' 18 | 19 | def __init__(self, 20 | inputs, 21 | dropout_rate=None, 22 | reuse=tf.AUTO_REUSE, 23 | training=True, 24 | scope='description'): 25 | ''' 26 | Args: 27 | inputs: input placeholder of shape [None, None, None, 1]. 28 | dropout_rate: if None, applies no dropout. Otherwise, 29 | dropout rate of second to last layer is dropout_rate. 30 | reuse: whether to reuse net variables in scope. 31 | training: whether net is training. Required for dropout 32 | and batch normalization. 33 | scope: model's variable scope. 34 | ''' 35 | self.loss = None 36 | self.train = None 37 | self.validation = None 38 | 39 | # capture scope 40 | self.scope = scope 41 | 42 | with tf.variable_scope(scope, reuse=reuse): 43 | # conv layers 44 | net = inputs 45 | filters_ls = [32, 32, 64, 64, 128, 128] 46 | strides_ls = [1, 1, 2, 1, 2, 1] 47 | i = 1 48 | for filters, strides in zip(filters_ls, strides_ls): 49 | net = tf.layers.conv2d( 50 | net, 51 | filters=filters, 52 | kernel_size=3, 53 | strides=strides, 54 | padding='same', 55 | activation=tf.nn.relu, 56 | use_bias=False, 57 | name='conv_{}'.format(i), 58 | reuse=reuse) 59 | net = tf.layers.batch_normalization( 60 | net, training=training, name='batchnorm_{}'.format(i), reuse=reuse) 61 | 62 | i += 1 63 | 64 | # dropout 65 | if dropout_rate is not None: 66 | net = tf.layers.dropout(net, rate=dropout_rate, training=training) 67 | 68 | # last conv layer 69 | net = tf.layers.conv2d( 70 | net, 71 | filters=128, 72 | kernel_size=8, 73 | strides=1, 74 | padding='valid', 75 | activation=None, 76 | use_bias=False, 77 | name='conv_{}'.format(i), 78 | reuse=reuse) 79 | net = tf.layers.batch_normalization( 80 | net, training=training, name='batchnorm_{}'.format(i), reuse=reuse) 81 | 82 | # descriptors 83 | spatial_descriptors = tf.nn.l2_normalize( 84 | net, axis=-1, name='spatial_descriptors') 85 | self.descriptors = tf.reshape( 86 | spatial_descriptors, [-1, 128], name='descriptors') 87 | 88 | def build_loss(self, labels, decay_weight=None): 89 | ''' 90 | Builds the model's loss node op. If decay_weight is None, 91 | loss is simply the triplet semi-hard loss of the descriptors. 92 | Otherwise, it is the sum of the triplet semi-hard loss with 93 | an L2 weight decay regularization with weight decay_weight. 94 | 95 | Args: 96 | labels: labels placeholder of shape [batch_size, 1]. 97 | decay_weight: if not None, weight of L2 weight decay 98 | regularization. 99 | 100 | Returns: 101 | the loss node op. 102 | ''' 103 | with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): 104 | with tf.name_scope('loss'): 105 | # triplet loss 106 | labels = tf.reshape(labels, (-1, )) 107 | self.loss = tf.contrib.losses.metric_learning.triplet_semihard_loss( 108 | labels, self.descriptors) 109 | 110 | # weight decay 111 | if decay_weight is not None: 112 | weight_decay = 0 113 | for var in tf.trainable_variables(self.scope): 114 | if 'kernel' in var.name: 115 | weight_decay += tf.nn.l2_loss(var) 116 | self.loss += decay_weight * weight_decay 117 | 118 | return self.loss 119 | 120 | def build_train(self, learning_rate): 121 | ''' 122 | Builds the model's training step node op. It minimizes 123 | the model's loss with Stochastic Gradient Descent (SGD) 124 | with a constant learning rate and updates the running 125 | mean and variances of batch normalization. 126 | 127 | Args: 128 | learning_rate: SGD learning rate. 129 | ''' 130 | with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): 131 | with tf.name_scope('train'): 132 | global_step = tf.Variable(1, name='global_step', trainable=False) 133 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 134 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 135 | with tf.control_dependencies(update_ops): 136 | self.train = optimizer.minimize(self.loss, global_step=global_step) 137 | 138 | return self.train 139 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from models import description 7 | import polyu 8 | import utils 9 | import validate 10 | 11 | FLAGS = None 12 | 13 | 14 | def train(dataset, log_dir): 15 | with tf.Graph().as_default(): 16 | # gets placeholders for images and labels 17 | images_pl, labels_pl = utils.placeholder_inputs() 18 | 19 | # build net graph 20 | net = description.Net(images_pl, FLAGS.dropout) 21 | 22 | # build training related ops 23 | net.build_loss(labels_pl, FLAGS.weight_decay) 24 | net.build_train(FLAGS.learning_rate) 25 | 26 | # builds validation graph 27 | val_net = description.Net(images_pl, training=False, reuse=True) 28 | 29 | # add summary to plot loss and rank 30 | eer_pl = tf.placeholder(tf.float32, shape=(), name='eer_pl') 31 | loss_pl = tf.placeholder(tf.float32, shape=(), name='loss_pl') 32 | eer_summary_op = tf.summary.scalar('eer', eer_pl) 33 | loss_summary_op = tf.summary.scalar('loss', loss_pl) 34 | 35 | # early stopping vars 36 | best_eer = 1 37 | faults = 0 38 | saver = tf.train.Saver() 39 | with tf.Session() as sess: 40 | # initialize summary and variables 41 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 42 | sess.run(tf.global_variables_initializer()) 43 | 44 | # 'compute_descriptors' function for validation 45 | def compute_descriptors(img, pts): 46 | return utils.trained_descriptors( 47 | img, 48 | pts, 49 | patch_size=dataset.train.images_shape[1], 50 | session=sess, 51 | imgs_pl=images_pl, 52 | descs_op=val_net.descriptors) 53 | 54 | # train loop 55 | for step in range(1, FLAGS.steps + 1): 56 | # fill feed dict 57 | feed_dict = utils.fill_feed_dict(dataset.train, images_pl, labels_pl, 58 | FLAGS.batch_size, FLAGS.augment) 59 | # train step 60 | loss_value, _ = sess.run([net.loss, net.train], feed_dict=feed_dict) 61 | 62 | # write loss summary periodically 63 | if step % 100 == 0: 64 | print('Step {}: loss = {}'.format(step, loss_value)) 65 | 66 | # summarize loss 67 | loss_summary = sess.run( 68 | loss_summary_op, feed_dict={loss_pl: loss_value}) 69 | summary_writer.add_summary(loss_summary, step) 70 | 71 | # evaluate model periodically 72 | if step % 500 == 0 and dataset.val is not None: 73 | print('Validation:') 74 | eer = validate.matching.validation_eer(dataset.val, 75 | compute_descriptors) 76 | print('EER = {}'.format(eer)) 77 | 78 | # summarize eer 79 | eer_summary = sess.run(eer_summary_op, feed_dict={eer_pl: eer}) 80 | summary_writer.add_summary(eer_summary, global_step=step) 81 | 82 | # early stopping 83 | if eer < best_eer: 84 | # update early stopping vars 85 | best_eer = eer 86 | faults = 0 87 | saver.save( 88 | sess, os.path.join(log_dir, 'model.ckpt'), global_step=step) 89 | else: 90 | faults += 1 91 | if faults >= FLAGS.tolerance: 92 | print('Training stopped early') 93 | break 94 | 95 | # if no validation set, save model when training completes 96 | if dataset.val is None: 97 | saver.save(sess, os.path.join(log_dir, 'model.ckpt')) 98 | 99 | print('Finished') 100 | print('best EER = {}'.format(best_eer)) 101 | 102 | 103 | def main(): 104 | # create folders to save train resources 105 | log_dir = utils.create_dirs(FLAGS.log_dir_path, FLAGS.batch_size, 106 | FLAGS.learning_rate) 107 | 108 | # load dataset 109 | print('Loading description dataset...') 110 | dataset = polyu.description.Dataset(FLAGS.dataset_path) 111 | print('Loaded') 112 | 113 | # train 114 | train(dataset, log_dir) 115 | 116 | 117 | if __name__ == '__main__': 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument( 120 | '--dataset_path', required=True, type=str, help='path to dataset') 121 | parser.add_argument( 122 | '--learning_rate', type=float, default=1e-1, help='learning rate') 123 | parser.add_argument( 124 | '--log_dir_path', type=str, default='log', help='logging directory') 125 | parser.add_argument( 126 | '--tolerance', type=int, default=5, help='early stopping tolerance') 127 | parser.add_argument('--batch_size', type=int, default=256, help='batch size') 128 | parser.add_argument( 129 | '--steps', type=int, default=100000, help='maximum training steps') 130 | parser.add_argument( 131 | '--augment', 132 | action='store_true', 133 | help='use this flag to perform dataset augmentation') 134 | parser.add_argument( 135 | '--dropout', type=float, help='dropout rate in last convolutional layer') 136 | parser.add_argument('--weight_decay', type=float, help='weight decay lambda') 137 | parser.add_argument('--seed', type=int, help='random seed') 138 | 139 | FLAGS = parser.parse_args() 140 | 141 | # set random seeds 142 | tf.set_random_seed(FLAGS.seed) 143 | np.random.seed(FLAGS.seed) 144 | 145 | main() 146 | -------------------------------------------------------------------------------- /polyu/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import tensorflow as tf 5 | import cv2 6 | 7 | import utils 8 | from polyu import aligned_images 9 | 10 | FLAGS = None 11 | 12 | 13 | def load_detections(path): 14 | pts = [] 15 | for pts_path in sorted(os.listdir(path)): 16 | if pts_path.endswith('.txt'): 17 | pts.append(utils.load_dets_txt(os.path.join(path, pts_path))) 18 | 19 | return pts 20 | 21 | 22 | def group_by_label(imgs, pts, labels): 23 | grouped_imgs = [] 24 | grouped_pts = [] 25 | unique_labels = np.unique(labels) 26 | for label in unique_labels: 27 | indices = np.where(labels == label) 28 | grouped_imgs.append(imgs[indices]) 29 | grouped_pts.append(pts[indices]) 30 | labels = unique_labels 31 | 32 | return grouped_imgs, grouped_pts, labels 33 | 34 | 35 | def create_dirs(path, should_create_val): 36 | # create train path 37 | train_path = os.path.join(path, 'train') 38 | if not os.path.exists(train_path): 39 | os.makedirs(train_path) 40 | 41 | # create val path 42 | val_path = os.path.join(path, 'val') 43 | if not os.path.exists(val_path) and should_create_val: 44 | os.makedirs(val_path) 45 | 46 | return train_path, val_path 47 | 48 | 49 | def train_val_split(imgs, pts, labels, names, total_imgs, split): 50 | # build training set 51 | perm = np.random.permutation(len(imgs)) 52 | train_imgs = [] 53 | train_pts = [] 54 | train_labels = [] 55 | train_size = 0 56 | i = 0 57 | while train_size < split * total_imgs: 58 | # add 'perm[i]'th-element to train set 59 | train_size += len(imgs[perm[i]]) 60 | train_imgs.append(imgs[perm[i]]) 61 | train_pts.append(pts[perm[i]]) 62 | train_labels.append(labels[perm[i]]) 63 | i += 1 64 | 65 | # build validation set 66 | val_imgs = [] 67 | val_pts = [] 68 | val_labels = [] 69 | val_names = [] 70 | for j in perm[i:]: 71 | val_imgs.append(imgs[j]) 72 | val_pts.append(pts[j]) 73 | val_labels.append(labels[j]) 74 | val_names.append(names[j]) 75 | 76 | # assert that both sets do not have 77 | # overlap in subjects identities 78 | assert not set(train_labels).intersection(val_labels) 79 | 80 | train = (train_imgs, train_pts) 81 | val = (val_imgs, val_pts, val_names) 82 | 83 | return train, val 84 | 85 | 86 | def save_patches(grouped_imgs, grouped_pts, path, patch_size): 87 | # 'patch_index' is a unique patch identifier 88 | patch_index = 1 89 | for imgs, pts in zip(grouped_imgs, grouped_pts): 90 | # align images 91 | handler = aligned_images.Handler(imgs, pts, patch_size) 92 | 93 | # extract patches 94 | for patches in handler: 95 | for i, patch in enumerate(patches): 96 | patch_path = os.path.join(path, '{}_{}.png'.format(patch_index, i + 1)) 97 | cv2.imwrite(patch_path, 255 * patch) 98 | patch_index += 1 99 | 100 | 101 | def save_dataset(grouped_imgs, grouped_pts, grouped_names, path): 102 | # save images with name "subject_session_identifier" 103 | for imgs, all_pts, names in zip(grouped_imgs, grouped_pts, grouped_names): 104 | for img, pts, name in zip(imgs, all_pts, names): 105 | # save image 106 | img_path = os.path.join(path, name + '.png') 107 | cv2.imwrite(img_path, 255 * img) 108 | 109 | # save detections 110 | pts_path = os.path.join(path, name + '.txt') 111 | utils.save_dets_txt(pts, pts_path) 112 | 113 | 114 | def main(): 115 | # load detections 116 | print('Loading detections...') 117 | pts_dir_path = os.path.join(FLAGS.pts_dir_path, 'DBI', 'Training') 118 | pts = load_detections(pts_dir_path) 119 | print('Done') 120 | 121 | # load images with names and retrieve labels 122 | print('Loading images...') 123 | imgs_dir_path = os.path.join(FLAGS.polyu_dir_path, 'DBI', 'Training') 124 | imgs, names = utils.load_images_with_names(imgs_dir_path) 125 | name2label = utils.retrieve_label_from_image_path 126 | labels = [name2label(name) for name in names] 127 | print('Done') 128 | 129 | # convert to np array 130 | imgs = np.array(imgs) 131 | pts = np.array(pts) 132 | labels = np.array(labels) 133 | 134 | # group (imgs, pts, names) by label 135 | print('Grouping (images, detections) by label...') 136 | total_imgs = len(imgs) 137 | imgs, pts, labels = group_by_label(imgs, pts, labels) 138 | names = [[name for name in names if name2label(name) == label] 139 | for label in labels] 140 | print('Done') 141 | 142 | # create 'train' & 'val' folders 143 | print('Creating directory tree...') 144 | should_create_val = FLAGS.split < 1 145 | train_path, val_path = create_dirs(FLAGS.result_dir_path, should_create_val) 146 | print('Done') 147 | 148 | # split dataset into train/val 149 | print('Splitting dataset...') 150 | train, val = train_val_split(imgs, pts, labels, names, total_imgs, 151 | FLAGS.split) 152 | train_imgs, train_pts = train 153 | val_imgs, val_pts, val_names = val 154 | print('Done') 155 | 156 | # extract and save all patches from train images 157 | print('Creating training set patches...') 158 | save_patches(train_imgs, train_pts, train_path, FLAGS.patch_size) 159 | print('Done') 160 | 161 | # save validation images 162 | print('Saving validation images...') 163 | save_dataset(val_imgs, val_pts, val_names, val_path) 164 | print('Done') 165 | 166 | 167 | if __name__ == '__main__': 168 | # parse args 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument( 171 | '--polyu_dir_path', 172 | required=True, 173 | type=str, 174 | help='path to PolyU-HRF dataset') 175 | parser.add_argument( 176 | '--pts_dir_path', 177 | type=str, 178 | required=True, 179 | help='path to PolyU-HRF DBI Training dataset keypoints detections') 180 | parser.add_argument( 181 | '--patch_size', 182 | type=int, 183 | required=True, 184 | help='image patch size for descriptor') 185 | parser.add_argument( 186 | '--result_dir_path', 187 | type=str, 188 | required=True, 189 | help='path to save description dataset') 190 | parser.add_argument( 191 | '--split', 192 | default='0.6', 193 | type=float, 194 | help='floating point percentage of training set in train/val split') 195 | parser.add_argument('--seed', type=int, help='random seed') 196 | 197 | FLAGS = parser.parse_args() 198 | 199 | # set random seeds 200 | tf.set_random_seed(FLAGS.seed) 201 | np.random.seed(FLAGS.seed) 202 | 203 | main() 204 | -------------------------------------------------------------------------------- /polyu/description.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import utils 5 | 6 | 7 | class TrainingSet: 8 | ''' 9 | PolyU-HRF description training set handler. Manages images and 10 | corresponding labels in batches, possibly shuffled (shuffles 11 | the training data between epochs), balanced (same number of 12 | labels per batch) and incomplete (if the batch size is 13 | greater than number of available examples in epoch, return 14 | only the examples in the current epoch). 15 | ''' 16 | 17 | def __init__(self, images_by_labels, labels, should_shuffle, 18 | balanced_batches, incomplete_batches): 19 | ''' 20 | Args: 21 | images_by_labels: list of images grouped by labels. 22 | labels: labels aligned to images_by_labels. 23 | should_shuffle: whether the training set should be 24 | shuffled between epochs. 25 | balanced_batches: whether batches should have the 26 | same number of examples per label. 27 | incomplete_batches: whether batches should avoid 28 | examples from different epochs. 29 | ''' 30 | self.n_labels = len(labels) 31 | self._shuffle = should_shuffle 32 | self._balance = balanced_batches 33 | self._incomplete = incomplete_batches 34 | self.images_shape = images_by_labels[0][0].shape 35 | 36 | if self._balance: 37 | # images must be separated by labels 38 | self._images = np.array(images_by_labels) 39 | self._labels = np.array(labels) 40 | 41 | # count images 42 | self.n_images = 0 43 | for images in self._images: 44 | self.n_images += len(images) 45 | 46 | # images per label 47 | self._imgs_per_label = self.n_images // self.n_labels 48 | 49 | # shuffle for first epoch 50 | if self._shuffle: 51 | perm = np.random.permutation(self.n_labels) 52 | self._images = self._images[perm] 53 | self._labels = self._labels[perm] 54 | 55 | # initialize data pointers 56 | self.epochs = 0 57 | self._index = 0 58 | else: 59 | # images can be flattened 60 | self._images = np.reshape( 61 | images_by_labels, (-1, self.images_shape[0], self.images_shape[1])) 62 | self._labels = np.repeat(labels, len(self._images) // len(labels)) 63 | self.n_images = len(self._images) 64 | 65 | # shuffle for first epoch 66 | if self._shuffle: 67 | perm = np.random.permutation(self.n_images) 68 | self._images = self._images[perm] 69 | self._labels = self._labels[perm] 70 | 71 | # initialize dataset pointers 72 | self.epochs = 0 73 | self._index = 0 74 | 75 | def next_batch(self, batch_size): 76 | ''' 77 | Samples a mini-batch of size batch_size. If balanced_batches 78 | was True, then batches are sampled with equal label distribution. 79 | If incomplete_batches was True, batches are only sampled inside 80 | epochs, even if this means eventually sampling smaller batches. 81 | 82 | Args: 83 | batch_size: sampled mini-batch size. 84 | 85 | Returns: 86 | batch_images: sampled images. 87 | batch_labels: sampled labels. 88 | ''' 89 | # adjust for balanced batch sampling 90 | if not self._balance: 91 | end = self.n_images 92 | else: 93 | batch_size = batch_size // self._imgs_per_label 94 | full_batch_size = batch_size * self._imgs_per_label 95 | end = self.n_labels 96 | 97 | if self._index + batch_size >= end: 98 | # finished epoch 99 | self.epochs += 1 100 | 101 | # get remainder of examples in this epoch 102 | start = self._index 103 | images_rest_part = self._images[start:] 104 | labels_rest_part = self._labels[start:] 105 | rest_num_images = end - start 106 | 107 | # shuffle the data 108 | if self._shuffle: 109 | perm = np.random.permutation(end) 110 | self._images = self._images[perm] 111 | self._labels = self._labels[perm] 112 | 113 | # handle incomplete batches 114 | if self._incomplete: 115 | # return incomplete batch 116 | self._index = 0 117 | batch_images = images_rest_part 118 | batch_labels = labels_rest_part 119 | else: 120 | # start next epoch 121 | self._index = batch_size - rest_num_images 122 | 123 | # retrieve observations in new epoch 124 | images_new_part = self._images[0:self._index] 125 | labels_new_part = self._labels[0:self._index] 126 | 127 | batch_images = np.concatenate( 128 | [images_rest_part, images_new_part], axis=0) 129 | batch_labels = np.concatenate( 130 | [labels_rest_part, labels_new_part], axis=0) 131 | else: 132 | start = self._index 133 | self._index += batch_size 134 | batch_images = self._images[start:self._index] 135 | batch_labels = self._labels[start:self._index] 136 | 137 | # reshape balanced batches 138 | if self._balance: 139 | batch_images = np.reshape(batch_images, 140 | (full_batch_size, ) + batch_images.shape[2:]) 141 | batch_labels = np.repeat(batch_labels, 142 | len(batch_images) // len(batch_labels)) 143 | 144 | return batch_images, batch_labels 145 | 146 | 147 | class ValidationSet: 148 | ''' 149 | PolyU-HRF description validation set handler. Manages images 150 | and corresponding labels and detections in batches. 151 | ValidationSet.__getitem__ provides access to instances, 152 | returning aligned images, detections and labels. 153 | ''' 154 | 155 | def __init__(self, images, detections, labels): 156 | ''' 157 | Args: 158 | images: validation images. 159 | detections: corresponding detections. 160 | labels: corresponding detections. 161 | ''' 162 | self._images = np.array(images) 163 | self._detections = np.array(detections) 164 | self._labels = np.array(labels) 165 | 166 | def __getitem__(self, val): 167 | return self._images[val], self._detections[val], self._labels[val] 168 | 169 | 170 | class Dataset: 171 | ''' 172 | PolyU-HRF description dataset handler. Contains a TrainingSet, as 173 | Dataset.train, and a ValidationSet, as Dataset.val, if it exists 174 | in the provided dataset path. 175 | ''' 176 | 177 | def __init__(self, path, should_shuffle=True, balanced_batches=True): 178 | ''' 179 | Args: 180 | path: path to preprocessed polyu description dataset that has 181 | a train subfolder with properly annotated images. 182 | should_shuffle: whether TrainingSet should shuffle its data 183 | between epochs. 184 | balanced_batches: whether TrainingSet should sample batches 185 | with the same number of examples per label. 186 | ''' 187 | # split paths 188 | train_path = os.path.join(path, 'train') 189 | val_path = os.path.join(path, 'val') 190 | 191 | # load training set 192 | images, labels = utils.load_images_with_labels(train_path) 193 | images_by_labels, labels = self._group_images_by_labels(images, labels) 194 | self.train = TrainingSet( 195 | images_by_labels, 196 | labels, 197 | should_shuffle=should_shuffle, 198 | balanced_batches=balanced_batches, 199 | incomplete_batches=False) 200 | 201 | # load validation set, if any 202 | self.val = None 203 | if os.path.exists(val_path): 204 | images, detections, labels = self._load_validation_data(val_path) 205 | self.val = ValidationSet(images, detections, labels) 206 | 207 | def _group_images_by_labels(self, images, labels): 208 | # convert to np array 209 | images = np.array(images) 210 | labels = np.array(labels) 211 | 212 | grouped_images = [] 213 | all_labels = np.unique(labels) 214 | for label in all_labels: 215 | indices = np.where(labels == label) 216 | grouped_images.append(images[indices]) 217 | return grouped_images, all_labels 218 | 219 | def _load_validation_data(self, val_path): 220 | # load images with respective names 221 | images, names = utils.load_images_with_names(val_path) 222 | 223 | # convert 'names' to validation 'labels' 224 | # each 'name' in 'names' is 'subject-id_session-id_register-rd' 225 | labels = [tuple(map(int, name.split('_'))) for name in names] 226 | 227 | # load detections, aligned with images and labels 228 | paths = map(lambda name: os.path.join(val_path, name), names) 229 | detections = [utils.load_dets_txt(path + '.txt') for path in paths] 230 | 231 | return images, detections, labels 232 | -------------------------------------------------------------------------------- /align.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import utils 4 | 5 | 6 | def _inside(img, pt): 7 | return 0 <= pt[0] < img.shape[0] and 0 <= pt[1] < img.shape[1] 8 | 9 | 10 | def _transf(pt, A, s, b): 11 | return s * np.dot(pt, A.T) + b 12 | 13 | 14 | def _inv_transf(pt, A, s, b): 15 | return np.dot(pt - b, A) / s 16 | 17 | 18 | def _horn(L, R, weights=None, scale=True): 19 | ''' 20 | Use Horn's closed form absolute orientation 21 | method with orthogonal matrices to align a set 22 | of points L to a set of points R. L and R are 23 | sets in which each row is a point and the ith 24 | point of L corresponds to the ith point of R. 25 | 26 | Args: 27 | L: set of points to align, a point per row. 28 | R: set of points to align to, a point per row. 29 | weights: an array of weights in [0, 1] for 30 | each correspondence. If 'None', 31 | correspondences are unweighted. 32 | scale: whether should also solve the scale. If 33 | 'False', s = 1. 34 | 35 | Returns: 36 | A, b, s: a transformation such that 37 | s * (L @ A) + b ~ R, 38 | making the sum of squares of errors the 39 | least possible. See the paper (Horn et al., 40 | 1988) for details. 41 | ''' 42 | L = np.asarray(L) 43 | R = np.asarray(R) 44 | if weights is None: 45 | weights = 1 46 | else: 47 | weights = np.expand_dims(weights, axis=-1) 48 | 49 | # compute points' centroids 50 | L_centroid = np.mean(weights * L, axis=0) 51 | R_centroid = np.mean(weights * R, axis=0) 52 | 53 | # translate points to make centroids origin 54 | L_ = L - L_centroid 55 | R_ = R - R_centroid 56 | 57 | # find scale 's' 58 | if scale: 59 | Sr = np.sum(np.reshape(weights, -1) * np.sum(R_ * R_, axis=1), axis=0) 60 | Sl = np.sum(np.reshape(weights, -1) * np.sum(L_ * L_, axis=1), axis=0) 61 | s = np.sqrt(Sr / Sl) 62 | else: 63 | s = 1 64 | 65 | # find rotation 'A' 66 | M = np.dot(R_.T, weights * L_) 67 | MTM = np.dot(M.T, M) 68 | w, u = np.linalg.eigh(MTM) 69 | u = u.T 70 | 71 | # solve even if M is not full rank 72 | rank = M.shape[0] - np.sum(np.isclose(w, 0)) 73 | if rank < M.shape[0] - 1: 74 | raise Exception( 75 | 'rank of M must be at least shape(M)[0] - 1 for unique solution') 76 | elif rank == M.shape[0] - 1: 77 | # get non zero eigenvalues and eigenvectors 78 | mask = np.logical_not(np.isclose(w, 0)) 79 | u = u[mask] 80 | w = w[mask] 81 | 82 | # compute S pseudoinverse 83 | w = np.expand_dims(w, 1) 84 | lhs = np.expand_dims(u / np.sqrt(w), axis=-1) 85 | rhs = np.expand_dims(u, axis=1) 86 | S_pseudoinv = np.sum(np.matmul(lhs, rhs), axis=0) 87 | 88 | # compute MTM svd 89 | u, _, v = np.linalg.svd(MTM) 90 | 91 | # compute tomasi's fix 92 | lhs = np.expand_dims(u[-1], axis=-1) 93 | rhs = np.expand_dims(v[-1], axis=0) 94 | tfix = np.dot(lhs, rhs) 95 | 96 | # find whether should sum or subtract 97 | A_base = np.dot(M, S_pseudoinv) 98 | A = A_base + tfix 99 | if np.linalg.det(A) < 0: 100 | A = A_base - tfix 101 | else: 102 | w = np.expand_dims(w, 1) 103 | lhs = np.expand_dims(u / np.sqrt(w), axis=-1) 104 | rhs = np.expand_dims(u, axis=1) 105 | S_inv = np.sum(np.matmul(lhs, rhs), axis=0) 106 | A = np.dot(M, S_inv) 107 | 108 | # find translation 'b' 109 | b = R_centroid - s * np.dot(A, L_centroid) 110 | 111 | return A, b, s 112 | 113 | 114 | def iterative(img1, 115 | pts1, 116 | img2, 117 | pts2, 118 | descs1=None, 119 | descs2=None, 120 | euclidean_lambda=500, 121 | weighted=False, 122 | max_iter=10): 123 | ''' 124 | Iteratively align image 'img1' to 'img2', using 125 | Horn's absolute orientation method to minimize 126 | the mean squared error between keypoints' 127 | correspondences in sets 'pts1' and 'pts2'. 128 | Correspondences between keypoints are found 129 | with the following metric: 130 | 131 | d(u, v) = ||SIFT(u) - SIFT(v)||^2 + 132 | + (\lambda * ||u - v||^2) / MSE 133 | 134 | where '\lambda' is a user specified weight and 135 | 'MSE' is the mean squared error from the 136 | previous alignment. For the first iteration, 137 | MSE = Inf. 138 | 139 | Args: 140 | img1: np array with image to align. 141 | pts1: np array with img1 keypoints, one 142 | keypoint coordinate per row. 143 | img2: np array with image to align to. 144 | pts2: same as pts1, but for img2. 145 | descs1: precomputed descriptors for img1. 146 | descs2: same as descs1, but for img2. 147 | euclidean_lambda: \lambda in above equation. 148 | weighted: whether should consider the 149 | correspondence confidence - computed as 150 | the reciprocal of its distance - in 151 | Horn's method. 152 | max_iter: Maximum number of iterations. 153 | 154 | Returns: 155 | A, b, s: the found alignment. For further 156 | information, read _horn() documentation. 157 | ''' 158 | # initialize before first alignment 159 | mse = np.inf 160 | euclidean_weight = -1 161 | A = np.identity(2) 162 | s = 1 163 | b = np.array([0, 0]) 164 | 165 | # precompute sift descriptors, if not given 166 | if descs1 is None: 167 | descs1 = utils.sift_descriptors(img1, pts1, scale=8) 168 | if descs2 is None: 169 | descs2 = utils.sift_descriptors(img2, pts2, scale=8) 170 | 171 | # iteratively align 172 | for _ in range(max_iter): 173 | # convergence criterion 174 | if np.isclose(mse * euclidean_weight, euclidean_lambda): 175 | break 176 | 177 | # compute weight of correspondences' euclidean distance 178 | euclidean_weight = euclidean_lambda / (mse + 1e-5) 179 | 180 | # find correspondences 181 | pairs = utils.find_correspondences( 182 | descs1, 183 | descs2, 184 | pts1=pts1, 185 | pts2=pts2, 186 | euclidean_weight=euclidean_weight, 187 | transf=lambda x: _transf(x, A, s, b), 188 | thr=0.8) 189 | 190 | # end alignment if no further correspondences are found 191 | if len(pairs) <= 1: 192 | break 193 | 194 | # make correspondence aligned array 195 | if weighted: 196 | max_dist = np.max(np.asarray(pairs)[:, 2]) 197 | w = [] 198 | L = [] 199 | R = [] 200 | for pair in pairs: 201 | L.append(pts1[pair[0]]) 202 | R.append(pts2[pair[1]]) 203 | w.append((max_dist - pair[2]) / max_dist) 204 | else: 205 | w = None 206 | L = [] 207 | R = [] 208 | for pair in pairs: 209 | L.append(pts1[pair[0]]) 210 | R.append(pts2[pair[1]]) 211 | 212 | # find alignment transformation 213 | A, b, s = _horn(L, R, weights=w) 214 | 215 | # compute alignment mse 216 | L = np.array(L) 217 | R = np.array(R) 218 | error = R - (s * np.dot(L, A.T) + b) 219 | dists = np.sum(error * error, axis=1) 220 | mse = np.mean(dists) 221 | 222 | # filter points and corresponding descriptors 223 | # that are out of the images overlap 224 | pts1_ = [] 225 | descs1_ = [] 226 | for i, pt in enumerate(pts1): 227 | t_pt = _transf(pt, A, s, b) 228 | if _inside(img2, t_pt): 229 | pts1_.append(pt) 230 | descs1_.append(descs1[i]) 231 | pts1 = pts1_ 232 | descs1 = np.array(descs1_) 233 | 234 | # same for second set 235 | pts2_ = [] 236 | descs2_ = [] 237 | for i, pt in enumerate(pts2): 238 | t_pt = _inv_transf(pt, A, s, b) 239 | if _inside(img1, t_pt): 240 | pts2_.append(pt) 241 | descs2_.append(descs2[i]) 242 | pts2 = pts2_ 243 | descs2 = np.array(descs2_) 244 | 245 | return A, b, s 246 | 247 | 248 | if __name__ == '__main__': 249 | import sys 250 | import cv2 251 | 252 | if len(sys.argv) < 5: 253 | raise Exception( 254 | 'Expected 4 arguments: , found {}'. 255 | format(len(sys.argv) - 1)) 256 | 257 | img1_path, pts1_path, img2_path, pts2_path = sys.argv[1:] 258 | 259 | # load images 260 | img1 = cv2.imread(img1_path, 0) 261 | img2 = cv2.imread(img2_path, 0) 262 | 263 | # load detection points 264 | pts1 = utils.load_dets_txt(pts1_path) 265 | pts2 = utils.load_dets_txt(pts2_path) 266 | 267 | A, b, s = iterative(img1, pts1, img2, pts2) 268 | 269 | # generate aligned images 270 | aligned = np.zeros_like(img1, dtype=img1.dtype) 271 | for ref_row in range(img1.shape[0]): 272 | for ref_col in range(img1.shape[1]): 273 | t_row, t_col = np.dot(A.T, (np.array([ref_row, ref_col]) - b) / s) 274 | if 0 <= t_row < img1.shape[0] - 1 and 0 <= t_col < img1.shape[1] - 1: 275 | aligned[ref_row, ref_col] = utils.bilinear_interpolation( 276 | t_row, t_col, img1) 277 | 278 | # display current alignment 279 | diff = np.stack([img2, img2, aligned], axis=-1) 280 | cv2.imshow('prealignment', img1) 281 | cv2.imshow('target', img2) 282 | cv2.imshow('aligned', aligned) 283 | cv2.imshow('diff', diff) 284 | cv2.waitKey(0) 285 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-resolution fingerprint recognition 2 | This repository contains the original implementation of the fingerprint pore detection and description models from [Automatic Dataset Annotation to Learn CNN Pore Description for Fingerprint Recognition](https://arxiv.org/abs/1809.10229). 3 | 4 | ## PolyU-HRF dataset 5 | The Hong Kong Polytechnic University (PolyU) High-Resolution-Fingerprint (HRF) Database is a high-resolution fingerprint dataset for fingerprint recognition. We ran all of our experiments in the PolyU-HRF dataset, so it is required to reproduce them. PolyU-HRF can be obtained by following the instructions from its authors [here](http://www4.comp.polyu.edu.hk/~biometrics/HRF/HRF_old.htm). 6 | 7 | Assuming PolyU-HRF is inside a local directory named `polyu_hrf`, its internal organization must be as following in order to reproduce our experiments with the code in this repository as it is: 8 | ``` 9 | polyu_hrf/ 10 | DBI/ 11 | Training/ 12 | Test/ 13 | DBII/ 14 | GroundTruth/ 15 | PoreGroundTruth/ 16 | PoreGroundTruthMarked/ 17 | PoreGroundTruthSampleimage/ 18 | ``` 19 | 20 | ## Requirements 21 | The code in this repository was tested for Ubuntu 16.04 and Python 3.5.2, but we believe any newer version of both will do. 22 | 23 | We recomend installing Python's venv (tested for version 15.0.1) to run the experiments. To do it in Ubuntu 16.04: 24 | ``` 25 | sudo apt install python3-venv 26 | ``` 27 | 28 | Then, create and activate a venv: 29 | ``` 30 | python3 -m venv env 31 | source env/bin/activate 32 | ``` 33 | 34 | To install the requirements either run, for CPU usage: 35 | ``` 36 | pip install -r cpu-requirements.txt 37 | ``` 38 | or run, for GPU usage, which requires the [Tensorflow GPU dependencies](https://www.tensorflow.org/install/gpu): 39 | ``` 40 | pip install -r gpu-requirements.txt 41 | ``` 42 | 43 | ## Pore description 44 | ### Detecing pores in every image 45 | To train a pore descriptor, pore detections are required for every image. To do this, run: 46 | ``` 47 | python3 -m batch_detect_pores --polyu_dir_path polyu_hrf --model_dir_path log/detection/[det_model_dir] --results_dir_path log/pores 48 | ``` 49 | This will detect pores for every image in PolyU-HRF and store them in `[image_name].txt` format inside `log/pores` subfolders `DBI/Training`, `DBI/Test` and `DBII`. 50 | 51 | The options for batch detecting pores are: 52 | ``` 53 | usage: batch_detect_pores [-h] --polyu_dir_path POLYU_DIR_PATH 54 | --model_dir_path MODEL_DIR_PATH 55 | [--patch_size PATCH_SIZE] 56 | [--results_dir_path RESULTS_DIR_PATH] 57 | [--prob_thr PROB_THR] [--inter_thr INTER_THR] 58 | 59 | optional arguments: 60 | -h, --help show this help message and exit 61 | --polyu_dir_path POLYU_DIR_PATH 62 | path to PolyU-HRF dataset 63 | --model_dir_path MODEL_DIR_PATH 64 | path from which to restore trained model 65 | --patch_size PATCH_SIZE 66 | pore patch size 67 | --results_dir_path RESULTS_DIR_PATH 68 | path to folder in which results should be saved 69 | --prob_thr PROB_THR probability threshold to filter detections 70 | --inter_thr INTER_THR 71 | nms intersection threshold 72 | 73 | ``` 74 | 75 | ### Generating pore identity annotations 76 | It is also required to have pore identity annotations to train a pore descriptor. To do this, run: 77 | ``` 78 | python3 -m polyu.preprocess --polyu_dir_path polyu_hrf --pts_dir_path log/pores --patch_size 32 --result_dir_path log/patch_polyu 79 | ``` 80 | This splits DBI Training in two subject independent subsets, one for training, another for validation. The default split creates the training subset with 60% of the subject identities. All other identites go to the validation subset. 81 | `polyu.preprocess` separates these subsets in subfolders of `log/patch_polyu`: `train` and `val`. 82 | 83 | `train` contains pore patch images named `pore-id_register-id.png`. It has exactly 6 images for each pore identity that is visible in every image of the training subset. 84 | 85 | `val` contains the images of the validation subset, with their original names, and their corresponding pore detection files. 86 | 87 | The options for generating pore identity annotations are: 88 | ``` 89 | usage: polyu.preprocess [-h] --polyu_dir_path POLYU_DIR_PATH --pts_dir_path 90 | PTS_DIR_PATH --patch_size PATCH_SIZE --result_dir_path 91 | RESULT_DIR_PATH [--split SPLIT] [--seed SEED] 92 | 93 | optional arguments: 94 | -h, --help show this help message and exit 95 | --polyu_dir_path POLYU_DIR_PATH 96 | path to PolyU-HRF dataset 97 | --pts_dir_path PTS_DIR_PATH 98 | path to PolyU-HRF DBI Training dataset keypoints 99 | detections 100 | --patch_size PATCH_SIZE 101 | image patch size for descriptor 102 | --result_dir_path RESULT_DIR_PATH 103 | path to save description dataset 104 | --split SPLIT floating point percentage of training set in train/val 105 | split 106 | --seed SEED random seed 107 | ``` 108 | 109 | ### Training the model 110 | To train the pore description model, run: 111 | ``` 112 | python3 -m train --dataset_path log/patch_polyu --log_dir_path log/description/ --augment --dropout 0.3 113 | ``` 114 | This will train a description model with the hyper-parameters we used for the model in our paper, but we recommend tuning them manually by observing the EER in the validation set. The above values usually provide excelent results. However, if the model fails to achieve 0% EER in the validation set, you should probably investigate other values. Training without augmentation has disastrous results, so always train with it. 115 | 116 | Running the script above will create a folder inside `log/description` for the trained model's resources. We will call it `[desc_model_dir]` for the rest of the instructions. 117 | 118 | Options for training the description model are: 119 | ``` 120 | usage: train [-h] --dataset_path DATASET_PATH 121 | [--learning_rate LEARNING_RATE] [--log_dir_path LOG_DIR_PATH] 122 | [--tolerance TOLERANCE] [--batch_size BATCH_SIZE] 123 | [--steps STEPS] [--augment] [--dropout DROPOUT] 124 | [--weight_decay WEIGHT_DECAY] [--seed SEED] 125 | 126 | optional arguments: 127 | -h, --help show this help message and exit 128 | --dataset_path DATASET_PATH 129 | path to dataset 130 | --learning_rate LEARNING_RATE 131 | learning rate 132 | --log_dir_path LOG_DIR_PATH 133 | logging directory 134 | --tolerance TOLERANCE 135 | early stopping tolerance 136 | --batch_size BATCH_SIZE 137 | batch size 138 | --steps STEPS maximum training steps 139 | --augment use this flag to perform dataset augmentation 140 | --dropout DROPOUT dropout rate in last convolutional layer 141 | --weight_decay WEIGHT_DECAY 142 | weight decay lambda 143 | --seed SEED random seed 144 | ``` 145 | 146 | ## Fingerprint recogntion experiments 147 | ### SIFT descriptors 148 | In order to reproduce the SIFT descriptors experiment in DBI Test, run: 149 | ``` 150 | python3 -m validate.matching --polyu_dir_path polyu_hrf --pts_dir_path log/pores --descriptors sift --thr 0.7 --fold DBI-test 151 | ``` 152 | For the DBII experiment: 153 | ``` 154 | python3 -m validate.matching --polyu_dir_path polyu_hrf --pts_dir_path log/pores --descriptors sift --thr 0.7 --fold DBII 155 | ``` 156 | 157 | ### DP descriptors 158 | For the DBI Test, run: 159 | ``` 160 | python3 -m validate.matching --polyu_dir_path polyu_hrf --pts_dir_path log/pores --descriptors dp --thr 0.7 --fold DBI-test --patch_size 32 161 | ``` 162 | For the DBII experiment: 163 | ``` 164 | python3 -m validate.matching --polyu_dir_path polyu_hrf --pts_dir_path log/pores --descriptors dp --thr 0.7 --fold DBII --patch_size 32 165 | ``` 166 | 167 | ### Trained descriptors 168 | For the DBI Test experiment: 169 | ``` 170 | python3 -m validate.matching --polyu_dir_path polyu_hrf --pts_dir_path log/pores --descriptors trained --thr 0.7 --model_dir_path log/description/[desc_model_dir] --fold DBI-test --patch_size 32 171 | ``` 172 | For the DBII experiment: 173 | ``` 174 | python3 -m validate.matching --polyu_dir_path polyu_hrf --pts_dir_path log/pores --descriptors trained --thr 0.7 --model_dir_path log/description/[desc_model_dir] --fold DBII --patch_size 32 175 | ``` 176 | 177 | Other options for `validate.matching` are: 178 | ``` 179 | usage: validate.matching [-h] --polyu_dir_path POLYU_DIR_PATH --pts_dir_path 180 | PTS_DIR_PATH [--results_path RESULTS_PATH] 181 | [--descriptors DESCRIPTORS] [--mode MODE] [--thr THR] 182 | [--model_dir_path MODEL_DIR_PATH] [--patch_size PATCH_SIZE] 183 | [--fold FOLD] [--seed SEED] 184 | 185 | optional arguments: 186 | -h, --help show this help message and exit 187 | --polyu_dir_path POLYU_DIR_PATH 188 | path to PolyU-HRF dataset 189 | --pts_dir_path PTS_DIR_PATH 190 | path to chosen dataset keypoints detections 191 | --results_path RESULTS_PATH 192 | path to results file 193 | --descriptors DESCRIPTORS 194 | which descriptors to use. Can be "sift", "dp" or 195 | "trained" 196 | --mode MODE mode to match images. Can be "basic" or "spatial" 197 | --thr THR distance ratio check threshold 198 | --model_dir_path MODEL_DIR_PATH 199 | trained model directory path 200 | --patch_size PATCH_SIZE 201 | pore patch size 202 | --fold FOLD choose what fold of polyu to use. Can be "DBI-train", 203 | "DBI-test" and "DBII" 204 | --seed SEED random seed 205 | ``` 206 | 207 | ## Pre-trained models and reproducing paper results 208 | The pre-trained [description model](https://drive.google.com/open?id=16GiLG7xBj64SOjCJwlCfbBcb-DORzYg1) is required to ensure that you get the exact same results as those of the paper. It is also required to use the same pore detection model we did; this model is available to download [here](https://drive.google.com/open?id=1U9rm_5za2kRU2FsviCe-qrZoouwUGyzI). After downloading both models, follow the batch pore detection and fingerprint recognition steps replacing `[det_model_dir]` and `[desc_model_dir]` where appropriate. 209 | 210 | ## Recognizing fingerprints 211 | We also provide `recognize.py`, a script to, given two high resolution fingerprint images and a model trained to detect pores and another one to describe them, determine if they are from the same subject or not. To use it, run: 212 | ``` 213 | python3 -m recognize --image_paths [image01_path] [image02_path] --det_model_dir [det_model_dir] --desc_model_dir [desc_model_dir] 214 | ``` 215 | 216 | There is also a command line parameter, `score_thr`, to control the minimum number of established correspondences to determine that the images belong to the same subject. Its default value is 2, the EER threshold for the partial fingerprints in DBI-test. For the full fingerprints of DBII, this value should be set to 9. 217 | 218 | Other options for this script are: 219 | ``` 220 | usage: recognize [-h] --img_paths IMG_PATHS IMG_PATHS --det_model_dir 221 | DET_MODEL_DIR --desc_model_dir DESC_MODEL_DIR 222 | [--score_thr SCORE_THR] [--det_patch_size DET_PATCH_SIZE] 223 | [--det_prob_thr DET_PROB_THR] 224 | [--nms_inter_thr NMS_INTER_THR] 225 | [--desc_patch_size DESC_PATCH_SIZE] 226 | 227 | optional arguments: 228 | -h, --help show this help message and exit 229 | --img_paths IMG_PATHS IMG_PATHS 230 | path to images to be recognized 231 | --det_model_dir DET_MODEL_DIR 232 | path to pore detection trained model 233 | --desc_model_dir DESC_MODEL_DIR 234 | path to pore description trained model 235 | --score_thr SCORE_THR 236 | score threshold to determine if pair is genuine or 237 | impostor 238 | --det_patch_size DET_PATCH_SIZE 239 | detection patch size 240 | --det_prob_thr DET_PROB_THR 241 | probability threshold for discarding detections 242 | --nms_inter_thr NMS_INTER_THR 243 | NMS area intersection threshold 244 | --desc_patch_size DESC_PATCH_SIZE 245 | patch size around each detected keypoint to describe 246 | ``` 247 | 248 | ## Reference 249 | If you find the code in this repository useful for your research, please consider citing: 250 | ``` 251 | @article{dahia2018cnn, 252 | title={Automatic Dataset Annotation to Learn CNN Pore Description for Fingerprint Recognition}, 253 | author={Dahia, Gabriel and Segundo, Maur{\'\i}cio Pamplona}, 254 | journal={arXiv preprint arXiv:1809.10229}, 255 | year={2018} 256 | } 257 | ``` 258 | 259 | ## License 260 | See the [LICENSE](LICENSE) file for details. 261 | -------------------------------------------------------------------------------- /validate/matching.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | import utils 8 | import matching 9 | 10 | FLAGS = None 11 | 12 | 13 | def validation_eer(dataset, compute_descriptors): 14 | ''' 15 | Computes the validation Equal Error Rate (EER) in dataset 16 | using descriptors computed with compute_descriptors and 17 | matching them using SIFT's original criterion with 18 | distance ratio check threshold of 0.7, following PolyU-HRF's 19 | recognition protocol. 20 | 21 | Args: 22 | dataset: dataset for which EER should be computed. 23 | compute_descriptors: function that receives image and 24 | keypoint detections and computes descriptors at these 25 | locations. 26 | 27 | Returns: 28 | computed EER. 29 | ''' 30 | # describe patches with detections and get 31 | # subject and register ids from labels 32 | all_descs = [] 33 | all_pts = [] 34 | subject_ids = set() 35 | register_ids = set() 36 | id2index_dict = {} 37 | index = 0 38 | for img, pts, label in dataset: 39 | # add patch descriptors to all descriptors 40 | descs, new_pts = compute_descriptors(img, pts) 41 | all_descs.append(descs) 42 | all_pts.append(new_pts) 43 | 44 | # add ids to all ids 45 | subject_ids.add(label[0]) 46 | register_ids.add(label[2]) 47 | 48 | # make 'id2index' correspondence 49 | id2index_dict[tuple(label)] = index 50 | index += 1 51 | 52 | # convert dict into function 53 | id2index = lambda x: id2index_dict[tuple(x)] 54 | 55 | # convert sets into lists 56 | subject_ids = list(subject_ids) 57 | register_ids = list(register_ids) 58 | 59 | # match and compute eer 60 | pos, neg = polyu_match( 61 | all_descs, 62 | all_pts, 63 | subject_ids, 64 | register_ids, 65 | id2index, 66 | matching.basic, 67 | thr=0.7) 68 | eer = utils.eer(pos, neg) 69 | 70 | return eer 71 | 72 | 73 | def load_dataset(imgs_dir_path, 74 | pts_dir_path, 75 | subject_ids, 76 | session_ids, 77 | register_ids, 78 | compute_descriptors, 79 | patch_size=None): 80 | ''' 81 | Loads PolyU-HRF dataset with corresponding keypoint detections. 82 | However, instead of keeping the images themselves, descriptors 83 | are computed at keypoint locations and are returned instead. 84 | 85 | Images must be named 'subject-id_session-id_register-id.jpg' 86 | and corresponding keypoints must have same name and '.txt' 87 | extension, instead of '.jpg'. Every file name from the 88 | cartesian product of subject_ids with session_ids and 89 | register_ids must have a corresponding image and keypoint 90 | files. 91 | 92 | Args: 93 | imgs_dir_path: images directory path. 94 | pts_dir_path: keypoint txt files directory path. 95 | subject_ids: list of subject ids. 96 | session_ids: list of session ids. 97 | register_ids: list of register ids. 98 | compute_descriptors: function to compute descriptors that 99 | takes as arguments an image and its corresponding 100 | keypoints. 101 | patch_size: if not None, discards keypoints that are close 102 | enough to the images' borders that they cannot have a 103 | patch of size patch_size centered on them. 104 | Returns: 105 | all_descs: list of computed descriptors for every image. 106 | all_pts: list of keypoint locations for every image. 107 | id2index: function that converts tuples with 108 | (subject_id, session_id, register_id) into corresponding 109 | descriptor and keypoint valid index. 110 | ''' 111 | id2index_dict = {} 112 | all_descs = [] 113 | all_pts = [] 114 | index = 0 115 | for subject_id in subject_ids: 116 | for session_id in session_ids: 117 | for register_id in register_ids: 118 | instance = '{}_{}_{}'.format(subject_id, session_id, register_id) 119 | 120 | # load image 121 | img_path = os.path.join(imgs_dir_path, '{}.jpg'.format(instance)) 122 | img = utils.load_image(img_path) 123 | 124 | # load detections 125 | pts_path = os.path.join(pts_dir_path, '{}.txt'.format(instance)) 126 | pts = utils.load_dets_txt(pts_path) 127 | 128 | # filter detections at non valid border 129 | if patch_size is not None: 130 | half = patch_size // 2 131 | pts_ = [] 132 | for pt in pts: 133 | if half <= pt[0] < img.shape[0] - half: 134 | if half <= pt[1] < img.shape[1] - half: 135 | pts_.append(pt) 136 | pts = pts_ 137 | 138 | all_pts.append(pts) 139 | 140 | # compute image descriptors 141 | descs, *_ = compute_descriptors(img, pts) 142 | all_descs.append(descs) 143 | 144 | # make id2index correspondence 145 | id2index_dict[(subject_id, session_id, register_id)] = index 146 | index += 1 147 | 148 | # turn id2index into conversion function 149 | id2index = lambda x: id2index_dict[tuple(x)] 150 | 151 | return all_descs, all_pts, id2index 152 | 153 | 154 | def polyu_match(all_descs, 155 | all_pts, 156 | subject_ids, 157 | register_ids, 158 | id2index, 159 | match, 160 | thr=None): 161 | ''' 162 | Implements PolyU-HRF recognition protocol comparisons with 163 | given descriptors, keypoints, ids and matching algorithm. 164 | 165 | Args: 166 | all_descs: list of descriptors for every image. 167 | all_pts: list of keypoint detections for every image. 168 | subject_ids: list with all subject ids. 169 | register_ids: list with all register ids. 170 | id2index: function that converts tuples with 171 | (subject_id, session_id, register_id) into descriptor 172 | and keypoint set index. 173 | match: function that receives as arguments two sets of 174 | descriptors, two set of keypoint detections and a 175 | threshold and returns a similarity score between 176 | them. 177 | 178 | Returns: 179 | pos: genuine comparison scores. 180 | neg: impostor comparison scores. 181 | ''' 182 | pos = [] 183 | neg = [] 184 | 185 | # same subject comparisons 186 | for subject_id in subject_ids: 187 | for register_id1 in register_ids: 188 | index1 = id2index((subject_id, 1, register_id1)) 189 | descs1 = all_descs[index1] 190 | pts1 = all_pts[index1] 191 | for register_id2 in register_ids: 192 | index2 = id2index((subject_id, 2, register_id2)) 193 | descs2 = all_descs[index2] 194 | pts2 = all_pts[index2] 195 | pos.append(match(descs1, descs2, pts1, pts2, thr=thr)) 196 | 197 | # different subject comparisons 198 | for subject_id1 in subject_ids: 199 | for subject_id2 in subject_ids: 200 | if subject_id1 != subject_id2: 201 | index1 = id2index((subject_id1, 1, 1)) 202 | index2 = id2index((subject_id2, 2, 1)) 203 | 204 | descs1 = all_descs[index1] 205 | descs2 = all_descs[index2] 206 | pts1 = all_pts[index1] 207 | pts2 = all_pts[index2] 208 | 209 | neg.append(match(descs1, descs2, pts1, pts2, thr=thr)) 210 | 211 | return pos, neg 212 | 213 | 214 | def main(): 215 | # parse descriptor and adjust accordingly 216 | compute_descriptors = None 217 | if FLAGS.descriptors == 'sift': 218 | compute_descriptors = utils.sift_descriptors 219 | elif FLAGS.descriptors == 'dp': 220 | if FLAGS.patch_size is None: 221 | raise TypeError('Patch size is required when using dp descriptor') 222 | 223 | def compute_descriptors(img, pts): 224 | return utils.dp_descriptors(img, pts, FLAGS.patch_size) 225 | else: 226 | if FLAGS.model_dir_path is None: 227 | raise TypeError( 228 | 'Trained model path is required when using trained descriptor') 229 | if FLAGS.patch_size is None: 230 | raise TypeError('Patch size is required when using trained descriptor') 231 | 232 | # create net graph and restore saved model 233 | from models import description 234 | 235 | img_pl, _ = utils.placeholder_inputs() 236 | net = description.Net(img_pl, training=False) 237 | sess = tf.Session() 238 | print('Restoring model in {}...'.format(FLAGS.model_dir_path)) 239 | utils.restore_model(sess, FLAGS.model_dir_path) 240 | print('Done') 241 | 242 | def compute_descriptors(img, pts): 243 | return utils.trained_descriptors(img, pts, FLAGS.patch_size, sess, 244 | img_pl, net.descriptors) 245 | 246 | # parse matching mode and adjust accordingly 247 | if FLAGS.mode == 'basic': 248 | match = matching.basic 249 | else: 250 | match = matching.spatial 251 | 252 | # make dir path be full appropriate dir path 253 | imgs_dir_path = None 254 | pts_dir_path = None 255 | subject_ids = None 256 | register_ids = None 257 | session_ids = None 258 | if FLAGS.fold == 'DBI-train': 259 | # adjust paths for appropriate fold 260 | imgs_dir_path = os.path.join(FLAGS.polyu_dir_path, 'DBI', 'Training') 261 | pts_dir_path = os.path.join(FLAGS.pts_dir_path, 'DBI', 'Training') 262 | 263 | # adjust ids for appropriate fold 264 | subject_ids = [ 265 | 6, 9, 11, 13, 16, 18, 34, 41, 42, 47, 62, 67, 118, 186, 187, 188, 196, 266 | 198, 202, 207, 223, 225, 226, 228, 242, 271, 272, 278, 287, 293, 297, 267 | 307, 311, 321, 323 268 | ] 269 | register_ids = [1, 2, 3] 270 | session_ids = [1, 2] 271 | else: 272 | # adjust paths for appropriate fold 273 | if FLAGS.fold == 'DBI-test': 274 | imgs_dir_path = os.path.join(FLAGS.polyu_dir_path, 'DBI', 'Test') 275 | pts_dir_path = os.path.join(FLAGS.pts_dir_path, 'DBI', 'Test') 276 | else: 277 | imgs_dir_path = os.path.join(FLAGS.polyu_dir_path, 'DBII') 278 | pts_dir_path = os.path.join(FLAGS.pts_dir_path, 'DBII') 279 | 280 | # adjust ids for appropriate fold 281 | subject_ids = [ 282 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 283 | 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 284 | 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 285 | 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 286 | 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 287 | 93, 94, 95, 96, 97, 98, 99, 100, 105, 106, 107, 108, 109, 110, 111, 288 | 112, 113, 114, 115, 116, 117, 118, 119, 120, 125, 126, 127, 128, 129, 289 | 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 290 | 144, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168 291 | ] 292 | register_ids = [1, 2, 3, 4, 5] 293 | session_ids = [1, 2] 294 | 295 | # load images, points, compute descriptors and make indices correspondences 296 | print('Loading images and detections, and computing descriptors...') 297 | all_descs, all_pts, id2index = load_dataset( 298 | imgs_dir_path, pts_dir_path, subject_ids, session_ids, register_ids, 299 | compute_descriptors) 300 | print('Done') 301 | 302 | print('Matching...') 303 | pos, neg = polyu_match( 304 | all_descs, 305 | all_pts, 306 | subject_ids, 307 | register_ids, 308 | id2index, 309 | match, 310 | thr=FLAGS.thr) 311 | print('Done') 312 | 313 | # print equal error rate 314 | print('EER = {}'.format(utils.eer(pos, neg))) 315 | 316 | # save results to file 317 | if FLAGS.results_path is not None: 318 | print('Saving results to file {}...'.format(FLAGS.results_path)) 319 | 320 | # create directory tree, if non-existing 321 | dirname = os.path.dirname(FLAGS.results_path) 322 | dirname = os.path.abspath(dirname) 323 | if not os.path.exists(dirname): 324 | os.makedirs(dirname) 325 | 326 | # save comparisons 327 | with open(FLAGS.results_path, 'w') as f: 328 | # save same subject scores 329 | for score in pos: 330 | print(1, score, file=f) 331 | 332 | # save different subject scores 333 | for score in neg: 334 | print(0, score, file=f) 335 | 336 | # save invoking command string 337 | with open(FLAGS.results_path + '.cmd', 'w') as f: 338 | print(*sys.argv, file=f) 339 | 340 | print('Done') 341 | 342 | 343 | if __name__ == '__main__': 344 | parser = argparse.ArgumentParser() 345 | parser.add_argument( 346 | '--polyu_dir_path', 347 | required=True, 348 | type=str, 349 | help='path to PolyU-HRF dataset') 350 | parser.add_argument( 351 | '--pts_dir_path', 352 | type=str, 353 | required=True, 354 | help='path to chosen dataset keypoints detections') 355 | parser.add_argument('--results_path', type=str, help='path to results file') 356 | parser.add_argument( 357 | '--descriptors', 358 | type=str, 359 | default='sift', 360 | help='which descriptors to use. Can be "sift", "dp" or "trained"') 361 | parser.add_argument( 362 | '--mode', 363 | type=str, 364 | default='basic', 365 | help='mode to match images. Can be "basic" or "spatial"') 366 | parser.add_argument( 367 | '--thr', type=float, help='distance ratio check threshold') 368 | parser.add_argument( 369 | '--model_dir_path', type=str, help='trained model directory path') 370 | parser.add_argument('--patch_size', type=int, help='pore patch size') 371 | parser.add_argument( 372 | '--fold', 373 | type=str, 374 | default='DBI-train', 375 | help= 376 | 'choose what fold of polyu to use. Can be "DBI-train", "DBI-test" and "DBII"' 377 | ) 378 | parser.add_argument('--seed', type=int, help='random seed') 379 | 380 | FLAGS = parser.parse_args() 381 | 382 | # set random seeds 383 | tf.set_random_seed(FLAGS.seed) 384 | np.random.seed(FLAGS.seed) 385 | 386 | main() 387 | -------------------------------------------------------------------------------- /polyu/detection.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import utils 5 | 6 | 7 | class _Dataset: 8 | def __init__(self, 9 | images, 10 | labels, 11 | shuffle_behavior, 12 | incomplete_batches, 13 | patch_size=None, 14 | one_hot=False, 15 | label_mode=None, 16 | label_size=None): 17 | self._images = np.array(images, dtype=np.float32) 18 | self._labels = np.array(labels, dtype=np.float32) 19 | self.patch_size = patch_size 20 | self._shuffle_behavior = shuffle_behavior 21 | self._one_hot = one_hot 22 | self._incomplete_batches = incomplete_batches 23 | 24 | # patch batch pointers 25 | self._index_in_epoch = 0 26 | self._epochs_completed = 0 27 | 28 | # image batch pointers 29 | self._image_index_in_epoch = 0 30 | self._image_epochs_completed = 0 31 | 32 | self.num_images = len(self._images) 33 | self._image_rows = self._images[0].shape[0] 34 | self._image_cols = self._images[0].shape[1] 35 | if patch_size is not None: 36 | self.num_samples = self.num_images * (self._image_rows - patch_size) * ( 37 | self._image_cols - patch_size) 38 | 39 | # create labels to be sampled by patches 40 | if label_mode == 'hard_bb': 41 | self._patch_labels = np.zeros_like(self._labels, dtype=np.float32) 42 | 43 | # draw 'label_size' bounding box around pores 44 | for k, i, j in np.ndindex(self._patch_labels.shape): 45 | i_ = max(i - label_size, 0) 46 | j_ = max(j - label_size, 0) 47 | self._patch_labels[k, i, j] = np.max( 48 | self._labels[k, i_:i + 1 + label_size, j_:j + 1 + label_size]) 49 | 50 | elif label_mode == 'hard_l1' or label_mode == 'hard_l2': 51 | self._patch_labels = np.zeros_like(self._labels, dtype=np.float32) 52 | 53 | # define norm to use 54 | norm_l = int(label_mode[-1]) 55 | 56 | # enqueue pores with origins 57 | queue = [] 58 | for pore in np.argwhere(self._labels >= 0.5): 59 | self._patch_labels[pore[0], pore[1], pore[2]] = self._labels[pore[ 60 | 0], pore[1], pore[2]] 61 | queue.append((pore, pore)) 62 | 63 | # bfs to draw l'norm_l' ball around pores 64 | while queue: 65 | # pop front 66 | coords, anchor = queue[0] 67 | queue = queue[1:] 68 | 69 | # propagate pore anchor label 70 | k, i, j = anchor 71 | val = self._patch_labels[k, i, j] 72 | 73 | # enqueue valid neighbors 74 | for d in [(0, 1, 0), (0, 0, 1), (0, -1, 0), (0, 0, -1)]: 75 | ngh = coords + d 76 | _, i, j = ngh 77 | if 0 <= i < self._patch_labels[k].shape[0] and \ 78 | 0 <= j < self._patch_labels[k].shape[1] and \ 79 | self._patch_labels[k, i, j] == 0 and \ 80 | np.linalg.norm(ngh - anchor, norm_l) <= label_size: 81 | self._patch_labels[k, i, j] = val 82 | queue.append((ngh, anchor)) 83 | 84 | def next_batch(self, batch_size, shuffle=None, incomplete=None): 85 | ''' 86 | Sample next batch, of size 'batch_size', of image patches. 87 | 88 | Args: 89 | batch_size: Size of batch to be sampled. 90 | shuffle: Overrides dataset split shuffle behavior, ie if data should be shuffled. 91 | incomplete: Overrides dataset incomplete batch behavior, ie if when completing an epoch, a batch should be provided with less samples than predicted so to not get samples from different epochs. 92 | 93 | Returns: 94 | The sampled patch batch and corresponding labels as np arrays. 95 | ''' 96 | # determine shuffle and incomplete behaviors 97 | if shuffle is None: 98 | shuffle = self._shuffle_behavior 99 | 100 | if incomplete is None: 101 | incomplete = self._incomplete_batches 102 | 103 | start = self._index_in_epoch 104 | 105 | # shuffle for first epoch 106 | if self._epochs_completed == 0 and start == 0: 107 | self._perm = np.arange(self.num_samples) 108 | if shuffle: 109 | np.random.shuffle(self._perm) 110 | 111 | # go to next epoch 112 | if start + batch_size >= self.num_samples: 113 | # finished epoch 114 | self._epochs_completed += 1 115 | 116 | # get the rest of samples in this epoch 117 | rest_num_samples = self.num_samples - start 118 | images_rest_part, labels_rest_part = self._to_patches(self._perm[start:]) 119 | 120 | # shuffle the data 121 | if shuffle: 122 | self._perm = np.arange(self.num_samples) 123 | np.random.shuffle(self._perm) 124 | 125 | # return incomplete batch 126 | if incomplete: 127 | self._index_in_epoch = 0 128 | return images_rest_part, labels_rest_part 129 | 130 | # start next epoch 131 | start = 0 132 | self._index_in_epoch = batch_size - rest_num_samples 133 | end = self._index_in_epoch 134 | 135 | # retrive samples in the new epoch 136 | images_new_part, labels_new_part = self._to_patches( 137 | self._perm[start:end]) 138 | 139 | return np.concatenate( 140 | (images_rest_part, images_new_part), axis=0), np.concatenate( 141 | (labels_rest_part, labels_new_part), axis=0) 142 | else: 143 | self._index_in_epoch += batch_size 144 | end = self._index_in_epoch 145 | return self._to_patches(self._perm[start:end]) 146 | 147 | def next_image_batch(self, batch_size, shuffle=None, incomplete=None): 148 | ''' 149 | Sample next batch, of size 'batch_size', of images. 150 | 151 | Args: 152 | batch_size: Size of batch to be sampled. 153 | shuffle: Overrides dataset split shuffle behavior, ie if data should be shuffled. 154 | incomplete: Overrides dataset incomplete batch behavior, ie if when completing an epoch, a batch should be provided with less images than predicted so to not get images from different epochs. 155 | 156 | Returns: 157 | The sampled image batch and corresponding labels as np arrays. 158 | ''' 159 | # determine shuffle and incomplete behaviors 160 | if shuffle is None: 161 | shuffle = self._shuffle_behavior 162 | 163 | if incomplete is None: 164 | incomplete = self._incomplete_batches 165 | 166 | start = self._image_index_in_epoch 167 | 168 | # shuffle for first epoch 169 | if self._image_epochs_completed == 0 and start == 0: 170 | self._image_perm = np.arange(self.num_images) 171 | if shuffle: 172 | np.random.shuffle(self._image_perm) 173 | 174 | # go to next epoch 175 | if start + batch_size >= self.num_images: 176 | # finished epoch 177 | self._image_epochs_completed += 1 178 | 179 | # get the rest of images in this epoch 180 | rest_num_images = self.num_images - start 181 | images_rest_part, labels_rest_part = self._images[start:], self._labels[ 182 | start:] 183 | 184 | # shuffle the data 185 | if shuffle: 186 | self._image_perm = np.arange(self.num_images) 187 | np.random.shuffle(self._image_perm) 188 | 189 | # return incomplete batch 190 | if incomplete: 191 | self._image_index_in_epoch = 0 192 | return images_rest_part, labels_rest_part 193 | 194 | # start next epoch 195 | start = 0 196 | self._image_index_in_epoch = batch_size - rest_num_images 197 | end = self._image_index_in_epoch 198 | 199 | # retrive images in the new epoch 200 | images_new_part, labels_new_part = self._images[start:end], self._labels[ 201 | start:end] 202 | 203 | return np.concatenate( 204 | (images_rest_part, images_new_part), axis=0), np.concatenate( 205 | (labels_rest_part, labels_new_part), axis=0) 206 | else: 207 | self._image_index_in_epoch += batch_size 208 | end = self._image_index_in_epoch 209 | return self._images[start:end], self._labels[start:end] 210 | 211 | def _to_patches(self, indices): 212 | ''' 213 | Retrieves image patches, on demand, corresponding to given indices. 214 | A patch of size WINDOW_SIZE is centered in the i-th row of the j-th column of the k-th image (of dimensions ROWS x COLS) according to the given index I: 215 | k = I / ((ROWS - WINDOW_SIZE) * (COLS - WINDOW_SIZE)) 216 | i = I / (COLS - WINDOW_SIZE) - k * (ROWS - WINDOW_SIZE) 217 | j = I mod (COLS - SIZE) 218 | 219 | Args: 220 | indices: Indices for which patches will be produced. 221 | 222 | Returns: 223 | patches: Windows corresponding to the given indices. 224 | labels: Labels of the returned patches. 225 | 226 | ''' 227 | # patches (samples) and labels 228 | size = self.patch_size 229 | patches = np.empty([indices.shape[0], size, size], np.float32) 230 | if self._one_hot: 231 | labels = np.empty([indices.shape[0], 2], np.float32) 232 | else: 233 | labels = np.empty([indices.shape[0]], np.float32) 234 | 235 | for index in range(indices.shape[0]): 236 | image_index = indices[index] 237 | 238 | # retrieve image number, row and column from index 239 | k = image_index // ( 240 | (self._image_rows - size) * (self._image_cols - size)) 241 | i = image_index // (self._image_cols - size) - k * ( 242 | self._image_rows - size) 243 | j = image_index % (self._image_cols - size) 244 | 245 | # generate patch 246 | patches[index] = self._images[k, i:i + size, j:j + size] 247 | 248 | # generate corresponding label 249 | center = size // 2 250 | if self._one_hot: 251 | labels[index, 0] = self._patch_labels[k, i + center, j + center] 252 | labels[index, 1] = 1 - labels[index, 0] 253 | else: 254 | labels[index] = self._patch_labels[k, i + center, j + center] 255 | 256 | return patches, labels 257 | 258 | 259 | class Dataset: 260 | ''' 261 | PolyU-HRF detection dataset handler. Contains a _Dataset for 262 | training and, depending on how it splits the entire dataset, 263 | another one for validation and another for testing. It 264 | converts ground truth coordinates into region labels according 265 | to label_mode and label_size. 266 | ''' 267 | 268 | def __init__(self, 269 | images_folder_path, 270 | labels_folder_path, 271 | split, 272 | patch_size=None, 273 | label_mode='hard_bb', 274 | label_size=3, 275 | should_shuffle=True, 276 | one_hot=False): 277 | ''' 278 | Args: 279 | images_folder_path: path from which to load images. 280 | labels_folder_path: path from which to label txt files. 281 | Must have name correspondence with images in 282 | images_folder_path. 283 | split: tuple determining how to split the detection dataset. 284 | split[0] gives the number of training images, split[1] 285 | the number of validation images, and split[2] the number 286 | of test images. Images are split sequentially into these 287 | sets, i.e. the training set takes the first split[0] images, 288 | the validation set takes the next split[1] images etc. 289 | patch_size: if not None, allows the sampling of patches, 290 | instead of images, from the dataset. Patches are then 291 | of size patch_size by patch_size. 292 | label_mode: mode of converting detection coordinates to 293 | labels. Can be either 'hard_bb', 'hard_l2' or 'hard_l1'. 294 | 'hard_bb' draws a bounding box of size 'label_size' around 295 | each detection. 'hard_lX' draws an LX ball of radius 296 | 'label_size' around each detection. 297 | label_size: see above, in label_mode description, for 298 | meaning. 299 | should_shuffle: whether the training set should be shuffled 300 | between batches. 301 | one_hot: whether the labels should be provided as one hot 302 | vectors or with integer values. 303 | ''' 304 | self._images = utils.load_images(images_folder_path) 305 | self._labels = self._load_labels(labels_folder_path) 306 | 307 | # splits loaded according to given 'split' 308 | if split[0] > 0: 309 | self.train = _Dataset( 310 | self._images[:split[0]], 311 | self._labels[:split[0]], 312 | shuffle_behavior=should_shuffle, 313 | incomplete_batches=False, 314 | patch_size=patch_size, 315 | one_hot=one_hot, 316 | label_mode=label_mode, 317 | label_size=label_size) 318 | else: 319 | self.train = None 320 | 321 | if split[1] > 0: 322 | self.val = _Dataset( 323 | self._images[split[0]:split[0] + split[1]], 324 | self._labels[split[0]:split[0] + split[1]], 325 | shuffle_behavior=False, 326 | incomplete_batches=True, 327 | patch_size=patch_size, 328 | one_hot=one_hot, 329 | label_mode=label_mode, 330 | label_size=label_size) 331 | else: 332 | self.val = None 333 | 334 | if split[2] > 0: 335 | self.test = _Dataset( 336 | self._images[split[0] + split[1]:split[0] + split[1] + split[2]], 337 | self._labels[split[0] + split[1]:split[0] + split[1] + split[2]], 338 | shuffle_behavior=False, 339 | incomplete_batches=True, 340 | patch_size=patch_size, 341 | one_hot=one_hot, 342 | label_mode=label_mode, 343 | label_size=label_size) 344 | else: 345 | self.test = None 346 | 347 | def _load_labels(self, folder_path): 348 | labels = [] 349 | for img_index, label_path in enumerate(sorted(os.listdir(folder_path))): 350 | if label_path.endswith('.txt'): 351 | labels.append( 352 | self._load_txt_label( 353 | os.path.join(folder_path, label_path), img_index)) 354 | 355 | return labels 356 | 357 | def _load_txt_label(self, label_path, img_index): 358 | label = np.zeros(self._images[img_index].shape, np.float32) 359 | with open(label_path, 'r') as f: 360 | for line in f: 361 | row, col = [int(j) for j in line.split()] 362 | label[row - 1, col - 1] = 1 363 | 364 | return label 365 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def placeholder_inputs(): 8 | images = tf.placeholder(tf.float32, [None, None, None, 1], name='images') 9 | labels = tf.placeholder(tf.float32, [None, 1], name='labels') 10 | return images, labels 11 | 12 | 13 | def fill_feed_dict(dataset, patches_pl, labels_pl, batch_size, augment=False): 14 | ''' 15 | Creates a tf feed_dict containing patches and corresponding labels from a 16 | mini-batch of size batch_size sampled from dataset, possibly transforming it 17 | for online dataset augmentation. 18 | 19 | Args: 20 | dataset: dataset object satisfying polyu._Dataset.next_batch signature. 21 | patches_pl: tf placeholder for patches. 22 | labels_pl: tf placeholder for labels. 23 | batch_size: size of mini-batch to be sampled. 24 | augment: whether or not this mini-batch should be transformed. 25 | 26 | Returns: 27 | tf feed_dict containing possibly augmented patches and corresponding labels. 28 | ''' 29 | patches_feed, labels_feed = dataset.next_batch(batch_size) 30 | 31 | if augment: 32 | patches_feed = _transform_mini_batch(patches_feed) 33 | 34 | feed_dict = { 35 | patches_pl: np.expand_dims(patches_feed, axis=-1), 36 | labels_pl: np.expand_dims(labels_feed, axis=-1) 37 | } 38 | 39 | return feed_dict 40 | 41 | 42 | def _transform_mini_batch(sample): 43 | ''' 44 | Transforms an image sample with contrast and brightness variations, 45 | translations and rotations. These transformations are sampled from a 46 | multivariate normal distribution with zero covariance or, equivalently, 47 | from individual normal distributions. Rotations and translations are 48 | done as if each image of sample is zero-padded and are obtained by 49 | linearly interpolating the image. 50 | 51 | Args: 52 | sample: array of images to be transformed. 53 | 54 | Returns: 55 | randomly transformed sample. 56 | ''' 57 | # contrast and brightness variations 58 | contrast = np.random.normal(loc=1, scale=0.05, size=(sample.shape[0], 1, 1)) 59 | brightness = np.random.normal( 60 | loc=0, scale=0.05, size=(sample.shape[0], 1, 1)) 61 | sample = contrast * sample + brightness 62 | 63 | # translation and rotation 64 | transformed = [] 65 | for image in sample: 66 | # random translation 67 | dx = np.random.normal(loc=0, scale=1) 68 | dy = np.random.normal(loc=0, scale=1) 69 | A = np.array([[1, 0, dx], [0, 1, dy]]) 70 | 71 | # random rotation 72 | theta = np.random.normal(loc=0, scale=7.5) 73 | center = (image.shape[1] // 2, image.shape[0] // 2) 74 | B = cv2.getRotationMatrix2D(center, theta, 1) 75 | 76 | # transform image 77 | image = cv2.warpAffine(image, A, image.shape[::-1], flags=cv2.INTER_LINEAR) 78 | image = cv2.warpAffine(image, B, image.shape[::-1], flags=cv2.INTER_LINEAR) 79 | 80 | # add to batch images 81 | transformed.append(image) 82 | 83 | return np.array(transformed) 84 | 85 | 86 | def create_dirs(log_dir_path, 87 | batch_size, 88 | learning_rate, 89 | batch_size2=None, 90 | learning_rate2=None): 91 | import datetime 92 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S") 93 | if batch_size2 is None or learning_rate2 is None: 94 | # individual training 95 | log_dir = os.path.join( 96 | log_dir_path, 'bs-{}_lr-{:.0e}_t-{}'.format(batch_size, learning_rate, 97 | timestamp)) 98 | else: 99 | # approximate joint training 100 | log_dir = os.path.join( 101 | log_dir_path, 'bs-{}x{}_lr-{:.0e}x{}_t-{}'.format( 102 | batch_size, batch_size2, learning_rate, learning_rate2, timestamp)) 103 | 104 | tf.gfile.MakeDirs(log_dir) 105 | 106 | return log_dir 107 | 108 | 109 | def nms(centers, probs, bb_size, thr): 110 | ''' 111 | Converts each center in centers into center centered bb_size x bb_size 112 | bounding boxes and applies Non-Maximum-Suppression to them. 113 | 114 | Args: 115 | centers: centers of detection bounding boxes. 116 | probs: probabilities for each of the detections. 117 | bb_size: bounding box size. 118 | thr: NMS discarding intersection threshold. 119 | 120 | Returns: 121 | dets: np.array of NMS filtered detections. 122 | det_probs: np.array of corresponding detection probabilities. 123 | ''' 124 | area = bb_size * bb_size 125 | half_bb_size = bb_size // 2 126 | 127 | xs, ys = np.transpose(centers) 128 | x1 = xs - half_bb_size 129 | x2 = xs + half_bb_size 130 | y1 = ys - half_bb_size 131 | y2 = ys + half_bb_size 132 | 133 | order = np.argsort(probs)[::-1] 134 | 135 | dets = [] 136 | det_probs = [] 137 | while len(order) > 0: 138 | i = order[0] 139 | order = order[1:] 140 | dets.append(centers[i]) 141 | det_probs.append(probs[i]) 142 | 143 | xx1 = np.maximum(x1[i], x1[order]) 144 | yy1 = np.maximum(y1[i], y1[order]) 145 | xx2 = np.minimum(x2[i], x2[order]) 146 | yy2 = np.minimum(y2[i], y2[order]) 147 | 148 | w = np.maximum(0.0, xx2 - xx1 + 1) 149 | h = np.maximum(0.0, yy2 - yy1 + 1) 150 | inter = w * h 151 | ovr = inter / (2 * area - inter) 152 | 153 | inds = np.where(ovr <= thr)[0] 154 | order = order[inds] 155 | 156 | return np.array(dets), np.array(det_probs) 157 | 158 | 159 | def pairwise_distances(x1, x2): 160 | # memory efficient implementation based on Yaroslav Bulatov's answer in 161 | # https://stackoverflow.com/questions/37009647/compute-pairwise-distance-in-a-batch-without-replicating-tensor-in-tensorflow 162 | sqr1 = np.sum(x1 * x1, axis=1, keepdims=True) 163 | sqr2 = np.sum(x2 * x2, axis=1) 164 | D = sqr1 - 2 * np.matmul(x1, x2.T) + sqr2 165 | 166 | return D 167 | 168 | 169 | def restore_model(sess, model_dir): 170 | saver = tf.train.Saver() 171 | ckpt = tf.train.get_checkpoint_state(model_dir) 172 | if ckpt and ckpt.model_checkpoint_path: 173 | saver.restore(sess, ckpt.model_checkpoint_path) 174 | else: 175 | raise IOError('No model found in {}.'.format(model_dir)) 176 | 177 | 178 | def load_image(image_path): 179 | image = cv2.imread(image_path, 0) 180 | image = np.array(image, dtype=np.float32) / 255 181 | return image 182 | 183 | 184 | def load_images(folder_path): 185 | images = [] 186 | for image_path in sorted(os.listdir(folder_path)): 187 | if image_path.endswith(('.jpg', '.png', '.bmp')): 188 | images.append(load_image(os.path.join(folder_path, image_path))) 189 | 190 | return images 191 | 192 | 193 | def load_images_with_names(images_dir): 194 | images = load_images(images_dir) 195 | image_names = [ 196 | path.split('.')[0] for path in sorted(os.listdir(images_dir)) 197 | if path.endswith(('.jpg', '.bmp', '.png')) 198 | ] 199 | 200 | return images, image_names 201 | 202 | 203 | def save_dets_txt(dets, filename): 204 | with open(filename, 'w') as f: 205 | for coord in dets: 206 | print(coord[0] + 1, coord[1] + 1, file=f) 207 | 208 | 209 | def load_dets_txt(pts_path): 210 | pts = [] 211 | with open(pts_path, 'r') as f: 212 | for line in f: 213 | row, col = [int(t) for t in line.split()] 214 | pts.append((row - 1, col - 1)) 215 | 216 | return pts 217 | 218 | 219 | def bilinear_interpolation(x, y, f): 220 | x1 = int(x) 221 | y1 = int(y) 222 | x2 = x1 + 1 223 | y2 = y1 + 1 224 | 225 | fq = [[f[x1, y1], f[x1, y2]], [f[x2, y1], f[x1, y2]]] 226 | lhs = [[x2 - x, x - x1]] 227 | rhs = [y2 - y, y - y1] 228 | 229 | return np.dot(np.dot(lhs, fq), rhs) 230 | 231 | 232 | def sift_descriptors(img, pts, scale=4, normalize=True): 233 | ''' 234 | Computes SIFT descriptors in pts keypoints of the image img. 235 | If normalize is True, image is normalized with median blur and 236 | CLAHE. 237 | 238 | Args: 239 | img: image from which descriptors will be computed. 240 | pts: [N, 2] coordinates of N keypoints in img from which 241 | descriptors will be computed. 242 | scale: SIFT scale. 243 | normalize: whether to normalize image with median blur and 244 | CLAHE prior to descriptor computation. 245 | 246 | Returns: 247 | [N, 128] np.array of N SIFT descriptors. 248 | ''' 249 | # empty detections set 250 | if len(pts) == 0: 251 | return [] 252 | 253 | # convert float image to np uint8 254 | if img.dtype == np.float32: 255 | img = np.array(255 * img, dtype=np.uint8) 256 | 257 | # improve image quality with median blur and clahe 258 | if normalize: 259 | img = cv2.medianBlur(img, ksize=3) 260 | clahe = cv2.createCLAHE(clipLimit=3) 261 | img = clahe.apply(img) 262 | 263 | # convert points to cv2.keypoints 264 | pts = list(np.asarray(pts)[:, [1, 0]]) 265 | kpts = cv2.KeyPoint.convert(pts, size=scale) 266 | 267 | # extract sift descriptors 268 | sift = cv2.xfeatures2d.SIFT_create() 269 | _, descs = sift.compute(img, kpts) 270 | 271 | return descs 272 | 273 | 274 | def find_correspondences(descs1, 275 | descs2, 276 | pts1=None, 277 | pts2=None, 278 | euclidean_weight=0, 279 | transf=None, 280 | thr=None): 281 | ''' 282 | Finds bidirectional correspondences between descs1 descriptors and 283 | descs2 descriptors. If thr is provided, discards correspondences 284 | that fail a distance ratio check with threshold thr. If pts1, pts2, 285 | and transf are give, the metric considered when finding correspondences 286 | is 287 | d(i, j) = ||descs1(j) - descs2(j)||^2 + euclidean_weight * 288 | * ||transf(pts1(i)) - pts2(j)||^2 289 | 290 | Args: 291 | descs1: [N, M] np.array of N descriptors of dimension M each. 292 | descs2: [N, M] np.array of N descriptors of dimension M each. 293 | pts1: [N, 2] np.array of N coordinates for each descriptor in descs1. 294 | pts2: [N, 2] np.array of N coordinates for each descriptor in descs2. 295 | euclidean_weight: weight given to spatial constraint in comparison 296 | metric. 297 | transf: alignment transformation that aligns pts1 to pts2. 298 | thr: distance ratio check threshold. 299 | 300 | Returns: 301 | list of correspondence tuples (j, i, d) in which index j of 302 | descs2 corresponds with i of descs1 with distance d. 303 | ''' 304 | # compute descriptors' pairwise distances 305 | D = pairwise_distances(descs1, descs2) 306 | 307 | # add points' euclidean distance 308 | if euclidean_weight != 0: 309 | assert transf is not None 310 | assert pts1 is not None 311 | assert pts2 is not None 312 | 313 | # assure pts are np array 314 | pts1 = transf(np.array(pts1)) 315 | pts2 = np.array(pts2) 316 | 317 | # compute points' pairwise distances 318 | euclidean_D = pairwise_distances(pts1, pts2) 319 | 320 | # add to overral keypoints distance 321 | D += euclidean_weight * euclidean_D 322 | 323 | # find bidirectional corresponding points 324 | pairs = [] 325 | if thr is None or len(descs1) == 1 or len(descs2) == 1: 326 | # find the best correspondence of each element 327 | # in 'descs2' to an element in 'descs1' 328 | corrs2 = np.argmin(D, axis=0) 329 | 330 | # find the best correspondence of each element 331 | # in 'descs1' to an element in 'descs2' 332 | corrs1 = np.argmin(D, axis=1) 333 | 334 | # keep only bidirectional correspondences 335 | for i, j in enumerate(corrs2): 336 | if corrs1[j] == i: 337 | pairs.append((j, i, D[j, i])) 338 | else: 339 | # find the 2 best correspondences of each 340 | # element in 'descs2' to an element in 'descs1' 341 | corrs2 = np.argpartition(D.T, [0, 1])[:, :2] 342 | 343 | # find the 2 best correspondences of each 344 | # element in 'descs1' to an element in 'descs2' 345 | corrs1 = np.argpartition(D, [0, 1])[:, :2] 346 | 347 | # find bidirectional corresponding points 348 | # with second best correspondence 'thr' 349 | # worse than best one 350 | for i, (j, _) in enumerate(corrs2): 351 | if corrs1[j, 0] == i: 352 | # discard close best second correspondences 353 | if D[j, i] < D[corrs2[i, 1], i] * thr: 354 | if D[j, i] < D[j, corrs1[j, 1]] * thr: 355 | pairs.append((j, i, D[j, i])) 356 | 357 | return pairs 358 | 359 | 360 | def load_images_with_labels(folder_path): 361 | images = [] 362 | labels = [] 363 | for image_path in sorted(os.listdir(folder_path)): 364 | if image_path.endswith(('.jpg', '.png', '.bmp')): 365 | images.append(load_image(os.path.join(folder_path, image_path))) 366 | labels.append(retrieve_label_from_image_path(image_path)) 367 | 368 | return images, labels 369 | 370 | 371 | def retrieve_label_from_image_path(image_path): 372 | return int(image_path.split('_')[0]) 373 | 374 | 375 | def eer(pos, neg): 376 | ''' 377 | Computes the Equal Error Rate of given comparison scores. 378 | If FAR and FRR crossing is not exact, lineary interpolate the ROC and 379 | compute its intersection with the identity line f(x) = x. 380 | 381 | Args: 382 | pos: scores of genuine comparisons. 383 | neg: scores of impostor comparisons. 384 | 385 | Returns: 386 | EER of comparisons. 387 | ''' 388 | # compute roc curve 389 | fars, frrs = roc(pos, neg) 390 | 391 | # iterate to find equal error rate 392 | old_far = None 393 | old_frr = None 394 | for far, frr in zip(fars, frrs): 395 | # if crossing happened, eer is found 396 | if far >= frr: 397 | break 398 | else: 399 | old_far = far 400 | old_frr = frr 401 | 402 | # if crossing is precisely found, return it 403 | # otherwise, approximate it though ROC linear 404 | # interpolation and intersection with f(x) = x 405 | if far == frr: 406 | return far 407 | else: 408 | return (far * old_frr - old_far * frr) / (far - old_far - (frr - old_frr)) 409 | 410 | 411 | def roc(pos, neg): 412 | ''' 413 | Computes Receiver Operating Characteristic curve for given comparison scores. 414 | 415 | Args: 416 | pos: scores of genuine comparisons. 417 | neg: scores of impostor comparisons. 418 | 419 | Returns: 420 | fars: False Acceptance Rates (FARs) over all possible thresholds. 421 | frrs: False Rejection Rates (FRRs) over all possible thresholds. 422 | ''' 423 | # sort comparisons arrays for efficiency 424 | pos = sorted(pos, reverse=True) 425 | neg = sorted(neg, reverse=True) 426 | 427 | # get all scores 428 | scores = list(pos) + list(neg) 429 | scores = np.unique(scores) 430 | 431 | # iterate to compute statistsics 432 | fars = [0.0] 433 | frrs = [1.0] 434 | pos_cursor = 0 435 | neg_cursor = 0 436 | for score in reversed(scores): 437 | # find correspondent positive score 438 | while pos_cursor < len(pos) and pos[pos_cursor] > score: 439 | pos_cursor += 1 440 | 441 | # find correspondent negative score 442 | while neg_cursor < len(neg) and neg[neg_cursor] > score: 443 | neg_cursor += 1 444 | 445 | # compute metrics for score 446 | far = neg_cursor / len(neg) 447 | frr = 1 - pos_cursor / len(pos) 448 | 449 | # add to overall statisics 450 | fars.append(far) 451 | frrs.append(frr) 452 | 453 | # add last step 454 | fars.append(1.0) 455 | frrs.append(0.0) 456 | 457 | return fars, frrs 458 | 459 | 460 | def retrieval_rank(probe_instance, probe_label, instances, labels): 461 | # compute distance of 'probe_instance' to 462 | # every instance in 'instances' 463 | dists = np.sum((instances - probe_instance)**2, axis=1) 464 | 465 | # sort labels according to instances distances 466 | matches = np.argsort(dists) 467 | labels = labels[matches] 468 | 469 | # find index of first instance with label 'probe_label' 470 | index = np.argwhere(labels == probe_label)[0, 0] 471 | 472 | # compute retrieval rank 473 | labels_up_to_index = np.unique(labels[:index + 1]) 474 | rank = len(labels_up_to_index) 475 | 476 | return rank 477 | 478 | 479 | def rank_n(instances, labels, sample_size): 480 | # get unique labels 481 | unique_labels = np.unique(labels) 482 | 483 | # initialize ranks 484 | ranks = np.zeros_like(unique_labels, dtype=np.int32) 485 | 486 | # sort examples by labels 487 | inds = np.argsort(labels) 488 | instances = instances[inds] 489 | labels = labels[inds] 490 | 491 | # compute rank following protocol in belongie et al. 492 | examples = list(zip(instances, labels)) 493 | for i, (probe, probe_label) in enumerate(examples): 494 | for target, target_label in examples[i + 1:]: 495 | if probe_label != target_label: 496 | break 497 | else: 498 | # mix examples of other labels 499 | other_labels_inds = np.argwhere(labels != probe_label) 500 | other_labels_inds = np.squeeze(other_labels_inds) 501 | inds_to_pick = np.random.choice( 502 | other_labels_inds, sample_size - 1, replace=False) 503 | instances_to_mix = instances[inds_to_pick] 504 | labels_to_mix = labels[inds_to_pick] 505 | 506 | # make set for retrieval 507 | target = np.expand_dims(target, axis=0) 508 | instance_set = np.concatenate([instances_to_mix, target], axis=0) 509 | target_label = np.expand_dims(target_label, axis=0) 510 | label_set = np.concatenate([labels_to_mix, target_label], axis=0) 511 | 512 | # compute retrieval rank for probe 513 | rank = retrieval_rank(probe, probe_label, instance_set, label_set) 514 | 515 | # update ranks, indexed from 0 516 | ranks[rank - 1] += 1 517 | 518 | # rank is cumulative 519 | ranks = np.cumsum(ranks) 520 | 521 | # normalize rank to [0, 1] range 522 | ranks = ranks / ranks[-1] 523 | 524 | return ranks 525 | 526 | 527 | def trained_descriptors(img, pts, patch_size, session, imgs_pl, descs_op): 528 | ''' 529 | Computes descriptors according to descs_op tf tensor operation for 530 | pts keypoints and patch size of patch_size in the given image. 531 | If patch_size is even, the last row and last column of the patch 532 | centered on each keypoint is discarded before descriptor computation. 533 | 534 | Args: 535 | img: image in which keypoints were detected. 536 | pts: [N, 2] keypoint coordinates for which descriptors 537 | should be computed. 538 | patch_size: patch size for descriptor. 539 | session: tf session with loaded descs_op variables. 540 | imgs_pl: image input placeholder for descs_op. 541 | descs_op: tf tensor op that describes images in imgs_pl. 542 | 543 | Returns: 544 | descs: [N, M] np.array of N descriptors of descs_op of dimension 545 | M. 546 | new_pts: [N, 2] keypoint coordinates filtered so its indices match 547 | those of `descs`. 548 | ''' 549 | # adjust for odd patch sizes 550 | odd = 1 if patch_size % 2 != 0 else 0 551 | 552 | # get patch locations at 'pts' 553 | half = patch_size // 2 554 | patches = [] 555 | new_pts = [] 556 | for pt in pts: 557 | if half <= pt[0] < img.shape[0] - half - odd: 558 | if half <= pt[1] < img.shape[1] - half - odd: 559 | patch = img[pt[0] - half:pt[0] + half + odd, pt[1] - half:pt[1] + 560 | half + odd] 561 | new_pts.append(pt) 562 | patches.append(patch) 563 | 564 | # empty detections set 565 | if len(patches) == 0: 566 | return [], [] 567 | 568 | # describe patches 569 | feed_dict = {imgs_pl: np.reshape(patches, np.shape(patches) + (1, ))} 570 | descs = session.run(descs_op, feed_dict=feed_dict) 571 | 572 | return descs, new_pts 573 | 574 | 575 | def compute_orientation(img): 576 | ''' 577 | Computes fingerprint ridge orientation for every spatial location using 578 | the method proposed by Hong et al. (Fingerprint Image Enhancement: 579 | Algorithm and Performance Evaluation, 1998). 580 | 581 | Args: 582 | img: fingerprint image for which orientation will be computed. 583 | 584 | Returns: 585 | matrix with ridge orientation per spatial location. 586 | ''' 587 | # compute gradients 588 | dx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=3) 589 | dy = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=3) 590 | 591 | # compute noisy orientation 592 | nu_x = np.zeros_like(img, dtype=np.float32) 593 | nu_y = np.zeros_like(img, dtype=np.float32) 594 | for i in range(8, img.shape[0] - 8): 595 | for j in range(8, img.shape[1] - 8): 596 | sub_dx = np.reshape(dx[i - 8:i + 9, j - 8:j + 9], -1) 597 | sub_dy = np.reshape(dy[i - 8:i + 9, j - 8:j + 9], -1) 598 | 599 | nu_x[i, j] = 2 * np.dot(sub_dx, sub_dy) 600 | nu_y[i, j] = np.dot(sub_dx + sub_dy, sub_dx - sub_dy) 601 | 602 | # refine orientation 603 | phi_x = cv2.GaussianBlur(nu_x, (5, 5), 0) 604 | phi_y = cv2.GaussianBlur(nu_y, (5, 5), 0) 605 | orientation = np.arctan2(phi_x, phi_y) / 2 606 | 607 | return orientation 608 | 609 | 610 | def dp_descriptors(img, pts, patch_size): 611 | ''' 612 | Computes Direct Pore (DP) descriptors (Direct Pore Matching for 613 | Fingerprint Recognition, 2009) for pts keypoints and patch 614 | size of patch_size in the given image. If patch_size is even, 615 | the last row and last column of the patch centered on each keypoint 616 | is discarded before descriptor computation. 617 | 618 | Args: 619 | img: image in which keypoints were detected. 620 | pts: [N, 2] keypoint coordinates for which DP descriptors 621 | should be computed. 622 | patch_size: patch size for DP descriptor. 623 | 624 | Returns: 625 | [N, M] np.array of N DP descriptors of dimension M. 626 | ''' 627 | # adjust for odd patch sizes 628 | odd = 1 if patch_size % 2 != 0 else 0 629 | 630 | # compute image's orientation per pixel 631 | orientation = compute_orientation(img) 632 | 633 | # gaussian blur image 634 | img = cv2.GaussianBlur(img, (3, 3), 0) 635 | 636 | # get patch locations at 'pts' 637 | half = patch_size // 2 638 | center = (half, half) 639 | descs = [] 640 | for pt in pts: 641 | if half <= pt[0] < img.shape[0] - half - odd: 642 | if half <= pt[1] < img.shape[1] - half - odd: 643 | # extract patch at 'pt' 644 | patch = img[pt[0] - half:pt[0] + half + odd, pt[1] - half:pt[1] + 645 | half + odd] 646 | 647 | # normalize orientation 648 | theta = orientation[pt[0], pt[1]] - np.pi / 2 649 | rot_mat = cv2.getRotationMatrix2D(center, 180 * theta / np.pi, 1) 650 | patch = cv2.warpAffine( 651 | patch, rot_mat, patch.shape[::-1], flags=cv2.INTER_LINEAR) 652 | 653 | # circular mask 654 | for i in range(patch_size): 655 | for j in range(patch_size): 656 | if np.hypot(i - half, j - half) > half: 657 | patch[i, j] = 0 658 | 659 | # reshape and normalize 660 | patch = np.reshape(patch, -1) 661 | patch -= np.mean(patch) 662 | patch = patch / np.linalg.norm(patch) 663 | 664 | descs.append(patch) 665 | 666 | return np.array(descs) 667 | 668 | 669 | def detect_pores(image, image_pl, predictions, half_patch_size, prob_thr, 670 | inter_thr, sess): 671 | ''' 672 | Detects pores in an image. First, a pore probability map is computed 673 | with the tf predictions op. This probability map is then thresholded 674 | and converted to coordinates, which are filtered with NMS. 675 | 676 | Args: 677 | image: image in which to detect pores. 678 | image_pl: tf placeholder holding net's image input. 679 | predictions: tf tensor op of net's output. 680 | half_patch_size: half the detection patch size. used for padding the 681 | predictions to the input's original dimensions. 682 | prob_thr: probability threshold. 683 | inter_thr: NMS intersection threshold. 684 | sess: tf session 685 | 686 | Returns: 687 | detections for image in shape [N, 2] 688 | ''' 689 | # predict probability of pores 690 | pred = sess.run( 691 | predictions, 692 | feed_dict={image_pl: np.reshape(image, (1, ) + image.shape + (1, ))}) 693 | 694 | # add borders lost in convolution 695 | pred = np.reshape(pred, pred.shape[1:-1]) 696 | pred = np.pad(pred, ((half_patch_size, half_patch_size), 697 | (half_patch_size, half_patch_size)), 'constant') 698 | 699 | # convert into coordinates 700 | pick = pred > prob_thr 701 | coords = np.argwhere(pick) 702 | probs = pred[pick] 703 | 704 | # filter detections with nms 705 | dets, _ = nms(coords, probs, 7, inter_thr) 706 | 707 | return dets 708 | --------------------------------------------------------------------------------