├── .gitignore ├── LFWDataset.py ├── LICENSE ├── TripletFaceDataset.py ├── dataset ├── clean_msceleb_using_openface.py ├── download_vgg_face_dataset.py └── extract_msceleb.py ├── eval_metrics.py ├── lfw_pairs.txt ├── logger.py ├── model.py ├── train_center.py ├── train_triplet.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | __pycache__/ 3 | align/ -------------------------------------------------------------------------------- /LFWDataset.py: -------------------------------------------------------------------------------- 1 | import torchvision.datasets as datasets 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | class LFWDataset(datasets.ImageFolder): 7 | ''' 8 | ''' 9 | def __init__(self, dir,pairs_path, transform=None): 10 | 11 | super(LFWDataset, self).__init__(dir,transform) 12 | 13 | self.pairs_path = pairs_path 14 | 15 | # LFW dir contains 2 folders: faces and lists 16 | self.validation_images = self.get_lfw_paths(dir) 17 | 18 | def read_lfw_pairs(self,pairs_filename): 19 | pairs = [] 20 | with open(pairs_filename, 'r') as f: 21 | for line in f.readlines()[1:]: 22 | pair = line.strip().split() 23 | pairs.append(pair) 24 | return np.array(pairs) 25 | 26 | def get_lfw_paths(self,lfw_dir,file_ext="jpg"): 27 | 28 | pairs = self.read_lfw_pairs(self.pairs_path) 29 | 30 | nrof_skipped_pairs = 0 31 | path_list = [] 32 | issame_list = [] 33 | 34 | for i in tqdm(range(len(pairs))): 35 | #for pair in pairs: 36 | pair = pairs[i] 37 | if len(pair) == 3: 38 | path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])+'.'+file_ext) 39 | path1 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])+'.'+file_ext) 40 | issame = True 41 | elif len(pair) == 4: 42 | path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])+'.'+file_ext) 43 | path1 = os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])+'.'+file_ext) 44 | issame = False 45 | if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist 46 | path_list.append((path0,path1,issame)) 47 | issame_list.append(issame) 48 | else: 49 | nrof_skipped_pairs += 1 50 | if nrof_skipped_pairs>0: 51 | print('Skipped %d image pairs' % nrof_skipped_pairs) 52 | 53 | return path_list 54 | 55 | def __getitem__(self, index): 56 | ''' 57 | 58 | Args: 59 | index: Index of the triplet or the matches - not of a single image 60 | 61 | Returns: 62 | 63 | ''' 64 | 65 | def transform(img_path): 66 | """Convert image into numpy array and apply transformation 67 | Doing this so that it is consistent with all other datasets 68 | to return a PIL Image. 69 | """ 70 | 71 | img = self.loader(img_path) 72 | return self.transform(img) 73 | 74 | (path_1,path_2,issame) = self.validation_images[index] 75 | img1, img2 = transform(path_1), transform(path_2) 76 | return img1, img2, issame 77 | 78 | 79 | def __len__(self): 80 | return len(self.validation_images) 81 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /TripletFaceDataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torchvision.datasets as datasets 4 | import os 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | class TripletFaceDataset(datasets.ImageFolder): 9 | 10 | def __init__(self, dir, n_triplets, transform=None, *arg, **kw): 11 | super(TripletFaceDataset, self).__init__(dir,transform) 12 | 13 | self.n_triplets = n_triplets 14 | 15 | print('Generating {} triplets'.format(self.n_triplets)) 16 | self.training_triplets = self.generate_triplets(self.imgs, self.n_triplets,len(self.classes)) 17 | 18 | @staticmethod 19 | def generate_triplets(imgs, num_triplets,n_classes): 20 | def create_indices(_imgs): 21 | inds = dict() 22 | for idx, (img_path,label) in enumerate(_imgs): 23 | if label not in inds: 24 | inds[label] = [] 25 | inds[label].append(img_path) 26 | return inds 27 | 28 | triplets = [] 29 | # Indices = array of labels and each label is an array of indices 30 | indices = create_indices(imgs) 31 | 32 | for x in tqdm(range(num_triplets)): 33 | c1 = np.random.randint(0, n_classes-1) 34 | c2 = np.random.randint(0, n_classes-1) 35 | while len(indices[c1]) < 2: 36 | c1 = np.random.randint(0, n_classes-1) 37 | 38 | while c1 == c2: 39 | c2 = np.random.randint(0, n_classes-1) 40 | if len(indices[c1]) == 2: # hack to speed up process 41 | n1, n2 = 0, 1 42 | else: 43 | n1 = np.random.randint(0, len(indices[c1]) - 1) 44 | n2 = np.random.randint(0, len(indices[c1]) - 1) 45 | while n1 == n2: 46 | n2 = np.random.randint(0, len(indices[c1]) - 1) 47 | if len(indices[c2]) ==1: 48 | n3 = 0 49 | else: 50 | n3 = np.random.randint(0, len(indices[c2]) - 1) 51 | 52 | triplets.append([indices[c1][n1], indices[c1][n2], indices[c2][n3],c1,c2]) 53 | return triplets 54 | 55 | def __getitem__(self, index): 56 | ''' 57 | 58 | Args: 59 | index: Index of the triplet or the matches - not of a single image 60 | 61 | Returns: 62 | 63 | ''' 64 | def transform(img_path): 65 | """Convert image into numpy array and apply transformation 66 | Doing this so that it is consistent with all other datasets 67 | to return a PIL Image. 68 | """ 69 | 70 | img = self.loader(img_path) 71 | return self.transform(img) 72 | 73 | # Get the index of each image in the triplet 74 | a, p, n,c1,c2 = self.training_triplets[index] 75 | 76 | # transform images if required 77 | img_a, img_p, img_n = transform(a), transform(p), transform(n) 78 | return img_a, img_p, img_n,c1,c2 79 | 80 | def __len__(self): 81 | return len(self.training_triplets) -------------------------------------------------------------------------------- /dataset/clean_msceleb_using_openface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import shutil 3 | import os 4 | reps_path = '/media/lior/LinuxHDD/datasets/msceleb_rep.txt' 5 | 6 | from_path = '/media/lior/LinuxHDD/datasets/MsCeleb-aligned' 7 | to_path = '/media/lior/LinuxHDD/datasets/MSCeleb-cleaned' 8 | 9 | reps = {} 10 | with open(reps_path) as f: 11 | for line in f: 12 | path = line.split('+')[0] 13 | vector = np.array([float(x) for x in line.split('+')[1].strip().split(',')]) 14 | folder = path.split('/')[0] 15 | file = path.split('/')[1] 16 | if folder not in reps: 17 | reps[folder] = {} 18 | 19 | reps[folder][file]=vector 20 | 21 | def get_value_by_index(d,ix): 22 | return next( v for i, v in enumerate(d.items()) if i == ix ) 23 | 24 | for _dir in reps: 25 | X = [] 26 | print(_dir) 27 | saved = None 28 | for _file in reps[_dir]: 29 | X.append(reps[_dir][_file]) 30 | 31 | # Using Mean + STD 32 | i = 0 33 | mean = np.array(X).mean(axis=0) 34 | 35 | diff = np.array(X) - mean 36 | res = [] 37 | for d in diff: 38 | res.append(np.dot(d,d)) 39 | avg_dist = np.array(res).mean() 40 | std_dist = np.array(res).std() 41 | print("Average Distance {}, Std: {}".format(avg_dist,std_dist)) 42 | if avg_dist > 0.5: 43 | print("BAD DIR: {}".format(_dir)) 44 | continue 45 | 46 | for d in diff: 47 | if np.dot(d,d) > avg_dist+std_dist*2: 48 | print("BAD IMAGE: {}".format(get_value_by_index(reps[_dir],i)[0])) 49 | else: 50 | img_name = get_value_by_index(reps[_dir],i)[0] 51 | 52 | os.makedirs(os.path.join(to_path,_dir), exist_ok=True) 53 | shutil.copy(os.path.join(from_path,_dir,img_name),os.path.join(to_path,_dir,img_name)) 54 | i+=1 -------------------------------------------------------------------------------- /dataset/download_vgg_face_dataset.py: -------------------------------------------------------------------------------- 1 | """Download the VGG face dataset from URLs given by http://www.robots.ox.ac.uk/~vgg/data/vgg_face/vgg_face_dataset.tar.gz 2 | """ 3 | # MIT License 4 | # 5 | # Copyright (c) 2016 David Sandberg 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | from scipy import misc 30 | import numpy as np 31 | from skimage import io 32 | import sys 33 | import argparse 34 | import os 35 | import socket 36 | import cv2 37 | from urllib.request import urlopen 38 | 39 | def main(args): 40 | 41 | limit = 100 42 | 43 | socket.setdefaulttimeout(30) 44 | textfile_names = os.listdir(args.dataset_descriptor+"/files") 45 | for textfile_name in textfile_names: 46 | if textfile_name.endswith('.txt'): 47 | with open(os.path.join(args.dataset_descriptor,"files", textfile_name), 'rt') as f: 48 | lines = f.readlines() 49 | dir_name = textfile_name.split('.')[0] 50 | class_path = os.path.join(args.dataset_descriptor,"cropped", dir_name) 51 | if not os.path.exists(class_path): 52 | os.makedirs(class_path) 53 | 54 | for idx,line in enumerate(lines): 55 | 56 | x = line.split(' ') 57 | filename = x[0] 58 | url = x[1] 59 | box =[int(float(val)) for val in x[2:6]] # x1,y1,x2,y2 60 | image_path = os.path.join(args.dataset_descriptor,"cropped", dir_name, filename+'.'+args.output_format) 61 | error_path = os.path.join(args.dataset_descriptor, dir_name, filename+'.err') 62 | if not os.path.exists(image_path) and not os.path.exists(error_path): 63 | try: 64 | img = io.imread(url, mode='RGB') 65 | except Exception as e: 66 | error_message = '{}: {}'.format(url, e) 67 | save_error_message_file(error_path, error_message) 68 | else: 69 | try: 70 | if img.ndim == 2: 71 | img = to_rgb(img) 72 | if img.ndim != 3: 73 | raise ValueError('Wrong number of image dimensions') 74 | hist = np.histogram(img, 255, density=True) 75 | if hist[0][0]>0.9 and hist[0][254]>0.9: 76 | raise ValueError('Image is mainly black or white') 77 | else: 78 | # Crop image according to dataset descriptor 79 | img_cropped = img[box[1]:box[3],box[0]:box[2],:] 80 | # Scale to 256x256 81 | img_resized = misc.imresize(img_cropped, (args.image_size,args.image_size)) 82 | # Save image as .png 83 | misc.imsave(image_path, img_resized) 84 | except ValueError as e: 85 | error_message = '{}: {}'.format(url, e) 86 | save_error_message_file(error_path, error_message) 87 | if idx > limit: 88 | break 89 | 90 | def save_error_message_file(filename, error_message): 91 | print(error_message) 92 | #with open(filename, "w") as textfile: 93 | # textfile.write(error_message) 94 | 95 | def to_rgb(img): 96 | w, h = img.shape 97 | ret = np.empty((w, h, 3), dtype=np.uint8) 98 | ret[:, :, 0] = ret[:, :, 1] = ret[:, :, 2] = img 99 | return ret 100 | 101 | def parse_arguments(argv): 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument('--dataset_descriptor', type=str, 104 | help='Directory containing the text files with the image URLs. Image files will also be placed in this directory.', 105 | default="/media/lior/LinuxHDD/datasets/vgg_face_dataset") 106 | parser.add_argument('--output_format', type=str, help='Format of the output images', default='jpg', choices=['jpg']) 107 | parser.add_argument('--image_size', type=int, 108 | help='Image size (height, width) in pixels.', default=96) 109 | return parser.parse_args(argv) 110 | 111 | if __name__ == '__main__': 112 | main(parse_arguments(sys.argv[1:])) 113 | -------------------------------------------------------------------------------- /dataset/extract_msceleb.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2016 Carnegie Mellon University 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # 18 | # This script extracts the MS-Celeb-1M TSV file into images on disk. 19 | # For more information, see http://arxiv.org/abs/1607.08221 and http://msceleb.com 20 | # 21 | # Brandon Amos 22 | # 2016-07-29 23 | 24 | import argparse 25 | import base64 26 | import csv 27 | import os 28 | # import magic # Detect image type from buffer contents (disabled, all are jpg) 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--croppedTSV', type=str,default='/media/lior/LinuxHDD/datasets/MsCelebV1-Faces-Aligned.part.04.tsv') 32 | parser.add_argument('--outputDir', type=str, default='/media/lior/LinuxHDD/datasets/MsCeleb') 33 | args = parser.parse_args() 34 | 35 | with open(args.croppedTSV, 'r') as tsvF: 36 | reader = csv.reader(tsvF, delimiter='\t') 37 | i = 0 38 | for row in reader: 39 | MID, imgSearchRank, faceID, data = row[0], row[1], row[4], base64.b64decode(row[-1]) 40 | 41 | saveDir = os.path.join(args.outputDir, MID) 42 | savePath = os.path.join(saveDir, "{}-{}.jpg".format(imgSearchRank, faceID)) 43 | 44 | os.makedirs(saveDir, exist_ok=True) 45 | with open(savePath, 'wb') as f: 46 | f.write(data) 47 | 48 | i += 1 49 | 50 | if i % 1000 == 0: 51 | print("Extracted {} images.".format(i)) 52 | -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import numpy as np 3 | from sklearn.model_selection import KFold 4 | from scipy import interpolate 5 | 6 | def evaluate(distances, labels, nrof_folds=10): 7 | # Calculate evaluation metrics 8 | thresholds = np.arange(0, 30, 0.01) 9 | tpr, fpr, accuracy = calculate_roc(thresholds, distances, 10 | labels, nrof_folds=nrof_folds) 11 | thresholds = np.arange(0, 30, 0.001) 12 | val, val_std, far = calculate_val(thresholds, distances, 13 | labels, 1e-3, nrof_folds=nrof_folds) 14 | return tpr, fpr, accuracy, val, val_std, far 15 | 16 | def calculate_roc(thresholds, distances, labels, nrof_folds=10): 17 | 18 | nrof_pairs = min(len(labels), len(distances)) 19 | nrof_thresholds = len(thresholds) 20 | k_fold = KFold(n_splits=nrof_folds, shuffle=False) 21 | 22 | tprs = np.zeros((nrof_folds,nrof_thresholds)) 23 | fprs = np.zeros((nrof_folds,nrof_thresholds)) 24 | accuracy = np.zeros((nrof_folds)) 25 | 26 | indices = np.arange(nrof_pairs) 27 | 28 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 29 | 30 | # Find the best threshold for the fold 31 | acc_train = np.zeros((nrof_thresholds)) 32 | for threshold_idx, threshold in enumerate(thresholds): 33 | _, _, acc_train[threshold_idx] = calculate_accuracy(threshold, distances[train_set], labels[train_set]) 34 | best_threshold_index = np.argmax(acc_train) 35 | for threshold_idx, threshold in enumerate(thresholds): 36 | tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _ = calculate_accuracy(threshold, distances[test_set], labels[test_set]) 37 | _, _, accuracy[fold_idx] = calculate_accuracy(thresholds[best_threshold_index], distances[test_set], labels[test_set]) 38 | 39 | tpr = np.mean(tprs,0) 40 | fpr = np.mean(fprs,0) 41 | return tpr, fpr, accuracy 42 | 43 | 44 | def calculate_accuracy(threshold, dist, actual_issame): 45 | predict_issame = np.less(dist, threshold) 46 | tp = np.sum(np.logical_and(predict_issame, actual_issame)) 47 | fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 48 | tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame))) 49 | fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) 50 | 51 | tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn) 52 | fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn) 53 | acc = float(tp+tn)/dist.size 54 | return tpr, fpr, acc 55 | 56 | 57 | 58 | def calculate_val(thresholds, distances, labels, far_target=1e-3, nrof_folds=10): 59 | nrof_pairs = min(len(labels), len(distances)) 60 | nrof_thresholds = len(thresholds) 61 | k_fold = KFold(n_splits=nrof_folds, shuffle=False) 62 | 63 | val = np.zeros(nrof_folds) 64 | far = np.zeros(nrof_folds) 65 | 66 | indices = np.arange(nrof_pairs) 67 | 68 | for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): 69 | 70 | # Find the threshold that gives FAR = far_target 71 | far_train = np.zeros(nrof_thresholds) 72 | for threshold_idx, threshold in enumerate(thresholds): 73 | _, far_train[threshold_idx] = calculate_val_far(threshold, distances[train_set], labels[train_set]) 74 | if np.max(far_train)>=far_target: 75 | f = interpolate.interp1d(far_train, thresholds, kind='slinear') 76 | threshold = f(far_target) 77 | else: 78 | threshold = 0.0 79 | 80 | val[fold_idx], far[fold_idx] = calculate_val_far(threshold, distances[test_set], labels[test_set]) 81 | 82 | val_mean = np.mean(val) 83 | far_mean = np.mean(far) 84 | val_std = np.std(val) 85 | return val_mean, val_std, far_mean 86 | 87 | 88 | def calculate_val_far(threshold, dist, actual_issame): 89 | predict_issame = np.less(dist, threshold) 90 | true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) 91 | false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) 92 | n_same = np.sum(actual_issame) 93 | n_diff = np.sum(np.logical_not(actual_issame)) 94 | if n_diff == 0: 95 | n_diff = 1 96 | if n_same == 0: 97 | return 0,0 98 | val = float(true_accept) / float(n_same) 99 | far = float(false_accept) / float(n_diff) 100 | return val, far -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorboard_logger import configure, log_value 3 | class Logger(object): 4 | def __init__(self, log_dir): 5 | # clean previous logged data under the same directory name 6 | self._remove(log_dir) 7 | 8 | # configure the project 9 | configure(log_dir) 10 | 11 | self.global_step = 0 12 | 13 | def log_value(self, name, value): 14 | log_value(name, value, self.global_step) 15 | return self 16 | 17 | def step(self): 18 | self.global_step += 1 19 | 20 | @staticmethod 21 | def _remove(path): 22 | """ param could either be relative or absolute. """ 23 | if os.path.isfile(path): 24 | os.remove(path) # remove the file 25 | elif os.path.isdir(path): 26 | import shutil 27 | shutil.rmtree(path) # remove dir and all contains -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import resnet18 4 | import math 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | class FaceModel(nn.Module): 9 | def __init__(self,embedding_size,num_classes,pretrained=False): 10 | super(FaceModel, self).__init__() 11 | 12 | self.model = resnet18(pretrained) 13 | 14 | self.embedding_size = embedding_size 15 | 16 | self.model.fc = nn.Linear(512*3*3, self.embedding_size) 17 | 18 | self.model.classifier = nn.Linear(self.embedding_size, num_classes) 19 | 20 | 21 | def l2_norm(self,input): 22 | input_size = input.size() 23 | buffer = torch.pow(input, 2) 24 | 25 | normp = torch.sum(buffer, 1).add_(1e-10) 26 | norm = torch.sqrt(normp) 27 | 28 | _output = torch.div(input, norm.view(-1, 1).expand_as(input)) 29 | 30 | output = _output.view(input_size) 31 | 32 | return output 33 | 34 | def forward(self, x): 35 | 36 | x = self.model.conv1(x) 37 | x = self.model.bn1(x) 38 | x = self.model.relu(x) 39 | x = self.model.maxpool(x) 40 | x = self.model.layer1(x) 41 | x = self.model.layer2(x) 42 | x = self.model.layer3(x) 43 | x = self.model.layer4(x) 44 | x = x.view(x.size(0), -1) 45 | x = self.model.fc(x) 46 | self.features = self.l2_norm(x) 47 | # Multiply by alpha = 10 as suggested in https://arxiv.org/pdf/1703.09507.pdf 48 | alpha=10 49 | self.features = self.features*alpha 50 | 51 | #x = self.model.classifier(self.features) 52 | return self.features 53 | 54 | def forward_classifier(self, x): 55 | features = self.forward(x) 56 | res = self.model.classifier(features) 57 | return res 58 | 59 | 60 | from torch.nn.parameter import Parameter 61 | 62 | class FaceModelCenter(nn.Module): 63 | def __init__(self,embedding_size,num_classes, checkpoint=None): 64 | super(FaceModelCenter, self).__init__() 65 | self.model = resnet18() 66 | self.model.avgpool = None 67 | self.model.fc1 = nn.Linear(512*3*3, 512) 68 | self.model.fc2 = nn.Linear(512, embedding_size) 69 | self.model.classifier = nn.Linear(embedding_size, num_classes) 70 | self.centers = torch.zeros(num_classes, embedding_size).type(torch.FloatTensor) 71 | self.num_classes = num_classes 72 | 73 | self.apply(self.weights_init) 74 | 75 | if checkpoint is not None: 76 | # Check if there are the same number of classes 77 | if list(checkpoint['state_dict'].values())[-1].size(0) == num_classes: 78 | self.load_state_dict(checkpoint['state_dict']) 79 | self.centers = checkpoint['centers'] 80 | else: 81 | own_state = self.state_dict() 82 | for name, param in checkpoint['state_dict'].items(): 83 | if "classifier" not in name: 84 | if isinstance(param, Parameter): 85 | # backwards compatibility for serialized parameters 86 | param = param.data 87 | own_state[name].copy_(param) 88 | 89 | def weights_init(self,m): 90 | classname = m.__class__.__name__ 91 | if classname.find('Conv') != -1: 92 | m.weight.data.normal_(0.0, 0.02) 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | if m.bias is not None: 96 | m.bias.data.zero_() 97 | elif classname.find('BatchNorm') != -1: 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | elif classname.find('Linear') != -1: 101 | n = m.weight.size(1) 102 | m.weight.data.normal_(0, 0.01) 103 | m.bias.data.zero_() 104 | 105 | 106 | def get_center_loss(self,target, alpha): 107 | batch_size = target.size(0) 108 | features_dim = self.features.size(1) 109 | 110 | target_expand = target.view(batch_size,1).expand(batch_size,features_dim) 111 | 112 | centers_var = Variable(self.centers) 113 | centers_batch = centers_var.gather(0,target_expand).cuda() 114 | 115 | criterion = nn.MSELoss() 116 | center_loss = criterion(self.features, centers_batch) 117 | 118 | diff = centers_batch - self.features 119 | 120 | unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True) 121 | 122 | appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse)) 123 | 124 | appear_times_expand = appear_times.view(-1,1).expand(batch_size,features_dim).type(torch.FloatTensor) 125 | 126 | diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6) 127 | 128 | #∆c_j =(sum_i=1^m δ(yi = j)(c_j − x_i)) / (1 + sum_i=1^m δ(yi = j)) 129 | diff_cpu = alpha * diff_cpu 130 | 131 | for i in range(batch_size): 132 | #Update the parameters c_j for each j by c^(t+1)_j = c^t_j − α · ∆c^t_j 133 | self.centers[target.data[i]] -= diff_cpu[i].type(self.centers.type()) 134 | 135 | return center_loss, self.centers 136 | 137 | def l2_norm(self,input): 138 | input_size = input.size() 139 | buffer = torch.pow(input, 2) 140 | 141 | normp = torch.sum(buffer, 1).add_(1e-10) 142 | norm = torch.sqrt(normp) 143 | 144 | _output = torch.div(input, norm.view(-1, 1).expand_as(input)) 145 | 146 | output = _output.view(input_size) 147 | 148 | return output 149 | 150 | def forward(self, x): 151 | x = self.model.conv1(x) 152 | x = self.model.bn1(x) 153 | x = self.model.relu(x) 154 | x = self.model.maxpool(x) 155 | x = self.model.layer1(x) 156 | x = self.model.layer2(x) 157 | x = self.model.layer3(x) 158 | x = self.model.layer4(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.model.fc1(x) 161 | #feature for center loss 162 | x = self.model.fc2(x) 163 | self.features = x 164 | self.features_norm = self.l2_norm(x) 165 | return self.features_norm 166 | 167 | def forward_classifier(self,x): 168 | features_norm = self.forward(x) 169 | x = self.model.classifier(features_norm) 170 | return F.log_softmax(x) 171 | -------------------------------------------------------------------------------- /train_center.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision.transforms as transforms 7 | from torchvision.datasets import ImageFolder 8 | 9 | from torch.autograd import Variable 10 | from torch.autograd import Function 11 | import torch.backends.cudnn as cudnn 12 | import os 13 | import numpy as np 14 | from tqdm import tqdm 15 | from model import FaceModel,FaceModelCenter 16 | from eval_metrics import evaluate 17 | from logger import Logger 18 | from LFWDataset import LFWDataset 19 | from PIL import Image 20 | from utils import PairwiseDistance,display_triplet_distance,display_triplet_distance_test 21 | import collections 22 | 23 | # Training settings 24 | parser = argparse.ArgumentParser(description='PyTorch Face Recognition') 25 | # Model options 26 | parser.add_argument('--dataroot', type=str, default='/media/lior/LinuxHDD/datasets/MSCeleb-cleaned',#default='/media/lior/LinuxHDD/datasets/vgg_face_dataset/aligned' 27 | help='path to dataset') 28 | parser.add_argument('--lfw-dir', type=str, default='/media/lior/LinuxHDD/datasets/lfw-aligned-mtcnn', 29 | help='path to dataset') 30 | parser.add_argument('--lfw-pairs-path', type=str, default='lfw_pairs.txt', 31 | help='path to pairs file') 32 | 33 | parser.add_argument('--log-dir', default='/media/lior/LinuxHDD/pytorch_face_logs', 34 | help='folder to output model checkpoints') 35 | 36 | parser.add_argument('--resume', 37 | default='/media/lior/LinuxHDD/pytorch_face_logs/run-optim_adam-lr0.001-wd0.0-embeddings512-center0.5-MSCeleb/checkpoint_11.pth', 38 | type=str, metavar='PATH', 39 | help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 41 | help='manual epoch number (useful on restarts)') 42 | parser.add_argument('--epochs', type=int, default=10, metavar='E', 43 | help='number of epochs to train (default: 10)') 44 | # Training options 45 | # parser.add_argument('--embedding-size', type=int, default=256, metavar='ES', 46 | # help='Dimensionality of the embedding') 47 | 48 | parser.add_argument('--center_loss_weight', type=float, default=0.5, help='weight for center loss') 49 | parser.add_argument('--alpha', type=float, default=0.5, help='learning rate of the centers') 50 | parser.add_argument('--embedding-size', type=int, default=512, metavar='ES', 51 | help='Dimensionality of the embedding') 52 | 53 | parser.add_argument('--batch-size', type=int, default=64, metavar='BS', 54 | help='input batch size for training (default: 128)') 55 | parser.add_argument('--test-batch-size', type=int, default=64, metavar='BST', 56 | help='input batch size for testing (default: 1000)') 57 | 58 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 59 | help='learning rate (default: 0.001)') 60 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 61 | 62 | parser.add_argument('--lr-decay', default=1e-4, type=float, metavar='LRD', 63 | help='learning rate decay ratio (default: 1e-4') 64 | parser.add_argument('--wd', default=0.0, type=float, 65 | metavar='W', help='weight decay (default: 0.0)') 66 | parser.add_argument('--optimizer', default='adam', type=str, 67 | metavar='OPT', help='The optimizer to use (default: Adagrad)') 68 | # Device options 69 | parser.add_argument('--no-cuda', action='store_true', default=False, 70 | help='enables CUDA training') 71 | parser.add_argument('--gpu-id', default='0', type=str, 72 | help='id(s) for CUDA_VISIBLE_DEVICES') 73 | parser.add_argument('--seed', type=int, default=0, metavar='S', 74 | help='random seed (default: 0)') 75 | parser.add_argument('--log-interval', type=int, default=10, metavar='LI', 76 | help='how many batches to wait before logging training status') 77 | 78 | args = parser.parse_args() 79 | 80 | # set the device to use by setting CUDA_VISIBLE_DEVICES env variable in 81 | # order to prevent any memory allocation on unused GPUs 82 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 83 | 84 | args.cuda = not args.no_cuda and torch.cuda.is_available() 85 | np.random.seed(args.seed) 86 | 87 | if not os.path.exists(args.log_dir): 88 | os.makedirs(args.log_dir) 89 | 90 | if args.cuda: 91 | cudnn.benchmark = True 92 | 93 | LOG_DIR = args.log_dir + '/run-optim_{}-lr{}-wd{}-embeddings{}-center{}-MSCeleb'.format(args.optimizer, args.lr, args.wd,args.embedding_size,args.center_loss_weight) 94 | 95 | # create logger 96 | logger = Logger(LOG_DIR) 97 | 98 | kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {} 99 | l2_dist = PairwiseDistance(2) 100 | 101 | transform = transforms.Compose([ 102 | transforms.Scale(96), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | transforms.Normalize(mean = [ 0.5, 0.5, 0.5 ], 106 | std = [ 0.5, 0.5, 0.5 ]) 107 | ]) 108 | 109 | train_dir = ImageFolder(args.dataroot,transform=transform) 110 | train_loader = torch.utils.data.DataLoader(train_dir, 111 | batch_size=args.batch_size, shuffle=True, **kwargs) 112 | 113 | test_loader = torch.utils.data.DataLoader( 114 | LFWDataset(dir=args.lfw_dir,pairs_path=args.lfw_pairs_path, 115 | transform=transform), 116 | batch_size=args.batch_size, shuffle=False, **kwargs) 117 | 118 | 119 | 120 | def main(): 121 | test_display_triplet_distance= True 122 | # print the experiment configuration 123 | print('\nparsed options:\n{}\n'.format(vars(args))) 124 | print('\nNumber of Classes:\n{}\n'.format(len(train_dir.classes))) 125 | 126 | # instantiate model and initialize weights 127 | 128 | 129 | 130 | # optionally resume from a checkpoint 131 | if args.resume: 132 | if os.path.isfile(args.resume): 133 | print('=> loading checkpoint {}'.format(args.resume)) 134 | checkpoint = torch.load(args.resume) 135 | args.start_epoch = checkpoint['epoch'] 136 | else: 137 | checkpoint = None 138 | print('=> no checkpoint found at {}'.format(args.resume)) 139 | 140 | model = FaceModelCenter(embedding_size=args.embedding_size,num_classes=len(train_dir.classes) 141 | ,checkpoint=checkpoint) 142 | 143 | if args.cuda: 144 | model.cuda() 145 | 146 | optimizer = create_optimizer(model, args.lr) 147 | 148 | start = args.start_epoch 149 | end = start + args.epochs 150 | 151 | for epoch in range(start, end): 152 | train(train_loader, model, optimizer, epoch) 153 | test(test_loader, model, epoch) 154 | if test_display_triplet_distance: 155 | display_triplet_distance_test(model,test_loader,LOG_DIR+"/test_{}".format(epoch)) 156 | 157 | 158 | 159 | 160 | def accuracy(output, target, topk=(1,)): 161 | """Computes the precision@k for the specified values of k""" 162 | maxk = max(topk) 163 | batch_size = target.size(0) 164 | 165 | _, pred = output.topk(maxk, 1, True, True) 166 | pred = pred.t() 167 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 168 | 169 | res = [] 170 | for k in topk: 171 | correct_k = correct[:k].view(-1).float().sum(0) 172 | res.append(correct_k.mul_(100.0 / batch_size)) 173 | return res 174 | 175 | 176 | class AverageMeter(object): 177 | """Computes and stores the average and current value""" 178 | def __init__(self): 179 | self.reset() 180 | 181 | def reset(self): 182 | self.val = 0 183 | self.avg = 0 184 | self.sum = 0 185 | self.count = 0 186 | 187 | def update(self, val, n=1): 188 | self.val = val 189 | self.sum += val * n 190 | self.count += n 191 | self.avg = self.sum / self.count 192 | 193 | def train(train_loader, model, optimizer, epoch): 194 | # switch to train mode 195 | model.train() 196 | 197 | pbar = tqdm(enumerate(train_loader)) 198 | 199 | top1 = AverageMeter() 200 | 201 | for batch_idx, (data, label) in pbar: 202 | 203 | data_v = Variable(data.cuda()) 204 | target_var = Variable(label) 205 | 206 | # compute output 207 | prediction = model.forward_classifier(data_v) 208 | 209 | center_loss, model.centers = model.get_center_loss(target_var, args.alpha) 210 | 211 | criterion = nn.CrossEntropyLoss() 212 | 213 | cross_entropy_loss = criterion(prediction.cuda(),target_var.cuda()) 214 | 215 | loss = args.center_loss_weight*center_loss + cross_entropy_loss 216 | 217 | # compute gradient and update weights 218 | optimizer.zero_grad() 219 | loss.backward() 220 | optimizer.step() 221 | 222 | # update the optimizer learning rate 223 | adjust_learning_rate(optimizer) 224 | 225 | # log loss value 226 | # logger.log_value('cross_entropy_loss', cross_entropy_loss.data[0]).step() 227 | logger.log_value('total_loss', loss.data[0]).step() 228 | 229 | prec = accuracy(prediction.data, label.cuda(), topk=(1,)) 230 | top1.update(prec[0], data_v.size(0)) 231 | 232 | if batch_idx % args.log_interval == 0: 233 | pbar.set_description( 234 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t' 235 | 'Train Prec@1 {:.2f} ({:.2f})'.format( 236 | epoch, batch_idx * len(data_v), len(train_loader.dataset), 237 | 100. * batch_idx / len(train_loader), 238 | loss.data[0],float(top1.val[0]), float(top1.avg[0]))) 239 | 240 | 241 | 242 | logger.log_value('Train Prec@1 ',float(top1.avg[0])) 243 | 244 | # do checkpointing 245 | torch.save({'epoch': epoch + 1, 246 | 'state_dict': model.state_dict(), 247 | 'centers': model.centers}, 248 | '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch)) 249 | 250 | 251 | def test(test_loader, model, epoch): 252 | # switch to evaluate mode 253 | model.eval() 254 | 255 | labels, distances = [], [] 256 | 257 | pbar = tqdm(enumerate(test_loader)) 258 | for batch_idx, (data_a, data_p, label) in pbar: 259 | if args.cuda: 260 | data_a, data_p = data_a.cuda(), data_p.cuda() 261 | data_a, data_p, label = Variable(data_a, volatile=True), \ 262 | Variable(data_p, volatile=True), Variable(label) 263 | 264 | # compute output 265 | out_a, out_p = model(data_a), model(data_p) 266 | dists = l2_dist.forward(out_a,out_p)#torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance 267 | distances.append(dists.data.cpu().numpy()) 268 | labels.append(label.data.cpu().numpy()) 269 | 270 | if batch_idx % args.log_interval == 0: 271 | pbar.set_description('Test Epoch: {} [{}/{} ({:.0f}%)]'.format( 272 | epoch, batch_idx * len(data_a), len(test_loader.dataset), 273 | 100. * batch_idx / len(test_loader))) 274 | 275 | labels = np.array([sublabel for label in labels for sublabel in label]) 276 | distances = np.array([subdist[0] for dist in distances for subdist in dist]) 277 | 278 | tpr, fpr, accuracy, val, val_std, far = evaluate(distances,labels) 279 | print('\33[91mTest set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy))) 280 | logger.log_value('Test Accuracy', np.mean(accuracy)) 281 | 282 | plot_roc(fpr,tpr,figure_name="roc_test_epoch_{}.png".format(epoch)) 283 | 284 | def plot_roc(fpr,tpr,figure_name="roc.png"): 285 | import matplotlib.pyplot as plt 286 | from sklearn.metrics import roc_curve, auc 287 | roc_auc = auc(fpr, tpr) 288 | fig = plt.figure() 289 | lw = 2 290 | plt.plot(fpr, tpr, color='darkorange', 291 | lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) 292 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') 293 | plt.xlim([0.0, 1.0]) 294 | plt.ylim([0.0, 1.05]) 295 | plt.xlabel('False Positive Rate') 296 | plt.ylabel('True Positive Rate') 297 | plt.title('Receiver operating characteristic') 298 | plt.legend(loc="lower right") 299 | fig.savefig(os.path.join(LOG_DIR,figure_name), dpi=fig.dpi) 300 | 301 | 302 | def adjust_learning_rate(optimizer): 303 | """Updates the learning rate given the learning rate decay. 304 | The routine has been implemented according to the original Lua SGD optimizer 305 | """ 306 | for group in optimizer.param_groups: 307 | if 'step' not in group: 308 | group['step'] = 0 309 | group['step'] += 1 310 | 311 | group['lr'] = args.lr / (1 + group['step'] * args.lr_decay) 312 | 313 | 314 | def create_optimizer(model, new_lr): 315 | # setup optimizer 316 | if args.optimizer == 'sgd': 317 | optimizer = optim.SGD(model.parameters(), lr=new_lr, 318 | momentum=0.9, dampening=0.9, 319 | weight_decay=args.wd) 320 | elif args.optimizer == 'adam': 321 | optimizer = optim.Adam(model.parameters(), lr=new_lr, 322 | weight_decay=args.wd, betas=(args.beta1, 0.999)) 323 | elif args.optimizer == 'adagrad': 324 | optimizer = optim.Adagrad(model.parameters(), 325 | lr=new_lr, 326 | lr_decay=args.lr_decay, 327 | weight_decay=args.wd) 328 | return optimizer 329 | 330 | if __name__ == '__main__': 331 | main() -------------------------------------------------------------------------------- /train_triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision.transforms as transforms 7 | 8 | from torch.autograd import Variable 9 | from torch.autograd import Function 10 | import torch.backends.cudnn as cudnn 11 | import os 12 | import numpy as np 13 | from tqdm import tqdm 14 | from model import FaceModel 15 | from eval_metrics import evaluate 16 | from logger import Logger 17 | from TripletFaceDataset import TripletFaceDataset 18 | from LFWDataset import LFWDataset 19 | from PIL import Image 20 | from utils import PairwiseDistance,display_triplet_distance,display_triplet_distance_test 21 | import collections 22 | 23 | # Training settings 24 | parser = argparse.ArgumentParser(description='PyTorch Face Recognition') 25 | # Model options 26 | parser.add_argument('--dataroot', type=str, default='/media/lior/LinuxHDD/datasets/MSCeleb-cleaned', 27 | help='path to dataset') 28 | parser.add_argument('--lfw-dir', type=str, default='/media/lior/LinuxHDD/datasets/lfw-aligned-mtcnn', 29 | help='path to dataset') 30 | parser.add_argument('--lfw-pairs-path', type=str, default='lfw_pairs.txt', 31 | help='path to pairs file') 32 | 33 | parser.add_argument('--log-dir', default='/media/lior/LinuxHDD/pytorch_face_logs', 34 | help='folder to output model checkpoints') 35 | 36 | parser.add_argument('--resume', 37 | default='/media/lior/LinuxHDD/pytorch_face_logs/run-optim_adagrad-n1000000-lr0.1-wd0.0-m0.5-embeddings256-msceleb-alpha10/checkpoint_1.pth', 38 | type=str, metavar='PATH', 39 | help='path to latest checkpoint (default: none)') 40 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 41 | help='manual epoch number (useful on restarts)') 42 | parser.add_argument('--epochs', type=int, default=10, metavar='E', 43 | help='number of epochs to train (default: 10)') 44 | # Training options 45 | parser.add_argument('--embedding-size', type=int, default=256, metavar='ES', 46 | help='Dimensionality of the embedding') 47 | 48 | parser.add_argument('--batch-size', type=int, default=64, metavar='BS', 49 | help='input batch size for training (default: 128)') 50 | parser.add_argument('--test-batch-size', type=int, default=64, metavar='BST', 51 | help='input batch size for testing (default: 1000)') 52 | parser.add_argument('--n-triplets', type=int, default=1000000, metavar='N', 53 | help='how many triplets will generate from the dataset') 54 | 55 | parser.add_argument('--margin', type=float, default=0.5, metavar='MARGIN', 56 | help='the margin value for the triplet loss function (default: 1.0') 57 | 58 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 59 | help='learning rate (default: 0.125)') 60 | parser.add_argument('--lr-decay', default=1e-4, type=float, metavar='LRD', 61 | help='learning rate decay ratio (default: 1e-4') 62 | parser.add_argument('--wd', default=0.0, type=float, 63 | metavar='W', help='weight decay (default: 0.0)') 64 | parser.add_argument('--optimizer', default='adagrad', type=str, 65 | metavar='OPT', help='The optimizer to use (default: Adagrad)') 66 | # Device options 67 | parser.add_argument('--no-cuda', action='store_true', default=False, 68 | help='enables CUDA training') 69 | parser.add_argument('--gpu-id', default='0', type=str, 70 | help='id(s) for CUDA_VISIBLE_DEVICES') 71 | parser.add_argument('--seed', type=int, default=0, metavar='S', 72 | help='random seed (default: 0)') 73 | parser.add_argument('--log-interval', type=int, default=10, metavar='LI', 74 | help='how many batches to wait before logging training status') 75 | 76 | args = parser.parse_args() 77 | 78 | # set the device to use by setting CUDA_VISIBLE_DEVICES env variable in 79 | # order to prevent any memory allocation on unused GPUs 80 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 81 | 82 | args.cuda = not args.no_cuda and torch.cuda.is_available() 83 | np.random.seed(args.seed) 84 | 85 | if not os.path.exists(args.log_dir): 86 | os.makedirs(args.log_dir) 87 | 88 | if args.cuda: 89 | cudnn.benchmark = True 90 | 91 | LOG_DIR = args.log_dir + '/run-optim_{}-n{}-lr{}-wd{}-m{}-embeddings{}-msceleb-alpha10'\ 92 | .format(args.optimizer, args.n_triplets, args.lr, args.wd, 93 | args.margin,args.embedding_size) 94 | 95 | # create logger 96 | logger = Logger(LOG_DIR) 97 | 98 | 99 | class TripletMarginLoss(Function): 100 | """Triplet loss function. 101 | """ 102 | def __init__(self, margin): 103 | super(TripletMarginLoss, self).__init__() 104 | self.margin = margin 105 | self.pdist = PairwiseDistance(2) # norm 2 106 | 107 | def forward(self, anchor, positive, negative): 108 | d_p = self.pdist.forward(anchor, positive) 109 | d_n = self.pdist.forward(anchor, negative) 110 | 111 | dist_hinge = torch.clamp(self.margin + d_p - d_n, min=0.0) 112 | loss = torch.mean(dist_hinge) 113 | return loss 114 | 115 | class Scale(object): 116 | """Rescales the input PIL.Image to the given 'size'. 117 | If 'size' is a 2-element tuple or list in the order of (width, height), it will be the exactly size to scale. 118 | If 'size' is a number, it will indicate the size of the smaller edge. 119 | For example, if height > width, then image will be 120 | rescaled to (size * height / width, size) 121 | size: size of the exactly size or the smaller edge 122 | interpolation: Default: PIL.Image.BILINEAR 123 | """ 124 | 125 | def __init__(self, size, interpolation=Image.BILINEAR): 126 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 127 | self.size = size 128 | self.interpolation = interpolation 129 | 130 | def __call__(self, img): 131 | if isinstance(self.size, int): 132 | w, h = img.size 133 | if (w <= h and w == self.size) or (h <= w and h == self.size): 134 | return img 135 | if w < h: 136 | ow = self.size 137 | oh = int(self.size * h / w) 138 | return img.resize((ow, oh), self.interpolation) 139 | else: 140 | oh = self.size 141 | ow = int(self.size * w / h) 142 | return img.resize((ow, oh), self.interpolation) 143 | else: 144 | return img.resize(self.size, self.interpolation) 145 | 146 | 147 | kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {} 148 | l2_dist = PairwiseDistance(2) 149 | 150 | transform = transforms.Compose([ 151 | Scale(96), 152 | transforms.ToTensor(), 153 | transforms.Normalize(mean = [ 0.5, 0.5, 0.5 ], 154 | std = [ 0.5, 0.5, 0.5 ]) 155 | ]) 156 | 157 | train_dir = TripletFaceDataset(dir=args.dataroot,n_triplets=args.n_triplets,transform=transform) 158 | train_loader = torch.utils.data.DataLoader(train_dir, 159 | batch_size=args.batch_size, shuffle=False, **kwargs) 160 | 161 | test_loader = torch.utils.data.DataLoader( 162 | LFWDataset(dir=args.lfw_dir,pairs_path=args.lfw_pairs_path, 163 | transform=transform), 164 | batch_size=args.batch_size, shuffle=False, **kwargs) 165 | 166 | 167 | 168 | def main(): 169 | # Views the training images and displays the distance on anchor-negative and anchor-positive 170 | test_display_triplet_distance = True 171 | 172 | # print the experiment configuration 173 | print('\nparsed options:\n{}\n'.format(vars(args))) 174 | print('\nNumber of Classes:\n{}\n'.format(len(train_dir.classes))) 175 | 176 | # instantiate model and initialize weights 177 | model = FaceModel(embedding_size=args.embedding_size, 178 | num_classes=len(train_dir.classes), 179 | pretrained=False) 180 | 181 | if args.cuda: 182 | model.cuda() 183 | 184 | optimizer = create_optimizer(model, args.lr) 185 | 186 | # optionally resume from a checkpoint 187 | if args.resume: 188 | if os.path.isfile(args.resume): 189 | print('=> loading checkpoint {}'.format(args.resume)) 190 | checkpoint = torch.load(args.resume) 191 | args.start_epoch = checkpoint['epoch'] 192 | checkpoint = torch.load(args.resume) 193 | model.load_state_dict(checkpoint['state_dict']) 194 | else: 195 | print('=> no checkpoint found at {}'.format(args.resume)) 196 | 197 | start = args.start_epoch 198 | end = start + args.epochs 199 | 200 | for epoch in range(start, end): 201 | train(train_loader, model, optimizer, epoch) 202 | test(test_loader, model, epoch) 203 | 204 | if test_display_triplet_distance: 205 | display_triplet_distance(model,train_loader,LOG_DIR+"/train_{}".format(epoch)) 206 | display_triplet_distance_test(model,test_loader,LOG_DIR+"/test_{}".format(epoch)) 207 | 208 | 209 | def train(train_loader, model, optimizer, epoch): 210 | # switch to train mode 211 | model.train() 212 | 213 | pbar = tqdm(enumerate(train_loader)) 214 | labels, distances = [], [] 215 | 216 | 217 | for batch_idx, (data_a, data_p, data_n,label_p,label_n) in pbar: 218 | 219 | data_a, data_p, data_n = data_a.cuda(), data_p.cuda(), data_n.cuda() 220 | data_a, data_p, data_n = Variable(data_a), Variable(data_p), \ 221 | Variable(data_n) 222 | 223 | # compute output 224 | out_a, out_p, out_n = model(data_a), model(data_p), model(data_n) 225 | 226 | # Choose the hard negatives 227 | d_p = l2_dist.forward(out_a, out_p) 228 | d_n = l2_dist.forward(out_a, out_n) 229 | all = (d_n - d_p < args.margin).cpu().data.numpy().flatten() 230 | hard_triplets = np.where(all == 1) 231 | if len(hard_triplets[0]) == 0: 232 | continue 233 | out_selected_a = Variable(torch.from_numpy(out_a.cpu().data.numpy()[hard_triplets]).cuda()) 234 | out_selected_p = Variable(torch.from_numpy(out_p.cpu().data.numpy()[hard_triplets]).cuda()) 235 | out_selected_n = Variable(torch.from_numpy(out_n.cpu().data.numpy()[hard_triplets]).cuda()) 236 | 237 | selected_data_a = Variable(torch.from_numpy(data_a.cpu().data.numpy()[hard_triplets]).cuda()) 238 | selected_data_p = Variable(torch.from_numpy(data_p.cpu().data.numpy()[hard_triplets]).cuda()) 239 | selected_data_n = Variable(torch.from_numpy(data_n.cpu().data.numpy()[hard_triplets]).cuda()) 240 | 241 | selected_label_p = torch.from_numpy(label_p.cpu().numpy()[hard_triplets]) 242 | selected_label_n= torch.from_numpy(label_n.cpu().numpy()[hard_triplets]) 243 | triplet_loss = TripletMarginLoss(args.margin).forward(out_selected_a, out_selected_p, out_selected_n) 244 | 245 | cls_a = model.forward_classifier(selected_data_a) 246 | cls_p = model.forward_classifier(selected_data_p) 247 | cls_n = model.forward_classifier(selected_data_n) 248 | 249 | criterion = nn.CrossEntropyLoss() 250 | predicted_labels = torch.cat([cls_a,cls_p,cls_n]) 251 | true_labels = torch.cat([Variable(selected_label_p.cuda()),Variable(selected_label_p.cuda()),Variable(selected_label_n.cuda())]) 252 | 253 | cross_entropy_loss = criterion(predicted_labels.cuda(),true_labels.cuda()) 254 | 255 | loss = cross_entropy_loss + triplet_loss 256 | # compute gradient and update weights 257 | optimizer.zero_grad() 258 | loss.backward() 259 | optimizer.step() 260 | 261 | # update the optimizer learning rate 262 | adjust_learning_rate(optimizer) 263 | 264 | # log loss value 265 | logger.log_value('triplet_loss', triplet_loss.data[0]).step() 266 | logger.log_value('cross_entropy_loss', cross_entropy_loss.data[0]).step() 267 | logger.log_value('total_loss', loss.data[0]).step() 268 | if batch_idx % args.log_interval == 0: 269 | pbar.set_description( 270 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t # of Selected Triplets: {}'.format( 271 | epoch, batch_idx * len(data_a), len(train_loader.dataset), 272 | 100. * batch_idx / len(train_loader), 273 | loss.data[0],len(hard_triplets[0]))) 274 | 275 | 276 | dists = l2_dist.forward(out_selected_a,out_selected_n) #torch.sqrt(torch.sum((out_a - out_n) ** 2, 1)) # euclidean distance 277 | distances.append(dists.data.cpu().numpy()) 278 | labels.append(np.zeros(dists.size(0))) 279 | 280 | 281 | dists = l2_dist.forward(out_selected_a,out_selected_p)#torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance 282 | distances.append(dists.data.cpu().numpy()) 283 | labels.append(np.ones(dists.size(0))) 284 | 285 | labels = np.array([sublabel for label in labels for sublabel in label]) 286 | distances = np.array([subdist[0] for dist in distances for subdist in dist]) 287 | 288 | tpr, fpr, accuracy, val, val_std, far = evaluate(distances,labels) 289 | print('\33[91mTrain set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy))) 290 | logger.log_value('Train Accuracy', np.mean(accuracy)) 291 | 292 | plot_roc(fpr,tpr,figure_name="roc_train_epoch_{}.png".format(epoch)) 293 | 294 | # do checkpointing 295 | torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, 296 | '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch)) 297 | 298 | 299 | def test(test_loader, model, epoch): 300 | # switch to evaluate mode 301 | model.eval() 302 | 303 | labels, distances = [], [] 304 | 305 | pbar = tqdm(enumerate(test_loader)) 306 | for batch_idx, (data_a, data_p, label) in pbar: 307 | if args.cuda: 308 | data_a, data_p = data_a.cuda(), data_p.cuda() 309 | data_a, data_p, label = Variable(data_a, volatile=True), \ 310 | Variable(data_p, volatile=True), Variable(label) 311 | 312 | # compute output 313 | out_a, out_p = model(data_a), model(data_p) 314 | dists = l2_dist.forward(out_a,out_p)#torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance 315 | distances.append(dists.data.cpu().numpy()) 316 | labels.append(label.data.cpu().numpy()) 317 | 318 | if batch_idx % args.log_interval == 0: 319 | pbar.set_description('Test Epoch: {} [{}/{} ({:.0f}%)]'.format( 320 | epoch, batch_idx * len(data_a), len(test_loader.dataset), 321 | 100. * batch_idx / len(test_loader))) 322 | 323 | labels = np.array([sublabel for label in labels for sublabel in label]) 324 | distances = np.array([subdist[0] for dist in distances for subdist in dist]) 325 | 326 | tpr, fpr, accuracy, val, val_std, far = evaluate(distances,labels) 327 | print('\33[91mTest set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy))) 328 | logger.log_value('Test Accuracy', np.mean(accuracy)) 329 | 330 | plot_roc(fpr,tpr,figure_name="roc_test_epoch_{}.png".format(epoch)) 331 | 332 | def plot_roc(fpr,tpr,figure_name="roc.png"): 333 | import matplotlib.pyplot as plt 334 | from sklearn.metrics import roc_curve, auc 335 | roc_auc = auc(fpr, tpr) 336 | fig = plt.figure() 337 | lw = 2 338 | plt.plot(fpr, tpr, color='darkorange', 339 | lw=lw, label='ROC curve (area = %0.2f)' % roc_auc) 340 | plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') 341 | plt.xlim([0.0, 1.0]) 342 | plt.ylim([0.0, 1.05]) 343 | plt.xlabel('False Positive Rate') 344 | plt.ylabel('True Positive Rate') 345 | plt.title('Receiver operating characteristic') 346 | plt.legend(loc="lower right") 347 | fig.savefig(os.path.join(LOG_DIR,figure_name), dpi=fig.dpi) 348 | 349 | 350 | def adjust_learning_rate(optimizer): 351 | """Updates the learning rate given the learning rate decay. 352 | The routine has been implemented according to the original Lua SGD optimizer 353 | """ 354 | for group in optimizer.param_groups: 355 | if 'step' not in group: 356 | group['step'] = 0 357 | group['step'] += 1 358 | 359 | group['lr'] = args.lr / (1 + group['step'] * args.lr_decay) 360 | 361 | 362 | def create_optimizer(model, new_lr): 363 | # setup optimizer 364 | if args.optimizer == 'sgd': 365 | optimizer = optim.SGD(model.parameters(), lr=new_lr, 366 | momentum=0.9, dampening=0.9, 367 | weight_decay=args.wd) 368 | elif args.optimizer == 'adam': 369 | optimizer = optim.Adam(model.parameters(), lr=new_lr, 370 | weight_decay=args.wd) 371 | elif args.optimizer == 'adagrad': 372 | optimizer = optim.Adagrad(model.parameters(), 373 | lr=new_lr, 374 | lr_decay=args.lr_decay, 375 | weight_decay=args.wd) 376 | return optimizer 377 | 378 | if __name__ == '__main__': 379 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from torch.autograd import Variable,Function 4 | 5 | 6 | class PairwiseDistance(Function): 7 | def __init__(self, p): 8 | super(PairwiseDistance, self).__init__() 9 | self.norm = p 10 | 11 | def forward(self, x1, x2): 12 | assert x1.size() == x2.size() 13 | eps = 1e-4 / x1.size(1) 14 | diff = torch.abs(x1 - x2) 15 | out = torch.pow(diff, self.norm).sum(dim=1) 16 | return torch.pow(out + eps, 1. / self.norm) 17 | 18 | def denormalize(tens): 19 | mean = [0.5,0.5,0.5] 20 | std = [0.5,0.5,0.5] 21 | 22 | img_1 = tens.clone() 23 | for t, m, s in zip(img_1, mean, std): 24 | t.mul_(s).add_(m) 25 | img_1 = img_1.numpy().transpose(1,2,0) 26 | return img_1 27 | 28 | def display_triplet_distance(model,train_loader,name): 29 | f, axarr = plt.subplots(3,figsize=(10,10)) 30 | f.tight_layout() 31 | l2_dist = PairwiseDistance(2) 32 | 33 | for batch_idx, (data_a, data_p, data_n,c1,c2) in enumerate(train_loader): 34 | 35 | try: 36 | data_a_c, data_p_c,data_n_c = data_a.cuda(), data_p.cuda(), data_n.cuda() 37 | data_a_v, data_p_v, data_n_v = Variable(data_a_c, volatile=True), \ 38 | Variable(data_p_c, volatile=True), \ 39 | Variable(data_n_c, volatile=True) 40 | 41 | out_a, out_p, out_n = model(data_a_v), model(data_p_v), model(data_n_v) 42 | except Exception as ex: 43 | print(ex) 44 | print("ERROR at: {}".format(batch_idx)) 45 | break 46 | 47 | print("Distance (anchor-positive): {}".format(l2_dist.forward(out_a,out_p).data[0][0])) 48 | print("Distance (anchor-negative): {}".format(l2_dist.forward(out_a,out_n).data[0][0])) 49 | 50 | 51 | axarr[0].imshow(denormalize(data_a[0])) 52 | axarr[1].imshow(denormalize(data_p[0])) 53 | axarr[2].imshow(denormalize(data_n[0])) 54 | axarr[0].set_title("Distance (anchor-positive): {}".format(l2_dist.forward(out_a,out_p).data[0][0])) 55 | axarr[2].set_title("Distance (anchor-negative): {}".format(l2_dist.forward(out_a,out_n).data[0][0])) 56 | 57 | break 58 | f.savefig("{}.png".format(name)) 59 | #plt.show() 60 | 61 | from sklearn.decomposition import PCA 62 | import numpy as np 63 | 64 | def display_triplet_distance_test(model,test_loader,name): 65 | f, axarr = plt.subplots(5,2,figsize=(10,10)) 66 | f.tight_layout() 67 | l2_dist = PairwiseDistance(2) 68 | 69 | for batch_idx, (data_a, data_n,label) in enumerate(test_loader): 70 | 71 | if np.all(label.cpu().numpy()): 72 | continue 73 | 74 | try: 75 | data_a_c, data_n_c = data_a.cuda(), data_n.cuda() 76 | data_a_v, data_n_v = Variable(data_a_c, volatile=True), \ 77 | Variable(data_n_c, volatile=True) 78 | 79 | out_a, out_n = model(data_a_v), model(data_n_v) 80 | 81 | except Exception as ex: 82 | print(ex) 83 | print("ERROR at: {}".format(batch_idx)) 84 | break 85 | 86 | for i in range(5): 87 | rand_index = np.random.randint(0, label.size(0)-1) 88 | if i%2 == 0: 89 | for j in range(label.size(0)): 90 | # Choose label == 0 91 | rand_index = np.random.randint(0, label.size(0)-1) 92 | if label[rand_index] == 0: 93 | break 94 | 95 | distance = l2_dist.forward(out_a,out_n).data[rand_index][0] 96 | print("Distance: {}".format(distance)) 97 | #distance_pca = l2_dist.forward(PCA(128).fit_transform(out_a.data[i].cpu().numpy()),PCA(128).fit_transform(out_n.data[i].cpu().numpy())).data[0] 98 | #print("Distance(PCA): {}".format(distance_pca)) 99 | 100 | axarr[i][0].imshow(denormalize(data_a[rand_index])) 101 | axarr[i][1].imshow(denormalize(data_n[rand_index])) 102 | plt.figtext(0.5, i/5.0+0.1,"Distance : {}, Label: {}\n".format(distance,label[rand_index]), ha='center', va='center') 103 | 104 | 105 | break 106 | plt.subplots_adjust(hspace=0.5) 107 | 108 | f.savefig("{}.png".format(name)) 109 | #plt.show() 110 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | from torch.autograd import Variable 5 | from torchvision import datasets, transforms 6 | from model import FaceModel 7 | from torchvision.datasets import ImageFolder 8 | from TripletFaceDataset import FaceDataset 9 | import math 10 | import os 11 | import seaborn as sns 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | import matplotlib.patheffects as PathEffects 15 | 16 | parser = argparse.ArgumentParser(description='PyTorch face recognition Example') 17 | parser.add_argument('--test_batch_size', type=int, default=64, metavar='N', 18 | help='input batch size for testing (default: 64)') 19 | parser.add_argument('--seed', type=int, default=1, metavar='S', 20 | help='random seed (default: 1)') 21 | parser.add_argument('--root', type=str, 22 | help='path to the data directory containing aligned face patches. Multiple directories are separated with colon.', 23 | default='/media/lior/LinuxHDD/datasets/vgg_face_dataset/aligned') 24 | parser.add_argument('--resume', type=str, 25 | help='model path to the resume training', 26 | default='/home/lior/dev/workspace/face_recognition_seminar/facenet_pytorch/logs/run-optim_adagrad-n1000000-lr0.125-wd0.0-m0.5/checkpoint_1.pth') 27 | 28 | def visual_feature_space(features, labels, num_classes, name_dict): 29 | num = len(labels) 30 | 31 | title_font = {'fontname':'Arial', 'size':'20', 'color':'black', 'weight':'normal', 32 | 'verticalalignment':'bottom'} # Bottom vertical alignment for more space 33 | axis_font = {'fontname':'Arial', 'size':'20'} 34 | 35 | # draw 36 | palette = np.array(sns.color_palette("hls", num_classes)) 37 | 38 | # We create a scatter plot. 39 | f = plt.figure(figsize=(8, 8)) 40 | ax = plt.subplot(aspect='equal') 41 | sc = ax.scatter(features[:,0], features[:,1], lw=0, s=40, 42 | c=palette[labels.astype(np.int)]) 43 | # ax.axis('off') 44 | # ax.axis('tight') 45 | 46 | # We add the labels for each digit. 47 | txts = [] 48 | for i in range(num_classes): 49 | # Position of each label. 50 | xtext, ytext = np.median(features[labels == i, :], axis=0) 51 | txt = ax.text(xtext, ytext, name_dict[i]) 52 | txt.set_path_effects([ 53 | PathEffects.Stroke(linewidth=5, foreground="w"), 54 | PathEffects.Normal()]) 55 | txts.append(txt) 56 | ax.set_xlabel('Activation of the 1st neuron', **axis_font) 57 | ax.set_ylabel('Activation of the 2nd neuron', **axis_font) 58 | ax.set_title('softmax_loss + center_loss', **title_font) 59 | ax.set_axis_bgcolor('grey') 60 | f.savefig('center_loss.png') 61 | plt.show() 62 | return f, ax, sc, txts 63 | 64 | def validation_iterator(dataLoader): 65 | for data, target in dataLoader: 66 | yield data, target 67 | 68 | def main(): 69 | args = parser.parse_args() 70 | 71 | cuda = torch.cuda.is_available() 72 | torch.manual_seed(args.seed) 73 | if cuda: 74 | torch.cuda.manual_seed(args.seed) 75 | 76 | # 1. dataset 77 | root = args.root 78 | kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {} 79 | test_transforms = transforms.Compose([transforms.Scale(96), 80 | transforms.ToTensor(), 81 | transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]) 82 | test_dataset = ImageFolder(root, transform=test_transforms) 83 | test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=args.test_batch_size, shuffle=False, **kwargs) 84 | 85 | val_iterator = validation_iterator(test_loader) 86 | 87 | # 2. model 88 | #train_dir = FaceDataset(dir='/media/lior/LinuxHDD/datasets/MSCeleb-cleaned',n_triplets=10) 89 | 90 | print('construct model') 91 | model = FaceModel(embedding_size=128, 92 | num_classes=3367, 93 | pretrained=False) 94 | 95 | model = model.cuda() 96 | 97 | if args.resume: 98 | if os.path.isfile(args.resume): 99 | print("=> loading checkpoint '{}'".format(args.resume)) 100 | checkpoint = torch.load(args.resume) 101 | 102 | model.load_state_dict(checkpoint['state_dict']) 103 | print("=> loaded checkpoint '{}'".format(args.resume)) 104 | else: 105 | print("=> no checkpoint found at '{}'".format(args.resume)) 106 | 107 | # extract feature 108 | print('extracting feature') 109 | embeds = [] 110 | labels = [] 111 | for data, target in val_iterator: 112 | if cuda: 113 | data, target = data.cuda(), target.cuda(async=True) 114 | data_var = Variable(data, volatile=True) 115 | # compute output 116 | output = model(data_var) 117 | 118 | embeds.append( output.data.cpu().numpy() ) 119 | labels.append( target.cpu().numpy() ) 120 | 121 | 122 | embeds = np.vstack(embeds) 123 | labels = np.hstack(labels) 124 | 125 | print('embeds shape is ', embeds.shape) 126 | print('labels shape is ', labels.shape) 127 | 128 | # prepare dict for display 129 | namedict = dict() 130 | for i in range(10): 131 | namedict[i]=str(i) 132 | 133 | visual_feature_space(embeds, labels, len(test_dataset.classes), namedict) 134 | 135 | if __name__ == '__main__': 136 | main() 137 | --------------------------------------------------------------------------------