├── .gitignore ├── README.md ├── dataset.py ├── device.py ├── imageaug.py ├── images ├── center-loss.png ├── obama_a.png ├── obama_b.png ├── result.jpg ├── result.png ├── roc.png ├── softmax.png ├── trump_a.png └── trump_b.png ├── loss.py ├── main.py ├── metrics.py ├── models ├── __init__.py ├── base.py └── resnet.py ├── tests ├── __init__.py └── center_loss_test.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs 3 | *.pth 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # center-loss.pytorch 2 | Center loss implementation for face recognition in pytorch. Paper at: [A Discriminative Feature Learning Approach 3 | for Deep Face Recognition](https://ydwen.github.io/papers/WenECCV16.pdf) 4 | 5 | ## Requirements 6 | 7 | * Python 3.6 8 | * Pytorch 0.4 9 | 10 | ## Usage 11 | 12 | **Training** No need to download anything. The following command will create directorys and download everything automatically. 13 | 14 | ``` 15 | # For quick reference and small GPU ram 16 | python3 main.py --arch resnet18 --batch_size 64 --epochs 50 17 | 18 | # For a more solid model 19 | python3 main.py 20 | 21 | # or 22 | python3 main.py --arch resnet50 --batch_size 256 --epochs 150 23 | ``` 24 | 25 | **Evaluation** 26 | 27 | ``` 28 | python3 main.py --evaluate ./logs/models/epoch_xx.pth.tar 29 | 30 | # Model accuracy is 0.961722195148468 31 | # ROC curve generated at /home/louis/center-loss.pytorch/logs/roc.png 32 | ``` 33 | 34 | **More Options** 35 | 36 | ``` 37 | usage: main.py [-h] [--batch_size N] [--log_dir LOG_DIR] [--epochs N] 38 | [--lr LR] [--arch ARCH] [--resume RESUME] 39 | [--dataset_dir DATASET_DIR] [--weights WEIGHTS] 40 | [--evaluate EVALUATE] [--pairs PAIRS] [--roc ROC] 41 | [--verify-model VERIFY_MODEL] [--verify-images VERIFY_IMAGES] 42 | 43 | center loss example 44 | 45 | optional arguments: 46 | -h, --help show this help message and exit 47 | --batch_size N input batch size for training (default: 256) 48 | --log_dir LOG_DIR log directory 49 | --epochs N number of epochs to train (default: 100) 50 | --lr LR learning rate (default: 0.001) 51 | --arch ARCH network arch to use, support resnet18 and resnet50 52 | (default: resnet50) 53 | --resume RESUME model path to the resume training 54 | --dataset_dir DATASET_DIR 55 | directory with lfw dataset (default: 56 | $HOME/datasets/lfw) 57 | --weights WEIGHTS pretrained weights to load default: 58 | ($LOG_DIR/resnet18.pth) 59 | --evaluate EVALUATE evaluate specified model on lfw dataset 60 | --pairs PAIRS path of pairs.txt (default: $DATASET_DIR/pairs.txt) 61 | --roc ROC path of roc.png to generated (default: 62 | $DATASET_DIR/roc.png) 63 | --verify-model VERIFY_MODEL 64 | verify 2 images of face belong to one person,the param 65 | is the model to use 66 | --verify-images VERIFY_IMAGES 67 | verify 2 images of face belong to one person,split 68 | image pathes by comma 69 | ``` 70 | 71 | ## Experiments 72 | 73 | Trained a model with default configuration(resnet50 for 100 epochs). The model can be downloaded from [Baidu Yun](https://pan.baidu.com/s/138sQCpHqImPjevvZMsVz3w) or [Google Drive](https://drive.google.com/open?id=1cQFwneObMeRc1KZF8959YGVbbYGvUhkp). 74 | 75 | Results shown as follows: 76 | 77 | ``` 78 | python main.py --evaluate logs/models/epoch_100.pth.tar --batch_size 128 79 | 80 | Model accuracy is 0.9628332853317261 81 | ROC curve generated at /home/louis/center-loss.pytorch/logs/roc.png 82 | ``` 83 | 84 | ![](images/roc.png) 85 | 86 | ### Experiments with MNIST dataset 87 | 88 | **softmax only** 89 | 90 | ![](images/softmax.png) 91 | 92 | **softmax + center loss** 93 | 94 | ![](images/center-loss.png) 95 | 96 | 97 | ## Random People Verification 98 | 99 | 2 images of Obama and 2 images of Trump. Verify 4 pairs together using the model. 100 | 101 | ```shell 102 | ➜ python main.py --verify-model logs/models/epoch_100.pth.tar --verify-images images/obama_a.png,images/obama_b.png 103 | distance: 0.9222122430801392 104 | ➜ python main.py --verify-model logs/models/epoch_100.pth.tar --verify-images images/trump_a.png,images/trump_b.png 105 | distance: 0.8140196800231934 106 | ➜ python main.py --verify-model logs/models/epoch_100.pth.tar --verify-images images/obama_a.png,images/trump_a.png 107 | distance: 1.2879283428192139 108 | ➜ python main.py --verify-model logs/models/epoch_100.pth.tar --verify-images images/obama_b.png,images/trump_b.png 109 | distance: 1.26639723777771 110 | ``` 111 | 112 | ![](images/result.png) 113 | 114 | We can see that threshold of 1.1 will perfectly seperate them. 115 | 116 | Due to the small dataset, this model is just for quick example reference. If one wants to use in production, change the feature extract network and train on a larger dataset. -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import tarfile 4 | from math import ceil, floor 5 | 6 | from torch.utils import data 7 | import numpy as np 8 | 9 | from utils import image_loader, download 10 | 11 | DATASET_TARBALL = "http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz" 12 | PAIRS_TRAIN = "http://vis-www.cs.umass.edu/lfw/pairsDevTrain.txt" 13 | PAIRS_VAL = "http://vis-www.cs.umass.edu/lfw/pairsDevTest.txt" 14 | 15 | def create_datasets(dataroot, train_val_split=0.9): 16 | if not os.path.isdir(dataroot): 17 | os.mkdir(dataroot) 18 | 19 | dataroot_files = os.listdir(dataroot) 20 | data_tarball_file = DATASET_TARBALL.split('/')[-1] 21 | data_dir_name = data_tarball_file.split('.')[0] 22 | 23 | if data_dir_name not in dataroot_files: 24 | if data_tarball_file not in dataroot_files: 25 | tarball = download(dataroot, DATASET_TARBALL) 26 | with tarfile.open(tarball, 'r') as t: 27 | t.extractall(dataroot) 28 | 29 | images_root = os.path.join(dataroot, 'lfw-deepfunneled') 30 | names = os.listdir(images_root) 31 | if len(names) == 0: 32 | raise RuntimeError('Empty dataset') 33 | 34 | training_set = [] 35 | validation_set = [] 36 | for klass, name in enumerate(names): 37 | def add_class(image): 38 | image_path = os.path.join(images_root, name, image) 39 | return (image_path, klass, name) 40 | 41 | images_of_person = os.listdir(os.path.join(images_root, name)) 42 | total = len(images_of_person) 43 | 44 | training_set += map( 45 | add_class, 46 | images_of_person[:ceil(total * train_val_split)]) 47 | validation_set += map( 48 | add_class, 49 | images_of_person[floor(total * train_val_split):]) 50 | 51 | return training_set, validation_set, len(names) 52 | 53 | 54 | class Dataset(data.Dataset): 55 | 56 | def __init__(self, datasets, transform=None, target_transform=None): 57 | self.datasets = datasets 58 | self.num_classes = len(datasets) 59 | self.transform = transform 60 | self.target_transform = target_transform 61 | 62 | def __len__(self): 63 | return len(self.datasets) 64 | 65 | def __getitem__(self, index): 66 | image = image_loader(self.datasets[index][0]) 67 | if self.transform: 68 | image = self.transform(image) 69 | return (image, self.datasets[index][1], self.datasets[index][2]) 70 | 71 | 72 | class PairedDataset(data.Dataset): 73 | 74 | def __init__(self, dataroot, pairs_cfg, transform=None, loader=None): 75 | self.dataroot = dataroot 76 | self.pairs_cfg = pairs_cfg 77 | self.transform = transform 78 | self.loader = loader if loader else image_loader 79 | 80 | self.image_names_a = [] 81 | self.image_names_b = [] 82 | self.matches = [] 83 | 84 | self._prepare_dataset() 85 | 86 | def __len__(self): 87 | return len(self.matches) 88 | 89 | def __getitem__(self, index): 90 | return (self.transform(self.loader(self.image_names_a[index])), 91 | self.transform(self.loader(self.image_names_b[index])), 92 | self.matches[index]) 93 | 94 | def _prepare_dataset(self): 95 | raise NotImplementedError 96 | 97 | 98 | class LFWPairedDataset(PairedDataset): 99 | 100 | def _prepare_dataset(self): 101 | pairs = self._read_pairs(self.pairs_cfg) 102 | 103 | for pair in pairs: 104 | if len(pair) == 3: 105 | match = True 106 | name1, name2, index1, index2 = \ 107 | pair[0], pair[0], int(pair[1]), int(pair[2]) 108 | 109 | else: 110 | match = False 111 | name1, name2, index1, index2 = \ 112 | pair[0], pair[2], int(pair[1]), int(pair[3]) 113 | 114 | self.image_names_a.append(os.path.join( 115 | self.dataroot, 'lfw-deepfunneled', 116 | name1, "{}_{:04d}.jpg".format(name1, index1))) 117 | 118 | self.image_names_b.append(os.path.join( 119 | self.dataroot, 'lfw-deepfunneled', 120 | name2, "{}_{:04d}.jpg".format(name2, index2))) 121 | self.matches.append(match) 122 | 123 | def _read_pairs(self, pairs_filename): 124 | pairs = [] 125 | with open(pairs_filename, 'r') as f: 126 | for line in f.readlines()[1:]: 127 | pair = line.strip().split() 128 | pairs.append(pair) 129 | return pairs 130 | -------------------------------------------------------------------------------- /device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device("cuda") -------------------------------------------------------------------------------- /imageaug.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | 4 | def transform_for_training(image_shape): 5 | return transforms.Compose( 6 | [transforms.ToPILImage(), 7 | transforms.Resize(image_shape), 8 | transforms.RandomHorizontalFlip(), 9 | transforms.ToTensor(), 10 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] 11 | ) 12 | 13 | 14 | def transform_for_infer(image_shape): 15 | return transforms.Compose( 16 | [transforms.ToPILImage(), 17 | transforms.Resize(image_shape), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])] 21 | ) -------------------------------------------------------------------------------- /images/center-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/center-loss.png -------------------------------------------------------------------------------- /images/obama_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/obama_a.png -------------------------------------------------------------------------------- /images/obama_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/obama_b.png -------------------------------------------------------------------------------- /images/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/result.jpg -------------------------------------------------------------------------------- /images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/result.png -------------------------------------------------------------------------------- /images/roc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/roc.png -------------------------------------------------------------------------------- /images/softmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/softmax.png -------------------------------------------------------------------------------- /images/trump_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/trump_a.png -------------------------------------------------------------------------------- /images/trump_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/louis-she/center-loss.pytorch/5be899d1f622d24d7de0039dc50b54ce5a6b1151/images/trump_b.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from device import device 4 | 5 | 6 | def compute_center_loss(features, centers, targets): 7 | features = features.view(features.size(0), -1) 8 | target_centers = centers[targets] 9 | criterion = torch.nn.MSELoss() 10 | center_loss = criterion(features, target_centers) 11 | return center_loss 12 | 13 | 14 | def get_center_delta(features, centers, targets, alpha): 15 | # implementation equation (4) in the center-loss paper 16 | features = features.view(features.size(0), -1) 17 | targets, indices = torch.sort(targets) 18 | target_centers = centers[targets] 19 | features = features[indices] 20 | 21 | delta_centers = target_centers - features 22 | uni_targets, indices = torch.unique( 23 | targets.cpu(), sorted=True, return_inverse=True) 24 | 25 | uni_targets = uni_targets.to(device) 26 | indices = indices.to(device) 27 | 28 | delta_centers = torch.zeros( 29 | uni_targets.size(0), delta_centers.size(1) 30 | ).to(device).index_add_(0, indices, delta_centers) 31 | 32 | targets_repeat_num = uni_targets.size()[0] 33 | uni_targets_repeat_num = targets.size()[0] 34 | targets_repeat = targets.repeat( 35 | targets_repeat_num).view(targets_repeat_num, -1) 36 | uni_targets_repeat = uni_targets.unsqueeze(1).repeat( 37 | 1, uni_targets_repeat_num) 38 | same_class_feature_count = torch.sum( 39 | targets_repeat == uni_targets_repeat, dim=1).float().unsqueeze(1) 40 | 41 | delta_centers = delta_centers / (same_class_feature_count + 1.0) * alpha 42 | result = torch.zeros_like(centers) 43 | result[uni_targets, :] = delta_centers 44 | return result 45 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | import numpy as np 8 | 9 | from dataset import Dataset, create_datasets, LFWPairedDataset 10 | from loss import compute_center_loss, get_center_delta 11 | from models import Resnet50FaceModel, Resnet18FaceModel 12 | from device import device 13 | from trainer import Trainer 14 | from utils import download, generate_roc_curve, image_loader 15 | from metrics import compute_roc, select_threshold 16 | from imageaug import transform_for_infer, transform_for_training 17 | 18 | 19 | def main(args): 20 | if args.evaluate: 21 | evaluate(args) 22 | elif args.verify_model: 23 | verify(args) 24 | else: 25 | train(args) 26 | 27 | 28 | def get_dataset_dir(args): 29 | home = os.path.expanduser("~") 30 | dataset_dir = args.dataset_dir if args.dataset_dir else os.path.join( 31 | home, 'datasets', 'lfw') 32 | 33 | if not os.path.isdir(dataset_dir): 34 | os.mkdir(dataset_dir) 35 | 36 | return dataset_dir 37 | 38 | 39 | def get_log_dir(args): 40 | log_dir = args.log_dir if args.log_dir else os.path.join( 41 | os.path.dirname(os.path.realpath(__file__)), 'logs') 42 | 43 | if not os.path.isdir(log_dir): 44 | os.mkdir(log_dir) 45 | 46 | return log_dir 47 | 48 | 49 | def get_model_class(args): 50 | if args.arch == 'resnet18': 51 | model_class = Resnet18FaceModel 52 | if args.arch == 'resnet50': 53 | model_class = Resnet50FaceModel 54 | elif args.arch == 'inceptionv3': 55 | model_class = InceptionFaceModel 56 | 57 | return model_class 58 | 59 | 60 | def train(args): 61 | dataset_dir = get_dataset_dir(args) 62 | log_dir = get_log_dir(args) 63 | model_class = get_model_class(args) 64 | 65 | training_set, validation_set, num_classes = create_datasets(dataset_dir) 66 | 67 | training_dataset = Dataset( 68 | training_set, transform_for_training(model_class.IMAGE_SHAPE)) 69 | validation_dataset = Dataset( 70 | validation_set, transform_for_infer(model_class.IMAGE_SHAPE)) 71 | 72 | training_dataloader = torch.utils.data.DataLoader( 73 | training_dataset, 74 | batch_size=args.batch_size, 75 | num_workers=6, 76 | shuffle=True 77 | ) 78 | 79 | validation_dataloader = torch.utils.data.DataLoader( 80 | validation_dataset, 81 | batch_size=args.batch_size, 82 | num_workers=6, 83 | shuffle=False 84 | ) 85 | 86 | model = model_class(num_classes).to(device) 87 | 88 | trainables_wo_bn = [param for name, param in model.named_parameters() if 89 | param.requires_grad and 'bn' not in name] 90 | trainables_only_bn = [param for name, param in model.named_parameters() if 91 | param.requires_grad and 'bn' in name] 92 | 93 | optimizer = torch.optim.SGD([ 94 | {'params': trainables_wo_bn, 'weight_decay': 0.0001}, 95 | {'params': trainables_only_bn} 96 | ], lr=args.lr, momentum=0.9) 97 | 98 | trainer = Trainer( 99 | optimizer, 100 | model, 101 | training_dataloader, 102 | validation_dataloader, 103 | max_epoch=args.epochs, 104 | resume=args.resume, 105 | log_dir=log_dir 106 | ) 107 | trainer.train() 108 | 109 | 110 | def evaluate(args): 111 | dataset_dir = get_dataset_dir(args) 112 | log_dir = get_log_dir(args) 113 | model_class = get_model_class(args) 114 | 115 | pairs_path = args.pairs if args.pairs else \ 116 | os.path.join(dataset_dir, 'pairs.txt') 117 | 118 | if not os.path.isfile(pairs_path): 119 | download(dataset_dir, 'http://vis-www.cs.umass.edu/lfw/pairs.txt') 120 | 121 | dataset = LFWPairedDataset( 122 | dataset_dir, pairs_path, transform_for_infer(model_class.IMAGE_SHAPE)) 123 | dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=4) 124 | model = model_class(False).to(device) 125 | 126 | checkpoint = torch.load(args.evaluate) 127 | model.load_state_dict(checkpoint['state_dict'], strict=False) 128 | model.eval() 129 | 130 | embedings_a = torch.zeros(len(dataset), model.FEATURE_DIM) 131 | embedings_b = torch.zeros(len(dataset), model.FEATURE_DIM) 132 | matches = torch.zeros(len(dataset), dtype=torch.uint8) 133 | 134 | for iteration, (images_a, images_b, batched_matches) \ 135 | in enumerate(dataloader): 136 | current_batch_size = len(batched_matches) 137 | images_a = images_a.to(device) 138 | images_b = images_b.to(device) 139 | 140 | _, batched_embedings_a = model(images_a) 141 | _, batched_embedings_b = model(images_b) 142 | 143 | start = args.batch_size * iteration 144 | end = start + current_batch_size 145 | 146 | embedings_a[start:end, :] = batched_embedings_a.data 147 | embedings_b[start:end, :] = batched_embedings_b.data 148 | matches[start:end] = batched_matches.data 149 | 150 | thresholds = np.arange(0, 4, 0.1) 151 | distances = torch.sum(torch.pow(embedings_a - embedings_b, 2), dim=1) 152 | 153 | tpr, fpr, accuracy, best_thresholds = compute_roc( 154 | distances, 155 | matches, 156 | thresholds 157 | ) 158 | 159 | roc_file = args.roc if args.roc else os.path.join(log_dir, 'roc.png') 160 | generate_roc_curve(fpr, tpr, roc_file) 161 | print('Model accuracy is {}'.format(accuracy)) 162 | print('ROC curve generated at {}'.format(roc_file)) 163 | 164 | 165 | def verify(args): 166 | dataset_dir = get_dataset_dir(args) 167 | log_dir = get_log_dir(args) 168 | model_class = get_model_class(args) 169 | 170 | model = model_class(False).to(device) 171 | checkpoint = torch.load(args.verify_model) 172 | model.load_state_dict(checkpoint['state_dict'], strict=False) 173 | model.eval() 174 | 175 | image_a, image_b = args.verify_images.split(',') 176 | image_a = transform_for_infer( 177 | model_class.IMAGE_SHAPE)(image_loader(image_a)) 178 | image_b = transform_for_infer( 179 | model_class.IMAGE_SHAPE)(image_loader(image_b)) 180 | images = torch.stack([image_a, image_b]).to(device) 181 | 182 | _, (embedings_a, embedings_b) = model(images) 183 | 184 | distance = torch.sum(torch.pow(embedings_a - embedings_b, 2)).item() 185 | print("distance: {}".format(distance)) 186 | 187 | 188 | if __name__ == '__main__': 189 | 190 | parser = argparse.ArgumentParser(description='center loss example') 191 | parser.add_argument('--batch_size', type=int, default=256, metavar='N', 192 | help='input batch size for training (default: 256)') 193 | parser.add_argument('--log_dir', type=str, 194 | help='log directory') 195 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 196 | help='number of epochs to train (default: 100)') 197 | parser.add_argument('--lr', type=float, default=0.001, 198 | help='learning rate (default: 0.001)') 199 | parser.add_argument('--arch', type=str, default='resnet50', 200 | help='network arch to use, support resnet18 and ' 201 | 'resnet50 (default: resnet50)') 202 | parser.add_argument('--resume', type=str, 203 | help='model path to the resume training', 204 | default=False) 205 | parser.add_argument('--dataset_dir', type=str, 206 | help='directory with lfw dataset' 207 | ' (default: $HOME/datasets/lfw)') 208 | parser.add_argument('--weights', type=str, 209 | help='pretrained weights to load ' 210 | 'default: ($LOG_DIR/resnet18.pth)') 211 | parser.add_argument('--evaluate', type=str, 212 | help='evaluate specified model on lfw dataset') 213 | parser.add_argument('--pairs', type=str, 214 | help='path of pairs.txt ' 215 | '(default: $DATASET_DIR/pairs.txt)') 216 | parser.add_argument('--roc', type=str, 217 | help='path of roc.png to generated ' 218 | '(default: $DATASET_DIR/roc.png)') 219 | parser.add_argument('--verify-model', type=str, 220 | help='verify 2 images of face belong to one person,' 221 | 'the param is the model to use') 222 | parser.add_argument('--verify-images', type=str, 223 | help='verify 2 images of face belong to one person,' 224 | 'split image pathes by comma') 225 | 226 | args = parser.parse_args() 227 | main(args) 228 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import KFold 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def select_threshold(distances, matches, thresholds): 7 | best_threshold_true_predicts = 0 8 | best_threshold = 0 9 | for threshold in thresholds: 10 | true_predicts = torch.sum(( 11 | distances < threshold 12 | ) == matches) 13 | 14 | if true_predicts > best_threshold_true_predicts: 15 | best_threshold_true_predicts = true_predicts 16 | best_threshold = threshold 17 | 18 | return best_threshold 19 | 20 | 21 | def compute_roc(distances, matches, thresholds, fold_size=10): 22 | assert(len(distances) == len(matches)) 23 | 24 | kf = KFold(n_splits=fold_size, shuffle=False) 25 | 26 | tpr = torch.zeros(fold_size, len(thresholds)) 27 | fpr = torch.zeros(fold_size, len(thresholds)) 28 | accuracy = torch.zeros(fold_size) 29 | best_thresholds = [] 30 | 31 | for fold_index, (training_indices, val_indices) \ 32 | in enumerate(kf.split(range(len(distances)))): 33 | 34 | training_distances = distances[training_indices] 35 | training_matches = matches[training_indices] 36 | 37 | # 1. find the best threshold for this fold using training set 38 | best_threshold_true_predicts = 0 39 | for threshold_index, threshold in enumerate(thresholds): 40 | true_predicts = torch.sum(( 41 | training_distances < threshold 42 | ) == training_matches) 43 | 44 | if true_predicts > best_threshold_true_predicts: 45 | best_threshold = threshold 46 | best_threshold_true_predicts = true_predicts 47 | 48 | # 2. calculate tpr, fpr on validation set 49 | val_distances = distances[val_indices] 50 | val_matches = matches[val_indices] 51 | for threshold_index, threshold in enumerate(thresholds): 52 | predicts = val_distances < threshold 53 | 54 | tp = torch.sum(predicts & val_matches).item() 55 | fp = torch.sum(predicts & ~val_matches).item() 56 | tn = torch.sum(~predicts & ~val_matches).item() 57 | fn = torch.sum(~predicts & val_matches).item() 58 | 59 | tpr[fold_index][threshold_index] = float(tp) / (tp + fn) 60 | fpr[fold_index][threshold_index] = float(fp) / (fp + tn) 61 | 62 | best_thresholds.append(best_threshold) 63 | accuracy[fold_index] = best_threshold_true_predicts.item() / float( 64 | len(training_indices)) 65 | 66 | # average fold 67 | tpr = torch.mean(tpr, dim=0).numpy() 68 | fpr = torch.mean(fpr, dim=0).numpy() 69 | accuracy = torch.mean(accuracy, dim=0).item() 70 | 71 | return tpr, fpr, accuracy, best_thresholds 72 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import Resnet18FaceModel, Resnet50FaceModel 2 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from device import device 5 | 6 | 7 | class FaceModel(nn.Module): 8 | 9 | def __init__(self, num_classes, feature_dim): 10 | super().__init__() 11 | self.num_classes = num_classes 12 | self.feature_dim = feature_dim 13 | 14 | if num_classes: 15 | self.register_buffer('centers', ( 16 | torch.rand(num_classes, feature_dim).to(device) - 0.5) * 2) 17 | self.classifier = nn.Linear(self.feature_dim, num_classes) 18 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision.models import resnet18, resnet50 4 | 5 | from .base import FaceModel 6 | from device import device 7 | 8 | 9 | class ResnetFaceModel(FaceModel): 10 | 11 | IMAGE_SHAPE = (96, 128) 12 | 13 | def __init__(self, num_classes, feature_dim): 14 | super().__init__(num_classes, feature_dim) 15 | 16 | self.extract_feature = nn.Linear( 17 | self.feature_dim*4*3, self.feature_dim) 18 | self.num_classes = num_classes 19 | if self.num_classes: 20 | self.classifier = nn.Linear(self.feature_dim, num_classes) 21 | 22 | def forward(self, x): 23 | x = self.base.conv1(x) 24 | x = self.base.bn1(x) 25 | x = self.base.relu(x) 26 | x = self.base.maxpool(x) 27 | x = self.base.layer1(x) 28 | x = self.base.layer2(x) 29 | x = self.base.layer3(x) 30 | x = self.base.layer4(x) 31 | 32 | x = x.view(x.size(0), -1) 33 | feature = self.extract_feature(x) 34 | logits = self.classifier(feature) if self.num_classes else None 35 | 36 | feature_normed = feature.div( 37 | torch.norm(feature, p=2, dim=1, keepdim=True).expand_as(feature)) 38 | 39 | return logits, feature_normed 40 | 41 | 42 | class Resnet18FaceModel(ResnetFaceModel): 43 | 44 | FEATURE_DIM = 512 45 | 46 | def __init__(self, num_classes): 47 | super().__init__(num_classes, self.FEATURE_DIM) 48 | self.base = resnet18(pretrained=True) 49 | 50 | 51 | class Resnet50FaceModel(ResnetFaceModel): 52 | 53 | FEATURE_DIM = 2048 54 | 55 | def __init__(self, num_classes): 56 | super().__init__(num_classes, self.FEATURE_DIM) 57 | self.base = resnet50(pretrained=True) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | def load_tests(loader, tests, pattern): 5 | suite = unittest.TestSuite() 6 | this_dir = os.path.dirname(__file__) 7 | package_tests = loader.discover(start_dir=this_dir, pattern='*_test.py') 8 | suite.addTests(package_tests) 9 | return suite 10 | 11 | if __name__ == '__main__': 12 | unittest.main() 13 | -------------------------------------------------------------------------------- /tests/center_loss_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import numpy 5 | 6 | from loss import get_center_delta, compute_center_loss 7 | from device import device 8 | 9 | 10 | class CenterLossTest(unittest.TestCase): 11 | 12 | def setUp(self): 13 | # Mock features, centers and targets 14 | self.features = torch.tensor( 15 | ((1, 2, 3), (4, 5, 6), (7, 8, 9)) 16 | ).float().to(device) 17 | 18 | self.centers = torch.tensor( 19 | ((1, 1, 1), (2, 2, 2), (3, 3, 3), (5, 5, 5)) 20 | ).float().to(device) 21 | 22 | self.targets = torch.tensor((1, 3, 1)).to(device) 23 | self.alpha = 0.1 24 | 25 | def test_get_center_delta(self): 26 | result = get_center_delta( 27 | self.features, self.centers, self.targets, self.alpha) 28 | # size should match 29 | self.assertTrue(result.size() == self.centers.size()) 30 | # for class 1 31 | class1_result = -( 32 | (self.features[0] + self.features[2]) - 33 | 2 * self.centers[1]) / 3 * self.alpha 34 | 35 | self.assertEqual(3, torch.sum(result[1] == class1_result).item()) 36 | # for class 3 37 | class3_result = -(self.features[1] - self.centers[3]) / 2 * self.alpha 38 | self.assertEqual(3, torch.sum(result[3] == class3_result).item()) 39 | 40 | # others should all be zero 41 | sum_others = torch.sum(result[(0, 2), :]).item() 42 | self.assertEqual(0, sum_others) 43 | 44 | def test_compute_center_loss(self): 45 | 46 | loss = torch.mean( 47 | (self.features[(0, 2, 1), :] - self.centers[(1, 1, 3), :]) ** 2) 48 | 49 | self.assertEqual(loss, compute_center_loss( 50 | self.features, self.centers, self.targets)) 51 | 52 | if __name__ == '__main__': 53 | unittest.main() -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from device import device 6 | from loss import compute_center_loss, get_center_delta 7 | 8 | 9 | class Trainer(object): 10 | 11 | def __init__( 12 | self, optimizer, model, training_dataloader, 13 | validation_dataloader, log_dir=False, max_epoch=100, resume=False, 14 | persist_stride=5, lamda=0.03, alpha=0.5): 15 | 16 | self.log_dir = log_dir 17 | self.optimizer = optimizer 18 | self.model = model 19 | self.max_epoch = max_epoch 20 | self.resume = resume 21 | self.persist_stride = persist_stride 22 | self.training_dataloader = training_dataloader 23 | self.validation_dataloader = validation_dataloader 24 | self.training_losses = { 25 | 'center': [], 'cross_entropy': [], 26 | 'together': [], 'top3acc': [], 'top1acc': []} 27 | self.validation_losses = { 28 | 'center': [], 'cross_entropy': [], 29 | 'together': [], 'top3acc': [], 'top1acc': []} 30 | self.start_epoch = 1 31 | self.current_epoch = 1 32 | self.lamda = lamda 33 | self.alpha = alpha 34 | 35 | if not self.log_dir: 36 | self.log_dir = os.path.join(os.path.dirname( 37 | os.path.realpath(__file__)), 'logs') 38 | if not os.path.isdir(self.log_dir): 39 | os.mkdir(self.log_dir) 40 | 41 | if resume: 42 | state_file = os.path.join(self.log_dir, 'models', resume) 43 | if not os.path.isfile(state_file): 44 | raise RuntimeError( 45 | "resume file {} is not found".format(state_file)) 46 | print("loading checkpoint {}".format(state_file)) 47 | checkpoint = torch.load(state_file) 48 | self.start_epoch = self.current_epoch = checkpoint['epoch'] 49 | self.model.load_state_dict(checkpoint['state_dict'], strict=True) 50 | self.optimizer.load_state_dict(checkpoint['optimizer']) 51 | self.training_losses = checkpoint['training_losses'] 52 | self.validation_losses = checkpoint['validation_losses'] 53 | print("loaded checkpoint {} (epoch {})".format( 54 | state_file, self.current_epoch)) 55 | 56 | def train(self): 57 | for self.current_epoch in range(self.start_epoch, self.max_epoch+1): 58 | self.run_epoch(mode='train') 59 | self.run_epoch(mode='validate') 60 | if not (self.current_epoch % self.persist_stride): 61 | self.persist() 62 | 63 | def run_epoch(self, mode): 64 | if mode == 'train': 65 | dataloader = self.training_dataloader 66 | loss_recorder = self.training_losses 67 | self.model.train() 68 | else: 69 | dataloader = self.validation_dataloader 70 | loss_recorder = self.validation_losses 71 | self.model.eval() 72 | 73 | total_cross_entropy_loss = 0 74 | total_center_loss = 0 75 | total_loss = 0 76 | total_top1_matches = 0 77 | total_top3_matches = 0 78 | batch = 0 79 | 80 | with torch.set_grad_enabled(mode == 'train'): 81 | for images, targets, names in dataloader: 82 | batch += 1 83 | targets = torch.tensor(targets).to(device) 84 | images = images.to(device) 85 | centers = self.model.centers 86 | 87 | logits, features = self.model(images) 88 | 89 | cross_entropy_loss = torch.nn.functional.cross_entropy( 90 | logits, targets) 91 | center_loss = compute_center_loss(features, centers, targets) 92 | loss = self.lamda * center_loss + cross_entropy_loss 93 | 94 | print("[{}:{}] cross entropy loss: {:.8f} - center loss: " 95 | "{:.8f} - total weighted loss: {:.8f}".format( 96 | mode, self.current_epoch, 97 | cross_entropy_loss.item(), 98 | center_loss.item(), loss.item())) 99 | 100 | total_cross_entropy_loss += cross_entropy_loss 101 | total_center_loss += center_loss 102 | total_loss += loss 103 | 104 | if mode == 'train': 105 | self.optimizer.zero_grad() 106 | loss.backward() 107 | self.optimizer.step() 108 | 109 | # make features untrack by autograd, or there will be 110 | # a memory leak when updating the centers 111 | center_deltas = get_center_delta( 112 | features.data, centers, targets, self.alpha) 113 | self.model.centers = centers - center_deltas 114 | 115 | # compute acc here 116 | total_top1_matches += self._get_matches(targets, logits, 1) 117 | total_top3_matches += self._get_matches(targets, logits, 3) 118 | 119 | center_loss = total_center_loss / batch 120 | cross_entropy_loss = total_cross_entropy_loss / batch 121 | loss = center_loss + cross_entropy_loss 122 | top1_acc = total_top1_matches / len(dataloader.dataset) 123 | top3_acc = total_top3_matches / len(dataloader.dataset) 124 | 125 | loss_recorder['center'].append(total_center_loss/batch) 126 | loss_recorder['cross_entropy'].append(cross_entropy_loss) 127 | loss_recorder['together'].append(total_loss/batch) 128 | loss_recorder['top1acc'].append(top1_acc) 129 | loss_recorder['top3acc'].append(top3_acc) 130 | 131 | print( 132 | "[{}:{}] finished. cross entropy loss: {:.8f} - " 133 | "center loss: {:.8f} - together: {:.8f} - " 134 | "top1 acc: {:.4f} % - top3 acc: {:.4f} %".format( 135 | mode, self.current_epoch, cross_entropy_loss.item(), 136 | center_loss.item(), loss.item(), 137 | top1_acc*100, top3_acc*100)) 138 | 139 | def _get_matches(self, targets, logits, n=1): 140 | _, preds = logits.topk(n, dim=1) 141 | targets_repeated = targets.view(-1, 1).repeat(1, n) 142 | matches = torch.sum(preds == targets_repeated, dim=1) \ 143 | .nonzero().size()[0] 144 | return matches 145 | 146 | def persist(self, is_best=False): 147 | model_dir = os.path.join(self.log_dir, 'models') 148 | if not os.path.isdir(model_dir): 149 | os.mkdir(model_dir) 150 | file_name = ( 151 | "epoch_{}_best.pth.tar" if is_best else "epoch_{}.pth.tar") \ 152 | .format(self.current_epoch) 153 | 154 | state = { 155 | 'epoch': self.current_epoch, 156 | 'state_dict': self.model.state_dict(), 157 | 'optimizer': self.optimizer.state_dict(), 158 | 'training_losses': self.training_losses, 159 | 'validation_losses': self.validation_losses 160 | } 161 | state_path = os.path.join(model_dir, file_name) 162 | torch.save(state, state_path) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | 4 | from tqdm import tqdm 5 | import numpy as np 6 | import cv2 7 | import requests 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def download(dir, url, dist=None): 12 | dist = dist if dist else url.split('/')[-1] 13 | print('Start to Download {} to {} from {}'.format(dist, dir, url)) 14 | download_path = os.path.join(dir, dist) 15 | if os.path.isfile(download_path): 16 | print('File {} already downloaded'.format(download_path)) 17 | return download_path 18 | r = requests.get(url, stream=True) 19 | total_size = int(r.headers.get('content-length', 0)) 20 | block_size = 1024 * 1024 21 | 22 | with open(download_path, 'wb') as f: 23 | for data in tqdm( 24 | r.iter_content(block_size), 25 | total=ceil(total_size//block_size), 26 | unit='MB', unit_scale=True): 27 | f.write(data) 28 | print('Downloaded {}'.format(dist)) 29 | return download_path 30 | 31 | 32 | def image_loader(image_path): 33 | return cv2.imread(image_path) 34 | 35 | 36 | def generate_roc_curve(fpr, tpr, path): 37 | assert len(fpr) == len(tpr) 38 | 39 | fig = plt.figure() 40 | plt.xlabel('FPR') 41 | plt.ylabel('TPR') 42 | plt.plot(fpr, tpr) 43 | fig.savefig(path, dpi=fig.dpi) --------------------------------------------------------------------------------