├── .gitignore ├── LICENSE ├── README.md ├── data_loader.py ├── datasets.py ├── inference.py ├── losses.py ├── main.py ├── networks.py ├── senet.py ├── setup.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sangwon Lee, Hyong-Keun Kook, and Seung-Wook Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Image Retrieval 2 | A PyTorch framework for an image retrieval task including implementation of [N-pair Loss (NIPS 2016)](http://papers.nips.cc/paper/6199-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective) and [Angular Loss (ICCV 2017)](https://arxiv.org/pdf/1708.01682.pdf). 3 | 4 | ### Loss functions 5 | We implemented loss functions to train the network for image retrieval. 6 | Batch sampler for the loss function borrowed from [here](https://github.com/adambielski/siamese-triplet). 7 | - [N-pair Loss (NIPS 2016)](http://papers.nips.cc/paper/6199-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective): Sohn, Kihyuk. "Improved Deep Metric Learning with Multi-class N-pair Loss Objective," Advances in Neural Information 8 | Processing Systems. 2016. 9 | - [Angular Loss (ICCV 2017)](https://arxiv.org/pdf/1708.01682.pdf): Wang, Jian. "Deep Metric Learning with Angular Loss," ICCV, 2017 10 | 11 | ### Self-attention module 12 | We attached the self-attention module of the [Self-Attention GAN](https://arxiv.org/abs/1805.08318) to conventional classification networks (e.g. DenseNet, ResNet, or SENet). 13 | Implementation of the module borrowed from [here](https://github.com/heykeetae/Self-Attention-GAN). 14 | 15 | ### Data augmentation 16 | We adopted data augmentation techniques used in [Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325). 17 | 18 | ### Post processing 19 | We utilized the following post-processing techniques in the inference phase. 20 | - Moving the origin of the feature space to the center of the feature vectors 21 | - L2-normalization 22 | - [Average query expansion](https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf) 23 | - [Database-side feature augmentation](https://arxiv.org/pdf/1610.07940.pdf) 24 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf_8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms, datasets 9 | from PIL import Image 10 | 11 | 12 | def train_data_loader(data_path, img_size, use_augment=False): 13 | if use_augment: 14 | data_transforms = transforms.Compose([ 15 | transforms.RandomOrder([ 16 | transforms.RandomApply([transforms.ColorJitter(contrast=0.5)], .5), 17 | transforms.Compose([ 18 | transforms.RandomApply([transforms.ColorJitter(saturation=0.5)], .5), 19 | transforms.RandomApply([transforms.ColorJitter(hue=0.1)], .5), 20 | ]) 21 | ]), 22 | transforms.RandomApply([transforms.ColorJitter(brightness=0.125)], .5), 23 | transforms.RandomApply([transforms.RandomRotation(15)], .5), 24 | transforms.RandomResizedCrop(img_size), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 28 | ]) 29 | else: 30 | data_transforms = transforms.Compose([ 31 | transforms.RandomResizedCrop(img_size), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 35 | ]) 36 | 37 | image_dataset = datasets.ImageFolder(data_path, data_transforms) 38 | 39 | return image_dataset 40 | 41 | 42 | def test_data_loader(data_path): 43 | 44 | # return full path 45 | queries_path = [os.path.join(data_path, 'query', path) for path in os.listdir(os.path.join(data_path, 'query'))] 46 | references_path = [os.path.join(data_path, 'reference', path) for path in 47 | os.listdir(os.path.join(data_path, 'reference'))] 48 | 49 | return queries_path, references_path 50 | 51 | 52 | def test_data_generator(data_path, img_size): 53 | img_size = (img_size, img_size) 54 | data_transforms = transforms.Compose([ 55 | transforms.Resize(img_size), 56 | transforms.ToTensor(), 57 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 58 | ]) 59 | 60 | test_image_dataset = TestDataset(data_path, data_transforms) 61 | 62 | return test_image_dataset 63 | 64 | 65 | class TestDataset(Dataset): 66 | def __init__(self, img_path_list, transform=None): 67 | self.img_path_list = img_path_list 68 | self.transform = transform 69 | 70 | def __getitem__(self, index): 71 | img_path = self.img_path_list[index] 72 | img = Image.open(img_path) 73 | if self.transform is not None: 74 | img = self.transform(img) 75 | return img_path, img 76 | 77 | def __len__(self): 78 | return len(self.img_path_list) 79 | 80 | 81 | if __name__ == '__main__': 82 | query, refer = test_data_loader('./') 83 | print(query) 84 | print(refer) 85 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original source: https://github.com/adambielski/siamese-triplet 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.sampler import BatchSampler 10 | 11 | 12 | class BalancedBatchSampler(BatchSampler): 13 | """ 14 | BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples. 15 | Returns batches of size n_classes * n_samples 16 | """ 17 | 18 | def __init__(self, dataset, n_classes, n_samples): 19 | loader = DataLoader(dataset) 20 | self.labels_list = [] 21 | for _, label in loader: 22 | self.labels_list.append(label) 23 | self.labels = torch.LongTensor(self.labels_list) 24 | self.labels_set = list(set(self.labels.numpy())) 25 | self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] 26 | for label in self.labels_set} 27 | for l in self.labels_set: 28 | np.random.shuffle(self.label_to_indices[l]) 29 | self.used_label_indices_count = {label: 0 for label in self.labels_set} 30 | self.count = 0 31 | self.n_classes = n_classes 32 | self.n_samples = n_samples 33 | self.dataset = dataset 34 | self.batch_size = self.n_samples * self.n_classes 35 | 36 | def __iter__(self): 37 | self.count = 0 38 | while self.count + self.batch_size < len(self.dataset): 39 | classes = np.random.choice(self.labels_set, self.n_classes, replace=False) 40 | indices = [] 41 | for class_ in classes: 42 | indices.extend(self.label_to_indices[class_][ 43 | self.used_label_indices_count[class_]:self.used_label_indices_count[ 44 | class_] + self.n_samples]) 45 | self.used_label_indices_count[class_] += self.n_samples 46 | if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]): 47 | np.random.shuffle(self.label_to_indices[class_]) 48 | self.used_label_indices_count[class_] = 0 49 | yield indices 50 | self.count += self.n_classes * self.n_samples 51 | 52 | def __len__(self): 53 | return len(self.dataset) // self.batch_size 54 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf_8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | from torch.utils.data import DataLoader 7 | from data_loader import test_data_generator 8 | 9 | import numpy as np 10 | 11 | 12 | def retrieve(model, queries, db, img_size, infer_batch_size): 13 | 14 | query_paths = queries 15 | reference_paths = db 16 | 17 | query_img_dataset = test_data_generator(queries, img_size=img_size) 18 | reference_img_dataset = test_data_generator(db, img_size=img_size) 19 | 20 | query_loader = DataLoader(query_img_dataset, batch_size=infer_batch_size, shuffle=False, num_workers=4, 21 | pin_memory=True) 22 | reference_loader = DataLoader(reference_img_dataset, batch_size=infer_batch_size, shuffle=False, num_workers=4, 23 | pin_memory=True) 24 | 25 | model.eval() 26 | model.cuda() 27 | 28 | query_paths, query_vecs = batch_process(model, query_loader) 29 | reference_paths, reference_vecs = batch_process(model, reference_loader) 30 | 31 | assert query_paths == queries and reference_paths == db, "order of paths should be same" 32 | 33 | # DBA and AQE 34 | query_vecs, reference_vecs = db_augmentation(query_vecs, reference_vecs, top_k=10) 35 | query_vecs, reference_vecs = average_query_expansion(query_vecs, reference_vecs, top_k=5) 36 | 37 | sim_matrix = calculate_sim_matrix(query_vecs, reference_vecs) 38 | 39 | indices = np.argsort(sim_matrix, axis=1) 40 | indices = np.flip(indices, axis=1) 41 | 42 | retrieval_results = {} 43 | 44 | # Evaluation: mean average precision (mAP) 45 | # You can change this part to fit your evaluation skim 46 | for (i, query) in enumerate(query_paths): 47 | query = query.split('/')[-1].split('.')[0] 48 | ranked_list = [reference_paths[k].split('/')[-1].split('.')[0] for k in indices[i]] 49 | ranked_list = ranked_list[:1000] 50 | 51 | retrieval_results[query] = ranked_list 52 | 53 | return retrieval_results 54 | 55 | 56 | def db_augmentation(query_vecs, reference_vecs, top_k=10): 57 | """ 58 | Database-side feature augmentation (DBA) 59 | Albert Gordo, et al. "End-to-end Learning of Deep Visual Representations for Image Retrieval," 60 | International Journal of Computer Vision. 2017. 61 | https://link.springer.com/article/10.1007/s11263-017-1016-8 62 | """ 63 | weights = np.logspace(0, -2., top_k+1) 64 | 65 | # Query augmentation 66 | sim_mat = calculate_sim_matrix(query_vecs, reference_vecs) 67 | indices = np.argsort(-sim_mat, axis=1) 68 | 69 | top_k_ref = reference_vecs[indices[:, :top_k], :] 70 | query_vecs = np.tensordot(weights, np.concatenate([np.expand_dims(query_vecs, 1), top_k_ref], axis=1), axes=(0, 1)) 71 | 72 | # Reference augmentation 73 | sim_mat = calculate_sim_matrix(reference_vecs, reference_vecs) 74 | indices = np.argsort(-sim_mat, axis=1) 75 | 76 | top_k_ref = reference_vecs[indices[:, :top_k+1], :] 77 | reference_vecs = np.tensordot(weights, top_k_ref, axes=(0, 1)) 78 | 79 | return query_vecs, reference_vecs 80 | 81 | 82 | def average_query_expansion(query_vecs, reference_vecs, top_k=5): 83 | """ 84 | Average Query Expansion (AQE) 85 | Ondrej Chum, et al. "Total Recall: Automatic Query Expansion with a Generative Feature Model for Object Retrieval," 86 | International Conference of Computer Vision. 2007. 87 | https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf 88 | """ 89 | # Query augmentation 90 | sim_mat = calculate_sim_matrix(query_vecs, reference_vecs) 91 | indices = np.argsort(-sim_mat, axis=1) 92 | 93 | top_k_ref_mean = np.mean(reference_vecs[indices[:, :top_k], :], axis=1) 94 | query_vecs = np.concatenate([query_vecs, top_k_ref_mean], axis=1) 95 | 96 | # Reference augmentation 97 | sim_mat = calculate_sim_matrix(reference_vecs, reference_vecs) 98 | indices = np.argsort(-sim_mat, axis=1) 99 | 100 | top_k_ref_mean = np.mean(reference_vecs[indices[:, 1:top_k+1], :], axis=1) 101 | reference_vecs = np.concatenate([reference_vecs, top_k_ref_mean], axis=1) 102 | 103 | return query_vecs, reference_vecs 104 | 105 | 106 | def calculate_sim_matrix(query_vecs, reference_vecs): 107 | query_vecs, reference_vecs = postprocess(query_vecs, reference_vecs) 108 | return np.dot(query_vecs, reference_vecs.T) 109 | 110 | 111 | def batch_process(model, loader): 112 | feature_vecs = [] 113 | img_paths = [] 114 | for data in loader: 115 | paths, inputs = data 116 | feature_vec = _get_feature(model, inputs.cuda()) 117 | feature_vec = feature_vec.detach().cpu().numpy() # (batch_size, channels) 118 | for i in range(feature_vec.shape[0]): 119 | feature_vecs.append(feature_vec[i]) 120 | img_paths = img_paths + paths 121 | 122 | return img_paths, np.asarray(feature_vecs) 123 | 124 | 125 | def _get_features_from(model, x, feature_names): 126 | features = {} 127 | 128 | def save_feature(name): 129 | def hook(m, i, o): 130 | features[name] = o.data 131 | 132 | return hook 133 | 134 | for name, module in model.named_modules(): 135 | _name = name.split('.')[-1] 136 | if _name in feature_names: 137 | module.register_forward_hook(save_feature(_name)) 138 | 139 | model(x) 140 | 141 | return features 142 | 143 | 144 | def _get_feature(model, x): 145 | model_name = model.__class__.__name__ 146 | 147 | if model_name == 'EmbeddingNetwork': 148 | feature = model(x) 149 | elif model_name == 'ResNet': 150 | features = _get_features_from(model, x, ['fc']) 151 | feature = features['fc'] 152 | elif model_name == 'DenseNet': 153 | features = _get_features_from(model, x, ['classifier']) 154 | feature = features['classifier'] 155 | else: 156 | raise ValueError("Invalid model name: {}".format(model_name)) 157 | 158 | return feature 159 | 160 | 161 | def postprocess(query_vecs, reference_vecs): 162 | """ 163 | Postprocessing: 164 | 1) Moving the origin of the feature space to the center of the feature vectors 165 | 2) L2-normalization 166 | """ 167 | # centerize 168 | query_vecs, reference_vecs = _centerize(query_vecs, reference_vecs) 169 | 170 | # l2 normalization 171 | query_vecs = _l2_normalize(query_vecs) 172 | reference_vecs = _l2_normalize(reference_vecs) 173 | 174 | return query_vecs, reference_vecs 175 | 176 | 177 | def _centerize(v1, v2): 178 | concat = np.concatenate([v1, v2], axis=0) 179 | center = np.mean(concat, axis=0) 180 | return v1-center, v2-center 181 | 182 | 183 | def _l2_normalize(v): 184 | norm = np.expand_dims(np.linalg.norm(v, axis=1), axis=1) 185 | if np.any(norm == 0): 186 | return v 187 | return v / norm 188 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | # Constants 7 | N_PAIR = 'n-pair' 8 | ANGULAR = 'angular' 9 | N_PAIR_ANGULAR = 'n-pair-angular' 10 | MAIN_LOSS_CHOICES = (N_PAIR, ANGULAR, N_PAIR_ANGULAR) 11 | 12 | CROSS_ENTROPY = 'cross-entropy' 13 | 14 | 15 | class BlendedLoss(object): 16 | def __init__(self, main_loss_type, cross_entropy_flag): 17 | super(BlendedLoss, self).__init__() 18 | self.main_loss_type = main_loss_type 19 | assert main_loss_type in MAIN_LOSS_CHOICES, "invalid main loss: %s" % main_loss_type 20 | 21 | if self.main_loss_type == N_PAIR: 22 | self.main_loss_fn = NPairLoss() 23 | elif self.main_loss_type == ANGULAR: 24 | self.main_loss_fn = AngularLoss() 25 | elif self.main_loss_type == N_PAIR_ANGULAR: 26 | self.main_loss_fn = NPairAngularLoss() 27 | else: 28 | raise ValueError 29 | 30 | self.cross_entropy_flag = cross_entropy_flag 31 | self.lambda_blending = 0 32 | if cross_entropy_flag: 33 | self.cross_entropy_loss_fn = nn.CrossEntropyLoss() 34 | self.lambda_blending = 0.3 35 | 36 | def calculate_loss(self, target, output_embedding, output_cross_entropy=None): 37 | if target is not None: 38 | target = (target,) 39 | 40 | loss_dict = {} 41 | blended_loss = 0 42 | if self.cross_entropy_flag: 43 | assert output_cross_entropy is not None, "Outputs for cross entropy loss is needed" 44 | 45 | loss_inputs = self._gen_loss_inputs(target, output_cross_entropy) 46 | cross_entropy_loss = self.cross_entropy_loss_fn(*loss_inputs) 47 | blended_loss += self.lambda_blending * cross_entropy_loss 48 | loss_dict[CROSS_ENTROPY + '-loss'] = [cross_entropy_loss.item()] 49 | 50 | loss_inputs = self._gen_loss_inputs(target, output_embedding) 51 | main_loss_outputs = self.main_loss_fn(*loss_inputs) 52 | main_loss = main_loss_outputs[0] if type(main_loss_outputs) in (tuple, list) else main_loss_outputs 53 | blended_loss += (1-self.lambda_blending) * main_loss 54 | loss_dict[self.main_loss_type+'-loss'] = [main_loss.item()] 55 | 56 | return blended_loss, loss_dict 57 | 58 | @staticmethod 59 | def _gen_loss_inputs(target, embedding): 60 | if type(embedding) not in (tuple, list): 61 | embedding = (embedding,) 62 | loss_inputs = embedding 63 | if target is not None: 64 | if type(target) not in (tuple, list): 65 | target = (target,) 66 | loss_inputs += target 67 | return loss_inputs 68 | 69 | 70 | class NPairLoss(nn.Module): 71 | """ 72 | N-Pair loss 73 | Sohn, Kihyuk. "Improved Deep Metric Learning with Multi-class N-pair Loss Objective," Advances in Neural Information 74 | Processing Systems. 2016. 75 | http://papers.nips.cc/paper/6199-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective 76 | """ 77 | 78 | def __init__(self, l2_reg=0.02): 79 | super(NPairLoss, self).__init__() 80 | self.l2_reg = l2_reg 81 | 82 | def forward(self, embeddings, target): 83 | n_pairs, n_negatives = self.get_n_pairs(target) 84 | 85 | if embeddings.is_cuda: 86 | n_pairs = n_pairs.cuda() 87 | n_negatives = n_negatives.cuda() 88 | 89 | anchors = embeddings[n_pairs[:, 0]] # (n, embedding_size) 90 | positives = embeddings[n_pairs[:, 1]] # (n, embedding_size) 91 | negatives = embeddings[n_negatives] # (n, n-1, embedding_size) 92 | 93 | losses = self.n_pair_loss(anchors, positives, negatives) \ 94 | + self.l2_reg * self.l2_loss(anchors, positives) 95 | 96 | return losses 97 | 98 | @staticmethod 99 | def get_n_pairs(labels): 100 | """ 101 | Get index of n-pairs and n-negatives 102 | :param labels: label vector of mini-batch 103 | :return: A tuple of n_pairs (n, 2) 104 | and n_negatives (n, n-1) 105 | """ 106 | labels = labels.cpu().data.numpy() 107 | n_pairs = [] 108 | 109 | for label in set(labels): 110 | label_mask = (labels == label) 111 | label_indices = np.where(label_mask)[0] 112 | if len(label_indices) < 2: 113 | continue 114 | anchor, positive = np.random.choice(label_indices, 2, replace=False) 115 | n_pairs.append([anchor, positive]) 116 | 117 | n_pairs = np.array(n_pairs) 118 | 119 | n_negatives = [] 120 | for i in range(len(n_pairs)): 121 | negative = np.concatenate([n_pairs[:i, 1], n_pairs[i+1:, 1]]) 122 | n_negatives.append(negative) 123 | 124 | n_negatives = np.array(n_negatives) 125 | 126 | return torch.LongTensor(n_pairs), torch.LongTensor(n_negatives) 127 | 128 | @staticmethod 129 | def n_pair_loss(anchors, positives, negatives): 130 | """ 131 | Calculates N-Pair loss 132 | :param anchors: A torch.Tensor, (n, embedding_size) 133 | :param positives: A torch.Tensor, (n, embedding_size) 134 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 135 | :return: A scalar 136 | """ 137 | anchors = torch.unsqueeze(anchors, dim=1) # (n, 1, embedding_size) 138 | positives = torch.unsqueeze(positives, dim=1) # (n, 1, embedding_size) 139 | 140 | x = torch.matmul(anchors, (negatives - positives).transpose(1, 2)) # (n, 1, n-1) 141 | x = torch.sum(torch.exp(x), 2) # (n, 1) 142 | loss = torch.mean(torch.log(1+x)) 143 | return loss 144 | 145 | @staticmethod 146 | def l2_loss(anchors, positives): 147 | """ 148 | Calculates L2 norm regularization loss 149 | :param anchors: A torch.Tensor, (n, embedding_size) 150 | :param positives: A torch.Tensor, (n, embedding_size) 151 | :return: A scalar 152 | """ 153 | return torch.sum(anchors ** 2 + positives ** 2) / anchors.shape[0] 154 | 155 | 156 | class AngularLoss(NPairLoss): 157 | """ 158 | Angular loss 159 | Wang, Jian. "Deep Metric Learning with Angular Loss," ICCV, 2017 160 | https://arxiv.org/pdf/1708.01682.pdf 161 | """ 162 | 163 | def __init__(self, l2_reg=0.02, angle_bound=1., lambda_ang=2): 164 | super(AngularLoss, self).__init__() 165 | self.l2_reg = l2_reg 166 | self.angle_bound = angle_bound 167 | self.lambda_ang = lambda_ang 168 | self.softplus = nn.Softplus() 169 | 170 | def forward(self, embeddings, target): 171 | n_pairs, n_negatives = self.get_n_pairs(target) 172 | 173 | if embeddings.is_cuda: 174 | n_pairs = n_pairs.cuda() 175 | n_negatives = n_negatives.cuda() 176 | 177 | anchors = embeddings[n_pairs[:, 0]] # (n, embedding_size) 178 | positives = embeddings[n_pairs[:, 1]] # (n, embedding_size) 179 | negatives = embeddings[n_negatives] # (n, n-1, embedding_size) 180 | 181 | losses = self.angular_loss(anchors, positives, negatives, self.angle_bound) \ 182 | + self.l2_reg * self.l2_loss(anchors, positives) 183 | 184 | return losses 185 | 186 | @staticmethod 187 | def angular_loss(anchors, positives, negatives, angle_bound=1.): 188 | """ 189 | Calculates angular loss 190 | :param anchors: A torch.Tensor, (n, embedding_size) 191 | :param positives: A torch.Tensor, (n, embedding_size) 192 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 193 | :param angle_bound: tan^2 angle 194 | :return: A scalar 195 | """ 196 | anchors = torch.unsqueeze(anchors, dim=1) # (n, 1, embedding_size) 197 | positives = torch.unsqueeze(positives, dim=1) # (n, 1, embedding_size) 198 | 199 | x = 4. * angle_bound * torch.matmul((anchors + positives), negatives.transpose(1, 2)) \ 200 | - 2. * (1. + angle_bound) * torch.matmul(anchors, positives.transpose(1, 2)) # (n, 1, n-1) 201 | 202 | # Preventing overflow 203 | with torch.no_grad(): 204 | t = torch.max(x, dim=2)[0] 205 | 206 | x = torch.exp(x - t.unsqueeze(dim=1)) 207 | x = torch.log(torch.exp(-t) + torch.sum(x, 2)) 208 | loss = torch.mean(t + x) 209 | 210 | return loss 211 | 212 | 213 | class NPairAngularLoss(AngularLoss): 214 | """ 215 | Angular loss 216 | Wang, Jian. "Deep Metric Learning with Angular Loss," ICCV, 2017 217 | https://arxiv.org/pdf/1708.01682.pdf 218 | """ 219 | 220 | def __init__(self, l2_reg=0.02, angle_bound=1., lambda_ang=2): 221 | super(NPairAngularLoss, self).__init__() 222 | self.l2_reg = l2_reg 223 | self.angle_bound = angle_bound 224 | self.lambda_ang = lambda_ang 225 | 226 | def forward(self, embeddings, target): 227 | n_pairs, n_negatives = self.get_n_pairs(target) 228 | 229 | if embeddings.is_cuda: 230 | n_pairs = n_pairs.cuda() 231 | n_negatives = n_negatives.cuda() 232 | 233 | anchors = embeddings[n_pairs[:, 0]] # (n, embedding_size) 234 | positives = embeddings[n_pairs[:, 1]] # (n, embedding_size) 235 | negatives = embeddings[n_negatives] # (n, n-1, embedding_size) 236 | 237 | losses = self.n_pair_angular_loss(anchors, positives, negatives, self.angle_bound) \ 238 | + self.l2_reg * self.l2_loss(anchors, positives) 239 | 240 | return losses 241 | 242 | def n_pair_angular_loss(self, anchors, positives, negatives, angle_bound=1.): 243 | """ 244 | Calculates N-Pair angular loss 245 | :param anchors: A torch.Tensor, (n, embedding_size) 246 | :param positives: A torch.Tensor, (n, embedding_size) 247 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 248 | :param angle_bound: tan^2 angle 249 | :return: A scalar, n-pair_loss + lambda * angular_loss 250 | """ 251 | n_pair = self.n_pair_loss(anchors, positives, negatives) 252 | angular = self.angular_loss(anchors, positives, negatives, angle_bound) 253 | 254 | return (n_pair + self.lambda_ang * angular) / (1+self.lambda_ang) 255 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf_8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import argparse 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | 12 | from data_loader import train_data_loader, test_data_loader 13 | 14 | # Load initial models 15 | from networks import EmbeddingNetwork 16 | 17 | # Load batch sampler and train loss 18 | from datasets import BalancedBatchSampler 19 | from losses import BlendedLoss, MAIN_LOSS_CHOICES 20 | 21 | from trainer import fit 22 | from inference import retrieve 23 | 24 | 25 | def load(file_path): 26 | model.load_state_dict(torch.load(file_path)) 27 | print('model loaded!') 28 | return model 29 | 30 | 31 | def infer(model, queries, db): 32 | retrieval_results = retrieve(model, queries, db, input_size, infer_batch_size) 33 | 34 | return list(zip(range(len(retrieval_results)), retrieval_results.items())) 35 | 36 | 37 | def get_arguments(): 38 | args = argparse.ArgumentParser() 39 | 40 | args.add_argument('--dataset-path', type=str) 41 | args.add_argument('--model-save-dir', type=str) 42 | args.add_argument('--model-to-test', type=str) 43 | 44 | # Hyperparameters 45 | args.add_argument('--epochs', type=int, default=20) 46 | args.add_argument('--model', type=str, 47 | choices=['densenet161', 'resnet101', 'inceptionv3', 'seresnext'], 48 | default='densenet161') 49 | args.add_argument('--input-size', type=int, default=224, help='size of input image') 50 | args.add_argument('--num-classes', type=int, default=64, help='number of classes for batch sampler') 51 | args.add_argument('--num-samples', type=int, default=4, help='number of samples per class for batch sampler') 52 | args.add_argument('--embedding-dim', type=int, default=128, help='size of embedding dimension') 53 | args.add_argument('--feature-extracting', type=bool, default=False) 54 | args.add_argument('--use-pretrained', type=bool, default=True) 55 | args.add_argument('--lr', type=float, default=1e-4) 56 | args.add_argument('--scheduler', type=str, choices=['StepLR', 'MultiStepLR']) 57 | args.add_argument('--attention', action='store_true') 58 | args.add_argument('--loss-type', type=str, choices=MAIN_LOSS_CHOICES) 59 | args.add_argument('--cross-entropy', action='store_true') 60 | args.add_argument('--use-augmentation', action='store_true') 61 | 62 | # Mode selection 63 | args.add_argument('--mode', type=str, default='train', help='mode selection: train or test.') 64 | 65 | return args.parse_args() 66 | 67 | 68 | if __name__ == '__main__': 69 | config = get_arguments() 70 | 71 | dataset_path = config.dataset_path 72 | 73 | # Model parameters 74 | model_name = config.model 75 | input_size = config.input_size 76 | embedding_dim = config.embedding_dim 77 | feature_extracting = config.feature_extracting 78 | use_pretrained = config.use_pretrained 79 | attention_flag = config.attention 80 | 81 | # Training parameters 82 | nb_epoch = config.epochs 83 | loss_type = config.loss_type 84 | cross_entropy_flag = config.cross_entropy 85 | scheduler_name = config.scheduler 86 | lr = config.lr 87 | 88 | # Mini-batch parameters 89 | num_classes = config.num_classes 90 | num_samples = config.num_samples 91 | use_augmentation = config.use_augmentation 92 | 93 | infer_batch_size = 64 94 | log_interval = 50 95 | 96 | """ Model """ 97 | model = EmbeddingNetwork(model_name=model_name, 98 | embedding_dim=embedding_dim, 99 | feature_extracting=feature_extracting, 100 | use_pretrained=use_pretrained, 101 | attention_flag=attention_flag, 102 | cross_entropy_flag=cross_entropy_flag) 103 | 104 | if torch.cuda.device_count() > 1: 105 | model = nn.DataParallel(model) 106 | 107 | if config.mode == 'train': 108 | 109 | """ Load data """ 110 | print('dataset path', dataset_path) 111 | train_dataset_path = dataset_path + '/train/train_data' 112 | 113 | img_dataset = train_data_loader(data_path=train_dataset_path, img_size=input_size, 114 | use_augment=use_augmentation) 115 | 116 | # Balanced batch sampler and online train loader 117 | train_batch_sampler = BalancedBatchSampler(img_dataset, n_classes=num_classes, n_samples=num_samples) 118 | online_train_loader = torch.utils.data.DataLoader(img_dataset, 119 | batch_sampler=train_batch_sampler, 120 | num_workers=4, 121 | pin_memory=True) 122 | 123 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 124 | 125 | # Gather the parameters to be optimized/updated. 126 | params_to_update = model.parameters() 127 | print("Params to learn:") 128 | if feature_extracting: 129 | params_to_update = [] 130 | for name, param in model.named_parameters(): 131 | if param.requires_grad: 132 | params_to_update.append(param) 133 | print("\t", name) 134 | else: 135 | for name, param in model.named_parameters(): 136 | if param.requires_grad: 137 | print("\t", name) 138 | 139 | # Send the model to GPU 140 | model = model.to(device) 141 | 142 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4) 143 | if scheduler_name == 'StepLR': 144 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.1) 145 | elif scheduler_name == 'MultiStepLR': 146 | if use_augmentation: 147 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30], gamma=0.1) 148 | else: 149 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 15, 20], gamma=0.1) 150 | else: 151 | raise ValueError('Invalid scheduler') 152 | 153 | # Loss function 154 | loss_fn = BlendedLoss(loss_type, cross_entropy_flag) 155 | 156 | # Train (fine-tune) model 157 | fit(online_train_loader, model, loss_fn, optimizer, scheduler, nb_epoch, 158 | device=device, log_interval=log_interval, save_model_to=config.model_save_dir) 159 | 160 | elif config.mode == 'test': 161 | test_dataset_path = dataset_path + '/test/test_data' 162 | queries, db = test_data_loader(test_dataset_path) 163 | model = load(file_path=config.model_to_test) 164 | result_dict = infer(model, queries, db) 165 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | SE-ResNet, SE_ResNeXt codes are gently borrowed from 3 | https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models 9 | 10 | from senet import se_resnext101_32x4d 11 | 12 | 13 | class BaseNetwork(nn.Module): 14 | """ Load Pretrained Module """ 15 | 16 | def __init__(self, model_name, embedding_dim, feature_extracting, use_pretrained): 17 | super(BaseNetwork, self).__init__() 18 | self.model_name = model_name 19 | self.embedding_dim = embedding_dim 20 | self.feature_extracting = feature_extracting 21 | self.use_pretrained = use_pretrained 22 | 23 | self.model_ft = initialize_model(self.model_name, 24 | self.embedding_dim, 25 | self.feature_extracting, 26 | self.use_pretrained) 27 | 28 | def forward(self, x): 29 | out = self.model_ft(x) 30 | return out 31 | 32 | 33 | class SelfAttention(nn.Module): 34 | """ Self attention Layer 35 | https://github.com/heykeetae/Self-Attention-GAN""" 36 | 37 | def __init__(self, in_dim, activation): 38 | super(SelfAttention, self).__init__() 39 | self.chanel_in = in_dim 40 | self.activation = activation 41 | 42 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 43 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) 44 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 45 | self.gamma = nn.Parameter(torch.zeros(1)) 46 | 47 | self.softmax = nn.Softmax(dim=-1) 48 | 49 | def forward(self, x): 50 | """ 51 | inputs : 52 | x : input feature maps( B X C X W X H) 53 | returns : 54 | out : self attention value + input feature 55 | attention: B X N X N (N is Width*Height) 56 | """ 57 | m_batchsize, C, width, height = x.size() 58 | proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) 59 | proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) 60 | energy = torch.bmm(proj_query, proj_key) # transpose check 61 | attention = self.softmax(energy) # BX (N) X (N) 62 | 63 | proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N 64 | 65 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 66 | out = out.view(m_batchsize, C, width, height) 67 | 68 | out = self.gamma * out + x 69 | return out 70 | 71 | 72 | class EmbeddingNetwork(BaseNetwork): 73 | """ Wrapping Modules to the BaseNetwork """ 74 | 75 | def __init__(self, model_name, embedding_dim, feature_extracting, use_pretrained, 76 | attention_flag=False, cross_entropy_flag=False, edge_cutting=False): 77 | super(EmbeddingNetwork, self).__init__(model_name, embedding_dim, feature_extracting, use_pretrained) 78 | self.attention_flag = attention_flag 79 | self.cross_entropy_flag = cross_entropy_flag 80 | self.edge_cutting = edge_cutting 81 | 82 | self.model_ft_convs = nn.Sequential(*list(self.model_ft.children())[:-1]) 83 | self.model_ft_embedding = nn.Sequential(*list(self.model_ft.children())[-1:]) 84 | 85 | if self.attention_flag: 86 | if self.model_name == 'densenet161': 87 | self.attention = SelfAttention(2208, 'relu') 88 | elif self.model_name == 'resnet101': 89 | self.attention = SelfAttention(2048, 'relu') 90 | elif self.model_name == 'inceptionv3': 91 | self.attention = SelfAttention(2048, 'relu') 92 | elif self.model_name == 'seresnext': 93 | self.attention = SelfAttention(2048, 'relu') 94 | 95 | if self.cross_entropy_flag: 96 | self.fc_cross_entropy = nn.Linear(self.model_ft.classifier.in_features, 1000) 97 | 98 | def forward(self, x): 99 | x = self.model_ft_convs(x) 100 | x = F.relu(x, inplace=True) 101 | 102 | if self.attention_flag: 103 | x = self.attention(x) 104 | 105 | if self.edge_cutting: 106 | x = F.adaptive_avg_pool2d(x[:, :, 1:-1, 1:-1], output_size=1).view(x.size(0), -1) 107 | else: 108 | x = F.adaptive_avg_pool2d(x, output_size=1).view(x.size(0), -1) 109 | # x = gem(x).view(x.size(0), -1) 110 | out_embedding = self.model_ft_embedding(x) 111 | 112 | if self.cross_entropy_flag: 113 | out_cross_entropy = self.fc_cross_entropy(x) 114 | return out_embedding, out_cross_entropy 115 | else: 116 | return out_embedding 117 | 118 | 119 | def set_parameter_requires_grad(model, feature_extracting): 120 | if feature_extracting: 121 | for param in model.parameters(): 122 | param.requires_grad = False 123 | 124 | 125 | def initialize_model(model_name, embedding_dim, feature_extracting, use_pretrained=True): 126 | if model_name == "densenet161": 127 | model_ft = models.densenet161(pretrained=use_pretrained) 128 | set_parameter_requires_grad(model_ft, feature_extracting) 129 | num_features = model_ft.classifier.in_features 130 | model_ft.classifier = nn.Linear(num_features, embedding_dim) 131 | elif model_name == "resnet101": 132 | model_ft = models.resnet101(pretrained=use_pretrained) 133 | set_parameter_requires_grad(model_ft, feature_extracting) 134 | num_features = model_ft.fc.in_features 135 | model_ft.fc = nn.Linear(num_features, embedding_dim) 136 | elif model_name == "inceptionv3": 137 | model_ft = models.inception_v3(pretrained=use_pretrained) 138 | set_parameter_requires_grad(model_ft, feature_extracting) 139 | num_features = model_ft.fc.in_features 140 | model_ft.fc = nn.Linear(num_features, embedding_dim) 141 | elif model_name == "seresnext": 142 | model_ft = se_resnext101_32x4d(num_classes=1000) 143 | set_parameter_requires_grad(model_ft, feature_extracting) 144 | num_features = model_ft.last_linear.in_features 145 | model_ft.last_linear = nn.Linear(num_features, embedding_dim) 146 | else: 147 | raise ValueError 148 | 149 | return model_ft 150 | 151 | 152 | # GeM Pooling 153 | def gem(x, p=3, eps=1e-6): 154 | return F.adaptive_avg_pool2d(x.clamp(min=eps).pow(p), output_size=1).pow(1. / p) 155 | -------------------------------------------------------------------------------- /senet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet code gently borrowed from 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | """ 5 | from __future__ import print_function, division, absolute_import 6 | from collections import OrderedDict 7 | import math 8 | 9 | import torch.nn as nn 10 | from torch.utils import model_zoo 11 | 12 | import torch 13 | 14 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 15 | 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 16 | 17 | pretrained_settings = { 18 | 'senet154': { 19 | 'imagenet': { 20 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', 21 | 'input_space': 'RGB', 22 | 'input_size': [3, 224, 224], 23 | 'input_range': [0, 1], 24 | 'mean': [0.485, 0.456, 0.406], 25 | 'std': [0.229, 0.224, 0.225], 26 | 'num_classes': 1000 27 | } 28 | }, 29 | 'se_resnet50': { 30 | 'imagenet': { 31 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', 32 | 'input_space': 'RGB', 33 | 'input_size': [3, 224, 224], 34 | 'input_range': [0, 1], 35 | 'mean': [0.485, 0.456, 0.406], 36 | 'std': [0.229, 0.224, 0.225], 37 | 'num_classes': 1000 38 | } 39 | }, 40 | 'se_resnet101': { 41 | 'imagenet': { 42 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', 43 | 'input_space': 'RGB', 44 | 'input_size': [3, 224, 224], 45 | 'input_range': [0, 1], 46 | 'mean': [0.485, 0.456, 0.406], 47 | 'std': [0.229, 0.224, 0.225], 48 | 'num_classes': 1000 49 | } 50 | }, 51 | 'se_resnet152': { 52 | 'imagenet': { 53 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', 54 | 'input_space': 'RGB', 55 | 'input_size': [3, 224, 224], 56 | 'input_range': [0, 1], 57 | 'mean': [0.485, 0.456, 0.406], 58 | 'std': [0.229, 0.224, 0.225], 59 | 'num_classes': 1000 60 | } 61 | }, 62 | 'se_resnext50_32x4d': { 63 | 'imagenet': { 64 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 65 | 'input_space': 'RGB', 66 | 'input_size': [3, 224, 224], 67 | 'input_range': [0, 1], 68 | 'mean': [0.485, 0.456, 0.406], 69 | 'std': [0.229, 0.224, 0.225], 70 | 'num_classes': 1000 71 | } 72 | }, 73 | 'se_resnext101_32x4d': { 74 | 'imagenet': { 75 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 76 | 'input_space': 'RGB', 77 | 'input_size': [3, 224, 224], 78 | 'input_range': [0, 1], 79 | 'mean': [0.485, 0.456, 0.406], 80 | 'std': [0.229, 0.224, 0.225], 81 | 'num_classes': 1000 82 | } 83 | }, 84 | } 85 | 86 | 87 | class SEModule(nn.Module): 88 | 89 | def __init__(self, channels, reduction): 90 | super(SEModule, self).__init__() 91 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 92 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 93 | padding=0) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 96 | padding=0) 97 | self.sigmoid = nn.Sigmoid() 98 | 99 | def forward(self, x): 100 | module_input = x 101 | x = self.avg_pool(x) 102 | x = self.fc1(x) 103 | x = self.relu(x) 104 | x = self.fc2(x) 105 | x = self.sigmoid(x) 106 | return module_input * x 107 | 108 | 109 | class Bottleneck(nn.Module): 110 | """ 111 | Base class for bottlenecks that implements `forward()` method. 112 | """ 113 | def forward(self, x): 114 | residual = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out = self.se_module(out) + residual 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class SEBottleneck(Bottleneck): 137 | """ 138 | Bottleneck for SENet154. 139 | """ 140 | expansion = 4 141 | 142 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 143 | downsample=None): 144 | super(SEBottleneck, self).__init__() 145 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 146 | self.bn1 = nn.BatchNorm2d(planes * 2) 147 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 148 | stride=stride, padding=1, groups=groups, 149 | bias=False) 150 | self.bn2 = nn.BatchNorm2d(planes * 4) 151 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 152 | bias=False) 153 | self.bn3 = nn.BatchNorm2d(planes * 4) 154 | self.relu = nn.ReLU(inplace=True) 155 | self.se_module = SEModule(planes * 4, reduction=reduction) 156 | self.downsample = downsample 157 | self.stride = stride 158 | 159 | 160 | class SEResNetBottleneck(Bottleneck): 161 | """ 162 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 163 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 164 | (the latter is used in the torchvision implementation of ResNet). 165 | """ 166 | expansion = 4 167 | 168 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 169 | downsample=None): 170 | super(SEResNetBottleneck, self).__init__() 171 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 172 | stride=stride) 173 | self.bn1 = nn.BatchNorm2d(planes) 174 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 175 | groups=groups, bias=False) 176 | self.bn2 = nn.BatchNorm2d(planes) 177 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 178 | self.bn3 = nn.BatchNorm2d(planes * 4) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.se_module = SEModule(planes * 4, reduction=reduction) 181 | self.downsample = downsample 182 | self.stride = stride 183 | 184 | 185 | class SEResNeXtBottleneck(Bottleneck): 186 | """ 187 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 188 | """ 189 | expansion = 4 190 | 191 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 192 | downsample=None, base_width=4): 193 | super(SEResNeXtBottleneck, self).__init__() 194 | width = math.floor(planes * (base_width / 64)) * groups 195 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 196 | stride=1) 197 | self.bn1 = nn.BatchNorm2d(width) 198 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 199 | padding=1, groups=groups, bias=False) 200 | self.bn2 = nn.BatchNorm2d(width) 201 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 202 | self.bn3 = nn.BatchNorm2d(planes * 4) 203 | self.relu = nn.ReLU(inplace=True) 204 | self.se_module = SEModule(planes * 4, reduction=reduction) 205 | self.downsample = downsample 206 | self.stride = stride 207 | 208 | 209 | class SENet(nn.Module): 210 | 211 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 212 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 213 | downsample_padding=1, num_classes=1000): 214 | """ 215 | Parameters 216 | ---------- 217 | block (nn.Module): Bottleneck class. 218 | - For SENet154: SEBottleneck 219 | - For SE-ResNet models: SEResNetBottleneck 220 | - For SE-ResNeXt models: SEResNeXtBottleneck 221 | layers (list of ints): Number of residual blocks for 4 layers of the 222 | network (layer1...layer4). 223 | groups (int): Number of groups for the 3x3 convolution in each 224 | bottleneck block. 225 | - For SENet154: 64 226 | - For SE-ResNet models: 1 227 | - For SE-ResNeXt models: 32 228 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 229 | - For all models: 16 230 | dropout_p (float or None): Drop probability for the Dropout layer. 231 | If `None` the Dropout layer is not used. 232 | - For SENet154: 0.2 233 | - For SE-ResNet models: None 234 | - For SE-ResNeXt models: None 235 | inplanes (int): Number of input channels for layer1. 236 | - For SENet154: 128 237 | - For SE-ResNet models: 64 238 | - For SE-ResNeXt models: 64 239 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 240 | a single 7x7 convolution in layer0. 241 | - For SENet154: True 242 | - For SE-ResNet models: False 243 | - For SE-ResNeXt models: False 244 | downsample_kernel_size (int): Kernel size for downsampling convolutions 245 | in layer2, layer3 and layer4. 246 | - For SENet154: 3 247 | - For SE-ResNet models: 1 248 | - For SE-ResNeXt models: 1 249 | downsample_padding (int): Padding for downsampling convolutions in 250 | layer2, layer3 and layer4. 251 | - For SENet154: 1 252 | - For SE-ResNet models: 0 253 | - For SE-ResNeXt models: 0 254 | num_classes (int): Number of outputs in `last_linear` layer. 255 | - For all models: 1000 256 | """ 257 | super(SENet, self).__init__() 258 | self.inplanes = inplanes 259 | if input_3x3: 260 | layer0_modules = [ 261 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 262 | bias=False)), 263 | ('bn1', nn.BatchNorm2d(64)), 264 | ('relu1', nn.ReLU(inplace=True)), 265 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 266 | bias=False)), 267 | ('bn2', nn.BatchNorm2d(64)), 268 | ('relu2', nn.ReLU(inplace=True)), 269 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 270 | bias=False)), 271 | ('bn3', nn.BatchNorm2d(inplanes)), 272 | ('relu3', nn.ReLU(inplace=True)), 273 | ] 274 | else: 275 | layer0_modules = [ 276 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 277 | padding=3, bias=False)), 278 | ('bn1', nn.BatchNorm2d(inplanes)), 279 | ('relu1', nn.ReLU(inplace=True)), 280 | ] 281 | # To preserve compatibility with Caffe weights `ceil_mode=True` 282 | # is used instead of `padding=1`. 283 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 284 | ceil_mode=True))) 285 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 286 | self.layer1 = self._make_layer( 287 | block, 288 | planes=64, 289 | blocks=layers[0], 290 | groups=groups, 291 | reduction=reduction, 292 | downsample_kernel_size=1, 293 | downsample_padding=0 294 | ) 295 | self.layer2 = self._make_layer( 296 | block, 297 | planes=128, 298 | blocks=layers[1], 299 | stride=2, 300 | groups=groups, 301 | reduction=reduction, 302 | downsample_kernel_size=downsample_kernel_size, 303 | downsample_padding=downsample_padding 304 | ) 305 | self.layer3 = self._make_layer( 306 | block, 307 | planes=256, 308 | blocks=layers[2], 309 | stride=2, 310 | groups=groups, 311 | reduction=reduction, 312 | downsample_kernel_size=downsample_kernel_size, 313 | downsample_padding=downsample_padding 314 | ) 315 | self.layer4 = self._make_layer( 316 | block, 317 | planes=512, 318 | blocks=layers[3], 319 | stride=2, 320 | groups=groups, 321 | reduction=reduction, 322 | downsample_kernel_size=downsample_kernel_size, 323 | downsample_padding=downsample_padding 324 | ) 325 | self.avg_pool = nn.AvgPool2d(7, stride=1) 326 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 327 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 328 | 329 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 330 | downsample_kernel_size=1, downsample_padding=0): 331 | downsample = None 332 | if stride != 1 or self.inplanes != planes * block.expansion: 333 | downsample = nn.Sequential( 334 | nn.Conv2d(self.inplanes, planes * block.expansion, 335 | kernel_size=downsample_kernel_size, stride=stride, 336 | padding=downsample_padding, bias=False), 337 | nn.BatchNorm2d(planes * block.expansion), 338 | ) 339 | 340 | layers = [] 341 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 342 | downsample)) 343 | self.inplanes = planes * block.expansion 344 | for i in range(1, blocks): 345 | layers.append(block(self.inplanes, planes, groups, reduction)) 346 | 347 | return nn.Sequential(*layers) 348 | 349 | def features(self, x): 350 | x = self.layer0(x) 351 | x = self.layer1(x) 352 | x = self.layer2(x) 353 | x = self.layer3(x) 354 | x = self.layer4(x) 355 | #print(x.size()) 356 | return x 357 | 358 | def logits(self, x): 359 | x = self.avg_pool(x) 360 | if self.dropout is not None: 361 | x = self.dropout(x) 362 | x = x.view(x.size(0), -1) 363 | x = self.last_linear(x) 364 | return x 365 | 366 | def forward(self, x): 367 | x = self.features(x) 368 | x = self.logits(x) 369 | return x 370 | 371 | 372 | def initialize_pretrained_model(model, num_classes, settings): 373 | assert num_classes == settings['num_classes'], \ 374 | 'num_classes should be {}, but is {}'.format( 375 | settings['num_classes'], num_classes) 376 | model.load_state_dict(model_zoo.load_url(settings['url'])) 377 | model.input_space = settings['input_space'] 378 | model.input_size = settings['input_size'] 379 | model.input_range = settings['input_range'] 380 | model.mean = settings['mean'] 381 | model.std = settings['std'] 382 | 383 | 384 | def senet154(num_classes=1000, pretrained='imagenet'): 385 | model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, 386 | dropout_p=0.2, num_classes=num_classes) 387 | if pretrained is not None: 388 | settings = pretrained_settings['senet154'][pretrained] 389 | initialize_pretrained_model(model, num_classes, settings) 390 | return model 391 | 392 | 393 | def se_resnet50(num_classes=1000, pretrained='imagenet'): 394 | model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, 395 | dropout_p=None, inplanes=64, input_3x3=False, 396 | downsample_kernel_size=1, downsample_padding=0, 397 | num_classes=num_classes) 398 | if pretrained is not None: 399 | settings = pretrained_settings['se_resnet50'][pretrained] 400 | initialize_pretrained_model(model, num_classes, settings) 401 | return model 402 | 403 | 404 | def se_resnet101(num_classes=1000, pretrained='imagenet'): 405 | model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, 406 | dropout_p=None, inplanes=64, input_3x3=False, 407 | downsample_kernel_size=1, downsample_padding=0, 408 | num_classes=num_classes) 409 | if pretrained is not None: 410 | settings = pretrained_settings['se_resnet101'][pretrained] 411 | initialize_pretrained_model(model, num_classes, settings) 412 | return model 413 | 414 | 415 | def se_resnet152(num_classes=1000, pretrained='imagenet'): 416 | model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, 417 | dropout_p=None, inplanes=64, input_3x3=False, 418 | downsample_kernel_size=1, downsample_padding=0, 419 | num_classes=num_classes) 420 | if pretrained is not None: 421 | settings = pretrained_settings['se_resnet152'][pretrained] 422 | initialize_pretrained_model(model, num_classes, settings) 423 | return model 424 | 425 | 426 | def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'): 427 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 428 | dropout_p=None, inplanes=64, input_3x3=False, 429 | downsample_kernel_size=1, downsample_padding=0, 430 | num_classes=num_classes) 431 | if pretrained is not None: 432 | settings = pretrained_settings['se_resnext50_32x4d'][pretrained] 433 | initialize_pretrained_model(model, num_classes, settings) 434 | return model 435 | 436 | 437 | def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'): 438 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 439 | dropout_p=None, inplanes=64, input_3x3=False, 440 | downsample_kernel_size=1, downsample_padding=0, 441 | num_classes=num_classes) 442 | if pretrained is not None: 443 | settings = pretrained_settings['se_resnext101_32x4d'][pretrained] 444 | initialize_pretrained_model(model, num_classes, settings) 445 | return model 446 | 447 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | version='1.0', 5 | description='image retrieval', 6 | install_requires=[ 7 | 'torch==1.0.0', 8 | 'torchvision==0.2.1', 9 | ] 10 | ) 11 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def save(model, ckpt_num, dir_name): 8 | os.makedirs(dir_name, exist_ok=True) 9 | if torch.cuda.device_count() > 1: 10 | torch.save(model.module.state_dict(), os.path.join(dir_name, 'model_%s' % ckpt_num)) 11 | else: 12 | torch.save(model.state_dict(), os.path.join(dir_name, 'model_%s' % ckpt_num)) 13 | print('model saved!') 14 | 15 | 16 | def fit(train_loader, model, loss_fn, optimizer, scheduler, nb_epoch, 17 | device, log_interval, start_epoch=0, save_model_to='/tmp/save_model_to'): 18 | """ 19 | Loaders, model, loss function and metrics should work together for a given task, 20 | i.e. The model should be able to process data output of loaders, 21 | loss function should process target output of loaders and outputs from the model 22 | 23 | Examples: Classification: batch loader, classification model, NLL loss, accuracy metric 24 | Siamese network: Siamese loader, siamese model, contrastive loss 25 | Online triplet learning: batch loader, embedding model, online triplet loss 26 | """ 27 | 28 | # Save pre-trained model 29 | save(model, 0, save_model_to) 30 | 31 | for epoch in range(0, start_epoch): 32 | scheduler.step() 33 | 34 | for epoch in range(start_epoch, nb_epoch): 35 | scheduler.step() 36 | 37 | # Train stage 38 | train_loss = train_epoch(train_loader, model, loss_fn, optimizer, device, log_interval) 39 | 40 | log_dict = {'epoch': epoch + 1, 41 | 'epoch_total': nb_epoch, 42 | 'loss': float(train_loss), 43 | } 44 | 45 | message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, nb_epoch, train_loss) 46 | 47 | print(message) 48 | print(log_dict) 49 | if (epoch + 1) % 5 == 0: 50 | save(model, epoch + 1, save_model_to) 51 | 52 | 53 | def train_epoch(train_loader, model, loss_fn, optimizer, device, log_interval): 54 | model.train() 55 | total_loss = 0 56 | 57 | for batch_idx, (data, target) in enumerate(train_loader): 58 | target = target if len(target) > 0 else None 59 | if not type(data) in (tuple, list): 60 | data = (data,) 61 | 62 | data = tuple(d.to(device) for d in data) 63 | if target is not None: 64 | target = target.to(device) 65 | 66 | optimizer.zero_grad() 67 | if loss_fn.cross_entropy_flag: 68 | output_embedding, output_cross_entropy = model(*data) 69 | blended_loss, losses = loss_fn.calculate_loss(target, output_embedding, output_cross_entropy) 70 | else: 71 | output_embedding = model(*data) 72 | blended_loss, losses = loss_fn.calculate_loss(target, output_embedding) 73 | total_loss += blended_loss.item() 74 | blended_loss.backward() 75 | 76 | optimizer.step() 77 | 78 | # Print log 79 | if batch_idx % log_interval == 0: 80 | message = 'Train: [{}/{} ({:.0f}%)]'.format( 81 | batch_idx * len(data[0]), len(train_loader.dataset), 100. * batch_idx / len(train_loader)) 82 | for name, value in losses.items(): 83 | message += '\t{}: {:.6f}'.format(name, np.mean(value)) 84 | 85 | print(message) 86 | 87 | total_loss /= (batch_idx + 1) 88 | return total_loss 89 | --------------------------------------------------------------------------------