├── .gitignore ├── README.md ├── __init__.py ├── _init_paths.py ├── config ├── __init__.py ├── base_config.py └── test.yaml ├── dataloader ├── __init__.py ├── base_dset.py ├── custom_dset.py ├── mnist.py ├── triplet_img_loader.py └── vggface2.py ├── images └── tSNE_mnist.jpg ├── model ├── __init__.py ├── embedding.py ├── net.py └── resnet.py ├── requirements.txt ├── train.py ├── tsne.py └── utils ├── gen_utils.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea 107 | data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep metric learning using Triplet network in PyTorch 2 | 3 | The following repository contains code for training Triplet Network in Pytorch 4 | Siamese and Triplet networks make use of a similarity metric with the aim of bringing similar images closer in the embedding space while separating non similar ones. 5 | Popular uses of such networks being - 6 | * Face Verification / Classification 7 | * Learning deep embeddings for other tasks like classification / detection / segmentation 8 | 9 | Paper - [Deep metric learning using Triplet network](http://arxiv.org/abs/1412.6622) 10 | 11 | ## Installation 12 | 13 | Install [PyTorch](https://pytorch.org/get-started/locally/) 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Demo 19 | 20 | Colab notebook with pretrained weights 21 | 22 | ## Training 23 | 24 | ``` 25 | python train.py --cuda 26 | ``` 27 | This by default will train on the MNIST dataset 28 | 29 | ### MNIST / FashionMNIST 30 | 31 | ``` 32 | python train.py --result_dir results --exp_name MNIST_exp1 --cuda --dataset / 33 | ``` 34 | To create a **tSNE** visualisation 35 | ``` 36 | python tsne.py --ckp 37 | ``` 38 | The embeddings and the labels are stored in the experiment folder as a pickle file, and you do not have to run the model everytime you create a visualisation. Just pass the saved embeddings as the --pkl parameter 39 | ``` 40 | python tsne.py --pkl 41 | ``` 42 | Sample tSNE visualisation on MNIST 43 | ![tSNE](images/tSNE_mnist.jpg "tSNE visualisation on MNIST") 44 | 45 | 46 | ### [VGGFace2](http://www.robots.ox.ac.uk/~vgg/data/vgg_face2/) 47 | 48 | Specify the location of the dataset in test.yaml 49 | The directory should have the following structure 50 | ```buildoutcfg 51 | +-- root 52 | | +-- train 53 | | +-- class1 54 | | +-- img1.jpg 55 | | +-- img2.jpg 56 | | +-- img3.jpg 57 | | +-- class2 58 | | +-- class3 59 | | +-- test 60 | | +-- class4 61 | | +-- class5 62 | ``` 63 | 64 | ``` 65 | python train.py --result_dir results --exp_name VGGFace2_exp1 --cuda --epochs 50 --ckp_freq 5 --dataset vggface2 --num_train_samples 32000 --num_test_samples 5000 --train_log_step 50 66 | ``` 67 | 68 | ### Custom Dataset 69 | 70 | Specify the location of the dataset in test.yaml 71 | The directory should have the following structure 72 | ```buildoutcfg 73 | +-- root 74 | | +-- train 75 | | +-- class1 76 | | +-- img1.jpg 77 | | +-- img2.jpg 78 | | +-- img3.jpg 79 | | +-- class2 80 | | +-- class3 81 | | +-- test 82 | | +-- class4 83 | | +-- class5 84 | ``` 85 | 86 | ``` 87 | python train.py --result_dir results --exp_name Custom_exp1 --cuda --epochs 50 --ckp_freq 5 --dataset custom --num_train_samples 32000 --num_test_samples 5000 --train_log_step 50 88 | ``` 89 | 90 | ## TODO 91 | 92 | - [x] Train on MNIST / FashionMNIST 93 | - [x] Train on a public dataset 94 | - [x] Multi GPU Training 95 | - [x] Custom Dataset 96 | - [ ] Include popular models - ResneXT / Resnet / VGG / Inception -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avilash/pytorch-siamese-triplet/c4f88b4f84cff02792118e90aabd8eedff603c2e/__init__.py -------------------------------------------------------------------------------- /_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | 5 | def add_path(res_path): 6 | if res_path not in sys.path: 7 | sys.path.insert(0, res_path) 8 | 9 | 10 | this_dir = osp.dirname(__file__) 11 | 12 | add_path(osp.join(this_dir)) 13 | 14 | path = osp.join(this_dir, 'config/') 15 | add_path(path) 16 | 17 | path = osp.join(this_dir, 'dataloader/') 18 | add_path(path) 19 | 20 | path = osp.join(this_dir, 'model/') 21 | add_path(path) 22 | 23 | path = osp.join(this_dir, 'utils/') 24 | add_path(path) 25 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avilash/pytorch-siamese-triplet/c4f88b4f84cff02792118e90aabd8eedff603c2e/config/__init__.py -------------------------------------------------------------------------------- /config/base_config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | __C = edict() 4 | 5 | cfg = __C 6 | 7 | __C.RESNET = edict() 8 | __C.RESNET.FIXED_BLOCKS = 1 9 | 10 | __C.DATASETS = edict() 11 | __C.DATASETS.VGGFACE2 = edict() 12 | __C.DATASETS.VGGFACE2.HOME = "" 13 | __C.DATASETS.CUSTOM = edict() 14 | __C.DATASETS.CUSTOM.HOME = "" 15 | 16 | 17 | def _merge_a_into_b(a, b): 18 | """Merge config dictionary a into config dictionary b, clobbering the 19 | options in b whenever they are also specified in a. 20 | """ 21 | if type(a) is not edict: 22 | return 23 | 24 | for k, v in a.items(): 25 | # a must specify keys that are in b 26 | if k not in b: 27 | raise KeyError('{} is not a valid config key'.format(k)) 28 | 29 | # the types must match, too 30 | old_type = type(b[k]) 31 | if old_type is not type(v): 32 | if isinstance(b[k], np.ndarray): 33 | v = np.array(v, dtype=b[k].dtype) 34 | else: 35 | raise ValueError(('Type mismatch ({} vs. {}) ' 36 | 'for config key: {}').format(type(b[k]), 37 | type(v), k)) 38 | 39 | # recursively merge dicts 40 | if type(v) is edict: 41 | try: 42 | _merge_a_into_b(a[k], b[k]) 43 | except: 44 | print(('Error under config key: {}'.format(k))) 45 | raise 46 | else: 47 | b[k] = v 48 | 49 | 50 | def cfg_from_file(filename): 51 | """Load a config file and merge it into the default options.""" 52 | import yaml 53 | with open(filename, 'r') as f: 54 | yaml_cfg = edict(yaml.load(f, Loader=yaml.FullLoader)) 55 | 56 | _merge_a_into_b(yaml_cfg, __C) 57 | 58 | 59 | def cfg_from_list(cfg_list): 60 | """Set config keys via list (e.g., from command line).""" 61 | from ast import literal_eval 62 | assert len(cfg_list) % 2 == 0 63 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]): 64 | key_list = k.split('.') 65 | d = __C 66 | for subkey in key_list[:-1]: 67 | assert subkey in d 68 | d = d[subkey] 69 | subkey = key_list[-1] 70 | assert subkey in d 71 | try: 72 | value = literal_eval(v) 73 | except Exception as e: 74 | # handle the case when v is a string literal 75 | value = v 76 | assert type(value) == type(d[subkey]), \ 77 | 'type {} does not match original type {}'.format( 78 | type(value), type(d[subkey])) 79 | d[subkey] = value 80 | -------------------------------------------------------------------------------- /config/test.yaml: -------------------------------------------------------------------------------- 1 | RESNET: 2 | FIXED_BLOCKS: 2 3 | 4 | DATASETS: 5 | VGGFACE2: 6 | HOME: "" 7 | CUSTOM: 8 | HOME: "" 9 | 10 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avilash/pytorch-siamese-triplet/c4f88b4f84cff02792118e90aabd8eedff603c2e/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/base_dset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | 5 | class BaseDset(object): 6 | 7 | def __init__(self): 8 | self.__base_path = "" 9 | 10 | self.__train_set = {} 11 | self.__test_set = {} 12 | self.__train_keys = [] 13 | self.__test_keys = [] 14 | 15 | def load(self, base_path): 16 | self.__base_path = base_path 17 | train_dir = os.path.join(self.__base_path, 'train') 18 | test_dir = os.path.join(self.__base_path, 'test') 19 | 20 | self.__train_set = {} 21 | self.__test_set = {} 22 | self.__train_keys = [] 23 | self.__test_keys = [] 24 | 25 | for class_id in os.listdir(train_dir): 26 | class_dir = os.path.join(train_dir, class_id) 27 | self.__train_set[class_id] = [] 28 | self.__train_keys.append(class_id) 29 | for img_name in os.listdir(class_dir): 30 | img_path = os.path.join(class_dir, img_name) 31 | self.__train_set[class_id].append(img_path) 32 | 33 | for class_id in os.listdir(test_dir): 34 | class_dir = os.path.join(test_dir, class_id) 35 | self.__test_set[class_id] = [] 36 | self.__test_keys.append(class_id) 37 | for img_name in os.listdir(class_dir): 38 | img_path = os.path.join(class_dir, img_name) 39 | self.__test_set[class_id].append(img_path) 40 | 41 | return len(self.__train_keys), len(self.__test_keys) 42 | 43 | def getTriplet(self, split='train'): 44 | if split == 'train': 45 | dataset = self.__train_set 46 | keys = self.__train_keys 47 | else: 48 | dataset = self.__test_set 49 | keys = self.__test_keys 50 | 51 | pos_idx = 0 52 | neg_idx = 0 53 | pos_anchor_img_idx = 0 54 | pos_img_idx = 0 55 | neg_img_idx = 0 56 | 57 | pos_idx = random.randint(0, len(keys) - 1) 58 | while True: 59 | neg_idx = random.randint(0, len(keys) - 1) 60 | if pos_idx != neg_idx: 61 | break 62 | 63 | pos_anchor_img_idx = random.randint(0, len(dataset[keys[pos_idx]]) - 1) 64 | while True: 65 | pos_img_idx = random.randint(0, len(dataset[keys[pos_idx]]) - 1) 66 | if pos_anchor_img_idx != pos_img_idx: 67 | break 68 | 69 | neg_img_idx = random.randint(0, len(dataset[keys[neg_idx]]) - 1) 70 | 71 | pos_anchor_img = dataset[keys[pos_idx]][pos_anchor_img_idx] 72 | pos_img = dataset[keys[pos_idx]][pos_img_idx] 73 | neg_img = dataset[keys[neg_idx]][neg_img_idx] 74 | 75 | return pos_anchor_img, pos_img, neg_img 76 | -------------------------------------------------------------------------------- /dataloader/custom_dset.py: -------------------------------------------------------------------------------- 1 | from config.base_config import cfg 2 | from dataloader.base_dset import BaseDset 3 | 4 | 5 | class Custom(BaseDset): 6 | 7 | def __init__(self): 8 | super(Custom, self).__init__() 9 | 10 | def load(self): 11 | base_path = cfg.DATASETS.CUSTOM.HOME 12 | super(Custom, self).load(base_path) 13 | -------------------------------------------------------------------------------- /dataloader/mnist.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | class MNIST_DS(object): 6 | 7 | def __init__(self, train_dataset, test_dataset): 8 | self.__train_labels_idx_map = {} 9 | self.__test_labels_idx_map = {} 10 | 11 | self.__train_data = train_dataset.data 12 | self.__test_data = test_dataset.data 13 | self.__train_labels = train_dataset.targets 14 | self.__test_labels = test_dataset.targets 15 | 16 | self.__train_labels_np = self.__train_labels.numpy() 17 | self.__train_unique_labels = np.unique(self.__train_labels_np) 18 | 19 | self.__test_labels_np = self.__test_labels.numpy() 20 | self.__test_unique_labels = np.unique(self.__test_labels_np) 21 | 22 | def load(self): 23 | self.__train_labels_idx_map = {} 24 | for label in self.__train_unique_labels: 25 | self.__train_labels_idx_map[label] = np.where(self.__train_labels_np == label)[0] 26 | 27 | self.__test_labels_idx_map = {} 28 | for label in self.__test_unique_labels: 29 | self.__test_labels_idx_map[label] = np.where(self.__test_labels_np == label)[0] 30 | 31 | def getTriplet(self, split="train"): 32 | pos_label = 0 33 | neg_label = 0 34 | label_idx_map = None 35 | data = None 36 | 37 | if split == 'train': 38 | pos_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)] 39 | neg_label = pos_label 40 | while neg_label is pos_label: 41 | neg_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)] 42 | label_idx_map = self.__train_labels_idx_map 43 | data = self.__train_data 44 | else: 45 | pos_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)] 46 | neg_label = pos_label 47 | while neg_label is pos_label: 48 | neg_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)] 49 | label_idx_map = self.__test_labels_idx_map 50 | data = self.__test_data 51 | 52 | pos_label_idx_map = label_idx_map[pos_label] 53 | pos_img_anchor_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)] 54 | pos_img_idx = pos_img_anchor_idx 55 | while pos_img_idx is pos_img_anchor_idx: 56 | pos_img_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)] 57 | 58 | neg_label_idx_map = label_idx_map[neg_label] 59 | neg_img_idx = neg_label_idx_map[random.randint(0, len(neg_label_idx_map) - 1)] 60 | 61 | pos_anchor_img = data[pos_img_anchor_idx].numpy() 62 | pos_img = data[pos_img_idx].numpy() 63 | neg_img = data[neg_img_idx].numpy() 64 | 65 | return pos_anchor_img, pos_img, neg_img 66 | -------------------------------------------------------------------------------- /dataloader/triplet_img_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | 5 | import torch.utils.data 6 | from torchvision import transforms 7 | from torchvision.datasets import MNIST, FashionMNIST 8 | 9 | from dataloader import mnist, vggface2, custom_dset 10 | 11 | 12 | class BaseLoader(torch.utils.data.Dataset): 13 | def __init__(self, triplets, transform=None): 14 | self.triplets = triplets 15 | self.transform = transform 16 | 17 | def __getitem__(self, index): 18 | img1_pth, img2_pth, img3_pth = self.triplets[index] 19 | img1 = cv2.imread(img1_pth) 20 | img2 = cv2.imread(img2_pth) 21 | img3 = cv2.imread(img3_pth) 22 | 23 | try: 24 | img1 = cv2.resize(img1, (228, 228)) 25 | except Exception as e: 26 | img1 = np.zeros((228, 228, 3), dtype=np.uint8) 27 | 28 | try: 29 | img2 = cv2.resize(img2, (228, 228)) 30 | except Exception as e: 31 | img2 = np.zeros((228, 228, 3), dtype=np.uint8) 32 | 33 | try: 34 | img3 = cv2.resize(img3, (228, 228)) 35 | except Exception as e: 36 | img3 = np.zeros((228, 228, 3), dtype=np.uint8) 37 | 38 | if self.transform is not None: 39 | img1 = self.transform(img1) 40 | img2 = self.transform(img2) 41 | img3 = self.transform(img3) 42 | 43 | return img1, img2, img3 44 | 45 | def __len__(self): 46 | return len(self.triplets) 47 | 48 | 49 | class TripletMNISTLoader(BaseLoader): 50 | def __init__(self, triplets, transform=None): 51 | super(TripletMNISTLoader, self).__init__(triplets, transform=transform) 52 | 53 | def __getitem__(self, index): 54 | img1, img2, img3 = self.triplets[index] 55 | img1 = np.expand_dims(img1, axis=2) 56 | img2 = np.expand_dims(img2, axis=2) 57 | img3 = np.expand_dims(img3, axis=2) 58 | 59 | if self.transform is not None: 60 | img1 = self.transform(img1) 61 | img2 = self.transform(img2) 62 | img3 = self.transform(img3) 63 | 64 | return img1, img2, img3 65 | 66 | 67 | def get_loader(args): 68 | train_data_loader = None 69 | test_data_loader = None 70 | 71 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 72 | 73 | train_triplets = [] 74 | test_triplets = [] 75 | 76 | dset_obj = None 77 | loader = BaseLoader 78 | means = (0.485, 0.456, 0.406) 79 | stds = (0.229, 0.224, 0.225) 80 | 81 | if args.dataset == 'vggface2': 82 | dset_obj = vggface2.VGGFace2() 83 | elif args.dataset == 'custom': 84 | dset_obj = custom_dset.Custom() 85 | elif (args.dataset == 'mnist') or (args.dataset == 'fmnist'): 86 | train_dataset, test_dataset = None, None 87 | if args.dataset == 'mnist': 88 | train_dataset = MNIST(os.path.join(args.result_dir, "MNIST"), train=True, download=True) 89 | test_dataset = MNIST(os.path.join(args.result_dir, "MNIST"), train=False, download=True) 90 | if args.dataset == 'fmnist': 91 | train_dataset = FashionMNIST(os.path.join(args.result_dir, "FashionMNIST"), train=True, download=True) 92 | test_dataset = FashionMNIST(os.path.join(args.result_dir, "FashionMNIST"), train=False, download=True) 93 | dset_obj = mnist.MNIST_DS(train_dataset, test_dataset) 94 | loader = TripletMNISTLoader 95 | means = (0.485,) 96 | stds = (0.229,) 97 | 98 | dset_obj.load() 99 | for i in range(args.num_train_samples): 100 | pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet() 101 | train_triplets.append([pos_anchor_img, pos_img, neg_img]) 102 | for i in range(args.num_test_samples): 103 | pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet(split='test') 104 | test_triplets.append([pos_anchor_img, pos_img, neg_img]) 105 | 106 | train_data_loader = torch.utils.data.DataLoader( 107 | loader(train_triplets, 108 | transform=transforms.Compose([ 109 | transforms.ToTensor(), 110 | transforms.Normalize(means, stds) 111 | ])), 112 | batch_size=args.batch_size, shuffle=True, **kwargs) 113 | test_data_loader = torch.utils.data.DataLoader( 114 | loader(test_triplets, 115 | transform=transforms.Compose([ 116 | transforms.ToTensor(), 117 | transforms.Normalize(means, stds) 118 | ])), 119 | batch_size=args.batch_size, shuffle=True, **kwargs) 120 | 121 | return train_data_loader, test_data_loader 122 | -------------------------------------------------------------------------------- /dataloader/vggface2.py: -------------------------------------------------------------------------------- 1 | from config.base_config import cfg 2 | from dataloader.base_dset import BaseDset 3 | 4 | 5 | class VGGFace2(BaseDset): 6 | 7 | def __init__(self): 8 | super(VGGFace2, self).__init__() 9 | 10 | def load(self): 11 | base_path = cfg.DATASETS.VGGFACE2.HOME 12 | super(VGGFace2, self).load(base_path) 13 | -------------------------------------------------------------------------------- /images/tSNE_mnist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avilash/pytorch-siamese-triplet/c4f88b4f84cff02792118e90aabd8eedff603c2e/images/tSNE_mnist.jpg -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avilash/pytorch-siamese-triplet/c4f88b4f84cff02792118e90aabd8eedff603c2e/model/__init__.py -------------------------------------------------------------------------------- /model/embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from model.resnet import resnet50 4 | from config.base_config import cfg 5 | 6 | 7 | class EmbeddingResnet(nn.Module): 8 | def __init__(self): 9 | super(EmbeddingResnet, self).__init__() 10 | 11 | resnet = resnet50(pretrained=True) 12 | self.features = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1, 13 | resnet.layer2, resnet.layer3, resnet.layer4, resnet.avgpool) 14 | # Fix blocks 15 | for p in self.features[0].parameters(): 16 | p.requires_grad = False 17 | for p in self.features[1].parameters(): 18 | p.requires_grad = False 19 | if cfg.RESNET.FIXED_BLOCKS >= 3: 20 | for p in self.features[6].parameters(): 21 | p.requires_grad = False 22 | if cfg.RESNET.FIXED_BLOCKS >= 2: 23 | for p in self.features[5].parameters(): 24 | p.requires_grad = False 25 | if cfg.RESNET.FIXED_BLOCKS >= 1: 26 | for p in self.features[4].parameters(): 27 | p.requires_grad = False 28 | 29 | def set_bn_fix(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm') != -1: 32 | for p in m.parameters(): p.requires_grad = False 33 | 34 | self.features.apply(set_bn_fix) 35 | 36 | def forward(self, x): 37 | features = self.features.forward(x) 38 | features = features.view(features.size(0), -1) 39 | features = F.normalize(features, p=2, dim=1) 40 | return features 41 | 42 | 43 | class EmbeddingLeNet(nn.Module): 44 | def __init__(self): 45 | super(EmbeddingLeNet, self).__init__() 46 | 47 | self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(), 48 | nn.MaxPool2d(2, stride=2), 49 | nn.Conv2d(32, 64, 5), nn.PReLU(), 50 | nn.MaxPool2d(2, stride=2)) 51 | 52 | self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256), 53 | nn.PReLU(), 54 | nn.Linear(256, 128), 55 | nn.PReLU(), 56 | nn.Linear(128, 64) 57 | ) 58 | 59 | def forward(self, x): 60 | output = self.convnet(x) 61 | output = output.view(output.size()[0], -1) 62 | output = self.fc(output) 63 | output = F.normalize(output, p=2, dim=1) 64 | return output 65 | -------------------------------------------------------------------------------- /model/net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model import embedding 6 | 7 | 8 | class TripletNet(nn.Module): 9 | def __init__(self, embeddingNet): 10 | super(TripletNet, self).__init__() 11 | self.embeddingNet = embeddingNet 12 | 13 | def forward(self, i1, i2, i3): 14 | E1 = self.embeddingNet(i1) 15 | E2 = self.embeddingNet(i2) 16 | E3 = self.embeddingNet(i3) 17 | return E1, E2, E3 18 | 19 | 20 | def get_model(args, device): 21 | # Model 22 | embeddingNet = None 23 | if (args.dataset == 'custom') or (args.dataset == 'vggface2'): 24 | embeddingNet = embedding.EmbeddingResnet() 25 | elif (args.dataset == 'mnist') or (args.dataset == 'fmnist'): 26 | embeddingNet = embedding.EmbeddingLeNet() 27 | else: 28 | print("Dataset %s not supported " % args.dataset) 29 | return None 30 | 31 | model = TripletNet(embeddingNet) 32 | model = nn.DataParallel(model, device_ids=args.gpu_devices) 33 | model = model.to(device) 34 | 35 | # Load weights if provided 36 | if args.ckp: 37 | if os.path.isfile(args.ckp): 38 | print("=> Loading checkpoint '{}'".format(args.ckp)) 39 | checkpoint = torch.load(args.ckp) 40 | model.load_state_dict(checkpoint['state_dict']) 41 | print("=> Loaded checkpoint '{}'".format(args.ckp)) 42 | else: 43 | print("=> No checkpoint found at '{}'".format(args.ckp)) 44 | 45 | return model 46 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AvgPool2d(7, stride=1) 110 | self.fc = nn.Linear(512 * block.expansion, num_classes) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 115 | elif isinstance(m, nn.BatchNorm2d): 116 | nn.init.constant_(m.weight, 1) 117 | nn.init.constant_(m.bias, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = list() 129 | layers.append(block(self.inplanes, planes, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | x = self.fc(x) 150 | 151 | return x 152 | 153 | 154 | def resnet18(pretrained=False, **kwargs): 155 | """Constructs a ResNet-18 model. 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 161 | if pretrained: 162 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 163 | return model 164 | 165 | 166 | def resnet34(pretrained=False, **kwargs): 167 | """Constructs a ResNet-34 model. 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | """ 172 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 173 | if pretrained: 174 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 175 | return model 176 | 177 | 178 | def resnet50(pretrained=False, **kwargs): 179 | """Constructs a ResNet-50 model. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 187 | return model 188 | 189 | 190 | def resnet101(pretrained=False, **kwargs): 191 | """Constructs a ResNet-101 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 199 | return model 200 | 201 | 202 | def resnet152(pretrained=False, **kwargs): 203 | """Constructs a ResNet-152 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 211 | return model 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | matplotlib 3 | numpy 4 | opencv-python 5 | Pillow 6 | pyaml 7 | scikit-learn 8 | scipy 9 | tqdm 10 | torch 11 | torchvision -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import _init_paths 2 | import os 3 | import argparse 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | 10 | from torch.autograd import Variable 11 | import torch.backends.cudnn as cudnn 12 | 13 | from model.net import get_model 14 | from dataloader.triplet_img_loader import get_loader 15 | from utils.gen_utils import make_dir_if_not_exist 16 | from utils.vis_utils import vis_with_paths, vis_with_paths_and_bboxes 17 | 18 | from config.base_config import cfg, cfg_from_file 19 | 20 | 21 | def main(): 22 | torch.manual_seed(1) 23 | if args.cuda: 24 | torch.cuda.manual_seed(1) 25 | cudnn.benchmark = True 26 | 27 | exp_dir = os.path.join(args.result_dir, args.exp_name) 28 | make_dir_if_not_exist(exp_dir) 29 | 30 | # Build Model 31 | model = get_model(args, device) 32 | if model is None: 33 | return 34 | 35 | # Criterion and Optimizer 36 | params = [] 37 | for key, value in dict(model.named_parameters()).items(): 38 | if value.requires_grad: 39 | params += [{'params': [value]}] 40 | criterion = torch.nn.MarginRankingLoss(margin=args.margin) 41 | optimizer = optim.Adam(params, lr=args.lr) 42 | 43 | # Train Test Loop 44 | for epoch in range(1, args.epochs + 1): 45 | # Init data loaders 46 | train_data_loader, test_data_loader = get_loader(args) 47 | # Test train 48 | test(test_data_loader, model, criterion) 49 | train(train_data_loader, model, criterion, optimizer, epoch) 50 | # Save model 51 | model_to_save = { 52 | "epoch": epoch + 1, 53 | 'state_dict': model.state_dict(), 54 | } 55 | if epoch % args.ckp_freq == 0: 56 | file_name = os.path.join(exp_dir, "checkpoint_" + str(epoch) + ".pth") 57 | save_checkpoint(model_to_save, file_name) 58 | 59 | 60 | def train(data, model, criterion, optimizer, epoch): 61 | print("******** Training ********") 62 | total_loss = 0 63 | model.train() 64 | for batch_idx, img_triplet in enumerate(data): 65 | anchor_img, pos_img, neg_img = img_triplet 66 | anchor_img, pos_img, neg_img = anchor_img.to(device), pos_img.to(device), neg_img.to(device) 67 | anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img) 68 | E1, E2, E3 = model(anchor_img, pos_img, neg_img) 69 | dist_E1_E2 = F.pairwise_distance(E1, E2, 2) 70 | dist_E1_E3 = F.pairwise_distance(E1, E3, 2) 71 | 72 | target = torch.FloatTensor(dist_E1_E2.size()).fill_(-1) 73 | target = target.to(device) 74 | target = Variable(target) 75 | loss = criterion(dist_E1_E2, dist_E1_E3, target) 76 | total_loss += loss 77 | 78 | optimizer.zero_grad() 79 | loss.backward() 80 | optimizer.step() 81 | 82 | log_step = args.train_log_step 83 | if (batch_idx % log_step == 0) and (batch_idx != 0): 84 | print('Train Epoch: {} [{}/{}] \t Loss: {:.4f}'.format(epoch, batch_idx, len(data), total_loss / log_step)) 85 | total_loss = 0 86 | print("****************") 87 | 88 | 89 | def test(data, model, criterion): 90 | print("******** Testing ********") 91 | with torch.no_grad(): 92 | model.eval() 93 | accuracies = [0, 0, 0] 94 | acc_threshes = [0, 0.2, 0.5] 95 | total_loss = 0 96 | for batch_idx, img_triplet in enumerate(data): 97 | anchor_img, pos_img, neg_img = img_triplet 98 | anchor_img, pos_img, neg_img = anchor_img.to(device), pos_img.to(device), neg_img.to(device) 99 | anchor_img, pos_img, neg_img = Variable(anchor_img), Variable(pos_img), Variable(neg_img) 100 | E1, E2, E3 = model(anchor_img, pos_img, neg_img) 101 | dist_E1_E2 = F.pairwise_distance(E1, E2, 2) 102 | dist_E1_E3 = F.pairwise_distance(E1, E3, 2) 103 | 104 | target = torch.FloatTensor(dist_E1_E2.size()).fill_(-1) 105 | target = target.to(device) 106 | target = Variable(target) 107 | 108 | loss = criterion(dist_E1_E2, dist_E1_E3, target) 109 | total_loss += loss 110 | 111 | for i in range(len(accuracies)): 112 | prediction = (dist_E1_E3 - dist_E1_E2 - args.margin * acc_threshes[i]).cpu().data 113 | prediction = prediction.view(prediction.numel()) 114 | prediction = (prediction > 0).float() 115 | batch_acc = prediction.sum() * 1.0 / prediction.numel() 116 | accuracies[i] += batch_acc 117 | print('Test Loss: {}'.format(total_loss / len(data))) 118 | for i in range(len(accuracies)): 119 | print( 120 | 'Test Accuracy with diff = {}% of margin: {}'.format(acc_threshes[i] * 100, accuracies[i] / len(data))) 121 | print("****************") 122 | 123 | 124 | def save_checkpoint(state, file_name): 125 | torch.save(state, file_name) 126 | 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser(description='PyTorch Siamese Example') 130 | parser.add_argument('--result_dir', default='data', type=str, 131 | help='Directory to store results') 132 | parser.add_argument('--exp_name', default='exp0', type=str, 133 | help='name of experiment') 134 | parser.add_argument('--cuda', action='store_true', default=False, 135 | help='enables CUDA training') 136 | parser.add_argument("--gpu_devices", type=int, nargs='+', default=None, 137 | help="List of GPU Devices to train on") 138 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 139 | help='number of epochs to train (default: 10)') 140 | parser.add_argument('--ckp_freq', type=int, default=1, metavar='N', 141 | help='Checkpoint Frequency (default: 1)') 142 | parser.add_argument('--batch_size', type=int, default=64, metavar='N', 143 | help='input batch size for training (default: 64)') 144 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', 145 | help='learning rate (default: 0.0001)') 146 | parser.add_argument('--margin', type=float, default=1.0, metavar='M', 147 | help='margin for triplet loss (default: 1.0)') 148 | parser.add_argument('--ckp', default=None, type=str, 149 | help='path to load checkpoint') 150 | 151 | parser.add_argument('--dataset', type=str, default='mnist', metavar='M', 152 | help='Dataset (default: mnist)') 153 | 154 | parser.add_argument('--num_train_samples', type=int, default=50000, metavar='M', 155 | help='number of training samples (default: 3000)') 156 | parser.add_argument('--num_test_samples', type=int, default=10000, metavar='M', 157 | help='number of test samples (default: 1000)') 158 | 159 | parser.add_argument('--train_log_step', type=int, default=100, metavar='M', 160 | help='Number of iterations after which to log the loss') 161 | 162 | global args, device 163 | args = parser.parse_args() 164 | args.cuda = args.cuda and torch.cuda.is_available() 165 | cfg_from_file("config/test.yaml") 166 | 167 | if args.cuda: 168 | device = 'cuda' 169 | if args.gpu_devices is None: 170 | args.gpu_devices = [0] 171 | else: 172 | device = 'cpu' 173 | main() 174 | -------------------------------------------------------------------------------- /tsne.py: -------------------------------------------------------------------------------- 1 | import _init_paths 2 | import os 3 | import argparse 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | import torch 8 | from torchvision import datasets, transforms 9 | from torchvision.datasets import MNIST, FashionMNIST 10 | from torch.autograd import Variable 11 | import torch.backends.cudnn as cudnn 12 | 13 | from model import net, embedding 14 | 15 | from utils.gen_utils import make_dir_if_not_exist 16 | 17 | from config.base_config import cfg, cfg_from_file 18 | 19 | import cv2 20 | import numpy as np 21 | from sklearn.manifold import TSNE 22 | import matplotlib as mpl 23 | import matplotlib.pyplot as plt 24 | 25 | 26 | def main(): 27 | torch.manual_seed(1) 28 | if args.cuda: 29 | torch.cuda.manual_seed(1) 30 | 31 | exp_dir = os.path.join("data", args.exp_name) 32 | make_dir_if_not_exist(exp_dir) 33 | 34 | if args.pkl is not None: 35 | input_file = open(args.pkl, 'rb') 36 | final_data = pickle.load(input_file) 37 | input_file.close() 38 | embeddings = final_data['embeddings'] 39 | labels = final_data['labels'] 40 | vis_tSNE(embeddings, labels) 41 | else: 42 | embeddingNet = None 43 | if (args.dataset == 's2s') or (args.dataset == 'vggface2'): 44 | embeddingNet = embedding.EmbeddingResnet() 45 | elif (args.dataset == 'mnist') or (args.dataset == 'fmnist'): 46 | embeddingNet = embedding.EmbeddingLeNet() 47 | else: 48 | print("Dataset {} not supported ".format(args.dataset)) 49 | return 50 | 51 | model_dict = None 52 | if args.ckp is not None: 53 | if os.path.isfile(args.ckp): 54 | print("=> Loading checkpoint '{}'".format(args.ckp)) 55 | try: 56 | model_dict = torch.load(args.ckp)['state_dict'] 57 | except Exception: 58 | model_dict = torch.load(args.ckp, map_location='cpu')['state_dict'] 59 | print("=> Loaded checkpoint '{}'".format(args.ckp)) 60 | else: 61 | print("=> No checkpoint found at '{}'".format(args.ckp)) 62 | return 63 | else: 64 | print("Please specify a model") 65 | return 66 | 67 | model_dict_mod = {} 68 | for key, value in model_dict.items(): 69 | new_key = '.'.join(key.split('.')[2:]) 70 | model_dict_mod[new_key] = value 71 | model = embeddingNet.to(device) 72 | model.load_state_dict(model_dict_mod) 73 | 74 | data_loader = None 75 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 76 | if (args.dataset == 'mnist') or (args.dataset == 'fmnist'): 77 | transform = transform = transforms.Compose([ 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.1307,), (0.3081,)) 80 | ]) 81 | train_dataset = None 82 | if args.dataset == 'mnist': 83 | train_dataset = MNIST('data/MNIST', train=True, download=True, transform=transform) 84 | if args.dataset == 'fmnist': 85 | train_dataset = FashionMNIST('data/FashionMNIST', train=True, download=True, transform=transform) 86 | data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, **kwargs) 87 | else: 88 | print("Dataset {} not supported ".format(args.dataset)) 89 | return 90 | 91 | embeddings, labels = generate_embeddings(data_loader, model) 92 | 93 | final_data = { 94 | 'embeddings': embeddings, 95 | 'labels': labels 96 | } 97 | 98 | dst_dir = os.path.join('data', args.exp_name, 'tSNE') 99 | make_dir_if_not_exist(dst_dir) 100 | 101 | output_file = open(os.path.join(dst_dir, 'tSNE.pkl'), 'wb') 102 | pickle.dump(final_data, output_file) 103 | output_file.close() 104 | 105 | vis_tSNE(embeddings, labels) 106 | 107 | 108 | def generate_embeddings(data_loader, model): 109 | with torch.no_grad(): 110 | model.eval() 111 | labels = None 112 | embeddings = None 113 | for batch_idx, data in tqdm(enumerate(data_loader)): 114 | batch_imgs, batch_labels = data 115 | batch_labels = batch_labels.numpy() 116 | batch_imgs = Variable(batch_imgs.to(device)) 117 | bacth_E = model(batch_imgs) 118 | bacth_E = bacth_E.data.cpu().numpy() 119 | embeddings = np.concatenate((embeddings, bacth_E), axis=0) if embeddings is not None else bacth_E 120 | labels = np.concatenate((labels, batch_labels), axis=0) if labels is not None else batch_labels 121 | return embeddings, labels 122 | 123 | 124 | def vis_tSNE(embeddings, labels): 125 | num_samples = args.tSNE_ns if args.tSNE_ns < embeddings.shape[0] else embeddings.shape[0] 126 | X_embedded = TSNE(n_components=2).fit_transform(embeddings[0:num_samples, :]) 127 | 128 | fig, ax = plt.subplots() 129 | 130 | x, y = X_embedded[:, 0], X_embedded[:, 1] 131 | colors = plt.cm.rainbow(np.linspace(0, 1, 10)) 132 | sc = ax.scatter(x, y, c=labels[0:num_samples], cmap=mpl.colors.ListedColormap(colors)) 133 | plt.colorbar(sc) 134 | plt.savefig(os.path.join('data', args.exp_name, 'tSNE', 'tSNE_' + str(num_samples) + '.jpg')) 135 | plt.show() 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = argparse.ArgumentParser(description='PyTorch Siamese Example') 140 | parser.add_argument('--exp_name', default='exp0', type=str, 141 | help='name of experiment') 142 | parser.add_argument('--cuda', action='store_true', default=False, 143 | help='enables CUDA training') 144 | parser.add_argument('--ckp', default=None, type=str, 145 | help='path to load checkpoint') 146 | parser.add_argument('--dataset', type=str, default='mnist', metavar='M', 147 | help='Dataset (default: mnist)') 148 | 149 | parser.add_argument('--pkl', default=None, type=str, 150 | help='Path to load embeddings') 151 | 152 | parser.add_argument('--tSNE_ns', default=5000, type=int, 153 | help='Num samples to create a tSNE visualisation') 154 | 155 | global args, device 156 | args = parser.parse_args() 157 | args.cuda = args.cuda and torch.cuda.is_available() 158 | cfg_from_file("config/test.yaml") 159 | 160 | if args.cuda: 161 | device = 'cuda' 162 | if args.gpu_devices is None: 163 | args.gpu_devices = [0] 164 | else: 165 | device = 'cpu' 166 | main() 167 | -------------------------------------------------------------------------------- /utils/gen_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_dir_if_not_exist(path): 5 | if not os.path.exists(path): 6 | os.makedirs(path) 7 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def visualise(imgs, txts, dst): 6 | f, axs = plt.subplots(1, len(imgs), figsize=(24, 9)) 7 | f.tight_layout() 8 | for ax, img, txt in zip(axs, imgs, txts): 9 | ax.imshow(img) 10 | ax.set_title(txt, fontsize=20) 11 | plt.subplots_adjust(left=0., right=1, top=0.95, bottom=0.) 12 | if dst is not "": 13 | plt.savefig(dst) 14 | plt.show() 15 | 16 | 17 | def vis_with_paths(img_paths, txts, dst): 18 | imgs = [] 19 | for img_path in img_paths: 20 | imgs.append(cv2.imread(img_path)) 21 | visualise(imgs, txts, dst) 22 | 23 | 24 | def vis_with_paths_and_bboxes(img_details, txts, dst): 25 | imgs = [] 26 | for img_path, bbox in img_details: 27 | img = cv2.imread(img_path) 28 | if bbox is not None: 29 | img = img[bbox['top']:bbox['top'] + bbox['height'], bbox['left']:bbox['left'] + bbox['width']] 30 | imgs.append(img) 31 | visualise(imgs, txts, dst) 32 | --------------------------------------------------------------------------------