├── utils ├── __init__.py └── benchmark_helpers.py ├── fer2013 ├── __init__.py ├── fer.py └── fer_loader.py ├── imagenet ├── __init__.py ├── imagenet.py.bak └── evaluation.py ├── .gitignore ├── LICENSE.md ├── README.md ├── run_fer_benchmarks.py ├── run_imagenet_benchmarks.py ├── lfw_eval.py └── matlab_cp2tform.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fer2013/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .nfs* 4 | scratch 5 | res_cache 6 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Samuel Albanie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### pytorch-benchmark 2 | 3 | Some scripts for validating models on common benchmarks. Assumes at least Python3 and PyTorch 4.0. 4 | 5 | 6 | ### Supported datasets: 7 | 8 | * **ImageNet** (this is essentially just a cut-down version of the [official example](https://github.com/pytorch/examples/tree/master/imagenet)) 9 | * **Fer2013** - A dataset of greyscale faces labelled with emotions. 10 | 11 | 12 | 13 | ### References 14 | 15 | **ImageNet**: [paper](https://arxiv.org/abs/1409.0575) 16 | 17 | ``` 18 | @article{ILSVRC15, 19 | Author = {Olga Russakovsky and Jia Deng and Hao Su and Jonathan Krause and Sanjeev Satheesh and Sean Ma and Zhiheng Huang and Andrej Karpathy and Aditya Khosla and Michael Bernstein and Alexander C. Berg and Li Fei-Fei}, 20 | Title = {{ImageNet Large Scale Visual Recognition Challenge}}, 21 | Year = {2015}, 22 | journal = {International Journal of Computer Vision (IJCV)}, 23 | doi = {10.1007/s11263-015-0816-y}, 24 | volume={115}, 25 | number={3}, 26 | pages={211-252} 27 | } 28 | ``` 29 | 30 | **FER2013**: [paper](https://arxiv.org/abs/1307.0414) 31 | 32 | ``` 33 | @inproceedings{goodfellow2013challenges, 34 | title={Challenges in representation learning: A report on three machine learning contests}, 35 | author={Goodfellow, Ian J and Erhan, Dumitru and Carrier, Pierre Luc and Courville, Aaron and Mirza, Mehdi and Hamner, Ben and Cukierski, Will and Tang, Yichuan and Thaler, David and Lee, Dong-Hyun and others}, 36 | booktitle={International Conference on Neural Information Processing}, 37 | pages={117--124}, 38 | year={2013}, 39 | organization={Springer} 40 | } 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /utils/benchmark_helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utilties shared among the benchmarking protocols 3 | """ 4 | import os 5 | import sys 6 | import six 7 | 8 | import torchvision.transforms as transforms 9 | 10 | 11 | def compose_transforms(meta, resize=256, center_crop=True, 12 | override_meta_imsize=False): 13 | """Compose preprocessing transforms for model 14 | 15 | The imported models use a range of different preprocessing options, 16 | depending on how they were originally trained. Models trained in MatConvNet 17 | typically require input images that have been scaled to [0,255], rather 18 | than the [0,1] range favoured by PyTorch. 19 | 20 | Args: 21 | meta (dict): model preprocessing requirements 22 | resize (int) [256]: resize the input image to this size 23 | center_crop (bool) [True]: whether to center crop the image 24 | override_meta_imsize (bool) [False]: if true, use the value of `resize` 25 | to select the image input size, rather than the properties contained 26 | in meta (this option only applies when center cropping is not used. 27 | 28 | Return: 29 | (transforms.Compose): Composition of preprocessing transforms 30 | """ 31 | normalize = transforms.Normalize(mean=meta['mean'], std=meta['std']) 32 | im_size = meta['imageSize'] 33 | assert im_size[0] == im_size[1], 'expected square image size' 34 | if center_crop: 35 | transform_list = [transforms.Resize(resize), 36 | transforms.CenterCrop(size=(im_size[0], im_size[1]))] 37 | else: 38 | if override_meta_imsize: 39 | im_size = (resize, resize) 40 | transform_list = [transforms.Resize(size=(im_size[0], im_size[1]))] 41 | transform_list += [transforms.ToTensor()] 42 | if meta['std'] == [1, 1, 1]: # common amongst mcn models 43 | transform_list += [lambda x: x * 255.0] 44 | transform_list.append(normalize) 45 | return transforms.Compose(transform_list) 46 | 47 | 48 | def load_module_2or3(model_name, model_def_path): 49 | """Load model definition module in a manner that is compatible with 50 | both Python2 and Python3 51 | 52 | Args: 53 | model_name: The name of the model to be loaded 54 | model_def_path: The filepath of the module containing the definition 55 | 56 | Return: 57 | The loaded python module.""" 58 | if six.PY3: 59 | import importlib.util 60 | spec = importlib.util.spec_from_file_location(model_name, model_def_path) 61 | mod = importlib.util.module_from_spec(spec) 62 | spec.loader.exec_module(mod) 63 | else: 64 | import importlib 65 | dirname = os.path.dirname(model_def_path) 66 | sys.path.insert(0, dirname) 67 | module_name = os.path.splitext(os.path.basename(model_def_path))[0] 68 | mod = importlib.import_module(module_name) 69 | return mod 70 | -------------------------------------------------------------------------------- /run_fer_benchmarks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """This module evaluates imported PyTorch models on fer2013 3 | """ 4 | 5 | import os 6 | import argparse 7 | from os.path import join as pjoin 8 | from fer2013.fer import fer2013_benchmark 9 | from utils.benchmark_helpers import load_module_2or3 10 | 11 | MODEL_DIR = os.path.expanduser('~/data/models/pytorch/mcn_imports') 12 | FER_DIR = os.path.expanduser('~/data/datasets/fer2013+') 13 | CACHE_DIR = 'res_cache/fer2013+' 14 | 15 | def load_model(model_name): 16 | """Load imoprted PyTorch model by name 17 | 18 | Args: 19 | model_name (str): the name of the model to be loaded 20 | 21 | Return: 22 | nn.Module: the loaded network 23 | """ 24 | model_def_path = pjoin(MODEL_DIR, model_name + '.py') 25 | weights_path = pjoin(MODEL_DIR, model_name + '.pth') 26 | mod = load_module_2or3(model_name, model_def_path) 27 | func = getattr(mod, model_name) 28 | net = func(weights_path=weights_path) 29 | return net 30 | 31 | def run_benchmarks(gpus, refresh, fer_plus): 32 | """Run bencmarks for imported models 33 | 34 | Args: 35 | gpus (str): comma separated gpu device identifiers 36 | refresh (bool): whether to overwrite the results of existing runs 37 | fer_plus (bool): whether to evaluate on the ferplus benchmark, 38 | rather than the standard fer benchmark. 39 | """ 40 | 41 | # Select models (and their batch sizes) to include in the benchmark. 42 | if fer_plus: 43 | model_list = [ 44 | ('resnet50_ferplus_dag', 32), 45 | ('senet50_ferplus_dag', 32), 46 | ] 47 | else: 48 | model_list = [ 49 | ('alexnet_face_fer_bn_dag', 32), 50 | ('vgg_m_face_bn_fer_dag', 32), 51 | ('vgg_vd_face_fer_dag', 32), 52 | ] 53 | 54 | if not os.path.exists(CACHE_DIR): 55 | os.makedirs(CACHE_DIR) 56 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 57 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpus) 58 | 59 | opts = {'data_dir': FER_DIR, 'refresh_cache': refresh} 60 | 61 | for model_name, batch_size in model_list: 62 | cache_name = model_name 63 | if fer_plus: 64 | cache_name = cache_name + 'fer_plus' 65 | opts['res_cache'] = '{}/{}.pth'.format(CACHE_DIR, cache_name) 66 | opts['fer_plus'] = fer_plus 67 | model = load_model(model_name) 68 | print('benchmarking {}'.format(model_name)) 69 | fer2013_benchmark(model, batch_size=batch_size, **opts) 70 | 71 | parser = argparse.ArgumentParser(description='Run PyTorch benchmarks.') 72 | parser.add_argument('--gpus', nargs='?', dest='gpus', 73 | help='select gpu device id') 74 | parser.add_argument('--refresh', dest='refresh', action='store_true', 75 | help='refresh results cache') 76 | parser.add_argument('--ferplus', dest='ferplus', action='store_true', 77 | help='run ferplus (rather than fer) benchmarks') 78 | parser.set_defaults(gpus=None) 79 | parser.set_defaults(refresh=False) 80 | parsed = parser.parse_args() 81 | 82 | if __name__ == '__main__': 83 | run_benchmarks(parsed.gpus, parsed.refresh, parsed.ferplus) 84 | -------------------------------------------------------------------------------- /fer2013/fer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Fer2013 benchmark 3 | 4 | The module evaluates the performance of a pytorch model on the FER2013 5 | benchmark. 6 | """ 7 | 8 | from __future__ import division 9 | 10 | import os 11 | import time 12 | 13 | import torch 14 | import numpy as np 15 | import torch.utils.data 16 | import torch.backends.cudnn as cudnn 17 | from fer2013.fer_loader import Fer2013Dataset, Fer2013PlusDataset 18 | from utils.benchmark_helpers import compose_transforms 19 | 20 | def fer2013_benchmark(model, data_dir, res_cache, refresh_cache, 21 | batch_size=256, num_workers=2, fer_plus=False): 22 | if not refresh_cache: # load result from cache, if available 23 | if os.path.isfile(res_cache): 24 | res = torch.load(res_cache) 25 | prec1_val, prec1_test = res['prec1_val'], res['prec1_test'] 26 | print("=> loaded results from '{}'".format(res_cache)) 27 | info = (prec1_val, prec1_test, res['speed']) 28 | msg = 'val acc: {:.2f}, test acc: {:.2f}, Speed: {:.1f}Hz' 29 | print(msg.format(*info)) 30 | return 31 | 32 | meta = model.meta 33 | cudnn.benchmark = True 34 | model = torch.nn.DataParallel(model).cuda() 35 | preproc_transforms = compose_transforms(meta, center_crop=False) 36 | if fer_plus: 37 | dataset = Fer2013PlusDataset 38 | else: 39 | dataset = Fer2013Dataset 40 | speeds = [] 41 | res = {} 42 | for mode in 'val', 'test': 43 | loader = torch.utils.data.DataLoader( 44 | dataset(data_dir, mode=mode, transform=preproc_transforms), 45 | batch_size=batch_size, shuffle=False, 46 | num_workers=num_workers, pin_memory=True) 47 | prec1, speed = validate(loader, model, mode) 48 | res['prec1_{}'.format(mode)] = prec1 49 | speeds.append(speed) 50 | res['speed'] = np.mean(speed) 51 | torch.save(res, res_cache) 52 | 53 | def validate(val_loader, model, mode): 54 | model.eval() 55 | top1 = AverageMeter() 56 | speed = WarmupAverageMeter() 57 | end = time.time() 58 | with torch.no_grad(): 59 | for ii, (ims, target) in enumerate(val_loader): 60 | target = target.cuda(async=True) 61 | output = model(ims) # compute output 62 | prec1, = accuracy(output.data, target, topk=(1,)) 63 | top1.update(prec1[0], ims.size(0)) 64 | speed.update(time.time() - end, ims.size(0)) 65 | end = time.time() 66 | if ii % 10 == 0: 67 | msg = ('{0}: [{1}/{2}]\tSpeed {speed.current:.1f}Hz\t' 68 | '({speed.avg:.1f})Hz\tPrec@1 {top1.avg:.3f}') 69 | print(msg.format(mode, ii, len(val_loader), 70 | speed=speed, top1=top1)) 71 | print(' * Accuracy {0:.3f}'.format(top1.avg)) 72 | return top1.avg, speed.avg 73 | 74 | class WarmupAverageMeter(object): 75 | """Computes and stores the average and current value, after a fixed 76 | warmup period (useful for approximate benchmarking) 77 | 78 | Args: 79 | warmup (int) [3]: The number of updates to be ignored before the 80 | average starts to be computed. 81 | """ 82 | def __init__(self, warmup=3): 83 | self.reset() 84 | self.warmup = warmup 85 | 86 | def reset(self): 87 | self.avg = 0 88 | self.current = 0 89 | self.delta_sum = 0 90 | self.count = 0 91 | self.warmup_count = 0 92 | 93 | def update(self, delta, n): 94 | self.warmup_count = self.warmup_count + 1 95 | if self.warmup_count >= self.warmup: 96 | self.current = n / delta 97 | self.delta_sum += delta 98 | self.count += n 99 | self.avg = self.count / self.delta_sum 100 | 101 | class AverageMeter(object): 102 | """Computes and stores the average and current value""" 103 | def __init__(self): 104 | self.reset() 105 | 106 | def reset(self): 107 | self.val = 0 108 | self.avg = 0 109 | self.sum = 0 110 | self.count = 0 111 | 112 | def update(self, val, n=1): 113 | self.val = val 114 | self.sum += val * n 115 | self.count += n 116 | self.avg = self.sum / self.count 117 | 118 | def accuracy(output, target, topk=(1,)): 119 | """Computes the precision@k for the specified values of k""" 120 | maxk = max(topk) 121 | batch_size = target.size(0) 122 | 123 | _, pred = output.topk(maxk, 1, True, True) 124 | pred = pred.t() 125 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 126 | 127 | res = [] 128 | for k in topk: 129 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 130 | res.append(correct_k.mul_(100.0 / batch_size)) 131 | return res 132 | -------------------------------------------------------------------------------- /imagenet/imagenet.py.bak: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Imagenet validation set benchmark 3 | 4 | The module evaluates the performance of a pytorch model on the ILSVRC 2012 5 | validation set. 6 | 7 | Based on PyTorch imagenet example: 8 | https://github.com/pytorch/examples/tree/master/imagenet 9 | """ 10 | 11 | from __future__ import division 12 | 13 | import os 14 | import time 15 | 16 | from PIL import ImageFile 17 | import torch 18 | import torch.nn.parallel 19 | import torch.utils.data 20 | import torch.backends.cudnn as cudnn 21 | import torchvision.datasets as datasets 22 | from utils.benchmark_helpers import compose_transforms 23 | 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | 26 | def imagenet_benchmark(model, data_dir, res_cache, refresh_cache, 27 | batch_size=256, num_workers=20, 28 | remove_blacklist=False, center_crop=True): 29 | if not refresh_cache: # load result from cache, if available 30 | if os.path.isfile(res_cache): 31 | res = torch.load(res_cache) 32 | prec1, prec5, speed = res['prec1'], res['prec5'], res['speed'] 33 | print("=> loaded results from '{}'".format(res_cache)) 34 | info = (100 - prec1, 100 - prec5, speed) 35 | msg = 'Top 1 err: {:.2f}, Top 5 err: {:.2f}, Speed: {:.1f}Hz' 36 | print(msg.format(*info)) 37 | return 38 | 39 | meta = model.meta 40 | cudnn.benchmark = True 41 | model = torch.nn.DataParallel(model).cuda() 42 | if remove_blacklist: 43 | subset = 'val_blacklisted' 44 | else: 45 | subset = 'val' 46 | valdir = os.path.join(data_dir, subset) 47 | preproc_transforms = compose_transforms(meta, center_crop=center_crop) 48 | val_loader = torch.utils.data.DataLoader( 49 | datasets.ImageFolder(valdir, preproc_transforms), 50 | batch_size=batch_size, shuffle=False, 51 | num_workers=num_workers, pin_memory=True) 52 | prec1, prec5, speed = validate(val_loader, model) 53 | torch.save({'prec1': prec1, 'prec5': prec5, 'speed': speed}, res_cache) 54 | 55 | def validate(val_loader, model): 56 | model.eval() 57 | top1 = AverageMeter() 58 | top5 = AverageMeter() 59 | speed = WarmupAverageMeter() 60 | end = time.time() 61 | with torch.no_grad(): 62 | for ii, (ims, target) in enumerate(val_loader): 63 | target = target.cuda(async=True) 64 | # ims_var = torch.autograd.Variable(ims, volatile=True) 65 | output = model(ims) # compute output 66 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 67 | top1.update(prec1[0], ims.size(0)) 68 | top5.update(prec5[0], ims.size(0)) 69 | speed.update(time.time() - end, ims.size(0)) 70 | end = time.time() 71 | if ii % 10 == 0: 72 | msg = ('Test: [{0}/{1}]\tSpeed {speed.current:.1f}Hz\t' 73 | '({speed.avg:.1f})Hz\tPrec@1 {top1.avg:.3f} ' 74 | '{top5.avg:.3f}') 75 | print(msg.format(ii, len(val_loader), speed=speed, 76 | top1=top1, top5=top5)) 77 | top1_err, top5_err = 100 - top1.avg, 100 - top5.avg 78 | print(' * Err@1 {0:.3f} Err@5 {1:.3f}'.format(top1_err, top5_err)) 79 | 80 | return top1.avg, top5.avg, speed.avg 81 | 82 | class WarmupAverageMeter(object): 83 | """Computes and stores the average and current value, after a fixed 84 | warmup period (useful for approximate benchmarking) 85 | 86 | Args: 87 | warmup (int) [3]: The number of updates to be ignored before the 88 | average starts to be computed. 89 | """ 90 | def __init__(self, warmup=3): 91 | self.reset() 92 | self.warmup = warmup 93 | 94 | def reset(self): 95 | self.avg = 0 96 | self.current = 0 97 | self.delta_sum = 0 98 | self.count = 0 99 | self.warmup_count = 0 100 | 101 | def update(self, delta, n): 102 | self.warmup_count = self.warmup_count + 1 103 | if self.warmup_count >= self.warmup: 104 | self.current = n / delta 105 | self.delta_sum += delta 106 | self.count += n 107 | self.avg = self.count / self.delta_sum 108 | 109 | class AverageMeter(object): 110 | """Computes and stores the average and current value""" 111 | def __init__(self): 112 | self.reset() 113 | 114 | def reset(self): 115 | self.val = 0 116 | self.avg = 0 117 | self.sum = 0 118 | self.count = 0 119 | 120 | def update(self, val, n=1): 121 | self.val = val 122 | self.sum += val * n 123 | self.count += n 124 | self.avg = self.sum / self.count 125 | 126 | def accuracy(output, target, topk=(1,)): 127 | """Computes the precision@k for the specified values of k""" 128 | maxk = max(topk) 129 | batch_size = target.size(0) 130 | 131 | _, pred = output.topk(maxk, 1, True, True) 132 | pred = pred.t() 133 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 134 | 135 | res = [] 136 | for k in topk: 137 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 138 | res.append(correct_k.mul_(100.0 / batch_size)) 139 | return res 140 | -------------------------------------------------------------------------------- /imagenet/evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Imagenet validation set benchmark 3 | 4 | The module evaluates the performance of a pytorch model on the ILSVRC 2012 5 | validation set. 6 | 7 | Based on PyTorch imagenet example: 8 | https://github.com/pytorch/examples/tree/master/imagenet 9 | """ 10 | 11 | from __future__ import division 12 | 13 | import os 14 | import time 15 | 16 | from PIL import ImageFile 17 | import torch 18 | import torch.nn.parallel 19 | import torch.utils.data 20 | import torch.backends.cudnn as cudnn 21 | import torchvision.datasets as datasets 22 | from utils.benchmark_helpers import compose_transforms 23 | 24 | ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | 26 | 27 | def imagenet_benchmark(model, data_dir, res_cache, refresh_cache, batch_size=256, 28 | num_workers=20, remove_blacklist=False, center_crop=True, 29 | override_meta_imsize=False): 30 | if not refresh_cache: # load result from cache, if available 31 | if os.path.isfile(res_cache): 32 | res = torch.load(res_cache) 33 | prec1, prec5, speed = res['prec1'], res['prec5'], res['speed'] 34 | print("=> loaded results from '{}'".format(res_cache)) 35 | info = (100 - prec1, 100 - prec5, speed) 36 | msg = 'Top 1 err: {:.2f}, Top 5 err: {:.2f}, Speed: {:.1f}Hz' 37 | print(msg.format(*info)) 38 | return 39 | 40 | meta = model.meta 41 | cudnn.benchmark = True 42 | 43 | if override_meta_imsize: # NOTE REMOVE THIS LATER! 44 | import torch.nn as nn 45 | model.features_8 = nn.AdaptiveAvgPool2d(1) 46 | 47 | model = torch.nn.DataParallel(model).cuda() 48 | if remove_blacklist: 49 | subset = 'val_blacklisted' 50 | else: 51 | subset = 'val' 52 | valdir = os.path.join(data_dir, subset) 53 | preproc_transforms = compose_transforms(meta, resize=256, center_crop=center_crop, 54 | override_meta_imsize=override_meta_imsize) 55 | val_loader = torch.utils.data.DataLoader( 56 | datasets.ImageFolder(valdir, preproc_transforms), batch_size=batch_size, 57 | shuffle=False, num_workers=num_workers, pin_memory=True) 58 | prec1, prec5, speed = validate(val_loader, model) 59 | torch.save({'prec1': prec1, 'prec5': prec5, 'speed': speed}, res_cache) 60 | 61 | 62 | def validate(val_loader, model): 63 | model.eval() 64 | top1 = AverageMeter() 65 | top5 = AverageMeter() 66 | speed = WarmupAverageMeter() 67 | end = time.time() 68 | with torch.no_grad(): 69 | for ii, (ims, target) in enumerate(val_loader): 70 | target = target.cuda() 71 | # ims_var = torch.autograd.Variable(ims, volatile=True) 72 | output = model(ims) # compute output 73 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 74 | top1.update(prec1[0], ims.size(0)) 75 | top5.update(prec5[0], ims.size(0)) 76 | speed.update(time.time() - end, ims.size(0)) 77 | end = time.time() 78 | if ii % 10 == 0: 79 | msg = ('Test: [{0}/{1}]\tSpeed {speed.current:.1f}Hz\t' 80 | '({speed.avg:.1f})Hz\tPrec@1 {top1.avg:.3f} ' 81 | '{top5.avg:.3f}') 82 | print(msg.format(ii, len(val_loader), speed=speed, top1=top1, 83 | top5=top5)) 84 | top1_err, top5_err = 100 - top1.avg, 100 - top5.avg 85 | print(' * Err@1 {0:.3f} Err@5 {1:.3f}'.format(top1_err, top5_err)) 86 | 87 | return top1.avg, top5.avg, speed.avg 88 | 89 | 90 | class WarmupAverageMeter(object): 91 | """Computes and stores the average and current value, after a fixed 92 | warmup period (useful for approximate benchmarking) 93 | 94 | Args: 95 | warmup (int) [3]: The number of updates to be ignored before the 96 | average starts to be computed. 97 | """ 98 | 99 | def __init__(self, warmup=3): 100 | self.reset() 101 | self.warmup = warmup 102 | 103 | def reset(self): 104 | self.avg = 0 105 | self.current = 0 106 | self.delta_sum = 0 107 | self.count = 0 108 | self.warmup_count = 0 109 | 110 | def update(self, delta, n): 111 | self.warmup_count = self.warmup_count + 1 112 | if self.warmup_count >= self.warmup: 113 | self.current = n / delta 114 | self.delta_sum += delta 115 | self.count += n 116 | self.avg = self.count / self.delta_sum 117 | 118 | 119 | class AverageMeter(object): 120 | """Computes and stores the average and current value""" 121 | 122 | def __init__(self): 123 | self.reset() 124 | 125 | def reset(self): 126 | self.val = 0 127 | self.avg = 0 128 | self.sum = 0 129 | self.count = 0 130 | 131 | def update(self, val, n=1): 132 | self.val = val 133 | self.sum += val * n 134 | self.count += n 135 | self.avg = self.sum / self.count 136 | 137 | 138 | def accuracy(output, target, topk=(1, )): 139 | """Computes the precision@k for the specified values of k""" 140 | maxk = max(topk) 141 | batch_size = target.size(0) 142 | 143 | _, pred = output.topk(maxk, 1, True, True) 144 | pred = pred.t() 145 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 146 | 147 | res = [] 148 | for k in topk: 149 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 150 | res.append(correct_k.mul_(100.0 / batch_size)) 151 | return res 152 | -------------------------------------------------------------------------------- /run_imagenet_benchmarks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """This script evaluates imported PyTorch models on the 3 | ImageNet validation set 4 | 5 | e.g. 6 | python run_imagenet_benchmarks.py --model_subset pt_tpu --gpus 2 7 | """ 8 | 9 | import os 10 | import argparse 11 | from torchvision.models import densenet 12 | from imagenet.evaluation import imagenet_benchmark 13 | from pathlib import Path 14 | from utils.benchmark_helpers import load_module_2or3 15 | 16 | # directory containing imported pytorch models 17 | MODEL_DIR = os.path.expanduser('~/data/models/pytorch/mcn_imports/') 18 | 19 | # imagenet directory 20 | ILSVRC_DIR = os.path.expanduser('~/data/shared-datasets/ILSVRC2012-pytorch-val') 21 | 22 | # results cache directory 23 | CACHE_DIR = 'res_cache/imagenet' 24 | 25 | 26 | def load_torchvision_model(model_name): 27 | if 'densenet' in model_name: 28 | func = getattr(densenet, model_name) 29 | net = func(pretrained=True) 30 | net.meta = { 31 | 'mean': [0.485, 0.456, 0.406], 32 | 'std': [0.229, 0.224, 0.225], 33 | 'imageSize': [224, 224]} 34 | return net 35 | 36 | 37 | def load_tpu_converted_model(model_name, model_def_path, weights_path, **kwargs): 38 | mod = load_module_2or3(model_name, model_def_path) 39 | func = getattr(mod, model_name) 40 | mean, std = getattr(mod, "MEAN_RGB"), getattr(mod, "STDDEV_RGB") 41 | net = func(pretrained=weights_path) 42 | net.meta = { 43 | 'mean': mean, 44 | 'std': std, 45 | 'imageSize': [224, 224]} 46 | return net 47 | 48 | 49 | def load_model(model_name): 50 | """Load imoprted PyTorch model by name 51 | 52 | Args: 53 | model_name (str): the name of the model to be loaded 54 | 55 | Return: 56 | nn.Module: the loaded network 57 | """ 58 | if 'tv_' in model_name: 59 | import ipdb 60 | ipdb.set_trace() 61 | net = load_torchvision_model(model_name) 62 | else: 63 | model_def_path = os.path.join(MODEL_DIR, model_name + ".py") 64 | weights_path = os.path.join(MODEL_DIR, model_name + ".pth") 65 | mod = load_module_2or3(model_name, model_def_path) 66 | func = getattr(mod, model_name) 67 | net = func(weights_path=weights_path) 68 | return net 69 | 70 | 71 | def run_benchmarks(gpus, refresh, remove_blacklist, workers, no_center_crop, 72 | override_meta_imsize, model_subset, tpu_model_dir, tpu_weights_dir): 73 | """Run bencmarks for imported models 74 | 75 | Args: 76 | gpus (str): comma separated gpu device identifiers 77 | refresh (bool): whether to overwrite the results of existing runs 78 | remove_blacklist (bool): whether to remove images from the 2014 ILSVRC 79 | blacklist from the validation images used in the benchmark 80 | workers (int): the number of workers 81 | """ 82 | model_loader = load_model 83 | 84 | # Select models (and their batch sizes) to include in the benchmark. 85 | if model_subset == "all": 86 | raise NotImplementedError("TODO: update to use config dicts") 87 | model_list = [ 88 | ('alexnet_pt_mcn', 256), 89 | ('squeezenet1_0_pt_mcn', 128), 90 | ('squeezenet1_1_pt_mcn', 128), 91 | ('vgg11_pt_mcn', 128), 92 | ('vgg13_pt_mcn', 92), 93 | ('vgg16_pt_mcn', 32), 94 | ('vgg19_pt_mcn', 24), 95 | ('resnet18_pt_mcn', 50), 96 | ('resnet34_pt_mcn', 50), 97 | ('resnet50_pt_mcn', 32), 98 | ('resnet101_pt_mcn', 24), 99 | ('resnet152_pt_mcn', 20), 100 | ('inception_v3_pt_mcn', 64), 101 | ("densenet121_pt_mcn", 50), 102 | ("densenet161_pt_mcn", 32), 103 | ("densenet169_pt_mcn", 32), 104 | ("densenet201_pt_mcn", 32), 105 | ('imagenet_matconvnet_alex', 256), 106 | ('imagenet_matconvnet_vgg_f_dag', 128), 107 | ('imagenet_matconvnet_vgg_m_dag', 128), 108 | ('imagenet_matconvnet_vgg_verydeep_16_dag', 32), 109 | ] 110 | elif model_subset == "pt_mcn": 111 | model_list = [('resnet50_pt_mcn', 32)] 112 | elif model_subset == "pt_tpu": 113 | model_loader = load_tpu_converted_model 114 | model_list = [{ 115 | "model_name": 'resnet50', 116 | "cache_name": 'resnet50_from_tpu', 117 | "batch_size": 32, 118 | "model_def_path": Path(tpu_model_dir) / "resnet_models.py", 119 | "weights_path": Path(tpu_weights_dir) / "resnet50_ported.pth"}, 120 | ] 121 | 122 | if not os.path.exists(CACHE_DIR): 123 | os.makedirs(CACHE_DIR) 124 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 125 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpus) 126 | 127 | opts = { 128 | 'data_dir': ILSVRC_DIR, 129 | 'refresh_cache': refresh, 130 | 'remove_blacklist': remove_blacklist, 131 | 'num_workers': workers, 132 | 'center_crop': not no_center_crop, 133 | 'override_meta_imsize': override_meta_imsize} 134 | 135 | for model_config in model_list: 136 | model_name = model_config["model_name"] 137 | cache_name = '{}/{}'.format(CACHE_DIR, model_config["cache_name"]) 138 | if no_center_crop: 139 | cache_name = '{}-no-center-crop'.format(cache_name) 140 | opts['res_cache'] = '{}.pth'.format(cache_name) 141 | model = model_loader(**model_config) 142 | print('benchmarking {}'.format(model_name)) 143 | imagenet_benchmark(model, batch_size=model_config["batch_size"], **opts) 144 | 145 | 146 | parser = argparse.ArgumentParser(description='Run PyTorch benchmarks.') 147 | parser.add_argument('--gpus', nargs='?', dest='gpus', help='select gpu device id') 148 | parser.add_argument('--workers', type=int, default=4, dest='workers', 149 | help='select number of workers') 150 | parser.add_argument('--model_subset', type=str, default="all", help='eval subset') 151 | parser.add_argument('--refresh', dest='refresh', action='store_true', 152 | help='refresh results cache') 153 | parser.add_argument('--no_center_crop', dest='no_center_crop', action='store_true', 154 | help='prevent center cropping') 155 | parser.add_argument('--override_meta_imsize', dest='override_meta_imsize', 156 | action='store_true', help='allow arbitrary resizing of input image') 157 | parser.add_argument( 158 | '--remove-blacklist', dest='remove_blacklist', action='store_true', 159 | help=('evaluate on 2012 validation subset without including' 160 | 'the 2014 list of blacklisted images (only applies to' 161 | 'imagenet models)')) 162 | parser.add_argument( 163 | "--tpu_model_dir", 164 | default=Path.home() / "coding/libs/tf/tpu-fork/models/official/resnet/tf2pytorch", 165 | ) 166 | parser.add_argument( 167 | "--tpu_weights_dir", 168 | default=Path.home() / "data/models/tensorflow/tpu/conversion_dir", 169 | ) 170 | parser.set_defaults(gpus=None) 171 | parser.set_defaults(refresh=False) 172 | parser.set_defaults(remove_blacklist=False) 173 | parser.set_defaults(no_center_crop=False) 174 | parser.set_defaults(override_meta_imsize=False) 175 | args = parser.parse_args() 176 | 177 | if __name__ == '__main__': 178 | run_benchmarks( 179 | gpus=args.gpus, 180 | refresh=args.refresh, 181 | remove_blacklist=args.remove_blacklist, 182 | workers=args.workers, 183 | no_center_crop=args.no_center_crop, 184 | override_meta_imsize=args.override_meta_imsize, 185 | model_subset=args.model_subset, 186 | tpu_model_dir=args.tpu_model_dir, 187 | tpu_weights_dir=args.tpu_weights_dir, 188 | ) 189 | -------------------------------------------------------------------------------- /fer2013/fer_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains two data loaders. One is for the Fer 2013 emotion dataset 3 | described in the paper: 4 | 5 | Goodfellow, Ian J., et al. "Challenges in representation learning: 6 | A report on three machine learning contests." International Conference on 7 | Neural Information Processing. Springer, Berlin, Heidelberg, 2013. 8 | https://arxiv.org/abs/1307.0414 9 | 10 | The second is for the "Fer 2013 plus" dataset, described in the paper: 11 | 12 | Barsoum, Emad, Cha Zhang, Cristian Canton Ferrer, and Zhengyou Zhang. 13 | "Training deep networks for facial expression recognition with 14 | crowd-sourced label distribution." In Proceedings of the 18th ACM 15 | International Conference on Multimodal Interaction, pp. 279-283. ACM, 2016. 16 | https://arxiv.org/abs/1608.01041 17 | 18 | """ 19 | 20 | import os 21 | import csv 22 | import tqdm 23 | import torch 24 | import pickle 25 | import numpy as np 26 | from copy import deepcopy 27 | import PIL.Image 28 | 29 | from os.path import join as pjoin 30 | 31 | class Fer2013Dataset(torch.utils.data.Dataset): 32 | """Dataset class helper for the Fer2013 dataset. Converts the csv 33 | files used to distribute the dataset into a pickle format 34 | 35 | Args: 36 | data_dir (str): Directory where the original csv files distributed 37 | with the dataset are found. 38 | mode (str): The subset of the dataset to use 39 | transform (torch.transforms): a transformaton that can be applied 40 | to images on loading 41 | include_train (bool) [False]: whether to include the training set 42 | in the loader (it's not required for benchmarking purposes). 43 | """ 44 | def __init__(self, data_dir, mode='val', transform=None, 45 | include_train=False): 46 | self.data_dir = data_dir 47 | self.mode = mode 48 | self.include_train = include_train 49 | self._transform = transform 50 | self.pkl_path = pjoin(data_dir, 'pytorch', 'data.pkl') 51 | 52 | if not os.path.isfile(self.pkl_path): 53 | self.prepare_data() 54 | 55 | with open(self.pkl_path, 'rb') as f: 56 | self.data = pickle.load(f) 57 | 58 | def __getitem__(self, index): 59 | """Retreive the sample at the given index. 60 | 61 | Args: 62 | index (int): the index of the sample to be retrieved 63 | 64 | Returns: 65 | (torch.Tensor): the image 66 | (int): the label 67 | """ 68 | im_data = self.data['images'][self.mode][index].astype('uint8') 69 | image = PIL.Image.fromarray(im_data) 70 | label = self.data['labels'][self.mode][index] 71 | if self._transform is not None: 72 | image = self._transform(image) 73 | return image, label 74 | 75 | def prepare_data(self): 76 | """Transform raw data from csv format into a dict. 77 | 78 | Args: 79 | phase, str: 'train'/'val'/'test'. 80 | size, int. Size of the dataset. 81 | """ 82 | print('preparing data...') 83 | with open(pjoin(self.data_dir, 'fer2013.csv'), 'r') as f: 84 | reader = csv.reader(f, delimiter=',') 85 | next(reader) # skip header 86 | rows = [row for row in reader] 87 | 88 | train_ims, val_ims, test_ims = [], [], [] 89 | train_labels, val_labels, test_labels = [], [], [] 90 | for row in tqdm.tqdm(rows): 91 | subset = row[2] 92 | raw_im = np.array([int(x) for x in row[1].split(' ')]) 93 | im = np.repeat(raw_im.reshape(48,48)[:,:,np.newaxis], 3, axis=2) 94 | if subset == 'Training': 95 | train_labels.append(int(row[0])) 96 | train_ims.append(im) 97 | elif subset == 'PublicTest': 98 | val_labels.append(int(row[0])) 99 | val_ims.append(im) 100 | elif subset == 'PrivateTest': 101 | test_labels.append(int(row[0])) 102 | test_ims.append(im) 103 | else: 104 | raise ValueError('unrecognised subset: {}'.format(subset)) 105 | 106 | data = {'labels': {}, 'images': {}} 107 | data['labels']['val'] = np.array(val_labels) 108 | data['labels']['test'] = np.array(test_labels) 109 | 110 | data['images']['val'] = np.array(val_ims) 111 | data['images']['test'] = np.array(test_ims) 112 | 113 | if self.include_train: 114 | data['labels']['train'] = np.array(train_labels) 115 | data['images']['train'] = np.array(train_ims) 116 | 117 | for key in 'images', 'labels': 118 | assert len(data[key]['val']) == 3589, 'unexpected length' 119 | assert len(data[key]['test']) == 3589, 'unexpected length' 120 | if self.include_train: 121 | assert len(data[key]['train']) == 28709, 'unexpected length' 122 | 123 | if not os.path.exists(os.path.dirname(self.pkl_path)): 124 | os.makedirs(os.path.dirname(self.pkl_path)) 125 | 126 | with open(self.pkl_path, 'wb') as f: 127 | pickle.dump(data, f) 128 | 129 | def __len__(self): 130 | """Return the total number of images in the datset. 131 | 132 | Return: 133 | (int) the number of images. 134 | """ 135 | return self.data['labels'][self.mode].size 136 | 137 | class Fer2013PlusDataset(Fer2013Dataset): 138 | """Dataset class helper for the Fer2013plus dataset. Converts the csv 139 | files used to distribute the dataset into a pickle format 140 | """ 141 | 142 | def __init__(self, *args, **kwargs): 143 | super(Fer2013PlusDataset, self).__init__(*args, **kwargs) 144 | self.update_labels() 145 | 146 | def update_labels(self): 147 | """Update dataset to use FerPlus labels, rather Fer2013 dataset labels 148 | 149 | Aim to reproduce the Microsoft CNTK cleaning process. These are based 150 | on some heuristics about the level of ambiguity in the annotator labels 151 | that should be tolerated to ensure that the dataset is moderately 152 | clearn. We generate hard labels, rather than soft ones for evaluation. 153 | """ 154 | with open(pjoin(self.data_dir, 'fer2013new.csv'), 'r') as f: 155 | reader = csv.reader(f, delimiter=',') 156 | next(reader) # skip header 157 | rows = [row for row in reader] 158 | 159 | set_map = {'Training': 1, 'PublicTest': 2, 'PrivateTest': 3} 160 | sets = np.atleast_2d([set_map[x[0]] for x in rows]).T 161 | labels = [np.atleast_2d([int(x) for x in r[2:]]) for r in rows] 162 | labels = np.concatenate(labels, axis=0) 163 | orig_labels = deepcopy(labels) 164 | outliers = (labels <=1) 165 | labels[outliers] = 0 # drop outliers 166 | dropped = 1 - (labels.sum() / orig_labels.sum()) 167 | print('dropped {:.1f}%% of votes as outliers'.format(dropped * 100)) 168 | num_votes = np.sum(labels, 1) 169 | # following CNTK processing - there are three reasons to drop examples: 170 | # (1) If the majority votes for either "unknown-face" or "not-face" 171 | # (2) If more than three votes share the maximum voting value 172 | # (3) If the max votes do not account for more than half of the votes 173 | to_drop = np.zeros((labels.shape[0], 1)) 174 | for ii in tqdm.tqdm(range(labels.shape[0])): 175 | max_vote = np.max(labels[ii,:]) 176 | max_vote_emos = np.where(labels[ii,:] == max_vote)[0] 177 | drop = any([x in [8, 9] for x in max_vote_emos]) 178 | num_max_votes = max_vote_emos.size 179 | drop = drop or num_max_votes >= 3 180 | drop = drop or (num_max_votes * max_vote <= 0.5 * num_votes[ii]) 181 | to_drop[ii] = drop 182 | 183 | # TODO(samuel): verify that this is correct 184 | assert to_drop.sum() == 3079, 'unexpected number of dropped votes' 185 | # NOTE: use slightly different "keep" indicies, depending on how data 186 | # is accessed. 187 | val_keep_ims = np.logical_not(to_drop[sets == 2]) 188 | test_keep_ims = np.logical_not(to_drop[sets == 3]) 189 | val_keep_labels = np.logical_and(sets == 2, 190 | np.logical_not(to_drop)).flatten() 191 | test_keep_labels = np.logical_and(sets == 3, 192 | np.logical_not(to_drop)).flatten() 193 | val_labels = labels[val_keep_labels, :] 194 | test_labels = labels[test_keep_labels, :] 195 | 196 | # update images in place 197 | self.data['images']['val'] = \ 198 | self.data['images']['val'][val_keep_ims,:,:,:] 199 | self.data['images']['test'] = \ 200 | self.data['images']['test'][test_keep_ims,:,:,:] 201 | 202 | # create "hard labels" with voting 203 | self.data['labels']['val'] = np.argmax(val_labels, 1) 204 | self.data['labels']['test'] = np.argmax(test_labels, 1) 205 | -------------------------------------------------------------------------------- /lfw_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """LFW benchmark for face verification. This is designed to be used as a 3 | sanity check for imported models. 4 | 5 | Example Invocation: 6 | ipy lfw_eval.py 7 | ipy lfw_eval.py -- --limit 200 --model_name vgg_face_dag 8 | ipy lfw_eval.py -- --model_name vgg_m_face_bn_dag 9 | 10 | This code is primarily based on the code of https://github.com/clcarwin. The 11 | original code can be found here: 12 | https://github.com/clcarwin/sphereface_pytorch 13 | 14 | License from original codebase: 15 | 16 | MIT License 17 | 18 | Copyright (c) 2017 carwin 19 | 20 | Permission is hereby granted, free of charge, to any person obtaining a copy 21 | of this software and associated documentation files (the "Software"), to deal 22 | in the Software without restriction, including without limitation the rights 23 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 24 | copies of the Software, and to permit persons to whom the Software is 25 | furnished to do so, subject to the following conditions: 26 | 27 | The above copyright notice and this permission notice shall be included in all 28 | copies or substantial portions of the Software. 29 | 30 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 31 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 32 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 33 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 34 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 35 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 36 | SOFTWARE. 37 | 38 | """ 39 | from __future__ import print_function 40 | import torch 41 | import sys 42 | from PIL import Image 43 | import tqdm 44 | import cv2 45 | import argparse 46 | import numpy as np 47 | from scipy.optimize import brentq 48 | from scipy.interpolate import interp1d 49 | from sklearn.metrics import roc_curve 50 | import zipfile 51 | import os 52 | import six 53 | import utils.benchmark_helpers 54 | from matlab_cp2tform import get_similarity_transform_for_cv2 55 | 56 | from torch.autograd import Variable 57 | 58 | torch.backends.cudnn.bencmark = True 59 | MODEL_DIR = os.path.expanduser("~/data/models/pytorch/mcn_imports/") 60 | 61 | # import matplotlib 62 | # matplotlib.use('Agg') 63 | # import net_sphere 64 | 65 | 66 | def alignment(src_img, src_pts, output_size=(96, 112)): 67 | """Warp a face image so that its features align with a canoncial 68 | reference set of landmarks. The alignment is performed with an 69 | affine warp 70 | 71 | Args: 72 | src_img (ndarray): an HxWx3 RGB containing a face 73 | src_pts (ndarray): a 5x2 array of landmark locations 74 | output_size (tuple): the dimensions (oH, oW) of the output image 75 | 76 | Returns: 77 | (ndarray): an (oH x oW x 3) warped RGB image. 78 | """ 79 | ref_pts = [ 80 | [30.2946, 51.6963], 81 | [65.5318, 51.5014], 82 | [48.0252, 71.7366], 83 | [33.5493, 92.3655], 84 | [62.7299, 92.2041], 85 | ] 86 | src_pts = np.array(src_pts).reshape(5, 2) 87 | s = np.array(src_pts).astype(np.float32) 88 | r = np.array(ref_pts).astype(np.float32) 89 | tfm = get_similarity_transform_for_cv2(s, r) 90 | face_img = cv2.warpAffine(src_img, tfm, output_size) 91 | return face_img 92 | 93 | 94 | def KFold(num_pairs, n_folds, shuffle=False): 95 | folds = [] 96 | base = list(range(num_pairs)) 97 | for i in range(n_folds): 98 | test = base[i * num_pairs // n_folds : (i + 1) * num_pairs // n_folds] 99 | train = list(set(base) - set(test)) 100 | folds.append([train, test]) 101 | return folds 102 | 103 | 104 | def eval_acc(threshold, diff): 105 | y_true = [] 106 | y_predict = [] 107 | for d in diff: 108 | same = 1 if float(d[2]) > threshold else 0 109 | y_predict.append(same) 110 | y_true.append(int(d[3])) 111 | y_true = np.array(y_true) 112 | y_predict = np.array(y_predict) 113 | accuracy = 1.0 * np.count_nonzero(y_true == y_predict) / len(y_true) 114 | return accuracy 115 | 116 | 117 | def find_best_threshold(thresholds, predicts): 118 | best_threshold = best_acc = 0 119 | for threshold in thresholds: 120 | accuracy = eval_acc(threshold, predicts) 121 | if accuracy >= best_acc: 122 | best_acc = accuracy 123 | best_threshold = threshold 124 | return best_threshold 125 | 126 | 127 | def modify_to_return_embeddings(net, model_name): 128 | """Modify the structure of the network (if necessary), to ensure that it 129 | returns embeddings (the features from the penultimate layer of the network, 130 | just before the classifier) 131 | 132 | Args: 133 | net (nn.Module): the network to be modified 134 | model_name (str): the name of the network 135 | 136 | Return: 137 | (nn.Module): the modified network 138 | 139 | NOTE: 140 | We use `nn.Sequential` to simluate Identity (i.e. no-op). 141 | """ 142 | if model_name in ["vgg_face_dag", "vgg_m_face_bn_dag"]: 143 | net.fc8 = torch.nn.Sequential() 144 | else: 145 | msg = "{} not yet supported".format(model_name) 146 | raise NotImplementedError(msg) 147 | return net 148 | 149 | 150 | def load_model(model_name): 151 | """Load imoprted PyTorch model by name 152 | 153 | Args: 154 | model_name (str): the name of the model to be loaded 155 | 156 | Return: 157 | nn.Module: the loaded network 158 | """ 159 | model_def_path = os.path.join(MODEL_DIR, model_name + ".py") 160 | weights_path = os.path.join(MODEL_DIR, model_name + ".pth") 161 | if six.PY3: 162 | import importlib.util 163 | 164 | spec = importlib.util.spec_from_file_location(model_name, 165 | model_def_path) 166 | mod = importlib.util.module_from_spec(spec) 167 | spec.loader.exec_module(mod) 168 | else: 169 | import importlib 170 | dirname = os.path.dirname(model_def_path) 171 | sys.path.insert(0, dirname) 172 | module_name = os.path.splitext(os.path.basename(model_def_path))[0] 173 | mod = importlib.import_module(module_name) 174 | func = getattr(mod, model_name) 175 | net = func(weights_path=weights_path) 176 | net = modify_to_return_embeddings(net, model_name) 177 | return net 178 | 179 | 180 | parser = argparse.ArgumentParser(description="PyTorch sphereface lfw") 181 | parser.add_argument("--net", default="sphere20a", type=str) 182 | parser.add_argument("--lfw", default="data/lfw.zip", type=str) 183 | parser.add_argument("--limit", default=None, type=int) 184 | parser.add_argument("--model_name", default="resnet50_scratch_dag", type=str) 185 | parser.add_argument("--use_flipped", action="store_true") 186 | args = parser.parse_args() 187 | 188 | predicts = [] 189 | net = load_model(args.model_name) 190 | net.cuda() 191 | net.eval() 192 | net.feature = True 193 | 194 | zfile = zipfile.ZipFile(args.lfw) 195 | 196 | landmark = {} 197 | with open("data/lfw_landmark.txt") as f: 198 | landmark_lines = f.readlines() 199 | for line in landmark_lines: 200 | ll = line.replace("\n", "").split("\t") 201 | landmark[ll[0]] = [int(k) for k in ll[1:]] 202 | 203 | with open("data/pairs.txt") as f: 204 | next(f) 205 | pairs_lines = f.readlines() 206 | 207 | orig_pairs = 6000 208 | if args.limit: 209 | num_pairs = min(orig_pairs, args.limit) 210 | else: 211 | num_pairs = orig_pairs 212 | 213 | 214 | def extract_features(net, ims): 215 | """Extract penultimate features from network 216 | 217 | Args: 218 | net (nn.Module): the network to be used to compute features 219 | ims (torch.Tensor): the data to be processed 220 | 221 | NOTE: 222 | Pretrained networks often vary in the manner in which their outputs 223 | are returned. For example, some return the penultimate features as 224 | a second argument, while others need to be modified directly and will 225 | return these features as their only output. 226 | """ 227 | outs = net(ims) 228 | if isinstance(outs, list): 229 | outs = outs[1] 230 | features = outs.data 231 | return features 232 | 233 | 234 | for i in tqdm.tqdm(range(num_pairs)): 235 | p = pairs_lines[i].replace("\n", "").split("\t") 236 | 237 | if 3 == len(p): 238 | sameflag = 1 239 | name1 = p[0] + "/" + p[0] + "_" + "{:04}.jpg".format(int(p[1])) 240 | name2 = p[0] + "/" + p[0] + "_" + "{:04}.jpg".format(int(p[2])) 241 | if 4 == len(p): 242 | sameflag = 0 243 | name1 = p[0] + "/" + p[0] + "_" + "{:04}.jpg".format(int(p[1])) 244 | name2 = p[2] + "/" + p[2] + "_" + "{:04}.jpg".format(int(p[3])) 245 | 246 | im1 = cv2.imdecode(np.frombuffer(zfile.read(name1), np.uint8), 1) 247 | img1_aligned = alignment(im1, landmark[name1]) 248 | im2 = cv2.imdecode(np.frombuffer(zfile.read(name2), np.uint8), 1) 249 | img2_aligned = alignment(im2, landmark[name2]) 250 | 251 | # convert images to PIL to use builtin transforms 252 | # import matplotlib.pyplot as plt 253 | img1 = cv2.cvtColor(img1_aligned, cv2.COLOR_BGR2RGB) 254 | img1 = Image.fromarray(img1) 255 | img2 = cv2.cvtColor(img2_aligned, cv2.COLOR_BGR2RGB) 256 | img2 = Image.fromarray(img2) 257 | 258 | meta = net.meta 259 | preproc_transforms = utils.benchmark_helpers.compose_transforms( 260 | meta=meta, 261 | center_crop=False 262 | ) 263 | 264 | imglist = [ 265 | img1, 266 | img1.transpose(Image.FLIP_LEFT_RIGHT), 267 | img2, 268 | img2.transpose(Image.FLIP_LEFT_RIGHT), 269 | ] 270 | 271 | for i in range(len(imglist)): 272 | imglist[i] = preproc_transforms(imglist[i]) 273 | 274 | ims = torch.stack(imglist, dim=0) 275 | ims = Variable(ims).cuda() 276 | outs = net(ims) 277 | features = extract_features(net, ims) 278 | f1, f2 = features[0].squeeze(), features[2].squeeze() 279 | cosdistance = f1.dot(f2) / (f1.norm() * f2.norm() + 1e-5) 280 | pred = [name1, name2, cosdistance.item(), sameflag] 281 | predicts.append(pred) 282 | 283 | 284 | def compute_eer(labels, scores): 285 | """Compute the Equal Error Rate (EER) from the predictions and scores. 286 | 287 | Args: 288 | labels (list[int]): values indicating whether the ground truth 289 | value is positive (1) or negative (0). 290 | scores (list[float]): the confidence of the prediction that the 291 | given sample is a positive. 292 | 293 | Return: 294 | (float, thresh): the Equal Error Rate and the corresponding threshold 295 | 296 | NOTES: 297 | The EER corresponds to the point on the ROC curve that intersects 298 | the line given by the equation 1 = FPR + TPR. 299 | 300 | The implementation of the function was taken from here: 301 | https://yangcha.github.io/EER-ROC/ 302 | """ 303 | fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1) 304 | eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 305 | thresh = interp1d(fpr, thresholds)(eer) 306 | return eer, thresh 307 | 308 | 309 | accuracy = [] 310 | thd = [] 311 | folds = KFold(num_pairs=num_pairs, n_folds=10, shuffle=False) 312 | thresholds = np.arange(-1.0, 1.0, 0.005) 313 | predicts = np.array(predicts) 314 | eers = [] 315 | eer_thresholds = [] 316 | for idx, (train, test) in tqdm.tqdm(enumerate(folds)): 317 | best_thresh = find_best_threshold(thresholds, predicts[train]) 318 | accuracy.append(eval_acc(best_thresh, predicts[test])) 319 | thd.append(best_thresh) 320 | scores = [float(x[2]) for x in predicts[test]] 321 | labels = [int(x[3]) for x in predicts[test]] 322 | eer, thresh = compute_eer(labels=labels, scores=scores) 323 | eers.append(eer) 324 | eer_thresholds.append(thresh) 325 | 326 | msg = "LFWACC={:.4f} std={:.4f} thd={:.4f}" 327 | print(msg.format(np.mean(accuracy), np.std(accuracy), np.mean(thd))) 328 | msg = "EER={:.4f} std={:.4f} thd={:.4f}" 329 | print(msg.format(np.mean(eers), np.std(eers), np.mean(eer_thresholds))) 330 | 331 | # Add blanks to prevent tqdm from swallowing the summary 332 | print("") 333 | -------------------------------------------------------------------------------- /matlab_cp2tform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jul 11 06:54:28 2017 4 | 5 | @author: zhaoyafei 6 | """ 7 | 8 | import numpy as np 9 | from numpy.linalg import inv, norm, lstsq 10 | from numpy.linalg import matrix_rank as rank 11 | 12 | 13 | """ 14 | Introduction: 15 | ---------- 16 | numpy implemetation form matlab function CP2TFORM(...) 17 | with 'transformtype': 18 | 1) 'nonreflective similarity' 19 | 2) 'similarity' 20 | """ 21 | """ 22 | 23 | 24 | MATLAB code: 25 | ---------- 26 | %-------------------------------------- 27 | % Function findNonreflectiveSimilarity 28 | % 29 | function [trans, output] = findNonreflectiveSimilarity(uv,xy,options) 30 | % 31 | % For a nonreflective similarity: 32 | % 33 | % let sc = s*cos(theta) 34 | % let ss = s*sin(theta) 35 | % 36 | % [ sc -ss 37 | % [u v] = [x y 1] * ss sc 38 | % tx ty] 39 | % 40 | % There are 4 unknowns: sc,ss,tx,ty. 41 | % 42 | % Another way to write this is: 43 | % 44 | % u = [x y 1 0] * [sc 45 | % ss 46 | % tx 47 | % ty] 48 | % 49 | % v = [y -x 0 1] * [sc 50 | % ss 51 | % tx 52 | % ty] 53 | % 54 | % With 2 or more correspondence points we can combine the u equations and 55 | % the v equations for one linear system to solve for sc,ss,tx,ty. 56 | """ 57 | """ 58 | % 59 | % [ u1 ] = [ x1 y1 1 0 ] * [sc] 60 | % [ u2 ] [ x2 y2 1 0 ] [ss] 61 | % [ ... ] [ ... ] [tx] 62 | % [ un ] [ xn yn 1 0 ] [ty] 63 | % [ v1 ] [ y1 -x1 0 1 ] 64 | % [ v2 ] [ y2 -x2 0 1 ] 65 | % [ ... ] [ ... ] 66 | % [ vn ] [ yn -xn 0 1 ] 67 | % 68 | % Or rewriting the above matrix equation: 69 | % U = X * r, where r = [sc ss tx ty]' 70 | """ 71 | """ 72 | so r = X \ U. 73 | """ 74 | """ 75 | 76 | K = options.K; 77 | M = size(xy,1); 78 | x = xy(:,1); 79 | y = xy(:,2); 80 | X = [x y ones(M,1) zeros(M,1); 81 | y -x zeros(M,1) ones(M,1) ]; 82 | 83 | u = uv(:,1); 84 | v = uv(:,2); 85 | U = [u; v]; 86 | 87 | % We know that X * r = U 88 | if rank(X) >= 2*K 89 | r = X \ U; 90 | else 91 | error(message('images:cp2tform:twoUniquePointsReq')) 92 | end 93 | 94 | sc = r(1); 95 | ss = r(2); 96 | tx = r(3); 97 | ty = r(4); 98 | 99 | Tinv = [sc -ss 0; 100 | ss sc 0; 101 | tx ty 1]; 102 | 103 | T = inv(Tinv); 104 | T(:,3) = [0 0 1]'; 105 | 106 | trans = maketform('affine', T); 107 | output = []; 108 | 109 | %------------------------- 110 | % Function findSimilarity 111 | % 112 | function [trans, output] = findSimilarity(uv,xy,options) 113 | % 114 | % The similarities are a superset of the nonreflective similarities as they may 115 | % also include reflection. 116 | % 117 | % let sc = s*cos(theta) 118 | % let ss = s*sin(theta) 119 | % 120 | % [ sc -ss 121 | % [u v] = [x y 1] * ss sc 122 | % tx ty] 123 | % 124 | % OR 125 | % 126 | % [ sc ss 127 | % [u v] = [x y 1] * ss -sc 128 | % tx ty] 129 | % 130 | % Algorithm: 131 | % 1) Solve for trans1, a nonreflective similarity. 132 | % 2) Reflect the xy data across the Y-axis, 133 | % and solve for trans2r, also a nonreflective similarity. 134 | % 3) Transform trans2r to trans2, undoing the reflection done in step 2. 135 | % 4) Use TFORMFWD to transform uv using both trans1 and trans2, 136 | % and compare the results, Returnsing the transformation corresponding 137 | % to the smaller L2 norm. 138 | 139 | % Need to reset options.K to prepare for calls to findNonreflectiveSimilarity. 140 | % This is safe because we already checked that there are enough point pairs. 141 | options.K = 2; 142 | 143 | % Solve for trans1 144 | [trans1, output] = findNonreflectiveSimilarity(uv,xy,options); 145 | 146 | 147 | % Solve for trans2 148 | 149 | % manually reflect the xy data across the Y-axis 150 | xyR = xy; 151 | xyR(:,1) = -1*xyR(:,1); 152 | 153 | trans2r = findNonreflectiveSimilarity(uv,xyR,options); 154 | 155 | % manually reflect the tform to undo the reflection done on xyR 156 | TreflectY = [-1 0 0; 157 | 0 1 0; 158 | 0 0 1]; 159 | trans2 = maketform('affine', trans2r.tdata.T * TreflectY); 160 | 161 | 162 | % Figure out if trans1 or trans2 is better 163 | xy1 = tformfwd(trans1,uv); 164 | norm1 = norm(xy1-xy); 165 | 166 | xy2 = tformfwd(trans2,uv); 167 | norm2 = norm(xy2-xy); 168 | 169 | if norm1 <= norm2 170 | trans = trans1; 171 | else 172 | trans = trans2; 173 | end 174 | """ 175 | 176 | 177 | class MatlabCp2tormException(Exception): 178 | def __str__(self): 179 | return "In File {}:{}".format(__file__, super.__str__(self)) 180 | 181 | 182 | def tformfwd(trans, uv): 183 | """ 184 | Function: 185 | ---------- 186 | apply affine transform 'trans' to uv 187 | 188 | Parameters: 189 | ---------- 190 | @trans: 3x3 np.array 191 | transform matrix 192 | @uv: Kx2 np.array 193 | each row is a pair of coordinates (x, y) 194 | 195 | Returns: 196 | ---------- 197 | @xy: Kx2 np.array 198 | each row is a pair of transformed coordinates (x, y) 199 | """ 200 | uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) 201 | xy = np.dot(uv, trans) 202 | xy = xy[:, 0:-1] 203 | return xy 204 | 205 | 206 | def tforminv(trans, uv): 207 | """ 208 | Function: 209 | ---------- 210 | apply the inverse of affine transform 'trans' to uv 211 | 212 | Parameters: 213 | ---------- 214 | @trans: 3x3 np.array 215 | transform matrix 216 | @uv: Kx2 np.array 217 | each row is a pair of coordinates (x, y) 218 | 219 | Returns: 220 | ---------- 221 | @xy: Kx2 np.array 222 | each row is a pair of inverse-transformed coordinates (x, y) 223 | """ 224 | Tinv = inv(trans) 225 | xy = tformfwd(Tinv, uv) 226 | return xy 227 | 228 | 229 | def findNonreflectiveSimilarity(uv, xy, options=None): 230 | """ 231 | Function: 232 | ---------- 233 | Find Non-reflective Similarity Transform Matrix 'trans': 234 | u = uv[:, 0] 235 | v = uv[:, 1] 236 | x = xy[:, 0] 237 | y = xy[:, 1] 238 | [x, y, 1] = [u, v, 1] * trans 239 | 240 | Parameters: 241 | ---------- 242 | @uv: Kx2 np.array 243 | source points each row is a pair of coordinates (x, y) 244 | @xy: Kx2 np.array 245 | each row is a pair of inverse-transformed 246 | @option: not used, keep it as None 247 | 248 | Returns: 249 | @trans: 3x3 np.array 250 | transform matrix from uv to xy 251 | @trans_inv: 3x3 np.array 252 | inverse of trans, transform matrix from xy to uv 253 | 254 | Matlab: 255 | ---------- 256 | % For a nonreflective similarity: 257 | % 258 | % let sc = s*cos(theta) 259 | % let ss = s*sin(theta) 260 | % 261 | % [ sc -ss 262 | % [u v] = [x y 1] * ss sc 263 | % tx ty] 264 | % 265 | % There are 4 unknowns: sc,ss,tx,ty. 266 | % 267 | % Another way to write this is: 268 | % 269 | % u = [x y 1 0] * [sc 270 | % ss 271 | % tx 272 | % ty] 273 | % 274 | % v = [y -x 0 1] * [sc 275 | % ss 276 | % tx 277 | % ty] 278 | % 279 | % With 2 or more correspondence points we can combine the u equations and 280 | % the v equations for one linear system to solve for sc,ss,tx,ty. 281 | % 282 | % [ u1 ] = [ x1 y1 1 0 ] * [sc] 283 | % [ u2 ] [ x2 y2 1 0 ] [ss] 284 | % [ ... ] [ ... ] [tx] 285 | % [ un ] [ xn yn 1 0 ] [ty] 286 | % [ v1 ] [ y1 -x1 0 1 ] 287 | % [ v2 ] [ y2 -x2 0 1 ] 288 | % [ ... ] [ ... ] 289 | % [ vn ] [ yn -xn 0 1 ] 290 | % 291 | % Or rewriting the above matrix equation: 292 | % U = X * r, where r = [sc ss tx ty]' 293 | % so r = X \ U. 294 | % 295 | """ 296 | options = {"K": 2} 297 | 298 | K = options["K"] 299 | M = xy.shape[0] 300 | x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 301 | y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 302 | # print '--->x, y:\n', x, y 303 | 304 | tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) 305 | tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) 306 | X = np.vstack((tmp1, tmp2)) 307 | # print '--->X.shape: ', X.shape 308 | # print 'X:\n', X 309 | 310 | u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 311 | v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 312 | U = np.vstack((u, v)) 313 | # print '--->U.shape: ', U.shape 314 | # print 'U:\n', U 315 | 316 | # We know that X * r = U 317 | if rank(X) >= 2 * K: 318 | r, _, _, _ = lstsq(X, U) 319 | r = np.squeeze(r) 320 | else: 321 | raise Exception("cp2tform:twoUniquePointsReq") 322 | 323 | # print '--->r:\n', r 324 | 325 | sc = r[0] 326 | ss = r[1] 327 | tx = r[2] 328 | ty = r[3] 329 | 330 | Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) 331 | 332 | # print '--->Tinv:\n', Tinv 333 | 334 | T = inv(Tinv) 335 | # print '--->T:\n', T 336 | 337 | T[:, 2] = np.array([0, 0, 1]) 338 | 339 | return T, Tinv 340 | 341 | 342 | def findSimilarity(uv, xy, options=None): 343 | """ 344 | Function: 345 | ---------- 346 | Find Reflective Similarity Transform Matrix 'trans': 347 | u = uv[:, 0] 348 | v = uv[:, 1] 349 | x = xy[:, 0] 350 | y = xy[:, 1] 351 | [x, y, 1] = [u, v, 1] * trans 352 | 353 | Parameters: 354 | ---------- 355 | @uv: Kx2 np.array 356 | source points each row is a pair of coordinates (x, y) 357 | @xy: Kx2 np.array 358 | each row is a pair of inverse-transformed 359 | @option: not used, keep it as None 360 | 361 | Returns: 362 | ---------- 363 | @trans: 3x3 np.array 364 | transform matrix from uv to xy 365 | @trans_inv: 3x3 np.array 366 | inverse of trans, transform matrix from xy to uv 367 | 368 | Matlab: 369 | ---------- 370 | % The similarities are a superset of the nonreflective similarities as they may 371 | % also include reflection. 372 | % 373 | % let sc = s*cos(theta) 374 | % let ss = s*sin(theta) 375 | % 376 | % [ sc -ss 377 | % [u v] = [x y 1] * ss sc 378 | % tx ty] 379 | % 380 | % OR 381 | % 382 | % [ sc ss 383 | % [u v] = [x y 1] * ss -sc 384 | % tx ty] 385 | % 386 | % Algorithm: 387 | % 1) Solve for trans1, a nonreflective similarity. 388 | % 2) Reflect the xy data across the Y-axis, 389 | % and solve for trans2r, also a nonreflective similarity. 390 | % 3) Transform trans2r to trans2, undoing the reflection done in step 2. 391 | % 4) Use TFORMFWD to transform uv using both trans1 and trans2, 392 | % and compare the results, Returnsing the transformation corresponding 393 | % to the smaller L2 norm. 394 | 395 | % Need to reset options.K to prepare for calls to 396 | findNonreflectiveSimilarity. 397 | % This is safe because we already checked that there are enough point 398 | pairs. 399 | """ 400 | options = {"K": 2} 401 | 402 | # uv = np.array(uv) 403 | # xy = np.array(xy) 404 | 405 | # Solve for trans1 406 | trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) 407 | 408 | # Solve for trans2 409 | 410 | # manually reflect the xy data across the Y-axis 411 | xyR = xy 412 | xyR[:, 0] = -1 * xyR[:, 0] 413 | 414 | trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) 415 | 416 | # manually reflect the tform to undo the reflection done on xyR 417 | TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) 418 | 419 | trans2 = np.dot(trans2r, TreflectY) 420 | 421 | # Figure out if trans1 or trans2 is better 422 | xy1 = tformfwd(trans1, uv) 423 | norm1 = norm(xy1 - xy) 424 | 425 | xy2 = tformfwd(trans2, uv) 426 | norm2 = norm(xy2 - xy) 427 | 428 | if norm1 <= norm2: 429 | return trans1, trans1_inv 430 | else: 431 | trans2_inv = inv(trans2) 432 | return trans2, trans2_inv 433 | 434 | 435 | def get_similarity_transform(src_pts, dst_pts, reflective=True): 436 | """ 437 | Function: 438 | ---------- 439 | Find Similarity Transform Matrix 'trans': 440 | u = src_pts[:, 0] 441 | v = src_pts[:, 1] 442 | x = dst_pts[:, 0] 443 | y = dst_pts[:, 1] 444 | [x, y, 1] = [u, v, 1] * trans 445 | 446 | Parameters: 447 | ---------- 448 | @src_pts: Kx2 np.array 449 | source points, each row is a pair of coordinates (x, y) 450 | @dst_pts: Kx2 np.array 451 | destination points, each row is a pair of transformed 452 | coordinates (x, y) 453 | @reflective: True or False 454 | if True: 455 | use reflective similarity transform 456 | else: 457 | use non-reflective similarity transform 458 | 459 | Returns: 460 | ---------- 461 | @trans: 3x3 np.array 462 | transform matrix from uv to xy 463 | trans_inv: 3x3 np.array 464 | inverse of trans, transform matrix from xy to uv 465 | """ 466 | 467 | if reflective: 468 | trans, trans_inv = findSimilarity(src_pts, dst_pts) 469 | else: 470 | trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) 471 | 472 | return trans, trans_inv 473 | 474 | 475 | def cvt_tform_mat_for_cv2(trans): 476 | """ 477 | Function: 478 | ---------- 479 | Convert Transform Matrix 'trans' into 'cv2_trans' which could be 480 | directly used by cv2.warpAffine(): 481 | u = src_pts[:, 0] 482 | v = src_pts[:, 1] 483 | x = dst_pts[:, 0] 484 | y = dst_pts[:, 1] 485 | [x, y].T = cv_trans * [u, v, 1].T 486 | 487 | Parameters: 488 | ---------- 489 | @trans: 3x3 np.array 490 | transform matrix from uv to xy 491 | 492 | Returns: 493 | ---------- 494 | @cv2_trans: 2x3 np.array 495 | transform matrix from src_pts to dst_pts, could be directly used 496 | for cv2.warpAffine() 497 | """ 498 | cv2_trans = trans[:, 0:2].T 499 | 500 | return cv2_trans 501 | 502 | 503 | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): 504 | """ 505 | Function: 506 | ---------- 507 | Find Similarity Transform Matrix 'cv2_trans' which could be 508 | directly used by cv2.warpAffine(): 509 | u = src_pts[:, 0] 510 | v = src_pts[:, 1] 511 | x = dst_pts[:, 0] 512 | y = dst_pts[:, 1] 513 | [x, y].T = cv_trans * [u, v, 1].T 514 | 515 | Parameters: 516 | ---------- 517 | @src_pts: Kx2 np.array 518 | source points, each row is a pair of coordinates (x, y) 519 | @dst_pts: Kx2 np.array 520 | destination points, each row is a pair of transformed 521 | coordinates (x, y) 522 | reflective: True or False 523 | if True: 524 | use reflective similarity transform 525 | else: 526 | use non-reflective similarity transform 527 | 528 | Returns: 529 | ---------- 530 | @cv2_trans: 2x3 np.array 531 | transform matrix from src_pts to dst_pts, could be directly used 532 | for cv2.warpAffine() 533 | """ 534 | trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) 535 | cv2_trans = cvt_tform_mat_for_cv2(trans) 536 | 537 | return cv2_trans 538 | 539 | 540 | if __name__ == "__main__": 541 | """ 542 | u = [0, 6, -2] 543 | v = [0, 3, 5] 544 | x = [-1, 0, 4] 545 | y = [-1, -10, 4] 546 | 547 | # In Matlab, run: 548 | # 549 | # uv = [u'; v']; 550 | # xy = [x'; y']; 551 | # tform_sim=cp2tform(uv,xy,'similarity'); 552 | # 553 | # trans = tform_sim.tdata.T 554 | # ans = 555 | # -0.0764 -1.6190 0 556 | # 1.6190 -0.0764 0 557 | # -3.2156 0.0290 1.0000 558 | # trans_inv = tform_sim.tdata.Tinv 559 | # ans = 560 | # 561 | # -0.0291 0.6163 0 562 | # -0.6163 -0.0291 0 563 | # -0.0756 1.9826 1.0000 564 | # xy_m=tformfwd(tform_sim, u,v) 565 | # 566 | # xy_m = 567 | # 568 | # -3.2156 0.0290 569 | # 1.1833 -9.9143 570 | # 5.0323 2.8853 571 | # uv_m=tforminv(tform_sim, x,y) 572 | # 573 | # uv_m = 574 | # 575 | # 0.5698 1.3953 576 | # 6.0872 2.2733 577 | # -2.6570 4.3314 578 | """ 579 | u = [0, 6, -2] 580 | v = [0, 3, 5] 581 | x = [-1, 0, 4] 582 | y = [-1, -10, 4] 583 | 584 | uv = np.array((u, v)).T 585 | xy = np.array((x, y)).T 586 | 587 | print("\n--->uv:") 588 | print(uv) 589 | print("\n--->xy:") 590 | print(xy) 591 | 592 | trans, trans_inv = get_similarity_transform(uv, xy) 593 | 594 | print("\n--->trans matrix:") 595 | print(trans) 596 | 597 | print("\n--->trans_inv matrix:") 598 | print(trans_inv) 599 | 600 | print("\n---> apply transform to uv") 601 | print("\nxy_m = uv_augmented * trans") 602 | uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) 603 | xy_m = np.dot(uv_aug, trans) 604 | print(xy_m) 605 | 606 | print("\nxy_m = tformfwd(trans, uv)") 607 | xy_m = tformfwd(trans, uv) 608 | print(xy_m) 609 | 610 | print("\n---> apply inverse transform to xy") 611 | print("\nuv_m = xy_augmented * trans_inv") 612 | xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) 613 | uv_m = np.dot(xy_aug, trans_inv) 614 | print(uv_m) 615 | 616 | print("\nuv_m = tformfwd(trans_inv, xy)") 617 | uv_m = tformfwd(trans_inv, xy) 618 | print(uv_m) 619 | 620 | uv_m = tforminv(trans, xy) 621 | print("\nuv_m = tforminv(trans, xy)") 622 | print(uv_m) 623 | --------------------------------------------------------------------------------