├── .gitignore ├── README.md ├── architectures ├── __init__.py ├── bninception.py ├── googlenet.py ├── resnet50.py └── resnet50_diml.py ├── batchminer ├── __init__.py ├── distance.py ├── intra_random.py ├── lifted.py ├── npair.py ├── parametric.py ├── random.py ├── random_distance.py ├── rho_distance.py ├── semihard.py └── softhard.py ├── criteria ├── __init__.py ├── adversarial_separation.py ├── angular.py ├── arcface.py ├── contrastive.py ├── histogram.py ├── lifted.py ├── margin.py ├── margin_diml.py ├── multisimilarity.py ├── multisimilarity_diml.py ├── npair.py ├── proxynca.py ├── quadruplet.py ├── snr.py ├── softmax.py ├── softtriplet.py └── triplet.py ├── datasampler ├── __init__.py ├── class_random_sampler.py ├── d2_coreset_sampler.py ├── disthist_batchmatch_sampler.py ├── fid_batchmatch_sampler.py ├── greedy_coreset_sampler.py ├── random_sampler.py └── samplers.py ├── datasets ├── __init__.py ├── basic_dataset_scaffold.py ├── cars196.py ├── cub200.py └── stanford_online_products.py ├── evaluation ├── __init__.py ├── eval_diml.py └── metrics.py ├── figs └── intro.gif ├── parameters.py ├── scripts ├── baselines │ ├── cars_runs.sh │ ├── cub_runs.sh │ └── sop_runs.sh └── diml │ ├── test_diml.sh │ └── train_diml.sh ├── test_diml.py ├── train_baseline.py ├── train_diml.py └── utilities ├── __init__.py ├── diml.py ├── logger.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | *.pyc 4 | Training_Results 5 | wandb 6 | diva_main.py 7 | cars196 8 | cub200 9 | online_products 10 | .vscode 11 | *.jpg 12 | .nvimlog 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DIML 2 | Created by [Wenliang Zhao](https://wl-zhao.github.io/)\*, [Yongming Rao](https://raoyongming.github.io/)\*, [Ziyi Wang](https://github.com/LavenderLA), [Jiwen Lu](https://scholar.google.com/citations?user=TN8uDQoAAAAJ&hl=en&authuser=1), [Jie Zhou](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1) 3 | 4 | This repository contains PyTorch implementation for paper **Towards Interpretable Deep Metric Learning with Structural Matching** (ICCV 2021). 5 | 6 | We present a deep interpretable metric learning (DIML) that adopts a structural matching strategy to explicitly aligns the spatial embeddings by computing an optimal matching flow between feature maps of the two images. Our method enables deep models to learn metrics in a more human-friendly way, where the similarity of two images can be decomposed to several part-wise similarities and their contributions to the overall similarity. Our method is model-agnostic, which can be applied to off-the-shelf backbone networks and metric learning methods. 7 | 8 | ![intro](figs/intro.gif) 9 | 10 | [[arXiv](https://arxiv.org/abs/2108.05889)] 11 | ## Usage 12 | ### Requirement 13 | - python3 14 | - PyTorch 1.7 15 | 16 | ### Dataset Preparation 17 | Please follow the instruction in [RevisitDML](https://github.com/Confusezius/Revisiting_Deep_Metric_Learning_PyTorch) to download the datasets and put all the datasets in `data` folder. The structure should be: 18 | ``` 19 | data 20 | ├── cars196 21 | │   └── images 22 | ├── cub200 23 | │   └── images 24 | └── online_products 25 | ├── images 26 | └── Info_Files 27 | ``` 28 | 29 | ### Training & Evaluation 30 | To train the baseline models, run the scripts in `scripts/baselines`. For example: 31 | ```bash 32 | CUDA_VISIBLE_DEVICES=0 ./script/baselines/cub_runs.sh 33 | ``` 34 | The checkpoints are saved in Training_Results folder. 35 | 36 | To test the baseline models with our proposed DIML, first edit the checkpoint paths in `test_diml.py`, then run 37 | ```bash 38 | CUDA_VISIBLE_DEVICES=0 ./scripts/diml/test_diml.sh cub200 39 | ``` 40 | The results will be written to `test_results/test_diml_.csv` in CSV format. 41 | 42 | You can also incorporate DIML into the training objectives. We provide two examples which apply DIML to Margin and Multi-Similarity loss. To train DIML models, run 43 | ```bash 44 | # ./scripts/diml/train_diml.sh 45 | # where loss could be margin_diml or multisimilarity_diml 46 | # e.g. 47 | CUDA_VISIBLE_DEVICES=0 ./scripts/diml/train_diml.sh cub200 112 margin_diml 150 48 | ``` 49 | 50 | ## Acknowledgement 51 | The code is based on [RevisitDML](https://github.com/Confusezius/Revisiting_Deep_Metric_Learning_PyTorch). 52 | 53 | 54 | ## Citation 55 | If you find our work useful in your research, please consider citing: 56 | ``` 57 | @inproceedings{zhao2021towards, 58 | title={Towards Interpretable Deep Metric Learning with Structural Matching}, 59 | author={Zhao, Wenliang and Rao, Yongming and Wang, Ziyi and Lu, Jiwen and Zhou, Jie}, 60 | booktitle={ICCV}, 61 | year={2021} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /architectures/__init__.py: -------------------------------------------------------------------------------- 1 | import architectures.resnet50 2 | import architectures.googlenet 3 | import architectures.bninception 4 | import architectures.resnet50_diml 5 | 6 | def select(arch, opt): 7 | if 'resnet50_diml' in arch: 8 | return resnet50_diml.Network(opt) 9 | if 'resnet50' in arch: 10 | return resnet50.Network(opt) 11 | if 'googlenet' in arch: 12 | return googlenet.Network(opt) 13 | if 'bninception' in arch: 14 | return bninception.Network(opt) 15 | -------------------------------------------------------------------------------- /architectures/bninception.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import pretrainedmodels as ptm 6 | 7 | """=============================================================""" 8 | class Network(torch.nn.Module): 9 | def __init__(self, opt, return_embed_dict=False): 10 | super(Network, self).__init__() 11 | 12 | self.pars = opt 13 | self.model = ptm.__dict__['bninception'](num_classes=1000, pretrained='imagenet') 14 | self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim) 15 | if '_he' in opt.arch: 16 | torch.nn.init.kaiming_normal_(self.model.last_linear.weight, mode='fan_out') 17 | torch.nn.init.constant_(self.model.last_linear.bias, 0) 18 | 19 | if 'frozen' in opt.arch: 20 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 21 | module.eval() 22 | module.train = lambda _: None 23 | 24 | self.return_embed_dict = return_embed_dict 25 | 26 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 27 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 28 | 29 | self.name = opt.arch 30 | 31 | self.out_adjust = None 32 | 33 | def forward(self, x, warmup=False, **kwargs): 34 | x = self.model.features(x) 35 | y = self.pool_base(x) 36 | if self.pool_aux is not None: 37 | y += self.pool_aux(x) 38 | if warmup: 39 | y,x = y.detach(), x.detach() 40 | z = self.model.last_linear(y.view(len(x),-1)) 41 | if 'normalize' in self.name: 42 | z = F.normalize(z, dim=-1) 43 | if self.out_adjust and not self.training: 44 | z = self.out_adjust(z) 45 | return z,(y,x) 46 | 47 | def functional_forward(self, x): 48 | pass 49 | -------------------------------------------------------------------------------- /architectures/googlenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import torchvision.models as mod 6 | 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = mod.googlenet(pretrained=True) 18 | 19 | self.model.last_linear = torch.nn.Linear(self.model.fc.in_features, opt.embed_dim) 20 | self.model.fc = self.model.last_linear 21 | 22 | self.name = opt.arch 23 | 24 | def forward(self, x): 25 | x = self.model(x) 26 | if not 'normalize' in self.pars.arch: 27 | return x 28 | return torch.nn.functional.normalize(x, dim=-1) 29 | -------------------------------------------------------------------------------- /architectures/resnet50.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | 7 | """=============================================================""" 8 | class Network(torch.nn.Module): 9 | def __init__(self, opt): 10 | super(Network, self).__init__() 11 | 12 | self.pars = opt 13 | self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained='imagenet' if not opt.not_pretrained else None) 14 | 15 | self.name = opt.arch 16 | 17 | if 'frozen' in opt.arch: 18 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 19 | module.eval() 20 | module.train = lambda _: None 21 | 22 | self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim) 23 | 24 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 25 | 26 | self.out_adjust = None 27 | 28 | 29 | def forward(self, x, **kwargs): 30 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 31 | for layerblock in self.layer_blocks: 32 | x = layerblock(x) 33 | no_avg_feat = x 34 | x = self.model.avgpool(x) 35 | enc_out = x = x.view(x.size(0),-1) 36 | 37 | x = self.model.last_linear(x) 38 | 39 | if 'normalize' in self.pars.arch: 40 | x = torch.nn.functional.normalize(x, dim=-1) 41 | if self.out_adjust and not self.train: 42 | x = self.out_adjust(x) 43 | 44 | return x, (enc_out, no_avg_feat) 45 | -------------------------------------------------------------------------------- /architectures/resnet50_diml.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | 7 | """=============================================================""" 8 | class Network(torch.nn.Module): 9 | def __init__(self, opt): 10 | super(Network, self).__init__() 11 | 12 | self.pars = opt 13 | self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained='imagenet' if not opt.not_pretrained else None) 14 | 15 | self.name = opt.arch 16 | 17 | if 'frozen' in opt.arch: 18 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 19 | module.eval() 20 | module.train = lambda _: None 21 | 22 | self.model.last_linear = torch.nn.Conv2d(self.model.last_linear.in_features, opt.embed_dim, 1) 23 | 24 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 25 | 26 | self.out_adjust = None 27 | 28 | def forward(self, x, **kwargs): 29 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 30 | for layerblock in self.layer_blocks: 31 | x = layerblock(x) 32 | no_avg_feat = x 33 | 34 | x = torch.nn.functional.upsample(x, size=(16, 16), mode='bilinear', align_corners=True) 35 | x = torch.nn.functional.adaptive_avg_pool2d(x, output_size=(4, 4)) 36 | 37 | per_point_pred = self.model.last_linear(x) 38 | 39 | x = self.model.avgpool(no_avg_feat) 40 | enc_out = x.view(x.size(0), -1) 41 | 42 | return per_point_pred, (enc_out, no_avg_feat) 43 | -------------------------------------------------------------------------------- /batchminer/__init__.py: -------------------------------------------------------------------------------- 1 | from batchminer import random_distance, intra_random 2 | from batchminer import lifted, rho_distance, softhard, npair, parametric, random, semihard, distance 3 | 4 | BATCHMINING_METHODS = {'random':random, 5 | 'semihard':semihard, 6 | 'softhard':softhard, 7 | 'distance':distance, 8 | 'rho_distance':rho_distance, 9 | 'npair':npair, 10 | 'parametric':parametric, 11 | 'lifted':lifted, 12 | 'random_distance': random_distance, 13 | 'intra_random': intra_random} 14 | 15 | 16 | def select(batchminername, opt): 17 | ##### 18 | if batchminername not in BATCHMINING_METHODS: raise NotImplementedError('Batchmining {} not available!'.format(batchminername)) 19 | 20 | batchmine_lib = BATCHMINING_METHODS[batchminername] 21 | 22 | return batchmine_lib.BatchMiner(opt) 23 | -------------------------------------------------------------------------------- /batchminer/distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | class BatchMiner(): 7 | def __init__(self, opt): 8 | self.par = opt 9 | self.lower_cutoff = opt.miner_distance_lower_cutoff 10 | self.upper_cutoff = opt.miner_distance_upper_cutoff 11 | self.name = 'distance' 12 | 13 | def __call__(self, batch, labels, tar_labels=None, return_distances=False, distances=None): 14 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 15 | # bs, dim = batch.shape 16 | 17 | if distances is None: 18 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 19 | 20 | dim = 128 21 | bs = distances.size(0) 22 | sel_d = distances.shape[-1] 23 | 24 | positives, negatives = [],[] 25 | labels_visited = [] 26 | anchors = [] 27 | 28 | tar_labels = labels if tar_labels is None else tar_labels 29 | 30 | for i in range(bs): 31 | neg = tar_labels!=labels[i]; pos = tar_labels==labels[i] 32 | 33 | anchors.append(i) 34 | q_d_inv = self.inverse_sphere_distances(dim, bs, distances[i], tar_labels, labels[i]) 35 | negatives.append(np.random.choice(sel_d,p=q_d_inv)) 36 | 37 | if np.sum(pos)>0: 38 | #Sample positives randomly 39 | if np.sum(pos)>1: pos[i] = 0 40 | positives.append(np.random.choice(np.where(pos)[0])) 41 | #Sample negatives by distance 42 | 43 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 44 | 45 | if return_distances: 46 | return sampled_triplets, distances 47 | else: 48 | return sampled_triplets 49 | 50 | 51 | def inverse_sphere_distances(self, dim, bs, anchor_to_all_dists, labels, anchor_label): 52 | dists = anchor_to_all_dists 53 | 54 | #negated log-distribution of distances of unit sphere in dimension 55 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 56 | log_q_d_inv[np.where(labels==anchor_label)[0]] = 0 57 | 58 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 59 | q_d_inv[np.where(labels==anchor_label)[0]] = 0 60 | 61 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 62 | # errors where there are no available negatives (for high samples_per_class cases). 63 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 64 | 65 | q_d_inv = q_d_inv/q_d_inv.sum() 66 | return q_d_inv.detach().cpu().numpy() 67 | 68 | 69 | def pdist(self, A): 70 | prod = torch.mm(A, A.t()) 71 | norm = prod.diag().unsqueeze(1).expand_as(prod) 72 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 73 | return res.sqrt() 74 | -------------------------------------------------------------------------------- /batchminer/intra_random.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | import itertools as it 3 | import random 4 | 5 | class BatchMiner(): 6 | def __init__(self, opt): 7 | self.par = opt 8 | self.name = 'random' 9 | 10 | def __call__(self, batch, labels): 11 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 12 | unique_classes = np.unique(labels) 13 | indices = np.arange(len(batch)) 14 | class_dict = {i:indices[labels==i] for i in unique_classes} 15 | 16 | sampled_triplets = [] 17 | for cls in np.random.choice(list(class_dict.keys()), len(labels), replace=True): 18 | a,p,n = np.random.choice(class_dict[cls], 3, replace=True) 19 | sampled_triplets.append((a,p,n)) 20 | 21 | return sampled_triplets 22 | -------------------------------------------------------------------------------- /batchminer/lifted.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | class BatchMiner(): 4 | def __init__(self, opt): 5 | self.par = opt 6 | self.name = 'lifted' 7 | 8 | def __call__(self, batch, labels): 9 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 10 | 11 | ### 12 | anchors, positives, negatives = [], [], [] 13 | list(range(len(batch))) 14 | 15 | for i in range(len(batch)): 16 | anchor = i 17 | pos = labels==labels[anchor] 18 | 19 | ### 20 | if np.sum(pos)>1: 21 | anchors.append(anchor) 22 | positive_set = np.where(pos)[0] 23 | positive_set = positive_set[positive_set!=anchor] 24 | positives.append(positive_set) 25 | 26 | ### 27 | negatives = [] 28 | for anchor,positive_set in zip(anchors, positives): 29 | neg_idxs = [i for i in range(len(batch)) if i not in [anchor]+list(positive_set)] 30 | negative_set = np.arange(len(batch))[neg_idxs] 31 | negatives.append(negative_set) 32 | 33 | return anchors, positives, negatives 34 | -------------------------------------------------------------------------------- /batchminer/npair.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | class BatchMiner(): 3 | def __init__(self, opt): 4 | self.par = opt 5 | self.name = 'npair' 6 | 7 | def __call__(self, batch, labels): 8 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 9 | 10 | anchors, positives, negatives = [],[],[] 11 | 12 | for i in range(len(batch)): 13 | anchor = i 14 | pos = labels==labels[anchor] 15 | 16 | if np.sum(pos)>1: 17 | anchors.append(anchor) 18 | avail_positive = np.where(pos)[0] 19 | avail_positive = avail_positive[avail_positive!=anchor] 20 | positive = np.random.choice(avail_positive) 21 | positives.append(positive) 22 | 23 | ### 24 | negatives = [] 25 | for anchor,positive in zip(anchors, positives): 26 | # neg_idxs = [i for i in range(len(batch)) if i not in [anchor, positive] and labels[i] != labels[anchor]] 27 | neg_idxs = [i for i in range(len(batch)) if i not in [anchor, positive]] 28 | negative_set = np.arange(len(batch))[neg_idxs] 29 | negatives.append(negative_set) 30 | 31 | return anchors, positives, negatives 32 | -------------------------------------------------------------------------------- /batchminer/parametric.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.mode = opt.miner_parametric_mode 8 | self.n_support = opt.miner_parametric_n_support 9 | self.support_lim = opt.miner_parametric_support_lim 10 | self.name = 'parametric' 11 | 12 | ### 13 | self.set_sample_distr() 14 | 15 | 16 | 17 | def __call__(self, batch, labels): 18 | bs = batch.shape[0] 19 | sample_distr = self.sample_distr 20 | 21 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 22 | 23 | ### 24 | distances = self.pdist(batch.detach()) 25 | 26 | p_assigns = np.sum((distances.cpu().numpy().reshape(-1)>self.support[1:-1].reshape(-1,1)).T,axis=1).reshape(distances.shape) 27 | outside_support_lim = (distances.cpu().numpy().reshape(-1)self.support_lim[1]) 28 | outside_support_lim = outside_support_lim.reshape(distances.shape) 29 | 30 | sample_ps = sample_distr[p_assigns] 31 | sample_ps[outside_support_lim] = 0 32 | 33 | ### 34 | anchors, labels_visited = [], [] 35 | positives, negatives = [],[] 36 | 37 | ### 38 | for i in range(bs): 39 | neg = labels!=labels[i]; pos = labels==labels[i] 40 | 41 | if np.sum(pos)>1: 42 | anchors.append(i) 43 | 44 | #Sample positives randomly 45 | pos[i] = 0 46 | positives.append(np.random.choice(np.where(pos)[0])) 47 | 48 | #Sample negatives by distance 49 | sample_p = sample_ps[i][neg] 50 | sample_p = sample_p/sample_p.sum() 51 | negatives.append(np.random.choice(np.arange(bs)[neg],p=sample_p)) 52 | 53 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 54 | return sampled_triplets 55 | 56 | 57 | 58 | def pdist(self, A, eps=1e-4): 59 | prod = torch.mm(A, A.t()) 60 | norm = prod.diag().unsqueeze(1).expand_as(prod) 61 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 62 | return res.clamp(min = eps).sqrt() 63 | 64 | 65 | def set_sample_distr(self): 66 | self.support = np.linspace(self.support_lim[0], self.support_lim[1], self.n_support) 67 | 68 | if self.mode == 'uniform': 69 | self.sample_distr = np.array([1.] * (self.n_support-1)) 70 | 71 | if self.mode == 'hards': 72 | self.sample_distr = self.support.copy() 73 | self.sample_distr[self.support<=0.5] = 1 74 | self.sample_distr[self.support>0.5] = 0 75 | 76 | if self.mode == 'semihards': 77 | self.sample_distr = self.support.copy() 78 | from IPython import embed; embed() 79 | self.sample_distr[(self.support<=0.7) * (self.support>=0.3)] = 1 80 | self.sample_distr[(self.support<0.3) * (self.support>0.7)] = 0 81 | 82 | if self.mode == 'veryhards': 83 | self.sample_distr = self.support.copy() 84 | self.sample_distr[self.support<=0.3] = 1 85 | self.sample_distr[self.support>0.3] = 0 86 | 87 | self.sample_distr = np.clip(self.sample_distr, 1e-15, 1) 88 | self.sample_distr = self.sample_distr/self.sample_distr.sum() 89 | -------------------------------------------------------------------------------- /batchminer/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | import itertools as it 3 | import random 4 | 5 | class BatchMiner(): 6 | def __init__(self, opt): 7 | self.par = opt 8 | self.name = 'random' 9 | 10 | def __call__(self, batch, labels): 11 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 12 | unique_classes = np.unique(labels) 13 | indices = np.arange(len(batch)) 14 | class_dict = {i:indices[labels==i] for i in unique_classes} 15 | 16 | sampled_triplets = [list(it.product([x],[x],[y for y in unique_classes if x!=y])) for x in unique_classes] 17 | sampled_triplets = [x for y in sampled_triplets for x in y] 18 | 19 | sampled_triplets = [[x for x in list(it.product(*[class_dict[j] for j in i])) if x[0]!=x[1]] for i in sampled_triplets] 20 | sampled_triplets = [x for y in sampled_triplets for x in y] 21 | 22 | #NOTE: The number of possible triplets is given by #unique_classes*(2*(samples_per_class-1)!)*(#unique_classes-1)*samples_per_class 23 | sampled_triplets = random.sample(sampled_triplets, batch.shape[0]) 24 | return sampled_triplets 25 | -------------------------------------------------------------------------------- /batchminer/random_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.lower_cutoff = opt.miner_distance_lower_cutoff 8 | self.upper_cutoff = opt.miner_distance_upper_cutoff 9 | self.name = 'distance' 10 | 11 | def __call__(self, batch, labels): 12 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 13 | labels = labels[np.random.choice(len(labels), len(labels), replace=False)] 14 | 15 | bs = batch.shape[0] 16 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 17 | 18 | positives, negatives = [],[] 19 | labels_visited = [] 20 | anchors = [] 21 | 22 | for i in range(bs): 23 | neg = labels!=labels[i]; pos = labels==labels[i] 24 | 25 | if np.sum(pos)>1: 26 | anchors.append(i) 27 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i]) 28 | #Sample positives randomly 29 | pos[i] = 0 30 | positives.append(np.random.choice(np.where(pos)[0])) 31 | #Sample negatives by distance 32 | negatives.append(np.random.choice(bs,p=q_d_inv)) 33 | 34 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 35 | return sampled_triplets 36 | 37 | 38 | def inverse_sphere_distances(self, batch, anchor_to_all_dists, labels, anchor_label): 39 | dists = anchor_to_all_dists 40 | bs,dim = len(dists),batch.shape[-1] 41 | 42 | #negated log-distribution of distances of unit sphere in dimension 43 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 44 | log_q_d_inv[np.where(labels==anchor_label)[0]] = 0 45 | 46 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 47 | q_d_inv[np.where(labels==anchor_label)[0]] = 0 48 | 49 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 50 | # errors where there are no available negatives (for high samples_per_class cases). 51 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 52 | 53 | q_d_inv = q_d_inv/q_d_inv.sum() 54 | return q_d_inv.detach().cpu().numpy() 55 | 56 | 57 | def pdist(self, A): 58 | prod = torch.mm(A, A.t()) 59 | norm = prod.diag().unsqueeze(1).expand_as(prod) 60 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 61 | return res.sqrt() 62 | -------------------------------------------------------------------------------- /batchminer/rho_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.lower_cutoff = opt.miner_rho_distance_lower_cutoff 8 | self.upper_cutoff = opt.miner_rho_distance_upper_cutoff 9 | self.contrastive_p = opt.miner_rho_distance_cp 10 | 11 | self.name = 'rho_distance' 12 | 13 | def __call__(self, batch, labels, return_distances=False): 14 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 15 | bs = batch.shape[0] 16 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 17 | 18 | positives, negatives = [],[] 19 | labels_visited = [] 20 | anchors = [] 21 | 22 | for i in range(bs): 23 | neg = labels!=labels[i]; pos = labels==labels[i] 24 | 25 | use_contr = np.random.choice(2, p=[1-self.contrastive_p, self.contrastive_p]) 26 | if np.sum(pos)>1: 27 | anchors.append(i) 28 | if use_contr: 29 | positives.append(i) 30 | #Sample negatives by distance 31 | pos[i] = 0 32 | negatives.append(np.random.choice(np.where(pos)[0])) 33 | else: 34 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i]) 35 | #Sample positives randomly 36 | pos[i] = 0 37 | positives.append(np.random.choice(np.where(pos)[0])) 38 | #Sample negatives by distance 39 | negatives.append(np.random.choice(bs,p=q_d_inv)) 40 | 41 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 42 | self.push_triplets = np.sum([m[1]==m[2] for m in labels[sampled_triplets]]) 43 | 44 | if return_distances: 45 | return sampled_triplets, distances 46 | else: 47 | return sampled_triplets 48 | 49 | 50 | def inverse_sphere_distances(self, batch, anchor_to_all_dists, labels, anchor_label): 51 | dists = anchor_to_all_dists 52 | bs,dim = len(dists),batch.shape[-1] 53 | 54 | #negated log-distribution of distances of unit sphere in dimension 55 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 56 | log_q_d_inv[np.where(labels==anchor_label)[0]] = 0 57 | 58 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 59 | q_d_inv[np.where(labels==anchor_label)[0]] = 0 60 | 61 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 62 | # errors where there are no available negatives (for high samples_per_class cases). 63 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 64 | 65 | q_d_inv = q_d_inv/q_d_inv.sum() 66 | return q_d_inv.detach().cpu().numpy() 67 | 68 | 69 | def pdist(self, A, eps=1e-4): 70 | prod = torch.mm(A, A.t()) 71 | norm = prod.diag().unsqueeze(1).expand_as(prod) 72 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 73 | return res.clamp(min = eps).sqrt() 74 | -------------------------------------------------------------------------------- /batchminer/semihard.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.name = 'semihard' 8 | self.margin = vars(opt)['loss_'+opt.loss+'_margin'] 9 | 10 | def __call__(self, batch, labels, return_distances=False): 11 | if isinstance(labels, torch.Tensor): labels = labels.detach().numpy() 12 | bs = batch.size(0) 13 | #Return distance matrix for all elements in batch (BSxBS) 14 | distances = self.pdist(batch.detach()).detach().cpu().numpy() 15 | 16 | positives, negatives = [], [] 17 | anchors = [] 18 | for i in range(bs): 19 | l, d = labels[i], distances[i] 20 | neg = labels!=l; pos = labels==l 21 | 22 | anchors.append(i) 23 | pos[i] = 0 24 | p = np.random.choice(np.where(pos)[0]) 25 | positives.append(p) 26 | 27 | #Find negatives that violate tripet constraint semi-negatives 28 | neg_mask = np.logical_and(neg,d>d[p]) 29 | neg_mask = np.logical_and(neg_mask,d0: 31 | negatives.append(np.random.choice(np.where(neg_mask)[0])) 32 | else: 33 | negatives.append(np.random.choice(np.where(neg)[0])) 34 | 35 | sampled_triplets = [[a, p, n] for a, p, n in zip(anchors, positives, negatives)] 36 | 37 | if return_distances: 38 | return sampled_triplets, distances 39 | else: 40 | return sampled_triplets 41 | 42 | 43 | def pdist(self, A): 44 | prod = torch.mm(A, A.t()) 45 | norm = prod.diag().unsqueeze(1).expand_as(prod) 46 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 47 | return res.clamp(min = 0).sqrt() 48 | -------------------------------------------------------------------------------- /batchminer/softhard.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.name = 'softhard' 8 | 9 | def __call__(self, batch, labels, return_distances=False): 10 | if isinstance(labels, torch.Tensor): labels = labels.detach().numpy() 11 | bs = batch.size(0) 12 | #Return distance matrix for all elements in batch (BSxBS) 13 | distances = self.pdist(batch.detach()).detach().cpu().numpy() 14 | 15 | positives, negatives = [], [] 16 | anchors = [] 17 | for i in range(bs): 18 | l, d = labels[i], distances[i] 19 | neg = labels!=l; pos = labels==l 20 | 21 | if np.sum(pos)>1: 22 | anchors.append(i) 23 | #1 for batchelements with label l 24 | #0 for current anchor 25 | pos[i] = False 26 | 27 | #Find negatives that violate triplet constraint in a hard fashion 28 | neg_mask = np.logical_and(neg,dd[np.where(neg)[0]].min()) 31 | 32 | if pos_mask.sum()>0: 33 | positives.append(np.random.choice(np.where(pos_mask)[0])) 34 | else: 35 | positives.append(np.random.choice(np.where(pos)[0])) 36 | 37 | if neg_mask.sum()>0: 38 | negatives.append(np.random.choice(np.where(neg_mask)[0])) 39 | else: 40 | negatives.append(np.random.choice(np.where(neg)[0])) 41 | 42 | sampled_triplets = [[a, p, n] for a, p, n in zip(anchors, positives, negatives)] 43 | if return_distances: 44 | return sampled_triplets, distances 45 | else: 46 | return sampled_triplets 47 | 48 | 49 | 50 | def pdist(self, A): 51 | prod = torch.mm(A, A.t()) 52 | norm = prod.diag().unsqueeze(1).expand_as(prod) 53 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 54 | return res.clamp(min = 0).sqrt() 55 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- 1 | ### Standard DML criteria 2 | from criteria import triplet, margin, proxynca, npair 3 | from criteria import lifted, contrastive, softmax 4 | from criteria import angular, snr, histogram, arcface 5 | from criteria import softtriplet, multisimilarity, quadruplet 6 | from criteria import margin_diml, multisimilarity_diml 7 | ### Non-Standard Criteria 8 | from criteria import adversarial_separation 9 | ### Basic Libs 10 | import copy 11 | 12 | 13 | """=================================================================================================""" 14 | def select(loss, opt, to_optim, batchminer=None): 15 | ##### 16 | losses = {'triplet': triplet, 17 | 'margin':margin, 18 | 'margin_diml': margin_diml, 19 | 'proxynca':proxynca, 20 | 'npair':npair, 21 | 'angular':angular, 22 | 'contrastive':contrastive, 23 | 'lifted':lifted, 24 | 'snr':snr, 25 | 'multisimilarity':multisimilarity, 26 | 'multisimilarity_diml':multisimilarity_diml, 27 | 'histogram':histogram, 28 | 'softmax':softmax, 29 | 'softtriplet':softtriplet, 30 | 'arcface':arcface, 31 | 'quadruplet':quadruplet, 32 | 'adversarial_separation':adversarial_separation} 33 | 34 | 35 | if loss not in losses: raise NotImplementedError('Loss {} not implemented!'.format(loss)) 36 | 37 | loss_lib = losses[loss] 38 | if loss_lib.REQUIRES_BATCHMINER: 39 | if batchminer is None: 40 | raise Exception('Loss {} requires one of the following batch mining methods: {}'.format(loss, loss_lib.ALLOWED_MINING_OPS)) 41 | else: 42 | if batchminer.name not in loss_lib.ALLOWED_MINING_OPS: 43 | raise Exception('{}-mining not allowed for {}-loss!'.format(batchminer.name, loss)) 44 | 45 | 46 | loss_par_dict = {'opt':opt} 47 | if loss_lib.REQUIRES_BATCHMINER: 48 | loss_par_dict['batchminer'] = batchminer 49 | 50 | criterion = loss_lib.Criterion(**loss_par_dict) 51 | 52 | if loss_lib.REQUIRES_OPTIM: 53 | if hasattr(criterion,'optim_dict_list') and criterion.optim_dict_list is not None: 54 | to_optim += criterion.optim_dict_list 55 | else: 56 | to_optim += [{'params':criterion.parameters(), 'lr':criterion.lr}] 57 | 58 | return criterion, to_optim 59 | -------------------------------------------------------------------------------- /criteria/adversarial_separation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | 11 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | margin: Triplet Margin. 17 | nu: Regularisation Parameter for beta values if they are learned. 18 | beta: Class-Margin values. 19 | n_classes: Number of different classes during training. 20 | """ 21 | super().__init__() 22 | 23 | #### 24 | self.embed_dim = opt.embed_dim 25 | self.proj_dim = opt.diva_decorrnet_dim 26 | 27 | self.directions = opt.diva_decorrelations 28 | self.weights = opt.diva_rho_decorrelation 29 | 30 | self.name = 'adversarial_separation' 31 | 32 | #Projection network 33 | self.regressors = nn.ModuleDict() 34 | for direction in self.directions: 35 | self.regressors[direction] = torch.nn.Sequential(torch.nn.Linear(self.embed_dim, self.proj_dim), torch.nn.ReLU(), torch.nn.Linear(self.proj_dim, self.embed_dim)).to(torch.float).to(opt.device) 36 | 37 | #Learning Rate for Projection Network 38 | self.lr = opt.diva_decorrnet_lr 39 | 40 | 41 | #### 42 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 43 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 44 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 45 | 46 | 47 | 48 | 49 | def forward(self, feature_dict): 50 | #Apply gradient reversal on input embeddings. 51 | adj_feature_dict = {key:torch.nn.functional.normalize(grad_reverse(features),dim=-1) for key, features in feature_dict.items()} 52 | #Project one embedding to the space of the other (with normalization), then compute the correlation. 53 | sim_loss = 0 54 | for weight, direction in zip(self.weights, self.directions): 55 | source, target = direction.split('-') 56 | sim_loss += -1.*weight*torch.mean(torch.mean((adj_feature_dict[target]*torch.nn.functional.normalize(self.regressors[direction](adj_feature_dict[source]),dim=-1))**2,dim=-1)) 57 | return sim_loss 58 | 59 | 60 | 61 | ### Gradient Reversal Layer 62 | class GradRev(torch.autograd.Function): 63 | """ 64 | Implements an autograd class to flip gradients during backward pass. 65 | """ 66 | def forward(self, x): 67 | """ 68 | Container which applies a simple identity function. 69 | 70 | Input: 71 | x: any torch tensor input. 72 | """ 73 | return x.view_as(x) 74 | 75 | def backward(self, grad_output): 76 | """ 77 | Container to reverse gradient signal during backward pass. 78 | 79 | Input: 80 | grad_output: any computed gradient. 81 | """ 82 | return (grad_output * -1.) 83 | 84 | ### Gradient reverse function 85 | def grad_reverse(x): 86 | """ 87 | Applies gradient reversal on input. 88 | 89 | Input: 90 | x: any torch tensor input. 91 | """ 92 | return GradRev()(x) 93 | -------------------------------------------------------------------------------- /criteria/angular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = ['npair'] 8 | REQUIRES_BATCHMINER = True 9 | REQUIRES_OPTIM = False 10 | 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | super(Criterion, self).__init__() 14 | 15 | self.tan_angular_margin = np.tan(np.pi/180*opt.loss_angular_alpha) 16 | self.lam = opt.loss_angular_npair_ang_weight 17 | self.l2_weight = opt.loss_angular_npair_l2 18 | self.batchminer = batchminer 19 | 20 | self.name = 'angular' 21 | 22 | #### 23 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 24 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 25 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 26 | 27 | 28 | 29 | def forward(self, batch, labels, **kwargs): 30 | ####NOTE: Normalize Angular Loss, but dont normalize npair loss! 31 | anchors, positives, negatives = self.batchminer(batch, labels) 32 | # print(batch.shape, len(anchors.shape, positives.shape, negatives.shape) 33 | anchors, positives, negatives = batch[anchors], batch[positives], batch[negatives] 34 | n_anchors, n_positives, n_negatives = F.normalize(anchors, dim=1), F.normalize(positives, dim=1), F.normalize(negatives, dim=-1) 35 | 36 | is_term1 = 4*self.tan_angular_margin**2*(n_anchors + n_positives)[:,None,:].bmm(n_negatives.permute(0,2,1)) 37 | is_term2 = 2*(1+self.tan_angular_margin**2)*n_anchors[:,None,:].bmm(n_positives[:,None,:].permute(0,2,1)) 38 | is_term1 = is_term1.view(is_term1.shape[0], is_term1.shape[-1]) 39 | is_term2 = is_term2.view(-1, 1) 40 | 41 | inner_sum_ang = is_term1 - is_term2 42 | angular_loss = torch.mean(torch.log(torch.sum(torch.exp(inner_sum_ang), dim=1) + 1)) 43 | 44 | 45 | inner_sum_npair = anchors[:,None,:].bmm((negatives - positives[:,None,:]).permute(0,2,1)) 46 | inner_sum_npair = inner_sum_npair.view(inner_sum_npair.shape[0], inner_sum_npair.shape[-1]) 47 | npair_loss = torch.mean(torch.log(torch.sum(torch.exp(inner_sum_npair.clamp(max=50,min=-50)), dim=1) + 1)) 48 | 49 | loss = npair_loss + self.lam*angular_loss + self.l2_weight*torch.mean(torch.norm(batch, p=2, dim=1)) 50 | return loss 51 | -------------------------------------------------------------------------------- /criteria/arcface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = None 7 | REQUIRES_BATCHMINER = False 8 | REQUIRES_OPTIM = True 9 | 10 | ### This implementation follows the pseudocode provided in the original paper. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(Criterion, self).__init__() 14 | self.par = opt 15 | 16 | #### 17 | self.angular_margin = opt.loss_arcface_angular_margin 18 | self.feature_scale = opt.loss_arcface_feature_scale 19 | 20 | self.class_map = torch.nn.Parameter(torch.Tensor(opt.n_classes, opt.embed_dim)) 21 | stdv = 1. / np.sqrt(self.class_map.size(1)) 22 | self.class_map.data.uniform_(-stdv, stdv) 23 | 24 | self.name = 'arcface' 25 | 26 | self.lr = opt.loss_arcface_lr 27 | 28 | #### 29 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 30 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 31 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 32 | 33 | 34 | 35 | 36 | def forward(self, batch, labels, **kwargs): 37 | bs, labels = len(batch), labels.to(self.par.device) 38 | 39 | class_map = torch.nn.functional.normalize(self.class_map, dim=1) 40 | #Note that the similarity becomes the cosine for normalized embeddings. Denoted as 'fc7' in the paper pseudocode. 41 | cos_similarity = batch.mm(class_map.T).clamp(min=1e-10, max=1-1e-10) 42 | 43 | pick = torch.zeros(bs, self.par.n_classes).bool().to(self.par.device) 44 | pick[torch.arange(bs), labels] = 1 45 | 46 | original_target_logit = cos_similarity[pick] 47 | 48 | theta = torch.acos(original_target_logit) 49 | marginal_target_logit = torch.cos(theta + self.angular_margin) 50 | 51 | class_pred = self.feature_scale * (cos_similarity + (marginal_target_logit-original_target_logit).unsqueeze(1)) 52 | loss = torch.nn.CrossEntropyLoss()(class_pred, labels) 53 | 54 | return loss 55 | -------------------------------------------------------------------------------- /criteria/contrastive.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = False 9 | 10 | 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | super(Criterion, self).__init__() 14 | self.pos_margin = opt.loss_contrastive_pos_margin 15 | self.neg_margin = opt.loss_contrastive_neg_margin 16 | self.batchminer = batchminer 17 | 18 | self.name = 'contrastive' 19 | 20 | #### 21 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 22 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 23 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 24 | 25 | 26 | 27 | def forward(self, batch, labels, **kwargs): 28 | sampled_triplets = self.batchminer(batch, labels) 29 | 30 | anchors = [triplet[0] for triplet in sampled_triplets] 31 | positives = [triplet[1] for triplet in sampled_triplets] 32 | negatives = [triplet[2] for triplet in sampled_triplets] 33 | 34 | pos_dists = torch.mean(F.relu(nn.PairwiseDistance(p=2)(batch[anchors,:], batch[positives,:]) - self.pos_margin)) 35 | neg_dists = torch.mean(F.relu(self.neg_margin - nn.PairwiseDistance(p=2)(batch[anchors,:], batch[negatives,:]))) 36 | 37 | loss = pos_dists + neg_dists 38 | 39 | return loss 40 | -------------------------------------------------------------------------------- /criteria/histogram.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = None 7 | REQUIRES_BATCHMINER = False 8 | REQUIRES_OPTIM = False 9 | 10 | 11 | #NOTE: This implementation follows: https://github.com/valerystrizh/pytorch-histogram-loss 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | margin: Triplet Margin. 17 | """ 18 | super(Criterion, self).__init__() 19 | self.par = opt 20 | 21 | self.nbins = opt.loss_histogram_nbins 22 | self.bin_width = 2/(self.nbins - 1) 23 | 24 | # We require a numpy and torch support as parts of the computation require numpy. 25 | self.support = np.linspace(-1,1,self.nbins).reshape(-1,1) 26 | self.support_torch = torch.linspace(-1,1,self.nbins).reshape(-1,1).to(opt.device) 27 | 28 | self.name = 'histogram' 29 | 30 | #### 31 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 32 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 33 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 34 | 35 | 36 | def forward(self, batch, labels, **kwargs): 37 | #The original paper utilizes similarities instead of distances. 38 | similarity = batch.mm(batch.T) 39 | 40 | bs = labels.size()[0] 41 | 42 | ### We create a equality matrix for labels occuring in the batch 43 | label_eqs = (labels.repeat(bs, 1) == labels.view(-1, 1).repeat(1, bs)) 44 | 45 | ### Because the similarity matrix is symmetric, we will only utilise the upper triangular. 46 | ### These values are indexed by sim_inds 47 | sim_inds = torch.triu(torch.ones(similarity.size()), 1).bool().to(self.par.device) 48 | 49 | ### For the upper triangular similarity matrix, we want to know where our positives/anchors and negatives are: 50 | pos_inds = label_eqs[sim_inds].repeat(self.nbins, 1) 51 | neg_inds = ~label_eqs[sim_inds].repeat(self.nbins, 1) 52 | 53 | ### 54 | n_pos = pos_inds[0].sum() 55 | n_neg = neg_inds[0].sum() 56 | 57 | ### Extract upper triangular from the similarity matrix. (produces a one-dim vector) 58 | unique_sim = similarity[sim_inds].view(1, -1) 59 | 60 | ### We broadcast this vector to each histogram bin. Each bin entry requires a different summation in self.histogram() 61 | unique_sim_rep = unique_sim.repeat(self.nbins, 1) 62 | 63 | ### This assigns bin-values for float-similarities. The conversion to numpy is important to avoid rounding errors in torch. 64 | assigned_bin_values = ((unique_sim_rep.detach().cpu().numpy() + 1) / self.bin_width).astype(int) * self.bin_width - 1 65 | 66 | ### We now compute the histogram over distances 67 | hist_pos_sim = self.histogram(unique_sim_rep, assigned_bin_values, pos_inds, n_pos) 68 | hist_neg_sim = self.histogram(unique_sim_rep, assigned_bin_values, neg_inds, n_neg) 69 | 70 | ### Compute the CDF for the positive similarity histogram 71 | hist_pos_rep = hist_pos_sim.view(-1, 1).repeat(1, hist_pos_sim.size()[0]) 72 | hist_pos_inds = torch.tril(torch.ones(hist_pos_rep.size()), -1).bool() 73 | hist_pos_rep[hist_pos_inds] = 0 74 | hist_pos_cdf = hist_pos_rep.sum(0) 75 | 76 | loss = torch.sum(hist_neg_sim * hist_pos_cdf) 77 | 78 | return loss 79 | 80 | 81 | def histogram(self, unique_sim_rep, assigned_bin_values, idxs, n_elem): 82 | """ 83 | Compute the histogram over similarities. 84 | Args: 85 | unique_sim_rep: torch tensor of shape nbins x n_unique_neg_similarities. 86 | assigned_bin_values: Bin value for each similarity value in unique_sim_rep. 87 | idxs: positive/negative entry indices in unique_sim_rep 88 | n_elem: number of elements in unique_sim_rep. 89 | """ 90 | # Cloning is required because we change the similarity matrix in-place, but need it for the 91 | # positive AND negative histogram. Note that clone() allows for backprop. 92 | usr = unique_sim_rep.clone() 93 | # For each bin (and its lower neighbour bin) we find the distance values that belong. 94 | indsa = torch.tensor((assigned_bin_values==(self.support-self.bin_width) ) & idxs.detach().cpu().numpy()) 95 | indsb = torch.tensor((assigned_bin_values==self.support) & idxs.detach().cpu().numpy()) 96 | # Set all irrelevant similarities to 0 97 | usr[~(indsb|indsa)]=0 98 | # 99 | usr[indsa] = (usr - self.support_torch + self.bin_width)[indsa] / self.bin_width 100 | usr[indsb] = (-usr + self.support_torch + self.bin_width)[indsb] / self.bin_width 101 | 102 | return usr.sum(1)/n_elem 103 | -------------------------------------------------------------------------------- /criteria/lifted.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = ['lifted'] 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = False 9 | 10 | 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | 14 | super(Criterion, self).__init__() 15 | self.margin = opt.loss_lifted_neg_margin 16 | self.l2_weight = opt.loss_lifted_l2 17 | self.batchminer = batchminer 18 | 19 | self.name = 'lifted' 20 | 21 | #### 22 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 23 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 24 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 25 | 26 | 27 | 28 | def forward(self, batch, labels, **kwargs): 29 | anchors, positives, negatives = self.batchminer(batch, labels) 30 | 31 | loss = [] 32 | for anchor, positive_set, negative_set in zip(anchors, positives, negatives): 33 | anchor, positive_set, negative_set = batch[anchor, :].view(1,-1), batch[positive_set, :].view(1,len(positive_set),-1), batch[negative_set, :].view(1,len(negative_set),-1) 34 | pos_term = torch.logsumexp(nn.PairwiseDistance(p=2)(anchor[:,:,None], positive_set.permute(0,2,1)), dim=1) 35 | neg_term = torch.logsumexp(self.margin - nn.PairwiseDistance(p=2)(anchor[:,:,None], negative_set.permute(0,2,1)), dim=1) 36 | loss.append(F.relu(pos_term + neg_term)) 37 | 38 | loss = torch.mean(torch.stack(loss)) + self.l2_weight*torch.mean(torch.norm(batch, p=2, dim=1)) 39 | return loss 40 | -------------------------------------------------------------------------------- /criteria/margin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = True 9 | 10 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | super(Criterion, self).__init__() 14 | self.n_classes = opt.n_classes 15 | 16 | self.margin = opt.loss_margin_margin 17 | self.nu = opt.loss_margin_nu 18 | self.beta_constant = opt.loss_margin_beta_constant 19 | self.beta_val = opt.loss_margin_beta 20 | 21 | if opt.loss_margin_beta_constant: 22 | self.beta = opt.loss_margin_beta 23 | else: 24 | self.beta = torch.nn.Parameter(torch.ones(opt.n_classes)*opt.loss_margin_beta) 25 | 26 | self.batchminer = batchminer 27 | 28 | self.name = 'margin' 29 | 30 | self.lr = opt.loss_margin_beta_lr 31 | 32 | #### 33 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 34 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 35 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 36 | 37 | 38 | 39 | def forward(self, batch, labels, **kwargs): 40 | sampled_triplets = self.batchminer(batch, labels) 41 | 42 | if len(sampled_triplets): 43 | d_ap, d_an = [],[] 44 | for triplet in sampled_triplets: 45 | train_triplet = {'Anchor': batch[triplet[0],:], 'Positive':batch[triplet[1],:], 'Negative':batch[triplet[2]]} 46 | 47 | pos_dist = ((train_triplet['Anchor']-train_triplet['Positive']).pow(2).sum()+1e-8).pow(1/2) 48 | neg_dist = ((train_triplet['Anchor']-train_triplet['Negative']).pow(2).sum()+1e-8).pow(1/2) 49 | 50 | d_ap.append(pos_dist) 51 | d_an.append(neg_dist) 52 | d_ap, d_an = torch.stack(d_ap), torch.stack(d_an) 53 | 54 | if self.beta_constant: 55 | beta = self.beta 56 | else: 57 | beta = torch.stack([self.beta[labels[triplet[0]]] for triplet in sampled_triplets]).to(torch.float).to(d_ap.device) 58 | 59 | pos_loss = torch.nn.functional.relu(d_ap-beta+self.margin) 60 | neg_loss = torch.nn.functional.relu(beta-d_an+self.margin) 61 | 62 | pair_count = torch.sum((pos_loss>0.)+(neg_loss>0.)).to(torch.float).to(d_ap.device) 63 | 64 | if pair_count == 0.: 65 | loss = torch.sum(pos_loss+neg_loss) 66 | else: 67 | loss = torch.sum(pos_loss+neg_loss)/pair_count 68 | 69 | if self.nu: loss = loss + beta_regularisation_loss.to(torch.float).to(d_ap.device) 70 | else: 71 | loss = torch.tensor(0.).to(torch.float).to(batch.device) 72 | 73 | return loss 74 | -------------------------------------------------------------------------------- /criteria/margin_diml.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = True 9 | 10 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | super(Criterion, self).__init__() 14 | self.n_classes = opt.n_classes 15 | 16 | self.margin = opt.loss_margin_margin 17 | self.nu = opt.loss_margin_nu 18 | self.beta_constant = opt.loss_margin_beta_constant 19 | self.beta_val = opt.loss_margin_beta 20 | 21 | if opt.loss_margin_beta_constant: 22 | self.beta = opt.loss_margin_beta 23 | else: 24 | self.beta = torch.nn.Parameter(torch.ones(opt.n_classes)*opt.loss_margin_beta) 25 | 26 | self.batchminer = batchminer 27 | 28 | self.name = 'margin' 29 | 30 | self.lr = opt.loss_margin_beta_lr 31 | 32 | #### 33 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 34 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 35 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 36 | 37 | self.eps = 0.05 38 | self.max_iter = 100 39 | self.use_uniform = opt.use_uniform 40 | 41 | def normalize_all(self, x, y, x_mean, y_mean): 42 | x = F.normalize(x, dim=1) 43 | y = F.normalize(y, dim=1) 44 | x_mean = F.normalize(x_mean, dim=1) 45 | y_mean = F.normalize(y_mean, dim=1) 46 | return x, y, x_mean, y_mean 47 | 48 | def cross_attention(self, x, y, x_mean, y_mean): 49 | N, C = x.shape[:2] 50 | x = x.view(N, C, -1) 51 | y = y.view(N, C, -1) 52 | 53 | att = F.relu(torch.einsum("nc,ncr->nr", x_mean, y)).view(N, -1) 54 | u = att / (att.sum(dim=1, keepdims=True) + 1e-5) 55 | att = F.relu(torch.einsum("nc,ncr->nr", y_mean, x)).view(N, -1) 56 | v = att / (att.sum(dim=1, keepdims=True) + 1e-5) 57 | return u, v 58 | 59 | def pair_wise_wdist(self, x, y): 60 | B, C, H, W = x.size() 61 | x = x.view(B, C, -1, 1) 62 | y = y.view(B, C, 1, -1) 63 | x_mean = x.mean([2, 3]) 64 | y_mean = y.mean([2, 3]) 65 | 66 | x, y, x_mean, y_mean = self.normalize_all(x, y, x_mean, y_mean) 67 | dist1 = torch.sqrt(torch.sum(torch.pow(x - y, 2), dim=1) + 1e-6).view(B, H*W, H*W) 68 | dist2 = torch.sqrt(torch.sum(torch.pow(x_mean - y_mean, 2), dim=1) + 1e-6).view(B) 69 | 70 | x = x.view(B, C, -1) 71 | y = y.view(B, C, -1) 72 | 73 | sim = torch.einsum('bcs, bcm->bsm', x, y).contiguous() 74 | if self.use_uniform: 75 | u = torch.zeros(B, H*W, dtype=sim.dtype, device=sim.device).fill_(1. / (H * W)) 76 | v = torch.zeros(B, H*W, dtype=sim.dtype, device=sim.device).fill_(1. / (H * W)) 77 | else: 78 | u, v = self.cross_attention(x, y, x_mean, y_mean) 79 | 80 | wdist = 1.0 - sim.view(B, H*W, H*W) 81 | 82 | with torch.no_grad(): 83 | K = torch.exp(-wdist / self.eps) 84 | T = self.Sinkhorn(K, u, v) 85 | 86 | if torch.isnan(T).any(): 87 | return None 88 | 89 | dist1 = torch.sum(T * dist1, dim=(1, 2)) 90 | dist = dist1 + dist2 91 | dist = dist / 2 92 | 93 | return dist 94 | 95 | def Sinkhorn(self, K, u, v): 96 | r = torch.ones_like(u) 97 | c = torch.ones_like(v) 98 | thresh = 1e-1 99 | for i in range(self.max_iter): 100 | r0 = r 101 | r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1) 102 | c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1) 103 | err = (r - r0).abs().mean() 104 | if err.item() < thresh: 105 | break 106 | 107 | T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K 108 | 109 | return T 110 | 111 | def forward(self, batch, labels, **kwargs): 112 | # sampled_triplets = self.batchminer(batch, labels) 113 | 114 | pooled_feature = batch.mean([2, 3]) 115 | pooled_feature = F.normalize(pooled_feature, dim=-1) 116 | 117 | sampled_triplets = self.batchminer(pooled_feature, labels) 118 | 119 | if len(sampled_triplets): 120 | d_ap, d_an = [],[] 121 | for triplet in sampled_triplets: 122 | train_triplet = {'Anchor': batch[triplet[0]], 'Positive':batch[triplet[1]], 'Negative':batch[triplet[2]]} 123 | 124 | pos_dist = self.pair_wise_wdist(train_triplet['Anchor'].unsqueeze(0), train_triplet['Positive'].unsqueeze(0)) 125 | neg_dist = self.pair_wise_wdist(train_triplet['Anchor'].unsqueeze(0), train_triplet['Negative'].unsqueeze(0)) 126 | 127 | if pos_dist is None or neg_dist is None: 128 | continue 129 | 130 | d_ap.append(pos_dist) 131 | d_an.append(neg_dist) 132 | d_ap, d_an = torch.stack(d_ap), torch.stack(d_an) 133 | 134 | if self.beta_constant: 135 | beta = self.beta 136 | else: 137 | beta = torch.stack([self.beta[labels[triplet[0]]] for triplet in sampled_triplets]).to(torch.float).to(d_ap.device) 138 | 139 | pos_loss = torch.nn.functional.relu(d_ap-beta+self.margin) 140 | neg_loss = torch.nn.functional.relu(beta-d_an+self.margin) 141 | 142 | pair_count = torch.sum((pos_loss>0.)+(neg_loss>0.)).to(torch.float).to(d_ap.device) 143 | 144 | if pair_count == 0.: 145 | loss = torch.sum(pos_loss+neg_loss) 146 | else: 147 | loss = torch.sum(pos_loss+neg_loss)/pair_count 148 | 149 | if self.nu: loss = loss + beta_regularisation_loss.to(torch.float).to(d_ap.device) 150 | 151 | else: 152 | loss = torch.tensor(0.).to(torch.float).to(batch.device) 153 | 154 | return loss 155 | -------------------------------------------------------------------------------- /criteria/multisimilarity.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn 2 | 3 | 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = None 7 | REQUIRES_BATCHMINER = False 8 | REQUIRES_OPTIM = False 9 | 10 | class Criterion(torch.nn.Module): 11 | def __init__(self, opt): 12 | super(Criterion, self).__init__() 13 | self.n_classes = opt.n_classes 14 | 15 | self.pos_weight = opt.loss_multisimilarity_pos_weight 16 | self.neg_weight = opt.loss_multisimilarity_neg_weight 17 | self.margin = opt.loss_multisimilarity_margin 18 | self.thresh = opt.loss_multisimilarity_thresh 19 | 20 | self.name = 'multisimilarity' 21 | 22 | #### 23 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 24 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 25 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 26 | 27 | 28 | def forward(self, batch, labels, **kwargs): 29 | similarity = batch.mm(batch.T) 30 | 31 | loss = [] 32 | for i in range(len(batch)): 33 | pos_idxs = labels==labels[i] 34 | pos_idxs[i] = 0 35 | neg_idxs = labels!=labels[i] 36 | 37 | anchor_pos_sim = similarity[i][pos_idxs] 38 | anchor_neg_sim = similarity[i][neg_idxs] 39 | 40 | ### This part doesn't really work, especially when you dont have a lot of positives in the batch... 41 | neg_idxs = (anchor_neg_sim + self.margin) > torch.min(anchor_pos_sim) 42 | pos_idxs = (anchor_pos_sim - self.margin) < torch.max(anchor_neg_sim) 43 | if not torch.sum(neg_idxs) or not torch.sum(pos_idxs): 44 | continue 45 | anchor_neg_sim = anchor_neg_sim[neg_idxs] 46 | anchor_pos_sim = anchor_pos_sim[pos_idxs] 47 | 48 | pos_term = 1./self.pos_weight * torch.log(1+torch.sum(torch.exp(-self.pos_weight* (anchor_pos_sim - self.thresh)))) 49 | neg_term = 1./self.neg_weight * torch.log(1+torch.sum(torch.exp(self.neg_weight * (anchor_neg_sim - self.thresh)))) 50 | 51 | loss.append(pos_term + neg_term) 52 | if len(loss) < 1: 53 | loss = 0 54 | else: 55 | loss = torch.mean(torch.stack(loss)) 56 | return loss 57 | -------------------------------------------------------------------------------- /criteria/multisimilarity_diml.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = None 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = False 10 | 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(Criterion, self).__init__() 14 | self.n_classes = opt.n_classes 15 | 16 | self.pos_weight = opt.loss_multisimilarity_pos_weight 17 | self.neg_weight = opt.loss_multisimilarity_neg_weight 18 | self.margin = opt.loss_multisimilarity_margin 19 | self.thresh = opt.loss_multisimilarity_thresh 20 | 21 | self.name = 'multisimilarity_w' 22 | 23 | #### 24 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 25 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 26 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 27 | 28 | self.eps = 0.05 29 | self.max_iter = 100 30 | self.use_uniform = opt.use_uniform 31 | 32 | def normalize_all(self, x, y, x_mean, y_mean): 33 | x = F.normalize(x, dim=1) 34 | y = F.normalize(y, dim=1) 35 | x_mean = F.normalize(x_mean, dim=1) 36 | y_mean = F.normalize(y_mean, dim=1) 37 | return x, y, x_mean, y_mean 38 | 39 | def cross_attention(self, x, y, x_mean, y_mean): 40 | N, C = x.shape[:2] 41 | x = x.view(N, C, -1) 42 | y = y.view(N, C, -1) 43 | 44 | att = F.relu(torch.einsum("nc,ncr->nr", x_mean, y)).view(N, -1) 45 | u = att / (att.sum(dim=1, keepdims=True) + 1e-5) 46 | att = F.relu(torch.einsum("nc,ncr->nr", y_mean, x)).view(N, -1) 47 | v = att / (att.sum(dim=1, keepdims=True) + 1e-5) 48 | return u, v 49 | 50 | def pair_wise_wdist(self, x, y): 51 | B, C, H, W = x.size() 52 | x_mean = x.mean([2, 3]) 53 | y_mean = y.mean([2, 3]) 54 | x = x.view(B, C, -1) 55 | y = y.view(B, C, -1) 56 | x, y, x_mean, y_mean = self.normalize_all(x, y, x_mean, y_mean) 57 | 58 | if self.use_uniform: 59 | u = torch.zeros(B, H * W, dtype=x.dtype, device=x.device).fill_(1. / (H * W)) 60 | v = torch.zeros(B, H * W, dtype=x.dtype, device=x.device).fill_(1. / (H * W)) 61 | else: 62 | u, v = self.cross_attention(x, y, x_mean, y_mean) 63 | 64 | sim1 = torch.einsum('bcs, bcm->bsm', x, y).contiguous() 65 | sim2 = torch.einsum('bc, bc->b', x_mean, y_mean).contiguous().reshape(B, 1, 1) 66 | 67 | wdist = 1.0 - sim1.view(B, H * W, H * W) 68 | 69 | with torch.no_grad(): 70 | K = torch.exp(-wdist / self.eps) 71 | T = self.Sinkhorn(K, u, v).detach() 72 | 73 | sim = (sim1 + sim2) / 2 74 | sim = torch.sum(T * sim, dim=(1, 2)) 75 | 76 | return sim 77 | 78 | def Sinkhorn(self, K, u, v): 79 | r = torch.ones_like(u) 80 | c = torch.ones_like(v) 81 | thresh = 1e-1 82 | for i in range(self.max_iter): 83 | r0 = r 84 | r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1) 85 | c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1) 86 | err = (r - r0).abs().mean(dim=1) 87 | err = err[~torch.isnan(err)] 88 | 89 | if len(err) == 0 or torch.max(err).item() < thresh: 90 | break 91 | 92 | T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K 93 | return T 94 | 95 | def forward(self, batch, labels, **kwargs): 96 | b, _, _, _ = batch.size() 97 | batch_repeat = torch.repeat_interleave(batch, b, dim=0) 98 | batch_cat = torch.cat([batch for _ in range(b)], dim=0) 99 | similarity = self.pair_wise_wdist(batch_repeat, batch_cat).view(b, b) 100 | if torch.isnan(similarity).any(): 101 | similarity.sum().backward() 102 | return None 103 | 104 | loss = [] 105 | for i in range(len(batch)): 106 | pos_idxs = labels==labels[i] 107 | pos_idxs[i] = 0 108 | neg_idxs = labels!=labels[i] 109 | 110 | anchor_pos_sim = similarity[i][pos_idxs] 111 | anchor_neg_sim = similarity[i][neg_idxs] 112 | 113 | # filter nan 114 | anchor_pos_sim = anchor_pos_sim[~torch.isnan(anchor_pos_sim)] 115 | anchor_neg_sim = anchor_neg_sim[~torch.isnan(anchor_neg_sim)] 116 | 117 | ### This part doesn't really work, especially when you dont have a lot of positives in the batch... 118 | if len(anchor_pos_sim) == 0 or len(anchor_neg_sim) == 0: 119 | print("all nan") 120 | continue 121 | 122 | neg_idxs = (anchor_neg_sim + self.margin) > torch.min(anchor_pos_sim) 123 | pos_idxs = (anchor_pos_sim - self.margin) < torch.max(anchor_neg_sim) 124 | if not torch.sum(neg_idxs) or not torch.sum(pos_idxs): 125 | continue 126 | anchor_neg_sim = anchor_neg_sim[neg_idxs] 127 | anchor_pos_sim = anchor_pos_sim[pos_idxs] 128 | 129 | pos_term = 1./self.pos_weight * torch.log(1+torch.sum(torch.exp(-self.pos_weight* (anchor_pos_sim - self.thresh)))) 130 | neg_term = 1./self.neg_weight * torch.log(1+torch.sum(torch.exp(self.neg_weight * (anchor_neg_sim - self.thresh)))) 131 | 132 | if torch.isnan(pos_term) or torch.isnan(neg_term): 133 | print("pos, neg:", pos_term, neg_term) 134 | print(anchor_pos_sim, anchor_neg_sim) 135 | 136 | loss.append(pos_term + neg_term) 137 | 138 | if len(loss) < 1: 139 | print("no loss") 140 | loss = None 141 | else: 142 | loss = torch.mean(torch.stack(loss)) 143 | return loss 144 | -------------------------------------------------------------------------------- /criteria/npair.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = ['npair'] 8 | REQUIRES_BATCHMINER = True 9 | REQUIRES_OPTIM = False 10 | 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | """ 14 | Args: 15 | """ 16 | super(Criterion, self).__init__() 17 | self.pars = opt 18 | self.l2_weight = opt.loss_npair_l2 19 | self.batchminer = batchminer 20 | 21 | self.name = 'npair' 22 | 23 | #### 24 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 25 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 26 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 27 | 28 | 29 | def forward(self, batch, labels, **kwargs): 30 | anchors, positives, negatives = self.batchminer(batch, labels) 31 | 32 | ## 33 | loss = 0 34 | if 'bninception' in self.pars.arch: 35 | ### clamping/value reduction to avoid initial overflow for high embedding dimensions! 36 | batch = batch/4 37 | for anchor, positive, negative_set in zip(anchors, positives, negatives): 38 | a_embs, p_embs, n_embs = batch[anchor:anchor+1], batch[positive:positive+1], batch[negative_set] 39 | inner_sum = a_embs[:,None,:].bmm((n_embs - p_embs[:,None,:]).permute(0,2,1)) 40 | inner_sum = inner_sum.view(inner_sum.shape[0], inner_sum.shape[-1]) 41 | loss = loss + torch.mean(torch.log(torch.sum(torch.exp(inner_sum), dim=1) + 1))/len(anchors) 42 | loss = loss + self.l2_weight*torch.mean(torch.norm(batch, p=2, dim=1))/len(anchors) 43 | 44 | 45 | return loss 46 | -------------------------------------------------------------------------------- /criteria/proxynca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = None 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | 11 | 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | opt: Namespace containing all relevant parameters. 17 | """ 18 | super(Criterion, self).__init__() 19 | 20 | #### 21 | self.num_proxies = opt.n_classes 22 | self.embed_dim = opt.embed_dim 23 | 24 | self.proxies = torch.nn.Parameter(torch.randn(self.num_proxies, self.embed_dim)/8) 25 | self.class_idxs = torch.arange(self.num_proxies) 26 | 27 | self.name = 'proxynca' 28 | 29 | self.optim_dict_list = [{'params':self.proxies, 'lr':opt.lr * opt.loss_proxynca_lrmulti}] 30 | 31 | 32 | #### 33 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 34 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 35 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 36 | 37 | 38 | 39 | def forward(self, batch, labels, **kwargs): 40 | #Empirically, multiplying the embeddings during the computation of the loss seem to allow for more stable training; 41 | #Acts as a temperature in the NCA objective. 42 | batch = 3*torch.nn.functional.normalize(batch, dim=1) 43 | proxies = 3*torch.nn.functional.normalize(self.proxies, dim=1) 44 | #Group required proxies 45 | pos_proxies = torch.stack([proxies[pos_label:pos_label+1,:] for pos_label in labels]) 46 | neg_proxies = torch.stack([torch.cat([self.class_idxs[:class_label],self.class_idxs[class_label+1:]]) for class_label in labels]) 47 | neg_proxies = torch.stack([proxies[neg_labels,:] for neg_labels in neg_proxies]) 48 | #Compute Proxy-distances 49 | dist_to_neg_proxies = torch.sum((batch[:,None,:]-neg_proxies).pow(2),dim=-1) 50 | dist_to_pos_proxies = torch.sum((batch[:,None,:]-pos_proxies).pow(2),dim=-1) 51 | #Compute final proxy-based NCA loss 52 | negative_log_proxy_nca_loss = torch.mean(dist_to_pos_proxies[:,0] + torch.logsumexp(-dist_to_neg_proxies, dim=1)) 53 | 54 | return negative_log_proxy_nca_loss 55 | -------------------------------------------------------------------------------- /criteria/quadruplet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | """=================================================================================================""" 5 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 6 | REQUIRES_BATCHMINER = True 7 | REQUIRES_OPTIM = False 8 | 9 | 10 | class Criterion(torch.nn.Module): 11 | def __init__(self, opt, batchminer): 12 | super(Criterion, self).__init__() 13 | self.batchminer = batchminer 14 | 15 | self.name = 'quadruplet' 16 | 17 | self.margin_alpha_1 = opt.loss_quadruplet_margin_alpha_1 18 | self.margin_alpha_2 = opt.loss_quadruplet_margin_alpha_2 19 | 20 | #### 21 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 22 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 23 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 24 | 25 | 26 | 27 | def triplet_distance(self, anchor, positive, negative): 28 | return torch.nn.functional.relu(torch.norm(anchor-positive, p=2, dim=-1)-torch.norm(anchor-negative, p=2, dim=-1)+self.margin_alpha_1) 29 | 30 | def quadruplet_distance(self, anchor, positive, negative, fourth_negative): 31 | return torch.nn.functional.relu(torch.norm(anchor-positive, p=2, dim=-1)-torch.norm(negative-fourth_negative, p=2, dim=-1)+self.margin_alpha_2) 32 | 33 | def forward(self, batch, labels, **kwargs): 34 | sampled_triplets = self.batchminer(batch, labels) 35 | 36 | anchors = np.array([triplet[0] for triplet in sampled_triplets]).reshape(-1,1) 37 | positives = np.array([triplet[1] for triplet in sampled_triplets]).reshape(-1,1) 38 | negatives = np.array([triplet[2] for triplet in sampled_triplets]).reshape(-1,1) 39 | 40 | fourth_negatives = negatives!=negatives.T 41 | fourth_negatives = [np.random.choice(np.arange(len(batch))[idxs]) for idxs in fourth_negatives] 42 | 43 | triplet_loss = self.triplet_distance(batch[anchors,:],batch[positives,:],batch[negatives,:]) 44 | quadruplet_loss = self.quadruplet_distance(batch[anchors,:],batch[positives,:],batch[negatives,:],batch[fourth_negatives,:]) 45 | 46 | return torch.mean(triplet_loss) + torch.mean(quadruplet_loss) 47 | -------------------------------------------------------------------------------- /criteria/snr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = False 9 | 10 | ### This implements the Signal-To-Noise Ratio Triplet Loss 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | super(Criterion, self).__init__() 14 | self.margin = opt.loss_snr_margin 15 | self.reg_lambda = opt.loss_snr_reg_lambda 16 | self.batchminer = batchminer 17 | 18 | if self.batchminer.name=='distance': self.reg_lambda = 0 19 | 20 | self.name = 'snr' 21 | 22 | #### 23 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 24 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 25 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 26 | 27 | 28 | 29 | def forward(self, batch, labels, **kwargs): 30 | sampled_triplets = self.batchminer(batch, labels) 31 | anchors = [triplet[0] for triplet in sampled_triplets] 32 | positives = [triplet[1] for triplet in sampled_triplets] 33 | negatives = [triplet[2] for triplet in sampled_triplets] 34 | 35 | pos_snr = torch.var(batch[anchors,:]-batch[positives,:], dim=1)/torch.var(batch[anchors,:], dim=1) 36 | neg_snr = torch.var(batch[anchors,:]-batch[negatives,:], dim=1)/torch.var(batch[anchors,:], dim=1) 37 | 38 | reg_loss = torch.mean(torch.abs(torch.sum(batch[anchors,:],dim=1))) 39 | 40 | snr_loss = torch.nn.functional.relu(pos_snr - neg_snr + self.margin) 41 | snr_loss = torch.sum(snr_loss)/torch.sum(snr_loss>0) 42 | 43 | loss = snr_loss + self.reg_lambda * reg_loss 44 | 45 | return loss 46 | -------------------------------------------------------------------------------- /criteria/softmax.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = None 7 | REQUIRES_BATCHMINER = False 8 | REQUIRES_OPTIM = True 9 | 10 | ### This Implementation follows: https://github.com/azgo14/classification_metric_learning 11 | 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Criterion, self).__init__() 15 | self.par = opt 16 | 17 | self.temperature = opt.loss_softmax_temperature 18 | 19 | self.class_map = torch.nn.Parameter(torch.Tensor(opt.n_classes, opt.embed_dim)) 20 | stdv = 1. / np.sqrt(self.class_map.size(1)) 21 | self.class_map.data.uniform_(-stdv, stdv) 22 | 23 | self.name = 'softmax' 24 | 25 | self.lr = opt.loss_softmax_lr 26 | 27 | #### 28 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 29 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 30 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 31 | 32 | 33 | def forward(self, batch, labels, **kwargs): 34 | class_mapped_batch = torch.nn.functional.linear(batch, torch.nn.functional.normalize(self.class_map, dim=1)) 35 | 36 | loss = torch.nn.CrossEntropyLoss()(class_mapped_batch/self.temperature, labels.to(torch.long).to(self.par.device)) 37 | 38 | return loss 39 | -------------------------------------------------------------------------------- /criteria/softtriplet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = None 7 | REQUIRES_BATCHMINER = False 8 | REQUIRES_OPTIM = True 9 | 10 | ### This implementation follows https://github.com/idstcv/SoftTriple 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(Criterion, self).__init__() 14 | 15 | #### 16 | self.par = opt 17 | self.n_classes = opt.n_classes 18 | 19 | #### 20 | self.n_centroids = opt.loss_softtriplet_n_centroids # K 10 21 | # original: 10 22 | self.margin_delta = opt.loss_softtriplet_margin_delta # margin 0.01 23 | self.gamma = opt.loss_softtriplet_gamma # 1 ./ gamma = 10 24 | self.lam = opt.loss_softtriplet_lambda # la 20 25 | self.reg_weight = opt.loss_softtriplet_reg_weight # tau 0.2 26 | 27 | 28 | #### 29 | self.reg_norm = self.n_classes*self.n_centroids*(self.n_centroids-1) 30 | self.reg_indices = torch.zeros((self.n_classes*self.n_centroids, self.n_classes*self.n_centroids), dtype=torch.bool).to(opt.device) 31 | for i in range(0, self.n_classes): 32 | for j in range(0, self.n_centroids): 33 | self.reg_indices[i*self.n_centroids+j, i*self.n_centroids+j+1:(i+1)*self.n_centroids] = 1 34 | 35 | 36 | #### 37 | self.intra_class_centroids = torch.nn.Parameter(torch.Tensor(opt.embed_dim, self.n_classes*self.n_centroids)) 38 | stdv = 1. / np.sqrt(self.intra_class_centroids.size(1)) 39 | self.intra_class_centroids.data.uniform_(-stdv, stdv) 40 | 41 | self.name = 'softtriplet' 42 | 43 | self.lr = opt.lr*opt.loss_softtriplet_lrmulti 44 | 45 | #### 46 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 47 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 48 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 49 | 50 | 51 | def forward(self, batch, labels, **kwargs): 52 | bs = batch.size(0) 53 | 54 | intra_class_centroids = torch.nn.functional.normalize(self.intra_class_centroids, dim=1) 55 | similarities_to_centroids = batch.mm(intra_class_centroids).reshape(-1, self.n_classes, self.n_centroids) 56 | 57 | soft_weight_over_centroids = torch.nn.Softmax(dim=1)(self.gamma*similarities_to_centroids) 58 | per_class_embed = torch.sum(soft_weight_over_centroids * similarities_to_centroids, dim=2) 59 | 60 | margin_delta = torch.zeros(per_class_embed.shape).to(self.par.device) 61 | margin_delta[torch.arange(0, bs), labels] = self.margin_delta 62 | 63 | centroid_classification_loss = torch.nn.CrossEntropyLoss()(self.lam*(per_class_embed-margin_delta), labels.to(torch.long).to(self.par.device)) 64 | 65 | inter_centroid_similarity = intra_class_centroids.T.mm(intra_class_centroids) 66 | regularisation_loss = torch.sum(torch.sqrt(2.00001-2*inter_centroid_similarity[self.reg_indices]))/self.reg_norm 67 | 68 | return centroid_classification_loss + self.reg_weight * regularisation_loss 69 | -------------------------------------------------------------------------------- /criteria/triplet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = False 9 | 10 | ### Standard Triplet Loss, finds triplets in Mini-batches. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | super(Criterion, self).__init__() 14 | self.margin = opt.loss_triplet_margin 15 | self.batchminer = batchminer 16 | self.name = 'triplet' 17 | 18 | #### 19 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 20 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 21 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 22 | 23 | 24 | def triplet_distance(self, anchor, positive, negative): 25 | return torch.nn.functional.relu((anchor-positive).pow(2).sum()-(anchor-negative).pow(2).sum()+self.margin) 26 | 27 | def forward(self, batch, labels, **kwargs): 28 | if isinstance(labels, torch.Tensor): labels = labels.cpu().numpy() 29 | sampled_triplets = self.batchminer(batch, labels) 30 | loss = torch.stack([self.triplet_distance(batch[triplet[0],:],batch[triplet[1],:],batch[triplet[2],:]) for triplet in sampled_triplets]) 31 | 32 | return torch.mean(loss) 33 | -------------------------------------------------------------------------------- /datasampler/__init__.py: -------------------------------------------------------------------------------- 1 | import datasampler.class_random_sampler 2 | import datasampler.random_sampler 3 | import datasampler.greedy_coreset_sampler 4 | import datasampler.fid_batchmatch_sampler 5 | import datasampler.disthist_batchmatch_sampler 6 | import datasampler.d2_coreset_sampler 7 | 8 | 9 | def select(sampler, opt, image_dict, image_list=None, **kwargs): 10 | if 'batchmatch' in sampler: 11 | if sampler=='disthist_batchmatch': 12 | sampler_lib = disthist_batchmatch_sampler 13 | elif sampler=='fid_batchmatch': 14 | sampler_lib = spc_fid_batchmatch_sampler 15 | elif 'random' in sampler: 16 | if 'class' in sampler: 17 | sampler_lib = class_random_sampler 18 | elif 'full' in sampler: 19 | sampler_lib = random_sampler 20 | elif 'coreset' in sampler: 21 | if 'greedy' in sampler: 22 | sampler_lib = greedy_coreset_sampler 23 | elif 'd2' in sampler: 24 | sampler_lib = d2_coreset_sampler 25 | else: 26 | raise Exception('Minibatch sampler <{}> not available!'.format(sampler)) 27 | 28 | sampler = sampler_lib.Sampler(opt,image_dict=image_dict,image_list=image_list) 29 | 30 | return sampler 31 | -------------------------------------------------------------------------------- /datasampler/class_random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | 6 | 7 | 8 | """======================================================""" 9 | REQUIRES_STORAGE = False 10 | 11 | ### 12 | class Sampler(torch.utils.data.sampler.Sampler): 13 | """ 14 | Plugs into PyTorch Batchsampler Package. 15 | """ 16 | def __init__(self, opt, image_dict, image_list, **kwargs): 17 | self.pars = opt 18 | 19 | ##### 20 | self.image_dict = image_dict 21 | self.image_list = image_list 22 | 23 | ##### 24 | self.classes = list(self.image_dict.keys()) 25 | 26 | #### 27 | self.batch_size = opt.bs 28 | self.samples_per_class = opt.samples_per_class 29 | self.sampler_length = len(image_list)//opt.bs 30 | assert self.batch_size%self.samples_per_class==0, '#Samples per class must divide batchsize!' 31 | 32 | self.name = 'class_random_sampler' 33 | self.requires_storage = False 34 | 35 | def __iter__(self): 36 | for _ in range(self.sampler_length): 37 | subset = [] 38 | ### Random Subset from Random classes 39 | draws = self.batch_size//self.samples_per_class 40 | 41 | for _ in range(draws): 42 | class_key = random.choice(self.classes) 43 | class_ix_list = [random.choice(self.image_dict[class_key])[-1] for _ in range(self.samples_per_class)] 44 | subset.extend(class_ix_list) 45 | 46 | yield subset 47 | 48 | def __len__(self): 49 | return self.sampler_length 50 | -------------------------------------------------------------------------------- /datasampler/d2_coreset_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | from scipy import linalg 6 | from scipy.stats import multivariate_normal 7 | 8 | """======================================================""" 9 | REQUIRES_STORAGE = True 10 | 11 | ### 12 | class Sampler(torch.utils.data.sampler.Sampler): 13 | """ 14 | Plugs into PyTorch Batchsampler Package. 15 | """ 16 | def __init__(self, opt, image_dict, image_list): 17 | self.image_dict = image_dict 18 | self.image_list = image_list 19 | 20 | self.batch_size = opt.bs 21 | self.samples_per_class = opt.samples_per_class 22 | self.sampler_length = len(image_list)//opt.bs 23 | assert self.batch_size%self.samples_per_class==0, '#Samples per class must divide batchsize!' 24 | 25 | self.name = 'greedy_coreset_sampler' 26 | self.requires_storage = True 27 | 28 | self.bigbs = opt.data_batchmatch_bigbs 29 | self.update_storage = not opt.data_storage_no_update 30 | self.num_batch_comps = opt.data_batchmatch_ncomps 31 | 32 | self.low_proj_dim = opt.data_sampler_lowproj_dim 33 | 34 | self.lam = opt.data_d2_coreset_lambda 35 | 36 | self.n_jobs = 16 37 | 38 | def __iter__(self): 39 | for i in range(self.sampler_length): 40 | yield self.epoch_indices[i] 41 | 42 | 43 | def precompute_indices(self): 44 | from joblib import Parallel, delayed 45 | import time 46 | 47 | ### Random Subset from Random classes 48 | bigb_idxs = np.random.choice(len(self.storage), self.bigbs, replace=True) 49 | bigbatch = self.storage[bigb_idxs] 50 | 51 | print('Precomputing Indices... ', end='') 52 | start = time.time() 53 | def batchfinder(n_calls, pos): 54 | idx_sets = self.d2_coreset(n_calls, pos) 55 | structured_batches = [list(bigb_idxs[idx_set]) for idx_set in idx_sets] 56 | # structured_batch = list(bigb_idxs[self.fid_match(bigbatch, batch_size=self.batch_size//self.samples_per_class)]) 57 | #Add random per-class fillers to ensure that the batch is build up correctly. 58 | for i in range(len(structured_batches)): 59 | class_idxs = [self.image_list[idx][-1] for idx in structured_batches[i]] 60 | for class_idx in class_idxs: 61 | structured_batches[i].extend([random.choice(self.image_dict[class_idx])[-1] for _ in range(self.samples_per_class-1)]) 62 | 63 | return structured_batches 64 | 65 | n_calls = int(np.ceil(self.sampler_length/self.n_jobs)) 66 | # self.epoch_indices = batchfinder(n_calls, 0) 67 | self.epoch_indices = Parallel(n_jobs = self.n_jobs)(delayed(batchfinder)(n_calls, i) for i in range(self.n_jobs)) 68 | self.epoch_indices = [x for y in self.epoch_indices for x in y] 69 | # self.epoch_indices = Parallel(n_jobs = self.n_jobs)(delayed(batchfinder)(self.storage[np.random.choice(len(self.storage), self.bigbs, replace=True)]) for _ in tqdm(range(self.sampler_length), desc='Precomputing Indices...')) 70 | 71 | print('Done in {0:3.4f}s.'.format(time.time()-start)) 72 | def replace_storage_entries(self, embeddings, indices): 73 | self.storage[indices] = embeddings 74 | 75 | def create_storage(self, dataloader, model, device): 76 | with torch.no_grad(): 77 | _ = model.eval() 78 | _ = model.to(device) 79 | 80 | embed_collect = [] 81 | for i,input_tuple in enumerate(tqdm(dataloader, 'Creating data storage...')): 82 | embed = model(input_tuple[1].type(torch.FloatTensor).to(device)) 83 | if isinstance(embed, tuple): embed = embed[0] 84 | embed = embed.cpu() 85 | embed_collect.append(embed) 86 | embed_collect = torch.cat(embed_collect, dim=0) 87 | self.storage = embed_collect 88 | 89 | 90 | def d2_coreset(self, calls, pos): 91 | """ 92 | """ 93 | coll = [] 94 | 95 | for _ in range(calls): 96 | bigbatch = self.storage[np.random.choice(len(self.storage), self.bigbs, replace=False)] 97 | batch_size = self.batch_size//self.samples_per_class 98 | 99 | if self.low_proj_dim>0: 100 | low_dim_proj = nn.Linear(bigbatch.shape[-1],self.low_proj_dim,bias=False) 101 | with torch.no_grad(): bigbatch = low_dim_proj(bigbatch) 102 | 103 | bigbatch = bigbatch.numpy() 104 | # emp_mean, emp_std = np.mean(bigbatch, axis=0), np.std(bigbatch, axis=0) 105 | emp_mean, emp_cov = np.mean(bigbatch, axis=0), np.cov(bigbatch.T) 106 | 107 | prod = np.matmul(bigbatch, bigbatch.T) 108 | sq = prod.diagonal().reshape(bigbatch.shape[0], 1) 109 | dist_matrix = np.clip(-2*prod + sq + sq.T, 0, None) 110 | 111 | start_anchor = np.random.multivariate_normal(emp_mean, emp_cov, 1).reshape(-1) 112 | start_dists = np.linalg.norm(bigbatch-start_anchor,axis=1) 113 | start_point = np.argmin(start_dists, axis=0) 114 | 115 | idxs = list(range(len(bigbatch))) 116 | del idxs[start_point] 117 | 118 | k, sampled_indices = 1, [start_point] 119 | dist_weights = dist_matrix[:,start_point] 120 | 121 | normal_weights = multivariate_normal.pdf(bigbatch,emp_mean,emp_cov) 122 | while k0: 118 | low_dim_proj = nn.Linear(bigbatch.shape[-1],self.low_proj_dim,bias=False) 119 | with torch.no_grad(): bigbatch = low_dim_proj(bigbatch) 120 | bigbatch = bigbatch.numpy() 121 | 122 | bigb_distmat_triu_idxs = np.triu_indices(len(bigbatch),1) 123 | bigb_distvals = self.get_distmat(bigbatch)[bigb_distmat_triu_idxs] 124 | 125 | bigb_disthist_range, bigb_disthist_bins = (np.min(bigb_distvals), np.max(bigb_distvals)), 50 126 | bigb_disthist, _ = np.histogram(bigb_distvals, bins=bigb_disthist_bins, range=bigb_disthist_range) 127 | bigb_disthist = bigb_disthist/np.sum(bigb_disthist) 128 | 129 | bigb_mu = np.mean(bigbatch, axis=0) 130 | bigb_std = np.std(bigbatch, axis=0) 131 | 132 | 133 | cost_collect, bigb_idxs = [], [] 134 | 135 | for _ in range(self.num_batch_comps): 136 | subset_idxs = [np.random.choice(bigb_dict[np.random.choice(list(bigb_dict.keys()))], self.samples_per_class, replace=False) for _ in range(self.batch_size//self.samples_per_class)] 137 | subset_idxs = [x for y in subset_idxs for x in y] 138 | # subset_idxs = sorted(np.random.choice(len(bigbatch), batch_size, replace=False)) 139 | bigb_idxs.append(subset_idxs) 140 | subset = bigbatch[subset_idxs,:] 141 | subset_distmat = self.get_distmat(subset) 142 | 143 | subset_distmat_triu_idxs = np.triu_indices(len(subset_distmat),1) 144 | subset_distvals = self.get_distmat(subset)[subset_distmat_triu_idxs] 145 | 146 | subset_disthist_range, subset_disthist_bins = (np.min(subset_distvals), np.max(subset_distvals)), 50 147 | subset_disthist, _ = np.histogram(subset_distvals, bins=bigb_disthist_bins, range=bigb_disthist_range) 148 | subset_disthist = subset_disthist/np.sum(subset_disthist) 149 | 150 | subset_mu = np.mean(subset, axis=0) 151 | subset_std = np.std(subset, axis=0) 152 | 153 | 154 | dist_wd = wasserstein_distance(bigb_disthist, subset_disthist)+wasserstein_distance(subset_disthist, bigb_disthist) 155 | cost = np.linalg.norm(bigb_mu - subset_mu) + np.linalg.norm(bigb_std - subset_std) + 75*dist_wd 156 | cost_collect.append(cost) 157 | 158 | bigb_ix = bigb_idxs[np.argmin(cost_collect)] 159 | bigb_data_ix = bigb_data_idxs[bigb_ix] 160 | coll.append(bigb_data_ix) 161 | 162 | return coll 163 | 164 | def __len__(self): 165 | return self.sampler_length 166 | -------------------------------------------------------------------------------- /datasampler/fid_batchmatch_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | from scipy import linalg 6 | 7 | 8 | """======================================================""" 9 | REQUIRES_STORAGE = True 10 | 11 | ### 12 | class Sampler(torch.utils.data.sampler.Sampler): 13 | """ 14 | Plugs into PyTorch Batchsampler Package. 15 | """ 16 | def __init__(self, opt, image_dict, image_list): 17 | self.image_dict = image_dict 18 | self.image_list = image_list 19 | 20 | self.batch_size = opt.bs 21 | self.samples_per_class = opt.samples_per_class 22 | self.sampler_length = len(image_list)//opt.bs 23 | assert self.batch_size%self.samples_per_class==0, '#Samples per class must divide batchsize!' 24 | 25 | self.name = 'spc_fid_batchmatch_sampler' 26 | self.requires_storage = True 27 | 28 | self.bigbs = opt.data_batchmatch_bigbs 29 | self.update_storage = not opt.data_storage_no_update 30 | self.num_batch_comps = opt.data_batchmatch_ncomps 31 | self.low_proj_dim = opt.data_sampler_lowproj_dim 32 | 33 | self.n_jobs = 16 34 | 35 | self.internal_image_dict = {self.image_list[i]:i for i in range(len(self.image_list))} 36 | 37 | 38 | def __iter__(self): 39 | for i in range(self.sampler_length): 40 | # ### Random Subset from Random classes 41 | # bigb_idxs = np.random.choice(len(self.storage), self.bigbs, replace=True) 42 | # bigbatch = self.storage[bigb_idxs] 43 | # 44 | # structured_batch = list(bigb_idxs[self.fid_match(bigbatch, batch_size=self.batch_size//self.samples_per_class)]) 45 | # #Add random per-class fillers to ensure that the batch is build up correctly. 46 | # 47 | # class_idxs = [self.image_list[idx][-1] for idx in structured_batch] 48 | # for class_idx in class_idxs: 49 | # structured_batch.extend([random.choice(self.image_dict[class_idx])[-1] for _ in range(self.samples_per_class-1)]) 50 | 51 | yield self.epoch_indices[i] 52 | 53 | 54 | def precompute_indices(self): 55 | from joblib import Parallel, delayed 56 | import time 57 | ### Random Subset from Random classes 58 | # self.disthist_match() 59 | print('Precomputing Indices... ', end='') 60 | start = time.time() 61 | n_calls = int(np.ceil(self.sampler_length/self.n_jobs)) 62 | self.epoch_indices = Parallel(n_jobs = self.n_jobs)(delayed(self.spc_fid_match)(n_calls, i) for i in range(self.n_jobs)) 63 | self.epoch_indices = [x for y in self.epoch_indices for x in y] 64 | print('Done in {0:3.4f}s.'.format(time.time()-start)) 65 | 66 | 67 | def replace_storage_entries(self, embeddings, indices): 68 | self.storage[indices] = embeddings 69 | 70 | def create_storage(self, dataloader, model, device): 71 | with torch.no_grad(): 72 | _ = model.eval() 73 | _ = model.to(device) 74 | 75 | embed_collect = [] 76 | for i,input_tuple in enumerate(tqdm(dataloader, 'Creating data storage...')): 77 | embed = model(input_tuple[1].type(torch.FloatTensor).to(device)) 78 | if isinstance(embed, tuple): embed = embed[0] 79 | embed = embed.cpu() 80 | embed_collect.append(embed) 81 | embed_collect = torch.cat(embed_collect, dim=0) 82 | self.storage = embed_collect 83 | 84 | 85 | def spc_batchfinder(self, n_samples): 86 | ### SpC-Sample big batch: 87 | subset, classes = [], [] 88 | ### Random Subset from Random classes 89 | for _ in range(n_samples//self.samples_per_class): 90 | class_key = random.choice(list(self.image_dict.keys())) 91 | # subset.extend([(class_key, random.choice(len(self.image_dict[class_key])) for _ in range(self.samples_per_class)]) 92 | subset.extend([random.choice(self.image_dict[class_key])[-1] for _ in range(self.samples_per_class)]) 93 | classes.extend([class_key]*self.samples_per_class) 94 | return np.array(subset), np.array(classes) 95 | 96 | 97 | def spc_fid_match(self, calls, pos): 98 | """ 99 | """ 100 | coll = [] 101 | 102 | for _ in range(calls): 103 | bigb_data_idxs, bigb_data_classes = self.spc_batchfinder(self.bigbs) 104 | bigb_dict = {} 105 | for i, bigb_cls in enumerate(bigb_data_classes): 106 | if bigb_cls not in bigb_dict: bigb_dict[bigb_cls] = [] 107 | bigb_dict[bigb_cls].append(i) 108 | 109 | bigbatch = self.storage[bigb_data_idxs] 110 | if self.low_proj_dim>0: 111 | low_dim_proj = nn.Linear(bigbatch.shape[-1],self.low_proj_dim,bias=False) 112 | with torch.no_grad(): bigbatch = low_dim_proj(bigbatch) 113 | bigbatch = bigbatch.numpy() 114 | 115 | bigbatch_mean = np.mean(bigbatch, axis=0).reshape(-1,1) 116 | bigbatch_cov = np.cov(bigbatch.T) 117 | 118 | 119 | fid_collect, bigb_idxs = [], [] 120 | 121 | for _ in range(self.num_batch_comps): 122 | subset_idxs = [np.random.choice(bigb_dict[np.random.choice(list(bigb_dict.keys()))], self.samples_per_class, replace=False) for _ in range(self.batch_size//self.samples_per_class)] 123 | subset_idxs = [x for y in subset_idxs for x in y] 124 | # subset_idxs = sorted(np.random.choice(len(bigbatch), batch_size, replace=False)) 125 | bigb_idxs.append(subset_idxs) 126 | subset = bigbatch[subset_idxs,:] 127 | 128 | subset_mean = np.mean(subset, axis=0).reshape(-1,1) 129 | subset_cov = np.cov(subset.T) 130 | 131 | diag_offset = np.eye(subset_cov.shape[0])*1e-8 132 | cov_sqrt = linalg.sqrtm((bigbatch_cov+diag_offset).dot((subset_cov+diag_offset)), disp=False)[0].real 133 | 134 | diff = bigbatch_mean-subset_mean 135 | fid = diff.T.dot(diff) + np.trace(bigbatch_cov) + np.trace(subset_cov) - 2*np.trace(cov_sqrt) 136 | 137 | fid_collect.append(fid) 138 | 139 | bigb_ix = bigb_idxs[np.argmin(fid_collect)] 140 | bigb_data_ix = bigb_data_idxs[bigb_ix] 141 | coll.append(bigb_data_ix) 142 | 143 | return coll 144 | 145 | 146 | 147 | def __len__(self): 148 | return self.sampler_length 149 | -------------------------------------------------------------------------------- /datasampler/greedy_coreset_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | from scipy import linalg 6 | 7 | 8 | """======================================================""" 9 | REQUIRES_STORAGE = True 10 | 11 | ### 12 | class Sampler(torch.utils.data.sampler.Sampler): 13 | """ 14 | Plugs into PyTorch Batchsampler Package. 15 | """ 16 | def __init__(self, opt, image_dict, image_list): 17 | self.image_dict = image_dict 18 | self.image_list = image_list 19 | 20 | self.batch_size = opt.bs 21 | self.samples_per_class = opt.samples_per_class 22 | self.sampler_length = len(image_list)//opt.bs 23 | assert self.batch_size%self.samples_per_class==0, '#Samples per class must divide batchsize!' 24 | 25 | self.name = 'greedy_coreset_sampler' 26 | self.requires_storage = True 27 | 28 | self.bigbs = opt.data_batchmatch_bigbs 29 | self.update_storage = not opt.data_storage_no_update 30 | self.num_batch_comps = opt.data_batchmatch_ncomps 31 | self.dist_lim = opt.data_gc_coreset_lim 32 | 33 | self.low_proj_dim = opt.data_sampler_lowproj_dim 34 | 35 | self.softened = opt.data_gc_softened 36 | 37 | self.n_jobs = 16 38 | 39 | def __iter__(self): 40 | for i in range(self.sampler_length): 41 | yield self.epoch_indices[i] 42 | 43 | 44 | def precompute_indices(self): 45 | from joblib import Parallel, delayed 46 | import time 47 | 48 | ### Random Subset from Random classes 49 | bigb_idxs = np.random.choice(len(self.storage), self.bigbs, replace=True) 50 | bigbatch = self.storage[bigb_idxs] 51 | 52 | print('Precomputing Indices... ', end='') 53 | start = time.time() 54 | def batchfinder(n_calls, pos): 55 | idx_sets = self.greedy_coreset(n_calls, pos) 56 | structured_batches = [list(bigb_idxs[idx_set]) for idx_set in idx_sets] 57 | # structured_batch = list(bigb_idxs[self.fid_match(bigbatch, batch_size=self.batch_size//self.samples_per_class)]) 58 | #Add random per-class fillers to ensure that the batch is build up correctly. 59 | for i in range(len(structured_batches)): 60 | class_idxs = [self.image_list[idx][-1] for idx in structured_batches[i]] 61 | for class_idx in class_idxs: 62 | structured_batches[i].extend([random.choice(self.image_dict[class_idx])[-1] for _ in range(self.samples_per_class-1)]) 63 | 64 | return structured_batches 65 | 66 | n_calls = int(np.ceil(self.sampler_length/self.n_jobs)) 67 | # self.epoch_indices = batchfinder(n_calls, 0) 68 | self.epoch_indices = Parallel(n_jobs = self.n_jobs)(delayed(batchfinder)(n_calls, i) for i in range(self.n_jobs)) 69 | self.epoch_indices = [x for y in self.epoch_indices for x in y] 70 | # self.epoch_indices = Parallel(n_jobs = self.n_jobs)(delayed(batchfinder)(self.storage[np.random.choice(len(self.storage), self.bigbs, replace=True)]) for _ in tqdm(range(self.sampler_length), desc='Precomputing Indices...')) 71 | 72 | print('Done in {0:3.4f}s.'.format(time.time()-start)) 73 | 74 | 75 | def replace_storage_entries(self, embeddings, indices): 76 | self.storage[indices] = embeddings 77 | 78 | def create_storage(self, dataloader, model, device): 79 | with torch.no_grad(): 80 | _ = model.eval() 81 | _ = model.to(device) 82 | 83 | embed_collect = [] 84 | for i,input_tuple in enumerate(tqdm(dataloader, 'Creating data storage...')): 85 | embed = model(input_tuple[1].type(torch.FloatTensor).to(device)) 86 | if isinstance(embed, tuple): embed = embed[0] 87 | embed = embed.cpu() 88 | embed_collect.append(embed) 89 | embed_collect = torch.cat(embed_collect, dim=0) 90 | self.storage = embed_collect 91 | 92 | 93 | def full_storage_update(self, dataloader, model, device): 94 | with torch.no_grad(): 95 | _ = model.eval() 96 | _ = model.to(device) 97 | 98 | embed_collect = [] 99 | for i,input_tuple in enumerate(tqdm(dataloader, 'Creating data storage...')): 100 | embed = model(input_tuple[1].type(torch.FloatTensor).to(device)) 101 | if isinstance(embed, tuple): embed = embed[0] 102 | embed = embed.cpu() 103 | embed_collect.append(embed) 104 | embed_collect = torch.cat(embed_collect, dim=0) 105 | if self.mb_mom>0: 106 | self.delta_storage = self.mb_mom*self.delta_storage + (1-self.mb_mom)*(embed_collect-self.storage) 107 | self.storage = embed_collect + self.mb_lr*self.delta_storage 108 | else: 109 | self.storage = embed_collect 110 | 111 | def greedy_coreset(self, calls, pos): 112 | """ 113 | """ 114 | coll = [] 115 | 116 | 117 | for _ in range(calls): 118 | bigbatch = self.storage[np.random.choice(len(self.storage), self.bigbs, replace=False)] 119 | batch_size = self.batch_size//self.samples_per_class 120 | 121 | if self.low_proj_dim>0: 122 | low_dim_proj = nn.Linear(bigbatch.shape[-1],self.low_proj_dim,bias=False) 123 | with torch.no_grad(): bigbatch = low_dim_proj(bigbatch) 124 | 125 | bigbatch = bigbatch.numpy() 126 | 127 | prod = np.matmul(bigbatch, bigbatch.T) 128 | sq = prod.diagonal().reshape(bigbatch.shape[0], 1) 129 | dist_matrix = np.clip(-2*prod + sq + sq.T, 0, None) 130 | coreset_anchor_dists = np.linalg.norm(dist_matrix, axis=1) 131 | 132 | k, sampled_indices = 0, [] 133 | 134 | while k=np.percentile(coreset_anchor_dists,97))[0]) 140 | else: 141 | no = np.argmax(coreset_anchor_dists) 142 | 143 | sampled_indices.append(no) 144 | add_d = dist_matrix[:, no:no+1] 145 | #If its closer to the remaining points than the new addition/additions, sample it. 146 | new_dj = np.concatenate([np.expand_dims(coreset_anchor_dists,-1), add_d], axis=1) 147 | coreset_anchor_dists = np.min(new_dj, axis=1) 148 | k += 1 149 | 150 | coll.append(sampled_indices) 151 | 152 | return coll 153 | 154 | 155 | def __len__(self): 156 | return self.sampler_length 157 | -------------------------------------------------------------------------------- /datasampler/random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | 6 | 7 | 8 | """======================================================""" 9 | REQUIRES_STORAGE = False 10 | 11 | ### 12 | class Sampler(torch.utils.data.sampler.Sampler): 13 | """ 14 | Plugs into PyTorch Batchsampler Package. 15 | """ 16 | def __init__(self, opt, image_dict, image_list=None): 17 | self.image_dict = image_dict 18 | self.image_list = image_list 19 | 20 | self.batch_size = opt.bs 21 | self.samples_per_class = opt.samples_per_class 22 | self.sampler_length = len(image_list)//opt.bs 23 | assert self.batch_size%self.samples_per_class==0, '#Samples per class must divide batchsize!' 24 | 25 | self.name = 'random_sampler' 26 | self.requires_storage = False 27 | 28 | def __iter__(self): 29 | for _ in range(self.sampler_length): 30 | subset = [] 31 | ### Random Subset from Random classes 32 | for _ in range(self.batch_size-1): 33 | class_key = random.choice(list(self.image_dict.keys())) 34 | sample_idx = np.random.choice(len(self.image_dict[class_key])) 35 | subset.append(self.image_dict[class_key][sample_idx][-1]) 36 | # 37 | subset.append(random.choice(self.image_dict[self.image_list[random.choice(subset)][-1]])[-1]) 38 | yield subset 39 | 40 | def __len__(self): 41 | return self.sampler_length 42 | -------------------------------------------------------------------------------- /datasampler/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | 6 | 7 | """======================================================""" 8 | def sampler_parse_args(parser): 9 | parser.add_argument('--batch_selection', default='class_random', type=str, help='Selection of the data batch: Modes of Selection: random, greedy_coreset') 10 | parser.add_argument('--primary_subset_perc', default=0.1, type=float, help='Size of the randomly selected subset before application of coreset selection.') 11 | return parser 12 | 13 | 14 | 15 | """======================================================""" 16 | ### 17 | # Methods: Full random, Per-Class-Random, CoreSet 18 | class AdvancedSampler(torch.utils.data.sampler.Sampler): 19 | """ 20 | Plugs into PyTorch Batchsampler Package. 21 | """ 22 | def __init__(self, method='class_random', random_subset_perc=0.1, batch_size=128, samples_per_class=4): 23 | self.random_subset_perc = random_subset_perc 24 | self.batch_size = batch_size 25 | self.samples_per_class = samples_per_class 26 | 27 | self.method = method 28 | 29 | self.storage = None 30 | self.sampler_length = None 31 | 32 | self.methods_requiring_storage = ['greedy_class_coreset', 'greedy_semi_class_coreset', 'presampled_infobatch'] 33 | 34 | def create_storage(self, dataloader, model, device): 35 | self.image_dict = dataloader.dataset.image_dict 36 | self.image_list = dataloader.dataset.image_list 37 | 38 | self.sampler_length = len(dataloader.dataset)//self.batch_size 39 | 40 | if self.method in self.methods_requiring_storage: 41 | with torch.no_grad(): 42 | _ = model.eval() 43 | _ = model.to(device) 44 | 45 | embed_collect = [] 46 | for i,input_tuple in enumerate(tqdm(dataloader, 'Creating data storage...')): 47 | embed = model(input_tuple[1].type(torch.FloatTensor).to(device)).cpu() 48 | embed_collect.append(embed) 49 | embed_collect = torch.cat(embed_collect, dim=0) 50 | self.storage = embed_collect 51 | 52 | self.random_subset_len = int(self.random_subset_perc*len(self.storage)) 53 | 54 | def update_storage(self, embeddings, indices): 55 | if 'coreset' in self.method: 56 | self.storage[indices] = embeddings 57 | 58 | def __iter__(self): 59 | for _ in range(self.sampler_length): 60 | subset = [] 61 | if self.method=='greedy_class_coreset': 62 | for _ in range(self.batch_size//self.samples_per_class): 63 | class_key = random.choice(list(self.image_dict.keys())) 64 | class_indices = np.array([x[1] for x in self.image_dict[class_key]]) 65 | # print(class_indices) 66 | ### Coreset subset of subset 67 | subset.extend(class_indices[self.greedy_coreset(self.storage[class_indices], self.samples_per_class)]) 68 | # print([self.image_list[x][1] for x in subset]) 69 | elif self.method=='greedy_semi_class_coreset': 70 | ### Big random subset 71 | subset = np.random.randint(0,len(self.storage),self.random_subset_len) 72 | ### Coreset subset of subset of half the batch size 73 | subset = subset[self.greedy_coreset(self.storage[subset], self.batch_size//2)] 74 | ### Fill the rest of the batch with random samples from each coreset member class 75 | subset = list(subset)+[random.choice(self.image_dict[self.image_list[idx][-1]])[-1] for idx in subset] 76 | elif self.method=='presampled_infobatch': 77 | ### Big random subset 78 | subset = np.random.randint(0,len(self.storage),self.random_subset_len) 79 | classes = torch.tensor([self.image_list[idx][-1] for idx in subset]) 80 | ### Presampled Infobatch for subset of data. 81 | subset = subset[self.presample_infobatch(classes, self.storage[subset], self.batch_size//2)] 82 | ### Fill the rest of the batch with random samples from each member class 83 | subset = list(subset)+[random.choice(self.image_dict[self.image_list[idx][-1]])[-1] for idx in subset] 84 | elif self.method=='class_random': 85 | ### Random Subset from Random classes 86 | for _ in range(self.batch_size//self.samples_per_class): 87 | class_key = random.choice(list(self.image_dict.keys())) 88 | subset.extend([random.choice(self.image_dict[class_key])[-1] for _ in range(self.samples_per_class)]) 89 | elif self.method=='semi_class_random': 90 | ### Select half of the indices completely at random, and the other half corresponding to the classes. 91 | for _ in range(self.batch_size//2): 92 | rand_idx = np.random.randint(len(self.image_list)) 93 | class_idx = self.image_list[rand_idx][-1] 94 | rand_class_idx = random.choice(self.image_dict[class_idx])[-1] 95 | subset.extend([rand_idx, rand_class_idx]) 96 | else: 97 | raise NotImplementedError('Batch selection method {} not available!'.format(self.method)) 98 | yield subset 99 | 100 | def __len__(self): 101 | return self.sampler_length 102 | 103 | def pdistsq(self, A): 104 | prod = torch.mm(A, A.t()) 105 | diag = prod.diag().unsqueeze(1).expand_as(prod) 106 | return (-2*prod + diag + diag.T) 107 | 108 | def greedy_coreset(self, A, samples): 109 | dist_matrix = self.pdistsq(A) 110 | coreset_anchor_dists = torch.norm(dist_matrix, dim=1) 111 | 112 | sampled_indices, i = [], 0 113 | 114 | while i0 else '',metricname, metricval) 31 | full_result_str += '\n' 32 | 33 | print(full_result_str) 34 | 35 | 36 | ### 37 | for evaltype in evaltypes: 38 | for storage_metric in opt.storage_metrics: 39 | parent_metric = evaltype+'_{}'.format(storage_metric.split('@')[0]) 40 | if parent_metric not in LOG.progress_saver[log_key].groups.keys() or \ 41 | numeric_metrics[evaltype][storage_metric]>np.max(LOG.progress_saver[log_key].groups[parent_metric][storage_metric]['content']): 42 | print('Saved weights for best {}: {}\n'.format(log_key, parent_metric)) 43 | set_checkpoint(model, opt, LOG.progress_saver, LOG.prop.save_path+'/checkpoint_{}_{}_{}.pth.tar'.format(log_key, evaltype, storage_metric), aux=aux_store) 44 | 45 | 46 | ### 47 | if opt.log_online: 48 | for evaltype in histogr_metrics.keys(): 49 | for eval_metric, hist in histogr_metrics[evaltype].items(): 50 | import wandb, numpy 51 | wandb.log({log_key+': '+evaltype+'_{}'.format(eval_metric): wandb.Histogram(np_histogram=(list(hist),list(np.arange(len(hist)+1))))}, step=opt.epoch) 52 | wandb.log({log_key+': '+evaltype+'_LOG-{}'.format(eval_metric): wandb.Histogram(np_histogram=(list(np.log(hist)+20),list(np.arange(len(hist)+1))))}, step=opt.epoch) 53 | 54 | ### 55 | for evaltype in numeric_metrics.keys(): 56 | for eval_metric in numeric_metrics[evaltype].keys(): 57 | parent_metric = evaltype+'_{}'.format(eval_metric.split('@')[0]) 58 | LOG.progress_saver[log_key].log(eval_metric, numeric_metrics[evaltype][eval_metric], group=parent_metric) 59 | 60 | ### 61 | if make_recall_plot: 62 | recover_closest_standard(extra_infos[evaltype]['features'], 63 | extra_infos[evaltype]['image_paths'], 64 | LOG.prop.save_path+'/sample_recoveries.png') 65 | 66 | 67 | ########################### 68 | def set_checkpoint(model, opt, progress_saver, savepath, aux=None): 69 | if 'experiment' in vars(opt): 70 | import argparse 71 | save_opt = {key:item for key,item in vars(opt).items() if key!='experiment'} 72 | save_opt = argparse.Namespace(**save_opt) 73 | else: 74 | save_opt = opt 75 | 76 | torch.save({'state_dict':model.state_dict(), 'opt':save_opt, 'progress':progress_saver, 'aux':aux}, savepath) 77 | 78 | 79 | 80 | 81 | ########################## 82 | def recover_closest_standard(feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3): 83 | image_paths = np.array([x[0] for x in image_paths]) 84 | sample_idxs = np.random.choice(np.arange(len(feature_matrix_all)), n_image_samples) 85 | 86 | faiss_search_index = faiss.IndexFlatL2(feature_matrix_all.shape[-1]) 87 | faiss_search_index.add(feature_matrix_all) 88 | _, closest_feature_idxs = faiss_search_index.search(feature_matrix_all, n_closest+1) 89 | 90 | sample_paths = image_paths[closest_feature_idxs][sample_idxs] 91 | 92 | f,axes = plt.subplots(n_image_samples, n_closest+1) 93 | for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))): 94 | ax.imshow(np.array(Image.open(plot_path))) 95 | ax.set_xticks([]) 96 | ax.set_yticks([]) 97 | if i%(n_closest+1): 98 | ax.axvline(x=0, color='g', linewidth=13) 99 | else: 100 | ax.axvline(x=0, color='r', linewidth=13) 101 | f.set_size_inches(10,20) 102 | f.tight_layout() 103 | f.savefig(save_path) 104 | plt.close() 105 | -------------------------------------------------------------------------------- /evaluation/eval_diml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from tqdm import tqdm, trange 6 | 7 | from utilities.diml import Sinkhorn, calc_similarity 8 | from evaluation.metrics import get_metrics_rank, get_metrics 9 | 10 | def evaluate(model, dataloader, no_training=True, trunc_nums=None, use_uniform=False, grid_size=4): 11 | model.eval() 12 | with torch.no_grad(): 13 | if no_training: 14 | if 7 % grid_size == 0: 15 | resize = nn.AdaptiveAvgPool2d(grid_size) 16 | else: 17 | resize = nn.Sequential( 18 | nn.Upsample(grid_size * 4, mode='bilinear', align_corners=True), 19 | nn.AdaptiveAvgPool2d(grid_size), 20 | ) 21 | feature_bank_center = [] 22 | 23 | target_labels = [] 24 | feature_bank = [] 25 | 26 | labels = [] 27 | final_iter = tqdm(dataloader, desc='Embedding Data...') 28 | for idx, inp in enumerate(final_iter): 29 | input_img, target = inp[1], inp[0] 30 | target_labels.extend(target.numpy().tolist()) 31 | out = model(input_img.cuda()) 32 | if isinstance(out, tuple): out, aux_f = out 33 | 34 | if no_training: 35 | enc_out, no_avg_feat = aux_f 36 | no_avg_feat = no_avg_feat.transpose(1, 3) 37 | no_avg_feat = model.model.last_linear(no_avg_feat) 38 | no_avg_feat = no_avg_feat.transpose(1, 3) 39 | no_avg_feat = resize(no_avg_feat) 40 | feature_bank.append(no_avg_feat.data) 41 | feature_bank_center.append(out.data) 42 | else: 43 | feature_bank.append(out.data) 44 | 45 | labels.append(target) 46 | 47 | feature_bank = torch.cat(feature_bank, dim=0) 48 | labels = torch.cat(labels, dim=0) 49 | N, C, H, W = feature_bank.size() 50 | feature_bank = feature_bank.view(N, C, -1) 51 | 52 | if no_training: 53 | feature_bank_center = torch.cat(feature_bank_center, dim=0) 54 | else: 55 | feature_bank_center = feature_bank.mean(2) 56 | 57 | feature_bank = torch.nn.functional.normalize(feature_bank, p=2, dim=1) 58 | feature_bank_center = torch.nn.functional.normalize(feature_bank_center, p=2, dim=1) 59 | 60 | 61 | trunc_nums = trunc_nums or [0, 5, 10, 50, 100, 500, 1000] 62 | 63 | overall_r1 = {k: 0.0 for k in trunc_nums} 64 | overall_rp = {k: 0.0 for k in trunc_nums} 65 | overall_mapr = {k: 0.0 for k in trunc_nums} 66 | 67 | for idx in trange(len(feature_bank)): 68 | anchor_center = feature_bank_center[idx] 69 | approx_sim = calc_similarity(None, anchor_center, None, feature_bank_center, 0) 70 | approx_sim[idx] = -100 71 | 72 | approx_tops = torch.argsort(approx_sim, descending=True) 73 | 74 | if max(trunc_nums) > 0: 75 | top_inds = approx_tops[:max(trunc_nums)] 76 | 77 | anchor = feature_bank[idx] 78 | sim = calc_similarity(anchor, anchor_center, feature_bank[top_inds], feature_bank_center[top_inds], 1, use_uniform) 79 | rank_in_tops = torch.argsort(sim + approx_sim[top_inds], descending=True) 80 | 81 | for trunc_num in trunc_nums: 82 | if trunc_num == 0: 83 | final_tops = approx_tops 84 | else: 85 | rank_in_tops_real = top_inds[rank_in_tops][:trunc_num] 86 | 87 | final_tops = torch.cat([rank_in_tops_real, approx_tops[trunc_num:]], dim=0) 88 | 89 | # sim[idx] = -100 90 | r1, rp, mapr = get_metrics_rank(final_tops.data.cpu(), labels[idx], labels) 91 | 92 | overall_r1[trunc_num] += r1 93 | overall_rp[trunc_num] += rp 94 | overall_mapr[trunc_num] += mapr 95 | 96 | 97 | for trunc_num in trunc_nums: 98 | overall_r1[trunc_num] /= float(N / 100) 99 | overall_rp[trunc_num] /= float(N / 100) 100 | overall_mapr[trunc_num] /= float(N / 100) 101 | 102 | print("trunc_num: ", trunc_num) 103 | print('###########') 104 | print('Now rank-1 acc=%f, RP=%f, MAP@R=%f' % (overall_r1[trunc_num], overall_rp[trunc_num], overall_mapr[trunc_num])) 105 | 106 | data = { 107 | 'r1': [overall_r1[k] for k in trunc_nums], 108 | 'rp': [overall_rp[k] for k in trunc_nums], 109 | 'mapr': [overall_mapr[k] for k in trunc_nums], 110 | } 111 | return data -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_metrics(sim, query_label, gallery_label): 4 | tops = torch.argsort(sim, descending=True) 5 | 6 | top1 = tops[0] 7 | r1 = 0.0 8 | if query_label == gallery_label[top1]: 9 | r1 = 1.0 10 | 11 | num_pos = torch.sum(gallery_label == query_label).item() 12 | 13 | rp = torch.sum(gallery_label[tops[0:num_pos]] == query_label).float() / float(num_pos) 14 | 15 | equality = gallery_label[tops[0:num_pos]] == query_label 16 | equality = equality.float() 17 | cumulative_correct = torch.cumsum(equality, dim=0) 18 | k_idx = torch.arange(num_pos) + 1 19 | precision_at_ks = (cumulative_correct * equality) / k_idx 20 | 21 | rp = rp.item() 22 | mapr = torch.mean(precision_at_ks).item() 23 | 24 | return r1, rp, mapr 25 | 26 | def get_metrics_rank(tops, query_label, gallery_label): 27 | 28 | # tops = torch.argsort(sim, descending=True) 29 | top1 = tops[0] 30 | r1 = 0.0 31 | if query_label == gallery_label[top1]: 32 | r1 = 1.0 33 | 34 | num_pos = torch.sum(gallery_label == query_label).item() 35 | 36 | rp = torch.sum(gallery_label[tops[0:num_pos]] == query_label).float() / float(num_pos) 37 | 38 | equality = gallery_label[tops[0:num_pos]] == query_label 39 | equality = equality.float() 40 | cumulative_correct = torch.cumsum(equality, dim=0) 41 | k_idx = torch.arange(num_pos) + 1 42 | precision_at_ks = (cumulative_correct * equality) / k_idx 43 | 44 | rp = rp.item() 45 | mapr = torch.mean(precision_at_ks).item() 46 | 47 | return r1, rp, mapr 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /figs/intro.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DIML/c15dbea696f68ddf889dcacfcaacd315d16a34ac/figs/intro.gif -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | 3 | 4 | ####################################### 5 | def basic_training_parameters(parser): 6 | ##### Dataset-related Parameters 7 | parser.add_argument('--dataset', default='cub200', type=str, help='Dataset to use. Currently supported: cub200, cars196, online_products.') 8 | parser.add_argument('--use_tv_split', action='store_true', help='Flag. If set, split the training set into a training/validation set.') 9 | parser.add_argument('--tv_split_by_samples', action='store_true', help='Flag. If set, create the validation set by taking a percentage of samples PER class. \ 10 | Otherwise, the validation set is create by taking a percentage of classes.') 11 | parser.add_argument('--tv_split_perc', default=0.8, type=float, help='Percentage with which the training dataset is split into training/validation.') 12 | parser.add_argument('--augmentation', default='base', type=str, help='Type of preprocessing/augmentation to use on the data. \ 13 | Available: base (standard), adv (with color/brightness changes), big (Images of size 256x256), red (No RandomResizedCrop).') 14 | 15 | ### General Training Parameters 16 | parser.add_argument('--lr', default=0.00001, type=float, help='Learning Rate for network parameters.') 17 | parser.add_argument('--fc_lr', default=-1, type=float, help='Optional. If not -1, sets the learning rate for the final linear embedding layer.') 18 | parser.add_argument('--decay', default=0.0004, type=float, help='Weight decay placed on network weights.') 19 | parser.add_argument('--n_epochs', default=150, type=int, help='Number of training epochs.') 20 | parser.add_argument('--kernels', default=6, type=int, help='Number of workers for pytorch dataloader.') 21 | parser.add_argument('--bs', default=112 , type=int, help='Mini-Batchsize to use.') 22 | parser.add_argument('--seed', default=1, type=int, help='Random seed for reproducibility.') 23 | parser.add_argument('--scheduler', default='step', type=str, help='Type of learning rate scheduling. Currently supported: step') 24 | parser.add_argument('--gamma', default=0.3, type=float, help='Learning rate reduction after tau epochs.') 25 | parser.add_argument('--tau', default=[1000], nargs='+',type=int , help='Stepsize before reducing learning rate.') 26 | parser.add_argument('--resume', default=None, type=str, help='resume path') 27 | 28 | ##### Loss-specific Settings 29 | parser.add_argument('--optim', default='adam', type=str, help='Optimization method to use. Currently supported: adam & sgd.') 30 | parser.add_argument('--loss', default='margin', type=str, help='Training criteria: For supported methods, please check criteria/__init__.py') 31 | parser.add_argument('--batch_mining', default='distance', type=str, help='Batchminer for tuple-based losses: For supported methods, please check batch_mining/__init__.py') 32 | 33 | ##### Network-related Flags 34 | parser.add_argument('--embed_dim', default=128, type=int, help='Embedding dimensionality of the network. Note: dim = 64, 128 or 512 is used in most papers, depending on the architecture.') 35 | parser.add_argument('--not_pretrained', action='store_true', help='Flag. If set, no ImageNet pretraining is used to initialize the network.') 36 | parser.add_argument('--arch', default='resnet50_frozen_normalize', type=str, help='Underlying network architecture. Frozen denotes that \ 37 | exisiting pretrained batchnorm layers are frozen, and normalize denotes normalization of the output embedding.') 38 | parser.add_argument('--use_uniform', default=False, action='store_true') 39 | 40 | 41 | ##### Evaluation Parameters 42 | parser.add_argument('--no_train_metrics', action='store_true', help='Flag. If set, evaluation metrics are not computed for the training data. Saves a forward pass over the full training dataset.') 43 | parser.add_argument('--evaluate_on_gpu', action='store_true', help='Flag. If set, all metrics, when possible, are computed on the GPU (requires Faiss-GPU).') 44 | parser.add_argument('--evaluation_metrics', nargs='+', default=['e_recall@1', 'e_recall@2', 'e_recall@4', 'nmi', 'f1', 'mAP_1000', 'mAP_lim', 'mAP_c', \ 45 | 'dists@intra', 'dists@inter', 'dists@intra_over_inter', 'rho_spectrum@0', \ 46 | 'rho_spectrum@-1', 'rho_spectrum@1', 'rho_spectrum@2', 'rho_spectrum@10'], type=str, help='Metrics to evaluate performance by.') 47 | 48 | parser.add_argument('--storage_metrics', nargs='+', default=['e_recall@1'], type=str, help='Improvement in these metrics on a dataset trigger checkpointing.') 49 | parser.add_argument('--evaltypes', nargs='+', default=['discriminative'], type=str, help='The network may produce multiple embeddings (ModuleDict, relevant for e.g. DiVA). If the key is listed here, the entry will be evaluated on the evaluation metrics.\ 50 | Note: One may use Combined_embed1_embed2_..._embedn-w1-w1-...-wn to compute evaluation metrics on weighted (normalized) combinations.') 51 | 52 | 53 | ##### Setup Parameters 54 | parser.add_argument('--savename', default='group_plus_seed', type=str, help='Run savename - if default, the savename will comprise the project and group name (see wandb_parameters()).') 55 | parser.add_argument('--source_path', default=os.getcwd()+'/data', type=str, help='Path to training data.') 56 | parser.add_argument('--save_path', default=os.getcwd()+'/Training_Results', type=str, help='Where to save everything.') 57 | parser.add_argument('--group', type=str, required=True) 58 | 59 | return parser 60 | 61 | 62 | ####################################### 63 | def loss_specific_parameters(parser): 64 | ### Contrastive Loss 65 | parser.add_argument('--loss_contrastive_pos_margin', default=0, type=float, help='positive margin for contrastive pairs.') 66 | parser.add_argument('--loss_contrastive_neg_margin', default=1, type=float, help='negative margin for contrastive pairs.') 67 | 68 | ### Triplet-based Losses 69 | parser.add_argument('--loss_triplet_margin', default=0.2, type=float, help='Margin for Triplet Loss') 70 | 71 | ### MarginLoss 72 | parser.add_argument('--loss_margin_margin', default=0.2, type=float, help='Triplet margin.') 73 | parser.add_argument('--loss_margin_beta_lr', default=0.0005, type=float, help='Learning Rate for learnable class margin parameters in MarginLoss') 74 | parser.add_argument('--loss_margin_beta', default=1.2, type=float, help='Initial Class Margin Parameter in Margin Loss') 75 | parser.add_argument('--loss_margin_nu', default=0, type=float, help='Regularisation value on betas in Margin Loss. Generally not needed.') 76 | parser.add_argument('--loss_margin_beta_constant',action='store_true', help='Flag. If set, beta-values are left untrained.') 77 | 78 | ### ProxyNCA 79 | parser.add_argument('--loss_proxynca_lrmulti', default=50, type=float, help='Learning Rate multiplier for Proxies in proxynca.') 80 | #NOTE: The number of proxies is determined by the number of data classes. 81 | 82 | ### NPair 83 | parser.add_argument('--loss_npair_l2', default=0.005, type=float, help='L2 weight in NPair. Note: Set to 0.02 in paper, but multiplied with 0.25 in their implementation.') 84 | 85 | ### Angular Loss 86 | parser.add_argument('--loss_angular_alpha', default=45, type=float, help='Angular margin in degrees.') 87 | parser.add_argument('--loss_angular_npair_ang_weight', default=2, type=float, help='Relative weighting between angular and npair contribution.') 88 | parser.add_argument('--loss_angular_npair_l2', default=0.005, type=float, help='L2 weight on NPair (as embeddings are not normalized).') 89 | 90 | ### Multisimilary Loss 91 | parser.add_argument('--loss_multisimilarity_pos_weight', default=2, type=float, help='Weighting on positive similarities.') 92 | parser.add_argument('--loss_multisimilarity_neg_weight', default=40, type=float, help='Weighting on negative similarities.') 93 | parser.add_argument('--loss_multisimilarity_margin', default=0.1, type=float, help='Distance margin for both positive and negative similarities.') 94 | parser.add_argument('--loss_multisimilarity_thresh', default=0.5, type=float, help='Exponential thresholding.') 95 | 96 | ### Lifted Structure Loss 97 | parser.add_argument('--loss_lifted_neg_margin', default=1, type=float, help='Margin placed on similarities.') 98 | parser.add_argument('--loss_lifted_l2', default=0.005, type=float, help='As embeddings are not normalized, they need to be placed under penalty.') 99 | 100 | ### Quadruplet Loss 101 | parser.add_argument('--loss_quadruplet_margin_alpha_1', default=0.2, type=float, help='Quadruplet Loss requires two margins. This is the first one.') 102 | parser.add_argument('--loss_quadruplet_margin_alpha_2', default=0.2, type=float, help='This is the second.') 103 | 104 | ### Soft-Triple Loss 105 | parser.add_argument('--loss_softtriplet_n_centroids', default=2, type=int, help='Number of proxies per class.') 106 | parser.add_argument('--loss_softtriplet_margin_delta', default=0.01, type=float, help='Margin placed on sample-proxy similarities.') 107 | parser.add_argument('--loss_softtriplet_gamma', default=0.1, type=float, help='Weight over sample-proxies within a class.') 108 | parser.add_argument('--loss_softtriplet_lambda', default=8, type=float, help='Serves as a temperature.') 109 | parser.add_argument('--loss_softtriplet_reg_weight', default=0.2, type=float, help='Regularization weight on the number of proxies.') 110 | parser.add_argument('--loss_softtriplet_lrmulti', default=1, type=float, help='Learning Rate multiplier for proxies.') 111 | 112 | ### Normalized Softmax Loss 113 | parser.add_argument('--loss_softmax_lr', default=0.00001, type=float, help='Learning rate on class proxies.') 114 | parser.add_argument('--loss_softmax_temperature', default=0.05, type=float, help='Temperature for NCA objective.') 115 | 116 | ### Histogram Loss 117 | parser.add_argument('--loss_histogram_nbins', default=65, type=int, help='Number of bins for histogram discretization.') 118 | 119 | ### SNR Triplet (with learnable margin) Loss 120 | parser.add_argument('--loss_snr_margin', default=0.2, type=float, help='Triplet margin.') 121 | parser.add_argument('--loss_snr_reg_lambda', default=0.005, type=float, help='Regularization of in-batch element sum.') 122 | 123 | ### ArcFace 124 | parser.add_argument('--loss_arcface_lr', default=0.0005, type=float, help='Learning rate on class proxies.') 125 | parser.add_argument('--loss_arcface_angular_margin', default=0.5, type=float, help='Angular margin in radians.') 126 | parser.add_argument('--loss_arcface_feature_scale', default=16, type=float, help='Inverse Temperature for NCA objective.') 127 | return parser 128 | 129 | 130 | 131 | ####################################### 132 | def batchmining_specific_parameters(parser): 133 | ### Distance-based Batchminer 134 | parser.add_argument('--miner_distance_lower_cutoff', default=0.5, type=float, help='Lower cutoff on distances - values below are sampled with equal prob.') 135 | parser.add_argument('--miner_distance_upper_cutoff', default=1.4, type=float, help='Upper cutoff on distances - values above are IGNORED.') 136 | ### Spectrum-Regularized Miner (as proposed in our paper) - utilizes a distance-based sampler that is regularized. 137 | parser.add_argument('--miner_rho_distance_lower_cutoff', default=0.5, type=float, help='Lower cutoff on distances - values below are sampled with equal prob.') 138 | parser.add_argument('--miner_rho_distance_upper_cutoff', default=1.4, type=float, help='Upper cutoff on distances - values above are IGNORED.') 139 | parser.add_argument('--miner_rho_distance_cp', default=0.2, type=float, help='Probability to replace a negative with a positive.') 140 | return parser 141 | 142 | 143 | ####################################### 144 | def batch_creation_parameters(parser): 145 | parser.add_argument('--data_sampler', default='class_random', type=str, help='How the batch is created. Available options: See datasampler/__init__.py.') 146 | parser.add_argument('--samples_per_class', default=2, type=int, help='Number of samples in one class drawn before choosing the next class. Set to >1 for tuple-based loss.') 147 | ### Batch-Sample Flags - Have no relevance to default SPC-N sampling 148 | parser.add_argument('--data_batchmatch_bigbs', default=512, type=int, help='Size of batch to be summarized into a smaller batch. For distillation/coreset-based methods.') 149 | parser.add_argument('--data_batchmatch_ncomps', default=10, type=int, help='Number of batch candidates that are evaluated, from which the best one is chosen.') 150 | parser.add_argument('--data_storage_no_update', action='store_true', help='Flag for methods that need a sample storage. If set, storage entries are NOT updated.') 151 | parser.add_argument('--data_d2_coreset_lambda', default=1, type=float, help='Regularisation for D2-coreset.') 152 | parser.add_argument('--data_gc_coreset_lim', default=1e-9, type=float, help='D2-coreset value limit.') 153 | parser.add_argument('--data_sampler_lowproj_dim', default=-1, type=int, help='Optionally project embeddings into a lower dimension to ensure that greedy coreset works better. Only makes a difference for large embedding dims.') 154 | parser.add_argument('--data_sim_measure', default='euclidean', type=str, help='Distance measure to use for batch selection.') 155 | parser.add_argument('--data_gc_softened', action='store_true', help='Flag. If set, use a soft version of greedy coreset.') 156 | parser.add_argument('--data_idx_full_prec', action='store_true', help='Deprecated.') 157 | parser.add_argument('--data_mb_mom', default=-1, type=float, help='For memory-bank based samplers - momentum term on storage entry updates.') 158 | parser.add_argument('--data_mb_lr', default=1, type=float, help='Deprecated.') 159 | 160 | return parser 161 | -------------------------------------------------------------------------------- /scripts/baselines/cars_runs.sh: -------------------------------------------------------------------------------- 1 | # """============= Baseline Runs --- CARS196 ====================""" 2 | main=train_baseline 3 | datapath=data 4 | gpu=${1:-0} 5 | 6 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Npair --seed 0 --bs 112 --samples_per_class 2 --loss npair --batch_mining npair --arch resnet50_frozen 7 | 8 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_GenLifted --seed 0 --bs 112 --samples_per_class 2 --loss lifted --batch_mining lifted --arch resnet50_frozen 9 | 10 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_ProxyNCA --seed 0 --bs 112 --samples_per_class 2 --loss proxynca --arch resnet50_frozen_normalize 11 | 12 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Histogram --seed 0 --bs 112 --samples_per_class 2 --loss histogram --arch resnet50_frozen_normalize --loss_histogram_nbins 65 13 | 14 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Contrastive --seed 0 --bs 112 --samples_per_class 2 --loss contrastive --batch_mining distance --arch resnet50_frozen_normalize 15 | 16 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_SoftTriple --seed 0 --bs 112 --samples_per_class 2 --loss softtriplet --arch resnet50_frozen_normalize 17 | 18 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Angular --seed 0 --bs 112 --samples_per_class 2 --loss angular --batch_mining npair --arch resnet50_frozen 19 | 20 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_ArcFace --seed 0 --bs 112 --samples_per_class 2 --loss arcface --arch resnet50_frozen_normalize 21 | 22 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Triplet_Random --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining random --arch resnet50_frozen_normalize 23 | 24 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Triplet_Semihard --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining semihard --arch resnet50_frozen_normalize 25 | 26 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Triplet_Softhard --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining softhard --arch resnet50_frozen_normalize 27 | 28 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Triplet_Distance --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining distance --arch resnet50_frozen_normalize 29 | 30 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Quadruplet_Distance --seed 0 --bs 112 --samples_per_class 2 --loss quadruplet --batch_mining distance --arch resnet50_frozen_normalize 31 | 32 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Margin_b06_Distance --loss_margin_beta 0.6 --seed 0 --bs 112 --samples_per_class 2 --loss margin --batch_mining distance --arch resnet50_frozen_normalize 33 | 34 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Margin_b12_Distance --seed 0 --bs 112 --samples_per_class 2 --loss margin --batch_mining distance --arch resnet50_frozen_normalize 35 | 36 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_SNR_Distance --seed 0 --bs 112 --samples_per_class 2 --loss snr --batch_mining distance --arch resnet50_frozen_normalize 37 | 38 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_MS --seed 2 --bs 112 --samples_per_class 2 --loss multisimilarity --arch resnet50_frozen_normalize 39 | 40 | python $main.py --dataset cars196 --kernels 6 --source $datapath --n_epochs 150 --group CARS_Softmax --seed 0 --bs 112 --samples_per_class 2 --loss softmax --batch_mining distance --arch resnet50_frozen_normalize -------------------------------------------------------------------------------- /scripts/baselines/cub_runs.sh: -------------------------------------------------------------------------------- 1 | # """============= Baseline Runs --- CUB200-2011 ====================""" 2 | main=train_baseline 3 | datapath=data 4 | gpu=${1:-0} 5 | 6 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Npair --seed 0 --bs 112 --samples_per_class 2 --loss npair --batch_mining npair --arch resnet50_frozen 7 | 8 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_GenLifted --seed 0 --bs 112 --samples_per_class 2 --loss lifted --batch_mining lifted --arch resnet50_frozen 9 | 10 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_ProxyNCA --seed 0 --bs 112 --samples_per_class 2 --loss proxynca --arch resnet50_frozen_normalize 11 | 12 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Histogram --seed 0 --bs 112 --samples_per_class 2 --loss histogram --arch resnet50_frozen_normalize --loss_histogram_nbins 65 13 | 14 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Contrastive --seed 0 --bs 112 --samples_per_class 2 --loss contrastive --batch_mining distance --arch resnet50_frozen_normalize 15 | 16 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_SoftTriple --seed 0 --bs 112 --samples_per_class 2 --loss softtriplet --arch resnet50_frozen_normalize 17 | 18 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Angular --seed 0 --bs 112 --samples_per_class 2 --loss angular --batch_mining npair --arch resnet50_frozen 19 | 20 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_ArcFace --seed 0 --bs 112 --samples_per_class 2 --loss arcface --arch resnet50_frozen_normalize 21 | 22 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Triplet_Random --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining random --arch resnet50_frozen_normalize 23 | 24 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Triplet_Semihard --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining semihard --arch resnet50_frozen_normalize 25 | 26 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Triplet_Softhard --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining softhard --arch resnet50_frozen_normalize 27 | 28 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Triplet_Distance --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining distance --arch resnet50_frozen_normalize 29 | 30 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Quadruplet_Distance --seed 0 --bs 112 --samples_per_class 2 --loss quadruplet --batch_mining distance --arch resnet50_frozen_normalize 31 | 32 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Margin_b06_Distance --loss_margin_beta 0.6 --seed 0 --bs 112 --samples_per_class 2 --loss margin --batch_mining distance --arch resnet50_frozen_normalize 33 | 34 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Margin_b12_Distance --seed 0 --bs 112 --samples_per_class 2 --loss margin --batch_mining distance --arch resnet50_frozen_normalize 35 | 36 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_SNR_Distance --seed 0 --bs 112 --samples_per_class 2 --loss snr --batch_mining distance --arch resnet50_frozen_normalize 37 | 38 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_MS --seed 0 --bs 112 --samples_per_class 2 --loss multisimilarity --arch resnet50_frozen_normalize 39 | 40 | python $main.py --kernels 6 --source $datapath --n_epochs 150 --group CUB_Softmax --seed 0 --bs 112 --samples_per_class 2 --loss softmax --batch_mining distance --arch resnet50_frozen_normalize -------------------------------------------------------------------------------- /scripts/baselines/sop_runs.sh: -------------------------------------------------------------------------------- 1 | 2 | # """============= Baseline Runs --- Online Products ====================""" 3 | main=train_baseline 4 | datapath=data 5 | gpu=${1:-0} 6 | 7 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Npair --seed 0 --bs 112 --samples_per_class 2 --loss npair --batch_mining npair --arch resnet50_frozen 8 | 9 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_GenLifted --seed 0 --bs 112 --samples_per_class 2 --loss lifted --batch_mining lifted --arch resnet50_frozen 10 | 11 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Histogram --seed 0 --bs 112 --samples_per_class 2 --loss histogram --arch resnet50_frozen_normalize --loss_histogram_nbins 11 12 | 13 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Contrastive --seed 0 --bs 112 --samples_per_class 2 --loss contrastive --batch_mining distance --arch resnet50_frozen_normalize 14 | 15 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Angular --seed 0 --bs 112 --samples_per_class 2 --loss angular --batch_mining npair --arch resnet50_frozen 16 | 17 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_ArcFace --seed 0 --bs 112 --samples_per_class 2 --loss arcface --arch resnet50_frozen_normalize 18 | 19 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Triplet_Random --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining random --arch resnet50_frozen_normalize 20 | 21 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Triplet_Semihard --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining semihard --arch resnet50_frozen_normalize 22 | 23 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Triplet_Softhard --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining softhard --arch resnet50_frozen_normalize 24 | 25 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Triplet_Distance --seed 0 --bs 112 --samples_per_class 2 --loss triplet --batch_mining distance --arch resnet50_frozen_normalize 26 | 27 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Quadruplet_Distance --seed 0 --bs 112 --samples_per_class 2 --loss quadruplet --batch_mining distance --arch resnet50_frozen_normalize 28 | 29 | python $main.py --dataset online_products --kernels 6 --source $datapath --n_epochs 100 --group SOP_Margin_b06_Distance --loss_margin_beta 0.6 --seed 0 --bs 112 --samples_per_class 2 --loss margin --batch_mining distance --arch resnet50_frozen_normalize 30 | 31 | python $main.py --dataset online_products --kernels 2 --source $datapath --n_epochs 100 --group SOP_Margin_b12_Distance --seed 0 --bs 112 --samples_per_class 2 --loss margin --batch_mining distance --arch resnet50_frozen_normalize 32 | 33 | python $main.py --dataset online_products --kernels 2 --source $datapath --n_epochs 100 --group SOP_SNR_Distance --seed 0 --bs 112 --samples_per_class 2 --loss snr --batch_mining distance --arch resnet50_frozen_normalize 34 | 35 | python $main.py --dataset online_products --kernels 2 --source $datapath --n_epochs 100 --group SOP_MS --seed 0 --gpu 0 --bs 112 --samples_per_class 2 --loss multisimilarity --arch resnet50_frozen_normalize 36 | 37 | python $main.py --dataset online_products --kernels 2 --source $datapath --n_epochs 100 --group SOP_Softmax --seed 0 --bs 112 --samples_per_class 2 --loss softmax --batch_mining distance --arch resnet50_frozen_normalize --loss_softmax_lr 0.002 38 | 39 | python $main.py --dataset online_products --kernels 2 --source $datapath --n_epochs 100 --group SOP_ProxyNCA --seed 0 --bs 112 --samples_per_class 2 --loss proxynca --arch resnet50_frozen_normalize 40 | 41 | python $main.py --dataset online_products --kernels 2 --source $datapath --n_epochs 100 --group SOP_SoftTriple --seed 0 --bs 32 --samples_per_class 2 --loss softtriplet --arch resnet50_frozen_normalize --loss_softtriplet_gamma 10 --loss_softtriplet_lambda 20 --loss_softtriplet_lrmulti 10 --lr 1e-1 -------------------------------------------------------------------------------- /scripts/diml/test_diml.sh: -------------------------------------------------------------------------------- 1 | dataset=${1:-cub200} 2 | embed_dim=${2:-128} 3 | arch=${3:-resnet50_frozen_normalize} 4 | 5 | python test_diml.py --dataset $dataset \ 6 | --seed 0 --bs 16 --data_sampler class_random --samples_per_class 2\ 7 | --arch $arch --group diml_test \ 8 | --embed_dim $embed_dim --evaluate_on_gpu \ 9 | -------------------------------------------------------------------------------- /scripts/diml/train_diml.sh: -------------------------------------------------------------------------------- 1 | dataset=${1:-cub200} 2 | bs=${2:-112} 3 | loss=${3:-margin_diml} 4 | epochs=${4:-150} 5 | seed=${5:-0} 6 | 7 | python train_diml.py --dataset $dataset --loss $loss --batch_mining distance \ 8 | --group ${dataset}_$loss --seed $seed \ 9 | --bs $bs --data_sampler class_random --samples_per_class 2 \ 10 | --arch resnet50_diml_frozen_normalize --n_epochs $epochs \ 11 | --lr 0.00001 --embed_dim 128 --evaluate_on_gpu 12 | -------------------------------------------------------------------------------- /test_diml.py: -------------------------------------------------------------------------------- 1 | """===================================================================================================""" 2 | ################### LIBRARIES ################### 3 | ### Basic Libraries 4 | import comet_ml 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | 8 | import os, sys, numpy as np, argparse, imp, datetime, pandas as pd, copy 9 | import time, pickle as pkl, random, json, collections 10 | import matplotlib 11 | matplotlib.use('agg') 12 | import matplotlib.pyplot as plt 13 | import torch 14 | 15 | from tqdm import tqdm, trange 16 | from utilities.misc import load_checkpoint 17 | import torch.nn.functional as F 18 | import shutil 19 | 20 | import parameters as par 21 | 22 | 23 | """===================================================================================================""" 24 | ################### INPUT ARGUMENTS ################### 25 | parser = argparse.ArgumentParser() 26 | 27 | parser = par.basic_training_parameters(parser) 28 | parser = par.batch_creation_parameters(parser) 29 | parser = par.batchmining_specific_parameters(parser) 30 | parser = par.loss_specific_parameters(parser) 31 | 32 | ##### Read in parameters 33 | opt = parser.parse_args() 34 | 35 | 36 | ### Load Remaining Libraries that neeed to be loaded after comet_ml 37 | import torch, torch.nn as nn 38 | import torch.multiprocessing 39 | torch.multiprocessing.set_sharing_strategy('file_system') 40 | import architectures as archs 41 | import datasampler as dsamplers 42 | import datasets as datasets 43 | import criteria as criteria 44 | import batchminer as bmine 45 | import evaluation as eval 46 | from utilities import misc 47 | from utilities import logger 48 | 49 | 50 | """===================================================================================================""" 51 | opt.source_path += '/'+opt.dataset 52 | opt.save_path += '/'+opt.dataset 53 | print(opt.save_path) 54 | 55 | #Assert that the construction of the batch makes sense, i.e. the division into class-subclusters. 56 | assert not opt.bs%opt.samples_per_class, 'Batchsize needs to fit number of samples per class for distance sampling and margin/triplet loss!' 57 | 58 | opt.pretrained = not opt.not_pretrained 59 | 60 | 61 | 62 | 63 | """===================================================================================================""" 64 | ################### GPU SETTINGS ########################### 65 | os.environ["CUDA_DEVICE_ORDER"] ="PCI_BUS_ID" 66 | 67 | 68 | 69 | """===================================================================================================""" 70 | #################### SEEDS FOR REPROD. ##################### 71 | torch.backends.cudnn.deterministic=True; np.random.seed(opt.seed); random.seed(opt.seed) 72 | torch.manual_seed(opt.seed); torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed) 73 | 74 | 75 | 76 | """===================================================================================================""" 77 | ##################### NETWORK SETUP ################## 78 | opt.device = torch.device('cuda') 79 | 80 | 81 | """============================================================================""" 82 | #################### DATALOADER SETUPS ################## 83 | dataloaders = {} 84 | datasets = datasets.select(opt.dataset, opt, opt.source_path) 85 | 86 | dataloaders['testing'] = torch.utils.data.DataLoader(datasets['testing'], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False) 87 | opt.n_classes = len(dataloaders['testing'].dataset.avail_classes) 88 | model = archs.select(opt.arch, opt) 89 | _ = model.to(opt.device) 90 | 91 | 92 | """============================================================================""" 93 | ################### Summary #########################3 94 | data_text = 'Dataset:\t {}'.format(opt.dataset.upper()) 95 | setup_text = 'Objective:\t {}'.format(opt.loss.upper()) 96 | arch_text = 'Backbone:\t {} (#weights: {})'.format(opt.arch.upper(), misc.gimme_params(model)) 97 | summary = data_text+'\n'+setup_text+'\n'+arch_text 98 | print(summary) 99 | 100 | """============================================================================""" 101 | ################### SCRIPT MAIN ########################## 102 | print('\n-----\n') 103 | 104 | iter_count = 0 105 | loss_args = {'batch': None, 'labels':None, 'batch_features':None, 'f_embed':None} 106 | 107 | 108 | # prepare path 109 | CUB_LOGS = { 110 | 'Angular': ['CUB_Angular_s0'], 111 | 'Arcface': ['CUB_ArcFace_s0'], 112 | 'Contrasitive': ['CUB_Contrastive_s0'], 113 | 'NPair': ['CUB_Npair_s0'], 114 | 'GenLifted': ['CUB_GenLifted_s0'], 115 | 'ProxyNCA': ['CUB_ProxyNCA_s0'], 116 | 'Histogram': ['CUB_Histogram_s0'], 117 | 'Quadruplet': ['CUB_Quadruplet_Distance_s0'], 118 | 'SNR': ['CUB_SNR_Distance_s0'], 119 | 'Softmax': ['CUB_Softmax_s0'], 120 | 'Triplet_Random': ['CUB_Triplet_Random_s0'], 121 | 'Triplet_Semihard': ['CUB_Triplet_Semihard_s0'], 122 | 'Triplet_Softhard': ['CUB_Triplet_Softhard_s0'], 123 | 'Triplet_Distance': ['CUB_Triplet_Distance_s0'], 124 | 'Margin_b12_64': ['CUB_Margin_b12_Distance_64_s0'], 125 | 'Margin_b12': ['CUB_Margin_b12_Distance_s0_3'], 126 | 'Margin_b12_512': ['CUB_Margin_b12_Distance_512_s0'], 127 | 'Margin_b06': ['CUB_Margin_b06_Distance_s0'], 128 | 'Multisimilarity_64': ['CUB_MS_64_s0'], 129 | 'Multisimilarity': ['CUB_MS_s0'], 130 | 'Multisimilarity_512': ['CUB_MS_512_s0'], 131 | } 132 | 133 | CARS_LOGS = { 134 | 'Angular': ['CARS_Angular_s0'], 135 | 'Arcface': ['CARS_ArcFace_s0'], 136 | 'Contrasitive': ['CARS_Contrastive_s0'], 137 | 'NPair': ['CARS_Npair_s0'], 138 | 'GenLifted': ['CARS_GenLifted_s0'], 139 | 'ProxyNCA': ['CARS_ProxyNCA_s0'], 140 | 'Histogram': ['CARS_Histogram_s0'], 141 | 'Quadruplet': ['CARS_Quadruplet_Distance_s0'], 142 | 'SNR': ['CARS_SNR_Distance_s0'], 143 | 'Softmax': ['CARS_Softmax_s0'], 144 | 'Triplet_Random': ['CARS_Triplet_Random_s0'], 145 | 'Triplet_Semihard': ['CARS_Triplet_Semihard_s0'], 146 | 'Triplet_Softhard': ['CARS_Triplet_Softhard_s0'], 147 | 'Triplet_Distance': ['CARS_Triplet_Distance_s0'], 148 | 'Margin_b12': ['CARS_Margin_b12_Distance_s0'], 149 | 'Margin_b06': ['CARS_Margin_b06_Distance_s0'], 150 | 'Multisimilarity': ['CARS_MS_s0'], 151 | 'Margin_b12_64': ['CARS_Margin_b12_Distance_64_s0_1'], 152 | 'Multisimilarity_64': ['CARS_MS_64_s2_1'], 153 | 'Margin_b12_512': ['CARS_Margin_b12_Distance_512_s0_1'], 154 | 'Multisimilarity_512': ['CARS_MS_512_s2_1'], 155 | } 156 | 157 | SOP_LOGS = { 158 | 'Angular': ['SOP_Angular_s0'], 159 | 'Arcface': ['SOP_ArcFace_s0'], 160 | 'Contrasitive': ['SOP_Contrastive_s0'], 161 | 'NPair': ['SOP_Npair_s0'], 162 | 'GenLifted': ['SOP_GenLifted_s0'], 163 | 'Histogram': ['SOP_Histogram_s0'], 164 | 'Quadruplet': ['SOP_Quadruplet_Distance_s0'], 165 | 'SNR': ['SOP_SNR_Distance_s0'], 166 | 'Softmax': ['SOP_Softmax_s0'], 167 | 'Triplet_Random': ['SOP_Triplet_Random_s0'], 168 | 'Triplet_Semihard': ['SOP_Triplet_Semihard_s0'], 169 | 'Triplet_Softhard': ['SOP_Triplet_Softhard_s0'], 170 | 'Triplet_Distance': ['SOP_Triplet_Distance_s0'], 171 | 'ProxyNCA': ['SOP_ProxyNCA_s0'], 172 | 'Margin_b12': ['SOP_Margin_b12_Distance_s0'], 173 | 'Margin_b06': ['SOP_Margin_b06_Distance_s0'], 174 | 'Multisimilarity': ['SOP_MS_s0'], 175 | } 176 | 177 | if opt.dataset == 'cub200': 178 | LOGS = CUB_LOGS 179 | elif opt.dataset == 'cars196': 180 | LOGS = CARS_LOGS 181 | else: 182 | LOGS = SOP_LOGS 183 | 184 | 185 | from evaluation.eval_diml import evaluate 186 | 187 | results = [] 188 | methods = [] 189 | data = { 190 | k: [] 191 | for k in ['method', 'r1', 'rp', 'mapr'] 192 | } 193 | 194 | trunc_nums = [0, 100] 195 | 196 | for method, info in LOGS.items(): 197 | path = f'Training_Results/{opt.dataset}/{info[0]}/best.pth' 198 | best_metrics = load_checkpoint(model, None, path) 199 | print(best_metrics) 200 | 201 | result = evaluate(model, dataloaders['testing'], True, trunc_nums, use_uniform=False, grid_size=4) 202 | 203 | print(result) 204 | result['method'] = [f'{method} + ours ({trunc})' for trunc in trunc_nums] 205 | for k, v in data.items(): 206 | v.extend(result[k]) 207 | 208 | df = pd.DataFrame(data) 209 | os.makedirs('test_results', exist_ok=True) 210 | df.to_csv(f'test_results/test_diml_{opt.dataset}.csv') 211 | -------------------------------------------------------------------------------- /train_baseline.py: -------------------------------------------------------------------------------- 1 | """===================================================================================================""" 2 | ################### LIBRARIES ################### 3 | ### Basic Libraries 4 | import comet_ml 5 | import warnings 6 | import logging 7 | warnings.filterwarnings("ignore") 8 | 9 | import os, sys, numpy as np, argparse, imp, datetime, pandas as pd, copy 10 | import time, pickle as pkl, random, json, collections 11 | import matplotlib 12 | matplotlib.use('agg') 13 | import matplotlib.pyplot as plt 14 | import torch 15 | 16 | from evaluation.metrics import get_metrics 17 | from tqdm import tqdm, trange 18 | import shutil 19 | import torch.nn as nn 20 | 21 | import parameters as par 22 | 23 | 24 | """===================================================================================================""" 25 | ################### INPUT ARGUMENTS ################### 26 | parser = argparse.ArgumentParser() 27 | 28 | parser = par.basic_training_parameters(parser) 29 | parser = par.batch_creation_parameters(parser) 30 | parser = par.batchmining_specific_parameters(parser) 31 | parser = par.loss_specific_parameters(parser) 32 | 33 | ##### Read in parameters 34 | opt = parser.parse_args() 35 | 36 | 37 | """===================================================================================================""" 38 | opt.savename = opt.group + '_s{}'.format(opt.seed) 39 | 40 | """===================================================================================================""" 41 | ### Load Remaining Libraries that neeed to be loaded after comet_ml 42 | import torch, torch.nn as nn 43 | import torch.multiprocessing 44 | torch.multiprocessing.set_sharing_strategy('file_system') 45 | import architectures as archs 46 | import datasampler as dsamplers 47 | import datasets as datasets 48 | import criteria as criteria 49 | import batchminer as bmine 50 | import evaluation as eval 51 | from utilities import misc 52 | from utilities import logger 53 | 54 | 55 | 56 | """===================================================================================================""" 57 | full_training_start_time = time.time() 58 | 59 | 60 | 61 | """===================================================================================================""" 62 | opt.source_path += '/'+opt.dataset 63 | opt.save_path += '/'+opt.dataset 64 | print(opt.save_path) 65 | 66 | #Assert that the construction of the batch makes sense, i.e. the division into class-subclusters. 67 | assert not opt.bs%opt.samples_per_class, 'Batchsize needs to fit number of samples per class for distance sampling and margin/triplet loss!' 68 | 69 | opt.pretrained = not opt.not_pretrained 70 | 71 | 72 | 73 | 74 | """===================================================================================================""" 75 | ################### GPU SETTINGS ########################### 76 | os.environ["CUDA_DEVICE_ORDER"] ="PCI_BUS_ID" 77 | 78 | 79 | """===================================================================================================""" 80 | #################### SEEDS FOR REPROD. ##################### 81 | torch.backends.cudnn.deterministic=True; np.random.seed(opt.seed); random.seed(opt.seed) 82 | torch.manual_seed(opt.seed); torch.cuda.manual_seed(opt.seed); torch.cuda.manual_seed_all(opt.seed) 83 | 84 | 85 | 86 | """===================================================================================================""" 87 | ##################### NETWORK SETUP ################## 88 | opt.device = torch.device('cuda') 89 | model = archs.select(opt.arch, opt) 90 | 91 | if opt.fc_lr<0: 92 | to_optim = [{'params':model.parameters(),'lr':opt.lr,'weight_decay':opt.decay}] 93 | else: 94 | all_but_fc_params = [x[-1] for x in list(filter(lambda x: 'last_linear' not in x[0], model.named_parameters()))] 95 | fc_params = model.model.last_linear.parameters() 96 | to_optim = [{'params':all_but_fc_params,'lr':opt.lr,'weight_decay':opt.decay}, 97 | {'params':fc_params,'lr':opt.fc_lr,'weight_decay':opt.decay}] 98 | 99 | _ = model.to(opt.device) 100 | 101 | 102 | model = nn.DataParallel(model) 103 | 104 | """============================================================================""" 105 | 106 | dataloaders, train_data_sampler = datasets.build_dataset(opt, model) 107 | opt.n_classes = len(dataloaders['training'].dataset.avail_classes) 108 | 109 | 110 | """============================================================================""" 111 | #################### CREATE LOGGING FILES ############### 112 | sub_loggers = ['Train', 'Test', 'Model Grad'] 113 | if opt.use_tv_split: sub_loggers.append('Val') 114 | LOG = logger.LOGGER(opt, sub_loggers=sub_loggers, start_new=True, log_online=False) 115 | 116 | """============================================================================""" 117 | #################### LOSS SETUP #################### 118 | batchminer = bmine.select(opt.batch_mining, opt) 119 | criterion, to_optim = criteria.select(opt.loss, opt, to_optim, batchminer) 120 | _ = criterion.to(opt.device) 121 | 122 | if 'criterion' in train_data_sampler.name: 123 | train_data_sampler.internal_criterion = criterion 124 | 125 | """============================================================================""" 126 | #################### OPTIM SETUP #################### 127 | if opt.optim == 'adam': 128 | optimizer = torch.optim.Adam(to_optim) 129 | elif opt.optim == 'sgd': 130 | optimizer = torch.optim.SGD(to_optim, momentum=0.9) 131 | else: 132 | raise Exception('Optimizer <{}> not available!'.format(opt.optim)) 133 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.tau, gamma=opt.gamma) 134 | 135 | 136 | """============================================================================""" 137 | #################### METRIC COMPUTER #################### 138 | opt.rho_spectrum_embed_dim = opt.embed_dim 139 | 140 | """============================================================================""" 141 | ################### Summary #########################3 142 | data_text = 'Dataset:\t {}'.format(opt.dataset.upper()) 143 | setup_text = 'Objective:\t {}'.format(opt.loss.upper()) 144 | miner_text = 'Batchminer:\t {}'.format(opt.batch_mining if criterion.REQUIRES_BATCHMINER else 'N/A') 145 | arch_text = 'Backbone:\t {} (#weights: {})'.format(opt.arch.upper(), misc.gimme_params(model)) 146 | summary = data_text+'\n'+setup_text+'\n'+miner_text+'\n'+arch_text 147 | print(summary) 148 | 149 | 150 | """============================================================================""" 151 | ################### SCRIPT MAIN ########################## 152 | print('\n-----\n') 153 | 154 | iter_count = 0 155 | loss_args = {'batch':None, 'labels':None, 'batch_features':None, 'f_embed':None} 156 | 157 | 158 | best_r1 = 0 159 | best_rp = 0 160 | best_mapr = 0 161 | 162 | file_handler = logging.FileHandler(filename=os.path.join(opt.save_path, 'log.txt')) 163 | stdout_handler = logging.StreamHandler(sys.stdout) 164 | handlers = [file_handler, stdout_handler] 165 | 166 | logging.basicConfig( 167 | level=logging.INFO, 168 | format='[%(asctime)s] %(levelname)s - %(message)s', 169 | handlers=handlers 170 | ) 171 | 172 | logger = logging.getLogger('root') 173 | 174 | 175 | for epoch in range(opt.n_epochs): 176 | epoch_start_time = time.time() 177 | 178 | if epoch>0 and opt.data_idx_full_prec and train_data_sampler.requires_storage: 179 | train_data_sampler.full_storage_update(dataloaders['evaluation'], model, opt.device) 180 | 181 | opt.epoch = epoch 182 | ### Scheduling Changes specifically for cosine scheduling 183 | if opt.scheduler!='none': print('Running with learning rates {}...'.format(' | '.join('{}'.format(x) for x in scheduler.get_lr()))) 184 | 185 | """=======================================""" 186 | if train_data_sampler.requires_storage: 187 | train_data_sampler.precompute_indices() 188 | 189 | 190 | """=======================================""" 191 | ### Train one epoch 192 | start = time.time() 193 | _ = model.train() 194 | 195 | 196 | loss_collect = [] 197 | data_iterator = tqdm(dataloaders['training'], desc='Epoch {} Training...'.format(epoch)) 198 | logger.info(f"Epoch {epoch} start") 199 | 200 | print(opt.save_path) 201 | 202 | for i,out in enumerate(data_iterator): 203 | class_labels, input, input_indices = out 204 | 205 | ### Compute Embedding 206 | input = input.to(opt.device) 207 | model_args = {'x':input.to(opt.device)} 208 | # Needed for MixManifold settings. 209 | if 'mix' in opt.arch: model_args['labels'] = class_labels 210 | embeds = model(**model_args) 211 | if isinstance(embeds, tuple): embeds, (avg_features, features) = embeds 212 | 213 | ### Compute Loss 214 | loss_args['batch'] = embeds 215 | loss_args['labels'] = class_labels 216 | # loss_args['f_embed'] = model.module.model.last_linear 217 | loss_args['batch_features'] = features 218 | loss = criterion(**loss_args) 219 | 220 | ### 221 | optimizer.zero_grad() 222 | loss.backward() 223 | 224 | ### Compute Model Gradients and log them! 225 | grads = np.concatenate([p.grad.detach().cpu().numpy().flatten() for p in model.parameters() if p.grad is not None]) 226 | grad_l2, grad_max = np.mean(np.sqrt(np.mean(np.square(grads)))), np.mean(np.max(np.abs(grads))) 227 | LOG.progress_saver['Model Grad'].log('Grad L2', grad_l2, group='L2') 228 | LOG.progress_saver['Model Grad'].log('Grad Max', grad_max, group='Max') 229 | 230 | ### Update network weights! 231 | optimizer.step() 232 | 233 | ### 234 | loss_collect.append(loss.item()) 235 | 236 | ### 237 | iter_count += 1 238 | 239 | if i==len(dataloaders['training'])-1: data_iterator.set_description('Epoch (Train) {0}: Mean Loss [{1:.4f}]'.format(epoch, np.mean(loss_collect))) 240 | 241 | """=======================================""" 242 | if train_data_sampler.requires_storage and train_data_sampler.update_storage: 243 | train_data_sampler.replace_storage_entries(embeds.detach().cpu(), input_indices) 244 | 245 | result_metrics = {'loss': np.mean(loss_collect)} 246 | 247 | #### 248 | LOG.progress_saver['Train'].log('epochs', epoch) 249 | for metricname, metricval in result_metrics.items(): 250 | LOG.progress_saver['Train'].log(metricname, metricval) 251 | LOG.progress_saver['Train'].log('time', np.round(time.time()-start, 4)) 252 | 253 | 254 | 255 | """=======================================""" 256 | ### Evaluate Metric for Training & Test (& Validation) 257 | _ = model.eval() 258 | dataloader = dataloaders['testing'] 259 | 260 | with torch.no_grad(): 261 | target_labels = [] 262 | feature_bank = [] 263 | labels = [] 264 | final_iter = tqdm(dataloader, desc='Embedding Data...') 265 | image_paths = [x[0] for x in dataloader.dataset.image_list] 266 | for idx, inp in enumerate(final_iter): 267 | input_img, target = inp[1], inp[0] 268 | target_labels.extend(target.numpy().tolist()) 269 | out = model(input_img.to(opt.device)) 270 | if isinstance(out, tuple): out, aux_f = out 271 | feature_bank.append(out.data) 272 | labels.append(target) 273 | feature_bank = torch.cat(feature_bank, dim=0) 274 | labels = torch.cat(labels, dim=0) 275 | 276 | N, C = feature_bank.size() 277 | feature_bank = torch.nn.functional.normalize(feature_bank, p=2, dim=1) 278 | 279 | overall_r1 = 0.0 280 | overall_rp = 0.0 281 | overall_mapr = 0.0 282 | 283 | for idx in range(len(feature_bank)): 284 | feature = feature_bank[idx] 285 | sim = torch.sum(feature_bank * feature.unsqueeze(0), dim=1) 286 | sim[idx] = -100 287 | r1, rp, mapr = get_metrics(sim.data.cpu(), labels[idx], labels) 288 | overall_r1 += r1 289 | overall_rp += rp 290 | overall_mapr += mapr 291 | 292 | overall_r1 = overall_r1 / float(N) 293 | overall_rp = overall_rp / float(N) 294 | overall_mapr = overall_mapr / float(N) 295 | 296 | is_best = False 297 | if overall_r1 > best_r1: 298 | best_r1 = overall_r1 299 | is_best = True 300 | if overall_rp > best_rp: 301 | best_rp = overall_rp 302 | if overall_mapr > best_mapr: 303 | best_mapr = overall_mapr 304 | 305 | all_metrics = { 306 | 'r1': overall_r1, 307 | 'rp': overall_rp, 308 | 'mapr': overall_mapr, 309 | } 310 | best_metrics = { 311 | 'r1': best_r1, 312 | 'rp': best_rp, 313 | 'mapr': best_mapr, 314 | } 315 | 316 | for k, v in all_metrics.items(): 317 | LOG.progress_saver['Test'].log(k, v) 318 | 319 | logger.info('saving checkpoint...') 320 | misc.save_checkpoint(model, optimizer, os.path.join(opt.save_path, 'latest.pth'), all_metrics, best_metrics, epoch) 321 | if is_best: 322 | logger.info('saving best checkpoint...') 323 | shutil.copy2(os.path.join(opt.save_path, 'latest.pth'), os.path.join(opt.save_path, 'best.pth')) 324 | 325 | 326 | print('###########') 327 | logger.info('Now rank-1 acc=%f, RP=%f, MAP@R=%f' % (overall_r1, overall_rp, overall_mapr)) 328 | logger.info('Best rank-1 acc=%f, RP=%f, MAP@R=%f' % (best_r1, best_rp, best_mapr)) 329 | 330 | LOG.update(all=True) 331 | 332 | 333 | """=======================================""" 334 | ### Learning Rate Scheduling Step 335 | if opt.scheduler != 'none': 336 | scheduler.step() 337 | 338 | print('Total Epoch Runtime: {0:4.2f}s'.format(time.time()-epoch_start_time)) 339 | print('\n-----\n') 340 | 341 | 342 | 343 | 344 | """=======================================================""" 345 | ### CREATE A SUMMARY TEXT FILE 346 | summary_text = '' 347 | full_training_time = time.time()-full_training_start_time 348 | summary_text += 'Training Time: {} min.\n'.format(np.round(full_training_time/60,2)) 349 | 350 | summary_text += '---------------\n' 351 | for sub_logger in LOG.sub_loggers: 352 | metrics = LOG.graph_writer[sub_logger].ov_title 353 | summary_text += '{} metrics: {}\n'.format(sub_logger.upper(), metrics) 354 | 355 | with open(opt.save_path+'/training_summary.txt','w') as summary_file: 356 | summary_file.write(summary_text) 357 | -------------------------------------------------------------------------------- /utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DIML/c15dbea696f68ddf889dcacfcaacd315d16a34ac/utilities/__init__.py -------------------------------------------------------------------------------- /utilities/diml.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def Sinkhorn(K, u, v): 5 | r = torch.ones_like(u) 6 | c = torch.ones_like(v) 7 | thresh = 1e-1 8 | for _ in range(100): 9 | r0 = r 10 | r = u / torch.matmul(K, c.unsqueeze(-1)).squeeze(-1) 11 | c = v / torch.matmul(K.permute(0, 2, 1).contiguous(), r.unsqueeze(-1)).squeeze(-1) 12 | err = (r - r0).abs().mean() 13 | if err.item() < thresh: 14 | break 15 | T = torch.matmul(r.unsqueeze(-1), c.unsqueeze(-2)) * K 16 | return T 17 | 18 | def calc_similarity(anchor, anchor_center, fb, fb_center, stage, use_uniform=False): 19 | if stage == 0: 20 | sim = torch.einsum('c,nc->n', anchor_center, fb_center) 21 | else: 22 | N, _, R = fb.size() 23 | 24 | sim = torch.einsum('cm,ncs->nsm', anchor, fb).contiguous().view(N, R, R) 25 | dis = 1.0 - sim 26 | K = torch.exp(-dis / 0.05) 27 | 28 | if use_uniform: 29 | u = torch.zeros(N, R, dtype=sim.dtype, device=sim.device).fill_(1. / R) 30 | v = torch.zeros(N, R, dtype=sim.dtype, device=sim.device).fill_(1. / R) 31 | else: 32 | att = F.relu(torch.einsum("c,ncr->nr", anchor_center, fb)).view(N, R) 33 | u = att / (att.sum(dim=1, keepdims=True) + 1e-5) 34 | 35 | att = F.relu(torch.einsum("cr,nc->nr", anchor, fb_center)).view(N, R) 36 | v = att / (att.sum(dim=1, keepdims=True) + 1e-5) 37 | 38 | T = Sinkhorn(K, u, v) 39 | sim = torch.sum(T * sim, dim=(1, 2)) 40 | return sim 41 | 42 | -------------------------------------------------------------------------------- /utilities/logger.py: -------------------------------------------------------------------------------- 1 | import datetime, csv, os, numpy as np 2 | from matplotlib import pyplot as plt 3 | import pickle as pkl 4 | from utilities.misc import gimme_save_string 5 | 6 | """=============================================================================================================""" 7 | ################## WRITE TO CSV FILE ##################### 8 | class CSV_Writer(): 9 | def __init__(self, save_path): 10 | self.save_path = save_path 11 | self.written = [] 12 | self.n_written_lines = {} 13 | 14 | def log(self, group, segments, content): 15 | if group not in self.n_written_lines.keys(): 16 | self.n_written_lines[group] = 0 17 | 18 | with open(self.save_path+'_'+group+'.csv', "a") as csv_file: 19 | writer = csv.writer(csv_file, delimiter=",") 20 | if group not in self.written: writer.writerow(segments) 21 | for line in content: 22 | writer.writerow(line) 23 | self.n_written_lines[group] += 1 24 | 25 | self.written.append(group) 26 | 27 | 28 | 29 | ################## PLOT SUMMARY IMAGE ##################### 30 | class InfoPlotter(): 31 | def __init__(self, save_path, title='Training Log', figsize=(25,19)): 32 | self.save_path = save_path 33 | self.title = title 34 | self.figsize = figsize 35 | self.colors = ['r','g','b','y','m','c','orange','darkgreen','lightblue'] 36 | 37 | def make_plot(self, base_title, title_append, sub_plots, sub_plots_data): 38 | sub_plots = list(sub_plots) 39 | if 'epochs' not in sub_plots: 40 | x_data = range(len(sub_plots_data[0])) 41 | else: 42 | x_data = range(sub_plots_data[np.where(np.array(sub_plots)=='epochs')[0][0]][-1]+1) 43 | 44 | self.ov_title = [(sub_plot,sub_plot_data) for sub_plot, sub_plot_data in zip(sub_plots,sub_plots_data) if sub_plot not in ['epoch','epochs','time']] 45 | self.ov_title = [(x[0],np.max(x[1])) if 'loss' not in x[0] else (x[0],np.min(x[1])) for x in self.ov_title] 46 | self.ov_title = title_append +': '+ ' | '.join('{0}: {1:.4f}'.format(x[0],x[1]) for x in self.ov_title) 47 | sub_plots_data = [x for x,y in zip(sub_plots_data, sub_plots)] 48 | sub_plots = [x for x in sub_plots] 49 | 50 | plt.style.use('ggplot') 51 | f,ax = plt.subplots(1) 52 | ax.set_title(self.ov_title, fontsize=22) 53 | for i,(data, title) in enumerate(zip(sub_plots_data, sub_plots)): 54 | ax.plot(x_data, data, '-{}'.format(self.colors[i]), linewidth=1.7, label=base_title+' '+title) 55 | ax.tick_params(axis='both', which='major', labelsize=18) 56 | ax.tick_params(axis='both', which='minor', labelsize=18) 57 | ax.legend(loc=2, prop={'size': 16}) 58 | f.set_size_inches(self.figsize[0], self.figsize[1]) 59 | f.savefig(self.save_path+'_'+title_append+'.svg') 60 | plt.close() 61 | 62 | 63 | ################## GENERATE LOGGING FOLDER/FILES ####################### 64 | def set_logging(opt): 65 | checkfolder = opt.save_path+'/'+opt.savename 66 | if opt.savename == '': 67 | date = datetime.datetime.now() 68 | time_string = '{}-{}-{}-{}-{}-{}'.format(date.year, date.month, date.day, date.hour, date.minute, date.second) 69 | checkfolder = opt.save_path+'/{}_{}_'.format(opt.dataset.upper(), opt.arch.upper())+time_string 70 | counter = 1 71 | while os.path.exists(checkfolder): 72 | checkfolder = opt.save_path+'/'+opt.savename+'_'+str(counter) 73 | counter += 1 74 | os.makedirs(checkfolder) 75 | opt.save_path = checkfolder 76 | 77 | if 'experiment' in vars(opt): 78 | import argparse 79 | save_opt = {key:item for key,item in vars(opt).items() if key!='experiment'} 80 | save_opt = argparse.Namespace(**save_opt) 81 | else: 82 | save_opt = opt 83 | 84 | with open(save_opt.save_path+'/Parameter_Info.txt','w') as f: 85 | f.write(gimme_save_string(save_opt)) 86 | pkl.dump(save_opt,open(save_opt.save_path+"/hypa.pkl","wb")) 87 | 88 | 89 | class Progress_Saver(): 90 | def __init__(self): 91 | self.groups = {} 92 | 93 | def log(self, segment, content, group=None): 94 | if group is None: group = segment 95 | if group not in self.groups.keys(): 96 | self.groups[group] = {} 97 | 98 | if segment not in self.groups[group].keys(): 99 | self.groups[group][segment] = {'content':[],'saved_idx':0} 100 | 101 | self.groups[group][segment]['content'].append(content) 102 | 103 | 104 | class LOGGER(): 105 | def __init__(self, opt, sub_loggers=[], prefix=None, start_new=True, log_online=False): 106 | """ 107 | LOGGER Internal Structure: 108 | 109 | self.progress_saver: Contains multiple Progress_Saver instances to log metrics for main metric subsets (e.g. "Train" for training metrics) 110 | ['main_subset_name']: Name of each main subset (-> e.g. "Train") 111 | .groups: Dictionary of subsets belonging to one of the main subsets, e.g. ["Recall", "NMI", ...] 112 | ['specific_metric_name']: Specific name of the metric of interest, e.g. Recall@1. 113 | """ 114 | self.prop = opt 115 | self.prefix = '{}_'.format(prefix) if prefix is not None else '' 116 | self.sub_loggers = sub_loggers 117 | 118 | ### Make Logging Directories 119 | if start_new: set_logging(opt) 120 | 121 | ### Set Graph and CSV writer 122 | self.csv_writer, self.graph_writer, self.progress_saver = {},{},{} 123 | for sub_logger in sub_loggers: 124 | csv_savepath = opt.save_path+'/CSV_Logs' 125 | if not os.path.exists(csv_savepath): os.makedirs(csv_savepath) 126 | self.csv_writer[sub_logger] = CSV_Writer(csv_savepath+'/Data_{}{}'.format(self.prefix, sub_logger)) 127 | 128 | prgs_savepath = opt.save_path+'/Progression_Plots' 129 | if not os.path.exists(prgs_savepath): os.makedirs(prgs_savepath) 130 | self.graph_writer[sub_logger] = InfoPlotter(prgs_savepath+'/Graph_{}{}'.format(self.prefix, sub_logger)) 131 | self.progress_saver[sub_logger] = Progress_Saver() 132 | 133 | 134 | ### WandB Init 135 | self.save_path = opt.save_path 136 | self.log_online = log_online 137 | 138 | 139 | def update(self, *sub_loggers, all=False): 140 | online_content = [] 141 | 142 | if all: sub_loggers = self.sub_loggers 143 | 144 | for sub_logger in list(sub_loggers): 145 | for group in self.progress_saver[sub_logger].groups.keys(): 146 | pgs = self.progress_saver[sub_logger].groups[group] 147 | segments = pgs.keys() 148 | per_seg_saved_idxs = [pgs[segment]['saved_idx'] for segment in segments] 149 | per_seg_contents = [pgs[segment]['content'][idx:] for segment,idx in zip(segments, per_seg_saved_idxs)] 150 | per_seg_contents_all = [pgs[segment]['content'] for segment,idx in zip(segments, per_seg_saved_idxs)] 151 | 152 | #Adjust indexes 153 | for content,segment in zip(per_seg_contents, segments): 154 | self.progress_saver[sub_logger].groups[group][segment]['saved_idx'] += len(content) 155 | 156 | tupled_seg_content = [list(seg_content_slice) for seg_content_slice in zip(*per_seg_contents)] 157 | 158 | self.csv_writer[sub_logger].log(group, segments, tupled_seg_content) 159 | self.graph_writer[sub_logger].make_plot(sub_logger, group, segments, per_seg_contents_all) 160 | 161 | for i,segment in enumerate(segments): 162 | if group == segment: 163 | name = sub_logger+': '+group 164 | else: 165 | name = sub_logger+': '+group+': '+segment 166 | online_content.append((name,per_seg_contents[i])) 167 | 168 | if self.log_online: 169 | if self.prop.online_backend=='wandb': 170 | import wandb 171 | for i,item in enumerate(online_content): 172 | if isinstance(item[1], list): 173 | wandb.log({item[0]:np.mean(item[1])}, step=self.prop.epoch) 174 | else: 175 | wandb.log({item[0]:item[1]}, step=self.prop.epoch) 176 | elif self.prop.online_backend=='comet_ml': 177 | for i,item in enumerate(online_content): 178 | if isinstance(item[1], list): 179 | self.prop.experiment.log_metric(item[0],np.mean(item[1]), self.prop.epoch) 180 | else: 181 | self.prop.experiment.log_metric(item[0],item[1],self.prop.epoch) 182 | -------------------------------------------------------------------------------- /utilities/misc.py: -------------------------------------------------------------------------------- 1 | """=============================================================================================================""" 2 | ######## LIBRARIES ##################### 3 | import numpy as np 4 | 5 | 6 | 7 | """=============================================================================================================""" 8 | ################# ACQUIRE NUMBER OF WEIGHTS ################# 9 | def gimme_params(model): 10 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 11 | params = sum([np.prod(p.size()) for p in model_parameters]) 12 | return params 13 | 14 | 15 | ################# SAVE TRAINING PARAMETERS IN NICE STRING ################# 16 | def gimme_save_string(opt): 17 | varx = vars(opt) 18 | base_str = '' 19 | for key in varx: 20 | base_str += str(key) 21 | if isinstance(varx[key],dict): 22 | for sub_key, sub_item in varx[key].items(): 23 | base_str += '\n\t'+str(sub_key)+': '+str(sub_item) 24 | else: 25 | base_str += '\n\t'+str(varx[key]) 26 | base_str+='\n\n' 27 | return base_str 28 | 29 | 30 | ############################################################################# 31 | import torch, torch.nn as nn 32 | 33 | class DataParallel(nn.Module): 34 | def __init__(self, model, device_ids, dim): 35 | super().__init__() 36 | self.model = model.model 37 | self.network = nn.DataParallel(model, device_ids, dim) 38 | 39 | def forward(self, x): 40 | return self.network(x) 41 | 42 | def save_checkpoint(model, optimizer, save_path, metrics, best_metrics, epoch): 43 | print('Save checkpoint to', save_path) 44 | save_dict = { 45 | 'model': model.state_dict(), 46 | 'optimizer': optimizer.state_dict(), 47 | 'metrics': metrics, 48 | 'best_metrics': best_metrics, 49 | 'epoch': epoch, 50 | } 51 | torch.save(save_dict, save_path) 52 | 53 | 54 | def load_checkpoint(model, optimizer, save_path): 55 | print('Load checkpoint from', save_path) 56 | state_dict = torch.load(save_path) 57 | model_state_dict = {} 58 | for k, v in state_dict['model'].items(): 59 | if k.startswith('module'): 60 | model_state_dict[k[7:]] = v 61 | else: 62 | model_state_dict[k] = v 63 | 64 | model.load_state_dict(model_state_dict) 65 | if optimizer is not None: 66 | optimizer.load_state_dict(state_dict['optimizer']) 67 | best_metrics = state_dict['best_metrics'] 68 | epoch = state_dict.get('epoch', 0) 69 | return best_metrics, epoch 70 | --------------------------------------------------------------------------------