├── roadmap ├── __init__.py ├── utils │ ├── logger.py │ ├── freeze_pos_embedding.py │ ├── count_parameters.py │ ├── set_lr.py │ ├── expand_path.py │ ├── set_initial_lr.py │ ├── get_lr.py │ ├── list_or_tuple.py │ ├── format_time.py │ ├── freeze_batch_norm.py │ ├── str_to_bool.py │ ├── rgb_to_bgr.py │ ├── get_gradient_norm.py │ ├── override_config.py │ ├── average_meter.py │ ├── extract_progress.py │ ├── moving_average.py │ ├── create_label_matrix.py │ ├── dict_average.py │ ├── get_set_random_state.py │ └── __init__.py ├── models │ ├── __init__.py │ ├── create_projection_head.py │ └── net.py ├── config │ ├── loss │ │ ├── fastap.yaml │ │ ├── smoothap.yaml │ │ ├── pair_loss.yaml │ │ ├── blackboxap.yaml │ │ ├── softbinap.yaml │ │ ├── affineap.yaml │ │ ├── supap.yaml │ │ ├── supap_inat.yaml │ │ ├── roadmap.yaml │ │ └── roadmap_inat.yaml │ ├── memory │ │ ├── cub.yaml │ │ ├── default.yaml │ │ ├── sop.yaml │ │ └── inaturalist.yaml │ ├── optimizer │ │ ├── cub.yaml │ │ ├── cub_deit.yaml │ │ ├── inshop_deit.yaml │ │ ├── sfm120k_deit.yaml │ │ ├── sop_deit.yaml │ │ ├── inaturalist_deit.yaml │ │ ├── inshop.yaml │ │ ├── sop.yaml │ │ └── inaturalist.yaml │ ├── dataset │ │ ├── cub.yaml │ │ ├── inaturalist.yaml │ │ ├── sop.yaml │ │ ├── inshop.yaml │ │ └── sfm120k.yaml │ ├── model │ │ ├── resnet.yaml │ │ ├── deit.yaml │ │ └── resnet_max_ln.yaml │ ├── hydra │ │ └── launcher │ │ │ └── ray_launcher.yaml │ ├── transform │ │ ├── sfm120k.yaml │ │ ├── cub.yaml │ │ ├── cub_big.yaml │ │ ├── inaturalist.yaml │ │ ├── inshop_big.yaml │ │ ├── sop.yaml │ │ └── sop_big.yaml │ ├── default.yaml │ └── experience │ │ ├── landmarks.yaml │ │ └── default.yaml ├── losses │ ├── fast_ap.py │ ├── __init__.py │ ├── calibration_loss.py │ ├── pair_loss.py │ ├── blackbox_ap.py │ ├── softbin_ap.py │ └── smooth_rank_ap.py ├── samplers │ ├── __init__.py │ ├── random_sampler.py │ ├── m_per_class_sampler.py │ └── hierarchical_sampler.py ├── datasets │ ├── __init__.py │ ├── sop.py │ ├── sfm120k.py │ ├── inshop.py │ ├── cub200.py │ ├── inaturalist.py │ ├── revisited_dataset.py │ └── base_dataset.py ├── engine │ ├── make_subset.py │ ├── __init__.py │ ├── chepoint.py │ ├── get_knn.py │ ├── memory.py │ ├── base_update.py │ ├── evaluate.py │ ├── train.py │ ├── cross_validation_splits.py │ ├── accuracy_calculator.py │ └── landmark_evaluation.py ├── single_experiment_runner.py ├── evaluate.py ├── getter.py └── run.py ├── picture └── outline.png ├── dev_requirements.txt ├── .gitignore ├── requirements.txt ├── setup.py ├── LICENSE └── README.md /roadmap/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /picture/outline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/elias-ramzi/ROADMAP/HEAD/picture/outline.png -------------------------------------------------------------------------------- /roadmap/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | LOGGER = logging.getLogger("ROADMAP") 4 | -------------------------------------------------------------------------------- /roadmap/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .net import RetrievalNet 2 | 3 | 4 | __all__ = [ 5 | 'RetrievalNet', 6 | ] 7 | -------------------------------------------------------------------------------- /roadmap/config/loss/fastap.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: FastAP 3 | weight: 1.0 4 | kwargs: 5 | num_bins: 10 6 | -------------------------------------------------------------------------------- /roadmap/config/loss/smoothap.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: SmoothAP 3 | weight: 1.0 4 | kwargs: 5 | tau: 0.01 6 | -------------------------------------------------------------------------------- /roadmap/config/loss/pair_loss.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: PairLoss 3 | weight: 1.0 4 | kwargs: 5 | margin: 0.5 6 | -------------------------------------------------------------------------------- /roadmap/utils/freeze_pos_embedding.py: -------------------------------------------------------------------------------- 1 | def freeze_pos_embedding(net): 2 | net.pos_embed.requires_grad_(False) 3 | return net 4 | -------------------------------------------------------------------------------- /roadmap/utils/count_parameters.py: -------------------------------------------------------------------------------- 1 | def count_parameters(model): 2 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 3 | -------------------------------------------------------------------------------- /roadmap/utils/set_lr.py: -------------------------------------------------------------------------------- 1 | def set_lr(optimizer, lr): 2 | for param_group in optimizer.param_groups: 3 | param_group['lr'] = lr 4 | -------------------------------------------------------------------------------- /roadmap/losses/fast_ap.py: -------------------------------------------------------------------------------- 1 | from pytorch_metric_learning import losses 2 | 3 | 4 | class FastAP(losses.FastAPLoss): 5 | takes_embeddings = True 6 | -------------------------------------------------------------------------------- /roadmap/config/loss/blackboxap.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: BlackBoxAP 3 | weight: 1.0 4 | kwargs: 5 | lambda_val: 4.0 6 | margin: 0.02 7 | -------------------------------------------------------------------------------- /roadmap/config/loss/softbinap.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: SoftBinAP 3 | weight: 1.0 4 | kwargs: 5 | nq: 20 6 | min: -1 7 | max: 1 8 | -------------------------------------------------------------------------------- /roadmap/config/memory/cub.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: XBM 3 | 4 | activate_after: -1 5 | weight: 1.0 6 | kwargs: 7 | size: 5824 8 | unique: True 9 | -------------------------------------------------------------------------------- /roadmap/config/memory/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: 3 | 4 | activate_after: -1 5 | weight: 1.0 6 | kwargs: 7 | size: null 8 | unique: null 9 | -------------------------------------------------------------------------------- /roadmap/config/memory/sop.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: XBM 3 | 4 | activate_after: -1 5 | weight: 1.0 6 | kwargs: 7 | size: 59551 8 | unique: True 9 | -------------------------------------------------------------------------------- /roadmap/config/loss/affineap.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: AffineAP 3 | weight: 1.0 4 | kwargs: 5 | theta: 0.5 6 | mu_n: 0.025 7 | mu_p: 0.025 8 | -------------------------------------------------------------------------------- /roadmap/config/memory/inaturalist.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: XBM 3 | 4 | activate_after: -1 5 | weight: 1.0 6 | kwargs: 7 | size: 60000 8 | unique: False 9 | -------------------------------------------------------------------------------- /roadmap/utils/expand_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def expand_path(pth): 5 | pth = os.path.expandvars(pth) 6 | pth = os.path.expanduser(pth) 7 | return pth 8 | -------------------------------------------------------------------------------- /dev_requirements.txt: -------------------------------------------------------------------------------- 1 | wheel 2 | six 3 | appdirs 4 | ordered_set 5 | ipython 6 | jupyter 7 | ipdb 8 | autopep8 9 | flake8 10 | pylint 11 | isort 12 | jedi==0.17.1 13 | -------------------------------------------------------------------------------- /roadmap/utils/set_initial_lr.py: -------------------------------------------------------------------------------- 1 | def set_initial_lr(optimizer): 2 | for param_group in optimizer.param_groups: 3 | param_group['initial_lr'] = param_group['lr'] 4 | -------------------------------------------------------------------------------- /roadmap/config/loss/supap.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: SupAP 3 | weight: 1.0 4 | kwargs: 5 | tau: 0.01 6 | rho: 100.0 7 | offset: 1.44 8 | delta: 0.05 9 | -------------------------------------------------------------------------------- /roadmap/utils/get_lr.py: -------------------------------------------------------------------------------- 1 | def get_lr(optimizer): 2 | all_lr = {} 3 | for i, param_group in enumerate(optimizer.param_groups): 4 | all_lr[i] = param_group['lr'] 5 | return all_lr 6 | -------------------------------------------------------------------------------- /roadmap/config/loss/supap_inat.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: SupAP 3 | weight: 1.0 4 | kwargs: 5 | tau: 0.01 6 | rho: 100.0 7 | offset: 1.0 8 | delta: 0.05 9 | start: 0.0 10 | -------------------------------------------------------------------------------- /roadmap/utils/list_or_tuple.py: -------------------------------------------------------------------------------- 1 | from omegaconf.listconfig import ListConfig 2 | 3 | 4 | def list_or_tuple(lst): 5 | if isinstance(lst, (ListConfig, tuple, list)): 6 | return True 7 | 8 | return False 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | __pycache__ 3 | .ipynb_checkpoints 4 | .pytest_cache 5 | .DS_Store 6 | 7 | build/ 8 | dist/ 9 | *.egg-info 10 | *.so 11 | *.c 12 | 13 | tmp 14 | personal_experiment 15 | .idea 16 | outputs 17 | scripts 18 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/cub.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: Adam 3 | params: 4 | kwargs: 5 | lr: 0.000001 6 | weight_decay: 0.0004 7 | scheduler_on_epoch: 8 | scheduler_on_step: 9 | scheduler_on_val: 10 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/cub_deit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: AdamW 3 | params: 4 | kwargs: 5 | lr: 0.000001 6 | weight_decay: 0.0005 7 | scheduler_on_epoch: 8 | scheduler_on_step: 9 | scheduler_on_val: 10 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/inshop_deit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: AdamW 3 | params: 4 | kwargs: 5 | lr: 0.00001 6 | weight_decay: 0.0005 7 | scheduler_on_epoch: 8 | scheduler_on_step: 9 | scheduler_on_val: 10 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/sfm120k_deit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: AdamW 3 | params: 4 | kwargs: 5 | lr: 0.00001 6 | weight_decay: 0.0005 7 | scheduler_on_epoch: 8 | scheduler_on_step: 9 | scheduler_on_val: 10 | -------------------------------------------------------------------------------- /roadmap/utils/format_time.py: -------------------------------------------------------------------------------- 1 | def format_time(seconds): 2 | seconds = int(seconds) 3 | minutes = seconds // 60 4 | hours = minutes // 60 5 | minutes = minutes % 60 6 | rseconds = seconds % 60 7 | return f"{hours}h{minutes}m{rseconds}s" 8 | -------------------------------------------------------------------------------- /roadmap/utils/freeze_batch_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def freeze_batch_norm(model): 5 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, model.modules()): 6 | module.eval() 7 | module.train = lambda _: None 8 | return model 9 | -------------------------------------------------------------------------------- /roadmap/config/dataset/cub.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: Cub200Dataset 3 | kwargs: 4 | data_dir: /local/DEEPLEARNING/image_retrieval/CUB_200_2011 5 | 6 | sampler: 7 | name: MPerClassSampler 8 | kwargs: 9 | batch_size: 128 10 | samples_per_class: 4 11 | -------------------------------------------------------------------------------- /roadmap/utils/str_to_bool.py: -------------------------------------------------------------------------------- 1 | def str_to_bool(condition): 2 | if isinstance(condition, str): 3 | if condition.lower() == 'true': 4 | condition = True 5 | if condition.lower() == 'false': 6 | condition = False 7 | 8 | return condition 9 | -------------------------------------------------------------------------------- /roadmap/config/dataset/inaturalist.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: INaturalistDataset 3 | kwargs: 4 | data_dir: /local/DEEPLEARNING/image_retrieval/Inaturalist 5 | 6 | sampler: 7 | name: MPerClassSampler 8 | kwargs: 9 | batch_size: 128 10 | samples_per_class: 4 11 | -------------------------------------------------------------------------------- /roadmap/config/model/resnet.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: RetrievalNet 3 | freeze_batch_norm: True 4 | freeze_pos_embedding: False 5 | kwargs: 6 | backbone_name: resnet50 7 | embed_dim: 512 8 | norm_features: False 9 | without_fc: False 10 | with_autocast: True 11 | -------------------------------------------------------------------------------- /roadmap/config/model/deit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: RetrievalNet 3 | freeze_batch_norm: False 4 | freeze_pos_embedding: False 5 | kwargs: 6 | backbone_name: vit_deit_distilled 7 | embed_dim: 384 8 | norm_features: False 9 | without_fc: True 10 | with_autocast: True 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.3 2 | torch==1.8.1 3 | torchvision==0.9.1 4 | faiss-gpu==1.6.5 5 | pillow==8.2.0 6 | pandas==1.2.4 7 | tqdm==4.61.0 8 | tensorboard==2.5.0 9 | matplotlib==3.4.2 10 | pytorch-metric-learning==0.9.99 11 | timm==0.4.12 12 | hydra-core==1.0.6 13 | hydra_colorlog==1.0.1 14 | -------------------------------------------------------------------------------- /roadmap/config/hydra/launcher/ray_launcher.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.launcher 2 | _target_: hydra_plugins.hydra_ray_launcher.ray_launcher.RayLauncher 3 | ray: 4 | init: 5 | address: null 6 | _temp_dir: ${env:HOME}/experiments/tmp 7 | 8 | remote: 9 | num_gpus: 1 10 | num_cpus: 16 11 | -------------------------------------------------------------------------------- /roadmap/config/loss/roadmap.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: CalibrationLoss 3 | weight: 1.0 4 | kwargs: 5 | pos_margin: 0.9 6 | neg_margin: 0.6 7 | 8 | - name: SupAP 9 | weight: 1.0 10 | kwargs: 11 | tau: 0.01 12 | rho: 100.0 13 | offset: 1.44 14 | delta: 0.05 15 | -------------------------------------------------------------------------------- /roadmap/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .hierarchical_sampler import HierarchicalSampler 2 | from .m_per_class_sampler import MPerClassSampler 3 | from .random_sampler import RandomSampler 4 | 5 | 6 | __all__ = [ 7 | 'HierarchicalSampler', 8 | 'MPerClassSampler', 9 | 'RandomSampler', 10 | ] 11 | -------------------------------------------------------------------------------- /roadmap/config/model/resnet_max_ln.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: RetrievalNet 3 | freeze_batch_norm: True 4 | freeze_pos_embedding: False 5 | kwargs: 6 | backbone_name: resnet50 7 | embed_dim: 512 8 | norm_features: True 9 | without_fc: False 10 | with_autocast: True 11 | pooling: max 12 | -------------------------------------------------------------------------------- /roadmap/config/dataset/sop.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: SOPDataset 3 | kwargs: 4 | data_dir: /local/DEEPLEARNING/image_retrieval/Stanford_Online_Products 5 | 6 | sampler: 7 | name: HierarchicalSampler 8 | kwargs: 9 | batch_size: 128 10 | samples_per_class: 4 11 | batches_per_super_pair: 10 12 | -------------------------------------------------------------------------------- /roadmap/config/loss/roadmap_inat.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: CalibrationLoss 3 | weight: 1.0 4 | kwargs: 5 | pos_margin: 0.9 6 | neg_margin: 0.6 7 | 8 | - name: SupAP 9 | weight: 1.0 10 | kwargs: 11 | tau: 0.01 12 | rho: 100.0 13 | offset: 1.0 14 | delta: 0.05 15 | start: 0.0 16 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/sop_deit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: AdamW 3 | params: 4 | kwargs: 5 | lr: 0.00001 6 | weight_decay: 0.0005 7 | scheduler_on_epoch: 8 | name: MultiStepLR 9 | kwargs: 10 | milestones: [25, 50] 11 | gamma: 0.3 12 | scheduler_on_step: 13 | scheduler_on_val: 14 | -------------------------------------------------------------------------------- /roadmap/config/dataset/inshop.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: InShopDataset 3 | kwargs: 4 | data_dir: /local/DEEPLEARNING/image_retrieval/inshop 5 | hierarchy_mode: 'all' 6 | 7 | sampler: 8 | name: HierarchicalSampler 9 | kwargs: 10 | batch_size: 128 11 | samples_per_class: 8 12 | batches_per_super_pair: 4 13 | nb_categories: 2 14 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/inaturalist_deit.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: AdamW 3 | params: 4 | kwargs: 5 | lr: 0.00001 6 | weight_decay: 0.0005 7 | scheduler_on_epoch: 8 | name: MultiStepLR 9 | kwargs: 10 | milestones: [30, 70] 11 | gamma: 0.3 12 | last_epoch: -1 13 | scheduler_on_step: 14 | scheduler_on_val: 15 | -------------------------------------------------------------------------------- /roadmap/utils/rgb_to_bgr.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | 4 | class RGBToBGR(): 5 | def __call__(self, im): 6 | assert im.mode == 'RGB' 7 | r, g, b = [im.getchannel(i) for i in range(3)] 8 | # RGB mode also for BGR, `3x8-bit pixels, true color`, see PIL doc 9 | im = Image.merge('RGB', [b, g, r]) 10 | return im 11 | -------------------------------------------------------------------------------- /roadmap/utils/get_gradient_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_gradient_norm(net): 5 | if torch.cuda.device_count() > 1: 6 | net = net.module 7 | 8 | if hasattr(net, 'fc'): 9 | final_layer = net.fc 10 | 11 | elif hasattr(net, 'blocks'): 12 | final_layer = net.blocks[-1].mlp.fc2 13 | 14 | return torch.norm(list(final_layer.parameters())[0].grad, 2) 15 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/inshop.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: Adam 3 | params: backbone 4 | kwargs: 5 | lr: 0.00001 6 | weight_decay: 0.0001 7 | scheduler_on_epoch: 8 | scheduler_on_step: 9 | scheduler_on_val: 10 | 11 | - name: Adam 12 | params: fc 13 | kwargs: 14 | lr: 0.00002 15 | weight_decay: 0.0001 16 | scheduler_on_epoch: 17 | scheduler_on_step: 18 | scheduler_on_val: 19 | -------------------------------------------------------------------------------- /roadmap/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cub200 import Cub200Dataset 2 | from .inaturalist import INaturalistDataset 3 | from .inshop import InShopDataset 4 | from .revisited_dataset import RevisitedDataset 5 | from .sfm120k import SfM120kDataset 6 | from .sop import SOPDataset 7 | 8 | 9 | __all__ = [ 10 | 'Cub200Dataset', 11 | 'INaturalistDataset', 12 | 'InShopDataset', 13 | 'RevisitedDataset', 14 | 'SfM120kDataset', 15 | 'SOPDataset', 16 | ] 17 | -------------------------------------------------------------------------------- /roadmap/config/transform/sfm120k.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | train: 3 | RandomResizedCrop: 4 | scale: [0.16, 1] 5 | ratio: [0.75, 1.33] 6 | size: 224 7 | RandomHorizontalFlip: 8 | p: 0.5 9 | ToTensor: {} 10 | Normalize: 11 | mean: [0.485, 0.456, 0.406] 12 | std: [0.229, 0.224, 0.225] 13 | 14 | test: 15 | Resize: 16 | size: [224,224] 17 | ToTensor: {} 18 | Normalize: 19 | mean: [0.485, 0.456, 0.406] 20 | std: [0.229, 0.224, 0.225] 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open('requirements.txt') as open_file: 4 | install_requires = open_file.read() 5 | 6 | setuptools.setup( 7 | name='roadmap', 8 | version='0.0.0', 9 | packages=[''], 10 | url='https://github.com/elias-ramzi/ROADMAP', 11 | license='', 12 | author='Elias Ramzi', 13 | author_email='elias.ramzi@lecnam.net', 14 | description='', 15 | python_requires='>=3.6', 16 | install_requires=install_requires 17 | ) 18 | -------------------------------------------------------------------------------- /roadmap/config/transform/cub.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | train: 3 | RandomResizedCrop: 4 | scale: [0.16, 1] 5 | ratio: [0.75, 1.33] 6 | size: 224 7 | RandomHorizontalFlip: 8 | p: 0.5 9 | ToTensor: {} 10 | Normalize: 11 | mean: [0.485, 0.456, 0.406] 12 | std: [0.229, 0.224, 0.225] 13 | 14 | test: 15 | Resize: 16 | size: 256 17 | CenterCrop: 18 | size: 224 19 | ToTensor: {} 20 | Normalize: 21 | mean: [0.485, 0.456, 0.406] 22 | std: [0.229, 0.224, 0.225] 23 | -------------------------------------------------------------------------------- /roadmap/config/transform/cub_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | train: 3 | RandomResizedCrop: 4 | scale: [0.16, 1] 5 | ratio: [0.75, 1.33] 6 | size: 256 7 | RandomHorizontalFlip: 8 | p: 0.5 9 | ToTensor: {} 10 | Normalize: 11 | mean: [0.485, 0.456, 0.406] 12 | std: [0.229, 0.224, 0.225] 13 | 14 | test: 15 | Resize: 16 | size: 288 17 | CenterCrop: 18 | size: 256 19 | ToTensor: {} 20 | Normalize: 21 | mean: [0.485, 0.456, 0.406] 22 | std: [0.229, 0.224, 0.225] 23 | -------------------------------------------------------------------------------- /roadmap/utils/override_config.py: -------------------------------------------------------------------------------- 1 | def set_attribute(cfg, key, value): 2 | all_key = key.split('.') 3 | obj = cfg 4 | for k in all_key[:-1]: 5 | try: 6 | obj = obj[int(k)] 7 | except ValueError: 8 | obj = getattr(obj, k) 9 | setattr(obj, all_key[-1], value) 10 | return cfg 11 | 12 | 13 | def override_config(hyperparameters, config): 14 | for k, v in hyperparameters.items(): 15 | config = set_attribute(config, k, v) 16 | 17 | return config 18 | -------------------------------------------------------------------------------- /roadmap/config/transform/inaturalist.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | train: 3 | RandomResizedCrop: 4 | scale: [0.16, 1] 5 | ratio: [0.75, 1.33] 6 | size: 224 7 | RandomHorizontalFlip: 8 | p: 0.5 9 | ToTensor: {} 10 | Normalize: 11 | mean: [0.485, 0.456, 0.406] 12 | std: [0.229, 0.224, 0.225] 13 | 14 | test: 15 | Resize: 16 | size: 256 17 | CenterCrop: 18 | size: 224 19 | ToTensor: {} 20 | Normalize: 21 | mean: [0.485, 0.456, 0.406] 22 | std: [0.229, 0.224, 0.225] 23 | -------------------------------------------------------------------------------- /roadmap/config/transform/inshop_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | train: 3 | RandomResizedCrop: 4 | scale: [0.16, 1] 5 | ratio: [0.75, 1.33] 6 | size: 256 7 | RandomHorizontalFlip: 8 | p: 0.5 9 | ToTensor: {} 10 | Normalize: 11 | mean: [0.485, 0.456, 0.406] 12 | std: [0.229, 0.224, 0.225] 13 | 14 | test: 15 | Resize: 16 | size: [288,288] 17 | CenterCrop: 18 | size: 256 19 | ToTensor: {} 20 | Normalize: 21 | mean: [0.485, 0.456, 0.406] 22 | std: [0.229, 0.224, 0.225] 23 | -------------------------------------------------------------------------------- /roadmap/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .blackbox_ap import BlackBoxAP 2 | from .calibration_loss import CalibrationLoss 3 | from .fast_ap import FastAP 4 | from .softbin_ap import SoftBinAP 5 | from .pair_loss import PairLoss 6 | from .smooth_rank_ap import ( 7 | HeavisideAP, 8 | SmoothAP, 9 | SupAP, 10 | ) 11 | 12 | 13 | __all__ = [ 14 | 'BlackBoxAP', 15 | 'CalibrationLoss', 16 | 'FastAP', 17 | 'SoftBinAP', 18 | 'PairLoss', 19 | 'HeavisideAP', 20 | 'SmoothAP', 21 | 'SupAP', 22 | ] 23 | -------------------------------------------------------------------------------- /roadmap/config/transform/sop.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | train: 3 | Resize: 4 | size: 256 5 | RandomResizedCrop: 6 | scale: [0.16, 1] 7 | ratio: [0.75, 1.33] 8 | size: 224 9 | RandomHorizontalFlip: 10 | p: 0.5 11 | ToTensor: {} 12 | Normalize: 13 | mean: [0.485, 0.456, 0.406] 14 | std: [0.229, 0.224, 0.225] 15 | 16 | test: 17 | Resize: 18 | size: [256, 256] 19 | CenterCrop: 20 | size: 224 21 | ToTensor: {} 22 | Normalize: 23 | mean: [0.485, 0.456, 0.406] 24 | std: [0.229, 0.224, 0.225] 25 | -------------------------------------------------------------------------------- /roadmap/config/dataset/sfm120k.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: SfM120kDataset 3 | kwargs: 4 | data_dir: /local/DEEPLEARNING/image_retrieval/landmarks/sfm120k 5 | 6 | sampler: 7 | name: MPerClassSampler 8 | kwargs: 9 | batch_size: 128 10 | samples_per_class: 4 11 | 12 | evaluation: 13 | - name: RevisitedDataset 14 | kwargs: 15 | data_dir: /local/DEEPLEARNING/image_retrieval/landmarks/rparis6k 16 | 17 | - name: RevisitedDataset 18 | kwargs: 19 | data_dir: /local/DEEPLEARNING/image_retrieval/landmarks/roxford5k 20 | -------------------------------------------------------------------------------- /roadmap/config/transform/sop_big.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | train: 3 | Resize: 4 | size: 288 5 | RandomResizedCrop: 6 | scale: [0.16, 1] 7 | ratio: [0.75, 1.33] 8 | size: 256 9 | RandomHorizontalFlip: 10 | p: 0.5 11 | ToTensor: {} 12 | Normalize: 13 | mean: [0.485, 0.456, 0.406] 14 | std: [0.229, 0.224, 0.225] 15 | 16 | test: 17 | Resize: 18 | size: [288, 288] 19 | CenterCrop: 20 | size: 256 21 | ToTensor: {} 22 | Normalize: 23 | mean: [0.485, 0.456, 0.406] 24 | std: [0.229, 0.224, 0.225] 25 | -------------------------------------------------------------------------------- /roadmap/config/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | 3 | - experience: default 4 | 5 | - dataset: sop 6 | 7 | - loss: roadmap 8 | 9 | - memory: default 10 | 11 | - model: resnet 12 | 13 | - optimizer: sop 14 | 15 | - transform: sop 16 | 17 | - hydra/job_logging: colorlog 18 | 19 | - hydra/hydra_logging: colorlog 20 | 21 | hydra: 22 | run: 23 | dir: ${experience.log_dir}/${experience.experiment_name}/outputs 24 | 25 | sweep: 26 | dir: ${experience.log_dir} 27 | subdir: ${experience.experiment_name}/outputs 28 | -------------------------------------------------------------------------------- /roadmap/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | def _handle_types(value): 2 | if hasattr(value, "detach"): 3 | return value.detach().item() 4 | else: 5 | return value 6 | 7 | 8 | class AverageMeter: 9 | def __init__(self) -> None: 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1) -> None: 16 | val = _handle_types(val) 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count 21 | -------------------------------------------------------------------------------- /roadmap/utils/extract_progress.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | import tarfile 3 | 4 | from tqdm import tqdm 5 | 6 | from .logger import LOGGER 7 | 8 | 9 | def extract_progress(compressed_obj): 10 | LOGGER.info("Extracting dataset") 11 | if isinstance(compressed_obj, tarfile.TarFile): 12 | iterable = compressed_obj 13 | length = len(compressed_obj.getmembers()) 14 | elif isinstance(compressed_obj, zipfile.ZipFile): 15 | iterable = compressed_obj.namelist() 16 | length = len(iterable) 17 | for member in tqdm.tqdm(iterable, total=length): 18 | yield member 19 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/sop.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: Adam 3 | params: backbone 4 | kwargs: 5 | lr: 0.00001 6 | weight_decay: 0.0001 7 | scheduler_on_epoch: 8 | name: MultiStepLR 9 | kwargs: 10 | milestones: [30,70] 11 | gamma: 0.3 12 | scheduler_on_step: 13 | scheduler_on_val: 14 | 15 | - name: Adam 16 | params: fc 17 | kwargs: 18 | lr: 0.00002 19 | weight_decay: 0.0001 20 | scheduler_on_epoch: 21 | name: MultiStepLR 22 | kwargs: 23 | milestones: [30,70] 24 | gamma: 0.3 25 | scheduler_on_step: 26 | scheduler_on_val: 27 | -------------------------------------------------------------------------------- /roadmap/engine/make_subset.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | 4 | def make_subset(dts, idxs, transform=None, mode=None): 5 | dts = deepcopy(dts) 6 | 7 | dts.paths = [dts.paths[x] for x in idxs] 8 | dts.labels = [dts.labels[x] for x in idxs] 9 | 10 | if hasattr(dts, 'super_labels') and dts.super_labels is not None: 11 | dts.super_labels = [dts.super_labels[x] for x in idxs] 12 | 13 | dts.get_instance_dict() 14 | dts.get_super_dict() 15 | 16 | if transform is not None: 17 | dts.transform = transform 18 | 19 | if mode is not None: 20 | dts.mode = mode 21 | 22 | return dts 23 | -------------------------------------------------------------------------------- /roadmap/config/optimizer/inaturalist.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | - name: Adam 3 | params: backbone 4 | kwargs: 5 | lr: 0.00001 6 | weight_decay: 0.0004 7 | scheduler_on_epoch: 8 | name: MultiStepLR 9 | kwargs: 10 | milestones: [30, 70] 11 | gamma: 0.3 12 | last_epoch: -1 13 | scheduler_on_step: 14 | scheduler_on_val: 15 | 16 | - name: Adam 17 | params: fc 18 | kwargs: 19 | lr: 0.00002 20 | weight_decay: 0.0004 21 | scheduler_on_epoch: 22 | name: MultiStepLR 23 | kwargs: 24 | milestones: [30, 70] 25 | gamma: 0.3 26 | last_epoch: -1 27 | scheduler_on_step: 28 | scheduler_on_val: 29 | -------------------------------------------------------------------------------- /roadmap/utils/moving_average.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class MovingAverage: 5 | 6 | def __init__(self, ws): 7 | self.data = [] 8 | 9 | def __call__(self, value): 10 | try: 11 | self.data.pop(0) 12 | except IndexError: 13 | pass 14 | 15 | self.data.append(value) 16 | return np.mean(self.data) 17 | 18 | def mean_first(self, value): 19 | if self.data: 20 | mean = np.mean(self.data) 21 | self.data.pop(0) 22 | self.data.append(value) 23 | return mean 24 | 25 | else: 26 | self.data.append(value) 27 | return value 28 | -------------------------------------------------------------------------------- /roadmap/config/experience/landmarks.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | experiment_name: ??? 3 | log_dir: /share/DEEPLEARNING/datasets/image_retrieval/experiments/ 4 | seed: ??? 5 | resume: null 6 | maybe_resume: False 7 | force_lr: null 8 | 9 | warm_up: -1 10 | warm_up_key: fc 11 | 12 | max_iter: 50 13 | 14 | train_eval_freq: -1 15 | val_eval_freq: -1 16 | test_eval_freq: 5 17 | eval_bs: 96 18 | save_model: 50 19 | eval_split: rparis6k 20 | principal_metric: mapH 21 | log_grad: False 22 | with_AP: False 23 | landmarks: True 24 | 25 | split: null 26 | kfold: null 27 | split_random_state: null 28 | with_super_labels: False 29 | 30 | num_workers: 16 31 | pin_memory: True 32 | 33 | sub_batch: 128 34 | update_type: base_update 35 | -------------------------------------------------------------------------------- /roadmap/config/experience/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | experiment_name: ??? 3 | log_dir: /share/DEEPLEARNING/datasets/image_retrieval/experiments/ 4 | seed: ??? 5 | resume: null 6 | maybe_resume: False 7 | force_lr: null 8 | 9 | warm_up: -1 10 | warm_up_key: fc 11 | 12 | max_iter: 50 13 | 14 | train_eval_freq: -1 15 | val_eval_freq: -1 16 | test_eval_freq: 5 17 | eval_bs: 96 18 | save_model: 50 19 | eval_split: test 20 | principal_metric: mean_average_precision_at_r_level0 21 | log_grad: False 22 | with_AP: False 23 | landmarks: False 24 | 25 | split: null 26 | kfold: null 27 | split_random_state: null 28 | with_super_labels: False 29 | 30 | num_workers: 16 31 | pin_memory: True 32 | 33 | sub_batch: 128 34 | update_type: base_update 35 | -------------------------------------------------------------------------------- /roadmap/utils/create_label_matrix.py: -------------------------------------------------------------------------------- 1 | def create_label_matrix(labels, other_labels=None): 2 | labels = labels.squeeze() 3 | 4 | if labels.ndim == 1: 5 | if other_labels is None: 6 | return (labels.view(-1, 1) == labels.t()).float() 7 | 8 | return (labels.view(-1, 1) == other_labels.t()).float() 9 | 10 | elif labels.ndim == 2: 11 | size = labels.size(0) 12 | if other_labels is None: 13 | return (labels.view(size, size, 1) == labels.view(size, 1, size)).float() 14 | 15 | raise NotImplementedError(f"Function for tensor dimension {labels.ndim} comparated to tensor of dimension {other_labels.ndim} not implemented") 16 | 17 | raise NotImplementedError(f"Function for tensor dimension {labels.ndim} not implemented") 18 | -------------------------------------------------------------------------------- /roadmap/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy_calculator import CustomCalculator 2 | from .base_update import base_update 3 | from .chepoint import checkpoint 4 | from .cross_validation_splits import ( 5 | get_class_disjoint_splits, 6 | get_hierarchical_class_disjoint_splits, 7 | get_closed_set_splits, 8 | get_splits, 9 | ) 10 | from .evaluate import evaluate, get_tester 11 | from .landmark_evaluation import landmark_evaluation 12 | from .make_subset import make_subset 13 | from .memory import XBM 14 | from .train import train 15 | 16 | 17 | __all__ = [ 18 | 'CustomCalculator', 19 | 'base_update', 20 | 'checkpoint', 21 | 'get_class_disjoint_splits', 22 | 'get_hierarchical_class_disjoint_splits', 23 | 'get_closed_set_splits', 24 | 'get_splits', 25 | 'evaluate', 26 | 'get_tester', 27 | 'landmark_evaluation', 28 | 'make_subset', 29 | 'XBM', 30 | 'train', 31 | ] 32 | -------------------------------------------------------------------------------- /roadmap/utils/dict_average.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | 3 | 4 | class DictAverage: 5 | 6 | def __init__(self,) -> None: 7 | self.dict_avg = {} 8 | 9 | def update(self, dict_values: dict) -> None: 10 | for key, item in dict_values.items(): 11 | try: 12 | self.dict_avg[key].update(item) 13 | except KeyError: 14 | self.dict_avg[key] = AverageMeter() 15 | self.dict_avg[key].update(item) 16 | 17 | def keys(self,): 18 | return self.dict_avg.keys() 19 | 20 | def __getitem__(self, name): 21 | self.dict_avg[name] 22 | 23 | def get(self, name, other=None): 24 | try: 25 | return self.dict_avg[name] 26 | except KeyError: 27 | return self.dict_avg[other] 28 | 29 | @property 30 | def avg(self,) -> dict: 31 | return {key: item.avg for key, item in self.dict_avg.items()} 32 | -------------------------------------------------------------------------------- /roadmap/samplers/random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import BatchSampler 3 | 4 | import roadmap.utils as lib 5 | 6 | 7 | class RandomSampler(BatchSampler): 8 | def __init__( 9 | self, 10 | dataset, 11 | batch_size, 12 | ): 13 | self.batch_size = batch_size 14 | 15 | self.length = len(dataset) 16 | self.reshuffle() 17 | 18 | def __iter__(self,): 19 | self.reshuffle() 20 | for batch in self.batches: 21 | yield batch 22 | 23 | def __len__(self,): 24 | return len(self.batches) 25 | 26 | def __repr__(self,): 27 | return f"{self.__class__.__name__}(batch_size={self.batch_size})" 28 | 29 | def reshuffle(self): 30 | lib.LOGGER.info("Shuffling data") 31 | idxs = list(range(self.length)) 32 | np.random.shuffle(idxs) 33 | self.batches = [] 34 | for i in range(self.length // self.batch_size): 35 | self.batches.append(idxs[i*self.batch_size:(i+1)*self.batch_size]) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Elias Ramzi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /roadmap/models/create_projection_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | import roadmap.utils as lib 4 | 5 | 6 | def create_projection_head( 7 | input_dimension=2048, 8 | layer_dim=512, 9 | normalization_layer='none', 10 | ): 11 | if isinstance(layer_dim, int): 12 | return nn.Linear(input_dimension, layer_dim) 13 | 14 | elif lib.list_or_tuple(layer_dim): 15 | layers = [] 16 | prev_dim = input_dimension 17 | for i, dim in enumerate(layer_dim): 18 | layers.append(nn.Linear(prev_dim, dim)) 19 | prev_dim = dim 20 | if i < len(layer_dim) - 1: 21 | if normalization_layer == 'bn': 22 | layers.append(nn.BatchNorm1d(dim)) 23 | elif normalization_layer == 'ln': 24 | layers.append(nn.LayerNorm(dim)) 25 | elif normalization_layer == 'none': 26 | pass 27 | else: 28 | raise ValueError(f"Unknown normalization layer : {normalization_layer}") 29 | layers.append(nn.ReLU(inplace=True)) 30 | 31 | return nn.Sequential(*layers) 32 | -------------------------------------------------------------------------------- /roadmap/utils/get_set_random_state.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import wraps 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .logger import LOGGER 8 | 9 | 10 | def get_random_state(): 11 | LOGGER.debug("Getting random state") 12 | RANDOM_STATE = {} 13 | RANDOM_STATE["RANDOM_STATE"] = random.getstate() 14 | RANDOM_STATE["NP_STATE"] = np.random.get_state() 15 | RANDOM_STATE["TORCH_STATE"] = torch.random.get_rng_state() 16 | RANDOM_STATE["TORCH_CUDA_STATE"] = torch.cuda.get_rng_state_all() 17 | return RANDOM_STATE 18 | 19 | 20 | def set_random_state(RANDOM_STATE): 21 | LOGGER.debug("Setting random state") 22 | random.setstate(RANDOM_STATE["RANDOM_STATE"]) 23 | np.random.set_state(RANDOM_STATE["NP_STATE"]) 24 | torch.random.set_rng_state(RANDOM_STATE["TORCH_STATE"]) 25 | torch.cuda.set_rng_state_all(RANDOM_STATE["TORCH_CUDA_STATE"]) 26 | 27 | 28 | def get_set_random_state(func): 29 | @wraps(func) 30 | def wrapper(*args, **kwargs): 31 | RANDOM_STATE = get_random_state() 32 | output = func(*args, **kwargs) 33 | set_random_state(RANDOM_STATE) 34 | return output 35 | return wrapper 36 | -------------------------------------------------------------------------------- /roadmap/datasets/sop.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import pandas as pd 4 | 5 | from .base_dataset import BaseDataset 6 | 7 | 8 | class SOPDataset(BaseDataset): 9 | 10 | def __init__(self, data_dir, mode, transform=None, **kwargs): 11 | super().__init__(**kwargs) 12 | 13 | self.data_dir = data_dir 14 | self.mode = mode 15 | self.transform = transform 16 | 17 | if mode == 'train': 18 | mode = ['train'] 19 | elif mode == 'test': 20 | mode = ['test'] 21 | elif mode == 'all': 22 | mode = ['train', 'test'] 23 | else: 24 | raise ValueError(f"Mode unrecognized {mode}") 25 | 26 | self.paths = [] 27 | self.labels = [] 28 | self.super_labels = [] 29 | for splt in mode: 30 | gt = pd.read_csv(join(self.data_dir, f'Ebay_{splt}.txt'), sep=' ') 31 | self.paths.extend(gt["path"].apply(lambda x: join(self.data_dir, x)).tolist()) 32 | self.labels.extend((gt["class_id"] - 1).tolist()) 33 | self.super_labels.extend((gt["super_class_id"] - 1).tolist()) 34 | 35 | self.get_instance_dict() 36 | self.get_super_dict() 37 | -------------------------------------------------------------------------------- /roadmap/datasets/sfm120k.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | from scipy.io import loadmat 4 | 5 | from .base_dataset import BaseDataset 6 | 7 | 8 | def cid2filename(cid, prefix): 9 | """ 10 | https://github.com/filipradenovic/cnnimageretrieval-pytorch 11 | 12 | Creates a training image path out of its CID name 13 | 14 | Arguments 15 | --------- 16 | cid : name of the image 17 | prefix : root directory where images are saved 18 | 19 | Returns 20 | ------- 21 | filename : full image filename 22 | """ 23 | return join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid) 24 | 25 | 26 | class SfM120kDataset(BaseDataset): 27 | 28 | def __init__(self, data_dir, mode, transform=None, **kwargs): 29 | super().__init__(**kwargs) 30 | 31 | self.data_dir = data_dir 32 | self.mode = mode 33 | self.transform = transform 34 | 35 | db = loadmat(join(self.data_dir, "retrieval-SfM-120k-imagenames-clusterids.mat")) 36 | 37 | cids = [x[0] for x in db['cids'][0]] 38 | self.paths = [cid2filename(x, join(self.data_dir, "ims")) for x in cids] 39 | self.labels = [int(x) for x in db['cluster'][0]] 40 | 41 | self.get_instance_dict() 42 | -------------------------------------------------------------------------------- /roadmap/losses/calibration_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_metric_learning import losses, distances 3 | from pytorch_metric_learning.utils import common_functions as c_f 4 | from pytorch_metric_learning.utils import loss_and_miner_utils as lmu 5 | 6 | 7 | class CalibrationLoss(losses.ContrastiveLoss): 8 | takes_embeddings = True 9 | 10 | def get_default_distance(self): 11 | return distances.DotProductSimilarity() 12 | 13 | def forward(self, embeddings, labels, ref_embeddings=None, ref_labels=None): 14 | if ref_embeddings is None: 15 | return super().forward(embeddings, labels) 16 | 17 | indices_tuple = self.create_indices_tuple( 18 | embeddings.size(0), 19 | embeddings, 20 | labels, 21 | ref_embeddings, 22 | ref_labels, 23 | ) 24 | 25 | combined_embeddings = torch.cat([embeddings, ref_embeddings], dim=0) 26 | combined_labels = torch.cat([labels, ref_labels], dim=0) 27 | return super().forward(combined_embeddings, combined_labels, indices_tuple) 28 | 29 | def create_indices_tuple( 30 | self, 31 | batch_size, 32 | embeddings, 33 | labels, 34 | E_mem, 35 | L_mem, 36 | ): 37 | indices_tuple = lmu.get_all_pairs_indices(labels, L_mem) 38 | indices_tuple = c_f.shift_indices_tuple(indices_tuple, batch_size) 39 | return indices_tuple 40 | -------------------------------------------------------------------------------- /roadmap/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | from .count_parameters import count_parameters 3 | from .create_label_matrix import create_label_matrix 4 | from .dict_average import DictAverage 5 | from .expand_path import expand_path 6 | from .extract_progress import extract_progress 7 | from .format_time import format_time 8 | from .freeze_batch_norm import freeze_batch_norm 9 | from .freeze_pos_embedding import freeze_pos_embedding 10 | from .get_gradient_norm import get_gradient_norm 11 | from .get_lr import get_lr 12 | from .get_set_random_state import get_random_state, set_random_state, get_set_random_state 13 | from .list_or_tuple import list_or_tuple 14 | from .logger import LOGGER 15 | from .moving_average import MovingAverage 16 | from .override_config import override_config 17 | from .rgb_to_bgr import RGBToBGR 18 | from .set_initial_lr import set_initial_lr 19 | from .set_lr import set_lr 20 | from .str_to_bool import str_to_bool 21 | 22 | 23 | __all__ = [ 24 | 'AverageMeter', 25 | 'count_parameters', 26 | 'create_label_matrix', 27 | 'DictAverage', 28 | 'expand_path', 29 | 'extract_progress', 30 | 'format_time', 31 | 'freeze_batch_norm', 32 | 'freeze_pos_embedding', 33 | 'get_gradient_norm', 34 | 'get_random_state', 35 | 'set_random_state', 36 | 'get_set_random_state', 37 | 'get_lr', 38 | 'list_or_tuple', 39 | 'LOGGER', 40 | 'MovingAverage', 41 | 'override_config', 42 | 'RGBToBGR', 43 | 'set_initial_lr', 44 | 'set_lr', 45 | 'str_to_bool', 46 | ] 47 | -------------------------------------------------------------------------------- /roadmap/datasets/inshop.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .base_dataset import BaseDataset 4 | 5 | 6 | class InShopDataset(BaseDataset): 7 | 8 | def __init__(self, data_dir, mode, transform=None, hierarchy_mode='all', **kwargs): 9 | super().__init__(**kwargs) 10 | 11 | assert mode in ["train", "query", "gallery"], f"Mode : {mode} unknown" 12 | assert hierarchy_mode in ['1', '2', 'all'], f"Hierarchy mode : {hierarchy_mode} unknown" 13 | self.data_dir = data_dir 14 | self.mode = mode 15 | self.transform = transform 16 | 17 | with open(os.path.join(data_dir, "list_eval_partition.txt")) as f: 18 | db = f.read().split("\n")[2:-1] 19 | 20 | paths = [] 21 | labels = [] 22 | super_labels_name = [] 23 | for line in db: 24 | line = line.split(" ") 25 | line = list(filter(lambda x: x, line)) 26 | if line[2] == mode: 27 | paths.append(os.path.join(data_dir, line[0])) 28 | labels.append(int(line[1].split("_")[-1])) 29 | if hierarchy_mode == '2': 30 | super_labels_name.append(line[0].split("/")[2]) 31 | elif hierarchy_mode == '1': 32 | super_labels_name.append(line[0].split("/")[1]) 33 | elif hierarchy_mode == 'all': 34 | super_labels_name.append('/'.join(line[0].split("/")[1:3])) 35 | 36 | self.paths = paths 37 | self.labels = labels 38 | 39 | slb_to_id = {slb: i for i, slb in enumerate(set(super_labels_name))} 40 | self.super_labels = [slb_to_id[slb] for slb in super_labels_name] 41 | self.super_labels_name = super_labels_name 42 | 43 | self.get_instance_dict() 44 | self.get_super_dict() 45 | -------------------------------------------------------------------------------- /roadmap/losses/pair_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Malong Technologies Co., Ltd. 2 | # All rights reserved. 3 | # 4 | # Contact: github@malong.com 5 | # 6 | # This source code is licensed under the LICENSE file in 7 | # https://github.com/msight-tech/research-xbm/blob/master/LICENSE 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class PairLoss(nn.Module): 13 | takes_embeddings = True 14 | 15 | def __init__(self, margin=0.5): 16 | super().__init__() 17 | self.margin = margin 18 | 19 | def compute_loss(self, inputs_col, targets_col, inputs_row, target_row): 20 | 21 | n = inputs_col.size(0) 22 | # Compute similarity matrix 23 | sim_mat = torch.matmul(inputs_col, inputs_row.t()) 24 | epsilon = 1e-5 25 | loss = list() 26 | 27 | neg_count = list() 28 | for i in range(n): 29 | pos_pair_ = torch.masked_select(sim_mat[i], targets_col[i] == target_row) 30 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1 - epsilon) 31 | neg_pair_ = torch.masked_select(sim_mat[i], targets_col[i] != target_row) 32 | 33 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > self.margin) 34 | 35 | pos_loss = torch.sum(-pos_pair_ + 1) 36 | if len(neg_pair) > 0: 37 | neg_loss = torch.sum(neg_pair) 38 | neg_count.append(len(neg_pair)) 39 | else: 40 | neg_loss = 0 41 | 42 | loss.append(pos_loss + neg_loss) 43 | 44 | loss = sum(loss) / n 45 | return loss 46 | 47 | def forward(self, embeddings, labels, ref_embeddings=None, ref_labels=None): 48 | if ref_embeddings is None: 49 | return self.compute_loss(embeddings, labels, embeddings, labels) 50 | 51 | return self.compute_loss(embeddings, labels, ref_embeddings, ref_labels) 52 | 53 | def extra_repr(self,): 54 | return f"margin={self.margin}" 55 | -------------------------------------------------------------------------------- /roadmap/datasets/cub200.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from torchvision import datasets 5 | 6 | from .base_dataset import BaseDataset 7 | 8 | 9 | class Cub200Dataset(BaseDataset): 10 | 11 | def __init__(self, data_dir, mode, transform=None, load_super_labels=False, **kwargs): 12 | super().__init__(**kwargs) 13 | self.data_dir = data_dir 14 | self.mode = mode 15 | self.transform = transform 16 | self.load_super_labels = load_super_labels 17 | 18 | dataset = datasets.ImageFolder(os.path.join(self.data_dir, 'images')) 19 | paths = np.array([a for (a, b) in dataset.imgs]) 20 | labels = np.array([b for (a, b) in dataset.imgs]) 21 | 22 | sorted_lb = list(sorted(set(labels))) 23 | if mode == 'train': 24 | set_labels = set(sorted_lb[:len(sorted_lb) // 2]) 25 | elif mode == 'test': 26 | set_labels = set(sorted_lb[len(sorted_lb) // 2:]) 27 | elif mode == 'all': 28 | set_labels = sorted_lb 29 | 30 | self.paths = [] 31 | self.labels = [] 32 | for lb, pth in zip(labels, paths): 33 | if lb in set_labels: 34 | self.paths.append(pth) 35 | self.labels.append(lb) 36 | 37 | self.super_labels = None 38 | if self.load_super_labels: 39 | with open(os.path.join(self.data_dir, "classes.txt")) as f: 40 | lines = f.read().split("\n") 41 | lines.remove("") 42 | labels_id = list(map(lambda x: int(x.split(" ")[0])-1, lines)) 43 | super_labels_name = list(map(lambda x: x.split(" ")[2], lines)) 44 | slb_names_to_id = {x: i for i, x in enumerate(sorted(set(super_labels_name)))} 45 | super_labels = [slb_names_to_id[x] for x in super_labels_name] 46 | labels_to_super_labels = {lb: slb for lb, slb in zip(labels_id, super_labels)} 47 | self.super_labels = [labels_to_super_labels[x] for x in self.labels] 48 | 49 | self.get_instance_dict() 50 | -------------------------------------------------------------------------------- /roadmap/engine/chepoint.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import torch 4 | 5 | import roadmap.utils as lib 6 | 7 | 8 | def checkpoint( 9 | log_dir, 10 | save_checkpoint, 11 | net, 12 | optimizer, 13 | scheduler, 14 | scaler, 15 | epoch, 16 | seed, 17 | args, 18 | score, 19 | best_model, 20 | best_score, 21 | ): 22 | state_dict = {} 23 | if torch.cuda.device_count() > 1: 24 | state_dict["net_state"] = net.module.state_dict() 25 | else: 26 | state_dict["net_state"] = net.state_dict() 27 | 28 | state_dict["optimizer_state"] = {key: opt.state_dict() for key, opt in optimizer.items()} 29 | 30 | state_dict["scheduler_on_epoch_state"] = [sch.state_dict() for sch in scheduler["on_epoch"]] 31 | state_dict["scheduler_on_step_state"] = [sch.state_dict() for sch in scheduler["on_step"]] 32 | state_dict["scheduler_on_val_state"] = [sch.state_dict() for sch, _ in scheduler["on_val"]] 33 | 34 | if scaler is not None: 35 | state_dict["scaler_state"] = scaler.state_dict() 36 | 37 | state_dict["epoch"] = epoch 38 | state_dict["seed"] = seed 39 | state_dict["config"] = args 40 | state_dict["score"] = score 41 | state_dict["best_score"] = best_score 42 | state_dict["best_model"] = f"{best_model}.ckpt" 43 | 44 | RANDOM_STATE = lib.get_random_state() 45 | state_dict.update(RANDOM_STATE) 46 | 47 | if log_dir is None: 48 | from ray import tune 49 | torch.save(state_dict, join(tune.get_trial_dir(), "rolling.ckpt")) 50 | if save_checkpoint: 51 | with tune.checkpoint_dir(step=epoch) as checkpoint_dir: 52 | lib.LOGGER.info(f"Checkpoint of epoch {epoch} created") 53 | torch.save(state_dict, join(checkpoint_dir, f"epoch_{epoch}.ckpt")) 54 | 55 | else: 56 | torch.save(state_dict, join(log_dir, 'weights', "rolling.ckpt")) 57 | if save_checkpoint: 58 | lib.LOGGER.info(f"Checkpoint of epoch {epoch} created") 59 | torch.save(state_dict, join(log_dir, 'weights', f"epoch_{epoch}.ckpt")) 60 | -------------------------------------------------------------------------------- /roadmap/engine/get_knn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import faiss 4 | import pytorch_metric_learning.utils.common_functions as c_f 5 | 6 | import roadmap.utils as lib 7 | 8 | 9 | def get_knn(references, queries, num_k, embeddings_come_from_same_source, with_faiss=True): 10 | num_k += embeddings_come_from_same_source 11 | 12 | lib.LOGGER.info("running k-nn with k=%d" % num_k) 13 | lib.LOGGER.info("embedding dimensionality is %d" % references.size(-1)) 14 | 15 | if with_faiss: 16 | distances, indices = get_knn_faiss(references, queries, num_k) 17 | else: 18 | distances, indices = get_knn_torch(references, queries, num_k) 19 | 20 | if embeddings_come_from_same_source: 21 | return indices[:, 1:], distances[:, 1:] 22 | 23 | return indices, distances 24 | 25 | 26 | def get_knn_faiss(references, queries, num_k): 27 | lib.LOGGER.debug("Computing k-nn with faiss") 28 | 29 | d = references.size(-1) 30 | device = references.device 31 | references = c_f.to_numpy(references).astype(np.float32) 32 | queries = c_f.to_numpy(queries).astype(np.float32) 33 | 34 | index = faiss.IndexFlatL2(d) 35 | try: 36 | if torch.cuda.device_count() > 1: 37 | co = faiss.GpuMultipleClonerOptions() 38 | co.shards = True 39 | index = faiss.index_cpu_to_all_gpus(index, co) 40 | else: 41 | res = faiss.StandardGpuResources() 42 | index = faiss.index_cpu_to_gpu(res, 0, index) 43 | except AttributeError: 44 | # Only faiss CPU is installed 45 | pass 46 | 47 | index.add(references) 48 | distances, indices = index.search(queries, num_k) 49 | distances = c_f.to_device(torch.from_numpy(distances), device=device) 50 | indices = c_f.to_device(torch.from_numpy(indices), device=device) 51 | index.reset() 52 | return distances, indices 53 | 54 | 55 | def get_knn_torch(references, queries, num_k): 56 | lib.LOGGER.debug("Computing k-nn with torch") 57 | 58 | scores = queries @ references.t() 59 | distances, indices = torch.topk(scores, num_k) 60 | return distances, indices 61 | -------------------------------------------------------------------------------- /roadmap/samplers/m_per_class_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from : 3 | https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/datasets.py 4 | """ 5 | import copy 6 | 7 | import numpy as np 8 | 9 | import roadmap.utils as lib 10 | 11 | 12 | def flatten(list_): 13 | return [item for sublist in list_ for item in sublist] 14 | 15 | 16 | class MPerClassSampler: 17 | def __init__( 18 | self, 19 | dataset, 20 | batch_size, 21 | samples_per_class=4, 22 | ): 23 | """ 24 | Args: 25 | image_dict: two-level dict, `super_dict[super_class_id][class_id]` gives the list of 26 | image idxs having the same super-label and class label 27 | """ 28 | assert samples_per_class > 1 29 | assert batch_size % samples_per_class == 0 30 | self.image_dict = dataset.instance_dict.copy() 31 | self.batch_size = batch_size 32 | self.samples_per_class = samples_per_class 33 | 34 | self.reshuffle() 35 | 36 | def __iter__(self,): 37 | for batch in self.batches: 38 | yield batch 39 | 40 | def __len__(self,): 41 | return len(self.batches) 42 | 43 | def __repr__(self,): 44 | return ( 45 | f"{self.__class__.__name__}(\n" 46 | f" batch_size={self.batch_size},\n" 47 | f" samples_per_class={self.samples_per_class}\n)" 48 | ) 49 | 50 | def reshuffle(self): 51 | lib.LOGGER.info("Shuffling data") 52 | image_dict = copy.deepcopy(self.image_dict) 53 | for sub in image_dict: 54 | np.random.shuffle(image_dict[sub]) 55 | 56 | classes = [*image_dict] 57 | np.random.shuffle(classes) 58 | total_batches = [] 59 | batch = [] 60 | finished = 0 61 | while finished == 0: 62 | for sub_class in classes: 63 | if (len(image_dict[sub_class]) >= self.samples_per_class) and (len(batch) < self.batch_size/self.samples_per_class): 64 | batch.append(image_dict[sub_class][:self.samples_per_class]) 65 | image_dict[sub_class] = image_dict[sub_class][self.samples_per_class:] 66 | 67 | if len(batch) == self.batch_size/self.samples_per_class: 68 | batch = flatten(batch) 69 | np.random.shuffle(batch) 70 | total_batches.append(batch) 71 | batch = [] 72 | else: 73 | finished = 1 74 | 75 | np.random.shuffle(total_batches) 76 | self.batches = total_batches 77 | -------------------------------------------------------------------------------- /roadmap/datasets/inaturalist.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import join 3 | 4 | from .base_dataset import BaseDataset 5 | 6 | 7 | class INaturalistDataset(BaseDataset): 8 | 9 | def __init__(self, data_dir, mode, transform=None, **kwargs): 10 | super().__init__(**kwargs) 11 | 12 | self.data_dir = data_dir 13 | self.mode = mode 14 | self.transform = transform 15 | 16 | if mode == 'train': 17 | mode = ['train'] 18 | elif mode == 'test': 19 | mode = ['test'] 20 | elif mode == 'all': 21 | mode = ['train', 'test'] 22 | else: 23 | raise ValueError(f"Mode unrecognized {mode}") 24 | 25 | self.paths = [] 26 | for splt in mode: 27 | with open(join(self.data_dir, f'Inat_dataset_splits/Inaturalist_{splt}_set1.txt')) as f: 28 | paths = f.read().split("\n") 29 | paths.remove("") 30 | self.paths.extend([join(self.data_dir, pth) for pth in paths]) 31 | 32 | with open(join(self.data_dir, 'train2018.json')) as f: 33 | db = json.load(f)['categories'] 34 | self.db = {} 35 | for x in db: 36 | _ = x.pop("name") 37 | id_ = x.pop("id") 38 | x["species"] = id_ 39 | self.db[id_] = x 40 | 41 | self.labels_name = [int(x.split("/")[-2]) for x in self.paths] 42 | self.labels_to_id = {cl: i for i, cl in enumerate(sorted(set(self.labels_name)))} 43 | self.labels = [self.labels_to_id[x] for x in self.labels_name] 44 | 45 | self.hierarchy_name = {} 46 | for x in self.labels_name: 47 | for key, val in self.db[x].items(): 48 | try: 49 | self.hierarchy_name[key].append(val) 50 | except KeyError: 51 | self.hierarchy_name[key] = [val] 52 | 53 | self.hierarchy_name_to_id = {} 54 | self.hierarchy_labels = {} 55 | for key, lst in self.hierarchy_name.items(): 56 | self.hierarchy_name_to_id[key] = {cl: i for i, cl in enumerate(sorted(set(lst)))} 57 | self.hierarchy_labels[key] = [self.hierarchy_name_to_id[key][x] for x in lst] 58 | 59 | self.super_labels_name = [x.split("/")[-3] for x in self.paths] 60 | self.super_labels_to_id = {scl: i for i, scl in enumerate(sorted(set(self.super_labels_name)))} 61 | self.super_labels = [self.super_labels_to_id[x] for x in self.super_labels_name] 62 | 63 | self.get_instance_dict() 64 | self.get_super_dict() 65 | -------------------------------------------------------------------------------- /roadmap/engine/memory.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def get_mask(lst): 8 | return [i == lst.index(x) for i, x in enumerate(lst)] 9 | 10 | 11 | class XBM(nn.Module): 12 | 13 | def __init__(self, size=None, weight=1.0, activate_after=-1, unique=True): 14 | super().__init__() 15 | self.size = size 16 | self.unique = unique 17 | self.weight = weight 18 | self.activate_after = activate_after 19 | 20 | if self.unique: 21 | self.features_memory = {} 22 | self.labels_memory = {} 23 | else: 24 | self.features_memory = deque() 25 | self.labels_memory = deque() 26 | 27 | def add_without_keys(self, features, labels): 28 | bs = features.size(0) 29 | while len(self.features_memory) + bs > self.size: 30 | self.features_memory.popleft() 31 | self.labels_memory.popleft() 32 | 33 | for feat, lb in zip(features, labels): 34 | self.features_memory.append(feat) 35 | self.labels_memory.append(lb) 36 | 37 | def add_with_keys(self, features, labels, keys): 38 | for k, feat, lb in zip(keys, features, labels): 39 | self.features_memory[k] = feat 40 | self.labels_memory[k] = lb 41 | 42 | def get_occupied_storage(self,): 43 | if not self.features_memory: 44 | return torch.tensor([]), torch.tensor([]) 45 | 46 | if self.unique: 47 | return torch.stack(list(self.features_memory.values())), torch.stack(list(self.labels_memory.values())) 48 | 49 | return torch.stack(list(self.features_memory)), torch.stack(list(self.labels_memory)) 50 | 51 | def forward(self, features, labels, keys=None): 52 | 53 | if self.unique: 54 | assert keys is not None 55 | self.add_with_keys(features, labels, keys) 56 | else: 57 | self.add_without_keys(features, labels) 58 | 59 | mem_features, mem_labels = self.get_occupied_storage() 60 | return mem_features, mem_labels 61 | 62 | def extra_repr(self,): 63 | return f"size={self.size}, unique={self.unique}" 64 | 65 | 66 | if __name__ == '__main__': 67 | mem = XBM((56, 128), unique=False) 68 | 69 | mem(torch.ones(32, 128), torch.ones(32,), torch.arange(32, 64)) 70 | mem(torch.ones(32, 128), torch.ones(32,), torch.arange(32)) 71 | 72 | # mem(torch.ones(32, 128), torch.ones(32,)) 73 | # features, labels = mem(torch.ones(32, 128), torch.ones(32,)) 74 | print(mem.index) 75 | -------------------------------------------------------------------------------- /roadmap/datasets/revisited_dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | import torch 4 | import pickle 5 | from PIL import Image 6 | 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | def imresize(img, imsize): 11 | img.thumbnail((imsize, imsize), Image.ANTIALIAS) 12 | return img 13 | 14 | 15 | def path_to_label(pth): 16 | return "_".join(pth.split("/")[-1].split(".")[0].split("_")[:-1]) 17 | 18 | 19 | class RevisitedDataset(BaseDataset): 20 | 21 | def __init__(self, data_dir, mode, imsize=None, transform=None, **kwargs): 22 | super().__init__(**kwargs) 23 | assert mode in ["query", "gallery"] 24 | 25 | self.data_dir = data_dir 26 | self.mode = mode 27 | self.imsize = imsize 28 | self.transform = transform 29 | self.city = self.data_dir.split('/') 30 | self.city = self.city[-1] if self.city[-1] else self.city[-2] 31 | 32 | with open(join(self.data_dir, f"gnd_{self.city}.pkl"), "rb") as f: 33 | db = pickle.load(f) 34 | 35 | self.paths = [join(self.data_dir, "jpg", f"{x}.jpg") for x in db["qimlist" if self.mode == "query" else "imlist"]] 36 | self.labels_name = [path_to_label(x) for x in self.paths] 37 | labels_name_to_id = {lb: i for i, lb in enumerate(sorted(set(self.labels_name)))} 38 | self.labels = [labels_name_to_id[x] for x in self.labels_name] 39 | 40 | if self.mode == "query": 41 | self.bbx = [x["bbx"] for x in db["gnd"]] 42 | self.easy = [x["easy"] for x in db["gnd"]] 43 | self.hard = [x["hard"] for x in db["gnd"]] 44 | self.junk = [x["junk"] for x in db["gnd"]] 45 | 46 | self.get_instance_dict() 47 | 48 | def __getitem__(self, idx): 49 | img = Image.open(self.paths[idx]) 50 | imfullsize = max(img.size) 51 | 52 | if self.mode == 'query': 53 | img = img.crop(self.bbx[idx]) 54 | 55 | if self.imsize is not None: 56 | if self.mode == 'query': 57 | img = imresize(img, self.imsize * max(img.size) / imfullsize) 58 | else: 59 | img = imresize(img, self.imsize) 60 | 61 | if self.transform is not None: 62 | img = self.transform(img) 63 | 64 | out = {"image": img, "label": torch.tensor([self.labels[idx]])} 65 | if self.mode == 'query': 66 | out["easy"] = self.easy[idx] 67 | out["hard"] = self.hard[idx] 68 | out["junk"] = self.junk[idx] 69 | 70 | return out 71 | 72 | def __repr__(self,): 73 | return f"{self.city.title()}Dataset(mode={self.mode}, imsize={self.imsize}, len={len(self)})" 74 | -------------------------------------------------------------------------------- /roadmap/single_experiment_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import torch 5 | 6 | import roadmap.run as run 7 | import roadmap.utils as lib 8 | 9 | 10 | @hydra.main(config_path='config', config_name='default') 11 | def single_experiment_runner(cfg): 12 | """ 13 | Parses hydra config, check for potential resuming of training 14 | and launches training 15 | """ 16 | 17 | try: 18 | try: 19 | import ray 20 | ray.cluster_resources() 21 | lib.LOGGER.info("Experiment running with ray : deactivating TQDM") 22 | os.environ['TQDM_DISABLE'] = "1" 23 | except Exception: 24 | pass 25 | except ray.exceptions.RaySystemError: 26 | pass 27 | 28 | cfg.experience.log_dir = lib.expand_path(cfg.experience.log_dir) 29 | 30 | if cfg.experience.resume is not None: 31 | if os.path.isfile(lib.expand_path(cfg.experience.resume)): 32 | resume = lib.expand_path(cfg.experience.resume) 33 | else: 34 | resume = os.path.join(cfg.experience.log_dir, cfg.experience.experiment_name, 'weights', cfg.experience.resume) 35 | 36 | if not os.path.isfile(resume): 37 | lib.LOGGER.warning("Checkpoint does not exists") 38 | return 39 | 40 | at_epoch = torch.load(resume, map_location='cpu')["epoch"] 41 | if at_epoch >= cfg.experience.max_iter: 42 | lib.LOGGER.warning(f"Exiting trial, experiment {cfg.experience.experiment_name} already finished") 43 | return 44 | 45 | elif cfg.experience.maybe_resume: 46 | state_path = os.path.join(cfg.experience.log_dir, cfg.experience.experiment_name, 'weights', 'rolling.ckpt') 47 | if os.path.isfile(state_path): 48 | resume = state_path 49 | lib.LOGGER.warning(f"Resuming experience because weights were found @ {resume}") 50 | at_epoch = torch.load(resume, map_location='cpu')["epoch"] 51 | if at_epoch >= cfg.experience.max_iter: 52 | lib.LOGGER.warning(f"Exiting trial, experiment {cfg.experience.experiment_name} already finished") 53 | return 54 | else: 55 | resume = None 56 | 57 | else: 58 | resume = None 59 | if os.path.isdir(os.path.join(cfg.experience.log_dir, cfg.experience.experiment_name, 'weights')): 60 | lib.LOGGER.warning(f"Exiting trial, experiment {cfg.experience.experiment_name} already exists") 61 | return 62 | 63 | metrics = run.run( 64 | config=cfg, 65 | base_config=None, 66 | checkpoint_dir=resume, 67 | ) 68 | 69 | if metrics is not None: 70 | return metrics[cfg.experience.eval_split][cfg.experience.principal_metric] 71 | 72 | 73 | if __name__ == '__main__': 74 | single_experiment_runner() 75 | -------------------------------------------------------------------------------- /roadmap/samplers/hierarchical_sampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from torch.utils.data.sampler import BatchSampler 5 | 6 | import roadmap.utils as lib 7 | 8 | 9 | def safe_random_choice(input_data, size): 10 | replace = len(input_data) < size 11 | return np.random.choice(input_data, size=size, replace=replace) 12 | 13 | 14 | # Inspired by 15 | # https://github.com/kunhe/Deep-Metric-Learning-Baselines/blob/master/datasets.py 16 | class HierarchicalSampler(BatchSampler): 17 | def __init__( 18 | self, 19 | dataset, 20 | batch_size, 21 | samples_per_class, 22 | batches_per_super_pair, 23 | nb_categories=2, 24 | ): 25 | """ 26 | labels: 2D array, where rows correspond to elements, and columns correspond to the hierarchical labels 27 | batch_size: because this is a BatchSampler the batch size must be specified 28 | samples_per_class: number of instances to sample for a specific class. set to 0 if all element in a class 29 | batches_per_super_pairs: number of batches to create for a pair of categories (or super labels) 30 | inner_label: columns index corresponding to classes 31 | outer_label: columns index corresponding to the level of hierarchy for the pairs 32 | """ 33 | self.batch_size = int(batch_size) 34 | self.batches_per_super_pair = int(batches_per_super_pair) 35 | self.samples_per_class = int(samples_per_class) 36 | self.nb_categories = int(nb_categories) 37 | 38 | # checks 39 | assert self.batch_size % self.nb_categories == 0, f"batch_size should be a multiple of {self.nb_categories}" 40 | self.sub_batch_len = self.batch_size // self.nb_categories 41 | if self.samples_per_class > 0: 42 | assert self.sub_batch_len % self.samples_per_class == 0, "batch_size not a multiple of samples_per_class" 43 | else: 44 | self.samples_per_class = None 45 | 46 | self.super_image_lists = dataset.super_dict.copy() 47 | self.super_pairs = list(itertools.combinations(set(dataset.super_labels), self.nb_categories)) 48 | self.reshuffle() 49 | 50 | def __iter__(self,): 51 | self.reshuffle() 52 | for batch in self.batches: 53 | yield batch 54 | 55 | def __len__(self,): 56 | return len(self.batches) 57 | 58 | def __repr__(self,): 59 | return ( 60 | f"{self.__class__.__name__}(\n" 61 | f" batch_size={self.batch_size},\n" 62 | f" samples_per_class={self.samples_per_class},\n" 63 | f" batches_per_super_pair={self.batches_per_super_pair},\n" 64 | f" nb_categories={self.nb_categories}\n)" 65 | ) 66 | 67 | def reshuffle(self): 68 | lib.LOGGER.info("Shuffling data") 69 | batches = [] 70 | for combinations in self.super_pairs: 71 | 72 | for b in range(self.batches_per_super_pair): 73 | 74 | batch = [] 75 | for slb in combinations: 76 | 77 | sub_batch = [] 78 | all_classes = list(self.super_image_lists[slb].keys()) 79 | np.random.shuffle(all_classes) 80 | for cl in all_classes: 81 | instances = self.super_image_lists[slb][cl] 82 | samples_per_class = self.samples_per_class if self.samples_per_class else len(instances) 83 | if len(sub_batch) + samples_per_class > self.sub_batch_len: 84 | continue 85 | sub_batch.extend(safe_random_choice(instances, size=samples_per_class)) 86 | 87 | batch.extend(sub_batch) 88 | 89 | np.random.shuffle(batch) 90 | batches.append(batch) 91 | 92 | np.random.shuffle(batches) 93 | self.batches = batches 94 | -------------------------------------------------------------------------------- /roadmap/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import argparse 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from roadmap.getter import Getter 8 | import roadmap.utils as lib 9 | import roadmap.engine as eng 10 | 11 | 12 | def load_and_evaluate( 13 | path, 14 | set, 15 | bs, 16 | nw, 17 | data_dir=None, 18 | ): 19 | lib.LOGGER.info(f"Evaluating : \033[92m{path}\033[0m") 20 | state = torch.load(lib.expand_path(path), map_location='cpu') 21 | cfg = state["config"] 22 | 23 | lib.LOGGER.info("Loading model...") 24 | cfg.model.kwargs.with_autocast = True 25 | net = Getter().get_model(cfg.model) 26 | net.load_state_dict(state["net_state"]) 27 | if torch.cuda.device_count() > 1: 28 | net = torch.nn.DataParallel(net) 29 | net.cuda() 30 | net.eval() 31 | 32 | if data_dir is not None: 33 | cfg.dataset.kwargs.data_dir = lib.expand_path(data_dir) 34 | 35 | getter = Getter() 36 | transform = getter.get_transform(cfg.transform.test) 37 | if hasattr(cfg.experience, 'split') and (cfg.experience.split is not None): 38 | assert isinstance(cfg.experience.split, int) 39 | dts = getter.get_dataset(None, 'all', cfg.dataset) 40 | splits = eng.get_splits(dts.labels, dts.super_labels, cfg.experience.kfold, random_state=cfg.experience.split_random_state) 41 | dts = eng.make_subset(dts, splits[cfg.experience.split]['train' if set == 'train' else 'val'], transform, set) 42 | lib.LOGGER.info(dts) 43 | else: 44 | dts = getter.get_dataset(transform, set, cfg.dataset) 45 | 46 | lib.LOGGER.info("Dataset created...") 47 | 48 | metrics = eng.evaluate( 49 | net=net, 50 | test_dataset=dts, 51 | epoch=state["epoch"], 52 | batch_size=bs, 53 | num_workers=nw, 54 | exclude=['mean_average_precision'], 55 | ) 56 | 57 | lib.LOGGER.info("Evaluation completed...") 58 | for split, mtrc in metrics.items(): 59 | for k, v in mtrc.items(): 60 | if k == 'epoch': 61 | continue 62 | lib.LOGGER.info(f"{split} --> {k} : {np.around(v*100, decimals=2)}") 63 | 64 | return metrics 65 | 66 | 67 | if __name__ == '__main__': 68 | 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("--config", type=str, required=True, nargs='+', help='Path.s to checkpoint') 71 | parser.add_argument("--parse-file", default=False, action='store_true', help='allows to pass a .txt file with several models to evaluate') 72 | parser.add_argument("--set", type=str, default='test', help='Set on which to evaluate') 73 | parser.add_argument("--bs", type=int, default=128, help='Batch size for DataLoader') 74 | parser.add_argument("--nw", type=int, default=10, help='Num workers for DataLoader') 75 | parser.add_argument("--data-dir", type=str, default=None, help='Possible override of the datadir in the dataset config') 76 | parser.add_argument("--metric-dir", type=str, default=None, help='Path in which to store the metrics') 77 | args = parser.parse_args() 78 | 79 | logging.basicConfig( 80 | format='%(asctime)s - %(levelname)s - %(message)s', 81 | datefmt='%m/%d/%Y %I:%M:%S %p', 82 | level=logging.INFO, 83 | ) 84 | 85 | if args.parse_file: 86 | with open(args.config[0], 'r') as f: 87 | paths = f.read().split('\n') 88 | paths.remove("") 89 | args.config = paths 90 | 91 | for path in args.config: 92 | metrics = load_and_evaluate( 93 | path=path, 94 | set=args.set, 95 | bs=args.bs, 96 | nw=args.nw, 97 | data_dir=args.data_dir, 98 | ) 99 | print() 100 | print() 101 | 102 | if args.metric_dir is not None: 103 | with open(args.metric_dir, 'a') as f: 104 | f.write(path) 105 | f.write("\n") 106 | for split, mtrc in metrics.items(): 107 | for k, v in mtrc.items(): 108 | if k == 'epoch': 109 | continue 110 | f.write(f"{split} --> {k} : {np.around(v*100, decimals=2)}\n") 111 | f.write("\n\n") 112 | -------------------------------------------------------------------------------- /roadmap/losses/blackbox_ap.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Autonomous Learning Group 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | import torch 23 | 24 | 25 | def rank(seq): 26 | return torch.argsort(torch.argsort(seq).flip(1)) 27 | 28 | 29 | def rank_normalised(seq): 30 | return (rank(seq) + 1).float() / seq.size()[1] 31 | 32 | 33 | class TrueRanker(torch.autograd.Function): 34 | @staticmethod 35 | @torch.cuda.amp.custom_fwd 36 | def forward(ctx, sequence, lambda_val): 37 | rank = rank_normalised(sequence) 38 | ctx.lambda_val = lambda_val 39 | ctx.save_for_backward(sequence, rank) 40 | return rank 41 | 42 | @staticmethod 43 | @torch.cuda.amp.custom_bwd 44 | def backward(ctx, grad_output): 45 | sequence, rank = ctx.saved_tensors 46 | assert grad_output.shape == rank.shape 47 | sequence_prime = sequence + ctx.lambda_val * grad_output 48 | rank_prime = rank_normalised(sequence_prime) 49 | gradient = -(rank - rank_prime) / (ctx.lambda_val + 1e-8) 50 | return gradient, None 51 | 52 | 53 | class BlackBoxAP(torch.nn.Module): 54 | """ Torch module for computing recall-based loss as in 'Blackbox differentiation of Ranking-based Metrics' """ 55 | def __init__(self, 56 | lambda_val=4., 57 | margin=0.02, 58 | return_type='1-mAP', 59 | ): 60 | """ 61 | :param lambda_val: hyperparameter of black-box backprop 62 | :param margin: margin to be enforced between positives and negatives (alpha in the paper) 63 | :param interclass_coef: coefficient for interclass loss (beta in paper) 64 | :param batch_memory: how many batches should be in memory 65 | """ 66 | super().__init__() 67 | assert return_type in ["AP", "mAP", "1-mAP", "1-AP"] 68 | self.lambda_val = lambda_val 69 | self.margin = margin 70 | self.return_type = return_type 71 | 72 | def raw_map_computation(self, scores, targets): 73 | """ 74 | :param scores: NxM predicted similarity scores 75 | :param targets: NxM ground truth relevances 76 | """ 77 | # Compute map 78 | HIGH_CONSTANT = 2.0 79 | epsilon = 1e-5 80 | deviations = torch.abs(torch.randn_like(targets, device=scores.device, dtype=scores.dtype)) * (targets - 0.5) 81 | 82 | scores = scores - self.margin * deviations 83 | ranks_of_positive = TrueRanker.apply(scores, self.lambda_val) 84 | scores_for_ranking_positives = -ranks_of_positive + HIGH_CONSTANT * targets 85 | ranks_within_positive = rank_normalised(scores_for_ranking_positives) 86 | ranks_within_positive.requires_grad = False 87 | assert torch.all(ranks_within_positive * targets < ranks_of_positive * targets + epsilon) 88 | 89 | sum_of_precisions_at_j_per_class = ((ranks_within_positive / ranks_of_positive) * targets).sum(dim=1) 90 | precisions_per_class = sum_of_precisions_at_j_per_class / (targets.sum(dim=1) + epsilon) 91 | 92 | if self.return_type == '1-mAP': 93 | return 1.0 - precisions_per_class.mean() 94 | elif self.return_type == '1-AP': 95 | return 1.0 - precisions_per_class 96 | elif self.return_type == 'mAP': 97 | return precisions_per_class.mean() 98 | elif self.return_type == 'AP': 99 | return precisions_per_class 100 | 101 | def forward(self, output, target): 102 | return self.raw_map_computation(output, target.type(output.dtype)) 103 | 104 | def extra_repr(self,): 105 | return f"lambda_val={self.lambda_val}, margin={self.margin}, return_type={self.return_type}" 106 | -------------------------------------------------------------------------------- /roadmap/losses/softbin_ap.py: -------------------------------------------------------------------------------- 1 | # BSD 3-Clause License 2 | # 3 | # Copyright (c) 2019, NAVER LABS 4 | # All rights reserved. 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # 9 | # 1. Redistributions of source code must retain the above copyright notice, this 10 | # list of conditions and the following disclaimer. 11 | # 12 | # 2. Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # 16 | # 3. Neither the name of the copyright holder nor the names of its 17 | # contributors may be used to endorse or promote products derived from 18 | # this software without specific prior written permission. 19 | # 20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | import numpy as np 31 | import torch 32 | import torch.nn as nn 33 | 34 | 35 | class SoftBinAP (nn.Module): 36 | """ Differentiable AP loss, through quantization. From the paper: 37 | 38 | Learning with Average Precision: Training Image Retrieval with a Listwise Loss 39 | Jerome Revaud, Jon Almazan, Rafael Sampaio de Rezende, Cesar de Souza 40 | https://arxiv.org/abs/1906.07589 41 | 42 | Input: (N, M) values in [min, max] 43 | label: (N, M) values in {0, 1} 44 | 45 | Returns: 1 - mAP (mean AP for each n in {1..N}) 46 | Note: typically, this is what you wanna minimize 47 | """ 48 | def __init__( 49 | self, 50 | nq=20, 51 | min=-1, 52 | max=1, 53 | return_type='1-mAP', 54 | ): 55 | super().__init__() 56 | assert isinstance(nq, int) and 2 <= nq <= 100 57 | assert return_type in ["1-mAP", "AP", "1-AP", "mAP", "debug"] 58 | self.nq = nq 59 | self.min = min 60 | self.max = max 61 | self.return_type = return_type 62 | 63 | gap = max - min 64 | assert gap > 0 65 | # Initialize quantizer as non-trainable convolution 66 | self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True) 67 | q.weight = nn.Parameter(q.weight.detach(), requires_grad=False) 68 | q.bias = nn.Parameter(q.bias.detach(), requires_grad=False) 69 | a = (nq-1) / gap 70 | # First half equal to lines passing to (min+x,1) and (min+x+1/a,0) 71 | # with x = {nq-1..0}*gap/(nq-1) 72 | q.weight[:nq] = -a 73 | q.bias[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x) 74 | # First half equal to lines passing to (min+x,1) and (min+x-1/a,0) 75 | # with x = {nq-1..0}*gap/(nq-1) 76 | q.weight[nq:] = a 77 | q.bias[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x) 78 | # First and last one as a horizontal straight line 79 | q.weight[0] = q.weight[-1] = 0 80 | q.bias[0] = q.bias[-1] = 1 81 | 82 | def forward(self, x, label, qw=None): 83 | assert x.shape == label.shape # N x M 84 | N, M = x.shape 85 | # Quantize all predictions 86 | q = self.quantizer(x.unsqueeze(1)) 87 | q = torch.min(q[:, :self.nq], q[:, self.nq:]).clamp(min=0) # N x Q x M 88 | 89 | nbs = q.sum(dim=-1) # number of samples N x Q = c 90 | rec = (q * label.view(N, 1, M).float()).sum(dim=-1) # number of correct samples = c+ N x Q 91 | prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision 92 | rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1] 93 | 94 | ap = (prec * rec).sum(dim=-1) # per-image AP 95 | 96 | if self.return_type == '1-mAP': 97 | if qw is not None: 98 | ap *= qw # query weights 99 | loss = 1 - ap.mean() 100 | return loss 101 | elif self.return_type == 'AP': 102 | assert qw is None 103 | return ap 104 | elif self.return_type == 'mAP': 105 | assert qw is None 106 | return ap.mean() 107 | elif self.return_type == '1-AP': 108 | return 1 - ap 109 | elif self.return_type == 'debug': 110 | return prec, rec 111 | 112 | def extra_repr(self,): 113 | return f"nq={self.nq}, min={self.min}, max={self.max}, return_type={self.return_type}" 114 | -------------------------------------------------------------------------------- /roadmap/engine/base_update.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | import roadmap.utils as lib 8 | 9 | 10 | def _batch_optimization( 11 | config, 12 | net, 13 | batch, 14 | criterion, 15 | optimizer, 16 | scaler, 17 | epoch, 18 | memory 19 | ): 20 | with torch.cuda.amp.autocast(enabled=(scaler is not None)): 21 | di = net(batch["image"].cuda()) 22 | labels = batch["label"].cuda() 23 | scores = torch.mm(di, di.t()) 24 | label_matrix = lib.create_label_matrix(labels) 25 | 26 | if memory: 27 | memory_embeddings, memory_labels = memory(di.detach(), labels, batch["path"]) 28 | if epoch >= config.memory.activate_after: 29 | memory_scores = torch.mm(di, memory_embeddings.t()) 30 | memory_label_matrix = lib.create_label_matrix(labels, memory_labels) 31 | 32 | logs = {} 33 | losses = [] 34 | for crit, weight in criterion: 35 | if hasattr(crit, 'takes_embeddings'): 36 | loss = crit(di, labels.view(-1)) 37 | if memory: 38 | if epoch >= config.memory.activate_after: 39 | mem_loss = crit(di, labels.view(-1), memory_embeddings, memory_labels.view(-1)) 40 | 41 | else: 42 | loss = crit(scores, label_matrix) 43 | if memory: 44 | if epoch >= config.memory.activate_after: 45 | mem_loss = crit(memory_scores, memory_label_matrix) 46 | 47 | loss = loss.mean() 48 | if weight == 'adaptative': 49 | losses.append(loss) 50 | else: 51 | losses.append(weight * loss) 52 | 53 | logs[crit.__class__.__name__] = loss.item() 54 | if memory: 55 | if epoch >= config.memory.activate_after: 56 | mem_loss = mem_loss.mean() 57 | if weight == 'adaptative': 58 | losses.append(config.memory.weight * mem_loss) 59 | else: 60 | losses.append(weight * config.memory.weight * mem_loss) 61 | logs[f"memory_{crit.__class__.__name__}"] = mem_loss.item() 62 | 63 | if weight == 'adaptative': 64 | grads = [] 65 | for i, lss in enumerate(losses): 66 | g = torch.autograd.grad(lss, net.fc.parameters(), retain_graph=True) 67 | grads.append(torch.norm(g[0]).item()) 68 | mean_grad = np.mean(grads) 69 | weights = [mean_grad / g for g in grads] 70 | losses = [w * lss for w, lss in zip(weights, losses)] 71 | logs.update({ 72 | f"weight_{crit.__class__.__name__}": w for (crit, _), w in zip(criterion, weights) 73 | }) 74 | logs.update({ 75 | f"grad_{crit.__class__.__name__}": w for (crit, _), w in zip(criterion, grads) 76 | }) 77 | 78 | total_loss = sum(losses) 79 | if scaler is None: 80 | total_loss.backward() 81 | else: 82 | scaler.scale(total_loss).backward() 83 | 84 | logs["total_loss"] = total_loss.item() 85 | _ = [loss.detach_() for loss in losses] 86 | total_loss.detach_() 87 | return logs 88 | 89 | 90 | def base_update( 91 | config, 92 | net, 93 | loader, 94 | criterion, 95 | optimizer, 96 | scheduler, 97 | scaler, 98 | epoch, 99 | memory=None, 100 | ): 101 | meter = lib.DictAverage() 102 | net.train() 103 | net.zero_grad() 104 | 105 | iterator = tqdm(loader, disable=os.getenv('TQDM_DISABLE')) 106 | for i, batch in enumerate(iterator): 107 | logs = _batch_optimization( 108 | config, 109 | net, 110 | batch, 111 | criterion, 112 | optimizer, 113 | scaler, 114 | epoch, 115 | memory, 116 | ) 117 | 118 | if config.experience.log_grad: 119 | grad_norm = lib.get_gradient_norm(net) 120 | logs["grad_norm"] = grad_norm.item() 121 | 122 | for key, opt in optimizer.items(): 123 | if epoch < config.experience.warm_up and key != config.experience.warm_up_key: 124 | lib.LOGGER.info(f"Warming up @epoch {epoch}") 125 | continue 126 | if scaler is None: 127 | opt.step() 128 | else: 129 | scaler.step(opt) 130 | 131 | net.zero_grad() 132 | _ = [crit.zero_grad() for crit, w in criterion] 133 | 134 | for sch in scheduler["on_step"]: 135 | sch.step() 136 | 137 | if scaler is not None: 138 | scaler.update() 139 | 140 | meter.update(logs) 141 | if not os.getenv('TQDM_DISABLE'): 142 | iterator.set_postfix(meter.avg) 143 | else: 144 | if (i + 1) % 50 == 0: 145 | lib.LOGGER.info(f'Iteration : {i}/{len(loader)}') 146 | for k, v in logs.items(): 147 | lib.LOGGER.info(f'Loss: {k}: {v} ') 148 | 149 | for crit, _ in criterion: 150 | if hasattr(crit, 'step'): 151 | crit.step() 152 | 153 | return meter.avg 154 | -------------------------------------------------------------------------------- /roadmap/getter.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | import torchvision.transforms as transforms 3 | 4 | from roadmap import losses 5 | from roadmap import samplers 6 | from roadmap import datasets 7 | from roadmap import models 8 | from roadmap import engine 9 | from roadmap import utils as lib 10 | 11 | 12 | class Getter: 13 | """ 14 | This class allows to create differents object (model, loss functions, optimizer...) 15 | based on the config 16 | """ 17 | 18 | def get(self, obj, *args, **kwargs): 19 | return getattr(self, f"get_{obj}")(*args, **kwargs) 20 | 21 | def get_transform(self, config): 22 | t_list = [] 23 | for k, v in config.items(): 24 | t_list.append(getattr(transforms, k)(**v)) 25 | 26 | transform = transforms.Compose(t_list) 27 | lib.LOGGER.info(transform) 28 | return transform 29 | 30 | def get_optimizer(self, net, config): 31 | optimizers = {} 32 | schedulers = { 33 | "on_epoch": [], 34 | "on_step": [], 35 | "on_val": [], 36 | } 37 | for opt in config: 38 | optimizer = getattr(optim, opt.name) 39 | if opt.params is not None: 40 | optimizer = optimizer(getattr(net, opt.params).parameters(), **opt.kwargs) 41 | optimizers[opt.params] = optimizer 42 | else: 43 | optimizer = optimizer(net.parameters(), **opt.kwargs) 44 | optimizers["net"] = optimizer 45 | lib.LOGGER.info(optimizer) 46 | if opt.scheduler_on_epoch is not None: 47 | schedulers["on_epoch"].append(self.get_scheduler(optimizer, opt.scheduler_on_epoch)) 48 | if opt.scheduler_on_step is not None: 49 | schedulers["on_step"].append(self.get_scheduler(optimizer, opt.scheduler_on_step)) 50 | if opt.scheduler_on_val is not None: 51 | schedulers["on_val"].append( 52 | (self.get_scheduler(optimizer, opt.scheduler_on_val), opt.scheduler_on_val.key) 53 | ) 54 | 55 | return optimizers, schedulers 56 | 57 | def get_scheduler(self, opt, config): 58 | sch = getattr(optim.lr_scheduler, config.name)(opt, **config.kwargs) 59 | lib.LOGGER.info(sch) 60 | return sch 61 | 62 | def get_loss(self, config): 63 | criterion = [] 64 | for crit in config: 65 | loss = getattr(losses, crit.name)(**crit.kwargs) 66 | weight = crit.weight 67 | lib.LOGGER.info(f"{loss} with weight {weight}") 68 | criterion.append((loss, weight)) 69 | return criterion 70 | 71 | def get_sampler(self, dataset, config): 72 | sampler = getattr(samplers, config.name)(dataset, **config.kwargs) 73 | lib.LOGGER.info(sampler) 74 | return sampler 75 | 76 | def get_dataset(self, transform, mode, config): 77 | if (config.name == "InShopDataset") and (mode == "test"): 78 | dataset = { 79 | "test": getattr(datasets, config.name)(transform=transform, mode="query", **config.kwargs), 80 | "gallery": getattr(datasets, config.name)(transform=transform, mode="gallery", **config.kwargs), 81 | } 82 | lib.LOGGER.info(dataset) 83 | return dataset 84 | elif (config.name == "DyMLDataset") and mode.startswith("test"): 85 | dataset = { 86 | "test": getattr(datasets, config.name)(transform=transform, mode="test_query_fine", **config.kwargs), 87 | "distractor": getattr(datasets, config.name)(transform=transform, mode="test_gallery_fine", **config.kwargs), 88 | } 89 | lib.LOGGER.info(dataset) 90 | return dataset 91 | elif (config.name == "SfM120kDataset") and (mode == "test"): 92 | dataset = [] 93 | for dts in config.evaluation: 94 | test_dts = getattr(datasets, dts.name)(transform=transform, mode="query", **dts.kwargs) 95 | dataset.append({ 96 | f"query_{test_dts.city}": test_dts, 97 | f"gallery_{test_dts.city}": getattr(datasets, dts.name)(transform=transform, mode="gallery", **dts.kwargs), 98 | }) 99 | lib.LOGGER.info(dataset) 100 | return dataset 101 | else: 102 | dataset = getattr(datasets, config.name)( 103 | transform=transform, 104 | mode=mode, 105 | **config.kwargs, 106 | ) 107 | lib.LOGGER.info(dataset) 108 | return dataset 109 | 110 | def get_model(self, config): 111 | net = getattr(models, config.name)(**config.kwargs) 112 | if config.freeze_batch_norm: 113 | lib.LOGGER.info("Freezing batch norm") 114 | net = lib.freeze_batch_norm(net) 115 | if config.freeze_pos_embedding: 116 | lib.LOGGER.info("Freezing pos embeddings") 117 | net.backbone = lib.freeze_pos_embedding(net.backbone) 118 | return net 119 | 120 | def get_memory(self, config): 121 | memory = getattr(engine, config.name)(**config.kwargs) 122 | lib.LOGGER.info(memory) 123 | return memory 124 | -------------------------------------------------------------------------------- /roadmap/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import Counter 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | from PIL import ImageFilter 10 | 11 | 12 | class BaseDataset(Dataset): 13 | 14 | def __init__( 15 | self, 16 | multi_crop=False, 17 | size_crops=[224, 96], 18 | nmb_crops=[2, 6], 19 | min_scale_crops=[0.14, 0.05], 20 | max_scale_crops=[1., 0.14], 21 | size_dataset=-1, 22 | return_label='none', 23 | ): 24 | super().__init__() 25 | 26 | if not multi_crop: 27 | self.get_fn = self.simple_get 28 | else: 29 | # adapted from 30 | # https://github.com/facebookresearch/swav/blob/master/src/multicropdataset.py 31 | self.get_fn = self.multiple_crop_get 32 | 33 | self.return_label = return_label 34 | assert self.return_label in ["none", "real", "hash"] 35 | 36 | color_transform = [get_color_distortion(), PILRandomGaussianBlur()] 37 | mean = [0.485, 0.456, 0.406] 38 | std = [0.228, 0.224, 0.225] 39 | trans = [] 40 | for i in range(len(size_crops)): 41 | randomresizedcrop = transforms.RandomResizedCrop( 42 | size_crops[i], 43 | scale=(min_scale_crops[i], max_scale_crops[i]), 44 | ) 45 | trans.extend([transforms.Compose([ 46 | randomresizedcrop, 47 | transforms.RandomHorizontalFlip(p=0.5), 48 | transforms.Compose(color_transform), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=mean, std=std)]) 51 | ] * nmb_crops[i]) 52 | self.trans = trans 53 | 54 | def __len__(self,): 55 | return len(self.paths) 56 | 57 | @property 58 | def my_at_R(self,): 59 | if not hasattr(self, '_at_R'): 60 | self._at_R = max(Counter(self.labels).values()) 61 | return self._at_R 62 | 63 | def get_instance_dict(self,): 64 | self.instance_dict = {cl: [] for cl in set(self.labels)} 65 | for idx, cl in enumerate(self.labels): 66 | self.instance_dict[cl].append(idx) 67 | 68 | def get_super_dict(self,): 69 | if hasattr(self, 'super_labels') and self.super_labels is not None: 70 | self.super_dict = {ct: {} for ct in set(self.super_labels)} 71 | for idx, cl, ct in zip(range(len(self.labels)), self.labels, self.super_labels): 72 | try: 73 | self.super_dict[ct][cl].append(idx) 74 | except KeyError: 75 | self.super_dict[ct][cl] = [idx] 76 | 77 | def simple_get(self, idx): 78 | pth = self.paths[idx] 79 | img = Image.open(pth).convert('RGB') 80 | if self.transform: 81 | img = self.transform(img) 82 | 83 | label = self.labels[idx] 84 | label = torch.tensor([label]) 85 | out = {"image": img, "label": label, "path": pth} 86 | 87 | if hasattr(self, 'super_labels') and self.super_labels is not None: 88 | super_label = self.super_labels[idx] 89 | super_label = torch.tensor([super_label]) 90 | out['super_label'] = super_label 91 | 92 | return out 93 | 94 | def multiple_crop_get(self, idx): 95 | pth = self.paths[idx] 96 | image = Image.open(pth).convert('RGB') 97 | multi_crops = list(map(lambda trans: trans(image), self.trans)) 98 | 99 | if self.return_label == 'real': 100 | label = self.labels[idx] 101 | labels = [label] * len(multi_crops) 102 | return {"image": multi_crops, "label": labels, "path": pth} 103 | 104 | if self.return_label == 'hash': 105 | label = abs(hash(pth)) 106 | labels = [label] * len(multi_crops) 107 | return {"image": multi_crops, "label": labels, "path": pth} 108 | 109 | return {"image": multi_crops, "path": pth} 110 | 111 | def __getitem__(self, idx): 112 | return self.get_fn(idx) 113 | 114 | def __repr__(self,): 115 | return f"{self.__class__.__name__}(mode={self.mode}, len={len(self)})" 116 | 117 | 118 | class PILRandomGaussianBlur(object): 119 | """ 120 | Apply Gaussian Blur to the PIL image. Take the radius and probability of 121 | application as the parameter. 122 | This transform was used in SimCLR - https://arxiv.org/abs/2002.05709 123 | """ 124 | 125 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 126 | self.prob = p 127 | self.radius_min = radius_min 128 | self.radius_max = radius_max 129 | 130 | def __call__(self, img): 131 | do_it = np.random.rand() <= self.prob 132 | if not do_it: 133 | return img 134 | 135 | return img.filter( 136 | ImageFilter.GaussianBlur( 137 | radius=random.uniform(self.radius_min, self.radius_max) 138 | ) 139 | ) 140 | 141 | 142 | def get_color_distortion(s=1.0): 143 | # s is the strength of color distortion. 144 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 145 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 146 | rnd_gray = transforms.RandomGrayscale(p=0.2) 147 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) 148 | return color_distort 149 | -------------------------------------------------------------------------------- /roadmap/engine/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from pytorch_metric_learning import testers 5 | import pytorch_metric_learning.utils.common_functions as c_f 6 | from tqdm import tqdm 7 | 8 | import roadmap.utils as lib 9 | from .accuracy_calculator import get_accuracy_calculator 10 | 11 | 12 | class GlobalEmbeddingSpaceTester(testers.GlobalEmbeddingSpaceTester): 13 | 14 | def label_levels_to_evaluate(self, query_labels): 15 | num_levels_available = query_labels.shape[1] 16 | if self.label_hierarchy_level == "all": 17 | return range(num_levels_available) 18 | elif isinstance(self.label_hierarchy_level, int): 19 | assert self.label_hierarchy_level < num_levels_available 20 | return [self.label_hierarchy_level] 21 | elif c_f.is_list_or_tuple(self.label_hierarchy_level): 22 | # assert max(self.label_hierarchy_level) < num_levels_available 23 | return self.label_hierarchy_level 24 | 25 | def compute_all_embeddings(self, dataloader, trunk_model, embedder_model): 26 | s, e = 0, 0 27 | with torch.no_grad(): 28 | lib.LOGGER.info("Computing embeddings") 29 | # added the option of disabling TQDM 30 | for i, data in enumerate(tqdm(dataloader, disable=os.getenv('TQDM_DISABLE'))): 31 | img, label = self.data_and_label_getter(data) 32 | label = c_f.process_label(label, "all", self.label_mapper) 33 | q = self.get_embeddings_for_eval(trunk_model, embedder_model, img) 34 | if label.dim() == 1: 35 | label = label.unsqueeze(1) 36 | if i == 0: 37 | labels = torch.zeros( 38 | len(dataloader.dataset), 39 | label.size(1), 40 | device=self.data_device, 41 | dtype=label.dtype, 42 | ) 43 | all_q = torch.zeros( 44 | len(dataloader.dataset), 45 | q.size(1), 46 | device=self.data_device, 47 | dtype=q.dtype, 48 | ) 49 | e = s + q.size(0) 50 | all_q[s:e] = q 51 | labels[s:e] = label 52 | s = e 53 | return all_q, labels 54 | 55 | 56 | def get_tester( 57 | normalize_embeddings=False, 58 | batch_size=64, 59 | num_workers=16, 60 | pca=None, 61 | exclude_ranks=None, 62 | k=2047, 63 | **kwargs, 64 | ): 65 | calculator = get_accuracy_calculator( 66 | exclude_ranks=exclude_ranks, 67 | k=k, 68 | **kwargs, 69 | ) 70 | 71 | return GlobalEmbeddingSpaceTester( 72 | normalize_embeddings=normalize_embeddings, 73 | data_and_label_getter=lambda batch: (batch["image"], batch["label"]), 74 | batch_size=batch_size, 75 | dataloader_num_workers=num_workers, 76 | accuracy_calculator=calculator, 77 | data_device=None, 78 | pca=pca, 79 | ) 80 | 81 | 82 | @lib.get_set_random_state 83 | def evaluate( 84 | net, 85 | train_dataset=None, 86 | val_dataset=None, 87 | test_dataset=None, 88 | epoch=None, 89 | tester=None, 90 | custom_eval=None, 91 | **kwargs 92 | ): 93 | at_R = 0 94 | 95 | dataset_dict = {} 96 | splits_to_eval = [] 97 | if train_dataset is not None: 98 | dataset_dict["train"] = train_dataset 99 | splits_to_eval.append(('train', ['train'])) 100 | at_R = max(at_R, train_dataset.my_at_R) 101 | 102 | if val_dataset is not None: 103 | dataset_dict["val"] = val_dataset 104 | splits_to_eval.append(('val', ['val'])) 105 | at_R = max(at_R, val_dataset.my_at_R) 106 | 107 | if test_dataset is not None: 108 | if isinstance(test_dataset, dict): 109 | if 'gallery' in test_dataset: 110 | dataset_dict.update(test_dataset) 111 | splits_to_eval.append(('test', ['gallery'])) 112 | at_R = max(at_R, test_dataset['test'].my_at_R, test_dataset['gallery'].my_at_R) 113 | elif 'distractor' in test_dataset: 114 | dataset_dict.update(test_dataset) 115 | splits_to_eval.append(('test', ['test', 'distractor'])) 116 | at_R = max(at_R, test_dataset['test'].my_at_R, test_dataset['distractor'].my_at_R) 117 | elif isinstance(test_dataset, list): 118 | for dts in test_dataset: 119 | dataset_dict.update(dts) 120 | names = list(dts.keys()) 121 | at_R = max(at_R, list(dts.values())[0].my_at_R, list(dts.values())[1].my_at_R) 122 | splits_to_eval.append(( 123 | names[0] if names[0].startswith("query") else names[1], 124 | [names[0] if names[0].startswith("gallery") else names[1]] 125 | )) 126 | else: 127 | dataset_dict["test"] = test_dataset 128 | splits_to_eval.append(('test', ['test'])) 129 | at_R = max(at_R, test_dataset.my_at_R) 130 | 131 | if custom_eval is not None: 132 | dataset_dict = custom_eval["dataset"] 133 | splits_to_eval = custom_eval["splits"] 134 | 135 | if tester is None: 136 | # next lines usefull when computing only the mAP@R and small recall values 137 | # if ('k' not in kwargs) and (at_R != 0): 138 | # kwargs["k"] = at_R + 1 139 | tester = get_tester(**kwargs) 140 | 141 | return tester.test( 142 | dataset_dict=dataset_dict, 143 | epoch=f"{epoch}", 144 | trunk_model=net, 145 | splits_to_eval=splits_to_eval, 146 | ) 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robust And Decomposable Average Precision for Image Retrieval (NeurIPS 2021) 2 | 3 | This repository contains the source code for our [ROADMAP paper (NeurIPS 2021)](https://arxiv.org/abs/2110.01445). 4 | 5 | ![outline](https://github.com/elias-ramzi/ROADMAP/blob/main/picture/outline.png) 6 | 7 | ## Use ROADMAP 8 | 9 | ``` 10 | python3 -m venv .venv 11 | source .venv/bin/activate 12 | pip install -e . 13 | ``` 14 | 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/robust-and-decomposable-average-precision-for/image-retrieval-on-inaturalist)](https://paperswithcode.com/sota/image-retrieval-on-inaturalist?p=robust-and-decomposable-average-precision-for) 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/robust-and-decomposable-average-precision-for/image-retrieval-on-sop)](https://paperswithcode.com/sota/image-retrieval-on-sop?p=robust-and-decomposable-average-precision-for) 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/robust-and-decomposable-average-precision-for/image-retrieval-on-cub-200-2011)](https://paperswithcode.com/sota/image-retrieval-on-cub-200-2011?p=robust-and-decomposable-average-precision-for) 18 | 19 | ## Datasets 20 | 21 | We use the following datasets for our submission 22 | 23 | - CUB-200-2011 (download link available on this website : http://www.vision.caltech.edu/visipedia/CUB-200.html) 24 | - Stanford Online Products (you can download it here : https://cvgl.stanford.edu/projects/lifted_struct/) 25 | - INaturalist-2018 (obtained from here https://github.com/visipedia/inat_comp/tree/master/2018#Data) 26 | 27 | 28 | ## Run the code 29 | 30 |
31 | SOP
32 | 33 | The following command reproduce our results for Table 4. 34 | 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \ 37 | 'experience.experiment_name=sop_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota' \ 38 | experience.seed=333 \ 39 | experience.max_iter=100 \ 40 | 'experience.log_dir=${env:HOME}experiments/ROADMAP' \ 41 | optimizer=sop \ 42 | model=resnet \ 43 | transform=sop_big \ 44 | dataset=sop \ 45 | dataset.sampler.kwargs.batch_size=128 \ 46 | dataset.sampler.kwargs.batches_per_super_pair=10 \ 47 | loss=roadmap 48 | ``` 49 | 50 | With the transformer backbone : 51 | 52 | ``` 53 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \ 54 | 'experience.experiment_name=sop_ROADMAP_${dataset.sampler.kwargs.batch_size}_DeiT' \ 55 | experience.seed=333 \ 56 | experience.max_iter=75 \ 57 | 'experience.log_dir=${env:HOME}/experiments/ROADMAP' \ 58 | optimizer=sop_deit \ 59 | model=deit \ 60 | transform=sop \ 61 | dataset=sop \ 62 | dataset.sampler.kwargs.batch_size=128 \ 63 | dataset.sampler.kwargs.batches_per_super_pair=10 \ 64 | loss=roadmap 65 | ``` 66 |
67 | 68 | 69 |
70 | INaturalist
71 | 72 | For ROADMAP sota results: 73 | 74 | ``` 75 | CUDA_VISIBLE_DEVICES='0,1,2' python roadmap/single_experiment_runner.py \ 76 | 'experience.experiment_name=inat_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota' \ 77 | experience.seed=333 \ 78 | experience.max_iter=90 \ 79 | 'experience.log_dir=experiments/ROADMAP' \ 80 | optimizer=inaturalist \ 81 | model=resnet \ 82 | transform=inaturalist \ 83 | dataset=inaturalist \ 84 | dataset.sampler.kwargs.batch_size=384 \ 85 | loss=roadmap_inat 86 | ``` 87 |
88 | 89 | 90 |
91 | CUB-200-2011
92 | 93 | For ROADMAP sota results: 94 | 95 | ``` 96 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \ 97 | 'experience.experiment_name=cub_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota' \ 98 | experience.seed=333 \ 99 | experience.max_iter=200 \ 100 | 'experience.log_dir=${env:HOME}/experiments/ROADMAP' \ 101 | optimizer=cub \ 102 | model=resnet_max_ln \ 103 | transform=cub_big \ 104 | dataset=cub \ 105 | dataset.sampler.kwargs.batch_size=128 \ 106 | loss=roadmap 107 | ``` 108 | 109 | ``` 110 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \ 111 | 'experience.experiment_name=cub_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota_DeiT' \ 112 | experience.seed=333 \ 113 | experience.max_iter=150 \ 114 | 'experience.log_dir=${env:HOME}/experiments/ROADMAP' \ 115 | optimizer=cub_deit \ 116 | model=deit \ 117 | transform=cub \ 118 | dataset=cub \ 119 | dataset.sampler.kwargs.batch_size=128 \ 120 | loss=roadmap 121 | ``` 122 | 123 |
124 | 125 | 126 | The results are not exactly the same as my code changed a bit (for instance the random seed are not the same). 127 | 128 | 129 | ## Contacts 130 | 131 | If you have any questions don't hesitate to create an issue on this repository. Or send me an email at elias.ramzi@lecnam.net. 132 | 133 | Don't hesitate to cite our work: 134 | ``` 135 | @inproceedings{ 136 | ramzi2021robust, 137 | title={Robust and Decomposable Average Precision for Image Retrieval}, 138 | author={Elias Ramzi and Nicolas THOME and Cl{\'e}ment Rambour and Nicolas Audebert and Xavier Bitot}, 139 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 140 | year={2021}, 141 | url={https://openreview.net/forum?id=VjQw3v3FpJx} 142 | } 143 | ``` 144 | 145 | 146 | ## Resources 147 | - Pytorch Metric Learning (PML): https://github.com/KevinMusgrave/pytorch-metric-learning 148 | - SmoothAP: https://github.com/Andrew-Brown1/Smooth_AP 149 | - Blackbox: https://github.com/martius-lab/blackbox-backprop 150 | - FastAP: https://github.com/kunhe/FastAP-metric-learning 151 | - SoftBinAP: https://github.com/naver/deep-image-retrieval 152 | - timm: https://github.com/rwightman/pytorch-image-models 153 | - PyTorch: https://github.com/pytorch/pytorch 154 | - Hydra: https://github.com/facebookresearch/hydra 155 | - Faiss: https://github.com/facebookresearch/faiss 156 | - Ray: https://github.com/ray-project/ray 157 | -------------------------------------------------------------------------------- /roadmap/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import logging 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | import roadmap.utils as lib 12 | import roadmap.engine as eng 13 | from roadmap.getter import Getter 14 | 15 | 16 | def run(config, base_config=None, checkpoint_dir=None, splits=None): 17 | """ 18 | creates all objects required to launch a training 19 | """ 20 | # """""""""""""""""" Handle Config """""""""""""""""""""""""" 21 | if base_config is not None: 22 | log_dir = None 23 | config = lib.override_config( 24 | hyperparameters=config, 25 | config=base_config, 26 | ) 27 | 28 | else: 29 | log_dir = lib.expand_path(config.experience.log_dir) 30 | log_dir = join(log_dir, config.experience.experiment_name) 31 | os.makedirs(join(log_dir, 'logs'), exist_ok=True) 32 | os.makedirs(join(log_dir, 'weights'), exist_ok=True) 33 | 34 | # """""""""""""""""" Handle Logging """""""""""""""""""""""""" 35 | logging.basicConfig( 36 | format='%(asctime)s - %(levelname)s - %(message)s', 37 | datefmt='%m/%d/%Y %I:%M:%S', 38 | level=logging.INFO, 39 | ) 40 | 41 | if checkpoint_dir is None: 42 | state = None 43 | restore_epoch = 0 44 | else: 45 | lib.LOGGER.info(f"Resuming from state : {checkpoint_dir}") 46 | state = torch.load(checkpoint_dir, map_location='cpu') 47 | restore_epoch = state['epoch'] 48 | 49 | if log_dir is None: 50 | from ray import tune 51 | writer = SummaryWriter(join(tune.get_trial_dir(), "logs"), purge_step=restore_epoch) 52 | else: 53 | writer = SummaryWriter(join(log_dir, "logs"), purge_step=restore_epoch) 54 | 55 | lib.LOGGER.info(f"Training with seed {config.experience.seed}") 56 | random.seed(config.experience.seed) 57 | np.random.seed(config.experience.seed) 58 | torch.manual_seed(config.experience.seed) 59 | torch.cuda.manual_seed_all(config.experience.seed) 60 | torch.backends.cudnn.deterministic = True 61 | torch.backends.cudnn.benchmark = False 62 | 63 | getter = Getter() 64 | 65 | # """""""""""""""""" Create Data """""""""""""""""""""""""" 66 | train_transform = getter.get_transform(config.transform.train) 67 | test_transform = getter.get_transform(config.transform.test) 68 | if config.experience.split is not None: 69 | assert isinstance(config.experience.split, int) 70 | dts = getter.get_dataset(None, 'all', config.dataset) 71 | splits = eng.get_splits( 72 | dts.labels, dts.super_labels, 73 | config.experience.kfold, 74 | random_state=config.experience.split_random_state, 75 | with_super_labels=config.experience.with_super_labels) 76 | train_dts = eng.make_subset(dts, splits[config.experience.split]['train'], train_transform, 'train') 77 | test_dts = eng.make_subset(dts, splits[config.experience.split]['val'], test_transform, 'test') 78 | val_dts = None 79 | lib.LOGGER.info(train_dts) 80 | lib.LOGGER.info(test_dts) 81 | else: 82 | train_dts = getter.get_dataset(train_transform, 'train', config.dataset) 83 | test_dts = getter.get_dataset(test_transform, 'test', config.dataset) 84 | val_dts = None 85 | 86 | sampler = getter.get_sampler(train_dts, config.dataset.sampler) 87 | 88 | # """""""""""""""""" Create Network """""""""""""""""""""""""" 89 | net = getter.get_model(config.model) 90 | 91 | scaler = None 92 | if config.model.kwargs.with_autocast: 93 | scaler = torch.cuda.amp.GradScaler() 94 | if checkpoint_dir: 95 | scaler.load_state_dict(state['scaler_state']) 96 | 97 | if checkpoint_dir: 98 | net.load_state_dict(state['net_state']) 99 | net.cuda() 100 | 101 | # """""""""""""""""" Create Optimizer & Scheduler """""""""""""""""""""""""" 102 | optimizer, scheduler = getter.get_optimizer(net, config.optimizer) 103 | 104 | if checkpoint_dir: 105 | for key, opt in optimizer.items(): 106 | opt.load_state_dict(state['optimizer_state'][key]) 107 | 108 | if config.experience.force_lr is not None: 109 | _ = [lib.set_lr(opt, config.experience.force_lr) for opt in optimizer.values()] 110 | lib.LOGGER.info(optimizer) 111 | 112 | if checkpoint_dir: 113 | for key, schs in scheduler.items(): 114 | for sch, sch_state in zip(schs, state[f'scheduler_{key}_state']): 115 | sch.load_state_dict(sch_state) 116 | 117 | # """""""""""""""""" Create Criterion """""""""""""""""""""""""" 118 | criterion = getter.get_loss(config.loss) 119 | 120 | # """""""""""""""""" Create Memory """""""""""""""""""""""""" 121 | memory = None 122 | if config.memory.name is not None: 123 | lib.LOGGER.info("Using cross batch memory") 124 | memory = getter.get_memory(config.memory) 125 | memory.cuda() 126 | 127 | # """""""""""""""""" Handle Cuda """""""""""""""""""""""""" 128 | if torch.cuda.device_count() > 1: 129 | lib.LOGGER.info("Model is parallelized") 130 | net = nn.DataParallel(net) 131 | 132 | net.cuda() 133 | _ = [crit.cuda() for crit, _ in criterion] 134 | 135 | # """""""""""""""""" Handle RANDOM_STATE """""""""""""""""""""""""" 136 | if state is not None: 137 | # set random NumPy and Torch random states 138 | lib.set_random_state(state) 139 | 140 | return eng.train( 141 | config=config, 142 | log_dir=log_dir, 143 | net=net, 144 | criterion=criterion, 145 | optimizer=optimizer, 146 | scheduler=scheduler, 147 | scaler=scaler, 148 | memory=memory, 149 | train_dts=train_dts, 150 | val_dts=val_dts, 151 | test_dts=test_dts, 152 | sampler=sampler, 153 | writer=writer, 154 | restore_epoch=restore_epoch, 155 | ) 156 | 157 | 158 | if __name__ == '__main__': 159 | run() 160 | -------------------------------------------------------------------------------- /roadmap/engine/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | from time import time 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | import roadmap.utils as lib 9 | from .base_update import base_update 10 | from .evaluate import evaluate 11 | from .landmark_evaluation import landmark_evaluation 12 | from . import checkpoint 13 | 14 | 15 | def train( 16 | config, 17 | log_dir, 18 | net, 19 | criterion, 20 | optimizer, 21 | scheduler, 22 | scaler, 23 | memory, 24 | train_dts, 25 | val_dts, 26 | test_dts, 27 | sampler, 28 | writer, 29 | restore_epoch, 30 | ): 31 | # """""""""""""""""" Iter over epochs """""""""""""""""""""""""" 32 | lib.LOGGER.info(f"Training of model {config.experience.experiment_name}") 33 | best_score = 0. 34 | best_model = None 35 | 36 | metrics = None 37 | for e in range(1 + restore_epoch, config.experience.max_iter + 1): 38 | 39 | lib.LOGGER.info(f"Training : @epoch #{e} for model {config.experience.experiment_name}") 40 | start_time = time() 41 | 42 | # """""""""""""""""" Training Loop """""""""""""""""""""""""" 43 | sampler.reshuffle() 44 | loader = DataLoader( 45 | train_dts, 46 | batch_sampler=sampler, 47 | num_workers=config.experience.num_workers, 48 | pin_memory=config.experience.pin_memory, 49 | ) 50 | logs = base_update( 51 | config=config, 52 | net=net, 53 | loader=loader, 54 | criterion=criterion, 55 | optimizer=optimizer, 56 | scheduler=scheduler, 57 | scaler=scaler, 58 | epoch=e, 59 | memory=memory, 60 | ) 61 | 62 | for sch in scheduler["on_epoch"]: 63 | sch.step() 64 | 65 | end_train_time = time() 66 | 67 | dataset_dict = {} 68 | if (config.experience.train_eval_freq > -1) and ((e % config.experience.train_eval_freq == 0) or (e == config.experience.max_iter)): 69 | dataset_dict["train_dataset"] = train_dts 70 | 71 | if (config.experience.val_eval_freq > -1) and ((e % config.experience.val_eval_freq == 0) or (e == config.experience.max_iter)): 72 | dataset_dict["val_dataset"] = val_dts 73 | 74 | if (config.experience.test_eval_freq > -1) and ((e % config.experience.test_eval_freq == 0) or (e == config.experience.max_iter)): 75 | dataset_dict["test_dataset"] = test_dts 76 | 77 | metrics = None 78 | if dataset_dict: 79 | RANDOM_STATE = random.getstate() 80 | NP_STATE = np.random.get_state() 81 | TORCH_STATE = torch.random.get_rng_state() 82 | TORCH_CUDA_STATE = torch.cuda.get_rng_state_all() 83 | 84 | lib.LOGGER.info(f"Evaluation : @epoch #{e} for model {config.experience.experiment_name}") 85 | torch.cuda.empty_cache() 86 | if config.experience.landmarks: 87 | metrics = landmark_evaluation( 88 | net=net, 89 | datasets=test_dts, 90 | batch_size=config.experience.eval_bs, 91 | num_workers=config.experience.num_workers, 92 | ) 93 | else: 94 | metrics = evaluate( 95 | net, 96 | epoch=e, 97 | batch_size=config.experience.eval_bs, 98 | num_workers=config.experience.num_workers, 99 | with_AP=config.experience.with_AP, 100 | **dataset_dict, 101 | ) 102 | torch.cuda.empty_cache() 103 | 104 | random.setstate(RANDOM_STATE) 105 | np.random.set_state(NP_STATE) 106 | torch.random.set_rng_state(TORCH_STATE) 107 | torch.cuda.set_rng_state_all(TORCH_CUDA_STATE) 108 | 109 | # """""""""""""""""" Evaluate Model """""""""""""""""""""""""" 110 | score = None 111 | if metrics is not None: 112 | score = metrics[config.experience.eval_split][config.experience.principal_metric] 113 | if score > best_score: 114 | best_model = f"epoch_{e}" 115 | best_score = score 116 | 117 | if log_dir is None: 118 | from ray import tune 119 | tune.report(**metrics[config.experience.eval_split]) 120 | 121 | for sch, key in scheduler["on_val"]: 122 | sch.step(metrics[config.experience.eval_split][key]) 123 | 124 | # """""""""""""""""" Logging Step """""""""""""""""""""""""" 125 | for grp, opt in optimizer.items(): 126 | writer.add_scalar(f"LR/{grp}", list(lib.get_lr(opt).values())[0], e) 127 | 128 | for k, v in logs.items(): 129 | lib.LOGGER.info(f"{k} : {v:.4f}") 130 | writer.add_scalar(f"Train/{k}", v, e) 131 | 132 | if metrics is not None: 133 | for split, mtrc in metrics.items(): 134 | for k, v in mtrc.items(): 135 | if k == 'epoch': 136 | continue 137 | lib.LOGGER.info(f"{split} --> {k} : {np.around(v*100, decimals=2)}") 138 | writer.add_scalar(f"{split.title()}/Evaluation/{k}", v, e) 139 | print() 140 | 141 | end_time = time() 142 | 143 | elapsed_time = lib.format_time(end_time - start_time) 144 | elapsed_time_train = lib.format_time(end_train_time - start_time) 145 | elapsed_time_eval = lib.format_time(end_time - end_train_time) 146 | 147 | lib.LOGGER.info(f"Epoch took : {elapsed_time}") 148 | lib.LOGGER.info(f"Training loop took : {elapsed_time_train}") 149 | if metrics is not None: 150 | lib.LOGGER.info(f"Evaluation step took : {elapsed_time_eval}") 151 | 152 | print() 153 | print() 154 | 155 | # """""""""""""""""" Checkpointing """""""""""""""""""""""""" 156 | checkpoint( 157 | log_dir=log_dir, 158 | save_checkpoint=(e % config.experience.save_model == 0), 159 | net=net, 160 | optimizer=optimizer, 161 | scheduler=scheduler, 162 | scaler=scaler, 163 | epoch=e, 164 | seed=config.experience.seed, 165 | args=config, 166 | score=score, 167 | best_model=best_model, 168 | best_score=best_score, 169 | ) 170 | 171 | return metrics 172 | -------------------------------------------------------------------------------- /roadmap/engine/cross_validation_splits.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from sklearn.model_selection import StratifiedKFold 5 | 6 | import roadmap.utils as lib 7 | 8 | 9 | @lib.get_set_random_state 10 | def get_class_disjoint_splits(labels, kfold, random_state=None): 11 | if random_state is not None: 12 | random.seed(random_state) 13 | 14 | unique_labels = list(set(labels)) 15 | n = len(unique_labels) 16 | random.shuffle(unique_labels) 17 | 18 | classes = [] 19 | for i in range(kfold): 20 | num_to_sample = n // kfold + (0 if i > n % kfold else 1) 21 | classes.append(set(unique_labels[:num_to_sample])) 22 | del unique_labels[:num_to_sample] 23 | 24 | indexes = [[] for _ in range(kfold)] 25 | for idx, cl in enumerate(labels): 26 | for i in range(kfold): 27 | if cl in classes[i]: 28 | indexes[i].append(idx) 29 | break 30 | 31 | splits = [{'train': [], 'val': []} for _ in range(kfold)] 32 | for i in range(kfold): 33 | tmp = indexes.copy() 34 | splits[i]['val'] = tmp[i] 35 | del tmp[i] 36 | splits[i]['train'] = [idx for sublist in tmp for idx in sublist] 37 | 38 | return splits 39 | 40 | 41 | @lib.get_set_random_state 42 | def get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=None): 43 | if random_state is not None: 44 | random.seed(random_state) 45 | 46 | unique_super_labels = sorted(set(super_labels)) 47 | super_dict = {slb: set() for slb in unique_super_labels} 48 | for slb, lb in zip(super_labels, labels): 49 | super_dict[slb].add(lb) 50 | 51 | super_dict = {slb: sorted(lb) for slb, lb in super_dict.items()} 52 | for slb in unique_super_labels: 53 | _ = random.shuffle(super_dict[slb]) 54 | classes = [[] for _ in range(kfold)] 55 | for slb in unique_super_labels: 56 | unique_labels = super_dict[slb].copy() 57 | n = len(unique_labels) 58 | for i in range(kfold): 59 | num_to_sample = n // kfold + (0 if i > n % kfold else 1) 60 | classes[i].extend(unique_labels[:num_to_sample]) 61 | del unique_labels[:num_to_sample] 62 | 63 | indexes = [[] for _ in range(kfold)] 64 | for idx, cl in enumerate(labels): 65 | for i in range(kfold): 66 | if cl in classes[i]: 67 | indexes[i].append(idx) 68 | break 69 | 70 | splits = [{'train': [], 'val': []} for _ in range(kfold)] 71 | for i in range(kfold): 72 | tmp = indexes.copy() 73 | splits[i]['val'] = tmp[i] 74 | del tmp[i] 75 | splits[i]['train'] = [idx for sublist in tmp for idx in sublist] 76 | 77 | return splits 78 | 79 | 80 | @lib.get_set_random_state 81 | def get_closed_set_splits(labels, kfold, random_state=None): 82 | split_generator = StratifiedKFold(n_splits=kfold, shuffle=True, random_state=random_state) 83 | 84 | splits = [{'train': [], 'val': []} for _ in range(kfold)] 85 | for i, (train_index, test_index) in enumerate(split_generator.split(np.zeros(len(labels)), labels)): 86 | splits[i]['train'] = train_index 87 | splits[i]['val'] = test_index 88 | 89 | return splits 90 | 91 | 92 | def get_splits(labels, super_labels, kfold, random_state=None, with_super_labels=False, open_set=True): 93 | if open_set: 94 | if super_labels is None: 95 | return get_class_disjoint_splits(labels, kfold, random_state) 96 | elif with_super_labels: 97 | return get_class_disjoint_splits(super_labels, kfold, random_state) 98 | else: 99 | return get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state) 100 | else: 101 | return get_closed_set_splits(labels, kfold, random_state) 102 | 103 | 104 | if __name__ == '__main__': 105 | random.seed(10) 106 | num_superlabels = 12 107 | dataset_size = 15023 108 | num_labels_per_super_labels = 17 109 | num_labels = num_superlabels * num_labels_per_super_labels 110 | kfold = 4 111 | labels = [random.randint(0, num_labels - 1) for _ in range(dataset_size)] 112 | splits = get_class_disjoint_splits(labels, kfold) 113 | 114 | for spl in splits: 115 | label_train = set([labels[x] for x in spl['train']]) 116 | label_val = set([labels[x] for x in spl['val']]) 117 | assert (len(set(spl['train'])) + len(set(spl['val']))) == dataset_size 118 | assert not label_train.intersection(label_val) 119 | 120 | num_superlabels = 12 121 | dataset_size = 15023 122 | num_labels_per_super_labels = 17 123 | num_labels = num_superlabels * num_labels_per_super_labels 124 | labels = [random.randint(0, num_labels - 1) for _ in range(dataset_size)] 125 | super_labels = [0] * len(labels) 126 | for slb in range(num_superlabels): 127 | mask = [] 128 | for idx, lb in enumerate(labels): 129 | if (lb >= slb * num_labels_per_super_labels) & (lb < (slb + 1) * num_labels_per_super_labels): 130 | mask.append(idx) 131 | 132 | for idx in mask: 133 | super_labels[idx] = slb 134 | 135 | assert len(set(zip(super_labels, labels))) == num_labels 136 | 137 | h_splits_1 = get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=1) 138 | h_splits_1_p = get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=1) 139 | h_splits_2 = get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=2) 140 | # import ipdb; ipdb.set_trace() 141 | for spl in h_splits_1: 142 | label_train = set([labels[x] for x in spl['train']]) 143 | label_val = set([labels[x] for x in spl['val']]) 144 | assert (len(set(spl['train'])) + len(set(spl['val']))) == dataset_size 145 | assert not label_train.intersection(label_val) 146 | 147 | for spl in h_splits_2: 148 | label_train = set([labels[x] for x in spl['train']]) 149 | label_val = set([labels[x] for x in spl['val']]) 150 | assert (len(set(spl['train'])) + len(set(spl['val']))) == dataset_size 151 | assert not label_train.intersection(label_val) 152 | 153 | for spl_1, spl_1_p, spl_2 in zip(h_splits_1, h_splits_1_p, h_splits_2): 154 | assert set(spl_1['train']) == set(spl_1_p['train']) 155 | assert set(spl_1['val']) == set(spl_1_p['val']) 156 | 157 | assert set(spl_1['train']) != set(spl_2['train']) 158 | assert set(spl_1['val']) != set(spl_2['val']) 159 | -------------------------------------------------------------------------------- /roadmap/engine/accuracy_calculator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_metric_learning.utils.common_functions as c_f 3 | from pytorch_metric_learning.utils.accuracy_calculator import ( 4 | AccuracyCalculator, 5 | get_label_match_counts, 6 | get_lone_query_labels, 7 | ) 8 | 9 | import roadmap.utils as lib 10 | from .get_knn import get_knn 11 | 12 | 13 | EQUALITY = torch.eq 14 | 15 | 16 | class CustomCalculator(AccuracyCalculator): 17 | 18 | def __init__( 19 | self, 20 | *args, 21 | with_faiss=True, 22 | **kwargs 23 | ): 24 | super().__init__(*args, **kwargs) 25 | self.with_faiss = with_faiss 26 | 27 | def recall_at_k(self, knn_labels, query_labels, k): 28 | recall = self.label_comparison_fn(query_labels, knn_labels[:, :k]) 29 | return recall.any(1).float().mean().item() 30 | 31 | def calculate_recall_at_1(self, knn_labels, query_labels, **kwargs): 32 | return self.recall_at_k( 33 | knn_labels, 34 | query_labels[:, None], 35 | 1, 36 | ) 37 | 38 | def calculate_recall_at_2(self, knn_labels, query_labels, **kwargs): 39 | return self.recall_at_k( 40 | knn_labels, 41 | query_labels[:, None], 42 | 2, 43 | ) 44 | 45 | def calculate_recall_at_4(self, knn_labels, query_labels, **kwargs): 46 | return self.recall_at_k( 47 | knn_labels, 48 | query_labels[:, None], 49 | 4, 50 | ) 51 | 52 | def calculate_recall_at_8(self, knn_labels, query_labels, **kwargs): 53 | return self.recall_at_k( 54 | knn_labels, 55 | query_labels[:, None], 56 | 8, 57 | ) 58 | 59 | def calculate_recall_at_10(self, knn_labels, query_labels, **kwargs): 60 | return self.recall_at_k( 61 | knn_labels, 62 | query_labels[:, None], 63 | 10, 64 | ) 65 | 66 | def calculate_recall_at_16(self, knn_labels, query_labels, **kwargs): 67 | return self.recall_at_k( 68 | knn_labels, 69 | query_labels[:, None], 70 | 16, 71 | ) 72 | 73 | def calculate_recall_at_20(self, knn_labels, query_labels, **kwargs): 74 | return self.recall_at_k( 75 | knn_labels, 76 | query_labels[:, None], 77 | 20, 78 | ) 79 | 80 | def calculate_recall_at_30(self, knn_labels, query_labels, **kwargs): 81 | return self.recall_at_k( 82 | knn_labels, 83 | query_labels[:, None], 84 | 30, 85 | ) 86 | 87 | def calculate_recall_at_32(self, knn_labels, query_labels, **kwargs): 88 | return self.recall_at_k( 89 | knn_labels, 90 | query_labels[:, None], 91 | 32, 92 | ) 93 | 94 | def calculate_recall_at_100(self, knn_labels, query_labels, **kwargs): 95 | return self.recall_at_k( 96 | knn_labels, 97 | query_labels[:, None], 98 | 100, 99 | ) 100 | 101 | def calculate_recall_at_1000(self, knn_labels, query_labels, **kwargs): 102 | return self.recall_at_k( 103 | knn_labels, 104 | query_labels[:, None], 105 | 1000, 106 | ) 107 | 108 | def requires_knn(self): 109 | return super().requires_knn() + ["recall_classic"] 110 | 111 | def get_accuracy( 112 | self, 113 | query, 114 | reference, 115 | query_labels, 116 | reference_labels, 117 | embeddings_come_from_same_source, 118 | include=(), 119 | exclude=(), 120 | return_indices=False, 121 | ): 122 | [query, reference, query_labels, reference_labels] = [ 123 | c_f.numpy_to_torch(x) 124 | for x in [query, reference, query_labels, reference_labels] 125 | ] 126 | 127 | self.curr_function_dict = self.get_function_dict(include, exclude) 128 | 129 | kwargs = { 130 | "query": query, 131 | "reference": reference, 132 | "query_labels": query_labels, 133 | "reference_labels": reference_labels, 134 | "embeddings_come_from_same_source": embeddings_come_from_same_source, 135 | "label_comparison_fn": self.label_comparison_fn, 136 | } 137 | 138 | if any(x in self.requires_knn() for x in self.get_curr_metrics()): 139 | label_counts = get_label_match_counts( 140 | query_labels, reference_labels, self.label_comparison_fn, 141 | ) 142 | 143 | lone_query_labels, not_lone_query_mask = get_lone_query_labels( 144 | query_labels, 145 | label_counts, 146 | embeddings_come_from_same_source, 147 | self.label_comparison_fn, 148 | ) 149 | 150 | num_k = self.determine_k( 151 | label_counts[1], len(reference), embeddings_come_from_same_source 152 | ) 153 | 154 | # USE OUR OWN KNN SEARCH 155 | knn_indices, knn_distances = get_knn( 156 | reference, query, num_k, embeddings_come_from_same_source, 157 | with_faiss=self.with_faiss, 158 | ) 159 | torch.cuda.empty_cache() 160 | 161 | knn_labels = reference_labels[knn_indices] 162 | if not any(not_lone_query_mask): 163 | lib.LOGGER.warning("None of the query labels are in the reference set.") 164 | kwargs["label_counts"] = label_counts 165 | kwargs["knn_labels"] = knn_labels 166 | kwargs["knn_distances"] = knn_distances 167 | kwargs["lone_query_labels"] = lone_query_labels 168 | kwargs["not_lone_query_mask"] = not_lone_query_mask 169 | 170 | if any(x in self.requires_clustering() for x in self.get_curr_metrics()): 171 | kwargs["cluster_labels"] = self.get_cluster_labels(**kwargs) 172 | 173 | if return_indices: 174 | # ADDED 175 | return knn_indices, self._get_accuracy(self.curr_function_dict, **kwargs) 176 | return self._get_accuracy(self.curr_function_dict, **kwargs) 177 | 178 | 179 | def get_accuracy_calculator( 180 | exclude_ranks=None, 181 | k=2047, 182 | with_AP=False, 183 | **kwargs, 184 | ): 185 | exclude = kwargs.pop('exclude', []) 186 | if with_AP: 187 | exclude.extend(['NMI', 'AMI']) 188 | else: 189 | exclude.extend(['NMI', 'AMI', 'mean_average_precision', 'mean_average_precision_at_r']) 190 | 191 | if exclude_ranks: 192 | for r in exclude_ranks: 193 | exclude.append(f'recall_at_{r}') 194 | 195 | return CustomCalculator( 196 | exclude=exclude, 197 | k=k, 198 | **kwargs, 199 | ) 200 | -------------------------------------------------------------------------------- /roadmap/engine/landmark_evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: 3 | https://github.com/filipradenovic/cnnimageretrieval-pytorch/blob/master/cirtorch/utils/evaluate.py 4 | """ 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | import roadmap.utils as lib 13 | 14 | 15 | def compute_ap(ranks, nres): 16 | """ 17 | Computes average precision for given ranked indexes. 18 | 19 | Arguments 20 | --------- 21 | ranks : zerro-based ranks of positive images 22 | nres : number of positive images 23 | 24 | Returns 25 | ------- 26 | ap : average precision 27 | """ 28 | 29 | # number of images ranked by the system 30 | nimgranks = len(ranks) 31 | 32 | # accumulate trapezoids in PR-plot 33 | ap = 0 34 | 35 | recall_step = 1. / nres 36 | 37 | for j in np.arange(nimgranks): 38 | rank = ranks[j] 39 | 40 | if rank == 0: 41 | precision_0 = 1. 42 | else: 43 | precision_0 = float(j) / rank 44 | 45 | precision_1 = float(j + 1) / (rank + 1) 46 | 47 | ap += (precision_0 + precision_1) * recall_step / 2. 48 | 49 | return ap 50 | 51 | 52 | def compute_map(ranks, gnd, kappas=[]): 53 | """ 54 | Computes the mAP for a given set of returned results. 55 | Usage: 56 | map = compute_map (ranks, gnd) 57 | computes mean average precsion (map) only 58 | 59 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 60 | computes mean average precision (map), average precision (aps) for each query 61 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 62 | 63 | Notes: 64 | 1) ranks starts from 0, ranks.shape = db_size X #queries 65 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 66 | 3) If there are no positive images for some query, that query is excluded from the evaluation 67 | """ 68 | 69 | map = 0. 70 | nq = len(gnd) # number of queries 71 | aps = np.zeros(nq) 72 | # pr = np.zeros(len(kappas)) 73 | # prs = np.zeros((nq, len(kappas))) 74 | nempty = 0 75 | 76 | for i in np.arange(nq): 77 | qgnd = np.array(gnd[i]['ok']) 78 | 79 | # no positive images, skip from the average 80 | if qgnd.shape[0] == 0: 81 | aps[i] = float('nan') 82 | # prs[i, :] = float('nan') 83 | nempty += 1 84 | continue 85 | 86 | try: 87 | qgndj = np.array(gnd[i]['junk']) 88 | except Exception: 89 | qgndj = np.empty(0) 90 | 91 | # sorted positions of positive and junk images (0 based) 92 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgnd)] 93 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgndj)] 94 | 95 | k = 0 96 | ij = 0 97 | if len(junk): 98 | # decrease positions of positives based on the number of 99 | # junk images appearing before them 100 | ip = 0 101 | while (ip < len(pos)): 102 | while (ij < len(junk) and pos[ip] > junk[ij]): 103 | k += 1 104 | ij += 1 105 | pos[ip] = pos[ip] - k 106 | ip += 1 107 | 108 | # compute ap 109 | ap = compute_ap(pos, len(qgnd)) 110 | map = map + ap 111 | aps[i] = ap 112 | 113 | # # compute precision @ k 114 | # pos += 1 # get it to 1-based 115 | # for j in np.arange(len(kappas)): 116 | # kq = min(max(pos), kappas[j]) 117 | # prs[i, j] = (pos <= kq).sum() / kq 118 | # pr = pr + prs[i, :] 119 | 120 | map = map / (nq - nempty) 121 | # pr = pr / (nq - nempty) 122 | 123 | return map # , aps, pr, prs 124 | 125 | 126 | def compute_map_M_and_H(ranks, gnd, kappas=[]): 127 | 128 | # gnd_t = [] 129 | # for i in range(len(gnd)): 130 | # g = {} 131 | # g['ok'] = np.concatenate([gnd[i]['easy']]) 132 | # g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['hard']]) 133 | # gnd_t.append(g) 134 | # mapE, *_ = compute_map(ranks, gnd_t, kappas) 135 | 136 | gnd_t = [] 137 | for i in range(len(gnd)): 138 | g = {} 139 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']]) 140 | g['junk'] = np.concatenate([gnd[i]['junk']]) 141 | gnd_t.append(g) 142 | mapM = compute_map(ranks, gnd_t, kappas) 143 | 144 | gnd_t = [] 145 | for i in range(len(gnd)): 146 | g = {} 147 | g['ok'] = np.concatenate([gnd[i]['hard']]) 148 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']]) 149 | gnd_t.append(g) 150 | mapH = compute_map(ranks, gnd_t, kappas) 151 | 152 | return {"mapM": mapM, "mapH": mapH} 153 | 154 | 155 | def evaluate_a_city(net, query, gallery, batch_size, num_workers): 156 | def collate_fn(batch): 157 | out = {} 158 | out["image"] = torch.stack([b["image"] for b in batch], dim=0) 159 | out["label"] = torch.cat([b["label"] for b in batch]) 160 | return out 161 | 162 | features_query = [] 163 | features_gallery = [] 164 | loader_gallery = DataLoader(gallery, batch_size=batch_size, num_workers=num_workers, pin_memory=True) 165 | loader_query = DataLoader(query, batch_size=batch_size, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn) 166 | 167 | lib.LOGGER.info("Computing embeddings") 168 | for batch in tqdm(loader_gallery, disable=os.getenv('TQDM_DISABLE')): 169 | with torch.no_grad(): 170 | X = net(batch["image"].cuda()) 171 | features_gallery.append(X) 172 | 173 | features_gallery = torch.cat(features_gallery) 174 | 175 | for batch in tqdm(loader_query, disable=os.getenv('TQDM_DISABLE')): 176 | with torch.no_grad(): 177 | X = net(batch["image"].cuda()) 178 | features_query.append(X) 179 | 180 | features_query = torch.cat(features_query) 181 | 182 | features_gallery = features_gallery.cpu().numpy() 183 | features_query = features_query.cpu().numpy() 184 | 185 | # search, rank, and print 186 | scores = np.dot(features_gallery, features_query.T) 187 | ranks = np.argsort(-scores, axis=0) 188 | 189 | return compute_map_M_and_H(ranks, query) 190 | 191 | 192 | def landmark_evaluation( 193 | net, 194 | datasets, 195 | batch_size, 196 | num_workers, 197 | ): 198 | metrics = {} 199 | for dts in datasets: 200 | city_name = list(dts.keys())[0].split('_')[-1] 201 | names = list(dts.keys()) 202 | dts_ = list(dts.values()) 203 | metrics[city_name] = evaluate_a_city( 204 | net=net, 205 | query=dts_[0] if names[0].startswith("query") else dts_[1], 206 | gallery=dts_[0] if names[0].startswith("gallery") else dts_[1], 207 | batch_size=batch_size, 208 | num_workers=num_workers, 209 | ) 210 | 211 | return metrics 212 | -------------------------------------------------------------------------------- /roadmap/losses/smooth_rank_ap.py: -------------------------------------------------------------------------------- 1 | """ 2 | inspired from 3 | https://github.com/Andrew-Brown1/Smooth_AP 4 | """ 5 | from functools import partial 6 | 7 | import torch 8 | import torch.nn as nn 9 | from tqdm.auto import tqdm 10 | 11 | import roadmap.utils as lib 12 | 13 | 14 | def heaviside(tens, val=1., target=None, general=None): 15 | return torch.heaviside(tens, values=torch.tensor(val, device=tens.device, dtype=tens.dtype)) 16 | 17 | 18 | def tau_sigmoid(tensor, tau, target=None, general=None): 19 | """ temperature controlled sigmoid 20 | takes as input a torch tensor (tensor) and passes it 21 | through a sigmoid, controlled by temperature: temp 22 | """ 23 | exponent = -tensor / tau 24 | # clamp the input tensor for stability 25 | exponent = 1. + exponent.clamp(-50, 50).exp() 26 | return 1.0 / exponent 27 | 28 | 29 | def step_rank(tens, tau, rho, offset=None, delta=None, start=0.5, target=None, general=False): 30 | target = target.squeeze() 31 | if general: 32 | target = target.view(1, -1).repeat(tens.size(0), 1) 33 | else: 34 | mask = target.unsqueeze(target.ndim - 1).bool() 35 | target = lib.create_label_matrix(target).bool() * mask 36 | pos_mask = (tens > 0).bool() 37 | neg_mask = ~pos_mask 38 | 39 | if isinstance(tau, str): 40 | tau_n, tau_p = tau.split("_") 41 | else: 42 | tau_n = tau_p = tau 43 | 44 | if delta is None: 45 | tens[~target & pos_mask] = rho * tens[~target & pos_mask] + offset 46 | else: 47 | margin_mask = tens > delta 48 | tens[~target & pos_mask & ~margin_mask] = start + tau_sigmoid(tens[~target & pos_mask & ~margin_mask], tau_p).type(tens.dtype) 49 | if offset is None: 50 | offset = tau_sigmoid(torch.tensor([delta], device=tens.device), tau_p).type(tens.dtype) + start 51 | tens[~target & pos_mask & margin_mask] = rho * (tens[~target & pos_mask & margin_mask] - delta) + offset 52 | 53 | tens[~target & neg_mask] = tau_sigmoid(tens[~target & neg_mask], tau_n).type(tens.dtype) 54 | 55 | tens[target] = torch.heaviside(tens[target], values=torch.tensor(1., device=tens.device, dtype=tens.dtype)) 56 | 57 | return tens 58 | 59 | 60 | class SmoothRankAP(nn.Module): 61 | def __init__( 62 | self, 63 | rank_approximation, 64 | return_type='1-mAP', 65 | ): 66 | super().__init__() 67 | self.rank_approximation = rank_approximation 68 | self.return_type = return_type 69 | assert return_type in ["1-mAP", "1-AP", "AP", 'mAP'] 70 | 71 | def general_forward(self, scores, target, verbose=False): 72 | batch_size = target.size(0) 73 | nb_instances = target.size(1) 74 | device = scores.device 75 | 76 | ap_score = [] 77 | mask = (1 - torch.eye(nb_instances, device=device)) 78 | iterator = range(batch_size) 79 | if verbose: 80 | iterator = tqdm(iterator, leave=None) 81 | for idx in iterator: 82 | # shape M 83 | query = scores[idx] 84 | pos_mask = target[idx].bool() 85 | 86 | # shape M x M 87 | query = query.view(1, -1) - query[pos_mask].view(-1, 1) 88 | query = self.rank_approximation(query, target=pos_mask, general=True) * mask[pos_mask] 89 | 90 | # shape M 91 | rk = 1 + query.sum(-1) 92 | 93 | # shape M 94 | pos_rk = 1 + (query * pos_mask.view(1, -1)).sum(-1) 95 | 96 | # shape 1 97 | ap = (pos_rk / rk).sum(-1) / pos_mask.sum() 98 | ap_score.append(ap) 99 | 100 | # shape N 101 | ap_score = torch.stack(ap_score) 102 | 103 | return ap_score 104 | 105 | def quick_forward(self, scores, target): 106 | batch_size = target.size(0) 107 | device = scores.device 108 | 109 | # ------ differentiable ranking of all retrieval set ------ 110 | # compute the mask which ignores the relevance score of the query to itself 111 | mask = 1.0 - torch.eye(batch_size, device=device).unsqueeze(0) 112 | # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors 113 | # compute the difference matrix 114 | sim_diff = scores.unsqueeze(1) - scores.unsqueeze(1).permute(0, 2, 1) 115 | 116 | # pass through the sigmoid 117 | sim_diff_sigmoid = self.rank_approximation(sim_diff, target=target) 118 | 119 | sim_sg = sim_diff_sigmoid * mask 120 | # compute the rankings 121 | sim_all_rk = torch.sum(sim_sg, dim=-1) + 1 122 | 123 | # ------ differentiable ranking of only positive set in retrieval set ------ 124 | # compute the mask which only gives non-zero weights to the positive set 125 | pos_mask = (target - torch.eye(batch_size).to(device)) 126 | sim_pos_sg = sim_diff_sigmoid * pos_mask 127 | sim_pos_rk = (torch.sum(sim_pos_sg, dim=-1) + target) * target 128 | # compute the rankings of the positive set 129 | 130 | ap = ((sim_pos_rk / sim_all_rk).sum(1) * (1 / target.sum(1))) 131 | return ap 132 | 133 | def forward(self, scores, target, force_general=False, verbose=False): 134 | assert scores.shape == target.shape 135 | assert len(scores.shape) == 2 136 | 137 | if (scores.size(0) == scores.size(1)) and not force_general: 138 | ap = self.quick_forward(scores, target) 139 | else: 140 | ap = self.general_forward(scores, target, verbose=verbose) 141 | 142 | if self.return_type == 'AP': 143 | return ap 144 | elif self.return_type == 'mAP': 145 | return ap.mean() 146 | elif self.return_type == '1-AP': 147 | return 1 - ap 148 | elif self.return_type == '1-mAP': 149 | return 1 - ap.mean() 150 | 151 | @property 152 | def my_repr(self,): 153 | repr = f"return_type={self.return_type}" 154 | return repr 155 | 156 | 157 | class HeavisideAP(SmoothRankAP): 158 | """here for testing purposes""" 159 | 160 | def __init__(self, **kwargs): 161 | rank_approximation = partial(heaviside) 162 | super().__init__(rank_approximation, **kwargs) 163 | 164 | def extra_repr(self,): 165 | repr = self.my_repr 166 | return repr 167 | 168 | 169 | class SmoothAP(SmoothRankAP): 170 | 171 | def __init__(self, tau=0.01, **kwargs): 172 | rank_approximation = partial(tau_sigmoid, tau=tau) 173 | super().__init__(rank_approximation, **kwargs) 174 | self.tau = tau 175 | 176 | def extra_repr(self,): 177 | repr = f"tau={self.tau}, {self.my_repr}" 178 | return repr 179 | 180 | 181 | class SupAP(SmoothRankAP): 182 | 183 | def __init__(self, tau=0.01, rho=100, offset=None, delta=0.05, start=0.5, **kwargs): 184 | rank_approximation = partial(step_rank, tau=tau, rho=rho, offset=offset, delta=delta, start=start) 185 | super().__init__(rank_approximation, **kwargs) 186 | self.tau = tau 187 | self.rho = rho 188 | self.offset = offset 189 | self.delta = delta 190 | 191 | def extra_repr(self,): 192 | repr = f"tau={self.tau}, rho={self.rho}, offset={self.offset}, delta={self.delta}, {self.my_repr}" 193 | return repr 194 | -------------------------------------------------------------------------------- /roadmap/models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | import timm 6 | 7 | import roadmap.utils as lib 8 | 9 | from .create_projection_head import create_projection_head 10 | 11 | 12 | def get_backbone(name, pretrained=True): 13 | if name == 'resnet18': 14 | lib.LOGGER.info("using ResNet-18") 15 | out_dim = 512 16 | backbone = models.resnet18(pretrained=pretrained) 17 | backbone = nn.Sequential(*list(backbone.children())[:-2]) 18 | pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 19 | elif name == 'resnet50': 20 | lib.LOGGER.info("using ResNet-50") 21 | out_dim = 2048 22 | backbone = models.resnet50(pretrained=pretrained) 23 | backbone = nn.Sequential(*list(backbone.children())[:-2]) 24 | pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 25 | elif name == 'resnet101': 26 | lib.LOGGER.info("using ResNet-101") 27 | out_dim = 2048 28 | backbone = models.resnet101(pretrained=pretrained) 29 | backbone = nn.Sequential(*list(backbone.children())[:-2]) 30 | pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 31 | elif name == 'vit': 32 | lib.LOGGER.info("using ViT-S") 33 | out_dim = 768 34 | backbone = timm.create_model('vit_small_patch16_224', pretrained=pretrained) 35 | backbone.reset_classifier(-1) 36 | pooling = nn.Identity() 37 | elif name == 'vit_deit': 38 | lib.LOGGER.info("using DeiT-S") 39 | out_dim = 384 40 | try: 41 | backbone = timm.create_model('vit_deit_small_patch16_224', pretrained=pretrained) 42 | except RuntimeError: 43 | backbone = timm.create_model('deit_small_patch16_224', pretrained=pretrained) 44 | backbone.reset_classifier(-1) 45 | pooling = nn.Identity() 46 | elif name == 'vit_deit_distilled': 47 | lib.LOGGER.info("using DeiT-S distilled") 48 | try: 49 | deit = timm.create_model('vit_deit_small_patch16_224') 50 | deit_distilled = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=pretrained) 51 | except RuntimeError: 52 | deit = timm.create_model('deit_small_patch16_224') 53 | deit_distilled = timm.create_model('deit_small_distilled_patch16_224', pretrained=pretrained) 54 | deit_distilled.pos_embed = torch.nn.Parameter(torch.cat((deit_distilled.pos_embed[:, :1], deit_distilled.pos_embed[:, 2:]), dim=1)) 55 | deit.load_state_dict(deit_distilled.state_dict(), strict=False) 56 | backbone = deit 57 | backbone.reset_classifier(-1) 58 | out_dim = 384 59 | pooling = nn.Identity() 60 | elif name == 'vit_deit_base': 61 | lib.LOGGER.info("using DeiT-B") 62 | out_dim = 768 63 | try: 64 | backbone = timm.create_model('vit_deit_base_patch16_224', pretrained=pretrained) 65 | except RuntimeError: 66 | backbone = timm.create_model('deit_base_patch16_224', pretrained=pretrained) 67 | backbone.reset_classifier(-1) 68 | pooling = nn.Identity() 69 | elif name == 'vit_deit_base_distilled': 70 | lib.LOGGER.info("using DeiT-B distilled") 71 | try: 72 | deit = timm.create_model('deit_base_patch16_224') 73 | deit_distilled = timm.create_model('deit_base_distilled_patch16_224', pretrained=pretrained) 74 | except RuntimeError: 75 | deit = timm.create_model('deit_base_patch16_224') 76 | deit_distilled = timm.create_model('deit_base_distilled_patch16_224', pretrained=pretrained) 77 | deit_distilled.pos_embed = torch.nn.Parameter(torch.cat((deit_distilled.pos_embed[:, :1], deit_distilled.pos_embed[:, 2:]), dim=1)) 78 | deit.load_state_dict(deit_distilled.state_dict(), strict=False) 79 | backbone = deit 80 | backbone.reset_classifier(-1) 81 | out_dim = 768 82 | pooling = nn.Identity() 83 | elif name == 'vit_deit_base_384': 84 | lib.LOGGER.info("using DeiT-B 384") 85 | out_dim = 768 86 | try: 87 | backbone = timm.create_model('vit_deit_base_patch16_384', pretrained=pretrained) 88 | except RuntimeError: 89 | backbone = timm.create_model('deit_base_patch16_384', pretrained=pretrained) 90 | backbone.reset_classifier(-1) 91 | pooling = nn.Identity() 92 | elif name == 'vit_deit_base_384_distilled': 93 | lib.LOGGER.info("using DeiT-B 384 distilled") 94 | try: 95 | deit = timm.create_model('deit_base_patch16_384') 96 | deit_distilled = timm.create_model('deit_base_distilled_patch16_384', pretrained=pretrained) 97 | except RuntimeError: 98 | deit = timm.create_model('deit_base_patch16_384') 99 | deit_distilled = timm.create_model('deit_base_distilled_patch16_384', pretrained=pretrained) 100 | deit_distilled.pos_embed = torch.nn.Parameter(torch.cat((deit_distilled.pos_embed[:, :1], deit_distilled.pos_embed[:, 2:]), dim=1)) 101 | deit.load_state_dict(deit_distilled.state_dict(), strict=False) 102 | backbone = deit 103 | backbone.reset_classifier(-1) 104 | out_dim = 768 105 | pooling = nn.Identity() 106 | else: 107 | raise ValueError(f"{name} is not recognized") 108 | 109 | return (backbone, pooling, out_dim) 110 | 111 | 112 | class RetrievalNet(nn.Module): 113 | 114 | def __init__( 115 | self, 116 | backbone_name, 117 | embed_dim=512, 118 | norm_features=False, 119 | without_fc=False, 120 | with_autocast=False, 121 | pooling='default', 122 | projection_normalization_layer='none', 123 | pretrained=True, 124 | ): 125 | super().__init__() 126 | 127 | norm_features = lib.str_to_bool(norm_features) 128 | without_fc = lib.str_to_bool(without_fc) 129 | with_autocast = lib.str_to_bool(with_autocast) 130 | 131 | assert isinstance(without_fc, bool) 132 | assert isinstance(norm_features, bool) 133 | assert isinstance(with_autocast, bool) 134 | self.norm_features = norm_features 135 | self.without_fc = without_fc 136 | self.with_autocast = with_autocast 137 | if with_autocast: 138 | lib.LOGGER.info("Using mixed precision") 139 | 140 | self.backbone, default_pooling, out_features = get_backbone(backbone_name, pretrained=pretrained) 141 | if pooling == 'default': 142 | self.pooling = default_pooling 143 | elif pooling == 'none': 144 | self.pooling = nn.Identity() 145 | elif pooling == 'max': 146 | self.pooling = nn.AdaptiveMaxPool2d(output_size=(1, 1)) 147 | elif pooling == 'avg': 148 | self.pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 149 | lib.LOGGER.info(f"Pooling is {self.pooling}") 150 | 151 | if self.norm_features: 152 | lib.LOGGER.info("Using a LayerNorm layer") 153 | self.standardize = nn.LayerNorm(out_features, elementwise_affine=False) 154 | else: 155 | self.standardize = nn.Identity() 156 | 157 | if not self.without_fc: 158 | self.fc = create_projection_head(out_features, embed_dim, projection_normalization_layer) 159 | lib.LOGGER.info(f"Projection head : \n{self.fc}") 160 | else: 161 | self.fc = nn.Identity() 162 | lib.LOGGER.info("Not using a linear projection layer") 163 | 164 | def forward(self, X): 165 | with torch.cuda.amp.autocast(enabled=self.with_autocast): 166 | X = self.backbone(X) 167 | X = self.pooling(X) 168 | 169 | X = X.view(X.size(0), -1) 170 | X = self.standardize(X) 171 | X = self.fc(X) 172 | X = F.normalize(X, p=2, dim=1) 173 | return X 174 | --------------------------------------------------------------------------------