├── .gitignore ├── OOD-Study_FewShot.sh ├── OOD-Study_ZeroShot.sh ├── README.md ├── architectures ├── __init__.py ├── bit.py ├── bninception.py ├── efficientb0.py ├── multiembed_bninception.py ├── multifeature_bit.py ├── multifeature_bninception.py ├── multifeature_efficientb0.py ├── multifeature_resnet101.py ├── multifeature_resnet18.py ├── multifeature_resnet50.py ├── resnet101.py ├── resnet18.py ├── resnet50.py └── resnext101.py ├── batchminer ├── __init__.py ├── distance.py ├── diva_shared_distance_an.py ├── diva_shared_distance_apn.py ├── easypositive.py ├── epshn.py ├── intra_random.py ├── lifted.py ├── npair.py ├── parametric.py ├── random.py ├── random_distance.py ├── rho_distance.py ├── semihard.py └── softhard.py ├── create_dataset_splits.py ├── criteria ├── __init__.py ├── adversarial_separation.py ├── angular.py ├── arcface.py ├── base_criterion.py ├── contrastive.py ├── fast_moco.py ├── imrot.py ├── margin.py ├── moco.py ├── multisimilarity.py ├── oproxy.py ├── proxynca.py ├── quadruplet.py ├── s2sd.py ├── shared_margin.py ├── shared_triplet.py └── triplet.py ├── datasampler ├── __init__.py ├── class_random_sampler.py └── random_sampler.py ├── datasets ├── __init__.py ├── basic_dataset_scaffold.py ├── cars196.py ├── cub200.py └── stanford_online_products.py ├── datasplits ├── cars196_splits.pkl ├── cub200_splits.pkl └── online_products_splits.pkl ├── evaluation └── __init__.py ├── fewshot_diva_ood_main.py ├── fewshot_ood_main.py ├── images ├── AUC_Comp.png ├── fewshot.png ├── progression.png ├── progression_comp.png └── umaps.png ├── metrics ├── __init__.py ├── a_recall.py ├── compute_stack.py ├── dists.py ├── e_recall.py ├── f1.py ├── mAP.py ├── mAP_1000.py ├── mAP_c.py ├── nmi.py └── rho_spectrum.py ├── ood_diva_main.py ├── ood_main.py ├── parameters.py ├── split_helpers.py └── utilities ├── __init__.py ├── finetune_utils.py ├── logger.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /architectures/__init__.py: -------------------------------------------------------------------------------- 1 | import architectures.resnet50 2 | import architectures.resnext101 3 | import architectures.bninception 4 | import architectures.multifeature_resnet18 5 | import architectures.multifeature_resnet50 6 | import architectures.multifeature_resnet101 7 | import architectures.multifeature_bninception 8 | import architectures.multifeature_bit 9 | import architectures.multifeature_efficientb0 10 | import architectures.resnet18 11 | import architectures.resnet101 12 | import architectures.bit 13 | import architectures.efficientb0 14 | 15 | def select(arch, opt): 16 | if 'multifeature_resnet50' in arch: 17 | return multifeature_resnet50.Network(opt) 18 | if 'multifeature_resnet18' in arch: 19 | return multifeature_resnet18.Network(opt) 20 | if 'multifeature_resnet101' in arch: 21 | return multifeature_resnet101.Network(opt) 22 | if 'multifeature_bninception' in arch: 23 | return multifeature_bninception.Network(opt) 24 | if 'multifeature_bit' in arch: 25 | return multifeature_bit.Network(opt) 26 | if 'multifeature_efficientb0' in arch: 27 | return multifeature_efficientb0.Network(opt) 28 | if 'resnet50' in arch: 29 | return resnet50.Network(opt) 30 | if 'resnet18' in arch: 31 | return resnet18.Network(opt) 32 | if 'resnet101' in arch: 33 | return resnet101.Network(opt) 34 | if 'resnext101' in arch: 35 | return resnext101.Network(opt) 36 | if 'googlenet' in arch: 37 | return googlenet.Network(opt) 38 | if 'bninception' in arch: 39 | return bninception.Network(opt) 40 | if 'bit' in arch: 41 | return bit.Network(opt) 42 | if 'efficientb0' in arch: 43 | return efficientb0.Network(opt) 44 | -------------------------------------------------------------------------------- /architectures/bit.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import timm 6 | import torchvision 7 | 8 | 9 | 10 | """=============================================================""" 11 | class Network(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(Network, self).__init__() 14 | 15 | self.pars = opt 16 | 17 | self.name = opt.arch 18 | 19 | self.model = timm.create_model('resnetv2_50x1_bitm', pretrained=True) 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == timm.models.layers.norm_act.GroupNormAct, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.model.last_linear = torch.nn.Linear(self.model.head.fc.in_channels, opt.embed_dim) 27 | 28 | self.out_adjust = None 29 | self.extra_out = None 30 | 31 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 32 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 33 | 34 | self.specific_normalization = torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 35 | 36 | 37 | def forward(self, x, warmup=False, **kwargs): 38 | if warmup: 39 | with torch.no_grad(): 40 | x = self.model.forward_features(x) 41 | prepool_y = x 42 | if self.pool_aux is not None: 43 | y = self.pool_aux(x) + self.pool_base(x) 44 | else: 45 | y = self.pool_base(x) 46 | y = y.view(y.size(0),-1) 47 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 48 | else: 49 | x = self.model.forward_features(x) 50 | prepool_y = x 51 | if self.pool_aux is not None: 52 | y = self.pool_aux(x) + self.pool_base(x) 53 | else: 54 | y = self.pool_base(x) 55 | y = y.view(y.size(0),-1) 56 | 57 | z = self.model.last_linear(y) 58 | 59 | if 'normalize' in self.pars.arch: 60 | z = torch.nn.functional.normalize(z, dim=-1) 61 | 62 | return {'embeds':z, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 63 | -------------------------------------------------------------------------------- /architectures/bninception.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn, torch.nn.functional as F 5 | import pretrainedmodels as ptm 6 | 7 | 8 | 9 | """=============================================================""" 10 | class Network(torch.nn.Module): 11 | def __init__(self, opt, return_embed_dict=False): 12 | super(Network, self).__init__() 13 | 14 | self.pars = opt 15 | self.model = ptm.__dict__['bninception'](num_classes=1000, pretrained='imagenet') 16 | self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim) 17 | if '_he' in opt.arch: 18 | torch.nn.init.kaiming_normal_(self.model.last_linear.weight, mode='fan_out') 19 | torch.nn.init.constant_(self.model.last_linear.bias, 0) 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.return_embed_dict = return_embed_dict 27 | 28 | self.pool_base = F.avg_pool2d 29 | self.pool_aux = F.max_pool2d if 'double' in opt.arch else None 30 | 31 | self.name = opt.arch 32 | 33 | self.out_adjust = None 34 | self.extra_out = None 35 | 36 | 37 | def forward(self, x, warmup=False, **kwargs): 38 | if warmup: 39 | with torch.no_grad(): 40 | x = self.model.features(x) 41 | prepool_y = y = self.pool_base(x,kernel_size=x.shape[-1]) 42 | if self.pool_aux is not None: 43 | y += self.pool_aux(x, kernel_size=x.shape[-1]) 44 | if 'lp2' in self.pars.arch: 45 | y += F.lp_pool2d(x, 2, kernel_size=x.shape[-1]) 46 | if 'lp3' in self.pars.arch: 47 | y += F.lp_pool2d(x, 3, kernel_size=x.shape[-1]) 48 | 49 | y = y.view(len(x),-1) 50 | 51 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 52 | else: 53 | x = self.model.features(x) 54 | prepool_y = y = self.pool_base(x,kernel_size=x.shape[-1]) 55 | if self.pool_aux is not None: 56 | y += self.pool_aux(x, kernel_size=x.shape[-1]) 57 | if 'lp2' in self.pars.arch: 58 | y += F.lp_pool2d(x, 2, kernel_size=x.shape[-1]) 59 | if 'lp3' in self.pars.arch: 60 | y += F.lp_pool2d(x, 3, kernel_size=x.shape[-1]) 61 | 62 | y = y.view(len(x),-1) 63 | 64 | z = self.model.last_linear(y) 65 | if 'normalize' in self.name: 66 | z = F.normalize(z, dim=-1) 67 | 68 | return {'embeds':z, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 69 | -------------------------------------------------------------------------------- /architectures/efficientb0.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import timm 6 | 7 | 8 | 9 | 10 | """=============================================================""" 11 | class Network(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(Network, self).__init__() 14 | 15 | self.pars = opt 16 | self.model = timm.create_model('efficientnet_b0', pretrained=True) 17 | 18 | self.name = opt.arch 19 | 20 | if 'frozen' in opt.arch: 21 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 22 | module.eval() 23 | module.train = lambda _: None 24 | 25 | self.model.last_linear = torch.nn.Linear(self.model.classifier.in_features, opt.embed_dim) 26 | 27 | self.out_adjust = None 28 | self.extra_out = None 29 | 30 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 31 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 32 | 33 | 34 | 35 | def forward(self, x, warmup=False, **kwargs): 36 | if warmup: 37 | with torch.no_grad(): 38 | x = self.model.forward_features(x) 39 | prepool_y = x 40 | if self.pool_aux is not None: 41 | y = self.pool_aux(x) + self.pool_base(x) 42 | else: 43 | y = self.pool_base(x) 44 | y = y.view(y.size(0),-1) 45 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 46 | else: 47 | x = self.model.forward_features(x) 48 | prepool_y = x 49 | if self.pool_aux is not None: 50 | y = self.pool_aux(x) + self.pool_base(x) 51 | else: 52 | y = self.pool_base(x) 53 | y = y.view(y.size(0),-1) 54 | 55 | z = self.model.last_linear(y) 56 | 57 | if 'normalize' in self.pars.arch: 58 | z = torch.nn.functional.normalize(z, dim=-1) 59 | 60 | return {'embeds':z, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 61 | -------------------------------------------------------------------------------- /architectures/multiembed_bninception.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = ptm.__dict__['bninception'](num_classes=1000, pretrained='imagenet') 18 | self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim) 19 | self.name = opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.feature_dim = self.model.last_linear.in_features 27 | out_dict = nn.ModuleDict() 28 | for mode in opt.diva_features: 29 | out_dict[mode] = torch.nn.Linear(self.feature_dim, opt.embed_dim) 30 | 31 | self.model.last_linear = out_dict 32 | 33 | def forward(self, x): 34 | x = self.model.features(x) 35 | x = nn.functional.avg_pool2d(x, kernel_size=x.shape[2]) 36 | x = x.view(x.size(0), -1) 37 | 38 | out_dict = {} 39 | for key,linear_map in self.model.last_linear.items(): 40 | if not 'normalize' in self.pars.arch: 41 | out_dict[key] = linear_map(x) 42 | else: 43 | out_dict[key] = torch.nn.functional.normalize(linear_map(x), dim=-1) 44 | 45 | return out_dict, x 46 | -------------------------------------------------------------------------------- /architectures/multifeature_bit.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import timm 6 | 7 | 8 | 9 | 10 | """=============================================================""" 11 | class Network(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(Network, self).__init__() 14 | 15 | self.pars = opt 16 | self.model = timm.create_model('resnetv2_50x1_bitm', pretrained=True) 17 | 18 | self.name = 'multifeature_'+opt.arch 19 | 20 | if 'frozen' in opt.arch: 21 | for module in filter(lambda m: type(m) == timm.models.layers.norm_act.GroupNormAct, self.model.modules()): 22 | module.eval() 23 | module.train = lambda _: None 24 | 25 | self.model.last_linear = torch.nn.Linear(self.model.head.fc.in_channels, opt.embed_dim) 26 | self.feature_dim = self.model.last_linear.in_features 27 | out_dict = nn.ModuleDict() 28 | for mode in opt.diva_features: 29 | out_dict[mode] = torch.nn.Linear(self.feature_dim, opt.embed_dim) 30 | self.has_merged = False 31 | 32 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 33 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 34 | 35 | self.model.last_linear = out_dict 36 | 37 | def merge_branches(self, weights=None): 38 | if weights is None: 39 | pass 40 | else: 41 | pass 42 | self.has_merged = True 43 | 44 | def forward(self, x, warmup=False, **kwargs): 45 | z_dict = {} 46 | if warmup: 47 | with torch.no_grad(): 48 | x = self.model.forward_features(x) 49 | prepool_y = x 50 | if self.pool_aux is not None: 51 | y = self.pool_aux(x) + self.pool_base(x) 52 | else: 53 | y = self.pool_base(x) 54 | y = y.view(y.size(0),-1) 55 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 56 | else: 57 | x = self.model.forward_features(x) 58 | prepool_y = x 59 | if self.pool_aux is not None: 60 | y = self.pool_aux(x) + self.pool_base(x) 61 | else: 62 | y = self.pool_base(x) 63 | y = y.view(y.size(0),-1) 64 | 65 | for key,embed in self.model.last_linear.items(): 66 | z = embed(y) 67 | if 'normalize' in self.pars.arch: 68 | z = torch.nn.functional.normalize(z, dim=-1) 69 | z_dict[key] = z 70 | 71 | return {'embeds':z_dict, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 72 | -------------------------------------------------------------------------------- /architectures/multifeature_bninception.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = ptm.__dict__['bninception'](num_classes=1000, pretrained='imagenet') 18 | self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim) 19 | self.name = 'multifeature_'+opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.feature_dim = self.model.last_linear.in_features 27 | out_dict = nn.ModuleDict() 28 | for mode in opt.diva_features: 29 | out_dict[mode] = torch.nn.Linear(self.feature_dim, opt.embed_dim) 30 | 31 | self.model.last_linear = out_dict 32 | 33 | def forward(self, x, **kwargs): 34 | prepool_y = x = self.model.features(x) 35 | x = nn.functional.avg_pool2d(x, kernel_size=x.shape[2]) 36 | y = x.view(x.size(0), -1) 37 | 38 | z_dict = {} 39 | 40 | for key,embed in self.model.last_linear.items(): 41 | z = embed(y) 42 | if 'normalize' in self.pars.arch: 43 | z = torch.nn.functional.normalize(z, dim=-1) 44 | z_dict[key] = z 45 | 46 | return {'embeds':z_dict, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 47 | -------------------------------------------------------------------------------- /architectures/multifeature_efficientb0.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import timm 6 | 7 | 8 | 9 | 10 | """=============================================================""" 11 | class Network(torch.nn.Module): 12 | def __init__(self, opt): 13 | super(Network, self).__init__() 14 | 15 | self.pars = opt 16 | self.model = timm.create_model('efficientnet_b0', pretrained=True) 17 | 18 | self.name = 'multifeature_'+opt.arch 19 | 20 | if 'frozen' in opt.arch: 21 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 22 | module.eval() 23 | module.train = lambda _: None 24 | 25 | self.feature_dim = self.model.classifier.in_features 26 | out_dict = nn.ModuleDict() 27 | for mode in opt.diva_features: 28 | out_dict[mode] = torch.nn.Linear(self.feature_dim, opt.embed_dim) 29 | self.has_merged = False 30 | self.model.last_linear = out_dict 31 | 32 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 33 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 34 | 35 | 36 | def merge_branches(self, weights=None): 37 | if weights is None: 38 | pass 39 | else: 40 | pass 41 | self.has_merged = True 42 | 43 | def forward(self, x, warmup=False, **kwargs): 44 | z_dict = {} 45 | if warmup: 46 | with torch.no_grad(): 47 | x = self.model.forward_features(x) 48 | prepool_y = x 49 | if self.pool_aux is not None: 50 | y = self.pool_aux(x) + self.pool_base(x) 51 | else: 52 | y = self.pool_base(x) 53 | y = y.view(y.size(0),-1) 54 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 55 | else: 56 | x = self.model.forward_features(x) 57 | prepool_y = x 58 | if self.pool_aux is not None: 59 | y = self.pool_aux(x) + self.pool_base(x) 60 | else: 61 | y = self.pool_base(x) 62 | y = y.view(y.size(0),-1) 63 | 64 | for key,embed in self.model.last_linear.items(): 65 | z = embed(y) 66 | if 'normalize' in self.pars.arch: 67 | z = torch.nn.functional.normalize(z, dim=-1) 68 | z_dict[key] = z 69 | 70 | return {'embeds':z_dict, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 71 | -------------------------------------------------------------------------------- /architectures/multifeature_resnet101.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | from torchvision import models as models 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = models.resnet101(pretrained=not opt.not_pretrained) 18 | 19 | self.name = 'multifeature_'+opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.feature_dim = self.model.fc.in_features 27 | out_dict = nn.ModuleDict() 28 | for mode in opt.diva_features: 29 | out_dict[mode] = torch.nn.Linear(self.feature_dim, opt.embed_dim) 30 | 31 | self.model.last_linear = out_dict 32 | 33 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 34 | 35 | 36 | def forward(self, x, **kwargs): 37 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 38 | for layerblock in self.layer_blocks: 39 | x = layerblock(x) 40 | prepool_y = x 41 | y = nn.functional.avg_pool2d(x, kernel_size=x.shape[2]) 42 | y = y.view(y.size(0),-1) 43 | 44 | z_dict = {} 45 | 46 | for key,embed in self.model.last_linear.items(): 47 | z = embed(y) 48 | if 'normalize' in self.pars.arch: 49 | z = torch.nn.functional.normalize(z, dim=-1) 50 | z_dict[key] = z 51 | 52 | return {'embeds':z_dict, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 53 | -------------------------------------------------------------------------------- /architectures/multifeature_resnet18.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | from torchvision import models as models 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = models.resnet18(pretrained=not opt.not_pretrained) 18 | 19 | self.name = 'multifeature_'+opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.feature_dim = self.model.fc.in_features 27 | out_dict = nn.ModuleDict() 28 | for mode in opt.diva_features: 29 | out_dict[mode] = torch.nn.Linear(self.feature_dim, opt.embed_dim) 30 | 31 | self.model.last_linear = out_dict 32 | 33 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 34 | 35 | 36 | def forward(self, x, **kwargs): 37 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 38 | for layerblock in self.layer_blocks: 39 | x = layerblock(x) 40 | prepool_y = x 41 | y = nn.functional.avg_pool2d(x, kernel_size=x.shape[2]) 42 | y = y.view(y.size(0),-1) 43 | 44 | z_dict = {} 45 | 46 | for key,embed in self.model.last_linear.items(): 47 | z = embed(y) 48 | if 'normalize' in self.pars.arch: 49 | z = torch.nn.functional.normalize(z, dim=-1) 50 | z_dict[key] = z 51 | 52 | return {'embeds':z_dict, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 53 | -------------------------------------------------------------------------------- /architectures/multifeature_resnet50.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained='imagenet') 18 | 19 | self.name = 'multifeature_'+opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.feature_dim = self.model.last_linear.in_features 27 | out_dict = nn.ModuleDict() 28 | for mode in opt.diva_features: 29 | out_dict[mode] = torch.nn.Linear(self.feature_dim, opt.embed_dim) 30 | self.has_merged = False 31 | 32 | self.model.last_linear = out_dict 33 | 34 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 35 | 36 | def merge_branches(self, weights=None): 37 | from IPython import embed; embed() 38 | if weights is None: 39 | pass 40 | else: 41 | pass 42 | self.has_merged = True 43 | 44 | def forward(self, x, warmup=False, **kwargs): 45 | z_dict = {} 46 | if warmup: 47 | with torch.no_grad(): 48 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 49 | for layerblock in self.layer_blocks: 50 | x = layerblock(x) 51 | prepool_y = x 52 | y = nn.functional.avg_pool2d(x, kernel_size=x.shape[2]) 53 | y = y.view(y.size(0),-1) 54 | else: 55 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 56 | for layerblock in self.layer_blocks: 57 | x = layerblock(x) 58 | prepool_y = x 59 | y = nn.functional.avg_pool2d(x, kernel_size=x.shape[2]) 60 | y = y.view(y.size(0),-1) 61 | 62 | for key,embed in self.model.last_linear.items(): 63 | z = embed(y) 64 | if 'normalize' in self.pars.arch: 65 | z = torch.nn.functional.normalize(z, dim=-1) 66 | z_dict[key] = z 67 | 68 | return {'embeds':z_dict, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 69 | -------------------------------------------------------------------------------- /architectures/resnet101.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | from torchvision import models as models 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = models.resnet101(pretrained=not opt.not_pretrained) 18 | 19 | self.name = opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.model.last_linear = torch.nn.Linear(self.model.fc.in_features, opt.embed_dim) 27 | 28 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 29 | 30 | self.out_adjust = None 31 | self.extra_out = None 32 | 33 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 34 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 35 | 36 | 37 | 38 | def forward(self, x, warmup=False, **kwargs): 39 | if warmup: 40 | with torch.no_grad(): 41 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 42 | for layerblock in self.layer_blocks: 43 | x = layerblock(x) 44 | prepool_y = x 45 | if self.pool_aux is not None: 46 | y = self.pool_aux(x) + self.pool_base(x) 47 | else: 48 | y = self.pool_base(x) 49 | y = y.view(y.size(0),-1) 50 | 51 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 52 | else: 53 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 54 | for layerblock in self.layer_blocks: 55 | x = layerblock(x) 56 | prepool_y = x 57 | if self.pool_aux is not None: 58 | y = self.pool_aux(x) + self.pool_base(x) 59 | else: 60 | y = self.pool_base(x) 61 | y = y.view(y.size(0),-1) 62 | 63 | z = self.model.last_linear(y) 64 | 65 | if 'normalize' in self.pars.arch: 66 | z = torch.nn.functional.normalize(z, dim=-1) 67 | 68 | return {'embeds':z, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 69 | -------------------------------------------------------------------------------- /architectures/resnet18.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | from torchvision import models as models 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = models.resnet18(pretrained=not opt.not_pretrained) 18 | 19 | self.name = opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.model.last_linear = torch.nn.Linear(self.model.fc.in_features, opt.embed_dim) 27 | 28 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 29 | 30 | self.out_adjust = None 31 | self.extra_out = None 32 | 33 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 34 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 35 | 36 | 37 | 38 | def forward(self, x, warmup=False, **kwargs): 39 | if warmup: 40 | with torch.no_grad(): 41 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 42 | for layerblock in self.layer_blocks: 43 | x = layerblock(x) 44 | prepool_y = x 45 | if self.pool_aux is not None: 46 | y = self.pool_aux(x) + self.pool_base(x) 47 | else: 48 | y = self.pool_base(x) 49 | y = y.view(y.size(0),-1) 50 | 51 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 52 | else: 53 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 54 | for layerblock in self.layer_blocks: 55 | x = layerblock(x) 56 | prepool_y = x 57 | if self.pool_aux is not None: 58 | y = self.pool_aux(x) + self.pool_base(x) 59 | else: 60 | y = self.pool_base(x) 61 | y = y.view(y.size(0),-1) 62 | 63 | z = self.model.last_linear(y) 64 | 65 | if 'normalize' in self.pars.arch: 66 | z = torch.nn.functional.normalize(z, dim=-1) 67 | 68 | return {'embeds':z, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 69 | -------------------------------------------------------------------------------- /architectures/resnet50.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import pretrainedmodels as ptm 6 | 7 | 8 | 9 | 10 | 11 | """=============================================================""" 12 | class Network(torch.nn.Module): 13 | def __init__(self, opt): 14 | super(Network, self).__init__() 15 | 16 | self.pars = opt 17 | self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained='imagenet' if not opt.not_pretrained else None) 18 | 19 | self.name = opt.arch 20 | 21 | if 'frozen' in opt.arch: 22 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 23 | module.eval() 24 | module.train = lambda _: None 25 | 26 | self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, opt.embed_dim) 27 | 28 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 29 | 30 | self.out_adjust = None 31 | self.extra_out = None 32 | 33 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 34 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 35 | 36 | 37 | 38 | def forward(self, x, warmup=False, **kwargs): 39 | if warmup: 40 | with torch.no_grad(): 41 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 42 | for layerblock in self.layer_blocks: 43 | x = layerblock(x) 44 | prepool_y = x 45 | if self.pool_aux is not None: 46 | y = self.pool_aux(x) + self.pool_base(x) 47 | else: 48 | y = self.pool_base(x) 49 | y = y.view(y.size(0),-1) 50 | 51 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 52 | else: 53 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 54 | for layerblock in self.layer_blocks: 55 | x = layerblock(x) 56 | prepool_y = x 57 | if self.pool_aux is not None: 58 | y = self.pool_aux(x) + self.pool_base(x) 59 | else: 60 | y = self.pool_base(x) 61 | y = y.view(y.size(0),-1) 62 | 63 | z = self.model.last_linear(y) 64 | 65 | if 'normalize' in self.pars.arch: 66 | z = torch.nn.functional.normalize(z, dim=-1) 67 | 68 | return {'embeds':z, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 69 | -------------------------------------------------------------------------------- /architectures/resnext101.py: -------------------------------------------------------------------------------- 1 | """ 2 | The network architectures and weights are adapted and used from the great https://github.com/Cadene/pretrained-models.pytorch. 3 | """ 4 | import torch, torch.nn as nn 5 | import torchvision.models as models 6 | 7 | 8 | """=============================================================""" 9 | class Network(torch.nn.Module): 10 | def __init__(self, opt): 11 | super(Network, self).__init__() 12 | 13 | self.pars = opt 14 | self.model = models.resnext101_32x8d(pretrained=not opt.not_pretrained) 15 | 16 | self.name = opt.arch 17 | 18 | if 'frozen' in opt.arch: 19 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): 20 | module.eval() 21 | module.train = lambda _: None 22 | 23 | self.model.fc = torch.nn.Linear(self.model.fc.in_features, opt.embed_dim) 24 | 25 | from IPython import embed; embed() 26 | self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4]) 27 | 28 | self.out_adjust = None 29 | self.extra_out = None 30 | 31 | self.pool_base = torch.nn.AdaptiveAvgPool2d(1) 32 | self.pool_aux = torch.nn.AdaptiveMaxPool2d(1) if 'double' in opt.arch else None 33 | 34 | 35 | 36 | def forward(self, x, warmup=False, **kwargs): 37 | x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) 38 | for layerblock in self.layer_blocks: 39 | x = layerblock(x) 40 | prepool_y = x 41 | if self.pool_aux is not None: 42 | y = self.pool_aux(x) + self.pool_base(x) 43 | else: 44 | y = self.pool_base(x) 45 | y = y.view(y.size(0),-1) 46 | 47 | if warmup: 48 | x,y,prepool_y = x.detach(), y.detach(), prepool_y.detach() 49 | 50 | z = self.model.last_linear(y) 51 | 52 | if 'normalize' in self.pars.arch: 53 | z = torch.nn.functional.normalize(z, dim=-1) 54 | 55 | return {'embeds':z, 'avg_features':y, 'features':x, 'extra_embeds': prepool_y} 56 | -------------------------------------------------------------------------------- /batchminer/__init__.py: -------------------------------------------------------------------------------- 1 | from batchminer import random_distance, diva_shared_distance_apn, diva_shared_distance_an, intra_random 2 | from batchminer import lifted, rho_distance, softhard, npair, parametric, random, semihard, distance 3 | from batchminer import epshn 4 | 5 | BATCHMINING_METHODS = {'random':random, 6 | 'semihard':semihard, 7 | 'softhard':softhard, 8 | 'distance':distance, 9 | 'rho_distance':rho_distance, 10 | 'npair':npair, 11 | 'parametric':parametric, 12 | 'lifted':lifted, 13 | 'random_distance': random_distance, 14 | 'intra_random': intra_random, 15 | 'epshn': epshn, 16 | 'shared_full_distance': diva_shared_distance_apn, 17 | 'shared_neg_distance': diva_shared_distance_an} 18 | 19 | 20 | def select(batchminername, opt): 21 | if batchminername not in BATCHMINING_METHODS: raise NotImplementedError('Batchmining {} not available!'.format(batchminername)) 22 | 23 | batchmine_lib = BATCHMINING_METHODS[batchminername] 24 | 25 | return batchmine_lib.BatchMiner(opt) 26 | -------------------------------------------------------------------------------- /batchminer/distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | class BatchMiner(): 7 | def __init__(self, opt): 8 | self.par = opt 9 | self.lower_cutoff = opt.miner_distance_lower_cutoff 10 | self.upper_cutoff = opt.miner_distance_upper_cutoff 11 | self.name = 'distance' 12 | 13 | def __call__(self, batch, labels, tar_labels=None, return_distances=False, distances=None): 14 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 15 | bs, dim = batch.shape 16 | 17 | if distances is None: 18 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 19 | sel_d = distances.shape[-1] 20 | 21 | positives, negatives = [],[] 22 | labels_visited = [] 23 | anchors = [] 24 | 25 | tar_labels = labels if tar_labels is None else tar_labels 26 | # 27 | # neg_all = labels.reshape(-1, 1) != tar_labels.reshape(1, -1) 28 | # pos_all = labels.reshape(-1, 1) == tar_labels.reshape(1, -1) 29 | # 30 | # log_q_d_inv = ((2.0 - float(dim)) * torch.log(distances) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (distances.pow(2)))) 31 | # log_q_d_inv[pos_all] = 0 32 | # cop_q_d_inv = torch.exp(log_q_d_inv - log_q_d_inv.max(dim=1).values.reshape(-1, 1)) # - max(log) for stability 33 | # cop_q_d_inv[pos_all] = 0 34 | # cop_q_d_inv = cop_q_d_inv/cop_q_d_inv.sum(dim=1).reshape(-1, 1) 35 | # cop_q_d_inv = cop_q_d_inv.detach().cpu().numpy() 36 | 37 | for i in range(bs): 38 | neg = tar_labels!=labels[i]; pos = tar_labels==labels[i] 39 | 40 | anchors.append(i) 41 | q_d_inv = self.inverse_sphere_distances(dim, bs, distances[i], tar_labels, labels[i]) 42 | negatives.append(np.random.choice(sel_d,p=q_d_inv)) 43 | 44 | if np.sum(pos)>0: 45 | #Sample positives randomly 46 | if np.sum(pos)>1: pos[i] = 0 47 | positives.append(np.random.choice(np.where(pos)[0])) 48 | #Sample negatives by distance 49 | 50 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 51 | 52 | if return_distances: 53 | return sampled_triplets, distances 54 | else: 55 | return sampled_triplets 56 | 57 | # def __call__(self, batch, labels, tar_labels=None, return_distances=False, distances=None): 58 | # # if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 59 | # bs, dim = batch.shape 60 | # 61 | # import time 62 | # start = time.time() 63 | # if distances is None: 64 | # distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 65 | # sel_d = distances.shape[-1] 66 | # print('A', time.time() - start) 67 | # start = time.time() 68 | # positives, negatives = [],[] 69 | # labels_visited = [] 70 | # anchors = [] 71 | # 72 | # tar_labels = labels if tar_labels is None else tar_labels 73 | # 74 | # pos_all = labels.view(-1, 1) == tar_labels.view(1, -1) 75 | # pos_all_mul = ~pos_all 76 | # 77 | # log_q_d_inv = ((2.0 - float(dim)) * torch.log(distances) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (distances.pow(2)))) 78 | # 79 | # print('A1', time.time() - start) 80 | # start = time.time() 81 | # log_q_d_inv = pos_all_mul * log_q_d_inv 82 | # print('A2', time.time() - start) 83 | # start = time.time() 84 | # q_d_inv = torch.exp(log_q_d_inv - log_q_d_inv.max(dim=1).values.reshape(-1, 1)) # - max(log) for stability 85 | # print('A3', time.time() - start) 86 | # start = time.time() 87 | # q_d_inv[pos_all] = 0 88 | # print('A4', time.time() - start) 89 | # start = time.time() 90 | # q_d_inv = q_d_inv/q_d_inv.sum(dim=1).reshape(-1, 1) 91 | # q_d_inv = q_d_inv.detach().cpu().numpy() 92 | # print('A5', time.time() - start) 93 | # start = time.time() 94 | # 95 | # pos_all = pos_all.detach().cpu().numpy() 96 | # for i in range(bs): 97 | # pos = pos_all[i] 98 | # anchors.append(i) 99 | # negatives.append(np.random.choice(sel_d, p=q_d_inv[i])) 100 | # if np.sum(pos)>0: 101 | # #Sample positives randomly 102 | # if np.sum(pos)>1: pos[i] = 0 103 | # positives.append(np.random.choice(np.where(pos)[0])) 104 | # #Sample negatives by distance 105 | # 106 | # sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 107 | # 108 | # print('B', time.time() - start) 109 | # 110 | # if return_distances: 111 | # return sampled_triplets, distances 112 | # else: 113 | # return sampled_triplets 114 | 115 | 116 | def inverse_sphere_distances(self, dim, bs, anchor_to_all_dists, labels, anchor_label): 117 | dists = anchor_to_all_dists 118 | 119 | #negated log-distribution of distances of unit sphere in dimension 120 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 121 | log_q_d_inv[np.where(labels==anchor_label)[0]] = 0 122 | 123 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 124 | q_d_inv[np.where(labels==anchor_label)[0]] = 0 125 | 126 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 127 | # errors where there are no available negatives (for high samples_per_class cases). 128 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 129 | 130 | q_d_inv = q_d_inv/q_d_inv.sum() 131 | return q_d_inv.detach().cpu().numpy() 132 | 133 | 134 | def pdist(self, A): 135 | prod = torch.mm(A, A.t()) 136 | norm = prod.diag().unsqueeze(1).expand_as(prod) 137 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 138 | return res.sqrt() 139 | -------------------------------------------------------------------------------- /batchminer/diva_shared_distance_an.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.lower_cutoff = opt.miner_distance_lower_cutoff 8 | self.upper_cutoff = opt.miner_distance_upper_cutoff 9 | self.name = 'distance' 10 | 11 | def __call__(self, batch, labels): 12 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 13 | bs = batch.shape[0] 14 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 15 | 16 | positives, negatives = [],[] 17 | labels_visited = [] 18 | anchors = [] 19 | 20 | for i in range(bs): 21 | anchors.append(i) 22 | 23 | neg = labels!=labels[i] 24 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], neg, labels[i]) 25 | neg_idx = np.random.choice(bs,p=q_d_inv) 26 | negatives.append(neg_idx) 27 | 28 | pos = np.logical_and(neg, labels!=labels[neg_idx]) 29 | pos_idx = np.random.choice(np.arange(bs)[pos]) 30 | positives.append(pos_idx) 31 | 32 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 33 | return sampled_triplets 34 | 35 | 36 | def inverse_sphere_distances(self, batch, anchor_to_all_dists, labels, anchor_label): 37 | dists = anchor_to_all_dists 38 | bs,dim = len(dists),batch.shape[-1] 39 | 40 | #negated log-distribution of distances of unit sphere in dimension 41 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 42 | log_q_d_inv[np.where(1-labels)[0]] = 0 43 | 44 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 45 | q_d_inv[np.where(1-labels)[0]] = 0 46 | 47 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 48 | # errors where there are no available negatives (for high samples_per_class cases). 49 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 50 | 51 | q_d_inv = q_d_inv/q_d_inv.sum() 52 | return q_d_inv.detach().cpu().numpy() 53 | 54 | 55 | def pdist(self, A): 56 | prod = torch.mm(A, A.t()) 57 | norm = prod.diag().unsqueeze(1).expand_as(prod) 58 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 59 | return res.sqrt() 60 | -------------------------------------------------------------------------------- /batchminer/diva_shared_distance_apn.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.lower_cutoff = opt.miner_distance_lower_cutoff 8 | self.upper_cutoff = opt.miner_distance_upper_cutoff 9 | self.name = 'distance' 10 | 11 | def __call__(self, batch, labels): 12 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 13 | bs = batch.shape[0] 14 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 15 | 16 | positives, negatives = [],[] 17 | labels_visited = [] 18 | anchors = [] 19 | 20 | for i in range(bs): 21 | anchors.append(i) 22 | 23 | neg = labels!=labels[i] 24 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], neg, labels[i]) 25 | neg_idx = np.random.choice(bs,p=q_d_inv) 26 | negatives.append(neg_idx) 27 | 28 | pos = np.logical_and(neg, labels!=labels[neg_idx]) 29 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], pos, labels[i]) 30 | pos_idx = np.random.choice(bs,p=q_d_inv) 31 | positives.append(pos_idx) 32 | 33 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 34 | return sampled_triplets 35 | 36 | 37 | def inverse_sphere_distances(self, batch, anchor_to_all_dists, labels, anchor_label): 38 | dists = anchor_to_all_dists 39 | bs,dim = len(dists),batch.shape[-1] 40 | 41 | #negated log-distribution of distances of unit sphere in dimension 42 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 43 | log_q_d_inv[np.where(1-labels)[0]] = 0 44 | 45 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 46 | q_d_inv[np.where(1-labels)[0]] = 0 47 | 48 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 49 | # errors where there are no available negatives (for high samples_per_class cases). 50 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 51 | 52 | q_d_inv = q_d_inv/q_d_inv.sum() 53 | return q_d_inv.detach().cpu().numpy() 54 | 55 | 56 | def pdist(self, A): 57 | prod = torch.mm(A, A.t()) 58 | norm = prod.diag().unsqueeze(1).expand_as(prod) 59 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 60 | return res.sqrt() 61 | -------------------------------------------------------------------------------- /batchminer/easypositive.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.lower_cutoff = opt.miner_distance_lower_cutoff 8 | self.upper_cutoff = opt.miner_distance_upper_cutoff 9 | self.name = 'distance' 10 | 11 | def __call__(self, batch, labels): 12 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 13 | bs = batch.shape[0] 14 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 15 | 16 | positives, negatives = [],[] 17 | labels_visited = [] 18 | anchors = [] 19 | 20 | for i in range(bs): 21 | neg = labels!=labels[i]; pos = labels==labels[i] 22 | 23 | if np.sum(pos)>1: 24 | anchors.append(i) 25 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i]) 26 | #Sample positives randomly 27 | pos[i] = 0 28 | positives.append(np.random.choice(np.where(pos)[0])) 29 | #Sample negatives by distance 30 | negatives.append(np.random.choice(bs,p=q_d_inv)) 31 | 32 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 33 | return sampled_triplets 34 | 35 | 36 | def inverse_sphere_distances(self, batch, anchor_to_all_dists, labels, anchor_label): 37 | dists = anchor_to_all_dists 38 | bs,dim = len(dists),batch.shape[-1] 39 | 40 | #negated log-distribution of distances of unit sphere in dimension 41 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 42 | log_q_d_inv[np.where(labels==anchor_label)[0]] = 0 43 | 44 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 45 | q_d_inv[np.where(labels==anchor_label)[0]] = 0 46 | 47 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 48 | # errors where there are no available negatives (for high samples_per_class cases). 49 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 50 | 51 | q_d_inv = q_d_inv/q_d_inv.sum() 52 | return q_d_inv.detach().cpu().numpy() 53 | 54 | 55 | def pdist(self, A): 56 | prod = torch.mm(A, A.t()) 57 | norm = prod.diag().unsqueeze(1).expand_as(prod) 58 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 59 | return res.sqrt() 60 | -------------------------------------------------------------------------------- /batchminer/epshn.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.name = 'epshn' 8 | 9 | def __call__(self, batch, labels, return_distances=False): 10 | if isinstance(labels, torch.Tensor): labels = labels.detach().numpy() 11 | bs = batch.size(0) 12 | #Return distance matrix for all elements in batch (BSxBS) 13 | distances = self.pdist(batch.detach()).detach().cpu().numpy() 14 | 15 | positives, negatives = [], [] 16 | anchors = [] 17 | for i in range(bs): 18 | l, d = labels[i], distances[i] 19 | neg = labels!=l; pos = labels==l 20 | 21 | anchors.append(i) 22 | pos[i] = False 23 | positives.append(pos[np.argmin(distances[i][pos])]) 24 | 25 | #Find negatives that violate tripet constraint semi-negatives 26 | neg_mask = np.logical_and(neg,d>d[p]) 27 | neg_mask = np.logical_and(neg_mask,d0: 29 | negatives.append(np.random.choice(np.where(neg_mask)[0])) 30 | else: 31 | negatives.append(np.random.choice(np.where(neg)[0])) 32 | 33 | sampled_triplets = [[a, p, n] for a, p, n in zip(anchors, positives, negatives)] 34 | 35 | if return_distances: 36 | return samples_triplets, distances 37 | else: 38 | return sampled_triplets 39 | 40 | 41 | def pdist(self, A): 42 | prod = torch.mm(A, A.t()) 43 | norm = prod.diag().unsqueeze(1).expand_as(prod) 44 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 45 | return res.clamp(min = 0).sqrt() 46 | -------------------------------------------------------------------------------- /batchminer/intra_random.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | import itertools as it 3 | import random 4 | 5 | class BatchMiner(): 6 | def __init__(self, opt): 7 | self.par = opt 8 | self.name = 'random' 9 | 10 | def __call__(self, batch, labels): 11 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 12 | unique_classes = np.unique(labels) 13 | indices = np.arange(len(batch)) 14 | class_dict = {i:indices[labels==i] for i in unique_classes} 15 | 16 | sampled_triplets = [] 17 | for cls in np.random.choice(list(class_dict.keys()), len(labels), replace=True): 18 | a,p,n = np.random.choice(class_dict[cls], 3, replace=True) 19 | sampled_triplets.append((a,p,n)) 20 | 21 | return sampled_triplets 22 | -------------------------------------------------------------------------------- /batchminer/lifted.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | class BatchMiner(): 4 | def __init__(self, opt): 5 | self.par = opt 6 | self.name = 'lifted' 7 | 8 | def __call__(self, batch, labels): 9 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 10 | 11 | ### 12 | anchors, positives, negatives = [], [], [] 13 | list(range(len(batch))) 14 | 15 | for i in range(len(batch)): 16 | anchor = i 17 | pos = labels==labels[anchor] 18 | 19 | ### 20 | if np.sum(pos)>1: 21 | anchors.append(anchor) 22 | positive_set = np.where(pos)[0] 23 | positive_set = positive_set[positive_set!=anchor] 24 | positives.append(positive_set) 25 | 26 | ### 27 | negatives = [] 28 | for anchor,positive_set in zip(anchors, positives): 29 | neg_idxs = [i for i in range(len(batch)) if i not in [anchor]+list(positive_set)] 30 | negative_set = np.arange(len(batch))[neg_idxs] 31 | negatives.append(negative_set) 32 | 33 | return anchors, positives, negatives 34 | -------------------------------------------------------------------------------- /batchminer/npair.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | class BatchMiner(): 3 | def __init__(self, opt): 4 | self.par = opt 5 | self.name = 'npair' 6 | 7 | def __call__(self, batch, labels): 8 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 9 | 10 | anchors, positives, negatives = [],[],[] 11 | 12 | for i in range(len(batch)): 13 | anchor = i 14 | pos = labels==labels[anchor] 15 | 16 | if np.sum(pos)>1: 17 | anchors.append(anchor) 18 | avail_positive = np.where(pos)[0] 19 | avail_positive = avail_positive[avail_positive!=anchor] 20 | positive = np.random.choice(avail_positive) 21 | positives.append(positive) 22 | 23 | ### 24 | negatives = [] 25 | for anchor,positive in zip(anchors, positives): 26 | neg_idxs = [i for i in range(len(batch)) if i not in [anchor, positive] and labels[i] != labels[anchor]] 27 | # neg_idxs = [i for i in range(len(batch)) if i not in [anchor, positive]] 28 | negative_set = np.arange(len(batch))[neg_idxs] 29 | negatives.append(negative_set) 30 | 31 | return anchors, positives, negatives 32 | -------------------------------------------------------------------------------- /batchminer/parametric.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.mode = opt.miner_parametric_mode 8 | self.n_support = opt.miner_parametric_n_support 9 | self.support_lim = opt.miner_parametric_support_lim 10 | self.name = 'parametric' 11 | 12 | ### 13 | self.set_sample_distr() 14 | 15 | 16 | 17 | def __call__(self, batch, labels): 18 | bs = batch.shape[0] 19 | sample_distr = self.sample_distr 20 | 21 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 22 | 23 | ### 24 | distances = self.pdist(batch.detach()) 25 | 26 | p_assigns = np.sum((distances.cpu().numpy().reshape(-1)>self.support[1:-1].reshape(-1,1)).T,axis=1).reshape(distances.shape) 27 | outside_support_lim = (distances.cpu().numpy().reshape(-1)self.support_lim[1]) 28 | outside_support_lim = outside_support_lim.reshape(distances.shape) 29 | 30 | sample_ps = sample_distr[p_assigns] 31 | sample_ps[outside_support_lim] = 0 32 | 33 | ### 34 | anchors, labels_visited = [], [] 35 | positives, negatives = [],[] 36 | 37 | ### 38 | for i in range(bs): 39 | neg = labels!=labels[i]; pos = labels==labels[i] 40 | 41 | if np.sum(pos)>1: 42 | anchors.append(i) 43 | 44 | #Sample positives randomly 45 | pos[i] = 0 46 | positives.append(np.random.choice(np.where(pos)[0])) 47 | 48 | #Sample negatives by distance 49 | sample_p = sample_ps[i][neg] 50 | sample_p = sample_p/sample_p.sum() 51 | negatives.append(np.random.choice(np.arange(bs)[neg],p=sample_p)) 52 | 53 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 54 | return sampled_triplets 55 | 56 | 57 | 58 | def pdist(self, A, eps=1e-4): 59 | prod = torch.mm(A, A.t()) 60 | norm = prod.diag().unsqueeze(1).expand_as(prod) 61 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 62 | return res.clamp(min = eps).sqrt() 63 | 64 | 65 | def set_sample_distr(self): 66 | self.support = np.linspace(self.support_lim[0], self.support_lim[1], self.n_support) 67 | 68 | if self.mode == 'uniform': 69 | self.sample_distr = np.array([1.] * (self.n_support-1)) 70 | 71 | if self.mode == 'hards': 72 | self.sample_distr = self.support.copy() 73 | self.sample_distr[self.support<=0.5] = 1 74 | self.sample_distr[self.support>0.5] = 0 75 | 76 | if self.mode == 'semihards': 77 | self.sample_distr = self.support.copy() 78 | from IPython import embed; embed() 79 | self.sample_distr[(self.support<=0.7) * (self.support>=0.3)] = 1 80 | self.sample_distr[(self.support<0.3) * (self.support>0.7)] = 0 81 | 82 | if self.mode == 'veryhards': 83 | self.sample_distr = self.support.copy() 84 | self.sample_distr[self.support<=0.3] = 1 85 | self.sample_distr[self.support>0.3] = 0 86 | 87 | self.sample_distr = np.clip(self.sample_distr, 1e-15, 1) 88 | self.sample_distr = self.sample_distr/self.sample_distr.sum() 89 | -------------------------------------------------------------------------------- /batchminer/random.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | import itertools as it 3 | import random 4 | 5 | class BatchMiner(): 6 | def __init__(self, opt): 7 | self.par = opt 8 | self.name = 'random' 9 | 10 | def __call__(self, batch, labels): 11 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 12 | unique_classes = np.unique(labels) 13 | indices = np.arange(len(batch)) 14 | class_dict = {i:indices[labels==i] for i in unique_classes} 15 | 16 | sampled_triplets = [list(it.product([x],[x],[y for y in unique_classes if x!=y])) for x in unique_classes] 17 | sampled_triplets = [x for y in sampled_triplets for x in y] 18 | 19 | sampled_triplets = [[x for x in list(it.product(*[class_dict[j] for j in i])) if x[0]!=x[1]] for i in sampled_triplets] 20 | sampled_triplets = [x for y in sampled_triplets for x in y] 21 | 22 | #NOTE: The number of possible triplets is given by #unique_classes*(2*(samples_per_class-1)!)*(#unique_classes-1)*samples_per_class 23 | sampled_triplets = random.sample(sampled_triplets, batch.shape[0]) 24 | return sampled_triplets 25 | -------------------------------------------------------------------------------- /batchminer/random_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.lower_cutoff = opt.miner_distance_lower_cutoff 8 | self.upper_cutoff = opt.miner_distance_upper_cutoff 9 | self.name = 'distance' 10 | 11 | def __call__(self, batch, labels): 12 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 13 | labels = labels[np.random.choice(len(labels), len(labels), replace=False)] 14 | 15 | bs = batch.shape[0] 16 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 17 | 18 | positives, negatives = [],[] 19 | labels_visited = [] 20 | anchors = [] 21 | 22 | for i in range(bs): 23 | neg = labels!=labels[i]; pos = labels==labels[i] 24 | 25 | if np.sum(pos)>1: 26 | anchors.append(i) 27 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i]) 28 | #Sample positives randomly 29 | pos[i] = 0 30 | positives.append(np.random.choice(np.where(pos)[0])) 31 | #Sample negatives by distance 32 | negatives.append(np.random.choice(bs,p=q_d_inv)) 33 | 34 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 35 | return sampled_triplets 36 | 37 | 38 | def inverse_sphere_distances(self, batch, anchor_to_all_dists, labels, anchor_label): 39 | dists = anchor_to_all_dists 40 | bs,dim = len(dists),batch.shape[-1] 41 | 42 | #negated log-distribution of distances of unit sphere in dimension 43 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 44 | log_q_d_inv[np.where(labels==anchor_label)[0]] = 0 45 | 46 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 47 | q_d_inv[np.where(labels==anchor_label)[0]] = 0 48 | 49 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 50 | # errors where there are no available negatives (for high samples_per_class cases). 51 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 52 | 53 | q_d_inv = q_d_inv/q_d_inv.sum() 54 | return q_d_inv.detach().cpu().numpy() 55 | 56 | 57 | def pdist(self, A): 58 | prod = torch.mm(A, A.t()) 59 | norm = prod.diag().unsqueeze(1).expand_as(prod) 60 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 61 | return res.sqrt() 62 | -------------------------------------------------------------------------------- /batchminer/rho_distance.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.lower_cutoff = opt.miner_rho_distance_lower_cutoff 8 | self.upper_cutoff = opt.miner_rho_distance_upper_cutoff 9 | self.contrastive_p = opt.miner_rho_distance_cp 10 | 11 | self.name = 'rho_distance' 12 | 13 | def __call__(self, batch, labels, return_distances=False): 14 | if isinstance(labels, torch.Tensor): labels = labels.detach().cpu().numpy() 15 | bs = batch.shape[0] 16 | distances = self.pdist(batch.detach()).clamp(min=self.lower_cutoff) 17 | 18 | positives, negatives = [],[] 19 | labels_visited = [] 20 | anchors = [] 21 | 22 | for i in range(bs): 23 | neg = labels!=labels[i]; pos = labels==labels[i] 24 | 25 | use_contr = np.random.choice(2, p=[1-self.contrastive_p, self.contrastive_p]) 26 | if np.sum(pos)>1: 27 | anchors.append(i) 28 | if use_contr: 29 | positives.append(i) 30 | #Sample negatives by distance 31 | pos[i] = 0 32 | negatives.append(np.random.choice(np.where(pos)[0])) 33 | else: 34 | q_d_inv = self.inverse_sphere_distances(batch, distances[i], labels, labels[i]) 35 | #Sample positives randomly 36 | pos[i] = 0 37 | positives.append(np.random.choice(np.where(pos)[0])) 38 | #Sample negatives by distance 39 | negatives.append(np.random.choice(bs,p=q_d_inv)) 40 | 41 | sampled_triplets = [[a,p,n] for a,p,n in zip(anchors, positives, negatives)] 42 | self.push_triplets = np.sum([m[1]==m[2] for m in labels[sampled_triplets]]) 43 | 44 | if return_distances: 45 | return sampled_triplets, distances 46 | else: 47 | return sampled_triplets 48 | 49 | 50 | def inverse_sphere_distances(self, batch, anchor_to_all_dists, labels, anchor_label): 51 | dists = anchor_to_all_dists 52 | bs,dim = len(dists),batch.shape[-1] 53 | 54 | #negated log-distribution of distances of unit sphere in dimension 55 | log_q_d_inv = ((2.0 - float(dim)) * torch.log(dists) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (dists.pow(2)))) 56 | log_q_d_inv[np.where(labels==anchor_label)[0]] = 0 57 | 58 | q_d_inv = torch.exp(log_q_d_inv - torch.max(log_q_d_inv)) # - max(log) for stability 59 | q_d_inv[np.where(labels==anchor_label)[0]] = 0 60 | 61 | ### NOTE: Cutting of values with high distances made the results slightly worse. It can also lead to 62 | # errors where there are no available negatives (for high samples_per_class cases). 63 | # q_d_inv[np.where(dists.detach().cpu().numpy()>self.upper_cutoff)[0]] = 0 64 | 65 | q_d_inv = q_d_inv/q_d_inv.sum() 66 | return q_d_inv.detach().cpu().numpy() 67 | 68 | 69 | def pdist(self, A, eps=1e-4): 70 | prod = torch.mm(A, A.t()) 71 | norm = prod.diag().unsqueeze(1).expand_as(prod) 72 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 73 | return res.clamp(min = eps).sqrt() 74 | -------------------------------------------------------------------------------- /batchminer/semihard.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.name = 'semihard' 8 | self.margin = vars(opt)['loss_'+opt.loss+'_margin'] 9 | 10 | def __call__(self, batch, labels, return_distances=False): 11 | if isinstance(labels, torch.Tensor): labels = labels.detach().numpy() 12 | bs = batch.size(0) 13 | #Return distance matrix for all elements in batch (BSxBS) 14 | distances = self.pdist(batch.detach()).detach().cpu().numpy() 15 | 16 | positives, negatives = [], [] 17 | anchors = [] 18 | for i in range(bs): 19 | l, d = labels[i], distances[i] 20 | neg = labels!=l; pos = labels==l 21 | 22 | anchors.append(i) 23 | pos[i] = 0 24 | p = np.random.choice(np.where(pos)[0]) 25 | positives.append(p) 26 | 27 | #Find negatives that violate tripet constraint semi-negatives 28 | neg_mask = np.logical_and(neg,d>d[p]) 29 | neg_mask = np.logical_and(neg_mask,d0: 31 | negatives.append(np.random.choice(np.where(neg_mask)[0])) 32 | else: 33 | negatives.append(np.random.choice(np.where(neg)[0])) 34 | 35 | sampled_triplets = [[a, p, n] for a, p, n in zip(anchors, positives, negatives)] 36 | 37 | if return_distances: 38 | return sampled_triplets, distances 39 | else: 40 | return sampled_triplets 41 | 42 | 43 | def pdist(self, A): 44 | prod = torch.mm(A, A.t()) 45 | norm = prod.diag().unsqueeze(1).expand_as(prod) 46 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 47 | return res.clamp(min = 0).sqrt() 48 | -------------------------------------------------------------------------------- /batchminer/softhard.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | 4 | class BatchMiner(): 5 | def __init__(self, opt): 6 | self.par = opt 7 | self.name = 'softhard' 8 | 9 | def __call__(self, batch, labels, return_distances=False): 10 | if isinstance(labels, torch.Tensor): labels = labels.detach().numpy() 11 | bs = batch.size(0) 12 | #Return distance matrix for all elements in batch (BSxBS) 13 | distances = self.pdist(batch.detach()).detach().cpu().numpy() 14 | 15 | positives, negatives = [], [] 16 | anchors = [] 17 | for i in range(bs): 18 | l, d = labels[i], distances[i] 19 | neg = labels!=l; pos = labels==l 20 | 21 | if np.sum(pos)>1: 22 | anchors.append(i) 23 | #1 for batchelements with label l 24 | #0 for current anchor 25 | pos[i] = False 26 | 27 | #Find negatives that violate triplet constraint in a hard fashion 28 | neg_mask = np.logical_and(neg,dd[np.where(neg)[0]].min()) 31 | 32 | if pos_mask.sum()>0: 33 | positives.append(np.random.choice(np.where(pos_mask)[0])) 34 | else: 35 | positives.append(np.random.choice(np.where(pos)[0])) 36 | 37 | if neg_mask.sum()>0: 38 | negatives.append(np.random.choice(np.where(neg_mask)[0])) 39 | else: 40 | negatives.append(np.random.choice(np.where(neg)[0])) 41 | 42 | sampled_triplets = [[a, p, n] for a, p, n in zip(anchors, positives, negatives)] 43 | if return_distances: 44 | return sampled_triplets, distances 45 | else: 46 | return sampled_triplets 47 | 48 | 49 | 50 | def pdist(self, A): 51 | prod = torch.mm(A, A.t()) 52 | norm = prod.diag().unsqueeze(1).expand_as(prod) 53 | res = (norm + norm.t() - 2 * prod).clamp(min = 0) 54 | return res.clamp(min = 0).sqrt() 55 | -------------------------------------------------------------------------------- /create_dataset_splits.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | import os, sys, numpy as np, argparse, imp, datetime, pandas as pd, copy 4 | sys.path.insert(0, '..') 5 | import time, pickle as pkl, random, json, collections 6 | 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | from sklearn.metrics import pairwise_distances 10 | import torch, torch.nn as nn 11 | import torch.multiprocessing 12 | torch.multiprocessing.set_sharing_strategy('file_system') 13 | from tqdm import tqdm 14 | 15 | import architectures as archs 16 | import datasets as datasets 17 | import metrics as metrics 18 | from utilities import misc 19 | import parameters as par 20 | import utilities.misc as misc 21 | import split_helpers as helper 22 | 23 | 24 | """===================================================================================================""" 25 | parser = argparse.ArgumentParser() 26 | parser = par.basic_training_parameters(parser) 27 | parser = par.batch_creation_parameters(parser) 28 | parser = par.batchmining_specific_parameters(parser) 29 | parser = par.loss_specific_parameters(parser) 30 | parser = par.wandb_parameters(parser) 31 | parser.add_argument('--n_swaps', default=25, type=int) 32 | parser.add_argument('--swaps_iter', default=2, type=int) 33 | parser.add_argument('--load', action='store_true') 34 | parser.add_argument('--super', action='store_true') 35 | ##### Read in parameters 36 | # Run with e.g. python create_dataset_splits.py --dataset cub200 [cars196, onlihe_products]. 37 | opt = parser.parse_args() 38 | # Note: For SOP, set e.g. opt.swaps_iter = 1000 and opt.n_swaps=20, respectively. 39 | 40 | 41 | """===================================================================================================""" 42 | def set_seed(seed): 43 | torch.backends.cudnn.deterministic=True; 44 | np.random.seed(seed); random.seed(seed) 45 | torch.manual_seed(seed); torch.cuda.manual_seed(seed); torch.cuda.manual_seed_all(seed) 46 | set_seed(opt.seed) 47 | 48 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 49 | os.environ["CUDA_VISIBLE_DEVICES"]= str(opt.gpu[0]) 50 | 51 | opt.device = torch.device('cuda') 52 | model = archs.select(opt.arch, opt) 53 | _ = model.to(opt.device) 54 | 55 | dataloaders = {} 56 | if opt.dataset=='online_products': 57 | opt.source_path += '/'+opt.dataset 58 | datasets = datasets.select(opt.dataset, opt, opt.source_path) 59 | else: 60 | datasets = datasets.select(opt.dataset, opt, opt.source_path+'/'+opt.dataset) 61 | dataloaders['training'] = torch.utils.data.DataLoader(datasets['evaluation'], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False) 62 | dataloaders['testing'] = torch.utils.data.DataLoader(datasets['testing'], num_workers=opt.kernels, batch_size=opt.bs, shuffle=False) 63 | 64 | 65 | """===================================================================================================""" 66 | info_dict = {} 67 | feat_collect = [] 68 | class_labels_collect = [] 69 | img_paths_collect = [] 70 | 71 | # These are the splits and FIDs used in our experiments. 72 | # Note that due to internal differences between the initial script and this public one, 73 | # there won't be an EXACT matching, but shifts/splits will be very close. 74 | splits_to_use = { 75 | 'cub200': { 76 | 'id': ['-20', '-10', '0', '6', '10', '30', 'R22', 'R48', 'R66'], 77 | 'base_id': ['-20', '-10', '0', '6', '10', '30'], 78 | 'fid': [19.16, 28.49, 52.62, 72.20, 92.48, 120.38, 136.45, 152.04, 173.94]}, 79 | 'cars196': { 80 | 'id': ['0', '6', '16', '20', '30', 'R18', 'R42', 'R66'], 81 | 'base_id': ['0', '6', '16', '20', '30'], 82 | 'fid': [8.59, 14.33, 32.18, 43.58, 63.29, 86.48, 101.17, 123.03]}, 83 | 'online_products': { 84 | 'id': ['0', '1000', '2000', '3000', '4000', '5000', 'R2000', 'R6000'], 85 | 'base_id': ['0', '1000', '2000', '3000', '4000', '5000'], 86 | 'fid': [3.43, 24.59, 53.47, 99.38, 135.53, 155.25, 189.81, 235.10]} 87 | } 88 | 89 | if not opt.load: 90 | info_dict = helper.get_features(model, dataloaders, opt.dataset, opt.device) 91 | # Save dictionaries of features and classmeans. 92 | pkl.dump(info_dict,open('{}_dict.pkl'.format(opt.dataset),'wb')) 93 | else: 94 | # If chosen, load pretrained embedding dictionaries. 95 | info_dict = pkl.load(open('{}_dict.pkl'.format(opt.dataset),'rb')) 96 | print("Data loaded!\n") 97 | 98 | 99 | """===============================================================""" 100 | # If opt.super is set, swap classes by superclass context. 101 | if opt.super: 102 | train_classmean_feats = info_dict['training']['classmeans_super']['feats'] 103 | train_classmean_cls = info_dict['training']['classmeans_super']['classes'] 104 | test_classmean_feats = info_dict['testing']['classmeans_super']['feats'] 105 | test_classmean_cls = info_dict['testing']['classmeans_super']['classes'] 106 | else: 107 | train_classmean_feats = info_dict['training']['classmeans']['feats'] 108 | train_classmean_cls = info_dict['training']['classmeans']['classes'] 109 | test_classmean_feats = info_dict['testing']['classmeans']['feats'] 110 | test_classmean_cls = info_dict['testing']['classmeans']['classes'] 111 | 112 | 113 | # Generate harder (more OOD) splits with class swapping. 114 | hard_SPLITS, hard_fids, hard_final_feats = helper.split_maker( 115 | copy.deepcopy(train_classmean_feats), copy.deepcopy(train_classmean_cls), 116 | copy.deepcopy(test_classmean_feats), copy.deepcopy(test_classmean_cls), 117 | N_SWAPS=opt.n_swaps, SWAPS_PER_ITER=opt.swaps_iter, HISTORY=5, inverse=False 118 | ) 119 | 120 | # Generate harder splits via class removal. 121 | hard_removed_SPLITS, hard_removed_fids = helper.split_maker_with_class_removal( 122 | copy.deepcopy(hard_final_feats['train']), copy.deepcopy(hard_SPLITS[48]['train']), 123 | copy.deepcopy(hard_final_feats['test']), copy.deepcopy(hard_SPLITS[48]['test']), 124 | N_SWAPS=opt.n_swaps+10, SWAPS_PER_ITER=opt.swaps_iter, HISTORY=5, inverse=False 125 | ) 126 | 127 | # Generate easier (less OOD) splits with class swapping. 128 | if opt.dataset == 'cub200': 129 | easy_SPLITS, easy_fids, _ = helper.split_maker( 130 | copy.deepcopy(train_classmean_feats), copy.deepcopy(train_classmean_cls), 131 | copy.deepcopy(test_classmean_feats), copy.deepcopy(test_classmean_cls), 132 | N_SWAPS=opt.n_swaps-10, SWAPS_PER_ITER=opt.swaps_iter, HISTORY=30, inverse=True 133 | ) 134 | 135 | SPLITS = {} 136 | for key in hard_SPLITS.keys(): 137 | SPLITS[key] = {} 138 | SPLITS[key]['train'] = sorted(hard_SPLITS[key]['train']) 139 | SPLITS[key]['test'] = sorted(hard_SPLITS[key]['test']) 140 | SPLITS[key]['fid'] = hard_SPLITS[key]['fid'] 141 | if opt.dataset == 'cub200': 142 | for key in easy_SPLITS.keys(): 143 | if key not in SPLITS.keys(): 144 | SPLITS[key] = {} 145 | SPLITS[key]['train'] = sorted(easy_SPLITS[key]['train']) 146 | SPLITS[key]['test'] = sorted(easy_SPLITS[key]['test']) 147 | SPLITS[key]['fid'] = easy_SPLITS[key]['fid'] 148 | 149 | 150 | """===============================================================""" 151 | # Only select the splits that are going to be used for the experiments. 152 | merged_dict = {} 153 | for i, idx in enumerate(splits_to_use[opt.dataset]['id']): 154 | if 'R' not in idx: 155 | if '-' in idx: 156 | idx = int(idx) 157 | merged_dict[i+1] = easy_SPLIT[idx] 158 | else: 159 | idx = int(idx) 160 | merged_dict[i+1] = hard_SPLIT[idx] 161 | else: 162 | idx = idx.replace('R', '') 163 | idx = int(idx) 164 | merged_dict[i+1] = hard_removed_SPLITS[idx] 165 | 166 | # Save complete split dictionary. 167 | pkl.dump(merged_dict, open('{}{}_splits.pkl'.format('super_' if opt.super else '', opt.dataset),'wb')) 168 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- 1 | ### Standard DML criteria 2 | from criteria import triplet, margin, proxynca 3 | from criteria import contrastive, angular, arcface 4 | from criteria import multisimilarity, quadruplet, oproxy 5 | ### Non-Standard Criteria 6 | from criteria import moco, adversarial_separation, fast_moco, imrot, s2sd 7 | ### Basic Libs 8 | import copy 9 | 10 | 11 | """=================================================================================================""" 12 | def select(loss, opt, to_optim=None, batchminer=None): 13 | losses = {'triplet': triplet, 14 | 'margin':margin, 15 | 'proxynca':proxynca, 16 | 's2sd':s2sd, 17 | 'angular':angular, 18 | 'contrastive':contrastive, 19 | 'oproxy':oproxy, 20 | 'multisimilarity':multisimilarity, 21 | 'arcface':arcface, 22 | 'quadruplet':quadruplet, 23 | 'adversarial_separation':adversarial_separation, 24 | 'moco': moco, 25 | 'imrot':imrot, 26 | 'fast_moco':fast_moco} 27 | 28 | 29 | if loss not in losses: raise NotImplementedError('Loss {} not implemented!'.format(loss)) 30 | 31 | loss_lib = losses[loss] 32 | if loss_lib.REQUIRES_BATCHMINER: 33 | if batchminer is None: 34 | raise Exception('Loss {} requires one of the following batch mining methods: {}'.format(loss, loss_lib.ALLOWED_MINING_OPS)) 35 | else: 36 | if batchminer.name not in loss_lib.ALLOWED_MINING_OPS: 37 | raise Exception('{}-mining not allowed for {}-loss!'.format(batchminer.name, loss)) 38 | 39 | loss_par_dict = {'opt':opt} 40 | if loss_lib.REQUIRES_BATCHMINER: 41 | loss_par_dict['batchminer'] = batchminer 42 | 43 | criterion = loss_lib.Criterion(**loss_par_dict) 44 | 45 | if to_optim is not None: 46 | if loss_lib.REQUIRES_OPTIM: 47 | if hasattr(criterion,'optim_dict_list') and criterion.optim_dict_list is not None: 48 | to_optim += criterion.optim_dict_list 49 | else: 50 | to_optim += [{'params':criterion.parameters(), 'lr':criterion.lr}] 51 | 52 | return criterion, to_optim 53 | else: 54 | return criterion 55 | -------------------------------------------------------------------------------- /criteria/adversarial_separation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | 11 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | margin: Triplet Margin. 17 | nu: Regularisation Parameter for beta values if they are learned. 18 | beta: Class-Margin values. 19 | n_classes: Number of different classes during training. 20 | """ 21 | super().__init__() 22 | 23 | #### 24 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 25 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 26 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 27 | 28 | #### 29 | self.embed_dim = opt.embed_dim 30 | self.proj_dim = opt.diva_decorrnet_dim 31 | 32 | self.directions = opt.diva_decorrelations 33 | self.weights = opt.diva_rho_decorrelation 34 | 35 | self.name = 'adversarial_separation' 36 | 37 | #Projection network 38 | self.regressors = nn.ModuleDict() 39 | for direction in self.directions: 40 | self.regressors[direction] = torch.nn.Sequential(torch.nn.Linear(self.embed_dim, self.proj_dim), torch.nn.ReLU(), torch.nn.Linear(self.proj_dim, self.embed_dim)).to(torch.float).to(opt.device) 41 | 42 | #Learning Rate for Projection Network 43 | self.lr = opt.diva_decorrnet_lr 44 | 45 | 46 | def forward(self, feature_dict): 47 | #Apply gradient reversal on input embeddings. 48 | adj_feature_dict = {key:torch.nn.functional.normalize(grad_reverse(features),dim=-1) for key, features in feature_dict.items()} 49 | #Project one embedding to the space of the other (with normalization), then compute the correlation. 50 | sim_loss = 0 51 | for weight, direction in zip(self.weights, self.directions): 52 | source, target = direction.split('-') 53 | sim_loss += -1.*weight*torch.mean(torch.mean((adj_feature_dict[target]*torch.nn.functional.normalize(self.regressors[direction](adj_feature_dict[source]),dim=-1))**2,dim=-1)) 54 | return sim_loss 55 | 56 | 57 | 58 | ### Gradient Reversal Layer 59 | class GradRev(torch.autograd.Function): 60 | @staticmethod 61 | def forward(ctx, x): 62 | return x.view_as(x) 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | return (grad_output * -1.) 66 | 67 | ### Gradient reverse function 68 | def grad_reverse(x): 69 | return GradRev.apply(x) 70 | -------------------------------------------------------------------------------- /criteria/angular.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = ['npair'] 8 | REQUIRES_BATCHMINER = True 9 | REQUIRES_OPTIM = False 10 | 11 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt, batchminer): 14 | """ 15 | Args: 16 | margin: Triplet Margin. 17 | nu: Regularisation Parameter for beta values if they are learned. 18 | beta: Class-Margin values. 19 | n_classes: Number of different classes during training. 20 | """ 21 | super(Criterion, self).__init__() 22 | 23 | self.tan_angular_margin = np.tan(np.pi/180*opt.loss_angular_alpha) 24 | self.lam = opt.loss_angular_npair_ang_weight 25 | self.l2_weight = opt.loss_angular_npair_l2 26 | self.batchminer = batchminer 27 | 28 | self.name = 'angular' 29 | 30 | #### 31 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 32 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 33 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 34 | 35 | def forward(self, batch, labels, **kwargs): 36 | """ 37 | Args: 38 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 39 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 40 | """ 41 | ####NOTE: Normalize Angular Loss, but not normalize npair loss! 42 | anchors, positives, negatives = self.batchminer(batch, labels) 43 | anchors, positives, negatives = batch[anchors], batch[positives], batch[negatives] 44 | n_anchors, n_positives, n_negatives = F.normalize(anchors, dim=1), F.normalize(positives, dim=1), F.normalize(negatives, dim=-1) 45 | 46 | is_term1 = 4*self.tan_angular_margin**2*(n_anchors + n_positives)[:,None,:].bmm(n_negatives.permute(0,2,1)) 47 | is_term2 = 2*(1+self.tan_angular_margin**2)*n_anchors[:,None,:].bmm(n_positives[:,None,:].permute(0,2,1)) 48 | is_term1 = is_term1.view(is_term1.shape[0], is_term1.shape[-1]) 49 | is_term2 = is_term2.view(-1, 1) 50 | 51 | inner_sum_ang = is_term1 - is_term2 52 | angular_loss = torch.mean(torch.log(torch.sum(torch.exp(inner_sum_ang), dim=1) + 1)) 53 | 54 | 55 | inner_sum_npair = anchors[:,None,:].bmm((negatives - positives[:,None,:]).permute(0,2,1)) 56 | inner_sum_npair = inner_sum_npair.view(inner_sum_npair.shape[0], inner_sum_npair.shape[-1]) 57 | npair_loss = torch.mean(torch.log(torch.sum(torch.exp(inner_sum_npair.clamp(max=50,min=-50)), dim=1) + 1)) 58 | 59 | loss = npair_loss + self.lam*angular_loss + self.l2_weight*torch.mean(torch.norm(batch, p=2, dim=1)) 60 | return loss 61 | -------------------------------------------------------------------------------- /criteria/arcface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = None 7 | REQUIRES_BATCHMINER = False 8 | REQUIRES_OPTIM = True 9 | 10 | ### This implementation follows the pseudocode provided in the original paper. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt): 13 | """ 14 | Args: 15 | margin: Triplet Margin. 16 | """ 17 | super(Criterion, self).__init__() 18 | self.par = opt 19 | 20 | #### 21 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 22 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 23 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 24 | 25 | #### 26 | self.angular_margin = opt.loss_arcface_angular_margin 27 | self.feature_scale = opt.loss_arcface_feature_scale 28 | 29 | self.class_map = torch.nn.Parameter(torch.Tensor(opt.n_classes, opt.embed_dim)) 30 | stdv = 1. / np.sqrt(self.class_map.size(1)) 31 | self.class_map.data.uniform_(-stdv, stdv) 32 | 33 | self.name = 'arcface' 34 | 35 | self.lr = opt.loss_arcface_lr 36 | 37 | def forward(self, batch, labels, **kwargs): 38 | """ 39 | Args: 40 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 41 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 42 | """ 43 | bs, labels = len(batch), labels.to(self.par.device) 44 | 45 | class_map = torch.nn.functional.normalize(self.class_map, dim=1) 46 | #Note that the similarity becomes the cosine for normalized embeddings. Denoted as 'fc7' in the paper pseudocode. 47 | cos_similarity = batch.mm(class_map.T).clamp(min=1e-10, max=1-1e-10) 48 | 49 | pick = torch.zeros(bs, self.par.n_classes).bool().to(self.par.device) 50 | pick[torch.arange(bs), labels] = 1 51 | 52 | original_target_logit = cos_similarity[pick.type(torch.bool)] 53 | 54 | theta = torch.acos(original_target_logit) 55 | marginal_target_logit = torch.cos(theta + self.angular_margin) 56 | 57 | class_pred = self.feature_scale * (cos_similarity + pick * (marginal_target_logit-original_target_logit).unsqueeze(1)) 58 | loss = torch.nn.CrossEntropyLoss()(class_pred, labels) 59 | 60 | return loss 61 | -------------------------------------------------------------------------------- /criteria/base_criterion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | import criteria 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = None 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | 11 | 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | opt: Namespace containing all relevant parameters. 17 | """ 18 | super(Criterion, self).__init__() 19 | 20 | #### 21 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 22 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 23 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 24 | 25 | #### 26 | self.num_proxies = opt.n_classes 27 | self.embed_dim = opt.embed_dim 28 | self.proxy_div = opt.loss_oproxy_proxy_div 29 | 30 | 31 | self.proxy_init_dev = opt.loss_oproxy_init_dev 32 | self.proxies = torch.randn(self.num_proxies, self.embed_dim)/self.proxy_div 33 | self.proxies = torch.nn.Parameter(self.proxies-self.proxies.max(dim=0)[0]*self.proxy_init_dev) 34 | self.optim_dict_list = [{'params':self.proxies, 'lr':opt.lr * opt.loss_oproxy_lrmulti}] 35 | 36 | self.class_idxs = torch.arange(self.num_proxies) 37 | 38 | self.name = 'oproxy' 39 | 40 | self.pars = {'pos_alpha':opt.loss_oproxy_pos_alpha, 41 | 'pos_delta':opt.loss_oproxy_pos_delta, 42 | 'neg_alpha':opt.loss_oproxy_neg_alpha, 43 | 'neg_delta':opt.loss_oproxy_neg_delta} 44 | 45 | self.learn_hyper = opt.loss_oproxy_learn_hyper 46 | if self.learn_hyper: 47 | self.pars = torch.nn.ParameterDict(self.pars) 48 | self.optim_dict_list.append({'params':self.pars, 'lr':opt.lr*opt.loss_oproxy_lrmulti_hyper}) 49 | 50 | ### 51 | self.mode = opt.loss_oproxy_mode 52 | self.detach_proxies = opt.loss_oproxy_detach_proxies 53 | self.euclidean = opt.loss_oproxy_euclidean 54 | self.d_mode = 'euclidean' if self.euclidean else 'cosine' 55 | 56 | ### 57 | self.delta_flip = opt.loss_oproxy_delta_flip 58 | self.prob_mode = opt.loss_oproxy_prob_mode 59 | self.nca_clean = opt.loss_oproxy_nca_clean 60 | self.msim_style = opt.loss_oproxy_msim_style 61 | self.unique = opt.loss_oproxy_unique 62 | 63 | def prep(self, thing): 64 | return 1.*torch.nn.functional.normalize(thing, dim=1) 65 | 66 | 67 | def forward(self, batch, labels, aux_batch=None): 68 | """ 69 | Args: 70 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 71 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 72 | """ 73 | ### 74 | bs = len(batch) 75 | batch = self.prep(batch) 76 | self.labels = labels.unsqueeze(1) 77 | 78 | ### 79 | if self.unique: 80 | self.u_labels = torch.unique(self.labels.view(-1)) 81 | else: 82 | self.u_labels, self.freq = self.labels.view(-1), None 83 | self.same_labels = (self.labels.T == self.u_labels.view(-1,1)).to(batch.device).T 84 | self.diff_labels = (self.class_idxs.unsqueeze(1) != self.labels.T).to(torch.float).to(batch.device).T 85 | if self.prob_mode: self.diff_labels = torch.ones_like(self.diff_labels).to(torch.float).to(batch.device) 86 | 87 | ### 88 | if self.mode == "anchor": 89 | self.dim = 0 90 | elif self.mode == "nca": 91 | self.dim = 1 92 | 93 | ### 94 | loss = self.compute_proxyloss(batch, detach_proxies=self.detach_proxies) 95 | 96 | ### 97 | return loss 98 | 99 | ### 100 | def compute_proxyloss(self, batch, detach_proxies=False): 101 | proxies = self.prep(self.proxies) 102 | if detach_proxies: proxies = proxies.detach() 103 | pars = {k:-p if self.euclidean and 'alpha' in k else p for k,p in self.pars.items()} 104 | ### 105 | pos_sims = self.smat(batch, proxies[self.u_labels], mode=self.d_mode) 106 | sims = self.smat(batch, proxies, mode=self.d_mode) 107 | ### 108 | w_pos_sims = -pars['pos_alpha']*(pos_sims-pars['pos_delta']) 109 | w_neg_sims = pars['neg_alpha']*(sims-pars['neg_delta']) 110 | ### 111 | pos_s = self.masked_logsumexp(w_pos_sims,mask=self.same_labels,dim=self.dim,max=True if self.d_mode=='euclidean' else False) 112 | neg_s = self.masked_logsumexp(w_neg_sims,mask=self.diff_labels,dim=self.dim,max=False if self.d_mode=='euclidean' else True) 113 | 114 | if not self.nca_clean: 115 | pos_s = torch.nn.Softplus()(pos_s) 116 | neg_s = torch.nn.Softplus()(neg_s) 117 | 118 | if self.msim_style: 119 | pos_s = 1/pars['pos_alpha']*pos_s 120 | neg_s = 1/pars['neg_alpha']*neg_s 121 | 122 | pos_s, neg_s = pos_s.mean(), neg_s.mean() 123 | loss = pos_s + neg_s 124 | return loss 125 | 126 | ### 127 | def smat(self, A, B, mode='cosine'): 128 | if mode=='cosine': 129 | return A.mm(B.T) 130 | elif mode=='euclidean': 131 | As, Bs = A.shape, B.shape 132 | return (A.view(As[0],1,As[1])-B.view(1,Bs[0],Bs[1])).pow(2).sum(-1).sqrt() 133 | 134 | ### 135 | def masked_logsumexp(self, sims, dim=0, mask=None, max=True): 136 | if mask is None: 137 | return torch.logsumexp(sims, dim=dim) 138 | else: 139 | if not max: 140 | ref_v = (sims*mask).min(dim=dim, keepdim=True)[0] 141 | else: 142 | ref_v = (sims*mask).max(dim=dim, keepdim=True)[0] 143 | 144 | nz_entries = (sims*mask) 145 | nz_entries = nz_entries.max(dim=dim,keepdim=True)[0]+nz_entries.min(dim=dim,keepdim=True)[0] 146 | nz_entries = torch.where(nz_entries.view(-1))[0].view(-1) 147 | 148 | if not len(nz_entries): 149 | return torch.tensor(0).to(torch.float).to(sims.device) 150 | else: 151 | return torch.log((torch.sum(torch.exp(sims-ref_v.detach())*mask,dim=dim)).view(-1)[nz_entries])+ref_v.detach().view(-1)[nz_entries] 152 | 153 | # return torch.log((torch.sum(torch.exp(sims)*mask,dim=dim)).view(-1))[nz_entries] 154 | -------------------------------------------------------------------------------- /criteria/contrastive.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = False 9 | 10 | ### Standard Triplet Loss, finds triplets in Mini-batches. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | """ 14 | Args: 15 | margin: Triplet Margin. 16 | """ 17 | super(Criterion, self).__init__() 18 | self.pos_margin = opt.loss_contrastive_pos_margin 19 | self.neg_margin = opt.loss_contrastive_neg_margin 20 | self.batchminer = batchminer 21 | 22 | self.name = 'contrastive' 23 | 24 | #### 25 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 26 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 27 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 28 | 29 | def forward(self, batch, labels, **kwargs): 30 | """ 31 | Args: 32 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 33 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 34 | """ 35 | sampled_triplets = self.batchminer(batch, labels) 36 | 37 | anchors = [triplet[0] for triplet in sampled_triplets] 38 | positives = [triplet[1] for triplet in sampled_triplets] 39 | negatives = [triplet[2] for triplet in sampled_triplets] 40 | 41 | pos_dists = torch.mean(F.relu(nn.PairwiseDistance(p=2)(batch[anchors,:], batch[positives,:]) - self.pos_margin)) 42 | neg_dists = torch.mean(F.relu(self.neg_margin - nn.PairwiseDistance(p=2)(batch[anchors,:], batch[negatives,:]))) 43 | 44 | loss = pos_dists + neg_dists 45 | 46 | return loss 47 | -------------------------------------------------------------------------------- /criteria/fast_moco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | from tqdm import tqdm 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | REQUIRES_EMA_NETWORK = True 11 | 12 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 13 | class Criterion(torch.nn.Module): 14 | def __init__(self, opt): 15 | """ 16 | Args: 17 | margin: Triplet Margin. 18 | nu: Regularisation Parameter for beta values if they are learned. 19 | beta: Class-Margin values. 20 | n_classes: Number of different classes during training. 21 | """ 22 | super(Criterion, self).__init__() 23 | 24 | self.temperature = opt.diva_moco_temperature 25 | self.momentum = opt.diva_moco_momentum 26 | self.n_key_batches = opt.diva_moco_n_key_batches 27 | 28 | 29 | if opt.diva_moco_trainable_temp: 30 | self.temperature = torch.nn.Parameter(torch.tensor(self.temperature).to(torch.float)) 31 | self.lr = opt.diva_moco_temp_lr 32 | 33 | self.name = 'fast_moco' 34 | self.reference_labels = torch.zeros(opt.bs).to(torch.long).to(opt.device) 35 | 36 | self.lower_cutoff = opt.diva_moco_lower_cutoff 37 | self.upper_cutoff = opt.diva_moco_upper_cutoff 38 | 39 | #### 40 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 41 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 42 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 43 | 44 | 45 | def update_memory_queue(self, embeddings): 46 | self.memory_queue = self.memory_queue[len(embeddings):,:] 47 | self.memory_queue = torch.cat([self.memory_queue, embeddings], dim=0) 48 | 49 | def create_memory_queue(self, model, dataloader, device, opt_key=None): 50 | with torch.no_grad(): 51 | _ = model.eval() 52 | _ = model.to(device) 53 | 54 | self.memory_queue = [] 55 | counter = 0 56 | load_count = 0 57 | total_count = self.n_key_batches//len(dataloader) + int(self.n_key_batches%len(dataloader)!=0) 58 | while counter=self.n_key_batches: 74 | break 75 | 76 | self.memory_queue = torch.cat(self.memory_queue, dim=0).to(device) 77 | 78 | self.n_keys = len(self.memory_queue) 79 | 80 | def shuffleBN(self, bs): 81 | forward_inds = torch.randperm(bs).long().cuda() 82 | backward_inds = torch.zeros(bs).long().cuda() 83 | value = torch.arange(bs).long().cuda() 84 | backward_inds.index_copy_(0, forward_inds, value) 85 | return forward_inds, backward_inds 86 | 87 | 88 | def forward(self, query_batch, key_batch, **kwargs): 89 | """ 90 | Args: 91 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 92 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 93 | """ 94 | bs = len(query_batch) 95 | 96 | l_pos = query_batch.view(bs, 1, -1).bmm(key_batch.view(bs, -1, 1)).squeeze(-1) 97 | l_neg = query_batch.view(bs, -1).mm(self.memory_queue.T) 98 | 99 | ### Compute Distance Matrix 100 | bs,dim = len(query_batch),query_batch.shape[-1] 101 | 102 | ab = torch.mm(query_batch, self.memory_queue.T).detach() 103 | a2 = torch.nn.CosineSimilarity()(query_batch, query_batch).unsqueeze(1).expand_as(ab).detach() 104 | b2 = torch.nn.CosineSimilarity()(self.memory_queue, self.memory_queue).unsqueeze(0).expand_as(ab).detach() 105 | #Euclidean Distances 106 | distance_weighting = (-2.*ab+a2+b2).clamp(min=0).sqrt() 107 | distances = (-2.*ab+a2+b2).clamp(min=0).sqrt().clamp(min=self.lower_cutoff) 108 | 109 | #Likelihood Weighting 110 | distance_weighting = ((2.0 - float(dim)) * torch.log(distances) - (float(dim-3) / 2) * torch.log(1.0 - 0.25 * (distances.pow(2)))) 111 | distance_weighting = torch.exp(distance_weighting - torch.max(distance_weighting)) 112 | distance_weighting[distances>self.upper_cutoff] = 0 113 | distance_weighting = distance_weighting.clamp(min=1e-45) 114 | distance_weighting = distance_weighting/torch.sum(distance_weighting, dim=0) 115 | 116 | ### 117 | l_neg = l_neg*distance_weighting 118 | 119 | ### INCLUDE SHUFFLE BN 120 | logits = torch.cat([l_pos, l_neg], dim=1) 121 | 122 | if isinstance(self.temperature, torch.Tensor): 123 | loss = torch.nn.CrossEntropyLoss()(logits/self.temperature.clamp(min=1e-8, max=1e4), self.reference_labels) 124 | else: 125 | loss = torch.nn.CrossEntropyLoss()(logits/self.temperature, self.reference_labels) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /criteria/imrot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | from tqdm import tqdm 5 | 6 | 7 | """=================================================================================================""" 8 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 9 | REQUIRES_BATCHMINER = False 10 | REQUIRES_OPTIM = True 11 | REQUIRES_EMA_NETWORK = True 12 | 13 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 14 | class Criterion(torch.nn.Module): 15 | def __init__(self, opt): 16 | """ 17 | Args: 18 | margin: Triplet Margin. 19 | nu: Regularisation Parameter for beta values if they are learned. 20 | beta: Class-Margin values. 21 | n_classes: Number of different classes during training. 22 | """ 23 | super(Criterion, self).__init__() 24 | self.classifier = torch.nn.Linear(opt.network_feature_dim, 4, bias=False).to(opt.device) 25 | self.lr = opt.lr * 10 26 | self.name = 'imrot' 27 | 28 | #### 29 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 30 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 31 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 32 | 33 | def forward(self, feature_batch, imrot_labels, **kwargs): 34 | """ 35 | Args: 36 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 37 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 38 | """ 39 | pred_batch = self.classifier(feature_batch) 40 | loss = torch.nn.CrossEntropyLoss()(pred_batch, imrot_labels.to(pred_batch.device)) 41 | return loss 42 | -------------------------------------------------------------------------------- /criteria/margin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = True 9 | 10 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | """ 14 | Args: 15 | margin: Triplet Margin. 16 | nu: Regularisation Parameter for beta values if they are learned. 17 | beta: Class-Margin values. 18 | n_classes: Number of different classes during training. 19 | """ 20 | super(Criterion, self).__init__() 21 | self.n_classes = opt.n_classes 22 | 23 | self.margin = opt.loss_margin_margin 24 | self.nu = opt.loss_margin_nu 25 | self.beta_constant = opt.loss_margin_beta_constant 26 | self.beta_val = opt.loss_margin_beta 27 | 28 | if opt.loss_margin_beta_constant: 29 | self.beta = opt.loss_margin_beta 30 | else: 31 | self.beta = torch.nn.Parameter(torch.ones(opt.n_classes)*opt.loss_margin_beta) 32 | 33 | self.batchminer = batchminer 34 | 35 | self.name = 'margin' 36 | 37 | self.lr = opt.loss_margin_beta_lr 38 | 39 | #### 40 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 41 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 42 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 43 | 44 | def forward(self, batch, labels, **kwargs): 45 | """ 46 | Args: 47 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 48 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 49 | """ 50 | sampled_triplets = self.batchminer(batch, labels) 51 | 52 | if len(sampled_triplets): 53 | d_ap, d_an = [],[] 54 | for triplet in sampled_triplets: 55 | train_triplet = {'Anchor': batch[triplet[0],:], 'Positive':batch[triplet[1],:], 'Negative':batch[triplet[2]]} 56 | 57 | pos_dist = ((train_triplet['Anchor']-train_triplet['Positive']).pow(2).sum()+1e-8).pow(1/2) 58 | neg_dist = ((train_triplet['Anchor']-train_triplet['Negative']).pow(2).sum()+1e-8).pow(1/2) 59 | 60 | d_ap.append(pos_dist) 61 | d_an.append(neg_dist) 62 | d_ap, d_an = torch.stack(d_ap), torch.stack(d_an) 63 | 64 | if self.beta_constant: 65 | beta = self.beta 66 | else: 67 | beta = torch.stack([self.beta[labels[triplet[0]]] for triplet in sampled_triplets]).to(torch.float).to(d_ap.device) 68 | 69 | pos_loss = torch.nn.functional.relu(d_ap-beta+self.margin) 70 | neg_loss = torch.nn.functional.relu(beta-d_an+self.margin) 71 | 72 | pair_count = torch.sum((pos_loss>0.)+(neg_loss>0.)).to(torch.float).to(d_ap.device) 73 | 74 | if pair_count == 0.: 75 | loss = torch.sum(pos_loss+neg_loss) 76 | else: 77 | loss = torch.sum(pos_loss+neg_loss)/pair_count 78 | 79 | if self.nu: loss = loss + beta_regularisation_loss.to(torch.float).to(d_ap.device) 80 | else: 81 | loss = torch.tensor(0.).to(torch.float).to(batch.device) 82 | 83 | return loss 84 | -------------------------------------------------------------------------------- /criteria/moco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | from tqdm import tqdm 5 | 6 | 7 | """=================================================================================================""" 8 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 9 | REQUIRES_BATCHMINER = False 10 | REQUIRES_OPTIM = True 11 | REQUIRES_EMA_NETWORK = True 12 | 13 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 14 | class Criterion(torch.nn.Module): 15 | def __init__(self, opt): 16 | """ 17 | Args: 18 | margin: Triplet Margin. 19 | nu: Regularisation Parameter for beta values if they are learned. 20 | beta: Class-Margin values. 21 | n_classes: Number of different classes during training. 22 | """ 23 | super(Criterion, self).__init__() 24 | 25 | #### 26 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 27 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 28 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 29 | 30 | #### 31 | self.temperature = opt.diva_moco_temperature 32 | if opt.diva_moco_trainable_temp: 33 | self.temperature = torch.nn.Parameter(torch.tensor(self.temperature).to(torch.float)) 34 | 35 | self.lr = opt.diva_moco_temp_lr 36 | self.momentum = opt.diva_moco_momentum 37 | self.n_key_batches = opt.diva_moco_n_key_batches 38 | 39 | self.name = 'moco' 40 | self.reference_labels = torch.zeros(opt.bs).to(torch.long).to(opt.device) 41 | 42 | #### 43 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 44 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 45 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 46 | 47 | 48 | def update_memory_queue(self, embeddings): 49 | self.memory_queue = self.memory_queue[len(embeddings):,:] 50 | self.memory_queue = torch.cat([self.memory_queue, embeddings], dim=0) 51 | 52 | def create_memory_queue(self, model, dataloader, device, opt_key=None): 53 | with torch.no_grad(): 54 | _ = model.eval() 55 | _ = model.to(device) 56 | 57 | self.memory_queue = [] 58 | counter = 0 59 | load_count = 0 60 | total_count = self.n_key_batches//len(dataloader) + int(self.n_key_batches%len(dataloader)!=0) 61 | while counter=self.n_key_batches: 76 | break 77 | 78 | self.memory_queue = torch.cat(self.memory_queue, dim=0).to(device) 79 | 80 | self.n_keys = len(self.memory_queue) 81 | 82 | def shuffleBN(self, bs): 83 | forward_inds = torch.randperm(bs).long().cuda() 84 | backward_inds = torch.zeros(bs).long().cuda() 85 | value = torch.arange(bs).long().cuda() 86 | backward_inds.index_copy_(0, forward_inds, value) 87 | return forward_inds, backward_inds 88 | 89 | 90 | def forward(self, query_batch, key_batch, **kwargs): 91 | """ 92 | Args: 93 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 94 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 95 | """ 96 | bs = len(query_batch) 97 | 98 | l_pos = query_batch.view(bs, 1, -1).bmm(key_batch.view(bs, -1, 1)).squeeze(-1) 99 | l_neg = query_batch.view(bs, -1).mm(self.memory_queue.T) 100 | 101 | ### INCLUDE SHUFFLE BN 102 | logits = torch.cat([l_pos, l_neg], dim=1) 103 | 104 | if isinstance(self.temperature, torch.Tensor): 105 | loss = torch.nn.CrossEntropyLoss()(logits/self.temperature.clamp(min=1e-8, max=1e4), self.reference_labels) 106 | else: 107 | loss = torch.nn.CrossEntropyLoss()(logits/self.temperature, self.reference_labels) 108 | 109 | return loss 110 | -------------------------------------------------------------------------------- /criteria/multisimilarity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import torch, torch.nn as nn, torch.nn.functional as F 4 | import batchminer 5 | import copy 6 | 7 | 8 | """=================================================================================================""" 9 | ALLOWED_MINING_OPS = None 10 | REQUIRES_BATCHMINER = False 11 | REQUIRES_OPTIM = False 12 | 13 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 14 | class Criterion(torch.nn.Module): 15 | def __init__(self, opt, **kwargs): 16 | """ 17 | Args: 18 | margin: Triplet Margin. 19 | nu: Regularisation Parameter for beta values if they are learned. 20 | beta: Class-Margin values. 21 | n_classes: Number of different classes during training. 22 | """ 23 | super(Criterion, self).__init__() 24 | self.pars = opt 25 | 26 | 27 | self.n_classes = opt.n_classes 28 | 29 | self.pos_weight = opt.loss_multisimilarity_pos_weight 30 | self.neg_weight = opt.loss_multisimilarity_neg_weight 31 | self.margin = opt.loss_multisimilarity_margin 32 | self.pos_thresh = opt.loss_multisimilarity_pos_thresh 33 | self.neg_thresh = opt.loss_multisimilarity_neg_thresh 34 | self.d_mode = opt.loss_multisimilarity_d_mode 35 | self.name = 'multisimilarity' 36 | 37 | self.lr = opt.lr 38 | 39 | #### 40 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 41 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 42 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 43 | 44 | 45 | def forward(self, batch, labels, **kwargs): 46 | """ 47 | Args: 48 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 49 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 50 | """ 51 | bs = len(batch) 52 | self.dim = 0 53 | self.embed_dim = batch.shape[-1] 54 | self.similarity = self.smat(batch, batch, self.d_mode) 55 | 56 | ### 57 | if self.d_mode=='euclidean': 58 | pos_weight = -1.*self.pos_weight 59 | neg_weight = -1.*self.neg_weight 60 | else: 61 | pos_weight = self.pos_weight 62 | neg_weight = self.neg_weight 63 | 64 | ### 65 | w_pos_sims = -pos_weight*(self.similarity-self.pos_thresh) 66 | w_neg_sims = neg_weight*(self.similarity-self.neg_thresh) 67 | 68 | ### 69 | labels = labels.unsqueeze(1) 70 | self.bsame_labels = (labels.T == labels.view(-1,1)).to(batch.device).T 71 | self.bdiff_labels = (labels.T != labels.view(-1,1)).to(batch.device).T 72 | 73 | ### Compute MultiSimLoss 74 | pos_mask, neg_mask = self.sample_mask(self.similarity) 75 | self.pos_mask, self.neg_mask = pos_mask, neg_mask 76 | 77 | pos_s = self.masked_logsumexp(w_pos_sims, mask=pos_mask, dim=self.dim, max=True if self.d_mode=='euclidean' else False) 78 | neg_s = self.masked_logsumexp(w_neg_sims, mask=neg_mask, dim=self.dim, max=False if self.d_mode=='euclidean' else True) 79 | 80 | ### 81 | pos_s, neg_s = 1./np.abs(pos_weight)*torch.nn.Softplus()(pos_s), 1./np.abs(neg_weight)*torch.nn.Softplus()(neg_s) 82 | pos_s, neg_s = pos_s.mean(), neg_s.mean() 83 | loss = pos_s + neg_s 84 | 85 | 86 | return loss 87 | 88 | 89 | ### 90 | def sample_mask(self, sims): 91 | ### Get Indices/Sampling Bounds 92 | bsame_labels = copy.deepcopy(self.bsame_labels) 93 | bdiff_labels = copy.deepcopy(self.bdiff_labels) 94 | pos_bound, neg_bound = [], [] 95 | bound = [] 96 | for i in range(len(sims)): 97 | pos_ixs = bsame_labels[i] 98 | neg_ixs = bdiff_labels[i] 99 | pos_ixs[i] = False 100 | pos_bsims = self.similarity[i][pos_ixs] 101 | neg_bsims = self.similarity[i][neg_ixs] 102 | if self.d_mode=='euclidean': 103 | pos_bound.append(pos_bsims.max()) 104 | neg_bound.append(neg_bsims.min()) 105 | else: 106 | pos_bound.append(pos_bsims.min()) 107 | neg_bound.append(neg_bsims.max()) 108 | pos_bound, neg_bound = torch.stack(pos_bound), torch.stack(neg_bound) 109 | ### Get LogSumExp-Masks 110 | if self.d_mode=='euclidean': 111 | self.neg_mask = neg_mask = self.bdiff_labels*((self.similarity - self.margin) < pos_bound) 112 | self.pos_mask = pos_mask = self.bsame_labels*((self.similarity + self.margin) > neg_bound) 113 | else: 114 | self.neg_mask = neg_mask = self.bdiff_labels*((self.similarity + self.margin) > pos_bound) 115 | self.pos_mask = pos_mask = self.bsame_labels*((self.similarity - self.margin) < neg_bound) 116 | 117 | return pos_mask, neg_mask 118 | 119 | 120 | ### 121 | def smat(self, A, B, mode='cosine'): 122 | if mode=='cosine': 123 | return A.mm(B.T) 124 | elif mode=='euclidean': 125 | return (A.mm(A.T).diag().unsqueeze(-1)+B.mm(B.T).diag().unsqueeze(0)-2*A.mm(B.T)).clamp(min=1e-20).sqrt() 126 | 127 | 128 | ### 129 | def masked_logsumexp(self, sims, dim=0, mask=None, max=True): 130 | if mask is None: 131 | return torch.logsumexp(sims, dim=dim) 132 | else: 133 | if not max: 134 | ref_v = (sims*mask).min(dim=dim, keepdim=True)[0] 135 | else: 136 | ref_v = (sims*mask).max(dim=dim, keepdim=True)[0] 137 | 138 | nz_entries = (sims*mask) 139 | nz_entries = nz_entries.max(dim=dim,keepdim=True)[0]+nz_entries.min(dim=dim,keepdim=True)[0] 140 | nz_entries = torch.where(nz_entries.view(-1))[0].view(-1) 141 | 142 | if not len(nz_entries): 143 | return torch.tensor(0).to(torch.float).to(sims.device) 144 | else: 145 | return torch.log((torch.sum(torch.exp(sims-ref_v.detach())*mask,dim=dim)).view(-1)[nz_entries])+ref_v.detach().view(-1)[nz_entries] 146 | -------------------------------------------------------------------------------- /criteria/oproxy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | import criteria 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = None 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | 11 | 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | opt: Namespace containing all relevant parameters. 17 | """ 18 | super(Criterion, self).__init__() 19 | 20 | self.pars = opt 21 | 22 | #### 23 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 24 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 25 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 26 | 27 | #### 28 | self.num_proxies = opt.n_classes 29 | self.embed_dim = opt.embed_dim 30 | 31 | self.proxies = torch.randn(self.num_proxies, self.embed_dim)/8 32 | self.proxies = torch.nn.Parameter(self.proxies) 33 | proxy_optim_dict = {'params':self.proxies, 'lr':opt.lr * opt.loss_oproxy_lrmulti} 34 | 35 | self.optim_dict_list = [] 36 | self.optim_dict_list.append(proxy_optim_dict) 37 | 38 | ### 39 | self.class_idxs = torch.arange(self.num_proxies) 40 | 41 | self.name = 'oproxy' 42 | 43 | pars = {'pos_alpha':opt.loss_oproxy_pos_alpha, 44 | 'pos_delta':opt.loss_oproxy_pos_delta, 45 | 'neg_alpha':opt.loss_oproxy_neg_alpha, 46 | 'neg_delta':opt.loss_oproxy_neg_delta} 47 | self.pars = pars 48 | 49 | ### 50 | self.mode = opt.loss_oproxy_mode 51 | self.detach_proxies = opt.loss_oproxy_detach_proxies 52 | self.euclidean = opt.loss_oproxy_euclidean 53 | self.d_mode = 'euclidean' if self.euclidean else 'cosine' 54 | 55 | ### 56 | self.f_soft = torch.nn.Softplus() 57 | self.optim_dict_list.append({'params':self.f_soft.parameters(), 'lr':opt.lr*opt.loss_oproxy_lrmulti}) 58 | 59 | ### 60 | self.warmup_it = opt.loss_oproxy_warmup_it 61 | self.it_count = 0 62 | 63 | def prep(self, thing): 64 | return 1.*torch.nn.functional.normalize(thing, dim=1) 65 | 66 | 67 | def forward(self, batch, labels, **kwargs): 68 | """ 69 | Args: 70 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 71 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 72 | """ 73 | ### 74 | bs = len(batch) 75 | batch = self.prep(batch) 76 | 77 | self.labels = labels.unsqueeze(1) 78 | 79 | ### 80 | self.u_labels, self.freq = self.labels.view(-1), None 81 | self.same_labels = (self.labels.T == self.u_labels.view(-1,1)).to(batch.device).T 82 | self.diff_labels = (self.class_idxs.unsqueeze(1) != self.labels.T).to(torch.float).to(batch.device).T 83 | 84 | ### 85 | if self.mode == "anchor": 86 | self.dim = 0 87 | elif self.mode == "nca": 88 | self.dim = 1 89 | 90 | ### 91 | loss = self.compute_proxyloss(batch, detach_proxies=self.detach_proxies) 92 | self.it_count += 1 93 | 94 | ### 95 | return loss 96 | 97 | ### 98 | def compute_proxyloss(self, batch, detach_proxies=False): 99 | proxies = self.prep(self.proxies) 100 | if detach_proxies: proxies = proxies.detach() 101 | pars = {k:-p if self.euclidean and 'alpha' in k else p for k,p in self.pars.items()} 102 | ### 103 | pos_sims = self.smat(batch, proxies[self.u_labels], mode=self.d_mode) 104 | sims = self.smat(batch, proxies, mode=self.d_mode) 105 | ### 106 | w_pos_sims = -pars['pos_alpha']*(pos_sims-pars['pos_delta']) 107 | w_neg_sims = pars['neg_alpha']*(sims-pars['neg_delta']) 108 | ### 109 | # self.label_smooth = 0.1 110 | # same_labs = utils.one_hot(labels_spt.reshape(-1), self.num_proxies) 111 | # same_labs = same_labs * (1 - self.label_smoot) + (1 - same_labs) * self.label_smoot / (self.num_proxies - 1) 112 | 113 | # pos_s = self.masked_logsumexp(w_pos_sims,mask=self.label_smooth,dim=self.dim,max=True if self.d_mode=='euclidean' else False) 114 | pos_s = self.masked_logsumexp(w_pos_sims,mask=self.same_labels,dim=self.dim,max=True if self.d_mode=='euclidean' else False) 115 | neg_s = self.masked_logsumexp(w_neg_sims,mask=self.diff_labels,dim=self.dim,max=False if self.d_mode=='euclidean' else True) 116 | 117 | pos_s = self.f_soft(pos_s) 118 | neg_s = self.f_soft(neg_s) 119 | 120 | pos_s, neg_s = pos_s.mean(), neg_s.mean() 121 | loss = pos_s + neg_s 122 | return loss 123 | 124 | ### 125 | def smat(self, A, B, mode='cosine'): 126 | if mode=='cosine': 127 | return A.mm(B.T) 128 | elif mode=='euclidean': 129 | As, Bs = A.shape, B.shape 130 | return (A.mm(A.T).diag().unsqueeze(-1)+B.mm(B.T).diag().unsqueeze(0)-2*A.mm(B.T)).clamp(min=1e-20).sqrt() 131 | 132 | ### 133 | def masked_logsumexp(self, sims, dim=0, mask=None, max=True): 134 | if mask is None: 135 | return torch.logsumexp(sims, dim=dim) 136 | else: 137 | if not max: 138 | ref_v = (sims*mask).min(dim=dim, keepdim=True)[0] 139 | else: 140 | ref_v = (sims*mask).max(dim=dim, keepdim=True)[0] 141 | 142 | nz_entries = (sims*mask) 143 | nz_entries = nz_entries.max(dim=dim,keepdim=True)[0]+nz_entries.min(dim=dim,keepdim=True)[0] 144 | nz_entries = torch.where(nz_entries.view(-1))[0].view(-1) 145 | 146 | if not len(nz_entries): 147 | return torch.tensor(0).to(torch.float).to(sims.device) 148 | else: 149 | return torch.log((torch.sum(torch.exp(sims-ref_v.detach())*mask,dim=dim)).view(-1)[nz_entries])+ref_v.detach().view(-1)[nz_entries] 150 | 151 | # return torch.log((torch.sum(torch.exp(sims)*mask,dim=dim)).view(-1))[nz_entries] 152 | -------------------------------------------------------------------------------- /criteria/proxynca.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = None 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | 11 | 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | opt: Namespace containing all relevant parameters. 17 | """ 18 | super(Criterion, self).__init__() 19 | 20 | #### 21 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 22 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 23 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 24 | 25 | #### 26 | self.num_proxies = opt.n_classes 27 | self.embed_dim = opt.embed_dim 28 | 29 | self.proxies = torch.nn.Parameter(torch.randn(self.num_proxies, self.embed_dim)/8) 30 | self.class_idxs = torch.arange(self.num_proxies) 31 | 32 | self.name = 'proxynca' 33 | 34 | self.lr = opt.lr * opt.loss_proxynca_lrmulti 35 | 36 | self.sphereradius = opt.loss_proxynca_sphereradius 37 | self.T = opt.loss_proxynca_temperature 38 | self.convert_to_p = opt.loss_proxynca_convert_to_p 39 | self.cosine = opt.loss_proxynca_cosine_dist 40 | self.sq_dist = opt.loss_proxynca_sq_dist 41 | 42 | 43 | def forward(self, batch, labels, **kwargs): 44 | """ 45 | Args: 46 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 47 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 48 | """ 49 | #Empirically, multiplying the embeddings during the computation of the loss seem to allow for more stable training; 50 | #presumably due to increased loss value. 51 | batch = self.sphereradius*torch.nn.functional.normalize(batch, dim=1) 52 | proxies = self.sphereradius*torch.nn.functional.normalize(self.proxies, dim=1) 53 | 54 | #Loss based on distance to positive proxies 55 | if self.cosine: 56 | dist_to_pos_proxies = batch.unsqueeze(1).bmm(proxies[labels].unsqueeze(2)).squeeze(-1).squeeze(-1) 57 | else: 58 | if self.sq_dist: 59 | dist_to_pos_proxies = -(batch-proxies[labels]).pow(2).sum(-1).sqrt() 60 | else: 61 | dist_to_pos_proxies = -(batch-proxies[labels]).pow(2).sum(-1) 62 | 63 | loss_pos = torch.mean(-dist_to_pos_proxies/self.T) 64 | 65 | #Loss based on distance to negative (or all) proxies 66 | if not self.convert_to_p: 67 | batch_neg_idxs = labels.unsqueeze(1) != self.class_idxs.unsqueeze(1).T 68 | else: 69 | batch_neg_idxs = torch.ones((len(batch),self.num_proxies)).bool().to(labels.device) 70 | 71 | loss_neg = 0 72 | for neg_idxs, sample in zip(batch_neg_idxs, batch): 73 | if self.cosine: 74 | dist_to_neg_proxies = -sample.unsqueeze(0).mm(proxies[neg_idxs,:].T).squeeze(0) 75 | else: 76 | if self.sq_dist: 77 | dist_to_neg_proxies = (sample.unsqueeze(0)-proxies[neg_idxs,:]).pow(2).sum(1).sqrt() 78 | else: 79 | dist_to_neg_proxies = (sample.unsqueeze(0)-proxies[neg_idxs,:]).pow(2).sum(1) 80 | 81 | 82 | loss_neg += torch.logsumexp(-dist_to_neg_proxies, dim=-1) 83 | loss_neg /= len(batch) 84 | 85 | loss = loss_pos + loss_neg 86 | # neg_proxies = torch.stack([torch.cat([self.class_idxs[:class_label],self.class_idxs[class_label+1:]]) for class_label in labels]) 87 | # neg_proxies = torch.stack([proxies[neg_labels,:] for neg_labels in neg_proxies]) 88 | # else: 89 | # neg_proxies = torch.stack(.) 90 | # #Compute Proxy-distances 91 | # dist_to_neg_proxies = torch.sum((batch[:,None,:]-neg_proxies).pow(2),dim=-1) 92 | # #Compute final proxy-based NCA loss 93 | # negative_log_proxy_nca_loss = torch.mean(dist_to_pos_proxies[:,0]/self.T + torch.logsumexp(-dist_to_neg_proxies/self.T, dim=1)) 94 | # else: 95 | # norm_proxies = 96 | return loss 97 | -------------------------------------------------------------------------------- /criteria/quadruplet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | """=================================================================================================""" 5 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 6 | REQUIRES_BATCHMINER = True 7 | REQUIRES_OPTIM = False 8 | 9 | 10 | class Criterion(torch.nn.Module): 11 | def __init__(self, opt, batchminer): 12 | """ 13 | Args: 14 | margin: Triplet Margin. 15 | """ 16 | super(Criterion, self).__init__() 17 | self.batchminer = batchminer 18 | 19 | self.name = 'quadruplet' 20 | 21 | self.margin_alpha_1 = opt.loss_quadruplet_margin_alpha_1 22 | self.margin_alpha_2 = opt.loss_quadruplet_margin_alpha_2 23 | 24 | #### 25 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 26 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 27 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 28 | 29 | def triplet_distance(self, anchor, positive, negative): 30 | return torch.nn.functional.relu(torch.norm(anchor-positive, p=2, dim=-1)-torch.norm(anchor-negative, p=2, dim=-1)+self.margin_alpha_1) 31 | 32 | def quadruplet_distance(self, anchor, positive, negative, fourth_negative): 33 | return torch.nn.functional.relu(torch.norm(anchor-positive, p=2, dim=-1)-torch.norm(negative-fourth_negative, p=2, dim=-1)+self.margin_alpha_2) 34 | 35 | def forward(self, batch, labels, **kwargs): 36 | """ 37 | Args: 38 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 39 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 40 | """ 41 | sampled_triplets = self.batchminer(batch, labels) 42 | 43 | anchors = np.array([triplet[0] for triplet in sampled_triplets]).reshape(-1,1) 44 | positives = np.array([triplet[1] for triplet in sampled_triplets]).reshape(-1,1) 45 | negatives = np.array([triplet[2] for triplet in sampled_triplets]).reshape(-1,1) 46 | 47 | fourth_negatives = negatives!=negatives.T 48 | fourth_negatives = [np.random.choice(np.arange(len(batch))[idxs]) for idxs in fourth_negatives] 49 | 50 | triplet_loss = self.triplet_distance(batch[anchors,:],batch[positives,:],batch[negatives,:]) 51 | quadruplet_loss = self.quadruplet_distance(batch[anchors,:],batch[positives,:],batch[negatives,:],batch[fourth_negatives,:]) 52 | # triplet_loss = torch.stack([self.triplet_distance(batch[anchor,:],batch[positive,:],batch[negative,:]) for anchor,positive,negative in zip(anchors, positives, negatives)]) 53 | # quadruplet_loss = torch.stack([self.quadruplet_distance(batch[anchor,:],batch[positive,:],batch[negative,:],batch[fourth_negative,:]) for anchor,positive,negative,fourth_negative in zip(anchors, positives, negatives, fourth_negatives)]) 54 | 55 | return torch.mean(triplet_loss) + torch.mean(quadruplet_loss) 56 | -------------------------------------------------------------------------------- /criteria/s2sd.py: -------------------------------------------------------------------------------- 1 | import numpy as np, copy 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer as bmine 4 | import criteria 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = None 8 | REQUIRES_BATCHMINER = False 9 | REQUIRES_OPTIM = True 10 | 11 | 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt): 14 | """ 15 | Args: 16 | opt: Namespace containing all relevant parameters. 17 | """ 18 | super(Criterion, self).__init__() 19 | 20 | self.opt = opt 21 | 22 | #### Some base flags and parameters 23 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 24 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 25 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 26 | self.name = 'S2SD' 27 | self.d_mode = 'cosine' 28 | self.iter_count = 0 29 | self.embed_dim = opt.embed_dim 30 | 31 | ### Will contain all parameters to be optimized, e.g. the target MLPs and 32 | ### potential parameters of training criteria. 33 | self.optim_dict_list = [] 34 | 35 | ### All S2SD-specific Parameters 36 | self.T = opt.loss_s2sd_T 37 | self.w = opt.loss_s2sd_w 38 | self.feat_w = opt.loss_s2sd_feat_w 39 | self.pool_aggr = opt.loss_s2sd_pool_aggr 40 | self.match_feats = opt.loss_s2sd_feat_distill 41 | self.max_feat_iter = opt.loss_s2sd_feat_distill_delay 42 | 43 | ### Initialize all target networks as two-layer MLPs 44 | if 'resnet50' in opt.arch: 45 | f_dim = 2048 46 | elif 'resnet18' in opt.arch: 47 | f_dim = 512 48 | elif 'bninception' in opt.arch: 49 | f_dim = 1024 50 | elif 'efficient' in opt.arch: 51 | f_dim = 1280 52 | else: 53 | f_dim = 2048 54 | 55 | self.target_nets = torch.nn.ModuleList([nn.Sequential(nn.Linear(f_dim, t_dim), nn.ReLU(), nn.Linear(t_dim, t_dim)) for t_dim in opt.loss_s2sd_target_dims]) 56 | self.optim_dict_list.append({'params':self.target_nets.parameters(), 'lr':opt.lr}) 57 | 58 | ### Initialize all target criteria. As each criterion may require its separate set of 59 | ### trainable parameters, several instances have to be created. 60 | old_embed_dim = copy.deepcopy(opt.embed_dim) 61 | self.target_criteria = nn.ModuleList() 62 | for t_dim in opt.loss_s2sd_target_dims: 63 | opt.embed_dim = t_dim 64 | 65 | batchminer = bmine.select(opt.batch_mining, opt) 66 | target_criterion = criteria.select(opt.loss_s2sd_target, opt, batchminer=batchminer) 67 | self.target_criteria.append(target_criterion) 68 | 69 | if hasattr(target_criterion, 'optim_dict_list'): 70 | self.optim_dict_list.extend(target_criterion.optim_dict_list) 71 | else: 72 | self.optim_dict_list.append({'params':target_criterion.parameters(), 'lr':opt.lr}) 73 | 74 | ### Initialize the source objective. By default the same as the target objective(s) 75 | opt.embed_dim = old_embed_dim 76 | batchminer = bmine.select(opt.batch_mining, opt) 77 | self.source_criterion = criteria.select(opt.loss_s2sd_source, opt, batchminer=batchminer) 78 | 79 | if hasattr(self.source_criterion, 'optim_dict_list'): 80 | self.optim_dict_list.extend(self.source_criterion.optim_dict_list) 81 | else: 82 | self.optim_dict_list.append({'params':self.source_criterion.parameters(), 'lr':opt.lr}) 83 | 84 | 85 | 86 | def prep(self, thing): 87 | return 1.*torch.nn.functional.normalize(thing, dim=1) 88 | 89 | 90 | def forward(self, batch, labels, batch_features, avg_batch_features, f_embed, **kwargs): 91 | """ 92 | Args: 93 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 94 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 95 | """ 96 | ### 97 | bs = len(batch) 98 | batch = self.prep(batch) 99 | self.labels = labels.unsqueeze(1) 100 | 101 | ### Compute loss on base/source embedding space as well as the similarity matrix of all base embeddings in . 102 | source_loss = self.source_criterion(batch, labels, batch_features=batch_features, f_embed=f_embed, **kwargs) 103 | source_smat = self.smat(batch, batch, mode=self.d_mode) 104 | loss = source_loss 105 | 106 | ### If required, use combined global max- and average pooling to produce the feature space. 107 | if self.pool_aggr: 108 | avg_batch_features = nn.AdaptiveAvgPool2d(1)(batch_features).view(bs,-1)+nn.AdaptiveMaxPool2d(1)(batch_features).view(bs,-1) 109 | else: 110 | avg_batch_features = avg_batch_features.view(bs,-1) 111 | 112 | ### Key Segment (1): For each target branch, computes the respective loss and similarity matrix . 113 | ### These will be used as distillation signal by computing the KL-Divergence to the source similarity matrix . 114 | kl_divs, target_losses = [], [] 115 | for i,out_net in enumerate(self.target_nets): 116 | target_batch = F.normalize(out_net(avg_batch_features.view(bs, -1)), dim=-1) 117 | target_loss = self.target_criteria[i](target_batch, labels, batch_features=batch_features, f_embed=f_embed, **kwargs) 118 | target_smat = self.smat(target_batch, target_batch, mode=self.d_mode) 119 | 120 | kl_divs.append(self.kl_div(source_smat, target_smat.detach())) 121 | target_losses.append(target_loss) 122 | 123 | loss = (torch.mean(torch.stack(target_losses)) + loss)/2. + self.w*torch.mean(torch.stack(kl_divs)) 124 | 125 | ### If enough iterations have passed, start applying feature space distillation to bridge the 126 | ### dimensionality bottleneck. 127 | if self.match_feats and self.iter_count>=self.max_feat_iter: 128 | n_avg_batch_features = F.normalize(avg_batch_features, dim=-1).detach() 129 | avg_feat_smat = self.smat(n_avg_batch_features, n_avg_batch_features, mode=self.d_mode) 130 | avg_batch_kl_div = self.kl_div(source_smat, avg_feat_smat.detach()) 131 | loss += self.feat_w*avg_batch_kl_div 132 | 133 | ### Update iteration counter for every training iteration. 134 | self.iter_count+=1 135 | 136 | return loss 137 | 138 | 139 | 140 | ### Apply relation distillation over similiarity vectors. 141 | def kl_div(self, A, B): 142 | log_p_A = F.log_softmax(A/self.T, dim=-1) 143 | p_B = F.softmax(B/self.T, dim=-1) 144 | kl_div = F.kl_div(log_p_A, p_B, reduction='sum') * (self.T**2) / A.shape[0] 145 | return kl_div 146 | 147 | 148 | ### Computes similarity matrices. 149 | def smat(self, A, B, mode='cosine'): 150 | if mode=='cosine': 151 | return A.mm(B.T) 152 | elif mode=='euclidean': 153 | As, Bs = A.shape, B.shape 154 | return (A.mm(A.T).diag().unsqueeze(-1)+B.mm(B.T).diag().unsqueeze(0)-2*A.mm(B.T)).clamp(min=1e-20).sqrt() 155 | -------------------------------------------------------------------------------- /criteria/shared_margin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | 6 | """=================================================================================================""" 7 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 8 | REQUIRES_BATCHMINER = True 9 | REQUIRES_OPTIM = True 10 | 11 | ### MarginLoss with trainable class separation margin beta. Runs on Mini-batches as well. 12 | class Criterion(torch.nn.Module): 13 | def __init__(self, opt, batchminer): 14 | """ 15 | Args: 16 | margin: Triplet Margin. 17 | nu: Regularisation Parameter for beta values if they are learned. 18 | beta: Class-Margin values. 19 | n_classes: Number of different classes during training. 20 | """ 21 | super(Criterion, self).__init__() 22 | self.n_classes = opt.n_classes 23 | 24 | self.margin = opt.loss_margin_margin 25 | self.nu = opt.loss_margin_nu 26 | self.beta_constant = opt.loss_margin_beta_constant 27 | self.beta_val = opt.loss_margin_beta 28 | 29 | if opt.loss_margin_beta_constant: 30 | self.beta = opt.loss_margin_beta 31 | else: 32 | self.beta = torch.nn.Parameter(torch.ones(opt.n_classes)*opt.loss_margin_beta) 33 | 34 | self.batchminer = batchminer 35 | 36 | self.name = 'margin' 37 | 38 | self.lr = opt.loss_margin_beta_lr 39 | 40 | #### 41 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 42 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 43 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 44 | 45 | def forward(self, batch, labels, **kwargs): 46 | """ 47 | Args: 48 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 49 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 50 | """ 51 | sampled_triplets = self.batchminer(batch, labels) 52 | 53 | if len(sampled_triplets): 54 | d_ap, d_an = [],[] 55 | for triplet in sampled_triplets: 56 | train_triplet = {'Anchor': batch[triplet[0],:], 'Positive':batch[triplet[1],:], 'Negative':batch[triplet[2]]} 57 | 58 | pos_dist = ((train_triplet['Anchor']-train_triplet['Positive']).pow(2).sum()+1e-8).pow(1/2) 59 | neg_dist = ((train_triplet['Anchor']-train_triplet['Negative']).pow(2).sum()+1e-8).pow(1/2) 60 | 61 | d_ap.append(pos_dist) 62 | d_an.append(neg_dist) 63 | d_ap, d_an = torch.stack(d_ap), torch.stack(d_an) 64 | 65 | if self.beta_constant: 66 | beta = self.beta 67 | else: 68 | beta = torch.stack([self.beta[labels[triplet[0]]] for triplet in sampled_triplets]).to(torch.float).to(d_ap.device) 69 | 70 | pos_loss = torch.nn.functional.relu(d_ap-beta+self.margin) 71 | neg_loss = torch.nn.functional.relu(beta-d_an+self.margin) 72 | 73 | pair_count = torch.sum((pos_loss>0.)+(neg_loss>0.)).to(torch.float).to(d_ap.device) 74 | 75 | if pair_count == 0.: 76 | loss = torch.sum(pos_loss+neg_loss) 77 | else: 78 | loss = torch.sum(pos_loss+neg_loss)/pair_count 79 | 80 | if self.nu: loss = loss + beta_regularisation_loss.to(torch.float).to(d_ap.device) 81 | else: 82 | loss = torch.tensor(0.).to(torch.float).to(batch.device) 83 | 84 | return loss 85 | -------------------------------------------------------------------------------- /criteria/shared_triplet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = False 9 | 10 | ### Standard Triplet Loss, finds triplets in Mini-batches. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | """ 14 | Args: 15 | margin: Triplet Margin. 16 | """ 17 | super(Criterion, self).__init__() 18 | self.margin = opt.loss_triplet_margin 19 | self.batchminer = batchminer 20 | 21 | self.name = 'triplet' 22 | 23 | #### 24 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 25 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 26 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 27 | 28 | def triplet_distance(self, anchor, positive, negative): 29 | return torch.nn.functional.relu((anchor-positive).pow(2).sum()-(anchor-negative).pow(2).sum()+self.margin) 30 | 31 | def forward(self, batch, labels, **kwargs): 32 | """ 33 | Args: 34 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 35 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 36 | """ 37 | sampled_triplets = self.batchminer(batch, labels) 38 | loss = torch.stack([self.triplet_distance(batch[triplet[0],:],batch[triplet[1],:],batch[triplet[2],:]) for triplet in sampled_triplets]) 39 | 40 | return torch.mean(loss) 41 | -------------------------------------------------------------------------------- /criteria/triplet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | import batchminer 4 | 5 | """=================================================================================================""" 6 | ALLOWED_MINING_OPS = list(batchminer.BATCHMINING_METHODS.keys()) 7 | REQUIRES_BATCHMINER = True 8 | REQUIRES_OPTIM = False 9 | 10 | ### Standard Triplet Loss, finds triplets in Mini-batches. 11 | class Criterion(torch.nn.Module): 12 | def __init__(self, opt, batchminer): 13 | """ 14 | Args: 15 | margin: Triplet Margin. 16 | """ 17 | super(Criterion, self).__init__() 18 | self.margin = opt.loss_triplet_margin 19 | self.batchminer = batchminer 20 | self.name = 'triplet' 21 | 22 | #### 23 | self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS 24 | self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER 25 | self.REQUIRES_OPTIM = REQUIRES_OPTIM 26 | 27 | 28 | def triplet_distance(self, anchor, positive, negative): 29 | return torch.nn.functional.relu((anchor-positive).pow(2).sum()-(anchor-negative).pow(2).sum()+self.margin) 30 | 31 | def forward(self, batch, labels, **kwargs): 32 | """ 33 | Args: 34 | batch: torch.Tensor: Input of embeddings with size (BS x DIM) 35 | labels: nparray/list: For each element of the batch assigns a class [0,...,C-1], shape: (BS x 1) 36 | """ 37 | if isinstance(labels, torch.Tensor): labels = labels.cpu().numpy() 38 | sampled_triplets = self.batchminer(batch, labels) 39 | loss = torch.stack([self.triplet_distance(batch[triplet[0],:],batch[triplet[1],:],batch[triplet[2],:]) for triplet in sampled_triplets]) 40 | 41 | return torch.mean(loss) 42 | -------------------------------------------------------------------------------- /datasampler/__init__.py: -------------------------------------------------------------------------------- 1 | import datasampler.class_random_sampler 2 | import datasampler.random_sampler 3 | 4 | def select(sampler, opt, image_dict, image_list=None, **kwargs): 5 | if 'class' in sampler: 6 | sampler_lib = class_random_sampler 7 | elif 'full' in sampler: 8 | sampler_lib = random_sampler 9 | else: 10 | raise Exception('Minibatch sampler <{}> not available!'.format(sampler)) 11 | 12 | sampler = sampler_lib.Sampler(opt,image_dict=image_dict,image_list=image_list) 13 | 14 | return sampler 15 | -------------------------------------------------------------------------------- /datasampler/class_random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | 6 | 7 | 8 | """======================================================""" 9 | REQUIRES_STORAGE = False 10 | 11 | ### 12 | class Sampler(torch.utils.data.sampler.Sampler): 13 | """ 14 | Plugs into PyTorch Batchsampler Package. 15 | """ 16 | def __init__(self, opt, image_dict, image_list, **kwargs): 17 | self.pars = opt 18 | 19 | ##### 20 | self.image_dict = image_dict 21 | self.image_list = image_list 22 | 23 | ##### 24 | self.internal_split = opt.internal_split 25 | self.use_meta_split = self.internal_split!=1 26 | self.classes = list(self.image_dict.keys()) 27 | self.tv_split = int(len(self.classes)*self.internal_split) 28 | self.train_classes = self.classes[:self.tv_split] 29 | self.val_classes = self.classes[self.tv_split:] 30 | 31 | #### 32 | self.batch_size = opt.bs 33 | self.samples_per_class = opt.samples_per_class 34 | self.sampler_length = len(image_list)//opt.bs 35 | assert self.batch_size%self.samples_per_class==0, '#Samples per class must divide batchsize!' 36 | 37 | self.name = 'class_random_sampler' 38 | self.requires_storage = False 39 | 40 | self.random_gen = random.Random(opt.seed) 41 | 42 | 43 | def __iter__(self): 44 | for _ in range(self.sampler_length): 45 | subset = [] 46 | ### Random Subset from Random classes 47 | if self.use_meta_split: 48 | train_draws = int((self.batch_size//self.samples_per_class)*self.internal_split) 49 | val_draws = self.batch_size//self.samples_per_class-train_draws 50 | else: 51 | train_draws = self.batch_size//self.samples_per_class 52 | val_draws = None 53 | 54 | if self.pars.data_ssl_set: 55 | for _ in range(train_draws//2): 56 | class_key = random.choice(self.train_classes) 57 | subset.extend([random.choice(self.image_dict[class_key])[-1] for _ in range(self.samples_per_class)]) 58 | subset = subset + subset 59 | else: 60 | for _ in range(train_draws): 61 | class_key = random.choice(self.train_classes) 62 | class_ix_list = [random.choice(self.image_dict[class_key])[-1] for _ in range(self.samples_per_class)] 63 | subset.extend(class_ix_list) 64 | 65 | if self.use_meta_split: 66 | for _ in range(val_draws): 67 | class_key = random.choice(self.val_classes) 68 | subset.extend([random.choice(self.image_dict[class_key])[-1] for _ in range(self.samples_per_class)]) 69 | yield subset 70 | 71 | def __len__(self): 72 | return self.sampler_length 73 | -------------------------------------------------------------------------------- /datasampler/random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, torch.nn as nn, torch.nn.functional as F 3 | from tqdm import tqdm 4 | import random 5 | 6 | 7 | 8 | """======================================================""" 9 | REQUIRES_STORAGE = False 10 | 11 | ### 12 | class Sampler(torch.utils.data.sampler.Sampler): 13 | """ 14 | Plugs into PyTorch Batchsampler Package. 15 | """ 16 | def __init__(self, opt, image_dict, image_list=None): 17 | self.image_dict = image_dict 18 | self.image_list = image_list 19 | 20 | self.batch_size = opt.bs 21 | self.samples_per_class = opt.samples_per_class 22 | self.sampler_length = len(image_list)//opt.bs 23 | assert self.batch_size%self.samples_per_class==0, '#Samples per class must divide batchsize!' 24 | 25 | self.name = 'random_sampler' 26 | self.requires_storage = False 27 | 28 | def __iter__(self): 29 | for _ in range(self.sampler_length): 30 | subset = [] 31 | ### Random Subset from Random classes 32 | for _ in range(self.batch_size-1): 33 | class_key = random.choice(list(self.image_dict.keys())) 34 | sample_idx = np.random.choice(len(self.image_dict[class_key])) 35 | subset.append(self.image_dict[class_key][sample_idx][-1]) 36 | # 37 | subset.append(random.choice(self.image_dict[self.image_list[random.choice(subset)][-1]])[-1]) 38 | yield subset 39 | 40 | def __len__(self): 41 | return self.sampler_length 42 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import datasets.cub200 2 | import datasets.cars196 3 | import datasets.stanford_online_products 4 | 5 | 6 | def select(dataset, opt, data_path, splitpath=None): 7 | if splitpath is None: 8 | if 'cub200' in dataset: 9 | return cub200.DefaultDatasets(opt, data_path) 10 | if 'cars196' in dataset: 11 | return cars196.DefaultDatasets(opt, data_path) 12 | if 'online_products' in dataset: 13 | return stanford_online_products.DefaultDatasets(opt, data_path) 14 | else: 15 | if 'cub200' in dataset: 16 | return cub200.OODatasets(opt, data_path, splitpath) 17 | if 'cars196' in dataset: 18 | return cars196.OODatasets(opt, data_path, splitpath) 19 | if 'online_products' in dataset: 20 | return stanford_online_products.OODatasets(opt, data_path, splitpath) 21 | 22 | raise NotImplementedError('A dataset for {} is currently not implemented.\n\ 23 | Currently available are : cub200, cars196 & stanford_online_products!'.format(dataset)) 24 | -------------------------------------------------------------------------------- /datasets/basic_dataset_scaffold.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torchvision.transforms as transforms 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | """===================================================================================================""" 8 | ################## BASIC PYTORCH DATASET USED FOR ALL DATASETS ################################## 9 | class BaseDataset(Dataset): 10 | def __init__(self, image_dict, opt, is_validation=False): 11 | self.is_validation = is_validation 12 | self.pars = opt 13 | 14 | ##### 15 | self.image_dict = image_dict 16 | 17 | ##### 18 | self.init_setup() 19 | 20 | ##### 21 | if 'bninception' not in opt.arch: 22 | self.f_norm = normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 23 | else: 24 | # normalize = transforms.Normalize(mean=[0.502, 0.4588, 0.4078],std=[1., 1., 1.]) 25 | self.f_norm = normalize = transforms.Normalize(mean=[0.502, 0.4588, 0.4078],std=[0.0039, 0.0039, 0.0039]) 26 | 27 | transf_list = [] 28 | 29 | self.crop_size = crop_im_size = 224 if 'googlenet' not in opt.arch else 227 30 | if opt.augmentation=='big': 31 | crop_im_size = 256 32 | 33 | ############# 34 | self.normal_transform = [] 35 | if not self.is_validation: 36 | if opt.augmentation=='base' or opt.augmentation=='big': 37 | self.normal_transform.extend([transforms.RandomResizedCrop(size=crop_im_size), transforms.RandomHorizontalFlip(0.5)]) 38 | elif opt.augmentation=='adv': 39 | self.normal_transform.extend([transforms.RandomResizedCrop(size=crop_im_size), transforms.RandomGrayscale(p=0.2), 40 | transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), transforms.RandomHorizontalFlip(0.5)]) 41 | elif opt.augmentation=='red': 42 | self.normal_transform.extend([transforms.Resize(size=256), transforms.RandomCrop(crop_im_size), transforms.RandomHorizontalFlip(0.5)]) 43 | else: 44 | self.normal_transform.extend([transforms.Resize(256), transforms.CenterCrop(crop_im_size)]) 45 | self.normal_transform.extend([transforms.ToTensor(), normalize]) 46 | self.normal_transform_list = self.normal_transform 47 | self.normal_transform = transforms.Compose(self.normal_transform) 48 | 49 | ##### 50 | self.include_aux_augmentations = False 51 | self.predict_rotations = None 52 | 53 | 54 | def init_setup(self): 55 | self.n_files = np.sum([len(self.image_dict[key]) for key in self.image_dict.keys()]) 56 | self.avail_classes = sorted(list(self.image_dict.keys())) 57 | 58 | counter = 0 59 | temp_image_dict = {} 60 | for i,key in enumerate(self.avail_classes): 61 | temp_image_dict[key] = [] 62 | for path in self.image_dict[key]: 63 | temp_image_dict[key].append([path, counter]) 64 | counter += 1 65 | 66 | self.image_dict = temp_image_dict 67 | self.image_list = [[(x[0],key) for x in self.image_dict[key]] for key in self.image_dict.keys()] 68 | self.image_list = [x for y in self.image_list for x in y] 69 | 70 | self.image_paths = self.image_list 71 | 72 | self.is_init = True 73 | 74 | 75 | def ensure_3dim(self, img): 76 | if len(img.size)==2: 77 | img = img.convert('RGB') 78 | return img 79 | 80 | 81 | def __getitem__(self, idx): 82 | input_image = self.ensure_3dim(Image.open(self.image_list[idx][0])) 83 | imrot_class = -1 84 | 85 | if self.include_aux_augmentations: 86 | im_a = self.normal_transform(input_image) 87 | 88 | if not self.predict_rotations: 89 | im_b = self.normal_transform(input_image) 90 | else: 91 | class ImRotTrafo: 92 | def __init__(self, angle): 93 | self.angle = angle 94 | def __call__(self, x): 95 | return transforms.functional.rotate(x, self.angle) 96 | 97 | imrot_class = idx%4 98 | angle = np.array([0,90,180,270])[imrot_class] 99 | imrot_aug = [ImRotTrafo(angle), transforms.Resize((256,256)), transforms.RandomCrop(size=self.crop_size), 100 | transforms.ToTensor(), self.f_norm] 101 | imrot_aug = transforms.Compose(imrot_aug) 102 | im_b = imrot_aug(input_image) 103 | 104 | if 'bninception' in self.pars.arch: 105 | im_a, im_b = im_a[range(3)[::-1],:], im_b[range(3)[::-1],:] 106 | 107 | return (self.image_list[idx][-1], im_a, idx, im_b, imrot_class) 108 | else: 109 | im_a = self.normal_transform(input_image) 110 | if 'bninception' in self.pars.arch: 111 | im_a = im_a[range(3)[::-1],:] 112 | return self.image_list[idx][-1], im_a, idx 113 | 114 | 115 | def __len__(self): 116 | return self.n_files 117 | -------------------------------------------------------------------------------- /datasets/cars196.py: -------------------------------------------------------------------------------- 1 | from datasets.basic_dataset_scaffold import BaseDataset 2 | import os, copy 3 | 4 | 5 | def give_info_dict(source, classes): 6 | image_list = [[(i,source + '/' + classname +'/'+x) for x in sorted(os.listdir(source + '/' + classname)) if '._' not in x] for i,classname in enumerate(classes)] 7 | image_list = [x for y in image_list for x in y] 8 | 9 | idx_to_class_conversion = {i:classname for i,classname in enumerate(classes)} 10 | 11 | image_dict = {} 12 | for key,img_path in image_list: 13 | if not key in image_dict.keys(): 14 | image_dict[key] = [] 15 | image_dict[key].append(img_path) 16 | 17 | return image_list, image_dict, idx_to_class_conversion 18 | 19 | 20 | def OODatasets(opt, datapath, splitpath=None): 21 | import pickle as pkl 22 | splitpath_base = os.getcwd() if splitpath is None else splitpath 23 | split_dict = pkl.load(open(splitpath_base+'/datasplits/cars196_splits.pkl', 'rb')) 24 | train_classes, test_classes, fid = split_dict[opt.data_hardness]['train'], split_dict[opt.data_hardness]['test'], split_dict[opt.data_hardness]['fid'] 25 | print('\nLoaded Data Split with FID-Hardness: {0:4.4f}'.format(fid)) 26 | 27 | ### 28 | image_sourcepath = datapath + '/images' 29 | 30 | ### 31 | if opt.use_tv_split: 32 | if not opt.tv_split_perc: 33 | train_classes, val_classes = split_dict[opt.data_hardness]['split_train'], split_dict[opt.data_hardness]['split_val'] 34 | else: 35 | train_val_split = int(len(train_classes)*opt.tv_split_perc) 36 | train_classes, val_classes = train_classes[:train_val_split], train_classes[train_val_split:] 37 | val_image_list, val_image_dict, val_conversion = give_info_dict(image_sourcepath, val_classes) 38 | val_dataset = BaseDataset(val_image_dict, opt, is_validation=True) 39 | val_dataset.conversion = val_conversion 40 | else: 41 | val_dataset, val_image_dict = None, None 42 | 43 | ### 44 | train_image_list, train_image_dict, train_conversion = give_info_dict(image_sourcepath, train_classes) 45 | test_image_list, test_image_dict, test_conversion = give_info_dict(image_sourcepath, test_classes) 46 | 47 | ### 48 | print('\nDataset Setup:\nUsing Train-Val Split: {0}\n#Classes: Train ({1}) | Val ({2}) | Test ({3})\n'.format(opt.use_tv_split, len(train_image_dict), len(val_image_dict) if val_image_dict is not None else 'X', len(test_image_dict))) 49 | 50 | ### 51 | train_dataset = BaseDataset(train_image_dict, opt) 52 | test_dataset = BaseDataset(test_image_dict, opt, is_validation=True) 53 | train_eval_dataset = BaseDataset(train_image_dict, opt, is_validation=True) 54 | 55 | ### 56 | reverse_train_conversion = {item: key for key, item in train_conversion.items()} 57 | reverse_test_conversion = {item: key for key, item in test_conversion.items()} 58 | 59 | train_dataset.conversion = train_conversion 60 | test_dataset.conversion = test_conversion 61 | train_eval_dataset.conversion = train_conversion 62 | 63 | few_shot_datasets = None 64 | episode_context = None 65 | if hasattr(opt, 'few_shot_evaluate'): 66 | test_episodes = split_dict[opt.data_hardness]['test_episodes'] 67 | shots = list(test_episodes.keys()) 68 | episode_idxs = list(test_episodes[opt.finetune_shots].keys()) 69 | classnames = list(test_episodes[opt.finetune_shots][episode_idxs[0]].keys()) 70 | conv_classnames = [reverse_test_conversion[classname] for classname in classnames] 71 | 72 | episode_context = {} 73 | for ep_idx in episode_idxs: 74 | ref_dict = copy.deepcopy(test_image_dict) 75 | test_support_image_dict = {} 76 | test_query_image_dict = {} 77 | for conv_classname, classname in zip(conv_classnames, classnames): 78 | samples_to_use = test_episodes[opt.finetune_shots][ep_idx][classname] 79 | base_path = '/'.join(ref_dict[conv_classname][0].split('/')[:-1]) 80 | support_samples_to_use = [base_path + '/' + x for x in samples_to_use] 81 | query_samples_to_use = [x for x in ref_dict[conv_classname] if x not in support_samples_to_use] 82 | test_query_image_dict[conv_classname] = query_samples_to_use 83 | test_support_image_dict[conv_classname] = support_samples_to_use 84 | 85 | test_support_dataset = BaseDataset(test_support_image_dict, opt) 86 | test_query_dataset = BaseDataset(test_query_image_dict, opt, is_validation=True) 87 | 88 | episode_context[ep_idx] = {'support': test_support_dataset, 'query': test_query_dataset} 89 | 90 | return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':train_eval_dataset, 'fewshot_episodes': episode_context} 91 | 92 | 93 | 94 | def DefaultDatasets(opt, datapath): 95 | image_sourcepath = datapath+'/images' 96 | image_classes = sorted([x for x in os.listdir(image_sourcepath)]) 97 | total_conversion = {i:x for i,x in enumerate(image_classes)} 98 | image_list = {i:sorted([image_sourcepath+'/'+key+'/'+x for x in os.listdir(image_sourcepath+'/'+key)]) for i,key in enumerate(image_classes)} 99 | image_list = [[(key,img_path) for img_path in image_list[key]] for key in image_list.keys()] 100 | image_list = [x for y in image_list for x in y] 101 | 102 | ### Dictionary of structure class:list_of_samples_with_said_class 103 | image_dict = {} 104 | for key, img_path in image_list: 105 | if not key in image_dict.keys(): 106 | image_dict[key] = [] 107 | image_dict[key].append(img_path) 108 | 109 | ### Use the first half of the sorted data as training and the second half as test set 110 | keys = sorted(list(image_dict.keys())) 111 | train,test = keys[:len(keys)//2], keys[len(keys)//2:] 112 | 113 | ### If required, split the training data into a train/val setup either by or per class. 114 | if opt.use_tv_split: 115 | if not opt.tv_split_by_samples: 116 | train_val_split = int(len(train)*opt.tv_split_perc) 117 | train, val = train[:train_val_split], train[train_val_split:] 118 | ### 119 | train_image_dict = {i:image_dict[key] for i,key in enumerate(train)} 120 | val_image_dict = {i:image_dict[key] for i,key in enumerate(val)} 121 | test_image_dict = {i:image_dict[key] for i,key in enumerate(test)} 122 | else: 123 | val = train 124 | train_image_dict, val_image_dict = {},{} 125 | for key in train: 126 | train_ixs = np.random.choice(len(image_dict[key]), int(len(image_dict[key])*opt.tv_split_perc), replace=False) 127 | val_ixs = np.array([x for x in range(len(image_dict[key])) if x not in train_ixs]) 128 | train_image_dict[key] = np.array(image_dict[key])[train_ixs] 129 | val_image_dict[key] = np.array(image_dict[key])[val_ixs] 130 | val_dataset = BaseDataset(val_image_dict, opt, is_validation=True) 131 | val_conversion = {i:total_conversion[key] for i,key in enumerate(val)} 132 | ### 133 | val_dataset.conversion = val_conversion 134 | else: 135 | train_image_dict = {key:image_dict[key] for key in train} 136 | val_image_dict = None 137 | val_dataset = None 138 | 139 | ### 140 | train_conversion = {i:total_conversion[key] for i,key in enumerate(train)} 141 | test_conversion = {i:total_conversion[key] for i,key in enumerate(test)} 142 | 143 | ### 144 | test_image_dict = {key:image_dict[key] for key in test} 145 | 146 | ### 147 | print('\nDataset Setup:\nUsing Train-Val Split: {0}\n#Classes: Train ({1}) | Val ({2}) | Test ({3})\n'.format(opt.use_tv_split, len(train_image_dict), len(val_image_dict) if val_image_dict else 'X', len(test_image_dict))) 148 | 149 | ### 150 | train_dataset = BaseDataset(train_image_dict, opt) 151 | test_dataset = BaseDataset(test_image_dict, opt, is_validation=True) 152 | eval_dataset = BaseDataset(train_image_dict, opt, is_validation=True) 153 | eval_train_dataset = BaseDataset(train_image_dict, opt, is_validation=False) 154 | train_dataset.conversion = train_conversion 155 | test_dataset.conversion = test_conversion 156 | eval_dataset.conversion = test_conversion 157 | eval_train_dataset.conversion = train_conversion 158 | 159 | return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':eval_train_dataset} 160 | -------------------------------------------------------------------------------- /datasets/cub200.py: -------------------------------------------------------------------------------- 1 | from datasets.basic_dataset_scaffold import BaseDataset 2 | import os, copy 3 | 4 | def give_info_dict(source, classes): 5 | image_list = [[(i,source + '/' + classname +'/'+x) for x in sorted(os.listdir(source + '/' + classname)) if '._' not in x] for i,classname in enumerate(classes)] 6 | image_list = [x for y in image_list for x in y] 7 | 8 | idx_to_class_conversion = {i:classname for i,classname in enumerate(classes)} 9 | 10 | image_dict = {} 11 | for key,img_path in image_list: 12 | if not key in image_dict.keys(): 13 | image_dict[key] = [] 14 | image_dict[key].append(img_path) 15 | 16 | return image_list, image_dict, idx_to_class_conversion 17 | 18 | 19 | def OODatasets(opt, datapath, splitpath=None): 20 | import pickle as pkl 21 | splitpath_base = os.getcwd() if splitpath is None else splitpath 22 | split_dict = pkl.load(open(splitpath_base+'/datasplits/cub200_splits.pkl', 'rb')) 23 | train_classes, test_classes, fid = split_dict[opt.data_hardness]['train'], split_dict[opt.data_hardness]['test'], split_dict[opt.data_hardness]['fid'] 24 | print('\nLoaded Data Split with FID-Hardness: {0:4.4f}'.format(fid)) 25 | 26 | ### 27 | image_sourcepath = datapath + '/images' 28 | 29 | ### 30 | if opt.use_tv_split: 31 | if not opt.tv_split_perc: 32 | train_classes, val_classes = split_dict[opt.data_hardness]['split_train'], split_dict[opt.data_hardness]['split_val'] 33 | else: 34 | train_val_split = int(len(train_classes)*opt.tv_split_perc) 35 | train_classes, val_classes = train_classes[:train_val_split], train_classes[train_val_split:] 36 | val_image_list, val_image_dict, val_conversion = give_info_dict(image_sourcepath, val_classes) 37 | val_dataset = BaseDataset(val_image_dict, opt, is_validation=True) 38 | val_dataset.conversion = val_conversion 39 | else: 40 | val_dataset, val_image_dict = None, None 41 | 42 | ### 43 | train_image_list, train_image_dict, train_conversion = give_info_dict(image_sourcepath, train_classes) 44 | test_image_list, test_image_dict, test_conversion = give_info_dict(image_sourcepath, test_classes) 45 | 46 | ### 47 | print('\nDataset Setup:\nUsing Train-Val Split: {0}\n#Classes: Train ({1}) | Val ({2}) | Test ({3})\n'.format(opt.use_tv_split, len(train_image_dict), len(val_image_dict) if val_image_dict is not None else 'X', len(test_image_dict))) 48 | 49 | ### 50 | train_dataset = BaseDataset(train_image_dict, opt) 51 | test_dataset = BaseDataset(test_image_dict, opt, is_validation=True) 52 | train_eval_dataset = BaseDataset(train_image_dict, opt, is_validation=True) 53 | 54 | ### 55 | reverse_train_conversion = {item: key for key, item in train_conversion.items()} 56 | reverse_test_conversion = {item: key for key, item in test_conversion.items()} 57 | 58 | train_dataset.conversion = train_conversion 59 | test_dataset.conversion = test_conversion 60 | train_eval_dataset.conversion = train_conversion 61 | 62 | ### 63 | few_shot_datasets = None 64 | episode_context = None 65 | if hasattr(opt, 'few_shot_evaluate'): 66 | test_episodes = split_dict[opt.data_hardness]['test_episodes'] 67 | shots = list(test_episodes.keys()) 68 | episode_idxs = list(test_episodes[opt.finetune_shots].keys()) 69 | classnames = list(test_episodes[opt.finetune_shots][episode_idxs[0]].keys()) 70 | conv_classnames = [reverse_test_conversion[classname] for classname in classnames] 71 | 72 | episode_context = {} 73 | for ep_idx in episode_idxs: 74 | ref_dict = copy.deepcopy(test_image_dict) 75 | test_support_image_dict = {} 76 | test_query_image_dict = {} 77 | for conv_classname, classname in zip(conv_classnames, classnames): 78 | samples_to_use = test_episodes[opt.finetune_shots][ep_idx][classname] 79 | base_path = '/'.join(ref_dict[conv_classname][0].split('/')[:-1]) 80 | support_samples_to_use = [base_path + '/' + x for x in samples_to_use] 81 | query_samples_to_use = [x for x in ref_dict[conv_classname] if x not in support_samples_to_use] 82 | test_query_image_dict[conv_classname] = query_samples_to_use 83 | test_support_image_dict[conv_classname] = support_samples_to_use 84 | 85 | test_support_dataset = BaseDataset(test_support_image_dict, opt) 86 | test_query_dataset = BaseDataset(test_query_image_dict, opt, is_validation=True) 87 | 88 | episode_context[ep_idx] = {'support': test_support_dataset, 'query': test_query_dataset} 89 | 90 | return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':train_eval_dataset, 'fewshot_episodes': episode_context} 91 | 92 | 93 | 94 | 95 | 96 | def DefaultDatasets(opt, datapath): 97 | image_sourcepath = datapath+'/images' 98 | image_classes = sorted([x for x in os.listdir(image_sourcepath) if '._' not in x], key=lambda x: int(x.split('.')[0])) 99 | total_conversion = {int(x.split('.')[0])-1:x.split('.')[-1] for x in image_classes} 100 | image_list = {int(key.split('.')[0])-1:sorted([image_sourcepath+'/'+key+'/'+x for x in os.listdir(image_sourcepath+'/'+key) if '._' not in x]) for key in image_classes} 101 | image_list = [[(key,img_path) for img_path in image_list[key]] for key in image_list.keys()] 102 | image_list = [x for y in image_list for x in y] 103 | 104 | ### Dictionary of structure class:list_of_samples_with_said_class 105 | image_dict = {} 106 | for key, img_path in image_list: 107 | if not key in image_dict.keys(): 108 | image_dict[key] = [] 109 | image_dict[key].append(img_path) 110 | 111 | ### Use the first half of the sorted data as training and the second half as test set 112 | keys = sorted(list(image_dict.keys())) 113 | 114 | train,test = keys[:len(keys)//2], keys[len(keys)//2:] 115 | 116 | ### If required, split the training data into a train/val setup either by or per class. 117 | # from IPython import embed; embed() 118 | if opt.use_tv_split: 119 | if not opt.tv_split_by_samples: 120 | train_val_split = int(len(train)*opt.tv_split_perc) 121 | train, val = train[:train_val_split], train[train_val_split:] 122 | ### 123 | train_image_dict = {i:image_dict[key] for i,key in enumerate(train)} 124 | val_image_dict = {i:image_dict[key] for i,key in enumerate(val)} 125 | test_image_dict = {i:image_dict[key] for i,key in enumerate(test)} 126 | else: 127 | val = train 128 | train_image_dict, val_image_dict = {},{} 129 | for key in train: 130 | train_ixs = np.array(list(set(np.round(np.linspace(0,len(image_dict[key])-1,int(len(image_dict[key])*opt.tv_split_perc)))))).astype(int) 131 | val_ixs = np.array([x for x in range(len(image_dict[key])) if x not in train_ixs]) 132 | train_image_dict[key] = np.array(image_dict[key])[train_ixs] 133 | val_image_dict[key] = np.array(image_dict[key])[val_ixs] 134 | val_dataset = BaseDataset(val_image_dict, opt, is_validation=True) 135 | val_conversion = {i:total_conversion[key] for i,key in enumerate(val)} 136 | ### 137 | val_dataset.conversion = val_conversion 138 | else: 139 | train_image_dict = {key:image_dict[key] for key in train} 140 | val_image_dict = None 141 | val_dataset = None 142 | 143 | ### 144 | train_conversion = {i:total_conversion[key] for i,key in enumerate(train)} 145 | test_conversion = {i:total_conversion[key] for i,key in enumerate(test)} 146 | 147 | ### 148 | test_image_dict = {key:image_dict[key] for key in test} 149 | 150 | ### 151 | print('\nDataset Setup:\nUsing Train-Val Split: {0}\n#Classes: Train ({1}) | Val ({2}) | Test ({3})\n'.format(opt.use_tv_split, len(train_image_dict), len(val_image_dict) if val_image_dict else 'X', len(test_image_dict))) 152 | 153 | ### 154 | train_dataset = BaseDataset(train_image_dict, opt) 155 | test_dataset = BaseDataset(test_image_dict, opt, is_validation=True) 156 | eval_dataset = BaseDataset(train_image_dict, opt, is_validation=True) 157 | eval_train_dataset = BaseDataset(train_image_dict, opt, is_validation=False) 158 | train_dataset.conversion = train_conversion 159 | test_dataset.conversion = test_conversion 160 | eval_dataset.conversion = test_conversion 161 | eval_train_dataset.conversion = train_conversion 162 | 163 | 164 | return {'training':train_dataset, 'validation':val_dataset, 'testing':test_dataset, 'evaluation':eval_dataset, 'evaluation_train':eval_train_dataset} 165 | -------------------------------------------------------------------------------- /datasplits/cars196_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/datasplits/cars196_splits.pkl -------------------------------------------------------------------------------- /datasplits/cub200_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/datasplits/cub200_splits.pkl -------------------------------------------------------------------------------- /datasplits/online_products_splits.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/datasplits/online_products_splits.pkl -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | import faiss, matplotlib.pyplot as plt, os, numpy as np, torch 2 | from PIL import Image 3 | 4 | 5 | 6 | ####################### 7 | def evaluate(dataset, LOG, metric_computer, dataloaders, model, opt, evaltypes, device, 8 | aux_store=None, make_recall_plot=False, store_checkpoints=True, log_key='Test', 9 | compute_metrics_only=False, print_text=True): 10 | """ 11 | Parent-Function to compute evaluation metrics, print summary string and store checkpoint files/plot sample recall plots. 12 | """ 13 | computed_metrics, extra_infos = metric_computer.compute_standard(opt, model, dataloaders[0], evaltypes, device, mode=log_key) 14 | 15 | numeric_metrics = {} 16 | histogr_metrics = {} 17 | for main_key in computed_metrics.keys(): 18 | for name,value in computed_metrics[main_key].items(): 19 | if isinstance(value, np.ndarray): 20 | if main_key not in histogr_metrics: histogr_metrics[main_key] = {} 21 | histogr_metrics[main_key][name] = value 22 | else: 23 | if main_key not in numeric_metrics: numeric_metrics[main_key] = {} 24 | numeric_metrics[main_key][name] = value 25 | 26 | ### 27 | full_result_str = '' 28 | for evaltype in numeric_metrics.keys(): 29 | full_result_str += 'Out-Type: {}:\n'.format(evaltype) 30 | for i,(metricname, metricval) in enumerate(numeric_metrics[evaltype].items()): 31 | full_result_str += '{0}{1}: {2:4.4f}'.format(' | ' if i>0 else '',metricname, metricval) 32 | full_result_str += '\n' 33 | 34 | if print_text: 35 | print(full_result_str) 36 | 37 | if not compute_metrics_only: 38 | ### Log Histogram-Style data with W&Bs 39 | if opt.log_online: 40 | for evaltype in histogr_metrics.keys(): 41 | for eval_metric, hist in histogr_metrics[evaltype].items(): 42 | import wandb, numpy 43 | hist = hist[:500] 44 | wandb.log({log_key+': '+evaltype+'_{}'.format(eval_metric): wandb.Histogram(np_histogram=(list(hist),list(np.arange(len(hist)+1))))}, step=opt.epoch) 45 | wandb.log({log_key+': '+evaltype+'_LOG-{}'.format(eval_metric): wandb.Histogram(np_histogram=(list(np.log(hist)+20),list(np.arange(len(hist)+1))))}, step=opt.epoch) 46 | 47 | ### 48 | for evaltype in numeric_metrics.keys(): 49 | for eval_metric in numeric_metrics[evaltype].keys(): 50 | parent_metric = evaltype+'_{}'.format(eval_metric.split('@')[0]) 51 | LOG.progress_saver[log_key].log(eval_metric, numeric_metrics[evaltype][eval_metric], group=parent_metric) 52 | 53 | 54 | ### 55 | if make_recall_plot: 56 | if opt.dataset!='inshop': 57 | recover_closest_standard(extra_infos[evaltype]['features'], 58 | extra_infos[evaltype]['image_paths'], 59 | LOG.prop.save_path+'/sample_recoveries.png') 60 | else: 61 | recover_closest_query_gallery(extra_infos[evaltype]['query_features'], 62 | extra_infos[evaltype]['gallery_features'], 63 | extra_infos[evaltype]['gallery_image_paths'], 64 | extra_infos[evaltype]['query_image_paths'], 65 | LOG.prop.save_path+'/sample_recoveries.png') 66 | 67 | ### 68 | for evaltype in evaltypes: 69 | for storage_metric in opt.storage_metrics: 70 | parent_metric = evaltype+'_{}'.format(storage_metric.split('@')[0]) 71 | ref_mets = LOG.progress_saver[log_key].groups[parent_metric][storage_metric]['content'] 72 | if not len(ref_mets): ref_mets = [-np.inf] 73 | if numeric_metrics[evaltype][storage_metric]>np.max(ref_mets): 74 | print('Saved improved checkpoint for {}: {}\n'.format(log_key, parent_metric)) 75 | set_checkpoint(model, opt, LOG, LOG.prop.save_path+'/checkpoint_{}_{}_{}.pth.tar'.format(log_key, evaltype, storage_metric), aux=aux_store) 76 | else: 77 | return numeric_metrics 78 | 79 | 80 | ########################### 81 | def set_checkpoint(model, opt, progress_saver, savepath, aux=None): 82 | torch.save({'state_dict':model.state_dict(), 'opt':opt, 'progress':progress_saver, 'aux':aux}, savepath) 83 | 84 | 85 | 86 | 87 | ########################## 88 | def recover_closest_standard(feature_matrix_all, image_paths, save_path, n_image_samples=10, n_closest=3): 89 | image_paths = np.array([x[0] for x in image_paths]) 90 | sample_idxs = np.random.choice(np.arange(len(feature_matrix_all)), n_image_samples) 91 | 92 | faiss_search_index = faiss.IndexFlatL2(feature_matrix_all.shape[-1]) 93 | faiss_search_index.add(feature_matrix_all) 94 | _, closest_feature_idxs = faiss_search_index.search(feature_matrix_all, n_closest+1) 95 | 96 | sample_paths = image_paths[closest_feature_idxs][sample_idxs] 97 | 98 | f,axes = plt.subplots(n_image_samples, n_closest+1) 99 | for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))): 100 | ax.imshow(np.array(Image.open(plot_path))) 101 | ax.set_xticks([]) 102 | ax.set_yticks([]) 103 | if i%(n_closest+1): 104 | ax.axvline(x=0, color='g', linewidth=13) 105 | else: 106 | ax.axvline(x=0, color='r', linewidth=13) 107 | f.set_size_inches(10,20) 108 | f.tight_layout() 109 | f.savefig(save_path) 110 | plt.close() 111 | 112 | 113 | 114 | 115 | ####### RECOVER CLOSEST EXAMPLE IMAGES ####### 116 | def recover_closest_query_gallery(query_feature_matrix_all, gallery_feature_matrix_all, query_image_paths, gallery_image_paths, \ 117 | save_path, n_image_samples=10, n_closest=3): 118 | query_image_paths, gallery_image_paths = np.array(query_image_paths), np.array(gallery_image_paths) 119 | sample_idxs = np.random.choice(np.arange(len(gallery_feature_matrix_all)), n_image_samples) 120 | 121 | faiss_search_index = faiss.IndexFlatL2(gallery_feature_matrix_all.shape[-1]) 122 | faiss_search_index.add(gallery_feature_matrix_all) 123 | _, closest_feature_idxs = faiss_search_index.search(query_feature_matrix_all, n_closest) 124 | 125 | ### TODO: EXAMINE THIS SECTION HERE FOR INSHOP-NEAREST SAMPLE RETRIEVAL 126 | image_paths = gallery_image_paths[closest_feature_idxs] 127 | image_paths = np.concatenate([query_image_paths.reshape(-1,1), image_paths],axis=-1) 128 | 129 | sample_paths = image_paths[closest_feature_idxs][sample_idxs] 130 | 131 | f,axes = plt.subplots(n_image_samples, n_closest+1) 132 | for i,(ax,plot_path) in enumerate(zip(axes.reshape(-1), sample_paths.reshape(-1))): 133 | ax.imshow(np.array(Image.open(plot_path))) 134 | ax.set_xticks([]) 135 | ax.set_yticks([]) 136 | if i%(n_closest+1): 137 | ax.axvline(x=0, color='g', linewidth=13) 138 | else: 139 | ax.axvline(x=0, color='r', linewidth=13) 140 | f.set_size_inches(10,20) 141 | f.tight_layout() 142 | f.savefig(save_path) 143 | # plt.show() 144 | plt.close() 145 | -------------------------------------------------------------------------------- /images/AUC_Comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/images/AUC_Comp.png -------------------------------------------------------------------------------- /images/fewshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/images/fewshot.png -------------------------------------------------------------------------------- /images/progression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/images/progression.png -------------------------------------------------------------------------------- /images/progression_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/images/progression_comp.png -------------------------------------------------------------------------------- /images/umaps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/images/umaps.png -------------------------------------------------------------------------------- /metrics/a_recall.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Metric(): 4 | def __init__(self, k, **kwargs): 5 | self.k = k 6 | self.requires = ['kmeans', 'nearest'] 7 | 8 | def compute(self, target_labels, k_closest_classes): 9 | recall_all_k = [] 10 | for k in k_vals: 11 | recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes) if target in recalled_predictions[:k]])/len(target_labels) 12 | recall_all_k.append(recall_at_k) 13 | return recall_all_k 14 | -------------------------------------------------------------------------------- /metrics/compute_stack.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /metrics/dists.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import distance 2 | from sklearn.preprocessing import normalize 3 | import numpy as np 4 | import torch 5 | 6 | class Metric(): 7 | def __init__(self, mode, **kwargs): 8 | self.mode = mode 9 | self.requires = ['features', 'target_labels'] 10 | self.name = 'dists@{}'.format(mode) 11 | 12 | def __call__(self, features, target_labels): 13 | features_locs = [] 14 | for lab in np.unique(target_labels): 15 | features_locs.append(np.where(target_labels==lab)[0]) 16 | 17 | if 'intra' in self.mode: 18 | if isinstance(features, torch.Tensor): 19 | intrafeatures = features.detach().cpu().numpy() 20 | else: 21 | intrafeatures = features 22 | 23 | intra_dists = [] 24 | for loc in features_locs: 25 | c_dists = distance.cdist(intrafeatures[loc], intrafeatures[loc], 'cosine') 26 | c_dists = np.sum(c_dists)/(len(c_dists)**2-len(c_dists)) 27 | intra_dists.append(c_dists) 28 | intra_dists = np.array(intra_dists) 29 | maxval = np.max(intra_dists[1-np.isnan(intra_dists)]) 30 | intra_dists[np.isnan(intra_dists)] = maxval 31 | intra_dists[np.isinf(intra_dists)] = maxval 32 | dist_metric = dist_metric_intra = np.mean(intra_dists) 33 | 34 | if 'inter' in self.mode: 35 | if not isinstance(features, torch.Tensor): 36 | coms = [] 37 | for loc in features_locs: 38 | com = normalize(np.mean(features[loc],axis=0).reshape(1,-1)).reshape(-1) 39 | coms.append(com) 40 | mean_inter_dist = distance.cdist(np.array(coms), np.array(coms), 'cosine') 41 | dist_metric = dist_metric_inter = np.sum(mean_inter_dist)/(len(mean_inter_dist)**2-len(mean_inter_dist)) 42 | else: 43 | coms = [] 44 | for loc in features_locs: 45 | com = torch.nn.functional.normalize(torch.mean(features[loc],dim=0).reshape(1,-1), dim=-1).reshape(1,-1) 46 | coms.append(com) 47 | mean_inter_dist = 1-torch.cat(coms,dim=0).mm(torch.cat(coms,dim=0).T).detach().cpu().numpy() 48 | dist_metric = dist_metric_inter = np.sum(mean_inter_dist)/(len(mean_inter_dist)**2-len(mean_inter_dist)) 49 | 50 | if self.mode=='intra_over_inter': 51 | dist_metric = dist_metric_intra/np.clip(dist_metric_inter, 1e-8, None) 52 | 53 | return dist_metric 54 | -------------------------------------------------------------------------------- /metrics/e_recall.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Metric(): 4 | def __init__(self, k, **kwargs): 5 | self.k = k 6 | self.requires = ['nearest_features', 'target_labels'] 7 | self.name = 'e_recall@{}'.format(k) 8 | 9 | def __call__(self, target_labels, k_closest_classes): 10 | recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes) if target in recalled_predictions[:self.k]])/len(target_labels) 11 | return recall_at_k 12 | -------------------------------------------------------------------------------- /metrics/f1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import comb, binom 3 | import torch 4 | 5 | class Metric(): 6 | def __init__(self, **kwargs): 7 | self.requires = ['kmeans', 'kmeans_nearest', 'features', 'target_labels'] 8 | self.name = 'f1' 9 | 10 | def __call__(self, target_labels, computed_cluster_labels, features, centroids): 11 | import time 12 | start = time.time() 13 | if isinstance(features, torch.Tensor): 14 | features = features.detach().cpu().numpy() 15 | d = np.zeros(len(features)) 16 | for i in range(len(features)): 17 | d[i] = np.linalg.norm(features[i,:] - centroids[computed_cluster_labels[i],:]) 18 | 19 | start = time.time() 20 | labels_pred = np.zeros(len(features)) 21 | for i in np.unique(computed_cluster_labels): 22 | index = np.where(computed_cluster_labels == i)[0] 23 | ind = np.argmin(d[index]) 24 | cid = index[ind] 25 | labels_pred[index] = cid 26 | 27 | 28 | start = time.time() 29 | N = len(target_labels) 30 | 31 | # cluster n_labels 32 | avail_labels = np.unique(target_labels) 33 | n_labels = len(avail_labels) 34 | 35 | # count the number of objects in each cluster 36 | count_cluster = np.zeros(n_labels) 37 | for i in range(n_labels): 38 | count_cluster[i] = len(np.where(target_labels == avail_labels[i])[0]) 39 | 40 | # build a mapping from item_id to item index 41 | keys = np.unique(labels_pred) 42 | num_item = len(keys) 43 | values = range(num_item) 44 | item_map = dict() 45 | for i in range(len(keys)): 46 | item_map.update([(keys[i], values[i])]) 47 | 48 | 49 | # count the number of objects of each item 50 | count_item = np.zeros(num_item) 51 | for i in range(N): 52 | index = item_map[labels_pred[i]] 53 | count_item[index] = count_item[index] + 1 54 | 55 | # compute True Positive (TP) plus False Positive (FP) 56 | # tp_fp = 0 57 | tp_fp = comb(count_cluster, 2).sum() 58 | # for k in range(n_labels): 59 | # if count_cluster[k] > 1: 60 | # tp_fp = tp_fp + comb(count_cluster[k], 2) 61 | 62 | # compute True Positive (TP) 63 | tp = 0 64 | start = time.time() 65 | for k in range(n_labels): 66 | member = np.where(target_labels == avail_labels[k])[0] 67 | member_ids = labels_pred[member] 68 | count = np.zeros(num_item) 69 | for j in range(len(member)): 70 | index = item_map[member_ids[j]] 71 | count[index] = count[index] + 1 72 | # for i in range(num_item): 73 | # if count[i] > 1: 74 | # tp = tp + comb(count[i], 2) 75 | tp += comb(count,2).sum() 76 | # False Positive (FP) 77 | fp = tp_fp - tp 78 | 79 | # Compute False Negative (FN) 80 | count = comb(count_item, 2).sum() 81 | # count = 0 82 | # for j in range(num_item): 83 | # if count_item[j] > 1: 84 | # count = count + comb(count_item[j], 2) 85 | fn = count - tp 86 | 87 | # compute F measure 88 | P = tp / (tp + fp) 89 | R = tp / (tp + fn) 90 | beta = 1 91 | F = (beta*beta + 1) * P * R / (beta*beta * P + R) 92 | return F 93 | -------------------------------------------------------------------------------- /metrics/mAP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import faiss 4 | 5 | 6 | 7 | class Metric(): 8 | def __init__(self, **kwargs): 9 | self.requires = ['features', 'target_labels'] 10 | self.name = 'mAP' 11 | 12 | def __call__(self, target_labels, features): 13 | labels, freqs = np.unique(target_labels, return_counts=True) 14 | #For all benchmarks, there is really no purpose to go beyond a recall of 1000. 15 | #In addition, faiss on gpu only supports k up to 1024. 16 | R = len(features) 17 | faiss_search_index = faiss.IndexFlatL2(features.shape[-1]) 18 | faiss_search_index.add(features) 19 | 20 | nearest_neighbours = faiss_search_index.search(features, int(R+1))[1][:,1:] 21 | 22 | target_labels = target_labels.reshape(-1) 23 | nn_labels = target_labels[nearest_neighbours] 24 | 25 | avg_r_precisions = [] 26 | for label, freq in zip(labels, freqs): 27 | rows_with_label = np.where(target_labels==label)[0] 28 | for row in rows_with_label: 29 | n_recalled_samples = np.arange(1,R+1) 30 | target_label_occ_in_row = nn_labels[row,:]==label 31 | cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row) 32 | avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq 33 | avg_r_precisions.append(avg_r_pr_row) 34 | 35 | return np.mean(avg_r_precisions) 36 | -------------------------------------------------------------------------------- /metrics/mAP_1000.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import faiss 4 | 5 | 6 | 7 | class Metric(): 8 | def __init__(self, **kwargs): 9 | self.requires = ['features', 'target_labels'] 10 | self.name = 'mAP_1000' 11 | 12 | def __call__(self, target_labels, features): 13 | labels, freqs = np.unique(target_labels, return_counts=True) 14 | #For all benchmarks, there is really no purpose to go beyond a recall of 1000. 15 | #In addition, faiss on gpu only supports k up to 1024. 16 | R = 1000 17 | faiss_search_index = faiss.IndexFlatL2(features.shape[-1]) 18 | if isinstance(features, torch.Tensor): 19 | features = features.detach().cpu().numpy() 20 | res = faiss.StandardGpuResources() 21 | faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index) 22 | faiss_search_index.add(features) 23 | 24 | nearest_neighbours = faiss_search_index.search(features, int(R+1))[1][:,1:] 25 | 26 | target_labels = target_labels.reshape(-1) 27 | nn_labels = target_labels[nearest_neighbours] 28 | 29 | avg_r_precisions = [] 30 | for label, freq in zip(labels, freqs): 31 | rows_with_label = np.where(target_labels==label)[0] 32 | for row in rows_with_label: 33 | n_recalled_samples = np.arange(1,R+1) 34 | target_label_occ_in_row = nn_labels[row,:]==label 35 | cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row) 36 | avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq 37 | avg_r_precisions.append(avg_r_pr_row) 38 | 39 | return np.mean(avg_r_precisions) 40 | -------------------------------------------------------------------------------- /metrics/mAP_c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import faiss 4 | 5 | 6 | 7 | class Metric(): 8 | def __init__(self, **kwargs): 9 | self.requires = ['features', 'target_labels'] 10 | self.name = 'mAP_c' 11 | 12 | def __call__(self, target_labels, features): 13 | labels, freqs = np.unique(target_labels, return_counts=True) 14 | R = np.max(freqs) 15 | 16 | faiss_search_index = faiss.IndexFlatL2(features.shape[-1]) 17 | 18 | if isinstance(features, torch.Tensor): 19 | features = features.detach().cpu().numpy() 20 | res = faiss.StandardGpuResources() 21 | faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index) 22 | faiss_search_index.add(features) 23 | nearest_neighbours = faiss_search_index.search(features, int(R+1))[1][:,1:] 24 | 25 | target_labels = target_labels.reshape(-1) 26 | nn_labels = target_labels[nearest_neighbours] 27 | 28 | avg_r_precisions = [] 29 | for label, freq in zip(labels, freqs): 30 | rows_with_label = np.where(target_labels==label)[0] 31 | for row in rows_with_label: 32 | n_recalled_samples = np.arange(1,freq+1) 33 | target_label_occ_in_row = nn_labels[row,:freq]==label 34 | cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row) 35 | avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq 36 | avg_r_precisions.append(avg_r_pr_row) 37 | 38 | return np.mean(avg_r_precisions) 39 | -------------------------------------------------------------------------------- /metrics/nmi.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | 3 | class Metric(): 4 | def __init__(self, **kwargs): 5 | self.requires = ['kmeans_nearest', 'target_labels'] 6 | self.name = 'nmi' 7 | 8 | def __call__(self, target_labels, computed_cluster_labels): 9 | NMI = metrics.cluster.normalized_mutual_info_score(computed_cluster_labels.reshape(-1), target_labels.reshape(-1)) 10 | return NMI 11 | -------------------------------------------------------------------------------- /metrics/rho_spectrum.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import distance 2 | from sklearn.preprocessing import normalize 3 | import numpy as np 4 | 5 | 6 | class Metric(): 7 | def __init__(self, embed_dim, mode, **kwargs): 8 | self.mode = mode 9 | self.embed_dim = embed_dim 10 | self.requires = ['features'] 11 | self.name = 'rho_spectrum@'+str(mode) 12 | 13 | def __call__(self, features): 14 | from sklearn.decomposition import TruncatedSVD 15 | from scipy.stats import entropy 16 | import torch 17 | 18 | if isinstance(features, torch.Tensor): 19 | _,s,_ = torch.svd(features) 20 | s = s.cpu().numpy() 21 | else: 22 | #Features need to be clipped due to maximum histogram length for W&B of 512 23 | svd = TruncatedSVD(n_components=np.clip(np.clip(self.embed_dim-1, None, features.shape[-1]-1),None,511), n_iter=7, random_state=42) 24 | svd.fit(features) 25 | s = svd.singular_values_ 26 | 27 | if self.mode!=0: 28 | s = s[np.abs(self.mode)-1:] 29 | s_norm = s/np.sum(s) 30 | uniform = np.ones(len(s))/(len(s)) 31 | 32 | if self.mode<0: 33 | kl = entropy(s_norm, uniform) 34 | if self.mode>0: 35 | kl = entropy(uniform, s_norm) 36 | if self.mode==0: 37 | kl = s_norm 38 | 39 | return kl 40 | -------------------------------------------------------------------------------- /utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Confusezius/Characterizing_Generalization_in_DeepMetricLearning/e8a4171cfce083ef91073dbefd3a299ca294df02/utilities/__init__.py -------------------------------------------------------------------------------- /utilities/finetune_utils.py: -------------------------------------------------------------------------------- 1 | from utilities import misc 2 | import numpy as np 3 | import random 4 | import tqdm 5 | import torch 6 | import criteria as criteria 7 | import batchminer as bmine 8 | 9 | def finetune_model(opt, model, dataloader, device, finetune_params, seed=None, reweight=True): 10 | if seed is not None: 11 | misc.set_seed(seed) 12 | else: 13 | misc.set_seed(opt.seed) 14 | 15 | 16 | if finetune_params['optim'] == 'sgd': 17 | optim_f = torch.optim.SGD 18 | elif finetune_params['optim'] == 'adam': 19 | optim_f = torch.optim.Adam 20 | 21 | if opt.dataset == 'cub200': 22 | weights = [0.75, 1.25, 1.25, 1.25] 23 | elif opt.dataset == 'cars196': 24 | weights = [0.5, 1.5, 1.5, 1.5] 25 | elif opt.datset == 'online_products': 26 | weights = [1., 1., 1., 1.] 27 | 28 | if finetune_params['only_last']: 29 | # if 'multifeature' in model.name and opt.optim_weights: 30 | # weights = torch.nn.Parameter(torch.Tensor(weights)) 31 | # to_optim = [{'params': model.model.last_linear.parameters(), 'lr': finetune_params['lr']}, 32 | # {'params': weights, 'lr': finetune_params['lr']}] 33 | # finetune_optim = optim_f(to_optim) 34 | # _ = weights.to(opt.device) 35 | # else: 36 | finetune_optim = optim_f(model.model.last_linear.parameters(), lr=finetune_params['lr']) 37 | 38 | _ = model.eval() 39 | _ = model.model.last_linear.train() 40 | else: 41 | finetune_optim = optim_f(model.parameters(), lr=finetune_params['lr']) 42 | _ = model.train() 43 | 44 | # finetune_dataset.image_dict 45 | batchminer = bmine.select('random', opt) 46 | if finetune_params['criterion'] == 'multisimilarity': 47 | if 'margin' in opt.loss: 48 | opt.loss_multisimilarity_d_mode = 'euclidean' 49 | criterion, _ = criteria.select('multisimilarity', opt, [], batchminer) 50 | elif finetune_params['criterion'] == 'margin': 51 | batchminer = bmine.select('distance', opt) 52 | criterion, _ = criteria.select('margin', opt, [], batchminer) 53 | elif finetune_params['criterion'] == 'triplet': 54 | criterion, _ = criteria.select('triplet', opt, [], batchminer) 55 | 56 | 57 | loss_collect = [] 58 | finetune_iterator = tqdm.tqdm(range(finetune_params['iter']), total=finetune_params['iter'], desc='Finetuning...') 59 | # finetune_iterator = tqdm.tqdm(range(finetune_params['iter']), total=np.clip(int(np.ceil(finetune_params['iter']/len(dataloader)))-1, 1, None), desc='Finetuning...') 60 | count = 0 61 | for i in range(finetune_params['iter']): 62 | for inp in dataloader: 63 | input_img, target = inp[1], inp[0] 64 | out_dict = model(input_img.to(device), warmup=finetune_params['only_last']) 65 | if 'multifeature' in model.name: 66 | if reweight: 67 | weighted_subfeatures = [weights[i]*out_dict['embeds'][subevaltype] for i,subevaltype in enumerate(['discriminative', 'shared', 'selfsimilarity', 'intra'])] 68 | else: 69 | weighted_subfeatures = [out_dict['embeds'][subevaltype] for i, subevaltype in enumerate(['discriminative', 'shared', 'selfsimilarity', 'intra'])] 70 | if 'normalize' in model.name: 71 | out_dict['embeds'] = torch.nn.functional.normalize(torch.cat(weighted_subfeatures, dim=-1), dim=-1) 72 | else: 73 | out_dict['embeds'] = torch.cat(weighted_subfeatures, dim=-1) 74 | 75 | loss_args = {'batch': out_dict['embeds'], 'labels': target} 76 | # loss_args = {'batch': out_dict['embeds'].to(opt.device), 'labels': target.to(opt.device)} 77 | finetune_optim.zero_grad() 78 | loss = criterion(**loss_args) 79 | loss.backward() 80 | loss_collect.append(loss.item()) 81 | finetune_optim.step() 82 | finetune_iterator.set_postfix_str('Loss: {0:3.5f}'.format(np.mean(loss_collect))) 83 | 84 | count += 1 85 | 86 | finetune_iterator.update(1) 87 | if count == finetune_params['iter']: 88 | break 89 | 90 | if count == finetune_params['iter']: 91 | break 92 | 93 | 94 | 95 | 96 | 97 | 98 | def nonredundant_finetuner(opt, backbone, dataloader, device, 99 | finetune_lr, finetune_iter, finetune_criterion='margin', finetune_optim='adam', 100 | head=None, seed=None, optim_head_only=False, ): 101 | if seed is not None: 102 | misc.set_seed(seed) 103 | else: 104 | misc.set_seed(opt.seed) 105 | 106 | if finetune_optim == 'sgd': 107 | optim_f = torch.optim.SGD 108 | elif finetune_optim == 'adam': 109 | optim_f = torch.optim.Adam 110 | 111 | if optim_head_only: 112 | finetune_optim = optim_f(head.parameters(), lr=finetune_lr) 113 | _ = backbone.eval() 114 | _ = head.train() 115 | else: 116 | finetune_optim = optim_f( 117 | [{'params': backbone.parameters()}, 118 | {'params': head.parameters()}], lr=finetune_lr 119 | ) 120 | _ = backbone.train() 121 | _ = head.train() 122 | 123 | # finetune_dataset.image_dict 124 | batchminer = bmine.select('random', opt) 125 | if finetune_criterion == 'multisimilarity': 126 | if 'margin' in opt.loss: 127 | opt.loss_multisimilarity_d_mode = 'euclidean' 128 | criterion, _ = criteria.select('multisimilarity', opt, [], batchminer) 129 | elif finetune_criterion == 'margin': 130 | batchminer = bmine.select('distance', opt) 131 | criterion, _ = criteria.select('margin', opt, [], batchminer) 132 | elif finetune_criterion == 'triplet': 133 | criterion, _ = criteria.select('triplet', opt, [], batchminer) 134 | 135 | 136 | loss_collect = [] 137 | finetune_iterator = tqdm.tqdm(range(finetune_iter), total=finetune_iter, desc='Finetuning...') 138 | # finetune_iterator = tqdm.tqdm(range(finetune_params['iter']), total=np.clip(int(np.ceil(finetune_params['iter']/len(dataloader)))-1, 1, None), desc='Finetuning...') 139 | count = 0 140 | for i in range(finetune_iter): 141 | for inp in dataloader: 142 | input_img, target = inp[1], inp[0] 143 | if optim_head_only: 144 | with torch.no_grad(): 145 | out = backbone(input_img.to(device)) 146 | else: 147 | out = backbone(input_img.to(device)) 148 | if isinstance(out, dict): 149 | out = out['avg_features'] 150 | if head is not None: 151 | out = head(out.to(torch.float).to(device)) 152 | 153 | loss_args = {'batch': out, 'labels': target} 154 | # loss_args = {'batch': out_dict['embeds'].to(opt.device), 'labels': target.to(opt.device)} 155 | finetune_optim.zero_grad() 156 | loss = criterion(**loss_args) 157 | loss.backward() 158 | loss_collect.append(loss.item()) 159 | finetune_optim.step() 160 | finetune_iterator.set_postfix_str('Loss: {0:3.5f}'.format(np.mean(loss_collect))) 161 | 162 | count += 1 163 | 164 | finetune_iterator.update(1) 165 | if count == finetune_iter: 166 | break 167 | 168 | if count == finetune_iter: 169 | break 170 | -------------------------------------------------------------------------------- /utilities/logger.py: -------------------------------------------------------------------------------- 1 | import datetime, csv, os, numpy as np 2 | from matplotlib import pyplot as plt 3 | import pickle as pkl 4 | from utilities.misc import gimme_save_string 5 | 6 | """=============================================================================================================""" 7 | ################## WRITE TO CSV FILE ##################### 8 | class CSV_Writer(): 9 | def __init__(self, save_path): 10 | self.save_path = save_path 11 | self.written = [] 12 | self.n_written_lines = {} 13 | 14 | def log(self, group, segments, content): 15 | if group not in self.n_written_lines.keys(): 16 | self.n_written_lines[group] = 0 17 | 18 | with open(self.save_path+'_'+group+'.csv', "a") as csv_file: 19 | writer = csv.writer(csv_file, delimiter=",") 20 | if group not in self.written: writer.writerow(segments) 21 | for line in content: 22 | writer.writerow(line) 23 | self.n_written_lines[group] += 1 24 | 25 | self.written.append(group) 26 | 27 | 28 | 29 | ################## PLOT SUMMARY IMAGE ##################### 30 | class InfoPlotter(): 31 | def __init__(self, save_path, title='Training Log', figsize=(25,19)): 32 | self.save_path = save_path 33 | self.title = title 34 | self.figsize = figsize 35 | self.colors = ['r','g','b','y','m','c','orange','darkgreen','lightblue'] 36 | 37 | def make_plot(self, base_title, title_append, sub_plots, sub_plots_data): 38 | sub_plots = list(sub_plots) 39 | if 'epochs' not in sub_plots: 40 | x_data = range(len(sub_plots_data[0])) 41 | else: 42 | x_data = range(sub_plots_data[np.where(np.array(sub_plots)=='epochs')[0][0]][-1]+1) 43 | 44 | self.ov_title = [(sub_plot,sub_plot_data) for sub_plot, sub_plot_data in zip(sub_plots,sub_plots_data) if sub_plot not in ['epoch','epochs','time']] 45 | self.ov_title = [(x[0],np.max(x[1])) if 'loss' not in x[0] else (x[0],np.min(x[1])) for x in self.ov_title] 46 | self.ov_title = title_append +': '+ ' | '.join('{0}: {1:.4f}'.format(x[0],x[1]) for x in self.ov_title) 47 | sub_plots_data = [x for x,y in zip(sub_plots_data, sub_plots)] 48 | sub_plots = [x for x in sub_plots] 49 | 50 | plt.style.use('ggplot') 51 | f,ax = plt.subplots(1) 52 | ax.set_title(self.ov_title, fontsize=22) 53 | for i,(data, title) in enumerate(zip(sub_plots_data, sub_plots)): 54 | ax.plot(x_data, data, '-{}'.format(self.colors[i]), linewidth=1.7, label=base_title+' '+title) 55 | ax.tick_params(axis='both', which='major', labelsize=18) 56 | ax.tick_params(axis='both', which='minor', labelsize=18) 57 | ax.legend(loc=2, prop={'size': 16}) 58 | f.set_size_inches(self.figsize[0], self.figsize[1]) 59 | f.savefig(self.save_path+'_'+title_append+'.svg') 60 | plt.close() 61 | 62 | 63 | ################## GENERATE LOGGING FOLDER/FILES ####################### 64 | def set_logging(opt): 65 | checkfolder = opt.save_path+'/'+opt.savename 66 | if opt.savename == '': 67 | date = datetime.datetime.now() 68 | time_string = '{}-{}-{}-{}-{}-{}'.format(date.year, date.month, date.day, date.hour, date.minute, date.second) 69 | checkfolder = opt.save_path+'/{}_{}_'.format(opt.dataset.upper(), opt.arch.upper())+time_string 70 | counter = 1 71 | while os.path.exists(checkfolder): 72 | checkfolder = opt.save_path+'/'+opt.savename+'_'+str(counter) 73 | counter += 1 74 | os.makedirs(checkfolder) 75 | opt.save_path = checkfolder 76 | 77 | save_opt = opt 78 | 79 | with open(save_opt.save_path+'/Parameter_Info.txt','w') as f: 80 | f.write(gimme_save_string(save_opt)) 81 | pkl.dump(save_opt,open(save_opt.save_path+"/hypa.pkl","wb")) 82 | 83 | 84 | class Progress_Saver(): 85 | def __init__(self): 86 | self.groups = {} 87 | 88 | def log(self, segment, content, group=None): 89 | if group is None: group = segment 90 | if group not in self.groups.keys(): 91 | self.groups[group] = {} 92 | 93 | if segment not in self.groups[group].keys(): 94 | self.groups[group][segment] = {'content':[],'saved_idx':0} 95 | 96 | self.groups[group][segment]['content'].append(content) 97 | 98 | 99 | class LOGGER(): 100 | def __init__(self, opt, sub_loggers=[], prefix=None, start_new=True, log_online=False): 101 | """ 102 | LOGGER Internal Structure: 103 | 104 | self.progress_saver: Contains multiple Progress_Saver instances to log metrics for main metric subsets (e.g. "Train" for training metrics) 105 | ['main_subset_name']: Name of each main subset (-> e.g. "Train") 106 | .groups: Dictionary of subsets belonging to one of the main subsets, e.g. ["Recall", "NMI", ...] 107 | ['specific_metric_name']: Specific name of the metric of interest, e.g. Recall@1. 108 | """ 109 | self.prop = opt 110 | self.prefix = '{}_'.format(prefix) if prefix is not None else '' 111 | self.sub_loggers = sub_loggers 112 | 113 | ### Make Logging Directories 114 | if start_new: set_logging(opt) 115 | 116 | ### Set Graph and CSV writer 117 | self.csv_writer, self.graph_writer, self.progress_saver = {},{},{} 118 | for sub_logger in sub_loggers: 119 | csv_savepath = opt.save_path+'/CSV_Logs' 120 | if not os.path.exists(csv_savepath): os.makedirs(csv_savepath) 121 | self.csv_writer[sub_logger] = CSV_Writer(csv_savepath+'/Data_{}{}'.format(self.prefix, sub_logger)) 122 | 123 | prgs_savepath = opt.save_path+'/Progression_Plots' 124 | if not os.path.exists(prgs_savepath): os.makedirs(prgs_savepath) 125 | self.graph_writer[sub_logger] = InfoPlotter(prgs_savepath+'/Graph_{}{}'.format(self.prefix, sub_logger)) 126 | self.progress_saver[sub_logger] = Progress_Saver() 127 | 128 | 129 | ### WandB Init 130 | self.save_path = opt.save_path 131 | self.log_online = log_online 132 | 133 | 134 | def update(self, *sub_loggers, all=False): 135 | online_content = [] 136 | 137 | if all: sub_loggers = self.sub_loggers 138 | 139 | for sub_logger in list(sub_loggers): 140 | for group in self.progress_saver[sub_logger].groups.keys(): 141 | pgs = self.progress_saver[sub_logger].groups[group] 142 | segments = pgs.keys() 143 | per_seg_saved_idxs = [pgs[segment]['saved_idx'] for segment in segments] 144 | per_seg_contents = [pgs[segment]['content'][idx:] for segment,idx in zip(segments, per_seg_saved_idxs)] 145 | per_seg_contents_all = [pgs[segment]['content'] for segment,idx in zip(segments, per_seg_saved_idxs)] 146 | 147 | #Adjust indexes 148 | for content,segment in zip(per_seg_contents, segments): 149 | self.progress_saver[sub_logger].groups[group][segment]['saved_idx'] += len(content) 150 | 151 | tupled_seg_content = [list(seg_content_slice) for seg_content_slice in zip(*per_seg_contents)] 152 | 153 | self.csv_writer[sub_logger].log(group, segments, tupled_seg_content) 154 | self.graph_writer[sub_logger].make_plot(sub_logger, group, segments, per_seg_contents_all) 155 | 156 | for i,segment in enumerate(segments): 157 | if group == segment: 158 | name = sub_logger+': '+group 159 | else: 160 | name = sub_logger+': '+group+': '+segment 161 | online_content.append((name,per_seg_contents[i])) 162 | 163 | if self.log_online: 164 | if self.prop.online_backend=='wandb': 165 | import wandb 166 | for i,item in enumerate(online_content): 167 | if isinstance(item[1], list): 168 | wandb.log({item[0]:np.mean(item[1])}, step=self.prop.epoch) 169 | else: 170 | wandb.log({item[0]:item[1]}, step=self.prop.epoch) 171 | elif self.prop.online_backend=='comet_ml': 172 | for i,item in enumerate(online_content): 173 | if isinstance(item[1], list): 174 | self.prop.experiment.log_metric(item[0],np.mean(item[1]), self.prop.epoch) 175 | else: 176 | self.prop.experiment.log_metric(item[0],item[1],self.prop.epoch) 177 | -------------------------------------------------------------------------------- /utilities/misc.py: -------------------------------------------------------------------------------- 1 | """=============================================================================================================""" 2 | ######## LIBRARIES ##################### 3 | import numpy as np 4 | import torch 5 | import numpy as np 6 | import random 7 | 8 | def set_seed(seed): 9 | torch.backends.cudnn.deterministic=True; 10 | np.random.seed(seed); random.seed(seed) 11 | torch.manual_seed(seed); torch.cuda.manual_seed(seed); torch.cuda.manual_seed_all(seed) 12 | 13 | """=============================================================================================================""" 14 | ################# ACQUIRE NUMBER OF WEIGHTS ################# 15 | def gimme_params(model): 16 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 17 | params = sum([np.prod(p.size()) for p in model_parameters]) 18 | return params 19 | 20 | 21 | ################# SAVE TRAINING PARAMETERS IN NICE STRING ################# 22 | def gimme_save_string(opt): 23 | varx = vars(opt) 24 | base_str = '' 25 | for key in varx: 26 | base_str += str(key) 27 | if isinstance(varx[key],dict): 28 | for sub_key, sub_item in varx[key].items(): 29 | base_str += '\n\t'+str(sub_key)+': '+str(sub_item) 30 | else: 31 | base_str += '\n\t'+str(varx[key]) 32 | base_str+='\n\n' 33 | return base_str 34 | 35 | 36 | ############################################################################# 37 | import torch, torch.nn as nn 38 | 39 | class DataParallel(nn.Module): 40 | def __init__(self, model, device_ids, dim): 41 | super().__init__() 42 | self.model = model.model 43 | self.network = nn.DataParallel(model, device_ids, dim) 44 | 45 | def forward(self, x): 46 | return self.network(x) 47 | --------------------------------------------------------------------------------